1 //===- PGOCtxProfLowering.cpp - Contextual PGO Instr. Lowering ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 10 #include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h" 11 #include "llvm/Analysis/CtxProfAnalysis.h" 12 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 13 #include "llvm/IR/Analysis.h" 14 #include "llvm/IR/DiagnosticInfo.h" 15 #include "llvm/IR/IRBuilder.h" 16 #include "llvm/IR/Instructions.h" 17 #include "llvm/IR/IntrinsicInst.h" 18 #include "llvm/IR/Module.h" 19 #include "llvm/IR/PassManager.h" 20 #include "llvm/ProfileData/InstrProf.h" 21 #include "llvm/Support/CommandLine.h" 22 #include <utility> 23 24 using namespace llvm; 25 26 #define DEBUG_TYPE "ctx-instr-lower" 27 28 static cl::list<std::string> ContextRoots( 29 "profile-context-root", cl::Hidden, 30 cl::desc( 31 "A function name, assumed to be global, which will be treated as the " 32 "root of an interesting graph, which will be profiled independently " 33 "from other similar graphs.")); 34 35 bool PGOCtxProfLoweringPass::isCtxIRPGOInstrEnabled() { 36 return !ContextRoots.empty(); 37 } 38 39 // the names of symbols we expect in compiler-rt. Using a namespace for 40 // readability. 41 namespace CompilerRtAPINames { 42 static auto StartCtx = "__llvm_ctx_profile_start_context"; 43 static auto ReleaseCtx = "__llvm_ctx_profile_release_context"; 44 static auto GetCtx = "__llvm_ctx_profile_get_context"; 45 static auto ExpectedCalleeTLS = "__llvm_ctx_profile_expected_callee"; 46 static auto CallsiteTLS = "__llvm_ctx_profile_callsite"; 47 } // namespace CompilerRtAPINames 48 49 namespace { 50 // The lowering logic and state. 51 class CtxInstrumentationLowerer final { 52 Module &M; 53 ModuleAnalysisManager &MAM; 54 Type *ContextNodeTy = nullptr; 55 Type *ContextRootTy = nullptr; 56 57 DenseMap<const Function *, Constant *> ContextRootMap; 58 Function *StartCtx = nullptr; 59 Function *GetCtx = nullptr; 60 Function *ReleaseCtx = nullptr; 61 GlobalVariable *ExpectedCalleeTLS = nullptr; 62 GlobalVariable *CallsiteInfoTLS = nullptr; 63 64 public: 65 CtxInstrumentationLowerer(Module &M, ModuleAnalysisManager &MAM); 66 // return true if lowering happened (i.e. a change was made) 67 bool lowerFunction(Function &F); 68 }; 69 70 // llvm.instrprof.increment[.step] captures the total number of counters as one 71 // of its parameters, and llvm.instrprof.callsite captures the total number of 72 // callsites. Those values are the same for instances of those intrinsics in 73 // this function. Find the first instance of each and return them. 74 std::pair<uint32_t, uint32_t> getNumCountersAndCallsites(const Function &F) { 75 uint32_t NumCounters = 0; 76 uint32_t NumCallsites = 0; 77 for (const auto &BB : F) { 78 for (const auto &I : BB) { 79 if (const auto *Incr = dyn_cast<InstrProfIncrementInst>(&I)) { 80 uint32_t V = 81 static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue()); 82 assert((!NumCounters || V == NumCounters) && 83 "expected all llvm.instrprof.increment[.step] intrinsics to " 84 "have the same total nr of counters parameter"); 85 NumCounters = V; 86 } else if (const auto *CSIntr = dyn_cast<InstrProfCallsite>(&I)) { 87 uint32_t V = 88 static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue()); 89 assert((!NumCallsites || V == NumCallsites) && 90 "expected all llvm.instrprof.callsite intrinsics to have the " 91 "same total nr of callsites parameter"); 92 NumCallsites = V; 93 } 94 #if NDEBUG 95 if (NumCounters && NumCallsites) 96 return std::make_pair(NumCounters, NumCallsites); 97 #endif 98 } 99 } 100 return {NumCounters, NumCallsites}; 101 } 102 } // namespace 103 104 // set up tie-in with compiler-rt. 105 // NOTE!!! 106 // These have to match compiler-rt/lib/ctx_profile/CtxInstrProfiling.h 107 CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M, 108 ModuleAnalysisManager &MAM) 109 : M(M), MAM(MAM) { 110 auto *PointerTy = PointerType::get(M.getContext(), 0); 111 auto *SanitizerMutexType = Type::getInt8Ty(M.getContext()); 112 auto *I32Ty = Type::getInt32Ty(M.getContext()); 113 auto *I64Ty = Type::getInt64Ty(M.getContext()); 114 115 // The ContextRoot type 116 ContextRootTy = 117 StructType::get(M.getContext(), { 118 PointerTy, /*FirstNode*/ 119 PointerTy, /*FirstMemBlock*/ 120 PointerTy, /*CurrentMem*/ 121 SanitizerMutexType, /*Taken*/ 122 }); 123 // The Context header. 124 ContextNodeTy = StructType::get(M.getContext(), { 125 I64Ty, /*Guid*/ 126 PointerTy, /*Next*/ 127 I32Ty, /*NumCounters*/ 128 I32Ty, /*NumCallsites*/ 129 }); 130 131 // Define a global for each entrypoint. We'll reuse the entrypoint's name as 132 // prefix. We assume the entrypoint names to be unique. 133 for (const auto &Fname : ContextRoots) { 134 if (const auto *F = M.getFunction(Fname)) { 135 if (F->isDeclaration()) 136 continue; 137 auto *G = M.getOrInsertGlobal(Fname + "_ctx_root", ContextRootTy); 138 cast<GlobalVariable>(G)->setInitializer( 139 Constant::getNullValue(ContextRootTy)); 140 ContextRootMap.insert(std::make_pair(F, G)); 141 for (const auto &BB : *F) 142 for (const auto &I : BB) 143 if (const auto *CB = dyn_cast<CallBase>(&I)) 144 if (CB->isMustTailCall()) { 145 M.getContext().emitError( 146 "The function " + Fname + 147 " was indicated as a context root, but it features musttail " 148 "calls, which is not supported."); 149 } 150 } 151 } 152 153 // Declare the functions we will call. 154 StartCtx = cast<Function>( 155 M.getOrInsertFunction( 156 CompilerRtAPINames::StartCtx, 157 FunctionType::get(PointerTy, 158 {PointerTy, /*ContextRoot*/ 159 I64Ty, /*Guid*/ I32Ty, 160 /*NumCounters*/ I32Ty /*NumCallsites*/}, 161 false)) 162 .getCallee()); 163 GetCtx = cast<Function>( 164 M.getOrInsertFunction(CompilerRtAPINames::GetCtx, 165 FunctionType::get(PointerTy, 166 {PointerTy, /*Callee*/ 167 I64Ty, /*Guid*/ 168 I32Ty, /*NumCounters*/ 169 I32Ty}, /*NumCallsites*/ 170 false)) 171 .getCallee()); 172 ReleaseCtx = cast<Function>( 173 M.getOrInsertFunction(CompilerRtAPINames::ReleaseCtx, 174 FunctionType::get(Type::getVoidTy(M.getContext()), 175 { 176 PointerTy, /*ContextRoot*/ 177 }, 178 false)) 179 .getCallee()); 180 181 // Declare the TLSes we will need to use. 182 CallsiteInfoTLS = 183 new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage, 184 nullptr, CompilerRtAPINames::CallsiteTLS); 185 CallsiteInfoTLS->setThreadLocal(true); 186 CallsiteInfoTLS->setVisibility(llvm::GlobalValue::HiddenVisibility); 187 ExpectedCalleeTLS = 188 new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage, 189 nullptr, CompilerRtAPINames::ExpectedCalleeTLS); 190 ExpectedCalleeTLS->setThreadLocal(true); 191 ExpectedCalleeTLS->setVisibility(llvm::GlobalValue::HiddenVisibility); 192 } 193 194 PreservedAnalyses PGOCtxProfLoweringPass::run(Module &M, 195 ModuleAnalysisManager &MAM) { 196 CtxInstrumentationLowerer Lowerer(M, MAM); 197 bool Changed = false; 198 for (auto &F : M) 199 Changed |= Lowerer.lowerFunction(F); 200 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); 201 } 202 203 bool CtxInstrumentationLowerer::lowerFunction(Function &F) { 204 if (F.isDeclaration()) 205 return false; 206 auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); 207 auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); 208 209 Value *Guid = nullptr; 210 auto [NumCounters, NumCallsites] = getNumCountersAndCallsites(F); 211 212 Value *Context = nullptr; 213 Value *RealContext = nullptr; 214 215 StructType *ThisContextType = nullptr; 216 Value *TheRootContext = nullptr; 217 Value *ExpectedCalleeTLSAddr = nullptr; 218 Value *CallsiteInfoTLSAddr = nullptr; 219 220 auto &Head = F.getEntryBlock(); 221 for (auto &I : Head) { 222 // Find the increment intrinsic in the entry basic block. 223 if (auto *Mark = dyn_cast<InstrProfIncrementInst>(&I)) { 224 assert(Mark->getIndex()->isZero()); 225 226 IRBuilder<> Builder(Mark); 227 228 Guid = Builder.getInt64( 229 AssignGUIDPass::getGUID(cast<Function>(*Mark->getNameValue()))); 230 // The type of the context of this function is now knowable since we have 231 // NumCallsites and NumCounters. We delcare it here because it's more 232 // convenient - we have the Builder. 233 ThisContextType = StructType::get( 234 F.getContext(), 235 {ContextNodeTy, ArrayType::get(Builder.getInt64Ty(), NumCounters), 236 ArrayType::get(Builder.getPtrTy(), NumCallsites)}); 237 // Figure out which way we obtain the context object for this function - 238 // if it's an entrypoint, then we call StartCtx, otherwise GetCtx. In the 239 // former case, we also set TheRootContext since we need to release it 240 // at the end (plus it can be used to know if we have an entrypoint or a 241 // regular function) 242 auto Iter = ContextRootMap.find(&F); 243 if (Iter != ContextRootMap.end()) { 244 TheRootContext = Iter->second; 245 Context = Builder.CreateCall( 246 StartCtx, {TheRootContext, Guid, Builder.getInt32(NumCounters), 247 Builder.getInt32(NumCallsites)}); 248 ORE.emit( 249 [&] { return OptimizationRemark(DEBUG_TYPE, "Entrypoint", &F); }); 250 } else { 251 Context = 252 Builder.CreateCall(GetCtx, {&F, Guid, Builder.getInt32(NumCounters), 253 Builder.getInt32(NumCallsites)}); 254 ORE.emit([&] { 255 return OptimizationRemark(DEBUG_TYPE, "RegularFunction", &F); 256 }); 257 } 258 // The context could be scratch. 259 auto *CtxAsInt = Builder.CreatePtrToInt(Context, Builder.getInt64Ty()); 260 if (NumCallsites > 0) { 261 // Figure out which index of the TLS 2-element buffers to use. 262 // Scratch context => we use index == 1. Real contexts => index == 0. 263 auto *Index = Builder.CreateAnd(CtxAsInt, Builder.getInt64(1)); 264 // The GEPs corresponding to that index, in the respective TLS. 265 ExpectedCalleeTLSAddr = Builder.CreateGEP( 266 PointerType::getUnqual(F.getContext()), 267 Builder.CreateThreadLocalAddress(ExpectedCalleeTLS), {Index}); 268 CallsiteInfoTLSAddr = Builder.CreateGEP( 269 Builder.getInt32Ty(), 270 Builder.CreateThreadLocalAddress(CallsiteInfoTLS), {Index}); 271 } 272 // Because the context pointer may have LSB set (to indicate scratch), 273 // clear it for the value we use as base address for the counter vector. 274 // This way, if later we want to have "real" (not clobbered) buffers 275 // acting as scratch, the lowering (at least this part of it that deals 276 // with counters) stays the same. 277 RealContext = Builder.CreateIntToPtr( 278 Builder.CreateAnd(CtxAsInt, Builder.getInt64(-2)), 279 PointerType::getUnqual(F.getContext())); 280 I.eraseFromParent(); 281 break; 282 } 283 } 284 if (!Context) { 285 ORE.emit([&] { 286 return OptimizationRemarkMissed(DEBUG_TYPE, "Skip", &F) 287 << "Function doesn't have instrumentation, skipping"; 288 }); 289 return false; 290 } 291 292 bool ContextWasReleased = false; 293 for (auto &BB : F) { 294 for (auto &I : llvm::make_early_inc_range(BB)) { 295 if (auto *Instr = dyn_cast<InstrProfCntrInstBase>(&I)) { 296 IRBuilder<> Builder(Instr); 297 switch (Instr->getIntrinsicID()) { 298 case llvm::Intrinsic::instrprof_increment: 299 case llvm::Intrinsic::instrprof_increment_step: { 300 // Increments (or increment-steps) are just a typical load - increment 301 // - store in the RealContext. 302 auto *AsStep = cast<InstrProfIncrementInst>(Instr); 303 auto *GEP = Builder.CreateGEP( 304 ThisContextType, RealContext, 305 {Builder.getInt32(0), Builder.getInt32(1), AsStep->getIndex()}); 306 Builder.CreateStore( 307 Builder.CreateAdd(Builder.CreateLoad(Builder.getInt64Ty(), GEP), 308 AsStep->getStep()), 309 GEP); 310 } break; 311 case llvm::Intrinsic::instrprof_callsite: 312 // callsite lowering: write the called value in the expected callee 313 // TLS we treat the TLS as volatile because of signal handlers and to 314 // avoid these being moved away from the callsite they decorate. 315 auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Instr); 316 Builder.CreateStore(CSIntrinsic->getCallee(), ExpectedCalleeTLSAddr, 317 true); 318 // write the GEP of the slot in the sub-contexts portion of the 319 // context in TLS. Now, here, we use the actual Context value - as 320 // returned from compiler-rt - which may have the LSB set if the 321 // Context was scratch. Since the header of the context object and 322 // then the values are all 8-aligned (or, really, insofar as we care, 323 // they are even) - if the context is scratch (meaning, an odd value), 324 // so will the GEP. This is important because this is then visible to 325 // compiler-rt which will produce scratch contexts for callers that 326 // have a scratch context. 327 Builder.CreateStore( 328 Builder.CreateGEP(ThisContextType, Context, 329 {Builder.getInt32(0), Builder.getInt32(2), 330 CSIntrinsic->getIndex()}), 331 CallsiteInfoTLSAddr, true); 332 break; 333 } 334 I.eraseFromParent(); 335 } else if (TheRootContext && isa<ReturnInst>(I)) { 336 // Remember to release the context if we are an entrypoint. 337 IRBuilder<> Builder(&I); 338 Builder.CreateCall(ReleaseCtx, {TheRootContext}); 339 ContextWasReleased = true; 340 } 341 } 342 } 343 // FIXME: This would happen if the entrypoint tailcalls. A way to fix would be 344 // to disallow this, (so this then stays as an error), another is to detect 345 // that and then do a wrapper or disallow the tail call. This only affects 346 // instrumentation, when we want to detect the call graph. 347 if (TheRootContext && !ContextWasReleased) 348 F.getContext().emitError( 349 "[ctx_prof] An entrypoint was instrumented but it has no `ret` " 350 "instructions above which to release the context: " + 351 F.getName()); 352 return true; 353 } 354