xref: /llvm-project/mlir/lib/ExecutionEngine/JitRunner.cpp (revision 0a1aa6cda2758b0926a95f87d39ffefb1cb90200)
1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a library that provides a shared implementation for command line
10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
11 // IR before JIT-compiling and executing the latter.
12 //
13 // The translation can be customized by providing an MLIR to MLIR
14 // transformation.
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/JitRunner.h"
18 
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/ExecutionEngine/ExecutionEngine.h"
21 #include "mlir/ExecutionEngine/OptUtils.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/Parser/Parser.h"
25 #include "mlir/Support/FileUtilities.h"
26 #include "mlir/Tools/ParseUtilities.h"
27 
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
30 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
31 #include "llvm/IR/IRBuilder.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/LegacyPassNameParser.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/FileUtilities.h"
37 #include "llvm/Support/SourceMgr.h"
38 #include "llvm/Support/StringSaver.h"
39 #include "llvm/Support/ToolOutputFile.h"
40 #include <cstdint>
41 #include <numeric>
42 #include <optional>
43 #include <utility>
44 
45 #define DEBUG_TYPE "jit-runner"
46 
47 using namespace mlir;
48 using llvm::Error;
49 
50 namespace {
51 /// This options struct prevents the need for global static initializers, and
52 /// is only initialized if the JITRunner is invoked.
53 struct Options {
54   llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional,
55                                            llvm::cl::desc("<input file>"),
56                                            llvm::cl::init("-")};
57   llvm::cl::opt<std::string> mainFuncName{
58       "e", llvm::cl::desc("The function to be called"),
59       llvm::cl::value_desc("<function name>"), llvm::cl::init("main")};
60   llvm::cl::opt<std::string> mainFuncType{
61       "entry-point-result",
62       llvm::cl::desc("Textual description of the function type to be called"),
63       llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")};
64 
65   llvm::cl::OptionCategory optFlags{"opt-like flags"};
66 
67   // CLI variables for -On options.
68   llvm::cl::opt<bool> optO0{"O0",
69                             llvm::cl::desc("Run opt passes and codegen at O0"),
70                             llvm::cl::cat(optFlags)};
71   llvm::cl::opt<bool> optO1{"O1",
72                             llvm::cl::desc("Run opt passes and codegen at O1"),
73                             llvm::cl::cat(optFlags)};
74   llvm::cl::opt<bool> optO2{"O2",
75                             llvm::cl::desc("Run opt passes and codegen at O2"),
76                             llvm::cl::cat(optFlags)};
77   llvm::cl::opt<bool> optO3{"O3",
78                             llvm::cl::desc("Run opt passes and codegen at O3"),
79                             llvm::cl::cat(optFlags)};
80 
81   llvm::cl::list<std::string> mAttrs{
82       "mattr", llvm::cl::MiscFlags::CommaSeparated,
83       llvm::cl::desc("Target specific attributes (-mattr=help for details)"),
84       llvm::cl::value_desc("a1,+a2,-a3,..."), llvm::cl::cat(optFlags)};
85 
86   llvm::cl::opt<std::string> mArch{
87       "march",
88       llvm::cl::desc("Architecture to generate code for (see --version)")};
89 
90   llvm::cl::OptionCategory clOptionsCategory{"linking options"};
91   llvm::cl::list<std::string> clSharedLibs{
92       "shared-libs", llvm::cl::desc("Libraries to link dynamically"),
93       llvm::cl::MiscFlags::CommaSeparated, llvm::cl::cat(clOptionsCategory)};
94 
95   /// CLI variables for debugging.
96   llvm::cl::opt<bool> dumpObjectFile{
97       "dump-object-file",
98       llvm::cl::desc("Dump JITted-compiled object to file specified with "
99                      "-object-filename (<input file>.o by default).")};
100 
101   llvm::cl::opt<std::string> objectFilename{
102       "object-filename",
103       llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
104 
105   llvm::cl::opt<bool> hostSupportsJit{"host-supports-jit",
106                                       llvm::cl::desc("Report host JIT support"),
107                                       llvm::cl::Hidden};
108 
109   llvm::cl::opt<bool> noImplicitModule{
110       "no-implicit-module",
111       llvm::cl::desc(
112           "Disable implicit addition of a top-level module op during parsing"),
113       llvm::cl::init(false)};
114 };
115 
116 struct CompileAndExecuteConfig {
117   /// LLVM module transformer that is passed to ExecutionEngine.
118   std::function<llvm::Error(llvm::Module *)> transformer;
119 
120   /// A custom function that is passed to ExecutionEngine. It processes MLIR
121   /// module and creates LLVM IR module.
122   llvm::function_ref<std::unique_ptr<llvm::Module>(Operation *,
123                                                    llvm::LLVMContext &)>
124       llvmModuleBuilder;
125 
126   /// A custom function that is passed to ExecutinEngine to register symbols at
127   /// runtime.
128   llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
129       runtimeSymbolMap;
130 };
131 
132 } // namespace
133 
parseMLIRInput(StringRef inputFilename,bool insertImplicitModule,MLIRContext * context)134 static OwningOpRef<Operation *> parseMLIRInput(StringRef inputFilename,
135                                                bool insertImplicitModule,
136                                                MLIRContext *context) {
137   // Set up the input file.
138   std::string errorMessage;
139   auto file = openInputFile(inputFilename, &errorMessage);
140   if (!file) {
141     llvm::errs() << errorMessage << "\n";
142     return nullptr;
143   }
144 
145   auto sourceMgr = std::make_shared<llvm::SourceMgr>();
146   sourceMgr->AddNewSourceBuffer(std::move(file), SMLoc());
147   OwningOpRef<Operation *> module =
148       parseSourceFileForTool(sourceMgr, context, insertImplicitModule);
149   if (!module)
150     return nullptr;
151   if (!module.get()->hasTrait<OpTrait::SymbolTable>()) {
152     llvm::errs() << "Error: top-level op must be a symbol table.\n";
153     return nullptr;
154   }
155   return module;
156 }
157 
makeStringError(const Twine & message)158 static inline Error makeStringError(const Twine &message) {
159   return llvm::make_error<llvm::StringError>(message.str(),
160                                              llvm::inconvertibleErrorCode());
161 }
162 
getCommandLineOptLevel(Options & options)163 static std::optional<unsigned> getCommandLineOptLevel(Options &options) {
164   std::optional<unsigned> optLevel;
165   SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
166       options.optO0, options.optO1, options.optO2, options.optO3};
167 
168   // Determine if there is an optimization flag present.
169   for (unsigned j = 0; j < 4; ++j) {
170     auto &flag = optFlags[j].get();
171     if (flag) {
172       optLevel = j;
173       break;
174     }
175   }
176   return optLevel;
177 }
178 
179 // JIT-compile the given module and run "entryPoint" with "args" as arguments.
180 static Error
compileAndExecute(Options & options,Operation * module,StringRef entryPoint,CompileAndExecuteConfig config,void ** args,std::unique_ptr<llvm::TargetMachine> tm=nullptr)181 compileAndExecute(Options &options, Operation *module, StringRef entryPoint,
182                   CompileAndExecuteConfig config, void **args,
183                   std::unique_ptr<llvm::TargetMachine> tm = nullptr) {
184   std::optional<llvm::CodeGenOptLevel> jitCodeGenOptLevel;
185   if (auto clOptLevel = getCommandLineOptLevel(options))
186     jitCodeGenOptLevel = static_cast<llvm::CodeGenOptLevel>(*clOptLevel);
187 
188   SmallVector<StringRef, 4> sharedLibs(options.clSharedLibs.begin(),
189                                        options.clSharedLibs.end());
190 
191   mlir::ExecutionEngineOptions engineOptions;
192   engineOptions.llvmModuleBuilder = config.llvmModuleBuilder;
193   if (config.transformer)
194     engineOptions.transformer = config.transformer;
195   engineOptions.jitCodeGenOptLevel = jitCodeGenOptLevel;
196   engineOptions.sharedLibPaths = sharedLibs;
197   engineOptions.enableObjectDump = true;
198   auto expectedEngine =
199       mlir::ExecutionEngine::create(module, engineOptions, std::move(tm));
200   if (!expectedEngine)
201     return expectedEngine.takeError();
202 
203   auto engine = std::move(*expectedEngine);
204 
205   auto expectedFPtr = engine->lookupPacked(entryPoint);
206   if (!expectedFPtr)
207     return expectedFPtr.takeError();
208 
209   if (options.dumpObjectFile)
210     engine->dumpToObjectFile(options.objectFilename.empty()
211                                  ? options.inputFilename + ".o"
212                                  : options.objectFilename);
213 
214   void (*fptr)(void **) = *expectedFPtr;
215   (*fptr)(args);
216 
217   return Error::success();
218 }
219 
compileAndExecuteVoidFunction(Options & options,Operation * module,StringRef entryPoint,CompileAndExecuteConfig config,std::unique_ptr<llvm::TargetMachine> tm)220 static Error compileAndExecuteVoidFunction(
221     Options &options, Operation *module, StringRef entryPoint,
222     CompileAndExecuteConfig config, std::unique_ptr<llvm::TargetMachine> tm) {
223   auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(
224       SymbolTable::lookupSymbolIn(module, entryPoint));
225   if (!mainFunction || mainFunction.empty())
226     return makeStringError("entry point not found");
227 
228   auto resultType = dyn_cast<LLVM::LLVMVoidType>(
229       mainFunction.getFunctionType().getReturnType());
230   if (!resultType)
231     return makeStringError("expected void function");
232 
233   void *empty = nullptr;
234   return compileAndExecute(options, module, entryPoint, std::move(config),
235                            &empty, std::move(tm));
236 }
237 
238 template <typename Type>
239 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction);
240 template <>
checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction)241 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
242   auto resultType = dyn_cast<IntegerType>(
243       cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
244           .getReturnType());
245   if (!resultType || resultType.getWidth() != 32)
246     return makeStringError("only single i32 function result supported");
247   return Error::success();
248 }
249 template <>
checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction)250 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
251   auto resultType = dyn_cast<IntegerType>(
252       cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
253           .getReturnType());
254   if (!resultType || resultType.getWidth() != 64)
255     return makeStringError("only single i64 function result supported");
256   return Error::success();
257 }
258 template <>
checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction)259 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
260   if (!isa<Float32Type>(
261           cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
262               .getReturnType()))
263     return makeStringError("only single f32 function result supported");
264   return Error::success();
265 }
266 template <typename Type>
compileAndExecuteSingleReturnFunction(Options & options,Operation * module,StringRef entryPoint,CompileAndExecuteConfig config,std::unique_ptr<llvm::TargetMachine> tm)267 Error compileAndExecuteSingleReturnFunction(
268     Options &options, Operation *module, StringRef entryPoint,
269     CompileAndExecuteConfig config, std::unique_ptr<llvm::TargetMachine> tm) {
270   auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(
271       SymbolTable::lookupSymbolIn(module, entryPoint));
272   if (!mainFunction || mainFunction.isExternal())
273     return makeStringError("entry point not found");
274 
275   if (cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
276           .getNumParams() != 0)
277     return makeStringError("function inputs not supported");
278 
279   if (Error error = checkCompatibleReturnType<Type>(mainFunction))
280     return error;
281 
282   Type res;
283   struct {
284     void *data;
285   } data;
286   data.data = &res;
287   if (auto error =
288           compileAndExecute(options, module, entryPoint, std::move(config),
289                             (void **)&data, std::move(tm)))
290     return error;
291 
292   // Intentional printing of the output so we can test.
293   llvm::outs() << res << '\n';
294 
295   return Error::success();
296 }
297 
298 /// Entry point for all CPU runners. Expects the common argc/argv arguments for
299 /// standard C++ main functions.
JitRunnerMain(int argc,char ** argv,const DialectRegistry & registry,JitRunnerConfig config)300 int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry &registry,
301                         JitRunnerConfig config) {
302   llvm::ExitOnError exitOnErr;
303 
304   // Create the options struct containing the command line options for the
305   // runner. This must come before the command line options are parsed.
306   Options options;
307   llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
308 
309   if (options.hostSupportsJit) {
310     auto j = llvm::orc::LLJITBuilder().create();
311     if (j)
312       llvm::outs() << "true\n";
313     else {
314       llvm::outs() << "false\n";
315       exitOnErr(j.takeError());
316     }
317     return 0;
318   }
319 
320   std::optional<unsigned> optLevel = getCommandLineOptLevel(options);
321   SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
322       options.optO0, options.optO1, options.optO2, options.optO3};
323 
324   MLIRContext context(registry);
325 
326   auto m = parseMLIRInput(options.inputFilename, !options.noImplicitModule,
327                           &context);
328   if (!m) {
329     llvm::errs() << "could not parse the input IR\n";
330     return 1;
331   }
332 
333   JitRunnerOptions runnerOptions{options.mainFuncName, options.mainFuncType};
334   if (config.mlirTransformer)
335     if (failed(config.mlirTransformer(m.get(), runnerOptions)))
336       return EXIT_FAILURE;
337 
338   auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
339   if (!tmBuilderOrError) {
340     llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
341     return EXIT_FAILURE;
342   }
343 
344   // Configure TargetMachine builder based on the command line options
345   llvm::SubtargetFeatures features;
346   if (!options.mAttrs.empty()) {
347     for (StringRef attr : options.mAttrs)
348       features.AddFeature(attr);
349     tmBuilderOrError->addFeatures(features.getFeatures());
350   }
351 
352   if (!options.mArch.empty()) {
353     tmBuilderOrError->getTargetTriple().setArchName(options.mArch);
354   }
355 
356   // Build TargetMachine
357   auto tmOrError = tmBuilderOrError->createTargetMachine();
358 
359   if (!tmOrError) {
360     llvm::errs() << "Failed to create a TargetMachine for the host\n";
361     exitOnErr(tmOrError.takeError());
362   }
363 
364   LLVM_DEBUG({
365     llvm::dbgs() << "  JITTargetMachineBuilder is "
366                  << llvm::orc::JITTargetMachineBuilderPrinter(*tmBuilderOrError,
367                                                               "\n");
368   });
369 
370   CompileAndExecuteConfig compileAndExecuteConfig;
371   if (optLevel) {
372     compileAndExecuteConfig.transformer = mlir::makeOptimizingTransformer(
373         *optLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get());
374   }
375   compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder;
376   compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap;
377 
378   // Get the function used to compile and execute the module.
379   using CompileAndExecuteFnT =
380       Error (*)(Options &, Operation *, StringRef, CompileAndExecuteConfig,
381                 std::unique_ptr<llvm::TargetMachine> tm);
382   auto compileAndExecuteFn =
383       StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
384           .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
385           .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>)
386           .Case("f32", compileAndExecuteSingleReturnFunction<float>)
387           .Case("void", compileAndExecuteVoidFunction)
388           .Default(nullptr);
389 
390   Error error = compileAndExecuteFn
391                     ? compileAndExecuteFn(
392                           options, m.get(), options.mainFuncName.getValue(),
393                           compileAndExecuteConfig, std::move(tmOrError.get()))
394                     : makeStringError("unsupported function type");
395 
396   int exitCode = EXIT_SUCCESS;
397   llvm::handleAllErrors(std::move(error),
398                         [&exitCode](const llvm::ErrorInfoBase &info) {
399                           llvm::errs() << "Error: ";
400                           info.log(llvm::errs());
401                           llvm::errs() << '\n';
402                           exitCode = EXIT_FAILURE;
403                         });
404 
405   return exitCode;
406 }
407