xref: /llvm-project/llvm/lib/ExecutionEngine/Orc/ReOptimizeLayer.cpp (revision 188ede28e046c911cb8e604fd1adc2b5cc1f264b)
1*188ede28SSunho Kim #include "llvm/ExecutionEngine/Orc/ReOptimizeLayer.h"
2*188ede28SSunho Kim #include "llvm/ExecutionEngine/Orc/Mangling.h"
3*188ede28SSunho Kim 
4*188ede28SSunho Kim using namespace llvm;
5*188ede28SSunho Kim using namespace orc;
6*188ede28SSunho Kim 
7*188ede28SSunho Kim bool ReOptimizeLayer::ReOptMaterializationUnitState::tryStartReoptimize() {
8*188ede28SSunho Kim   std::unique_lock<std::mutex> Lock(Mutex);
9*188ede28SSunho Kim   if (Reoptimizing)
10*188ede28SSunho Kim     return false;
11*188ede28SSunho Kim 
12*188ede28SSunho Kim   Reoptimizing = true;
13*188ede28SSunho Kim   return true;
14*188ede28SSunho Kim }
15*188ede28SSunho Kim 
16*188ede28SSunho Kim void ReOptimizeLayer::ReOptMaterializationUnitState::reoptimizeSucceeded() {
17*188ede28SSunho Kim   std::unique_lock<std::mutex> Lock(Mutex);
18*188ede28SSunho Kim   assert(Reoptimizing && "Tried to mark unstarted reoptimization as done");
19*188ede28SSunho Kim   Reoptimizing = false;
20*188ede28SSunho Kim   CurVersion++;
21*188ede28SSunho Kim }
22*188ede28SSunho Kim 
23*188ede28SSunho Kim void ReOptimizeLayer::ReOptMaterializationUnitState::reoptimizeFailed() {
24*188ede28SSunho Kim   std::unique_lock<std::mutex> Lock(Mutex);
25*188ede28SSunho Kim   assert(Reoptimizing && "Tried to mark unstarted reoptimization as done");
26*188ede28SSunho Kim   Reoptimizing = false;
27*188ede28SSunho Kim }
28*188ede28SSunho Kim 
29*188ede28SSunho Kim Error ReOptimizeLayer::reigsterRuntimeFunctions(JITDylib &PlatformJD) {
30*188ede28SSunho Kim   ExecutionSession::JITDispatchHandlerAssociationMap WFs;
31*188ede28SSunho Kim   using ReoptimizeSPSSig = shared::SPSError(uint64_t, uint32_t);
32*188ede28SSunho Kim   WFs[Mangle("__orc_rt_reoptimize_tag")] =
33*188ede28SSunho Kim       ES.wrapAsyncWithSPS<ReoptimizeSPSSig>(this,
34*188ede28SSunho Kim                                             &ReOptimizeLayer::rt_reoptimize);
35*188ede28SSunho Kim   return ES.registerJITDispatchHandlers(PlatformJD, std::move(WFs));
36*188ede28SSunho Kim }
37*188ede28SSunho Kim 
38*188ede28SSunho Kim void ReOptimizeLayer::emit(std::unique_ptr<MaterializationResponsibility> R,
39*188ede28SSunho Kim                            ThreadSafeModule TSM) {
40*188ede28SSunho Kim   auto &JD = R->getTargetJITDylib();
41*188ede28SSunho Kim 
42*188ede28SSunho Kim   bool HasNonCallable = false;
43*188ede28SSunho Kim   for (auto &KV : R->getSymbols()) {
44*188ede28SSunho Kim     auto &Flags = KV.second;
45*188ede28SSunho Kim     if (!Flags.isCallable())
46*188ede28SSunho Kim       HasNonCallable = true;
47*188ede28SSunho Kim   }
48*188ede28SSunho Kim 
49*188ede28SSunho Kim   if (HasNonCallable) {
50*188ede28SSunho Kim     BaseLayer.emit(std::move(R), std::move(TSM));
51*188ede28SSunho Kim     return;
52*188ede28SSunho Kim   }
53*188ede28SSunho Kim 
54*188ede28SSunho Kim   auto &MUState = createMaterializationUnitState(TSM);
55*188ede28SSunho Kim 
56*188ede28SSunho Kim   if (auto Err = R->withResourceKeyDo([&](ResourceKey Key) {
57*188ede28SSunho Kim         registerMaterializationUnitResource(Key, MUState);
58*188ede28SSunho Kim       })) {
59*188ede28SSunho Kim     ES.reportError(std::move(Err));
60*188ede28SSunho Kim     R->failMaterialization();
61*188ede28SSunho Kim     return;
62*188ede28SSunho Kim   }
63*188ede28SSunho Kim 
64*188ede28SSunho Kim   if (auto Err =
65*188ede28SSunho Kim           ProfilerFunc(*this, MUState.getID(), MUState.getCurVersion(), TSM)) {
66*188ede28SSunho Kim     ES.reportError(std::move(Err));
67*188ede28SSunho Kim     R->failMaterialization();
68*188ede28SSunho Kim     return;
69*188ede28SSunho Kim   }
70*188ede28SSunho Kim 
71*188ede28SSunho Kim   auto InitialDests =
72*188ede28SSunho Kim       emitMUImplSymbols(MUState, MUState.getCurVersion(), JD, std::move(TSM));
73*188ede28SSunho Kim   if (!InitialDests) {
74*188ede28SSunho Kim     ES.reportError(InitialDests.takeError());
75*188ede28SSunho Kim     R->failMaterialization();
76*188ede28SSunho Kim     return;
77*188ede28SSunho Kim   }
78*188ede28SSunho Kim 
79*188ede28SSunho Kim   RSManager.emitRedirectableSymbols(std::move(R), std::move(*InitialDests));
80*188ede28SSunho Kim }
81*188ede28SSunho Kim 
82*188ede28SSunho Kim Error ReOptimizeLayer::reoptimizeIfCallFrequent(ReOptimizeLayer &Parent,
83*188ede28SSunho Kim                                                 ReOptMaterializationUnitID MUID,
84*188ede28SSunho Kim                                                 unsigned CurVersion,
85*188ede28SSunho Kim                                                 ThreadSafeModule &TSM) {
86*188ede28SSunho Kim   return TSM.withModuleDo([&](Module &M) -> Error {
87*188ede28SSunho Kim     Type *I64Ty = Type::getInt64Ty(M.getContext());
88*188ede28SSunho Kim     GlobalVariable *Counter = new GlobalVariable(
89*188ede28SSunho Kim         M, I64Ty, false, GlobalValue::InternalLinkage,
90*188ede28SSunho Kim         Constant::getNullValue(I64Ty), "__orc_reopt_counter");
91*188ede28SSunho Kim     auto ArgBufferConst = createReoptimizeArgBuffer(M, MUID, CurVersion);
92*188ede28SSunho Kim     if (auto Err = ArgBufferConst.takeError())
93*188ede28SSunho Kim       return Err;
94*188ede28SSunho Kim     GlobalVariable *ArgBuffer =
95*188ede28SSunho Kim         new GlobalVariable(M, (*ArgBufferConst)->getType(), true,
96*188ede28SSunho Kim                            GlobalValue::InternalLinkage, (*ArgBufferConst));
97*188ede28SSunho Kim     for (auto &F : M) {
98*188ede28SSunho Kim       if (F.isDeclaration())
99*188ede28SSunho Kim         continue;
100*188ede28SSunho Kim       auto &BB = F.getEntryBlock();
101*188ede28SSunho Kim       auto *IP = &*BB.getFirstInsertionPt();
102*188ede28SSunho Kim       IRBuilder<> IRB(IP);
103*188ede28SSunho Kim       Value *Threshold = ConstantInt::get(I64Ty, CallCountThreshold, true);
104*188ede28SSunho Kim       Value *Cnt = IRB.CreateLoad(I64Ty, Counter);
105*188ede28SSunho Kim       // Use EQ to prevent further reoptimize calls.
106*188ede28SSunho Kim       Value *Cmp = IRB.CreateICmpEQ(Cnt, Threshold);
107*188ede28SSunho Kim       Value *Added = IRB.CreateAdd(Cnt, ConstantInt::get(I64Ty, 1));
108*188ede28SSunho Kim       (void)IRB.CreateStore(Added, Counter);
109*188ede28SSunho Kim       Instruction *SplitTerminator = SplitBlockAndInsertIfThen(Cmp, IP, false);
110*188ede28SSunho Kim       createReoptimizeCall(M, *SplitTerminator, ArgBuffer);
111*188ede28SSunho Kim     }
112*188ede28SSunho Kim     return Error::success();
113*188ede28SSunho Kim   });
114*188ede28SSunho Kim }
115*188ede28SSunho Kim 
116*188ede28SSunho Kim Expected<SymbolMap>
117*188ede28SSunho Kim ReOptimizeLayer::emitMUImplSymbols(ReOptMaterializationUnitState &MUState,
118*188ede28SSunho Kim                                    uint32_t Version, JITDylib &JD,
119*188ede28SSunho Kim                                    ThreadSafeModule TSM) {
120*188ede28SSunho Kim   DenseMap<SymbolStringPtr, SymbolStringPtr> RenamedMap;
121*188ede28SSunho Kim   cantFail(TSM.withModuleDo([&](Module &M) -> Error {
122*188ede28SSunho Kim     MangleAndInterner Mangle(ES, M.getDataLayout());
123*188ede28SSunho Kim     for (auto &F : M)
124*188ede28SSunho Kim       if (!F.isDeclaration()) {
125*188ede28SSunho Kim         std::string NewName =
126*188ede28SSunho Kim             (F.getName() + ".__def__." + Twine(Version)).str();
127*188ede28SSunho Kim         RenamedMap[Mangle(F.getName())] = Mangle(NewName);
128*188ede28SSunho Kim         F.setName(NewName);
129*188ede28SSunho Kim       }
130*188ede28SSunho Kim     return Error::success();
131*188ede28SSunho Kim   }));
132*188ede28SSunho Kim 
133*188ede28SSunho Kim   auto RT = JD.createResourceTracker();
134*188ede28SSunho Kim   if (auto Err =
135*188ede28SSunho Kim           JD.define(std::make_unique<BasicIRLayerMaterializationUnit>(
136*188ede28SSunho Kim                         BaseLayer, *getManglingOptions(), std::move(TSM)),
137*188ede28SSunho Kim                     RT))
138*188ede28SSunho Kim     return Err;
139*188ede28SSunho Kim   MUState.setResourceTracker(RT);
140*188ede28SSunho Kim 
141*188ede28SSunho Kim   SymbolLookupSet LookupSymbols;
142*188ede28SSunho Kim   for (auto [K, V] : RenamedMap)
143*188ede28SSunho Kim     LookupSymbols.add(V);
144*188ede28SSunho Kim 
145*188ede28SSunho Kim   auto ImplSymbols =
146*188ede28SSunho Kim       ES.lookup({{&JD, JITDylibLookupFlags::MatchAllSymbols}}, LookupSymbols,
147*188ede28SSunho Kim                 LookupKind::Static, SymbolState::Resolved);
148*188ede28SSunho Kim   if (auto Err = ImplSymbols.takeError())
149*188ede28SSunho Kim     return Err;
150*188ede28SSunho Kim 
151*188ede28SSunho Kim   SymbolMap Result;
152*188ede28SSunho Kim   for (auto [K, V] : RenamedMap)
153*188ede28SSunho Kim     Result[K] = (*ImplSymbols)[V];
154*188ede28SSunho Kim 
155*188ede28SSunho Kim   return Result;
156*188ede28SSunho Kim }
157*188ede28SSunho Kim 
158*188ede28SSunho Kim void ReOptimizeLayer::rt_reoptimize(SendErrorFn SendResult,
159*188ede28SSunho Kim                                     ReOptMaterializationUnitID MUID,
160*188ede28SSunho Kim                                     uint32_t CurVersion) {
161*188ede28SSunho Kim   auto &MUState = getMaterializationUnitState(MUID);
162*188ede28SSunho Kim   if (CurVersion < MUState.getCurVersion() || !MUState.tryStartReoptimize()) {
163*188ede28SSunho Kim     SendResult(Error::success());
164*188ede28SSunho Kim     return;
165*188ede28SSunho Kim   }
166*188ede28SSunho Kim 
167*188ede28SSunho Kim   ThreadSafeModule TSM = cloneToNewContext(MUState.getThreadSafeModule());
168*188ede28SSunho Kim   auto OldRT = MUState.getResourceTracker();
169*188ede28SSunho Kim   auto &JD = OldRT->getJITDylib();
170*188ede28SSunho Kim 
171*188ede28SSunho Kim   if (auto Err = ReOptFunc(*this, MUID, CurVersion + 1, OldRT, TSM)) {
172*188ede28SSunho Kim     ES.reportError(std::move(Err));
173*188ede28SSunho Kim     MUState.reoptimizeFailed();
174*188ede28SSunho Kim     SendResult(Error::success());
175*188ede28SSunho Kim     return;
176*188ede28SSunho Kim   }
177*188ede28SSunho Kim 
178*188ede28SSunho Kim   auto SymbolDests =
179*188ede28SSunho Kim       emitMUImplSymbols(MUState, CurVersion + 1, JD, std::move(TSM));
180*188ede28SSunho Kim   if (!SymbolDests) {
181*188ede28SSunho Kim     ES.reportError(SymbolDests.takeError());
182*188ede28SSunho Kim     MUState.reoptimizeFailed();
183*188ede28SSunho Kim     SendResult(Error::success());
184*188ede28SSunho Kim     return;
185*188ede28SSunho Kim   }
186*188ede28SSunho Kim 
187*188ede28SSunho Kim   if (auto Err = RSManager.redirect(JD, std::move(*SymbolDests))) {
188*188ede28SSunho Kim     ES.reportError(std::move(Err));
189*188ede28SSunho Kim     MUState.reoptimizeFailed();
190*188ede28SSunho Kim     SendResult(Error::success());
191*188ede28SSunho Kim     return;
192*188ede28SSunho Kim   }
193*188ede28SSunho Kim 
194*188ede28SSunho Kim   MUState.reoptimizeSucceeded();
195*188ede28SSunho Kim   SendResult(Error::success());
196*188ede28SSunho Kim }
197*188ede28SSunho Kim 
198*188ede28SSunho Kim Expected<Constant *> ReOptimizeLayer::createReoptimizeArgBuffer(
199*188ede28SSunho Kim     Module &M, ReOptMaterializationUnitID MUID, uint32_t CurVersion) {
200*188ede28SSunho Kim   size_t ArgBufferSize = SPSReoptimizeArgList::size(MUID, CurVersion);
201*188ede28SSunho Kim   std::vector<char> ArgBuffer(ArgBufferSize);
202*188ede28SSunho Kim   shared::SPSOutputBuffer OB(ArgBuffer.data(), ArgBuffer.size());
203*188ede28SSunho Kim   if (!SPSReoptimizeArgList::serialize(OB, MUID, CurVersion))
204*188ede28SSunho Kim     return make_error<StringError>("Could not serealize args list",
205*188ede28SSunho Kim                                    inconvertibleErrorCode());
206*188ede28SSunho Kim   return ConstantDataArray::get(M.getContext(), ArrayRef(ArgBuffer));
207*188ede28SSunho Kim }
208*188ede28SSunho Kim 
209*188ede28SSunho Kim void ReOptimizeLayer::createReoptimizeCall(Module &M, Instruction &IP,
210*188ede28SSunho Kim                                            GlobalVariable *ArgBuffer) {
211*188ede28SSunho Kim   GlobalVariable *DispatchCtx =
212*188ede28SSunho Kim       M.getGlobalVariable("__orc_rt_jit_dispatch_ctx");
213*188ede28SSunho Kim   if (!DispatchCtx)
214*188ede28SSunho Kim     DispatchCtx = new GlobalVariable(M, PointerType::get(M.getContext(), 0),
215*188ede28SSunho Kim                                      false, GlobalValue::ExternalLinkage,
216*188ede28SSunho Kim                                      nullptr, "__orc_rt_jit_dispatch_ctx");
217*188ede28SSunho Kim   GlobalVariable *ReoptimizeTag =
218*188ede28SSunho Kim       M.getGlobalVariable("__orc_rt_reoptimize_tag");
219*188ede28SSunho Kim   if (!ReoptimizeTag)
220*188ede28SSunho Kim     ReoptimizeTag = new GlobalVariable(M, PointerType::get(M.getContext(), 0),
221*188ede28SSunho Kim                                        false, GlobalValue::ExternalLinkage,
222*188ede28SSunho Kim                                        nullptr, "__orc_rt_reoptimize_tag");
223*188ede28SSunho Kim   Function *DispatchFunc = M.getFunction("__orc_rt_jit_dispatch");
224*188ede28SSunho Kim   if (!DispatchFunc) {
225*188ede28SSunho Kim     std::vector<Type *> Args = {PointerType::get(M.getContext(), 0),
226*188ede28SSunho Kim                                 PointerType::get(M.getContext(), 0),
227*188ede28SSunho Kim                                 PointerType::get(M.getContext(), 0),
228*188ede28SSunho Kim                                 IntegerType::get(M.getContext(), 64)};
229*188ede28SSunho Kim     FunctionType *FuncTy =
230*188ede28SSunho Kim         FunctionType::get(Type::getVoidTy(M.getContext()), Args, false);
231*188ede28SSunho Kim     DispatchFunc = Function::Create(FuncTy, GlobalValue::ExternalLinkage,
232*188ede28SSunho Kim                                     "__orc_rt_jit_dispatch", &M);
233*188ede28SSunho Kim   }
234*188ede28SSunho Kim   size_t ArgBufferSizeConst =
235*188ede28SSunho Kim       SPSReoptimizeArgList::size(ReOptMaterializationUnitID{}, uint32_t{});
236*188ede28SSunho Kim   Constant *ArgBufferSize = ConstantInt::get(
237*188ede28SSunho Kim       IntegerType::get(M.getContext(), 64), ArgBufferSizeConst, false);
238*188ede28SSunho Kim   IRBuilder<> IRB(&IP);
239*188ede28SSunho Kim   (void)IRB.CreateCall(DispatchFunc,
240*188ede28SSunho Kim                        {DispatchCtx, ReoptimizeTag, ArgBuffer, ArgBufferSize});
241*188ede28SSunho Kim }
242*188ede28SSunho Kim 
243*188ede28SSunho Kim ReOptimizeLayer::ReOptMaterializationUnitState &
244*188ede28SSunho Kim ReOptimizeLayer::createMaterializationUnitState(const ThreadSafeModule &TSM) {
245*188ede28SSunho Kim   std::unique_lock<std::mutex> Lock(Mutex);
246*188ede28SSunho Kim   ReOptMaterializationUnitID MUID = NextID;
247*188ede28SSunho Kim   MUStates.emplace(MUID,
248*188ede28SSunho Kim                    ReOptMaterializationUnitState(MUID, cloneToNewContext(TSM)));
249*188ede28SSunho Kim   ++NextID;
250*188ede28SSunho Kim   return MUStates.at(MUID);
251*188ede28SSunho Kim }
252*188ede28SSunho Kim 
253*188ede28SSunho Kim ReOptimizeLayer::ReOptMaterializationUnitState &
254*188ede28SSunho Kim ReOptimizeLayer::getMaterializationUnitState(ReOptMaterializationUnitID MUID) {
255*188ede28SSunho Kim   std::unique_lock<std::mutex> Lock(Mutex);
256*188ede28SSunho Kim   return MUStates.at(MUID);
257*188ede28SSunho Kim }
258*188ede28SSunho Kim 
259*188ede28SSunho Kim void ReOptimizeLayer::registerMaterializationUnitResource(
260*188ede28SSunho Kim     ResourceKey Key, ReOptMaterializationUnitState &State) {
261*188ede28SSunho Kim   std::unique_lock<std::mutex> Lock(Mutex);
262*188ede28SSunho Kim   MUResources[Key].insert(State.getID());
263*188ede28SSunho Kim }
264*188ede28SSunho Kim 
265*188ede28SSunho Kim Error ReOptimizeLayer::handleRemoveResources(JITDylib &JD, ResourceKey K) {
266*188ede28SSunho Kim   std::unique_lock<std::mutex> Lock(Mutex);
267*188ede28SSunho Kim   for (auto MUID : MUResources[K])
268*188ede28SSunho Kim     MUStates.erase(MUID);
269*188ede28SSunho Kim 
270*188ede28SSunho Kim   MUResources.erase(K);
271*188ede28SSunho Kim   return Error::success();
272*188ede28SSunho Kim }
273*188ede28SSunho Kim 
274*188ede28SSunho Kim void ReOptimizeLayer::handleTransferResources(JITDylib &JD, ResourceKey DstK,
275*188ede28SSunho Kim                                               ResourceKey SrcK) {
276*188ede28SSunho Kim   std::unique_lock<std::mutex> Lock(Mutex);
277*188ede28SSunho Kim   MUResources[DstK].insert(MUResources[SrcK].begin(), MUResources[SrcK].end());
278*188ede28SSunho Kim   MUResources.erase(SrcK);
279*188ede28SSunho Kim }
280