xref: /llvm-project/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h (revision 56b29074fe924243640547a9fec35bef0942b210)
1 //===- TransformDialect.h - Transform Dialect Definition --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
10 #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
11 
12 #include "mlir/IR/Dialect.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Support/LLVM.h"
15 #include "mlir/Support/TypeID.h"
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/StringMap.h"
18 #include <optional>
19 
20 namespace mlir {
21 namespace transform {
22 
23 namespace detail {
24 /// Concrete base class for CRTP TransformDialectDataBase. Must not be used
25 /// directly.
26 class TransformDialectDataBase {
27 public:
28   virtual ~TransformDialectDataBase() = default;
29 
30   /// Returns the dynamic type ID of the subclass.
31   TypeID getTypeID() const { return typeID; }
32 
33 protected:
34   /// Must be called by the subclass with the appropriate type ID.
35   explicit TransformDialectDataBase(TypeID typeID, MLIRContext *ctx)
36       : typeID(typeID), ctx(ctx) {}
37 
38   /// Return the MLIR context.
39   MLIRContext *getContext() const { return ctx; }
40 
41 private:
42   /// The type ID of the subclass.
43   const TypeID typeID;
44 
45   /// The MLIR context.
46   MLIRContext *ctx;
47 };
48 } // namespace detail
49 
50 /// Base class for additional data owned by the Transform dialect. Extensions
51 /// may communicate with each other using this data. The data object is
52 /// identified by the TypeID of the specific data subclass, querying the data of
53 /// the same subclass returns a reference to the same object. When a Transform
54 /// dialect extension is initialized, it can populate the data in the specific
55 /// subclass. When a Transform op is applied, it can read (but not mutate) the
56 /// data in the specific subclass, including the data provided by other
57 /// extensions.
58 ///
59 /// This follows CRTP: derived classes must list themselves as template
60 /// argument.
61 template <typename DerivedTy>
62 class TransformDialectData : public detail::TransformDialectDataBase {
63 protected:
64   /// Forward the TypeID of the derived class to the base.
65   TransformDialectData(MLIRContext *ctx)
66       : TransformDialectDataBase(TypeID::get<DerivedTy>(), ctx) {}
67 };
68 
69 #ifndef NDEBUG
70 namespace detail {
71 /// Asserts that the operations provided as template arguments implement the
72 /// TransformOpInterface and MemoryEffectsOpInterface. This must be a dynamic
73 /// assertion since interface implementations may be registered at runtime.
74 void checkImplementsTransformOpInterface(StringRef name, MLIRContext *context);
75 
76 /// Asserts that the type provided as template argument implements the
77 /// TransformHandleTypeInterface. This must be a dynamic assertion since
78 /// interface implementations may be registered at runtime.
79 void checkImplementsTransformHandleTypeInterface(TypeID typeID,
80                                                  MLIRContext *context);
81 } // namespace detail
82 #endif // NDEBUG
83 } // namespace transform
84 } // namespace mlir
85 
86 #include "mlir/Dialect/Transform/IR/TransformDialect.h.inc"
87 
88 namespace mlir {
89 namespace transform {
90 
91 /// Base class for extensions of the Transform dialect that supports injecting
92 /// operations into the Transform dialect at load time. Concrete extensions are
93 /// expected to derive this class and register operations in the constructor.
94 /// They can be registered with the DialectRegistry and automatically applied
95 /// to the Transform dialect when it is loaded.
96 ///
97 /// Derived classes are expected to define a `void init()` function in which
98 /// they can call various protected methods of the base class to register
99 /// extension operations and declare their dependencies.
100 ///
101 /// By default, the extension is configured both for construction of the
102 /// Transform IR and for its application to some payload. If only the
103 /// construction is desired, the extension can be switched to "build-only" mode
104 /// that avoids loading the dialects that are only necessary for transforming
105 /// the payload. To perform the switch, the extension must be wrapped into the
106 /// `BuildOnly` class template (see below) when it is registered, as in:
107 ///
108 ///    dialectRegistry.addExtension<BuildOnly<MyTransformDialectExt>>();
109 ///
110 /// instead of:
111 ///
112 ///    dialectRegistry.addExtension<MyTransformDialectExt>();
113 ///
114 /// Derived classes must reexport the constructor of this class or otherwise
115 /// forward its boolean argument to support this behavior.
116 template <typename DerivedTy, typename... ExtraDialects>
117 class TransformDialectExtension
118     : public DialectExtension<DerivedTy, TransformDialect, ExtraDialects...> {
119   using Initializer = std::function<void(TransformDialect *)>;
120   using DialectLoader = std::function<void(MLIRContext *)>;
121 
122 public:
123   /// Extension application hook. Actually loads the dependent dialects and
124   /// registers the additional operations. Not expected to be called directly.
125   void apply(MLIRContext *context, TransformDialect *transformDialect,
126              ExtraDialects *...) const final {
127     for (const DialectLoader &loader : dialectLoaders)
128       loader(context);
129 
130     // Only load generated dialects if the user intends to apply
131     // transformations specified by the extension.
132     if (!buildOnly)
133       for (const DialectLoader &loader : generatedDialectLoaders)
134         loader(context);
135 
136     for (const Initializer &init : initializers)
137       init(transformDialect);
138   }
139 
140 protected:
141   using Base = TransformDialectExtension<DerivedTy, ExtraDialects...>;
142 
143   /// Extension constructor. The argument indicates whether to skip generated
144   /// dialects when applying the extension.
145   explicit TransformDialectExtension(bool buildOnly = false)
146       : buildOnly(buildOnly) {
147     static_cast<DerivedTy *>(this)->init();
148   }
149 
150   /// Registers a custom initialization step to be performed when the extension
151   /// is applied to the dialect while loading. This is discouraged in favor of
152   /// more specific calls `declareGeneratedDialect`, `addDialectDataInitializer`
153   /// etc. `Func` must be convertible to the `void (MLIRContext *)` form. It
154   /// will be called during the extension initialization and given the current
155   /// MLIR context. This may be used to attach additional interfaces that cannot
156   /// be attached elsewhere.
157   template <typename Func>
158   void addCustomInitializationStep(Func &&func) {
159     std::function<void(MLIRContext *)> initializer = func;
160     dialectLoaders.push_back(
161         [init = std::move(initializer)](MLIRContext *ctx) { init(ctx); });
162   }
163 
164   /// Registers the given function as one of the initializers for the
165   /// dialect-owned data of the kind specified as template argument. The
166   /// function must be convertible to the `void (DataTy &)` form. It will be
167   /// called during the extension initialization and will be given a mutable
168   /// reference to `DataTy`. The callback is expected to append data to the
169   /// given storage, and is not allowed to remove or destructively mutate the
170   /// existing data. The order in which callbacks from different extensions are
171   /// executed is unspecified so the callbacks may not rely on data being
172   /// already present. `DataTy` must be a class deriving `TransformDialectData`.
173   template <typename DataTy, typename Func>
174   void addDialectDataInitializer(Func &&func) {
175     static_assert(std::is_base_of_v<detail::TransformDialectDataBase, DataTy>,
176                   "only classes deriving TransformDialectData are accepted");
177 
178     std::function<void(DataTy &)> initializer = func;
179     initializers.push_back(
180         [init = std::move(initializer)](TransformDialect *transformDialect) {
181           init(transformDialect->getOrCreateExtraData<DataTy>());
182         });
183   }
184 
185   /// Hook for derived classes to inject constructor behavior.
186   void init() {}
187 
188   /// Injects the operations into the Transform dialect. The operations must
189   /// implement the TransformOpInterface and MemoryEffectsOpInterface, and the
190   /// implementations must be already available when the operation is injected.
191   template <typename... OpTys>
192   void registerTransformOps() {
193     initializers.push_back([](TransformDialect *transformDialect) {
194       transformDialect->addOperationsChecked<OpTys...>();
195     });
196   }
197 
198   /// Injects the types into the Transform dialect. The types must implement
199   /// the TransformHandleTypeInterface and the implementation must be already
200   /// available when the type is injected. Furthermore, the types must provide
201   /// a `getMnemonic` static method returning an object convertible to
202   /// `StringRef` that is unique across all injected types.
203   template <typename... TypeTys>
204   void registerTypes() {
205     initializers.push_back([](TransformDialect *transformDialect) {
206       transformDialect->addTypesChecked<TypeTys...>();
207     });
208   }
209 
210   /// Declares that this Transform dialect extension depends on the dialect
211   /// provided as template parameter. When the Transform dialect is loaded,
212   /// dependent dialects will be loaded as well. This is intended for dialects
213   /// that contain attributes and types used in creation and canonicalization of
214   /// the injected operations, similarly to how the dialect definition may list
215   /// dependent dialects. This is *not* intended for dialects entities from
216   /// which may be produced when applying the transformations specified by ops
217   /// registered by this extension.
218   template <typename DialectTy>
219   void declareDependentDialect() {
220     dialectLoaders.push_back(
221         [](MLIRContext *context) { context->loadDialect<DialectTy>(); });
222   }
223 
224   /// Declares that the transformations associated with the operations
225   /// registered by this dialect extension may produce operations from the
226   /// dialect provided as template parameter while processing payload IR that
227   /// does not contain the operations from said dialect. This is similar to
228   /// dependent dialects of a pass. These dialects will be loaded along with the
229   /// transform dialect unless the extension is in the build-only mode.
230   template <typename DialectTy>
231   void declareGeneratedDialect() {
232     generatedDialectLoaders.push_back(
233         [](MLIRContext *context) { context->loadDialect<DialectTy>(); });
234   }
235 
236 private:
237   /// Callbacks performing extension initialization, e.g., registering ops,
238   /// types and defining the additional data.
239   SmallVector<Initializer> initializers;
240 
241   /// Callbacks loading the dependent dialects, i.e. the dialect needed for the
242   /// extension ops.
243   SmallVector<DialectLoader> dialectLoaders;
244 
245   /// Callbacks loading the generated dialects, i.e. the dialects produced when
246   /// applying the transformations.
247   SmallVector<DialectLoader> generatedDialectLoaders;
248 
249   /// Indicates that the extension is in build-only mode.
250   bool buildOnly;
251 };
252 
253 template <typename OpTy>
254 void TransformDialect::addOperationIfNotRegistered() {
255   std::optional<RegisteredOperationName> opName =
256       RegisteredOperationName::lookup(TypeID::get<OpTy>(), getContext());
257   if (!opName) {
258     addOperations<OpTy>();
259 #ifndef NDEBUG
260     StringRef name = OpTy::getOperationName();
261     detail::checkImplementsTransformOpInterface(name, getContext());
262 #endif // NDEBUG
263     return;
264   }
265 
266   if (LLVM_LIKELY(opName->getTypeID() == TypeID::get<OpTy>()))
267     return;
268 
269   reportDuplicateOpRegistration(OpTy::getOperationName());
270 }
271 
272 template <typename Type>
273 void TransformDialect::addTypeIfNotRegistered() {
274   // Use the address of the parse method as a proxy for identifying whether we
275   // are registering the same type class for the same mnemonic.
276   StringRef mnemonic = Type::getMnemonic();
277   auto [it, inserted] = typeParsingHooks.try_emplace(mnemonic, Type::parse);
278   if (!inserted) {
279     const ExtensionTypeParsingHook &parsingHook = it->getValue();
280     if (parsingHook != &Type::parse)
281       reportDuplicateTypeRegistration(mnemonic);
282     else
283       return;
284   }
285   typePrintingHooks.try_emplace(
286       TypeID::get<Type>(), +[](mlir::Type type, AsmPrinter &printer) {
287         printer << Type::getMnemonic();
288         cast<Type>(type).print(printer);
289       });
290   addTypes<Type>();
291 
292 #ifndef NDEBUG
293   detail::checkImplementsTransformHandleTypeInterface(TypeID::get<Type>(),
294                                                       getContext());
295 #endif // NDEBUG
296 }
297 
298 template <typename DataTy>
299 DataTy &TransformDialect::getOrCreateExtraData() {
300   TypeID typeID = TypeID::get<DataTy>();
301   auto [it, inserted] = extraData.try_emplace(typeID);
302   if (inserted)
303     it->getSecond() = std::make_unique<DataTy>(getContext());
304   return static_cast<DataTy &>(*it->getSecond());
305 }
306 
307 /// A wrapper for transform dialect extensions that forces them to be
308 /// constructed in the build-only mode.
309 template <typename DerivedTy>
310 class BuildOnly : public DerivedTy {
311 public:
312   BuildOnly() : DerivedTy(/*buildOnly=*/true) {}
313 };
314 
315 } // namespace transform
316 } // namespace mlir
317 
318 #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
319