xref: /llvm-project/mlir/docs/Tutorials/QuickstartRewrites.md (revision 1294fa697176c244e92e72c0b01fd3b5e3a06477)
1# Quickstart tutorial to adding MLIR graph rewrite
2
3This document will present a quickstart to adding graph rewrites. We shall start
4by defining an operation, showing multiple ways to define the rewrite using
5patterns, as well as defining the rewrite using a graph walker (note: using
6patterns and the rewrite engine is preferred, showing the walker is for
7demonstration purposes).
8
9See [MLIR specification](../LangRef.md) for more information about MLIR, the
10structure of the IR, operations, etc. See
11[Table-driven Operation Definition](../DefiningDialects/Operations.md) and
12[Declarative Rewrite Rule](../DeclarativeRewrites.md) for the detailed explanation
13of all available mechanisms for defining operations and rewrites in a
14table-driven manner.
15
16## Adding operation
17
18An operation in MLIR is specified using a definition in
19[TableGen](https://llvm.org/docs/TableGen/index.html) file. TableGen is a
20modeling tool to specify the ops and the C++ code to interact with these
21operations are generated from. To define an operation one needs to specify:
22
23*   The operation name. This name is a unique identifier of the operation within
24    MLIR. Most operations are within a dialect, so for example one could have
25    `tfl.add` to represent the add operation in the TensorFlow Lite dialect.
26    Instead of repeating the dialect in the op definition, a base class for the
27    op dialect is commonly created that prepends the dialect namespace given an
28    op name.
29*   The traits of the operation. These allow you to specify traits of the
30    operation, such as whether it has side effects or whether it should be
31    verified that the operands and result types are the same. These are backed
32    by C++ traits that perform the verification.
33*   The arguments of the operation. These are the input operands (values at
34    runtime produced by other ops) and attributes (compile time known constant
35    values that affect the behavior of the op) that are the inputs of/define the
36    behavior of the operation. The input operands may be named, the attributes
37    must be named.
38*   The result(s) of the operation. These may again named or not.
39*   Documentation of the operation. This includes a one-line summary as well as
40    a longer human-readable description of the operation.
41*   Dialect specific information. Additional information could be added to the
42    operation definition that are only used by dialect specific drivers. These
43    are ignored by the main op and doc generators, but could be used in, say,
44    the translation from a dialect to another representation.
45
46```tablegen
47def TFL_LeakyReluOp: TFL_Op<TFL_Dialect, "leaky_relu",
48                            [NoMemoryEffect, SameValueType]>,
49                     Results<(outs Tensor)> {
50  let arguments = (ins
51    F32Tensor:$x,
52    // Slope of the activation function at x < 0.
53    F32Attr:$alpha
54  );
55
56  let summary = "Leaky ReLU operator";
57  let description = [{
58    Element-wise Leaky ReLU operator
59      x -> x >= 0 ? x : (alpha * x)
60  }];
61
62  // TFLite specific attribute that is used when generating the output
63  // flatbuffer.
64  let hasOptions = 1;
65}
66```
67
68Note in the above the result types and inputs are specified in different ways,
69one by way of trait and the other by way of let. It is possible to specify both
70in either way.
71
72<!-- TODO: Define a style convention. -->
73
74Operations can also have custom parser, printer, builder, verifier, constant
75folder, or canonicalizer. These require specifying additional C++ methods to
76invoke for additional functionality. For example, if an operation is marked to
77have a folder, the constant folder also needs to be added, e.g.,:
78
79```c++
80OpFoldResult SpecificOp::fold(ArrayRef<Attribute> constOperands) {
81  if (unable_to_fold)
82    return {};
83  ....
84  return val;
85}
86```
87
88## Adding patterns
89
90There are multiple forms of graph rewrite that can be performed in MLIR. One of
91the most common is DAG tile to DAG tile rewrite. Patterns provide a concise way
92to express this transformation as a pair of source pattern to match and
93resultant pattern. There are both the C++ classes to represent this
94transformation, as well as the patterns in TableGen from which these can be
95generated.
96
97### TableGen patterns
98
99Let us continue with LeakyRelu. To map from TensorFlow's `LeakyRelu` to
100TensorFlow Lite's `LeakyRelu`:
101
102```tablegen
103def : Pat<(TF_LeakyReluOp $arg, F32Attr:$a), (TFL_LeakyReluOp $arg, $a)>
104```
105
106The pattern is specified by instantiating a `Pat` with a source and result DAG.
107The arguments in the source pattern is captured and can be used in the result
108pattern. This is a simple pattern as we have a 1:1 mapping and the attribute
109does not need to be transformed (e.g., both have a floating point attribute for
110alpha). The names of the attributes specified in the pattern is for
111matching/referencing and need not match the original attribute name in the op
112definition but the order of arguments of the dags do need to match.
113
114To specify a pattern, both the source and resultant ops need to be defined using
115TableGen.
116
117If this were a more advance pattern that the current framework could not express
118as destination then one could use a general native code fallback method. This
119consists of defining a pattern as well as adding a C++ function to perform the
120replacement:
121
122```tablegen
123def createTFLLeakyRelu : NativeCodeCall<
124    "createTFLLeakyRelu($_builder, $0.getDefiningOp(), $1, $2)">;
125
126def : Pat<(TF_LeakyReluOp:$old_value, $arg, F32Attr:$a),
127          (createTFLLeakyRelu $old_value, $arg, $a)>;
128```
129
130```c++
131static Value createTFLLeakyRelu(PatternRewriter &rewriter, Operation *op,
132                                Value operand, Attribute attr) {
133  return rewriter.create<mlir::TFL::LeakyReluOp>(
134      op->getLoc(), operands[0].getType(), /*arg=*/operands[0],
135      /*alpha=*/attrs[0].cast<FloatAttr>());
136}
137```
138
139This allows for arbitrarily complex builders. Input pattern side one can express
140multi-op patterns with constraints on input operands and attributes. But input
141patterns cannot yet express constraints across multiple operands/attributes.
142
143### Register the pattern
144
145The file containing the patterns need to be processed using `mlir-tblgen`
146`-gen-rewriters` during compilation time. It can be invoked with the following
147configuration in CMake:
148
149```cmake
150set(LLVM_TARGET_DEFINITIONS <name-of-the-td-file>)
151mlir_tablegen(<name-of-the-generated-inc-file> -gen-rewriters)
152add_public_tablegen_target(<name-of-the-cmake-target>)
153```
154
155Then you can `#include` the generated file in any C++ implementation file you
156like. (You will also need to make sure the library depends on the CMake target
157defined in the above.) The generated file will have a `populateWithGenerated(
158RewritePatternSet &patterns)` function that you can
159use to collect all the generated patterns inside `patterns` and then use
160`patterns` in any pass you would like.
161
162### Simple C++ `matchAndRewrite` style specifications
163
164Many simple rewrites can be expressed with a `matchAndRewrite` style  of
165pattern, e.g. when converting a multiply by a power of two into a shift.  For
166these cases, the you can define the pattern as a simple function:
167
168```c++
169static LogicalResult
170convertTFLeakyRelu(TFLeakyReluOp op, PatternRewriter &rewriter) {
171  rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
172      op, op->getResult(0).getType(), op->getOperand(0),
173      /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
174  return success();
175}
176
177void populateRewrites(RewritePatternSet &patternSet) {
178  // Add it to a pattern set.
179  patternSet.add(convertTFLeakyRelu);
180}
181```
182
183ODS provides a simple way to define a function-style canonicalization for your
184operation.  In the TableGen definition of the op, specify
185`let hasCanonicalizeMethod = 1;` and then implement the `canonicalize` method in
186your .cpp file:
187
188```c++
189// Example from the CIRCT project which has a variadic integer multiply.
190LogicalResult circt::MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
191  auto inputs = op.inputs();
192  APInt value;
193
194  // mul(x, c) -> shl(x, log2(c)), where c is a power of two.
195  if (inputs.size() == 2 && matchPattern(inputs.back(), m_RConstant(value)) &&
196      value.isPowerOf2()) {
197    auto shift = rewriter.create<rtl::ConstantOp>(op.getLoc(), op.getType(),
198                                                  value.exactLogBase2());
199    auto shlOp =
200        rewriter.create<comb::ShlOp>(op.getLoc(), inputs[0], shift);
201    rewriter.replaceOpWithNewOp<MulOp>(op, op.getType(),
202                                       ArrayRef<Value>(shlOp));
203    return success();
204  }
205
206  return failure();
207}
208```
209
210However, you may want the full generality of canonicalization patterns, for that
211you can specify an arbitrary list of `RewritePattern`s.
212
213### Fully general C++ `RewritePattern` specifications
214
215In case ODS patterns and `matchAndRewrite`-style functions are not sufficient
216you can also specify rewrites as a general set of `RewritePattern`s:
217
218```c++
219/// Multi-step rewrite using "match" and "rewrite". This allows for separating
220/// the concerns of matching and rewriting.
221struct ConvertTFLeakyRelu : public RewritePattern {
222  ConvertTFLeakyRelu(MLIRContext *context)
223      : RewritePattern("tf.LeakyRelu", 1, context) {}
224
225  LogicalResult match(Operation *op) const override {
226    return success();
227  }
228
229  void rewrite(Operation *op, PatternRewriter &rewriter) const override {
230    rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
231        op, op->getResult(0).getType(), op->getOperand(0),
232        /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
233  }
234};
235
236/// Single-step rewrite with "matchAndRewrite". This allows for performing the
237/// rewrite immediately upon a successful match.
238struct ConvertTFLeakyRelu : public RewritePattern {
239  ConvertTFLeakyRelu(MLIRContext *context)
240      : RewritePattern("tf.LeakyRelu", 1, context) {}
241
242  LogicalResult matchAndRewrite(Operation *op,
243                                PatternRewriter &rewriter) const override {
244    rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
245        op, op->getResult(0).getType(), op->getOperand(0),
246        /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
247    return success();
248  }
249};
250```
251
252In the C++ rewrite the static benefit of the rewrite pattern is specified at
253construction. While in the pattern generator a simple heuristic is currently
254employed based around the number of ops matched and replaced.
255
256The above rule did not capture the matching operands/attributes, but in general
257the `match` function in a multi-step rewrite may populate and return a
258`PatternState` (or class derived from one) to pass information extracted during
259matching to the rewrite. A single-step rewrite with the `matchAndRewrite`
260function has the benefit of being able to directly use any values created when
261matching; removing the need for `PatternState`.
262
263## Testing
264
265MLIR uses [lit](https://llvm.org/docs/CommandGuide/lit.html) (LLVM Integrated
266Testing) tool for performing testing. Testing is performed by way of creating
267the input IR file, running a transformation and then verifying the output IR.
268C++ unit tests are the exception, with the IR transformation serving as the core
269testing mechanism. This results in fewer binaries that need to be built (and
270linked) and forces to focus on the representation as an important piece.
271
272For the legalization transform above we would have a test (probably as part of
273the legalization pass test in TensorFlow Lite) such as:
274
275```mlir
276// RUN: mlir-opt -tfl-legalize-tf %s | FileCheck %s
277
278func.func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
279  %2 = "tf.LeakyRelu"(%arg0) {alpha: 0.1} : (tensor<1xf32>) -> tensor<1xf32>
280  return %2: tensor<1xf32>
281
282// CHECK-LABEL: LeakyRelu
283// CHECK:  %0 = "tfl.leaky_relu"(%arg0) {alpha: 1.000000e-01} : (tensor<1xf32>) -> tensor<1xf32>
284}
285```
286
287The RUN command at the top results in running the `mlir-opt` binary (which is
288compiler writer tool to exercise different registered passes) to invoke the
289optimization pass this transform was added as part of on the current file and to
290verify its output using `FileCheck`. `FileCheck` is textual output verifier. In
291particular it uses the CHECK expressions to verify the given output is produced.
292
293There can be multiple RUN commands with different corresponding CHECK prefixes.
294And in addition multiple independent tests separated by `// -----` and
295`mlir-opt` invoked with `-split-input-file` flag. This is especially useful for
296error testing.
297
298This results in very simple, directed testing without need to work around
299constant propagation or other, unrelated, optimization passes.
300
301## Adding optimization pass
302
303Optimization passes that do not fit/are difficult to specify in the above
304structure can be specified as general iterations across modules/functions. See
305[Writing a Pass](../PassManagement.md) for a general overview and introduction to
306optimization passes in MLIR.
307