xref: /llvm-project/clang/lib/CodeGen/CodeGenPGO.cpp (revision 397ac44f623f891d8f05d6673a95984ac0a26671)
1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Instrumentation-based profile-guided optimization
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/MD5.h"
23 #include <optional>
24 
25 namespace llvm {
26 extern cl::opt<bool> EnableSingleByteCoverage;
27 } // namespace llvm
28 
29 static llvm::cl::opt<bool>
30     EnableValueProfiling("enable-value-profiling",
31                          llvm::cl::desc("Enable value profiling"),
32                          llvm::cl::Hidden, llvm::cl::init(false));
33 
34 using namespace clang;
35 using namespace CodeGen;
36 
37 void CodeGenPGO::setFuncName(StringRef Name,
38                              llvm::GlobalValue::LinkageTypes Linkage) {
39   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
40   FuncName = llvm::getPGOFuncName(
41       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
42       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
43 
44   // If we're generating a profile, create a variable for the name.
45   if (CGM.getCodeGenOpts().hasProfileClangInstr())
46     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
47 }
48 
49 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
50   setFuncName(Fn->getName(), Fn->getLinkage());
51   // Create PGOFuncName meta data.
52   llvm::createPGOFuncNameMetadata(*Fn, FuncName);
53 }
54 
55 /// The version of the PGO hash algorithm.
56 enum PGOHashVersion : unsigned {
57   PGO_HASH_V1,
58   PGO_HASH_V2,
59   PGO_HASH_V3,
60 
61   // Keep this set to the latest hash version.
62   PGO_HASH_LATEST = PGO_HASH_V3
63 };
64 
65 namespace {
66 /// Stable hasher for PGO region counters.
67 ///
68 /// PGOHash produces a stable hash of a given function's control flow.
69 ///
70 /// Changing the output of this hash will invalidate all previously generated
71 /// profiles -- i.e., don't do it.
72 ///
73 /// \note  When this hash does eventually change (years?), we still need to
74 /// support old hashes.  We'll need to pull in the version number from the
75 /// profile data format and use the matching hash function.
76 class PGOHash {
77   uint64_t Working;
78   unsigned Count;
79   PGOHashVersion HashVersion;
80   llvm::MD5 MD5;
81 
82   static const int NumBitsPerType = 6;
83   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
84   static const unsigned TooBig = 1u << NumBitsPerType;
85 
86 public:
87   /// Hash values for AST nodes.
88   ///
89   /// Distinct values for AST nodes that have region counters attached.
90   ///
91   /// These values must be stable.  All new members must be added at the end,
92   /// and no members should be removed.  Changing the enumeration value for an
93   /// AST node will affect the hash of every function that contains that node.
94   enum HashType : unsigned char {
95     None = 0,
96     LabelStmt = 1,
97     WhileStmt,
98     DoStmt,
99     ForStmt,
100     CXXForRangeStmt,
101     ObjCForCollectionStmt,
102     SwitchStmt,
103     CaseStmt,
104     DefaultStmt,
105     IfStmt,
106     CXXTryStmt,
107     CXXCatchStmt,
108     ConditionalOperator,
109     BinaryOperatorLAnd,
110     BinaryOperatorLOr,
111     BinaryConditionalOperator,
112     // The preceding values are available with PGO_HASH_V1.
113 
114     EndOfScope,
115     IfThenBranch,
116     IfElseBranch,
117     GotoStmt,
118     IndirectGotoStmt,
119     BreakStmt,
120     ContinueStmt,
121     ReturnStmt,
122     ThrowExpr,
123     UnaryOperatorLNot,
124     BinaryOperatorLT,
125     BinaryOperatorGT,
126     BinaryOperatorLE,
127     BinaryOperatorGE,
128     BinaryOperatorEQ,
129     BinaryOperatorNE,
130     // The preceding values are available since PGO_HASH_V2.
131 
132     // Keep this last.  It's for the static assert that follows.
133     LastHashType
134   };
135   static_assert(LastHashType <= TooBig, "Too many types in HashType");
136 
137   PGOHash(PGOHashVersion HashVersion)
138       : Working(0), Count(0), HashVersion(HashVersion) {}
139   void combine(HashType Type);
140   uint64_t finalize();
141   PGOHashVersion getHashVersion() const { return HashVersion; }
142 };
143 const int PGOHash::NumBitsPerType;
144 const unsigned PGOHash::NumTypesPerWord;
145 const unsigned PGOHash::TooBig;
146 
147 /// Get the PGO hash version used in the given indexed profile.
148 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
149                                         CodeGenModule &CGM) {
150   if (PGOReader->getVersion() <= 4)
151     return PGO_HASH_V1;
152   if (PGOReader->getVersion() <= 5)
153     return PGO_HASH_V2;
154   return PGO_HASH_V3;
155 }
156 
157 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
158 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
159   using Base = RecursiveASTVisitor<MapRegionCounters>;
160 
161   /// The next counter value to assign.
162   unsigned NextCounter;
163   /// The function hash.
164   PGOHash Hash;
165   /// The map of statements to counters.
166   llvm::DenseMap<const Stmt *, CounterPair> &CounterMap;
167   /// The state of MC/DC Coverage in this function.
168   MCDC::State &MCDCState;
169   /// Maximum number of supported MC/DC conditions in a boolean expression.
170   unsigned MCDCMaxCond;
171   /// The profile version.
172   uint64_t ProfileVersion;
173   /// Diagnostics Engine used to report warnings.
174   DiagnosticsEngine &Diag;
175 
176   MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
177                     llvm::DenseMap<const Stmt *, CounterPair> &CounterMap,
178                     MCDC::State &MCDCState, unsigned MCDCMaxCond,
179                     DiagnosticsEngine &Diag)
180       : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
181         MCDCState(MCDCState), MCDCMaxCond(MCDCMaxCond),
182         ProfileVersion(ProfileVersion), Diag(Diag) {}
183 
184   // Blocks and lambdas are handled as separate functions, so we need not
185   // traverse them in the parent context.
186   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
187   bool TraverseLambdaExpr(LambdaExpr *LE) {
188     // Traverse the captures, but not the body.
189     for (auto C : zip(LE->captures(), LE->capture_inits()))
190       TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
191     return true;
192   }
193   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
194 
195   bool VisitDecl(const Decl *D) {
196     switch (D->getKind()) {
197     default:
198       break;
199     case Decl::Function:
200     case Decl::CXXMethod:
201     case Decl::CXXConstructor:
202     case Decl::CXXDestructor:
203     case Decl::CXXConversion:
204     case Decl::ObjCMethod:
205     case Decl::Block:
206     case Decl::Captured:
207       CounterMap[D->getBody()] = NextCounter++;
208       break;
209     }
210     return true;
211   }
212 
213   /// If \p S gets a fresh counter, update the counter mappings. Return the
214   /// V1 hash of \p S.
215   PGOHash::HashType updateCounterMappings(Stmt *S) {
216     auto Type = getHashType(PGO_HASH_V1, S);
217     if (Type != PGOHash::None)
218       CounterMap[S] = NextCounter++;
219     return Type;
220   }
221 
222   /// The following stacks are used with dataTraverseStmtPre() and
223   /// dataTraverseStmtPost() to track the depth of nested logical operators in a
224   /// boolean expression in a function.  The ultimate purpose is to keep track
225   /// of the number of leaf-level conditions in the boolean expression so that a
226   /// profile bitmap can be allocated based on that number.
227   ///
228   /// The stacks are also used to find error cases and notify the user.  A
229   /// standard logical operator nest for a boolean expression could be in a form
230   /// similar to this: "x = a && b && c && (d || f)"
231   unsigned NumCond = 0;
232   bool SplitNestedLogicalOp = false;
233   SmallVector<const Stmt *, 16> NonLogOpStack;
234   SmallVector<const BinaryOperator *, 16> LogOpStack;
235 
236   // Hook: dataTraverseStmtPre() is invoked prior to visiting an AST Stmt node.
237   bool dataTraverseStmtPre(Stmt *S) {
238     /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
239     if (MCDCMaxCond == 0)
240       return true;
241 
242     /// At the top of the logical operator nest, reset the number of conditions,
243     /// also forget previously seen split nesting cases.
244     if (LogOpStack.empty()) {
245       NumCond = 0;
246       SplitNestedLogicalOp = false;
247     }
248 
249     if (const Expr *E = dyn_cast<Expr>(S)) {
250       const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(E->IgnoreParens());
251       if (BinOp && BinOp->isLogicalOp()) {
252         /// Check for "split-nested" logical operators. This happens when a new
253         /// boolean expression logical-op nest is encountered within an existing
254         /// boolean expression, separated by a non-logical operator.  For
255         /// example, in "x = (a && b && c && foo(d && f))", the "d && f" case
256         /// starts a new boolean expression that is separated from the other
257         /// conditions by the operator foo(). Split-nested cases are not
258         /// supported by MC/DC.
259         SplitNestedLogicalOp = SplitNestedLogicalOp || !NonLogOpStack.empty();
260 
261         LogOpStack.push_back(BinOp);
262         return true;
263       }
264     }
265 
266     /// Keep track of non-logical operators. These are OK as long as we don't
267     /// encounter a new logical operator after seeing one.
268     if (!LogOpStack.empty())
269       NonLogOpStack.push_back(S);
270 
271     return true;
272   }
273 
274   // Hook: dataTraverseStmtPost() is invoked by the AST visitor after visiting
275   // an AST Stmt node.  MC/DC will use it to to signal when the top of a
276   // logical operation (boolean expression) nest is encountered.
277   bool dataTraverseStmtPost(Stmt *S) {
278     /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
279     if (MCDCMaxCond == 0)
280       return true;
281 
282     if (const Expr *E = dyn_cast<Expr>(S)) {
283       const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(E->IgnoreParens());
284       if (BinOp && BinOp->isLogicalOp()) {
285         assert(LogOpStack.back() == BinOp);
286         LogOpStack.pop_back();
287 
288         /// At the top of logical operator nest:
289         if (LogOpStack.empty()) {
290           /// Was the "split-nested" logical operator case encountered?
291           if (SplitNestedLogicalOp) {
292             unsigned DiagID = Diag.getCustomDiagID(
293                 DiagnosticsEngine::Warning,
294                 "unsupported MC/DC boolean expression; "
295                 "contains an operation with a nested boolean expression. "
296                 "Expression will not be covered");
297             Diag.Report(S->getBeginLoc(), DiagID);
298             return true;
299           }
300 
301           /// Was the maximum number of conditions encountered?
302           if (NumCond > MCDCMaxCond) {
303             unsigned DiagID = Diag.getCustomDiagID(
304                 DiagnosticsEngine::Warning,
305                 "unsupported MC/DC boolean expression; "
306                 "number of conditions (%0) exceeds max (%1). "
307                 "Expression will not be covered");
308             Diag.Report(S->getBeginLoc(), DiagID) << NumCond << MCDCMaxCond;
309             return true;
310           }
311 
312           // Otherwise, allocate the Decision.
313           MCDCState.DecisionByStmt[BinOp].BitmapIdx = 0;
314         }
315         return true;
316       }
317     }
318 
319     if (!LogOpStack.empty())
320       NonLogOpStack.pop_back();
321 
322     return true;
323   }
324 
325   /// The RHS of all logical operators gets a fresh counter in order to count
326   /// how many times the RHS evaluates to true or false, depending on the
327   /// semantics of the operator. This is only valid for ">= v7" of the profile
328   /// version so that we facilitate backward compatibility. In addition, in
329   /// order to use MC/DC, count the number of total LHS and RHS conditions.
330   bool VisitBinaryOperator(BinaryOperator *S) {
331     if (S->isLogicalOp()) {
332       if (CodeGenFunction::isInstrumentedCondition(S->getLHS()))
333         NumCond++;
334 
335       if (CodeGenFunction::isInstrumentedCondition(S->getRHS())) {
336         if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
337           CounterMap[S->getRHS()] = NextCounter++;
338 
339         NumCond++;
340       }
341     }
342     return Base::VisitBinaryOperator(S);
343   }
344 
345   bool VisitConditionalOperator(ConditionalOperator *S) {
346     if (llvm::EnableSingleByteCoverage && S->getTrueExpr())
347       CounterMap[S->getTrueExpr()] = NextCounter++;
348     if (llvm::EnableSingleByteCoverage && S->getFalseExpr())
349       CounterMap[S->getFalseExpr()] = NextCounter++;
350     return Base::VisitConditionalOperator(S);
351   }
352 
353   /// Include \p S in the function hash.
354   bool VisitStmt(Stmt *S) {
355     auto Type = updateCounterMappings(S);
356     if (Hash.getHashVersion() != PGO_HASH_V1)
357       Type = getHashType(Hash.getHashVersion(), S);
358     if (Type != PGOHash::None)
359       Hash.combine(Type);
360     return true;
361   }
362 
363   bool TraverseIfStmt(IfStmt *If) {
364     // If we used the V1 hash, use the default traversal.
365     if (Hash.getHashVersion() == PGO_HASH_V1)
366       return Base::TraverseIfStmt(If);
367 
368     // When single byte coverage mode is enabled, add a counter to then and
369     // else.
370     bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
371     for (Stmt *CS : If->children()) {
372       if (!CS || NoSingleByteCoverage)
373         continue;
374       if (CS == If->getThen())
375         CounterMap[If->getThen()] = NextCounter++;
376       else if (CS == If->getElse())
377         CounterMap[If->getElse()] = NextCounter++;
378     }
379 
380     // Otherwise, keep track of which branch we're in while traversing.
381     VisitStmt(If);
382 
383     for (Stmt *CS : If->children()) {
384       if (!CS)
385         continue;
386       if (CS == If->getThen())
387         Hash.combine(PGOHash::IfThenBranch);
388       else if (CS == If->getElse())
389         Hash.combine(PGOHash::IfElseBranch);
390       TraverseStmt(CS);
391     }
392     Hash.combine(PGOHash::EndOfScope);
393     return true;
394   }
395 
396   bool TraverseWhileStmt(WhileStmt *While) {
397     // When single byte coverage mode is enabled, add a counter to condition and
398     // body.
399     bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
400     for (Stmt *CS : While->children()) {
401       if (!CS || NoSingleByteCoverage)
402         continue;
403       if (CS == While->getCond())
404         CounterMap[While->getCond()] = NextCounter++;
405       else if (CS == While->getBody())
406         CounterMap[While->getBody()] = NextCounter++;
407     }
408 
409     Base::TraverseWhileStmt(While);
410     if (Hash.getHashVersion() != PGO_HASH_V1)
411       Hash.combine(PGOHash::EndOfScope);
412     return true;
413   }
414 
415   bool TraverseDoStmt(DoStmt *Do) {
416     // When single byte coverage mode is enabled, add a counter to condition and
417     // body.
418     bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
419     for (Stmt *CS : Do->children()) {
420       if (!CS || NoSingleByteCoverage)
421         continue;
422       if (CS == Do->getCond())
423         CounterMap[Do->getCond()] = NextCounter++;
424       else if (CS == Do->getBody())
425         CounterMap[Do->getBody()] = NextCounter++;
426     }
427 
428     Base::TraverseDoStmt(Do);
429     if (Hash.getHashVersion() != PGO_HASH_V1)
430       Hash.combine(PGOHash::EndOfScope);
431     return true;
432   }
433 
434   bool TraverseForStmt(ForStmt *For) {
435     // When single byte coverage mode is enabled, add a counter to condition,
436     // increment and body.
437     bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
438     for (Stmt *CS : For->children()) {
439       if (!CS || NoSingleByteCoverage)
440         continue;
441       if (CS == For->getCond())
442         CounterMap[For->getCond()] = NextCounter++;
443       else if (CS == For->getInc())
444         CounterMap[For->getInc()] = NextCounter++;
445       else if (CS == For->getBody())
446         CounterMap[For->getBody()] = NextCounter++;
447     }
448 
449     Base::TraverseForStmt(For);
450     if (Hash.getHashVersion() != PGO_HASH_V1)
451       Hash.combine(PGOHash::EndOfScope);
452     return true;
453   }
454 
455   bool TraverseCXXForRangeStmt(CXXForRangeStmt *ForRange) {
456     // When single byte coverage mode is enabled, add a counter to body.
457     bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
458     for (Stmt *CS : ForRange->children()) {
459       if (!CS || NoSingleByteCoverage)
460         continue;
461       if (CS == ForRange->getBody())
462         CounterMap[ForRange->getBody()] = NextCounter++;
463     }
464 
465     Base::TraverseCXXForRangeStmt(ForRange);
466     if (Hash.getHashVersion() != PGO_HASH_V1)
467       Hash.combine(PGOHash::EndOfScope);
468     return true;
469   }
470 
471 // If the statement type \p N is nestable, and its nesting impacts profile
472 // stability, define a custom traversal which tracks the end of the statement
473 // in the hash (provided we're not using the V1 hash).
474 #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
475   bool Traverse##N(N *S) {                                                     \
476     Base::Traverse##N(S);                                                      \
477     if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
478       Hash.combine(PGOHash::EndOfScope);                                       \
479     return true;                                                               \
480   }
481 
482   DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
483   DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
484   DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
485 
486   /// Get version \p HashVersion of the PGO hash for \p S.
487   PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
488     switch (S->getStmtClass()) {
489     default:
490       break;
491     case Stmt::LabelStmtClass:
492       return PGOHash::LabelStmt;
493     case Stmt::WhileStmtClass:
494       return PGOHash::WhileStmt;
495     case Stmt::DoStmtClass:
496       return PGOHash::DoStmt;
497     case Stmt::ForStmtClass:
498       return PGOHash::ForStmt;
499     case Stmt::CXXForRangeStmtClass:
500       return PGOHash::CXXForRangeStmt;
501     case Stmt::ObjCForCollectionStmtClass:
502       return PGOHash::ObjCForCollectionStmt;
503     case Stmt::SwitchStmtClass:
504       return PGOHash::SwitchStmt;
505     case Stmt::CaseStmtClass:
506       return PGOHash::CaseStmt;
507     case Stmt::DefaultStmtClass:
508       return PGOHash::DefaultStmt;
509     case Stmt::IfStmtClass:
510       return PGOHash::IfStmt;
511     case Stmt::CXXTryStmtClass:
512       return PGOHash::CXXTryStmt;
513     case Stmt::CXXCatchStmtClass:
514       return PGOHash::CXXCatchStmt;
515     case Stmt::ConditionalOperatorClass:
516       return PGOHash::ConditionalOperator;
517     case Stmt::BinaryConditionalOperatorClass:
518       return PGOHash::BinaryConditionalOperator;
519     case Stmt::BinaryOperatorClass: {
520       const BinaryOperator *BO = cast<BinaryOperator>(S);
521       if (BO->getOpcode() == BO_LAnd)
522         return PGOHash::BinaryOperatorLAnd;
523       if (BO->getOpcode() == BO_LOr)
524         return PGOHash::BinaryOperatorLOr;
525       if (HashVersion >= PGO_HASH_V2) {
526         switch (BO->getOpcode()) {
527         default:
528           break;
529         case BO_LT:
530           return PGOHash::BinaryOperatorLT;
531         case BO_GT:
532           return PGOHash::BinaryOperatorGT;
533         case BO_LE:
534           return PGOHash::BinaryOperatorLE;
535         case BO_GE:
536           return PGOHash::BinaryOperatorGE;
537         case BO_EQ:
538           return PGOHash::BinaryOperatorEQ;
539         case BO_NE:
540           return PGOHash::BinaryOperatorNE;
541         }
542       }
543       break;
544     }
545     }
546 
547     if (HashVersion >= PGO_HASH_V2) {
548       switch (S->getStmtClass()) {
549       default:
550         break;
551       case Stmt::GotoStmtClass:
552         return PGOHash::GotoStmt;
553       case Stmt::IndirectGotoStmtClass:
554         return PGOHash::IndirectGotoStmt;
555       case Stmt::BreakStmtClass:
556         return PGOHash::BreakStmt;
557       case Stmt::ContinueStmtClass:
558         return PGOHash::ContinueStmt;
559       case Stmt::ReturnStmtClass:
560         return PGOHash::ReturnStmt;
561       case Stmt::CXXThrowExprClass:
562         return PGOHash::ThrowExpr;
563       case Stmt::UnaryOperatorClass: {
564         const UnaryOperator *UO = cast<UnaryOperator>(S);
565         if (UO->getOpcode() == UO_LNot)
566           return PGOHash::UnaryOperatorLNot;
567         break;
568       }
569       }
570     }
571 
572     return PGOHash::None;
573   }
574 };
575 
576 /// A StmtVisitor that propagates the raw counts through the AST and
577 /// records the count at statements where the value may change.
578 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
579   /// PGO state.
580   CodeGenPGO &PGO;
581 
582   /// A flag that is set when the current count should be recorded on the
583   /// next statement, such as at the exit of a loop.
584   bool RecordNextStmtCount;
585 
586   /// The count at the current location in the traversal.
587   uint64_t CurrentCount;
588 
589   /// The map of statements to count values.
590   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
591 
592   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
593   struct BreakContinue {
594     uint64_t BreakCount = 0;
595     uint64_t ContinueCount = 0;
596     BreakContinue() = default;
597   };
598   SmallVector<BreakContinue, 8> BreakContinueStack;
599 
600   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
601                       CodeGenPGO &PGO)
602       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
603 
604   void RecordStmtCount(const Stmt *S) {
605     if (RecordNextStmtCount) {
606       CountMap[S] = CurrentCount;
607       RecordNextStmtCount = false;
608     }
609   }
610 
611   /// Set and return the current count.
612   uint64_t setCount(uint64_t Count) {
613     CurrentCount = Count;
614     return Count;
615   }
616 
617   void VisitStmt(const Stmt *S) {
618     RecordStmtCount(S);
619     for (const Stmt *Child : S->children())
620       if (Child)
621         this->Visit(Child);
622   }
623 
624   void VisitFunctionDecl(const FunctionDecl *D) {
625     // Counter tracks entry to the function body.
626     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
627     CountMap[D->getBody()] = BodyCount;
628     Visit(D->getBody());
629   }
630 
631   // Skip lambda expressions. We visit these as FunctionDecls when we're
632   // generating them and aren't interested in the body when generating a
633   // parent context.
634   void VisitLambdaExpr(const LambdaExpr *LE) {}
635 
636   void VisitCapturedDecl(const CapturedDecl *D) {
637     // Counter tracks entry to the capture body.
638     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
639     CountMap[D->getBody()] = BodyCount;
640     Visit(D->getBody());
641   }
642 
643   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
644     // Counter tracks entry to the method body.
645     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
646     CountMap[D->getBody()] = BodyCount;
647     Visit(D->getBody());
648   }
649 
650   void VisitBlockDecl(const BlockDecl *D) {
651     // Counter tracks entry to the block body.
652     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
653     CountMap[D->getBody()] = BodyCount;
654     Visit(D->getBody());
655   }
656 
657   void VisitReturnStmt(const ReturnStmt *S) {
658     RecordStmtCount(S);
659     if (S->getRetValue())
660       Visit(S->getRetValue());
661     CurrentCount = 0;
662     RecordNextStmtCount = true;
663   }
664 
665   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
666     RecordStmtCount(E);
667     if (E->getSubExpr())
668       Visit(E->getSubExpr());
669     CurrentCount = 0;
670     RecordNextStmtCount = true;
671   }
672 
673   void VisitGotoStmt(const GotoStmt *S) {
674     RecordStmtCount(S);
675     CurrentCount = 0;
676     RecordNextStmtCount = true;
677   }
678 
679   void VisitLabelStmt(const LabelStmt *S) {
680     RecordNextStmtCount = false;
681     // Counter tracks the block following the label.
682     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
683     CountMap[S] = BlockCount;
684     Visit(S->getSubStmt());
685   }
686 
687   void VisitBreakStmt(const BreakStmt *S) {
688     RecordStmtCount(S);
689     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
690     BreakContinueStack.back().BreakCount += CurrentCount;
691     CurrentCount = 0;
692     RecordNextStmtCount = true;
693   }
694 
695   void VisitContinueStmt(const ContinueStmt *S) {
696     RecordStmtCount(S);
697     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
698     BreakContinueStack.back().ContinueCount += CurrentCount;
699     CurrentCount = 0;
700     RecordNextStmtCount = true;
701   }
702 
703   void VisitWhileStmt(const WhileStmt *S) {
704     RecordStmtCount(S);
705     uint64_t ParentCount = CurrentCount;
706 
707     BreakContinueStack.push_back(BreakContinue());
708     // Visit the body region first so the break/continue adjustments can be
709     // included when visiting the condition.
710     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
711     CountMap[S->getBody()] = CurrentCount;
712     Visit(S->getBody());
713     uint64_t BackedgeCount = CurrentCount;
714 
715     // ...then go back and propagate counts through the condition. The count
716     // at the start of the condition is the sum of the incoming edges,
717     // the backedge from the end of the loop body, and the edges from
718     // continue statements.
719     BreakContinue BC = BreakContinueStack.pop_back_val();
720     uint64_t CondCount =
721         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
722     CountMap[S->getCond()] = CondCount;
723     Visit(S->getCond());
724     setCount(BC.BreakCount + CondCount - BodyCount);
725     RecordNextStmtCount = true;
726   }
727 
728   void VisitDoStmt(const DoStmt *S) {
729     RecordStmtCount(S);
730     uint64_t LoopCount = PGO.getRegionCount(S);
731 
732     BreakContinueStack.push_back(BreakContinue());
733     // The count doesn't include the fallthrough from the parent scope. Add it.
734     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
735     CountMap[S->getBody()] = BodyCount;
736     Visit(S->getBody());
737     uint64_t BackedgeCount = CurrentCount;
738 
739     BreakContinue BC = BreakContinueStack.pop_back_val();
740     // The count at the start of the condition is equal to the count at the
741     // end of the body, plus any continues.
742     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
743     CountMap[S->getCond()] = CondCount;
744     Visit(S->getCond());
745     setCount(BC.BreakCount + CondCount - LoopCount);
746     RecordNextStmtCount = true;
747   }
748 
749   void VisitForStmt(const ForStmt *S) {
750     RecordStmtCount(S);
751     if (S->getInit())
752       Visit(S->getInit());
753 
754     uint64_t ParentCount = CurrentCount;
755 
756     BreakContinueStack.push_back(BreakContinue());
757     // Visit the body region first. (This is basically the same as a while
758     // loop; see further comments in VisitWhileStmt.)
759     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
760     CountMap[S->getBody()] = BodyCount;
761     Visit(S->getBody());
762     uint64_t BackedgeCount = CurrentCount;
763     BreakContinue BC = BreakContinueStack.pop_back_val();
764 
765     // The increment is essentially part of the body but it needs to include
766     // the count for all the continue statements.
767     if (S->getInc()) {
768       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
769       CountMap[S->getInc()] = IncCount;
770       Visit(S->getInc());
771     }
772 
773     // ...then go back and propagate counts through the condition.
774     uint64_t CondCount =
775         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
776     if (S->getCond()) {
777       CountMap[S->getCond()] = CondCount;
778       Visit(S->getCond());
779     }
780     setCount(BC.BreakCount + CondCount - BodyCount);
781     RecordNextStmtCount = true;
782   }
783 
784   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
785     RecordStmtCount(S);
786     if (S->getInit())
787       Visit(S->getInit());
788     Visit(S->getLoopVarStmt());
789     Visit(S->getRangeStmt());
790     Visit(S->getBeginStmt());
791     Visit(S->getEndStmt());
792 
793     uint64_t ParentCount = CurrentCount;
794     BreakContinueStack.push_back(BreakContinue());
795     // Visit the body region first. (This is basically the same as a while
796     // loop; see further comments in VisitWhileStmt.)
797     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
798     CountMap[S->getBody()] = BodyCount;
799     Visit(S->getBody());
800     uint64_t BackedgeCount = CurrentCount;
801     BreakContinue BC = BreakContinueStack.pop_back_val();
802 
803     // The increment is essentially part of the body but it needs to include
804     // the count for all the continue statements.
805     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
806     CountMap[S->getInc()] = IncCount;
807     Visit(S->getInc());
808 
809     // ...then go back and propagate counts through the condition.
810     uint64_t CondCount =
811         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
812     CountMap[S->getCond()] = CondCount;
813     Visit(S->getCond());
814     setCount(BC.BreakCount + CondCount - BodyCount);
815     RecordNextStmtCount = true;
816   }
817 
818   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
819     RecordStmtCount(S);
820     Visit(S->getElement());
821     uint64_t ParentCount = CurrentCount;
822     BreakContinueStack.push_back(BreakContinue());
823     // Counter tracks the body of the loop.
824     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
825     CountMap[S->getBody()] = BodyCount;
826     Visit(S->getBody());
827     uint64_t BackedgeCount = CurrentCount;
828     BreakContinue BC = BreakContinueStack.pop_back_val();
829 
830     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
831              BodyCount);
832     RecordNextStmtCount = true;
833   }
834 
835   void VisitSwitchStmt(const SwitchStmt *S) {
836     RecordStmtCount(S);
837     if (S->getInit())
838       Visit(S->getInit());
839     Visit(S->getCond());
840     CurrentCount = 0;
841     BreakContinueStack.push_back(BreakContinue());
842     Visit(S->getBody());
843     // If the switch is inside a loop, add the continue counts.
844     BreakContinue BC = BreakContinueStack.pop_back_val();
845     if (!BreakContinueStack.empty())
846       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
847     // Counter tracks the exit block of the switch.
848     setCount(PGO.getRegionCount(S));
849     RecordNextStmtCount = true;
850   }
851 
852   void VisitSwitchCase(const SwitchCase *S) {
853     RecordNextStmtCount = false;
854     // Counter for this particular case. This counts only jumps from the
855     // switch header and does not include fallthrough from the case before
856     // this one.
857     uint64_t CaseCount = PGO.getRegionCount(S);
858     setCount(CurrentCount + CaseCount);
859     // We need the count without fallthrough in the mapping, so it's more useful
860     // for branch probabilities.
861     CountMap[S] = CaseCount;
862     RecordNextStmtCount = true;
863     Visit(S->getSubStmt());
864   }
865 
866   void VisitIfStmt(const IfStmt *S) {
867     RecordStmtCount(S);
868 
869     if (S->isConsteval()) {
870       const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse();
871       if (Stm)
872         Visit(Stm);
873       return;
874     }
875 
876     uint64_t ParentCount = CurrentCount;
877     if (S->getInit())
878       Visit(S->getInit());
879     Visit(S->getCond());
880 
881     // Counter tracks the "then" part of an if statement. The count for
882     // the "else" part, if it exists, will be calculated from this counter.
883     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
884     CountMap[S->getThen()] = ThenCount;
885     Visit(S->getThen());
886     uint64_t OutCount = CurrentCount;
887 
888     uint64_t ElseCount = ParentCount - ThenCount;
889     if (S->getElse()) {
890       setCount(ElseCount);
891       CountMap[S->getElse()] = ElseCount;
892       Visit(S->getElse());
893       OutCount += CurrentCount;
894     } else
895       OutCount += ElseCount;
896     setCount(OutCount);
897     RecordNextStmtCount = true;
898   }
899 
900   void VisitCXXTryStmt(const CXXTryStmt *S) {
901     RecordStmtCount(S);
902     Visit(S->getTryBlock());
903     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
904       Visit(S->getHandler(I));
905     // Counter tracks the continuation block of the try statement.
906     setCount(PGO.getRegionCount(S));
907     RecordNextStmtCount = true;
908   }
909 
910   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
911     RecordNextStmtCount = false;
912     // Counter tracks the catch statement's handler block.
913     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
914     CountMap[S] = CatchCount;
915     Visit(S->getHandlerBlock());
916   }
917 
918   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
919     RecordStmtCount(E);
920     uint64_t ParentCount = CurrentCount;
921     Visit(E->getCond());
922 
923     // Counter tracks the "true" part of a conditional operator. The
924     // count in the "false" part will be calculated from this counter.
925     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
926     CountMap[E->getTrueExpr()] = TrueCount;
927     Visit(E->getTrueExpr());
928     uint64_t OutCount = CurrentCount;
929 
930     uint64_t FalseCount = setCount(ParentCount - TrueCount);
931     CountMap[E->getFalseExpr()] = FalseCount;
932     Visit(E->getFalseExpr());
933     OutCount += CurrentCount;
934 
935     setCount(OutCount);
936     RecordNextStmtCount = true;
937   }
938 
939   void VisitBinLAnd(const BinaryOperator *E) {
940     RecordStmtCount(E);
941     uint64_t ParentCount = CurrentCount;
942     Visit(E->getLHS());
943     // Counter tracks the right hand side of a logical and operator.
944     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
945     CountMap[E->getRHS()] = RHSCount;
946     Visit(E->getRHS());
947     setCount(ParentCount + RHSCount - CurrentCount);
948     RecordNextStmtCount = true;
949   }
950 
951   void VisitBinLOr(const BinaryOperator *E) {
952     RecordStmtCount(E);
953     uint64_t ParentCount = CurrentCount;
954     Visit(E->getLHS());
955     // Counter tracks the right hand side of a logical or operator.
956     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
957     CountMap[E->getRHS()] = RHSCount;
958     Visit(E->getRHS());
959     setCount(ParentCount + RHSCount - CurrentCount);
960     RecordNextStmtCount = true;
961   }
962 };
963 } // end anonymous namespace
964 
965 void PGOHash::combine(HashType Type) {
966   // Check that we never combine 0 and only have six bits.
967   assert(Type && "Hash is invalid: unexpected type 0");
968   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
969 
970   // Pass through MD5 if enough work has built up.
971   if (Count && Count % NumTypesPerWord == 0) {
972     using namespace llvm::support;
973     uint64_t Swapped =
974         endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
975     MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
976     Working = 0;
977   }
978 
979   // Accumulate the current type.
980   ++Count;
981   Working = Working << NumBitsPerType | Type;
982 }
983 
984 uint64_t PGOHash::finalize() {
985   // Use Working as the hash directly if we never used MD5.
986   if (Count <= NumTypesPerWord)
987     // No need to byte swap here, since none of the math was endian-dependent.
988     // This number will be byte-swapped as required on endianness transitions,
989     // so we will see the same value on the other side.
990     return Working;
991 
992   // Check for remaining work in Working.
993   if (Working) {
994     // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
995     // is buggy because it converts a uint64_t into an array of uint8_t.
996     if (HashVersion < PGO_HASH_V3) {
997       MD5.update({(uint8_t)Working});
998     } else {
999       using namespace llvm::support;
1000       uint64_t Swapped =
1001           endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
1002       MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
1003     }
1004   }
1005 
1006   // Finalize the MD5 and return the hash.
1007   llvm::MD5::MD5Result Result;
1008   MD5.final(Result);
1009   return Result.low();
1010 }
1011 
1012 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
1013   const Decl *D = GD.getDecl();
1014   if (!D->hasBody())
1015     return;
1016 
1017   // Skip CUDA/HIP kernel launch stub functions.
1018   if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
1019       D->hasAttr<CUDAGlobalAttr>())
1020     return;
1021 
1022   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
1023   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1024   if (!InstrumentRegions && !PGOReader)
1025     return;
1026   if (D->isImplicit())
1027     return;
1028   // Constructors and destructors may be represented by several functions in IR.
1029   // If so, instrument only base variant, others are implemented by delegation
1030   // to the base one, it would be counted twice otherwise.
1031   if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
1032     if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
1033       if (GD.getCtorType() != Ctor_Base &&
1034           CodeGenFunction::IsConstructorDelegationValid(CCD))
1035         return;
1036   }
1037   if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
1038     return;
1039 
1040   CGM.ClearUnusedCoverageMapping(D);
1041   if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
1042     return;
1043   if (Fn->hasFnAttribute(llvm::Attribute::SkipProfile))
1044     return;
1045 
1046   SourceManager &SM = CGM.getContext().getSourceManager();
1047   if (!llvm::coverage::SystemHeadersCoverage &&
1048       SM.isInSystemHeader(D->getLocation()))
1049     return;
1050 
1051   setFuncName(Fn);
1052 
1053   mapRegionCounters(D);
1054   if (CGM.getCodeGenOpts().CoverageMapping)
1055     emitCounterRegionMapping(D);
1056   if (PGOReader) {
1057     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
1058     computeRegionCounts(D);
1059     applyFunctionAttributes(PGOReader, Fn);
1060   }
1061 }
1062 
1063 void CodeGenPGO::mapRegionCounters(const Decl *D) {
1064   // Use the latest hash version when inserting instrumentation, but use the
1065   // version in the indexed profile if we're reading PGO data.
1066   PGOHashVersion HashVersion = PGO_HASH_LATEST;
1067   uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
1068   if (auto *PGOReader = CGM.getPGOReader()) {
1069     HashVersion = getPGOHashVersion(PGOReader, CGM);
1070     ProfileVersion = PGOReader->getVersion();
1071   }
1072 
1073   // If MC/DC is enabled, set the MaxConditions to a preset value. Otherwise,
1074   // set it to zero. This value impacts the number of conditions accepted in a
1075   // given boolean expression, which impacts the size of the bitmap used to
1076   // track test vector execution for that boolean expression.  Because the
1077   // bitmap scales exponentially (2^n) based on the number of conditions seen,
1078   // the maximum value is hard-coded at 6 conditions, which is more than enough
1079   // for most embedded applications. Setting a maximum value prevents the
1080   // bitmap footprint from growing too large without the user's knowledge. In
1081   // the future, this value could be adjusted with a command-line option.
1082   unsigned MCDCMaxConditions =
1083       (CGM.getCodeGenOpts().MCDCCoverage ? CGM.getCodeGenOpts().MCDCMaxConds
1084                                          : 0);
1085 
1086   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, CounterPair>);
1087   RegionMCDCState.reset(new MCDC::State);
1088   MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap,
1089                            *RegionMCDCState, MCDCMaxConditions, CGM.getDiags());
1090   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
1091     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
1092   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
1093     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
1094   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
1095     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
1096   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
1097     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
1098   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
1099   NumRegionCounters = Walker.NextCounter;
1100   FunctionHash = Walker.Hash.finalize();
1101 }
1102 
1103 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
1104   if (!D->getBody())
1105     return true;
1106 
1107   // Skip host-only functions in the CUDA device compilation and device-only
1108   // functions in the host compilation. Just roughly filter them out based on
1109   // the function attributes. If there are effectively host-only or device-only
1110   // ones, their coverage mapping may still be generated.
1111   if (CGM.getLangOpts().CUDA &&
1112       ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
1113         !D->hasAttr<CUDAGlobalAttr>()) ||
1114        (!CGM.getLangOpts().CUDAIsDevice &&
1115         (D->hasAttr<CUDAGlobalAttr>() ||
1116          (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
1117     return true;
1118 
1119   // Don't map the functions in system headers.
1120   const auto &SM = CGM.getContext().getSourceManager();
1121   auto Loc = D->getBody()->getBeginLoc();
1122   return !llvm::coverage::SystemHeadersCoverage && SM.isInSystemHeader(Loc);
1123 }
1124 
1125 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
1126   if (skipRegionMappingForDecl(D))
1127     return;
1128 
1129   std::string CoverageMapping;
1130   llvm::raw_string_ostream OS(CoverageMapping);
1131   RegionMCDCState->BranchByStmt.clear();
1132   CoverageMappingGen MappingGen(
1133       *CGM.getCoverageMapping(), CGM.getContext().getSourceManager(),
1134       CGM.getLangOpts(), RegionCounterMap.get(), RegionMCDCState.get());
1135   MappingGen.emitCounterMapping(D, OS);
1136 
1137   if (CoverageMapping.empty())
1138     return;
1139 
1140   CGM.getCoverageMapping()->addFunctionMappingRecord(
1141       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
1142 }
1143 
1144 void
1145 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
1146                                     llvm::GlobalValue::LinkageTypes Linkage) {
1147   if (skipRegionMappingForDecl(D))
1148     return;
1149 
1150   std::string CoverageMapping;
1151   llvm::raw_string_ostream OS(CoverageMapping);
1152   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
1153                                 CGM.getContext().getSourceManager(),
1154                                 CGM.getLangOpts());
1155   MappingGen.emitEmptyMapping(D, OS);
1156 
1157   if (CoverageMapping.empty())
1158     return;
1159 
1160   setFuncName(Name, Linkage);
1161   CGM.getCoverageMapping()->addFunctionMappingRecord(
1162       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
1163 }
1164 
1165 void CodeGenPGO::computeRegionCounts(const Decl *D) {
1166   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
1167   ComputeRegionCounts Walker(*StmtCountMap, *this);
1168   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
1169     Walker.VisitFunctionDecl(FD);
1170   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
1171     Walker.VisitObjCMethodDecl(MD);
1172   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
1173     Walker.VisitBlockDecl(BD);
1174   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
1175     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
1176 }
1177 
1178 void
1179 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
1180                                     llvm::Function *Fn) {
1181   if (!haveRegionCounts())
1182     return;
1183 
1184   uint64_t FunctionCount = getRegionCount(nullptr);
1185   Fn->setEntryCount(FunctionCount);
1186 }
1187 
1188 std::pair<bool, bool> CodeGenPGO::getIsCounterPair(const Stmt *S) const {
1189   if (!RegionCounterMap)
1190     return {false, false};
1191 
1192   auto I = RegionCounterMap->find(S);
1193   if (I == RegionCounterMap->end())
1194     return {false, false};
1195 
1196   return {I->second.Executed.hasValue(), I->second.Skipped.hasValue()};
1197 }
1198 
1199 void CodeGenPGO::emitCounterSetOrIncrement(CGBuilderTy &Builder, const Stmt *S,
1200                                            llvm::Value *StepV) {
1201   if (!RegionCounterMap || !Builder.GetInsertBlock())
1202     return;
1203 
1204   unsigned Counter = (*RegionCounterMap)[S].Executed;
1205 
1206   // Make sure that pointer to global is passed in with zero addrspace
1207   // This is relevant during GPU profiling
1208   auto *NormalizedFuncNameVarPtr =
1209       llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast(
1210           FuncNameVar, llvm::PointerType::get(CGM.getLLVMContext(), 0));
1211 
1212   llvm::Value *Args[] = {
1213       NormalizedFuncNameVarPtr, Builder.getInt64(FunctionHash),
1214       Builder.getInt32(NumRegionCounters), Builder.getInt32(Counter), StepV};
1215 
1216   if (llvm::EnableSingleByteCoverage)
1217     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_cover),
1218                        ArrayRef(Args, 4));
1219   else if (!StepV)
1220     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
1221                        ArrayRef(Args, 4));
1222   else
1223     Builder.CreateCall(
1224         CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step), Args);
1225 }
1226 
1227 bool CodeGenPGO::canEmitMCDCCoverage(const CGBuilderTy &Builder) {
1228   return (CGM.getCodeGenOpts().hasProfileClangInstr() &&
1229           CGM.getCodeGenOpts().MCDCCoverage && Builder.GetInsertBlock());
1230 }
1231 
1232 void CodeGenPGO::emitMCDCParameters(CGBuilderTy &Builder) {
1233   if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1234     return;
1235 
1236   auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1237 
1238   // Emit intrinsic representing MCDC bitmap parameters at function entry.
1239   // This is used by the instrumentation pass, but it isn't actually lowered to
1240   // anything.
1241   llvm::Value *Args[3] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1242                           Builder.getInt64(FunctionHash),
1243                           Builder.getInt32(RegionMCDCState->BitmapBits)};
1244   Builder.CreateCall(
1245       CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_parameters), Args);
1246 }
1247 
1248 void CodeGenPGO::emitMCDCTestVectorBitmapUpdate(CGBuilderTy &Builder,
1249                                                 const Expr *S,
1250                                                 Address MCDCCondBitmapAddr,
1251                                                 CodeGenFunction &CGF) {
1252   if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1253     return;
1254 
1255   S = S->IgnoreParens();
1256 
1257   auto DecisionStateIter = RegionMCDCState->DecisionByStmt.find(S);
1258   if (DecisionStateIter == RegionMCDCState->DecisionByStmt.end())
1259     return;
1260 
1261   // Don't create tvbitmap_update if the record is allocated but excluded.
1262   // Or `bitmap |= (1 << 0)` would be wrongly executed to the next bitmap.
1263   if (DecisionStateIter->second.Indices.size() == 0)
1264     return;
1265 
1266   // Extract the offset of the global bitmap associated with this expression.
1267   unsigned MCDCTestVectorBitmapOffset = DecisionStateIter->second.BitmapIdx;
1268   auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1269 
1270   // Emit intrinsic responsible for updating the global bitmap corresponding to
1271   // a boolean expression. The index being set is based on the value loaded
1272   // from a pointer to a dedicated temporary value on the stack that is itself
1273   // updated via emitMCDCCondBitmapReset() and emitMCDCCondBitmapUpdate(). The
1274   // index represents an executed test vector.
1275   llvm::Value *Args[4] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1276                           Builder.getInt64(FunctionHash),
1277                           Builder.getInt32(MCDCTestVectorBitmapOffset),
1278                           MCDCCondBitmapAddr.emitRawPointer(CGF)};
1279   Builder.CreateCall(
1280       CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_tvbitmap_update), Args);
1281 }
1282 
1283 void CodeGenPGO::emitMCDCCondBitmapReset(CGBuilderTy &Builder, const Expr *S,
1284                                          Address MCDCCondBitmapAddr) {
1285   if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1286     return;
1287 
1288   S = S->IgnoreParens();
1289 
1290   if (!RegionMCDCState->DecisionByStmt.contains(S))
1291     return;
1292 
1293   // Emit intrinsic that resets a dedicated temporary value on the stack to 0.
1294   Builder.CreateStore(Builder.getInt32(0), MCDCCondBitmapAddr);
1295 }
1296 
1297 void CodeGenPGO::emitMCDCCondBitmapUpdate(CGBuilderTy &Builder, const Expr *S,
1298                                           Address MCDCCondBitmapAddr,
1299                                           llvm::Value *Val,
1300                                           CodeGenFunction &CGF) {
1301   if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1302     return;
1303 
1304   // Even though, for simplicity, parentheses and unary logical-NOT operators
1305   // are considered part of their underlying condition for both MC/DC and
1306   // branch coverage, the condition IDs themselves are assigned and tracked
1307   // using the underlying condition itself.  This is done solely for
1308   // consistency since parentheses and logical-NOTs are ignored when checking
1309   // whether the condition is actually an instrumentable condition. This can
1310   // also make debugging a bit easier.
1311   S = CodeGenFunction::stripCond(S);
1312 
1313   auto BranchStateIter = RegionMCDCState->BranchByStmt.find(S);
1314   if (BranchStateIter == RegionMCDCState->BranchByStmt.end())
1315     return;
1316 
1317   // Extract the ID of the condition we are setting in the bitmap.
1318   const auto &Branch = BranchStateIter->second;
1319   assert(Branch.ID >= 0 && "Condition has no ID!");
1320   assert(Branch.DecisionStmt);
1321 
1322   // Cancel the emission if the Decision is erased after the allocation.
1323   const auto DecisionIter =
1324       RegionMCDCState->DecisionByStmt.find(Branch.DecisionStmt);
1325   if (DecisionIter == RegionMCDCState->DecisionByStmt.end())
1326     return;
1327 
1328   const auto &TVIdxs = DecisionIter->second.Indices[Branch.ID];
1329 
1330   auto *CurTV = Builder.CreateLoad(MCDCCondBitmapAddr,
1331                                    "mcdc." + Twine(Branch.ID + 1) + ".cur");
1332   auto *NewTV = Builder.CreateAdd(CurTV, Builder.getInt32(TVIdxs[true]));
1333   NewTV = Builder.CreateSelect(
1334       Val, NewTV, Builder.CreateAdd(CurTV, Builder.getInt32(TVIdxs[false])));
1335   Builder.CreateStore(NewTV, MCDCCondBitmapAddr);
1336 }
1337 
1338 void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
1339   if (CGM.getCodeGenOpts().hasProfileClangInstr())
1340     M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling",
1341                     uint32_t(EnableValueProfiling));
1342 }
1343 
1344 void CodeGenPGO::setProfileVersion(llvm::Module &M) {
1345   if (CGM.getCodeGenOpts().hasProfileClangInstr() &&
1346       llvm::EnableSingleByteCoverage) {
1347     const StringRef VarName(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
1348     llvm::Type *IntTy64 = llvm::Type::getInt64Ty(M.getContext());
1349     uint64_t ProfileVersion =
1350         (INSTR_PROF_RAW_VERSION | VARIANT_MASK_BYTE_COVERAGE);
1351 
1352     auto IRLevelVersionVariable = new llvm::GlobalVariable(
1353         M, IntTy64, true, llvm::GlobalValue::WeakAnyLinkage,
1354         llvm::Constant::getIntegerValue(IntTy64,
1355                                         llvm::APInt(64, ProfileVersion)),
1356         VarName);
1357 
1358     IRLevelVersionVariable->setVisibility(llvm::GlobalValue::HiddenVisibility);
1359     llvm::Triple TT(M.getTargetTriple());
1360     if (TT.supportsCOMDAT()) {
1361       IRLevelVersionVariable->setLinkage(llvm::GlobalValue::ExternalLinkage);
1362       IRLevelVersionVariable->setComdat(M.getOrInsertComdat(VarName));
1363     }
1364     IRLevelVersionVariable->setDSOLocal(true);
1365   }
1366 }
1367 
1368 // This method either inserts a call to the profile run-time during
1369 // instrumentation or puts profile data into metadata for PGO use.
1370 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
1371     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
1372 
1373   if (!EnableValueProfiling)
1374     return;
1375 
1376   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
1377     return;
1378 
1379   if (isa<llvm::Constant>(ValuePtr))
1380     return;
1381 
1382   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
1383   if (InstrumentValueSites && RegionCounterMap) {
1384     auto BuilderInsertPoint = Builder.saveIP();
1385     Builder.SetInsertPoint(ValueSite);
1386     llvm::Value *Args[5] = {
1387         FuncNameVar,
1388         Builder.getInt64(FunctionHash),
1389         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
1390         Builder.getInt32(ValueKind),
1391         Builder.getInt32(NumValueSites[ValueKind]++)
1392     };
1393     Builder.CreateCall(
1394         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
1395     Builder.restoreIP(BuilderInsertPoint);
1396     return;
1397   }
1398 
1399   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1400   if (PGOReader && haveRegionCounts()) {
1401     // We record the top most called three functions at each call site.
1402     // Profile metadata contains "VP" string identifying this metadata
1403     // as value profiling data, then a uint32_t value for the value profiling
1404     // kind, a uint64_t value for the total number of times the call is
1405     // executed, followed by the function hash and execution count (uint64_t)
1406     // pairs for each function.
1407     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1408       return;
1409 
1410     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
1411                             (llvm::InstrProfValueKind)ValueKind,
1412                             NumValueSites[ValueKind]);
1413 
1414     NumValueSites[ValueKind]++;
1415   }
1416 }
1417 
1418 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1419                                   bool IsInMainFile) {
1420   CGM.getPGOStats().addVisited(IsInMainFile);
1421   RegionCounts.clear();
1422   llvm::Expected<llvm::InstrProfRecord> RecordExpected =
1423       PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1424   if (auto E = RecordExpected.takeError()) {
1425     auto IPE = std::get<0>(llvm::InstrProfError::take(std::move(E)));
1426     if (IPE == llvm::instrprof_error::unknown_function)
1427       CGM.getPGOStats().addMissing(IsInMainFile);
1428     else if (IPE == llvm::instrprof_error::hash_mismatch)
1429       CGM.getPGOStats().addMismatched(IsInMainFile);
1430     else if (IPE == llvm::instrprof_error::malformed)
1431       // TODO: Consider a more specific warning for this case.
1432       CGM.getPGOStats().addMismatched(IsInMainFile);
1433     return;
1434   }
1435   ProfRecord =
1436       std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1437   RegionCounts = ProfRecord->Counts;
1438 }
1439 
1440 /// Calculate what to divide by to scale weights.
1441 ///
1442 /// Given the maximum weight, calculate a divisor that will scale all the
1443 /// weights to strictly less than UINT32_MAX.
1444 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1445   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1446 }
1447 
1448 /// Scale an individual branch weight (and add 1).
1449 ///
1450 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1451 ///
1452 /// According to Laplace's Rule of Succession, it is better to compute the
1453 /// weight based on the count plus 1, so universally add 1 to the value.
1454 ///
1455 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1456 /// greater than \c Weight.
1457 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1458   assert(Scale && "scale by 0?");
1459   uint64_t Scaled = Weight / Scale + 1;
1460   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1461   return Scaled;
1462 }
1463 
1464 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1465                                                     uint64_t FalseCount) const {
1466   // Check for empty weights.
1467   if (!TrueCount && !FalseCount)
1468     return nullptr;
1469 
1470   // Calculate how to scale down to 32-bits.
1471   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1472 
1473   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1474   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1475                                       scaleBranchWeight(FalseCount, Scale));
1476 }
1477 
1478 llvm::MDNode *
1479 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1480   // We need at least two elements to create meaningful weights.
1481   if (Weights.size() < 2)
1482     return nullptr;
1483 
1484   // Check for empty weights.
1485   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1486   if (MaxWeight == 0)
1487     return nullptr;
1488 
1489   // Calculate how to scale down to 32-bits.
1490   uint64_t Scale = calculateWeightScale(MaxWeight);
1491 
1492   SmallVector<uint32_t, 16> ScaledWeights;
1493   ScaledWeights.reserve(Weights.size());
1494   for (uint64_t W : Weights)
1495     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1496 
1497   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1498   return MDHelper.createBranchWeights(ScaledWeights);
1499 }
1500 
1501 llvm::MDNode *
1502 CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1503                                              uint64_t LoopCount) const {
1504   if (!PGO.haveRegionCounts())
1505     return nullptr;
1506   std::optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1507   if (!CondCount || *CondCount == 0)
1508     return nullptr;
1509   return createProfileWeights(LoopCount,
1510                               std::max(*CondCount, LoopCount) - LoopCount);
1511 }
1512