xref: /llvm-project/mlir/include/mlir/TableGen/Operator.h (revision e768b076e3b7ed38485a29244a0b989076e4b131)
1 //===- Operator.h - Operator class ------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Operator wrapper to simplify using TableGen Record defining a MLIR Op.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_TABLEGEN_OPERATOR_H_
14 #define MLIR_TABLEGEN_OPERATOR_H_
15 
16 #include "mlir/Support/LLVM.h"
17 #include "mlir/TableGen/Argument.h"
18 #include "mlir/TableGen/Attribute.h"
19 #include "mlir/TableGen/Builder.h"
20 #include "mlir/TableGen/Dialect.h"
21 #include "mlir/TableGen/Property.h"
22 #include "mlir/TableGen/Region.h"
23 #include "mlir/TableGen/Successor.h"
24 #include "mlir/TableGen/Trait.h"
25 #include "mlir/TableGen/Type.h"
26 #include "llvm/ADT/PointerUnion.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringMap.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/Support/SMLoc.h"
31 
32 namespace llvm {
33 class DefInit;
34 class Record;
35 class StringInit;
36 } // namespace llvm
37 
38 namespace mlir {
39 namespace tblgen {
40 
41 /// This class represents an inferred result type. The result type can be
42 /// inferred from an argument or result type. If it is inferred from another
43 /// result type, that type must be buildable or inferred from yet another type.
44 class InferredResultType {
45 public:
46   InferredResultType(int index, std::string transformer)
47       : index(index), transformer(std::move(transformer)) {}
48 
49   /// Returns true if result type is inferred from an argument type.
50   bool isArg() const { return isArgIndex(index); }
51   /// Return the mapped argument or result index.
52   int getIndex() const { return index; }
53   /// If the type is inferred from a result, return the result index.
54   int getResultIndex() const { return unmapResultIndex(index); }
55 
56   // Mapping from result index to combined argument and result index.
57   // Arguments are indexed to match getArg index, while the result indexes are
58   // mapped to avoid overlap.
59   static int mapResultIndex(int i) { return -1 - i; }
60   static int unmapResultIndex(int i) { return -i - 1; }
61   static bool isResultIndex(int i) { return i < 0; }
62   static bool isArgIndex(int i) { return i >= 0; }
63 
64   StringRef getTransformer() const { return transformer; }
65 
66 private:
67   /// The index of the source argument or result.
68   int index;
69 
70   /// The transfer to apply to the type to obtain the inferred type.
71   std::string transformer;
72 };
73 
74 /// Wrapper class that contains a MLIR op's information (e.g., operands,
75 /// attributes) defined in TableGen and provides helper methods for
76 /// accessing them.
77 class Operator {
78 public:
79   explicit Operator(const llvm::Record &def);
80   explicit Operator(const llvm::Record *def) : Operator(*def) {}
81 
82   /// Returns this op's dialect name.
83   StringRef getDialectName() const;
84 
85   /// Returns the operation name. The name will follow the "<dialect>.<op-name>"
86   /// format if its dialect name is not empty.
87   std::string getOperationName() const;
88 
89   /// Returns this op's C++ class name.
90   StringRef getCppClassName() const;
91 
92   /// Returns this op's C++ class name prefixed with namespaces.
93   std::string getQualCppClassName() const;
94 
95   /// Returns this op's C++ namespace.
96   StringRef getCppNamespace() const;
97 
98   /// Returns the name of op's adaptor C++ class.
99   std::string getAdaptorName() const;
100 
101   /// Returns the name of op's generic adaptor C++ class.
102   std::string getGenericAdaptorName() const;
103 
104   /// Check invariants (like no duplicated or conflicted names) and abort the
105   /// process if any invariant is broken.
106   void assertInvariants() const;
107 
108   /// A class used to represent the decorators of an operator variable, i.e.
109   /// argument or result.
110   struct VariableDecorator {
111   public:
112     explicit VariableDecorator(const llvm::Record *def) : def(def) {}
113     const llvm::Record &getDef() const { return *def; }
114 
115   protected:
116     /// The TableGen definition of this decorator.
117     const llvm::Record *def;
118   };
119 
120   /// A utility iterator over a list of variable decorators.
121   struct VariableDecoratorIterator
122       : public llvm::mapped_iterator<const llvm::Init *const *,
123                                      VariableDecorator (*)(
124                                          const llvm::Init *)> {
125     /// Initializes the iterator to the specified iterator.
126     VariableDecoratorIterator(const llvm::Init *const *it)
127         : llvm::mapped_iterator<const llvm::Init *const *,
128                                 VariableDecorator (*)(const llvm::Init *)>(
129               it, &unwrap) {}
130     static VariableDecorator unwrap(const llvm::Init *init);
131   };
132   using var_decorator_iterator = VariableDecoratorIterator;
133   using var_decorator_range = llvm::iterator_range<VariableDecoratorIterator>;
134 
135   using value_iterator = NamedTypeConstraint *;
136   using const_value_iterator = const NamedTypeConstraint *;
137   using value_range = llvm::iterator_range<value_iterator>;
138   using const_value_range = llvm::iterator_range<const_value_iterator>;
139 
140   /// Returns true if this op has variable length operands or results.
141   bool isVariadic() const;
142 
143   /// Returns true if default builders should not be generated.
144   bool skipDefaultBuilders() const;
145 
146   /// Op result iterators.
147   const_value_iterator result_begin() const;
148   const_value_iterator result_end() const;
149   const_value_range getResults() const;
150 
151   /// Returns the number of results this op produces.
152   int getNumResults() const;
153 
154   /// Returns the op result at the given `index`.
155   NamedTypeConstraint &getResult(int index) { return results[index]; }
156   const NamedTypeConstraint &getResult(int index) const {
157     return results[index];
158   }
159 
160   /// Returns the `index`-th result's type constraint.
161   TypeConstraint getResultTypeConstraint(int index) const;
162   /// Returns the `index`-th result's name.
163   StringRef getResultName(int index) const;
164   /// Returns the `index`-th result's decorators.
165   var_decorator_range getResultDecorators(int index) const;
166 
167   /// Returns the number of variable length results in this operation.
168   unsigned getNumVariableLengthResults() const;
169 
170   /// Op attribute iterators.
171   using const_attribute_iterator = const NamedAttribute *;
172   const_attribute_iterator attribute_begin() const;
173   const_attribute_iterator attribute_end() const;
174   llvm::iterator_range<const_attribute_iterator> getAttributes() const;
175   using attribute_iterator = NamedAttribute *;
176   attribute_iterator attribute_begin();
177   attribute_iterator attribute_end();
178   llvm::iterator_range<attribute_iterator> getAttributes();
179 
180   int getNumAttributes() const { return attributes.size(); }
181   int getNumNativeAttributes() const { return numNativeAttributes; }
182 
183   /// Op attribute accessors.
184   NamedAttribute &getAttribute(int index) { return attributes[index]; }
185   const NamedAttribute &getAttribute(int index) const {
186     return attributes[index];
187   }
188 
189   /// Op operand iterators.
190   const_value_iterator operand_begin() const;
191   const_value_iterator operand_end() const;
192   const_value_range getOperands() const;
193 
194   // Op properties iterators.
195   using const_property_iterator = const NamedProperty *;
196   const_property_iterator properties_begin() const {
197     return properties.begin();
198   }
199   const_property_iterator properties_end() const { return properties.end(); }
200   llvm::iterator_range<const_property_iterator> getProperties() const {
201     return properties;
202   }
203   using property_iterator = NamedProperty *;
204   property_iterator properties_begin() { return properties.begin(); }
205   property_iterator properties_end() { return properties.end(); }
206   llvm::iterator_range<property_iterator> getProperties() { return properties; }
207   int getNumCoreAttributes() const { return properties.size(); }
208 
209   // Op properties accessors.
210   NamedProperty &getProperty(int index) { return properties[index]; }
211   const NamedProperty &getProperty(int index) const {
212     return properties[index];
213   }
214 
215   int getNumOperands() const { return operands.size(); }
216   NamedTypeConstraint &getOperand(int index) { return operands[index]; }
217   const NamedTypeConstraint &getOperand(int index) const {
218     return operands[index];
219   }
220 
221   /// Returns the number of variadic operands in this operation.
222   unsigned getNumVariableLengthOperands() const;
223 
224   /// Returns the total number of arguments.
225   int getNumArgs() const { return arguments.size(); }
226 
227   /// Returns true of the operation has a single variadic arg.
228   bool hasSingleVariadicArg() const;
229 
230   /// Returns true if the operation has a single variadic result.
231   bool hasSingleVariadicResult() const {
232     return getNumResults() == 1 && getResult(0).isVariadic();
233   }
234 
235   /// Returns true of the operation has no variadic regions.
236   bool hasNoVariadicRegions() const { return getNumVariadicRegions() == 0; }
237 
238   using arg_iterator = const Argument *;
239   using arg_range = llvm::iterator_range<arg_iterator>;
240 
241   /// Op argument (attribute or operand) iterators.
242   arg_iterator arg_begin() const;
243   arg_iterator arg_end() const;
244   arg_range getArgs() const;
245 
246   /// Op argument (attribute or operand) accessors.
247   Argument getArg(int index) const;
248   StringRef getArgName(int index) const;
249   var_decorator_range getArgDecorators(int index) const;
250 
251   /// Returns the trait wrapper for the given MLIR C++ `trait`.
252   const Trait *getTrait(llvm::StringRef trait) const;
253 
254   /// Regions.
255   using const_region_iterator = const NamedRegion *;
256   const_region_iterator region_begin() const;
257   const_region_iterator region_end() const;
258   llvm::iterator_range<const_region_iterator> getRegions() const;
259 
260   /// Returns the number of regions.
261   unsigned getNumRegions() const;
262   /// Returns the `index`-th region.
263   const NamedRegion &getRegion(unsigned index) const;
264 
265   /// Returns the number of variadic regions in this operation.
266   unsigned getNumVariadicRegions() const;
267 
268   /// Successors.
269   using const_successor_iterator = const NamedSuccessor *;
270   const_successor_iterator successor_begin() const;
271   const_successor_iterator successor_end() const;
272   llvm::iterator_range<const_successor_iterator> getSuccessors() const;
273 
274   /// Returns the number of successors.
275   unsigned getNumSuccessors() const;
276   /// Returns the `index`-th successor.
277   const NamedSuccessor &getSuccessor(unsigned index) const;
278 
279   /// Returns the number of variadic successors in this operation.
280   unsigned getNumVariadicSuccessors() const;
281 
282   /// Trait.
283   using const_trait_iterator = const Trait *;
284   const_trait_iterator trait_begin() const;
285   const_trait_iterator trait_end() const;
286   llvm::iterator_range<const_trait_iterator> getTraits() const;
287 
288   ArrayRef<SMLoc> getLoc() const;
289 
290   /// Query functions for the documentation of the operator.
291   bool hasDescription() const;
292   StringRef getDescription() const;
293   bool hasSummary() const;
294   StringRef getSummary() const;
295 
296   /// Query functions for the assembly format of the operator.
297   bool hasAssemblyFormat() const;
298   StringRef getAssemblyFormat() const;
299 
300   /// Returns this op's extra class declaration code.
301   StringRef getExtraClassDeclaration() const;
302 
303   /// Returns this op's extra class definition code.
304   StringRef getExtraClassDefinition() const;
305 
306   /// Returns the Tablegen definition this operator was constructed from.
307   /// TODO: do not expose the TableGen record, this is a temporary solution to
308   /// OpEmitter requiring a Record because Operator does not provide enough
309   /// methods.
310   const llvm::Record &getDef() const;
311 
312   /// Returns the dialect of the op.
313   const Dialect &getDialect() const { return dialect; }
314 
315   /// Prints the contents in this operator to the given `os`. This is used for
316   /// debugging purposes.
317   void print(llvm::raw_ostream &os) const;
318 
319   /// Return whether all the result types are known.
320   bool allResultTypesKnown() const { return allResultsHaveKnownTypes; };
321 
322   ///  Return all arguments or type constraints with same type as result[index].
323   /// Requires: all result types are known.
324   const InferredResultType &getInferredResultType(int index) const;
325 
326   /// Pair consisting kind of argument and index into operands or attributes.
327   struct OperandOrAttribute {
328     enum class Kind { Operand, Attribute };
329     OperandOrAttribute(Kind kind, int index) {
330       packed = (index << 1) | (kind == Kind::Attribute);
331     }
332     int operandOrAttributeIndex() const { return (packed >> 1); }
333     Kind kind() { return (packed & 0x1) ? Kind::Attribute : Kind::Operand; }
334 
335   private:
336     int packed;
337   };
338 
339   /// Returns the OperandOrAttribute corresponding to the index.
340   OperandOrAttribute getArgToOperandOrAttribute(int index) const;
341 
342   /// Returns the builders of this operation.
343   ArrayRef<Builder> getBuilders() const { return builders; }
344 
345   /// Returns the getter name for the accessor of `name`.
346   std::string getGetterName(StringRef name) const;
347 
348   /// Returns the setter name for the accessor of `name`.
349   std::string getSetterName(StringRef name) const;
350 
351   /// Returns the remove name for the accessor of `name`.
352   std::string getRemoverName(StringRef name) const;
353 
354   bool hasFolder() const;
355 
356   /// Whether to generate the `readProperty`/`writeProperty` methods for
357   /// bytecode emission.
358   bool useCustomPropertiesEncoding() const;
359 
360 private:
361   /// Populates the vectors containing operands, attributes, results and traits.
362   void populateOpStructure();
363 
364   /// Populates type inference info (mostly equality) with input a mapping from
365   /// names to indices for arguments and results.
366   void populateTypeInferenceInfo(
367       const llvm::StringMap<int> &argumentsAndResultsIndex);
368 
369   /// The dialect of this op.
370   Dialect dialect;
371 
372   /// The unqualified C++ class name of the op.
373   StringRef cppClassName;
374 
375   /// The C++ namespace for this op.
376   StringRef cppNamespace;
377 
378   /// The operands of the op.
379   SmallVector<NamedTypeConstraint, 4> operands;
380 
381   /// The attributes of the op.  Contains native attributes (corresponding to
382   /// the actual stored attributed of the operation) followed by derived
383   /// attributes (corresponding to dynamic properties of the operation that are
384   /// computed upon request).
385   SmallVector<NamedAttribute, 4> attributes;
386 
387   /// The properties of the op.
388   SmallVector<NamedProperty, 4> properties;
389 
390   /// The arguments of the op (operands and native attributes).
391   SmallVector<Argument, 4> arguments;
392 
393   /// The results of the op.
394   SmallVector<NamedTypeConstraint, 4> results;
395 
396   /// The successors of this op.
397   SmallVector<NamedSuccessor, 0> successors;
398 
399   /// The traits of the op.
400   SmallVector<Trait, 4> traits;
401 
402   /// The regions of this op.
403   SmallVector<NamedRegion, 1> regions;
404 
405   /// The argument with the same type as the result.
406   SmallVector<InferredResultType> resultTypeMapping;
407 
408   /// Map from argument to attribute or operand number.
409   SmallVector<OperandOrAttribute, 4> attrOrOperandMapping;
410 
411   /// The builders of this operator.
412   SmallVector<Builder> builders;
413 
414   /// The number of native attributes stored in the leading positions of
415   /// `attributes`.
416   int numNativeAttributes;
417 
418   /// The TableGen definition of this op.
419   const llvm::Record &def;
420 
421   /// Whether the type of all results are known.
422   bool allResultsHaveKnownTypes;
423 };
424 
425 } // namespace tblgen
426 } // namespace mlir
427 
428 #endif // MLIR_TABLEGEN_OPERATOR_H_
429