1 //===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the 'dialect' abstraction. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_IR_DIALECT_H 14 #define MLIR_IR_DIALECT_H 15 16 #include "mlir/IR/DialectRegistry.h" 17 #include "mlir/IR/OperationSupport.h" 18 #include "mlir/Support/TypeID.h" 19 20 namespace mlir { 21 class DialectAsmParser; 22 class DialectAsmPrinter; 23 class DialectInterface; 24 class OpBuilder; 25 class Type; 26 27 //===----------------------------------------------------------------------===// 28 // Dialect 29 //===----------------------------------------------------------------------===// 30 31 /// Dialects are groups of MLIR operations, types and attributes, as well as 32 /// behavior associated with the entire group. For example, hooks into other 33 /// systems for constant folding, interfaces, default named types for asm 34 /// printing, etc. 35 /// 36 /// Instances of the dialect object are loaded in a specific MLIRContext. 37 /// 38 class Dialect { 39 public: 40 /// Type for a callback provided by the dialect to parse a custom operation. 41 /// This is used for the dialect to provide an alternative way to parse custom 42 /// operations, including unregistered ones. 43 using ParseOpHook = 44 function_ref<ParseResult(OpAsmParser &parser, OperationState &result)>; 45 46 virtual ~Dialect(); 47 48 /// Utility function that returns if the given string is a valid dialect 49 /// namespace 50 static bool isValidNamespace(StringRef str); 51 52 MLIRContext *getContext() const { return context; } 53 54 StringRef getNamespace() const { return name; } 55 56 /// Returns the unique identifier that corresponds to this dialect. 57 TypeID getTypeID() const { return dialectID; } 58 59 /// Returns true if this dialect allows for unregistered operations, i.e. 60 /// operations prefixed with the dialect namespace but not registered with 61 /// addOperation. 62 bool allowsUnknownOperations() const { return unknownOpsAllowed; } 63 64 /// Return true if this dialect allows for unregistered types, i.e., types 65 /// prefixed with the dialect namespace but not registered with addType. 66 /// These are represented with OpaqueType. 67 bool allowsUnknownTypes() const { return unknownTypesAllowed; } 68 69 /// Register dialect-wide canonicalization patterns. This method should only 70 /// be used to register canonicalization patterns that do not conceptually 71 /// belong to any single operation in the dialect. (In that case, use the op's 72 /// canonicalizer.) E.g., canonicalization patterns for op interfaces should 73 /// be registered here. 74 virtual void getCanonicalizationPatterns(RewritePatternSet &results) const {} 75 76 /// Registered hook to materialize a single constant operation from a given 77 /// attribute value with the desired resultant type. This method should use 78 /// the provided builder to create the operation without changing the 79 /// insertion position. The generated operation is expected to be constant 80 /// like, i.e. single result, zero operands, non side-effecting, etc. On 81 /// success, this hook should return the value generated to represent the 82 /// constant value. Otherwise, it should return null on failure. 83 virtual Operation *materializeConstant(OpBuilder &builder, Attribute value, 84 Type type, Location loc) { 85 return nullptr; 86 } 87 88 //===--------------------------------------------------------------------===// 89 // Parsing Hooks 90 //===--------------------------------------------------------------------===// 91 92 /// Parse an attribute registered to this dialect. If 'type' is nonnull, it 93 /// refers to the expected type of the attribute. 94 virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const; 95 96 /// Print an attribute registered to this dialect. Note: The type of the 97 /// attribute need not be printed by this method as it is always printed by 98 /// the caller. 99 virtual void printAttribute(Attribute, DialectAsmPrinter &) const { 100 llvm_unreachable("dialect has no registered attribute printing hook"); 101 } 102 103 /// Parse a type registered to this dialect. 104 virtual Type parseType(DialectAsmParser &parser) const; 105 106 /// Print a type registered to this dialect. 107 virtual void printType(Type, DialectAsmPrinter &) const { 108 llvm_unreachable("dialect has no registered type printing hook"); 109 } 110 111 /// Return the hook to parse an operation registered to this dialect, if any. 112 /// By default this will lookup for registered operations and return the 113 /// `parse()` method registered on the RegisteredOperationName. Dialects can 114 /// override this behavior and handle unregistered operations as well. 115 virtual std::optional<ParseOpHook> 116 getParseOperationHook(StringRef opName) const; 117 118 /// Print an operation registered to this dialect. 119 /// This hook is invoked for registered operation which don't override the 120 /// `print()` method to define their own custom assembly. 121 virtual llvm::unique_function<void(Operation *, OpAsmPrinter &printer)> 122 getOperationPrinter(Operation *op) const; 123 124 //===--------------------------------------------------------------------===// 125 // Verification Hooks 126 //===--------------------------------------------------------------------===// 127 128 /// Verify an attribute from this dialect on the argument at 'argIndex' for 129 /// the region at 'regionIndex' on the given operation. Returns failure if 130 /// the verification failed, success otherwise. This hook may optionally be 131 /// invoked from any operation containing a region. 132 virtual LogicalResult verifyRegionArgAttribute(Operation *, 133 unsigned regionIndex, 134 unsigned argIndex, 135 NamedAttribute); 136 137 /// Verify an attribute from this dialect on the result at 'resultIndex' for 138 /// the region at 'regionIndex' on the given operation. Returns failure if 139 /// the verification failed, success otherwise. This hook may optionally be 140 /// invoked from any operation containing a region. 141 virtual LogicalResult verifyRegionResultAttribute(Operation *, 142 unsigned regionIndex, 143 unsigned resultIndex, 144 NamedAttribute); 145 146 /// Verify an attribute from this dialect on the given operation. Returns 147 /// failure if the verification failed, success otherwise. 148 virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute) { 149 return success(); 150 } 151 152 //===--------------------------------------------------------------------===// 153 // Interfaces 154 //===--------------------------------------------------------------------===// 155 156 /// Lookup an interface for the given ID if one is registered, otherwise 157 /// nullptr. 158 DialectInterface *getRegisteredInterface(TypeID interfaceID) { 159 #ifndef NDEBUG 160 handleUseOfUndefinedPromisedInterface(getTypeID(), interfaceID); 161 #endif 162 163 auto it = registeredInterfaces.find(interfaceID); 164 return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr; 165 } 166 template <typename InterfaceT> 167 InterfaceT *getRegisteredInterface() { 168 #ifndef NDEBUG 169 handleUseOfUndefinedPromisedInterface(getTypeID(), 170 InterfaceT::getInterfaceID(), 171 llvm::getTypeName<InterfaceT>()); 172 #endif 173 174 return static_cast<InterfaceT *>( 175 getRegisteredInterface(InterfaceT::getInterfaceID())); 176 } 177 178 /// Lookup an op interface for the given ID if one is registered, otherwise 179 /// nullptr. 180 virtual void *getRegisteredInterfaceForOp(TypeID interfaceID, 181 OperationName opName) { 182 return nullptr; 183 } 184 template <typename InterfaceT> 185 typename InterfaceT::Concept * 186 getRegisteredInterfaceForOp(OperationName opName) { 187 return static_cast<typename InterfaceT::Concept *>( 188 getRegisteredInterfaceForOp(InterfaceT::getInterfaceID(), opName)); 189 } 190 191 /// Register a dialect interface with this dialect instance. 192 void addInterface(std::unique_ptr<DialectInterface> interface); 193 194 /// Register a set of dialect interfaces with this dialect instance. 195 template <typename... Args> 196 void addInterfaces() { 197 (addInterface(std::make_unique<Args>(this)), ...); 198 } 199 template <typename InterfaceT, typename... Args> 200 InterfaceT &addInterface(Args &&...args) { 201 InterfaceT *interface = new InterfaceT(this, std::forward<Args>(args)...); 202 addInterface(std::unique_ptr<DialectInterface>(interface)); 203 return *interface; 204 } 205 206 /// Declare that the given interface will be implemented, but has a delayed 207 /// registration. The promised interface type can be an interface of any type 208 /// not just a dialect interface, i.e. it may also be an 209 /// AttributeInterface/OpInterface/TypeInterface/etc. 210 template <typename InterfaceT, typename ConcreteT> 211 void declarePromisedInterface() { 212 unresolvedPromisedInterfaces.insert( 213 {TypeID::get<ConcreteT>(), InterfaceT::getInterfaceID()}); 214 } 215 216 // Declare the same interface for multiple types. 217 // Example: 218 // declarePromisedInterfaces<FunctionOpInterface, MyFuncType1, MyFuncType2>() 219 template <typename InterfaceT, typename... ConcreteT> 220 void declarePromisedInterfaces() { 221 (declarePromisedInterface<InterfaceT, ConcreteT>(), ...); 222 } 223 224 /// Checks if the given interface, which is attempting to be used, is a 225 /// promised interface of this dialect that has yet to be implemented. If so, 226 /// emits a fatal error. `interfaceName` is an optional string that contains a 227 /// more user readable name for the interface (such as the class name). 228 void handleUseOfUndefinedPromisedInterface(TypeID interfaceRequestorID, 229 TypeID interfaceID, 230 StringRef interfaceName = "") { 231 if (unresolvedPromisedInterfaces.count( 232 {interfaceRequestorID, interfaceID})) { 233 llvm::report_fatal_error( 234 "checking for an interface (`" + interfaceName + 235 "`) that was promised by dialect '" + getNamespace() + 236 "' but never implemented. This is generally an indication " 237 "that the dialect extension implementing the interface was never " 238 "registered."); 239 } 240 } 241 242 /// Checks if the given interface, which is attempting to be attached to a 243 /// construct owned by this dialect, is a promised interface of this dialect 244 /// that has yet to be implemented. If so, it resolves the interface promise. 245 void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceRequestorID, 246 TypeID interfaceID) { 247 unresolvedPromisedInterfaces.erase({interfaceRequestorID, interfaceID}); 248 } 249 250 /// Checks if a promise has been made for the interface/requestor pair. 251 bool hasPromisedInterface(TypeID interfaceRequestorID, 252 TypeID interfaceID) const { 253 return unresolvedPromisedInterfaces.count( 254 {interfaceRequestorID, interfaceID}); 255 } 256 257 /// Checks if a promise has been made for the interface/requestor pair. 258 template <typename ConcreteT, typename InterfaceT> 259 bool hasPromisedInterface() const { 260 return hasPromisedInterface(TypeID::get<ConcreteT>(), 261 InterfaceT::getInterfaceID()); 262 } 263 264 protected: 265 /// The constructor takes a unique namespace for this dialect as well as the 266 /// context to bind to. 267 /// Note: The namespace must not contain '.' characters. 268 /// Note: All operations belonging to this dialect must have names starting 269 /// with the namespace followed by '.'. 270 /// Example: 271 /// - "tf" for the TensorFlow ops like "tf.add". 272 Dialect(StringRef name, MLIRContext *context, TypeID id); 273 274 /// This method is used by derived classes to add their operations to the set. 275 /// 276 template <typename... Args> 277 void addOperations() { 278 // This initializer_list argument pack expansion is essentially equal to 279 // using a fold expression with a comma operator. Clang however, refuses 280 // to compile a fold expression with a depth of more than 256 by default. 281 // There seem to be no such limitations for initializer_list. 282 (void)std::initializer_list<int>{ 283 0, (RegisteredOperationName::insert<Args>(*this), 0)...}; 284 } 285 286 /// Register a set of type classes with this dialect. 287 template <typename... Args> 288 void addTypes() { 289 // This initializer_list argument pack expansion is essentially equal to 290 // using a fold expression with a comma operator. Clang however, refuses 291 // to compile a fold expression with a depth of more than 256 by default. 292 // There seem to be no such limitations for initializer_list. 293 (void)std::initializer_list<int>{0, (addType<Args>(), 0)...}; 294 } 295 296 /// Register a type instance with this dialect. 297 /// The use of this method is in general discouraged in favor of 298 /// 'addTypes<CustomType>()'. 299 void addType(TypeID typeID, AbstractType &&typeInfo); 300 301 /// Register a set of attribute classes with this dialect. 302 template <typename... Args> 303 void addAttributes() { 304 // This initializer_list argument pack expansion is essentially equal to 305 // using a fold expression with a comma operator. Clang however, refuses 306 // to compile a fold expression with a depth of more than 256 by default. 307 // There seem to be no such limitations for initializer_list. 308 (void)std::initializer_list<int>{0, (addAttribute<Args>(), 0)...}; 309 } 310 311 /// Register an attribute instance with this dialect. 312 /// The use of this method is in general discouraged in favor of 313 /// 'addAttributes<CustomAttr>()'. 314 void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo); 315 316 /// Enable support for unregistered operations. 317 void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; } 318 319 /// Enable support for unregistered types. 320 void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; } 321 322 private: 323 Dialect(const Dialect &) = delete; 324 void operator=(Dialect &) = delete; 325 326 /// Register an attribute instance with this dialect. 327 template <typename T> 328 void addAttribute() { 329 // Add this attribute to the dialect and register it with the uniquer. 330 addAttribute(T::getTypeID(), AbstractAttribute::get<T>(*this)); 331 detail::AttributeUniquer::registerAttribute<T>(context); 332 } 333 334 /// Register a type instance with this dialect. 335 template <typename T> 336 void addType() { 337 // Add this type to the dialect and register it with the uniquer. 338 addType(T::getTypeID(), AbstractType::get<T>(*this)); 339 detail::TypeUniquer::registerType<T>(context); 340 } 341 342 /// The namespace of this dialect. 343 StringRef name; 344 345 /// The unique identifier of the derived Op class, this is used in the context 346 /// to allow registering multiple times the same dialect. 347 TypeID dialectID; 348 349 /// This is the context that owns this Dialect object. 350 MLIRContext *context; 351 352 /// Flag that specifies whether this dialect supports unregistered operations, 353 /// i.e. operations prefixed with the dialect namespace but not registered 354 /// with addOperation. 355 bool unknownOpsAllowed = false; 356 357 /// Flag that specifies whether this dialect allows unregistered types, i.e. 358 /// types prefixed with the dialect namespace but not registered with addType. 359 /// These types are represented with OpaqueType. 360 bool unknownTypesAllowed = false; 361 362 /// A collection of registered dialect interfaces. 363 DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces; 364 365 /// A set of interfaces that the dialect (or its constructs, i.e. 366 /// Attributes/Operations/Types/etc.) has promised to implement, but has yet 367 /// to provide an implementation for. 368 DenseSet<std::pair<TypeID, TypeID>> unresolvedPromisedInterfaces; 369 370 friend class DialectRegistry; 371 friend class MLIRContext; 372 }; 373 374 } // namespace mlir 375 376 namespace llvm { 377 /// Provide isa functionality for Dialects. 378 template <typename T> 379 struct isa_impl<T, ::mlir::Dialect, 380 std::enable_if_t<std::is_base_of<::mlir::Dialect, T>::value>> { 381 static inline bool doit(const ::mlir::Dialect &dialect) { 382 return mlir::TypeID::get<T>() == dialect.getTypeID(); 383 } 384 }; 385 template <typename T> 386 struct isa_impl< 387 T, ::mlir::Dialect, 388 std::enable_if_t<std::is_base_of<::mlir::DialectInterface, T>::value>> { 389 static inline bool doit(const ::mlir::Dialect &dialect) { 390 return const_cast<::mlir::Dialect &>(dialect).getRegisteredInterface<T>(); 391 } 392 }; 393 template <typename T> 394 struct cast_retty_impl<T, ::mlir::Dialect *> { 395 using ret_type = T *; 396 }; 397 template <typename T> 398 struct cast_retty_impl<T, ::mlir::Dialect> { 399 using ret_type = T &; 400 }; 401 402 template <typename T> 403 struct cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect> { 404 template <typename To> 405 static std::enable_if_t<std::is_base_of<::mlir::Dialect, To>::value, To &> 406 doitImpl(::mlir::Dialect &dialect) { 407 return static_cast<To &>(dialect); 408 } 409 template <typename To> 410 static std::enable_if_t<std::is_base_of<::mlir::DialectInterface, To>::value, 411 To &> 412 doitImpl(::mlir::Dialect &dialect) { 413 return *dialect.getRegisteredInterface<To>(); 414 } 415 416 static auto &doit(::mlir::Dialect &dialect) { return doitImpl<T>(dialect); } 417 }; 418 template <class T> 419 struct cast_convert_val<T, ::mlir::Dialect *, ::mlir::Dialect *> { 420 static auto doit(::mlir::Dialect *dialect) { 421 return &cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect>::doit( 422 *dialect); 423 } 424 }; 425 426 } // namespace llvm 427 428 #endif 429