xref: /llvm-project/llvm/unittests/ExecutionEngine/Orc/ReOptimizeLayerTest.cpp (revision 7fea5c034ca1e08403da39d64f20b08a6e7542bd)
1 #include "llvm/ExecutionEngine/Orc/ReOptimizeLayer.h"
2 #include "OrcTestCommon.h"
3 #include "llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h"
4 #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
5 #include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h"
6 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
7 #include "llvm/ExecutionEngine/Orc/IRPartitionLayer.h"
8 #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
9 #include "llvm/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.h"
10 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
11 #include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h"
12 #include "llvm/ExecutionEngine/Orc/ObjectTransformLayer.h"
13 #include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h"
14 #include "llvm/IR/IRBuilder.h"
15 #include "llvm/Support/CodeGen.h"
16 #include "llvm/TargetParser/Host.h"
17 #include "llvm/Testing/Support/Error.h"
18 #include "gtest/gtest.h"
19 
20 using namespace llvm;
21 using namespace llvm::orc;
22 using namespace llvm::jitlink;
23 
24 class ReOptimizeLayerTest : public testing::Test {
25 public:
26   ~ReOptimizeLayerTest() {
27     if (ES)
28       if (auto Err = ES->endSession())
29         ES->reportError(std::move(Err));
30   }
31 
32 protected:
33   void SetUp() override {
34     auto JTMB = JITTargetMachineBuilder::detectHost();
35     // Bail out if we can not detect the host.
36     if (!JTMB) {
37       consumeError(JTMB.takeError());
38       GTEST_SKIP();
39     }
40 
41     auto EPC = SelfExecutorProcessControl::Create();
42     if (!EPC) {
43       consumeError(EPC.takeError());
44       GTEST_SKIP();
45     }
46 
47     auto DLOrErr = JTMB->getDefaultDataLayoutForTarget();
48     if (!DLOrErr) {
49       consumeError(DLOrErr.takeError());
50       GTEST_SKIP();
51     }
52     ES = std::make_unique<ExecutionSession>(std::move(*EPC));
53     JD = &ES->createBareJITDylib("main");
54     ObjLinkingLayer = std::make_unique<ObjectLinkingLayer>(
55         *ES, std::make_unique<InProcessMemoryManager>(16384));
56     DL = std::make_unique<DataLayout>(std::move(*DLOrErr));
57 
58     auto TM = JTMB->createTargetMachine();
59     if (!TM) {
60       consumeError(TM.takeError());
61       GTEST_SKIP();
62     }
63     auto CompileFunction =
64         std::make_unique<TMOwningSimpleCompiler>(std::move(*TM));
65     CompileLayer = std::make_unique<IRCompileLayer>(*ES, *ObjLinkingLayer,
66                                                     std::move(CompileFunction));
67   }
68 
69   Error addIRModule(ResourceTrackerSP RT, ThreadSafeModule TSM) {
70     assert(TSM && "Can not add null module");
71 
72     TSM.withModuleDo([&](Module &M) { M.setDataLayout(*DL); });
73 
74     return ROLayer->add(std::move(RT), std::move(TSM));
75   }
76 
77   JITDylib *JD{nullptr};
78   std::unique_ptr<ExecutionSession> ES;
79   std::unique_ptr<ObjectLinkingLayer> ObjLinkingLayer;
80   std::unique_ptr<IRCompileLayer> CompileLayer;
81   std::unique_ptr<ReOptimizeLayer> ROLayer;
82   std::unique_ptr<DataLayout> DL;
83 };
84 
85 static Function *createRetFunction(Module *M, StringRef Name,
86                                    uint32_t ReturnCode) {
87   Function *Result = Function::Create(
88       FunctionType::get(Type::getInt32Ty(M->getContext()), {}, false),
89       GlobalValue::ExternalLinkage, Name, M);
90 
91   BasicBlock *BB = BasicBlock::Create(M->getContext(), Name, Result);
92   IRBuilder<> Builder(M->getContext());
93   Builder.SetInsertPoint(BB);
94 
95   Value *RetValue = ConstantInt::get(M->getContext(), APInt(32, ReturnCode));
96   Builder.CreateRet(RetValue);
97   return Result;
98 }
99 
100 TEST_F(ReOptimizeLayerTest, BasicReOptimization) {
101   MangleAndInterner Mangle(*ES, *DL);
102 
103   auto &EPC = ES->getExecutorProcessControl();
104   EXPECT_THAT_ERROR(JD->define(absoluteSymbols(
105                         {{Mangle("__orc_rt_jit_dispatch"),
106                           {EPC.getJITDispatchInfo().JITDispatchFunction,
107                            JITSymbolFlags::Exported}},
108                          {Mangle("__orc_rt_jit_dispatch_ctx"),
109                           {EPC.getJITDispatchInfo().JITDispatchContext,
110                            JITSymbolFlags::Exported}},
111                          {Mangle("__orc_rt_reoptimize_tag"),
112                           {ExecutorAddr(), JITSymbolFlags::Exported}}})),
113                     Succeeded());
114 
115   auto RM = JITLinkRedirectableSymbolManager::Create(*ObjLinkingLayer, *JD);
116   EXPECT_THAT_ERROR(RM.takeError(), Succeeded());
117 
118   ROLayer = std::make_unique<ReOptimizeLayer>(*ES, *DL, *CompileLayer, **RM);
119   ROLayer->setReoptimizeFunc(
120       [&](ReOptimizeLayer &Parent,
121           ReOptimizeLayer::ReOptMaterializationUnitID MUID, unsigned CurVerison,
122           ResourceTrackerSP OldRT, ThreadSafeModule &TSM) {
123         TSM.withModuleDo([&](Module &M) {
124           for (auto &F : M) {
125             if (F.isDeclaration())
126               continue;
127             for (auto &B : F) {
128               for (auto &I : B) {
129                 if (ReturnInst *Ret = dyn_cast<ReturnInst>(&I)) {
130                   Value *RetValue =
131                       ConstantInt::get(M.getContext(), APInt(32, 53));
132                   Ret->setOperand(0, RetValue);
133                 }
134               }
135             }
136           }
137         });
138         return Error::success();
139       });
140   EXPECT_THAT_ERROR(ROLayer->reigsterRuntimeFunctions(*JD), Succeeded());
141 
142   ThreadSafeContext Ctx(std::make_unique<LLVMContext>());
143   auto M = std::make_unique<Module>("<main>", *Ctx.getContext());
144   M->setTargetTriple(sys::getProcessTriple());
145 
146   (void)createRetFunction(M.get(), "main", 42);
147 
148   EXPECT_THAT_ERROR(addIRModule(JD->getDefaultResourceTracker(),
149                                 ThreadSafeModule(std::move(M), std::move(Ctx))),
150                     Succeeded());
151 
152   auto Result = cantFail(ES->lookup({JD}, Mangle("main")));
153   auto FuncPtr = Result.getAddress().toPtr<int (*)()>();
154   for (size_t I = 0; I <= ReOptimizeLayer::CallCountThreshold; I++)
155     EXPECT_EQ(FuncPtr(), 42);
156   EXPECT_EQ(FuncPtr(), 53);
157 }
158