xref: /llvm-project/llvm/lib/ExecutionEngine/Orc/ReOptimizeLayer.cpp (revision 188ede28e046c911cb8e604fd1adc2b5cc1f264b)
1 #include "llvm/ExecutionEngine/Orc/ReOptimizeLayer.h"
2 #include "llvm/ExecutionEngine/Orc/Mangling.h"
3 
4 using namespace llvm;
5 using namespace orc;
6 
7 bool ReOptimizeLayer::ReOptMaterializationUnitState::tryStartReoptimize() {
8   std::unique_lock<std::mutex> Lock(Mutex);
9   if (Reoptimizing)
10     return false;
11 
12   Reoptimizing = true;
13   return true;
14 }
15 
16 void ReOptimizeLayer::ReOptMaterializationUnitState::reoptimizeSucceeded() {
17   std::unique_lock<std::mutex> Lock(Mutex);
18   assert(Reoptimizing && "Tried to mark unstarted reoptimization as done");
19   Reoptimizing = false;
20   CurVersion++;
21 }
22 
23 void ReOptimizeLayer::ReOptMaterializationUnitState::reoptimizeFailed() {
24   std::unique_lock<std::mutex> Lock(Mutex);
25   assert(Reoptimizing && "Tried to mark unstarted reoptimization as done");
26   Reoptimizing = false;
27 }
28 
29 Error ReOptimizeLayer::reigsterRuntimeFunctions(JITDylib &PlatformJD) {
30   ExecutionSession::JITDispatchHandlerAssociationMap WFs;
31   using ReoptimizeSPSSig = shared::SPSError(uint64_t, uint32_t);
32   WFs[Mangle("__orc_rt_reoptimize_tag")] =
33       ES.wrapAsyncWithSPS<ReoptimizeSPSSig>(this,
34                                             &ReOptimizeLayer::rt_reoptimize);
35   return ES.registerJITDispatchHandlers(PlatformJD, std::move(WFs));
36 }
37 
38 void ReOptimizeLayer::emit(std::unique_ptr<MaterializationResponsibility> R,
39                            ThreadSafeModule TSM) {
40   auto &JD = R->getTargetJITDylib();
41 
42   bool HasNonCallable = false;
43   for (auto &KV : R->getSymbols()) {
44     auto &Flags = KV.second;
45     if (!Flags.isCallable())
46       HasNonCallable = true;
47   }
48 
49   if (HasNonCallable) {
50     BaseLayer.emit(std::move(R), std::move(TSM));
51     return;
52   }
53 
54   auto &MUState = createMaterializationUnitState(TSM);
55 
56   if (auto Err = R->withResourceKeyDo([&](ResourceKey Key) {
57         registerMaterializationUnitResource(Key, MUState);
58       })) {
59     ES.reportError(std::move(Err));
60     R->failMaterialization();
61     return;
62   }
63 
64   if (auto Err =
65           ProfilerFunc(*this, MUState.getID(), MUState.getCurVersion(), TSM)) {
66     ES.reportError(std::move(Err));
67     R->failMaterialization();
68     return;
69   }
70 
71   auto InitialDests =
72       emitMUImplSymbols(MUState, MUState.getCurVersion(), JD, std::move(TSM));
73   if (!InitialDests) {
74     ES.reportError(InitialDests.takeError());
75     R->failMaterialization();
76     return;
77   }
78 
79   RSManager.emitRedirectableSymbols(std::move(R), std::move(*InitialDests));
80 }
81 
82 Error ReOptimizeLayer::reoptimizeIfCallFrequent(ReOptimizeLayer &Parent,
83                                                 ReOptMaterializationUnitID MUID,
84                                                 unsigned CurVersion,
85                                                 ThreadSafeModule &TSM) {
86   return TSM.withModuleDo([&](Module &M) -> Error {
87     Type *I64Ty = Type::getInt64Ty(M.getContext());
88     GlobalVariable *Counter = new GlobalVariable(
89         M, I64Ty, false, GlobalValue::InternalLinkage,
90         Constant::getNullValue(I64Ty), "__orc_reopt_counter");
91     auto ArgBufferConst = createReoptimizeArgBuffer(M, MUID, CurVersion);
92     if (auto Err = ArgBufferConst.takeError())
93       return Err;
94     GlobalVariable *ArgBuffer =
95         new GlobalVariable(M, (*ArgBufferConst)->getType(), true,
96                            GlobalValue::InternalLinkage, (*ArgBufferConst));
97     for (auto &F : M) {
98       if (F.isDeclaration())
99         continue;
100       auto &BB = F.getEntryBlock();
101       auto *IP = &*BB.getFirstInsertionPt();
102       IRBuilder<> IRB(IP);
103       Value *Threshold = ConstantInt::get(I64Ty, CallCountThreshold, true);
104       Value *Cnt = IRB.CreateLoad(I64Ty, Counter);
105       // Use EQ to prevent further reoptimize calls.
106       Value *Cmp = IRB.CreateICmpEQ(Cnt, Threshold);
107       Value *Added = IRB.CreateAdd(Cnt, ConstantInt::get(I64Ty, 1));
108       (void)IRB.CreateStore(Added, Counter);
109       Instruction *SplitTerminator = SplitBlockAndInsertIfThen(Cmp, IP, false);
110       createReoptimizeCall(M, *SplitTerminator, ArgBuffer);
111     }
112     return Error::success();
113   });
114 }
115 
116 Expected<SymbolMap>
117 ReOptimizeLayer::emitMUImplSymbols(ReOptMaterializationUnitState &MUState,
118                                    uint32_t Version, JITDylib &JD,
119                                    ThreadSafeModule TSM) {
120   DenseMap<SymbolStringPtr, SymbolStringPtr> RenamedMap;
121   cantFail(TSM.withModuleDo([&](Module &M) -> Error {
122     MangleAndInterner Mangle(ES, M.getDataLayout());
123     for (auto &F : M)
124       if (!F.isDeclaration()) {
125         std::string NewName =
126             (F.getName() + ".__def__." + Twine(Version)).str();
127         RenamedMap[Mangle(F.getName())] = Mangle(NewName);
128         F.setName(NewName);
129       }
130     return Error::success();
131   }));
132 
133   auto RT = JD.createResourceTracker();
134   if (auto Err =
135           JD.define(std::make_unique<BasicIRLayerMaterializationUnit>(
136                         BaseLayer, *getManglingOptions(), std::move(TSM)),
137                     RT))
138     return Err;
139   MUState.setResourceTracker(RT);
140 
141   SymbolLookupSet LookupSymbols;
142   for (auto [K, V] : RenamedMap)
143     LookupSymbols.add(V);
144 
145   auto ImplSymbols =
146       ES.lookup({{&JD, JITDylibLookupFlags::MatchAllSymbols}}, LookupSymbols,
147                 LookupKind::Static, SymbolState::Resolved);
148   if (auto Err = ImplSymbols.takeError())
149     return Err;
150 
151   SymbolMap Result;
152   for (auto [K, V] : RenamedMap)
153     Result[K] = (*ImplSymbols)[V];
154 
155   return Result;
156 }
157 
158 void ReOptimizeLayer::rt_reoptimize(SendErrorFn SendResult,
159                                     ReOptMaterializationUnitID MUID,
160                                     uint32_t CurVersion) {
161   auto &MUState = getMaterializationUnitState(MUID);
162   if (CurVersion < MUState.getCurVersion() || !MUState.tryStartReoptimize()) {
163     SendResult(Error::success());
164     return;
165   }
166 
167   ThreadSafeModule TSM = cloneToNewContext(MUState.getThreadSafeModule());
168   auto OldRT = MUState.getResourceTracker();
169   auto &JD = OldRT->getJITDylib();
170 
171   if (auto Err = ReOptFunc(*this, MUID, CurVersion + 1, OldRT, TSM)) {
172     ES.reportError(std::move(Err));
173     MUState.reoptimizeFailed();
174     SendResult(Error::success());
175     return;
176   }
177 
178   auto SymbolDests =
179       emitMUImplSymbols(MUState, CurVersion + 1, JD, std::move(TSM));
180   if (!SymbolDests) {
181     ES.reportError(SymbolDests.takeError());
182     MUState.reoptimizeFailed();
183     SendResult(Error::success());
184     return;
185   }
186 
187   if (auto Err = RSManager.redirect(JD, std::move(*SymbolDests))) {
188     ES.reportError(std::move(Err));
189     MUState.reoptimizeFailed();
190     SendResult(Error::success());
191     return;
192   }
193 
194   MUState.reoptimizeSucceeded();
195   SendResult(Error::success());
196 }
197 
198 Expected<Constant *> ReOptimizeLayer::createReoptimizeArgBuffer(
199     Module &M, ReOptMaterializationUnitID MUID, uint32_t CurVersion) {
200   size_t ArgBufferSize = SPSReoptimizeArgList::size(MUID, CurVersion);
201   std::vector<char> ArgBuffer(ArgBufferSize);
202   shared::SPSOutputBuffer OB(ArgBuffer.data(), ArgBuffer.size());
203   if (!SPSReoptimizeArgList::serialize(OB, MUID, CurVersion))
204     return make_error<StringError>("Could not serealize args list",
205                                    inconvertibleErrorCode());
206   return ConstantDataArray::get(M.getContext(), ArrayRef(ArgBuffer));
207 }
208 
209 void ReOptimizeLayer::createReoptimizeCall(Module &M, Instruction &IP,
210                                            GlobalVariable *ArgBuffer) {
211   GlobalVariable *DispatchCtx =
212       M.getGlobalVariable("__orc_rt_jit_dispatch_ctx");
213   if (!DispatchCtx)
214     DispatchCtx = new GlobalVariable(M, PointerType::get(M.getContext(), 0),
215                                      false, GlobalValue::ExternalLinkage,
216                                      nullptr, "__orc_rt_jit_dispatch_ctx");
217   GlobalVariable *ReoptimizeTag =
218       M.getGlobalVariable("__orc_rt_reoptimize_tag");
219   if (!ReoptimizeTag)
220     ReoptimizeTag = new GlobalVariable(M, PointerType::get(M.getContext(), 0),
221                                        false, GlobalValue::ExternalLinkage,
222                                        nullptr, "__orc_rt_reoptimize_tag");
223   Function *DispatchFunc = M.getFunction("__orc_rt_jit_dispatch");
224   if (!DispatchFunc) {
225     std::vector<Type *> Args = {PointerType::get(M.getContext(), 0),
226                                 PointerType::get(M.getContext(), 0),
227                                 PointerType::get(M.getContext(), 0),
228                                 IntegerType::get(M.getContext(), 64)};
229     FunctionType *FuncTy =
230         FunctionType::get(Type::getVoidTy(M.getContext()), Args, false);
231     DispatchFunc = Function::Create(FuncTy, GlobalValue::ExternalLinkage,
232                                     "__orc_rt_jit_dispatch", &M);
233   }
234   size_t ArgBufferSizeConst =
235       SPSReoptimizeArgList::size(ReOptMaterializationUnitID{}, uint32_t{});
236   Constant *ArgBufferSize = ConstantInt::get(
237       IntegerType::get(M.getContext(), 64), ArgBufferSizeConst, false);
238   IRBuilder<> IRB(&IP);
239   (void)IRB.CreateCall(DispatchFunc,
240                        {DispatchCtx, ReoptimizeTag, ArgBuffer, ArgBufferSize});
241 }
242 
243 ReOptimizeLayer::ReOptMaterializationUnitState &
244 ReOptimizeLayer::createMaterializationUnitState(const ThreadSafeModule &TSM) {
245   std::unique_lock<std::mutex> Lock(Mutex);
246   ReOptMaterializationUnitID MUID = NextID;
247   MUStates.emplace(MUID,
248                    ReOptMaterializationUnitState(MUID, cloneToNewContext(TSM)));
249   ++NextID;
250   return MUStates.at(MUID);
251 }
252 
253 ReOptimizeLayer::ReOptMaterializationUnitState &
254 ReOptimizeLayer::getMaterializationUnitState(ReOptMaterializationUnitID MUID) {
255   std::unique_lock<std::mutex> Lock(Mutex);
256   return MUStates.at(MUID);
257 }
258 
259 void ReOptimizeLayer::registerMaterializationUnitResource(
260     ResourceKey Key, ReOptMaterializationUnitState &State) {
261   std::unique_lock<std::mutex> Lock(Mutex);
262   MUResources[Key].insert(State.getID());
263 }
264 
265 Error ReOptimizeLayer::handleRemoveResources(JITDylib &JD, ResourceKey K) {
266   std::unique_lock<std::mutex> Lock(Mutex);
267   for (auto MUID : MUResources[K])
268     MUStates.erase(MUID);
269 
270   MUResources.erase(K);
271   return Error::success();
272 }
273 
274 void ReOptimizeLayer::handleTransferResources(JITDylib &JD, ResourceKey DstK,
275                                               ResourceKey SrcK) {
276   std::unique_lock<std::mutex> Lock(Mutex);
277   MUResources[DstK].insert(MUResources[SrcK].begin(), MUResources[SrcK].end());
278   MUResources.erase(SrcK);
279 }
280