xref: /llvm-project/llvm/unittests/ExecutionEngine/Orc/ReOptimizeLayerTest.cpp (revision e5f7e73d90dd8ea7b1fa0e4e77ae11eabf398da9)
1 #include "llvm/ExecutionEngine/Orc/ReOptimizeLayer.h"
2 #include "OrcTestCommon.h"
3 #include "llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h"
4 #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
5 #include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h"
6 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
7 #include "llvm/ExecutionEngine/Orc/IRPartitionLayer.h"
8 #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
9 #include "llvm/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.h"
10 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
11 #include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h"
12 #include "llvm/ExecutionEngine/Orc/ObjectTransformLayer.h"
13 #include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h"
14 #include "llvm/IR/IRBuilder.h"
15 #include "llvm/Support/CodeGen.h"
16 #include "llvm/TargetParser/Host.h"
17 #include "llvm/Testing/Support/Error.h"
18 #include "gtest/gtest.h"
19 
20 using namespace llvm;
21 using namespace llvm::orc;
22 using namespace llvm::jitlink;
23 
24 class ReOptimizeLayerTest : public testing::Test {
25 public:
26   ~ReOptimizeLayerTest() {
27     if (ES)
28       if (auto Err = ES->endSession())
29         ES->reportError(std::move(Err));
30   }
31 
32 protected:
33   void SetUp() override {
34     auto JTMB = JITTargetMachineBuilder::detectHost();
35     // Bail out if we can not detect the host.
36     if (!JTMB) {
37       consumeError(JTMB.takeError());
38       GTEST_SKIP();
39     }
40 
41     // COFF-ARM64 is not supported yet
42     auto Triple = JTMB->getTargetTriple();
43     if (Triple.isOSBinFormatCOFF() && Triple.isAArch64())
44       GTEST_SKIP();
45 
46     auto EPC = SelfExecutorProcessControl::Create();
47     if (!EPC) {
48       consumeError(EPC.takeError());
49       GTEST_SKIP();
50     }
51 
52     auto DLOrErr = JTMB->getDefaultDataLayoutForTarget();
53     if (!DLOrErr) {
54       consumeError(DLOrErr.takeError());
55       GTEST_SKIP();
56     }
57     ES = std::make_unique<ExecutionSession>(std::move(*EPC));
58     JD = &ES->createBareJITDylib("main");
59     ObjLinkingLayer = std::make_unique<ObjectLinkingLayer>(
60         *ES, std::make_unique<InProcessMemoryManager>(16384));
61     DL = std::make_unique<DataLayout>(std::move(*DLOrErr));
62 
63     auto TM = JTMB->createTargetMachine();
64     if (!TM) {
65       consumeError(TM.takeError());
66       GTEST_SKIP();
67     }
68     auto CompileFunction =
69         std::make_unique<TMOwningSimpleCompiler>(std::move(*TM));
70     CompileLayer = std::make_unique<IRCompileLayer>(*ES, *ObjLinkingLayer,
71                                                     std::move(CompileFunction));
72   }
73 
74   Error addIRModule(ResourceTrackerSP RT, ThreadSafeModule TSM) {
75     assert(TSM && "Can not add null module");
76 
77     TSM.withModuleDo([&](Module &M) { M.setDataLayout(*DL); });
78 
79     return ROLayer->add(std::move(RT), std::move(TSM));
80   }
81 
82   JITDylib *JD{nullptr};
83   std::unique_ptr<ExecutionSession> ES;
84   std::unique_ptr<ObjectLinkingLayer> ObjLinkingLayer;
85   std::unique_ptr<IRCompileLayer> CompileLayer;
86   std::unique_ptr<ReOptimizeLayer> ROLayer;
87   std::unique_ptr<DataLayout> DL;
88 };
89 
90 static Function *createRetFunction(Module *M, StringRef Name,
91                                    uint32_t ReturnCode) {
92   Function *Result = Function::Create(
93       FunctionType::get(Type::getInt32Ty(M->getContext()), {}, false),
94       GlobalValue::ExternalLinkage, Name, M);
95 
96   BasicBlock *BB = BasicBlock::Create(M->getContext(), Name, Result);
97   IRBuilder<> Builder(M->getContext());
98   Builder.SetInsertPoint(BB);
99 
100   Value *RetValue = ConstantInt::get(M->getContext(), APInt(32, ReturnCode));
101   Builder.CreateRet(RetValue);
102   return Result;
103 }
104 
105 TEST_F(ReOptimizeLayerTest, BasicReOptimization) {
106   MangleAndInterner Mangle(*ES, *DL);
107 
108   auto &EPC = ES->getExecutorProcessControl();
109   EXPECT_THAT_ERROR(JD->define(absoluteSymbols(
110                         {{Mangle("__orc_rt_jit_dispatch"),
111                           {EPC.getJITDispatchInfo().JITDispatchFunction,
112                            JITSymbolFlags::Exported}},
113                          {Mangle("__orc_rt_jit_dispatch_ctx"),
114                           {EPC.getJITDispatchInfo().JITDispatchContext,
115                            JITSymbolFlags::Exported}},
116                          {Mangle("__orc_rt_reoptimize_tag"),
117                           {ExecutorAddr(), JITSymbolFlags::Exported}}})),
118                     Succeeded());
119 
120   auto RM = JITLinkRedirectableSymbolManager::Create(*ObjLinkingLayer, *JD);
121   EXPECT_THAT_ERROR(RM.takeError(), Succeeded());
122 
123   ROLayer = std::make_unique<ReOptimizeLayer>(*ES, *DL, *CompileLayer, **RM);
124   ROLayer->setReoptimizeFunc(
125       [&](ReOptimizeLayer &Parent,
126           ReOptimizeLayer::ReOptMaterializationUnitID MUID, unsigned CurVerison,
127           ResourceTrackerSP OldRT, ThreadSafeModule &TSM) {
128         TSM.withModuleDo([&](Module &M) {
129           for (auto &F : M) {
130             if (F.isDeclaration())
131               continue;
132             for (auto &B : F) {
133               for (auto &I : B) {
134                 if (ReturnInst *Ret = dyn_cast<ReturnInst>(&I)) {
135                   Value *RetValue =
136                       ConstantInt::get(M.getContext(), APInt(32, 53));
137                   Ret->setOperand(0, RetValue);
138                 }
139               }
140             }
141           }
142         });
143         return Error::success();
144       });
145   EXPECT_THAT_ERROR(ROLayer->reigsterRuntimeFunctions(*JD), Succeeded());
146 
147   ThreadSafeContext Ctx(std::make_unique<LLVMContext>());
148   auto M = std::make_unique<Module>("<main>", *Ctx.getContext());
149   M->setTargetTriple(sys::getProcessTriple());
150 
151   (void)createRetFunction(M.get(), "main", 42);
152 
153   EXPECT_THAT_ERROR(addIRModule(JD->getDefaultResourceTracker(),
154                                 ThreadSafeModule(std::move(M), std::move(Ctx))),
155                     Succeeded());
156 
157   auto Result = cantFail(ES->lookup({JD}, Mangle("main")));
158   auto FuncPtr = Result.getAddress().toPtr<int (*)()>();
159   for (size_t I = 0; I <= ReOptimizeLayer::CallCountThreshold; I++)
160     EXPECT_EQ(FuncPtr(), 42);
161   EXPECT_EQ(FuncPtr(), 53);
162 }
163