xref: /freebsd-src/contrib/llvm-project/clang/lib/CodeGen/CodeGenPGO.cpp (revision cb14a3fe5122c879eae1fb480ed7ce82a699ddb6)
1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Instrumentation-based profile-guided optimization
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24 #include <optional>
25 
26 static llvm::cl::opt<bool>
27     EnableValueProfiling("enable-value-profiling",
28                          llvm::cl::desc("Enable value profiling"),
29                          llvm::cl::Hidden, llvm::cl::init(false));
30 
31 using namespace clang;
32 using namespace CodeGen;
33 
34 void CodeGenPGO::setFuncName(StringRef Name,
35                              llvm::GlobalValue::LinkageTypes Linkage) {
36   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
37   FuncName = llvm::getPGOFuncName(
38       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
39       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
40 
41   // If we're generating a profile, create a variable for the name.
42   if (CGM.getCodeGenOpts().hasProfileClangInstr())
43     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
44 }
45 
46 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
47   setFuncName(Fn->getName(), Fn->getLinkage());
48   // Create PGOFuncName meta data.
49   llvm::createPGOFuncNameMetadata(*Fn, FuncName);
50 }
51 
52 /// The version of the PGO hash algorithm.
53 enum PGOHashVersion : unsigned {
54   PGO_HASH_V1,
55   PGO_HASH_V2,
56   PGO_HASH_V3,
57 
58   // Keep this set to the latest hash version.
59   PGO_HASH_LATEST = PGO_HASH_V3
60 };
61 
62 namespace {
63 /// Stable hasher for PGO region counters.
64 ///
65 /// PGOHash produces a stable hash of a given function's control flow.
66 ///
67 /// Changing the output of this hash will invalidate all previously generated
68 /// profiles -- i.e., don't do it.
69 ///
70 /// \note  When this hash does eventually change (years?), we still need to
71 /// support old hashes.  We'll need to pull in the version number from the
72 /// profile data format and use the matching hash function.
73 class PGOHash {
74   uint64_t Working;
75   unsigned Count;
76   PGOHashVersion HashVersion;
77   llvm::MD5 MD5;
78 
79   static const int NumBitsPerType = 6;
80   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
81   static const unsigned TooBig = 1u << NumBitsPerType;
82 
83 public:
84   /// Hash values for AST nodes.
85   ///
86   /// Distinct values for AST nodes that have region counters attached.
87   ///
88   /// These values must be stable.  All new members must be added at the end,
89   /// and no members should be removed.  Changing the enumeration value for an
90   /// AST node will affect the hash of every function that contains that node.
91   enum HashType : unsigned char {
92     None = 0,
93     LabelStmt = 1,
94     WhileStmt,
95     DoStmt,
96     ForStmt,
97     CXXForRangeStmt,
98     ObjCForCollectionStmt,
99     SwitchStmt,
100     CaseStmt,
101     DefaultStmt,
102     IfStmt,
103     CXXTryStmt,
104     CXXCatchStmt,
105     ConditionalOperator,
106     BinaryOperatorLAnd,
107     BinaryOperatorLOr,
108     BinaryConditionalOperator,
109     // The preceding values are available with PGO_HASH_V1.
110 
111     EndOfScope,
112     IfThenBranch,
113     IfElseBranch,
114     GotoStmt,
115     IndirectGotoStmt,
116     BreakStmt,
117     ContinueStmt,
118     ReturnStmt,
119     ThrowExpr,
120     UnaryOperatorLNot,
121     BinaryOperatorLT,
122     BinaryOperatorGT,
123     BinaryOperatorLE,
124     BinaryOperatorGE,
125     BinaryOperatorEQ,
126     BinaryOperatorNE,
127     // The preceding values are available since PGO_HASH_V2.
128 
129     // Keep this last.  It's for the static assert that follows.
130     LastHashType
131   };
132   static_assert(LastHashType <= TooBig, "Too many types in HashType");
133 
134   PGOHash(PGOHashVersion HashVersion)
135       : Working(0), Count(0), HashVersion(HashVersion) {}
136   void combine(HashType Type);
137   uint64_t finalize();
138   PGOHashVersion getHashVersion() const { return HashVersion; }
139 };
140 const int PGOHash::NumBitsPerType;
141 const unsigned PGOHash::NumTypesPerWord;
142 const unsigned PGOHash::TooBig;
143 
144 /// Get the PGO hash version used in the given indexed profile.
145 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
146                                         CodeGenModule &CGM) {
147   if (PGOReader->getVersion() <= 4)
148     return PGO_HASH_V1;
149   if (PGOReader->getVersion() <= 5)
150     return PGO_HASH_V2;
151   return PGO_HASH_V3;
152 }
153 
154 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
155 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
156   using Base = RecursiveASTVisitor<MapRegionCounters>;
157 
158   /// The next counter value to assign.
159   unsigned NextCounter;
160   /// The function hash.
161   PGOHash Hash;
162   /// The map of statements to counters.
163   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
164   /// The profile version.
165   uint64_t ProfileVersion;
166 
167   MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
168                     llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
169       : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
170         ProfileVersion(ProfileVersion) {}
171 
172   // Blocks and lambdas are handled as separate functions, so we need not
173   // traverse them in the parent context.
174   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
175   bool TraverseLambdaExpr(LambdaExpr *LE) {
176     // Traverse the captures, but not the body.
177     for (auto C : zip(LE->captures(), LE->capture_inits()))
178       TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
179     return true;
180   }
181   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
182 
183   bool VisitDecl(const Decl *D) {
184     switch (D->getKind()) {
185     default:
186       break;
187     case Decl::Function:
188     case Decl::CXXMethod:
189     case Decl::CXXConstructor:
190     case Decl::CXXDestructor:
191     case Decl::CXXConversion:
192     case Decl::ObjCMethod:
193     case Decl::Block:
194     case Decl::Captured:
195       CounterMap[D->getBody()] = NextCounter++;
196       break;
197     }
198     return true;
199   }
200 
201   /// If \p S gets a fresh counter, update the counter mappings. Return the
202   /// V1 hash of \p S.
203   PGOHash::HashType updateCounterMappings(Stmt *S) {
204     auto Type = getHashType(PGO_HASH_V1, S);
205     if (Type != PGOHash::None)
206       CounterMap[S] = NextCounter++;
207     return Type;
208   }
209 
210   /// The RHS of all logical operators gets a fresh counter in order to count
211   /// how many times the RHS evaluates to true or false, depending on the
212   /// semantics of the operator. This is only valid for ">= v7" of the profile
213   /// version so that we facilitate backward compatibility.
214   bool VisitBinaryOperator(BinaryOperator *S) {
215     if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
216       if (S->isLogicalOp() &&
217           CodeGenFunction::isInstrumentedCondition(S->getRHS()))
218         CounterMap[S->getRHS()] = NextCounter++;
219     return Base::VisitBinaryOperator(S);
220   }
221 
222   /// Include \p S in the function hash.
223   bool VisitStmt(Stmt *S) {
224     auto Type = updateCounterMappings(S);
225     if (Hash.getHashVersion() != PGO_HASH_V1)
226       Type = getHashType(Hash.getHashVersion(), S);
227     if (Type != PGOHash::None)
228       Hash.combine(Type);
229     return true;
230   }
231 
232   bool TraverseIfStmt(IfStmt *If) {
233     // If we used the V1 hash, use the default traversal.
234     if (Hash.getHashVersion() == PGO_HASH_V1)
235       return Base::TraverseIfStmt(If);
236 
237     // Otherwise, keep track of which branch we're in while traversing.
238     VisitStmt(If);
239     for (Stmt *CS : If->children()) {
240       if (!CS)
241         continue;
242       if (CS == If->getThen())
243         Hash.combine(PGOHash::IfThenBranch);
244       else if (CS == If->getElse())
245         Hash.combine(PGOHash::IfElseBranch);
246       TraverseStmt(CS);
247     }
248     Hash.combine(PGOHash::EndOfScope);
249     return true;
250   }
251 
252 // If the statement type \p N is nestable, and its nesting impacts profile
253 // stability, define a custom traversal which tracks the end of the statement
254 // in the hash (provided we're not using the V1 hash).
255 #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
256   bool Traverse##N(N *S) {                                                     \
257     Base::Traverse##N(S);                                                      \
258     if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
259       Hash.combine(PGOHash::EndOfScope);                                       \
260     return true;                                                               \
261   }
262 
263   DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
264   DEFINE_NESTABLE_TRAVERSAL(DoStmt)
265   DEFINE_NESTABLE_TRAVERSAL(ForStmt)
266   DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
267   DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
268   DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
269   DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
270 
271   /// Get version \p HashVersion of the PGO hash for \p S.
272   PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
273     switch (S->getStmtClass()) {
274     default:
275       break;
276     case Stmt::LabelStmtClass:
277       return PGOHash::LabelStmt;
278     case Stmt::WhileStmtClass:
279       return PGOHash::WhileStmt;
280     case Stmt::DoStmtClass:
281       return PGOHash::DoStmt;
282     case Stmt::ForStmtClass:
283       return PGOHash::ForStmt;
284     case Stmt::CXXForRangeStmtClass:
285       return PGOHash::CXXForRangeStmt;
286     case Stmt::ObjCForCollectionStmtClass:
287       return PGOHash::ObjCForCollectionStmt;
288     case Stmt::SwitchStmtClass:
289       return PGOHash::SwitchStmt;
290     case Stmt::CaseStmtClass:
291       return PGOHash::CaseStmt;
292     case Stmt::DefaultStmtClass:
293       return PGOHash::DefaultStmt;
294     case Stmt::IfStmtClass:
295       return PGOHash::IfStmt;
296     case Stmt::CXXTryStmtClass:
297       return PGOHash::CXXTryStmt;
298     case Stmt::CXXCatchStmtClass:
299       return PGOHash::CXXCatchStmt;
300     case Stmt::ConditionalOperatorClass:
301       return PGOHash::ConditionalOperator;
302     case Stmt::BinaryConditionalOperatorClass:
303       return PGOHash::BinaryConditionalOperator;
304     case Stmt::BinaryOperatorClass: {
305       const BinaryOperator *BO = cast<BinaryOperator>(S);
306       if (BO->getOpcode() == BO_LAnd)
307         return PGOHash::BinaryOperatorLAnd;
308       if (BO->getOpcode() == BO_LOr)
309         return PGOHash::BinaryOperatorLOr;
310       if (HashVersion >= PGO_HASH_V2) {
311         switch (BO->getOpcode()) {
312         default:
313           break;
314         case BO_LT:
315           return PGOHash::BinaryOperatorLT;
316         case BO_GT:
317           return PGOHash::BinaryOperatorGT;
318         case BO_LE:
319           return PGOHash::BinaryOperatorLE;
320         case BO_GE:
321           return PGOHash::BinaryOperatorGE;
322         case BO_EQ:
323           return PGOHash::BinaryOperatorEQ;
324         case BO_NE:
325           return PGOHash::BinaryOperatorNE;
326         }
327       }
328       break;
329     }
330     }
331 
332     if (HashVersion >= PGO_HASH_V2) {
333       switch (S->getStmtClass()) {
334       default:
335         break;
336       case Stmt::GotoStmtClass:
337         return PGOHash::GotoStmt;
338       case Stmt::IndirectGotoStmtClass:
339         return PGOHash::IndirectGotoStmt;
340       case Stmt::BreakStmtClass:
341         return PGOHash::BreakStmt;
342       case Stmt::ContinueStmtClass:
343         return PGOHash::ContinueStmt;
344       case Stmt::ReturnStmtClass:
345         return PGOHash::ReturnStmt;
346       case Stmt::CXXThrowExprClass:
347         return PGOHash::ThrowExpr;
348       case Stmt::UnaryOperatorClass: {
349         const UnaryOperator *UO = cast<UnaryOperator>(S);
350         if (UO->getOpcode() == UO_LNot)
351           return PGOHash::UnaryOperatorLNot;
352         break;
353       }
354       }
355     }
356 
357     return PGOHash::None;
358   }
359 };
360 
361 /// A StmtVisitor that propagates the raw counts through the AST and
362 /// records the count at statements where the value may change.
363 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
364   /// PGO state.
365   CodeGenPGO &PGO;
366 
367   /// A flag that is set when the current count should be recorded on the
368   /// next statement, such as at the exit of a loop.
369   bool RecordNextStmtCount;
370 
371   /// The count at the current location in the traversal.
372   uint64_t CurrentCount;
373 
374   /// The map of statements to count values.
375   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
376 
377   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
378   struct BreakContinue {
379     uint64_t BreakCount = 0;
380     uint64_t ContinueCount = 0;
381     BreakContinue() = default;
382   };
383   SmallVector<BreakContinue, 8> BreakContinueStack;
384 
385   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
386                       CodeGenPGO &PGO)
387       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
388 
389   void RecordStmtCount(const Stmt *S) {
390     if (RecordNextStmtCount) {
391       CountMap[S] = CurrentCount;
392       RecordNextStmtCount = false;
393     }
394   }
395 
396   /// Set and return the current count.
397   uint64_t setCount(uint64_t Count) {
398     CurrentCount = Count;
399     return Count;
400   }
401 
402   void VisitStmt(const Stmt *S) {
403     RecordStmtCount(S);
404     for (const Stmt *Child : S->children())
405       if (Child)
406         this->Visit(Child);
407   }
408 
409   void VisitFunctionDecl(const FunctionDecl *D) {
410     // Counter tracks entry to the function body.
411     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
412     CountMap[D->getBody()] = BodyCount;
413     Visit(D->getBody());
414   }
415 
416   // Skip lambda expressions. We visit these as FunctionDecls when we're
417   // generating them and aren't interested in the body when generating a
418   // parent context.
419   void VisitLambdaExpr(const LambdaExpr *LE) {}
420 
421   void VisitCapturedDecl(const CapturedDecl *D) {
422     // Counter tracks entry to the capture body.
423     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
424     CountMap[D->getBody()] = BodyCount;
425     Visit(D->getBody());
426   }
427 
428   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
429     // Counter tracks entry to the method body.
430     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
431     CountMap[D->getBody()] = BodyCount;
432     Visit(D->getBody());
433   }
434 
435   void VisitBlockDecl(const BlockDecl *D) {
436     // Counter tracks entry to the block body.
437     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
438     CountMap[D->getBody()] = BodyCount;
439     Visit(D->getBody());
440   }
441 
442   void VisitReturnStmt(const ReturnStmt *S) {
443     RecordStmtCount(S);
444     if (S->getRetValue())
445       Visit(S->getRetValue());
446     CurrentCount = 0;
447     RecordNextStmtCount = true;
448   }
449 
450   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
451     RecordStmtCount(E);
452     if (E->getSubExpr())
453       Visit(E->getSubExpr());
454     CurrentCount = 0;
455     RecordNextStmtCount = true;
456   }
457 
458   void VisitGotoStmt(const GotoStmt *S) {
459     RecordStmtCount(S);
460     CurrentCount = 0;
461     RecordNextStmtCount = true;
462   }
463 
464   void VisitLabelStmt(const LabelStmt *S) {
465     RecordNextStmtCount = false;
466     // Counter tracks the block following the label.
467     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
468     CountMap[S] = BlockCount;
469     Visit(S->getSubStmt());
470   }
471 
472   void VisitBreakStmt(const BreakStmt *S) {
473     RecordStmtCount(S);
474     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
475     BreakContinueStack.back().BreakCount += CurrentCount;
476     CurrentCount = 0;
477     RecordNextStmtCount = true;
478   }
479 
480   void VisitContinueStmt(const ContinueStmt *S) {
481     RecordStmtCount(S);
482     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
483     BreakContinueStack.back().ContinueCount += CurrentCount;
484     CurrentCount = 0;
485     RecordNextStmtCount = true;
486   }
487 
488   void VisitWhileStmt(const WhileStmt *S) {
489     RecordStmtCount(S);
490     uint64_t ParentCount = CurrentCount;
491 
492     BreakContinueStack.push_back(BreakContinue());
493     // Visit the body region first so the break/continue adjustments can be
494     // included when visiting the condition.
495     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
496     CountMap[S->getBody()] = CurrentCount;
497     Visit(S->getBody());
498     uint64_t BackedgeCount = CurrentCount;
499 
500     // ...then go back and propagate counts through the condition. The count
501     // at the start of the condition is the sum of the incoming edges,
502     // the backedge from the end of the loop body, and the edges from
503     // continue statements.
504     BreakContinue BC = BreakContinueStack.pop_back_val();
505     uint64_t CondCount =
506         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
507     CountMap[S->getCond()] = CondCount;
508     Visit(S->getCond());
509     setCount(BC.BreakCount + CondCount - BodyCount);
510     RecordNextStmtCount = true;
511   }
512 
513   void VisitDoStmt(const DoStmt *S) {
514     RecordStmtCount(S);
515     uint64_t LoopCount = PGO.getRegionCount(S);
516 
517     BreakContinueStack.push_back(BreakContinue());
518     // The count doesn't include the fallthrough from the parent scope. Add it.
519     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
520     CountMap[S->getBody()] = BodyCount;
521     Visit(S->getBody());
522     uint64_t BackedgeCount = CurrentCount;
523 
524     BreakContinue BC = BreakContinueStack.pop_back_val();
525     // The count at the start of the condition is equal to the count at the
526     // end of the body, plus any continues.
527     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
528     CountMap[S->getCond()] = CondCount;
529     Visit(S->getCond());
530     setCount(BC.BreakCount + CondCount - LoopCount);
531     RecordNextStmtCount = true;
532   }
533 
534   void VisitForStmt(const ForStmt *S) {
535     RecordStmtCount(S);
536     if (S->getInit())
537       Visit(S->getInit());
538 
539     uint64_t ParentCount = CurrentCount;
540 
541     BreakContinueStack.push_back(BreakContinue());
542     // Visit the body region first. (This is basically the same as a while
543     // loop; see further comments in VisitWhileStmt.)
544     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
545     CountMap[S->getBody()] = BodyCount;
546     Visit(S->getBody());
547     uint64_t BackedgeCount = CurrentCount;
548     BreakContinue BC = BreakContinueStack.pop_back_val();
549 
550     // The increment is essentially part of the body but it needs to include
551     // the count for all the continue statements.
552     if (S->getInc()) {
553       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
554       CountMap[S->getInc()] = IncCount;
555       Visit(S->getInc());
556     }
557 
558     // ...then go back and propagate counts through the condition.
559     uint64_t CondCount =
560         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
561     if (S->getCond()) {
562       CountMap[S->getCond()] = CondCount;
563       Visit(S->getCond());
564     }
565     setCount(BC.BreakCount + CondCount - BodyCount);
566     RecordNextStmtCount = true;
567   }
568 
569   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
570     RecordStmtCount(S);
571     if (S->getInit())
572       Visit(S->getInit());
573     Visit(S->getLoopVarStmt());
574     Visit(S->getRangeStmt());
575     Visit(S->getBeginStmt());
576     Visit(S->getEndStmt());
577 
578     uint64_t ParentCount = CurrentCount;
579     BreakContinueStack.push_back(BreakContinue());
580     // Visit the body region first. (This is basically the same as a while
581     // loop; see further comments in VisitWhileStmt.)
582     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
583     CountMap[S->getBody()] = BodyCount;
584     Visit(S->getBody());
585     uint64_t BackedgeCount = CurrentCount;
586     BreakContinue BC = BreakContinueStack.pop_back_val();
587 
588     // The increment is essentially part of the body but it needs to include
589     // the count for all the continue statements.
590     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
591     CountMap[S->getInc()] = IncCount;
592     Visit(S->getInc());
593 
594     // ...then go back and propagate counts through the condition.
595     uint64_t CondCount =
596         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
597     CountMap[S->getCond()] = CondCount;
598     Visit(S->getCond());
599     setCount(BC.BreakCount + CondCount - BodyCount);
600     RecordNextStmtCount = true;
601   }
602 
603   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
604     RecordStmtCount(S);
605     Visit(S->getElement());
606     uint64_t ParentCount = CurrentCount;
607     BreakContinueStack.push_back(BreakContinue());
608     // Counter tracks the body of the loop.
609     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
610     CountMap[S->getBody()] = BodyCount;
611     Visit(S->getBody());
612     uint64_t BackedgeCount = CurrentCount;
613     BreakContinue BC = BreakContinueStack.pop_back_val();
614 
615     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
616              BodyCount);
617     RecordNextStmtCount = true;
618   }
619 
620   void VisitSwitchStmt(const SwitchStmt *S) {
621     RecordStmtCount(S);
622     if (S->getInit())
623       Visit(S->getInit());
624     Visit(S->getCond());
625     CurrentCount = 0;
626     BreakContinueStack.push_back(BreakContinue());
627     Visit(S->getBody());
628     // If the switch is inside a loop, add the continue counts.
629     BreakContinue BC = BreakContinueStack.pop_back_val();
630     if (!BreakContinueStack.empty())
631       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
632     // Counter tracks the exit block of the switch.
633     setCount(PGO.getRegionCount(S));
634     RecordNextStmtCount = true;
635   }
636 
637   void VisitSwitchCase(const SwitchCase *S) {
638     RecordNextStmtCount = false;
639     // Counter for this particular case. This counts only jumps from the
640     // switch header and does not include fallthrough from the case before
641     // this one.
642     uint64_t CaseCount = PGO.getRegionCount(S);
643     setCount(CurrentCount + CaseCount);
644     // We need the count without fallthrough in the mapping, so it's more useful
645     // for branch probabilities.
646     CountMap[S] = CaseCount;
647     RecordNextStmtCount = true;
648     Visit(S->getSubStmt());
649   }
650 
651   void VisitIfStmt(const IfStmt *S) {
652     RecordStmtCount(S);
653 
654     if (S->isConsteval()) {
655       const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse();
656       if (Stm)
657         Visit(Stm);
658       return;
659     }
660 
661     uint64_t ParentCount = CurrentCount;
662     if (S->getInit())
663       Visit(S->getInit());
664     Visit(S->getCond());
665 
666     // Counter tracks the "then" part of an if statement. The count for
667     // the "else" part, if it exists, will be calculated from this counter.
668     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
669     CountMap[S->getThen()] = ThenCount;
670     Visit(S->getThen());
671     uint64_t OutCount = CurrentCount;
672 
673     uint64_t ElseCount = ParentCount - ThenCount;
674     if (S->getElse()) {
675       setCount(ElseCount);
676       CountMap[S->getElse()] = ElseCount;
677       Visit(S->getElse());
678       OutCount += CurrentCount;
679     } else
680       OutCount += ElseCount;
681     setCount(OutCount);
682     RecordNextStmtCount = true;
683   }
684 
685   void VisitCXXTryStmt(const CXXTryStmt *S) {
686     RecordStmtCount(S);
687     Visit(S->getTryBlock());
688     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
689       Visit(S->getHandler(I));
690     // Counter tracks the continuation block of the try statement.
691     setCount(PGO.getRegionCount(S));
692     RecordNextStmtCount = true;
693   }
694 
695   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
696     RecordNextStmtCount = false;
697     // Counter tracks the catch statement's handler block.
698     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
699     CountMap[S] = CatchCount;
700     Visit(S->getHandlerBlock());
701   }
702 
703   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
704     RecordStmtCount(E);
705     uint64_t ParentCount = CurrentCount;
706     Visit(E->getCond());
707 
708     // Counter tracks the "true" part of a conditional operator. The
709     // count in the "false" part will be calculated from this counter.
710     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
711     CountMap[E->getTrueExpr()] = TrueCount;
712     Visit(E->getTrueExpr());
713     uint64_t OutCount = CurrentCount;
714 
715     uint64_t FalseCount = setCount(ParentCount - TrueCount);
716     CountMap[E->getFalseExpr()] = FalseCount;
717     Visit(E->getFalseExpr());
718     OutCount += CurrentCount;
719 
720     setCount(OutCount);
721     RecordNextStmtCount = true;
722   }
723 
724   void VisitBinLAnd(const BinaryOperator *E) {
725     RecordStmtCount(E);
726     uint64_t ParentCount = CurrentCount;
727     Visit(E->getLHS());
728     // Counter tracks the right hand side of a logical and operator.
729     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
730     CountMap[E->getRHS()] = RHSCount;
731     Visit(E->getRHS());
732     setCount(ParentCount + RHSCount - CurrentCount);
733     RecordNextStmtCount = true;
734   }
735 
736   void VisitBinLOr(const BinaryOperator *E) {
737     RecordStmtCount(E);
738     uint64_t ParentCount = CurrentCount;
739     Visit(E->getLHS());
740     // Counter tracks the right hand side of a logical or operator.
741     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
742     CountMap[E->getRHS()] = RHSCount;
743     Visit(E->getRHS());
744     setCount(ParentCount + RHSCount - CurrentCount);
745     RecordNextStmtCount = true;
746   }
747 };
748 } // end anonymous namespace
749 
750 void PGOHash::combine(HashType Type) {
751   // Check that we never combine 0 and only have six bits.
752   assert(Type && "Hash is invalid: unexpected type 0");
753   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
754 
755   // Pass through MD5 if enough work has built up.
756   if (Count && Count % NumTypesPerWord == 0) {
757     using namespace llvm::support;
758     uint64_t Swapped =
759         endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
760     MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
761     Working = 0;
762   }
763 
764   // Accumulate the current type.
765   ++Count;
766   Working = Working << NumBitsPerType | Type;
767 }
768 
769 uint64_t PGOHash::finalize() {
770   // Use Working as the hash directly if we never used MD5.
771   if (Count <= NumTypesPerWord)
772     // No need to byte swap here, since none of the math was endian-dependent.
773     // This number will be byte-swapped as required on endianness transitions,
774     // so we will see the same value on the other side.
775     return Working;
776 
777   // Check for remaining work in Working.
778   if (Working) {
779     // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
780     // is buggy because it converts a uint64_t into an array of uint8_t.
781     if (HashVersion < PGO_HASH_V3) {
782       MD5.update({(uint8_t)Working});
783     } else {
784       using namespace llvm::support;
785       uint64_t Swapped =
786           endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
787       MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
788     }
789   }
790 
791   // Finalize the MD5 and return the hash.
792   llvm::MD5::MD5Result Result;
793   MD5.final(Result);
794   return Result.low();
795 }
796 
797 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
798   const Decl *D = GD.getDecl();
799   if (!D->hasBody())
800     return;
801 
802   // Skip CUDA/HIP kernel launch stub functions.
803   if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
804       D->hasAttr<CUDAGlobalAttr>())
805     return;
806 
807   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
808   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
809   if (!InstrumentRegions && !PGOReader)
810     return;
811   if (D->isImplicit())
812     return;
813   // Constructors and destructors may be represented by several functions in IR.
814   // If so, instrument only base variant, others are implemented by delegation
815   // to the base one, it would be counted twice otherwise.
816   if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
817     if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
818       if (GD.getCtorType() != Ctor_Base &&
819           CodeGenFunction::IsConstructorDelegationValid(CCD))
820         return;
821   }
822   if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
823     return;
824 
825   CGM.ClearUnusedCoverageMapping(D);
826   if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
827     return;
828   if (Fn->hasFnAttribute(llvm::Attribute::SkipProfile))
829     return;
830 
831   setFuncName(Fn);
832 
833   mapRegionCounters(D);
834   if (CGM.getCodeGenOpts().CoverageMapping)
835     emitCounterRegionMapping(D);
836   if (PGOReader) {
837     SourceManager &SM = CGM.getContext().getSourceManager();
838     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
839     computeRegionCounts(D);
840     applyFunctionAttributes(PGOReader, Fn);
841   }
842 }
843 
844 void CodeGenPGO::mapRegionCounters(const Decl *D) {
845   // Use the latest hash version when inserting instrumentation, but use the
846   // version in the indexed profile if we're reading PGO data.
847   PGOHashVersion HashVersion = PGO_HASH_LATEST;
848   uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
849   if (auto *PGOReader = CGM.getPGOReader()) {
850     HashVersion = getPGOHashVersion(PGOReader, CGM);
851     ProfileVersion = PGOReader->getVersion();
852   }
853 
854   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
855   MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap);
856   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
857     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
858   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
859     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
860   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
861     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
862   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
863     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
864   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
865   NumRegionCounters = Walker.NextCounter;
866   FunctionHash = Walker.Hash.finalize();
867 }
868 
869 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
870   if (!D->getBody())
871     return true;
872 
873   // Skip host-only functions in the CUDA device compilation and device-only
874   // functions in the host compilation. Just roughly filter them out based on
875   // the function attributes. If there are effectively host-only or device-only
876   // ones, their coverage mapping may still be generated.
877   if (CGM.getLangOpts().CUDA &&
878       ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
879         !D->hasAttr<CUDAGlobalAttr>()) ||
880        (!CGM.getLangOpts().CUDAIsDevice &&
881         (D->hasAttr<CUDAGlobalAttr>() ||
882          (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
883     return true;
884 
885   // Don't map the functions in system headers.
886   const auto &SM = CGM.getContext().getSourceManager();
887   auto Loc = D->getBody()->getBeginLoc();
888   return SM.isInSystemHeader(Loc);
889 }
890 
891 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
892   if (skipRegionMappingForDecl(D))
893     return;
894 
895   std::string CoverageMapping;
896   llvm::raw_string_ostream OS(CoverageMapping);
897   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
898                                 CGM.getContext().getSourceManager(),
899                                 CGM.getLangOpts(), RegionCounterMap.get());
900   MappingGen.emitCounterMapping(D, OS);
901   OS.flush();
902 
903   if (CoverageMapping.empty())
904     return;
905 
906   CGM.getCoverageMapping()->addFunctionMappingRecord(
907       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
908 }
909 
910 void
911 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
912                                     llvm::GlobalValue::LinkageTypes Linkage) {
913   if (skipRegionMappingForDecl(D))
914     return;
915 
916   std::string CoverageMapping;
917   llvm::raw_string_ostream OS(CoverageMapping);
918   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
919                                 CGM.getContext().getSourceManager(),
920                                 CGM.getLangOpts());
921   MappingGen.emitEmptyMapping(D, OS);
922   OS.flush();
923 
924   if (CoverageMapping.empty())
925     return;
926 
927   setFuncName(Name, Linkage);
928   CGM.getCoverageMapping()->addFunctionMappingRecord(
929       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
930 }
931 
932 void CodeGenPGO::computeRegionCounts(const Decl *D) {
933   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
934   ComputeRegionCounts Walker(*StmtCountMap, *this);
935   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
936     Walker.VisitFunctionDecl(FD);
937   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
938     Walker.VisitObjCMethodDecl(MD);
939   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
940     Walker.VisitBlockDecl(BD);
941   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
942     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
943 }
944 
945 void
946 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
947                                     llvm::Function *Fn) {
948   if (!haveRegionCounts())
949     return;
950 
951   uint64_t FunctionCount = getRegionCount(nullptr);
952   Fn->setEntryCount(FunctionCount);
953 }
954 
955 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
956                                       llvm::Value *StepV) {
957   if (!RegionCounterMap || !Builder.GetInsertBlock())
958     return;
959 
960   unsigned Counter = (*RegionCounterMap)[S];
961 
962   llvm::Value *Args[] = {FuncNameVar,
963                          Builder.getInt64(FunctionHash),
964                          Builder.getInt32(NumRegionCounters),
965                          Builder.getInt32(Counter), StepV};
966   if (!StepV)
967     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
968                        ArrayRef(Args, 4));
969   else
970     Builder.CreateCall(
971         CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
972         ArrayRef(Args));
973 }
974 
975 void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
976   if (CGM.getCodeGenOpts().hasProfileClangInstr())
977     M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling",
978                     uint32_t(EnableValueProfiling));
979 }
980 
981 // This method either inserts a call to the profile run-time during
982 // instrumentation or puts profile data into metadata for PGO use.
983 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
984     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
985 
986   if (!EnableValueProfiling)
987     return;
988 
989   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
990     return;
991 
992   if (isa<llvm::Constant>(ValuePtr))
993     return;
994 
995   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
996   if (InstrumentValueSites && RegionCounterMap) {
997     auto BuilderInsertPoint = Builder.saveIP();
998     Builder.SetInsertPoint(ValueSite);
999     llvm::Value *Args[5] = {
1000         FuncNameVar,
1001         Builder.getInt64(FunctionHash),
1002         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
1003         Builder.getInt32(ValueKind),
1004         Builder.getInt32(NumValueSites[ValueKind]++)
1005     };
1006     Builder.CreateCall(
1007         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
1008     Builder.restoreIP(BuilderInsertPoint);
1009     return;
1010   }
1011 
1012   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1013   if (PGOReader && haveRegionCounts()) {
1014     // We record the top most called three functions at each call site.
1015     // Profile metadata contains "VP" string identifying this metadata
1016     // as value profiling data, then a uint32_t value for the value profiling
1017     // kind, a uint64_t value for the total number of times the call is
1018     // executed, followed by the function hash and execution count (uint64_t)
1019     // pairs for each function.
1020     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1021       return;
1022 
1023     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
1024                             (llvm::InstrProfValueKind)ValueKind,
1025                             NumValueSites[ValueKind]);
1026 
1027     NumValueSites[ValueKind]++;
1028   }
1029 }
1030 
1031 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1032                                   bool IsInMainFile) {
1033   CGM.getPGOStats().addVisited(IsInMainFile);
1034   RegionCounts.clear();
1035   llvm::Expected<llvm::InstrProfRecord> RecordExpected =
1036       PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1037   if (auto E = RecordExpected.takeError()) {
1038     auto IPE = std::get<0>(llvm::InstrProfError::take(std::move(E)));
1039     if (IPE == llvm::instrprof_error::unknown_function)
1040       CGM.getPGOStats().addMissing(IsInMainFile);
1041     else if (IPE == llvm::instrprof_error::hash_mismatch)
1042       CGM.getPGOStats().addMismatched(IsInMainFile);
1043     else if (IPE == llvm::instrprof_error::malformed)
1044       // TODO: Consider a more specific warning for this case.
1045       CGM.getPGOStats().addMismatched(IsInMainFile);
1046     return;
1047   }
1048   ProfRecord =
1049       std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1050   RegionCounts = ProfRecord->Counts;
1051 }
1052 
1053 /// Calculate what to divide by to scale weights.
1054 ///
1055 /// Given the maximum weight, calculate a divisor that will scale all the
1056 /// weights to strictly less than UINT32_MAX.
1057 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1058   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1059 }
1060 
1061 /// Scale an individual branch weight (and add 1).
1062 ///
1063 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1064 ///
1065 /// According to Laplace's Rule of Succession, it is better to compute the
1066 /// weight based on the count plus 1, so universally add 1 to the value.
1067 ///
1068 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1069 /// greater than \c Weight.
1070 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1071   assert(Scale && "scale by 0?");
1072   uint64_t Scaled = Weight / Scale + 1;
1073   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1074   return Scaled;
1075 }
1076 
1077 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1078                                                     uint64_t FalseCount) const {
1079   // Check for empty weights.
1080   if (!TrueCount && !FalseCount)
1081     return nullptr;
1082 
1083   // Calculate how to scale down to 32-bits.
1084   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1085 
1086   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1087   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1088                                       scaleBranchWeight(FalseCount, Scale));
1089 }
1090 
1091 llvm::MDNode *
1092 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1093   // We need at least two elements to create meaningful weights.
1094   if (Weights.size() < 2)
1095     return nullptr;
1096 
1097   // Check for empty weights.
1098   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1099   if (MaxWeight == 0)
1100     return nullptr;
1101 
1102   // Calculate how to scale down to 32-bits.
1103   uint64_t Scale = calculateWeightScale(MaxWeight);
1104 
1105   SmallVector<uint32_t, 16> ScaledWeights;
1106   ScaledWeights.reserve(Weights.size());
1107   for (uint64_t W : Weights)
1108     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1109 
1110   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1111   return MDHelper.createBranchWeights(ScaledWeights);
1112 }
1113 
1114 llvm::MDNode *
1115 CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1116                                              uint64_t LoopCount) const {
1117   if (!PGO.haveRegionCounts())
1118     return nullptr;
1119   std::optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1120   if (!CondCount || *CondCount == 0)
1121     return nullptr;
1122   return createProfileWeights(LoopCount,
1123                               std::max(*CondCount, LoopCount) - LoopCount);
1124 }
1125