1# Quickstart tutorial to adding MLIR graph rewrite 2 3This document will present a quickstart to adding graph rewrites. We shall start 4by defining an operation, showing multiple ways to define the rewrite using 5patterns, as well as defining the rewrite using a graph walker (note: using 6patterns and the rewrite engine is preferred, showing the walker is for 7demonstration purposes). 8 9See [MLIR specification](../LangRef.md) for more information about MLIR, the 10structure of the IR, operations, etc. See 11[Table-driven Operation Definition](../DefiningDialects/Operations.md) and 12[Declarative Rewrite Rule](../DeclarativeRewrites.md) for the detailed explanation 13of all available mechanisms for defining operations and rewrites in a 14table-driven manner. 15 16## Adding operation 17 18An operation in MLIR is specified using a definition in 19[TableGen](https://llvm.org/docs/TableGen/index.html) file. TableGen is a 20modeling tool to specify the ops and the C++ code to interact with these 21operations are generated from. To define an operation one needs to specify: 22 23* The operation name. This name is a unique identifier of the operation within 24 MLIR. Most operations are within a dialect, so for example one could have 25 `tfl.add` to represent the add operation in the TensorFlow Lite dialect. 26 Instead of repeating the dialect in the op definition, a base class for the 27 op dialect is commonly created that prepends the dialect namespace given an 28 op name. 29* The traits of the operation. These allow you to specify traits of the 30 operation, such as whether it has side effects or whether it should be 31 verified that the operands and result types are the same. These are backed 32 by C++ traits that perform the verification. 33* The arguments of the operation. These are the input operands (values at 34 runtime produced by other ops) and attributes (compile time known constant 35 values that affect the behavior of the op) that are the inputs of/define the 36 behavior of the operation. The input operands may be named, the attributes 37 must be named. 38* The result(s) of the operation. These may again named or not. 39* Documentation of the operation. This includes a one-line summary as well as 40 a longer human-readable description of the operation. 41* Dialect specific information. Additional information could be added to the 42 operation definition that are only used by dialect specific drivers. These 43 are ignored by the main op and doc generators, but could be used in, say, 44 the translation from a dialect to another representation. 45 46```tablegen 47def TFL_LeakyReluOp: TFL_Op<TFL_Dialect, "leaky_relu", 48 [NoMemoryEffect, SameValueType]>, 49 Results<(outs Tensor)> { 50 let arguments = (ins 51 F32Tensor:$x, 52 // Slope of the activation function at x < 0. 53 F32Attr:$alpha 54 ); 55 56 let summary = "Leaky ReLU operator"; 57 let description = [{ 58 Element-wise Leaky ReLU operator 59 x -> x >= 0 ? x : (alpha * x) 60 }]; 61 62 // TFLite specific attribute that is used when generating the output 63 // flatbuffer. 64 let hasOptions = 1; 65} 66``` 67 68Note in the above the result types and inputs are specified in different ways, 69one by way of trait and the other by way of let. It is possible to specify both 70in either way. 71 72<!-- TODO: Define a style convention. --> 73 74Operations can also have custom parser, printer, builder, verifier, constant 75folder, or canonicalizer. These require specifying additional C++ methods to 76invoke for additional functionality. For example, if an operation is marked to 77have a folder, the constant folder also needs to be added, e.g.,: 78 79```c++ 80OpFoldResult SpecificOp::fold(ArrayRef<Attribute> constOperands) { 81 if (unable_to_fold) 82 return {}; 83 .... 84 return val; 85} 86``` 87 88## Adding patterns 89 90There are multiple forms of graph rewrite that can be performed in MLIR. One of 91the most common is DAG tile to DAG tile rewrite. Patterns provide a concise way 92to express this transformation as a pair of source pattern to match and 93resultant pattern. There are both the C++ classes to represent this 94transformation, as well as the patterns in TableGen from which these can be 95generated. 96 97### TableGen patterns 98 99Let us continue with LeakyRelu. To map from TensorFlow's `LeakyRelu` to 100TensorFlow Lite's `LeakyRelu`: 101 102```tablegen 103def : Pat<(TF_LeakyReluOp $arg, F32Attr:$a), (TFL_LeakyReluOp $arg, $a)> 104``` 105 106The pattern is specified by instantiating a `Pat` with a source and result DAG. 107The arguments in the source pattern is captured and can be used in the result 108pattern. This is a simple pattern as we have a 1:1 mapping and the attribute 109does not need to be transformed (e.g., both have a floating point attribute for 110alpha). The names of the attributes specified in the pattern is for 111matching/referencing and need not match the original attribute name in the op 112definition but the order of arguments of the dags do need to match. 113 114To specify a pattern, both the source and resultant ops need to be defined using 115TableGen. 116 117If this were a more advance pattern that the current framework could not express 118as destination then one could use a general native code fallback method. This 119consists of defining a pattern as well as adding a C++ function to perform the 120replacement: 121 122```tablegen 123def createTFLLeakyRelu : NativeCodeCall< 124 "createTFLLeakyRelu($_builder, $0.getDefiningOp(), $1, $2)">; 125 126def : Pat<(TF_LeakyReluOp:$old_value, $arg, F32Attr:$a), 127 (createTFLLeakyRelu $old_value, $arg, $a)>; 128``` 129 130```c++ 131static Value createTFLLeakyRelu(PatternRewriter &rewriter, Operation *op, 132 Value operand, Attribute attr) { 133 return rewriter.create<mlir::TFL::LeakyReluOp>( 134 op->getLoc(), operands[0].getType(), /*arg=*/operands[0], 135 /*alpha=*/attrs[0].cast<FloatAttr>()); 136} 137``` 138 139This allows for arbitrarily complex builders. Input pattern side one can express 140multi-op patterns with constraints on input operands and attributes. But input 141patterns cannot yet express constraints across multiple operands/attributes. 142 143### Register the pattern 144 145The file containing the patterns need to be processed using `mlir-tblgen` 146`-gen-rewriters` during compilation time. It can be invoked with the following 147configuration in CMake: 148 149```cmake 150set(LLVM_TARGET_DEFINITIONS <name-of-the-td-file>) 151mlir_tablegen(<name-of-the-generated-inc-file> -gen-rewriters) 152add_public_tablegen_target(<name-of-the-cmake-target>) 153``` 154 155Then you can `#include` the generated file in any C++ implementation file you 156like. (You will also need to make sure the library depends on the CMake target 157defined in the above.) The generated file will have a `populateWithGenerated( 158RewritePatternSet &patterns)` function that you can 159use to collect all the generated patterns inside `patterns` and then use 160`patterns` in any pass you would like. 161 162### Simple C++ `matchAndRewrite` style specifications 163 164Many simple rewrites can be expressed with a `matchAndRewrite` style of 165pattern, e.g. when converting a multiply by a power of two into a shift. For 166these cases, the you can define the pattern as a simple function: 167 168```c++ 169static LogicalResult 170convertTFLeakyRelu(TFLeakyReluOp op, PatternRewriter &rewriter) { 171 rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>( 172 op, op->getResult(0).getType(), op->getOperand(0), 173 /*alpha=*/op->getAttrOfType<FloatAttr>("alpha")); 174 return success(); 175} 176 177void populateRewrites(RewritePatternSet &patternSet) { 178 // Add it to a pattern set. 179 patternSet.add(convertTFLeakyRelu); 180} 181``` 182 183ODS provides a simple way to define a function-style canonicalization for your 184operation. In the TableGen definition of the op, specify 185`let hasCanonicalizeMethod = 1;` and then implement the `canonicalize` method in 186your .cpp file: 187 188```c++ 189// Example from the CIRCT project which has a variadic integer multiply. 190LogicalResult circt::MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) { 191 auto inputs = op.inputs(); 192 APInt value; 193 194 // mul(x, c) -> shl(x, log2(c)), where c is a power of two. 195 if (inputs.size() == 2 && matchPattern(inputs.back(), m_RConstant(value)) && 196 value.isPowerOf2()) { 197 auto shift = rewriter.create<rtl::ConstantOp>(op.getLoc(), op.getType(), 198 value.exactLogBase2()); 199 auto shlOp = 200 rewriter.create<comb::ShlOp>(op.getLoc(), inputs[0], shift); 201 rewriter.replaceOpWithNewOp<MulOp>(op, op.getType(), 202 ArrayRef<Value>(shlOp)); 203 return success(); 204 } 205 206 return failure(); 207} 208``` 209 210However, you may want the full generality of canonicalization patterns, for that 211you can specify an arbitrary list of `RewritePattern`s. 212 213### Fully general C++ `RewritePattern` specifications 214 215In case ODS patterns and `matchAndRewrite`-style functions are not sufficient 216you can also specify rewrites as a general set of `RewritePattern`s: 217 218```c++ 219/// Multi-step rewrite using "match" and "rewrite". This allows for separating 220/// the concerns of matching and rewriting. 221struct ConvertTFLeakyRelu : public RewritePattern { 222 ConvertTFLeakyRelu(MLIRContext *context) 223 : RewritePattern("tf.LeakyRelu", 1, context) {} 224 225 LogicalResult match(Operation *op) const override { 226 return success(); 227 } 228 229 void rewrite(Operation *op, PatternRewriter &rewriter) const override { 230 rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>( 231 op, op->getResult(0).getType(), op->getOperand(0), 232 /*alpha=*/op->getAttrOfType<FloatAttr>("alpha")); 233 } 234}; 235 236/// Single-step rewrite with "matchAndRewrite". This allows for performing the 237/// rewrite immediately upon a successful match. 238struct ConvertTFLeakyRelu : public RewritePattern { 239 ConvertTFLeakyRelu(MLIRContext *context) 240 : RewritePattern("tf.LeakyRelu", 1, context) {} 241 242 LogicalResult matchAndRewrite(Operation *op, 243 PatternRewriter &rewriter) const override { 244 rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>( 245 op, op->getResult(0).getType(), op->getOperand(0), 246 /*alpha=*/op->getAttrOfType<FloatAttr>("alpha")); 247 return success(); 248 } 249}; 250``` 251 252In the C++ rewrite the static benefit of the rewrite pattern is specified at 253construction. While in the pattern generator a simple heuristic is currently 254employed based around the number of ops matched and replaced. 255 256The above rule did not capture the matching operands/attributes, but in general 257the `match` function in a multi-step rewrite may populate and return a 258`PatternState` (or class derived from one) to pass information extracted during 259matching to the rewrite. A single-step rewrite with the `matchAndRewrite` 260function has the benefit of being able to directly use any values created when 261matching; removing the need for `PatternState`. 262 263## Testing 264 265MLIR uses [lit](https://llvm.org/docs/CommandGuide/lit.html) (LLVM Integrated 266Testing) tool for performing testing. Testing is performed by way of creating 267the input IR file, running a transformation and then verifying the output IR. 268C++ unit tests are the exception, with the IR transformation serving as the core 269testing mechanism. This results in fewer binaries that need to be built (and 270linked) and forces to focus on the representation as an important piece. 271 272For the legalization transform above we would have a test (probably as part of 273the legalization pass test in TensorFlow Lite) such as: 274 275```mlir 276// RUN: mlir-opt -tfl-legalize-tf %s | FileCheck %s 277 278func.func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> { 279 %2 = "tf.LeakyRelu"(%arg0) {alpha: 0.1} : (tensor<1xf32>) -> tensor<1xf32> 280 return %2: tensor<1xf32> 281 282// CHECK-LABEL: LeakyRelu 283// CHECK: %0 = "tfl.leaky_relu"(%arg0) {alpha: 1.000000e-01} : (tensor<1xf32>) -> tensor<1xf32> 284} 285``` 286 287The RUN command at the top results in running the `mlir-opt` binary (which is 288compiler writer tool to exercise different registered passes) to invoke the 289optimization pass this transform was added as part of on the current file and to 290verify its output using `FileCheck`. `FileCheck` is textual output verifier. In 291particular it uses the CHECK expressions to verify the given output is produced. 292 293There can be multiple RUN commands with different corresponding CHECK prefixes. 294And in addition multiple independent tests separated by `// -----` and 295`mlir-opt` invoked with `-split-input-file` flag. This is especially useful for 296error testing. 297 298This results in very simple, directed testing without need to work around 299constant propagation or other, unrelated, optimization passes. 300 301## Adding optimization pass 302 303Optimization passes that do not fit/are difficult to specify in the above 304structure can be specified as general iterations across modules/functions. See 305[Writing a Pass](../PassManagement.md) for a general overview and introduction to 306optimization passes in MLIR. 307