xref: /llvm-project/mlir/docs/Tutorials/QuickstartRewrites.md (revision 1294fa697176c244e92e72c0b01fd3b5e3a06477)
155de49acSRiver Riddle# Quickstart tutorial to adding MLIR graph rewrite
255de49acSRiver Riddle
355de49acSRiver RiddleThis document will present a quickstart to adding graph rewrites. We shall start
455de49acSRiver Riddleby defining an operation, showing multiple ways to define the rewrite using
555de49acSRiver Riddlepatterns, as well as defining the rewrite using a graph walker (note: using
655de49acSRiver Riddlepatterns and the rewrite engine is preferred, showing the walker is for
755de49acSRiver Riddledemonstration purposes).
855de49acSRiver Riddle
931d1ae79SMarkus BöckSee [MLIR specification](../LangRef.md) for more information about MLIR, the
1055de49acSRiver Riddlestructure of the IR, operations, etc. See
11*1294fa69SRiver Riddle[Table-driven Operation Definition](../DefiningDialects/Operations.md) and
1231d1ae79SMarkus Böck[Declarative Rewrite Rule](../DeclarativeRewrites.md) for the detailed explanation
1355de49acSRiver Riddleof all available mechanisms for defining operations and rewrites in a
1455de49acSRiver Riddletable-driven manner.
1555de49acSRiver Riddle
1655de49acSRiver Riddle## Adding operation
1755de49acSRiver Riddle
1855de49acSRiver RiddleAn operation in MLIR is specified using a definition in
19848d66faSPaul C. Anagnostopoulos[TableGen](https://llvm.org/docs/TableGen/index.html) file. TableGen is a
2055de49acSRiver Riddlemodeling tool to specify the ops and the C++ code to interact with these
2155de49acSRiver Riddleoperations are generated from. To define an operation one needs to specify:
2255de49acSRiver Riddle
2355de49acSRiver Riddle*   The operation name. This name is a unique identifier of the operation within
2455de49acSRiver Riddle    MLIR. Most operations are within a dialect, so for example one could have
2555de49acSRiver Riddle    `tfl.add` to represent the add operation in the TensorFlow Lite dialect.
2655de49acSRiver Riddle    Instead of repeating the dialect in the op definition, a base class for the
2755de49acSRiver Riddle    op dialect is commonly created that prepends the dialect namespace given an
2855de49acSRiver Riddle    op name.
2955de49acSRiver Riddle*   The traits of the operation. These allow you to specify traits of the
3055de49acSRiver Riddle    operation, such as whether it has side effects or whether it should be
3155de49acSRiver Riddle    verified that the operands and result types are the same. These are backed
3255de49acSRiver Riddle    by C++ traits that perform the verification.
3355de49acSRiver Riddle*   The arguments of the operation. These are the input operands (values at
3455de49acSRiver Riddle    runtime produced by other ops) and attributes (compile time known constant
3555de49acSRiver Riddle    values that affect the behavior of the op) that are the inputs of/define the
3655de49acSRiver Riddle    behavior of the operation. The input operands may be named, the attributes
3755de49acSRiver Riddle    must be named.
3855de49acSRiver Riddle*   The result(s) of the operation. These may again named or not.
3955de49acSRiver Riddle*   Documentation of the operation. This includes a one-line summary as well as
4055de49acSRiver Riddle    a longer human-readable description of the operation.
4155de49acSRiver Riddle*   Dialect specific information. Additional information could be added to the
4255de49acSRiver Riddle    operation definition that are only used by dialect specific drivers. These
4355de49acSRiver Riddle    are ignored by the main op and doc generators, but could be used in, say,
4455de49acSRiver Riddle    the translation from a dialect to another representation.
4555de49acSRiver Riddle
4655de49acSRiver Riddle```tablegen
4755de49acSRiver Riddledef TFL_LeakyReluOp: TFL_Op<TFL_Dialect, "leaky_relu",
4886771d0bSSanjoy Das                            [NoMemoryEffect, SameValueType]>,
4955de49acSRiver Riddle                     Results<(outs Tensor)> {
5055de49acSRiver Riddle  let arguments = (ins
5155de49acSRiver Riddle    F32Tensor:$x,
5255de49acSRiver Riddle    // Slope of the activation function at x < 0.
5355de49acSRiver Riddle    F32Attr:$alpha
5455de49acSRiver Riddle  );
5555de49acSRiver Riddle
5655de49acSRiver Riddle  let summary = "Leaky ReLU operator";
5755de49acSRiver Riddle  let description = [{
5855de49acSRiver Riddle    Element-wise Leaky ReLU operator
5955de49acSRiver Riddle      x -> x >= 0 ? x : (alpha * x)
6055de49acSRiver Riddle  }];
6155de49acSRiver Riddle
6255de49acSRiver Riddle  // TFLite specific attribute that is used when generating the output
6355de49acSRiver Riddle  // flatbuffer.
6455de49acSRiver Riddle  let hasOptions = 1;
6555de49acSRiver Riddle}
6655de49acSRiver Riddle```
6755de49acSRiver Riddle
6855de49acSRiver RiddleNote in the above the result types and inputs are specified in different ways,
6955de49acSRiver Riddleone by way of trait and the other by way of let. It is possible to specify both
7055de49acSRiver Riddlein either way.
7155de49acSRiver Riddle
7255de49acSRiver Riddle<!-- TODO: Define a style convention. -->
7355de49acSRiver Riddle
7455de49acSRiver RiddleOperations can also have custom parser, printer, builder, verifier, constant
7555de49acSRiver Riddlefolder, or canonicalizer. These require specifying additional C++ methods to
7655de49acSRiver Riddleinvoke for additional functionality. For example, if an operation is marked to
7755de49acSRiver Riddlehave a folder, the constant folder also needs to be added, e.g.,:
7855de49acSRiver Riddle
7955de49acSRiver Riddle```c++
8055de49acSRiver RiddleOpFoldResult SpecificOp::fold(ArrayRef<Attribute> constOperands) {
8155de49acSRiver Riddle  if (unable_to_fold)
8255de49acSRiver Riddle    return {};
8355de49acSRiver Riddle  ....
8455de49acSRiver Riddle  return val;
8555de49acSRiver Riddle}
8655de49acSRiver Riddle```
8755de49acSRiver Riddle
8855de49acSRiver Riddle## Adding patterns
8955de49acSRiver Riddle
9055de49acSRiver RiddleThere are multiple forms of graph rewrite that can be performed in MLIR. One of
9155de49acSRiver Riddlethe most common is DAG tile to DAG tile rewrite. Patterns provide a concise way
9255de49acSRiver Riddleto express this transformation as a pair of source pattern to match and
9355de49acSRiver Riddleresultant pattern. There are both the C++ classes to represent this
9455de49acSRiver Riddletransformation, as well as the patterns in TableGen from which these can be
9555de49acSRiver Riddlegenerated.
9655de49acSRiver Riddle
9755de49acSRiver Riddle### TableGen patterns
9855de49acSRiver Riddle
9955de49acSRiver RiddleLet us continue with LeakyRelu. To map from TensorFlow's `LeakyRelu` to
10055de49acSRiver RiddleTensorFlow Lite's `LeakyRelu`:
10155de49acSRiver Riddle
10255de49acSRiver Riddle```tablegen
10355de49acSRiver Riddledef : Pat<(TF_LeakyReluOp $arg, F32Attr:$a), (TFL_LeakyReluOp $arg, $a)>
10455de49acSRiver Riddle```
10555de49acSRiver Riddle
10655de49acSRiver RiddleThe pattern is specified by instantiating a `Pat` with a source and result DAG.
10755de49acSRiver RiddleThe arguments in the source pattern is captured and can be used in the result
10855de49acSRiver Riddlepattern. This is a simple pattern as we have a 1:1 mapping and the attribute
10955de49acSRiver Riddledoes not need to be transformed (e.g., both have a floating point attribute for
11055de49acSRiver Riddlealpha). The names of the attributes specified in the pattern is for
11155de49acSRiver Riddlematching/referencing and need not match the original attribute name in the op
11255de49acSRiver Riddledefinition but the order of arguments of the dags do need to match.
11355de49acSRiver Riddle
11455de49acSRiver RiddleTo specify a pattern, both the source and resultant ops need to be defined using
11555de49acSRiver RiddleTableGen.
11655de49acSRiver Riddle
11755de49acSRiver RiddleIf this were a more advance pattern that the current framework could not express
11855de49acSRiver Riddleas destination then one could use a general native code fallback method. This
11955de49acSRiver Riddleconsists of defining a pattern as well as adding a C++ function to perform the
12055de49acSRiver Riddlereplacement:
12155de49acSRiver Riddle
12255de49acSRiver Riddle```tablegen
12355de49acSRiver Riddledef createTFLLeakyRelu : NativeCodeCall<
12455de49acSRiver Riddle    "createTFLLeakyRelu($_builder, $0.getDefiningOp(), $1, $2)">;
12555de49acSRiver Riddle
12655de49acSRiver Riddledef : Pat<(TF_LeakyReluOp:$old_value, $arg, F32Attr:$a),
12755de49acSRiver Riddle          (createTFLLeakyRelu $old_value, $arg, $a)>;
12855de49acSRiver Riddle```
12955de49acSRiver Riddle
13055de49acSRiver Riddle```c++
13155de49acSRiver Riddlestatic Value createTFLLeakyRelu(PatternRewriter &rewriter, Operation *op,
13255de49acSRiver Riddle                                Value operand, Attribute attr) {
13355de49acSRiver Riddle  return rewriter.create<mlir::TFL::LeakyReluOp>(
13455de49acSRiver Riddle      op->getLoc(), operands[0].getType(), /*arg=*/operands[0],
13555de49acSRiver Riddle      /*alpha=*/attrs[0].cast<FloatAttr>());
13655de49acSRiver Riddle}
13755de49acSRiver Riddle```
13855de49acSRiver Riddle
13955de49acSRiver RiddleThis allows for arbitrarily complex builders. Input pattern side one can express
14055de49acSRiver Riddlemulti-op patterns with constraints on input operands and attributes. But input
14155de49acSRiver Riddlepatterns cannot yet express constraints across multiple operands/attributes.
14255de49acSRiver Riddle
14355de49acSRiver Riddle### Register the pattern
14455de49acSRiver Riddle
14555de49acSRiver RiddleThe file containing the patterns need to be processed using `mlir-tblgen`
14655de49acSRiver Riddle`-gen-rewriters` during compilation time. It can be invoked with the following
14755de49acSRiver Riddleconfiguration in CMake:
14855de49acSRiver Riddle
14955de49acSRiver Riddle```cmake
15055de49acSRiver Riddleset(LLVM_TARGET_DEFINITIONS <name-of-the-td-file>)
15155de49acSRiver Riddlemlir_tablegen(<name-of-the-generated-inc-file> -gen-rewriters)
15255de49acSRiver Riddleadd_public_tablegen_target(<name-of-the-cmake-target>)
15355de49acSRiver Riddle```
15455de49acSRiver Riddle
15555de49acSRiver RiddleThen you can `#include` the generated file in any C++ implementation file you
15655de49acSRiver Riddlelike. (You will also need to make sure the library depends on the CMake target
15755de49acSRiver Riddledefined in the above.) The generated file will have a `populateWithGenerated(
158dc4e913bSChris LattnerRewritePatternSet &patterns)` function that you can
15955de49acSRiver Riddleuse to collect all the generated patterns inside `patterns` and then use
16055de49acSRiver Riddle`patterns` in any pass you would like.
16155de49acSRiver Riddle
162782c5341SChris Lattner### Simple C++ `matchAndRewrite` style specifications
16355de49acSRiver Riddle
164782c5341SChris LattnerMany simple rewrites can be expressed with a `matchAndRewrite` style  of
165782c5341SChris Lattnerpattern, e.g. when converting a multiply by a power of two into a shift.  For
166782c5341SChris Lattnerthese cases, the you can define the pattern as a simple function:
167782c5341SChris Lattner
168782c5341SChris Lattner```c++
169782c5341SChris Lattnerstatic LogicalResult
170782c5341SChris LattnerconvertTFLeakyRelu(TFLeakyReluOp op, PatternRewriter &rewriter) {
171782c5341SChris Lattner  rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
172782c5341SChris Lattner      op, op->getResult(0).getType(), op->getOperand(0),
173782c5341SChris Lattner      /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
174782c5341SChris Lattner  return success();
175782c5341SChris Lattner}
176782c5341SChris Lattner
177782c5341SChris Lattnervoid populateRewrites(RewritePatternSet &patternSet) {
178782c5341SChris Lattner  // Add it to a pattern set.
179782c5341SChris Lattner  patternSet.add(convertTFLeakyRelu);
180782c5341SChris Lattner}
181782c5341SChris Lattner```
182782c5341SChris Lattner
183782c5341SChris LattnerODS provides a simple way to define a function-style canonicalization for your
184782c5341SChris Lattneroperation.  In the TableGen definition of the op, specify
185782c5341SChris Lattner`let hasCanonicalizeMethod = 1;` and then implement the `canonicalize` method in
186782c5341SChris Lattneryour .cpp file:
187782c5341SChris Lattner
188782c5341SChris Lattner```c++
189782c5341SChris Lattner// Example from the CIRCT project which has a variadic integer multiply.
190782c5341SChris LattnerLogicalResult circt::MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
191782c5341SChris Lattner  auto inputs = op.inputs();
192782c5341SChris Lattner  APInt value;
193782c5341SChris Lattner
194782c5341SChris Lattner  // mul(x, c) -> shl(x, log2(c)), where c is a power of two.
195782c5341SChris Lattner  if (inputs.size() == 2 && matchPattern(inputs.back(), m_RConstant(value)) &&
196782c5341SChris Lattner      value.isPowerOf2()) {
197782c5341SChris Lattner    auto shift = rewriter.create<rtl::ConstantOp>(op.getLoc(), op.getType(),
198782c5341SChris Lattner                                                  value.exactLogBase2());
199782c5341SChris Lattner    auto shlOp =
200782c5341SChris Lattner        rewriter.create<comb::ShlOp>(op.getLoc(), inputs[0], shift);
201782c5341SChris Lattner    rewriter.replaceOpWithNewOp<MulOp>(op, op.getType(),
202782c5341SChris Lattner                                       ArrayRef<Value>(shlOp));
203782c5341SChris Lattner    return success();
204782c5341SChris Lattner  }
205782c5341SChris Lattner
206782c5341SChris Lattner  return failure();
207782c5341SChris Lattner}
208782c5341SChris Lattner```
209782c5341SChris Lattner
210782c5341SChris LattnerHowever, you may want the full generality of canonicalization patterns, for that
211782c5341SChris Lattneryou can specify an arbitrary list of `RewritePattern`s.
212782c5341SChris Lattner
213782c5341SChris Lattner### Fully general C++ `RewritePattern` specifications
214782c5341SChris Lattner
215782c5341SChris LattnerIn case ODS patterns and `matchAndRewrite`-style functions are not sufficient
216782c5341SChris Lattneryou can also specify rewrites as a general set of `RewritePattern`s:
21755de49acSRiver Riddle
21855de49acSRiver Riddle```c++
21955de49acSRiver Riddle/// Multi-step rewrite using "match" and "rewrite". This allows for separating
22055de49acSRiver Riddle/// the concerns of matching and rewriting.
22155de49acSRiver Riddlestruct ConvertTFLeakyRelu : public RewritePattern {
22255de49acSRiver Riddle  ConvertTFLeakyRelu(MLIRContext *context)
22355de49acSRiver Riddle      : RewritePattern("tf.LeakyRelu", 1, context) {}
22455de49acSRiver Riddle
22555de49acSRiver Riddle  LogicalResult match(Operation *op) const override {
22655de49acSRiver Riddle    return success();
22755de49acSRiver Riddle  }
22855de49acSRiver Riddle
22955de49acSRiver Riddle  void rewrite(Operation *op, PatternRewriter &rewriter) const override {
23055de49acSRiver Riddle    rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
23155de49acSRiver Riddle        op, op->getResult(0).getType(), op->getOperand(0),
23255de49acSRiver Riddle        /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
23355de49acSRiver Riddle  }
23455de49acSRiver Riddle};
23555de49acSRiver Riddle
23655de49acSRiver Riddle/// Single-step rewrite with "matchAndRewrite". This allows for performing the
23755de49acSRiver Riddle/// rewrite immediately upon a successful match.
23855de49acSRiver Riddlestruct ConvertTFLeakyRelu : public RewritePattern {
23955de49acSRiver Riddle  ConvertTFLeakyRelu(MLIRContext *context)
24055de49acSRiver Riddle      : RewritePattern("tf.LeakyRelu", 1, context) {}
24155de49acSRiver Riddle
24255de49acSRiver Riddle  LogicalResult matchAndRewrite(Operation *op,
24355de49acSRiver Riddle                                PatternRewriter &rewriter) const override {
24455de49acSRiver Riddle    rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
24555de49acSRiver Riddle        op, op->getResult(0).getType(), op->getOperand(0),
24655de49acSRiver Riddle        /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
24755de49acSRiver Riddle    return success();
24855de49acSRiver Riddle  }
24955de49acSRiver Riddle};
25055de49acSRiver Riddle```
25155de49acSRiver Riddle
25255de49acSRiver RiddleIn the C++ rewrite the static benefit of the rewrite pattern is specified at
25355de49acSRiver Riddleconstruction. While in the pattern generator a simple heuristic is currently
25455de49acSRiver Riddleemployed based around the number of ops matched and replaced.
25568747266SChris Lattner
25655de49acSRiver RiddleThe above rule did not capture the matching operands/attributes, but in general
25755de49acSRiver Riddlethe `match` function in a multi-step rewrite may populate and return a
25855de49acSRiver Riddle`PatternState` (or class derived from one) to pass information extracted during
25955de49acSRiver Riddlematching to the rewrite. A single-step rewrite with the `matchAndRewrite`
26055de49acSRiver Riddlefunction has the benefit of being able to directly use any values created when
26155de49acSRiver Riddlematching; removing the need for `PatternState`.
26255de49acSRiver Riddle
26355de49acSRiver Riddle## Testing
26455de49acSRiver Riddle
26555de49acSRiver RiddleMLIR uses [lit](https://llvm.org/docs/CommandGuide/lit.html) (LLVM Integrated
26655de49acSRiver RiddleTesting) tool for performing testing. Testing is performed by way of creating
26755de49acSRiver Riddlethe input IR file, running a transformation and then verifying the output IR.
26855de49acSRiver RiddleC++ unit tests are the exception, with the IR transformation serving as the core
26955de49acSRiver Riddletesting mechanism. This results in fewer binaries that need to be built (and
27055de49acSRiver Riddlelinked) and forces to focus on the representation as an important piece.
27155de49acSRiver Riddle
27255de49acSRiver RiddleFor the legalization transform above we would have a test (probably as part of
27355de49acSRiver Riddlethe legalization pass test in TensorFlow Lite) such as:
27455de49acSRiver Riddle
27555de49acSRiver Riddle```mlir
27655de49acSRiver Riddle// RUN: mlir-opt -tfl-legalize-tf %s | FileCheck %s
27755de49acSRiver Riddle
2782310ced8SRiver Riddlefunc.func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
27955de49acSRiver Riddle  %2 = "tf.LeakyRelu"(%arg0) {alpha: 0.1} : (tensor<1xf32>) -> tensor<1xf32>
28055de49acSRiver Riddle  return %2: tensor<1xf32>
28155de49acSRiver Riddle
28255de49acSRiver Riddle// CHECK-LABEL: LeakyRelu
28355de49acSRiver Riddle// CHECK:  %0 = "tfl.leaky_relu"(%arg0) {alpha: 1.000000e-01} : (tensor<1xf32>) -> tensor<1xf32>
28455de49acSRiver Riddle}
28555de49acSRiver Riddle```
28655de49acSRiver Riddle
28755de49acSRiver RiddleThe RUN command at the top results in running the `mlir-opt` binary (which is
28855de49acSRiver Riddlecompiler writer tool to exercise different registered passes) to invoke the
28955de49acSRiver Riddleoptimization pass this transform was added as part of on the current file and to
29055de49acSRiver Riddleverify its output using `FileCheck`. `FileCheck` is textual output verifier. In
29155de49acSRiver Riddleparticular it uses the CHECK expressions to verify the given output is produced.
29255de49acSRiver Riddle
29355de49acSRiver RiddleThere can be multiple RUN commands with different corresponding CHECK prefixes.
29455de49acSRiver RiddleAnd in addition multiple independent tests separated by `// -----` and
29555de49acSRiver Riddle`mlir-opt` invoked with `-split-input-file` flag. This is especially useful for
29655de49acSRiver Riddleerror testing.
29755de49acSRiver Riddle
29855de49acSRiver RiddleThis results in very simple, directed testing without need to work around
29955de49acSRiver Riddleconstant propagation or other, unrelated, optimization passes.
30055de49acSRiver Riddle
30155de49acSRiver Riddle## Adding optimization pass
30255de49acSRiver Riddle
30355de49acSRiver RiddleOptimization passes that do not fit/are difficult to specify in the above
30455de49acSRiver Riddlestructure can be specified as general iterations across modules/functions. See
3051b012a91SMehdi Amini[Writing a Pass](../PassManagement.md) for a general overview and introduction to
30655de49acSRiver Riddleoptimization passes in MLIR.
307