xref: /llvm-project/mlir/docs/Tutorials/Toy/Ch-6.md (revision 5ad919a1d6bbe449210f81e4ae7a5f765eb5a976)
1# Chapter 6: Lowering to LLVM and CodeGeneration
2
3[TOC]
4
5In the [previous chapter](Ch-5.md), we introduced the
6[dialect conversion](../../DialectConversion.md) framework and partially lowered
7many of the `Toy` operations to affine loop nests for optimization. In this
8chapter, we will finally lower to LLVM for code generation.
9
10## Lowering to LLVM
11
12For this lowering, we will again use the dialect conversion framework to perform
13the heavy lifting. However, this time, we will be performing a full conversion
14to the [LLVM dialect](../../Dialects/LLVM.md). Thankfully, we have already
15lowered all but one of the `toy` operations, with the last being `toy.print`.
16Before going over the conversion to LLVM, let's lower the `toy.print` operation.
17We will lower this operation to a non-affine loop nest that invokes `printf` for
18each element. Note that, because the dialect conversion framework supports
19[transitive lowering](../../../getting_started/Glossary.md/#transitive-lowering),
20we don't need to directly emit operations in the LLVM dialect. By transitive
21lowering, we mean that the conversion framework may apply multiple patterns to
22fully legalize an operation. In this example, we are generating a structured
23loop nest instead of the branch-form in the LLVM dialect. As long as we then
24have a lowering from the loop operations to LLVM, the lowering will still
25succeed.
26
27During lowering we can get, or build, the declaration for printf as so:
28
29```c++
30/// Return a symbol reference to the printf function, inserting it into the
31/// module if necessary.
32static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
33                                           ModuleOp module,
34                                           LLVM::LLVMDialect *llvmDialect) {
35  auto *context = module.getContext();
36  if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
37    return SymbolRefAttr::get("printf", context);
38
39  // Create a function declaration for printf, the signature is:
40  //   * `i32 (i8*, ...)`
41  auto llvmI32Ty = IntegerType::get(context, 32);
42  auto llvmI8PtrTy =
43      LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
44  auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
45                                                /*isVarArg=*/true);
46
47  // Insert the printf function into the body of the parent module.
48  PatternRewriter::InsertionGuard insertGuard(rewriter);
49  rewriter.setInsertionPointToStart(module.getBody());
50  rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
51  return SymbolRefAttr::get("printf", context);
52}
53```
54
55Now that the lowering for the printf operation has been defined, we can specify
56the components necessary for the lowering. These are largely the same as the
57components defined in the [previous chapter](Ch-5.md).
58
59### Conversion Target
60
61For this conversion, aside from the top-level module, we will be lowering
62everything to the LLVM dialect.
63
64```c++
65  mlir::ConversionTarget target(getContext());
66  target.addLegalDialect<mlir::LLVMDialect>();
67  target.addLegalOp<mlir::ModuleOp>();
68```
69
70### Type Converter
71
72This lowering will also transform the MemRef types which are currently being
73operated on into a representation in LLVM. To perform this conversion, we use a
74TypeConverter as part of the lowering. This converter specifies how one type
75maps to another. This is necessary now that we are performing more complicated
76lowerings involving block arguments. Given that we don't have any
77Toy-dialect-specific types that need to be lowered, the default converter is
78enough for our use case.
79
80```c++
81  LLVMTypeConverter typeConverter(&getContext());
82```
83
84### Conversion Patterns
85
86Now that the conversion target has been defined, we need to provide the patterns
87used for lowering. At this point in the compilation process, we have a
88combination of `toy`, `affine`, `arith`, and `std` operations. Luckily, the
89`affine`, `arith`, and `std` dialects already provide the set of patterns needed
90to transform them into LLVM dialect. These patterns allow for lowering the IR in
91multiple stages by relying on
92[transitive lowering](../../../getting_started/Glossary.md/#transitive-lowering).
93
94```c++
95  mlir::RewritePatternSet patterns(&getContext());
96  mlir::populateAffineToStdConversionPatterns(patterns, &getContext());
97  mlir::cf::populateSCFToControlFlowConversionPatterns(patterns, &getContext());
98  mlir::arith::populateArithToLLVMConversionPatterns(typeConverter,
99                                                          patterns);
100  mlir::populateFuncToLLVMConversionPatterns(typeConverter, patterns);
101  mlir::cf::populateControlFlowToLLVMConversionPatterns(patterns, &getContext());
102
103  // The only remaining operation, to lower from the `toy` dialect, is the
104  // PrintOp.
105  patterns.add<PrintOpLowering>(&getContext());
106```
107
108### Full Lowering
109
110We want to completely lower to LLVM, so we use a `FullConversion`. This ensures
111that only legal operations will remain after the conversion.
112
113```c++
114  mlir::ModuleOp module = getOperation();
115  if (mlir::failed(mlir::applyFullConversion(module, target, patterns)))
116    signalPassFailure();
117```
118
119Looking back at our current working example:
120
121```mlir
122toy.func @main() {
123  %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
124  %2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
125  %3 = toy.mul %2, %2 : tensor<3x2xf64>
126  toy.print %3 : tensor<3x2xf64>
127  toy.return
128}
129```
130
131We can now lower down to the LLVM dialect, which produces the following code:
132
133```mlir
134llvm.func @free(!llvm<"i8*">)
135llvm.func @printf(!llvm<"i8*">, ...) -> i32
136llvm.func @malloc(i64) -> !llvm<"i8*">
137llvm.func @main() {
138  %0 = llvm.mlir.constant(1.000000e+00 : f64) : f64
139  %1 = llvm.mlir.constant(2.000000e+00 : f64) : f64
140
141  ...
142
143^bb16:
144  %221 = llvm.extractvalue %25[0] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
145  %222 = llvm.mlir.constant(0 : index) : i64
146  %223 = llvm.mlir.constant(2 : index) : i64
147  %224 = llvm.mul %214, %223 : i64
148  %225 = llvm.add %222, %224 : i64
149  %226 = llvm.mlir.constant(1 : index) : i64
150  %227 = llvm.mul %219, %226 : i64
151  %228 = llvm.add %225, %227 : i64
152  %229 = llvm.getelementptr %221[%228] : (!llvm."double*">, i64) -> !llvm<"f64*">
153  %230 = llvm.load %229 : !llvm<"double*">
154  %231 = llvm.call @printf(%207, %230) : (!llvm<"i8*">, f64) -> i32
155  %232 = llvm.add %219, %218 : i64
156  llvm.br ^bb15(%232 : i64)
157
158  ...
159
160^bb18:
161  %235 = llvm.extractvalue %65[0] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
162  %236 = llvm.bitcast %235 : !llvm<"double*"> to !llvm<"i8*">
163  llvm.call @free(%236) : (!llvm<"i8*">) -> ()
164  %237 = llvm.extractvalue %45[0] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
165  %238 = llvm.bitcast %237 : !llvm<"double*"> to !llvm<"i8*">
166  llvm.call @free(%238) : (!llvm<"i8*">) -> ()
167  %239 = llvm.extractvalue %25[0] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
168  %240 = llvm.bitcast %239 : !llvm<"double*"> to !llvm<"i8*">
169  llvm.call @free(%240) : (!llvm<"i8*">) -> ()
170  llvm.return
171}
172```
173
174See [LLVM IR Target](../../TargetLLVMIR.md) for
175more in-depth details on lowering to the LLVM dialect.
176
177## CodeGen: Getting Out of MLIR
178
179At this point we are right at the cusp of code generation. We can generate code
180in the LLVM dialect, so now we just need to export to LLVM IR and setup a JIT to
181run it.
182
183### Emitting LLVM IR
184
185Now that our module is comprised only of operations in the LLVM dialect, we can
186export to LLVM IR. To do this programmatically, we can invoke the following
187utility:
188
189```c++
190  std::unique_ptr<llvm::Module> llvmModule = mlir::translateModuleToLLVMIR(module);
191  if (!llvmModule)
192    /* ... an error was encountered ... */
193```
194
195Exporting our module to LLVM IR generates:
196
197```llvm
198define void @main() {
199  ...
200
201102:
202  %103 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0
203  %104 = mul i64 %96, 2
204  %105 = add i64 0, %104
205  %106 = mul i64 %100, 1
206  %107 = add i64 %105, %106
207  %108 = getelementptr double, double* %103, i64 %107
208  %109 = memref.load double, double* %108
209  %110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109)
210  %111 = add i64 %100, 1
211  cf.br label %99
212
213  ...
214
215115:
216  %116 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %24, 0
217  %117 = bitcast double* %116 to i8*
218  call void @free(i8* %117)
219  %118 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %16, 0
220  %119 = bitcast double* %118 to i8*
221  call void @free(i8* %119)
222  %120 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0
223  %121 = bitcast double* %120 to i8*
224  call void @free(i8* %121)
225  ret void
226}
227```
228
229If we enable optimization on the generated LLVM IR, we can trim this down quite
230a bit:
231
232```llvm
233define void @main()
234  %0 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.000000e+00)
235  %1 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.600000e+01)
236  %putchar = tail call i32 @putchar(i32 10)
237  %2 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 4.000000e+00)
238  %3 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 2.500000e+01)
239  %putchar.1 = tail call i32 @putchar(i32 10)
240  %4 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 9.000000e+00)
241  %5 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 3.600000e+01)
242  %putchar.2 = tail call i32 @putchar(i32 10)
243  ret void
244}
245```
246
247The full code listing for dumping LLVM IR can be found in
248`examples/toy/Ch6/toy.cpp` in the `dumpLLVMIR()` function:
249
250```c++
251
252int dumpLLVMIR(mlir::ModuleOp module) {
253  // Translate the module, that contains the LLVM dialect, to LLVM IR. Use a
254  // fresh LLVM IR context. (Note that LLVM is not thread-safe and any
255  // concurrent use of a context requires external locking.)
256  llvm::LLVMContext llvmContext;
257  auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext);
258  if (!llvmModule) {
259    llvm::errs() << "Failed to emit LLVM IR\n";
260    return -1;
261  }
262
263  // Initialize LLVM targets.
264  llvm::InitializeNativeTarget();
265  llvm::InitializeNativeTargetAsmPrinter();
266  mlir::ExecutionEngine::setupTargetTriple(llvmModule.get());
267
268  /// Optionally run an optimization pipeline over the llvm module.
269  auto optPipeline = mlir::makeOptimizingTransformer(
270      /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
271      /*targetMachine=*/nullptr);
272  if (auto err = optPipeline(llvmModule.get())) {
273    llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
274    return -1;
275  }
276  llvm::errs() << *llvmModule << "\n";
277  return 0;
278}
279```
280
281### Setting up a JIT
282
283Setting up a JIT to run the module containing the LLVM dialect can be done using
284the `mlir::ExecutionEngine` infrastructure. This is a utility wrapper around
285LLVM's JIT that accepts `.mlir` as input. The full code listing for setting up
286the JIT can be found in `Ch6/toyc.cpp` in the `runJit()` function:
287
288```c++
289int runJit(mlir::ModuleOp module) {
290  // Initialize LLVM targets.
291  llvm::InitializeNativeTarget();
292  llvm::InitializeNativeTargetAsmPrinter();
293
294  // An optimization pipeline to use within the execution engine.
295  auto optPipeline = mlir::makeOptimizingTransformer(
296      /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
297      /*targetMachine=*/nullptr);
298
299  // Create an MLIR execution engine. The execution engine eagerly JIT-compiles
300  // the module.
301  auto maybeEngine = mlir::ExecutionEngine::create(module,
302      /*llvmModuleBuilder=*/nullptr, optPipeline);
303  assert(maybeEngine && "failed to construct an execution engine");
304  auto &engine = maybeEngine.get();
305
306  // Invoke the JIT-compiled function.
307  auto invocationResult = engine->invoke("main");
308  if (invocationResult) {
309    llvm::errs() << "JIT invocation failed\n";
310    return -1;
311  }
312
313  return 0;
314}
315```
316
317You can play around with it from the build directory:
318
319```shell
320$ echo 'def main() { print([[1, 2], [3, 4]]); }' | ./bin/toyc-ch6 -emit=jit
3211.000000 2.000000
3223.000000 4.000000
323```
324
325You can also play with `-emit=mlir`, `-emit=mlir-affine`, `-emit=mlir-llvm`, and
326`-emit=llvm` to compare the various levels of IR involved. Also try options like
327[`--mlir-print-ir-after-all`](../../PassManagement.md/#ir-printing) to track the
328evolution of the IR throughout the pipeline.
329
330The example code used throughout this section can be found in
331test/Examples/Toy/Ch6/llvm-lowering.mlir.
332
333So far, we have worked with primitive data types. In the
334[next chapter](Ch-7.md), we will add a composite `struct` type.
335