1# Chapter 2: Adding a Simple New Transformation Operation 2 3## Setting Up to Add New Transformations 4 5Before defining a new transform operation, we need to choose where its implementation should be located. While MLIR encourages upstream contributions, it is not always possible or even desirable to modify the main Transform dialect, for example, if the transformation is specific to some out-of-tree dialect that is not itself available upstream. 6 7The Transform dialect uses the dialect extension mechanism to allow additional operations to be injected without modifying the dialect itself. Dialect extensions are registered with the context and loaded when the dialect itself is loaded. Extension definition is straightforward: 8 9```cpp 10// In MyExtension.cpp. 11#include "mlir/Dialect/Transform/IR/TransformDialect.h" 12 13// Define a new transform dialect extension. This uses the CRTP idiom to identify 14// extensions. 15class MyExtension : public ::mlir::transform::TransformDialectExtension<MyExtension> { 16public: 17 // The extension must derive the base constructor. 18 using Base::Base; 19 20 // This function initializes the extension, similarly to `initialize` in dialect 21 // definitions. List individual operations and dependent dialects here. 22 void init(); 23}; 24 25void MyExtension::init() { 26 // Similarly to dialects, an extension can declare a dependent dialect. This dialect 27 // will be loaded along with the extension and, therefore, along with the Transform 28 // dialect. Only declare as dependent the dialects that contain the attributes or 29 // types used by transform operations. Do NOT declare as dependent the dialects 30 // produced during the transformation. 31 // declareDependentDialect<MyDialect>(); 32 33 // When transformations are applied, they may produce new operations from previously 34 // unloaded dialects. Typically, a pass would need to declare itself dependent on 35 // the dialects containing such new operations. To avoid confusion with the dialects 36 // the extension itself depends on, the Transform dialects differentiates between: 37 // - dependent dialects, which are used by the transform operations, and 38 // - generated dialects, which contain the entities (attributes, operations, 39 // types) that may be produced by applying the transformation even when not 40 // present in the original payload IR. 41 // In the following chapter, we will be add operations that generate function calls 42 // and structured control flow operations, so let's declare the corresponding 43 // dialects as generated. 44 declareGeneratedDialect<::mlir::scf::SCFDialect>(); 45 declareGeneratedDialect<::mlir::func::FuncDialect>(); 46 47 // Finally, we register the additional transform operations with the dialect. 48 registerTransformOps< 49 // TODO: list the operation classes. 50 >(); 51} 52``` 53 54The operations themselves can be defined using ODS, exactly in the same way as regular operations in a dialect. 55 56```tablegen 57// In MyExtension.td 58#ifndef MY_EXTENSION 59#define MY_EXTENSION 60 61include "mlir/Dialect/Transform/IR/TransformDialect.td" 62include "mlir/Dialect/Transform/IR/TransformInterfaces.td" 63include "mlir/IR/OpBase.td" 64include "mlir/Interfaces/SideEffectInterfaces.td" 65 66def MyOp : Op<Transform_Dialect, "transform.my.op", [ 67 // TODO: interfaces and traits here. 68 ]> { 69 let summary = "my transform op"; 70 // TODO: define the operation properties. 71} 72 73#endif // MY_EXTENSION 74``` 75 76Similarly to dialects, we must use Tablegen to generate the header and implementation of these operations. We can instruct CMake to do it as follows. 77 78 79```sh 80# In CMakeLists.txt next to MyExtension.td. 81 82# Tell Tablegen to use MyExtension.td as input. 83set(LLVM_TARGET_DEFINITIONS MyExtension.td) 84 85# Ask Tablegen to generate op declarations and definitions from ODS. 86mlir_tablegen(MyExtension.h.inc -gen-op-decls) 87mlir_tablegen(MyExtension.cpp.inc -gen-op-defs) 88 89# Add a CMakeTarget we can depend on to ensure the generation happens before the compilation. 90add_public_tablegen_target(MyExtensionIncGen) 91 92# Don't forget to generate the documentation, this will produce a MyExtension.md under 93# Dialects. 94add_mlir_doc(MyExtension MyExtension Dialects/ -gen-op-doc) 95``` 96 97```sh 98# In CMakeLists.txt next to MyExtension.cpp 99add_mlir_library( 100 # Library called MyExtension. 101 MyExtension 102 103 # Built from the following source files. 104 MyExtension.cpp 105 106 # Make sure ODS declaration and definitions are generated before compiling this. 107 DEPENDS 108 MyExtensionIncGen 109 110 # Link in the transform dialect, and all generated dialects. 111 LINK_LIBS PUBLIC 112 MLIRTransformDialect 113 MLIRFuncDialect 114 MLIRSCFDialect 115) 116``` 117 118This will generate two files, `MyExtension.h.inc` and `MyExtension.cpp.inc`, that are supposed to be included into the declaration and definition of the transform operations, respectively. 119 120```c++ 121// In MyExtension.h. 122#include "mlir/Dialect/Transform/IR/TransformDialect.h" 123#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 124 125#define GET_OP_CLASSES 126#include "MyExtension.h.inc" 127``` 128 129```c++ 130// In MyExtension.cpp. 131 132#define GET_OP_CLASSES 133#include "MyExtension.cpp.inc" 134 135// … 136void MyExtension::init() { 137 // … 138 139 // Finally, we register the additional transform operations with the dialect. List all 140 // operations generated from ODS. This call will perform additional checks that the 141 // operations implement the transform and memory effect interfaces required by the 142 // dialect interpreter and assert if they do not. 143 registerTransformOps< 144#define GET_OP_LIST 145#include "MyExtension.cpp.inc" 146 >(); 147} 148``` 149 150## Defining a Transform Operation 151 152With this setup, we are now ready to define the new transform operation to rewrite the function call. This is identical to defining a regular operation in a dialect. Note that the Transform dialect requires operations to implement the `TransformOpInterface` as well as `MemoryEffectsOpInterface` to indicate whether the operands are consumed or only read. Our operation can be defined along the following lines. 153 154```tablegen 155// In MyExtension.td. 156 157// Define the new operation. By convention, prefix its name with the name of the dialect 158// extension, "my.". The full operation name will be further prefixed with "transform.". 159def ChangeCallTargetOp : Op<Transform_Dialect, "my.change_call_target", 160 // Indicate that the operation implements the required TransformOpInterface and 161 // MemoryEffectsOpInterface. 162 [DeclareOpInterfaceMethods<TransformOpInterface>, 163 DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> { 164 // Provide a brief and a full description. It is recommended that the latter describes 165 // the effects on the operands and how the operation processes various failure modes. 166 let summary = "Changes the callee of a call operation to the specified one"; 167 let description = [{ 168 For each `func.call` payload operation associated with the handle, changes its 169 callee to be the symbol whose name is provided as an attribute to this operation. 170 171 Generates a silenceable failure if the operand is associated with payload operations 172 that are not `func.call`. 173 Only reads the operand. 174 }]; 175 176 // The arguments include the handle to the payload operations and the attribute that 177 // specifies the new callee. The handle must implement TransformHandleTypeInterface. 178 // We use a string attribute as the symbol may not exist in the transform IR so the 179 // verification may fail. 180 let arguments = (ins 181 TransformHandleTypeInterface:$call, 182 StrAttr:$new_target); 183 184 // The results are empty as the transformation does not produce any new payload. 185 let results = (outs); 186 187 // Provide nice syntax. 188 let assemblyFormat = "$call `,` $new_target attr-dict `:` type($call)"; 189} 190``` 191 192To finalize the definition of the transform operation, we need to implement the 193interface methods. The `TransformOpInterface` currently requires only one method 194– `apply` – that performs the actual transformation. It is a good practice to 195limit the body of the method to manipulation of the Transform dialect constructs 196and have the actual transformation implemented as a standalone function so it 197can be used from other places in the code. Similar to rewrite patterns, all IR 198must be modified with the provided rewriter. 199 200```c++ 201// In MyExtension.cpp 202 203// Implementation of our transform dialect operation. 204// This operation returns a tri-state result that can be one of: 205// - success when the transformation succeeded; 206// - definite failure when the transformation failed in such a way that 207// following transformations are impossible or undesirable, typically it could 208// have left payload IR in an invalid state; it is expected that a diagnostic 209// is emitted immediately before returning the definite error; 210// - silenceable failure when the transformation failed but following 211// transformations are still applicable, typically this means a precondition 212// for the transformation is not satisfied and the payload IR has not been 213// modified. The silenceable failure additionally carries a Diagnostic that 214// can be emitted to the user. 215::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::apply( 216 // The rewriter that should be used when modifying IR. 217 ::mlir::transform::TransformRewriter &rewriter, 218 // The list of payload IR entities that will be associated with the 219 // transform IR values defined by this transform operation. In this case, it 220 // can remain empty as there are no results. 221 ::mlir::transform::TransformResults &results, 222 // The transform application state. This object can be used to query the 223 // current associations between transform IR values and payload IR entities. 224 // It can also carry additional user-defined state. 225 ::mlir::transform::TransformState &state) { 226 227 // First, we need to obtain the list of payload operations that are associated with 228 // the operand handle. 229 auto payload = state.getPayloadOps(getCall()); 230 231 // Then, we iterate over the list of operands and call the actual IR-mutating 232 // function. We also check the preconditions here. 233 for (Operation *payloadOp : payload) { 234 auto call = dyn_cast<::mlir::func::CallOp>(payloadOp); 235 if (!call) { 236 DiagnosedSilenceableFailure diag = emitSilenceableError() 237 << "only applies to func.call payloads"; 238 diag.attachNote(payloadOp->getLoc()) << "offending payload"; 239 return diag; 240 } 241 242 updateCallee(call, getNewTarget()); 243 } 244 245 // If everything went well, return success. 246 return DiagnosedSilenceableFailure::success(); 247} 248``` 249 250The implementation of the `MemoryEffectsOpInterface` must specify the effects this operation has on its operands (consumed or readonly) and on the payload IR (mutates or readonly). Transform dialect verifiers will check for side effects being present and assert in debug builds if they are not. 251 252```c++ 253// In MyExtension.cpp 254 255void ChangeCallTargetOp::getEffects( 256 ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) { 257 // Indicate that the `call` handle is only read by this operation because the 258 // associated operation is not erased but rather modified in-place, so the 259 // reference to it remains valid. 260 onlyReadsHandle(getCall(), effects); 261 262 // Indicate that the payload is modified by this operation. 263 modifiesPayload(effects); 264} 265``` 266 267## Registration and Usage 268 269This is enough to define transform operations. The only remaining bit is providing the extension registration hook that can be called from the project’s `main`. 270 271 272```c++ 273// In TransformDialect.cpp (don't forget a declaration in TransformDialect.h); 274 275void registerMyExtension(::mlir::DialectRegistry ®istry) { 276 registry.addExtensions<MyExtension>(); 277} 278``` 279 280After registering the extension, it becomes possible to use our new operation in the transform dialect interpreter. The upstream testing pass can be used as is. 281 282```mlir 283transform.sequence failures(propagate) { 284^bb0(%arg0: !transform.any_op, 285 %arg1: !transform.op<"linalg.matmul">, 286 %arg2: !transform.op<"linalg.elemwise_binary">): 287 // Since the %arg2 handle is associated with both elementwise operations, 288 // we need to split it into two handles so we can target only the second 289 // elementwise operation. 290 %add, %max = transform.split_handle %arg2 : (!transform.op<"linalg.elemwise_binary">) 291 -> (!transform.any_op, !transform.any_op) 292 293 // The actual tiling transformation takes tile sizes as attributes. It produces a 294 // handle to the loop generated during tiling. 295 %loop, %tiled = transform.structured.tile_using_forall %max tile_sizes [8, 32] 296 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 297 298 // We can now fuse the other operations into the loop. Here, we fuse 299 // operations one-by-one. This requires the operation that is being fused 300 // to define the value used within the loop, so the order of such fusions 301 // is important. We could also use "transform.merge_handles" to obtain 302 // a single handle to all operations and give it to `fuse_into_containing_op` 303 // that would take care of the ordering in this case. 304 %add_fused = transform.structured.fuse_into_containing_op %add into %loop 305 : (!transform.any_op, !transform.any_op) -> !transform.any_op 306 %matmul_fused = transform.structured.fuse_into_containing_op %arg1 into %loop 307 : (!transform.op<"linalg.matmul">, !transform.any_op) -> !transform.any_op 308 309 // Tile again to get the desired size. Note that this time this tiles the 310 // "add" operation and fuses matmul into the loop, but doesn't affect the 311 // "max" operation. This illustrates the precise targeting with the transform 312 // dialect. Otherwise, it is difficult to differentiate "add" and "max", both 313 // of which having the same kind. 314 %loop_2, %tiled_2 = transform.structured.tile_using_forall %add_fused tile_sizes [4, 4] 315 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 316 %matmul_fused_2 = transform.structured.fuse_into_containing_op %matmul_fused into %loop_2 317 : (!transform.any_op, !transform.any_op) -> !transform.any_op 318 319 // Since outlining is currently only implemented for region-holding operations 320 // such as loops, use tiling to size 1 to materialize the outer loop that is 321 // going to be outlined. 322 %outline_target, %_ = transform.structured.tile_using_forall %tiled_2 tile_sizes [1] 323 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 324 transform.structured.fuse_into_containing_op %matmul_fused_2 into %outline_target 325 : (!transform.any_op, !transform.any_op) -> !transform.any_op 326 %func, %call = transform.loop.outline %outline_target {func_name = "outlined"} 327 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 328 329 // Rewrite the call target. 330 transform.my.change_call_target %call, "microkernel" : !transform.any_op 331 332 transform.yield 333} 334``` 335 336## Appendix: Autogenerated Documentation 337 338[include "Tutorials/transform/MyExtensionCh2.md"] 339