155de49acSRiver Riddle# Quickstart tutorial to adding MLIR graph rewrite 255de49acSRiver Riddle 355de49acSRiver RiddleThis document will present a quickstart to adding graph rewrites. We shall start 455de49acSRiver Riddleby defining an operation, showing multiple ways to define the rewrite using 555de49acSRiver Riddlepatterns, as well as defining the rewrite using a graph walker (note: using 655de49acSRiver Riddlepatterns and the rewrite engine is preferred, showing the walker is for 755de49acSRiver Riddledemonstration purposes). 855de49acSRiver Riddle 931d1ae79SMarkus BöckSee [MLIR specification](../LangRef.md) for more information about MLIR, the 1055de49acSRiver Riddlestructure of the IR, operations, etc. See 11*1294fa69SRiver Riddle[Table-driven Operation Definition](../DefiningDialects/Operations.md) and 1231d1ae79SMarkus Böck[Declarative Rewrite Rule](../DeclarativeRewrites.md) for the detailed explanation 1355de49acSRiver Riddleof all available mechanisms for defining operations and rewrites in a 1455de49acSRiver Riddletable-driven manner. 1555de49acSRiver Riddle 1655de49acSRiver Riddle## Adding operation 1755de49acSRiver Riddle 1855de49acSRiver RiddleAn operation in MLIR is specified using a definition in 19848d66faSPaul C. Anagnostopoulos[TableGen](https://llvm.org/docs/TableGen/index.html) file. TableGen is a 2055de49acSRiver Riddlemodeling tool to specify the ops and the C++ code to interact with these 2155de49acSRiver Riddleoperations are generated from. To define an operation one needs to specify: 2255de49acSRiver Riddle 2355de49acSRiver Riddle* The operation name. This name is a unique identifier of the operation within 2455de49acSRiver Riddle MLIR. Most operations are within a dialect, so for example one could have 2555de49acSRiver Riddle `tfl.add` to represent the add operation in the TensorFlow Lite dialect. 2655de49acSRiver Riddle Instead of repeating the dialect in the op definition, a base class for the 2755de49acSRiver Riddle op dialect is commonly created that prepends the dialect namespace given an 2855de49acSRiver Riddle op name. 2955de49acSRiver Riddle* The traits of the operation. These allow you to specify traits of the 3055de49acSRiver Riddle operation, such as whether it has side effects or whether it should be 3155de49acSRiver Riddle verified that the operands and result types are the same. These are backed 3255de49acSRiver Riddle by C++ traits that perform the verification. 3355de49acSRiver Riddle* The arguments of the operation. These are the input operands (values at 3455de49acSRiver Riddle runtime produced by other ops) and attributes (compile time known constant 3555de49acSRiver Riddle values that affect the behavior of the op) that are the inputs of/define the 3655de49acSRiver Riddle behavior of the operation. The input operands may be named, the attributes 3755de49acSRiver Riddle must be named. 3855de49acSRiver Riddle* The result(s) of the operation. These may again named or not. 3955de49acSRiver Riddle* Documentation of the operation. This includes a one-line summary as well as 4055de49acSRiver Riddle a longer human-readable description of the operation. 4155de49acSRiver Riddle* Dialect specific information. Additional information could be added to the 4255de49acSRiver Riddle operation definition that are only used by dialect specific drivers. These 4355de49acSRiver Riddle are ignored by the main op and doc generators, but could be used in, say, 4455de49acSRiver Riddle the translation from a dialect to another representation. 4555de49acSRiver Riddle 4655de49acSRiver Riddle```tablegen 4755de49acSRiver Riddledef TFL_LeakyReluOp: TFL_Op<TFL_Dialect, "leaky_relu", 4886771d0bSSanjoy Das [NoMemoryEffect, SameValueType]>, 4955de49acSRiver Riddle Results<(outs Tensor)> { 5055de49acSRiver Riddle let arguments = (ins 5155de49acSRiver Riddle F32Tensor:$x, 5255de49acSRiver Riddle // Slope of the activation function at x < 0. 5355de49acSRiver Riddle F32Attr:$alpha 5455de49acSRiver Riddle ); 5555de49acSRiver Riddle 5655de49acSRiver Riddle let summary = "Leaky ReLU operator"; 5755de49acSRiver Riddle let description = [{ 5855de49acSRiver Riddle Element-wise Leaky ReLU operator 5955de49acSRiver Riddle x -> x >= 0 ? x : (alpha * x) 6055de49acSRiver Riddle }]; 6155de49acSRiver Riddle 6255de49acSRiver Riddle // TFLite specific attribute that is used when generating the output 6355de49acSRiver Riddle // flatbuffer. 6455de49acSRiver Riddle let hasOptions = 1; 6555de49acSRiver Riddle} 6655de49acSRiver Riddle``` 6755de49acSRiver Riddle 6855de49acSRiver RiddleNote in the above the result types and inputs are specified in different ways, 6955de49acSRiver Riddleone by way of trait and the other by way of let. It is possible to specify both 7055de49acSRiver Riddlein either way. 7155de49acSRiver Riddle 7255de49acSRiver Riddle<!-- TODO: Define a style convention. --> 7355de49acSRiver Riddle 7455de49acSRiver RiddleOperations can also have custom parser, printer, builder, verifier, constant 7555de49acSRiver Riddlefolder, or canonicalizer. These require specifying additional C++ methods to 7655de49acSRiver Riddleinvoke for additional functionality. For example, if an operation is marked to 7755de49acSRiver Riddlehave a folder, the constant folder also needs to be added, e.g.,: 7855de49acSRiver Riddle 7955de49acSRiver Riddle```c++ 8055de49acSRiver RiddleOpFoldResult SpecificOp::fold(ArrayRef<Attribute> constOperands) { 8155de49acSRiver Riddle if (unable_to_fold) 8255de49acSRiver Riddle return {}; 8355de49acSRiver Riddle .... 8455de49acSRiver Riddle return val; 8555de49acSRiver Riddle} 8655de49acSRiver Riddle``` 8755de49acSRiver Riddle 8855de49acSRiver Riddle## Adding patterns 8955de49acSRiver Riddle 9055de49acSRiver RiddleThere are multiple forms of graph rewrite that can be performed in MLIR. One of 9155de49acSRiver Riddlethe most common is DAG tile to DAG tile rewrite. Patterns provide a concise way 9255de49acSRiver Riddleto express this transformation as a pair of source pattern to match and 9355de49acSRiver Riddleresultant pattern. There are both the C++ classes to represent this 9455de49acSRiver Riddletransformation, as well as the patterns in TableGen from which these can be 9555de49acSRiver Riddlegenerated. 9655de49acSRiver Riddle 9755de49acSRiver Riddle### TableGen patterns 9855de49acSRiver Riddle 9955de49acSRiver RiddleLet us continue with LeakyRelu. To map from TensorFlow's `LeakyRelu` to 10055de49acSRiver RiddleTensorFlow Lite's `LeakyRelu`: 10155de49acSRiver Riddle 10255de49acSRiver Riddle```tablegen 10355de49acSRiver Riddledef : Pat<(TF_LeakyReluOp $arg, F32Attr:$a), (TFL_LeakyReluOp $arg, $a)> 10455de49acSRiver Riddle``` 10555de49acSRiver Riddle 10655de49acSRiver RiddleThe pattern is specified by instantiating a `Pat` with a source and result DAG. 10755de49acSRiver RiddleThe arguments in the source pattern is captured and can be used in the result 10855de49acSRiver Riddlepattern. This is a simple pattern as we have a 1:1 mapping and the attribute 10955de49acSRiver Riddledoes not need to be transformed (e.g., both have a floating point attribute for 11055de49acSRiver Riddlealpha). The names of the attributes specified in the pattern is for 11155de49acSRiver Riddlematching/referencing and need not match the original attribute name in the op 11255de49acSRiver Riddledefinition but the order of arguments of the dags do need to match. 11355de49acSRiver Riddle 11455de49acSRiver RiddleTo specify a pattern, both the source and resultant ops need to be defined using 11555de49acSRiver RiddleTableGen. 11655de49acSRiver Riddle 11755de49acSRiver RiddleIf this were a more advance pattern that the current framework could not express 11855de49acSRiver Riddleas destination then one could use a general native code fallback method. This 11955de49acSRiver Riddleconsists of defining a pattern as well as adding a C++ function to perform the 12055de49acSRiver Riddlereplacement: 12155de49acSRiver Riddle 12255de49acSRiver Riddle```tablegen 12355de49acSRiver Riddledef createTFLLeakyRelu : NativeCodeCall< 12455de49acSRiver Riddle "createTFLLeakyRelu($_builder, $0.getDefiningOp(), $1, $2)">; 12555de49acSRiver Riddle 12655de49acSRiver Riddledef : Pat<(TF_LeakyReluOp:$old_value, $arg, F32Attr:$a), 12755de49acSRiver Riddle (createTFLLeakyRelu $old_value, $arg, $a)>; 12855de49acSRiver Riddle``` 12955de49acSRiver Riddle 13055de49acSRiver Riddle```c++ 13155de49acSRiver Riddlestatic Value createTFLLeakyRelu(PatternRewriter &rewriter, Operation *op, 13255de49acSRiver Riddle Value operand, Attribute attr) { 13355de49acSRiver Riddle return rewriter.create<mlir::TFL::LeakyReluOp>( 13455de49acSRiver Riddle op->getLoc(), operands[0].getType(), /*arg=*/operands[0], 13555de49acSRiver Riddle /*alpha=*/attrs[0].cast<FloatAttr>()); 13655de49acSRiver Riddle} 13755de49acSRiver Riddle``` 13855de49acSRiver Riddle 13955de49acSRiver RiddleThis allows for arbitrarily complex builders. Input pattern side one can express 14055de49acSRiver Riddlemulti-op patterns with constraints on input operands and attributes. But input 14155de49acSRiver Riddlepatterns cannot yet express constraints across multiple operands/attributes. 14255de49acSRiver Riddle 14355de49acSRiver Riddle### Register the pattern 14455de49acSRiver Riddle 14555de49acSRiver RiddleThe file containing the patterns need to be processed using `mlir-tblgen` 14655de49acSRiver Riddle`-gen-rewriters` during compilation time. It can be invoked with the following 14755de49acSRiver Riddleconfiguration in CMake: 14855de49acSRiver Riddle 14955de49acSRiver Riddle```cmake 15055de49acSRiver Riddleset(LLVM_TARGET_DEFINITIONS <name-of-the-td-file>) 15155de49acSRiver Riddlemlir_tablegen(<name-of-the-generated-inc-file> -gen-rewriters) 15255de49acSRiver Riddleadd_public_tablegen_target(<name-of-the-cmake-target>) 15355de49acSRiver Riddle``` 15455de49acSRiver Riddle 15555de49acSRiver RiddleThen you can `#include` the generated file in any C++ implementation file you 15655de49acSRiver Riddlelike. (You will also need to make sure the library depends on the CMake target 15755de49acSRiver Riddledefined in the above.) The generated file will have a `populateWithGenerated( 158dc4e913bSChris LattnerRewritePatternSet &patterns)` function that you can 15955de49acSRiver Riddleuse to collect all the generated patterns inside `patterns` and then use 16055de49acSRiver Riddle`patterns` in any pass you would like. 16155de49acSRiver Riddle 162782c5341SChris Lattner### Simple C++ `matchAndRewrite` style specifications 16355de49acSRiver Riddle 164782c5341SChris LattnerMany simple rewrites can be expressed with a `matchAndRewrite` style of 165782c5341SChris Lattnerpattern, e.g. when converting a multiply by a power of two into a shift. For 166782c5341SChris Lattnerthese cases, the you can define the pattern as a simple function: 167782c5341SChris Lattner 168782c5341SChris Lattner```c++ 169782c5341SChris Lattnerstatic LogicalResult 170782c5341SChris LattnerconvertTFLeakyRelu(TFLeakyReluOp op, PatternRewriter &rewriter) { 171782c5341SChris Lattner rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>( 172782c5341SChris Lattner op, op->getResult(0).getType(), op->getOperand(0), 173782c5341SChris Lattner /*alpha=*/op->getAttrOfType<FloatAttr>("alpha")); 174782c5341SChris Lattner return success(); 175782c5341SChris Lattner} 176782c5341SChris Lattner 177782c5341SChris Lattnervoid populateRewrites(RewritePatternSet &patternSet) { 178782c5341SChris Lattner // Add it to a pattern set. 179782c5341SChris Lattner patternSet.add(convertTFLeakyRelu); 180782c5341SChris Lattner} 181782c5341SChris Lattner``` 182782c5341SChris Lattner 183782c5341SChris LattnerODS provides a simple way to define a function-style canonicalization for your 184782c5341SChris Lattneroperation. In the TableGen definition of the op, specify 185782c5341SChris Lattner`let hasCanonicalizeMethod = 1;` and then implement the `canonicalize` method in 186782c5341SChris Lattneryour .cpp file: 187782c5341SChris Lattner 188782c5341SChris Lattner```c++ 189782c5341SChris Lattner// Example from the CIRCT project which has a variadic integer multiply. 190782c5341SChris LattnerLogicalResult circt::MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) { 191782c5341SChris Lattner auto inputs = op.inputs(); 192782c5341SChris Lattner APInt value; 193782c5341SChris Lattner 194782c5341SChris Lattner // mul(x, c) -> shl(x, log2(c)), where c is a power of two. 195782c5341SChris Lattner if (inputs.size() == 2 && matchPattern(inputs.back(), m_RConstant(value)) && 196782c5341SChris Lattner value.isPowerOf2()) { 197782c5341SChris Lattner auto shift = rewriter.create<rtl::ConstantOp>(op.getLoc(), op.getType(), 198782c5341SChris Lattner value.exactLogBase2()); 199782c5341SChris Lattner auto shlOp = 200782c5341SChris Lattner rewriter.create<comb::ShlOp>(op.getLoc(), inputs[0], shift); 201782c5341SChris Lattner rewriter.replaceOpWithNewOp<MulOp>(op, op.getType(), 202782c5341SChris Lattner ArrayRef<Value>(shlOp)); 203782c5341SChris Lattner return success(); 204782c5341SChris Lattner } 205782c5341SChris Lattner 206782c5341SChris Lattner return failure(); 207782c5341SChris Lattner} 208782c5341SChris Lattner``` 209782c5341SChris Lattner 210782c5341SChris LattnerHowever, you may want the full generality of canonicalization patterns, for that 211782c5341SChris Lattneryou can specify an arbitrary list of `RewritePattern`s. 212782c5341SChris Lattner 213782c5341SChris Lattner### Fully general C++ `RewritePattern` specifications 214782c5341SChris Lattner 215782c5341SChris LattnerIn case ODS patterns and `matchAndRewrite`-style functions are not sufficient 216782c5341SChris Lattneryou can also specify rewrites as a general set of `RewritePattern`s: 21755de49acSRiver Riddle 21855de49acSRiver Riddle```c++ 21955de49acSRiver Riddle/// Multi-step rewrite using "match" and "rewrite". This allows for separating 22055de49acSRiver Riddle/// the concerns of matching and rewriting. 22155de49acSRiver Riddlestruct ConvertTFLeakyRelu : public RewritePattern { 22255de49acSRiver Riddle ConvertTFLeakyRelu(MLIRContext *context) 22355de49acSRiver Riddle : RewritePattern("tf.LeakyRelu", 1, context) {} 22455de49acSRiver Riddle 22555de49acSRiver Riddle LogicalResult match(Operation *op) const override { 22655de49acSRiver Riddle return success(); 22755de49acSRiver Riddle } 22855de49acSRiver Riddle 22955de49acSRiver Riddle void rewrite(Operation *op, PatternRewriter &rewriter) const override { 23055de49acSRiver Riddle rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>( 23155de49acSRiver Riddle op, op->getResult(0).getType(), op->getOperand(0), 23255de49acSRiver Riddle /*alpha=*/op->getAttrOfType<FloatAttr>("alpha")); 23355de49acSRiver Riddle } 23455de49acSRiver Riddle}; 23555de49acSRiver Riddle 23655de49acSRiver Riddle/// Single-step rewrite with "matchAndRewrite". This allows for performing the 23755de49acSRiver Riddle/// rewrite immediately upon a successful match. 23855de49acSRiver Riddlestruct ConvertTFLeakyRelu : public RewritePattern { 23955de49acSRiver Riddle ConvertTFLeakyRelu(MLIRContext *context) 24055de49acSRiver Riddle : RewritePattern("tf.LeakyRelu", 1, context) {} 24155de49acSRiver Riddle 24255de49acSRiver Riddle LogicalResult matchAndRewrite(Operation *op, 24355de49acSRiver Riddle PatternRewriter &rewriter) const override { 24455de49acSRiver Riddle rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>( 24555de49acSRiver Riddle op, op->getResult(0).getType(), op->getOperand(0), 24655de49acSRiver Riddle /*alpha=*/op->getAttrOfType<FloatAttr>("alpha")); 24755de49acSRiver Riddle return success(); 24855de49acSRiver Riddle } 24955de49acSRiver Riddle}; 25055de49acSRiver Riddle``` 25155de49acSRiver Riddle 25255de49acSRiver RiddleIn the C++ rewrite the static benefit of the rewrite pattern is specified at 25355de49acSRiver Riddleconstruction. While in the pattern generator a simple heuristic is currently 25455de49acSRiver Riddleemployed based around the number of ops matched and replaced. 25568747266SChris Lattner 25655de49acSRiver RiddleThe above rule did not capture the matching operands/attributes, but in general 25755de49acSRiver Riddlethe `match` function in a multi-step rewrite may populate and return a 25855de49acSRiver Riddle`PatternState` (or class derived from one) to pass information extracted during 25955de49acSRiver Riddlematching to the rewrite. A single-step rewrite with the `matchAndRewrite` 26055de49acSRiver Riddlefunction has the benefit of being able to directly use any values created when 26155de49acSRiver Riddlematching; removing the need for `PatternState`. 26255de49acSRiver Riddle 26355de49acSRiver Riddle## Testing 26455de49acSRiver Riddle 26555de49acSRiver RiddleMLIR uses [lit](https://llvm.org/docs/CommandGuide/lit.html) (LLVM Integrated 26655de49acSRiver RiddleTesting) tool for performing testing. Testing is performed by way of creating 26755de49acSRiver Riddlethe input IR file, running a transformation and then verifying the output IR. 26855de49acSRiver RiddleC++ unit tests are the exception, with the IR transformation serving as the core 26955de49acSRiver Riddletesting mechanism. This results in fewer binaries that need to be built (and 27055de49acSRiver Riddlelinked) and forces to focus on the representation as an important piece. 27155de49acSRiver Riddle 27255de49acSRiver RiddleFor the legalization transform above we would have a test (probably as part of 27355de49acSRiver Riddlethe legalization pass test in TensorFlow Lite) such as: 27455de49acSRiver Riddle 27555de49acSRiver Riddle```mlir 27655de49acSRiver Riddle// RUN: mlir-opt -tfl-legalize-tf %s | FileCheck %s 27755de49acSRiver Riddle 2782310ced8SRiver Riddlefunc.func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> { 27955de49acSRiver Riddle %2 = "tf.LeakyRelu"(%arg0) {alpha: 0.1} : (tensor<1xf32>) -> tensor<1xf32> 28055de49acSRiver Riddle return %2: tensor<1xf32> 28155de49acSRiver Riddle 28255de49acSRiver Riddle// CHECK-LABEL: LeakyRelu 28355de49acSRiver Riddle// CHECK: %0 = "tfl.leaky_relu"(%arg0) {alpha: 1.000000e-01} : (tensor<1xf32>) -> tensor<1xf32> 28455de49acSRiver Riddle} 28555de49acSRiver Riddle``` 28655de49acSRiver Riddle 28755de49acSRiver RiddleThe RUN command at the top results in running the `mlir-opt` binary (which is 28855de49acSRiver Riddlecompiler writer tool to exercise different registered passes) to invoke the 28955de49acSRiver Riddleoptimization pass this transform was added as part of on the current file and to 29055de49acSRiver Riddleverify its output using `FileCheck`. `FileCheck` is textual output verifier. In 29155de49acSRiver Riddleparticular it uses the CHECK expressions to verify the given output is produced. 29255de49acSRiver Riddle 29355de49acSRiver RiddleThere can be multiple RUN commands with different corresponding CHECK prefixes. 29455de49acSRiver RiddleAnd in addition multiple independent tests separated by `// -----` and 29555de49acSRiver Riddle`mlir-opt` invoked with `-split-input-file` flag. This is especially useful for 29655de49acSRiver Riddleerror testing. 29755de49acSRiver Riddle 29855de49acSRiver RiddleThis results in very simple, directed testing without need to work around 29955de49acSRiver Riddleconstant propagation or other, unrelated, optimization passes. 30055de49acSRiver Riddle 30155de49acSRiver Riddle## Adding optimization pass 30255de49acSRiver Riddle 30355de49acSRiver RiddleOptimization passes that do not fit/are difficult to specify in the above 30455de49acSRiver Riddlestructure can be specified as general iterations across modules/functions. See 3051b012a91SMehdi Amini[Writing a Pass](../PassManagement.md) for a general overview and introduction to 30655de49acSRiver Riddleoptimization passes in MLIR. 307