1 //===--- LLJITWithThinLTOSummaries.cpp - Module summaries as LLJIT input --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // In this example we will use a module summary index file produced for ThinLTO
10 // to (A) find the module that defines the main entry point and (B) find all
11 // extra modules that we need. We will do this in five steps:
12 //
13 // (1) Read the index file and parse the module summary index.
14 // (2) Find the path of the module that defines "main".
15 // (3) Parse the main module and create a matching LLJIT.
16 // (4) Add all modules to the LLJIT that are covered by the index.
17 // (5) Look up and run the JIT'd function.
18 //
19 // The index file name must be passed in as command line argument. Please find
20 // this test for instructions on creating the index file:
21 //
22 //       llvm/test/Examples/OrcV2Examples/lljit-with-thinlto-summaries.test
23 //
24 // If you use "build" as the build directory, you can run the test from the root
25 // of the monorepo like this:
26 //
27 // > build/bin/llvm-lit -a \
28 //       llvm/test/Examples/OrcV2Examples/lljit-with-thinlto-summaries.test
29 //
30 //===----------------------------------------------------------------------===//
31 
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/StringRef.h"
34 #include "llvm/Bitcode/BitcodeReader.h"
35 #include "llvm/ExecutionEngine/Orc/Core.h"
36 #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
37 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
38 #include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
39 #include "llvm/ExecutionEngine/Orc/TargetProcess/TargetExecutionUtils.h"
40 #include "llvm/IR/GlobalValue.h"
41 #include "llvm/IR/LLVMContext.h"
42 #include "llvm/IR/ModuleSummaryIndex.h"
43 #include "llvm/Support/CommandLine.h"
44 #include "llvm/Support/Error.h"
45 #include "llvm/Support/InitLLVM.h"
46 #include "llvm/Support/MemoryBuffer.h"
47 #include "llvm/Support/TargetSelect.h"
48 #include "llvm/Support/raw_ostream.h"
49 
50 #include <string>
51 #include <system_error>
52 #include <vector>
53 
54 using namespace llvm;
55 using namespace llvm::orc;
56 
57 // Path of the module summary index file.
58 cl::opt<std::string> IndexFile{cl::desc("<module summary index>"),
59                                cl::Positional, cl::init("-")};
60 
61 // Describe a fail state that is caused by the given ModuleSummaryIndex
62 // providing multiple definitions of the given global value name. It will dump
63 // name and GUID for the global value and list the paths of the modules covered
64 // by the index.
65 class DuplicateDefinitionInSummary
66     : public ErrorInfo<DuplicateDefinitionInSummary> {
67 public:
68   static char ID;
69 
DuplicateDefinitionInSummary(std::string GlobalValueName,ValueInfo VI)70   DuplicateDefinitionInSummary(std::string GlobalValueName, ValueInfo VI)
71       : GlobalValueName(std::move(GlobalValueName)) {
72     ModulePaths.reserve(VI.getSummaryList().size());
73     for (const auto &S : VI.getSummaryList())
74       ModulePaths.push_back(S->modulePath().str());
75     llvm::sort(ModulePaths);
76   }
77 
log(raw_ostream & OS) const78   void log(raw_ostream &OS) const override {
79     OS << "Duplicate symbol for global value '" << GlobalValueName
80        << "' (GUID: " << GlobalValue::getGUID(GlobalValueName) << ") in:\n";
81     for (const std::string &Path : ModulePaths) {
82       OS << "    " << Path << "\n";
83     }
84   }
85 
convertToErrorCode() const86   std::error_code convertToErrorCode() const override {
87     return inconvertibleErrorCode();
88   }
89 
90 private:
91   std::string GlobalValueName;
92   std::vector<std::string> ModulePaths;
93 };
94 
95 // Describe a fail state where the given global value name was not found in the
96 // given ModuleSummaryIndex. It will dump name and GUID for the global value and
97 // list the paths of the modules covered by the index.
98 class DefinitionNotFoundInSummary
99     : public ErrorInfo<DefinitionNotFoundInSummary> {
100 public:
101   static char ID;
102 
DefinitionNotFoundInSummary(std::string GlobalValueName,ModuleSummaryIndex & Index)103   DefinitionNotFoundInSummary(std::string GlobalValueName,
104                               ModuleSummaryIndex &Index)
105       : GlobalValueName(std::move(GlobalValueName)) {
106     ModulePaths.reserve(Index.modulePaths().size());
107     for (const auto &Entry : Index.modulePaths())
108       ModulePaths.push_back(Entry.first().str());
109     llvm::sort(ModulePaths);
110   }
111 
log(raw_ostream & OS) const112   void log(raw_ostream &OS) const override {
113     OS << "No symbol for global value '" << GlobalValueName
114        << "' (GUID: " << GlobalValue::getGUID(GlobalValueName) << ") in:\n";
115     for (const std::string &Path : ModulePaths) {
116       OS << "    " << Path << "\n";
117     }
118   }
119 
convertToErrorCode() const120   std::error_code convertToErrorCode() const override {
121     return llvm::inconvertibleErrorCode();
122   }
123 
124 private:
125   std::string GlobalValueName;
126   std::vector<std::string> ModulePaths;
127 };
128 
129 char DuplicateDefinitionInSummary::ID = 0;
130 char DefinitionNotFoundInSummary::ID = 0;
131 
132 // Lookup the a function in the ModuleSummaryIndex and return the path of the
133 // module that defines it. Paths in the ModuleSummaryIndex are relative to the
134 // build directory of the covered modules.
getMainModulePath(StringRef FunctionName,ModuleSummaryIndex & Index)135 Expected<StringRef> getMainModulePath(StringRef FunctionName,
136                                       ModuleSummaryIndex &Index) {
137   // Summaries use unmangled names.
138   GlobalValue::GUID G = GlobalValue::getGUID(FunctionName);
139   ValueInfo VI = Index.getValueInfo(G);
140 
141   // We need a unique definition, otherwise don't try further.
142   if (!VI || VI.getSummaryList().empty())
143     return make_error<DefinitionNotFoundInSummary>(FunctionName.str(), Index);
144   if (VI.getSummaryList().size() > 1)
145     return make_error<DuplicateDefinitionInSummary>(FunctionName.str(), VI);
146 
147   GlobalValueSummary *S = VI.getSummaryList().front()->getBaseObject();
148   if (!isa<FunctionSummary>(S))
149     return createStringError(inconvertibleErrorCode(),
150                              "Entry point is not a function: " + FunctionName);
151 
152   // Return a reference. ModuleSummaryIndex owns the module paths.
153   return S->modulePath();
154 }
155 
156 // Parse the bitcode module from the given path into a ThreadSafeModule.
loadModule(StringRef Path,orc::ThreadSafeContext TSCtx)157 Expected<ThreadSafeModule> loadModule(StringRef Path,
158                                       orc::ThreadSafeContext TSCtx) {
159   outs() << "About to load module: " << Path << "\n";
160 
161   Expected<std::unique_ptr<MemoryBuffer>> BitcodeBuffer =
162       errorOrToExpected(MemoryBuffer::getFile(Path));
163   if (!BitcodeBuffer)
164     return BitcodeBuffer.takeError();
165 
166   MemoryBufferRef BitcodeBufferRef = (**BitcodeBuffer).getMemBufferRef();
167   Expected<std::unique_ptr<Module>> M =
168       parseBitcodeFile(BitcodeBufferRef, *TSCtx.getContext());
169   if (!M)
170     return M.takeError();
171 
172   return ThreadSafeModule(std::move(*M), std::move(TSCtx));
173 }
174 
main(int Argc,char * Argv[])175 int main(int Argc, char *Argv[]) {
176   InitLLVM X(Argc, Argv);
177 
178   InitializeNativeTarget();
179   InitializeNativeTargetAsmPrinter();
180 
181   cl::ParseCommandLineOptions(Argc, Argv, "LLJITWithThinLTOSummaries");
182 
183   ExitOnError ExitOnErr;
184   ExitOnErr.setBanner(std::string(Argv[0]) + ": ");
185 
186   // (1) Read the index file and parse the module summary index.
187   std::unique_ptr<MemoryBuffer> SummaryBuffer =
188       ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(IndexFile)));
189 
190   std::unique_ptr<ModuleSummaryIndex> SummaryIndex =
191       ExitOnErr(getModuleSummaryIndex(SummaryBuffer->getMemBufferRef()));
192 
193   // (2) Find the path of the module that defines "main".
194   std::string MainFunctionName = "main";
195   StringRef MainModulePath =
196       ExitOnErr(getMainModulePath(MainFunctionName, *SummaryIndex));
197 
198   // (3) Parse the main module and create a matching LLJIT.
199   ThreadSafeContext TSCtx(std::make_unique<LLVMContext>());
200   ThreadSafeModule MainModule = ExitOnErr(loadModule(MainModulePath, TSCtx));
201 
202   auto Builder = LLJITBuilder();
203 
204   MainModule.withModuleDo([&](Module &M) {
205     if (M.getTargetTriple().empty()) {
206       Builder.setJITTargetMachineBuilder(
207           ExitOnErr(JITTargetMachineBuilder::detectHost()));
208     } else {
209       Builder.setJITTargetMachineBuilder(
210           JITTargetMachineBuilder(Triple(M.getTargetTriple())));
211     }
212     if (!M.getDataLayout().getStringRepresentation().empty())
213       Builder.setDataLayout(M.getDataLayout());
214   });
215 
216   auto J = ExitOnErr(Builder.create());
217 
218   // (4) Add all modules to the LLJIT that are covered by the index.
219   JITDylib &JD = J->getMainJITDylib();
220 
221   for (const auto &Entry : SummaryIndex->modulePaths()) {
222     StringRef Path = Entry.first();
223     ThreadSafeModule M = (Path == MainModulePath)
224                              ? std::move(MainModule)
225                              : ExitOnErr(loadModule(Path, TSCtx));
226     ExitOnErr(J->addIRModule(JD, std::move(M)));
227   }
228 
229   // (5) Look up and run the JIT'd function.
230   auto MainAddr = ExitOnErr(J->lookup(MainFunctionName));
231 
232   using MainFnPtr = int (*)(int, char *[]);
233   auto *MainFunction = MainAddr.toPtr<MainFnPtr>();
234 
235   int Result = runAsMain(MainFunction, {}, MainModulePath);
236   outs() << "'" << MainFunctionName << "' finished with exit code: " << Result
237          << "\n";
238 
239   return 0;
240 }
241