xref: /netbsd-src/external/apache2/llvm/dist/clang/lib/CodeGen/CodeGenPGO.cpp (revision 7330f729ccf0bd976a06f95fad452fe774fc7fd1)
1*7330f729Sjoerg //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2*7330f729Sjoerg //
3*7330f729Sjoerg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*7330f729Sjoerg // See https://llvm.org/LICENSE.txt for license information.
5*7330f729Sjoerg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*7330f729Sjoerg //
7*7330f729Sjoerg //===----------------------------------------------------------------------===//
8*7330f729Sjoerg //
9*7330f729Sjoerg // Instrumentation-based profile-guided optimization
10*7330f729Sjoerg //
11*7330f729Sjoerg //===----------------------------------------------------------------------===//
12*7330f729Sjoerg 
13*7330f729Sjoerg #include "CodeGenPGO.h"
14*7330f729Sjoerg #include "CodeGenFunction.h"
15*7330f729Sjoerg #include "CoverageMappingGen.h"
16*7330f729Sjoerg #include "clang/AST/RecursiveASTVisitor.h"
17*7330f729Sjoerg #include "clang/AST/StmtVisitor.h"
18*7330f729Sjoerg #include "llvm/IR/Intrinsics.h"
19*7330f729Sjoerg #include "llvm/IR/MDBuilder.h"
20*7330f729Sjoerg #include "llvm/Support/Endian.h"
21*7330f729Sjoerg #include "llvm/Support/FileSystem.h"
22*7330f729Sjoerg #include "llvm/Support/MD5.h"
23*7330f729Sjoerg 
24*7330f729Sjoerg static llvm::cl::opt<bool>
25*7330f729Sjoerg     EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
26*7330f729Sjoerg                          llvm::cl::desc("Enable value profiling"),
27*7330f729Sjoerg                          llvm::cl::Hidden, llvm::cl::init(false));
28*7330f729Sjoerg 
29*7330f729Sjoerg using namespace clang;
30*7330f729Sjoerg using namespace CodeGen;
31*7330f729Sjoerg 
32*7330f729Sjoerg void CodeGenPGO::setFuncName(StringRef Name,
33*7330f729Sjoerg                              llvm::GlobalValue::LinkageTypes Linkage) {
34*7330f729Sjoerg   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
35*7330f729Sjoerg   FuncName = llvm::getPGOFuncName(
36*7330f729Sjoerg       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
37*7330f729Sjoerg       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
38*7330f729Sjoerg 
39*7330f729Sjoerg   // If we're generating a profile, create a variable for the name.
40*7330f729Sjoerg   if (CGM.getCodeGenOpts().hasProfileClangInstr())
41*7330f729Sjoerg     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
42*7330f729Sjoerg }
43*7330f729Sjoerg 
44*7330f729Sjoerg void CodeGenPGO::setFuncName(llvm::Function *Fn) {
45*7330f729Sjoerg   setFuncName(Fn->getName(), Fn->getLinkage());
46*7330f729Sjoerg   // Create PGOFuncName meta data.
47*7330f729Sjoerg   llvm::createPGOFuncNameMetadata(*Fn, FuncName);
48*7330f729Sjoerg }
49*7330f729Sjoerg 
50*7330f729Sjoerg /// The version of the PGO hash algorithm.
51*7330f729Sjoerg enum PGOHashVersion : unsigned {
52*7330f729Sjoerg   PGO_HASH_V1,
53*7330f729Sjoerg   PGO_HASH_V2,
54*7330f729Sjoerg 
55*7330f729Sjoerg   // Keep this set to the latest hash version.
56*7330f729Sjoerg   PGO_HASH_LATEST = PGO_HASH_V2
57*7330f729Sjoerg };
58*7330f729Sjoerg 
59*7330f729Sjoerg namespace {
60*7330f729Sjoerg /// Stable hasher for PGO region counters.
61*7330f729Sjoerg ///
62*7330f729Sjoerg /// PGOHash produces a stable hash of a given function's control flow.
63*7330f729Sjoerg ///
64*7330f729Sjoerg /// Changing the output of this hash will invalidate all previously generated
65*7330f729Sjoerg /// profiles -- i.e., don't do it.
66*7330f729Sjoerg ///
67*7330f729Sjoerg /// \note  When this hash does eventually change (years?), we still need to
68*7330f729Sjoerg /// support old hashes.  We'll need to pull in the version number from the
69*7330f729Sjoerg /// profile data format and use the matching hash function.
70*7330f729Sjoerg class PGOHash {
71*7330f729Sjoerg   uint64_t Working;
72*7330f729Sjoerg   unsigned Count;
73*7330f729Sjoerg   PGOHashVersion HashVersion;
74*7330f729Sjoerg   llvm::MD5 MD5;
75*7330f729Sjoerg 
76*7330f729Sjoerg   static const int NumBitsPerType = 6;
77*7330f729Sjoerg   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
78*7330f729Sjoerg   static const unsigned TooBig = 1u << NumBitsPerType;
79*7330f729Sjoerg 
80*7330f729Sjoerg public:
81*7330f729Sjoerg   /// Hash values for AST nodes.
82*7330f729Sjoerg   ///
83*7330f729Sjoerg   /// Distinct values for AST nodes that have region counters attached.
84*7330f729Sjoerg   ///
85*7330f729Sjoerg   /// These values must be stable.  All new members must be added at the end,
86*7330f729Sjoerg   /// and no members should be removed.  Changing the enumeration value for an
87*7330f729Sjoerg   /// AST node will affect the hash of every function that contains that node.
88*7330f729Sjoerg   enum HashType : unsigned char {
89*7330f729Sjoerg     None = 0,
90*7330f729Sjoerg     LabelStmt = 1,
91*7330f729Sjoerg     WhileStmt,
92*7330f729Sjoerg     DoStmt,
93*7330f729Sjoerg     ForStmt,
94*7330f729Sjoerg     CXXForRangeStmt,
95*7330f729Sjoerg     ObjCForCollectionStmt,
96*7330f729Sjoerg     SwitchStmt,
97*7330f729Sjoerg     CaseStmt,
98*7330f729Sjoerg     DefaultStmt,
99*7330f729Sjoerg     IfStmt,
100*7330f729Sjoerg     CXXTryStmt,
101*7330f729Sjoerg     CXXCatchStmt,
102*7330f729Sjoerg     ConditionalOperator,
103*7330f729Sjoerg     BinaryOperatorLAnd,
104*7330f729Sjoerg     BinaryOperatorLOr,
105*7330f729Sjoerg     BinaryConditionalOperator,
106*7330f729Sjoerg     // The preceding values are available with PGO_HASH_V1.
107*7330f729Sjoerg 
108*7330f729Sjoerg     EndOfScope,
109*7330f729Sjoerg     IfThenBranch,
110*7330f729Sjoerg     IfElseBranch,
111*7330f729Sjoerg     GotoStmt,
112*7330f729Sjoerg     IndirectGotoStmt,
113*7330f729Sjoerg     BreakStmt,
114*7330f729Sjoerg     ContinueStmt,
115*7330f729Sjoerg     ReturnStmt,
116*7330f729Sjoerg     ThrowExpr,
117*7330f729Sjoerg     UnaryOperatorLNot,
118*7330f729Sjoerg     BinaryOperatorLT,
119*7330f729Sjoerg     BinaryOperatorGT,
120*7330f729Sjoerg     BinaryOperatorLE,
121*7330f729Sjoerg     BinaryOperatorGE,
122*7330f729Sjoerg     BinaryOperatorEQ,
123*7330f729Sjoerg     BinaryOperatorNE,
124*7330f729Sjoerg     // The preceding values are available with PGO_HASH_V2.
125*7330f729Sjoerg 
126*7330f729Sjoerg     // Keep this last.  It's for the static assert that follows.
127*7330f729Sjoerg     LastHashType
128*7330f729Sjoerg   };
129*7330f729Sjoerg   static_assert(LastHashType <= TooBig, "Too many types in HashType");
130*7330f729Sjoerg 
131*7330f729Sjoerg   PGOHash(PGOHashVersion HashVersion)
132*7330f729Sjoerg       : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
133*7330f729Sjoerg   void combine(HashType Type);
134*7330f729Sjoerg   uint64_t finalize();
135*7330f729Sjoerg   PGOHashVersion getHashVersion() const { return HashVersion; }
136*7330f729Sjoerg };
137*7330f729Sjoerg const int PGOHash::NumBitsPerType;
138*7330f729Sjoerg const unsigned PGOHash::NumTypesPerWord;
139*7330f729Sjoerg const unsigned PGOHash::TooBig;
140*7330f729Sjoerg 
141*7330f729Sjoerg /// Get the PGO hash version used in the given indexed profile.
142*7330f729Sjoerg static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
143*7330f729Sjoerg                                         CodeGenModule &CGM) {
144*7330f729Sjoerg   if (PGOReader->getVersion() <= 4)
145*7330f729Sjoerg     return PGO_HASH_V1;
146*7330f729Sjoerg   return PGO_HASH_V2;
147*7330f729Sjoerg }
148*7330f729Sjoerg 
149*7330f729Sjoerg /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
150*7330f729Sjoerg struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
151*7330f729Sjoerg   using Base = RecursiveASTVisitor<MapRegionCounters>;
152*7330f729Sjoerg 
153*7330f729Sjoerg   /// The next counter value to assign.
154*7330f729Sjoerg   unsigned NextCounter;
155*7330f729Sjoerg   /// The function hash.
156*7330f729Sjoerg   PGOHash Hash;
157*7330f729Sjoerg   /// The map of statements to counters.
158*7330f729Sjoerg   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
159*7330f729Sjoerg 
160*7330f729Sjoerg   MapRegionCounters(PGOHashVersion HashVersion,
161*7330f729Sjoerg                     llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
162*7330f729Sjoerg       : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {}
163*7330f729Sjoerg 
164*7330f729Sjoerg   // Blocks and lambdas are handled as separate functions, so we need not
165*7330f729Sjoerg   // traverse them in the parent context.
166*7330f729Sjoerg   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
167*7330f729Sjoerg   bool TraverseLambdaExpr(LambdaExpr *LE) {
168*7330f729Sjoerg     // Traverse the captures, but not the body.
169*7330f729Sjoerg     for (const auto &C : zip(LE->captures(), LE->capture_inits()))
170*7330f729Sjoerg       TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
171*7330f729Sjoerg     return true;
172*7330f729Sjoerg   }
173*7330f729Sjoerg   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
174*7330f729Sjoerg 
175*7330f729Sjoerg   bool VisitDecl(const Decl *D) {
176*7330f729Sjoerg     switch (D->getKind()) {
177*7330f729Sjoerg     default:
178*7330f729Sjoerg       break;
179*7330f729Sjoerg     case Decl::Function:
180*7330f729Sjoerg     case Decl::CXXMethod:
181*7330f729Sjoerg     case Decl::CXXConstructor:
182*7330f729Sjoerg     case Decl::CXXDestructor:
183*7330f729Sjoerg     case Decl::CXXConversion:
184*7330f729Sjoerg     case Decl::ObjCMethod:
185*7330f729Sjoerg     case Decl::Block:
186*7330f729Sjoerg     case Decl::Captured:
187*7330f729Sjoerg       CounterMap[D->getBody()] = NextCounter++;
188*7330f729Sjoerg       break;
189*7330f729Sjoerg     }
190*7330f729Sjoerg     return true;
191*7330f729Sjoerg   }
192*7330f729Sjoerg 
193*7330f729Sjoerg   /// If \p S gets a fresh counter, update the counter mappings. Return the
194*7330f729Sjoerg   /// V1 hash of \p S.
195*7330f729Sjoerg   PGOHash::HashType updateCounterMappings(Stmt *S) {
196*7330f729Sjoerg     auto Type = getHashType(PGO_HASH_V1, S);
197*7330f729Sjoerg     if (Type != PGOHash::None)
198*7330f729Sjoerg       CounterMap[S] = NextCounter++;
199*7330f729Sjoerg     return Type;
200*7330f729Sjoerg   }
201*7330f729Sjoerg 
202*7330f729Sjoerg   /// Include \p S in the function hash.
203*7330f729Sjoerg   bool VisitStmt(Stmt *S) {
204*7330f729Sjoerg     auto Type = updateCounterMappings(S);
205*7330f729Sjoerg     if (Hash.getHashVersion() != PGO_HASH_V1)
206*7330f729Sjoerg       Type = getHashType(Hash.getHashVersion(), S);
207*7330f729Sjoerg     if (Type != PGOHash::None)
208*7330f729Sjoerg       Hash.combine(Type);
209*7330f729Sjoerg     return true;
210*7330f729Sjoerg   }
211*7330f729Sjoerg 
212*7330f729Sjoerg   bool TraverseIfStmt(IfStmt *If) {
213*7330f729Sjoerg     // If we used the V1 hash, use the default traversal.
214*7330f729Sjoerg     if (Hash.getHashVersion() == PGO_HASH_V1)
215*7330f729Sjoerg       return Base::TraverseIfStmt(If);
216*7330f729Sjoerg 
217*7330f729Sjoerg     // Otherwise, keep track of which branch we're in while traversing.
218*7330f729Sjoerg     VisitStmt(If);
219*7330f729Sjoerg     for (Stmt *CS : If->children()) {
220*7330f729Sjoerg       if (!CS)
221*7330f729Sjoerg         continue;
222*7330f729Sjoerg       if (CS == If->getThen())
223*7330f729Sjoerg         Hash.combine(PGOHash::IfThenBranch);
224*7330f729Sjoerg       else if (CS == If->getElse())
225*7330f729Sjoerg         Hash.combine(PGOHash::IfElseBranch);
226*7330f729Sjoerg       TraverseStmt(CS);
227*7330f729Sjoerg     }
228*7330f729Sjoerg     Hash.combine(PGOHash::EndOfScope);
229*7330f729Sjoerg     return true;
230*7330f729Sjoerg   }
231*7330f729Sjoerg 
232*7330f729Sjoerg // If the statement type \p N is nestable, and its nesting impacts profile
233*7330f729Sjoerg // stability, define a custom traversal which tracks the end of the statement
234*7330f729Sjoerg // in the hash (provided we're not using the V1 hash).
235*7330f729Sjoerg #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
236*7330f729Sjoerg   bool Traverse##N(N *S) {                                                     \
237*7330f729Sjoerg     Base::Traverse##N(S);                                                      \
238*7330f729Sjoerg     if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
239*7330f729Sjoerg       Hash.combine(PGOHash::EndOfScope);                                       \
240*7330f729Sjoerg     return true;                                                               \
241*7330f729Sjoerg   }
242*7330f729Sjoerg 
243*7330f729Sjoerg   DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
244*7330f729Sjoerg   DEFINE_NESTABLE_TRAVERSAL(DoStmt)
245*7330f729Sjoerg   DEFINE_NESTABLE_TRAVERSAL(ForStmt)
246*7330f729Sjoerg   DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
247*7330f729Sjoerg   DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
248*7330f729Sjoerg   DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
249*7330f729Sjoerg   DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
250*7330f729Sjoerg 
251*7330f729Sjoerg   /// Get version \p HashVersion of the PGO hash for \p S.
252*7330f729Sjoerg   PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
253*7330f729Sjoerg     switch (S->getStmtClass()) {
254*7330f729Sjoerg     default:
255*7330f729Sjoerg       break;
256*7330f729Sjoerg     case Stmt::LabelStmtClass:
257*7330f729Sjoerg       return PGOHash::LabelStmt;
258*7330f729Sjoerg     case Stmt::WhileStmtClass:
259*7330f729Sjoerg       return PGOHash::WhileStmt;
260*7330f729Sjoerg     case Stmt::DoStmtClass:
261*7330f729Sjoerg       return PGOHash::DoStmt;
262*7330f729Sjoerg     case Stmt::ForStmtClass:
263*7330f729Sjoerg       return PGOHash::ForStmt;
264*7330f729Sjoerg     case Stmt::CXXForRangeStmtClass:
265*7330f729Sjoerg       return PGOHash::CXXForRangeStmt;
266*7330f729Sjoerg     case Stmt::ObjCForCollectionStmtClass:
267*7330f729Sjoerg       return PGOHash::ObjCForCollectionStmt;
268*7330f729Sjoerg     case Stmt::SwitchStmtClass:
269*7330f729Sjoerg       return PGOHash::SwitchStmt;
270*7330f729Sjoerg     case Stmt::CaseStmtClass:
271*7330f729Sjoerg       return PGOHash::CaseStmt;
272*7330f729Sjoerg     case Stmt::DefaultStmtClass:
273*7330f729Sjoerg       return PGOHash::DefaultStmt;
274*7330f729Sjoerg     case Stmt::IfStmtClass:
275*7330f729Sjoerg       return PGOHash::IfStmt;
276*7330f729Sjoerg     case Stmt::CXXTryStmtClass:
277*7330f729Sjoerg       return PGOHash::CXXTryStmt;
278*7330f729Sjoerg     case Stmt::CXXCatchStmtClass:
279*7330f729Sjoerg       return PGOHash::CXXCatchStmt;
280*7330f729Sjoerg     case Stmt::ConditionalOperatorClass:
281*7330f729Sjoerg       return PGOHash::ConditionalOperator;
282*7330f729Sjoerg     case Stmt::BinaryConditionalOperatorClass:
283*7330f729Sjoerg       return PGOHash::BinaryConditionalOperator;
284*7330f729Sjoerg     case Stmt::BinaryOperatorClass: {
285*7330f729Sjoerg       const BinaryOperator *BO = cast<BinaryOperator>(S);
286*7330f729Sjoerg       if (BO->getOpcode() == BO_LAnd)
287*7330f729Sjoerg         return PGOHash::BinaryOperatorLAnd;
288*7330f729Sjoerg       if (BO->getOpcode() == BO_LOr)
289*7330f729Sjoerg         return PGOHash::BinaryOperatorLOr;
290*7330f729Sjoerg       if (HashVersion == PGO_HASH_V2) {
291*7330f729Sjoerg         switch (BO->getOpcode()) {
292*7330f729Sjoerg         default:
293*7330f729Sjoerg           break;
294*7330f729Sjoerg         case BO_LT:
295*7330f729Sjoerg           return PGOHash::BinaryOperatorLT;
296*7330f729Sjoerg         case BO_GT:
297*7330f729Sjoerg           return PGOHash::BinaryOperatorGT;
298*7330f729Sjoerg         case BO_LE:
299*7330f729Sjoerg           return PGOHash::BinaryOperatorLE;
300*7330f729Sjoerg         case BO_GE:
301*7330f729Sjoerg           return PGOHash::BinaryOperatorGE;
302*7330f729Sjoerg         case BO_EQ:
303*7330f729Sjoerg           return PGOHash::BinaryOperatorEQ;
304*7330f729Sjoerg         case BO_NE:
305*7330f729Sjoerg           return PGOHash::BinaryOperatorNE;
306*7330f729Sjoerg         }
307*7330f729Sjoerg       }
308*7330f729Sjoerg       break;
309*7330f729Sjoerg     }
310*7330f729Sjoerg     }
311*7330f729Sjoerg 
312*7330f729Sjoerg     if (HashVersion == PGO_HASH_V2) {
313*7330f729Sjoerg       switch (S->getStmtClass()) {
314*7330f729Sjoerg       default:
315*7330f729Sjoerg         break;
316*7330f729Sjoerg       case Stmt::GotoStmtClass:
317*7330f729Sjoerg         return PGOHash::GotoStmt;
318*7330f729Sjoerg       case Stmt::IndirectGotoStmtClass:
319*7330f729Sjoerg         return PGOHash::IndirectGotoStmt;
320*7330f729Sjoerg       case Stmt::BreakStmtClass:
321*7330f729Sjoerg         return PGOHash::BreakStmt;
322*7330f729Sjoerg       case Stmt::ContinueStmtClass:
323*7330f729Sjoerg         return PGOHash::ContinueStmt;
324*7330f729Sjoerg       case Stmt::ReturnStmtClass:
325*7330f729Sjoerg         return PGOHash::ReturnStmt;
326*7330f729Sjoerg       case Stmt::CXXThrowExprClass:
327*7330f729Sjoerg         return PGOHash::ThrowExpr;
328*7330f729Sjoerg       case Stmt::UnaryOperatorClass: {
329*7330f729Sjoerg         const UnaryOperator *UO = cast<UnaryOperator>(S);
330*7330f729Sjoerg         if (UO->getOpcode() == UO_LNot)
331*7330f729Sjoerg           return PGOHash::UnaryOperatorLNot;
332*7330f729Sjoerg         break;
333*7330f729Sjoerg       }
334*7330f729Sjoerg       }
335*7330f729Sjoerg     }
336*7330f729Sjoerg 
337*7330f729Sjoerg     return PGOHash::None;
338*7330f729Sjoerg   }
339*7330f729Sjoerg };
340*7330f729Sjoerg 
341*7330f729Sjoerg /// A StmtVisitor that propagates the raw counts through the AST and
342*7330f729Sjoerg /// records the count at statements where the value may change.
343*7330f729Sjoerg struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
344*7330f729Sjoerg   /// PGO state.
345*7330f729Sjoerg   CodeGenPGO &PGO;
346*7330f729Sjoerg 
347*7330f729Sjoerg   /// A flag that is set when the current count should be recorded on the
348*7330f729Sjoerg   /// next statement, such as at the exit of a loop.
349*7330f729Sjoerg   bool RecordNextStmtCount;
350*7330f729Sjoerg 
351*7330f729Sjoerg   /// The count at the current location in the traversal.
352*7330f729Sjoerg   uint64_t CurrentCount;
353*7330f729Sjoerg 
354*7330f729Sjoerg   /// The map of statements to count values.
355*7330f729Sjoerg   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
356*7330f729Sjoerg 
357*7330f729Sjoerg   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
358*7330f729Sjoerg   struct BreakContinue {
359*7330f729Sjoerg     uint64_t BreakCount;
360*7330f729Sjoerg     uint64_t ContinueCount;
361*7330f729Sjoerg     BreakContinue() : BreakCount(0), ContinueCount(0) {}
362*7330f729Sjoerg   };
363*7330f729Sjoerg   SmallVector<BreakContinue, 8> BreakContinueStack;
364*7330f729Sjoerg 
365*7330f729Sjoerg   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
366*7330f729Sjoerg                       CodeGenPGO &PGO)
367*7330f729Sjoerg       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
368*7330f729Sjoerg 
369*7330f729Sjoerg   void RecordStmtCount(const Stmt *S) {
370*7330f729Sjoerg     if (RecordNextStmtCount) {
371*7330f729Sjoerg       CountMap[S] = CurrentCount;
372*7330f729Sjoerg       RecordNextStmtCount = false;
373*7330f729Sjoerg     }
374*7330f729Sjoerg   }
375*7330f729Sjoerg 
376*7330f729Sjoerg   /// Set and return the current count.
377*7330f729Sjoerg   uint64_t setCount(uint64_t Count) {
378*7330f729Sjoerg     CurrentCount = Count;
379*7330f729Sjoerg     return Count;
380*7330f729Sjoerg   }
381*7330f729Sjoerg 
382*7330f729Sjoerg   void VisitStmt(const Stmt *S) {
383*7330f729Sjoerg     RecordStmtCount(S);
384*7330f729Sjoerg     for (const Stmt *Child : S->children())
385*7330f729Sjoerg       if (Child)
386*7330f729Sjoerg         this->Visit(Child);
387*7330f729Sjoerg   }
388*7330f729Sjoerg 
389*7330f729Sjoerg   void VisitFunctionDecl(const FunctionDecl *D) {
390*7330f729Sjoerg     // Counter tracks entry to the function body.
391*7330f729Sjoerg     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
392*7330f729Sjoerg     CountMap[D->getBody()] = BodyCount;
393*7330f729Sjoerg     Visit(D->getBody());
394*7330f729Sjoerg   }
395*7330f729Sjoerg 
396*7330f729Sjoerg   // Skip lambda expressions. We visit these as FunctionDecls when we're
397*7330f729Sjoerg   // generating them and aren't interested in the body when generating a
398*7330f729Sjoerg   // parent context.
399*7330f729Sjoerg   void VisitLambdaExpr(const LambdaExpr *LE) {}
400*7330f729Sjoerg 
401*7330f729Sjoerg   void VisitCapturedDecl(const CapturedDecl *D) {
402*7330f729Sjoerg     // Counter tracks entry to the capture body.
403*7330f729Sjoerg     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
404*7330f729Sjoerg     CountMap[D->getBody()] = BodyCount;
405*7330f729Sjoerg     Visit(D->getBody());
406*7330f729Sjoerg   }
407*7330f729Sjoerg 
408*7330f729Sjoerg   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
409*7330f729Sjoerg     // Counter tracks entry to the method body.
410*7330f729Sjoerg     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
411*7330f729Sjoerg     CountMap[D->getBody()] = BodyCount;
412*7330f729Sjoerg     Visit(D->getBody());
413*7330f729Sjoerg   }
414*7330f729Sjoerg 
415*7330f729Sjoerg   void VisitBlockDecl(const BlockDecl *D) {
416*7330f729Sjoerg     // Counter tracks entry to the block body.
417*7330f729Sjoerg     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
418*7330f729Sjoerg     CountMap[D->getBody()] = BodyCount;
419*7330f729Sjoerg     Visit(D->getBody());
420*7330f729Sjoerg   }
421*7330f729Sjoerg 
422*7330f729Sjoerg   void VisitReturnStmt(const ReturnStmt *S) {
423*7330f729Sjoerg     RecordStmtCount(S);
424*7330f729Sjoerg     if (S->getRetValue())
425*7330f729Sjoerg       Visit(S->getRetValue());
426*7330f729Sjoerg     CurrentCount = 0;
427*7330f729Sjoerg     RecordNextStmtCount = true;
428*7330f729Sjoerg   }
429*7330f729Sjoerg 
430*7330f729Sjoerg   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
431*7330f729Sjoerg     RecordStmtCount(E);
432*7330f729Sjoerg     if (E->getSubExpr())
433*7330f729Sjoerg       Visit(E->getSubExpr());
434*7330f729Sjoerg     CurrentCount = 0;
435*7330f729Sjoerg     RecordNextStmtCount = true;
436*7330f729Sjoerg   }
437*7330f729Sjoerg 
438*7330f729Sjoerg   void VisitGotoStmt(const GotoStmt *S) {
439*7330f729Sjoerg     RecordStmtCount(S);
440*7330f729Sjoerg     CurrentCount = 0;
441*7330f729Sjoerg     RecordNextStmtCount = true;
442*7330f729Sjoerg   }
443*7330f729Sjoerg 
444*7330f729Sjoerg   void VisitLabelStmt(const LabelStmt *S) {
445*7330f729Sjoerg     RecordNextStmtCount = false;
446*7330f729Sjoerg     // Counter tracks the block following the label.
447*7330f729Sjoerg     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
448*7330f729Sjoerg     CountMap[S] = BlockCount;
449*7330f729Sjoerg     Visit(S->getSubStmt());
450*7330f729Sjoerg   }
451*7330f729Sjoerg 
452*7330f729Sjoerg   void VisitBreakStmt(const BreakStmt *S) {
453*7330f729Sjoerg     RecordStmtCount(S);
454*7330f729Sjoerg     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
455*7330f729Sjoerg     BreakContinueStack.back().BreakCount += CurrentCount;
456*7330f729Sjoerg     CurrentCount = 0;
457*7330f729Sjoerg     RecordNextStmtCount = true;
458*7330f729Sjoerg   }
459*7330f729Sjoerg 
460*7330f729Sjoerg   void VisitContinueStmt(const ContinueStmt *S) {
461*7330f729Sjoerg     RecordStmtCount(S);
462*7330f729Sjoerg     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
463*7330f729Sjoerg     BreakContinueStack.back().ContinueCount += CurrentCount;
464*7330f729Sjoerg     CurrentCount = 0;
465*7330f729Sjoerg     RecordNextStmtCount = true;
466*7330f729Sjoerg   }
467*7330f729Sjoerg 
468*7330f729Sjoerg   void VisitWhileStmt(const WhileStmt *S) {
469*7330f729Sjoerg     RecordStmtCount(S);
470*7330f729Sjoerg     uint64_t ParentCount = CurrentCount;
471*7330f729Sjoerg 
472*7330f729Sjoerg     BreakContinueStack.push_back(BreakContinue());
473*7330f729Sjoerg     // Visit the body region first so the break/continue adjustments can be
474*7330f729Sjoerg     // included when visiting the condition.
475*7330f729Sjoerg     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
476*7330f729Sjoerg     CountMap[S->getBody()] = CurrentCount;
477*7330f729Sjoerg     Visit(S->getBody());
478*7330f729Sjoerg     uint64_t BackedgeCount = CurrentCount;
479*7330f729Sjoerg 
480*7330f729Sjoerg     // ...then go back and propagate counts through the condition. The count
481*7330f729Sjoerg     // at the start of the condition is the sum of the incoming edges,
482*7330f729Sjoerg     // the backedge from the end of the loop body, and the edges from
483*7330f729Sjoerg     // continue statements.
484*7330f729Sjoerg     BreakContinue BC = BreakContinueStack.pop_back_val();
485*7330f729Sjoerg     uint64_t CondCount =
486*7330f729Sjoerg         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
487*7330f729Sjoerg     CountMap[S->getCond()] = CondCount;
488*7330f729Sjoerg     Visit(S->getCond());
489*7330f729Sjoerg     setCount(BC.BreakCount + CondCount - BodyCount);
490*7330f729Sjoerg     RecordNextStmtCount = true;
491*7330f729Sjoerg   }
492*7330f729Sjoerg 
493*7330f729Sjoerg   void VisitDoStmt(const DoStmt *S) {
494*7330f729Sjoerg     RecordStmtCount(S);
495*7330f729Sjoerg     uint64_t LoopCount = PGO.getRegionCount(S);
496*7330f729Sjoerg 
497*7330f729Sjoerg     BreakContinueStack.push_back(BreakContinue());
498*7330f729Sjoerg     // The count doesn't include the fallthrough from the parent scope. Add it.
499*7330f729Sjoerg     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
500*7330f729Sjoerg     CountMap[S->getBody()] = BodyCount;
501*7330f729Sjoerg     Visit(S->getBody());
502*7330f729Sjoerg     uint64_t BackedgeCount = CurrentCount;
503*7330f729Sjoerg 
504*7330f729Sjoerg     BreakContinue BC = BreakContinueStack.pop_back_val();
505*7330f729Sjoerg     // The count at the start of the condition is equal to the count at the
506*7330f729Sjoerg     // end of the body, plus any continues.
507*7330f729Sjoerg     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
508*7330f729Sjoerg     CountMap[S->getCond()] = CondCount;
509*7330f729Sjoerg     Visit(S->getCond());
510*7330f729Sjoerg     setCount(BC.BreakCount + CondCount - LoopCount);
511*7330f729Sjoerg     RecordNextStmtCount = true;
512*7330f729Sjoerg   }
513*7330f729Sjoerg 
514*7330f729Sjoerg   void VisitForStmt(const ForStmt *S) {
515*7330f729Sjoerg     RecordStmtCount(S);
516*7330f729Sjoerg     if (S->getInit())
517*7330f729Sjoerg       Visit(S->getInit());
518*7330f729Sjoerg 
519*7330f729Sjoerg     uint64_t ParentCount = CurrentCount;
520*7330f729Sjoerg 
521*7330f729Sjoerg     BreakContinueStack.push_back(BreakContinue());
522*7330f729Sjoerg     // Visit the body region first. (This is basically the same as a while
523*7330f729Sjoerg     // loop; see further comments in VisitWhileStmt.)
524*7330f729Sjoerg     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
525*7330f729Sjoerg     CountMap[S->getBody()] = BodyCount;
526*7330f729Sjoerg     Visit(S->getBody());
527*7330f729Sjoerg     uint64_t BackedgeCount = CurrentCount;
528*7330f729Sjoerg     BreakContinue BC = BreakContinueStack.pop_back_val();
529*7330f729Sjoerg 
530*7330f729Sjoerg     // The increment is essentially part of the body but it needs to include
531*7330f729Sjoerg     // the count for all the continue statements.
532*7330f729Sjoerg     if (S->getInc()) {
533*7330f729Sjoerg       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
534*7330f729Sjoerg       CountMap[S->getInc()] = IncCount;
535*7330f729Sjoerg       Visit(S->getInc());
536*7330f729Sjoerg     }
537*7330f729Sjoerg 
538*7330f729Sjoerg     // ...then go back and propagate counts through the condition.
539*7330f729Sjoerg     uint64_t CondCount =
540*7330f729Sjoerg         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
541*7330f729Sjoerg     if (S->getCond()) {
542*7330f729Sjoerg       CountMap[S->getCond()] = CondCount;
543*7330f729Sjoerg       Visit(S->getCond());
544*7330f729Sjoerg     }
545*7330f729Sjoerg     setCount(BC.BreakCount + CondCount - BodyCount);
546*7330f729Sjoerg     RecordNextStmtCount = true;
547*7330f729Sjoerg   }
548*7330f729Sjoerg 
549*7330f729Sjoerg   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
550*7330f729Sjoerg     RecordStmtCount(S);
551*7330f729Sjoerg     if (S->getInit())
552*7330f729Sjoerg       Visit(S->getInit());
553*7330f729Sjoerg     Visit(S->getLoopVarStmt());
554*7330f729Sjoerg     Visit(S->getRangeStmt());
555*7330f729Sjoerg     Visit(S->getBeginStmt());
556*7330f729Sjoerg     Visit(S->getEndStmt());
557*7330f729Sjoerg 
558*7330f729Sjoerg     uint64_t ParentCount = CurrentCount;
559*7330f729Sjoerg     BreakContinueStack.push_back(BreakContinue());
560*7330f729Sjoerg     // Visit the body region first. (This is basically the same as a while
561*7330f729Sjoerg     // loop; see further comments in VisitWhileStmt.)
562*7330f729Sjoerg     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
563*7330f729Sjoerg     CountMap[S->getBody()] = BodyCount;
564*7330f729Sjoerg     Visit(S->getBody());
565*7330f729Sjoerg     uint64_t BackedgeCount = CurrentCount;
566*7330f729Sjoerg     BreakContinue BC = BreakContinueStack.pop_back_val();
567*7330f729Sjoerg 
568*7330f729Sjoerg     // The increment is essentially part of the body but it needs to include
569*7330f729Sjoerg     // the count for all the continue statements.
570*7330f729Sjoerg     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
571*7330f729Sjoerg     CountMap[S->getInc()] = IncCount;
572*7330f729Sjoerg     Visit(S->getInc());
573*7330f729Sjoerg 
574*7330f729Sjoerg     // ...then go back and propagate counts through the condition.
575*7330f729Sjoerg     uint64_t CondCount =
576*7330f729Sjoerg         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
577*7330f729Sjoerg     CountMap[S->getCond()] = CondCount;
578*7330f729Sjoerg     Visit(S->getCond());
579*7330f729Sjoerg     setCount(BC.BreakCount + CondCount - BodyCount);
580*7330f729Sjoerg     RecordNextStmtCount = true;
581*7330f729Sjoerg   }
582*7330f729Sjoerg 
583*7330f729Sjoerg   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
584*7330f729Sjoerg     RecordStmtCount(S);
585*7330f729Sjoerg     Visit(S->getElement());
586*7330f729Sjoerg     uint64_t ParentCount = CurrentCount;
587*7330f729Sjoerg     BreakContinueStack.push_back(BreakContinue());
588*7330f729Sjoerg     // Counter tracks the body of the loop.
589*7330f729Sjoerg     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
590*7330f729Sjoerg     CountMap[S->getBody()] = BodyCount;
591*7330f729Sjoerg     Visit(S->getBody());
592*7330f729Sjoerg     uint64_t BackedgeCount = CurrentCount;
593*7330f729Sjoerg     BreakContinue BC = BreakContinueStack.pop_back_val();
594*7330f729Sjoerg 
595*7330f729Sjoerg     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
596*7330f729Sjoerg              BodyCount);
597*7330f729Sjoerg     RecordNextStmtCount = true;
598*7330f729Sjoerg   }
599*7330f729Sjoerg 
600*7330f729Sjoerg   void VisitSwitchStmt(const SwitchStmt *S) {
601*7330f729Sjoerg     RecordStmtCount(S);
602*7330f729Sjoerg     if (S->getInit())
603*7330f729Sjoerg       Visit(S->getInit());
604*7330f729Sjoerg     Visit(S->getCond());
605*7330f729Sjoerg     CurrentCount = 0;
606*7330f729Sjoerg     BreakContinueStack.push_back(BreakContinue());
607*7330f729Sjoerg     Visit(S->getBody());
608*7330f729Sjoerg     // If the switch is inside a loop, add the continue counts.
609*7330f729Sjoerg     BreakContinue BC = BreakContinueStack.pop_back_val();
610*7330f729Sjoerg     if (!BreakContinueStack.empty())
611*7330f729Sjoerg       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
612*7330f729Sjoerg     // Counter tracks the exit block of the switch.
613*7330f729Sjoerg     setCount(PGO.getRegionCount(S));
614*7330f729Sjoerg     RecordNextStmtCount = true;
615*7330f729Sjoerg   }
616*7330f729Sjoerg 
617*7330f729Sjoerg   void VisitSwitchCase(const SwitchCase *S) {
618*7330f729Sjoerg     RecordNextStmtCount = false;
619*7330f729Sjoerg     // Counter for this particular case. This counts only jumps from the
620*7330f729Sjoerg     // switch header and does not include fallthrough from the case before
621*7330f729Sjoerg     // this one.
622*7330f729Sjoerg     uint64_t CaseCount = PGO.getRegionCount(S);
623*7330f729Sjoerg     setCount(CurrentCount + CaseCount);
624*7330f729Sjoerg     // We need the count without fallthrough in the mapping, so it's more useful
625*7330f729Sjoerg     // for branch probabilities.
626*7330f729Sjoerg     CountMap[S] = CaseCount;
627*7330f729Sjoerg     RecordNextStmtCount = true;
628*7330f729Sjoerg     Visit(S->getSubStmt());
629*7330f729Sjoerg   }
630*7330f729Sjoerg 
631*7330f729Sjoerg   void VisitIfStmt(const IfStmt *S) {
632*7330f729Sjoerg     RecordStmtCount(S);
633*7330f729Sjoerg     uint64_t ParentCount = CurrentCount;
634*7330f729Sjoerg     if (S->getInit())
635*7330f729Sjoerg       Visit(S->getInit());
636*7330f729Sjoerg     Visit(S->getCond());
637*7330f729Sjoerg 
638*7330f729Sjoerg     // Counter tracks the "then" part of an if statement. The count for
639*7330f729Sjoerg     // the "else" part, if it exists, will be calculated from this counter.
640*7330f729Sjoerg     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
641*7330f729Sjoerg     CountMap[S->getThen()] = ThenCount;
642*7330f729Sjoerg     Visit(S->getThen());
643*7330f729Sjoerg     uint64_t OutCount = CurrentCount;
644*7330f729Sjoerg 
645*7330f729Sjoerg     uint64_t ElseCount = ParentCount - ThenCount;
646*7330f729Sjoerg     if (S->getElse()) {
647*7330f729Sjoerg       setCount(ElseCount);
648*7330f729Sjoerg       CountMap[S->getElse()] = ElseCount;
649*7330f729Sjoerg       Visit(S->getElse());
650*7330f729Sjoerg       OutCount += CurrentCount;
651*7330f729Sjoerg     } else
652*7330f729Sjoerg       OutCount += ElseCount;
653*7330f729Sjoerg     setCount(OutCount);
654*7330f729Sjoerg     RecordNextStmtCount = true;
655*7330f729Sjoerg   }
656*7330f729Sjoerg 
657*7330f729Sjoerg   void VisitCXXTryStmt(const CXXTryStmt *S) {
658*7330f729Sjoerg     RecordStmtCount(S);
659*7330f729Sjoerg     Visit(S->getTryBlock());
660*7330f729Sjoerg     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
661*7330f729Sjoerg       Visit(S->getHandler(I));
662*7330f729Sjoerg     // Counter tracks the continuation block of the try statement.
663*7330f729Sjoerg     setCount(PGO.getRegionCount(S));
664*7330f729Sjoerg     RecordNextStmtCount = true;
665*7330f729Sjoerg   }
666*7330f729Sjoerg 
667*7330f729Sjoerg   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
668*7330f729Sjoerg     RecordNextStmtCount = false;
669*7330f729Sjoerg     // Counter tracks the catch statement's handler block.
670*7330f729Sjoerg     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
671*7330f729Sjoerg     CountMap[S] = CatchCount;
672*7330f729Sjoerg     Visit(S->getHandlerBlock());
673*7330f729Sjoerg   }
674*7330f729Sjoerg 
675*7330f729Sjoerg   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
676*7330f729Sjoerg     RecordStmtCount(E);
677*7330f729Sjoerg     uint64_t ParentCount = CurrentCount;
678*7330f729Sjoerg     Visit(E->getCond());
679*7330f729Sjoerg 
680*7330f729Sjoerg     // Counter tracks the "true" part of a conditional operator. The
681*7330f729Sjoerg     // count in the "false" part will be calculated from this counter.
682*7330f729Sjoerg     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
683*7330f729Sjoerg     CountMap[E->getTrueExpr()] = TrueCount;
684*7330f729Sjoerg     Visit(E->getTrueExpr());
685*7330f729Sjoerg     uint64_t OutCount = CurrentCount;
686*7330f729Sjoerg 
687*7330f729Sjoerg     uint64_t FalseCount = setCount(ParentCount - TrueCount);
688*7330f729Sjoerg     CountMap[E->getFalseExpr()] = FalseCount;
689*7330f729Sjoerg     Visit(E->getFalseExpr());
690*7330f729Sjoerg     OutCount += CurrentCount;
691*7330f729Sjoerg 
692*7330f729Sjoerg     setCount(OutCount);
693*7330f729Sjoerg     RecordNextStmtCount = true;
694*7330f729Sjoerg   }
695*7330f729Sjoerg 
696*7330f729Sjoerg   void VisitBinLAnd(const BinaryOperator *E) {
697*7330f729Sjoerg     RecordStmtCount(E);
698*7330f729Sjoerg     uint64_t ParentCount = CurrentCount;
699*7330f729Sjoerg     Visit(E->getLHS());
700*7330f729Sjoerg     // Counter tracks the right hand side of a logical and operator.
701*7330f729Sjoerg     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
702*7330f729Sjoerg     CountMap[E->getRHS()] = RHSCount;
703*7330f729Sjoerg     Visit(E->getRHS());
704*7330f729Sjoerg     setCount(ParentCount + RHSCount - CurrentCount);
705*7330f729Sjoerg     RecordNextStmtCount = true;
706*7330f729Sjoerg   }
707*7330f729Sjoerg 
708*7330f729Sjoerg   void VisitBinLOr(const BinaryOperator *E) {
709*7330f729Sjoerg     RecordStmtCount(E);
710*7330f729Sjoerg     uint64_t ParentCount = CurrentCount;
711*7330f729Sjoerg     Visit(E->getLHS());
712*7330f729Sjoerg     // Counter tracks the right hand side of a logical or operator.
713*7330f729Sjoerg     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
714*7330f729Sjoerg     CountMap[E->getRHS()] = RHSCount;
715*7330f729Sjoerg     Visit(E->getRHS());
716*7330f729Sjoerg     setCount(ParentCount + RHSCount - CurrentCount);
717*7330f729Sjoerg     RecordNextStmtCount = true;
718*7330f729Sjoerg   }
719*7330f729Sjoerg };
720*7330f729Sjoerg } // end anonymous namespace
721*7330f729Sjoerg 
722*7330f729Sjoerg void PGOHash::combine(HashType Type) {
723*7330f729Sjoerg   // Check that we never combine 0 and only have six bits.
724*7330f729Sjoerg   assert(Type && "Hash is invalid: unexpected type 0");
725*7330f729Sjoerg   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
726*7330f729Sjoerg 
727*7330f729Sjoerg   // Pass through MD5 if enough work has built up.
728*7330f729Sjoerg   if (Count && Count % NumTypesPerWord == 0) {
729*7330f729Sjoerg     using namespace llvm::support;
730*7330f729Sjoerg     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
731*7330f729Sjoerg     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
732*7330f729Sjoerg     Working = 0;
733*7330f729Sjoerg   }
734*7330f729Sjoerg 
735*7330f729Sjoerg   // Accumulate the current type.
736*7330f729Sjoerg   ++Count;
737*7330f729Sjoerg   Working = Working << NumBitsPerType | Type;
738*7330f729Sjoerg }
739*7330f729Sjoerg 
740*7330f729Sjoerg uint64_t PGOHash::finalize() {
741*7330f729Sjoerg   // Use Working as the hash directly if we never used MD5.
742*7330f729Sjoerg   if (Count <= NumTypesPerWord)
743*7330f729Sjoerg     // No need to byte swap here, since none of the math was endian-dependent.
744*7330f729Sjoerg     // This number will be byte-swapped as required on endianness transitions,
745*7330f729Sjoerg     // so we will see the same value on the other side.
746*7330f729Sjoerg     return Working;
747*7330f729Sjoerg 
748*7330f729Sjoerg   // Check for remaining work in Working.
749*7330f729Sjoerg   if (Working)
750*7330f729Sjoerg     MD5.update(Working);
751*7330f729Sjoerg 
752*7330f729Sjoerg   // Finalize the MD5 and return the hash.
753*7330f729Sjoerg   llvm::MD5::MD5Result Result;
754*7330f729Sjoerg   MD5.final(Result);
755*7330f729Sjoerg   using namespace llvm::support;
756*7330f729Sjoerg   return Result.low();
757*7330f729Sjoerg }
758*7330f729Sjoerg 
759*7330f729Sjoerg void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
760*7330f729Sjoerg   const Decl *D = GD.getDecl();
761*7330f729Sjoerg   if (!D->hasBody())
762*7330f729Sjoerg     return;
763*7330f729Sjoerg 
764*7330f729Sjoerg   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
765*7330f729Sjoerg   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
766*7330f729Sjoerg   if (!InstrumentRegions && !PGOReader)
767*7330f729Sjoerg     return;
768*7330f729Sjoerg   if (D->isImplicit())
769*7330f729Sjoerg     return;
770*7330f729Sjoerg   // Constructors and destructors may be represented by several functions in IR.
771*7330f729Sjoerg   // If so, instrument only base variant, others are implemented by delegation
772*7330f729Sjoerg   // to the base one, it would be counted twice otherwise.
773*7330f729Sjoerg   if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
774*7330f729Sjoerg     if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
775*7330f729Sjoerg       if (GD.getCtorType() != Ctor_Base &&
776*7330f729Sjoerg           CodeGenFunction::IsConstructorDelegationValid(CCD))
777*7330f729Sjoerg         return;
778*7330f729Sjoerg   }
779*7330f729Sjoerg   if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
780*7330f729Sjoerg     return;
781*7330f729Sjoerg 
782*7330f729Sjoerg   CGM.ClearUnusedCoverageMapping(D);
783*7330f729Sjoerg   setFuncName(Fn);
784*7330f729Sjoerg 
785*7330f729Sjoerg   mapRegionCounters(D);
786*7330f729Sjoerg   if (CGM.getCodeGenOpts().CoverageMapping)
787*7330f729Sjoerg     emitCounterRegionMapping(D);
788*7330f729Sjoerg   if (PGOReader) {
789*7330f729Sjoerg     SourceManager &SM = CGM.getContext().getSourceManager();
790*7330f729Sjoerg     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
791*7330f729Sjoerg     computeRegionCounts(D);
792*7330f729Sjoerg     applyFunctionAttributes(PGOReader, Fn);
793*7330f729Sjoerg   }
794*7330f729Sjoerg }
795*7330f729Sjoerg 
796*7330f729Sjoerg void CodeGenPGO::mapRegionCounters(const Decl *D) {
797*7330f729Sjoerg   // Use the latest hash version when inserting instrumentation, but use the
798*7330f729Sjoerg   // version in the indexed profile if we're reading PGO data.
799*7330f729Sjoerg   PGOHashVersion HashVersion = PGO_HASH_LATEST;
800*7330f729Sjoerg   if (auto *PGOReader = CGM.getPGOReader())
801*7330f729Sjoerg     HashVersion = getPGOHashVersion(PGOReader, CGM);
802*7330f729Sjoerg 
803*7330f729Sjoerg   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
804*7330f729Sjoerg   MapRegionCounters Walker(HashVersion, *RegionCounterMap);
805*7330f729Sjoerg   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
806*7330f729Sjoerg     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
807*7330f729Sjoerg   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
808*7330f729Sjoerg     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
809*7330f729Sjoerg   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
810*7330f729Sjoerg     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
811*7330f729Sjoerg   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
812*7330f729Sjoerg     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
813*7330f729Sjoerg   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
814*7330f729Sjoerg   NumRegionCounters = Walker.NextCounter;
815*7330f729Sjoerg   FunctionHash = Walker.Hash.finalize();
816*7330f729Sjoerg }
817*7330f729Sjoerg 
818*7330f729Sjoerg bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
819*7330f729Sjoerg   if (!D->getBody())
820*7330f729Sjoerg     return true;
821*7330f729Sjoerg 
822*7330f729Sjoerg   // Don't map the functions in system headers.
823*7330f729Sjoerg   const auto &SM = CGM.getContext().getSourceManager();
824*7330f729Sjoerg   auto Loc = D->getBody()->getBeginLoc();
825*7330f729Sjoerg   return SM.isInSystemHeader(Loc);
826*7330f729Sjoerg }
827*7330f729Sjoerg 
828*7330f729Sjoerg void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
829*7330f729Sjoerg   if (skipRegionMappingForDecl(D))
830*7330f729Sjoerg     return;
831*7330f729Sjoerg 
832*7330f729Sjoerg   std::string CoverageMapping;
833*7330f729Sjoerg   llvm::raw_string_ostream OS(CoverageMapping);
834*7330f729Sjoerg   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
835*7330f729Sjoerg                                 CGM.getContext().getSourceManager(),
836*7330f729Sjoerg                                 CGM.getLangOpts(), RegionCounterMap.get());
837*7330f729Sjoerg   MappingGen.emitCounterMapping(D, OS);
838*7330f729Sjoerg   OS.flush();
839*7330f729Sjoerg 
840*7330f729Sjoerg   if (CoverageMapping.empty())
841*7330f729Sjoerg     return;
842*7330f729Sjoerg 
843*7330f729Sjoerg   CGM.getCoverageMapping()->addFunctionMappingRecord(
844*7330f729Sjoerg       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
845*7330f729Sjoerg }
846*7330f729Sjoerg 
847*7330f729Sjoerg void
848*7330f729Sjoerg CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
849*7330f729Sjoerg                                     llvm::GlobalValue::LinkageTypes Linkage) {
850*7330f729Sjoerg   if (skipRegionMappingForDecl(D))
851*7330f729Sjoerg     return;
852*7330f729Sjoerg 
853*7330f729Sjoerg   std::string CoverageMapping;
854*7330f729Sjoerg   llvm::raw_string_ostream OS(CoverageMapping);
855*7330f729Sjoerg   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
856*7330f729Sjoerg                                 CGM.getContext().getSourceManager(),
857*7330f729Sjoerg                                 CGM.getLangOpts());
858*7330f729Sjoerg   MappingGen.emitEmptyMapping(D, OS);
859*7330f729Sjoerg   OS.flush();
860*7330f729Sjoerg 
861*7330f729Sjoerg   if (CoverageMapping.empty())
862*7330f729Sjoerg     return;
863*7330f729Sjoerg 
864*7330f729Sjoerg   setFuncName(Name, Linkage);
865*7330f729Sjoerg   CGM.getCoverageMapping()->addFunctionMappingRecord(
866*7330f729Sjoerg       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
867*7330f729Sjoerg }
868*7330f729Sjoerg 
869*7330f729Sjoerg void CodeGenPGO::computeRegionCounts(const Decl *D) {
870*7330f729Sjoerg   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
871*7330f729Sjoerg   ComputeRegionCounts Walker(*StmtCountMap, *this);
872*7330f729Sjoerg   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
873*7330f729Sjoerg     Walker.VisitFunctionDecl(FD);
874*7330f729Sjoerg   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
875*7330f729Sjoerg     Walker.VisitObjCMethodDecl(MD);
876*7330f729Sjoerg   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
877*7330f729Sjoerg     Walker.VisitBlockDecl(BD);
878*7330f729Sjoerg   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
879*7330f729Sjoerg     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
880*7330f729Sjoerg }
881*7330f729Sjoerg 
882*7330f729Sjoerg void
883*7330f729Sjoerg CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
884*7330f729Sjoerg                                     llvm::Function *Fn) {
885*7330f729Sjoerg   if (!haveRegionCounts())
886*7330f729Sjoerg     return;
887*7330f729Sjoerg 
888*7330f729Sjoerg   uint64_t FunctionCount = getRegionCount(nullptr);
889*7330f729Sjoerg   Fn->setEntryCount(FunctionCount);
890*7330f729Sjoerg }
891*7330f729Sjoerg 
892*7330f729Sjoerg void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
893*7330f729Sjoerg                                       llvm::Value *StepV) {
894*7330f729Sjoerg   if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
895*7330f729Sjoerg     return;
896*7330f729Sjoerg   if (!Builder.GetInsertBlock())
897*7330f729Sjoerg     return;
898*7330f729Sjoerg 
899*7330f729Sjoerg   unsigned Counter = (*RegionCounterMap)[S];
900*7330f729Sjoerg   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
901*7330f729Sjoerg 
902*7330f729Sjoerg   llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
903*7330f729Sjoerg                          Builder.getInt64(FunctionHash),
904*7330f729Sjoerg                          Builder.getInt32(NumRegionCounters),
905*7330f729Sjoerg                          Builder.getInt32(Counter), StepV};
906*7330f729Sjoerg   if (!StepV)
907*7330f729Sjoerg     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
908*7330f729Sjoerg                        makeArrayRef(Args, 4));
909*7330f729Sjoerg   else
910*7330f729Sjoerg     Builder.CreateCall(
911*7330f729Sjoerg         CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
912*7330f729Sjoerg         makeArrayRef(Args));
913*7330f729Sjoerg }
914*7330f729Sjoerg 
915*7330f729Sjoerg // This method either inserts a call to the profile run-time during
916*7330f729Sjoerg // instrumentation or puts profile data into metadata for PGO use.
917*7330f729Sjoerg void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
918*7330f729Sjoerg     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
919*7330f729Sjoerg 
920*7330f729Sjoerg   if (!EnableValueProfiling)
921*7330f729Sjoerg     return;
922*7330f729Sjoerg 
923*7330f729Sjoerg   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
924*7330f729Sjoerg     return;
925*7330f729Sjoerg 
926*7330f729Sjoerg   if (isa<llvm::Constant>(ValuePtr))
927*7330f729Sjoerg     return;
928*7330f729Sjoerg 
929*7330f729Sjoerg   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
930*7330f729Sjoerg   if (InstrumentValueSites && RegionCounterMap) {
931*7330f729Sjoerg     auto BuilderInsertPoint = Builder.saveIP();
932*7330f729Sjoerg     Builder.SetInsertPoint(ValueSite);
933*7330f729Sjoerg     llvm::Value *Args[5] = {
934*7330f729Sjoerg         llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
935*7330f729Sjoerg         Builder.getInt64(FunctionHash),
936*7330f729Sjoerg         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
937*7330f729Sjoerg         Builder.getInt32(ValueKind),
938*7330f729Sjoerg         Builder.getInt32(NumValueSites[ValueKind]++)
939*7330f729Sjoerg     };
940*7330f729Sjoerg     Builder.CreateCall(
941*7330f729Sjoerg         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
942*7330f729Sjoerg     Builder.restoreIP(BuilderInsertPoint);
943*7330f729Sjoerg     return;
944*7330f729Sjoerg   }
945*7330f729Sjoerg 
946*7330f729Sjoerg   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
947*7330f729Sjoerg   if (PGOReader && haveRegionCounts()) {
948*7330f729Sjoerg     // We record the top most called three functions at each call site.
949*7330f729Sjoerg     // Profile metadata contains "VP" string identifying this metadata
950*7330f729Sjoerg     // as value profiling data, then a uint32_t value for the value profiling
951*7330f729Sjoerg     // kind, a uint64_t value for the total number of times the call is
952*7330f729Sjoerg     // executed, followed by the function hash and execution count (uint64_t)
953*7330f729Sjoerg     // pairs for each function.
954*7330f729Sjoerg     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
955*7330f729Sjoerg       return;
956*7330f729Sjoerg 
957*7330f729Sjoerg     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
958*7330f729Sjoerg                             (llvm::InstrProfValueKind)ValueKind,
959*7330f729Sjoerg                             NumValueSites[ValueKind]);
960*7330f729Sjoerg 
961*7330f729Sjoerg     NumValueSites[ValueKind]++;
962*7330f729Sjoerg   }
963*7330f729Sjoerg }
964*7330f729Sjoerg 
965*7330f729Sjoerg void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
966*7330f729Sjoerg                                   bool IsInMainFile) {
967*7330f729Sjoerg   CGM.getPGOStats().addVisited(IsInMainFile);
968*7330f729Sjoerg   RegionCounts.clear();
969*7330f729Sjoerg   llvm::Expected<llvm::InstrProfRecord> RecordExpected =
970*7330f729Sjoerg       PGOReader->getInstrProfRecord(FuncName, FunctionHash);
971*7330f729Sjoerg   if (auto E = RecordExpected.takeError()) {
972*7330f729Sjoerg     auto IPE = llvm::InstrProfError::take(std::move(E));
973*7330f729Sjoerg     if (IPE == llvm::instrprof_error::unknown_function)
974*7330f729Sjoerg       CGM.getPGOStats().addMissing(IsInMainFile);
975*7330f729Sjoerg     else if (IPE == llvm::instrprof_error::hash_mismatch)
976*7330f729Sjoerg       CGM.getPGOStats().addMismatched(IsInMainFile);
977*7330f729Sjoerg     else if (IPE == llvm::instrprof_error::malformed)
978*7330f729Sjoerg       // TODO: Consider a more specific warning for this case.
979*7330f729Sjoerg       CGM.getPGOStats().addMismatched(IsInMainFile);
980*7330f729Sjoerg     return;
981*7330f729Sjoerg   }
982*7330f729Sjoerg   ProfRecord =
983*7330f729Sjoerg       std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
984*7330f729Sjoerg   RegionCounts = ProfRecord->Counts;
985*7330f729Sjoerg }
986*7330f729Sjoerg 
987*7330f729Sjoerg /// Calculate what to divide by to scale weights.
988*7330f729Sjoerg ///
989*7330f729Sjoerg /// Given the maximum weight, calculate a divisor that will scale all the
990*7330f729Sjoerg /// weights to strictly less than UINT32_MAX.
991*7330f729Sjoerg static uint64_t calculateWeightScale(uint64_t MaxWeight) {
992*7330f729Sjoerg   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
993*7330f729Sjoerg }
994*7330f729Sjoerg 
995*7330f729Sjoerg /// Scale an individual branch weight (and add 1).
996*7330f729Sjoerg ///
997*7330f729Sjoerg /// Scale a 64-bit weight down to 32-bits using \c Scale.
998*7330f729Sjoerg ///
999*7330f729Sjoerg /// According to Laplace's Rule of Succession, it is better to compute the
1000*7330f729Sjoerg /// weight based on the count plus 1, so universally add 1 to the value.
1001*7330f729Sjoerg ///
1002*7330f729Sjoerg /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1003*7330f729Sjoerg /// greater than \c Weight.
1004*7330f729Sjoerg static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1005*7330f729Sjoerg   assert(Scale && "scale by 0?");
1006*7330f729Sjoerg   uint64_t Scaled = Weight / Scale + 1;
1007*7330f729Sjoerg   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1008*7330f729Sjoerg   return Scaled;
1009*7330f729Sjoerg }
1010*7330f729Sjoerg 
1011*7330f729Sjoerg llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1012*7330f729Sjoerg                                                     uint64_t FalseCount) {
1013*7330f729Sjoerg   // Check for empty weights.
1014*7330f729Sjoerg   if (!TrueCount && !FalseCount)
1015*7330f729Sjoerg     return nullptr;
1016*7330f729Sjoerg 
1017*7330f729Sjoerg   // Calculate how to scale down to 32-bits.
1018*7330f729Sjoerg   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1019*7330f729Sjoerg 
1020*7330f729Sjoerg   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1021*7330f729Sjoerg   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1022*7330f729Sjoerg                                       scaleBranchWeight(FalseCount, Scale));
1023*7330f729Sjoerg }
1024*7330f729Sjoerg 
1025*7330f729Sjoerg llvm::MDNode *
1026*7330f729Sjoerg CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
1027*7330f729Sjoerg   // We need at least two elements to create meaningful weights.
1028*7330f729Sjoerg   if (Weights.size() < 2)
1029*7330f729Sjoerg     return nullptr;
1030*7330f729Sjoerg 
1031*7330f729Sjoerg   // Check for empty weights.
1032*7330f729Sjoerg   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1033*7330f729Sjoerg   if (MaxWeight == 0)
1034*7330f729Sjoerg     return nullptr;
1035*7330f729Sjoerg 
1036*7330f729Sjoerg   // Calculate how to scale down to 32-bits.
1037*7330f729Sjoerg   uint64_t Scale = calculateWeightScale(MaxWeight);
1038*7330f729Sjoerg 
1039*7330f729Sjoerg   SmallVector<uint32_t, 16> ScaledWeights;
1040*7330f729Sjoerg   ScaledWeights.reserve(Weights.size());
1041*7330f729Sjoerg   for (uint64_t W : Weights)
1042*7330f729Sjoerg     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1043*7330f729Sjoerg 
1044*7330f729Sjoerg   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1045*7330f729Sjoerg   return MDHelper.createBranchWeights(ScaledWeights);
1046*7330f729Sjoerg }
1047*7330f729Sjoerg 
1048*7330f729Sjoerg llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1049*7330f729Sjoerg                                                            uint64_t LoopCount) {
1050*7330f729Sjoerg   if (!PGO.haveRegionCounts())
1051*7330f729Sjoerg     return nullptr;
1052*7330f729Sjoerg   Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1053*7330f729Sjoerg   assert(CondCount.hasValue() && "missing expected loop condition count");
1054*7330f729Sjoerg   if (*CondCount == 0)
1055*7330f729Sjoerg     return nullptr;
1056*7330f729Sjoerg   return createProfileWeights(LoopCount,
1057*7330f729Sjoerg                               std::max(*CondCount, LoopCount) - LoopCount);
1058*7330f729Sjoerg }
1059