xref: /llvm-project/llvm/unittests/ExecutionEngine/MCJIT/MCJITTestBase.h (revision 0a1aa6cda2758b0926a95f87d39ffefb1cb90200)
1 //===- MCJITTestBase.h - Common base class for MCJIT Unit tests -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This class implements common functionality required by the MCJIT unit tests,
10 // as well as logic to skip tests on unsupported architectures and operating
11 // systems.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef LLVM_UNITTESTS_EXECUTIONENGINE_MCJIT_MCJITTESTBASE_H
16 #define LLVM_UNITTESTS_EXECUTIONENGINE_MCJIT_MCJITTESTBASE_H
17 
18 #include "MCJITTestAPICommon.h"
19 #include "llvm/Config/config.h"
20 #include "llvm/ExecutionEngine/ExecutionEngine.h"
21 #include "llvm/ExecutionEngine/SectionMemoryManager.h"
22 #include "llvm/IR/Function.h"
23 #include "llvm/IR/IRBuilder.h"
24 #include "llvm/IR/LLVMContext.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/IR/Type.h"
27 #include "llvm/Support/CodeGen.h"
28 
29 namespace llvm {
30 
31 /// Helper class that can build very simple Modules
32 class TrivialModuleBuilder {
33 protected:
34   LLVMContext Context;
35   IRBuilder<> Builder;
36   std::string BuilderTriple;
37 
TrivialModuleBuilder(const std::string & Triple)38   TrivialModuleBuilder(const std::string &Triple)
39     : Builder(Context), BuilderTriple(Triple) {}
40 
41   Module *createEmptyModule(StringRef Name = StringRef()) {
42     Module * M = new Module(Name, Context);
43     M->setTargetTriple(Triple::normalize(BuilderTriple));
44     return M;
45   }
46 
startFunction(Module * M,FunctionType * FT,StringRef Name)47   Function *startFunction(Module *M, FunctionType *FT, StringRef Name) {
48     Function *Result =
49         Function::Create(FT, GlobalValue::ExternalLinkage, Name, M);
50 
51     BasicBlock *BB = BasicBlock::Create(Context, Name, Result);
52     Builder.SetInsertPoint(BB);
53 
54     return Result;
55   }
56 
endFunctionWithRet(Function * Func,Value * RetValue)57   void endFunctionWithRet(Function *Func, Value *RetValue) {
58     Builder.CreateRet(RetValue);
59   }
60 
61   // Inserts a simple function that invokes Callee and takes the same arguments:
62   //    int Caller(...) { return Callee(...); }
insertSimpleCallFunction(Module * M,Function * Callee)63   Function *insertSimpleCallFunction(Module *M, Function *Callee) {
64     Function *Result = startFunction(M, Callee->getFunctionType(), "caller");
65 
66     SmallVector<Value*, 1> CallArgs;
67 
68     for (Argument &A : Result->args())
69       CallArgs.push_back(&A);
70 
71     Value *ReturnCode = Builder.CreateCall(Callee, CallArgs);
72     Builder.CreateRet(ReturnCode);
73     return Result;
74   }
75 
76   // Inserts a function named 'main' that returns a uint32_t:
77   //    int32_t main() { return X; }
78   // where X is given by returnCode
insertMainFunction(Module * M,uint32_t returnCode)79   Function *insertMainFunction(Module *M, uint32_t returnCode) {
80     Function *Result = startFunction(
81         M, FunctionType::get(Type::getInt32Ty(Context), {}, false), "main");
82 
83     Value *ReturnVal = ConstantInt::get(Context, APInt(32, returnCode));
84     endFunctionWithRet(Result, ReturnVal);
85 
86     return Result;
87   }
88 
89   // Inserts a function
90   //    int32_t add(int32_t a, int32_t b) { return a + b; }
91   // in the current module and returns a pointer to it.
92   Function *insertAddFunction(Module *M, StringRef Name = "add") {
93     Function *Result = startFunction(
94         M,
95         FunctionType::get(
96             Type::getInt32Ty(Context),
97             {Type::getInt32Ty(Context), Type::getInt32Ty(Context)}, false),
98         Name);
99 
100     Function::arg_iterator args = Result->arg_begin();
101     Value *Arg1 = &*args;
102     Value *Arg2 = &*++args;
103     Value *AddResult = Builder.CreateAdd(Arg1, Arg2);
104 
105     endFunctionWithRet(Result, AddResult);
106 
107     return Result;
108   }
109 
110   // Inserts a declaration to a function defined elsewhere
insertExternalReferenceToFunction(Module * M,FunctionType * FTy,StringRef Name)111   Function *insertExternalReferenceToFunction(Module *M, FunctionType *FTy,
112                                               StringRef Name) {
113     Function *Result =
114         Function::Create(FTy, GlobalValue::ExternalLinkage, Name, M);
115     return Result;
116   }
117 
118   // Inserts an declaration to a function defined elsewhere
insertExternalReferenceToFunction(Module * M,Function * Func)119   Function *insertExternalReferenceToFunction(Module *M, Function *Func) {
120     Function *Result = Function::Create(Func->getFunctionType(),
121                                         GlobalValue::ExternalLinkage,
122                                         Func->getName(), M);
123     return Result;
124   }
125 
126   // Inserts a global variable of type int32
127   // FIXME: make this a template function to support any type
insertGlobalInt32(Module * M,StringRef name,int32_t InitialValue)128   GlobalVariable *insertGlobalInt32(Module *M,
129                                     StringRef name,
130                                     int32_t InitialValue) {
131     Type *GlobalTy = Type::getInt32Ty(Context);
132     Constant *IV = ConstantInt::get(Context, APInt(32, InitialValue));
133     GlobalVariable *Global = new GlobalVariable(*M,
134                                                 GlobalTy,
135                                                 false,
136                                                 GlobalValue::ExternalLinkage,
137                                                 IV,
138                                                 name);
139     return Global;
140   }
141 
142   // Inserts a function
143   //   int32_t recursive_add(int32_t num) {
144   //     if (num == 0) {
145   //       return num;
146   //     } else {
147   //       int32_t recursive_param = num - 1;
148   //       return num + Helper(recursive_param);
149   //     }
150   //   }
151   // NOTE: if Helper is left as the default parameter, Helper == recursive_add.
152   Function *insertAccumulateFunction(Module *M,
153                                      Function *Helper = nullptr,
154                                      StringRef Name = "accumulate") {
155     Function *Result =
156         startFunction(M,
157                       FunctionType::get(Type::getInt32Ty(Context),
158                                         {Type::getInt32Ty(Context)}, false),
159                       Name);
160     if (!Helper)
161       Helper = Result;
162 
163     BasicBlock *BaseCase = BasicBlock::Create(Context, "", Result);
164     BasicBlock *RecursiveCase = BasicBlock::Create(Context, "", Result);
165 
166     // if (num == 0)
167     Value *Param = &*Result->arg_begin();
168     Value *Zero = ConstantInt::get(Context, APInt(32, 0));
169     Builder.CreateCondBr(Builder.CreateICmpEQ(Param, Zero),
170                          BaseCase, RecursiveCase);
171 
172     //   return num;
173     Builder.SetInsertPoint(BaseCase);
174     Builder.CreateRet(Param);
175 
176     //   int32_t recursive_param = num - 1;
177     //   return Helper(recursive_param);
178     Builder.SetInsertPoint(RecursiveCase);
179     Value *One = ConstantInt::get(Context, APInt(32, 1));
180     Value *RecursiveParam = Builder.CreateSub(Param, One);
181     Value *RecursiveReturn = Builder.CreateCall(Helper, RecursiveParam);
182     Value *Accumulator = Builder.CreateAdd(Param, RecursiveReturn);
183     Builder.CreateRet(Accumulator);
184 
185     return Result;
186   }
187 
188   // Populates Modules A and B:
189   // Module A { Extern FB1, Function FA which calls FB1 },
190   // Module B { Extern FA, Function FB1, Function FB2 which calls FA },
createCrossModuleRecursiveCase(std::unique_ptr<Module> & A,Function * & FA,std::unique_ptr<Module> & B,Function * & FB1,Function * & FB2)191   void createCrossModuleRecursiveCase(std::unique_ptr<Module> &A, Function *&FA,
192                                       std::unique_ptr<Module> &B,
193                                       Function *&FB1, Function *&FB2) {
194     // Define FB1 in B.
195     B.reset(createEmptyModule("B"));
196     FB1 = insertAccumulateFunction(B.get(), nullptr, "FB1");
197 
198     // Declare FB1 in A (as an external).
199     A.reset(createEmptyModule("A"));
200     Function *FB1Extern = insertExternalReferenceToFunction(A.get(), FB1);
201 
202     // Define FA in A (with a call to FB1).
203     FA = insertAccumulateFunction(A.get(), FB1Extern, "FA");
204 
205     // Declare FA in B (as an external)
206     Function *FAExtern = insertExternalReferenceToFunction(B.get(), FA);
207 
208     // Define FB2 in B (with a call to FA)
209     FB2 = insertAccumulateFunction(B.get(), FAExtern, "FB2");
210   }
211 
212   // Module A { Function FA },
213   // Module B { Extern FA, Function FB which calls FA },
214   // Module C { Extern FB, Function FC which calls FB },
215   void
createThreeModuleChainedCallsCase(std::unique_ptr<Module> & A,Function * & FA,std::unique_ptr<Module> & B,Function * & FB,std::unique_ptr<Module> & C,Function * & FC)216   createThreeModuleChainedCallsCase(std::unique_ptr<Module> &A, Function *&FA,
217                                     std::unique_ptr<Module> &B, Function *&FB,
218                                     std::unique_ptr<Module> &C, Function *&FC) {
219     A.reset(createEmptyModule("A"));
220     FA = insertAddFunction(A.get());
221 
222     B.reset(createEmptyModule("B"));
223     Function *FAExtern_in_B = insertExternalReferenceToFunction(B.get(), FA);
224     FB = insertSimpleCallFunction(B.get(), FAExtern_in_B);
225 
226     C.reset(createEmptyModule("C"));
227     Function *FBExtern_in_C = insertExternalReferenceToFunction(C.get(), FB);
228     FC = insertSimpleCallFunction(C.get(), FBExtern_in_C);
229   }
230 
231   // Module A { Function FA },
232   // Populates Modules A and B:
233   // Module B { Function FB }
createTwoModuleCase(std::unique_ptr<Module> & A,Function * & FA,std::unique_ptr<Module> & B,Function * & FB)234   void createTwoModuleCase(std::unique_ptr<Module> &A, Function *&FA,
235                            std::unique_ptr<Module> &B, Function *&FB) {
236     A.reset(createEmptyModule("A"));
237     FA = insertAddFunction(A.get());
238 
239     B.reset(createEmptyModule("B"));
240     FB = insertAddFunction(B.get());
241   }
242 
243   // Module A { Function FA },
244   // Module B { Extern FA, Function FB which calls FA }
createTwoModuleExternCase(std::unique_ptr<Module> & A,Function * & FA,std::unique_ptr<Module> & B,Function * & FB)245   void createTwoModuleExternCase(std::unique_ptr<Module> &A, Function *&FA,
246                                  std::unique_ptr<Module> &B, Function *&FB) {
247     A.reset(createEmptyModule("A"));
248     FA = insertAddFunction(A.get());
249 
250     B.reset(createEmptyModule("B"));
251     Function *FAExtern_in_B = insertExternalReferenceToFunction(B.get(), FA);
252     FB = insertSimpleCallFunction(B.get(), FAExtern_in_B);
253   }
254 
255   // Module A { Function FA },
256   // Module B { Extern FA, Function FB which calls FA },
257   // Module C { Extern FB, Function FC which calls FA },
createThreeModuleCase(std::unique_ptr<Module> & A,Function * & FA,std::unique_ptr<Module> & B,Function * & FB,std::unique_ptr<Module> & C,Function * & FC)258   void createThreeModuleCase(std::unique_ptr<Module> &A, Function *&FA,
259                              std::unique_ptr<Module> &B, Function *&FB,
260                              std::unique_ptr<Module> &C, Function *&FC) {
261     A.reset(createEmptyModule("A"));
262     FA = insertAddFunction(A.get());
263 
264     B.reset(createEmptyModule("B"));
265     Function *FAExtern_in_B = insertExternalReferenceToFunction(B.get(), FA);
266     FB = insertSimpleCallFunction(B.get(), FAExtern_in_B);
267 
268     C.reset(createEmptyModule("C"));
269     Function *FAExtern_in_C = insertExternalReferenceToFunction(C.get(), FA);
270     FC = insertSimpleCallFunction(C.get(), FAExtern_in_C);
271   }
272 };
273 
274 class MCJITTestBase : public MCJITTestAPICommon, public TrivialModuleBuilder {
275 protected:
MCJITTestBase()276   MCJITTestBase()
277       : TrivialModuleBuilder(HostTriple), OptLevel(CodeGenOptLevel::None),
278         CodeModel(CodeModel::Small), MArch(""), MM(new SectionMemoryManager) {
279     // The architectures below are known to be compatible with MCJIT as they
280     // are copied from test/ExecutionEngine/MCJIT/lit.local.cfg and should be
281     // kept in sync.
282     SupportedArchs.push_back(Triple::aarch64);
283     SupportedArchs.push_back(Triple::arm);
284     SupportedArchs.push_back(Triple::mips);
285     SupportedArchs.push_back(Triple::mipsel);
286     SupportedArchs.push_back(Triple::mips64);
287     SupportedArchs.push_back(Triple::mips64el);
288     SupportedArchs.push_back(Triple::x86);
289     SupportedArchs.push_back(Triple::x86_64);
290 
291     // Some architectures have sub-architectures in which tests will fail, like
292     // ARM. These two vectors will define if they do have sub-archs (to avoid
293     // extra work for those who don't), and if so, if they are listed to work
294     HasSubArchs.push_back(Triple::arm);
295     SupportedSubArchs.push_back("armv6");
296     SupportedSubArchs.push_back("armv7");
297 
298     UnsupportedEnvironments.push_back(Triple::Cygnus);
299   }
300 
createJIT(std::unique_ptr<Module> M)301   void createJIT(std::unique_ptr<Module> M) {
302 
303     // Due to the EngineBuilder constructor, it is required to have a Module
304     // in order to construct an ExecutionEngine (i.e. MCJIT)
305     assert(M != 0 && "a non-null Module must be provided to create MCJIT");
306 
307     EngineBuilder EB(std::move(M));
308     std::string Error;
309     TheJIT.reset(EB.setEngineKind(EngineKind::JIT)
310                      .setMCJITMemoryManager(std::move(MM))
311                      .setErrorStr(&Error)
312                      .setOptLevel(CodeGenOptLevel::None)
313                      .setMArch(MArch)
314                      .setMCPU(sys::getHostCPUName())
315                      //.setMAttrs(MAttrs)
316                      .create());
317     // At this point, we cannot modify the module any more.
318     assert(TheJIT.get() != NULL && "error creating MCJIT with EngineBuilder");
319   }
320 
321   CodeGenOptLevel OptLevel;
322   CodeModel::Model CodeModel;
323   StringRef MArch;
324   SmallVector<std::string, 1> MAttrs;
325   std::unique_ptr<ExecutionEngine> TheJIT;
326   std::unique_ptr<RTDyldMemoryManager> MM;
327 
328   std::unique_ptr<Module> M;
329 };
330 
331 } // namespace llvm
332 
333 #endif // LLVM_UNITTESTS_EXECUTIONENGINE_MCJIT_MCJITTESTBASE_H
334