xref: /llvm-project/mlir/include/mlir/Interfaces/FunctionInterfaces.h (revision 34a35a8b244243f5a4ad5d531007bccfeaa0b02e)
1 //===- FunctionSupport.h - Utility types for function-like ops --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines support types for Operations that represent function-like
10 // constructs to use.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_IR_FUNCTIONINTERFACES_H
15 #define MLIR_IR_FUNCTIONINTERFACES_H
16 
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/OpDefinition.h"
20 #include "mlir/IR/SymbolTable.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/Interfaces/CallInterfaces.h"
23 #include "llvm/ADT/BitVector.h"
24 #include "llvm/ADT/SmallString.h"
25 
26 namespace mlir {
27 class FunctionOpInterface;
28 
29 namespace function_interface_impl {
30 
31 /// Returns the dictionary attribute corresponding to the argument at 'index'.
32 /// If there are no argument attributes at 'index', a null attribute is
33 /// returned.
34 DictionaryAttr getArgAttrDict(FunctionOpInterface op, unsigned index);
35 
36 /// Returns the dictionary attribute corresponding to the result at 'index'.
37 /// If there are no result attributes at 'index', a null attribute is
38 /// returned.
39 DictionaryAttr getResultAttrDict(FunctionOpInterface op, unsigned index);
40 
41 /// Return all of the attributes for the argument at 'index'.
42 ArrayRef<NamedAttribute> getArgAttrs(FunctionOpInterface op, unsigned index);
43 
44 /// Return all of the attributes for the result at 'index'.
45 ArrayRef<NamedAttribute> getResultAttrs(FunctionOpInterface op, unsigned index);
46 
47 /// Set all of the argument or result attribute dictionaries for a function. The
48 /// size of `attrs` is expected to match the number of arguments/results of the
49 /// given `op`.
50 void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs);
51 void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef<Attribute> attrs);
52 void setAllResultAttrDicts(FunctionOpInterface op,
53                            ArrayRef<DictionaryAttr> attrs);
54 void setAllResultAttrDicts(FunctionOpInterface op, ArrayRef<Attribute> attrs);
55 
56 /// Insert the specified arguments and update the function type attribute.
57 void insertFunctionArguments(FunctionOpInterface op,
58                              ArrayRef<unsigned> argIndices, TypeRange argTypes,
59                              ArrayRef<DictionaryAttr> argAttrs,
60                              ArrayRef<Location> argLocs,
61                              unsigned originalNumArgs, Type newType);
62 
63 /// Insert the specified results and update the function type attribute.
64 void insertFunctionResults(FunctionOpInterface op,
65                            ArrayRef<unsigned> resultIndices,
66                            TypeRange resultTypes,
67                            ArrayRef<DictionaryAttr> resultAttrs,
68                            unsigned originalNumResults, Type newType);
69 
70 /// Erase the specified arguments and update the function type attribute.
71 void eraseFunctionArguments(FunctionOpInterface op, const BitVector &argIndices,
72                             Type newType);
73 
74 /// Erase the specified results and update the function type attribute.
75 void eraseFunctionResults(FunctionOpInterface op,
76                           const BitVector &resultIndices, Type newType);
77 
78 /// Set a FunctionOpInterface operation's type signature.
79 void setFunctionType(FunctionOpInterface op, Type newType);
80 
81 //===----------------------------------------------------------------------===//
82 // Function Argument Attribute.
83 //===----------------------------------------------------------------------===//
84 
85 /// Set the attributes held by the argument at 'index'.
86 void setArgAttrs(FunctionOpInterface op, unsigned index,
87                  ArrayRef<NamedAttribute> attributes);
88 void setArgAttrs(FunctionOpInterface op, unsigned index,
89                  DictionaryAttr attributes);
90 
91 /// If the an attribute exists with the specified name, change it to the new
92 /// value. Otherwise, add a new attribute with the specified name/value.
93 template <typename ConcreteType>
setArgAttr(ConcreteType op,unsigned index,StringAttr name,Attribute value)94 void setArgAttr(ConcreteType op, unsigned index, StringAttr name,
95                 Attribute value) {
96   NamedAttrList attributes(op.getArgAttrDict(index));
97   Attribute oldValue = attributes.set(name, value);
98 
99   // If the attribute changed, then set the new arg attribute list.
100   if (value != oldValue)
101     op.setArgAttrs(index, attributes.getDictionary(value.getContext()));
102 }
103 
104 /// Remove the attribute 'name' from the argument at 'index'. Returns the
105 /// removed attribute, or nullptr if `name` was not a valid attribute.
106 template <typename ConcreteType>
removeArgAttr(ConcreteType op,unsigned index,StringAttr name)107 Attribute removeArgAttr(ConcreteType op, unsigned index, StringAttr name) {
108   // Build an attribute list and remove the attribute at 'name'.
109   NamedAttrList attributes(op.getArgAttrDict(index));
110   Attribute removedAttr = attributes.erase(name);
111 
112   // If the attribute was removed, then update the argument dictionary.
113   if (removedAttr)
114     op.setArgAttrs(index, attributes.getDictionary(removedAttr.getContext()));
115   return removedAttr;
116 }
117 
118 //===----------------------------------------------------------------------===//
119 // Function Result Attribute.
120 //===----------------------------------------------------------------------===//
121 
122 /// Set the attributes held by the result at 'index'.
123 void setResultAttrs(FunctionOpInterface op, unsigned index,
124                     ArrayRef<NamedAttribute> attributes);
125 void setResultAttrs(FunctionOpInterface op, unsigned index,
126                     DictionaryAttr attributes);
127 
128 /// If the an attribute exists with the specified name, change it to the new
129 /// value. Otherwise, add a new attribute with the specified name/value.
130 template <typename ConcreteType>
setResultAttr(ConcreteType op,unsigned index,StringAttr name,Attribute value)131 void setResultAttr(ConcreteType op, unsigned index, StringAttr name,
132                    Attribute value) {
133   NamedAttrList attributes(op.getResultAttrDict(index));
134   Attribute oldAttr = attributes.set(name, value);
135 
136   // If the attribute changed, then set the new arg attribute list.
137   if (oldAttr != value)
138     op.setResultAttrs(index, attributes.getDictionary(value.getContext()));
139 }
140 
141 /// Remove the attribute 'name' from the result at 'index'.
142 template <typename ConcreteType>
removeResultAttr(ConcreteType op,unsigned index,StringAttr name)143 Attribute removeResultAttr(ConcreteType op, unsigned index, StringAttr name) {
144   // Build an attribute list and remove the attribute at 'name'.
145   NamedAttrList attributes(op.getResultAttrDict(index));
146   Attribute removedAttr = attributes.erase(name);
147 
148   // If the attribute was removed, then update the result dictionary.
149   if (removedAttr)
150     op.setResultAttrs(index,
151                       attributes.getDictionary(removedAttr.getContext()));
152   return removedAttr;
153 }
154 
155 /// This function defines the internal implementation of the `verifyTrait`
156 /// method on FunctionOpInterface::Trait.
157 template <typename ConcreteOp>
verifyTrait(ConcreteOp op)158 LogicalResult verifyTrait(ConcreteOp op) {
159   if (failed(op.verifyType()))
160     return failure();
161 
162   if (ArrayAttr allArgAttrs = op.getAllArgAttrs()) {
163     unsigned numArgs = op.getNumArguments();
164     if (allArgAttrs.size() != numArgs) {
165       return op.emitOpError()
166              << "expects argument attribute array to have the same number of "
167                 "elements as the number of function arguments, got "
168              << allArgAttrs.size() << ", but expected " << numArgs;
169     }
170     for (unsigned i = 0; i != numArgs; ++i) {
171       DictionaryAttr argAttrs =
172           llvm::dyn_cast_or_null<DictionaryAttr>(allArgAttrs[i]);
173       if (!argAttrs) {
174         return op.emitOpError() << "expects argument attribute dictionary "
175                                    "to be a DictionaryAttr, but got `"
176                                 << allArgAttrs[i] << "`";
177       }
178 
179       // Verify that all of the argument attributes are dialect attributes, i.e.
180       // that they contain a dialect prefix in their name.  Call the dialect, if
181       // registered, to verify the attributes themselves.
182       for (auto attr : argAttrs) {
183         if (!attr.getName().strref().contains('.'))
184           return op.emitOpError("arguments may only have dialect attributes");
185         if (Dialect *dialect = attr.getNameDialect()) {
186           if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0,
187                                                        /*argIndex=*/i, attr)))
188             return failure();
189         }
190       }
191     }
192   }
193   if (ArrayAttr allResultAttrs = op.getAllResultAttrs()) {
194     unsigned numResults = op.getNumResults();
195     if (allResultAttrs.size() != numResults) {
196       return op.emitOpError()
197              << "expects result attribute array to have the same number of "
198                 "elements as the number of function results, got "
199              << allResultAttrs.size() << ", but expected " << numResults;
200     }
201     for (unsigned i = 0; i != numResults; ++i) {
202       DictionaryAttr resultAttrs =
203           llvm::dyn_cast_or_null<DictionaryAttr>(allResultAttrs[i]);
204       if (!resultAttrs) {
205         return op.emitOpError() << "expects result attribute dictionary "
206                                    "to be a DictionaryAttr, but got `"
207                                 << allResultAttrs[i] << "`";
208       }
209 
210       // Verify that all of the result attributes are dialect attributes, i.e.
211       // that they contain a dialect prefix in their name.  Call the dialect, if
212       // registered, to verify the attributes themselves.
213       for (auto attr : resultAttrs) {
214         if (!attr.getName().strref().contains('.'))
215           return op.emitOpError("results may only have dialect attributes");
216         if (Dialect *dialect = attr.getNameDialect()) {
217           if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0,
218                                                           /*resultIndex=*/i,
219                                                           attr)))
220             return failure();
221         }
222       }
223     }
224   }
225 
226   // Check that the op has exactly one region for the body.
227   if (op->getNumRegions() != 1)
228     return op.emitOpError("expects one region");
229 
230   return op.verifyBody();
231 }
232 } // namespace function_interface_impl
233 } // namespace mlir
234 
235 //===----------------------------------------------------------------------===//
236 // Tablegen Interface Declarations
237 //===----------------------------------------------------------------------===//
238 
239 #include "mlir/Interfaces/FunctionInterfaces.h.inc"
240 
241 #endif // MLIR_IR_FUNCTIONINTERFACES_H
242