1 //===- PGOCtxProfFlattening.cpp - Contextual Instr. Flattening ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Flattens the contextual profile and lowers it to MD_prof. 10 // This should happen after all IPO (which is assumed to have maintained the 11 // contextual profile) happened. Flattening consists of summing the values at 12 // the same index of the counters belonging to all the contexts of a function. 13 // The lowering consists of materializing the counter values to function 14 // entrypoint counts and branch probabilities. 15 // 16 // This pass also removes contextual instrumentation, which has been kept around 17 // to facilitate its functionality. 18 // 19 //===----------------------------------------------------------------------===// 20 21 #include "llvm/Transforms/Instrumentation/PGOCtxProfFlattening.h" 22 #include "llvm/ADT/STLExtras.h" 23 #include "llvm/ADT/ScopeExit.h" 24 #include "llvm/Analysis/CtxProfAnalysis.h" 25 #include "llvm/Analysis/ProfileSummaryInfo.h" 26 #include "llvm/IR/Analysis.h" 27 #include "llvm/IR/CFG.h" 28 #include "llvm/IR/Dominators.h" 29 #include "llvm/IR/Instructions.h" 30 #include "llvm/IR/IntrinsicInst.h" 31 #include "llvm/IR/Module.h" 32 #include "llvm/IR/PassManager.h" 33 #include "llvm/IR/ProfileSummary.h" 34 #include "llvm/ProfileData/ProfileCommon.h" 35 #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h" 36 #include "llvm/Transforms/Scalar/DCE.h" 37 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 38 #include <deque> 39 40 using namespace llvm; 41 42 namespace { 43 44 class ProfileAnnotator final { 45 class BBInfo; 46 struct EdgeInfo { 47 BBInfo *const Src; 48 BBInfo *const Dest; 49 std::optional<uint64_t> Count; 50 51 explicit EdgeInfo(BBInfo &Src, BBInfo &Dest) : Src(&Src), Dest(&Dest) {} 52 }; 53 54 class BBInfo { 55 std::optional<uint64_t> Count; 56 // OutEdges is dimensioned to match the number of terminator operands. 57 // Entries in the vector match the index in the terminator operand list. In 58 // some cases - see `shouldExcludeEdge` and its implementation - an entry 59 // will be nullptr. 60 // InEdges doesn't have the above constraint. 61 SmallVector<EdgeInfo *> OutEdges; 62 SmallVector<EdgeInfo *> InEdges; 63 size_t UnknownCountOutEdges = 0; 64 size_t UnknownCountInEdges = 0; 65 66 // Pass AssumeAllKnown when we try to propagate counts from edges to BBs - 67 // because all the edge counters must be known. 68 // Return std::nullopt if there were no edges to sum. The user can decide 69 // how to interpret that. 70 std::optional<uint64_t> getEdgeSum(const SmallVector<EdgeInfo *> &Edges, 71 bool AssumeAllKnown) const { 72 std::optional<uint64_t> Sum; 73 for (const auto *E : Edges) { 74 // `Edges` may be `OutEdges`, case in which `E` could be nullptr. 75 if (E) { 76 if (!Sum.has_value()) 77 Sum = 0; 78 *Sum += (AssumeAllKnown ? *E->Count : E->Count.value_or(0U)); 79 } 80 } 81 return Sum; 82 } 83 84 bool computeCountFrom(const SmallVector<EdgeInfo *> &Edges) { 85 assert(!Count.has_value()); 86 Count = getEdgeSum(Edges, true); 87 return Count.has_value(); 88 } 89 90 void setSingleUnknownEdgeCount(SmallVector<EdgeInfo *> &Edges) { 91 uint64_t KnownSum = getEdgeSum(Edges, false).value_or(0U); 92 uint64_t EdgeVal = *Count > KnownSum ? *Count - KnownSum : 0U; 93 EdgeInfo *E = nullptr; 94 for (auto *I : Edges) 95 if (I && !I->Count.has_value()) { 96 E = I; 97 #ifdef NDEBUG 98 break; 99 #else 100 assert((!E || E == I) && 101 "Expected exactly one edge to have an unknown count, " 102 "found a second one"); 103 continue; 104 #endif 105 } 106 assert(E && "Expected exactly one edge to have an unknown count"); 107 assert(!E->Count.has_value()); 108 E->Count = EdgeVal; 109 assert(E->Src->UnknownCountOutEdges > 0); 110 assert(E->Dest->UnknownCountInEdges > 0); 111 --E->Src->UnknownCountOutEdges; 112 --E->Dest->UnknownCountInEdges; 113 } 114 115 public: 116 BBInfo(size_t NumInEdges, size_t NumOutEdges, std::optional<uint64_t> Count) 117 : Count(Count) { 118 // For in edges, we just want to pre-allocate enough space, since we know 119 // it at this stage. For out edges, we will insert edges at the indices 120 // corresponding to positions in this BB's terminator instruction, so we 121 // construct a default (nullptr values)-initialized vector. A nullptr edge 122 // corresponds to those that are excluded (see shouldExcludeEdge). 123 InEdges.reserve(NumInEdges); 124 OutEdges.resize(NumOutEdges); 125 } 126 127 bool tryTakeCountFromKnownOutEdges(const BasicBlock &BB) { 128 if (!UnknownCountOutEdges) { 129 return computeCountFrom(OutEdges); 130 } 131 return false; 132 } 133 134 bool tryTakeCountFromKnownInEdges(const BasicBlock &BB) { 135 if (!UnknownCountInEdges) { 136 return computeCountFrom(InEdges); 137 } 138 return false; 139 } 140 141 void addInEdge(EdgeInfo &Info) { 142 InEdges.push_back(&Info); 143 ++UnknownCountInEdges; 144 } 145 146 // For the out edges, we care about the position we place them in, which is 147 // the position in terminator instruction's list (at construction). Later, 148 // we build branch_weights metadata with edge frequency values matching 149 // these positions. 150 void addOutEdge(size_t Index, EdgeInfo &Info) { 151 OutEdges[Index] = &Info; 152 ++UnknownCountOutEdges; 153 } 154 155 bool hasCount() const { return Count.has_value(); } 156 157 uint64_t getCount() const { return *Count; } 158 159 bool trySetSingleUnknownInEdgeCount() { 160 if (UnknownCountInEdges == 1) { 161 setSingleUnknownEdgeCount(InEdges); 162 return true; 163 } 164 return false; 165 } 166 167 bool trySetSingleUnknownOutEdgeCount() { 168 if (UnknownCountOutEdges == 1) { 169 setSingleUnknownEdgeCount(OutEdges); 170 return true; 171 } 172 return false; 173 } 174 size_t getNumOutEdges() const { return OutEdges.size(); } 175 176 uint64_t getEdgeCount(size_t Index) const { 177 if (auto *E = OutEdges[Index]) 178 return *E->Count; 179 return 0U; 180 } 181 }; 182 183 Function &F; 184 const SmallVectorImpl<uint64_t> &Counters; 185 // To be accessed through getBBInfo() after construction. 186 std::map<const BasicBlock *, BBInfo> BBInfos; 187 std::vector<EdgeInfo> EdgeInfos; 188 InstrProfSummaryBuilder &PB; 189 190 // This is an adaptation of PGOUseFunc::populateCounters. 191 // FIXME(mtrofin): look into factoring the code to share one implementation. 192 void propagateCounterValues(const SmallVectorImpl<uint64_t> &Counters) { 193 bool KeepGoing = true; 194 while (KeepGoing) { 195 KeepGoing = false; 196 for (const auto &BB : F) { 197 auto &Info = getBBInfo(BB); 198 if (!Info.hasCount()) 199 KeepGoing |= Info.tryTakeCountFromKnownOutEdges(BB) || 200 Info.tryTakeCountFromKnownInEdges(BB); 201 if (Info.hasCount()) { 202 KeepGoing |= Info.trySetSingleUnknownOutEdgeCount(); 203 KeepGoing |= Info.trySetSingleUnknownInEdgeCount(); 204 } 205 } 206 } 207 } 208 // The only criteria for exclusion is faux suspend -> exit edges in presplit 209 // coroutines. The API serves for readability, currently. 210 bool shouldExcludeEdge(const BasicBlock &Src, const BasicBlock &Dest) const { 211 return llvm::isPresplitCoroSuspendExitEdge(Src, Dest); 212 } 213 214 BBInfo &getBBInfo(const BasicBlock &BB) { return BBInfos.find(&BB)->second; } 215 216 const BBInfo &getBBInfo(const BasicBlock &BB) const { 217 return BBInfos.find(&BB)->second; 218 } 219 220 // validation function after we propagate the counters: all BBs and edges' 221 // counters must have a value. 222 bool allCountersAreAssigned() const { 223 for (const auto &BBInfo : BBInfos) 224 if (!BBInfo.second.hasCount()) 225 return false; 226 for (const auto &EdgeInfo : EdgeInfos) 227 if (!EdgeInfo.Count.has_value()) 228 return false; 229 return true; 230 } 231 232 /// Check that all paths from the entry basic block that use edges with 233 /// non-zero counts arrive at a basic block with no successors (i.e. "exit") 234 bool allTakenPathsExit() const { 235 std::deque<const BasicBlock *> Worklist; 236 DenseSet<const BasicBlock *> Visited; 237 Worklist.push_back(&F.getEntryBlock()); 238 bool HitExit = false; 239 while (!Worklist.empty()) { 240 const auto *BB = Worklist.front(); 241 Worklist.pop_front(); 242 if (!Visited.insert(BB).second) 243 continue; 244 if (succ_size(BB) == 0) { 245 if (isa<UnreachableInst>(BB->getTerminator())) 246 return false; 247 HitExit = true; 248 continue; 249 } 250 if (succ_size(BB) == 1) { 251 Worklist.push_back(BB->getUniqueSuccessor()); 252 continue; 253 } 254 const auto &BBInfo = getBBInfo(*BB); 255 bool HasAWayOut = false; 256 for (auto I = 0U; I < BB->getTerminator()->getNumSuccessors(); ++I) { 257 const auto *Succ = BB->getTerminator()->getSuccessor(I); 258 if (!shouldExcludeEdge(*BB, *Succ)) { 259 if (BBInfo.getEdgeCount(I) > 0) { 260 HasAWayOut = true; 261 Worklist.push_back(Succ); 262 } 263 } 264 } 265 if (!HasAWayOut) 266 return false; 267 } 268 return HitExit; 269 } 270 271 bool allNonColdSelectsHaveProfile() const { 272 for (const auto &BB : F) { 273 if (getBBInfo(BB).getCount() > 0) { 274 for (const auto &I : BB) { 275 if (const auto *SI = dyn_cast<SelectInst>(&I)) { 276 if (!SI->getMetadata(LLVMContext::MD_prof)) { 277 return false; 278 } 279 } 280 } 281 } 282 } 283 return true; 284 } 285 286 public: 287 ProfileAnnotator(Function &F, const SmallVectorImpl<uint64_t> &Counters, 288 InstrProfSummaryBuilder &PB) 289 : F(F), Counters(Counters), PB(PB) { 290 assert(!F.isDeclaration()); 291 assert(!Counters.empty()); 292 size_t NrEdges = 0; 293 for (const auto &BB : F) { 294 std::optional<uint64_t> Count; 295 if (auto *Ins = CtxProfAnalysis::getBBInstrumentation( 296 const_cast<BasicBlock &>(BB))) { 297 auto Index = Ins->getIndex()->getZExtValue(); 298 assert(Index < Counters.size() && 299 "The index must be inside the counters vector by construction - " 300 "tripping this assertion indicates a bug in how the contextual " 301 "profile is managed by IPO transforms"); 302 (void)Index; 303 Count = Counters[Ins->getIndex()->getZExtValue()]; 304 } else if (isa<UnreachableInst>(BB.getTerminator())) { 305 // The program presumably didn't crash. 306 Count = 0; 307 } 308 auto [It, Ins] = 309 BBInfos.insert({&BB, {pred_size(&BB), succ_size(&BB), Count}}); 310 (void)Ins; 311 assert(Ins && "We iterate through the function's BBs, no reason to " 312 "insert one more than once"); 313 NrEdges += llvm::count_if(successors(&BB), [&](const auto *Succ) { 314 return !shouldExcludeEdge(BB, *Succ); 315 }); 316 } 317 // Pre-allocate the vector, we want references to its contents to be stable. 318 EdgeInfos.reserve(NrEdges); 319 for (const auto &BB : F) { 320 auto &Info = getBBInfo(BB); 321 for (auto I = 0U; I < BB.getTerminator()->getNumSuccessors(); ++I) { 322 const auto *Succ = BB.getTerminator()->getSuccessor(I); 323 if (!shouldExcludeEdge(BB, *Succ)) { 324 auto &EI = EdgeInfos.emplace_back(getBBInfo(BB), getBBInfo(*Succ)); 325 Info.addOutEdge(I, EI); 326 getBBInfo(*Succ).addInEdge(EI); 327 } 328 } 329 } 330 assert(EdgeInfos.capacity() == NrEdges && 331 "The capacity of EdgeInfos should have stayed unchanged it was " 332 "populated, because we need pointers to its contents to be stable"); 333 } 334 335 void setProfileForSelectInstructions(BasicBlock &BB, const BBInfo &BBInfo) { 336 if (BBInfo.getCount() == 0) 337 return; 338 339 for (auto &I : BB) { 340 if (auto *SI = dyn_cast<SelectInst>(&I)) { 341 if (auto *Step = CtxProfAnalysis::getSelectInstrumentation(*SI)) { 342 auto Index = Step->getIndex()->getZExtValue(); 343 assert(Index < Counters.size() && 344 "The index of the step instruction must be inside the " 345 "counters vector by " 346 "construction - tripping this assertion indicates a bug in " 347 "how the contextual profile is managed by IPO transforms"); 348 auto TotalCount = BBInfo.getCount(); 349 auto TrueCount = Counters[Index]; 350 auto FalseCount = 351 (TotalCount > TrueCount ? TotalCount - TrueCount : 0U); 352 setProfMetadata(F.getParent(), SI, {TrueCount, FalseCount}, 353 std::max(TrueCount, FalseCount)); 354 PB.addInternalCount(TrueCount); 355 PB.addInternalCount(FalseCount); 356 } 357 } 358 } 359 } 360 361 /// Assign branch weights and function entry count. Also update the PSI 362 /// builder. 363 void assignProfileData() { 364 assert(!Counters.empty()); 365 propagateCounterValues(Counters); 366 F.setEntryCount(Counters[0]); 367 PB.addEntryCount(Counters[0]); 368 369 for (auto &BB : F) { 370 const auto &BBInfo = getBBInfo(BB); 371 setProfileForSelectInstructions(BB, BBInfo); 372 if (succ_size(&BB) < 2) 373 continue; 374 auto *Term = BB.getTerminator(); 375 SmallVector<uint64_t, 2> EdgeCounts(Term->getNumSuccessors(), 0); 376 uint64_t MaxCount = 0; 377 378 for (unsigned SuccIdx = 0, Size = BBInfo.getNumOutEdges(); SuccIdx < Size; 379 ++SuccIdx) { 380 uint64_t EdgeCount = BBInfo.getEdgeCount(SuccIdx); 381 if (EdgeCount > MaxCount) 382 MaxCount = EdgeCount; 383 EdgeCounts[SuccIdx] = EdgeCount; 384 PB.addInternalCount(EdgeCount); 385 } 386 387 if (MaxCount != 0) 388 setProfMetadata(F.getParent(), Term, EdgeCounts, MaxCount); 389 } 390 assert(allCountersAreAssigned() && 391 "[ctx-prof] Expected all counters have been assigned."); 392 assert(allTakenPathsExit() && 393 "[ctx-prof] Encountered a BB with more than one successor, where " 394 "all outgoing edges have a 0 count. This occurs in non-exiting " 395 "functions (message pumps, usually) which are not supported in the " 396 "contextual profiling case"); 397 assert(allNonColdSelectsHaveProfile() && 398 "[ctx-prof] All non-cold select instructions were expected to have " 399 "a profile."); 400 } 401 }; 402 403 [[maybe_unused]] bool areAllBBsReachable(const Function &F, 404 FunctionAnalysisManager &FAM) { 405 auto &DT = FAM.getResult<DominatorTreeAnalysis>(const_cast<Function &>(F)); 406 return llvm::all_of( 407 F, [&](const BasicBlock &BB) { return DT.isReachableFromEntry(&BB); }); 408 } 409 410 void clearColdFunctionProfile(Function &F) { 411 for (auto &BB : F) 412 BB.getTerminator()->setMetadata(LLVMContext::MD_prof, nullptr); 413 F.setEntryCount(0U); 414 } 415 416 void removeInstrumentation(Function &F) { 417 for (auto &BB : F) 418 for (auto &I : llvm::make_early_inc_range(BB)) 419 if (isa<InstrProfCntrInstBase>(I)) 420 I.eraseFromParent(); 421 } 422 423 } // namespace 424 425 PreservedAnalyses PGOCtxProfFlatteningPass::run(Module &M, 426 ModuleAnalysisManager &MAM) { 427 // Ensure in all cases the instrumentation is removed: if this module had no 428 // roots, the contextual profile would evaluate to false, but there would 429 // still be instrumentation. 430 // Note: in such cases we leave as-is any other profile info (if present - 431 // e.g. synthetic weights, etc) because it wouldn't interfere with the 432 // contextual - based one (which would be in other modules) 433 auto OnExit = llvm::make_scope_exit([&]() { 434 for (auto &F : M) 435 removeInstrumentation(F); 436 }); 437 auto &CtxProf = MAM.getResult<CtxProfAnalysis>(M); 438 if (!CtxProf) 439 return PreservedAnalyses::none(); 440 441 const auto FlattenedProfile = CtxProf.flatten(); 442 443 InstrProfSummaryBuilder PB(ProfileSummaryBuilder::DefaultCutoffs); 444 for (auto &F : M) { 445 if (F.isDeclaration()) 446 continue; 447 448 assert(areAllBBsReachable( 449 F, MAM.getResult<FunctionAnalysisManagerModuleProxy>(M) 450 .getManager()) && 451 "Function has unreacheable basic blocks. The expectation was that " 452 "DCE was run before."); 453 454 auto It = FlattenedProfile.find(AssignGUIDPass::getGUID(F)); 455 // If this function didn't appear in the contextual profile, it's cold. 456 if (It == FlattenedProfile.end()) 457 clearColdFunctionProfile(F); 458 else { 459 ProfileAnnotator S(F, It->second, PB); 460 S.assignProfileData(); 461 } 462 } 463 464 auto &PSI = MAM.getResult<ProfileSummaryAnalysis>(M); 465 466 M.setProfileSummary(PB.getSummary()->getMD(M.getContext()), 467 ProfileSummary::Kind::PSK_Instr); 468 PSI.refresh(); 469 return PreservedAnalyses::none(); 470 } 471