1 #include "llvm/ExecutionEngine/Orc/ReOptimizeLayer.h" 2 #include "llvm/ExecutionEngine/Orc/Mangling.h" 3 4 using namespace llvm; 5 using namespace orc; 6 7 bool ReOptimizeLayer::ReOptMaterializationUnitState::tryStartReoptimize() { 8 std::unique_lock<std::mutex> Lock(Mutex); 9 if (Reoptimizing) 10 return false; 11 12 Reoptimizing = true; 13 return true; 14 } 15 16 void ReOptimizeLayer::ReOptMaterializationUnitState::reoptimizeSucceeded() { 17 std::unique_lock<std::mutex> Lock(Mutex); 18 assert(Reoptimizing && "Tried to mark unstarted reoptimization as done"); 19 Reoptimizing = false; 20 CurVersion++; 21 } 22 23 void ReOptimizeLayer::ReOptMaterializationUnitState::reoptimizeFailed() { 24 std::unique_lock<std::mutex> Lock(Mutex); 25 assert(Reoptimizing && "Tried to mark unstarted reoptimization as done"); 26 Reoptimizing = false; 27 } 28 29 Error ReOptimizeLayer::reigsterRuntimeFunctions(JITDylib &PlatformJD) { 30 ExecutionSession::JITDispatchHandlerAssociationMap WFs; 31 using ReoptimizeSPSSig = shared::SPSError(uint64_t, uint32_t); 32 WFs[Mangle("__orc_rt_reoptimize_tag")] = 33 ES.wrapAsyncWithSPS<ReoptimizeSPSSig>(this, 34 &ReOptimizeLayer::rt_reoptimize); 35 return ES.registerJITDispatchHandlers(PlatformJD, std::move(WFs)); 36 } 37 38 void ReOptimizeLayer::emit(std::unique_ptr<MaterializationResponsibility> R, 39 ThreadSafeModule TSM) { 40 auto &JD = R->getTargetJITDylib(); 41 42 bool HasNonCallable = false; 43 for (auto &KV : R->getSymbols()) { 44 auto &Flags = KV.second; 45 if (!Flags.isCallable()) 46 HasNonCallable = true; 47 } 48 49 if (HasNonCallable) { 50 BaseLayer.emit(std::move(R), std::move(TSM)); 51 return; 52 } 53 54 auto &MUState = createMaterializationUnitState(TSM); 55 56 if (auto Err = R->withResourceKeyDo([&](ResourceKey Key) { 57 registerMaterializationUnitResource(Key, MUState); 58 })) { 59 ES.reportError(std::move(Err)); 60 R->failMaterialization(); 61 return; 62 } 63 64 if (auto Err = 65 ProfilerFunc(*this, MUState.getID(), MUState.getCurVersion(), TSM)) { 66 ES.reportError(std::move(Err)); 67 R->failMaterialization(); 68 return; 69 } 70 71 auto InitialDests = 72 emitMUImplSymbols(MUState, MUState.getCurVersion(), JD, std::move(TSM)); 73 if (!InitialDests) { 74 ES.reportError(InitialDests.takeError()); 75 R->failMaterialization(); 76 return; 77 } 78 79 RSManager.emitRedirectableSymbols(std::move(R), std::move(*InitialDests)); 80 } 81 82 Error ReOptimizeLayer::reoptimizeIfCallFrequent(ReOptimizeLayer &Parent, 83 ReOptMaterializationUnitID MUID, 84 unsigned CurVersion, 85 ThreadSafeModule &TSM) { 86 return TSM.withModuleDo([&](Module &M) -> Error { 87 Type *I64Ty = Type::getInt64Ty(M.getContext()); 88 GlobalVariable *Counter = new GlobalVariable( 89 M, I64Ty, false, GlobalValue::InternalLinkage, 90 Constant::getNullValue(I64Ty), "__orc_reopt_counter"); 91 auto ArgBufferConst = createReoptimizeArgBuffer(M, MUID, CurVersion); 92 if (auto Err = ArgBufferConst.takeError()) 93 return Err; 94 GlobalVariable *ArgBuffer = 95 new GlobalVariable(M, (*ArgBufferConst)->getType(), true, 96 GlobalValue::InternalLinkage, (*ArgBufferConst)); 97 for (auto &F : M) { 98 if (F.isDeclaration()) 99 continue; 100 auto &BB = F.getEntryBlock(); 101 auto *IP = &*BB.getFirstInsertionPt(); 102 IRBuilder<> IRB(IP); 103 Value *Threshold = ConstantInt::get(I64Ty, CallCountThreshold, true); 104 Value *Cnt = IRB.CreateLoad(I64Ty, Counter); 105 // Use EQ to prevent further reoptimize calls. 106 Value *Cmp = IRB.CreateICmpEQ(Cnt, Threshold); 107 Value *Added = IRB.CreateAdd(Cnt, ConstantInt::get(I64Ty, 1)); 108 (void)IRB.CreateStore(Added, Counter); 109 Instruction *SplitTerminator = SplitBlockAndInsertIfThen(Cmp, IP, false); 110 createReoptimizeCall(M, *SplitTerminator, ArgBuffer); 111 } 112 return Error::success(); 113 }); 114 } 115 116 Expected<SymbolMap> 117 ReOptimizeLayer::emitMUImplSymbols(ReOptMaterializationUnitState &MUState, 118 uint32_t Version, JITDylib &JD, 119 ThreadSafeModule TSM) { 120 DenseMap<SymbolStringPtr, SymbolStringPtr> RenamedMap; 121 cantFail(TSM.withModuleDo([&](Module &M) -> Error { 122 MangleAndInterner Mangle(ES, M.getDataLayout()); 123 for (auto &F : M) 124 if (!F.isDeclaration()) { 125 std::string NewName = 126 (F.getName() + ".__def__." + Twine(Version)).str(); 127 RenamedMap[Mangle(F.getName())] = Mangle(NewName); 128 F.setName(NewName); 129 } 130 return Error::success(); 131 })); 132 133 auto RT = JD.createResourceTracker(); 134 if (auto Err = 135 JD.define(std::make_unique<BasicIRLayerMaterializationUnit>( 136 BaseLayer, *getManglingOptions(), std::move(TSM)), 137 RT)) 138 return Err; 139 MUState.setResourceTracker(RT); 140 141 SymbolLookupSet LookupSymbols; 142 for (auto [K, V] : RenamedMap) 143 LookupSymbols.add(V); 144 145 auto ImplSymbols = 146 ES.lookup({{&JD, JITDylibLookupFlags::MatchAllSymbols}}, LookupSymbols, 147 LookupKind::Static, SymbolState::Resolved); 148 if (auto Err = ImplSymbols.takeError()) 149 return Err; 150 151 SymbolMap Result; 152 for (auto [K, V] : RenamedMap) 153 Result[K] = (*ImplSymbols)[V]; 154 155 return Result; 156 } 157 158 void ReOptimizeLayer::rt_reoptimize(SendErrorFn SendResult, 159 ReOptMaterializationUnitID MUID, 160 uint32_t CurVersion) { 161 auto &MUState = getMaterializationUnitState(MUID); 162 if (CurVersion < MUState.getCurVersion() || !MUState.tryStartReoptimize()) { 163 SendResult(Error::success()); 164 return; 165 } 166 167 ThreadSafeModule TSM = cloneToNewContext(MUState.getThreadSafeModule()); 168 auto OldRT = MUState.getResourceTracker(); 169 auto &JD = OldRT->getJITDylib(); 170 171 if (auto Err = ReOptFunc(*this, MUID, CurVersion + 1, OldRT, TSM)) { 172 ES.reportError(std::move(Err)); 173 MUState.reoptimizeFailed(); 174 SendResult(Error::success()); 175 return; 176 } 177 178 auto SymbolDests = 179 emitMUImplSymbols(MUState, CurVersion + 1, JD, std::move(TSM)); 180 if (!SymbolDests) { 181 ES.reportError(SymbolDests.takeError()); 182 MUState.reoptimizeFailed(); 183 SendResult(Error::success()); 184 return; 185 } 186 187 if (auto Err = RSManager.redirect(JD, std::move(*SymbolDests))) { 188 ES.reportError(std::move(Err)); 189 MUState.reoptimizeFailed(); 190 SendResult(Error::success()); 191 return; 192 } 193 194 MUState.reoptimizeSucceeded(); 195 SendResult(Error::success()); 196 } 197 198 Expected<Constant *> ReOptimizeLayer::createReoptimizeArgBuffer( 199 Module &M, ReOptMaterializationUnitID MUID, uint32_t CurVersion) { 200 size_t ArgBufferSize = SPSReoptimizeArgList::size(MUID, CurVersion); 201 std::vector<char> ArgBuffer(ArgBufferSize); 202 shared::SPSOutputBuffer OB(ArgBuffer.data(), ArgBuffer.size()); 203 if (!SPSReoptimizeArgList::serialize(OB, MUID, CurVersion)) 204 return make_error<StringError>("Could not serealize args list", 205 inconvertibleErrorCode()); 206 return ConstantDataArray::get(M.getContext(), ArrayRef(ArgBuffer)); 207 } 208 209 void ReOptimizeLayer::createReoptimizeCall(Module &M, Instruction &IP, 210 GlobalVariable *ArgBuffer) { 211 GlobalVariable *DispatchCtx = 212 M.getGlobalVariable("__orc_rt_jit_dispatch_ctx"); 213 if (!DispatchCtx) 214 DispatchCtx = new GlobalVariable(M, PointerType::get(M.getContext(), 0), 215 false, GlobalValue::ExternalLinkage, 216 nullptr, "__orc_rt_jit_dispatch_ctx"); 217 GlobalVariable *ReoptimizeTag = 218 M.getGlobalVariable("__orc_rt_reoptimize_tag"); 219 if (!ReoptimizeTag) 220 ReoptimizeTag = new GlobalVariable(M, PointerType::get(M.getContext(), 0), 221 false, GlobalValue::ExternalLinkage, 222 nullptr, "__orc_rt_reoptimize_tag"); 223 Function *DispatchFunc = M.getFunction("__orc_rt_jit_dispatch"); 224 if (!DispatchFunc) { 225 std::vector<Type *> Args = {PointerType::get(M.getContext(), 0), 226 PointerType::get(M.getContext(), 0), 227 PointerType::get(M.getContext(), 0), 228 IntegerType::get(M.getContext(), 64)}; 229 FunctionType *FuncTy = 230 FunctionType::get(Type::getVoidTy(M.getContext()), Args, false); 231 DispatchFunc = Function::Create(FuncTy, GlobalValue::ExternalLinkage, 232 "__orc_rt_jit_dispatch", &M); 233 } 234 size_t ArgBufferSizeConst = 235 SPSReoptimizeArgList::size(ReOptMaterializationUnitID{}, uint32_t{}); 236 Constant *ArgBufferSize = ConstantInt::get( 237 IntegerType::get(M.getContext(), 64), ArgBufferSizeConst, false); 238 IRBuilder<> IRB(&IP); 239 (void)IRB.CreateCall(DispatchFunc, 240 {DispatchCtx, ReoptimizeTag, ArgBuffer, ArgBufferSize}); 241 } 242 243 ReOptimizeLayer::ReOptMaterializationUnitState & 244 ReOptimizeLayer::createMaterializationUnitState(const ThreadSafeModule &TSM) { 245 std::unique_lock<std::mutex> Lock(Mutex); 246 ReOptMaterializationUnitID MUID = NextID; 247 MUStates.emplace(MUID, 248 ReOptMaterializationUnitState(MUID, cloneToNewContext(TSM))); 249 ++NextID; 250 return MUStates.at(MUID); 251 } 252 253 ReOptimizeLayer::ReOptMaterializationUnitState & 254 ReOptimizeLayer::getMaterializationUnitState(ReOptMaterializationUnitID MUID) { 255 std::unique_lock<std::mutex> Lock(Mutex); 256 return MUStates.at(MUID); 257 } 258 259 void ReOptimizeLayer::registerMaterializationUnitResource( 260 ResourceKey Key, ReOptMaterializationUnitState &State) { 261 std::unique_lock<std::mutex> Lock(Mutex); 262 MUResources[Key].insert(State.getID()); 263 } 264 265 Error ReOptimizeLayer::handleRemoveResources(JITDylib &JD, ResourceKey K) { 266 std::unique_lock<std::mutex> Lock(Mutex); 267 for (auto MUID : MUResources[K]) 268 MUStates.erase(MUID); 269 270 MUResources.erase(K); 271 return Error::success(); 272 } 273 274 void ReOptimizeLayer::handleTransferResources(JITDylib &JD, ResourceKey DstK, 275 ResourceKey SrcK) { 276 std::unique_lock<std::mutex> Lock(Mutex); 277 MUResources[DstK].insert(MUResources[SrcK].begin(), MUResources[SrcK].end()); 278 MUResources.erase(SrcK); 279 } 280