1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Instrumentation-based profile-guided optimization 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CodeGenPGO.h" 14 #include "CodeGenFunction.h" 15 #include "CoverageMappingGen.h" 16 #include "clang/AST/RecursiveASTVisitor.h" 17 #include "clang/AST/StmtVisitor.h" 18 #include "llvm/IR/Intrinsics.h" 19 #include "llvm/IR/MDBuilder.h" 20 #include "llvm/Support/CommandLine.h" 21 #include "llvm/Support/Endian.h" 22 #include "llvm/Support/FileSystem.h" 23 #include "llvm/Support/MD5.h" 24 #include <optional> 25 26 static llvm::cl::opt<bool> 27 EnableValueProfiling("enable-value-profiling", 28 llvm::cl::desc("Enable value profiling"), 29 llvm::cl::Hidden, llvm::cl::init(false)); 30 31 using namespace clang; 32 using namespace CodeGen; 33 34 void CodeGenPGO::setFuncName(StringRef Name, 35 llvm::GlobalValue::LinkageTypes Linkage) { 36 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 37 FuncName = llvm::getPGOFuncName( 38 Name, Linkage, CGM.getCodeGenOpts().MainFileName, 39 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version); 40 41 // If we're generating a profile, create a variable for the name. 42 if (CGM.getCodeGenOpts().hasProfileClangInstr()) 43 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName); 44 } 45 46 void CodeGenPGO::setFuncName(llvm::Function *Fn) { 47 setFuncName(Fn->getName(), Fn->getLinkage()); 48 // Create PGOFuncName meta data. 49 llvm::createPGOFuncNameMetadata(*Fn, FuncName); 50 } 51 52 /// The version of the PGO hash algorithm. 53 enum PGOHashVersion : unsigned { 54 PGO_HASH_V1, 55 PGO_HASH_V2, 56 PGO_HASH_V3, 57 58 // Keep this set to the latest hash version. 59 PGO_HASH_LATEST = PGO_HASH_V3 60 }; 61 62 namespace { 63 /// Stable hasher for PGO region counters. 64 /// 65 /// PGOHash produces a stable hash of a given function's control flow. 66 /// 67 /// Changing the output of this hash will invalidate all previously generated 68 /// profiles -- i.e., don't do it. 69 /// 70 /// \note When this hash does eventually change (years?), we still need to 71 /// support old hashes. We'll need to pull in the version number from the 72 /// profile data format and use the matching hash function. 73 class PGOHash { 74 uint64_t Working; 75 unsigned Count; 76 PGOHashVersion HashVersion; 77 llvm::MD5 MD5; 78 79 static const int NumBitsPerType = 6; 80 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType; 81 static const unsigned TooBig = 1u << NumBitsPerType; 82 83 public: 84 /// Hash values for AST nodes. 85 /// 86 /// Distinct values for AST nodes that have region counters attached. 87 /// 88 /// These values must be stable. All new members must be added at the end, 89 /// and no members should be removed. Changing the enumeration value for an 90 /// AST node will affect the hash of every function that contains that node. 91 enum HashType : unsigned char { 92 None = 0, 93 LabelStmt = 1, 94 WhileStmt, 95 DoStmt, 96 ForStmt, 97 CXXForRangeStmt, 98 ObjCForCollectionStmt, 99 SwitchStmt, 100 CaseStmt, 101 DefaultStmt, 102 IfStmt, 103 CXXTryStmt, 104 CXXCatchStmt, 105 ConditionalOperator, 106 BinaryOperatorLAnd, 107 BinaryOperatorLOr, 108 BinaryConditionalOperator, 109 // The preceding values are available with PGO_HASH_V1. 110 111 EndOfScope, 112 IfThenBranch, 113 IfElseBranch, 114 GotoStmt, 115 IndirectGotoStmt, 116 BreakStmt, 117 ContinueStmt, 118 ReturnStmt, 119 ThrowExpr, 120 UnaryOperatorLNot, 121 BinaryOperatorLT, 122 BinaryOperatorGT, 123 BinaryOperatorLE, 124 BinaryOperatorGE, 125 BinaryOperatorEQ, 126 BinaryOperatorNE, 127 // The preceding values are available since PGO_HASH_V2. 128 129 // Keep this last. It's for the static assert that follows. 130 LastHashType 131 }; 132 static_assert(LastHashType <= TooBig, "Too many types in HashType"); 133 134 PGOHash(PGOHashVersion HashVersion) 135 : Working(0), Count(0), HashVersion(HashVersion) {} 136 void combine(HashType Type); 137 uint64_t finalize(); 138 PGOHashVersion getHashVersion() const { return HashVersion; } 139 }; 140 const int PGOHash::NumBitsPerType; 141 const unsigned PGOHash::NumTypesPerWord; 142 const unsigned PGOHash::TooBig; 143 144 /// Get the PGO hash version used in the given indexed profile. 145 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader, 146 CodeGenModule &CGM) { 147 if (PGOReader->getVersion() <= 4) 148 return PGO_HASH_V1; 149 if (PGOReader->getVersion() <= 5) 150 return PGO_HASH_V2; 151 return PGO_HASH_V3; 152 } 153 154 /// A RecursiveASTVisitor that fills a map of statements to PGO counters. 155 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { 156 using Base = RecursiveASTVisitor<MapRegionCounters>; 157 158 /// The next counter value to assign. 159 unsigned NextCounter; 160 /// The function hash. 161 PGOHash Hash; 162 /// The map of statements to counters. 163 llvm::DenseMap<const Stmt *, unsigned> &CounterMap; 164 /// The profile version. 165 uint64_t ProfileVersion; 166 167 MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion, 168 llvm::DenseMap<const Stmt *, unsigned> &CounterMap) 169 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap), 170 ProfileVersion(ProfileVersion) {} 171 172 // Blocks and lambdas are handled as separate functions, so we need not 173 // traverse them in the parent context. 174 bool TraverseBlockExpr(BlockExpr *BE) { return true; } 175 bool TraverseLambdaExpr(LambdaExpr *LE) { 176 // Traverse the captures, but not the body. 177 for (auto C : zip(LE->captures(), LE->capture_inits())) 178 TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C)); 179 return true; 180 } 181 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } 182 183 bool VisitDecl(const Decl *D) { 184 switch (D->getKind()) { 185 default: 186 break; 187 case Decl::Function: 188 case Decl::CXXMethod: 189 case Decl::CXXConstructor: 190 case Decl::CXXDestructor: 191 case Decl::CXXConversion: 192 case Decl::ObjCMethod: 193 case Decl::Block: 194 case Decl::Captured: 195 CounterMap[D->getBody()] = NextCounter++; 196 break; 197 } 198 return true; 199 } 200 201 /// If \p S gets a fresh counter, update the counter mappings. Return the 202 /// V1 hash of \p S. 203 PGOHash::HashType updateCounterMappings(Stmt *S) { 204 auto Type = getHashType(PGO_HASH_V1, S); 205 if (Type != PGOHash::None) 206 CounterMap[S] = NextCounter++; 207 return Type; 208 } 209 210 /// The RHS of all logical operators gets a fresh counter in order to count 211 /// how many times the RHS evaluates to true or false, depending on the 212 /// semantics of the operator. This is only valid for ">= v7" of the profile 213 /// version so that we facilitate backward compatibility. 214 bool VisitBinaryOperator(BinaryOperator *S) { 215 if (ProfileVersion >= llvm::IndexedInstrProf::Version7) 216 if (S->isLogicalOp() && 217 CodeGenFunction::isInstrumentedCondition(S->getRHS())) 218 CounterMap[S->getRHS()] = NextCounter++; 219 return Base::VisitBinaryOperator(S); 220 } 221 222 /// Include \p S in the function hash. 223 bool VisitStmt(Stmt *S) { 224 auto Type = updateCounterMappings(S); 225 if (Hash.getHashVersion() != PGO_HASH_V1) 226 Type = getHashType(Hash.getHashVersion(), S); 227 if (Type != PGOHash::None) 228 Hash.combine(Type); 229 return true; 230 } 231 232 bool TraverseIfStmt(IfStmt *If) { 233 // If we used the V1 hash, use the default traversal. 234 if (Hash.getHashVersion() == PGO_HASH_V1) 235 return Base::TraverseIfStmt(If); 236 237 // Otherwise, keep track of which branch we're in while traversing. 238 VisitStmt(If); 239 for (Stmt *CS : If->children()) { 240 if (!CS) 241 continue; 242 if (CS == If->getThen()) 243 Hash.combine(PGOHash::IfThenBranch); 244 else if (CS == If->getElse()) 245 Hash.combine(PGOHash::IfElseBranch); 246 TraverseStmt(CS); 247 } 248 Hash.combine(PGOHash::EndOfScope); 249 return true; 250 } 251 252 // If the statement type \p N is nestable, and its nesting impacts profile 253 // stability, define a custom traversal which tracks the end of the statement 254 // in the hash (provided we're not using the V1 hash). 255 #define DEFINE_NESTABLE_TRAVERSAL(N) \ 256 bool Traverse##N(N *S) { \ 257 Base::Traverse##N(S); \ 258 if (Hash.getHashVersion() != PGO_HASH_V1) \ 259 Hash.combine(PGOHash::EndOfScope); \ 260 return true; \ 261 } 262 263 DEFINE_NESTABLE_TRAVERSAL(WhileStmt) 264 DEFINE_NESTABLE_TRAVERSAL(DoStmt) 265 DEFINE_NESTABLE_TRAVERSAL(ForStmt) 266 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt) 267 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt) 268 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt) 269 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt) 270 271 /// Get version \p HashVersion of the PGO hash for \p S. 272 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) { 273 switch (S->getStmtClass()) { 274 default: 275 break; 276 case Stmt::LabelStmtClass: 277 return PGOHash::LabelStmt; 278 case Stmt::WhileStmtClass: 279 return PGOHash::WhileStmt; 280 case Stmt::DoStmtClass: 281 return PGOHash::DoStmt; 282 case Stmt::ForStmtClass: 283 return PGOHash::ForStmt; 284 case Stmt::CXXForRangeStmtClass: 285 return PGOHash::CXXForRangeStmt; 286 case Stmt::ObjCForCollectionStmtClass: 287 return PGOHash::ObjCForCollectionStmt; 288 case Stmt::SwitchStmtClass: 289 return PGOHash::SwitchStmt; 290 case Stmt::CaseStmtClass: 291 return PGOHash::CaseStmt; 292 case Stmt::DefaultStmtClass: 293 return PGOHash::DefaultStmt; 294 case Stmt::IfStmtClass: 295 return PGOHash::IfStmt; 296 case Stmt::CXXTryStmtClass: 297 return PGOHash::CXXTryStmt; 298 case Stmt::CXXCatchStmtClass: 299 return PGOHash::CXXCatchStmt; 300 case Stmt::ConditionalOperatorClass: 301 return PGOHash::ConditionalOperator; 302 case Stmt::BinaryConditionalOperatorClass: 303 return PGOHash::BinaryConditionalOperator; 304 case Stmt::BinaryOperatorClass: { 305 const BinaryOperator *BO = cast<BinaryOperator>(S); 306 if (BO->getOpcode() == BO_LAnd) 307 return PGOHash::BinaryOperatorLAnd; 308 if (BO->getOpcode() == BO_LOr) 309 return PGOHash::BinaryOperatorLOr; 310 if (HashVersion >= PGO_HASH_V2) { 311 switch (BO->getOpcode()) { 312 default: 313 break; 314 case BO_LT: 315 return PGOHash::BinaryOperatorLT; 316 case BO_GT: 317 return PGOHash::BinaryOperatorGT; 318 case BO_LE: 319 return PGOHash::BinaryOperatorLE; 320 case BO_GE: 321 return PGOHash::BinaryOperatorGE; 322 case BO_EQ: 323 return PGOHash::BinaryOperatorEQ; 324 case BO_NE: 325 return PGOHash::BinaryOperatorNE; 326 } 327 } 328 break; 329 } 330 } 331 332 if (HashVersion >= PGO_HASH_V2) { 333 switch (S->getStmtClass()) { 334 default: 335 break; 336 case Stmt::GotoStmtClass: 337 return PGOHash::GotoStmt; 338 case Stmt::IndirectGotoStmtClass: 339 return PGOHash::IndirectGotoStmt; 340 case Stmt::BreakStmtClass: 341 return PGOHash::BreakStmt; 342 case Stmt::ContinueStmtClass: 343 return PGOHash::ContinueStmt; 344 case Stmt::ReturnStmtClass: 345 return PGOHash::ReturnStmt; 346 case Stmt::CXXThrowExprClass: 347 return PGOHash::ThrowExpr; 348 case Stmt::UnaryOperatorClass: { 349 const UnaryOperator *UO = cast<UnaryOperator>(S); 350 if (UO->getOpcode() == UO_LNot) 351 return PGOHash::UnaryOperatorLNot; 352 break; 353 } 354 } 355 } 356 357 return PGOHash::None; 358 } 359 }; 360 361 /// A StmtVisitor that propagates the raw counts through the AST and 362 /// records the count at statements where the value may change. 363 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { 364 /// PGO state. 365 CodeGenPGO &PGO; 366 367 /// A flag that is set when the current count should be recorded on the 368 /// next statement, such as at the exit of a loop. 369 bool RecordNextStmtCount; 370 371 /// The count at the current location in the traversal. 372 uint64_t CurrentCount; 373 374 /// The map of statements to count values. 375 llvm::DenseMap<const Stmt *, uint64_t> &CountMap; 376 377 /// BreakContinueStack - Keep counts of breaks and continues inside loops. 378 struct BreakContinue { 379 uint64_t BreakCount = 0; 380 uint64_t ContinueCount = 0; 381 BreakContinue() = default; 382 }; 383 SmallVector<BreakContinue, 8> BreakContinueStack; 384 385 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, 386 CodeGenPGO &PGO) 387 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} 388 389 void RecordStmtCount(const Stmt *S) { 390 if (RecordNextStmtCount) { 391 CountMap[S] = CurrentCount; 392 RecordNextStmtCount = false; 393 } 394 } 395 396 /// Set and return the current count. 397 uint64_t setCount(uint64_t Count) { 398 CurrentCount = Count; 399 return Count; 400 } 401 402 void VisitStmt(const Stmt *S) { 403 RecordStmtCount(S); 404 for (const Stmt *Child : S->children()) 405 if (Child) 406 this->Visit(Child); 407 } 408 409 void VisitFunctionDecl(const FunctionDecl *D) { 410 // Counter tracks entry to the function body. 411 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 412 CountMap[D->getBody()] = BodyCount; 413 Visit(D->getBody()); 414 } 415 416 // Skip lambda expressions. We visit these as FunctionDecls when we're 417 // generating them and aren't interested in the body when generating a 418 // parent context. 419 void VisitLambdaExpr(const LambdaExpr *LE) {} 420 421 void VisitCapturedDecl(const CapturedDecl *D) { 422 // Counter tracks entry to the capture body. 423 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 424 CountMap[D->getBody()] = BodyCount; 425 Visit(D->getBody()); 426 } 427 428 void VisitObjCMethodDecl(const ObjCMethodDecl *D) { 429 // Counter tracks entry to the method body. 430 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 431 CountMap[D->getBody()] = BodyCount; 432 Visit(D->getBody()); 433 } 434 435 void VisitBlockDecl(const BlockDecl *D) { 436 // Counter tracks entry to the block body. 437 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 438 CountMap[D->getBody()] = BodyCount; 439 Visit(D->getBody()); 440 } 441 442 void VisitReturnStmt(const ReturnStmt *S) { 443 RecordStmtCount(S); 444 if (S->getRetValue()) 445 Visit(S->getRetValue()); 446 CurrentCount = 0; 447 RecordNextStmtCount = true; 448 } 449 450 void VisitCXXThrowExpr(const CXXThrowExpr *E) { 451 RecordStmtCount(E); 452 if (E->getSubExpr()) 453 Visit(E->getSubExpr()); 454 CurrentCount = 0; 455 RecordNextStmtCount = true; 456 } 457 458 void VisitGotoStmt(const GotoStmt *S) { 459 RecordStmtCount(S); 460 CurrentCount = 0; 461 RecordNextStmtCount = true; 462 } 463 464 void VisitLabelStmt(const LabelStmt *S) { 465 RecordNextStmtCount = false; 466 // Counter tracks the block following the label. 467 uint64_t BlockCount = setCount(PGO.getRegionCount(S)); 468 CountMap[S] = BlockCount; 469 Visit(S->getSubStmt()); 470 } 471 472 void VisitBreakStmt(const BreakStmt *S) { 473 RecordStmtCount(S); 474 assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); 475 BreakContinueStack.back().BreakCount += CurrentCount; 476 CurrentCount = 0; 477 RecordNextStmtCount = true; 478 } 479 480 void VisitContinueStmt(const ContinueStmt *S) { 481 RecordStmtCount(S); 482 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); 483 BreakContinueStack.back().ContinueCount += CurrentCount; 484 CurrentCount = 0; 485 RecordNextStmtCount = true; 486 } 487 488 void VisitWhileStmt(const WhileStmt *S) { 489 RecordStmtCount(S); 490 uint64_t ParentCount = CurrentCount; 491 492 BreakContinueStack.push_back(BreakContinue()); 493 // Visit the body region first so the break/continue adjustments can be 494 // included when visiting the condition. 495 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 496 CountMap[S->getBody()] = CurrentCount; 497 Visit(S->getBody()); 498 uint64_t BackedgeCount = CurrentCount; 499 500 // ...then go back and propagate counts through the condition. The count 501 // at the start of the condition is the sum of the incoming edges, 502 // the backedge from the end of the loop body, and the edges from 503 // continue statements. 504 BreakContinue BC = BreakContinueStack.pop_back_val(); 505 uint64_t CondCount = 506 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 507 CountMap[S->getCond()] = CondCount; 508 Visit(S->getCond()); 509 setCount(BC.BreakCount + CondCount - BodyCount); 510 RecordNextStmtCount = true; 511 } 512 513 void VisitDoStmt(const DoStmt *S) { 514 RecordStmtCount(S); 515 uint64_t LoopCount = PGO.getRegionCount(S); 516 517 BreakContinueStack.push_back(BreakContinue()); 518 // The count doesn't include the fallthrough from the parent scope. Add it. 519 uint64_t BodyCount = setCount(LoopCount + CurrentCount); 520 CountMap[S->getBody()] = BodyCount; 521 Visit(S->getBody()); 522 uint64_t BackedgeCount = CurrentCount; 523 524 BreakContinue BC = BreakContinueStack.pop_back_val(); 525 // The count at the start of the condition is equal to the count at the 526 // end of the body, plus any continues. 527 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount); 528 CountMap[S->getCond()] = CondCount; 529 Visit(S->getCond()); 530 setCount(BC.BreakCount + CondCount - LoopCount); 531 RecordNextStmtCount = true; 532 } 533 534 void VisitForStmt(const ForStmt *S) { 535 RecordStmtCount(S); 536 if (S->getInit()) 537 Visit(S->getInit()); 538 539 uint64_t ParentCount = CurrentCount; 540 541 BreakContinueStack.push_back(BreakContinue()); 542 // Visit the body region first. (This is basically the same as a while 543 // loop; see further comments in VisitWhileStmt.) 544 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 545 CountMap[S->getBody()] = BodyCount; 546 Visit(S->getBody()); 547 uint64_t BackedgeCount = CurrentCount; 548 BreakContinue BC = BreakContinueStack.pop_back_val(); 549 550 // The increment is essentially part of the body but it needs to include 551 // the count for all the continue statements. 552 if (S->getInc()) { 553 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 554 CountMap[S->getInc()] = IncCount; 555 Visit(S->getInc()); 556 } 557 558 // ...then go back and propagate counts through the condition. 559 uint64_t CondCount = 560 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 561 if (S->getCond()) { 562 CountMap[S->getCond()] = CondCount; 563 Visit(S->getCond()); 564 } 565 setCount(BC.BreakCount + CondCount - BodyCount); 566 RecordNextStmtCount = true; 567 } 568 569 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 570 RecordStmtCount(S); 571 if (S->getInit()) 572 Visit(S->getInit()); 573 Visit(S->getLoopVarStmt()); 574 Visit(S->getRangeStmt()); 575 Visit(S->getBeginStmt()); 576 Visit(S->getEndStmt()); 577 578 uint64_t ParentCount = CurrentCount; 579 BreakContinueStack.push_back(BreakContinue()); 580 // Visit the body region first. (This is basically the same as a while 581 // loop; see further comments in VisitWhileStmt.) 582 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 583 CountMap[S->getBody()] = BodyCount; 584 Visit(S->getBody()); 585 uint64_t BackedgeCount = CurrentCount; 586 BreakContinue BC = BreakContinueStack.pop_back_val(); 587 588 // The increment is essentially part of the body but it needs to include 589 // the count for all the continue statements. 590 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 591 CountMap[S->getInc()] = IncCount; 592 Visit(S->getInc()); 593 594 // ...then go back and propagate counts through the condition. 595 uint64_t CondCount = 596 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 597 CountMap[S->getCond()] = CondCount; 598 Visit(S->getCond()); 599 setCount(BC.BreakCount + CondCount - BodyCount); 600 RecordNextStmtCount = true; 601 } 602 603 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 604 RecordStmtCount(S); 605 Visit(S->getElement()); 606 uint64_t ParentCount = CurrentCount; 607 BreakContinueStack.push_back(BreakContinue()); 608 // Counter tracks the body of the loop. 609 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 610 CountMap[S->getBody()] = BodyCount; 611 Visit(S->getBody()); 612 uint64_t BackedgeCount = CurrentCount; 613 BreakContinue BC = BreakContinueStack.pop_back_val(); 614 615 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount - 616 BodyCount); 617 RecordNextStmtCount = true; 618 } 619 620 void VisitSwitchStmt(const SwitchStmt *S) { 621 RecordStmtCount(S); 622 if (S->getInit()) 623 Visit(S->getInit()); 624 Visit(S->getCond()); 625 CurrentCount = 0; 626 BreakContinueStack.push_back(BreakContinue()); 627 Visit(S->getBody()); 628 // If the switch is inside a loop, add the continue counts. 629 BreakContinue BC = BreakContinueStack.pop_back_val(); 630 if (!BreakContinueStack.empty()) 631 BreakContinueStack.back().ContinueCount += BC.ContinueCount; 632 // Counter tracks the exit block of the switch. 633 setCount(PGO.getRegionCount(S)); 634 RecordNextStmtCount = true; 635 } 636 637 void VisitSwitchCase(const SwitchCase *S) { 638 RecordNextStmtCount = false; 639 // Counter for this particular case. This counts only jumps from the 640 // switch header and does not include fallthrough from the case before 641 // this one. 642 uint64_t CaseCount = PGO.getRegionCount(S); 643 setCount(CurrentCount + CaseCount); 644 // We need the count without fallthrough in the mapping, so it's more useful 645 // for branch probabilities. 646 CountMap[S] = CaseCount; 647 RecordNextStmtCount = true; 648 Visit(S->getSubStmt()); 649 } 650 651 void VisitIfStmt(const IfStmt *S) { 652 RecordStmtCount(S); 653 654 if (S->isConsteval()) { 655 const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse(); 656 if (Stm) 657 Visit(Stm); 658 return; 659 } 660 661 uint64_t ParentCount = CurrentCount; 662 if (S->getInit()) 663 Visit(S->getInit()); 664 Visit(S->getCond()); 665 666 // Counter tracks the "then" part of an if statement. The count for 667 // the "else" part, if it exists, will be calculated from this counter. 668 uint64_t ThenCount = setCount(PGO.getRegionCount(S)); 669 CountMap[S->getThen()] = ThenCount; 670 Visit(S->getThen()); 671 uint64_t OutCount = CurrentCount; 672 673 uint64_t ElseCount = ParentCount - ThenCount; 674 if (S->getElse()) { 675 setCount(ElseCount); 676 CountMap[S->getElse()] = ElseCount; 677 Visit(S->getElse()); 678 OutCount += CurrentCount; 679 } else 680 OutCount += ElseCount; 681 setCount(OutCount); 682 RecordNextStmtCount = true; 683 } 684 685 void VisitCXXTryStmt(const CXXTryStmt *S) { 686 RecordStmtCount(S); 687 Visit(S->getTryBlock()); 688 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 689 Visit(S->getHandler(I)); 690 // Counter tracks the continuation block of the try statement. 691 setCount(PGO.getRegionCount(S)); 692 RecordNextStmtCount = true; 693 } 694 695 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 696 RecordNextStmtCount = false; 697 // Counter tracks the catch statement's handler block. 698 uint64_t CatchCount = setCount(PGO.getRegionCount(S)); 699 CountMap[S] = CatchCount; 700 Visit(S->getHandlerBlock()); 701 } 702 703 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) { 704 RecordStmtCount(E); 705 uint64_t ParentCount = CurrentCount; 706 Visit(E->getCond()); 707 708 // Counter tracks the "true" part of a conditional operator. The 709 // count in the "false" part will be calculated from this counter. 710 uint64_t TrueCount = setCount(PGO.getRegionCount(E)); 711 CountMap[E->getTrueExpr()] = TrueCount; 712 Visit(E->getTrueExpr()); 713 uint64_t OutCount = CurrentCount; 714 715 uint64_t FalseCount = setCount(ParentCount - TrueCount); 716 CountMap[E->getFalseExpr()] = FalseCount; 717 Visit(E->getFalseExpr()); 718 OutCount += CurrentCount; 719 720 setCount(OutCount); 721 RecordNextStmtCount = true; 722 } 723 724 void VisitBinLAnd(const BinaryOperator *E) { 725 RecordStmtCount(E); 726 uint64_t ParentCount = CurrentCount; 727 Visit(E->getLHS()); 728 // Counter tracks the right hand side of a logical and operator. 729 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 730 CountMap[E->getRHS()] = RHSCount; 731 Visit(E->getRHS()); 732 setCount(ParentCount + RHSCount - CurrentCount); 733 RecordNextStmtCount = true; 734 } 735 736 void VisitBinLOr(const BinaryOperator *E) { 737 RecordStmtCount(E); 738 uint64_t ParentCount = CurrentCount; 739 Visit(E->getLHS()); 740 // Counter tracks the right hand side of a logical or operator. 741 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 742 CountMap[E->getRHS()] = RHSCount; 743 Visit(E->getRHS()); 744 setCount(ParentCount + RHSCount - CurrentCount); 745 RecordNextStmtCount = true; 746 } 747 }; 748 } // end anonymous namespace 749 750 void PGOHash::combine(HashType Type) { 751 // Check that we never combine 0 and only have six bits. 752 assert(Type && "Hash is invalid: unexpected type 0"); 753 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types"); 754 755 // Pass through MD5 if enough work has built up. 756 if (Count && Count % NumTypesPerWord == 0) { 757 using namespace llvm::support; 758 uint64_t Swapped = 759 endian::byte_swap<uint64_t, llvm::endianness::little>(Working); 760 MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 761 Working = 0; 762 } 763 764 // Accumulate the current type. 765 ++Count; 766 Working = Working << NumBitsPerType | Type; 767 } 768 769 uint64_t PGOHash::finalize() { 770 // Use Working as the hash directly if we never used MD5. 771 if (Count <= NumTypesPerWord) 772 // No need to byte swap here, since none of the math was endian-dependent. 773 // This number will be byte-swapped as required on endianness transitions, 774 // so we will see the same value on the other side. 775 return Working; 776 777 // Check for remaining work in Working. 778 if (Working) { 779 // Keep the buggy behavior from v1 and v2 for backward-compatibility. This 780 // is buggy because it converts a uint64_t into an array of uint8_t. 781 if (HashVersion < PGO_HASH_V3) { 782 MD5.update({(uint8_t)Working}); 783 } else { 784 using namespace llvm::support; 785 uint64_t Swapped = 786 endian::byte_swap<uint64_t, llvm::endianness::little>(Working); 787 MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 788 } 789 } 790 791 // Finalize the MD5 and return the hash. 792 llvm::MD5::MD5Result Result; 793 MD5.final(Result); 794 return Result.low(); 795 } 796 797 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) { 798 const Decl *D = GD.getDecl(); 799 if (!D->hasBody()) 800 return; 801 802 // Skip CUDA/HIP kernel launch stub functions. 803 if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice && 804 D->hasAttr<CUDAGlobalAttr>()) 805 return; 806 807 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr(); 808 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 809 if (!InstrumentRegions && !PGOReader) 810 return; 811 if (D->isImplicit()) 812 return; 813 // Constructors and destructors may be represented by several functions in IR. 814 // If so, instrument only base variant, others are implemented by delegation 815 // to the base one, it would be counted twice otherwise. 816 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) { 817 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D)) 818 if (GD.getCtorType() != Ctor_Base && 819 CodeGenFunction::IsConstructorDelegationValid(CCD)) 820 return; 821 } 822 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base) 823 return; 824 825 CGM.ClearUnusedCoverageMapping(D); 826 if (Fn->hasFnAttribute(llvm::Attribute::NoProfile)) 827 return; 828 if (Fn->hasFnAttribute(llvm::Attribute::SkipProfile)) 829 return; 830 831 setFuncName(Fn); 832 833 mapRegionCounters(D); 834 if (CGM.getCodeGenOpts().CoverageMapping) 835 emitCounterRegionMapping(D); 836 if (PGOReader) { 837 SourceManager &SM = CGM.getContext().getSourceManager(); 838 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation())); 839 computeRegionCounts(D); 840 applyFunctionAttributes(PGOReader, Fn); 841 } 842 } 843 844 void CodeGenPGO::mapRegionCounters(const Decl *D) { 845 // Use the latest hash version when inserting instrumentation, but use the 846 // version in the indexed profile if we're reading PGO data. 847 PGOHashVersion HashVersion = PGO_HASH_LATEST; 848 uint64_t ProfileVersion = llvm::IndexedInstrProf::Version; 849 if (auto *PGOReader = CGM.getPGOReader()) { 850 HashVersion = getPGOHashVersion(PGOReader, CGM); 851 ProfileVersion = PGOReader->getVersion(); 852 } 853 854 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>); 855 MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap); 856 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 857 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD)); 858 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 859 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD)); 860 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 861 Walker.TraverseDecl(const_cast<BlockDecl *>(BD)); 862 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 863 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD)); 864 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl"); 865 NumRegionCounters = Walker.NextCounter; 866 FunctionHash = Walker.Hash.finalize(); 867 } 868 869 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) { 870 if (!D->getBody()) 871 return true; 872 873 // Skip host-only functions in the CUDA device compilation and device-only 874 // functions in the host compilation. Just roughly filter them out based on 875 // the function attributes. If there are effectively host-only or device-only 876 // ones, their coverage mapping may still be generated. 877 if (CGM.getLangOpts().CUDA && 878 ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() && 879 !D->hasAttr<CUDAGlobalAttr>()) || 880 (!CGM.getLangOpts().CUDAIsDevice && 881 (D->hasAttr<CUDAGlobalAttr>() || 882 (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>()))))) 883 return true; 884 885 // Don't map the functions in system headers. 886 const auto &SM = CGM.getContext().getSourceManager(); 887 auto Loc = D->getBody()->getBeginLoc(); 888 return SM.isInSystemHeader(Loc); 889 } 890 891 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) { 892 if (skipRegionMappingForDecl(D)) 893 return; 894 895 std::string CoverageMapping; 896 llvm::raw_string_ostream OS(CoverageMapping); 897 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 898 CGM.getContext().getSourceManager(), 899 CGM.getLangOpts(), RegionCounterMap.get()); 900 MappingGen.emitCounterMapping(D, OS); 901 OS.flush(); 902 903 if (CoverageMapping.empty()) 904 return; 905 906 CGM.getCoverageMapping()->addFunctionMappingRecord( 907 FuncNameVar, FuncName, FunctionHash, CoverageMapping); 908 } 909 910 void 911 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name, 912 llvm::GlobalValue::LinkageTypes Linkage) { 913 if (skipRegionMappingForDecl(D)) 914 return; 915 916 std::string CoverageMapping; 917 llvm::raw_string_ostream OS(CoverageMapping); 918 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 919 CGM.getContext().getSourceManager(), 920 CGM.getLangOpts()); 921 MappingGen.emitEmptyMapping(D, OS); 922 OS.flush(); 923 924 if (CoverageMapping.empty()) 925 return; 926 927 setFuncName(Name, Linkage); 928 CGM.getCoverageMapping()->addFunctionMappingRecord( 929 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false); 930 } 931 932 void CodeGenPGO::computeRegionCounts(const Decl *D) { 933 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>); 934 ComputeRegionCounts Walker(*StmtCountMap, *this); 935 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 936 Walker.VisitFunctionDecl(FD); 937 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 938 Walker.VisitObjCMethodDecl(MD); 939 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 940 Walker.VisitBlockDecl(BD); 941 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 942 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD)); 943 } 944 945 void 946 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader, 947 llvm::Function *Fn) { 948 if (!haveRegionCounts()) 949 return; 950 951 uint64_t FunctionCount = getRegionCount(nullptr); 952 Fn->setEntryCount(FunctionCount); 953 } 954 955 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S, 956 llvm::Value *StepV) { 957 if (!RegionCounterMap || !Builder.GetInsertBlock()) 958 return; 959 960 unsigned Counter = (*RegionCounterMap)[S]; 961 962 llvm::Value *Args[] = {FuncNameVar, 963 Builder.getInt64(FunctionHash), 964 Builder.getInt32(NumRegionCounters), 965 Builder.getInt32(Counter), StepV}; 966 if (!StepV) 967 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment), 968 ArrayRef(Args, 4)); 969 else 970 Builder.CreateCall( 971 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step), 972 ArrayRef(Args)); 973 } 974 975 void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) { 976 if (CGM.getCodeGenOpts().hasProfileClangInstr()) 977 M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling", 978 uint32_t(EnableValueProfiling)); 979 } 980 981 // This method either inserts a call to the profile run-time during 982 // instrumentation or puts profile data into metadata for PGO use. 983 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind, 984 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) { 985 986 if (!EnableValueProfiling) 987 return; 988 989 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock()) 990 return; 991 992 if (isa<llvm::Constant>(ValuePtr)) 993 return; 994 995 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr(); 996 if (InstrumentValueSites && RegionCounterMap) { 997 auto BuilderInsertPoint = Builder.saveIP(); 998 Builder.SetInsertPoint(ValueSite); 999 llvm::Value *Args[5] = { 1000 FuncNameVar, 1001 Builder.getInt64(FunctionHash), 1002 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()), 1003 Builder.getInt32(ValueKind), 1004 Builder.getInt32(NumValueSites[ValueKind]++) 1005 }; 1006 Builder.CreateCall( 1007 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args); 1008 Builder.restoreIP(BuilderInsertPoint); 1009 return; 1010 } 1011 1012 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 1013 if (PGOReader && haveRegionCounts()) { 1014 // We record the top most called three functions at each call site. 1015 // Profile metadata contains "VP" string identifying this metadata 1016 // as value profiling data, then a uint32_t value for the value profiling 1017 // kind, a uint64_t value for the total number of times the call is 1018 // executed, followed by the function hash and execution count (uint64_t) 1019 // pairs for each function. 1020 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind)) 1021 return; 1022 1023 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord, 1024 (llvm::InstrProfValueKind)ValueKind, 1025 NumValueSites[ValueKind]); 1026 1027 NumValueSites[ValueKind]++; 1028 } 1029 } 1030 1031 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader, 1032 bool IsInMainFile) { 1033 CGM.getPGOStats().addVisited(IsInMainFile); 1034 RegionCounts.clear(); 1035 llvm::Expected<llvm::InstrProfRecord> RecordExpected = 1036 PGOReader->getInstrProfRecord(FuncName, FunctionHash); 1037 if (auto E = RecordExpected.takeError()) { 1038 auto IPE = std::get<0>(llvm::InstrProfError::take(std::move(E))); 1039 if (IPE == llvm::instrprof_error::unknown_function) 1040 CGM.getPGOStats().addMissing(IsInMainFile); 1041 else if (IPE == llvm::instrprof_error::hash_mismatch) 1042 CGM.getPGOStats().addMismatched(IsInMainFile); 1043 else if (IPE == llvm::instrprof_error::malformed) 1044 // TODO: Consider a more specific warning for this case. 1045 CGM.getPGOStats().addMismatched(IsInMainFile); 1046 return; 1047 } 1048 ProfRecord = 1049 std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get())); 1050 RegionCounts = ProfRecord->Counts; 1051 } 1052 1053 /// Calculate what to divide by to scale weights. 1054 /// 1055 /// Given the maximum weight, calculate a divisor that will scale all the 1056 /// weights to strictly less than UINT32_MAX. 1057 static uint64_t calculateWeightScale(uint64_t MaxWeight) { 1058 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1; 1059 } 1060 1061 /// Scale an individual branch weight (and add 1). 1062 /// 1063 /// Scale a 64-bit weight down to 32-bits using \c Scale. 1064 /// 1065 /// According to Laplace's Rule of Succession, it is better to compute the 1066 /// weight based on the count plus 1, so universally add 1 to the value. 1067 /// 1068 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no 1069 /// greater than \c Weight. 1070 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) { 1071 assert(Scale && "scale by 0?"); 1072 uint64_t Scaled = Weight / Scale + 1; 1073 assert(Scaled <= UINT32_MAX && "overflow 32-bits"); 1074 return Scaled; 1075 } 1076 1077 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount, 1078 uint64_t FalseCount) const { 1079 // Check for empty weights. 1080 if (!TrueCount && !FalseCount) 1081 return nullptr; 1082 1083 // Calculate how to scale down to 32-bits. 1084 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount)); 1085 1086 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1087 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale), 1088 scaleBranchWeight(FalseCount, Scale)); 1089 } 1090 1091 llvm::MDNode * 1092 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const { 1093 // We need at least two elements to create meaningful weights. 1094 if (Weights.size() < 2) 1095 return nullptr; 1096 1097 // Check for empty weights. 1098 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end()); 1099 if (MaxWeight == 0) 1100 return nullptr; 1101 1102 // Calculate how to scale down to 32-bits. 1103 uint64_t Scale = calculateWeightScale(MaxWeight); 1104 1105 SmallVector<uint32_t, 16> ScaledWeights; 1106 ScaledWeights.reserve(Weights.size()); 1107 for (uint64_t W : Weights) 1108 ScaledWeights.push_back(scaleBranchWeight(W, Scale)); 1109 1110 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1111 return MDHelper.createBranchWeights(ScaledWeights); 1112 } 1113 1114 llvm::MDNode * 1115 CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond, 1116 uint64_t LoopCount) const { 1117 if (!PGO.haveRegionCounts()) 1118 return nullptr; 1119 std::optional<uint64_t> CondCount = PGO.getStmtCount(Cond); 1120 if (!CondCount || *CondCount == 0) 1121 return nullptr; 1122 return createProfileWeights(LoopCount, 1123 std::max(*CondCount, LoopCount) - LoopCount); 1124 } 1125