xref: /netbsd-src/external/apache2/llvm/dist/clang/lib/CodeGen/CodeGenPGO.cpp (revision e038c9c4676b0f19b1b7dd08a940c6ed64a6d5ae)
17330f729Sjoerg //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
27330f729Sjoerg //
37330f729Sjoerg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47330f729Sjoerg // See https://llvm.org/LICENSE.txt for license information.
57330f729Sjoerg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67330f729Sjoerg //
77330f729Sjoerg //===----------------------------------------------------------------------===//
87330f729Sjoerg //
97330f729Sjoerg // Instrumentation-based profile-guided optimization
107330f729Sjoerg //
117330f729Sjoerg //===----------------------------------------------------------------------===//
127330f729Sjoerg 
137330f729Sjoerg #include "CodeGenPGO.h"
147330f729Sjoerg #include "CodeGenFunction.h"
157330f729Sjoerg #include "CoverageMappingGen.h"
167330f729Sjoerg #include "clang/AST/RecursiveASTVisitor.h"
177330f729Sjoerg #include "clang/AST/StmtVisitor.h"
187330f729Sjoerg #include "llvm/IR/Intrinsics.h"
197330f729Sjoerg #include "llvm/IR/MDBuilder.h"
20*e038c9c4Sjoerg #include "llvm/Support/CommandLine.h"
217330f729Sjoerg #include "llvm/Support/Endian.h"
227330f729Sjoerg #include "llvm/Support/FileSystem.h"
237330f729Sjoerg #include "llvm/Support/MD5.h"
247330f729Sjoerg 
257330f729Sjoerg static llvm::cl::opt<bool>
267330f729Sjoerg     EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
277330f729Sjoerg                          llvm::cl::desc("Enable value profiling"),
287330f729Sjoerg                          llvm::cl::Hidden, llvm::cl::init(false));
297330f729Sjoerg 
307330f729Sjoerg using namespace clang;
317330f729Sjoerg using namespace CodeGen;
327330f729Sjoerg 
setFuncName(StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)337330f729Sjoerg void CodeGenPGO::setFuncName(StringRef Name,
347330f729Sjoerg                              llvm::GlobalValue::LinkageTypes Linkage) {
357330f729Sjoerg   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
367330f729Sjoerg   FuncName = llvm::getPGOFuncName(
377330f729Sjoerg       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
387330f729Sjoerg       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
397330f729Sjoerg 
407330f729Sjoerg   // If we're generating a profile, create a variable for the name.
417330f729Sjoerg   if (CGM.getCodeGenOpts().hasProfileClangInstr())
427330f729Sjoerg     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
437330f729Sjoerg }
447330f729Sjoerg 
setFuncName(llvm::Function * Fn)457330f729Sjoerg void CodeGenPGO::setFuncName(llvm::Function *Fn) {
467330f729Sjoerg   setFuncName(Fn->getName(), Fn->getLinkage());
477330f729Sjoerg   // Create PGOFuncName meta data.
487330f729Sjoerg   llvm::createPGOFuncNameMetadata(*Fn, FuncName);
497330f729Sjoerg }
507330f729Sjoerg 
517330f729Sjoerg /// The version of the PGO hash algorithm.
527330f729Sjoerg enum PGOHashVersion : unsigned {
537330f729Sjoerg   PGO_HASH_V1,
547330f729Sjoerg   PGO_HASH_V2,
55*e038c9c4Sjoerg   PGO_HASH_V3,
567330f729Sjoerg 
577330f729Sjoerg   // Keep this set to the latest hash version.
58*e038c9c4Sjoerg   PGO_HASH_LATEST = PGO_HASH_V3
597330f729Sjoerg };
607330f729Sjoerg 
617330f729Sjoerg namespace {
627330f729Sjoerg /// Stable hasher for PGO region counters.
637330f729Sjoerg ///
647330f729Sjoerg /// PGOHash produces a stable hash of a given function's control flow.
657330f729Sjoerg ///
667330f729Sjoerg /// Changing the output of this hash will invalidate all previously generated
677330f729Sjoerg /// profiles -- i.e., don't do it.
687330f729Sjoerg ///
697330f729Sjoerg /// \note  When this hash does eventually change (years?), we still need to
707330f729Sjoerg /// support old hashes.  We'll need to pull in the version number from the
717330f729Sjoerg /// profile data format and use the matching hash function.
727330f729Sjoerg class PGOHash {
737330f729Sjoerg   uint64_t Working;
747330f729Sjoerg   unsigned Count;
757330f729Sjoerg   PGOHashVersion HashVersion;
767330f729Sjoerg   llvm::MD5 MD5;
777330f729Sjoerg 
787330f729Sjoerg   static const int NumBitsPerType = 6;
797330f729Sjoerg   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
807330f729Sjoerg   static const unsigned TooBig = 1u << NumBitsPerType;
817330f729Sjoerg 
827330f729Sjoerg public:
837330f729Sjoerg   /// Hash values for AST nodes.
847330f729Sjoerg   ///
857330f729Sjoerg   /// Distinct values for AST nodes that have region counters attached.
867330f729Sjoerg   ///
877330f729Sjoerg   /// These values must be stable.  All new members must be added at the end,
887330f729Sjoerg   /// and no members should be removed.  Changing the enumeration value for an
897330f729Sjoerg   /// AST node will affect the hash of every function that contains that node.
907330f729Sjoerg   enum HashType : unsigned char {
917330f729Sjoerg     None = 0,
927330f729Sjoerg     LabelStmt = 1,
937330f729Sjoerg     WhileStmt,
947330f729Sjoerg     DoStmt,
957330f729Sjoerg     ForStmt,
967330f729Sjoerg     CXXForRangeStmt,
977330f729Sjoerg     ObjCForCollectionStmt,
987330f729Sjoerg     SwitchStmt,
997330f729Sjoerg     CaseStmt,
1007330f729Sjoerg     DefaultStmt,
1017330f729Sjoerg     IfStmt,
1027330f729Sjoerg     CXXTryStmt,
1037330f729Sjoerg     CXXCatchStmt,
1047330f729Sjoerg     ConditionalOperator,
1057330f729Sjoerg     BinaryOperatorLAnd,
1067330f729Sjoerg     BinaryOperatorLOr,
1077330f729Sjoerg     BinaryConditionalOperator,
1087330f729Sjoerg     // The preceding values are available with PGO_HASH_V1.
1097330f729Sjoerg 
1107330f729Sjoerg     EndOfScope,
1117330f729Sjoerg     IfThenBranch,
1127330f729Sjoerg     IfElseBranch,
1137330f729Sjoerg     GotoStmt,
1147330f729Sjoerg     IndirectGotoStmt,
1157330f729Sjoerg     BreakStmt,
1167330f729Sjoerg     ContinueStmt,
1177330f729Sjoerg     ReturnStmt,
1187330f729Sjoerg     ThrowExpr,
1197330f729Sjoerg     UnaryOperatorLNot,
1207330f729Sjoerg     BinaryOperatorLT,
1217330f729Sjoerg     BinaryOperatorGT,
1227330f729Sjoerg     BinaryOperatorLE,
1237330f729Sjoerg     BinaryOperatorGE,
1247330f729Sjoerg     BinaryOperatorEQ,
1257330f729Sjoerg     BinaryOperatorNE,
126*e038c9c4Sjoerg     // The preceding values are available since PGO_HASH_V2.
1277330f729Sjoerg 
1287330f729Sjoerg     // Keep this last.  It's for the static assert that follows.
1297330f729Sjoerg     LastHashType
1307330f729Sjoerg   };
1317330f729Sjoerg   static_assert(LastHashType <= TooBig, "Too many types in HashType");
1327330f729Sjoerg 
PGOHash(PGOHashVersion HashVersion)1337330f729Sjoerg   PGOHash(PGOHashVersion HashVersion)
1347330f729Sjoerg       : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
1357330f729Sjoerg   void combine(HashType Type);
1367330f729Sjoerg   uint64_t finalize();
getHashVersion() const1377330f729Sjoerg   PGOHashVersion getHashVersion() const { return HashVersion; }
1387330f729Sjoerg };
1397330f729Sjoerg const int PGOHash::NumBitsPerType;
1407330f729Sjoerg const unsigned PGOHash::NumTypesPerWord;
1417330f729Sjoerg const unsigned PGOHash::TooBig;
1427330f729Sjoerg 
1437330f729Sjoerg /// Get the PGO hash version used in the given indexed profile.
getPGOHashVersion(llvm::IndexedInstrProfReader * PGOReader,CodeGenModule & CGM)1447330f729Sjoerg static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
1457330f729Sjoerg                                         CodeGenModule &CGM) {
1467330f729Sjoerg   if (PGOReader->getVersion() <= 4)
1477330f729Sjoerg     return PGO_HASH_V1;
148*e038c9c4Sjoerg   if (PGOReader->getVersion() <= 5)
1497330f729Sjoerg     return PGO_HASH_V2;
150*e038c9c4Sjoerg   return PGO_HASH_V3;
1517330f729Sjoerg }
1527330f729Sjoerg 
1537330f729Sjoerg /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
1547330f729Sjoerg struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
1557330f729Sjoerg   using Base = RecursiveASTVisitor<MapRegionCounters>;
1567330f729Sjoerg 
1577330f729Sjoerg   /// The next counter value to assign.
1587330f729Sjoerg   unsigned NextCounter;
1597330f729Sjoerg   /// The function hash.
1607330f729Sjoerg   PGOHash Hash;
1617330f729Sjoerg   /// The map of statements to counters.
1627330f729Sjoerg   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
163*e038c9c4Sjoerg   /// The profile version.
164*e038c9c4Sjoerg   uint64_t ProfileVersion;
1657330f729Sjoerg 
MapRegionCounters__anona4a975a40111::MapRegionCounters166*e038c9c4Sjoerg   MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
1677330f729Sjoerg                     llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
168*e038c9c4Sjoerg       : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
169*e038c9c4Sjoerg         ProfileVersion(ProfileVersion) {}
1707330f729Sjoerg 
1717330f729Sjoerg   // Blocks and lambdas are handled as separate functions, so we need not
1727330f729Sjoerg   // traverse them in the parent context.
TraverseBlockExpr__anona4a975a40111::MapRegionCounters1737330f729Sjoerg   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
TraverseLambdaExpr__anona4a975a40111::MapRegionCounters1747330f729Sjoerg   bool TraverseLambdaExpr(LambdaExpr *LE) {
1757330f729Sjoerg     // Traverse the captures, but not the body.
176*e038c9c4Sjoerg     for (auto C : zip(LE->captures(), LE->capture_inits()))
1777330f729Sjoerg       TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
1787330f729Sjoerg     return true;
1797330f729Sjoerg   }
TraverseCapturedStmt__anona4a975a40111::MapRegionCounters1807330f729Sjoerg   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
1817330f729Sjoerg 
VisitDecl__anona4a975a40111::MapRegionCounters1827330f729Sjoerg   bool VisitDecl(const Decl *D) {
1837330f729Sjoerg     switch (D->getKind()) {
1847330f729Sjoerg     default:
1857330f729Sjoerg       break;
1867330f729Sjoerg     case Decl::Function:
1877330f729Sjoerg     case Decl::CXXMethod:
1887330f729Sjoerg     case Decl::CXXConstructor:
1897330f729Sjoerg     case Decl::CXXDestructor:
1907330f729Sjoerg     case Decl::CXXConversion:
1917330f729Sjoerg     case Decl::ObjCMethod:
1927330f729Sjoerg     case Decl::Block:
1937330f729Sjoerg     case Decl::Captured:
1947330f729Sjoerg       CounterMap[D->getBody()] = NextCounter++;
1957330f729Sjoerg       break;
1967330f729Sjoerg     }
1977330f729Sjoerg     return true;
1987330f729Sjoerg   }
1997330f729Sjoerg 
2007330f729Sjoerg   /// If \p S gets a fresh counter, update the counter mappings. Return the
2017330f729Sjoerg   /// V1 hash of \p S.
updateCounterMappings__anona4a975a40111::MapRegionCounters2027330f729Sjoerg   PGOHash::HashType updateCounterMappings(Stmt *S) {
2037330f729Sjoerg     auto Type = getHashType(PGO_HASH_V1, S);
2047330f729Sjoerg     if (Type != PGOHash::None)
2057330f729Sjoerg       CounterMap[S] = NextCounter++;
2067330f729Sjoerg     return Type;
2077330f729Sjoerg   }
2087330f729Sjoerg 
209*e038c9c4Sjoerg   /// The RHS of all logical operators gets a fresh counter in order to count
210*e038c9c4Sjoerg   /// how many times the RHS evaluates to true or false, depending on the
211*e038c9c4Sjoerg   /// semantics of the operator. This is only valid for ">= v7" of the profile
212*e038c9c4Sjoerg   /// version so that we facilitate backward compatibility.
VisitBinaryOperator__anona4a975a40111::MapRegionCounters213*e038c9c4Sjoerg   bool VisitBinaryOperator(BinaryOperator *S) {
214*e038c9c4Sjoerg     if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
215*e038c9c4Sjoerg       if (S->isLogicalOp() &&
216*e038c9c4Sjoerg           CodeGenFunction::isInstrumentedCondition(S->getRHS()))
217*e038c9c4Sjoerg         CounterMap[S->getRHS()] = NextCounter++;
218*e038c9c4Sjoerg     return Base::VisitBinaryOperator(S);
219*e038c9c4Sjoerg   }
220*e038c9c4Sjoerg 
2217330f729Sjoerg   /// Include \p S in the function hash.
VisitStmt__anona4a975a40111::MapRegionCounters2227330f729Sjoerg   bool VisitStmt(Stmt *S) {
2237330f729Sjoerg     auto Type = updateCounterMappings(S);
2247330f729Sjoerg     if (Hash.getHashVersion() != PGO_HASH_V1)
2257330f729Sjoerg       Type = getHashType(Hash.getHashVersion(), S);
2267330f729Sjoerg     if (Type != PGOHash::None)
2277330f729Sjoerg       Hash.combine(Type);
2287330f729Sjoerg     return true;
2297330f729Sjoerg   }
2307330f729Sjoerg 
TraverseIfStmt__anona4a975a40111::MapRegionCounters2317330f729Sjoerg   bool TraverseIfStmt(IfStmt *If) {
2327330f729Sjoerg     // If we used the V1 hash, use the default traversal.
2337330f729Sjoerg     if (Hash.getHashVersion() == PGO_HASH_V1)
2347330f729Sjoerg       return Base::TraverseIfStmt(If);
2357330f729Sjoerg 
2367330f729Sjoerg     // Otherwise, keep track of which branch we're in while traversing.
2377330f729Sjoerg     VisitStmt(If);
2387330f729Sjoerg     for (Stmt *CS : If->children()) {
2397330f729Sjoerg       if (!CS)
2407330f729Sjoerg         continue;
2417330f729Sjoerg       if (CS == If->getThen())
2427330f729Sjoerg         Hash.combine(PGOHash::IfThenBranch);
2437330f729Sjoerg       else if (CS == If->getElse())
2447330f729Sjoerg         Hash.combine(PGOHash::IfElseBranch);
2457330f729Sjoerg       TraverseStmt(CS);
2467330f729Sjoerg     }
2477330f729Sjoerg     Hash.combine(PGOHash::EndOfScope);
2487330f729Sjoerg     return true;
2497330f729Sjoerg   }
2507330f729Sjoerg 
2517330f729Sjoerg // If the statement type \p N is nestable, and its nesting impacts profile
2527330f729Sjoerg // stability, define a custom traversal which tracks the end of the statement
2537330f729Sjoerg // in the hash (provided we're not using the V1 hash).
2547330f729Sjoerg #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
2557330f729Sjoerg   bool Traverse##N(N *S) {                                                     \
2567330f729Sjoerg     Base::Traverse##N(S);                                                      \
2577330f729Sjoerg     if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
2587330f729Sjoerg       Hash.combine(PGOHash::EndOfScope);                                       \
2597330f729Sjoerg     return true;                                                               \
2607330f729Sjoerg   }
2617330f729Sjoerg 
2627330f729Sjoerg   DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
DEFINE_NESTABLE_TRAVERSAL__anona4a975a40111::MapRegionCounters2637330f729Sjoerg   DEFINE_NESTABLE_TRAVERSAL(DoStmt)
2647330f729Sjoerg   DEFINE_NESTABLE_TRAVERSAL(ForStmt)
2657330f729Sjoerg   DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
2667330f729Sjoerg   DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
2677330f729Sjoerg   DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
2687330f729Sjoerg   DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
2697330f729Sjoerg 
2707330f729Sjoerg   /// Get version \p HashVersion of the PGO hash for \p S.
2717330f729Sjoerg   PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
2727330f729Sjoerg     switch (S->getStmtClass()) {
2737330f729Sjoerg     default:
2747330f729Sjoerg       break;
2757330f729Sjoerg     case Stmt::LabelStmtClass:
2767330f729Sjoerg       return PGOHash::LabelStmt;
2777330f729Sjoerg     case Stmt::WhileStmtClass:
2787330f729Sjoerg       return PGOHash::WhileStmt;
2797330f729Sjoerg     case Stmt::DoStmtClass:
2807330f729Sjoerg       return PGOHash::DoStmt;
2817330f729Sjoerg     case Stmt::ForStmtClass:
2827330f729Sjoerg       return PGOHash::ForStmt;
2837330f729Sjoerg     case Stmt::CXXForRangeStmtClass:
2847330f729Sjoerg       return PGOHash::CXXForRangeStmt;
2857330f729Sjoerg     case Stmt::ObjCForCollectionStmtClass:
2867330f729Sjoerg       return PGOHash::ObjCForCollectionStmt;
2877330f729Sjoerg     case Stmt::SwitchStmtClass:
2887330f729Sjoerg       return PGOHash::SwitchStmt;
2897330f729Sjoerg     case Stmt::CaseStmtClass:
2907330f729Sjoerg       return PGOHash::CaseStmt;
2917330f729Sjoerg     case Stmt::DefaultStmtClass:
2927330f729Sjoerg       return PGOHash::DefaultStmt;
2937330f729Sjoerg     case Stmt::IfStmtClass:
2947330f729Sjoerg       return PGOHash::IfStmt;
2957330f729Sjoerg     case Stmt::CXXTryStmtClass:
2967330f729Sjoerg       return PGOHash::CXXTryStmt;
2977330f729Sjoerg     case Stmt::CXXCatchStmtClass:
2987330f729Sjoerg       return PGOHash::CXXCatchStmt;
2997330f729Sjoerg     case Stmt::ConditionalOperatorClass:
3007330f729Sjoerg       return PGOHash::ConditionalOperator;
3017330f729Sjoerg     case Stmt::BinaryConditionalOperatorClass:
3027330f729Sjoerg       return PGOHash::BinaryConditionalOperator;
3037330f729Sjoerg     case Stmt::BinaryOperatorClass: {
3047330f729Sjoerg       const BinaryOperator *BO = cast<BinaryOperator>(S);
3057330f729Sjoerg       if (BO->getOpcode() == BO_LAnd)
3067330f729Sjoerg         return PGOHash::BinaryOperatorLAnd;
3077330f729Sjoerg       if (BO->getOpcode() == BO_LOr)
3087330f729Sjoerg         return PGOHash::BinaryOperatorLOr;
309*e038c9c4Sjoerg       if (HashVersion >= PGO_HASH_V2) {
3107330f729Sjoerg         switch (BO->getOpcode()) {
3117330f729Sjoerg         default:
3127330f729Sjoerg           break;
3137330f729Sjoerg         case BO_LT:
3147330f729Sjoerg           return PGOHash::BinaryOperatorLT;
3157330f729Sjoerg         case BO_GT:
3167330f729Sjoerg           return PGOHash::BinaryOperatorGT;
3177330f729Sjoerg         case BO_LE:
3187330f729Sjoerg           return PGOHash::BinaryOperatorLE;
3197330f729Sjoerg         case BO_GE:
3207330f729Sjoerg           return PGOHash::BinaryOperatorGE;
3217330f729Sjoerg         case BO_EQ:
3227330f729Sjoerg           return PGOHash::BinaryOperatorEQ;
3237330f729Sjoerg         case BO_NE:
3247330f729Sjoerg           return PGOHash::BinaryOperatorNE;
3257330f729Sjoerg         }
3267330f729Sjoerg       }
3277330f729Sjoerg       break;
3287330f729Sjoerg     }
3297330f729Sjoerg     }
3307330f729Sjoerg 
331*e038c9c4Sjoerg     if (HashVersion >= PGO_HASH_V2) {
3327330f729Sjoerg       switch (S->getStmtClass()) {
3337330f729Sjoerg       default:
3347330f729Sjoerg         break;
3357330f729Sjoerg       case Stmt::GotoStmtClass:
3367330f729Sjoerg         return PGOHash::GotoStmt;
3377330f729Sjoerg       case Stmt::IndirectGotoStmtClass:
3387330f729Sjoerg         return PGOHash::IndirectGotoStmt;
3397330f729Sjoerg       case Stmt::BreakStmtClass:
3407330f729Sjoerg         return PGOHash::BreakStmt;
3417330f729Sjoerg       case Stmt::ContinueStmtClass:
3427330f729Sjoerg         return PGOHash::ContinueStmt;
3437330f729Sjoerg       case Stmt::ReturnStmtClass:
3447330f729Sjoerg         return PGOHash::ReturnStmt;
3457330f729Sjoerg       case Stmt::CXXThrowExprClass:
3467330f729Sjoerg         return PGOHash::ThrowExpr;
3477330f729Sjoerg       case Stmt::UnaryOperatorClass: {
3487330f729Sjoerg         const UnaryOperator *UO = cast<UnaryOperator>(S);
3497330f729Sjoerg         if (UO->getOpcode() == UO_LNot)
3507330f729Sjoerg           return PGOHash::UnaryOperatorLNot;
3517330f729Sjoerg         break;
3527330f729Sjoerg       }
3537330f729Sjoerg       }
3547330f729Sjoerg     }
3557330f729Sjoerg 
3567330f729Sjoerg     return PGOHash::None;
3577330f729Sjoerg   }
3587330f729Sjoerg };
3597330f729Sjoerg 
3607330f729Sjoerg /// A StmtVisitor that propagates the raw counts through the AST and
3617330f729Sjoerg /// records the count at statements where the value may change.
3627330f729Sjoerg struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
3637330f729Sjoerg   /// PGO state.
3647330f729Sjoerg   CodeGenPGO &PGO;
3657330f729Sjoerg 
3667330f729Sjoerg   /// A flag that is set when the current count should be recorded on the
3677330f729Sjoerg   /// next statement, such as at the exit of a loop.
3687330f729Sjoerg   bool RecordNextStmtCount;
3697330f729Sjoerg 
3707330f729Sjoerg   /// The count at the current location in the traversal.
3717330f729Sjoerg   uint64_t CurrentCount;
3727330f729Sjoerg 
3737330f729Sjoerg   /// The map of statements to count values.
3747330f729Sjoerg   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
3757330f729Sjoerg 
3767330f729Sjoerg   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
3777330f729Sjoerg   struct BreakContinue {
3787330f729Sjoerg     uint64_t BreakCount;
3797330f729Sjoerg     uint64_t ContinueCount;
BreakContinue__anona4a975a40111::ComputeRegionCounts::BreakContinue3807330f729Sjoerg     BreakContinue() : BreakCount(0), ContinueCount(0) {}
3817330f729Sjoerg   };
3827330f729Sjoerg   SmallVector<BreakContinue, 8> BreakContinueStack;
3837330f729Sjoerg 
ComputeRegionCounts__anona4a975a40111::ComputeRegionCounts3847330f729Sjoerg   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
3857330f729Sjoerg                       CodeGenPGO &PGO)
3867330f729Sjoerg       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
3877330f729Sjoerg 
RecordStmtCount__anona4a975a40111::ComputeRegionCounts3887330f729Sjoerg   void RecordStmtCount(const Stmt *S) {
3897330f729Sjoerg     if (RecordNextStmtCount) {
3907330f729Sjoerg       CountMap[S] = CurrentCount;
3917330f729Sjoerg       RecordNextStmtCount = false;
3927330f729Sjoerg     }
3937330f729Sjoerg   }
3947330f729Sjoerg 
3957330f729Sjoerg   /// Set and return the current count.
setCount__anona4a975a40111::ComputeRegionCounts3967330f729Sjoerg   uint64_t setCount(uint64_t Count) {
3977330f729Sjoerg     CurrentCount = Count;
3987330f729Sjoerg     return Count;
3997330f729Sjoerg   }
4007330f729Sjoerg 
VisitStmt__anona4a975a40111::ComputeRegionCounts4017330f729Sjoerg   void VisitStmt(const Stmt *S) {
4027330f729Sjoerg     RecordStmtCount(S);
4037330f729Sjoerg     for (const Stmt *Child : S->children())
4047330f729Sjoerg       if (Child)
4057330f729Sjoerg         this->Visit(Child);
4067330f729Sjoerg   }
4077330f729Sjoerg 
VisitFunctionDecl__anona4a975a40111::ComputeRegionCounts4087330f729Sjoerg   void VisitFunctionDecl(const FunctionDecl *D) {
4097330f729Sjoerg     // Counter tracks entry to the function body.
4107330f729Sjoerg     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
4117330f729Sjoerg     CountMap[D->getBody()] = BodyCount;
4127330f729Sjoerg     Visit(D->getBody());
4137330f729Sjoerg   }
4147330f729Sjoerg 
4157330f729Sjoerg   // Skip lambda expressions. We visit these as FunctionDecls when we're
4167330f729Sjoerg   // generating them and aren't interested in the body when generating a
4177330f729Sjoerg   // parent context.
VisitLambdaExpr__anona4a975a40111::ComputeRegionCounts4187330f729Sjoerg   void VisitLambdaExpr(const LambdaExpr *LE) {}
4197330f729Sjoerg 
VisitCapturedDecl__anona4a975a40111::ComputeRegionCounts4207330f729Sjoerg   void VisitCapturedDecl(const CapturedDecl *D) {
4217330f729Sjoerg     // Counter tracks entry to the capture body.
4227330f729Sjoerg     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
4237330f729Sjoerg     CountMap[D->getBody()] = BodyCount;
4247330f729Sjoerg     Visit(D->getBody());
4257330f729Sjoerg   }
4267330f729Sjoerg 
VisitObjCMethodDecl__anona4a975a40111::ComputeRegionCounts4277330f729Sjoerg   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
4287330f729Sjoerg     // Counter tracks entry to the method body.
4297330f729Sjoerg     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
4307330f729Sjoerg     CountMap[D->getBody()] = BodyCount;
4317330f729Sjoerg     Visit(D->getBody());
4327330f729Sjoerg   }
4337330f729Sjoerg 
VisitBlockDecl__anona4a975a40111::ComputeRegionCounts4347330f729Sjoerg   void VisitBlockDecl(const BlockDecl *D) {
4357330f729Sjoerg     // Counter tracks entry to the block body.
4367330f729Sjoerg     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
4377330f729Sjoerg     CountMap[D->getBody()] = BodyCount;
4387330f729Sjoerg     Visit(D->getBody());
4397330f729Sjoerg   }
4407330f729Sjoerg 
VisitReturnStmt__anona4a975a40111::ComputeRegionCounts4417330f729Sjoerg   void VisitReturnStmt(const ReturnStmt *S) {
4427330f729Sjoerg     RecordStmtCount(S);
4437330f729Sjoerg     if (S->getRetValue())
4447330f729Sjoerg       Visit(S->getRetValue());
4457330f729Sjoerg     CurrentCount = 0;
4467330f729Sjoerg     RecordNextStmtCount = true;
4477330f729Sjoerg   }
4487330f729Sjoerg 
VisitCXXThrowExpr__anona4a975a40111::ComputeRegionCounts4497330f729Sjoerg   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
4507330f729Sjoerg     RecordStmtCount(E);
4517330f729Sjoerg     if (E->getSubExpr())
4527330f729Sjoerg       Visit(E->getSubExpr());
4537330f729Sjoerg     CurrentCount = 0;
4547330f729Sjoerg     RecordNextStmtCount = true;
4557330f729Sjoerg   }
4567330f729Sjoerg 
VisitGotoStmt__anona4a975a40111::ComputeRegionCounts4577330f729Sjoerg   void VisitGotoStmt(const GotoStmt *S) {
4587330f729Sjoerg     RecordStmtCount(S);
4597330f729Sjoerg     CurrentCount = 0;
4607330f729Sjoerg     RecordNextStmtCount = true;
4617330f729Sjoerg   }
4627330f729Sjoerg 
VisitLabelStmt__anona4a975a40111::ComputeRegionCounts4637330f729Sjoerg   void VisitLabelStmt(const LabelStmt *S) {
4647330f729Sjoerg     RecordNextStmtCount = false;
4657330f729Sjoerg     // Counter tracks the block following the label.
4667330f729Sjoerg     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
4677330f729Sjoerg     CountMap[S] = BlockCount;
4687330f729Sjoerg     Visit(S->getSubStmt());
4697330f729Sjoerg   }
4707330f729Sjoerg 
VisitBreakStmt__anona4a975a40111::ComputeRegionCounts4717330f729Sjoerg   void VisitBreakStmt(const BreakStmt *S) {
4727330f729Sjoerg     RecordStmtCount(S);
4737330f729Sjoerg     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
4747330f729Sjoerg     BreakContinueStack.back().BreakCount += CurrentCount;
4757330f729Sjoerg     CurrentCount = 0;
4767330f729Sjoerg     RecordNextStmtCount = true;
4777330f729Sjoerg   }
4787330f729Sjoerg 
VisitContinueStmt__anona4a975a40111::ComputeRegionCounts4797330f729Sjoerg   void VisitContinueStmt(const ContinueStmt *S) {
4807330f729Sjoerg     RecordStmtCount(S);
4817330f729Sjoerg     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
4827330f729Sjoerg     BreakContinueStack.back().ContinueCount += CurrentCount;
4837330f729Sjoerg     CurrentCount = 0;
4847330f729Sjoerg     RecordNextStmtCount = true;
4857330f729Sjoerg   }
4867330f729Sjoerg 
VisitWhileStmt__anona4a975a40111::ComputeRegionCounts4877330f729Sjoerg   void VisitWhileStmt(const WhileStmt *S) {
4887330f729Sjoerg     RecordStmtCount(S);
4897330f729Sjoerg     uint64_t ParentCount = CurrentCount;
4907330f729Sjoerg 
4917330f729Sjoerg     BreakContinueStack.push_back(BreakContinue());
4927330f729Sjoerg     // Visit the body region first so the break/continue adjustments can be
4937330f729Sjoerg     // included when visiting the condition.
4947330f729Sjoerg     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
4957330f729Sjoerg     CountMap[S->getBody()] = CurrentCount;
4967330f729Sjoerg     Visit(S->getBody());
4977330f729Sjoerg     uint64_t BackedgeCount = CurrentCount;
4987330f729Sjoerg 
4997330f729Sjoerg     // ...then go back and propagate counts through the condition. The count
5007330f729Sjoerg     // at the start of the condition is the sum of the incoming edges,
5017330f729Sjoerg     // the backedge from the end of the loop body, and the edges from
5027330f729Sjoerg     // continue statements.
5037330f729Sjoerg     BreakContinue BC = BreakContinueStack.pop_back_val();
5047330f729Sjoerg     uint64_t CondCount =
5057330f729Sjoerg         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
5067330f729Sjoerg     CountMap[S->getCond()] = CondCount;
5077330f729Sjoerg     Visit(S->getCond());
5087330f729Sjoerg     setCount(BC.BreakCount + CondCount - BodyCount);
5097330f729Sjoerg     RecordNextStmtCount = true;
5107330f729Sjoerg   }
5117330f729Sjoerg 
VisitDoStmt__anona4a975a40111::ComputeRegionCounts5127330f729Sjoerg   void VisitDoStmt(const DoStmt *S) {
5137330f729Sjoerg     RecordStmtCount(S);
5147330f729Sjoerg     uint64_t LoopCount = PGO.getRegionCount(S);
5157330f729Sjoerg 
5167330f729Sjoerg     BreakContinueStack.push_back(BreakContinue());
5177330f729Sjoerg     // The count doesn't include the fallthrough from the parent scope. Add it.
5187330f729Sjoerg     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
5197330f729Sjoerg     CountMap[S->getBody()] = BodyCount;
5207330f729Sjoerg     Visit(S->getBody());
5217330f729Sjoerg     uint64_t BackedgeCount = CurrentCount;
5227330f729Sjoerg 
5237330f729Sjoerg     BreakContinue BC = BreakContinueStack.pop_back_val();
5247330f729Sjoerg     // The count at the start of the condition is equal to the count at the
5257330f729Sjoerg     // end of the body, plus any continues.
5267330f729Sjoerg     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
5277330f729Sjoerg     CountMap[S->getCond()] = CondCount;
5287330f729Sjoerg     Visit(S->getCond());
5297330f729Sjoerg     setCount(BC.BreakCount + CondCount - LoopCount);
5307330f729Sjoerg     RecordNextStmtCount = true;
5317330f729Sjoerg   }
5327330f729Sjoerg 
VisitForStmt__anona4a975a40111::ComputeRegionCounts5337330f729Sjoerg   void VisitForStmt(const ForStmt *S) {
5347330f729Sjoerg     RecordStmtCount(S);
5357330f729Sjoerg     if (S->getInit())
5367330f729Sjoerg       Visit(S->getInit());
5377330f729Sjoerg 
5387330f729Sjoerg     uint64_t ParentCount = CurrentCount;
5397330f729Sjoerg 
5407330f729Sjoerg     BreakContinueStack.push_back(BreakContinue());
5417330f729Sjoerg     // Visit the body region first. (This is basically the same as a while
5427330f729Sjoerg     // loop; see further comments in VisitWhileStmt.)
5437330f729Sjoerg     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
5447330f729Sjoerg     CountMap[S->getBody()] = BodyCount;
5457330f729Sjoerg     Visit(S->getBody());
5467330f729Sjoerg     uint64_t BackedgeCount = CurrentCount;
5477330f729Sjoerg     BreakContinue BC = BreakContinueStack.pop_back_val();
5487330f729Sjoerg 
5497330f729Sjoerg     // The increment is essentially part of the body but it needs to include
5507330f729Sjoerg     // the count for all the continue statements.
5517330f729Sjoerg     if (S->getInc()) {
5527330f729Sjoerg       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
5537330f729Sjoerg       CountMap[S->getInc()] = IncCount;
5547330f729Sjoerg       Visit(S->getInc());
5557330f729Sjoerg     }
5567330f729Sjoerg 
5577330f729Sjoerg     // ...then go back and propagate counts through the condition.
5587330f729Sjoerg     uint64_t CondCount =
5597330f729Sjoerg         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
5607330f729Sjoerg     if (S->getCond()) {
5617330f729Sjoerg       CountMap[S->getCond()] = CondCount;
5627330f729Sjoerg       Visit(S->getCond());
5637330f729Sjoerg     }
5647330f729Sjoerg     setCount(BC.BreakCount + CondCount - BodyCount);
5657330f729Sjoerg     RecordNextStmtCount = true;
5667330f729Sjoerg   }
5677330f729Sjoerg 
VisitCXXForRangeStmt__anona4a975a40111::ComputeRegionCounts5687330f729Sjoerg   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
5697330f729Sjoerg     RecordStmtCount(S);
5707330f729Sjoerg     if (S->getInit())
5717330f729Sjoerg       Visit(S->getInit());
5727330f729Sjoerg     Visit(S->getLoopVarStmt());
5737330f729Sjoerg     Visit(S->getRangeStmt());
5747330f729Sjoerg     Visit(S->getBeginStmt());
5757330f729Sjoerg     Visit(S->getEndStmt());
5767330f729Sjoerg 
5777330f729Sjoerg     uint64_t ParentCount = CurrentCount;
5787330f729Sjoerg     BreakContinueStack.push_back(BreakContinue());
5797330f729Sjoerg     // Visit the body region first. (This is basically the same as a while
5807330f729Sjoerg     // loop; see further comments in VisitWhileStmt.)
5817330f729Sjoerg     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
5827330f729Sjoerg     CountMap[S->getBody()] = BodyCount;
5837330f729Sjoerg     Visit(S->getBody());
5847330f729Sjoerg     uint64_t BackedgeCount = CurrentCount;
5857330f729Sjoerg     BreakContinue BC = BreakContinueStack.pop_back_val();
5867330f729Sjoerg 
5877330f729Sjoerg     // The increment is essentially part of the body but it needs to include
5887330f729Sjoerg     // the count for all the continue statements.
5897330f729Sjoerg     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
5907330f729Sjoerg     CountMap[S->getInc()] = IncCount;
5917330f729Sjoerg     Visit(S->getInc());
5927330f729Sjoerg 
5937330f729Sjoerg     // ...then go back and propagate counts through the condition.
5947330f729Sjoerg     uint64_t CondCount =
5957330f729Sjoerg         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
5967330f729Sjoerg     CountMap[S->getCond()] = CondCount;
5977330f729Sjoerg     Visit(S->getCond());
5987330f729Sjoerg     setCount(BC.BreakCount + CondCount - BodyCount);
5997330f729Sjoerg     RecordNextStmtCount = true;
6007330f729Sjoerg   }
6017330f729Sjoerg 
VisitObjCForCollectionStmt__anona4a975a40111::ComputeRegionCounts6027330f729Sjoerg   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
6037330f729Sjoerg     RecordStmtCount(S);
6047330f729Sjoerg     Visit(S->getElement());
6057330f729Sjoerg     uint64_t ParentCount = CurrentCount;
6067330f729Sjoerg     BreakContinueStack.push_back(BreakContinue());
6077330f729Sjoerg     // Counter tracks the body of the loop.
6087330f729Sjoerg     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
6097330f729Sjoerg     CountMap[S->getBody()] = BodyCount;
6107330f729Sjoerg     Visit(S->getBody());
6117330f729Sjoerg     uint64_t BackedgeCount = CurrentCount;
6127330f729Sjoerg     BreakContinue BC = BreakContinueStack.pop_back_val();
6137330f729Sjoerg 
6147330f729Sjoerg     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
6157330f729Sjoerg              BodyCount);
6167330f729Sjoerg     RecordNextStmtCount = true;
6177330f729Sjoerg   }
6187330f729Sjoerg 
VisitSwitchStmt__anona4a975a40111::ComputeRegionCounts6197330f729Sjoerg   void VisitSwitchStmt(const SwitchStmt *S) {
6207330f729Sjoerg     RecordStmtCount(S);
6217330f729Sjoerg     if (S->getInit())
6227330f729Sjoerg       Visit(S->getInit());
6237330f729Sjoerg     Visit(S->getCond());
6247330f729Sjoerg     CurrentCount = 0;
6257330f729Sjoerg     BreakContinueStack.push_back(BreakContinue());
6267330f729Sjoerg     Visit(S->getBody());
6277330f729Sjoerg     // If the switch is inside a loop, add the continue counts.
6287330f729Sjoerg     BreakContinue BC = BreakContinueStack.pop_back_val();
6297330f729Sjoerg     if (!BreakContinueStack.empty())
6307330f729Sjoerg       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
6317330f729Sjoerg     // Counter tracks the exit block of the switch.
6327330f729Sjoerg     setCount(PGO.getRegionCount(S));
6337330f729Sjoerg     RecordNextStmtCount = true;
6347330f729Sjoerg   }
6357330f729Sjoerg 
VisitSwitchCase__anona4a975a40111::ComputeRegionCounts6367330f729Sjoerg   void VisitSwitchCase(const SwitchCase *S) {
6377330f729Sjoerg     RecordNextStmtCount = false;
6387330f729Sjoerg     // Counter for this particular case. This counts only jumps from the
6397330f729Sjoerg     // switch header and does not include fallthrough from the case before
6407330f729Sjoerg     // this one.
6417330f729Sjoerg     uint64_t CaseCount = PGO.getRegionCount(S);
6427330f729Sjoerg     setCount(CurrentCount + CaseCount);
6437330f729Sjoerg     // We need the count without fallthrough in the mapping, so it's more useful
6447330f729Sjoerg     // for branch probabilities.
6457330f729Sjoerg     CountMap[S] = CaseCount;
6467330f729Sjoerg     RecordNextStmtCount = true;
6477330f729Sjoerg     Visit(S->getSubStmt());
6487330f729Sjoerg   }
6497330f729Sjoerg 
VisitIfStmt__anona4a975a40111::ComputeRegionCounts6507330f729Sjoerg   void VisitIfStmt(const IfStmt *S) {
6517330f729Sjoerg     RecordStmtCount(S);
6527330f729Sjoerg     uint64_t ParentCount = CurrentCount;
6537330f729Sjoerg     if (S->getInit())
6547330f729Sjoerg       Visit(S->getInit());
6557330f729Sjoerg     Visit(S->getCond());
6567330f729Sjoerg 
6577330f729Sjoerg     // Counter tracks the "then" part of an if statement. The count for
6587330f729Sjoerg     // the "else" part, if it exists, will be calculated from this counter.
6597330f729Sjoerg     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
6607330f729Sjoerg     CountMap[S->getThen()] = ThenCount;
6617330f729Sjoerg     Visit(S->getThen());
6627330f729Sjoerg     uint64_t OutCount = CurrentCount;
6637330f729Sjoerg 
6647330f729Sjoerg     uint64_t ElseCount = ParentCount - ThenCount;
6657330f729Sjoerg     if (S->getElse()) {
6667330f729Sjoerg       setCount(ElseCount);
6677330f729Sjoerg       CountMap[S->getElse()] = ElseCount;
6687330f729Sjoerg       Visit(S->getElse());
6697330f729Sjoerg       OutCount += CurrentCount;
6707330f729Sjoerg     } else
6717330f729Sjoerg       OutCount += ElseCount;
6727330f729Sjoerg     setCount(OutCount);
6737330f729Sjoerg     RecordNextStmtCount = true;
6747330f729Sjoerg   }
6757330f729Sjoerg 
VisitCXXTryStmt__anona4a975a40111::ComputeRegionCounts6767330f729Sjoerg   void VisitCXXTryStmt(const CXXTryStmt *S) {
6777330f729Sjoerg     RecordStmtCount(S);
6787330f729Sjoerg     Visit(S->getTryBlock());
6797330f729Sjoerg     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
6807330f729Sjoerg       Visit(S->getHandler(I));
6817330f729Sjoerg     // Counter tracks the continuation block of the try statement.
6827330f729Sjoerg     setCount(PGO.getRegionCount(S));
6837330f729Sjoerg     RecordNextStmtCount = true;
6847330f729Sjoerg   }
6857330f729Sjoerg 
VisitCXXCatchStmt__anona4a975a40111::ComputeRegionCounts6867330f729Sjoerg   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
6877330f729Sjoerg     RecordNextStmtCount = false;
6887330f729Sjoerg     // Counter tracks the catch statement's handler block.
6897330f729Sjoerg     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
6907330f729Sjoerg     CountMap[S] = CatchCount;
6917330f729Sjoerg     Visit(S->getHandlerBlock());
6927330f729Sjoerg   }
6937330f729Sjoerg 
VisitAbstractConditionalOperator__anona4a975a40111::ComputeRegionCounts6947330f729Sjoerg   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
6957330f729Sjoerg     RecordStmtCount(E);
6967330f729Sjoerg     uint64_t ParentCount = CurrentCount;
6977330f729Sjoerg     Visit(E->getCond());
6987330f729Sjoerg 
6997330f729Sjoerg     // Counter tracks the "true" part of a conditional operator. The
7007330f729Sjoerg     // count in the "false" part will be calculated from this counter.
7017330f729Sjoerg     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
7027330f729Sjoerg     CountMap[E->getTrueExpr()] = TrueCount;
7037330f729Sjoerg     Visit(E->getTrueExpr());
7047330f729Sjoerg     uint64_t OutCount = CurrentCount;
7057330f729Sjoerg 
7067330f729Sjoerg     uint64_t FalseCount = setCount(ParentCount - TrueCount);
7077330f729Sjoerg     CountMap[E->getFalseExpr()] = FalseCount;
7087330f729Sjoerg     Visit(E->getFalseExpr());
7097330f729Sjoerg     OutCount += CurrentCount;
7107330f729Sjoerg 
7117330f729Sjoerg     setCount(OutCount);
7127330f729Sjoerg     RecordNextStmtCount = true;
7137330f729Sjoerg   }
7147330f729Sjoerg 
VisitBinLAnd__anona4a975a40111::ComputeRegionCounts7157330f729Sjoerg   void VisitBinLAnd(const BinaryOperator *E) {
7167330f729Sjoerg     RecordStmtCount(E);
7177330f729Sjoerg     uint64_t ParentCount = CurrentCount;
7187330f729Sjoerg     Visit(E->getLHS());
7197330f729Sjoerg     // Counter tracks the right hand side of a logical and operator.
7207330f729Sjoerg     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
7217330f729Sjoerg     CountMap[E->getRHS()] = RHSCount;
7227330f729Sjoerg     Visit(E->getRHS());
7237330f729Sjoerg     setCount(ParentCount + RHSCount - CurrentCount);
7247330f729Sjoerg     RecordNextStmtCount = true;
7257330f729Sjoerg   }
7267330f729Sjoerg 
VisitBinLOr__anona4a975a40111::ComputeRegionCounts7277330f729Sjoerg   void VisitBinLOr(const BinaryOperator *E) {
7287330f729Sjoerg     RecordStmtCount(E);
7297330f729Sjoerg     uint64_t ParentCount = CurrentCount;
7307330f729Sjoerg     Visit(E->getLHS());
7317330f729Sjoerg     // Counter tracks the right hand side of a logical or operator.
7327330f729Sjoerg     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
7337330f729Sjoerg     CountMap[E->getRHS()] = RHSCount;
7347330f729Sjoerg     Visit(E->getRHS());
7357330f729Sjoerg     setCount(ParentCount + RHSCount - CurrentCount);
7367330f729Sjoerg     RecordNextStmtCount = true;
7377330f729Sjoerg   }
7387330f729Sjoerg };
7397330f729Sjoerg } // end anonymous namespace
7407330f729Sjoerg 
combine(HashType Type)7417330f729Sjoerg void PGOHash::combine(HashType Type) {
7427330f729Sjoerg   // Check that we never combine 0 and only have six bits.
7437330f729Sjoerg   assert(Type && "Hash is invalid: unexpected type 0");
7447330f729Sjoerg   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
7457330f729Sjoerg 
7467330f729Sjoerg   // Pass through MD5 if enough work has built up.
7477330f729Sjoerg   if (Count && Count % NumTypesPerWord == 0) {
7487330f729Sjoerg     using namespace llvm::support;
7497330f729Sjoerg     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
7507330f729Sjoerg     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
7517330f729Sjoerg     Working = 0;
7527330f729Sjoerg   }
7537330f729Sjoerg 
7547330f729Sjoerg   // Accumulate the current type.
7557330f729Sjoerg   ++Count;
7567330f729Sjoerg   Working = Working << NumBitsPerType | Type;
7577330f729Sjoerg }
7587330f729Sjoerg 
finalize()7597330f729Sjoerg uint64_t PGOHash::finalize() {
7607330f729Sjoerg   // Use Working as the hash directly if we never used MD5.
7617330f729Sjoerg   if (Count <= NumTypesPerWord)
7627330f729Sjoerg     // No need to byte swap here, since none of the math was endian-dependent.
7637330f729Sjoerg     // This number will be byte-swapped as required on endianness transitions,
7647330f729Sjoerg     // so we will see the same value on the other side.
7657330f729Sjoerg     return Working;
7667330f729Sjoerg 
7677330f729Sjoerg   // Check for remaining work in Working.
768*e038c9c4Sjoerg   if (Working) {
769*e038c9c4Sjoerg     // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
770*e038c9c4Sjoerg     // is buggy because it converts a uint64_t into an array of uint8_t.
771*e038c9c4Sjoerg     if (HashVersion < PGO_HASH_V3) {
772*e038c9c4Sjoerg       MD5.update({(uint8_t)Working});
773*e038c9c4Sjoerg     } else {
774*e038c9c4Sjoerg       using namespace llvm::support;
775*e038c9c4Sjoerg       uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
776*e038c9c4Sjoerg       MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
777*e038c9c4Sjoerg     }
778*e038c9c4Sjoerg   }
7797330f729Sjoerg 
7807330f729Sjoerg   // Finalize the MD5 and return the hash.
7817330f729Sjoerg   llvm::MD5::MD5Result Result;
7827330f729Sjoerg   MD5.final(Result);
7837330f729Sjoerg   return Result.low();
7847330f729Sjoerg }
7857330f729Sjoerg 
assignRegionCounters(GlobalDecl GD,llvm::Function * Fn)7867330f729Sjoerg void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
7877330f729Sjoerg   const Decl *D = GD.getDecl();
7887330f729Sjoerg   if (!D->hasBody())
7897330f729Sjoerg     return;
7907330f729Sjoerg 
791*e038c9c4Sjoerg   // Skip CUDA/HIP kernel launch stub functions.
792*e038c9c4Sjoerg   if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
793*e038c9c4Sjoerg       D->hasAttr<CUDAGlobalAttr>())
794*e038c9c4Sjoerg     return;
795*e038c9c4Sjoerg 
7967330f729Sjoerg   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
7977330f729Sjoerg   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
7987330f729Sjoerg   if (!InstrumentRegions && !PGOReader)
7997330f729Sjoerg     return;
8007330f729Sjoerg   if (D->isImplicit())
8017330f729Sjoerg     return;
8027330f729Sjoerg   // Constructors and destructors may be represented by several functions in IR.
8037330f729Sjoerg   // If so, instrument only base variant, others are implemented by delegation
8047330f729Sjoerg   // to the base one, it would be counted twice otherwise.
8057330f729Sjoerg   if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
8067330f729Sjoerg     if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
8077330f729Sjoerg       if (GD.getCtorType() != Ctor_Base &&
8087330f729Sjoerg           CodeGenFunction::IsConstructorDelegationValid(CCD))
8097330f729Sjoerg         return;
8107330f729Sjoerg   }
8117330f729Sjoerg   if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
8127330f729Sjoerg     return;
8137330f729Sjoerg 
8147330f729Sjoerg   CGM.ClearUnusedCoverageMapping(D);
815*e038c9c4Sjoerg   if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
816*e038c9c4Sjoerg     return;
817*e038c9c4Sjoerg 
8187330f729Sjoerg   setFuncName(Fn);
8197330f729Sjoerg 
8207330f729Sjoerg   mapRegionCounters(D);
8217330f729Sjoerg   if (CGM.getCodeGenOpts().CoverageMapping)
8227330f729Sjoerg     emitCounterRegionMapping(D);
8237330f729Sjoerg   if (PGOReader) {
8247330f729Sjoerg     SourceManager &SM = CGM.getContext().getSourceManager();
8257330f729Sjoerg     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
8267330f729Sjoerg     computeRegionCounts(D);
8277330f729Sjoerg     applyFunctionAttributes(PGOReader, Fn);
8287330f729Sjoerg   }
8297330f729Sjoerg }
8307330f729Sjoerg 
mapRegionCounters(const Decl * D)8317330f729Sjoerg void CodeGenPGO::mapRegionCounters(const Decl *D) {
8327330f729Sjoerg   // Use the latest hash version when inserting instrumentation, but use the
8337330f729Sjoerg   // version in the indexed profile if we're reading PGO data.
8347330f729Sjoerg   PGOHashVersion HashVersion = PGO_HASH_LATEST;
835*e038c9c4Sjoerg   uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
836*e038c9c4Sjoerg   if (auto *PGOReader = CGM.getPGOReader()) {
8377330f729Sjoerg     HashVersion = getPGOHashVersion(PGOReader, CGM);
838*e038c9c4Sjoerg     ProfileVersion = PGOReader->getVersion();
839*e038c9c4Sjoerg   }
8407330f729Sjoerg 
8417330f729Sjoerg   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
842*e038c9c4Sjoerg   MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap);
8437330f729Sjoerg   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
8447330f729Sjoerg     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
8457330f729Sjoerg   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
8467330f729Sjoerg     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
8477330f729Sjoerg   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
8487330f729Sjoerg     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
8497330f729Sjoerg   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
8507330f729Sjoerg     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
8517330f729Sjoerg   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
8527330f729Sjoerg   NumRegionCounters = Walker.NextCounter;
8537330f729Sjoerg   FunctionHash = Walker.Hash.finalize();
8547330f729Sjoerg }
8557330f729Sjoerg 
skipRegionMappingForDecl(const Decl * D)8567330f729Sjoerg bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
8577330f729Sjoerg   if (!D->getBody())
8587330f729Sjoerg     return true;
8597330f729Sjoerg 
860*e038c9c4Sjoerg   // Skip host-only functions in the CUDA device compilation and device-only
861*e038c9c4Sjoerg   // functions in the host compilation. Just roughly filter them out based on
862*e038c9c4Sjoerg   // the function attributes. If there are effectively host-only or device-only
863*e038c9c4Sjoerg   // ones, their coverage mapping may still be generated.
864*e038c9c4Sjoerg   if (CGM.getLangOpts().CUDA &&
865*e038c9c4Sjoerg       ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
866*e038c9c4Sjoerg         !D->hasAttr<CUDAGlobalAttr>()) ||
867*e038c9c4Sjoerg        (!CGM.getLangOpts().CUDAIsDevice &&
868*e038c9c4Sjoerg         (D->hasAttr<CUDAGlobalAttr>() ||
869*e038c9c4Sjoerg          (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
870*e038c9c4Sjoerg     return true;
871*e038c9c4Sjoerg 
8727330f729Sjoerg   // Don't map the functions in system headers.
8737330f729Sjoerg   const auto &SM = CGM.getContext().getSourceManager();
8747330f729Sjoerg   auto Loc = D->getBody()->getBeginLoc();
8757330f729Sjoerg   return SM.isInSystemHeader(Loc);
8767330f729Sjoerg }
8777330f729Sjoerg 
emitCounterRegionMapping(const Decl * D)8787330f729Sjoerg void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
8797330f729Sjoerg   if (skipRegionMappingForDecl(D))
8807330f729Sjoerg     return;
8817330f729Sjoerg 
8827330f729Sjoerg   std::string CoverageMapping;
8837330f729Sjoerg   llvm::raw_string_ostream OS(CoverageMapping);
8847330f729Sjoerg   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
8857330f729Sjoerg                                 CGM.getContext().getSourceManager(),
8867330f729Sjoerg                                 CGM.getLangOpts(), RegionCounterMap.get());
8877330f729Sjoerg   MappingGen.emitCounterMapping(D, OS);
8887330f729Sjoerg   OS.flush();
8897330f729Sjoerg 
8907330f729Sjoerg   if (CoverageMapping.empty())
8917330f729Sjoerg     return;
8927330f729Sjoerg 
8937330f729Sjoerg   CGM.getCoverageMapping()->addFunctionMappingRecord(
8947330f729Sjoerg       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
8957330f729Sjoerg }
8967330f729Sjoerg 
8977330f729Sjoerg void
emitEmptyCounterMapping(const Decl * D,StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)8987330f729Sjoerg CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
8997330f729Sjoerg                                     llvm::GlobalValue::LinkageTypes Linkage) {
9007330f729Sjoerg   if (skipRegionMappingForDecl(D))
9017330f729Sjoerg     return;
9027330f729Sjoerg 
9037330f729Sjoerg   std::string CoverageMapping;
9047330f729Sjoerg   llvm::raw_string_ostream OS(CoverageMapping);
9057330f729Sjoerg   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
9067330f729Sjoerg                                 CGM.getContext().getSourceManager(),
9077330f729Sjoerg                                 CGM.getLangOpts());
9087330f729Sjoerg   MappingGen.emitEmptyMapping(D, OS);
9097330f729Sjoerg   OS.flush();
9107330f729Sjoerg 
9117330f729Sjoerg   if (CoverageMapping.empty())
9127330f729Sjoerg     return;
9137330f729Sjoerg 
9147330f729Sjoerg   setFuncName(Name, Linkage);
9157330f729Sjoerg   CGM.getCoverageMapping()->addFunctionMappingRecord(
9167330f729Sjoerg       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
9177330f729Sjoerg }
9187330f729Sjoerg 
computeRegionCounts(const Decl * D)9197330f729Sjoerg void CodeGenPGO::computeRegionCounts(const Decl *D) {
9207330f729Sjoerg   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
9217330f729Sjoerg   ComputeRegionCounts Walker(*StmtCountMap, *this);
9227330f729Sjoerg   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
9237330f729Sjoerg     Walker.VisitFunctionDecl(FD);
9247330f729Sjoerg   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
9257330f729Sjoerg     Walker.VisitObjCMethodDecl(MD);
9267330f729Sjoerg   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
9277330f729Sjoerg     Walker.VisitBlockDecl(BD);
9287330f729Sjoerg   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
9297330f729Sjoerg     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
9307330f729Sjoerg }
9317330f729Sjoerg 
9327330f729Sjoerg void
applyFunctionAttributes(llvm::IndexedInstrProfReader * PGOReader,llvm::Function * Fn)9337330f729Sjoerg CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
9347330f729Sjoerg                                     llvm::Function *Fn) {
9357330f729Sjoerg   if (!haveRegionCounts())
9367330f729Sjoerg     return;
9377330f729Sjoerg 
9387330f729Sjoerg   uint64_t FunctionCount = getRegionCount(nullptr);
9397330f729Sjoerg   Fn->setEntryCount(FunctionCount);
9407330f729Sjoerg }
9417330f729Sjoerg 
emitCounterIncrement(CGBuilderTy & Builder,const Stmt * S,llvm::Value * StepV)9427330f729Sjoerg void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
9437330f729Sjoerg                                       llvm::Value *StepV) {
9447330f729Sjoerg   if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
9457330f729Sjoerg     return;
9467330f729Sjoerg   if (!Builder.GetInsertBlock())
9477330f729Sjoerg     return;
9487330f729Sjoerg 
9497330f729Sjoerg   unsigned Counter = (*RegionCounterMap)[S];
9507330f729Sjoerg   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
9517330f729Sjoerg 
9527330f729Sjoerg   llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
9537330f729Sjoerg                          Builder.getInt64(FunctionHash),
9547330f729Sjoerg                          Builder.getInt32(NumRegionCounters),
9557330f729Sjoerg                          Builder.getInt32(Counter), StepV};
9567330f729Sjoerg   if (!StepV)
9577330f729Sjoerg     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
9587330f729Sjoerg                        makeArrayRef(Args, 4));
9597330f729Sjoerg   else
9607330f729Sjoerg     Builder.CreateCall(
9617330f729Sjoerg         CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
9627330f729Sjoerg         makeArrayRef(Args));
9637330f729Sjoerg }
9647330f729Sjoerg 
setValueProfilingFlag(llvm::Module & M)965*e038c9c4Sjoerg void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
966*e038c9c4Sjoerg   if (CGM.getCodeGenOpts().hasProfileClangInstr())
967*e038c9c4Sjoerg     M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling",
968*e038c9c4Sjoerg                     uint32_t(EnableValueProfiling));
969*e038c9c4Sjoerg }
970*e038c9c4Sjoerg 
9717330f729Sjoerg // This method either inserts a call to the profile run-time during
9727330f729Sjoerg // instrumentation or puts profile data into metadata for PGO use.
valueProfile(CGBuilderTy & Builder,uint32_t ValueKind,llvm::Instruction * ValueSite,llvm::Value * ValuePtr)9737330f729Sjoerg void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
9747330f729Sjoerg     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
9757330f729Sjoerg 
9767330f729Sjoerg   if (!EnableValueProfiling)
9777330f729Sjoerg     return;
9787330f729Sjoerg 
9797330f729Sjoerg   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
9807330f729Sjoerg     return;
9817330f729Sjoerg 
9827330f729Sjoerg   if (isa<llvm::Constant>(ValuePtr))
9837330f729Sjoerg     return;
9847330f729Sjoerg 
9857330f729Sjoerg   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
9867330f729Sjoerg   if (InstrumentValueSites && RegionCounterMap) {
9877330f729Sjoerg     auto BuilderInsertPoint = Builder.saveIP();
9887330f729Sjoerg     Builder.SetInsertPoint(ValueSite);
9897330f729Sjoerg     llvm::Value *Args[5] = {
9907330f729Sjoerg         llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
9917330f729Sjoerg         Builder.getInt64(FunctionHash),
9927330f729Sjoerg         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
9937330f729Sjoerg         Builder.getInt32(ValueKind),
9947330f729Sjoerg         Builder.getInt32(NumValueSites[ValueKind]++)
9957330f729Sjoerg     };
9967330f729Sjoerg     Builder.CreateCall(
9977330f729Sjoerg         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
9987330f729Sjoerg     Builder.restoreIP(BuilderInsertPoint);
9997330f729Sjoerg     return;
10007330f729Sjoerg   }
10017330f729Sjoerg 
10027330f729Sjoerg   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
10037330f729Sjoerg   if (PGOReader && haveRegionCounts()) {
10047330f729Sjoerg     // We record the top most called three functions at each call site.
10057330f729Sjoerg     // Profile metadata contains "VP" string identifying this metadata
10067330f729Sjoerg     // as value profiling data, then a uint32_t value for the value profiling
10077330f729Sjoerg     // kind, a uint64_t value for the total number of times the call is
10087330f729Sjoerg     // executed, followed by the function hash and execution count (uint64_t)
10097330f729Sjoerg     // pairs for each function.
10107330f729Sjoerg     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
10117330f729Sjoerg       return;
10127330f729Sjoerg 
10137330f729Sjoerg     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
10147330f729Sjoerg                             (llvm::InstrProfValueKind)ValueKind,
10157330f729Sjoerg                             NumValueSites[ValueKind]);
10167330f729Sjoerg 
10177330f729Sjoerg     NumValueSites[ValueKind]++;
10187330f729Sjoerg   }
10197330f729Sjoerg }
10207330f729Sjoerg 
loadRegionCounts(llvm::IndexedInstrProfReader * PGOReader,bool IsInMainFile)10217330f729Sjoerg void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
10227330f729Sjoerg                                   bool IsInMainFile) {
10237330f729Sjoerg   CGM.getPGOStats().addVisited(IsInMainFile);
10247330f729Sjoerg   RegionCounts.clear();
10257330f729Sjoerg   llvm::Expected<llvm::InstrProfRecord> RecordExpected =
10267330f729Sjoerg       PGOReader->getInstrProfRecord(FuncName, FunctionHash);
10277330f729Sjoerg   if (auto E = RecordExpected.takeError()) {
10287330f729Sjoerg     auto IPE = llvm::InstrProfError::take(std::move(E));
10297330f729Sjoerg     if (IPE == llvm::instrprof_error::unknown_function)
10307330f729Sjoerg       CGM.getPGOStats().addMissing(IsInMainFile);
10317330f729Sjoerg     else if (IPE == llvm::instrprof_error::hash_mismatch)
10327330f729Sjoerg       CGM.getPGOStats().addMismatched(IsInMainFile);
10337330f729Sjoerg     else if (IPE == llvm::instrprof_error::malformed)
10347330f729Sjoerg       // TODO: Consider a more specific warning for this case.
10357330f729Sjoerg       CGM.getPGOStats().addMismatched(IsInMainFile);
10367330f729Sjoerg     return;
10377330f729Sjoerg   }
10387330f729Sjoerg   ProfRecord =
10397330f729Sjoerg       std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
10407330f729Sjoerg   RegionCounts = ProfRecord->Counts;
10417330f729Sjoerg }
10427330f729Sjoerg 
10437330f729Sjoerg /// Calculate what to divide by to scale weights.
10447330f729Sjoerg ///
10457330f729Sjoerg /// Given the maximum weight, calculate a divisor that will scale all the
10467330f729Sjoerg /// weights to strictly less than UINT32_MAX.
calculateWeightScale(uint64_t MaxWeight)10477330f729Sjoerg static uint64_t calculateWeightScale(uint64_t MaxWeight) {
10487330f729Sjoerg   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
10497330f729Sjoerg }
10507330f729Sjoerg 
10517330f729Sjoerg /// Scale an individual branch weight (and add 1).
10527330f729Sjoerg ///
10537330f729Sjoerg /// Scale a 64-bit weight down to 32-bits using \c Scale.
10547330f729Sjoerg ///
10557330f729Sjoerg /// According to Laplace's Rule of Succession, it is better to compute the
10567330f729Sjoerg /// weight based on the count plus 1, so universally add 1 to the value.
10577330f729Sjoerg ///
10587330f729Sjoerg /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
10597330f729Sjoerg /// greater than \c Weight.
scaleBranchWeight(uint64_t Weight,uint64_t Scale)10607330f729Sjoerg static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
10617330f729Sjoerg   assert(Scale && "scale by 0?");
10627330f729Sjoerg   uint64_t Scaled = Weight / Scale + 1;
10637330f729Sjoerg   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
10647330f729Sjoerg   return Scaled;
10657330f729Sjoerg }
10667330f729Sjoerg 
createProfileWeights(uint64_t TrueCount,uint64_t FalseCount) const10677330f729Sjoerg llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1068*e038c9c4Sjoerg                                                     uint64_t FalseCount) const {
10697330f729Sjoerg   // Check for empty weights.
10707330f729Sjoerg   if (!TrueCount && !FalseCount)
10717330f729Sjoerg     return nullptr;
10727330f729Sjoerg 
10737330f729Sjoerg   // Calculate how to scale down to 32-bits.
10747330f729Sjoerg   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
10757330f729Sjoerg 
10767330f729Sjoerg   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
10777330f729Sjoerg   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
10787330f729Sjoerg                                       scaleBranchWeight(FalseCount, Scale));
10797330f729Sjoerg }
10807330f729Sjoerg 
10817330f729Sjoerg llvm::MDNode *
createProfileWeights(ArrayRef<uint64_t> Weights) const1082*e038c9c4Sjoerg CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
10837330f729Sjoerg   // We need at least two elements to create meaningful weights.
10847330f729Sjoerg   if (Weights.size() < 2)
10857330f729Sjoerg     return nullptr;
10867330f729Sjoerg 
10877330f729Sjoerg   // Check for empty weights.
10887330f729Sjoerg   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
10897330f729Sjoerg   if (MaxWeight == 0)
10907330f729Sjoerg     return nullptr;
10917330f729Sjoerg 
10927330f729Sjoerg   // Calculate how to scale down to 32-bits.
10937330f729Sjoerg   uint64_t Scale = calculateWeightScale(MaxWeight);
10947330f729Sjoerg 
10957330f729Sjoerg   SmallVector<uint32_t, 16> ScaledWeights;
10967330f729Sjoerg   ScaledWeights.reserve(Weights.size());
10977330f729Sjoerg   for (uint64_t W : Weights)
10987330f729Sjoerg     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
10997330f729Sjoerg 
11007330f729Sjoerg   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
11017330f729Sjoerg   return MDHelper.createBranchWeights(ScaledWeights);
11027330f729Sjoerg }
11037330f729Sjoerg 
1104*e038c9c4Sjoerg llvm::MDNode *
createProfileWeightsForLoop(const Stmt * Cond,uint64_t LoopCount) const1105*e038c9c4Sjoerg CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1106*e038c9c4Sjoerg                                              uint64_t LoopCount) const {
11077330f729Sjoerg   if (!PGO.haveRegionCounts())
11087330f729Sjoerg     return nullptr;
11097330f729Sjoerg   Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1110*e038c9c4Sjoerg   if (!CondCount || *CondCount == 0)
11117330f729Sjoerg     return nullptr;
11127330f729Sjoerg   return createProfileWeights(LoopCount,
11137330f729Sjoerg                               std::max(*CondCount, LoopCount) - LoopCount);
11147330f729Sjoerg }
1115