xref: /openbsd-src/gnu/llvm/clang/lib/CodeGen/CodeGenPGO.cpp (revision 12c855180aad702bbcca06e0398d774beeafb155)
1e5dd7070Spatrick //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2e5dd7070Spatrick //
3e5dd7070Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e5dd7070Spatrick // See https://llvm.org/LICENSE.txt for license information.
5e5dd7070Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e5dd7070Spatrick //
7e5dd7070Spatrick //===----------------------------------------------------------------------===//
8e5dd7070Spatrick //
9e5dd7070Spatrick // Instrumentation-based profile-guided optimization
10e5dd7070Spatrick //
11e5dd7070Spatrick //===----------------------------------------------------------------------===//
12e5dd7070Spatrick 
13e5dd7070Spatrick #include "CodeGenPGO.h"
14e5dd7070Spatrick #include "CodeGenFunction.h"
15e5dd7070Spatrick #include "CoverageMappingGen.h"
16e5dd7070Spatrick #include "clang/AST/RecursiveASTVisitor.h"
17e5dd7070Spatrick #include "clang/AST/StmtVisitor.h"
18e5dd7070Spatrick #include "llvm/IR/Intrinsics.h"
19e5dd7070Spatrick #include "llvm/IR/MDBuilder.h"
20e5dd7070Spatrick #include "llvm/Support/CommandLine.h"
21e5dd7070Spatrick #include "llvm/Support/Endian.h"
22e5dd7070Spatrick #include "llvm/Support/FileSystem.h"
23e5dd7070Spatrick #include "llvm/Support/MD5.h"
24*12c85518Srobert #include <optional>
25e5dd7070Spatrick 
26e5dd7070Spatrick static llvm::cl::opt<bool>
27*12c85518Srobert     EnableValueProfiling("enable-value-profiling",
28e5dd7070Spatrick                          llvm::cl::desc("Enable value profiling"),
29e5dd7070Spatrick                          llvm::cl::Hidden, llvm::cl::init(false));
30e5dd7070Spatrick 
31e5dd7070Spatrick using namespace clang;
32e5dd7070Spatrick using namespace CodeGen;
33e5dd7070Spatrick 
setFuncName(StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)34e5dd7070Spatrick void CodeGenPGO::setFuncName(StringRef Name,
35e5dd7070Spatrick                              llvm::GlobalValue::LinkageTypes Linkage) {
36e5dd7070Spatrick   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
37e5dd7070Spatrick   FuncName = llvm::getPGOFuncName(
38e5dd7070Spatrick       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
39e5dd7070Spatrick       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
40e5dd7070Spatrick 
41e5dd7070Spatrick   // If we're generating a profile, create a variable for the name.
42e5dd7070Spatrick   if (CGM.getCodeGenOpts().hasProfileClangInstr())
43e5dd7070Spatrick     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
44e5dd7070Spatrick }
45e5dd7070Spatrick 
setFuncName(llvm::Function * Fn)46e5dd7070Spatrick void CodeGenPGO::setFuncName(llvm::Function *Fn) {
47e5dd7070Spatrick   setFuncName(Fn->getName(), Fn->getLinkage());
48e5dd7070Spatrick   // Create PGOFuncName meta data.
49e5dd7070Spatrick   llvm::createPGOFuncNameMetadata(*Fn, FuncName);
50e5dd7070Spatrick }
51e5dd7070Spatrick 
52e5dd7070Spatrick /// The version of the PGO hash algorithm.
53e5dd7070Spatrick enum PGOHashVersion : unsigned {
54e5dd7070Spatrick   PGO_HASH_V1,
55e5dd7070Spatrick   PGO_HASH_V2,
56ec727ea7Spatrick   PGO_HASH_V3,
57e5dd7070Spatrick 
58e5dd7070Spatrick   // Keep this set to the latest hash version.
59ec727ea7Spatrick   PGO_HASH_LATEST = PGO_HASH_V3
60e5dd7070Spatrick };
61e5dd7070Spatrick 
62e5dd7070Spatrick namespace {
63e5dd7070Spatrick /// Stable hasher for PGO region counters.
64e5dd7070Spatrick ///
65e5dd7070Spatrick /// PGOHash produces a stable hash of a given function's control flow.
66e5dd7070Spatrick ///
67e5dd7070Spatrick /// Changing the output of this hash will invalidate all previously generated
68e5dd7070Spatrick /// profiles -- i.e., don't do it.
69e5dd7070Spatrick ///
70e5dd7070Spatrick /// \note  When this hash does eventually change (years?), we still need to
71e5dd7070Spatrick /// support old hashes.  We'll need to pull in the version number from the
72e5dd7070Spatrick /// profile data format and use the matching hash function.
73e5dd7070Spatrick class PGOHash {
74e5dd7070Spatrick   uint64_t Working;
75e5dd7070Spatrick   unsigned Count;
76e5dd7070Spatrick   PGOHashVersion HashVersion;
77e5dd7070Spatrick   llvm::MD5 MD5;
78e5dd7070Spatrick 
79e5dd7070Spatrick   static const int NumBitsPerType = 6;
80e5dd7070Spatrick   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
81e5dd7070Spatrick   static const unsigned TooBig = 1u << NumBitsPerType;
82e5dd7070Spatrick 
83e5dd7070Spatrick public:
84e5dd7070Spatrick   /// Hash values for AST nodes.
85e5dd7070Spatrick   ///
86e5dd7070Spatrick   /// Distinct values for AST nodes that have region counters attached.
87e5dd7070Spatrick   ///
88e5dd7070Spatrick   /// These values must be stable.  All new members must be added at the end,
89e5dd7070Spatrick   /// and no members should be removed.  Changing the enumeration value for an
90e5dd7070Spatrick   /// AST node will affect the hash of every function that contains that node.
91e5dd7070Spatrick   enum HashType : unsigned char {
92e5dd7070Spatrick     None = 0,
93e5dd7070Spatrick     LabelStmt = 1,
94e5dd7070Spatrick     WhileStmt,
95e5dd7070Spatrick     DoStmt,
96e5dd7070Spatrick     ForStmt,
97e5dd7070Spatrick     CXXForRangeStmt,
98e5dd7070Spatrick     ObjCForCollectionStmt,
99e5dd7070Spatrick     SwitchStmt,
100e5dd7070Spatrick     CaseStmt,
101e5dd7070Spatrick     DefaultStmt,
102e5dd7070Spatrick     IfStmt,
103e5dd7070Spatrick     CXXTryStmt,
104e5dd7070Spatrick     CXXCatchStmt,
105e5dd7070Spatrick     ConditionalOperator,
106e5dd7070Spatrick     BinaryOperatorLAnd,
107e5dd7070Spatrick     BinaryOperatorLOr,
108e5dd7070Spatrick     BinaryConditionalOperator,
109e5dd7070Spatrick     // The preceding values are available with PGO_HASH_V1.
110e5dd7070Spatrick 
111e5dd7070Spatrick     EndOfScope,
112e5dd7070Spatrick     IfThenBranch,
113e5dd7070Spatrick     IfElseBranch,
114e5dd7070Spatrick     GotoStmt,
115e5dd7070Spatrick     IndirectGotoStmt,
116e5dd7070Spatrick     BreakStmt,
117e5dd7070Spatrick     ContinueStmt,
118e5dd7070Spatrick     ReturnStmt,
119e5dd7070Spatrick     ThrowExpr,
120e5dd7070Spatrick     UnaryOperatorLNot,
121e5dd7070Spatrick     BinaryOperatorLT,
122e5dd7070Spatrick     BinaryOperatorGT,
123e5dd7070Spatrick     BinaryOperatorLE,
124e5dd7070Spatrick     BinaryOperatorGE,
125e5dd7070Spatrick     BinaryOperatorEQ,
126e5dd7070Spatrick     BinaryOperatorNE,
127ec727ea7Spatrick     // The preceding values are available since PGO_HASH_V2.
128e5dd7070Spatrick 
129e5dd7070Spatrick     // Keep this last.  It's for the static assert that follows.
130e5dd7070Spatrick     LastHashType
131e5dd7070Spatrick   };
132e5dd7070Spatrick   static_assert(LastHashType <= TooBig, "Too many types in HashType");
133e5dd7070Spatrick 
PGOHash(PGOHashVersion HashVersion)134e5dd7070Spatrick   PGOHash(PGOHashVersion HashVersion)
135*12c85518Srobert       : Working(0), Count(0), HashVersion(HashVersion) {}
136e5dd7070Spatrick   void combine(HashType Type);
137e5dd7070Spatrick   uint64_t finalize();
getHashVersion() const138e5dd7070Spatrick   PGOHashVersion getHashVersion() const { return HashVersion; }
139e5dd7070Spatrick };
140e5dd7070Spatrick const int PGOHash::NumBitsPerType;
141e5dd7070Spatrick const unsigned PGOHash::NumTypesPerWord;
142e5dd7070Spatrick const unsigned PGOHash::TooBig;
143e5dd7070Spatrick 
144e5dd7070Spatrick /// Get the PGO hash version used in the given indexed profile.
getPGOHashVersion(llvm::IndexedInstrProfReader * PGOReader,CodeGenModule & CGM)145e5dd7070Spatrick static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
146e5dd7070Spatrick                                         CodeGenModule &CGM) {
147e5dd7070Spatrick   if (PGOReader->getVersion() <= 4)
148e5dd7070Spatrick     return PGO_HASH_V1;
149ec727ea7Spatrick   if (PGOReader->getVersion() <= 5)
150e5dd7070Spatrick     return PGO_HASH_V2;
151ec727ea7Spatrick   return PGO_HASH_V3;
152e5dd7070Spatrick }
153e5dd7070Spatrick 
154e5dd7070Spatrick /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
155e5dd7070Spatrick struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
156e5dd7070Spatrick   using Base = RecursiveASTVisitor<MapRegionCounters>;
157e5dd7070Spatrick 
158e5dd7070Spatrick   /// The next counter value to assign.
159e5dd7070Spatrick   unsigned NextCounter;
160e5dd7070Spatrick   /// The function hash.
161e5dd7070Spatrick   PGOHash Hash;
162e5dd7070Spatrick   /// The map of statements to counters.
163e5dd7070Spatrick   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
164a9ac8606Spatrick   /// The profile version.
165a9ac8606Spatrick   uint64_t ProfileVersion;
166e5dd7070Spatrick 
MapRegionCounters__anon393d18f00111::MapRegionCounters167a9ac8606Spatrick   MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
168e5dd7070Spatrick                     llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
169a9ac8606Spatrick       : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
170a9ac8606Spatrick         ProfileVersion(ProfileVersion) {}
171e5dd7070Spatrick 
172e5dd7070Spatrick   // Blocks and lambdas are handled as separate functions, so we need not
173e5dd7070Spatrick   // traverse them in the parent context.
TraverseBlockExpr__anon393d18f00111::MapRegionCounters174e5dd7070Spatrick   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
TraverseLambdaExpr__anon393d18f00111::MapRegionCounters175e5dd7070Spatrick   bool TraverseLambdaExpr(LambdaExpr *LE) {
176e5dd7070Spatrick     // Traverse the captures, but not the body.
177e5dd7070Spatrick     for (auto C : zip(LE->captures(), LE->capture_inits()))
178e5dd7070Spatrick       TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
179e5dd7070Spatrick     return true;
180e5dd7070Spatrick   }
TraverseCapturedStmt__anon393d18f00111::MapRegionCounters181e5dd7070Spatrick   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
182e5dd7070Spatrick 
VisitDecl__anon393d18f00111::MapRegionCounters183e5dd7070Spatrick   bool VisitDecl(const Decl *D) {
184e5dd7070Spatrick     switch (D->getKind()) {
185e5dd7070Spatrick     default:
186e5dd7070Spatrick       break;
187e5dd7070Spatrick     case Decl::Function:
188e5dd7070Spatrick     case Decl::CXXMethod:
189e5dd7070Spatrick     case Decl::CXXConstructor:
190e5dd7070Spatrick     case Decl::CXXDestructor:
191e5dd7070Spatrick     case Decl::CXXConversion:
192e5dd7070Spatrick     case Decl::ObjCMethod:
193e5dd7070Spatrick     case Decl::Block:
194e5dd7070Spatrick     case Decl::Captured:
195e5dd7070Spatrick       CounterMap[D->getBody()] = NextCounter++;
196e5dd7070Spatrick       break;
197e5dd7070Spatrick     }
198e5dd7070Spatrick     return true;
199e5dd7070Spatrick   }
200e5dd7070Spatrick 
201e5dd7070Spatrick   /// If \p S gets a fresh counter, update the counter mappings. Return the
202e5dd7070Spatrick   /// V1 hash of \p S.
updateCounterMappings__anon393d18f00111::MapRegionCounters203e5dd7070Spatrick   PGOHash::HashType updateCounterMappings(Stmt *S) {
204e5dd7070Spatrick     auto Type = getHashType(PGO_HASH_V1, S);
205e5dd7070Spatrick     if (Type != PGOHash::None)
206e5dd7070Spatrick       CounterMap[S] = NextCounter++;
207e5dd7070Spatrick     return Type;
208e5dd7070Spatrick   }
209e5dd7070Spatrick 
210a9ac8606Spatrick   /// The RHS of all logical operators gets a fresh counter in order to count
211a9ac8606Spatrick   /// how many times the RHS evaluates to true or false, depending on the
212a9ac8606Spatrick   /// semantics of the operator. This is only valid for ">= v7" of the profile
213a9ac8606Spatrick   /// version so that we facilitate backward compatibility.
VisitBinaryOperator__anon393d18f00111::MapRegionCounters214a9ac8606Spatrick   bool VisitBinaryOperator(BinaryOperator *S) {
215a9ac8606Spatrick     if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
216a9ac8606Spatrick       if (S->isLogicalOp() &&
217a9ac8606Spatrick           CodeGenFunction::isInstrumentedCondition(S->getRHS()))
218a9ac8606Spatrick         CounterMap[S->getRHS()] = NextCounter++;
219a9ac8606Spatrick     return Base::VisitBinaryOperator(S);
220a9ac8606Spatrick   }
221a9ac8606Spatrick 
222e5dd7070Spatrick   /// Include \p S in the function hash.
VisitStmt__anon393d18f00111::MapRegionCounters223e5dd7070Spatrick   bool VisitStmt(Stmt *S) {
224e5dd7070Spatrick     auto Type = updateCounterMappings(S);
225e5dd7070Spatrick     if (Hash.getHashVersion() != PGO_HASH_V1)
226e5dd7070Spatrick       Type = getHashType(Hash.getHashVersion(), S);
227e5dd7070Spatrick     if (Type != PGOHash::None)
228e5dd7070Spatrick       Hash.combine(Type);
229e5dd7070Spatrick     return true;
230e5dd7070Spatrick   }
231e5dd7070Spatrick 
TraverseIfStmt__anon393d18f00111::MapRegionCounters232e5dd7070Spatrick   bool TraverseIfStmt(IfStmt *If) {
233e5dd7070Spatrick     // If we used the V1 hash, use the default traversal.
234e5dd7070Spatrick     if (Hash.getHashVersion() == PGO_HASH_V1)
235e5dd7070Spatrick       return Base::TraverseIfStmt(If);
236e5dd7070Spatrick 
237e5dd7070Spatrick     // Otherwise, keep track of which branch we're in while traversing.
238e5dd7070Spatrick     VisitStmt(If);
239e5dd7070Spatrick     for (Stmt *CS : If->children()) {
240e5dd7070Spatrick       if (!CS)
241e5dd7070Spatrick         continue;
242e5dd7070Spatrick       if (CS == If->getThen())
243e5dd7070Spatrick         Hash.combine(PGOHash::IfThenBranch);
244e5dd7070Spatrick       else if (CS == If->getElse())
245e5dd7070Spatrick         Hash.combine(PGOHash::IfElseBranch);
246e5dd7070Spatrick       TraverseStmt(CS);
247e5dd7070Spatrick     }
248e5dd7070Spatrick     Hash.combine(PGOHash::EndOfScope);
249e5dd7070Spatrick     return true;
250e5dd7070Spatrick   }
251e5dd7070Spatrick 
252e5dd7070Spatrick // If the statement type \p N is nestable, and its nesting impacts profile
253e5dd7070Spatrick // stability, define a custom traversal which tracks the end of the statement
254e5dd7070Spatrick // in the hash (provided we're not using the V1 hash).
255e5dd7070Spatrick #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
256e5dd7070Spatrick   bool Traverse##N(N *S) {                                                     \
257e5dd7070Spatrick     Base::Traverse##N(S);                                                      \
258e5dd7070Spatrick     if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
259e5dd7070Spatrick       Hash.combine(PGOHash::EndOfScope);                                       \
260e5dd7070Spatrick     return true;                                                               \
261e5dd7070Spatrick   }
262e5dd7070Spatrick 
263e5dd7070Spatrick   DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
DEFINE_NESTABLE_TRAVERSAL__anon393d18f00111::MapRegionCounters264e5dd7070Spatrick   DEFINE_NESTABLE_TRAVERSAL(DoStmt)
265e5dd7070Spatrick   DEFINE_NESTABLE_TRAVERSAL(ForStmt)
266e5dd7070Spatrick   DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
267e5dd7070Spatrick   DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
268e5dd7070Spatrick   DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
269e5dd7070Spatrick   DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
270e5dd7070Spatrick 
271e5dd7070Spatrick   /// Get version \p HashVersion of the PGO hash for \p S.
272e5dd7070Spatrick   PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
273e5dd7070Spatrick     switch (S->getStmtClass()) {
274e5dd7070Spatrick     default:
275e5dd7070Spatrick       break;
276e5dd7070Spatrick     case Stmt::LabelStmtClass:
277e5dd7070Spatrick       return PGOHash::LabelStmt;
278e5dd7070Spatrick     case Stmt::WhileStmtClass:
279e5dd7070Spatrick       return PGOHash::WhileStmt;
280e5dd7070Spatrick     case Stmt::DoStmtClass:
281e5dd7070Spatrick       return PGOHash::DoStmt;
282e5dd7070Spatrick     case Stmt::ForStmtClass:
283e5dd7070Spatrick       return PGOHash::ForStmt;
284e5dd7070Spatrick     case Stmt::CXXForRangeStmtClass:
285e5dd7070Spatrick       return PGOHash::CXXForRangeStmt;
286e5dd7070Spatrick     case Stmt::ObjCForCollectionStmtClass:
287e5dd7070Spatrick       return PGOHash::ObjCForCollectionStmt;
288e5dd7070Spatrick     case Stmt::SwitchStmtClass:
289e5dd7070Spatrick       return PGOHash::SwitchStmt;
290e5dd7070Spatrick     case Stmt::CaseStmtClass:
291e5dd7070Spatrick       return PGOHash::CaseStmt;
292e5dd7070Spatrick     case Stmt::DefaultStmtClass:
293e5dd7070Spatrick       return PGOHash::DefaultStmt;
294e5dd7070Spatrick     case Stmt::IfStmtClass:
295e5dd7070Spatrick       return PGOHash::IfStmt;
296e5dd7070Spatrick     case Stmt::CXXTryStmtClass:
297e5dd7070Spatrick       return PGOHash::CXXTryStmt;
298e5dd7070Spatrick     case Stmt::CXXCatchStmtClass:
299e5dd7070Spatrick       return PGOHash::CXXCatchStmt;
300e5dd7070Spatrick     case Stmt::ConditionalOperatorClass:
301e5dd7070Spatrick       return PGOHash::ConditionalOperator;
302e5dd7070Spatrick     case Stmt::BinaryConditionalOperatorClass:
303e5dd7070Spatrick       return PGOHash::BinaryConditionalOperator;
304e5dd7070Spatrick     case Stmt::BinaryOperatorClass: {
305e5dd7070Spatrick       const BinaryOperator *BO = cast<BinaryOperator>(S);
306e5dd7070Spatrick       if (BO->getOpcode() == BO_LAnd)
307e5dd7070Spatrick         return PGOHash::BinaryOperatorLAnd;
308e5dd7070Spatrick       if (BO->getOpcode() == BO_LOr)
309e5dd7070Spatrick         return PGOHash::BinaryOperatorLOr;
310ec727ea7Spatrick       if (HashVersion >= PGO_HASH_V2) {
311e5dd7070Spatrick         switch (BO->getOpcode()) {
312e5dd7070Spatrick         default:
313e5dd7070Spatrick           break;
314e5dd7070Spatrick         case BO_LT:
315e5dd7070Spatrick           return PGOHash::BinaryOperatorLT;
316e5dd7070Spatrick         case BO_GT:
317e5dd7070Spatrick           return PGOHash::BinaryOperatorGT;
318e5dd7070Spatrick         case BO_LE:
319e5dd7070Spatrick           return PGOHash::BinaryOperatorLE;
320e5dd7070Spatrick         case BO_GE:
321e5dd7070Spatrick           return PGOHash::BinaryOperatorGE;
322e5dd7070Spatrick         case BO_EQ:
323e5dd7070Spatrick           return PGOHash::BinaryOperatorEQ;
324e5dd7070Spatrick         case BO_NE:
325e5dd7070Spatrick           return PGOHash::BinaryOperatorNE;
326e5dd7070Spatrick         }
327e5dd7070Spatrick       }
328e5dd7070Spatrick       break;
329e5dd7070Spatrick     }
330e5dd7070Spatrick     }
331e5dd7070Spatrick 
332ec727ea7Spatrick     if (HashVersion >= PGO_HASH_V2) {
333e5dd7070Spatrick       switch (S->getStmtClass()) {
334e5dd7070Spatrick       default:
335e5dd7070Spatrick         break;
336e5dd7070Spatrick       case Stmt::GotoStmtClass:
337e5dd7070Spatrick         return PGOHash::GotoStmt;
338e5dd7070Spatrick       case Stmt::IndirectGotoStmtClass:
339e5dd7070Spatrick         return PGOHash::IndirectGotoStmt;
340e5dd7070Spatrick       case Stmt::BreakStmtClass:
341e5dd7070Spatrick         return PGOHash::BreakStmt;
342e5dd7070Spatrick       case Stmt::ContinueStmtClass:
343e5dd7070Spatrick         return PGOHash::ContinueStmt;
344e5dd7070Spatrick       case Stmt::ReturnStmtClass:
345e5dd7070Spatrick         return PGOHash::ReturnStmt;
346e5dd7070Spatrick       case Stmt::CXXThrowExprClass:
347e5dd7070Spatrick         return PGOHash::ThrowExpr;
348e5dd7070Spatrick       case Stmt::UnaryOperatorClass: {
349e5dd7070Spatrick         const UnaryOperator *UO = cast<UnaryOperator>(S);
350e5dd7070Spatrick         if (UO->getOpcode() == UO_LNot)
351e5dd7070Spatrick           return PGOHash::UnaryOperatorLNot;
352e5dd7070Spatrick         break;
353e5dd7070Spatrick       }
354e5dd7070Spatrick       }
355e5dd7070Spatrick     }
356e5dd7070Spatrick 
357e5dd7070Spatrick     return PGOHash::None;
358e5dd7070Spatrick   }
359e5dd7070Spatrick };
360e5dd7070Spatrick 
361e5dd7070Spatrick /// A StmtVisitor that propagates the raw counts through the AST and
362e5dd7070Spatrick /// records the count at statements where the value may change.
363e5dd7070Spatrick struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
364e5dd7070Spatrick   /// PGO state.
365e5dd7070Spatrick   CodeGenPGO &PGO;
366e5dd7070Spatrick 
367e5dd7070Spatrick   /// A flag that is set when the current count should be recorded on the
368e5dd7070Spatrick   /// next statement, such as at the exit of a loop.
369e5dd7070Spatrick   bool RecordNextStmtCount;
370e5dd7070Spatrick 
371e5dd7070Spatrick   /// The count at the current location in the traversal.
372e5dd7070Spatrick   uint64_t CurrentCount;
373e5dd7070Spatrick 
374e5dd7070Spatrick   /// The map of statements to count values.
375e5dd7070Spatrick   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
376e5dd7070Spatrick 
377e5dd7070Spatrick   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
378e5dd7070Spatrick   struct BreakContinue {
379e5dd7070Spatrick     uint64_t BreakCount;
380e5dd7070Spatrick     uint64_t ContinueCount;
BreakContinue__anon393d18f00111::ComputeRegionCounts::BreakContinue381e5dd7070Spatrick     BreakContinue() : BreakCount(0), ContinueCount(0) {}
382e5dd7070Spatrick   };
383e5dd7070Spatrick   SmallVector<BreakContinue, 8> BreakContinueStack;
384e5dd7070Spatrick 
ComputeRegionCounts__anon393d18f00111::ComputeRegionCounts385e5dd7070Spatrick   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
386e5dd7070Spatrick                       CodeGenPGO &PGO)
387e5dd7070Spatrick       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
388e5dd7070Spatrick 
RecordStmtCount__anon393d18f00111::ComputeRegionCounts389e5dd7070Spatrick   void RecordStmtCount(const Stmt *S) {
390e5dd7070Spatrick     if (RecordNextStmtCount) {
391e5dd7070Spatrick       CountMap[S] = CurrentCount;
392e5dd7070Spatrick       RecordNextStmtCount = false;
393e5dd7070Spatrick     }
394e5dd7070Spatrick   }
395e5dd7070Spatrick 
396e5dd7070Spatrick   /// Set and return the current count.
setCount__anon393d18f00111::ComputeRegionCounts397e5dd7070Spatrick   uint64_t setCount(uint64_t Count) {
398e5dd7070Spatrick     CurrentCount = Count;
399e5dd7070Spatrick     return Count;
400e5dd7070Spatrick   }
401e5dd7070Spatrick 
VisitStmt__anon393d18f00111::ComputeRegionCounts402e5dd7070Spatrick   void VisitStmt(const Stmt *S) {
403e5dd7070Spatrick     RecordStmtCount(S);
404e5dd7070Spatrick     for (const Stmt *Child : S->children())
405e5dd7070Spatrick       if (Child)
406e5dd7070Spatrick         this->Visit(Child);
407e5dd7070Spatrick   }
408e5dd7070Spatrick 
VisitFunctionDecl__anon393d18f00111::ComputeRegionCounts409e5dd7070Spatrick   void VisitFunctionDecl(const FunctionDecl *D) {
410e5dd7070Spatrick     // Counter tracks entry to the function body.
411e5dd7070Spatrick     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
412e5dd7070Spatrick     CountMap[D->getBody()] = BodyCount;
413e5dd7070Spatrick     Visit(D->getBody());
414e5dd7070Spatrick   }
415e5dd7070Spatrick 
416e5dd7070Spatrick   // Skip lambda expressions. We visit these as FunctionDecls when we're
417e5dd7070Spatrick   // generating them and aren't interested in the body when generating a
418e5dd7070Spatrick   // parent context.
VisitLambdaExpr__anon393d18f00111::ComputeRegionCounts419e5dd7070Spatrick   void VisitLambdaExpr(const LambdaExpr *LE) {}
420e5dd7070Spatrick 
VisitCapturedDecl__anon393d18f00111::ComputeRegionCounts421e5dd7070Spatrick   void VisitCapturedDecl(const CapturedDecl *D) {
422e5dd7070Spatrick     // Counter tracks entry to the capture body.
423e5dd7070Spatrick     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
424e5dd7070Spatrick     CountMap[D->getBody()] = BodyCount;
425e5dd7070Spatrick     Visit(D->getBody());
426e5dd7070Spatrick   }
427e5dd7070Spatrick 
VisitObjCMethodDecl__anon393d18f00111::ComputeRegionCounts428e5dd7070Spatrick   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
429e5dd7070Spatrick     // Counter tracks entry to the method body.
430e5dd7070Spatrick     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
431e5dd7070Spatrick     CountMap[D->getBody()] = BodyCount;
432e5dd7070Spatrick     Visit(D->getBody());
433e5dd7070Spatrick   }
434e5dd7070Spatrick 
VisitBlockDecl__anon393d18f00111::ComputeRegionCounts435e5dd7070Spatrick   void VisitBlockDecl(const BlockDecl *D) {
436e5dd7070Spatrick     // Counter tracks entry to the block body.
437e5dd7070Spatrick     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
438e5dd7070Spatrick     CountMap[D->getBody()] = BodyCount;
439e5dd7070Spatrick     Visit(D->getBody());
440e5dd7070Spatrick   }
441e5dd7070Spatrick 
VisitReturnStmt__anon393d18f00111::ComputeRegionCounts442e5dd7070Spatrick   void VisitReturnStmt(const ReturnStmt *S) {
443e5dd7070Spatrick     RecordStmtCount(S);
444e5dd7070Spatrick     if (S->getRetValue())
445e5dd7070Spatrick       Visit(S->getRetValue());
446e5dd7070Spatrick     CurrentCount = 0;
447e5dd7070Spatrick     RecordNextStmtCount = true;
448e5dd7070Spatrick   }
449e5dd7070Spatrick 
VisitCXXThrowExpr__anon393d18f00111::ComputeRegionCounts450e5dd7070Spatrick   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
451e5dd7070Spatrick     RecordStmtCount(E);
452e5dd7070Spatrick     if (E->getSubExpr())
453e5dd7070Spatrick       Visit(E->getSubExpr());
454e5dd7070Spatrick     CurrentCount = 0;
455e5dd7070Spatrick     RecordNextStmtCount = true;
456e5dd7070Spatrick   }
457e5dd7070Spatrick 
VisitGotoStmt__anon393d18f00111::ComputeRegionCounts458e5dd7070Spatrick   void VisitGotoStmt(const GotoStmt *S) {
459e5dd7070Spatrick     RecordStmtCount(S);
460e5dd7070Spatrick     CurrentCount = 0;
461e5dd7070Spatrick     RecordNextStmtCount = true;
462e5dd7070Spatrick   }
463e5dd7070Spatrick 
VisitLabelStmt__anon393d18f00111::ComputeRegionCounts464e5dd7070Spatrick   void VisitLabelStmt(const LabelStmt *S) {
465e5dd7070Spatrick     RecordNextStmtCount = false;
466e5dd7070Spatrick     // Counter tracks the block following the label.
467e5dd7070Spatrick     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
468e5dd7070Spatrick     CountMap[S] = BlockCount;
469e5dd7070Spatrick     Visit(S->getSubStmt());
470e5dd7070Spatrick   }
471e5dd7070Spatrick 
VisitBreakStmt__anon393d18f00111::ComputeRegionCounts472e5dd7070Spatrick   void VisitBreakStmt(const BreakStmt *S) {
473e5dd7070Spatrick     RecordStmtCount(S);
474e5dd7070Spatrick     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
475e5dd7070Spatrick     BreakContinueStack.back().BreakCount += CurrentCount;
476e5dd7070Spatrick     CurrentCount = 0;
477e5dd7070Spatrick     RecordNextStmtCount = true;
478e5dd7070Spatrick   }
479e5dd7070Spatrick 
VisitContinueStmt__anon393d18f00111::ComputeRegionCounts480e5dd7070Spatrick   void VisitContinueStmt(const ContinueStmt *S) {
481e5dd7070Spatrick     RecordStmtCount(S);
482e5dd7070Spatrick     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
483e5dd7070Spatrick     BreakContinueStack.back().ContinueCount += CurrentCount;
484e5dd7070Spatrick     CurrentCount = 0;
485e5dd7070Spatrick     RecordNextStmtCount = true;
486e5dd7070Spatrick   }
487e5dd7070Spatrick 
VisitWhileStmt__anon393d18f00111::ComputeRegionCounts488e5dd7070Spatrick   void VisitWhileStmt(const WhileStmt *S) {
489e5dd7070Spatrick     RecordStmtCount(S);
490e5dd7070Spatrick     uint64_t ParentCount = CurrentCount;
491e5dd7070Spatrick 
492e5dd7070Spatrick     BreakContinueStack.push_back(BreakContinue());
493e5dd7070Spatrick     // Visit the body region first so the break/continue adjustments can be
494e5dd7070Spatrick     // included when visiting the condition.
495e5dd7070Spatrick     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
496e5dd7070Spatrick     CountMap[S->getBody()] = CurrentCount;
497e5dd7070Spatrick     Visit(S->getBody());
498e5dd7070Spatrick     uint64_t BackedgeCount = CurrentCount;
499e5dd7070Spatrick 
500e5dd7070Spatrick     // ...then go back and propagate counts through the condition. The count
501e5dd7070Spatrick     // at the start of the condition is the sum of the incoming edges,
502e5dd7070Spatrick     // the backedge from the end of the loop body, and the edges from
503e5dd7070Spatrick     // continue statements.
504e5dd7070Spatrick     BreakContinue BC = BreakContinueStack.pop_back_val();
505e5dd7070Spatrick     uint64_t CondCount =
506e5dd7070Spatrick         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
507e5dd7070Spatrick     CountMap[S->getCond()] = CondCount;
508e5dd7070Spatrick     Visit(S->getCond());
509e5dd7070Spatrick     setCount(BC.BreakCount + CondCount - BodyCount);
510e5dd7070Spatrick     RecordNextStmtCount = true;
511e5dd7070Spatrick   }
512e5dd7070Spatrick 
VisitDoStmt__anon393d18f00111::ComputeRegionCounts513e5dd7070Spatrick   void VisitDoStmt(const DoStmt *S) {
514e5dd7070Spatrick     RecordStmtCount(S);
515e5dd7070Spatrick     uint64_t LoopCount = PGO.getRegionCount(S);
516e5dd7070Spatrick 
517e5dd7070Spatrick     BreakContinueStack.push_back(BreakContinue());
518e5dd7070Spatrick     // The count doesn't include the fallthrough from the parent scope. Add it.
519e5dd7070Spatrick     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
520e5dd7070Spatrick     CountMap[S->getBody()] = BodyCount;
521e5dd7070Spatrick     Visit(S->getBody());
522e5dd7070Spatrick     uint64_t BackedgeCount = CurrentCount;
523e5dd7070Spatrick 
524e5dd7070Spatrick     BreakContinue BC = BreakContinueStack.pop_back_val();
525e5dd7070Spatrick     // The count at the start of the condition is equal to the count at the
526e5dd7070Spatrick     // end of the body, plus any continues.
527e5dd7070Spatrick     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
528e5dd7070Spatrick     CountMap[S->getCond()] = CondCount;
529e5dd7070Spatrick     Visit(S->getCond());
530e5dd7070Spatrick     setCount(BC.BreakCount + CondCount - LoopCount);
531e5dd7070Spatrick     RecordNextStmtCount = true;
532e5dd7070Spatrick   }
533e5dd7070Spatrick 
VisitForStmt__anon393d18f00111::ComputeRegionCounts534e5dd7070Spatrick   void VisitForStmt(const ForStmt *S) {
535e5dd7070Spatrick     RecordStmtCount(S);
536e5dd7070Spatrick     if (S->getInit())
537e5dd7070Spatrick       Visit(S->getInit());
538e5dd7070Spatrick 
539e5dd7070Spatrick     uint64_t ParentCount = CurrentCount;
540e5dd7070Spatrick 
541e5dd7070Spatrick     BreakContinueStack.push_back(BreakContinue());
542e5dd7070Spatrick     // Visit the body region first. (This is basically the same as a while
543e5dd7070Spatrick     // loop; see further comments in VisitWhileStmt.)
544e5dd7070Spatrick     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
545e5dd7070Spatrick     CountMap[S->getBody()] = BodyCount;
546e5dd7070Spatrick     Visit(S->getBody());
547e5dd7070Spatrick     uint64_t BackedgeCount = CurrentCount;
548e5dd7070Spatrick     BreakContinue BC = BreakContinueStack.pop_back_val();
549e5dd7070Spatrick 
550e5dd7070Spatrick     // The increment is essentially part of the body but it needs to include
551e5dd7070Spatrick     // the count for all the continue statements.
552e5dd7070Spatrick     if (S->getInc()) {
553e5dd7070Spatrick       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
554e5dd7070Spatrick       CountMap[S->getInc()] = IncCount;
555e5dd7070Spatrick       Visit(S->getInc());
556e5dd7070Spatrick     }
557e5dd7070Spatrick 
558e5dd7070Spatrick     // ...then go back and propagate counts through the condition.
559e5dd7070Spatrick     uint64_t CondCount =
560e5dd7070Spatrick         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
561e5dd7070Spatrick     if (S->getCond()) {
562e5dd7070Spatrick       CountMap[S->getCond()] = CondCount;
563e5dd7070Spatrick       Visit(S->getCond());
564e5dd7070Spatrick     }
565e5dd7070Spatrick     setCount(BC.BreakCount + CondCount - BodyCount);
566e5dd7070Spatrick     RecordNextStmtCount = true;
567e5dd7070Spatrick   }
568e5dd7070Spatrick 
VisitCXXForRangeStmt__anon393d18f00111::ComputeRegionCounts569e5dd7070Spatrick   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
570e5dd7070Spatrick     RecordStmtCount(S);
571e5dd7070Spatrick     if (S->getInit())
572e5dd7070Spatrick       Visit(S->getInit());
573e5dd7070Spatrick     Visit(S->getLoopVarStmt());
574e5dd7070Spatrick     Visit(S->getRangeStmt());
575e5dd7070Spatrick     Visit(S->getBeginStmt());
576e5dd7070Spatrick     Visit(S->getEndStmt());
577e5dd7070Spatrick 
578e5dd7070Spatrick     uint64_t ParentCount = CurrentCount;
579e5dd7070Spatrick     BreakContinueStack.push_back(BreakContinue());
580e5dd7070Spatrick     // Visit the body region first. (This is basically the same as a while
581e5dd7070Spatrick     // loop; see further comments in VisitWhileStmt.)
582e5dd7070Spatrick     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
583e5dd7070Spatrick     CountMap[S->getBody()] = BodyCount;
584e5dd7070Spatrick     Visit(S->getBody());
585e5dd7070Spatrick     uint64_t BackedgeCount = CurrentCount;
586e5dd7070Spatrick     BreakContinue BC = BreakContinueStack.pop_back_val();
587e5dd7070Spatrick 
588e5dd7070Spatrick     // The increment is essentially part of the body but it needs to include
589e5dd7070Spatrick     // the count for all the continue statements.
590e5dd7070Spatrick     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
591e5dd7070Spatrick     CountMap[S->getInc()] = IncCount;
592e5dd7070Spatrick     Visit(S->getInc());
593e5dd7070Spatrick 
594e5dd7070Spatrick     // ...then go back and propagate counts through the condition.
595e5dd7070Spatrick     uint64_t CondCount =
596e5dd7070Spatrick         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
597e5dd7070Spatrick     CountMap[S->getCond()] = CondCount;
598e5dd7070Spatrick     Visit(S->getCond());
599e5dd7070Spatrick     setCount(BC.BreakCount + CondCount - BodyCount);
600e5dd7070Spatrick     RecordNextStmtCount = true;
601e5dd7070Spatrick   }
602e5dd7070Spatrick 
VisitObjCForCollectionStmt__anon393d18f00111::ComputeRegionCounts603e5dd7070Spatrick   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
604e5dd7070Spatrick     RecordStmtCount(S);
605e5dd7070Spatrick     Visit(S->getElement());
606e5dd7070Spatrick     uint64_t ParentCount = CurrentCount;
607e5dd7070Spatrick     BreakContinueStack.push_back(BreakContinue());
608e5dd7070Spatrick     // Counter tracks the body of the loop.
609e5dd7070Spatrick     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
610e5dd7070Spatrick     CountMap[S->getBody()] = BodyCount;
611e5dd7070Spatrick     Visit(S->getBody());
612e5dd7070Spatrick     uint64_t BackedgeCount = CurrentCount;
613e5dd7070Spatrick     BreakContinue BC = BreakContinueStack.pop_back_val();
614e5dd7070Spatrick 
615e5dd7070Spatrick     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
616e5dd7070Spatrick              BodyCount);
617e5dd7070Spatrick     RecordNextStmtCount = true;
618e5dd7070Spatrick   }
619e5dd7070Spatrick 
VisitSwitchStmt__anon393d18f00111::ComputeRegionCounts620e5dd7070Spatrick   void VisitSwitchStmt(const SwitchStmt *S) {
621e5dd7070Spatrick     RecordStmtCount(S);
622e5dd7070Spatrick     if (S->getInit())
623e5dd7070Spatrick       Visit(S->getInit());
624e5dd7070Spatrick     Visit(S->getCond());
625e5dd7070Spatrick     CurrentCount = 0;
626e5dd7070Spatrick     BreakContinueStack.push_back(BreakContinue());
627e5dd7070Spatrick     Visit(S->getBody());
628e5dd7070Spatrick     // If the switch is inside a loop, add the continue counts.
629e5dd7070Spatrick     BreakContinue BC = BreakContinueStack.pop_back_val();
630e5dd7070Spatrick     if (!BreakContinueStack.empty())
631e5dd7070Spatrick       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
632e5dd7070Spatrick     // Counter tracks the exit block of the switch.
633e5dd7070Spatrick     setCount(PGO.getRegionCount(S));
634e5dd7070Spatrick     RecordNextStmtCount = true;
635e5dd7070Spatrick   }
636e5dd7070Spatrick 
VisitSwitchCase__anon393d18f00111::ComputeRegionCounts637e5dd7070Spatrick   void VisitSwitchCase(const SwitchCase *S) {
638e5dd7070Spatrick     RecordNextStmtCount = false;
639e5dd7070Spatrick     // Counter for this particular case. This counts only jumps from the
640e5dd7070Spatrick     // switch header and does not include fallthrough from the case before
641e5dd7070Spatrick     // this one.
642e5dd7070Spatrick     uint64_t CaseCount = PGO.getRegionCount(S);
643e5dd7070Spatrick     setCount(CurrentCount + CaseCount);
644e5dd7070Spatrick     // We need the count without fallthrough in the mapping, so it's more useful
645e5dd7070Spatrick     // for branch probabilities.
646e5dd7070Spatrick     CountMap[S] = CaseCount;
647e5dd7070Spatrick     RecordNextStmtCount = true;
648e5dd7070Spatrick     Visit(S->getSubStmt());
649e5dd7070Spatrick   }
650e5dd7070Spatrick 
VisitIfStmt__anon393d18f00111::ComputeRegionCounts651e5dd7070Spatrick   void VisitIfStmt(const IfStmt *S) {
652e5dd7070Spatrick     RecordStmtCount(S);
653*12c85518Srobert 
654*12c85518Srobert     if (S->isConsteval()) {
655*12c85518Srobert       const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse();
656*12c85518Srobert       if (Stm)
657*12c85518Srobert         Visit(Stm);
658*12c85518Srobert       return;
659*12c85518Srobert     }
660*12c85518Srobert 
661e5dd7070Spatrick     uint64_t ParentCount = CurrentCount;
662e5dd7070Spatrick     if (S->getInit())
663e5dd7070Spatrick       Visit(S->getInit());
664e5dd7070Spatrick     Visit(S->getCond());
665e5dd7070Spatrick 
666e5dd7070Spatrick     // Counter tracks the "then" part of an if statement. The count for
667e5dd7070Spatrick     // the "else" part, if it exists, will be calculated from this counter.
668e5dd7070Spatrick     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
669e5dd7070Spatrick     CountMap[S->getThen()] = ThenCount;
670e5dd7070Spatrick     Visit(S->getThen());
671e5dd7070Spatrick     uint64_t OutCount = CurrentCount;
672e5dd7070Spatrick 
673e5dd7070Spatrick     uint64_t ElseCount = ParentCount - ThenCount;
674e5dd7070Spatrick     if (S->getElse()) {
675e5dd7070Spatrick       setCount(ElseCount);
676e5dd7070Spatrick       CountMap[S->getElse()] = ElseCount;
677e5dd7070Spatrick       Visit(S->getElse());
678e5dd7070Spatrick       OutCount += CurrentCount;
679e5dd7070Spatrick     } else
680e5dd7070Spatrick       OutCount += ElseCount;
681e5dd7070Spatrick     setCount(OutCount);
682e5dd7070Spatrick     RecordNextStmtCount = true;
683e5dd7070Spatrick   }
684e5dd7070Spatrick 
VisitCXXTryStmt__anon393d18f00111::ComputeRegionCounts685e5dd7070Spatrick   void VisitCXXTryStmt(const CXXTryStmt *S) {
686e5dd7070Spatrick     RecordStmtCount(S);
687e5dd7070Spatrick     Visit(S->getTryBlock());
688e5dd7070Spatrick     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
689e5dd7070Spatrick       Visit(S->getHandler(I));
690e5dd7070Spatrick     // Counter tracks the continuation block of the try statement.
691e5dd7070Spatrick     setCount(PGO.getRegionCount(S));
692e5dd7070Spatrick     RecordNextStmtCount = true;
693e5dd7070Spatrick   }
694e5dd7070Spatrick 
VisitCXXCatchStmt__anon393d18f00111::ComputeRegionCounts695e5dd7070Spatrick   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
696e5dd7070Spatrick     RecordNextStmtCount = false;
697e5dd7070Spatrick     // Counter tracks the catch statement's handler block.
698e5dd7070Spatrick     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
699e5dd7070Spatrick     CountMap[S] = CatchCount;
700e5dd7070Spatrick     Visit(S->getHandlerBlock());
701e5dd7070Spatrick   }
702e5dd7070Spatrick 
VisitAbstractConditionalOperator__anon393d18f00111::ComputeRegionCounts703e5dd7070Spatrick   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
704e5dd7070Spatrick     RecordStmtCount(E);
705e5dd7070Spatrick     uint64_t ParentCount = CurrentCount;
706e5dd7070Spatrick     Visit(E->getCond());
707e5dd7070Spatrick 
708e5dd7070Spatrick     // Counter tracks the "true" part of a conditional operator. The
709e5dd7070Spatrick     // count in the "false" part will be calculated from this counter.
710e5dd7070Spatrick     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
711e5dd7070Spatrick     CountMap[E->getTrueExpr()] = TrueCount;
712e5dd7070Spatrick     Visit(E->getTrueExpr());
713e5dd7070Spatrick     uint64_t OutCount = CurrentCount;
714e5dd7070Spatrick 
715e5dd7070Spatrick     uint64_t FalseCount = setCount(ParentCount - TrueCount);
716e5dd7070Spatrick     CountMap[E->getFalseExpr()] = FalseCount;
717e5dd7070Spatrick     Visit(E->getFalseExpr());
718e5dd7070Spatrick     OutCount += CurrentCount;
719e5dd7070Spatrick 
720e5dd7070Spatrick     setCount(OutCount);
721e5dd7070Spatrick     RecordNextStmtCount = true;
722e5dd7070Spatrick   }
723e5dd7070Spatrick 
VisitBinLAnd__anon393d18f00111::ComputeRegionCounts724e5dd7070Spatrick   void VisitBinLAnd(const BinaryOperator *E) {
725e5dd7070Spatrick     RecordStmtCount(E);
726e5dd7070Spatrick     uint64_t ParentCount = CurrentCount;
727e5dd7070Spatrick     Visit(E->getLHS());
728e5dd7070Spatrick     // Counter tracks the right hand side of a logical and operator.
729e5dd7070Spatrick     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
730e5dd7070Spatrick     CountMap[E->getRHS()] = RHSCount;
731e5dd7070Spatrick     Visit(E->getRHS());
732e5dd7070Spatrick     setCount(ParentCount + RHSCount - CurrentCount);
733e5dd7070Spatrick     RecordNextStmtCount = true;
734e5dd7070Spatrick   }
735e5dd7070Spatrick 
VisitBinLOr__anon393d18f00111::ComputeRegionCounts736e5dd7070Spatrick   void VisitBinLOr(const BinaryOperator *E) {
737e5dd7070Spatrick     RecordStmtCount(E);
738e5dd7070Spatrick     uint64_t ParentCount = CurrentCount;
739e5dd7070Spatrick     Visit(E->getLHS());
740e5dd7070Spatrick     // Counter tracks the right hand side of a logical or operator.
741e5dd7070Spatrick     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
742e5dd7070Spatrick     CountMap[E->getRHS()] = RHSCount;
743e5dd7070Spatrick     Visit(E->getRHS());
744e5dd7070Spatrick     setCount(ParentCount + RHSCount - CurrentCount);
745e5dd7070Spatrick     RecordNextStmtCount = true;
746e5dd7070Spatrick   }
747e5dd7070Spatrick };
748e5dd7070Spatrick } // end anonymous namespace
749e5dd7070Spatrick 
combine(HashType Type)750e5dd7070Spatrick void PGOHash::combine(HashType Type) {
751e5dd7070Spatrick   // Check that we never combine 0 and only have six bits.
752e5dd7070Spatrick   assert(Type && "Hash is invalid: unexpected type 0");
753e5dd7070Spatrick   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
754e5dd7070Spatrick 
755e5dd7070Spatrick   // Pass through MD5 if enough work has built up.
756e5dd7070Spatrick   if (Count && Count % NumTypesPerWord == 0) {
757e5dd7070Spatrick     using namespace llvm::support;
758e5dd7070Spatrick     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
759*12c85518Srobert     MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
760e5dd7070Spatrick     Working = 0;
761e5dd7070Spatrick   }
762e5dd7070Spatrick 
763e5dd7070Spatrick   // Accumulate the current type.
764e5dd7070Spatrick   ++Count;
765e5dd7070Spatrick   Working = Working << NumBitsPerType | Type;
766e5dd7070Spatrick }
767e5dd7070Spatrick 
finalize()768e5dd7070Spatrick uint64_t PGOHash::finalize() {
769e5dd7070Spatrick   // Use Working as the hash directly if we never used MD5.
770e5dd7070Spatrick   if (Count <= NumTypesPerWord)
771e5dd7070Spatrick     // No need to byte swap here, since none of the math was endian-dependent.
772e5dd7070Spatrick     // This number will be byte-swapped as required on endianness transitions,
773e5dd7070Spatrick     // so we will see the same value on the other side.
774e5dd7070Spatrick     return Working;
775e5dd7070Spatrick 
776e5dd7070Spatrick   // Check for remaining work in Working.
777ec727ea7Spatrick   if (Working) {
778ec727ea7Spatrick     // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
779ec727ea7Spatrick     // is buggy because it converts a uint64_t into an array of uint8_t.
780ec727ea7Spatrick     if (HashVersion < PGO_HASH_V3) {
781ec727ea7Spatrick       MD5.update({(uint8_t)Working});
782ec727ea7Spatrick     } else {
783ec727ea7Spatrick       using namespace llvm::support;
784ec727ea7Spatrick       uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
785*12c85518Srobert       MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
786ec727ea7Spatrick     }
787ec727ea7Spatrick   }
788e5dd7070Spatrick 
789e5dd7070Spatrick   // Finalize the MD5 and return the hash.
790e5dd7070Spatrick   llvm::MD5::MD5Result Result;
791e5dd7070Spatrick   MD5.final(Result);
792e5dd7070Spatrick   return Result.low();
793e5dd7070Spatrick }
794e5dd7070Spatrick 
assignRegionCounters(GlobalDecl GD,llvm::Function * Fn)795e5dd7070Spatrick void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
796e5dd7070Spatrick   const Decl *D = GD.getDecl();
797e5dd7070Spatrick   if (!D->hasBody())
798e5dd7070Spatrick     return;
799e5dd7070Spatrick 
800a9ac8606Spatrick   // Skip CUDA/HIP kernel launch stub functions.
801a9ac8606Spatrick   if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
802a9ac8606Spatrick       D->hasAttr<CUDAGlobalAttr>())
803a9ac8606Spatrick     return;
804a9ac8606Spatrick 
805e5dd7070Spatrick   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
806e5dd7070Spatrick   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
807e5dd7070Spatrick   if (!InstrumentRegions && !PGOReader)
808e5dd7070Spatrick     return;
809e5dd7070Spatrick   if (D->isImplicit())
810e5dd7070Spatrick     return;
811e5dd7070Spatrick   // Constructors and destructors may be represented by several functions in IR.
812e5dd7070Spatrick   // If so, instrument only base variant, others are implemented by delegation
813e5dd7070Spatrick   // to the base one, it would be counted twice otherwise.
814e5dd7070Spatrick   if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
815e5dd7070Spatrick     if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
816e5dd7070Spatrick       if (GD.getCtorType() != Ctor_Base &&
817e5dd7070Spatrick           CodeGenFunction::IsConstructorDelegationValid(CCD))
818e5dd7070Spatrick         return;
819e5dd7070Spatrick   }
820e5dd7070Spatrick   if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
821e5dd7070Spatrick     return;
822e5dd7070Spatrick 
823e5dd7070Spatrick   CGM.ClearUnusedCoverageMapping(D);
824a9ac8606Spatrick   if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
825a9ac8606Spatrick     return;
826*12c85518Srobert   if (Fn->hasFnAttribute(llvm::Attribute::SkipProfile))
827*12c85518Srobert     return;
828a9ac8606Spatrick 
829e5dd7070Spatrick   setFuncName(Fn);
830e5dd7070Spatrick 
831e5dd7070Spatrick   mapRegionCounters(D);
832e5dd7070Spatrick   if (CGM.getCodeGenOpts().CoverageMapping)
833e5dd7070Spatrick     emitCounterRegionMapping(D);
834e5dd7070Spatrick   if (PGOReader) {
835e5dd7070Spatrick     SourceManager &SM = CGM.getContext().getSourceManager();
836e5dd7070Spatrick     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
837e5dd7070Spatrick     computeRegionCounts(D);
838e5dd7070Spatrick     applyFunctionAttributes(PGOReader, Fn);
839e5dd7070Spatrick   }
840e5dd7070Spatrick }
841e5dd7070Spatrick 
mapRegionCounters(const Decl * D)842e5dd7070Spatrick void CodeGenPGO::mapRegionCounters(const Decl *D) {
843e5dd7070Spatrick   // Use the latest hash version when inserting instrumentation, but use the
844e5dd7070Spatrick   // version in the indexed profile if we're reading PGO data.
845e5dd7070Spatrick   PGOHashVersion HashVersion = PGO_HASH_LATEST;
846a9ac8606Spatrick   uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
847a9ac8606Spatrick   if (auto *PGOReader = CGM.getPGOReader()) {
848e5dd7070Spatrick     HashVersion = getPGOHashVersion(PGOReader, CGM);
849a9ac8606Spatrick     ProfileVersion = PGOReader->getVersion();
850a9ac8606Spatrick   }
851e5dd7070Spatrick 
852e5dd7070Spatrick   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
853a9ac8606Spatrick   MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap);
854e5dd7070Spatrick   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
855e5dd7070Spatrick     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
856e5dd7070Spatrick   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
857e5dd7070Spatrick     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
858e5dd7070Spatrick   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
859e5dd7070Spatrick     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
860e5dd7070Spatrick   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
861e5dd7070Spatrick     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
862e5dd7070Spatrick   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
863e5dd7070Spatrick   NumRegionCounters = Walker.NextCounter;
864e5dd7070Spatrick   FunctionHash = Walker.Hash.finalize();
865e5dd7070Spatrick }
866e5dd7070Spatrick 
skipRegionMappingForDecl(const Decl * D)867e5dd7070Spatrick bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
868e5dd7070Spatrick   if (!D->getBody())
869e5dd7070Spatrick     return true;
870e5dd7070Spatrick 
871a9ac8606Spatrick   // Skip host-only functions in the CUDA device compilation and device-only
872a9ac8606Spatrick   // functions in the host compilation. Just roughly filter them out based on
873a9ac8606Spatrick   // the function attributes. If there are effectively host-only or device-only
874a9ac8606Spatrick   // ones, their coverage mapping may still be generated.
875a9ac8606Spatrick   if (CGM.getLangOpts().CUDA &&
876a9ac8606Spatrick       ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
877a9ac8606Spatrick         !D->hasAttr<CUDAGlobalAttr>()) ||
878a9ac8606Spatrick        (!CGM.getLangOpts().CUDAIsDevice &&
879a9ac8606Spatrick         (D->hasAttr<CUDAGlobalAttr>() ||
880a9ac8606Spatrick          (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
881a9ac8606Spatrick     return true;
882a9ac8606Spatrick 
883e5dd7070Spatrick   // Don't map the functions in system headers.
884e5dd7070Spatrick   const auto &SM = CGM.getContext().getSourceManager();
885e5dd7070Spatrick   auto Loc = D->getBody()->getBeginLoc();
886e5dd7070Spatrick   return SM.isInSystemHeader(Loc);
887e5dd7070Spatrick }
888e5dd7070Spatrick 
emitCounterRegionMapping(const Decl * D)889e5dd7070Spatrick void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
890e5dd7070Spatrick   if (skipRegionMappingForDecl(D))
891e5dd7070Spatrick     return;
892e5dd7070Spatrick 
893e5dd7070Spatrick   std::string CoverageMapping;
894e5dd7070Spatrick   llvm::raw_string_ostream OS(CoverageMapping);
895e5dd7070Spatrick   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
896e5dd7070Spatrick                                 CGM.getContext().getSourceManager(),
897e5dd7070Spatrick                                 CGM.getLangOpts(), RegionCounterMap.get());
898e5dd7070Spatrick   MappingGen.emitCounterMapping(D, OS);
899e5dd7070Spatrick   OS.flush();
900e5dd7070Spatrick 
901e5dd7070Spatrick   if (CoverageMapping.empty())
902e5dd7070Spatrick     return;
903e5dd7070Spatrick 
904e5dd7070Spatrick   CGM.getCoverageMapping()->addFunctionMappingRecord(
905e5dd7070Spatrick       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
906e5dd7070Spatrick }
907e5dd7070Spatrick 
908e5dd7070Spatrick void
emitEmptyCounterMapping(const Decl * D,StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)909e5dd7070Spatrick CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
910e5dd7070Spatrick                                     llvm::GlobalValue::LinkageTypes Linkage) {
911e5dd7070Spatrick   if (skipRegionMappingForDecl(D))
912e5dd7070Spatrick     return;
913e5dd7070Spatrick 
914e5dd7070Spatrick   std::string CoverageMapping;
915e5dd7070Spatrick   llvm::raw_string_ostream OS(CoverageMapping);
916e5dd7070Spatrick   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
917e5dd7070Spatrick                                 CGM.getContext().getSourceManager(),
918e5dd7070Spatrick                                 CGM.getLangOpts());
919e5dd7070Spatrick   MappingGen.emitEmptyMapping(D, OS);
920e5dd7070Spatrick   OS.flush();
921e5dd7070Spatrick 
922e5dd7070Spatrick   if (CoverageMapping.empty())
923e5dd7070Spatrick     return;
924e5dd7070Spatrick 
925e5dd7070Spatrick   setFuncName(Name, Linkage);
926e5dd7070Spatrick   CGM.getCoverageMapping()->addFunctionMappingRecord(
927e5dd7070Spatrick       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
928e5dd7070Spatrick }
929e5dd7070Spatrick 
computeRegionCounts(const Decl * D)930e5dd7070Spatrick void CodeGenPGO::computeRegionCounts(const Decl *D) {
931e5dd7070Spatrick   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
932e5dd7070Spatrick   ComputeRegionCounts Walker(*StmtCountMap, *this);
933e5dd7070Spatrick   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
934e5dd7070Spatrick     Walker.VisitFunctionDecl(FD);
935e5dd7070Spatrick   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
936e5dd7070Spatrick     Walker.VisitObjCMethodDecl(MD);
937e5dd7070Spatrick   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
938e5dd7070Spatrick     Walker.VisitBlockDecl(BD);
939e5dd7070Spatrick   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
940e5dd7070Spatrick     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
941e5dd7070Spatrick }
942e5dd7070Spatrick 
943e5dd7070Spatrick void
applyFunctionAttributes(llvm::IndexedInstrProfReader * PGOReader,llvm::Function * Fn)944e5dd7070Spatrick CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
945e5dd7070Spatrick                                     llvm::Function *Fn) {
946e5dd7070Spatrick   if (!haveRegionCounts())
947e5dd7070Spatrick     return;
948e5dd7070Spatrick 
949e5dd7070Spatrick   uint64_t FunctionCount = getRegionCount(nullptr);
950e5dd7070Spatrick   Fn->setEntryCount(FunctionCount);
951e5dd7070Spatrick }
952e5dd7070Spatrick 
emitCounterIncrement(CGBuilderTy & Builder,const Stmt * S,llvm::Value * StepV)953e5dd7070Spatrick void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
954e5dd7070Spatrick                                       llvm::Value *StepV) {
955e5dd7070Spatrick   if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
956e5dd7070Spatrick     return;
957e5dd7070Spatrick   if (!Builder.GetInsertBlock())
958e5dd7070Spatrick     return;
959e5dd7070Spatrick 
960e5dd7070Spatrick   unsigned Counter = (*RegionCounterMap)[S];
961e5dd7070Spatrick   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
962e5dd7070Spatrick 
963e5dd7070Spatrick   llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
964e5dd7070Spatrick                          Builder.getInt64(FunctionHash),
965e5dd7070Spatrick                          Builder.getInt32(NumRegionCounters),
966e5dd7070Spatrick                          Builder.getInt32(Counter), StepV};
967e5dd7070Spatrick   if (!StepV)
968e5dd7070Spatrick     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
969*12c85518Srobert                        ArrayRef(Args, 4));
970e5dd7070Spatrick   else
971e5dd7070Spatrick     Builder.CreateCall(
972e5dd7070Spatrick         CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
973*12c85518Srobert         ArrayRef(Args));
974e5dd7070Spatrick }
975e5dd7070Spatrick 
setValueProfilingFlag(llvm::Module & M)976a9ac8606Spatrick void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
977a9ac8606Spatrick   if (CGM.getCodeGenOpts().hasProfileClangInstr())
978a9ac8606Spatrick     M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling",
979a9ac8606Spatrick                     uint32_t(EnableValueProfiling));
980a9ac8606Spatrick }
981a9ac8606Spatrick 
982e5dd7070Spatrick // This method either inserts a call to the profile run-time during
983e5dd7070Spatrick // instrumentation or puts profile data into metadata for PGO use.
valueProfile(CGBuilderTy & Builder,uint32_t ValueKind,llvm::Instruction * ValueSite,llvm::Value * ValuePtr)984e5dd7070Spatrick void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
985e5dd7070Spatrick     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
986e5dd7070Spatrick 
987e5dd7070Spatrick   if (!EnableValueProfiling)
988e5dd7070Spatrick     return;
989e5dd7070Spatrick 
990e5dd7070Spatrick   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
991e5dd7070Spatrick     return;
992e5dd7070Spatrick 
993e5dd7070Spatrick   if (isa<llvm::Constant>(ValuePtr))
994e5dd7070Spatrick     return;
995e5dd7070Spatrick 
996e5dd7070Spatrick   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
997e5dd7070Spatrick   if (InstrumentValueSites && RegionCounterMap) {
998e5dd7070Spatrick     auto BuilderInsertPoint = Builder.saveIP();
999e5dd7070Spatrick     Builder.SetInsertPoint(ValueSite);
1000e5dd7070Spatrick     llvm::Value *Args[5] = {
1001e5dd7070Spatrick         llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
1002e5dd7070Spatrick         Builder.getInt64(FunctionHash),
1003e5dd7070Spatrick         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
1004e5dd7070Spatrick         Builder.getInt32(ValueKind),
1005e5dd7070Spatrick         Builder.getInt32(NumValueSites[ValueKind]++)
1006e5dd7070Spatrick     };
1007e5dd7070Spatrick     Builder.CreateCall(
1008e5dd7070Spatrick         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
1009e5dd7070Spatrick     Builder.restoreIP(BuilderInsertPoint);
1010e5dd7070Spatrick     return;
1011e5dd7070Spatrick   }
1012e5dd7070Spatrick 
1013e5dd7070Spatrick   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1014e5dd7070Spatrick   if (PGOReader && haveRegionCounts()) {
1015e5dd7070Spatrick     // We record the top most called three functions at each call site.
1016e5dd7070Spatrick     // Profile metadata contains "VP" string identifying this metadata
1017e5dd7070Spatrick     // as value profiling data, then a uint32_t value for the value profiling
1018e5dd7070Spatrick     // kind, a uint64_t value for the total number of times the call is
1019e5dd7070Spatrick     // executed, followed by the function hash and execution count (uint64_t)
1020e5dd7070Spatrick     // pairs for each function.
1021e5dd7070Spatrick     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1022e5dd7070Spatrick       return;
1023e5dd7070Spatrick 
1024e5dd7070Spatrick     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
1025e5dd7070Spatrick                             (llvm::InstrProfValueKind)ValueKind,
1026e5dd7070Spatrick                             NumValueSites[ValueKind]);
1027e5dd7070Spatrick 
1028e5dd7070Spatrick     NumValueSites[ValueKind]++;
1029e5dd7070Spatrick   }
1030e5dd7070Spatrick }
1031e5dd7070Spatrick 
loadRegionCounts(llvm::IndexedInstrProfReader * PGOReader,bool IsInMainFile)1032e5dd7070Spatrick void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1033e5dd7070Spatrick                                   bool IsInMainFile) {
1034e5dd7070Spatrick   CGM.getPGOStats().addVisited(IsInMainFile);
1035e5dd7070Spatrick   RegionCounts.clear();
1036e5dd7070Spatrick   llvm::Expected<llvm::InstrProfRecord> RecordExpected =
1037e5dd7070Spatrick       PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1038e5dd7070Spatrick   if (auto E = RecordExpected.takeError()) {
1039e5dd7070Spatrick     auto IPE = llvm::InstrProfError::take(std::move(E));
1040e5dd7070Spatrick     if (IPE == llvm::instrprof_error::unknown_function)
1041e5dd7070Spatrick       CGM.getPGOStats().addMissing(IsInMainFile);
1042e5dd7070Spatrick     else if (IPE == llvm::instrprof_error::hash_mismatch)
1043e5dd7070Spatrick       CGM.getPGOStats().addMismatched(IsInMainFile);
1044e5dd7070Spatrick     else if (IPE == llvm::instrprof_error::malformed)
1045e5dd7070Spatrick       // TODO: Consider a more specific warning for this case.
1046e5dd7070Spatrick       CGM.getPGOStats().addMismatched(IsInMainFile);
1047e5dd7070Spatrick     return;
1048e5dd7070Spatrick   }
1049e5dd7070Spatrick   ProfRecord =
1050e5dd7070Spatrick       std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1051e5dd7070Spatrick   RegionCounts = ProfRecord->Counts;
1052e5dd7070Spatrick }
1053e5dd7070Spatrick 
1054e5dd7070Spatrick /// Calculate what to divide by to scale weights.
1055e5dd7070Spatrick ///
1056e5dd7070Spatrick /// Given the maximum weight, calculate a divisor that will scale all the
1057e5dd7070Spatrick /// weights to strictly less than UINT32_MAX.
calculateWeightScale(uint64_t MaxWeight)1058e5dd7070Spatrick static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1059e5dd7070Spatrick   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1060e5dd7070Spatrick }
1061e5dd7070Spatrick 
1062e5dd7070Spatrick /// Scale an individual branch weight (and add 1).
1063e5dd7070Spatrick ///
1064e5dd7070Spatrick /// Scale a 64-bit weight down to 32-bits using \c Scale.
1065e5dd7070Spatrick ///
1066e5dd7070Spatrick /// According to Laplace's Rule of Succession, it is better to compute the
1067e5dd7070Spatrick /// weight based on the count plus 1, so universally add 1 to the value.
1068e5dd7070Spatrick ///
1069e5dd7070Spatrick /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1070e5dd7070Spatrick /// greater than \c Weight.
scaleBranchWeight(uint64_t Weight,uint64_t Scale)1071e5dd7070Spatrick static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1072e5dd7070Spatrick   assert(Scale && "scale by 0?");
1073e5dd7070Spatrick   uint64_t Scaled = Weight / Scale + 1;
1074e5dd7070Spatrick   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1075e5dd7070Spatrick   return Scaled;
1076e5dd7070Spatrick }
1077e5dd7070Spatrick 
createProfileWeights(uint64_t TrueCount,uint64_t FalseCount) const1078e5dd7070Spatrick llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1079a9ac8606Spatrick                                                     uint64_t FalseCount) const {
1080e5dd7070Spatrick   // Check for empty weights.
1081e5dd7070Spatrick   if (!TrueCount && !FalseCount)
1082e5dd7070Spatrick     return nullptr;
1083e5dd7070Spatrick 
1084e5dd7070Spatrick   // Calculate how to scale down to 32-bits.
1085e5dd7070Spatrick   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1086e5dd7070Spatrick 
1087e5dd7070Spatrick   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1088e5dd7070Spatrick   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1089e5dd7070Spatrick                                       scaleBranchWeight(FalseCount, Scale));
1090e5dd7070Spatrick }
1091e5dd7070Spatrick 
1092e5dd7070Spatrick llvm::MDNode *
createProfileWeights(ArrayRef<uint64_t> Weights) const1093a9ac8606Spatrick CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1094e5dd7070Spatrick   // We need at least two elements to create meaningful weights.
1095e5dd7070Spatrick   if (Weights.size() < 2)
1096e5dd7070Spatrick     return nullptr;
1097e5dd7070Spatrick 
1098e5dd7070Spatrick   // Check for empty weights.
1099e5dd7070Spatrick   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1100e5dd7070Spatrick   if (MaxWeight == 0)
1101e5dd7070Spatrick     return nullptr;
1102e5dd7070Spatrick 
1103e5dd7070Spatrick   // Calculate how to scale down to 32-bits.
1104e5dd7070Spatrick   uint64_t Scale = calculateWeightScale(MaxWeight);
1105e5dd7070Spatrick 
1106e5dd7070Spatrick   SmallVector<uint32_t, 16> ScaledWeights;
1107e5dd7070Spatrick   ScaledWeights.reserve(Weights.size());
1108e5dd7070Spatrick   for (uint64_t W : Weights)
1109e5dd7070Spatrick     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1110e5dd7070Spatrick 
1111e5dd7070Spatrick   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1112e5dd7070Spatrick   return MDHelper.createBranchWeights(ScaledWeights);
1113e5dd7070Spatrick }
1114e5dd7070Spatrick 
1115a9ac8606Spatrick llvm::MDNode *
createProfileWeightsForLoop(const Stmt * Cond,uint64_t LoopCount) const1116a9ac8606Spatrick CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1117a9ac8606Spatrick                                              uint64_t LoopCount) const {
1118e5dd7070Spatrick   if (!PGO.haveRegionCounts())
1119e5dd7070Spatrick     return nullptr;
1120*12c85518Srobert   std::optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1121ec727ea7Spatrick   if (!CondCount || *CondCount == 0)
1122e5dd7070Spatrick     return nullptr;
1123e5dd7070Spatrick   return createProfileWeights(LoopCount,
1124e5dd7070Spatrick                               std::max(*CondCount, LoopCount) - LoopCount);
1125e5dd7070Spatrick }
1126