1 //===- toyc.cpp - The Toy Compiler ----------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the entry point for the Toy compiler. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Func/Extensions/AllExtensions.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" 16 #include "toy/AST.h" 17 #include "toy/Dialect.h" 18 #include "toy/Lexer.h" 19 #include "toy/MLIRGen.h" 20 #include "toy/Parser.h" 21 #include "toy/Passes.h" 22 23 #include "mlir/Dialect/Affine/Passes.h" 24 #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" 25 #include "mlir/ExecutionEngine/ExecutionEngine.h" 26 #include "mlir/ExecutionEngine/OptUtils.h" 27 #include "mlir/IR/AsmState.h" 28 #include "mlir/IR/BuiltinOps.h" 29 #include "mlir/IR/MLIRContext.h" 30 #include "mlir/IR/Verifier.h" 31 #include "mlir/InitAllDialects.h" 32 #include "mlir/Parser/Parser.h" 33 #include "mlir/Pass/PassManager.h" 34 #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" 35 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" 36 #include "mlir/Target/LLVMIR/Export.h" 37 #include "mlir/Transforms/Passes.h" 38 39 #include "llvm/ADT/StringRef.h" 40 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" 41 #include "llvm/IR/Module.h" 42 #include "llvm/Support/CommandLine.h" 43 #include "llvm/Support/ErrorOr.h" 44 #include "llvm/Support/MemoryBuffer.h" 45 #include "llvm/Support/SourceMgr.h" 46 #include "llvm/Support/TargetSelect.h" 47 #include "llvm/Support/raw_ostream.h" 48 #include <cassert> 49 #include <memory> 50 #include <string> 51 #include <system_error> 52 #include <utility> 53 54 using namespace toy; 55 namespace cl = llvm::cl; 56 57 static cl::opt<std::string> inputFilename(cl::Positional, 58 cl::desc("<input toy file>"), 59 cl::init("-"), 60 cl::value_desc("filename")); 61 62 namespace { 63 enum InputType { Toy, MLIR }; 64 } // namespace 65 static cl::opt<enum InputType> inputType( 66 "x", cl::init(Toy), cl::desc("Decided the kind of output desired"), 67 cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), 68 cl::values(clEnumValN(MLIR, "mlir", 69 "load the input file as an MLIR file"))); 70 71 namespace { 72 enum Action { 73 None, 74 DumpAST, 75 DumpMLIR, 76 DumpMLIRAffine, 77 DumpMLIRLLVM, 78 DumpLLVMIR, 79 RunJIT 80 }; 81 } // namespace 82 static cl::opt<enum Action> emitAction( 83 "emit", cl::desc("Select the kind of output desired"), 84 cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), 85 cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")), 86 cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine", 87 "output the MLIR dump after affine lowering")), 88 cl::values(clEnumValN(DumpMLIRLLVM, "mlir-llvm", 89 "output the MLIR dump after llvm lowering")), 90 cl::values(clEnumValN(DumpLLVMIR, "llvm", "output the LLVM IR dump")), 91 cl::values( 92 clEnumValN(RunJIT, "jit", 93 "JIT the code and run it by invoking the main function"))); 94 95 static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations")); 96 97 /// Returns a Toy AST resulting from parsing the file or a nullptr on error. 98 std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) { 99 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = 100 llvm::MemoryBuffer::getFileOrSTDIN(filename); 101 if (std::error_code ec = fileOrErr.getError()) { 102 llvm::errs() << "Could not open input file: " << ec.message() << "\n"; 103 return nullptr; 104 } 105 auto buffer = fileOrErr.get()->getBuffer(); 106 LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename)); 107 Parser parser(lexer); 108 return parser.parseModule(); 109 } 110 111 int loadMLIR(mlir::MLIRContext &context, 112 mlir::OwningOpRef<mlir::ModuleOp> &module) { 113 // Handle '.toy' input to the compiler. 114 if (inputType != InputType::MLIR && 115 !llvm::StringRef(inputFilename).ends_with(".mlir")) { 116 auto moduleAST = parseInputFile(inputFilename); 117 if (!moduleAST) 118 return 6; 119 module = mlirGen(context, *moduleAST); 120 return !module ? 1 : 0; 121 } 122 123 // Otherwise, the input is '.mlir'. 124 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = 125 llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); 126 if (std::error_code ec = fileOrErr.getError()) { 127 llvm::errs() << "Could not open input file: " << ec.message() << "\n"; 128 return -1; 129 } 130 131 // Parse the input mlir. 132 llvm::SourceMgr sourceMgr; 133 sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); 134 module = mlir::parseSourceFile<mlir::ModuleOp>(sourceMgr, &context); 135 if (!module) { 136 llvm::errs() << "Error can't load file " << inputFilename << "\n"; 137 return 3; 138 } 139 return 0; 140 } 141 142 int loadAndProcessMLIR(mlir::MLIRContext &context, 143 mlir::OwningOpRef<mlir::ModuleOp> &module) { 144 if (int error = loadMLIR(context, module)) 145 return error; 146 147 mlir::PassManager pm(module.get()->getName()); 148 // Apply any generic pass manager command line options and run the pipeline. 149 if (mlir::failed(mlir::applyPassManagerCLOptions(pm))) 150 return 4; 151 152 // Check to see what granularity of MLIR we are compiling to. 153 bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine; 154 bool isLoweringToLLVM = emitAction >= Action::DumpMLIRLLVM; 155 156 if (enableOpt || isLoweringToAffine) { 157 // Inline all functions into main and then delete them. 158 pm.addPass(mlir::createInlinerPass()); 159 160 // Now that there is only one function, we can infer the shapes of each of 161 // the operations. 162 mlir::OpPassManager &optPM = pm.nest<mlir::toy::FuncOp>(); 163 optPM.addPass(mlir::toy::createShapeInferencePass()); 164 optPM.addPass(mlir::createCanonicalizerPass()); 165 optPM.addPass(mlir::createCSEPass()); 166 } 167 168 if (isLoweringToAffine) { 169 // Partially lower the toy dialect. 170 pm.addPass(mlir::toy::createLowerToAffinePass()); 171 172 // Add a few cleanups post lowering. 173 mlir::OpPassManager &optPM = pm.nest<mlir::func::FuncOp>(); 174 optPM.addPass(mlir::createCanonicalizerPass()); 175 optPM.addPass(mlir::createCSEPass()); 176 177 // Add optimizations if enabled. 178 if (enableOpt) { 179 optPM.addPass(mlir::affine::createLoopFusionPass()); 180 optPM.addPass(mlir::affine::createAffineScalarReplacementPass()); 181 } 182 } 183 184 if (isLoweringToLLVM) { 185 // Finish lowering the toy IR to the LLVM dialect. 186 pm.addPass(mlir::toy::createLowerToLLVMPass()); 187 // This is necessary to have line tables emitted and basic 188 // debugger working. In the future we will add proper debug information 189 // emission directly from our frontend. 190 pm.addPass(mlir::LLVM::createDIScopeForLLVMFuncOpPass()); 191 } 192 193 if (mlir::failed(pm.run(*module))) 194 return 4; 195 return 0; 196 } 197 198 int dumpAST() { 199 if (inputType == InputType::MLIR) { 200 llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; 201 return 5; 202 } 203 204 auto moduleAST = parseInputFile(inputFilename); 205 if (!moduleAST) 206 return 1; 207 208 dump(*moduleAST); 209 return 0; 210 } 211 212 int dumpLLVMIR(mlir::ModuleOp module) { 213 // Register the translation to LLVM IR with the MLIR context. 214 mlir::registerBuiltinDialectTranslation(*module->getContext()); 215 mlir::registerLLVMDialectTranslation(*module->getContext()); 216 217 // Convert the module to LLVM IR in a new LLVM IR context. 218 llvm::LLVMContext llvmContext; 219 auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext); 220 if (!llvmModule) { 221 llvm::errs() << "Failed to emit LLVM IR\n"; 222 return -1; 223 } 224 225 // Initialize LLVM targets. 226 llvm::InitializeNativeTarget(); 227 llvm::InitializeNativeTargetAsmPrinter(); 228 229 // Configure the LLVM Module 230 auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); 231 if (!tmBuilderOrError) { 232 llvm::errs() << "Could not create JITTargetMachineBuilder\n"; 233 return -1; 234 } 235 236 auto tmOrError = tmBuilderOrError->createTargetMachine(); 237 if (!tmOrError) { 238 llvm::errs() << "Could not create TargetMachine\n"; 239 return -1; 240 } 241 mlir::ExecutionEngine::setupTargetTripleAndDataLayout(llvmModule.get(), 242 tmOrError.get().get()); 243 244 /// Optionally run an optimization pipeline over the llvm module. 245 auto optPipeline = mlir::makeOptimizingTransformer( 246 /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, 247 /*targetMachine=*/nullptr); 248 if (auto err = optPipeline(llvmModule.get())) { 249 llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; 250 return -1; 251 } 252 llvm::errs() << *llvmModule << "\n"; 253 return 0; 254 } 255 256 int runJit(mlir::ModuleOp module) { 257 // Initialize LLVM targets. 258 llvm::InitializeNativeTarget(); 259 llvm::InitializeNativeTargetAsmPrinter(); 260 261 // Register the translation from MLIR to LLVM IR, which must happen before we 262 // can JIT-compile. 263 mlir::registerBuiltinDialectTranslation(*module->getContext()); 264 mlir::registerLLVMDialectTranslation(*module->getContext()); 265 266 // An optimization pipeline to use within the execution engine. 267 auto optPipeline = mlir::makeOptimizingTransformer( 268 /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, 269 /*targetMachine=*/nullptr); 270 271 // Create an MLIR execution engine. The execution engine eagerly JIT-compiles 272 // the module. 273 mlir::ExecutionEngineOptions engineOptions; 274 engineOptions.transformer = optPipeline; 275 auto maybeEngine = mlir::ExecutionEngine::create(module, engineOptions); 276 assert(maybeEngine && "failed to construct an execution engine"); 277 auto &engine = maybeEngine.get(); 278 279 // Invoke the JIT-compiled function. 280 auto invocationResult = engine->invokePacked("main"); 281 if (invocationResult) { 282 llvm::errs() << "JIT invocation failed\n"; 283 return -1; 284 } 285 286 return 0; 287 } 288 289 int main(int argc, char **argv) { 290 // Register any command line options. 291 mlir::registerAsmPrinterCLOptions(); 292 mlir::registerMLIRContextCLOptions(); 293 mlir::registerPassManagerCLOptions(); 294 295 cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); 296 297 if (emitAction == Action::DumpAST) 298 return dumpAST(); 299 300 // If we aren't dumping the AST, then we are compiling with/to MLIR. 301 mlir::DialectRegistry registry; 302 mlir::func::registerAllExtensions(registry); 303 mlir::LLVM::registerInlinerInterface(registry); 304 305 mlir::MLIRContext context(registry); 306 // Load our Dialect in this MLIR Context. 307 context.getOrLoadDialect<mlir::toy::ToyDialect>(); 308 309 mlir::OwningOpRef<mlir::ModuleOp> module; 310 if (int error = loadAndProcessMLIR(context, module)) 311 return error; 312 313 // If we aren't exporting to non-mlir, then we are done. 314 bool isOutputingMLIR = emitAction <= Action::DumpMLIRLLVM; 315 if (isOutputingMLIR) { 316 module->dump(); 317 return 0; 318 } 319 320 // Check to see if we are compiling to LLVM IR. 321 if (emitAction == Action::DumpLLVMIR) 322 return dumpLLVMIR(*module); 323 324 // Otherwise, we must be running the jit. 325 if (emitAction == Action::RunJIT) 326 return runJit(*module); 327 328 llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n"; 329 return -1; 330 } 331