1# Chapter 2: Adding a Simple New Transformation Operation 2 3## Setting Up to Add New Transformations 4 5Before defining a new transform operation, we need to choose where its implementation should be located. While MLIR encourages upstream contributions, it is not always possible or even desirable to modify the main Transform dialect, for example, if the transformation is specific to some out-of-tree dialect that is not itself available upstream. 6 7The Transform dialect uses the dialect extension mechanism to allow additional operations to be injected without modifying the dialect itself. Dialect extensions are registered with the context and loaded when the dialect itself is loaded. Extension definition is straightforward: 8 9```cpp 10// In MyExtension.cpp. 11#include "mlir/Dialect/Transform/IR/TransformDialect.h" 12 13// Define a new Transform dialect extension. This uses the CRTP idiom to 14// identify extensions. 15class MyExtension : public ::mlir::transform::TransformDialectExtension<MyExtension> { 16public: 17 // The extension must derive the base constructor. 18 using Base::Base; 19 20 // This function initializes the extension, similarly to `initialize` in 21 // dialect definitions. List individual operations and dependent dialects 22 // here. 23 void init(); 24}; 25 26void MyExtension::init() { 27 // Similarly to dialects, an extension can declare a dependent dialect. This 28 // dialect will be loaded along with the extension and, therefore, along with 29 // the Transform dialect. Only declare as dependent the dialects that contain 30 // the attributes or types used by transform operations. Do NOT declare as 31 // dependent the dialects produced during the transformation. 32 // 33 // declareDependentDialect<MyDialect>(); 34 35 // When transformations are applied, they may produce new operations from 36 // previously unloaded dialects. Typically, a pass would need to declare 37 // itself dependent on the dialects containing such new operations. To avoid 38 // confusion with the dialects the extension itself depends on, the Transform 39 // dialects differentiates between: 40 // - dependent dialects, which are used by the transform operations, and 41 // - generated dialects, which contain the entities (attributes, operations, 42 // types) that may be produced by applying the transformation even when 43 // not present in the original payload IR. 44 // In the following chapter, we will be add operations that generate function 45 // calls and structured control flow operations, so let's declare the 46 // corresponding dialects as generated. 47 declareGeneratedDialect<::mlir::scf::SCFDialect>(); 48 declareGeneratedDialect<::mlir::func::FuncDialect>(); 49 50 // Finally, we register the additional transform operations with the dialect. 51 registerTransformOps< 52 // TODO: list the operation classes. 53 >(); 54} 55``` 56 57The operations themselves can be defined using ODS, exactly in the same way as regular operations in a dialect. 58 59```tablegen 60// In MyExtension.td 61#ifndef MY_EXTENSION 62#define MY_EXTENSION 63 64include "mlir/Dialect/Transform/IR/TransformDialect.td" 65include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" 66include "mlir/IR/OpBase.td" 67include "mlir/Interfaces/SideEffectInterfaces.td" 68 69def MyOp : Op<Transform_Dialect, "transform.my.op", [ 70 // TODO: interfaces and traits here. 71 ]> { 72 let summary = "my transform op"; 73 // TODO: define the operation properties. 74} 75 76#endif // MY_EXTENSION 77``` 78 79Similarly to dialects, we must use Tablegen to generate the header and implementation of these operations. We can instruct CMake to do it as follows. 80 81 82```sh 83# In CMakeLists.txt next to MyExtension.td. 84 85# Tell Tablegen to use MyExtension.td as input. 86set(LLVM_TARGET_DEFINITIONS MyExtension.td) 87 88# Ask Tablegen to generate op declarations and definitions from ODS. 89mlir_tablegen(MyExtension.h.inc -gen-op-decls) 90mlir_tablegen(MyExtension.cpp.inc -gen-op-defs) 91 92# Add a CMakeTarget we can depend on to ensure the generation happens before the compilation. 93add_public_tablegen_target(MyExtensionIncGen) 94 95# Don't forget to generate the documentation, this will produce a MyExtension.md under 96# Dialects. 97add_mlir_doc(MyExtension MyExtension Dialects/ -gen-op-doc) 98``` 99 100```sh 101# In CMakeLists.txt next to MyExtension.cpp 102add_mlir_library( 103 # Library called MyExtension. 104 MyExtension 105 106 # Built from the following source files. 107 MyExtension.cpp 108 109 # Make sure ODS declaration and definitions are generated before compiling 110 # this. 111 DEPENDS 112 MyExtensionIncGen 113 114 # Link in the transform dialect, and all generated dialects. 115 LINK_LIBS PUBLIC 116 MLIRTransformDialect 117 MLIRFuncDialect 118 MLIRSCFDialect 119) 120``` 121 122This will generate two files, `MyExtension.h.inc` and `MyExtension.cpp.inc`, that are supposed to be included into the declaration and definition of the transform operations, respectively. 123 124```c++ 125// In MyExtension.h. 126#include "mlir/Dialect/Transform/IR/TransformDialect.h" 127#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 128 129#define GET_OP_CLASSES 130#include "MyExtension.h.inc" 131``` 132 133```c++ 134// In MyExtension.cpp. 135 136#define GET_OP_CLASSES 137#include "MyExtension.cpp.inc" 138 139// … 140void MyExtension::init() { 141 // … 142 143 // Finally, we register the additional transform operations with the dialect. 144 // List all operations generated from ODS. This call will perform additional 145 // checks that the operations implement the transform and memory effect 146 // interfaces required by the dialect interpreter and assert if they do not. 147 registerTransformOps< 148#define GET_OP_LIST 149#include "MyExtension.cpp.inc" 150 >(); 151} 152``` 153 154## Defining a Transform Operation 155 156With this setup, we are now ready to define the new transform operation to rewrite the function call. This is identical to defining a regular operation in a dialect. Note that the Transform dialect requires operations to implement the `TransformOpInterface` as well as `MemoryEffectsOpInterface` to indicate whether the operands are consumed or only read. Our operation can be defined along the following lines. 157 158```tablegen 159// In MyExtension.td. 160 161// Define the new operation. By convention, prefix its name with the name of the 162// dialect extension, "my.". The full operation name will be further prefixed 163// with "transform.". 164def ChangeCallTargetOp : Op<Transform_Dialect, "my.change_call_target", 165 // Indicate that the operation implements the required TransformOpInterface 166 // and MemoryEffectsOpInterface. 167 [DeclareOpInterfaceMethods<TransformOpInterface>, 168 DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> { 169 // Provide a brief and a full description. It is recommended that the latter 170 // describes the effects on the operands and how the operation processes 171 // various failure modes. 172 let summary = "Changes the callee of a call operation to the specified one"; 173 let description = [{ 174 For each `func.call` payload operation associated with the handle, changes 175 its callee to be the symbol whose name is provided as an attribute to this operation. 176 177 Generates a silenceable failure if the operand is associated with payload operations that are not `func.call`. Only reads the operand. 178 }]; 179 180 // The arguments include the handle to the payload operations and the 181 // attribute that specifies the new callee. The handle must implement 182 // TransformHandleTypeInterface. 183 // We use a string attribute as the symbol may not exist in the transform IR 184 // so the verification may fail. 185 let arguments = (ins 186 TransformHandleTypeInterface:$call, 187 StrAttr:$new_target); 188 189 // The results are empty as the transformation does not produce any new 190 // payload. 191 let results = (outs); 192 193 // Provide nice syntax. 194 let assemblyFormat = "$call `,` $new_target attr-dict `:` type($call)"; 195} 196``` 197 198To finalize the definition of the transform operation, we need to implement the 199interface methods. The `TransformOpInterface` currently requires only one method 200– `apply` – that performs the actual transformation. It is a good practice to 201limit the body of the method to manipulation of the Transform dialect constructs 202and have the actual transformation implemented as a standalone function so it 203can be used from other places in the code. Similar to rewrite patterns, all IR 204must be modified with the provided rewriter. 205 206```c++ 207// In MyExtension.cpp 208 209// Implementation of our Transform dialect operation. 210// This operation returns a tri-state result that can be one of: 211// - success when the transformation succeeded; 212// - definite failure when the transformation failed in such a way that 213// following transformations are impossible or undesirable, typically it could 214// have left payload IR in an invalid state; it is expected that a diagnostic 215// is emitted immediately before returning the definite error; 216// - silenceable failure when the transformation failed but following 217// transformations are still applicable, typically this means a precondition 218// for the transformation is not satisfied and the payload IR has not been 219// modified. The silenceable failure additionally carries a Diagnostic that 220// can be emitted to the user. 221::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::apply( 222 // The rewriter that should be used when modifying IR. 223 ::mlir::transform::TransformRewriter &rewriter, 224 // The list of payload IR entities that will be associated with the 225 // transform IR values defined by this transform operation. In this case, it 226 // can remain empty as there are no results. 227 ::mlir::transform::TransformResults &results, 228 // The transform application state. This object can be used to query the 229 // current associations between transform IR values and payload IR entities. 230 // It can also carry additional user-defined state. 231 ::mlir::transform::TransformState &state) { 232 233 // First, we need to obtain the list of payload operations that are associated 234 // with the operand handle. 235 auto payload = state.getPayloadOps(getCall()); 236 237 // Then, we iterate over the list of operands and call the actual IR-mutating 238 // function. We also check the preconditions here. 239 for (Operation *payloadOp : payload) { 240 auto call = dyn_cast<::mlir::func::CallOp>(payloadOp); 241 if (!call) { 242 DiagnosedSilenceableFailure diag = emitSilenceableError() 243 << "only applies to func.call payloads"; 244 diag.attachNote(payloadOp->getLoc()) << "offending payload"; 245 return diag; 246 } 247 248 updateCallee(call, getNewTarget()); 249 } 250 251 // If everything went well, return success. 252 return DiagnosedSilenceableFailure::success(); 253} 254``` 255 256The implementation of the `MemoryEffectsOpInterface` must specify the effects this operation has on its operands (consumed or readonly) and on the payload IR (mutates or readonly). Transform dialect verifiers will check for side effects being present and assert in debug builds if they are not. 257 258```c++ 259// In MyExtension.cpp 260 261void ChangeCallTargetOp::getEffects( 262 ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) { 263 // Indicate that the `call` handle is only read by this operation because the 264 // associated operation is not erased but rather modified in-place, so the 265 // reference to it remains valid. 266 onlyReadsHandle(getCall(), effects); 267 268 // Indicate that the payload is modified by this operation. 269 modifiesPayload(effects); 270} 271``` 272 273## Registration and Usage 274 275This is enough to define transform operations. The only remaining bit is providing the extension registration hook that can be called from the project’s `main`. 276 277 278```c++ 279// In TransformDialect.cpp (don't forget a declaration in TransformDialect.h); 280 281void registerMyExtension(::mlir::DialectRegistry ®istry) { 282 registry.addExtensions<MyExtension>(); 283} 284``` 285 286After registering the extension, it becomes possible to use our new operation in the Transform dialect interpreter. The upstream testing pass can be used as is. 287 288```mlir 289module attributes {transform.with_named_sequence} { 290 transform.named_sequence @__transform_main( 291 %arg0: !transform.any_op, 292 %arg1: !transform.op<"linalg.matmul">, 293 %arg2: !transform.op<"linalg.elemwise_binary">) { 294 // Since the %arg2 handle is associated with both elementwise operations, 295 // we need to split it into two handles so we can target only the second 296 // elementwise operation. 297 %add, %max = transform.split_handle %arg2 298 : (!transform.op<"linalg.elemwise_binary">) 299 -> (!transform.any_op, !transform.any_op) 300 301 // The actual tiling transformation takes tile sizes as attributes. It 302 // produces a handle to the loop generated during tiling. 303 %loop, %tiled = transform.structured.tile_using_forall %max 304 tile_sizes [8, 32] 305 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 306 307 // We can now fuse the other operations into the loop. Here, we fuse 308 // operations one-by-one. This requires the operation that is being fused 309 // to define the value used within the loop, so the order of such fusions 310 // is important. We could also use "transform.merge_handles" to obtain 311 // a single handle to all operations and give it to 312 // `fuse_into_containing_op` that would take care of the ordering in this 313 // case. 314 %add_fused = transform.structured.fuse_into_containing_op %add into %loop 315 : (!transform.any_op, !transform.any_op) -> !transform.any_op 316 %matmul_fused = transform.structured.fuse_into_containing_op %arg1 317 into %loop 318 : (!transform.op<"linalg.matmul">, !transform.any_op) 319 -> !transform.any_op 320 321 // Tile again to get the desired size. Note that this time this tiles the 322 // "add" operation and fuses matmul into the loop, but doesn't affect the 323 // "max" operation. This illustrates the precise targeting with the 324 // transform dialect. Otherwise, it is difficult to differentiate "add" and 325 // "max", both of which having the same kind. 326 %loop_2, %tiled_2 = transform.structured.tile_using_forall %add_fused 327 tile_sizes [4, 4] 328 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 329 %matmul_fused_2 = transform.structured.fuse_into_containing_op %matmul_fused 330 into %loop_2 331 : (!transform.any_op, !transform.any_op) -> !transform.any_op 332 333 // Since outlining is currently only implemented for region-holding 334 // operations such as loops, use tiling to size 1 to materialize the outer 335 // loop that is going to be outlined. 336 %outline_target, %_ = transform.structured.tile_using_forall %tiled_2 tile_sizes [1] 337 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 338 transform.structured.fuse_into_containing_op %matmul_fused_2 into %outline_target 339 : (!transform.any_op, !transform.any_op) -> !transform.any_op 340 %func, %call = transform.loop.outline %outline_target 341 {func_name = "outlined"} 342 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 343 344 // Rewrite the call target. 345 transform.my.change_call_target %call, "microkernel" : !transform.any_op 346 347 transform.yield 348 } 349} 350``` 351 352## Appendix: Autogenerated Documentation 353 354[include "Tutorials/transform/MyExtensionCh2.md"] 355