1 //===- SampleContextTracker.cpp - Context-sensitive Profile Tracker -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the SampleContextTracker used by CSSPGO. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "llvm/Transforms/IPO/SampleContextTracker.h" 14 #include "llvm/ADT/StringMap.h" 15 #include "llvm/ADT/StringRef.h" 16 #include "llvm/IR/DebugInfoMetadata.h" 17 #include "llvm/IR/Instructions.h" 18 #include "llvm/ProfileData/SampleProf.h" 19 #include <map> 20 #include <queue> 21 #include <vector> 22 23 using namespace llvm; 24 using namespace sampleprof; 25 26 #define DEBUG_TYPE "sample-context-tracker" 27 28 namespace llvm { 29 30 ContextTrieNode *ContextTrieNode::getChildContext(const LineLocation &CallSite, 31 StringRef CalleeName) { 32 if (CalleeName.empty()) 33 return getChildContext(CallSite); 34 35 uint32_t Hash = nodeHash(CalleeName, CallSite); 36 auto It = AllChildContext.find(Hash); 37 if (It != AllChildContext.end()) 38 return &It->second; 39 return nullptr; 40 } 41 42 ContextTrieNode * 43 ContextTrieNode::getChildContext(const LineLocation &CallSite) { 44 // CSFDO-TODO: This could be slow, change AllChildContext so we can 45 // do point look up for child node by call site alone. 46 // CSFDO-TODO: Return the child with max count for indirect call 47 ContextTrieNode *ChildNodeRet = nullptr; 48 for (auto &It : AllChildContext) { 49 ContextTrieNode &ChildNode = It.second; 50 if (ChildNode.CallSiteLoc == CallSite) { 51 if (ChildNodeRet) 52 return nullptr; 53 else 54 ChildNodeRet = &ChildNode; 55 } 56 } 57 58 return ChildNodeRet; 59 } 60 61 ContextTrieNode &ContextTrieNode::moveToChildContext( 62 const LineLocation &CallSite, ContextTrieNode &&NodeToMove, 63 StringRef ContextStrToRemove, bool DeleteNode) { 64 uint32_t Hash = nodeHash(NodeToMove.getFuncName(), CallSite); 65 assert(!AllChildContext.count(Hash) && "Node to remove must exist"); 66 LineLocation OldCallSite = NodeToMove.CallSiteLoc; 67 ContextTrieNode &OldParentContext = *NodeToMove.getParentContext(); 68 AllChildContext[Hash] = NodeToMove; 69 ContextTrieNode &NewNode = AllChildContext[Hash]; 70 NewNode.CallSiteLoc = CallSite; 71 72 // Walk through nodes in the moved the subtree, and update 73 // FunctionSamples' context as for the context promotion. 74 // We also need to set new parant link for all children. 75 std::queue<ContextTrieNode *> NodeToUpdate; 76 NewNode.setParentContext(this); 77 NodeToUpdate.push(&NewNode); 78 79 while (!NodeToUpdate.empty()) { 80 ContextTrieNode *Node = NodeToUpdate.front(); 81 NodeToUpdate.pop(); 82 FunctionSamples *FSamples = Node->getFunctionSamples(); 83 84 if (FSamples) { 85 FSamples->getContext().promoteOnPath(ContextStrToRemove); 86 FSamples->getContext().setState(SyntheticContext); 87 LLVM_DEBUG(dbgs() << " Context promoted to: " << FSamples->getContext() 88 << "\n"); 89 } 90 91 for (auto &It : Node->getAllChildContext()) { 92 ContextTrieNode *ChildNode = &It.second; 93 ChildNode->setParentContext(Node); 94 NodeToUpdate.push(ChildNode); 95 } 96 } 97 98 // Original context no longer needed, destroy if requested. 99 if (DeleteNode) 100 OldParentContext.removeChildContext(OldCallSite, NewNode.getFuncName()); 101 102 return NewNode; 103 } 104 105 void ContextTrieNode::removeChildContext(const LineLocation &CallSite, 106 StringRef CalleeName) { 107 uint32_t Hash = nodeHash(CalleeName, CallSite); 108 // Note this essentially calls dtor and destroys that child context 109 AllChildContext.erase(Hash); 110 } 111 112 std::map<uint32_t, ContextTrieNode> &ContextTrieNode::getAllChildContext() { 113 return AllChildContext; 114 } 115 116 const StringRef ContextTrieNode::getFuncName() const { return FuncName; } 117 118 FunctionSamples *ContextTrieNode::getFunctionSamples() const { 119 return FuncSamples; 120 } 121 122 void ContextTrieNode::setFunctionSamples(FunctionSamples *FSamples) { 123 FuncSamples = FSamples; 124 } 125 126 LineLocation ContextTrieNode::getCallSiteLoc() const { return CallSiteLoc; } 127 128 ContextTrieNode *ContextTrieNode::getParentContext() const { 129 return ParentContext; 130 } 131 132 void ContextTrieNode::setParentContext(ContextTrieNode *Parent) { 133 ParentContext = Parent; 134 } 135 136 void ContextTrieNode::dump() { 137 dbgs() << "Node: " << FuncName << "\n" 138 << " Callsite: " << CallSiteLoc << "\n" 139 << " Children:\n"; 140 141 for (auto &It : AllChildContext) { 142 dbgs() << " Node: " << It.second.getFuncName() << "\n"; 143 } 144 } 145 146 uint32_t ContextTrieNode::nodeHash(StringRef ChildName, 147 const LineLocation &Callsite) { 148 // We still use child's name for child hash, this is 149 // because for children of root node, we don't have 150 // different line/discriminator, and we'll rely on name 151 // to differentiate children. 152 uint32_t NameHash = std::hash<std::string>{}(ChildName.str()); 153 uint32_t LocId = (Callsite.LineOffset << 16) | Callsite.Discriminator; 154 return NameHash + (LocId << 5) + LocId; 155 } 156 157 ContextTrieNode *ContextTrieNode::getOrCreateChildContext( 158 const LineLocation &CallSite, StringRef CalleeName, bool AllowCreate) { 159 uint32_t Hash = nodeHash(CalleeName, CallSite); 160 auto It = AllChildContext.find(Hash); 161 if (It != AllChildContext.end()) { 162 assert(It->second.getFuncName() == CalleeName && 163 "Hash collision for child context node"); 164 return &It->second; 165 } 166 167 if (!AllowCreate) 168 return nullptr; 169 170 AllChildContext[Hash] = ContextTrieNode(this, CalleeName, nullptr, CallSite); 171 return &AllChildContext[Hash]; 172 } 173 174 // Profiler tracker than manages profiles and its associated context 175 SampleContextTracker::SampleContextTracker( 176 StringMap<FunctionSamples> &Profiles) { 177 for (auto &FuncSample : Profiles) { 178 FunctionSamples *FSamples = &FuncSample.second; 179 SampleContext Context(FuncSample.first(), RawContext); 180 LLVM_DEBUG(dbgs() << "Tracking Context for function: " << Context << "\n"); 181 if (!Context.isBaseContext()) 182 FuncToCtxtProfileSet[Context.getName()].insert(FSamples); 183 ContextTrieNode *NewNode = getOrCreateContextPath(Context, true); 184 assert(!NewNode->getFunctionSamples() && 185 "New node can't have sample profile"); 186 NewNode->setFunctionSamples(FSamples); 187 } 188 } 189 190 FunctionSamples * 191 SampleContextTracker::getCalleeContextSamplesFor(const CallBase &Inst, 192 StringRef CalleeName) { 193 LLVM_DEBUG(dbgs() << "Getting callee context for instr: " << Inst << "\n"); 194 // CSFDO-TODO: We use CalleeName to differentiate indirect call 195 // We need to get sample for indirect callee too. 196 DILocation *DIL = Inst.getDebugLoc(); 197 if (!DIL) 198 return nullptr; 199 200 ContextTrieNode *CalleeContext = getCalleeContextFor(DIL, CalleeName); 201 if (CalleeContext) { 202 FunctionSamples *FSamples = CalleeContext->getFunctionSamples(); 203 LLVM_DEBUG(if (FSamples) { 204 dbgs() << " Callee context found: " << FSamples->getContext() << "\n"; 205 }); 206 return FSamples; 207 } 208 209 return nullptr; 210 } 211 212 FunctionSamples * 213 SampleContextTracker::getContextSamplesFor(const DILocation *DIL) { 214 assert(DIL && "Expect non-null location"); 215 216 ContextTrieNode *ContextNode = getContextFor(DIL); 217 if (!ContextNode) 218 return nullptr; 219 220 // We may have inlined callees during pre-LTO compilation, in which case 221 // we need to rely on the inline stack from !dbg to mark context profile 222 // as inlined, instead of `MarkContextSamplesInlined` during inlining. 223 // Sample profile loader walks through all instructions to get profile, 224 // which calls this function. So once that is done, all previously inlined 225 // context profile should be marked properly. 226 FunctionSamples *Samples = ContextNode->getFunctionSamples(); 227 if (Samples && ContextNode->getParentContext() != &RootContext) 228 Samples->getContext().setState(InlinedContext); 229 230 return Samples; 231 } 232 233 FunctionSamples * 234 SampleContextTracker::getContextSamplesFor(const SampleContext &Context) { 235 ContextTrieNode *Node = getContextFor(Context); 236 if (!Node) 237 return nullptr; 238 239 return Node->getFunctionSamples(); 240 } 241 242 FunctionSamples *SampleContextTracker::getBaseSamplesFor(const Function &Func, 243 bool MergeContext) { 244 StringRef CanonName = FunctionSamples::getCanonicalFnName(Func); 245 return getBaseSamplesFor(CanonName, MergeContext); 246 } 247 248 FunctionSamples *SampleContextTracker::getBaseSamplesFor(StringRef Name, 249 bool MergeContext) { 250 LLVM_DEBUG(dbgs() << "Getting base profile for function: " << Name << "\n"); 251 // Base profile is top-level node (child of root node), so try to retrieve 252 // existing top-level node for given function first. If it exists, it could be 253 // that we've merged base profile before, or there's actually context-less 254 // profile from the input (e.g. due to unreliable stack walking). 255 ContextTrieNode *Node = getTopLevelContextNode(Name); 256 if (MergeContext) { 257 LLVM_DEBUG(dbgs() << " Merging context profile into base profile: " << Name 258 << "\n"); 259 260 // We have profile for function under different contexts, 261 // create synthetic base profile and merge context profiles 262 // into base profile. 263 for (auto *CSamples : FuncToCtxtProfileSet[Name]) { 264 SampleContext &Context = CSamples->getContext(); 265 ContextTrieNode *FromNode = getContextFor(Context); 266 if (FromNode == Node) 267 continue; 268 269 // Skip inlined context profile and also don't re-merge any context 270 if (Context.hasState(InlinedContext) || Context.hasState(MergedContext)) 271 continue; 272 273 ContextTrieNode &ToNode = promoteMergeContextSamplesTree(*FromNode); 274 assert((!Node || Node == &ToNode) && "Expect only one base profile"); 275 Node = &ToNode; 276 } 277 } 278 279 // Still no profile even after merge/promotion (if allowed) 280 if (!Node) 281 return nullptr; 282 283 return Node->getFunctionSamples(); 284 } 285 286 void SampleContextTracker::markContextSamplesInlined( 287 const FunctionSamples *InlinedSamples) { 288 assert(InlinedSamples && "Expect non-null inlined samples"); 289 LLVM_DEBUG(dbgs() << "Marking context profile as inlined: " 290 << InlinedSamples->getContext() << "\n"); 291 InlinedSamples->getContext().setState(InlinedContext); 292 } 293 294 void SampleContextTracker::promoteMergeContextSamplesTree( 295 const Instruction &Inst, StringRef CalleeName) { 296 LLVM_DEBUG(dbgs() << "Promoting and merging context tree for instr: \n" 297 << Inst << "\n"); 298 // CSFDO-TODO: We also need to promote context profile from indirect 299 // calls. We won't have callee names from those from call instr. 300 if (CalleeName.empty()) 301 return; 302 303 // Get the caller context for the call instruction, we don't use callee 304 // name from call because there can be context from indirect calls too. 305 DILocation *DIL = Inst.getDebugLoc(); 306 ContextTrieNode *CallerNode = getContextFor(DIL); 307 if (!CallerNode) 308 return; 309 310 // Get the context that needs to be promoted 311 LineLocation CallSite(FunctionSamples::getOffset(DIL), 312 DIL->getBaseDiscriminator()); 313 ContextTrieNode *NodeToPromo = 314 CallerNode->getChildContext(CallSite, CalleeName); 315 if (!NodeToPromo) 316 return; 317 318 promoteMergeContextSamplesTree(*NodeToPromo); 319 } 320 321 ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree( 322 ContextTrieNode &NodeToPromo) { 323 // Promote the input node to be directly under root. This can happen 324 // when we decided to not inline a function under context represented 325 // by the input node. The promote and merge is then needed to reflect 326 // the context profile in the base (context-less) profile. 327 FunctionSamples *FromSamples = NodeToPromo.getFunctionSamples(); 328 assert(FromSamples && "Shouldn't promote a context without profile"); 329 LLVM_DEBUG(dbgs() << " Found context tree root to promote: " 330 << FromSamples->getContext() << "\n"); 331 332 StringRef ContextStrToRemove = FromSamples->getContext().getCallingContext(); 333 return promoteMergeContextSamplesTree(NodeToPromo, RootContext, 334 ContextStrToRemove); 335 } 336 337 void SampleContextTracker::dump() { 338 dbgs() << "Context Profile Tree:\n"; 339 std::queue<ContextTrieNode *> NodeQueue; 340 NodeQueue.push(&RootContext); 341 342 while (!NodeQueue.empty()) { 343 ContextTrieNode *Node = NodeQueue.front(); 344 NodeQueue.pop(); 345 Node->dump(); 346 347 for (auto &It : Node->getAllChildContext()) { 348 ContextTrieNode *ChildNode = &It.second; 349 NodeQueue.push(ChildNode); 350 } 351 } 352 } 353 354 ContextTrieNode * 355 SampleContextTracker::getContextFor(const SampleContext &Context) { 356 return getOrCreateContextPath(Context, false); 357 } 358 359 ContextTrieNode * 360 SampleContextTracker::getCalleeContextFor(const DILocation *DIL, 361 StringRef CalleeName) { 362 assert(DIL && "Expect non-null location"); 363 364 // CSSPGO-TODO: need to support indirect callee 365 if (CalleeName.empty()) 366 return nullptr; 367 368 ContextTrieNode *CallContext = getContextFor(DIL); 369 if (!CallContext) 370 return nullptr; 371 372 return CallContext->getChildContext( 373 LineLocation(FunctionSamples::getOffset(DIL), 374 DIL->getBaseDiscriminator()), 375 CalleeName); 376 } 377 378 ContextTrieNode *SampleContextTracker::getContextFor(const DILocation *DIL) { 379 assert(DIL && "Expect non-null location"); 380 SmallVector<std::pair<LineLocation, StringRef>, 10> S; 381 382 // Use C++ linkage name if possible. 383 const DILocation *PrevDIL = DIL; 384 for (DIL = DIL->getInlinedAt(); DIL; DIL = DIL->getInlinedAt()) { 385 StringRef Name = PrevDIL->getScope()->getSubprogram()->getLinkageName(); 386 if (Name.empty()) 387 Name = PrevDIL->getScope()->getSubprogram()->getName(); 388 S.push_back( 389 std::make_pair(LineLocation(FunctionSamples::getOffset(DIL), 390 DIL->getBaseDiscriminator()), Name)); 391 PrevDIL = DIL; 392 } 393 394 // Push root node, note that root node like main may only 395 // a name, but not linkage name. 396 StringRef RootName = PrevDIL->getScope()->getSubprogram()->getLinkageName(); 397 if (RootName.empty()) 398 RootName = PrevDIL->getScope()->getSubprogram()->getName(); 399 S.push_back(std::make_pair(LineLocation(0, 0), RootName)); 400 401 ContextTrieNode *ContextNode = &RootContext; 402 int I = S.size(); 403 while (--I >= 0 && ContextNode) { 404 LineLocation &CallSite = S[I].first; 405 StringRef &CalleeName = S[I].second; 406 ContextNode = ContextNode->getChildContext(CallSite, CalleeName); 407 } 408 409 if (I < 0) 410 return ContextNode; 411 412 return nullptr; 413 } 414 415 ContextTrieNode * 416 SampleContextTracker::getOrCreateContextPath(const SampleContext &Context, 417 bool AllowCreate) { 418 ContextTrieNode *ContextNode = &RootContext; 419 StringRef ContextRemain = Context; 420 StringRef ChildContext; 421 StringRef CalleeName; 422 LineLocation CallSiteLoc(0, 0); 423 424 while (ContextNode && !ContextRemain.empty()) { 425 auto ContextSplit = SampleContext::splitContextString(ContextRemain); 426 ChildContext = ContextSplit.first; 427 ContextRemain = ContextSplit.second; 428 LineLocation NextCallSiteLoc(0, 0); 429 SampleContext::decodeContextString(ChildContext, CalleeName, 430 NextCallSiteLoc); 431 432 // Create child node at parent line/disc location 433 if (AllowCreate) { 434 ContextNode = 435 ContextNode->getOrCreateChildContext(CallSiteLoc, CalleeName); 436 } else { 437 ContextNode = ContextNode->getChildContext(CallSiteLoc, CalleeName); 438 } 439 CallSiteLoc = NextCallSiteLoc; 440 } 441 442 assert((!AllowCreate || ContextNode) && 443 "Node must exist if creation is allowed"); 444 return ContextNode; 445 } 446 447 ContextTrieNode *SampleContextTracker::getTopLevelContextNode(StringRef FName) { 448 return RootContext.getChildContext(LineLocation(0, 0), FName); 449 } 450 451 ContextTrieNode &SampleContextTracker::addTopLevelContextNode(StringRef FName) { 452 assert(!getTopLevelContextNode(FName) && "Node to add must not exist"); 453 return *RootContext.getOrCreateChildContext(LineLocation(0, 0), FName); 454 } 455 456 void SampleContextTracker::mergeContextNode(ContextTrieNode &FromNode, 457 ContextTrieNode &ToNode, 458 StringRef ContextStrToRemove) { 459 FunctionSamples *FromSamples = FromNode.getFunctionSamples(); 460 FunctionSamples *ToSamples = ToNode.getFunctionSamples(); 461 if (FromSamples && ToSamples) { 462 // Merge/duplicate FromSamples into ToSamples 463 ToSamples->merge(*FromSamples); 464 ToSamples->getContext().setState(SyntheticContext); 465 FromSamples->getContext().setState(MergedContext); 466 } else if (FromSamples) { 467 // Transfer FromSamples from FromNode to ToNode 468 ToNode.setFunctionSamples(FromSamples); 469 FromSamples->getContext().setState(SyntheticContext); 470 FromSamples->getContext().promoteOnPath(ContextStrToRemove); 471 FromNode.setFunctionSamples(nullptr); 472 } 473 } 474 475 ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree( 476 ContextTrieNode &FromNode, ContextTrieNode &ToNodeParent, 477 StringRef ContextStrToRemove) { 478 assert(!ContextStrToRemove.empty() && "Context to remove can't be empty"); 479 480 // Ignore call site location if destination is top level under root 481 LineLocation NewCallSiteLoc = LineLocation(0, 0); 482 LineLocation OldCallSiteLoc = FromNode.getCallSiteLoc(); 483 ContextTrieNode &FromNodeParent = *FromNode.getParentContext(); 484 ContextTrieNode *ToNode = nullptr; 485 bool MoveToRoot = (&ToNodeParent == &RootContext); 486 if (!MoveToRoot) { 487 NewCallSiteLoc = OldCallSiteLoc; 488 } 489 490 // Locate destination node, create/move if not existing 491 ToNode = ToNodeParent.getChildContext(NewCallSiteLoc, FromNode.getFuncName()); 492 if (!ToNode) { 493 // Do not delete node to move from its parent here because 494 // caller is iterating over children of that parent node. 495 ToNode = &ToNodeParent.moveToChildContext( 496 NewCallSiteLoc, std::move(FromNode), ContextStrToRemove, false); 497 } else { 498 // Destination node exists, merge samples for the context tree 499 mergeContextNode(FromNode, *ToNode, ContextStrToRemove); 500 LLVM_DEBUG(dbgs() << " Context promoted and merged to: " 501 << ToNode->getFunctionSamples()->getContext() << "\n"); 502 503 // Recursively promote and merge children 504 for (auto &It : FromNode.getAllChildContext()) { 505 ContextTrieNode &FromChildNode = It.second; 506 promoteMergeContextSamplesTree(FromChildNode, *ToNode, 507 ContextStrToRemove); 508 } 509 510 // Remove children once they're all merged 511 FromNode.getAllChildContext().clear(); 512 } 513 514 // For root of subtree, remove itself from old parent too 515 if (MoveToRoot) 516 FromNodeParent.removeChildContext(OldCallSiteLoc, ToNode->getFuncName()); 517 518 return *ToNode; 519 } 520 521 } // namespace llvm 522