xref: /llvm-project/mlir/lib/Interfaces/FunctionInterfaces.cpp (revision 5262865aac683b72f3e66de7a122e0c455ab6b9b)
1 //===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/FunctionInterfaces.h"
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // Tablegen Interface Definitions
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Interfaces/FunctionInterfaces.cpp.inc"
18 
19 //===----------------------------------------------------------------------===//
20 // Function Arguments and Results.
21 //===----------------------------------------------------------------------===//
22 
23 static bool isEmptyAttrDict(Attribute attr) {
24   return llvm::cast<DictionaryAttr>(attr).empty();
25 }
26 
27 DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op,
28                                                        unsigned index) {
29   ArrayAttr attrs = op.getArgAttrsAttr();
30   DictionaryAttr argAttrs =
31       attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr();
32   return argAttrs;
33 }
34 
35 DictionaryAttr
36 function_interface_impl::getResultAttrDict(FunctionOpInterface op,
37                                            unsigned index) {
38   ArrayAttr attrs = op.getResAttrsAttr();
39   DictionaryAttr resAttrs =
40       attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr();
41   return resAttrs;
42 }
43 
44 ArrayRef<NamedAttribute>
45 function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) {
46   auto argDict = getArgAttrDict(op, index);
47   return argDict ? argDict.getValue() : std::nullopt;
48 }
49 
50 ArrayRef<NamedAttribute>
51 function_interface_impl::getResultAttrs(FunctionOpInterface op,
52                                         unsigned index) {
53   auto resultDict = getResultAttrDict(op, index);
54   return resultDict ? resultDict.getValue() : std::nullopt;
55 }
56 
57 /// Get either the argument or result attributes array.
58 template <bool isArg>
59 static ArrayAttr getArgResAttrs(FunctionOpInterface op) {
60   if constexpr (isArg)
61     return op.getArgAttrsAttr();
62   else
63     return op.getResAttrsAttr();
64 }
65 
66 /// Set either the argument or result attributes array.
67 template <bool isArg>
68 static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) {
69   if constexpr (isArg)
70     op.setArgAttrsAttr(attrs);
71   else
72     op.setResAttrsAttr(attrs);
73 }
74 
75 /// Erase either the argument or result attributes array.
76 template <bool isArg>
77 static void removeArgResAttrs(FunctionOpInterface op) {
78   if constexpr (isArg)
79     op.removeArgAttrsAttr();
80   else
81     op.removeResAttrsAttr();
82 }
83 
84 /// Set all of the argument or result attribute dictionaries for a function.
85 template <bool isArg>
86 static void setAllArgResAttrDicts(FunctionOpInterface op,
87                                   ArrayRef<Attribute> attrs) {
88   if (llvm::all_of(attrs, isEmptyAttrDict))
89     removeArgResAttrs<isArg>(op);
90   else
91     setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), attrs));
92 }
93 
94 void function_interface_impl::setAllArgAttrDicts(
95     FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
96   setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
97 }
98 
99 void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op,
100                                                  ArrayRef<Attribute> attrs) {
101   auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
102     return !attr ? DictionaryAttr::get(op->getContext()) : attr;
103   });
104   setAllArgResAttrDicts</*isArg=*/true>(op, llvm::to_vector<8>(wrappedAttrs));
105 }
106 
107 void function_interface_impl::setAllResultAttrDicts(
108     FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
109   setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
110 }
111 
112 void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op,
113                                                     ArrayRef<Attribute> attrs) {
114   auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
115     return !attr ? DictionaryAttr::get(op->getContext()) : attr;
116   });
117   setAllArgResAttrDicts</*isArg=*/false>(op, llvm::to_vector<8>(wrappedAttrs));
118 }
119 
120 /// Update the given index into an argument or result attribute dictionary.
121 template <bool isArg>
122 static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices,
123                               unsigned index, DictionaryAttr attrs) {
124   ArrayAttr allAttrs = getArgResAttrs<isArg>(op);
125   if (!allAttrs) {
126     if (attrs.empty())
127       return;
128 
129     // If this attribute is not empty, we need to create a new attribute array.
130     SmallVector<Attribute, 8> newAttrs(numTotalIndices,
131                                        DictionaryAttr::get(op->getContext()));
132     newAttrs[index] = attrs;
133     setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
134     return;
135   }
136   // Check to see if the attribute is different from what we already have.
137   if (allAttrs[index] == attrs)
138     return;
139 
140   // If it is, check to see if the attribute array would now contain only empty
141   // dictionaries.
142   ArrayRef<Attribute> rawAttrArray = allAttrs.getValue();
143   if (attrs.empty() &&
144       llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) &&
145       llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict))
146     return removeArgResAttrs<isArg>(op);
147 
148   // Otherwise, create a new attribute array with the updated dictionary.
149   SmallVector<Attribute, 8> newAttrs(rawAttrArray);
150   newAttrs[index] = attrs;
151   setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
152 }
153 
154 void function_interface_impl::setArgAttrs(FunctionOpInterface op,
155                                           unsigned index,
156                                           ArrayRef<NamedAttribute> attributes) {
157   assert(index < op.getNumArguments() && "invalid argument number");
158   return setArgResAttrDict</*isArg=*/true>(
159       op, op.getNumArguments(), index,
160       DictionaryAttr::get(op->getContext(), attributes));
161 }
162 
163 void function_interface_impl::setArgAttrs(FunctionOpInterface op,
164                                           unsigned index,
165                                           DictionaryAttr attributes) {
166   return setArgResAttrDict</*isArg=*/true>(
167       op, op.getNumArguments(), index,
168       attributes ? attributes : DictionaryAttr::get(op->getContext()));
169 }
170 
171 void function_interface_impl::setResultAttrs(
172     FunctionOpInterface op, unsigned index,
173     ArrayRef<NamedAttribute> attributes) {
174   assert(index < op.getNumResults() && "invalid result number");
175   return setArgResAttrDict</*isArg=*/false>(
176       op, op.getNumResults(), index,
177       DictionaryAttr::get(op->getContext(), attributes));
178 }
179 
180 void function_interface_impl::setResultAttrs(FunctionOpInterface op,
181                                              unsigned index,
182                                              DictionaryAttr attributes) {
183   assert(index < op.getNumResults() && "invalid result number");
184   return setArgResAttrDict</*isArg=*/false>(
185       op, op.getNumResults(), index,
186       attributes ? attributes : DictionaryAttr::get(op->getContext()));
187 }
188 
189 void function_interface_impl::insertFunctionArguments(
190     FunctionOpInterface op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
191     ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs,
192     unsigned originalNumArgs, Type newType) {
193   assert(argIndices.size() == argTypes.size());
194   assert(argIndices.size() == argAttrs.size() || argAttrs.empty());
195   assert(argIndices.size() == argLocs.size());
196   if (argIndices.empty())
197     return;
198 
199   // There are 3 things that need to be updated:
200   // - Function type.
201   // - Arg attrs.
202   // - Block arguments of entry block.
203   Block &entry = op->getRegion(0).front();
204 
205   // Update the argument attributes of the function.
206   ArrayAttr oldArgAttrs = op.getArgAttrsAttr();
207   if (oldArgAttrs || !argAttrs.empty()) {
208     SmallVector<DictionaryAttr, 4> newArgAttrs;
209     newArgAttrs.reserve(originalNumArgs + argIndices.size());
210     unsigned oldIdx = 0;
211     auto migrate = [&](unsigned untilIdx) {
212       if (!oldArgAttrs) {
213         newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx);
214       } else {
215         auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>();
216         newArgAttrs.append(oldArgAttrRange.begin() + oldIdx,
217                            oldArgAttrRange.begin() + untilIdx);
218       }
219       oldIdx = untilIdx;
220     };
221     for (unsigned i = 0, e = argIndices.size(); i < e; ++i) {
222       migrate(argIndices[i]);
223       newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]);
224     }
225     migrate(originalNumArgs);
226     setAllArgAttrDicts(op, newArgAttrs);
227   }
228 
229   // Update the function type and any entry block arguments.
230   op.setFunctionTypeAttr(TypeAttr::get(newType));
231   for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
232     entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]);
233 }
234 
235 void function_interface_impl::insertFunctionResults(
236     FunctionOpInterface op, ArrayRef<unsigned> resultIndices,
237     TypeRange resultTypes, ArrayRef<DictionaryAttr> resultAttrs,
238     unsigned originalNumResults, Type newType) {
239   assert(resultIndices.size() == resultTypes.size());
240   assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
241   if (resultIndices.empty())
242     return;
243 
244   // There are 2 things that need to be updated:
245   // - Function type.
246   // - Result attrs.
247 
248   // Update the result attributes of the function.
249   ArrayAttr oldResultAttrs = op.getResAttrsAttr();
250   if (oldResultAttrs || !resultAttrs.empty()) {
251     SmallVector<DictionaryAttr, 4> newResultAttrs;
252     newResultAttrs.reserve(originalNumResults + resultIndices.size());
253     unsigned oldIdx = 0;
254     auto migrate = [&](unsigned untilIdx) {
255       if (!oldResultAttrs) {
256         newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx);
257       } else {
258         auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>();
259         newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx,
260                               oldResultAttrsRange.begin() + untilIdx);
261       }
262       oldIdx = untilIdx;
263     };
264     for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) {
265       migrate(resultIndices[i]);
266       newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{}
267                                                    : resultAttrs[i]);
268     }
269     migrate(originalNumResults);
270     setAllResultAttrDicts(op, newResultAttrs);
271   }
272 
273   // Update the function type.
274   op.setFunctionTypeAttr(TypeAttr::get(newType));
275 }
276 
277 void function_interface_impl::eraseFunctionArguments(
278     FunctionOpInterface op, const BitVector &argIndices, Type newType) {
279   // There are 3 things that need to be updated:
280   // - Function type.
281   // - Arg attrs.
282   // - Block arguments of entry block.
283   Block &entry = op->getRegion(0).front();
284 
285   // Update the argument attributes of the function.
286   if (ArrayAttr argAttrs = op.getArgAttrsAttr()) {
287     SmallVector<DictionaryAttr, 4> newArgAttrs;
288     newArgAttrs.reserve(argAttrs.size());
289     for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
290       if (!argIndices[i])
291         newArgAttrs.emplace_back(llvm::cast<DictionaryAttr>(argAttrs[i]));
292     setAllArgAttrDicts(op, newArgAttrs);
293   }
294 
295   // Update the function type and any entry block arguments.
296   op.setFunctionTypeAttr(TypeAttr::get(newType));
297   entry.eraseArguments(argIndices);
298 }
299 
300 void function_interface_impl::eraseFunctionResults(
301     FunctionOpInterface op, const BitVector &resultIndices, Type newType) {
302   // There are 2 things that need to be updated:
303   // - Function type.
304   // - Result attrs.
305 
306   // Update the result attributes of the function.
307   if (ArrayAttr resAttrs = op.getResAttrsAttr()) {
308     SmallVector<DictionaryAttr, 4> newResultAttrs;
309     newResultAttrs.reserve(resAttrs.size());
310     for (unsigned i = 0, e = resultIndices.size(); i < e; ++i)
311       if (!resultIndices[i])
312         newResultAttrs.emplace_back(llvm::cast<DictionaryAttr>(resAttrs[i]));
313     setAllResultAttrDicts(op, newResultAttrs);
314   }
315 
316   // Update the function type.
317   op.setFunctionTypeAttr(TypeAttr::get(newType));
318 }
319 
320 //===----------------------------------------------------------------------===//
321 // Function type signature.
322 //===----------------------------------------------------------------------===//
323 
324 void function_interface_impl::setFunctionType(FunctionOpInterface op,
325                                               Type newType) {
326   unsigned oldNumArgs = op.getNumArguments();
327   unsigned oldNumResults = op.getNumResults();
328   op.setFunctionTypeAttr(TypeAttr::get(newType));
329   unsigned newNumArgs = op.getNumArguments();
330   unsigned newNumResults = op.getNumResults();
331 
332   // Functor used to update the argument and result attributes of the function.
333   auto emptyDict = DictionaryAttr::get(op.getContext());
334   auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) {
335     constexpr bool isArgVal = std::is_same_v<decltype(isArg), std::true_type>;
336 
337     if (oldCount == newCount)
338       return;
339     // The new type has no arguments/results, just drop the attribute.
340     if (newCount == 0)
341       return removeArgResAttrs<isArgVal>(op);
342     ArrayAttr attrs = getArgResAttrs<isArgVal>(op);
343     if (!attrs)
344       return;
345 
346     // The new type has less arguments/results, take the first N attributes.
347     if (newCount < oldCount)
348       return setAllArgResAttrDicts<isArgVal>(
349           op, attrs.getValue().take_front(newCount));
350 
351     // Otherwise, the new type has more arguments/results. Initialize the new
352     // arguments/results with empty dictionary attributes.
353     SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end());
354     newAttrs.resize(newCount, emptyDict);
355     setAllArgResAttrDicts<isArgVal>(op, newAttrs);
356   };
357 
358   // Update the argument and result attributes.
359   updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs);
360   updateAttrFn(std::false_type{}, oldNumResults, newNumResults);
361 }
362