xref: /llvm-project/mlir/docs/Tutorials/transform/Ch2.md (revision 5a9bdd85ee4d8527e2cedf44f3ce26ff414f9b6a)
1# Chapter 2: Adding a Simple New Transformation Operation
2
3## Setting Up to Add New Transformations
4
5Before defining a new transform operation, we need to choose where its implementation should be located. While MLIR encourages upstream contributions, it is not always possible or even desirable to modify the main Transform dialect, for example, if the transformation is specific to some out-of-tree dialect that is not itself available upstream.
6
7The Transform dialect uses the dialect extension mechanism to allow additional operations to be injected without modifying the dialect itself. Dialect extensions are registered with the context and loaded when the dialect itself is loaded. Extension definition is straightforward:
8
9```cpp
10// In MyExtension.cpp.
11#include "mlir/Dialect/Transform/IR/TransformDialect.h"
12
13// Define a new Transform dialect extension. This uses the CRTP idiom to
14// identify extensions.
15class MyExtension : public ::mlir::transform::TransformDialectExtension<MyExtension> {
16public:
17  // The extension must derive the base constructor.
18  using Base::Base;
19
20  // This function initializes the extension, similarly to `initialize` in
21  // dialect  definitions. List individual operations and dependent dialects
22  // here.
23  void init();
24};
25
26void MyExtension::init() {
27  // Similarly to dialects, an extension can declare a dependent dialect. This
28  // dialect will be loaded along with the extension and, therefore, along with
29  // the Transform  dialect. Only declare as dependent the dialects that contain
30  // the attributes or types used by transform operations. Do NOT declare as
31  // dependent the dialects produced during the transformation.
32  //
33  // declareDependentDialect<MyDialect>();
34
35  // When transformations are applied, they may produce new operations from
36  // previously unloaded dialects. Typically, a pass would need to declare
37  // itself dependent on the dialects containing such new operations. To avoid
38  // confusion with the dialects the extension itself depends on, the Transform
39  // dialects differentiates between:
40  //   - dependent dialects, which are used by the transform operations, and
41  //   - generated dialects, which contain the entities (attributes, operations,
42  //     types) that may be produced by applying the transformation even when
43  //     not present in the original payload IR.
44  // In the following chapter, we will be add operations that generate function
45  // calls and structured control flow operations, so let's declare the
46  // corresponding dialects as generated.
47  declareGeneratedDialect<::mlir::scf::SCFDialect>();
48  declareGeneratedDialect<::mlir::func::FuncDialect>();
49
50  // Finally, we register the additional transform operations with the dialect.
51  registerTransformOps<
52    // TODO: list the operation classes.
53  >();
54}
55```
56
57The operations themselves can be defined using ODS, exactly in the same way as regular operations in a dialect.
58
59```tablegen
60// In MyExtension.td
61#ifndef MY_EXTENSION
62#define MY_EXTENSION
63
64include "mlir/Dialect/Transform/IR/TransformDialect.td"
65include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
66include "mlir/IR/OpBase.td"
67include "mlir/Interfaces/SideEffectInterfaces.td"
68
69def MyOp : Op<Transform_Dialect, "transform.my.op", [
70    // TODO: interfaces and traits here.
71   ]> {
72  let summary = "my transform op";
73  // TODO: define the operation properties.
74}
75
76#endif // MY_EXTENSION
77```
78
79Similarly to dialects, we must use Tablegen to generate the header and implementation of these operations. We can instruct CMake to do it as follows.
80
81
82```sh
83# In CMakeLists.txt next to MyExtension.td.
84
85# Tell Tablegen to use MyExtension.td as input.
86set(LLVM_TARGET_DEFINITIONS MyExtension.td)
87
88# Ask Tablegen to generate op declarations and definitions from ODS.
89mlir_tablegen(MyExtension.h.inc -gen-op-decls)
90mlir_tablegen(MyExtension.cpp.inc -gen-op-defs)
91
92# Add a CMakeTarget we can depend on to ensure the generation happens before the compilation.
93add_public_tablegen_target(MyExtensionIncGen)
94
95# Don't forget to generate the documentation, this will produce a MyExtension.md under
96# Dialects.
97add_mlir_doc(MyExtension MyExtension Dialects/ -gen-op-doc)
98```
99
100```sh
101# In CMakeLists.txt next to MyExtension.cpp
102add_mlir_library(
103  # Library called MyExtension.
104  MyExtension
105
106  # Built from the following source files.
107  MyExtension.cpp
108
109  # Make sure ODS declaration and definitions are generated before compiling
110  # this.
111  DEPENDS
112  MyExtensionIncGen
113
114  # Link in the transform dialect, and all generated dialects.
115  LINK_LIBS PUBLIC
116  MLIRTransformDialect
117  MLIRFuncDialect
118  MLIRSCFDialect
119)
120```
121
122This will generate two files, `MyExtension.h.inc` and `MyExtension.cpp.inc`, that are supposed to be included into the declaration and definition of the transform operations, respectively.
123
124```c++
125// In MyExtension.h.
126#include "mlir/Dialect/Transform/IR/TransformDialect.h"
127#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
128
129#define GET_OP_CLASSES
130#include "MyExtension.h.inc"
131```
132
133```c++
134// In MyExtension.cpp.
135
136#define GET_OP_CLASSES
137#include "MyExtension.cpp.inc"
138
139// …
140void MyExtension::init() {
141  // …
142
143  // Finally, we register the additional transform operations with the dialect.
144  // List all  operations generated from ODS. This call will perform additional
145  // checks that the  operations implement the transform and memory effect
146  // interfaces required by the dialect interpreter and assert if they do not.
147  registerTransformOps<
148#define GET_OP_LIST
149#include "MyExtension.cpp.inc"
150  >();
151}
152```
153
154## Defining a Transform Operation
155
156With this setup, we are now ready to define the new transform operation to rewrite the function call. This is identical to defining a regular operation in a dialect. Note that the Transform dialect requires operations to implement the `TransformOpInterface` as well as `MemoryEffectsOpInterface` to indicate whether the operands are consumed or only read. Our operation can be defined along the following lines.
157
158```tablegen
159// In MyExtension.td.
160
161// Define the new operation. By convention, prefix its name with the name of the
162// dialect  extension, "my.". The full operation name will be further prefixed
163// with "transform.".
164def ChangeCallTargetOp : Op<Transform_Dialect, "my.change_call_target",
165    // Indicate that the operation implements the required TransformOpInterface
166    // and MemoryEffectsOpInterface.
167    [DeclareOpInterfaceMethods<TransformOpInterface>,
168     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
169  // Provide a brief and a full description. It is recommended that the latter
170  // describes the effects on the operands and how the operation processes
171  // various failure modes.
172  let summary = "Changes the callee of a call operation to the specified one";
173  let description = [{
174    For each `func.call` payload operation associated with the handle, changes
175    its callee to be the symbol whose name is provided as an attribute to this operation.
176
177    Generates a silenceable failure if the operand is associated with payload operations that are not `func.call`. Only reads the operand.
178  }];
179
180  // The arguments include the handle to the payload operations and the
181  // attribute that specifies the new callee. The handle must implement
182  // TransformHandleTypeInterface.
183  // We use a string attribute as the symbol may not exist in the transform IR
184  // so the verification may fail.
185  let arguments = (ins
186    TransformHandleTypeInterface:$call,
187    StrAttr:$new_target);
188
189  // The results are empty as the transformation does not produce any new
190  // payload.
191  let results = (outs);
192
193  // Provide nice syntax.
194  let assemblyFormat = "$call `,` $new_target attr-dict `:` type($call)";
195}
196```
197
198To finalize the definition of the transform operation, we need to implement the
199interface methods. The `TransformOpInterface` currently requires only one method
200– `apply` – that performs the actual transformation. It is a good practice to
201limit the body of the method to manipulation of the Transform dialect constructs
202and have the actual transformation implemented as a standalone function so it
203can be used from other places in the code. Similar to rewrite patterns, all IR
204must be modified with the provided rewriter.
205
206```c++
207// In MyExtension.cpp
208
209// Implementation of our Transform dialect operation.
210// This operation returns a tri-state result that can be one of:
211// - success when the transformation succeeded;
212// - definite failure when the transformation failed in such a way that
213//   following transformations are impossible or undesirable, typically it could
214//   have left payload IR in an invalid state; it is expected that a diagnostic
215//   is emitted immediately before returning the definite error;
216// - silenceable failure when the transformation failed but following
217//   transformations are still applicable, typically this means a precondition
218//   for the transformation is not satisfied and the payload IR has not been
219//   modified. The silenceable failure additionally carries a Diagnostic that
220//   can be emitted to the user.
221::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::apply(
222    // The rewriter that should be used when modifying IR.
223    ::mlir::transform::TransformRewriter &rewriter,
224    // The list of payload IR entities that will be associated with the
225    // transform IR values defined by this transform operation. In this case, it
226    // can remain empty as there are no results.
227    ::mlir::transform::TransformResults &results,
228    // The transform application state. This object can be used to query the
229    // current associations between transform IR values and payload IR entities.
230    // It can also carry additional user-defined state.
231    ::mlir::transform::TransformState &state) {
232
233  // First, we need to obtain the list of payload operations that are associated
234  // with the operand handle.
235  auto payload = state.getPayloadOps(getCall());
236
237  // Then, we iterate over the list of operands and call the actual IR-mutating
238  // function. We also check the preconditions here.
239  for (Operation *payloadOp : payload) {
240    auto call = dyn_cast<::mlir::func::CallOp>(payloadOp);
241    if (!call) {
242      DiagnosedSilenceableFailure diag = emitSilenceableError()
243          << "only applies to func.call payloads";
244      diag.attachNote(payloadOp->getLoc()) << "offending payload";
245      return diag;
246    }
247
248    updateCallee(call, getNewTarget());
249  }
250
251  // If everything went well, return success.
252  return DiagnosedSilenceableFailure::success();
253}
254```
255
256The implementation of the `MemoryEffectsOpInterface` must specify the effects this operation has on its operands (consumed or readonly) and on the payload IR (mutates or readonly). Transform dialect verifiers will check for side effects being present and assert in debug builds if they are not.
257
258```c++
259// In MyExtension.cpp
260
261void ChangeCallTargetOp::getEffects(
262    ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) {
263  // Indicate that the `call` handle is only read by this operation because the
264  // associated operation is not erased but rather modified in-place, so the
265  // reference to it remains valid.
266  onlyReadsHandle(getCall(), effects);
267
268  // Indicate that the payload is modified by this operation.
269  modifiesPayload(effects);
270}
271```
272
273## Registration and Usage
274
275This is enough to define transform operations. The only remaining bit is providing the extension registration hook that can be called from the project’s `main`.
276
277
278```c++
279// In TransformDialect.cpp (don't forget a declaration in TransformDialect.h);
280
281void registerMyExtension(::mlir::DialectRegistry &registry) {
282  registry.addExtensions<MyExtension>();
283}
284```
285
286After registering the extension, it becomes possible to use our new operation in the Transform dialect interpreter. The upstream testing pass can be used as is.
287
288```mlir
289module attributes {transform.with_named_sequence} {
290  transform.named_sequence @__transform_main(
291      %arg0: !transform.any_op,
292      %arg1: !transform.op<"linalg.matmul">,
293      %arg2: !transform.op<"linalg.elemwise_binary">) {
294    // Since the %arg2 handle is associated with both elementwise operations,
295    // we need to split it into two handles so we can target only the second
296    // elementwise operation.
297    %add, %max = transform.split_handle %arg2
298        : (!transform.op<"linalg.elemwise_binary">)
299        -> (!transform.any_op, !transform.any_op)
300
301    // The actual tiling transformation takes tile sizes as attributes. It
302    // produces a handle to the loop generated during tiling.
303    %loop, %tiled = transform.structured.tile_using_forall %max
304                    tile_sizes [8, 32]
305        : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
306
307    // We can now fuse the other operations into the loop. Here, we fuse
308    // operations one-by-one. This requires the operation that is being fused
309    // to define the value used within the loop, so the order of such fusions
310    // is important. We could also use "transform.merge_handles" to obtain
311    // a single handle to all operations and give it to
312    // `fuse_into_containing_op` that would take care of the ordering in this
313    // case.
314    %add_fused = transform.structured.fuse_into_containing_op %add into %loop
315        : (!transform.any_op, !transform.any_op) -> !transform.any_op
316    %matmul_fused = transform.structured.fuse_into_containing_op %arg1
317                    into %loop
318        : (!transform.op<"linalg.matmul">, !transform.any_op)
319       -> !transform.any_op
320
321    // Tile again to get the desired size. Note that this time this tiles the
322    // "add" operation and fuses matmul into the loop, but doesn't affect the
323    // "max" operation. This illustrates the precise targeting with the
324    // transform dialect. Otherwise, it is difficult to differentiate "add" and
325    // "max", both of which having the same kind.
326    %loop_2, %tiled_2 = transform.structured.tile_using_forall %add_fused
327                        tile_sizes [4, 4]
328        : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
329    %matmul_fused_2 = transform.structured.fuse_into_containing_op %matmul_fused
330                      into %loop_2
331        : (!transform.any_op, !transform.any_op) -> !transform.any_op
332
333    // Since outlining is currently only implemented for region-holding
334    // operations such as loops, use tiling to size 1 to materialize the outer
335    // loop that is going to be outlined.
336    %outline_target, %_ = transform.structured.tile_using_forall %tiled_2 tile_sizes [1]
337        : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
338    transform.structured.fuse_into_containing_op %matmul_fused_2 into %outline_target
339        : (!transform.any_op, !transform.any_op) -> !transform.any_op
340    %func, %call = transform.loop.outline %outline_target
341                   {func_name = "outlined"}
342        : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
343
344    // Rewrite the call target.
345    transform.my.change_call_target %call, "microkernel" : !transform.any_op
346
347    transform.yield
348  }
349}
350```
351
352## Appendix: Autogenerated Documentation
353
354[include "Tutorials/transform/MyExtensionCh2.md"]
355