xref: /llvm-project/mlir/docs/Tutorials/Toy/Ch-5.md (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1# Chapter 5: Partial Lowering to Lower-Level Dialects for Optimization
2
3[TOC]
4
5At this point, we are eager to generate actual code and see our Toy language
6take life. We will use LLVM to generate code, but just showing the LLVM builder
7interface here wouldn't be very exciting. Instead, we will show how to perform
8progressive lowering through a mix of dialects coexisting in the same function.
9
10To make it more interesting, in this chapter we will consider that we want to
11reuse existing optimizations implemented in a dialect optimizing affine
12transformations: `Affine`. This dialect is tailored to the computation-heavy
13part of the program and is limited: it doesn't support representing our
14`toy.print` builtin, for instance, neither should it! Instead, we can target
15`Affine` for the computation heavy part of Toy, and in the
16[next chapter](Ch-6.md) directly target the `LLVM IR` dialect for lowering
17`print`. As part of this lowering, we will be lowering from the
18[TensorType](../../Dialects/Builtin.md/#rankedtensortype) that `Toy` operates on
19to the [MemRefType](../../Dialects/Builtin.md/#memreftype) that is indexed via
20an affine loop-nest. Tensors represent an abstract value-typed sequence of data,
21meaning that they don't live in any memory. MemRefs, on the other hand,
22represent lower level buffer access, as they are concrete references to a region
23of memory.
24
25# Dialect Conversions
26
27MLIR has many different dialects, so it is important to have a unified framework
28for [converting](../../../getting_started/Glossary.md/#conversion) between them.
29This is where the `DialectConversion` framework comes into play. This framework
30allows for transforming a set of *illegal* operations to a set of *legal* ones.
31To use this framework, we need to provide two things (and an optional third):
32
33*   A [Conversion Target](../../DialectConversion.md/#conversion-target)
34
35    -   This is the formal specification of what operations or dialects are
36        legal for the conversion. Operations that aren't legal will require
37        rewrite patterns to perform
38        [legalization](../../../getting_started/Glossary.md/#legalization).
39
40*   A set of
41    [Rewrite Patterns](../../DialectConversion.md/#rewrite-pattern-specification)
42
43    -   This is the set of [patterns](../QuickstartRewrites.md) used to convert
44        *illegal* operations into a set of zero or more *legal* ones.
45
46*   Optionally, a [Type Converter](../../DialectConversion.md/#type-conversion).
47
48    -   If provided, this is used to convert the types of block arguments. We
49        won't be needing this for our conversion.
50
51## Conversion Target
52
53For our purposes, we want to convert the compute-intensive `Toy` operations into
54a combination of operations from the `Affine`, `Arith`, `Func`, and `MemRef` dialects
55for further optimization. To start off the lowering, we first define our
56conversion target:
57
58```c++
59void ToyToAffineLoweringPass::runOnOperation() {
60  // The first thing to define is the conversion target. This will define the
61  // final target for this lowering.
62  mlir::ConversionTarget target(getContext());
63
64  // We define the specific operations, or dialects, that are legal targets for
65  // this lowering. In our case, we are lowering to a combination of the
66  // `Affine`, `Arith`, `Func`, and `MemRef` dialects.
67  target.addLegalDialect<affine::AffineDialect, arith::ArithDialect,
68                         func::FuncDialect, memref::MemRefDialect>();
69
70  // We also define the Toy dialect as Illegal so that the conversion will fail
71  // if any of these operations are *not* converted. Given that we actually want
72  // a partial lowering, we explicitly mark the Toy operations that don't want
73  // to lower, `toy.print`, as *legal*. `toy.print` will still need its operands
74  // to be updated though (as we convert from TensorType to MemRefType), so we
75  // only treat it as `legal` if its operands are legal.
76  target.addIllegalDialect<ToyDialect>();
77  target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
78    return llvm::none_of(op->getOperandTypes(),
79                         [](Type type) { return type.isa<TensorType>(); });
80  });
81  ...
82}
83```
84
85Above, we first set the toy dialect to illegal, and then the print operation as
86legal. We could have done this the other way around. Individual operations
87always take precedence over the (more generic) dialect definitions, so the order
88doesn't matter. See `ConversionTarget::getOpInfo` for the details.
89
90## Conversion Patterns
91
92After the conversion target has been defined, we can define how to convert the
93*illegal* operations into *legal* ones. Similarly to the canonicalization
94framework introduced in [chapter 3](Ch-3.md), the
95[`DialectConversion` framework](../../DialectConversion.md) also uses
96[RewritePatterns](../QuickstartRewrites.md) to perform the conversion logic.
97These patterns may be the `RewritePatterns` seen before or a new type of pattern
98specific to the conversion framework `ConversionPattern`. `ConversionPatterns`
99are different from traditional `RewritePatterns` in that they accept an
100additional `operands` parameter containing operands that have been
101remapped/replaced. This is used when dealing with type conversions, as the
102pattern will want to operate on values of the new type but match against the
103old. For our lowering, this invariant will be useful as it translates from the
104[TensorType](../../Dialects/Builtin.md/#rankedtensortype) currently being
105operated on to the [MemRefType](../../Dialects/Builtin.md/#memreftype). Let's
106look at a snippet of lowering the `toy.transpose` operation:
107
108```c++
109/// Lower the `toy.transpose` operation to an affine loop nest.
110struct TransposeOpLowering : public mlir::ConversionPattern {
111  TransposeOpLowering(mlir::MLIRContext *ctx)
112      : mlir::ConversionPattern(TransposeOp::getOperationName(), 1, ctx) {}
113
114  /// Match and rewrite the given `toy.transpose` operation, with the given
115  /// operands that have been remapped from `tensor<...>` to `memref<...>`.
116  llvm::LogicalResult
117  matchAndRewrite(mlir::Operation *op, ArrayRef<mlir::Value> operands,
118                  mlir::ConversionPatternRewriter &rewriter) const final {
119    auto loc = op->getLoc();
120
121    // Call to a helper function that will lower the current operation to a set
122    // of affine loops. We provide a functor that operates on the remapped
123    // operands, as well as the loop induction variables for the inner most
124    // loop body.
125    lowerOpToLoops(
126        op, operands, rewriter,
127        [loc](mlir::PatternRewriter &rewriter,
128              ArrayRef<mlir::Value> memRefOperands,
129              ArrayRef<mlir::Value> loopIvs) {
130          // Generate an adaptor for the remapped operands of the TransposeOp.
131          // This allows for using the nice named accessors that are generated
132          // by the ODS. This adaptor is automatically provided by the ODS
133          // framework.
134          TransposeOpAdaptor transposeAdaptor(memRefOperands);
135          mlir::Value input = transposeAdaptor.input();
136
137          // Transpose the elements by generating a load from the reverse
138          // indices.
139          SmallVector<mlir::Value, 2> reverseIvs(llvm::reverse(loopIvs));
140          return rewriter.create<mlir::AffineLoadOp>(loc, input, reverseIvs);
141        });
142    return success();
143  }
144};
145```
146
147Now we can prepare the list of patterns to use during the lowering process:
148
149```c++
150void ToyToAffineLoweringPass::runOnOperation() {
151  ...
152
153  // Now that the conversion target has been defined, we just need to provide
154  // the set of patterns that will lower the Toy operations.
155  mlir::RewritePatternSet patterns(&getContext());
156  patterns.add<..., TransposeOpLowering>(&getContext());
157
158  ...
159```
160
161## Partial Lowering
162
163Once the patterns have been defined, we can perform the actual lowering. The
164`DialectConversion` framework provides several different modes of lowering, but,
165for our purposes, we will perform a partial lowering, as we will not convert
166`toy.print` at this time.
167
168```c++
169void ToyToAffineLoweringPass::runOnOperation() {
170  ...
171
172  // With the target and rewrite patterns defined, we can now attempt the
173  // conversion. The conversion will signal failure if any of our *illegal*
174  // operations were not converted successfully.
175  if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, patterns)))
176    signalPassFailure();
177}
178```
179
180### Design Considerations With Partial Lowering
181
182Before diving into the result of our lowering, this is a good time to discuss
183potential design considerations when it comes to partial lowering. In our
184lowering, we transform from a value-type, TensorType, to an allocated
185(buffer-like) type, MemRefType. However, given that we do not lower the
186`toy.print` operation, we need to temporarily bridge these two worlds. There are
187many ways to go about this, each with their own tradeoffs:
188
189*   Generate `load` operations from the buffer
190
191    One option is to generate `load` operations from the buffer type to
192    materialize an instance of the value type. This allows for the definition of
193    the `toy.print` operation to remain unchanged. The downside to this approach
194    is that the optimizations on the `affine` dialect are limited, because the
195    `load` will actually involve a full copy that is only visible *after* our
196    optimizations have been performed.
197
198*   Generate a new version of `toy.print` that operates on the lowered type
199
200    Another option would be to have another, lowered, variant of `toy.print`
201    that operates on the lowered type. The benefit of this option is that there
202    is no hidden, unnecessary copy to the optimizer. The downside is that
203    another operation definition is needed that may duplicate many aspects of
204    the first. Defining a base class in [ODS](../../DefiningDialects/Operations.md) may
205    simplify this, but you still need to treat these operations separately.
206
207*   Update `toy.print` to allow for operating on the lowered type
208
209    A third option is to update the current definition of `toy.print` to allow
210    for operating the on the lowered type. The benefit of this approach is that
211    it is simple, does not introduce an additional hidden copy, and does not
212    require another operation definition. The downside to this option is that it
213    requires mixing abstraction levels in the `Toy` dialect.
214
215For the sake of simplicity, we will use the third option for this lowering. This
216involves updating the type constraints on the PrintOp in the operation
217definition file:
218
219```tablegen
220def PrintOp : Toy_Op<"print"> {
221  ...
222
223  // The print operation takes an input tensor to print.
224  // We also allow a F64MemRef to enable interop during partial lowering.
225  let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input);
226}
227```
228
229## Complete Toy Example
230
231Let's take a concrete example:
232
233```mlir
234toy.func @main() {
235  %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
236  %2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
237  %3 = toy.mul %2, %2 : tensor<3x2xf64>
238  toy.print %3 : tensor<3x2xf64>
239  toy.return
240}
241```
242
243With affine lowering added to our pipeline, we can now generate:
244
245```mlir
246func.func @main() {
247  %cst = arith.constant 1.000000e+00 : f64
248  %cst_0 = arith.constant 2.000000e+00 : f64
249  %cst_1 = arith.constant 3.000000e+00 : f64
250  %cst_2 = arith.constant 4.000000e+00 : f64
251  %cst_3 = arith.constant 5.000000e+00 : f64
252  %cst_4 = arith.constant 6.000000e+00 : f64
253
254  // Allocating buffers for the inputs and outputs.
255  %0 = memref.alloc() : memref<3x2xf64>
256  %1 = memref.alloc() : memref<3x2xf64>
257  %2 = memref.alloc() : memref<2x3xf64>
258
259  // Initialize the input buffer with the constant values.
260  affine.store %cst, %2[0, 0] : memref<2x3xf64>
261  affine.store %cst_0, %2[0, 1] : memref<2x3xf64>
262  affine.store %cst_1, %2[0, 2] : memref<2x3xf64>
263  affine.store %cst_2, %2[1, 0] : memref<2x3xf64>
264  affine.store %cst_3, %2[1, 1] : memref<2x3xf64>
265  affine.store %cst_4, %2[1, 2] : memref<2x3xf64>
266
267  // Load the transpose value from the input buffer and store it into the
268  // next input buffer.
269  affine.for %arg0 = 0 to 3 {
270    affine.for %arg1 = 0 to 2 {
271      %3 = affine.load %2[%arg1, %arg0] : memref<2x3xf64>
272      affine.store %3, %1[%arg0, %arg1] : memref<3x2xf64>
273    }
274  }
275
276  // Multiply and store into the output buffer.
277  affine.for %arg0 = 0 to 3 {
278    affine.for %arg1 = 0 to 2 {
279      %3 = affine.load %1[%arg0, %arg1] : memref<3x2xf64>
280      %4 = affine.load %1[%arg0, %arg1] : memref<3x2xf64>
281      %5 = arith.mulf %3, %4 : f64
282      affine.store %5, %0[%arg0, %arg1] : memref<3x2xf64>
283    }
284  }
285
286  // Print the value held by the buffer.
287  toy.print %0 : memref<3x2xf64>
288  memref.dealloc %2 : memref<2x3xf64>
289  memref.dealloc %1 : memref<3x2xf64>
290  memref.dealloc %0 : memref<3x2xf64>
291  return
292}
293```
294
295## Taking Advantage of Affine Optimization
296
297Our naive lowering is correct, but it leaves a lot to be desired with regards to
298efficiency. For example, the lowering of `toy.mul` has generated some redundant
299loads. Let's look at how adding a few existing optimizations to the pipeline can
300help clean this up. Adding the `LoopFusion` and `AffineScalarReplacement` passes
301to the pipeline gives the following result:
302
303```mlir
304func.func @main() {
305  %cst = arith.constant 1.000000e+00 : f64
306  %cst_0 = arith.constant 2.000000e+00 : f64
307  %cst_1 = arith.constant 3.000000e+00 : f64
308  %cst_2 = arith.constant 4.000000e+00 : f64
309  %cst_3 = arith.constant 5.000000e+00 : f64
310  %cst_4 = arith.constant 6.000000e+00 : f64
311
312  // Allocating buffers for the inputs and outputs.
313  %0 = memref.alloc() : memref<3x2xf64>
314  %1 = memref.alloc() : memref<2x3xf64>
315
316  // Initialize the input buffer with the constant values.
317  affine.store %cst, %1[0, 0] : memref<2x3xf64>
318  affine.store %cst_0, %1[0, 1] : memref<2x3xf64>
319  affine.store %cst_1, %1[0, 2] : memref<2x3xf64>
320  affine.store %cst_2, %1[1, 0] : memref<2x3xf64>
321  affine.store %cst_3, %1[1, 1] : memref<2x3xf64>
322  affine.store %cst_4, %1[1, 2] : memref<2x3xf64>
323
324  affine.for %arg0 = 0 to 3 {
325    affine.for %arg1 = 0 to 2 {
326      // Load the transpose value from the input buffer.
327      %2 = affine.load %1[%arg1, %arg0] : memref<2x3xf64>
328
329      // Multiply and store into the output buffer.
330      %3 = arith.mulf %2, %2 : f64
331      affine.store %3, %0[%arg0, %arg1] : memref<3x2xf64>
332    }
333  }
334
335  // Print the value held by the buffer.
336  toy.print %0 : memref<3x2xf64>
337  memref.dealloc %1 : memref<2x3xf64>
338  memref.dealloc %0 : memref<3x2xf64>
339  return
340}
341```
342
343Here, we can see that a redundant allocation was removed, the two loop nests
344were fused, and some unnecessary `load`s were removed. You can build `toyc-ch5`
345and try yourself: `toyc-ch5 test/Examples/Toy/Ch5/affine-lowering.mlir
346-emit=mlir-affine`. We can also check our optimizations by adding `-opt`.
347
348In this chapter we explored some aspects of partial lowering, with the intent to
349optimize. In the [next chapter](Ch-6.md) we will continue the discussion about
350dialect conversion by targeting LLVM for code generation.
351