1 //===- TransformDialect.h - Transform Dialect Definition --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H 10 #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H 11 12 #include "mlir/IR/Dialect.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/Support/LLVM.h" 15 #include "mlir/Support/TypeID.h" 16 #include "llvm/ADT/DenseMap.h" 17 #include "llvm/ADT/StringMap.h" 18 #include <optional> 19 20 namespace mlir { 21 namespace transform { 22 23 namespace detail { 24 /// Concrete base class for CRTP TransformDialectDataBase. Must not be used 25 /// directly. 26 class TransformDialectDataBase { 27 public: 28 virtual ~TransformDialectDataBase() = default; 29 30 /// Returns the dynamic type ID of the subclass. 31 TypeID getTypeID() const { return typeID; } 32 33 protected: 34 /// Must be called by the subclass with the appropriate type ID. 35 explicit TransformDialectDataBase(TypeID typeID, MLIRContext *ctx) 36 : typeID(typeID), ctx(ctx) {} 37 38 /// Return the MLIR context. 39 MLIRContext *getContext() const { return ctx; } 40 41 private: 42 /// The type ID of the subclass. 43 const TypeID typeID; 44 45 /// The MLIR context. 46 MLIRContext *ctx; 47 }; 48 } // namespace detail 49 50 /// Base class for additional data owned by the Transform dialect. Extensions 51 /// may communicate with each other using this data. The data object is 52 /// identified by the TypeID of the specific data subclass, querying the data of 53 /// the same subclass returns a reference to the same object. When a Transform 54 /// dialect extension is initialized, it can populate the data in the specific 55 /// subclass. When a Transform op is applied, it can read (but not mutate) the 56 /// data in the specific subclass, including the data provided by other 57 /// extensions. 58 /// 59 /// This follows CRTP: derived classes must list themselves as template 60 /// argument. 61 template <typename DerivedTy> 62 class TransformDialectData : public detail::TransformDialectDataBase { 63 protected: 64 /// Forward the TypeID of the derived class to the base. 65 TransformDialectData(MLIRContext *ctx) 66 : TransformDialectDataBase(TypeID::get<DerivedTy>(), ctx) {} 67 }; 68 69 #ifndef NDEBUG 70 namespace detail { 71 /// Asserts that the operations provided as template arguments implement the 72 /// TransformOpInterface and MemoryEffectsOpInterface. This must be a dynamic 73 /// assertion since interface implementations may be registered at runtime. 74 void checkImplementsTransformOpInterface(StringRef name, MLIRContext *context); 75 76 /// Asserts that the type provided as template argument implements the 77 /// TransformHandleTypeInterface. This must be a dynamic assertion since 78 /// interface implementations may be registered at runtime. 79 void checkImplementsTransformHandleTypeInterface(TypeID typeID, 80 MLIRContext *context); 81 } // namespace detail 82 #endif // NDEBUG 83 } // namespace transform 84 } // namespace mlir 85 86 #include "mlir/Dialect/Transform/IR/TransformDialect.h.inc" 87 88 namespace mlir { 89 namespace transform { 90 91 /// Base class for extensions of the Transform dialect that supports injecting 92 /// operations into the Transform dialect at load time. Concrete extensions are 93 /// expected to derive this class and register operations in the constructor. 94 /// They can be registered with the DialectRegistry and automatically applied 95 /// to the Transform dialect when it is loaded. 96 /// 97 /// Derived classes are expected to define a `void init()` function in which 98 /// they can call various protected methods of the base class to register 99 /// extension operations and declare their dependencies. 100 /// 101 /// By default, the extension is configured both for construction of the 102 /// Transform IR and for its application to some payload. If only the 103 /// construction is desired, the extension can be switched to "build-only" mode 104 /// that avoids loading the dialects that are only necessary for transforming 105 /// the payload. To perform the switch, the extension must be wrapped into the 106 /// `BuildOnly` class template (see below) when it is registered, as in: 107 /// 108 /// dialectRegistry.addExtension<BuildOnly<MyTransformDialectExt>>(); 109 /// 110 /// instead of: 111 /// 112 /// dialectRegistry.addExtension<MyTransformDialectExt>(); 113 /// 114 /// Derived classes must reexport the constructor of this class or otherwise 115 /// forward its boolean argument to support this behavior. 116 template <typename DerivedTy, typename... ExtraDialects> 117 class TransformDialectExtension 118 : public DialectExtension<DerivedTy, TransformDialect, ExtraDialects...> { 119 using Initializer = std::function<void(TransformDialect *)>; 120 using DialectLoader = std::function<void(MLIRContext *)>; 121 122 public: 123 /// Extension application hook. Actually loads the dependent dialects and 124 /// registers the additional operations. Not expected to be called directly. 125 void apply(MLIRContext *context, TransformDialect *transformDialect, 126 ExtraDialects *...) const final { 127 for (const DialectLoader &loader : dialectLoaders) 128 loader(context); 129 130 // Only load generated dialects if the user intends to apply 131 // transformations specified by the extension. 132 if (!buildOnly) 133 for (const DialectLoader &loader : generatedDialectLoaders) 134 loader(context); 135 136 for (const Initializer &init : initializers) 137 init(transformDialect); 138 } 139 140 protected: 141 using Base = TransformDialectExtension<DerivedTy, ExtraDialects...>; 142 143 /// Extension constructor. The argument indicates whether to skip generated 144 /// dialects when applying the extension. 145 explicit TransformDialectExtension(bool buildOnly = false) 146 : buildOnly(buildOnly) { 147 static_cast<DerivedTy *>(this)->init(); 148 } 149 150 /// Registers a custom initialization step to be performed when the extension 151 /// is applied to the dialect while loading. This is discouraged in favor of 152 /// more specific calls `declareGeneratedDialect`, `addDialectDataInitializer` 153 /// etc. `Func` must be convertible to the `void (MLIRContext *)` form. It 154 /// will be called during the extension initialization and given the current 155 /// MLIR context. This may be used to attach additional interfaces that cannot 156 /// be attached elsewhere. 157 template <typename Func> 158 void addCustomInitializationStep(Func &&func) { 159 std::function<void(MLIRContext *)> initializer = func; 160 dialectLoaders.push_back( 161 [init = std::move(initializer)](MLIRContext *ctx) { init(ctx); }); 162 } 163 164 /// Registers the given function as one of the initializers for the 165 /// dialect-owned data of the kind specified as template argument. The 166 /// function must be convertible to the `void (DataTy &)` form. It will be 167 /// called during the extension initialization and will be given a mutable 168 /// reference to `DataTy`. The callback is expected to append data to the 169 /// given storage, and is not allowed to remove or destructively mutate the 170 /// existing data. The order in which callbacks from different extensions are 171 /// executed is unspecified so the callbacks may not rely on data being 172 /// already present. `DataTy` must be a class deriving `TransformDialectData`. 173 template <typename DataTy, typename Func> 174 void addDialectDataInitializer(Func &&func) { 175 static_assert(std::is_base_of_v<detail::TransformDialectDataBase, DataTy>, 176 "only classes deriving TransformDialectData are accepted"); 177 178 std::function<void(DataTy &)> initializer = func; 179 initializers.push_back( 180 [init = std::move(initializer)](TransformDialect *transformDialect) { 181 init(transformDialect->getOrCreateExtraData<DataTy>()); 182 }); 183 } 184 185 /// Hook for derived classes to inject constructor behavior. 186 void init() {} 187 188 /// Injects the operations into the Transform dialect. The operations must 189 /// implement the TransformOpInterface and MemoryEffectsOpInterface, and the 190 /// implementations must be already available when the operation is injected. 191 template <typename... OpTys> 192 void registerTransformOps() { 193 initializers.push_back([](TransformDialect *transformDialect) { 194 transformDialect->addOperationsChecked<OpTys...>(); 195 }); 196 } 197 198 /// Injects the types into the Transform dialect. The types must implement 199 /// the TransformHandleTypeInterface and the implementation must be already 200 /// available when the type is injected. Furthermore, the types must provide 201 /// a `getMnemonic` static method returning an object convertible to 202 /// `StringRef` that is unique across all injected types. 203 template <typename... TypeTys> 204 void registerTypes() { 205 initializers.push_back([](TransformDialect *transformDialect) { 206 transformDialect->addTypesChecked<TypeTys...>(); 207 }); 208 } 209 210 /// Declares that this Transform dialect extension depends on the dialect 211 /// provided as template parameter. When the Transform dialect is loaded, 212 /// dependent dialects will be loaded as well. This is intended for dialects 213 /// that contain attributes and types used in creation and canonicalization of 214 /// the injected operations, similarly to how the dialect definition may list 215 /// dependent dialects. This is *not* intended for dialects entities from 216 /// which may be produced when applying the transformations specified by ops 217 /// registered by this extension. 218 template <typename DialectTy> 219 void declareDependentDialect() { 220 dialectLoaders.push_back( 221 [](MLIRContext *context) { context->loadDialect<DialectTy>(); }); 222 } 223 224 /// Declares that the transformations associated with the operations 225 /// registered by this dialect extension may produce operations from the 226 /// dialect provided as template parameter while processing payload IR that 227 /// does not contain the operations from said dialect. This is similar to 228 /// dependent dialects of a pass. These dialects will be loaded along with the 229 /// transform dialect unless the extension is in the build-only mode. 230 template <typename DialectTy> 231 void declareGeneratedDialect() { 232 generatedDialectLoaders.push_back( 233 [](MLIRContext *context) { context->loadDialect<DialectTy>(); }); 234 } 235 236 private: 237 /// Callbacks performing extension initialization, e.g., registering ops, 238 /// types and defining the additional data. 239 SmallVector<Initializer> initializers; 240 241 /// Callbacks loading the dependent dialects, i.e. the dialect needed for the 242 /// extension ops. 243 SmallVector<DialectLoader> dialectLoaders; 244 245 /// Callbacks loading the generated dialects, i.e. the dialects produced when 246 /// applying the transformations. 247 SmallVector<DialectLoader> generatedDialectLoaders; 248 249 /// Indicates that the extension is in build-only mode. 250 bool buildOnly; 251 }; 252 253 template <typename OpTy> 254 void TransformDialect::addOperationIfNotRegistered() { 255 std::optional<RegisteredOperationName> opName = 256 RegisteredOperationName::lookup(TypeID::get<OpTy>(), getContext()); 257 if (!opName) { 258 addOperations<OpTy>(); 259 #ifndef NDEBUG 260 StringRef name = OpTy::getOperationName(); 261 detail::checkImplementsTransformOpInterface(name, getContext()); 262 #endif // NDEBUG 263 return; 264 } 265 266 if (LLVM_LIKELY(opName->getTypeID() == TypeID::get<OpTy>())) 267 return; 268 269 reportDuplicateOpRegistration(OpTy::getOperationName()); 270 } 271 272 template <typename Type> 273 void TransformDialect::addTypeIfNotRegistered() { 274 // Use the address of the parse method as a proxy for identifying whether we 275 // are registering the same type class for the same mnemonic. 276 StringRef mnemonic = Type::getMnemonic(); 277 auto [it, inserted] = typeParsingHooks.try_emplace(mnemonic, Type::parse); 278 if (!inserted) { 279 const ExtensionTypeParsingHook &parsingHook = it->getValue(); 280 if (parsingHook != &Type::parse) 281 reportDuplicateTypeRegistration(mnemonic); 282 else 283 return; 284 } 285 typePrintingHooks.try_emplace( 286 TypeID::get<Type>(), +[](mlir::Type type, AsmPrinter &printer) { 287 printer << Type::getMnemonic(); 288 cast<Type>(type).print(printer); 289 }); 290 addTypes<Type>(); 291 292 #ifndef NDEBUG 293 detail::checkImplementsTransformHandleTypeInterface(TypeID::get<Type>(), 294 getContext()); 295 #endif // NDEBUG 296 } 297 298 template <typename DataTy> 299 DataTy &TransformDialect::getOrCreateExtraData() { 300 TypeID typeID = TypeID::get<DataTy>(); 301 auto [it, inserted] = extraData.try_emplace(typeID); 302 if (inserted) 303 it->getSecond() = std::make_unique<DataTy>(getContext()); 304 return static_cast<DataTy &>(*it->getSecond()); 305 } 306 307 /// A wrapper for transform dialect extensions that forces them to be 308 /// constructed in the build-only mode. 309 template <typename DerivedTy> 310 class BuildOnly : public DerivedTy { 311 public: 312 BuildOnly() : DerivedTy(/*buildOnly=*/true) {} 313 }; 314 315 } // namespace transform 316 } // namespace mlir 317 318 #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H 319