xref: /llvm-project/mlir/lib/ExecutionEngine/ExecutionEngine.cpp (revision 884221eddb9d395830704fac79fd04008e02e368)
1 //===- ExecutionEngine.cpp - MLIR Execution engine and utils --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the execution engine for MLIR modules based on LLVM Orc
10 // JIT engine.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/ExecutionEngine/ExecutionEngine.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/Support/FileUtilities.h"
17 #include "mlir/Target/LLVMIR/Export.h"
18 
19 #include "llvm/ExecutionEngine/JITEventListener.h"
20 #include "llvm/ExecutionEngine/ObjectCache.h"
21 #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
22 #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
23 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
24 #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
25 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
26 #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
27 #include "llvm/IR/IRBuilder.h"
28 #include "llvm/MC/TargetRegistry.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/Error.h"
31 #include "llvm/Support/ToolOutputFile.h"
32 #include "llvm/TargetParser/Host.h"
33 #include "llvm/TargetParser/SubtargetFeature.h"
34 
35 #define DEBUG_TYPE "execution-engine"
36 
37 using namespace mlir;
38 using llvm::dbgs;
39 using llvm::Error;
40 using llvm::errs;
41 using llvm::Expected;
42 using llvm::LLVMContext;
43 using llvm::MemoryBuffer;
44 using llvm::MemoryBufferRef;
45 using llvm::Module;
46 using llvm::SectionMemoryManager;
47 using llvm::StringError;
48 using llvm::Triple;
49 using llvm::orc::DynamicLibrarySearchGenerator;
50 using llvm::orc::ExecutionSession;
51 using llvm::orc::IRCompileLayer;
52 using llvm::orc::JITTargetMachineBuilder;
53 using llvm::orc::MangleAndInterner;
54 using llvm::orc::RTDyldObjectLinkingLayer;
55 using llvm::orc::SymbolMap;
56 using llvm::orc::ThreadSafeModule;
57 using llvm::orc::TMOwningSimpleCompiler;
58 
59 /// Wrap a string into an llvm::StringError.
60 static Error makeStringError(const Twine &message) {
61   return llvm::make_error<StringError>(message.str(),
62                                        llvm::inconvertibleErrorCode());
63 }
64 
65 void SimpleObjectCache::notifyObjectCompiled(const Module *m,
66                                              MemoryBufferRef objBuffer) {
67   cachedObjects[m->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy(
68       objBuffer.getBuffer(), objBuffer.getBufferIdentifier());
69 }
70 
71 std::unique_ptr<MemoryBuffer> SimpleObjectCache::getObject(const Module *m) {
72   auto i = cachedObjects.find(m->getModuleIdentifier());
73   if (i == cachedObjects.end()) {
74     LLVM_DEBUG(dbgs() << "No object for " << m->getModuleIdentifier()
75                       << " in cache. Compiling.\n");
76     return nullptr;
77   }
78   LLVM_DEBUG(dbgs() << "Object for " << m->getModuleIdentifier()
79                     << " loaded from cache.\n");
80   return MemoryBuffer::getMemBuffer(i->second->getMemBufferRef());
81 }
82 
83 void SimpleObjectCache::dumpToObjectFile(StringRef outputFilename) {
84   // Set up the output file.
85   std::string errorMessage;
86   auto file = openOutputFile(outputFilename, &errorMessage);
87   if (!file) {
88     llvm::errs() << errorMessage << "\n";
89     return;
90   }
91 
92   // Dump the object generated for a single module to the output file.
93   assert(cachedObjects.size() == 1 && "Expected only one object entry.");
94   auto &cachedObject = cachedObjects.begin()->second;
95   file->os() << cachedObject->getBuffer();
96   file->keep();
97 }
98 
99 bool SimpleObjectCache::isEmpty() { return cachedObjects.empty(); }
100 
101 void ExecutionEngine::dumpToObjectFile(StringRef filename) {
102   if (cache == nullptr) {
103     llvm::errs() << "cannot dump ExecutionEngine object code to file: "
104                     "object cache is disabled\n";
105     return;
106   }
107   // Compilation is lazy and it doesn't populate object cache unless requested.
108   // In case object dump is requested before cache is populated, we need to
109   // force compilation manually.
110   if (cache->isEmpty()) {
111     for (std::string &functionName : functionNames) {
112       auto result = lookupPacked(functionName);
113       if (!result) {
114         llvm::errs() << "Could not compile " << functionName << ":\n  "
115                      << result.takeError() << "\n";
116         return;
117       }
118     }
119   }
120   cache->dumpToObjectFile(filename);
121 }
122 
123 void ExecutionEngine::registerSymbols(
124     llvm::function_ref<SymbolMap(MangleAndInterner)> symbolMap) {
125   auto &mainJitDylib = jit->getMainJITDylib();
126   cantFail(mainJitDylib.define(
127       absoluteSymbols(symbolMap(llvm::orc::MangleAndInterner(
128           mainJitDylib.getExecutionSession(), jit->getDataLayout())))));
129 }
130 
131 void ExecutionEngine::setupTargetTripleAndDataLayout(Module *llvmModule,
132                                                      llvm::TargetMachine *tm) {
133   llvmModule->setDataLayout(tm->createDataLayout());
134   llvmModule->setTargetTriple(tm->getTargetTriple().getTriple());
135 }
136 
137 static std::string makePackedFunctionName(StringRef name) {
138   return "_mlir_" + name.str();
139 }
140 
141 // For each function in the LLVM module, define an interface function that wraps
142 // all the arguments of the original function and all its results into an i8**
143 // pointer to provide a unified invocation interface.
144 static void packFunctionArguments(Module *module) {
145   auto &ctx = module->getContext();
146   llvm::IRBuilder<> builder(ctx);
147   DenseSet<llvm::Function *> interfaceFunctions;
148   for (auto &func : module->getFunctionList()) {
149     if (func.isDeclaration()) {
150       continue;
151     }
152     if (interfaceFunctions.count(&func)) {
153       continue;
154     }
155 
156     // Given a function `foo(<...>)`, define the interface function
157     // `mlir_foo(i8**)`.
158     auto *newType =
159         llvm::FunctionType::get(builder.getVoidTy(), builder.getPtrTy(),
160                                 /*isVarArg=*/false);
161     auto newName = makePackedFunctionName(func.getName());
162     auto funcCst = module->getOrInsertFunction(newName, newType);
163     llvm::Function *interfaceFunc = cast<llvm::Function>(funcCst.getCallee());
164     interfaceFunctions.insert(interfaceFunc);
165 
166     // Extract the arguments from the type-erased argument list and cast them to
167     // the proper types.
168     auto *bb = llvm::BasicBlock::Create(ctx);
169     bb->insertInto(interfaceFunc);
170     builder.SetInsertPoint(bb);
171     llvm::Value *argList = interfaceFunc->arg_begin();
172     SmallVector<llvm::Value *, 8> args;
173     args.reserve(llvm::size(func.args()));
174     for (auto [index, arg] : llvm::enumerate(func.args())) {
175       llvm::Value *argIndex = llvm::Constant::getIntegerValue(
176           builder.getInt64Ty(), APInt(64, index));
177       llvm::Value *argPtrPtr =
178           builder.CreateGEP(builder.getPtrTy(), argList, argIndex);
179       llvm::Value *argPtr = builder.CreateLoad(builder.getPtrTy(), argPtrPtr);
180       llvm::Type *argTy = arg.getType();
181       llvm::Value *load = builder.CreateLoad(argTy, argPtr);
182       args.push_back(load);
183     }
184 
185     // Call the implementation function with the extracted arguments.
186     llvm::Value *result = builder.CreateCall(&func, args);
187 
188     // Assuming the result is one value, potentially of type `void`.
189     if (!result->getType()->isVoidTy()) {
190       llvm::Value *retIndex = llvm::Constant::getIntegerValue(
191           builder.getInt64Ty(), APInt(64, llvm::size(func.args())));
192       llvm::Value *retPtrPtr =
193           builder.CreateGEP(builder.getPtrTy(), argList, retIndex);
194       llvm::Value *retPtr = builder.CreateLoad(builder.getPtrTy(), retPtrPtr);
195       builder.CreateStore(result, retPtr);
196     }
197 
198     // The interface function returns void.
199     builder.CreateRetVoid();
200   }
201 }
202 
203 ExecutionEngine::ExecutionEngine(bool enableObjectDump,
204                                  bool enableGDBNotificationListener,
205                                  bool enablePerfNotificationListener)
206     : cache(enableObjectDump ? new SimpleObjectCache() : nullptr),
207       functionNames(),
208       gdbListener(enableGDBNotificationListener
209                       ? llvm::JITEventListener::createGDBRegistrationListener()
210                       : nullptr),
211       perfListener(nullptr) {
212   if (enablePerfNotificationListener) {
213     if (auto *listener = llvm::JITEventListener::createPerfJITEventListener())
214       perfListener = listener;
215     else if (auto *listener =
216                  llvm::JITEventListener::createIntelJITEventListener())
217       perfListener = listener;
218   }
219 }
220 
221 ExecutionEngine::~ExecutionEngine() {
222   // Execute the global destructors from the module being processed.
223   // TODO: Allow JIT deinitialize for AArch64. Currently there's a bug causing a
224   // crash for AArch64 see related issue #71963.
225   if (jit && !jit->getTargetTriple().isAArch64())
226     llvm::consumeError(jit->deinitialize(jit->getMainJITDylib()));
227   // Run all dynamic library destroy callbacks to prepare for the shutdown.
228   for (LibraryDestroyFn destroy : destroyFns)
229     destroy();
230 }
231 
232 Expected<std::unique_ptr<ExecutionEngine>>
233 ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options,
234                         std::unique_ptr<llvm::TargetMachine> tm) {
235   auto engine = std::make_unique<ExecutionEngine>(
236       options.enableObjectDump, options.enableGDBNotificationListener,
237       options.enablePerfNotificationListener);
238 
239   // Remember all entry-points if object dumping is enabled.
240   if (options.enableObjectDump) {
241     for (auto funcOp : m->getRegion(0).getOps<LLVM::LLVMFuncOp>()) {
242       StringRef funcName = funcOp.getSymName();
243       engine->functionNames.push_back(funcName.str());
244     }
245   }
246 
247   std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext);
248   auto llvmModule = options.llvmModuleBuilder
249                         ? options.llvmModuleBuilder(m, *ctx)
250                         : translateModuleToLLVMIR(m, *ctx);
251   if (!llvmModule)
252     return makeStringError("could not convert to LLVM IR");
253 
254   // If no valid TargetMachine was passed, create a default TM ignoring any
255   // input arguments from the user.
256   if (!tm) {
257     auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
258     if (!tmBuilderOrError)
259       return tmBuilderOrError.takeError();
260 
261     auto tmOrError = tmBuilderOrError->createTargetMachine();
262     if (!tmOrError)
263       return tmOrError.takeError();
264     tm = std::move(tmOrError.get());
265   }
266 
267   // TODO: Currently, the LLVM module created above has no triple associated
268   // with it. Instead, the triple is extracted from the TargetMachine, which is
269   // either based on the host defaults or command line arguments when specified
270   // (set-up by callers of this method). It could also be passed to the
271   // translation or dialect conversion instead of this.
272   setupTargetTripleAndDataLayout(llvmModule.get(), tm.get());
273   packFunctionArguments(llvmModule.get());
274 
275   auto dataLayout = llvmModule->getDataLayout();
276 
277   // Use absolute library path so that gdb can find the symbol table.
278   SmallVector<SmallString<256>, 4> sharedLibPaths;
279   transform(
280       options.sharedLibPaths, std::back_inserter(sharedLibPaths),
281       [](StringRef libPath) {
282         SmallString<256> absPath(libPath.begin(), libPath.end());
283         cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath)));
284         return absPath;
285       });
286 
287   // If shared library implements custom execution layer library init and
288   // destroy functions, we'll use them to register the library. Otherwise, load
289   // the library as JITDyLib below.
290   llvm::StringMap<void *> exportSymbols;
291   SmallVector<LibraryDestroyFn> destroyFns;
292   SmallVector<StringRef> jitDyLibPaths;
293 
294   for (auto &libPath : sharedLibPaths) {
295     auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(
296         libPath.str().str().c_str());
297     void *initSym = lib.getAddressOfSymbol(kLibraryInitFnName);
298     void *destroySim = lib.getAddressOfSymbol(kLibraryDestroyFnName);
299 
300     // Library does not provide call backs, rely on symbol visiblity.
301     if (!initSym || !destroySim) {
302       jitDyLibPaths.push_back(libPath);
303       continue;
304     }
305 
306     auto initFn = reinterpret_cast<LibraryInitFn>(initSym);
307     initFn(exportSymbols);
308 
309     auto destroyFn = reinterpret_cast<LibraryDestroyFn>(destroySim);
310     destroyFns.push_back(destroyFn);
311   }
312   engine->destroyFns = std::move(destroyFns);
313 
314   // Callback to create the object layer with symbol resolution to current
315   // process and dynamically linked libraries.
316   auto objectLinkingLayerCreator = [&](ExecutionSession &session,
317                                        const Triple &tt) {
318     auto objectLayer = std::make_unique<RTDyldObjectLinkingLayer>(
319         session, [sectionMemoryMapper = options.sectionMemoryMapper]() {
320           return std::make_unique<SectionMemoryManager>(sectionMemoryMapper);
321         });
322 
323     // Register JIT event listeners if they are enabled.
324     if (engine->gdbListener)
325       objectLayer->registerJITEventListener(*engine->gdbListener);
326     if (engine->perfListener)
327       objectLayer->registerJITEventListener(*engine->perfListener);
328 
329     // COFF format binaries (Windows) need special handling to deal with
330     // exported symbol visibility.
331     // cf llvm/lib/ExecutionEngine/Orc/LLJIT.cpp LLJIT::createObjectLinkingLayer
332     llvm::Triple targetTriple(llvm::Twine(llvmModule->getTargetTriple()));
333     if (targetTriple.isOSBinFormatCOFF()) {
334       objectLayer->setOverrideObjectFlagsWithResponsibilityFlags(true);
335       objectLayer->setAutoClaimResponsibilityForObjectSymbols(true);
336     }
337 
338     // Resolve symbols from shared libraries.
339     for (auto &libPath : jitDyLibPaths) {
340       auto mb = llvm::MemoryBuffer::getFile(libPath);
341       if (!mb) {
342         errs() << "Failed to create MemoryBuffer for: " << libPath
343                << "\nError: " << mb.getError().message() << "\n";
344         continue;
345       }
346       auto &jd = session.createBareJITDylib(std::string(libPath));
347       auto loaded = DynamicLibrarySearchGenerator::Load(
348           libPath.str().c_str(), dataLayout.getGlobalPrefix());
349       if (!loaded) {
350         errs() << "Could not load " << libPath << ":\n  " << loaded.takeError()
351                << "\n";
352         continue;
353       }
354       jd.addGenerator(std::move(*loaded));
355       cantFail(objectLayer->add(jd, std::move(mb.get())));
356     }
357 
358     return objectLayer;
359   };
360 
361   // Callback to inspect the cache and recompile on demand. This follows Lang's
362   // LLJITWithObjectCache example.
363   auto compileFunctionCreator = [&](JITTargetMachineBuilder jtmb)
364       -> Expected<std::unique_ptr<IRCompileLayer::IRCompiler>> {
365     if (options.jitCodeGenOptLevel)
366       jtmb.setCodeGenOptLevel(*options.jitCodeGenOptLevel);
367     return std::make_unique<TMOwningSimpleCompiler>(std::move(tm),
368                                                     engine->cache.get());
369   };
370 
371   // Create the LLJIT by calling the LLJITBuilder with 2 callbacks.
372   auto jit =
373       cantFail(llvm::orc::LLJITBuilder()
374                    .setCompileFunctionCreator(compileFunctionCreator)
375                    .setObjectLinkingLayerCreator(objectLinkingLayerCreator)
376                    .setDataLayout(dataLayout)
377                    .create());
378 
379   // Add a ThreadSafemodule to the engine and return.
380   ThreadSafeModule tsm(std::move(llvmModule), std::move(ctx));
381   if (options.transformer)
382     cantFail(tsm.withModuleDo(
383         [&](llvm::Module &module) { return options.transformer(&module); }));
384   cantFail(jit->addIRModule(std::move(tsm)));
385   engine->jit = std::move(jit);
386 
387   // Resolve symbols that are statically linked in the current process.
388   llvm::orc::JITDylib &mainJD = engine->jit->getMainJITDylib();
389   mainJD.addGenerator(
390       cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess(
391           dataLayout.getGlobalPrefix())));
392 
393   // Build a runtime symbol map from the exported symbols and register them.
394   auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
395     auto symbolMap = llvm::orc::SymbolMap();
396     for (auto &exportSymbol : exportSymbols)
397       symbolMap[interner(exportSymbol.getKey())] = {
398           llvm::orc::ExecutorAddr::fromPtr(exportSymbol.getValue()),
399           llvm::JITSymbolFlags::Exported};
400     return symbolMap;
401   };
402   engine->registerSymbols(runtimeSymbolMap);
403 
404   // Execute the global constructors from the module being processed.
405   // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
406   // crash for AArch64 see related issue #71963.
407   if (!engine->jit->getTargetTriple().isAArch64())
408     cantFail(engine->jit->initialize(engine->jit->getMainJITDylib()));
409 
410   return std::move(engine);
411 }
412 
413 Expected<void (*)(void **)>
414 ExecutionEngine::lookupPacked(StringRef name) const {
415   auto result = lookup(makePackedFunctionName(name));
416   if (!result)
417     return result.takeError();
418   return reinterpret_cast<void (*)(void **)>(result.get());
419 }
420 
421 Expected<void *> ExecutionEngine::lookup(StringRef name) const {
422   auto expectedSymbol = jit->lookup(name);
423 
424   // JIT lookup may return an Error referring to strings stored internally by
425   // the JIT. If the Error outlives the ExecutionEngine, it would want have a
426   // dangling reference, which is currently caught by an assertion inside JIT
427   // thanks to hand-rolled reference counting. Rewrap the error message into a
428   // string before returning. Alternatively, ORC JIT should consider copying
429   // the string into the error message.
430   if (!expectedSymbol) {
431     std::string errorMessage;
432     llvm::raw_string_ostream os(errorMessage);
433     llvm::handleAllErrors(expectedSymbol.takeError(),
434                           [&os](llvm::ErrorInfoBase &ei) { ei.log(os); });
435     return makeStringError(errorMessage);
436   }
437 
438   if (void *fptr = expectedSymbol->toPtr<void *>())
439     return fptr;
440   return makeStringError("looked up function is null");
441 }
442 
443 Error ExecutionEngine::invokePacked(StringRef name,
444                                     MutableArrayRef<void *> args) {
445   auto expectedFPtr = lookupPacked(name);
446   if (!expectedFPtr)
447     return expectedFPtr.takeError();
448   auto fptr = *expectedFPtr;
449 
450   (*fptr)(args.data());
451 
452   return Error::success();
453 }
454