xref: /llvm-project/mlir/include/mlir/IR/Dialect.h (revision 9ce8f4b70b31b031ac9b4818a268bfc8c67a7a8e)
1 //===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the 'dialect' abstraction.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_DIALECT_H
14 #define MLIR_IR_DIALECT_H
15 
16 #include "mlir/IR/DialectRegistry.h"
17 #include "mlir/IR/OperationSupport.h"
18 #include "mlir/Support/TypeID.h"
19 
20 namespace mlir {
21 class DialectAsmParser;
22 class DialectAsmPrinter;
23 class DialectInterface;
24 class OpBuilder;
25 class Type;
26 
27 //===----------------------------------------------------------------------===//
28 // Dialect
29 //===----------------------------------------------------------------------===//
30 
31 /// Dialects are groups of MLIR operations, types and attributes, as well as
32 /// behavior associated with the entire group.  For example, hooks into other
33 /// systems for constant folding, interfaces, default named types for asm
34 /// printing, etc.
35 ///
36 /// Instances of the dialect object are loaded in a specific MLIRContext.
37 ///
38 class Dialect {
39 public:
40   /// Type for a callback provided by the dialect to parse a custom operation.
41   /// This is used for the dialect to provide an alternative way to parse custom
42   /// operations, including unregistered ones.
43   using ParseOpHook =
44       function_ref<ParseResult(OpAsmParser &parser, OperationState &result)>;
45 
46   virtual ~Dialect();
47 
48   /// Utility function that returns if the given string is a valid dialect
49   /// namespace
50   static bool isValidNamespace(StringRef str);
51 
52   MLIRContext *getContext() const { return context; }
53 
54   StringRef getNamespace() const { return name; }
55 
56   /// Returns the unique identifier that corresponds to this dialect.
57   TypeID getTypeID() const { return dialectID; }
58 
59   /// Returns true if this dialect allows for unregistered operations, i.e.
60   /// operations prefixed with the dialect namespace but not registered with
61   /// addOperation.
62   bool allowsUnknownOperations() const { return unknownOpsAllowed; }
63 
64   /// Return true if this dialect allows for unregistered types, i.e., types
65   /// prefixed with the dialect namespace but not registered with addType.
66   /// These are represented with OpaqueType.
67   bool allowsUnknownTypes() const { return unknownTypesAllowed; }
68 
69   /// Register dialect-wide canonicalization patterns. This method should only
70   /// be used to register canonicalization patterns that do not conceptually
71   /// belong to any single operation in the dialect. (In that case, use the op's
72   /// canonicalizer.) E.g., canonicalization patterns for op interfaces should
73   /// be registered here.
74   virtual void getCanonicalizationPatterns(RewritePatternSet &results) const {}
75 
76   /// Registered hook to materialize a single constant operation from a given
77   /// attribute value with the desired resultant type. This method should use
78   /// the provided builder to create the operation without changing the
79   /// insertion position. The generated operation is expected to be constant
80   /// like, i.e. single result, zero operands, non side-effecting, etc. On
81   /// success, this hook should return the value generated to represent the
82   /// constant value. Otherwise, it should return null on failure.
83   virtual Operation *materializeConstant(OpBuilder &builder, Attribute value,
84                                          Type type, Location loc) {
85     return nullptr;
86   }
87 
88   //===--------------------------------------------------------------------===//
89   // Parsing Hooks
90   //===--------------------------------------------------------------------===//
91 
92   /// Parse an attribute registered to this dialect. If 'type' is nonnull, it
93   /// refers to the expected type of the attribute.
94   virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const;
95 
96   /// Print an attribute registered to this dialect. Note: The type of the
97   /// attribute need not be printed by this method as it is always printed by
98   /// the caller.
99   virtual void printAttribute(Attribute, DialectAsmPrinter &) const {
100     llvm_unreachable("dialect has no registered attribute printing hook");
101   }
102 
103   /// Parse a type registered to this dialect.
104   virtual Type parseType(DialectAsmParser &parser) const;
105 
106   /// Print a type registered to this dialect.
107   virtual void printType(Type, DialectAsmPrinter &) const {
108     llvm_unreachable("dialect has no registered type printing hook");
109   }
110 
111   /// Return the hook to parse an operation registered to this dialect, if any.
112   /// By default this will lookup for registered operations and return the
113   /// `parse()` method registered on the RegisteredOperationName. Dialects can
114   /// override this behavior and handle unregistered operations as well.
115   virtual std::optional<ParseOpHook>
116   getParseOperationHook(StringRef opName) const;
117 
118   /// Print an operation registered to this dialect.
119   /// This hook is invoked for registered operation which don't override the
120   /// `print()` method to define their own custom assembly.
121   virtual llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
122   getOperationPrinter(Operation *op) const;
123 
124   //===--------------------------------------------------------------------===//
125   // Verification Hooks
126   //===--------------------------------------------------------------------===//
127 
128   /// Verify an attribute from this dialect on the argument at 'argIndex' for
129   /// the region at 'regionIndex' on the given operation. Returns failure if
130   /// the verification failed, success otherwise. This hook may optionally be
131   /// invoked from any operation containing a region.
132   virtual LogicalResult verifyRegionArgAttribute(Operation *,
133                                                  unsigned regionIndex,
134                                                  unsigned argIndex,
135                                                  NamedAttribute);
136 
137   /// Verify an attribute from this dialect on the result at 'resultIndex' for
138   /// the region at 'regionIndex' on the given operation. Returns failure if
139   /// the verification failed, success otherwise. This hook may optionally be
140   /// invoked from any operation containing a region.
141   virtual LogicalResult verifyRegionResultAttribute(Operation *,
142                                                     unsigned regionIndex,
143                                                     unsigned resultIndex,
144                                                     NamedAttribute);
145 
146   /// Verify an attribute from this dialect on the given operation. Returns
147   /// failure if the verification failed, success otherwise.
148   virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute) {
149     return success();
150   }
151 
152   //===--------------------------------------------------------------------===//
153   // Interfaces
154   //===--------------------------------------------------------------------===//
155 
156   /// Lookup an interface for the given ID if one is registered, otherwise
157   /// nullptr.
158   DialectInterface *getRegisteredInterface(TypeID interfaceID) {
159 #ifndef NDEBUG
160     handleUseOfUndefinedPromisedInterface(getTypeID(), interfaceID);
161 #endif
162 
163     auto it = registeredInterfaces.find(interfaceID);
164     return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr;
165   }
166   template <typename InterfaceT>
167   InterfaceT *getRegisteredInterface() {
168 #ifndef NDEBUG
169     handleUseOfUndefinedPromisedInterface(getTypeID(),
170                                           InterfaceT::getInterfaceID(),
171                                           llvm::getTypeName<InterfaceT>());
172 #endif
173 
174     return static_cast<InterfaceT *>(
175         getRegisteredInterface(InterfaceT::getInterfaceID()));
176   }
177 
178   /// Lookup an op interface for the given ID if one is registered, otherwise
179   /// nullptr.
180   virtual void *getRegisteredInterfaceForOp(TypeID interfaceID,
181                                             OperationName opName) {
182     return nullptr;
183   }
184   template <typename InterfaceT>
185   typename InterfaceT::Concept *
186   getRegisteredInterfaceForOp(OperationName opName) {
187     return static_cast<typename InterfaceT::Concept *>(
188         getRegisteredInterfaceForOp(InterfaceT::getInterfaceID(), opName));
189   }
190 
191   /// Register a dialect interface with this dialect instance.
192   void addInterface(std::unique_ptr<DialectInterface> interface);
193 
194   /// Register a set of dialect interfaces with this dialect instance.
195   template <typename... Args>
196   void addInterfaces() {
197     (addInterface(std::make_unique<Args>(this)), ...);
198   }
199   template <typename InterfaceT, typename... Args>
200   InterfaceT &addInterface(Args &&...args) {
201     InterfaceT *interface = new InterfaceT(this, std::forward<Args>(args)...);
202     addInterface(std::unique_ptr<DialectInterface>(interface));
203     return *interface;
204   }
205 
206   /// Declare that the given interface will be implemented, but has a delayed
207   /// registration. The promised interface type can be an interface of any type
208   /// not just a dialect interface, i.e. it may also be an
209   /// AttributeInterface/OpInterface/TypeInterface/etc.
210   template <typename InterfaceT, typename ConcreteT>
211   void declarePromisedInterface() {
212     unresolvedPromisedInterfaces.insert(
213         {TypeID::get<ConcreteT>(), InterfaceT::getInterfaceID()});
214   }
215 
216   // Declare the same interface for multiple types.
217   // Example:
218   // declarePromisedInterfaces<FunctionOpInterface, MyFuncType1, MyFuncType2>()
219   template <typename InterfaceT, typename... ConcreteT>
220   void declarePromisedInterfaces() {
221     (declarePromisedInterface<InterfaceT, ConcreteT>(), ...);
222   }
223 
224   /// Checks if the given interface, which is attempting to be used, is a
225   /// promised interface of this dialect that has yet to be implemented. If so,
226   /// emits a fatal error. `interfaceName` is an optional string that contains a
227   /// more user readable name for the interface (such as the class name).
228   void handleUseOfUndefinedPromisedInterface(TypeID interfaceRequestorID,
229                                              TypeID interfaceID,
230                                              StringRef interfaceName = "") {
231     if (unresolvedPromisedInterfaces.count(
232             {interfaceRequestorID, interfaceID})) {
233       llvm::report_fatal_error(
234           "checking for an interface (`" + interfaceName +
235           "`) that was promised by dialect '" + getNamespace() +
236           "' but never implemented. This is generally an indication "
237           "that the dialect extension implementing the interface was never "
238           "registered.");
239     }
240   }
241 
242   /// Checks if the given interface, which is attempting to be attached to a
243   /// construct owned by this dialect, is a promised interface of this dialect
244   /// that has yet to be implemented. If so, it resolves the interface promise.
245   void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceRequestorID,
246                                                   TypeID interfaceID) {
247     unresolvedPromisedInterfaces.erase({interfaceRequestorID, interfaceID});
248   }
249 
250   /// Checks if a promise has been made for the interface/requestor pair.
251   bool hasPromisedInterface(TypeID interfaceRequestorID,
252                             TypeID interfaceID) const {
253     return unresolvedPromisedInterfaces.count(
254         {interfaceRequestorID, interfaceID});
255   }
256 
257   /// Checks if a promise has been made for the interface/requestor pair.
258   template <typename ConcreteT, typename InterfaceT>
259   bool hasPromisedInterface() const {
260     return hasPromisedInterface(TypeID::get<ConcreteT>(),
261                                 InterfaceT::getInterfaceID());
262   }
263 
264 protected:
265   /// The constructor takes a unique namespace for this dialect as well as the
266   /// context to bind to.
267   /// Note: The namespace must not contain '.' characters.
268   /// Note: All operations belonging to this dialect must have names starting
269   ///       with the namespace followed by '.'.
270   /// Example:
271   ///       - "tf" for the TensorFlow ops like "tf.add".
272   Dialect(StringRef name, MLIRContext *context, TypeID id);
273 
274   /// This method is used by derived classes to add their operations to the set.
275   ///
276   template <typename... Args>
277   void addOperations() {
278     // This initializer_list argument pack expansion is essentially equal to
279     // using a fold expression with a comma operator. Clang however, refuses
280     // to compile a fold expression with a depth of more than 256 by default.
281     // There seem to be no such limitations for initializer_list.
282     (void)std::initializer_list<int>{
283         0, (RegisteredOperationName::insert<Args>(*this), 0)...};
284   }
285 
286   /// Register a set of type classes with this dialect.
287   template <typename... Args>
288   void addTypes() {
289     // This initializer_list argument pack expansion is essentially equal to
290     // using a fold expression with a comma operator. Clang however, refuses
291     // to compile a fold expression with a depth of more than 256 by default.
292     // There seem to be no such limitations for initializer_list.
293     (void)std::initializer_list<int>{0, (addType<Args>(), 0)...};
294   }
295 
296   /// Register a type instance with this dialect.
297   /// The use of this method is in general discouraged in favor of
298   /// 'addTypes<CustomType>()'.
299   void addType(TypeID typeID, AbstractType &&typeInfo);
300 
301   /// Register a set of attribute classes with this dialect.
302   template <typename... Args>
303   void addAttributes() {
304     // This initializer_list argument pack expansion is essentially equal to
305     // using a fold expression with a comma operator. Clang however, refuses
306     // to compile a fold expression with a depth of more than 256 by default.
307     // There seem to be no such limitations for initializer_list.
308     (void)std::initializer_list<int>{0, (addAttribute<Args>(), 0)...};
309   }
310 
311   /// Register an attribute instance with this dialect.
312   /// The use of this method is in general discouraged in favor of
313   /// 'addAttributes<CustomAttr>()'.
314   void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
315 
316   /// Enable support for unregistered operations.
317   void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
318 
319   /// Enable support for unregistered types.
320   void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; }
321 
322 private:
323   Dialect(const Dialect &) = delete;
324   void operator=(Dialect &) = delete;
325 
326   /// Register an attribute instance with this dialect.
327   template <typename T>
328   void addAttribute() {
329     // Add this attribute to the dialect and register it with the uniquer.
330     addAttribute(T::getTypeID(), AbstractAttribute::get<T>(*this));
331     detail::AttributeUniquer::registerAttribute<T>(context);
332   }
333 
334   /// Register a type instance with this dialect.
335   template <typename T>
336   void addType() {
337     // Add this type to the dialect and register it with the uniquer.
338     addType(T::getTypeID(), AbstractType::get<T>(*this));
339     detail::TypeUniquer::registerType<T>(context);
340   }
341 
342   /// The namespace of this dialect.
343   StringRef name;
344 
345   /// The unique identifier of the derived Op class, this is used in the context
346   /// to allow registering multiple times the same dialect.
347   TypeID dialectID;
348 
349   /// This is the context that owns this Dialect object.
350   MLIRContext *context;
351 
352   /// Flag that specifies whether this dialect supports unregistered operations,
353   /// i.e. operations prefixed with the dialect namespace but not registered
354   /// with addOperation.
355   bool unknownOpsAllowed = false;
356 
357   /// Flag that specifies whether this dialect allows unregistered types, i.e.
358   /// types prefixed with the dialect namespace but not registered with addType.
359   /// These types are represented with OpaqueType.
360   bool unknownTypesAllowed = false;
361 
362   /// A collection of registered dialect interfaces.
363   DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces;
364 
365   /// A set of interfaces that the dialect (or its constructs, i.e.
366   /// Attributes/Operations/Types/etc.) has promised to implement, but has yet
367   /// to provide an implementation for.
368   DenseSet<std::pair<TypeID, TypeID>> unresolvedPromisedInterfaces;
369 
370   friend class DialectRegistry;
371   friend class MLIRContext;
372 };
373 
374 } // namespace mlir
375 
376 namespace llvm {
377 /// Provide isa functionality for Dialects.
378 template <typename T>
379 struct isa_impl<T, ::mlir::Dialect,
380                 std::enable_if_t<std::is_base_of<::mlir::Dialect, T>::value>> {
381   static inline bool doit(const ::mlir::Dialect &dialect) {
382     return mlir::TypeID::get<T>() == dialect.getTypeID();
383   }
384 };
385 template <typename T>
386 struct isa_impl<
387     T, ::mlir::Dialect,
388     std::enable_if_t<std::is_base_of<::mlir::DialectInterface, T>::value>> {
389   static inline bool doit(const ::mlir::Dialect &dialect) {
390     return const_cast<::mlir::Dialect &>(dialect).getRegisteredInterface<T>();
391   }
392 };
393 template <typename T>
394 struct cast_retty_impl<T, ::mlir::Dialect *> {
395   using ret_type = T *;
396 };
397 template <typename T>
398 struct cast_retty_impl<T, ::mlir::Dialect> {
399   using ret_type = T &;
400 };
401 
402 template <typename T>
403 struct cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect> {
404   template <typename To>
405   static std::enable_if_t<std::is_base_of<::mlir::Dialect, To>::value, To &>
406   doitImpl(::mlir::Dialect &dialect) {
407     return static_cast<To &>(dialect);
408   }
409   template <typename To>
410   static std::enable_if_t<std::is_base_of<::mlir::DialectInterface, To>::value,
411                           To &>
412   doitImpl(::mlir::Dialect &dialect) {
413     return *dialect.getRegisteredInterface<To>();
414   }
415 
416   static auto &doit(::mlir::Dialect &dialect) { return doitImpl<T>(dialect); }
417 };
418 template <class T>
419 struct cast_convert_val<T, ::mlir::Dialect *, ::mlir::Dialect *> {
420   static auto doit(::mlir::Dialect *dialect) {
421     return &cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect>::doit(
422         *dialect);
423   }
424 };
425 
426 } // namespace llvm
427 
428 #endif
429