xref: /llvm-project/mlir/docs/Tutorials/transform/Ch2.md (revision 96ff0255f2ecd0bd5d1aba650a99e046a4bd5ec6)
1# Chapter 2: Adding a Simple New Transformation Operation
2
3## Setting Up to Add New Transformations
4
5Before defining a new transform operation, we need to choose where its implementation should be located. While MLIR encourages upstream contributions, it is not always possible or even desirable to modify the main Transform dialect, for example, if the transformation is specific to some out-of-tree dialect that is not itself available upstream.
6
7The Transform dialect uses the dialect extension mechanism to allow additional operations to be injected without modifying the dialect itself. Dialect extensions are registered with the context and loaded when the dialect itself is loaded. Extension definition is straightforward:
8
9```cpp
10// In MyExtension.cpp.
11#include "mlir/Dialect/Transform/IR/TransformDialect.h"
12
13// Define a new transform dialect extension. This uses the CRTP idiom to identify
14// extensions.
15class MyExtension : public ::mlir::transform::TransformDialectExtension<MyExtension> {
16public:
17  // The extension must derive the base constructor.
18  using Base::Base;
19
20  // This function initializes the extension, similarly to `initialize` in dialect
21  // definitions. List individual operations and dependent dialects here.
22  void init();
23};
24
25void MyExtension::init() {
26  // Similarly to dialects, an extension can declare a dependent dialect. This dialect
27  // will be loaded along with the extension and, therefore, along with the Transform
28  // dialect. Only declare as dependent the dialects that contain the attributes or
29  // types used by transform operations. Do NOT declare as dependent the dialects
30  // produced during the transformation.
31  // declareDependentDialect<MyDialect>();
32
33  // When transformations are applied, they may produce new operations from previously
34  // unloaded dialects. Typically, a pass would need to declare itself dependent on
35  // the dialects containing such new operations. To avoid confusion with the dialects
36  // the extension itself depends on, the Transform dialects differentiates between:
37  //   - dependent dialects, which are used by the transform operations, and
38  //   - generated dialects, which contain the entities (attributes, operations,
39  //     types) that may be produced by applying the transformation even when not
40  //     present in the original payload IR.
41  // In the following chapter, we will be add operations that generate function calls
42  // and structured control flow operations, so let's declare the corresponding
43  // dialects as generated.
44  declareGeneratedDialect<::mlir::scf::SCFDialect>();
45  declareGeneratedDialect<::mlir::func::FuncDialect>();
46
47  // Finally, we register the additional transform operations with the dialect.
48  registerTransformOps<
49    // TODO: list the operation classes.
50  >();
51}
52```
53
54The operations themselves can be defined using ODS, exactly in the same way as regular operations in a dialect.
55
56```tablegen
57// In MyExtension.td
58#ifndef MY_EXTENSION
59#define MY_EXTENSION
60
61include "mlir/Dialect/Transform/IR/TransformDialect.td"
62include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
63include "mlir/IR/OpBase.td"
64include "mlir/Interfaces/SideEffectInterfaces.td"
65
66def MyOp : Op<Transform_Dialect, "transform.my.op", [
67    // TODO: interfaces and traits here.
68   ]> {
69  let summary = "my transform op";
70  // TODO: define the operation properties.
71}
72
73#endif // MY_EXTENSION
74```
75
76Similarly to dialects, we must use Tablegen to generate the header and implementation of these operations. We can instruct CMake to do it as follows.
77
78
79```sh
80# In CMakeLists.txt next to MyExtension.td.
81
82# Tell Tablegen to use MyExtension.td as input.
83set(LLVM_TARGET_DEFINITIONS MyExtension.td)
84
85# Ask Tablegen to generate op declarations and definitions from ODS.
86mlir_tablegen(MyExtension.h.inc -gen-op-decls)
87mlir_tablegen(MyExtension.cpp.inc -gen-op-defs)
88
89# Add a CMakeTarget we can depend on to ensure the generation happens before the compilation.
90add_public_tablegen_target(MyExtensionIncGen)
91
92# Don't forget to generate the documentation, this will produce a MyExtension.md under
93# Dialects.
94add_mlir_doc(MyExtension MyExtension Dialects/ -gen-op-doc)
95```
96
97```sh
98# In CMakeLists.txt next to MyExtension.cpp
99add_mlir_library(
100  # Library called MyExtension.
101  MyExtension
102
103  # Built from the following source files.
104  MyExtension.cpp
105
106  # Make sure ODS declaration and definitions are generated before compiling this.
107  DEPENDS
108  MyExtensionIncGen
109
110  # Link in the transform dialect, and all generated dialects.
111  LINK_LIBS PUBLIC
112  MLIRTransformDialect
113  MLIRFuncDialect
114  MLIRSCFDialect
115)
116```
117
118This will generate two files, `MyExtension.h.inc` and `MyExtension.cpp.inc`, that are supposed to be included into the declaration and definition of the transform operations, respectively.
119
120```c++
121// In MyExtension.h.
122#include "mlir/Dialect/Transform/IR/TransformDialect.h"
123#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
124
125#define GET_OP_CLASSES
126#include "MyExtension.h.inc"
127```
128
129```c++
130// In MyExtension.cpp.
131
132#define GET_OP_CLASSES
133#include "MyExtension.cpp.inc"
134
135// …
136void MyExtension::init() {
137  // …
138
139  // Finally, we register the additional transform operations with the dialect. List all
140  // operations generated from ODS. This call will perform additional checks that the
141  // operations implement the transform and memory effect interfaces required by the
142  // dialect interpreter and assert if they do not.
143  registerTransformOps<
144#define GET_OP_LIST
145#include "MyExtension.cpp.inc"
146  >();
147}
148```
149
150## Defining a Transform Operation
151
152With this setup, we are now ready to define the new transform operation to rewrite the function call. This is identical to defining a regular operation in a dialect. Note that the Transform dialect requires operations to implement the `TransformOpInterface` as well as `MemoryEffectsOpInterface` to indicate whether the operands are consumed or only read. Our operation can be defined along the following lines.
153
154```tablegen
155// In MyExtension.td.
156
157// Define the new operation. By convention, prefix its name with the name of the dialect
158// extension, "my.". The full operation name will be further prefixed with "transform.".
159def ChangeCallTargetOp : Op<Transform_Dialect, "my.change_call_target",
160    // Indicate that the operation implements the required TransformOpInterface and
161    // MemoryEffectsOpInterface.
162    [DeclareOpInterfaceMethods<TransformOpInterface>,
163     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
164  // Provide a brief and a full description. It is recommended that the latter describes
165  // the effects on the operands and how the operation processes various failure modes.
166  let summary = "Changes the callee of a call operation to the specified one";
167  let description = [{
168    For each `func.call` payload operation associated with the handle, changes its
169    callee to be the symbol whose name is provided as an attribute to this operation.
170
171    Generates a silenceable failure if the operand is associated with payload operations
172    that are not `func.call`.
173    Only reads the operand.
174  }];
175
176  // The arguments include the handle to the payload operations and the attribute that
177  // specifies the new callee. The handle must implement TransformHandleTypeInterface.
178  // We use a string attribute as the symbol may not exist in the transform IR so the
179  // verification may fail.
180  let arguments = (ins
181    TransformHandleTypeInterface:$call,
182    StrAttr:$new_target);
183
184  // The results are empty as the transformation does not produce any new payload.
185  let results = (outs);
186
187  // Provide nice syntax.
188  let assemblyFormat = "$call `,` $new_target attr-dict `:` type($call)";
189}
190```
191
192To finalize the definition of the transform operation, we need to implement the
193interface methods. The `TransformOpInterface` currently requires only one method
194– `apply` – that performs the actual transformation. It is a good practice to
195limit the body of the method to manipulation of the Transform dialect constructs
196and have the actual transformation implemented as a standalone function so it
197can be used from other places in the code. Similar to rewrite patterns, all IR
198must be modified with the provided rewriter.
199
200```c++
201// In MyExtension.cpp
202
203// Implementation of our transform dialect operation.
204// This operation returns a tri-state result that can be one of:
205// - success when the transformation succeeded;
206// - definite failure when the transformation failed in such a way that
207//   following transformations are impossible or undesirable, typically it could
208//   have left payload IR in an invalid state; it is expected that a diagnostic
209//   is emitted immediately before returning the definite error;
210// - silenceable failure when the transformation failed but following
211//   transformations are still applicable, typically this means a precondition
212//   for the transformation is not satisfied and the payload IR has not been
213//   modified. The silenceable failure additionally carries a Diagnostic that
214//   can be emitted to the user.
215::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::apply(
216    // The rewriter that should be used when modifying IR.
217    ::mlir::transform::TransformRewriter &rewriter,
218    // The list of payload IR entities that will be associated with the
219    // transform IR values defined by this transform operation. In this case, it
220    // can remain empty as there are no results.
221    ::mlir::transform::TransformResults &results,
222    // The transform application state. This object can be used to query the
223    // current associations between transform IR values and payload IR entities.
224    // It can also carry additional user-defined state.
225    ::mlir::transform::TransformState &state) {
226
227  // First, we need to obtain the list of payload operations that are associated with
228  // the operand handle.
229  auto payload = state.getPayloadOps(getCall());
230
231  // Then, we iterate over the list of operands and call the actual IR-mutating
232  // function. We also check the preconditions here.
233  for (Operation *payloadOp : payload) {
234    auto call = dyn_cast<::mlir::func::CallOp>(payloadOp);
235    if (!call) {
236      DiagnosedSilenceableFailure diag = emitSilenceableError()
237          << "only applies to func.call payloads";
238      diag.attachNote(payloadOp->getLoc()) << "offending payload";
239      return diag;
240    }
241
242    updateCallee(call, getNewTarget());
243  }
244
245  // If everything went well, return success.
246  return DiagnosedSilenceableFailure::success();
247}
248```
249
250The implementation of the `MemoryEffectsOpInterface` must specify the effects this operation has on its operands (consumed or readonly) and on the payload IR (mutates or readonly). Transform dialect verifiers will check for side effects being present and assert in debug builds if they are not.
251
252```c++
253// In MyExtension.cpp
254
255void ChangeCallTargetOp::getEffects(
256    ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) {
257  // Indicate that the `call` handle is only read by this operation because the
258  // associated operation is not erased but rather modified in-place, so the
259  // reference to it remains valid.
260  onlyReadsHandle(getCall(), effects);
261
262  // Indicate that the payload is modified by this operation.
263  modifiesPayload(effects);
264}
265```
266
267## Registration and Usage
268
269This is enough to define transform operations. The only remaining bit is providing the extension registration hook that can be called from the project’s `main`.
270
271
272```c++
273// In TransformDialect.cpp (don't forget a declaration in TransformDialect.h);
274
275void registerMyExtension(::mlir::DialectRegistry &registry) {
276  registry.addExtensions<MyExtension>();
277}
278```
279
280After registering the extension, it becomes possible to use our new operation in the transform dialect interpreter. The upstream testing pass can be used as is.
281
282```mlir
283transform.sequence failures(propagate) {
284^bb0(%arg0: !transform.any_op,
285     %arg1: !transform.op<"linalg.matmul">,
286     %arg2: !transform.op<"linalg.elemwise_binary">):
287  // Since the %arg2 handle is associated with both elementwise operations,
288  // we need to split it into two handles so we can target only the second
289  // elementwise operation.
290  %add, %max = transform.split_handle %arg2 : (!transform.op<"linalg.elemwise_binary">)
291      -> (!transform.any_op, !transform.any_op)
292
293  // The actual tiling transformation takes tile sizes as attributes. It produces a
294  // handle to the loop generated during tiling.
295  %loop, %tiled = transform.structured.tile_using_forall %max tile_sizes [8, 32]
296      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
297
298  // We can now fuse the other operations into the loop. Here, we fuse
299  // operations one-by-one. This requires the operation that is being fused
300  // to define the value used within the loop, so the order of such fusions
301  // is important. We could also use "transform.merge_handles" to obtain
302  // a single handle to all operations and give it to `fuse_into_containing_op`
303  // that would take care of the ordering in this case.
304  %add_fused = transform.structured.fuse_into_containing_op %add into %loop
305      : (!transform.any_op, !transform.any_op) -> !transform.any_op
306  %matmul_fused = transform.structured.fuse_into_containing_op %arg1 into %loop
307      : (!transform.op<"linalg.matmul">, !transform.any_op) -> !transform.any_op
308
309  // Tile again to get the desired size. Note that this time this tiles the
310  // "add" operation and fuses matmul into the loop, but doesn't affect the
311  // "max" operation. This illustrates the precise targeting with the transform
312  // dialect. Otherwise, it is difficult to differentiate "add" and "max", both
313  // of which having the same kind.
314  %loop_2, %tiled_2 = transform.structured.tile_using_forall %add_fused tile_sizes [4, 4]
315      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
316  %matmul_fused_2 = transform.structured.fuse_into_containing_op %matmul_fused into %loop_2
317      : (!transform.any_op, !transform.any_op) -> !transform.any_op
318
319  // Since outlining is currently only implemented for region-holding operations
320  // such as loops, use tiling to size 1 to materialize the outer loop that is
321  // going to be outlined.
322  %outline_target, %_ = transform.structured.tile_using_forall %tiled_2 tile_sizes [1]
323      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
324  transform.structured.fuse_into_containing_op %matmul_fused_2 into %outline_target
325      : (!transform.any_op, !transform.any_op) -> !transform.any_op
326  %func, %call = transform.loop.outline %outline_target {func_name = "outlined"}
327      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
328
329  // Rewrite the call target.
330  transform.my.change_call_target %call, "microkernel" : !transform.any_op
331
332  transform.yield
333}
334```
335
336## Appendix: Autogenerated Documentation
337
338[include "Tutorials/transform/MyExtensionCh2.md"]
339