1# Chapter 6: Lowering to LLVM and CodeGeneration 2 3[TOC] 4 5In the [previous chapter](Ch-5.md), we introduced the 6[dialect conversion](../../DialectConversion.md) framework and partially lowered 7many of the `Toy` operations to affine loop nests for optimization. In this 8chapter, we will finally lower to LLVM for code generation. 9 10## Lowering to LLVM 11 12For this lowering, we will again use the dialect conversion framework to perform 13the heavy lifting. However, this time, we will be performing a full conversion 14to the [LLVM dialect](../../Dialects/LLVM.md). Thankfully, we have already 15lowered all but one of the `toy` operations, with the last being `toy.print`. 16Before going over the conversion to LLVM, let's lower the `toy.print` operation. 17We will lower this operation to a non-affine loop nest that invokes `printf` for 18each element. Note that, because the dialect conversion framework supports 19[transitive lowering](../../../getting_started/Glossary.md/#transitive-lowering), 20we don't need to directly emit operations in the LLVM dialect. By transitive 21lowering, we mean that the conversion framework may apply multiple patterns to 22fully legalize an operation. In this example, we are generating a structured 23loop nest instead of the branch-form in the LLVM dialect. As long as we then 24have a lowering from the loop operations to LLVM, the lowering will still 25succeed. 26 27During lowering we can get, or build, the declaration for printf as so: 28 29```c++ 30/// Return a symbol reference to the printf function, inserting it into the 31/// module if necessary. 32static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, 33 ModuleOp module, 34 LLVM::LLVMDialect *llvmDialect) { 35 auto *context = module.getContext(); 36 if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf")) 37 return SymbolRefAttr::get("printf", context); 38 39 // Create a function declaration for printf, the signature is: 40 // * `i32 (i8*, ...)` 41 auto llvmI32Ty = IntegerType::get(context, 32); 42 auto llvmI8PtrTy = 43 LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); 44 auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy, 45 /*isVarArg=*/true); 46 47 // Insert the printf function into the body of the parent module. 48 PatternRewriter::InsertionGuard insertGuard(rewriter); 49 rewriter.setInsertionPointToStart(module.getBody()); 50 rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType); 51 return SymbolRefAttr::get("printf", context); 52} 53``` 54 55Now that the lowering for the printf operation has been defined, we can specify 56the components necessary for the lowering. These are largely the same as the 57components defined in the [previous chapter](Ch-5.md). 58 59### Conversion Target 60 61For this conversion, aside from the top-level module, we will be lowering 62everything to the LLVM dialect. 63 64```c++ 65 mlir::ConversionTarget target(getContext()); 66 target.addLegalDialect<mlir::LLVMDialect>(); 67 target.addLegalOp<mlir::ModuleOp>(); 68``` 69 70### Type Converter 71 72This lowering will also transform the MemRef types which are currently being 73operated on into a representation in LLVM. To perform this conversion, we use a 74TypeConverter as part of the lowering. This converter specifies how one type 75maps to another. This is necessary now that we are performing more complicated 76lowerings involving block arguments. Given that we don't have any 77Toy-dialect-specific types that need to be lowered, the default converter is 78enough for our use case. 79 80```c++ 81 LLVMTypeConverter typeConverter(&getContext()); 82``` 83 84### Conversion Patterns 85 86Now that the conversion target has been defined, we need to provide the patterns 87used for lowering. At this point in the compilation process, we have a 88combination of `toy`, `affine`, `arith`, and `std` operations. Luckily, the 89`affine`, `arith`, and `std` dialects already provide the set of patterns needed 90to transform them into LLVM dialect. These patterns allow for lowering the IR in 91multiple stages by relying on 92[transitive lowering](../../../getting_started/Glossary.md/#transitive-lowering). 93 94```c++ 95 mlir::RewritePatternSet patterns(&getContext()); 96 mlir::populateAffineToStdConversionPatterns(patterns, &getContext()); 97 mlir::cf::populateSCFToControlFlowConversionPatterns(patterns, &getContext()); 98 mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, 99 patterns); 100 mlir::populateFuncToLLVMConversionPatterns(typeConverter, patterns); 101 mlir::cf::populateControlFlowToLLVMConversionPatterns(patterns, &getContext()); 102 103 // The only remaining operation, to lower from the `toy` dialect, is the 104 // PrintOp. 105 patterns.add<PrintOpLowering>(&getContext()); 106``` 107 108### Full Lowering 109 110We want to completely lower to LLVM, so we use a `FullConversion`. This ensures 111that only legal operations will remain after the conversion. 112 113```c++ 114 mlir::ModuleOp module = getOperation(); 115 if (mlir::failed(mlir::applyFullConversion(module, target, patterns))) 116 signalPassFailure(); 117``` 118 119Looking back at our current working example: 120 121```mlir 122toy.func @main() { 123 %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> 124 %2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64> 125 %3 = toy.mul %2, %2 : tensor<3x2xf64> 126 toy.print %3 : tensor<3x2xf64> 127 toy.return 128} 129``` 130 131We can now lower down to the LLVM dialect, which produces the following code: 132 133```mlir 134llvm.func @free(!llvm<"i8*">) 135llvm.func @printf(!llvm<"i8*">, ...) -> i32 136llvm.func @malloc(i64) -> !llvm<"i8*"> 137llvm.func @main() { 138 %0 = llvm.mlir.constant(1.000000e+00 : f64) : f64 139 %1 = llvm.mlir.constant(2.000000e+00 : f64) : f64 140 141 ... 142 143^bb16: 144 %221 = llvm.extractvalue %25[0] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> 145 %222 = llvm.mlir.constant(0 : index) : i64 146 %223 = llvm.mlir.constant(2 : index) : i64 147 %224 = llvm.mul %214, %223 : i64 148 %225 = llvm.add %222, %224 : i64 149 %226 = llvm.mlir.constant(1 : index) : i64 150 %227 = llvm.mul %219, %226 : i64 151 %228 = llvm.add %225, %227 : i64 152 %229 = llvm.getelementptr %221[%228] : (!llvm."double*">, i64) -> !llvm<"f64*"> 153 %230 = llvm.load %229 : !llvm<"double*"> 154 %231 = llvm.call @printf(%207, %230) : (!llvm<"i8*">, f64) -> i32 155 %232 = llvm.add %219, %218 : i64 156 llvm.br ^bb15(%232 : i64) 157 158 ... 159 160^bb18: 161 %235 = llvm.extractvalue %65[0] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> 162 %236 = llvm.bitcast %235 : !llvm<"double*"> to !llvm<"i8*"> 163 llvm.call @free(%236) : (!llvm<"i8*">) -> () 164 %237 = llvm.extractvalue %45[0] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> 165 %238 = llvm.bitcast %237 : !llvm<"double*"> to !llvm<"i8*"> 166 llvm.call @free(%238) : (!llvm<"i8*">) -> () 167 %239 = llvm.extractvalue %25[0] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> 168 %240 = llvm.bitcast %239 : !llvm<"double*"> to !llvm<"i8*"> 169 llvm.call @free(%240) : (!llvm<"i8*">) -> () 170 llvm.return 171} 172``` 173 174See [LLVM IR Target](../../TargetLLVMIR.md) for 175more in-depth details on lowering to the LLVM dialect. 176 177## CodeGen: Getting Out of MLIR 178 179At this point we are right at the cusp of code generation. We can generate code 180in the LLVM dialect, so now we just need to export to LLVM IR and setup a JIT to 181run it. 182 183### Emitting LLVM IR 184 185Now that our module is comprised only of operations in the LLVM dialect, we can 186export to LLVM IR. To do this programmatically, we can invoke the following 187utility: 188 189```c++ 190 std::unique_ptr<llvm::Module> llvmModule = mlir::translateModuleToLLVMIR(module); 191 if (!llvmModule) 192 /* ... an error was encountered ... */ 193``` 194 195Exporting our module to LLVM IR generates: 196 197```llvm 198define void @main() { 199 ... 200 201102: 202 %103 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0 203 %104 = mul i64 %96, 2 204 %105 = add i64 0, %104 205 %106 = mul i64 %100, 1 206 %107 = add i64 %105, %106 207 %108 = getelementptr double, double* %103, i64 %107 208 %109 = memref.load double, double* %108 209 %110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109) 210 %111 = add i64 %100, 1 211 cf.br label %99 212 213 ... 214 215115: 216 %116 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %24, 0 217 %117 = bitcast double* %116 to i8* 218 call void @free(i8* %117) 219 %118 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %16, 0 220 %119 = bitcast double* %118 to i8* 221 call void @free(i8* %119) 222 %120 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0 223 %121 = bitcast double* %120 to i8* 224 call void @free(i8* %121) 225 ret void 226} 227``` 228 229If we enable optimization on the generated LLVM IR, we can trim this down quite 230a bit: 231 232```llvm 233define void @main() 234 %0 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.000000e+00) 235 %1 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.600000e+01) 236 %putchar = tail call i32 @putchar(i32 10) 237 %2 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 4.000000e+00) 238 %3 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 2.500000e+01) 239 %putchar.1 = tail call i32 @putchar(i32 10) 240 %4 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 9.000000e+00) 241 %5 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 3.600000e+01) 242 %putchar.2 = tail call i32 @putchar(i32 10) 243 ret void 244} 245``` 246 247The full code listing for dumping LLVM IR can be found in 248`examples/toy/Ch6/toy.cpp` in the `dumpLLVMIR()` function: 249 250```c++ 251 252int dumpLLVMIR(mlir::ModuleOp module) { 253 // Translate the module, that contains the LLVM dialect, to LLVM IR. Use a 254 // fresh LLVM IR context. (Note that LLVM is not thread-safe and any 255 // concurrent use of a context requires external locking.) 256 llvm::LLVMContext llvmContext; 257 auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext); 258 if (!llvmModule) { 259 llvm::errs() << "Failed to emit LLVM IR\n"; 260 return -1; 261 } 262 263 // Initialize LLVM targets. 264 llvm::InitializeNativeTarget(); 265 llvm::InitializeNativeTargetAsmPrinter(); 266 mlir::ExecutionEngine::setupTargetTriple(llvmModule.get()); 267 268 /// Optionally run an optimization pipeline over the llvm module. 269 auto optPipeline = mlir::makeOptimizingTransformer( 270 /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0, 271 /*targetMachine=*/nullptr); 272 if (auto err = optPipeline(llvmModule.get())) { 273 llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; 274 return -1; 275 } 276 llvm::errs() << *llvmModule << "\n"; 277 return 0; 278} 279``` 280 281### Setting up a JIT 282 283Setting up a JIT to run the module containing the LLVM dialect can be done using 284the `mlir::ExecutionEngine` infrastructure. This is a utility wrapper around 285LLVM's JIT that accepts `.mlir` as input. The full code listing for setting up 286the JIT can be found in `Ch6/toyc.cpp` in the `runJit()` function: 287 288```c++ 289int runJit(mlir::ModuleOp module) { 290 // Initialize LLVM targets. 291 llvm::InitializeNativeTarget(); 292 llvm::InitializeNativeTargetAsmPrinter(); 293 294 // An optimization pipeline to use within the execution engine. 295 auto optPipeline = mlir::makeOptimizingTransformer( 296 /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0, 297 /*targetMachine=*/nullptr); 298 299 // Create an MLIR execution engine. The execution engine eagerly JIT-compiles 300 // the module. 301 auto maybeEngine = mlir::ExecutionEngine::create(module, 302 /*llvmModuleBuilder=*/nullptr, optPipeline); 303 assert(maybeEngine && "failed to construct an execution engine"); 304 auto &engine = maybeEngine.get(); 305 306 // Invoke the JIT-compiled function. 307 auto invocationResult = engine->invoke("main"); 308 if (invocationResult) { 309 llvm::errs() << "JIT invocation failed\n"; 310 return -1; 311 } 312 313 return 0; 314} 315``` 316 317You can play around with it from the build directory: 318 319```shell 320$ echo 'def main() { print([[1, 2], [3, 4]]); }' | ./bin/toyc-ch6 -emit=jit 3211.000000 2.000000 3223.000000 4.000000 323``` 324 325You can also play with `-emit=mlir`, `-emit=mlir-affine`, `-emit=mlir-llvm`, and 326`-emit=llvm` to compare the various levels of IR involved. Also try options like 327[`--mlir-print-ir-after-all`](../../PassManagement.md/#ir-printing) to track the 328evolution of the IR throughout the pipeline. 329 330The example code used throughout this section can be found in 331test/Examples/Toy/Ch6/llvm-lowering.mlir. 332 333So far, we have worked with primitive data types. In the 334[next chapter](Ch-7.md), we will add a composite `struct` type. 335