1e5dd7070Spatrick //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2e5dd7070Spatrick //
3e5dd7070Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e5dd7070Spatrick // See https://llvm.org/LICENSE.txt for license information.
5e5dd7070Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e5dd7070Spatrick //
7e5dd7070Spatrick //===----------------------------------------------------------------------===//
8e5dd7070Spatrick //
9e5dd7070Spatrick // Instrumentation-based profile-guided optimization
10e5dd7070Spatrick //
11e5dd7070Spatrick //===----------------------------------------------------------------------===//
12e5dd7070Spatrick
13e5dd7070Spatrick #include "CodeGenPGO.h"
14e5dd7070Spatrick #include "CodeGenFunction.h"
15e5dd7070Spatrick #include "CoverageMappingGen.h"
16e5dd7070Spatrick #include "clang/AST/RecursiveASTVisitor.h"
17e5dd7070Spatrick #include "clang/AST/StmtVisitor.h"
18e5dd7070Spatrick #include "llvm/IR/Intrinsics.h"
19e5dd7070Spatrick #include "llvm/IR/MDBuilder.h"
20e5dd7070Spatrick #include "llvm/Support/CommandLine.h"
21e5dd7070Spatrick #include "llvm/Support/Endian.h"
22e5dd7070Spatrick #include "llvm/Support/FileSystem.h"
23e5dd7070Spatrick #include "llvm/Support/MD5.h"
24*12c85518Srobert #include <optional>
25e5dd7070Spatrick
26e5dd7070Spatrick static llvm::cl::opt<bool>
27*12c85518Srobert EnableValueProfiling("enable-value-profiling",
28e5dd7070Spatrick llvm::cl::desc("Enable value profiling"),
29e5dd7070Spatrick llvm::cl::Hidden, llvm::cl::init(false));
30e5dd7070Spatrick
31e5dd7070Spatrick using namespace clang;
32e5dd7070Spatrick using namespace CodeGen;
33e5dd7070Spatrick
setFuncName(StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)34e5dd7070Spatrick void CodeGenPGO::setFuncName(StringRef Name,
35e5dd7070Spatrick llvm::GlobalValue::LinkageTypes Linkage) {
36e5dd7070Spatrick llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
37e5dd7070Spatrick FuncName = llvm::getPGOFuncName(
38e5dd7070Spatrick Name, Linkage, CGM.getCodeGenOpts().MainFileName,
39e5dd7070Spatrick PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
40e5dd7070Spatrick
41e5dd7070Spatrick // If we're generating a profile, create a variable for the name.
42e5dd7070Spatrick if (CGM.getCodeGenOpts().hasProfileClangInstr())
43e5dd7070Spatrick FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
44e5dd7070Spatrick }
45e5dd7070Spatrick
setFuncName(llvm::Function * Fn)46e5dd7070Spatrick void CodeGenPGO::setFuncName(llvm::Function *Fn) {
47e5dd7070Spatrick setFuncName(Fn->getName(), Fn->getLinkage());
48e5dd7070Spatrick // Create PGOFuncName meta data.
49e5dd7070Spatrick llvm::createPGOFuncNameMetadata(*Fn, FuncName);
50e5dd7070Spatrick }
51e5dd7070Spatrick
52e5dd7070Spatrick /// The version of the PGO hash algorithm.
53e5dd7070Spatrick enum PGOHashVersion : unsigned {
54e5dd7070Spatrick PGO_HASH_V1,
55e5dd7070Spatrick PGO_HASH_V2,
56ec727ea7Spatrick PGO_HASH_V3,
57e5dd7070Spatrick
58e5dd7070Spatrick // Keep this set to the latest hash version.
59ec727ea7Spatrick PGO_HASH_LATEST = PGO_HASH_V3
60e5dd7070Spatrick };
61e5dd7070Spatrick
62e5dd7070Spatrick namespace {
63e5dd7070Spatrick /// Stable hasher for PGO region counters.
64e5dd7070Spatrick ///
65e5dd7070Spatrick /// PGOHash produces a stable hash of a given function's control flow.
66e5dd7070Spatrick ///
67e5dd7070Spatrick /// Changing the output of this hash will invalidate all previously generated
68e5dd7070Spatrick /// profiles -- i.e., don't do it.
69e5dd7070Spatrick ///
70e5dd7070Spatrick /// \note When this hash does eventually change (years?), we still need to
71e5dd7070Spatrick /// support old hashes. We'll need to pull in the version number from the
72e5dd7070Spatrick /// profile data format and use the matching hash function.
73e5dd7070Spatrick class PGOHash {
74e5dd7070Spatrick uint64_t Working;
75e5dd7070Spatrick unsigned Count;
76e5dd7070Spatrick PGOHashVersion HashVersion;
77e5dd7070Spatrick llvm::MD5 MD5;
78e5dd7070Spatrick
79e5dd7070Spatrick static const int NumBitsPerType = 6;
80e5dd7070Spatrick static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
81e5dd7070Spatrick static const unsigned TooBig = 1u << NumBitsPerType;
82e5dd7070Spatrick
83e5dd7070Spatrick public:
84e5dd7070Spatrick /// Hash values for AST nodes.
85e5dd7070Spatrick ///
86e5dd7070Spatrick /// Distinct values for AST nodes that have region counters attached.
87e5dd7070Spatrick ///
88e5dd7070Spatrick /// These values must be stable. All new members must be added at the end,
89e5dd7070Spatrick /// and no members should be removed. Changing the enumeration value for an
90e5dd7070Spatrick /// AST node will affect the hash of every function that contains that node.
91e5dd7070Spatrick enum HashType : unsigned char {
92e5dd7070Spatrick None = 0,
93e5dd7070Spatrick LabelStmt = 1,
94e5dd7070Spatrick WhileStmt,
95e5dd7070Spatrick DoStmt,
96e5dd7070Spatrick ForStmt,
97e5dd7070Spatrick CXXForRangeStmt,
98e5dd7070Spatrick ObjCForCollectionStmt,
99e5dd7070Spatrick SwitchStmt,
100e5dd7070Spatrick CaseStmt,
101e5dd7070Spatrick DefaultStmt,
102e5dd7070Spatrick IfStmt,
103e5dd7070Spatrick CXXTryStmt,
104e5dd7070Spatrick CXXCatchStmt,
105e5dd7070Spatrick ConditionalOperator,
106e5dd7070Spatrick BinaryOperatorLAnd,
107e5dd7070Spatrick BinaryOperatorLOr,
108e5dd7070Spatrick BinaryConditionalOperator,
109e5dd7070Spatrick // The preceding values are available with PGO_HASH_V1.
110e5dd7070Spatrick
111e5dd7070Spatrick EndOfScope,
112e5dd7070Spatrick IfThenBranch,
113e5dd7070Spatrick IfElseBranch,
114e5dd7070Spatrick GotoStmt,
115e5dd7070Spatrick IndirectGotoStmt,
116e5dd7070Spatrick BreakStmt,
117e5dd7070Spatrick ContinueStmt,
118e5dd7070Spatrick ReturnStmt,
119e5dd7070Spatrick ThrowExpr,
120e5dd7070Spatrick UnaryOperatorLNot,
121e5dd7070Spatrick BinaryOperatorLT,
122e5dd7070Spatrick BinaryOperatorGT,
123e5dd7070Spatrick BinaryOperatorLE,
124e5dd7070Spatrick BinaryOperatorGE,
125e5dd7070Spatrick BinaryOperatorEQ,
126e5dd7070Spatrick BinaryOperatorNE,
127ec727ea7Spatrick // The preceding values are available since PGO_HASH_V2.
128e5dd7070Spatrick
129e5dd7070Spatrick // Keep this last. It's for the static assert that follows.
130e5dd7070Spatrick LastHashType
131e5dd7070Spatrick };
132e5dd7070Spatrick static_assert(LastHashType <= TooBig, "Too many types in HashType");
133e5dd7070Spatrick
PGOHash(PGOHashVersion HashVersion)134e5dd7070Spatrick PGOHash(PGOHashVersion HashVersion)
135*12c85518Srobert : Working(0), Count(0), HashVersion(HashVersion) {}
136e5dd7070Spatrick void combine(HashType Type);
137e5dd7070Spatrick uint64_t finalize();
getHashVersion() const138e5dd7070Spatrick PGOHashVersion getHashVersion() const { return HashVersion; }
139e5dd7070Spatrick };
140e5dd7070Spatrick const int PGOHash::NumBitsPerType;
141e5dd7070Spatrick const unsigned PGOHash::NumTypesPerWord;
142e5dd7070Spatrick const unsigned PGOHash::TooBig;
143e5dd7070Spatrick
144e5dd7070Spatrick /// Get the PGO hash version used in the given indexed profile.
getPGOHashVersion(llvm::IndexedInstrProfReader * PGOReader,CodeGenModule & CGM)145e5dd7070Spatrick static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
146e5dd7070Spatrick CodeGenModule &CGM) {
147e5dd7070Spatrick if (PGOReader->getVersion() <= 4)
148e5dd7070Spatrick return PGO_HASH_V1;
149ec727ea7Spatrick if (PGOReader->getVersion() <= 5)
150e5dd7070Spatrick return PGO_HASH_V2;
151ec727ea7Spatrick return PGO_HASH_V3;
152e5dd7070Spatrick }
153e5dd7070Spatrick
154e5dd7070Spatrick /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
155e5dd7070Spatrick struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
156e5dd7070Spatrick using Base = RecursiveASTVisitor<MapRegionCounters>;
157e5dd7070Spatrick
158e5dd7070Spatrick /// The next counter value to assign.
159e5dd7070Spatrick unsigned NextCounter;
160e5dd7070Spatrick /// The function hash.
161e5dd7070Spatrick PGOHash Hash;
162e5dd7070Spatrick /// The map of statements to counters.
163e5dd7070Spatrick llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
164a9ac8606Spatrick /// The profile version.
165a9ac8606Spatrick uint64_t ProfileVersion;
166e5dd7070Spatrick
MapRegionCounters__anon393d18f00111::MapRegionCounters167a9ac8606Spatrick MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
168e5dd7070Spatrick llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
169a9ac8606Spatrick : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
170a9ac8606Spatrick ProfileVersion(ProfileVersion) {}
171e5dd7070Spatrick
172e5dd7070Spatrick // Blocks and lambdas are handled as separate functions, so we need not
173e5dd7070Spatrick // traverse them in the parent context.
TraverseBlockExpr__anon393d18f00111::MapRegionCounters174e5dd7070Spatrick bool TraverseBlockExpr(BlockExpr *BE) { return true; }
TraverseLambdaExpr__anon393d18f00111::MapRegionCounters175e5dd7070Spatrick bool TraverseLambdaExpr(LambdaExpr *LE) {
176e5dd7070Spatrick // Traverse the captures, but not the body.
177e5dd7070Spatrick for (auto C : zip(LE->captures(), LE->capture_inits()))
178e5dd7070Spatrick TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
179e5dd7070Spatrick return true;
180e5dd7070Spatrick }
TraverseCapturedStmt__anon393d18f00111::MapRegionCounters181e5dd7070Spatrick bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
182e5dd7070Spatrick
VisitDecl__anon393d18f00111::MapRegionCounters183e5dd7070Spatrick bool VisitDecl(const Decl *D) {
184e5dd7070Spatrick switch (D->getKind()) {
185e5dd7070Spatrick default:
186e5dd7070Spatrick break;
187e5dd7070Spatrick case Decl::Function:
188e5dd7070Spatrick case Decl::CXXMethod:
189e5dd7070Spatrick case Decl::CXXConstructor:
190e5dd7070Spatrick case Decl::CXXDestructor:
191e5dd7070Spatrick case Decl::CXXConversion:
192e5dd7070Spatrick case Decl::ObjCMethod:
193e5dd7070Spatrick case Decl::Block:
194e5dd7070Spatrick case Decl::Captured:
195e5dd7070Spatrick CounterMap[D->getBody()] = NextCounter++;
196e5dd7070Spatrick break;
197e5dd7070Spatrick }
198e5dd7070Spatrick return true;
199e5dd7070Spatrick }
200e5dd7070Spatrick
201e5dd7070Spatrick /// If \p S gets a fresh counter, update the counter mappings. Return the
202e5dd7070Spatrick /// V1 hash of \p S.
updateCounterMappings__anon393d18f00111::MapRegionCounters203e5dd7070Spatrick PGOHash::HashType updateCounterMappings(Stmt *S) {
204e5dd7070Spatrick auto Type = getHashType(PGO_HASH_V1, S);
205e5dd7070Spatrick if (Type != PGOHash::None)
206e5dd7070Spatrick CounterMap[S] = NextCounter++;
207e5dd7070Spatrick return Type;
208e5dd7070Spatrick }
209e5dd7070Spatrick
210a9ac8606Spatrick /// The RHS of all logical operators gets a fresh counter in order to count
211a9ac8606Spatrick /// how many times the RHS evaluates to true or false, depending on the
212a9ac8606Spatrick /// semantics of the operator. This is only valid for ">= v7" of the profile
213a9ac8606Spatrick /// version so that we facilitate backward compatibility.
VisitBinaryOperator__anon393d18f00111::MapRegionCounters214a9ac8606Spatrick bool VisitBinaryOperator(BinaryOperator *S) {
215a9ac8606Spatrick if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
216a9ac8606Spatrick if (S->isLogicalOp() &&
217a9ac8606Spatrick CodeGenFunction::isInstrumentedCondition(S->getRHS()))
218a9ac8606Spatrick CounterMap[S->getRHS()] = NextCounter++;
219a9ac8606Spatrick return Base::VisitBinaryOperator(S);
220a9ac8606Spatrick }
221a9ac8606Spatrick
222e5dd7070Spatrick /// Include \p S in the function hash.
VisitStmt__anon393d18f00111::MapRegionCounters223e5dd7070Spatrick bool VisitStmt(Stmt *S) {
224e5dd7070Spatrick auto Type = updateCounterMappings(S);
225e5dd7070Spatrick if (Hash.getHashVersion() != PGO_HASH_V1)
226e5dd7070Spatrick Type = getHashType(Hash.getHashVersion(), S);
227e5dd7070Spatrick if (Type != PGOHash::None)
228e5dd7070Spatrick Hash.combine(Type);
229e5dd7070Spatrick return true;
230e5dd7070Spatrick }
231e5dd7070Spatrick
TraverseIfStmt__anon393d18f00111::MapRegionCounters232e5dd7070Spatrick bool TraverseIfStmt(IfStmt *If) {
233e5dd7070Spatrick // If we used the V1 hash, use the default traversal.
234e5dd7070Spatrick if (Hash.getHashVersion() == PGO_HASH_V1)
235e5dd7070Spatrick return Base::TraverseIfStmt(If);
236e5dd7070Spatrick
237e5dd7070Spatrick // Otherwise, keep track of which branch we're in while traversing.
238e5dd7070Spatrick VisitStmt(If);
239e5dd7070Spatrick for (Stmt *CS : If->children()) {
240e5dd7070Spatrick if (!CS)
241e5dd7070Spatrick continue;
242e5dd7070Spatrick if (CS == If->getThen())
243e5dd7070Spatrick Hash.combine(PGOHash::IfThenBranch);
244e5dd7070Spatrick else if (CS == If->getElse())
245e5dd7070Spatrick Hash.combine(PGOHash::IfElseBranch);
246e5dd7070Spatrick TraverseStmt(CS);
247e5dd7070Spatrick }
248e5dd7070Spatrick Hash.combine(PGOHash::EndOfScope);
249e5dd7070Spatrick return true;
250e5dd7070Spatrick }
251e5dd7070Spatrick
252e5dd7070Spatrick // If the statement type \p N is nestable, and its nesting impacts profile
253e5dd7070Spatrick // stability, define a custom traversal which tracks the end of the statement
254e5dd7070Spatrick // in the hash (provided we're not using the V1 hash).
255e5dd7070Spatrick #define DEFINE_NESTABLE_TRAVERSAL(N) \
256e5dd7070Spatrick bool Traverse##N(N *S) { \
257e5dd7070Spatrick Base::Traverse##N(S); \
258e5dd7070Spatrick if (Hash.getHashVersion() != PGO_HASH_V1) \
259e5dd7070Spatrick Hash.combine(PGOHash::EndOfScope); \
260e5dd7070Spatrick return true; \
261e5dd7070Spatrick }
262e5dd7070Spatrick
263e5dd7070Spatrick DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
DEFINE_NESTABLE_TRAVERSAL__anon393d18f00111::MapRegionCounters264e5dd7070Spatrick DEFINE_NESTABLE_TRAVERSAL(DoStmt)
265e5dd7070Spatrick DEFINE_NESTABLE_TRAVERSAL(ForStmt)
266e5dd7070Spatrick DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
267e5dd7070Spatrick DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
268e5dd7070Spatrick DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
269e5dd7070Spatrick DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
270e5dd7070Spatrick
271e5dd7070Spatrick /// Get version \p HashVersion of the PGO hash for \p S.
272e5dd7070Spatrick PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
273e5dd7070Spatrick switch (S->getStmtClass()) {
274e5dd7070Spatrick default:
275e5dd7070Spatrick break;
276e5dd7070Spatrick case Stmt::LabelStmtClass:
277e5dd7070Spatrick return PGOHash::LabelStmt;
278e5dd7070Spatrick case Stmt::WhileStmtClass:
279e5dd7070Spatrick return PGOHash::WhileStmt;
280e5dd7070Spatrick case Stmt::DoStmtClass:
281e5dd7070Spatrick return PGOHash::DoStmt;
282e5dd7070Spatrick case Stmt::ForStmtClass:
283e5dd7070Spatrick return PGOHash::ForStmt;
284e5dd7070Spatrick case Stmt::CXXForRangeStmtClass:
285e5dd7070Spatrick return PGOHash::CXXForRangeStmt;
286e5dd7070Spatrick case Stmt::ObjCForCollectionStmtClass:
287e5dd7070Spatrick return PGOHash::ObjCForCollectionStmt;
288e5dd7070Spatrick case Stmt::SwitchStmtClass:
289e5dd7070Spatrick return PGOHash::SwitchStmt;
290e5dd7070Spatrick case Stmt::CaseStmtClass:
291e5dd7070Spatrick return PGOHash::CaseStmt;
292e5dd7070Spatrick case Stmt::DefaultStmtClass:
293e5dd7070Spatrick return PGOHash::DefaultStmt;
294e5dd7070Spatrick case Stmt::IfStmtClass:
295e5dd7070Spatrick return PGOHash::IfStmt;
296e5dd7070Spatrick case Stmt::CXXTryStmtClass:
297e5dd7070Spatrick return PGOHash::CXXTryStmt;
298e5dd7070Spatrick case Stmt::CXXCatchStmtClass:
299e5dd7070Spatrick return PGOHash::CXXCatchStmt;
300e5dd7070Spatrick case Stmt::ConditionalOperatorClass:
301e5dd7070Spatrick return PGOHash::ConditionalOperator;
302e5dd7070Spatrick case Stmt::BinaryConditionalOperatorClass:
303e5dd7070Spatrick return PGOHash::BinaryConditionalOperator;
304e5dd7070Spatrick case Stmt::BinaryOperatorClass: {
305e5dd7070Spatrick const BinaryOperator *BO = cast<BinaryOperator>(S);
306e5dd7070Spatrick if (BO->getOpcode() == BO_LAnd)
307e5dd7070Spatrick return PGOHash::BinaryOperatorLAnd;
308e5dd7070Spatrick if (BO->getOpcode() == BO_LOr)
309e5dd7070Spatrick return PGOHash::BinaryOperatorLOr;
310ec727ea7Spatrick if (HashVersion >= PGO_HASH_V2) {
311e5dd7070Spatrick switch (BO->getOpcode()) {
312e5dd7070Spatrick default:
313e5dd7070Spatrick break;
314e5dd7070Spatrick case BO_LT:
315e5dd7070Spatrick return PGOHash::BinaryOperatorLT;
316e5dd7070Spatrick case BO_GT:
317e5dd7070Spatrick return PGOHash::BinaryOperatorGT;
318e5dd7070Spatrick case BO_LE:
319e5dd7070Spatrick return PGOHash::BinaryOperatorLE;
320e5dd7070Spatrick case BO_GE:
321e5dd7070Spatrick return PGOHash::BinaryOperatorGE;
322e5dd7070Spatrick case BO_EQ:
323e5dd7070Spatrick return PGOHash::BinaryOperatorEQ;
324e5dd7070Spatrick case BO_NE:
325e5dd7070Spatrick return PGOHash::BinaryOperatorNE;
326e5dd7070Spatrick }
327e5dd7070Spatrick }
328e5dd7070Spatrick break;
329e5dd7070Spatrick }
330e5dd7070Spatrick }
331e5dd7070Spatrick
332ec727ea7Spatrick if (HashVersion >= PGO_HASH_V2) {
333e5dd7070Spatrick switch (S->getStmtClass()) {
334e5dd7070Spatrick default:
335e5dd7070Spatrick break;
336e5dd7070Spatrick case Stmt::GotoStmtClass:
337e5dd7070Spatrick return PGOHash::GotoStmt;
338e5dd7070Spatrick case Stmt::IndirectGotoStmtClass:
339e5dd7070Spatrick return PGOHash::IndirectGotoStmt;
340e5dd7070Spatrick case Stmt::BreakStmtClass:
341e5dd7070Spatrick return PGOHash::BreakStmt;
342e5dd7070Spatrick case Stmt::ContinueStmtClass:
343e5dd7070Spatrick return PGOHash::ContinueStmt;
344e5dd7070Spatrick case Stmt::ReturnStmtClass:
345e5dd7070Spatrick return PGOHash::ReturnStmt;
346e5dd7070Spatrick case Stmt::CXXThrowExprClass:
347e5dd7070Spatrick return PGOHash::ThrowExpr;
348e5dd7070Spatrick case Stmt::UnaryOperatorClass: {
349e5dd7070Spatrick const UnaryOperator *UO = cast<UnaryOperator>(S);
350e5dd7070Spatrick if (UO->getOpcode() == UO_LNot)
351e5dd7070Spatrick return PGOHash::UnaryOperatorLNot;
352e5dd7070Spatrick break;
353e5dd7070Spatrick }
354e5dd7070Spatrick }
355e5dd7070Spatrick }
356e5dd7070Spatrick
357e5dd7070Spatrick return PGOHash::None;
358e5dd7070Spatrick }
359e5dd7070Spatrick };
360e5dd7070Spatrick
361e5dd7070Spatrick /// A StmtVisitor that propagates the raw counts through the AST and
362e5dd7070Spatrick /// records the count at statements where the value may change.
363e5dd7070Spatrick struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
364e5dd7070Spatrick /// PGO state.
365e5dd7070Spatrick CodeGenPGO &PGO;
366e5dd7070Spatrick
367e5dd7070Spatrick /// A flag that is set when the current count should be recorded on the
368e5dd7070Spatrick /// next statement, such as at the exit of a loop.
369e5dd7070Spatrick bool RecordNextStmtCount;
370e5dd7070Spatrick
371e5dd7070Spatrick /// The count at the current location in the traversal.
372e5dd7070Spatrick uint64_t CurrentCount;
373e5dd7070Spatrick
374e5dd7070Spatrick /// The map of statements to count values.
375e5dd7070Spatrick llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
376e5dd7070Spatrick
377e5dd7070Spatrick /// BreakContinueStack - Keep counts of breaks and continues inside loops.
378e5dd7070Spatrick struct BreakContinue {
379e5dd7070Spatrick uint64_t BreakCount;
380e5dd7070Spatrick uint64_t ContinueCount;
BreakContinue__anon393d18f00111::ComputeRegionCounts::BreakContinue381e5dd7070Spatrick BreakContinue() : BreakCount(0), ContinueCount(0) {}
382e5dd7070Spatrick };
383e5dd7070Spatrick SmallVector<BreakContinue, 8> BreakContinueStack;
384e5dd7070Spatrick
ComputeRegionCounts__anon393d18f00111::ComputeRegionCounts385e5dd7070Spatrick ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
386e5dd7070Spatrick CodeGenPGO &PGO)
387e5dd7070Spatrick : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
388e5dd7070Spatrick
RecordStmtCount__anon393d18f00111::ComputeRegionCounts389e5dd7070Spatrick void RecordStmtCount(const Stmt *S) {
390e5dd7070Spatrick if (RecordNextStmtCount) {
391e5dd7070Spatrick CountMap[S] = CurrentCount;
392e5dd7070Spatrick RecordNextStmtCount = false;
393e5dd7070Spatrick }
394e5dd7070Spatrick }
395e5dd7070Spatrick
396e5dd7070Spatrick /// Set and return the current count.
setCount__anon393d18f00111::ComputeRegionCounts397e5dd7070Spatrick uint64_t setCount(uint64_t Count) {
398e5dd7070Spatrick CurrentCount = Count;
399e5dd7070Spatrick return Count;
400e5dd7070Spatrick }
401e5dd7070Spatrick
VisitStmt__anon393d18f00111::ComputeRegionCounts402e5dd7070Spatrick void VisitStmt(const Stmt *S) {
403e5dd7070Spatrick RecordStmtCount(S);
404e5dd7070Spatrick for (const Stmt *Child : S->children())
405e5dd7070Spatrick if (Child)
406e5dd7070Spatrick this->Visit(Child);
407e5dd7070Spatrick }
408e5dd7070Spatrick
VisitFunctionDecl__anon393d18f00111::ComputeRegionCounts409e5dd7070Spatrick void VisitFunctionDecl(const FunctionDecl *D) {
410e5dd7070Spatrick // Counter tracks entry to the function body.
411e5dd7070Spatrick uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
412e5dd7070Spatrick CountMap[D->getBody()] = BodyCount;
413e5dd7070Spatrick Visit(D->getBody());
414e5dd7070Spatrick }
415e5dd7070Spatrick
416e5dd7070Spatrick // Skip lambda expressions. We visit these as FunctionDecls when we're
417e5dd7070Spatrick // generating them and aren't interested in the body when generating a
418e5dd7070Spatrick // parent context.
VisitLambdaExpr__anon393d18f00111::ComputeRegionCounts419e5dd7070Spatrick void VisitLambdaExpr(const LambdaExpr *LE) {}
420e5dd7070Spatrick
VisitCapturedDecl__anon393d18f00111::ComputeRegionCounts421e5dd7070Spatrick void VisitCapturedDecl(const CapturedDecl *D) {
422e5dd7070Spatrick // Counter tracks entry to the capture body.
423e5dd7070Spatrick uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
424e5dd7070Spatrick CountMap[D->getBody()] = BodyCount;
425e5dd7070Spatrick Visit(D->getBody());
426e5dd7070Spatrick }
427e5dd7070Spatrick
VisitObjCMethodDecl__anon393d18f00111::ComputeRegionCounts428e5dd7070Spatrick void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
429e5dd7070Spatrick // Counter tracks entry to the method body.
430e5dd7070Spatrick uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
431e5dd7070Spatrick CountMap[D->getBody()] = BodyCount;
432e5dd7070Spatrick Visit(D->getBody());
433e5dd7070Spatrick }
434e5dd7070Spatrick
VisitBlockDecl__anon393d18f00111::ComputeRegionCounts435e5dd7070Spatrick void VisitBlockDecl(const BlockDecl *D) {
436e5dd7070Spatrick // Counter tracks entry to the block body.
437e5dd7070Spatrick uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
438e5dd7070Spatrick CountMap[D->getBody()] = BodyCount;
439e5dd7070Spatrick Visit(D->getBody());
440e5dd7070Spatrick }
441e5dd7070Spatrick
VisitReturnStmt__anon393d18f00111::ComputeRegionCounts442e5dd7070Spatrick void VisitReturnStmt(const ReturnStmt *S) {
443e5dd7070Spatrick RecordStmtCount(S);
444e5dd7070Spatrick if (S->getRetValue())
445e5dd7070Spatrick Visit(S->getRetValue());
446e5dd7070Spatrick CurrentCount = 0;
447e5dd7070Spatrick RecordNextStmtCount = true;
448e5dd7070Spatrick }
449e5dd7070Spatrick
VisitCXXThrowExpr__anon393d18f00111::ComputeRegionCounts450e5dd7070Spatrick void VisitCXXThrowExpr(const CXXThrowExpr *E) {
451e5dd7070Spatrick RecordStmtCount(E);
452e5dd7070Spatrick if (E->getSubExpr())
453e5dd7070Spatrick Visit(E->getSubExpr());
454e5dd7070Spatrick CurrentCount = 0;
455e5dd7070Spatrick RecordNextStmtCount = true;
456e5dd7070Spatrick }
457e5dd7070Spatrick
VisitGotoStmt__anon393d18f00111::ComputeRegionCounts458e5dd7070Spatrick void VisitGotoStmt(const GotoStmt *S) {
459e5dd7070Spatrick RecordStmtCount(S);
460e5dd7070Spatrick CurrentCount = 0;
461e5dd7070Spatrick RecordNextStmtCount = true;
462e5dd7070Spatrick }
463e5dd7070Spatrick
VisitLabelStmt__anon393d18f00111::ComputeRegionCounts464e5dd7070Spatrick void VisitLabelStmt(const LabelStmt *S) {
465e5dd7070Spatrick RecordNextStmtCount = false;
466e5dd7070Spatrick // Counter tracks the block following the label.
467e5dd7070Spatrick uint64_t BlockCount = setCount(PGO.getRegionCount(S));
468e5dd7070Spatrick CountMap[S] = BlockCount;
469e5dd7070Spatrick Visit(S->getSubStmt());
470e5dd7070Spatrick }
471e5dd7070Spatrick
VisitBreakStmt__anon393d18f00111::ComputeRegionCounts472e5dd7070Spatrick void VisitBreakStmt(const BreakStmt *S) {
473e5dd7070Spatrick RecordStmtCount(S);
474e5dd7070Spatrick assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
475e5dd7070Spatrick BreakContinueStack.back().BreakCount += CurrentCount;
476e5dd7070Spatrick CurrentCount = 0;
477e5dd7070Spatrick RecordNextStmtCount = true;
478e5dd7070Spatrick }
479e5dd7070Spatrick
VisitContinueStmt__anon393d18f00111::ComputeRegionCounts480e5dd7070Spatrick void VisitContinueStmt(const ContinueStmt *S) {
481e5dd7070Spatrick RecordStmtCount(S);
482e5dd7070Spatrick assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
483e5dd7070Spatrick BreakContinueStack.back().ContinueCount += CurrentCount;
484e5dd7070Spatrick CurrentCount = 0;
485e5dd7070Spatrick RecordNextStmtCount = true;
486e5dd7070Spatrick }
487e5dd7070Spatrick
VisitWhileStmt__anon393d18f00111::ComputeRegionCounts488e5dd7070Spatrick void VisitWhileStmt(const WhileStmt *S) {
489e5dd7070Spatrick RecordStmtCount(S);
490e5dd7070Spatrick uint64_t ParentCount = CurrentCount;
491e5dd7070Spatrick
492e5dd7070Spatrick BreakContinueStack.push_back(BreakContinue());
493e5dd7070Spatrick // Visit the body region first so the break/continue adjustments can be
494e5dd7070Spatrick // included when visiting the condition.
495e5dd7070Spatrick uint64_t BodyCount = setCount(PGO.getRegionCount(S));
496e5dd7070Spatrick CountMap[S->getBody()] = CurrentCount;
497e5dd7070Spatrick Visit(S->getBody());
498e5dd7070Spatrick uint64_t BackedgeCount = CurrentCount;
499e5dd7070Spatrick
500e5dd7070Spatrick // ...then go back and propagate counts through the condition. The count
501e5dd7070Spatrick // at the start of the condition is the sum of the incoming edges,
502e5dd7070Spatrick // the backedge from the end of the loop body, and the edges from
503e5dd7070Spatrick // continue statements.
504e5dd7070Spatrick BreakContinue BC = BreakContinueStack.pop_back_val();
505e5dd7070Spatrick uint64_t CondCount =
506e5dd7070Spatrick setCount(ParentCount + BackedgeCount + BC.ContinueCount);
507e5dd7070Spatrick CountMap[S->getCond()] = CondCount;
508e5dd7070Spatrick Visit(S->getCond());
509e5dd7070Spatrick setCount(BC.BreakCount + CondCount - BodyCount);
510e5dd7070Spatrick RecordNextStmtCount = true;
511e5dd7070Spatrick }
512e5dd7070Spatrick
VisitDoStmt__anon393d18f00111::ComputeRegionCounts513e5dd7070Spatrick void VisitDoStmt(const DoStmt *S) {
514e5dd7070Spatrick RecordStmtCount(S);
515e5dd7070Spatrick uint64_t LoopCount = PGO.getRegionCount(S);
516e5dd7070Spatrick
517e5dd7070Spatrick BreakContinueStack.push_back(BreakContinue());
518e5dd7070Spatrick // The count doesn't include the fallthrough from the parent scope. Add it.
519e5dd7070Spatrick uint64_t BodyCount = setCount(LoopCount + CurrentCount);
520e5dd7070Spatrick CountMap[S->getBody()] = BodyCount;
521e5dd7070Spatrick Visit(S->getBody());
522e5dd7070Spatrick uint64_t BackedgeCount = CurrentCount;
523e5dd7070Spatrick
524e5dd7070Spatrick BreakContinue BC = BreakContinueStack.pop_back_val();
525e5dd7070Spatrick // The count at the start of the condition is equal to the count at the
526e5dd7070Spatrick // end of the body, plus any continues.
527e5dd7070Spatrick uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
528e5dd7070Spatrick CountMap[S->getCond()] = CondCount;
529e5dd7070Spatrick Visit(S->getCond());
530e5dd7070Spatrick setCount(BC.BreakCount + CondCount - LoopCount);
531e5dd7070Spatrick RecordNextStmtCount = true;
532e5dd7070Spatrick }
533e5dd7070Spatrick
VisitForStmt__anon393d18f00111::ComputeRegionCounts534e5dd7070Spatrick void VisitForStmt(const ForStmt *S) {
535e5dd7070Spatrick RecordStmtCount(S);
536e5dd7070Spatrick if (S->getInit())
537e5dd7070Spatrick Visit(S->getInit());
538e5dd7070Spatrick
539e5dd7070Spatrick uint64_t ParentCount = CurrentCount;
540e5dd7070Spatrick
541e5dd7070Spatrick BreakContinueStack.push_back(BreakContinue());
542e5dd7070Spatrick // Visit the body region first. (This is basically the same as a while
543e5dd7070Spatrick // loop; see further comments in VisitWhileStmt.)
544e5dd7070Spatrick uint64_t BodyCount = setCount(PGO.getRegionCount(S));
545e5dd7070Spatrick CountMap[S->getBody()] = BodyCount;
546e5dd7070Spatrick Visit(S->getBody());
547e5dd7070Spatrick uint64_t BackedgeCount = CurrentCount;
548e5dd7070Spatrick BreakContinue BC = BreakContinueStack.pop_back_val();
549e5dd7070Spatrick
550e5dd7070Spatrick // The increment is essentially part of the body but it needs to include
551e5dd7070Spatrick // the count for all the continue statements.
552e5dd7070Spatrick if (S->getInc()) {
553e5dd7070Spatrick uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
554e5dd7070Spatrick CountMap[S->getInc()] = IncCount;
555e5dd7070Spatrick Visit(S->getInc());
556e5dd7070Spatrick }
557e5dd7070Spatrick
558e5dd7070Spatrick // ...then go back and propagate counts through the condition.
559e5dd7070Spatrick uint64_t CondCount =
560e5dd7070Spatrick setCount(ParentCount + BackedgeCount + BC.ContinueCount);
561e5dd7070Spatrick if (S->getCond()) {
562e5dd7070Spatrick CountMap[S->getCond()] = CondCount;
563e5dd7070Spatrick Visit(S->getCond());
564e5dd7070Spatrick }
565e5dd7070Spatrick setCount(BC.BreakCount + CondCount - BodyCount);
566e5dd7070Spatrick RecordNextStmtCount = true;
567e5dd7070Spatrick }
568e5dd7070Spatrick
VisitCXXForRangeStmt__anon393d18f00111::ComputeRegionCounts569e5dd7070Spatrick void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
570e5dd7070Spatrick RecordStmtCount(S);
571e5dd7070Spatrick if (S->getInit())
572e5dd7070Spatrick Visit(S->getInit());
573e5dd7070Spatrick Visit(S->getLoopVarStmt());
574e5dd7070Spatrick Visit(S->getRangeStmt());
575e5dd7070Spatrick Visit(S->getBeginStmt());
576e5dd7070Spatrick Visit(S->getEndStmt());
577e5dd7070Spatrick
578e5dd7070Spatrick uint64_t ParentCount = CurrentCount;
579e5dd7070Spatrick BreakContinueStack.push_back(BreakContinue());
580e5dd7070Spatrick // Visit the body region first. (This is basically the same as a while
581e5dd7070Spatrick // loop; see further comments in VisitWhileStmt.)
582e5dd7070Spatrick uint64_t BodyCount = setCount(PGO.getRegionCount(S));
583e5dd7070Spatrick CountMap[S->getBody()] = BodyCount;
584e5dd7070Spatrick Visit(S->getBody());
585e5dd7070Spatrick uint64_t BackedgeCount = CurrentCount;
586e5dd7070Spatrick BreakContinue BC = BreakContinueStack.pop_back_val();
587e5dd7070Spatrick
588e5dd7070Spatrick // The increment is essentially part of the body but it needs to include
589e5dd7070Spatrick // the count for all the continue statements.
590e5dd7070Spatrick uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
591e5dd7070Spatrick CountMap[S->getInc()] = IncCount;
592e5dd7070Spatrick Visit(S->getInc());
593e5dd7070Spatrick
594e5dd7070Spatrick // ...then go back and propagate counts through the condition.
595e5dd7070Spatrick uint64_t CondCount =
596e5dd7070Spatrick setCount(ParentCount + BackedgeCount + BC.ContinueCount);
597e5dd7070Spatrick CountMap[S->getCond()] = CondCount;
598e5dd7070Spatrick Visit(S->getCond());
599e5dd7070Spatrick setCount(BC.BreakCount + CondCount - BodyCount);
600e5dd7070Spatrick RecordNextStmtCount = true;
601e5dd7070Spatrick }
602e5dd7070Spatrick
VisitObjCForCollectionStmt__anon393d18f00111::ComputeRegionCounts603e5dd7070Spatrick void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
604e5dd7070Spatrick RecordStmtCount(S);
605e5dd7070Spatrick Visit(S->getElement());
606e5dd7070Spatrick uint64_t ParentCount = CurrentCount;
607e5dd7070Spatrick BreakContinueStack.push_back(BreakContinue());
608e5dd7070Spatrick // Counter tracks the body of the loop.
609e5dd7070Spatrick uint64_t BodyCount = setCount(PGO.getRegionCount(S));
610e5dd7070Spatrick CountMap[S->getBody()] = BodyCount;
611e5dd7070Spatrick Visit(S->getBody());
612e5dd7070Spatrick uint64_t BackedgeCount = CurrentCount;
613e5dd7070Spatrick BreakContinue BC = BreakContinueStack.pop_back_val();
614e5dd7070Spatrick
615e5dd7070Spatrick setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
616e5dd7070Spatrick BodyCount);
617e5dd7070Spatrick RecordNextStmtCount = true;
618e5dd7070Spatrick }
619e5dd7070Spatrick
VisitSwitchStmt__anon393d18f00111::ComputeRegionCounts620e5dd7070Spatrick void VisitSwitchStmt(const SwitchStmt *S) {
621e5dd7070Spatrick RecordStmtCount(S);
622e5dd7070Spatrick if (S->getInit())
623e5dd7070Spatrick Visit(S->getInit());
624e5dd7070Spatrick Visit(S->getCond());
625e5dd7070Spatrick CurrentCount = 0;
626e5dd7070Spatrick BreakContinueStack.push_back(BreakContinue());
627e5dd7070Spatrick Visit(S->getBody());
628e5dd7070Spatrick // If the switch is inside a loop, add the continue counts.
629e5dd7070Spatrick BreakContinue BC = BreakContinueStack.pop_back_val();
630e5dd7070Spatrick if (!BreakContinueStack.empty())
631e5dd7070Spatrick BreakContinueStack.back().ContinueCount += BC.ContinueCount;
632e5dd7070Spatrick // Counter tracks the exit block of the switch.
633e5dd7070Spatrick setCount(PGO.getRegionCount(S));
634e5dd7070Spatrick RecordNextStmtCount = true;
635e5dd7070Spatrick }
636e5dd7070Spatrick
VisitSwitchCase__anon393d18f00111::ComputeRegionCounts637e5dd7070Spatrick void VisitSwitchCase(const SwitchCase *S) {
638e5dd7070Spatrick RecordNextStmtCount = false;
639e5dd7070Spatrick // Counter for this particular case. This counts only jumps from the
640e5dd7070Spatrick // switch header and does not include fallthrough from the case before
641e5dd7070Spatrick // this one.
642e5dd7070Spatrick uint64_t CaseCount = PGO.getRegionCount(S);
643e5dd7070Spatrick setCount(CurrentCount + CaseCount);
644e5dd7070Spatrick // We need the count without fallthrough in the mapping, so it's more useful
645e5dd7070Spatrick // for branch probabilities.
646e5dd7070Spatrick CountMap[S] = CaseCount;
647e5dd7070Spatrick RecordNextStmtCount = true;
648e5dd7070Spatrick Visit(S->getSubStmt());
649e5dd7070Spatrick }
650e5dd7070Spatrick
VisitIfStmt__anon393d18f00111::ComputeRegionCounts651e5dd7070Spatrick void VisitIfStmt(const IfStmt *S) {
652e5dd7070Spatrick RecordStmtCount(S);
653*12c85518Srobert
654*12c85518Srobert if (S->isConsteval()) {
655*12c85518Srobert const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse();
656*12c85518Srobert if (Stm)
657*12c85518Srobert Visit(Stm);
658*12c85518Srobert return;
659*12c85518Srobert }
660*12c85518Srobert
661e5dd7070Spatrick uint64_t ParentCount = CurrentCount;
662e5dd7070Spatrick if (S->getInit())
663e5dd7070Spatrick Visit(S->getInit());
664e5dd7070Spatrick Visit(S->getCond());
665e5dd7070Spatrick
666e5dd7070Spatrick // Counter tracks the "then" part of an if statement. The count for
667e5dd7070Spatrick // the "else" part, if it exists, will be calculated from this counter.
668e5dd7070Spatrick uint64_t ThenCount = setCount(PGO.getRegionCount(S));
669e5dd7070Spatrick CountMap[S->getThen()] = ThenCount;
670e5dd7070Spatrick Visit(S->getThen());
671e5dd7070Spatrick uint64_t OutCount = CurrentCount;
672e5dd7070Spatrick
673e5dd7070Spatrick uint64_t ElseCount = ParentCount - ThenCount;
674e5dd7070Spatrick if (S->getElse()) {
675e5dd7070Spatrick setCount(ElseCount);
676e5dd7070Spatrick CountMap[S->getElse()] = ElseCount;
677e5dd7070Spatrick Visit(S->getElse());
678e5dd7070Spatrick OutCount += CurrentCount;
679e5dd7070Spatrick } else
680e5dd7070Spatrick OutCount += ElseCount;
681e5dd7070Spatrick setCount(OutCount);
682e5dd7070Spatrick RecordNextStmtCount = true;
683e5dd7070Spatrick }
684e5dd7070Spatrick
VisitCXXTryStmt__anon393d18f00111::ComputeRegionCounts685e5dd7070Spatrick void VisitCXXTryStmt(const CXXTryStmt *S) {
686e5dd7070Spatrick RecordStmtCount(S);
687e5dd7070Spatrick Visit(S->getTryBlock());
688e5dd7070Spatrick for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
689e5dd7070Spatrick Visit(S->getHandler(I));
690e5dd7070Spatrick // Counter tracks the continuation block of the try statement.
691e5dd7070Spatrick setCount(PGO.getRegionCount(S));
692e5dd7070Spatrick RecordNextStmtCount = true;
693e5dd7070Spatrick }
694e5dd7070Spatrick
VisitCXXCatchStmt__anon393d18f00111::ComputeRegionCounts695e5dd7070Spatrick void VisitCXXCatchStmt(const CXXCatchStmt *S) {
696e5dd7070Spatrick RecordNextStmtCount = false;
697e5dd7070Spatrick // Counter tracks the catch statement's handler block.
698e5dd7070Spatrick uint64_t CatchCount = setCount(PGO.getRegionCount(S));
699e5dd7070Spatrick CountMap[S] = CatchCount;
700e5dd7070Spatrick Visit(S->getHandlerBlock());
701e5dd7070Spatrick }
702e5dd7070Spatrick
VisitAbstractConditionalOperator__anon393d18f00111::ComputeRegionCounts703e5dd7070Spatrick void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
704e5dd7070Spatrick RecordStmtCount(E);
705e5dd7070Spatrick uint64_t ParentCount = CurrentCount;
706e5dd7070Spatrick Visit(E->getCond());
707e5dd7070Spatrick
708e5dd7070Spatrick // Counter tracks the "true" part of a conditional operator. The
709e5dd7070Spatrick // count in the "false" part will be calculated from this counter.
710e5dd7070Spatrick uint64_t TrueCount = setCount(PGO.getRegionCount(E));
711e5dd7070Spatrick CountMap[E->getTrueExpr()] = TrueCount;
712e5dd7070Spatrick Visit(E->getTrueExpr());
713e5dd7070Spatrick uint64_t OutCount = CurrentCount;
714e5dd7070Spatrick
715e5dd7070Spatrick uint64_t FalseCount = setCount(ParentCount - TrueCount);
716e5dd7070Spatrick CountMap[E->getFalseExpr()] = FalseCount;
717e5dd7070Spatrick Visit(E->getFalseExpr());
718e5dd7070Spatrick OutCount += CurrentCount;
719e5dd7070Spatrick
720e5dd7070Spatrick setCount(OutCount);
721e5dd7070Spatrick RecordNextStmtCount = true;
722e5dd7070Spatrick }
723e5dd7070Spatrick
VisitBinLAnd__anon393d18f00111::ComputeRegionCounts724e5dd7070Spatrick void VisitBinLAnd(const BinaryOperator *E) {
725e5dd7070Spatrick RecordStmtCount(E);
726e5dd7070Spatrick uint64_t ParentCount = CurrentCount;
727e5dd7070Spatrick Visit(E->getLHS());
728e5dd7070Spatrick // Counter tracks the right hand side of a logical and operator.
729e5dd7070Spatrick uint64_t RHSCount = setCount(PGO.getRegionCount(E));
730e5dd7070Spatrick CountMap[E->getRHS()] = RHSCount;
731e5dd7070Spatrick Visit(E->getRHS());
732e5dd7070Spatrick setCount(ParentCount + RHSCount - CurrentCount);
733e5dd7070Spatrick RecordNextStmtCount = true;
734e5dd7070Spatrick }
735e5dd7070Spatrick
VisitBinLOr__anon393d18f00111::ComputeRegionCounts736e5dd7070Spatrick void VisitBinLOr(const BinaryOperator *E) {
737e5dd7070Spatrick RecordStmtCount(E);
738e5dd7070Spatrick uint64_t ParentCount = CurrentCount;
739e5dd7070Spatrick Visit(E->getLHS());
740e5dd7070Spatrick // Counter tracks the right hand side of a logical or operator.
741e5dd7070Spatrick uint64_t RHSCount = setCount(PGO.getRegionCount(E));
742e5dd7070Spatrick CountMap[E->getRHS()] = RHSCount;
743e5dd7070Spatrick Visit(E->getRHS());
744e5dd7070Spatrick setCount(ParentCount + RHSCount - CurrentCount);
745e5dd7070Spatrick RecordNextStmtCount = true;
746e5dd7070Spatrick }
747e5dd7070Spatrick };
748e5dd7070Spatrick } // end anonymous namespace
749e5dd7070Spatrick
combine(HashType Type)750e5dd7070Spatrick void PGOHash::combine(HashType Type) {
751e5dd7070Spatrick // Check that we never combine 0 and only have six bits.
752e5dd7070Spatrick assert(Type && "Hash is invalid: unexpected type 0");
753e5dd7070Spatrick assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
754e5dd7070Spatrick
755e5dd7070Spatrick // Pass through MD5 if enough work has built up.
756e5dd7070Spatrick if (Count && Count % NumTypesPerWord == 0) {
757e5dd7070Spatrick using namespace llvm::support;
758e5dd7070Spatrick uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
759*12c85518Srobert MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
760e5dd7070Spatrick Working = 0;
761e5dd7070Spatrick }
762e5dd7070Spatrick
763e5dd7070Spatrick // Accumulate the current type.
764e5dd7070Spatrick ++Count;
765e5dd7070Spatrick Working = Working << NumBitsPerType | Type;
766e5dd7070Spatrick }
767e5dd7070Spatrick
finalize()768e5dd7070Spatrick uint64_t PGOHash::finalize() {
769e5dd7070Spatrick // Use Working as the hash directly if we never used MD5.
770e5dd7070Spatrick if (Count <= NumTypesPerWord)
771e5dd7070Spatrick // No need to byte swap here, since none of the math was endian-dependent.
772e5dd7070Spatrick // This number will be byte-swapped as required on endianness transitions,
773e5dd7070Spatrick // so we will see the same value on the other side.
774e5dd7070Spatrick return Working;
775e5dd7070Spatrick
776e5dd7070Spatrick // Check for remaining work in Working.
777ec727ea7Spatrick if (Working) {
778ec727ea7Spatrick // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
779ec727ea7Spatrick // is buggy because it converts a uint64_t into an array of uint8_t.
780ec727ea7Spatrick if (HashVersion < PGO_HASH_V3) {
781ec727ea7Spatrick MD5.update({(uint8_t)Working});
782ec727ea7Spatrick } else {
783ec727ea7Spatrick using namespace llvm::support;
784ec727ea7Spatrick uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
785*12c85518Srobert MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
786ec727ea7Spatrick }
787ec727ea7Spatrick }
788e5dd7070Spatrick
789e5dd7070Spatrick // Finalize the MD5 and return the hash.
790e5dd7070Spatrick llvm::MD5::MD5Result Result;
791e5dd7070Spatrick MD5.final(Result);
792e5dd7070Spatrick return Result.low();
793e5dd7070Spatrick }
794e5dd7070Spatrick
assignRegionCounters(GlobalDecl GD,llvm::Function * Fn)795e5dd7070Spatrick void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
796e5dd7070Spatrick const Decl *D = GD.getDecl();
797e5dd7070Spatrick if (!D->hasBody())
798e5dd7070Spatrick return;
799e5dd7070Spatrick
800a9ac8606Spatrick // Skip CUDA/HIP kernel launch stub functions.
801a9ac8606Spatrick if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
802a9ac8606Spatrick D->hasAttr<CUDAGlobalAttr>())
803a9ac8606Spatrick return;
804a9ac8606Spatrick
805e5dd7070Spatrick bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
806e5dd7070Spatrick llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
807e5dd7070Spatrick if (!InstrumentRegions && !PGOReader)
808e5dd7070Spatrick return;
809e5dd7070Spatrick if (D->isImplicit())
810e5dd7070Spatrick return;
811e5dd7070Spatrick // Constructors and destructors may be represented by several functions in IR.
812e5dd7070Spatrick // If so, instrument only base variant, others are implemented by delegation
813e5dd7070Spatrick // to the base one, it would be counted twice otherwise.
814e5dd7070Spatrick if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
815e5dd7070Spatrick if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
816e5dd7070Spatrick if (GD.getCtorType() != Ctor_Base &&
817e5dd7070Spatrick CodeGenFunction::IsConstructorDelegationValid(CCD))
818e5dd7070Spatrick return;
819e5dd7070Spatrick }
820e5dd7070Spatrick if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
821e5dd7070Spatrick return;
822e5dd7070Spatrick
823e5dd7070Spatrick CGM.ClearUnusedCoverageMapping(D);
824a9ac8606Spatrick if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
825a9ac8606Spatrick return;
826*12c85518Srobert if (Fn->hasFnAttribute(llvm::Attribute::SkipProfile))
827*12c85518Srobert return;
828a9ac8606Spatrick
829e5dd7070Spatrick setFuncName(Fn);
830e5dd7070Spatrick
831e5dd7070Spatrick mapRegionCounters(D);
832e5dd7070Spatrick if (CGM.getCodeGenOpts().CoverageMapping)
833e5dd7070Spatrick emitCounterRegionMapping(D);
834e5dd7070Spatrick if (PGOReader) {
835e5dd7070Spatrick SourceManager &SM = CGM.getContext().getSourceManager();
836e5dd7070Spatrick loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
837e5dd7070Spatrick computeRegionCounts(D);
838e5dd7070Spatrick applyFunctionAttributes(PGOReader, Fn);
839e5dd7070Spatrick }
840e5dd7070Spatrick }
841e5dd7070Spatrick
mapRegionCounters(const Decl * D)842e5dd7070Spatrick void CodeGenPGO::mapRegionCounters(const Decl *D) {
843e5dd7070Spatrick // Use the latest hash version when inserting instrumentation, but use the
844e5dd7070Spatrick // version in the indexed profile if we're reading PGO data.
845e5dd7070Spatrick PGOHashVersion HashVersion = PGO_HASH_LATEST;
846a9ac8606Spatrick uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
847a9ac8606Spatrick if (auto *PGOReader = CGM.getPGOReader()) {
848e5dd7070Spatrick HashVersion = getPGOHashVersion(PGOReader, CGM);
849a9ac8606Spatrick ProfileVersion = PGOReader->getVersion();
850a9ac8606Spatrick }
851e5dd7070Spatrick
852e5dd7070Spatrick RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
853a9ac8606Spatrick MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap);
854e5dd7070Spatrick if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
855e5dd7070Spatrick Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
856e5dd7070Spatrick else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
857e5dd7070Spatrick Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
858e5dd7070Spatrick else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
859e5dd7070Spatrick Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
860e5dd7070Spatrick else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
861e5dd7070Spatrick Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
862e5dd7070Spatrick assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
863e5dd7070Spatrick NumRegionCounters = Walker.NextCounter;
864e5dd7070Spatrick FunctionHash = Walker.Hash.finalize();
865e5dd7070Spatrick }
866e5dd7070Spatrick
skipRegionMappingForDecl(const Decl * D)867e5dd7070Spatrick bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
868e5dd7070Spatrick if (!D->getBody())
869e5dd7070Spatrick return true;
870e5dd7070Spatrick
871a9ac8606Spatrick // Skip host-only functions in the CUDA device compilation and device-only
872a9ac8606Spatrick // functions in the host compilation. Just roughly filter them out based on
873a9ac8606Spatrick // the function attributes. If there are effectively host-only or device-only
874a9ac8606Spatrick // ones, their coverage mapping may still be generated.
875a9ac8606Spatrick if (CGM.getLangOpts().CUDA &&
876a9ac8606Spatrick ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
877a9ac8606Spatrick !D->hasAttr<CUDAGlobalAttr>()) ||
878a9ac8606Spatrick (!CGM.getLangOpts().CUDAIsDevice &&
879a9ac8606Spatrick (D->hasAttr<CUDAGlobalAttr>() ||
880a9ac8606Spatrick (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
881a9ac8606Spatrick return true;
882a9ac8606Spatrick
883e5dd7070Spatrick // Don't map the functions in system headers.
884e5dd7070Spatrick const auto &SM = CGM.getContext().getSourceManager();
885e5dd7070Spatrick auto Loc = D->getBody()->getBeginLoc();
886e5dd7070Spatrick return SM.isInSystemHeader(Loc);
887e5dd7070Spatrick }
888e5dd7070Spatrick
emitCounterRegionMapping(const Decl * D)889e5dd7070Spatrick void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
890e5dd7070Spatrick if (skipRegionMappingForDecl(D))
891e5dd7070Spatrick return;
892e5dd7070Spatrick
893e5dd7070Spatrick std::string CoverageMapping;
894e5dd7070Spatrick llvm::raw_string_ostream OS(CoverageMapping);
895e5dd7070Spatrick CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
896e5dd7070Spatrick CGM.getContext().getSourceManager(),
897e5dd7070Spatrick CGM.getLangOpts(), RegionCounterMap.get());
898e5dd7070Spatrick MappingGen.emitCounterMapping(D, OS);
899e5dd7070Spatrick OS.flush();
900e5dd7070Spatrick
901e5dd7070Spatrick if (CoverageMapping.empty())
902e5dd7070Spatrick return;
903e5dd7070Spatrick
904e5dd7070Spatrick CGM.getCoverageMapping()->addFunctionMappingRecord(
905e5dd7070Spatrick FuncNameVar, FuncName, FunctionHash, CoverageMapping);
906e5dd7070Spatrick }
907e5dd7070Spatrick
908e5dd7070Spatrick void
emitEmptyCounterMapping(const Decl * D,StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)909e5dd7070Spatrick CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
910e5dd7070Spatrick llvm::GlobalValue::LinkageTypes Linkage) {
911e5dd7070Spatrick if (skipRegionMappingForDecl(D))
912e5dd7070Spatrick return;
913e5dd7070Spatrick
914e5dd7070Spatrick std::string CoverageMapping;
915e5dd7070Spatrick llvm::raw_string_ostream OS(CoverageMapping);
916e5dd7070Spatrick CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
917e5dd7070Spatrick CGM.getContext().getSourceManager(),
918e5dd7070Spatrick CGM.getLangOpts());
919e5dd7070Spatrick MappingGen.emitEmptyMapping(D, OS);
920e5dd7070Spatrick OS.flush();
921e5dd7070Spatrick
922e5dd7070Spatrick if (CoverageMapping.empty())
923e5dd7070Spatrick return;
924e5dd7070Spatrick
925e5dd7070Spatrick setFuncName(Name, Linkage);
926e5dd7070Spatrick CGM.getCoverageMapping()->addFunctionMappingRecord(
927e5dd7070Spatrick FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
928e5dd7070Spatrick }
929e5dd7070Spatrick
computeRegionCounts(const Decl * D)930e5dd7070Spatrick void CodeGenPGO::computeRegionCounts(const Decl *D) {
931e5dd7070Spatrick StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
932e5dd7070Spatrick ComputeRegionCounts Walker(*StmtCountMap, *this);
933e5dd7070Spatrick if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
934e5dd7070Spatrick Walker.VisitFunctionDecl(FD);
935e5dd7070Spatrick else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
936e5dd7070Spatrick Walker.VisitObjCMethodDecl(MD);
937e5dd7070Spatrick else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
938e5dd7070Spatrick Walker.VisitBlockDecl(BD);
939e5dd7070Spatrick else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
940e5dd7070Spatrick Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
941e5dd7070Spatrick }
942e5dd7070Spatrick
943e5dd7070Spatrick void
applyFunctionAttributes(llvm::IndexedInstrProfReader * PGOReader,llvm::Function * Fn)944e5dd7070Spatrick CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
945e5dd7070Spatrick llvm::Function *Fn) {
946e5dd7070Spatrick if (!haveRegionCounts())
947e5dd7070Spatrick return;
948e5dd7070Spatrick
949e5dd7070Spatrick uint64_t FunctionCount = getRegionCount(nullptr);
950e5dd7070Spatrick Fn->setEntryCount(FunctionCount);
951e5dd7070Spatrick }
952e5dd7070Spatrick
emitCounterIncrement(CGBuilderTy & Builder,const Stmt * S,llvm::Value * StepV)953e5dd7070Spatrick void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
954e5dd7070Spatrick llvm::Value *StepV) {
955e5dd7070Spatrick if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
956e5dd7070Spatrick return;
957e5dd7070Spatrick if (!Builder.GetInsertBlock())
958e5dd7070Spatrick return;
959e5dd7070Spatrick
960e5dd7070Spatrick unsigned Counter = (*RegionCounterMap)[S];
961e5dd7070Spatrick auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
962e5dd7070Spatrick
963e5dd7070Spatrick llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
964e5dd7070Spatrick Builder.getInt64(FunctionHash),
965e5dd7070Spatrick Builder.getInt32(NumRegionCounters),
966e5dd7070Spatrick Builder.getInt32(Counter), StepV};
967e5dd7070Spatrick if (!StepV)
968e5dd7070Spatrick Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
969*12c85518Srobert ArrayRef(Args, 4));
970e5dd7070Spatrick else
971e5dd7070Spatrick Builder.CreateCall(
972e5dd7070Spatrick CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
973*12c85518Srobert ArrayRef(Args));
974e5dd7070Spatrick }
975e5dd7070Spatrick
setValueProfilingFlag(llvm::Module & M)976a9ac8606Spatrick void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
977a9ac8606Spatrick if (CGM.getCodeGenOpts().hasProfileClangInstr())
978a9ac8606Spatrick M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling",
979a9ac8606Spatrick uint32_t(EnableValueProfiling));
980a9ac8606Spatrick }
981a9ac8606Spatrick
982e5dd7070Spatrick // This method either inserts a call to the profile run-time during
983e5dd7070Spatrick // instrumentation or puts profile data into metadata for PGO use.
valueProfile(CGBuilderTy & Builder,uint32_t ValueKind,llvm::Instruction * ValueSite,llvm::Value * ValuePtr)984e5dd7070Spatrick void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
985e5dd7070Spatrick llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
986e5dd7070Spatrick
987e5dd7070Spatrick if (!EnableValueProfiling)
988e5dd7070Spatrick return;
989e5dd7070Spatrick
990e5dd7070Spatrick if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
991e5dd7070Spatrick return;
992e5dd7070Spatrick
993e5dd7070Spatrick if (isa<llvm::Constant>(ValuePtr))
994e5dd7070Spatrick return;
995e5dd7070Spatrick
996e5dd7070Spatrick bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
997e5dd7070Spatrick if (InstrumentValueSites && RegionCounterMap) {
998e5dd7070Spatrick auto BuilderInsertPoint = Builder.saveIP();
999e5dd7070Spatrick Builder.SetInsertPoint(ValueSite);
1000e5dd7070Spatrick llvm::Value *Args[5] = {
1001e5dd7070Spatrick llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
1002e5dd7070Spatrick Builder.getInt64(FunctionHash),
1003e5dd7070Spatrick Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
1004e5dd7070Spatrick Builder.getInt32(ValueKind),
1005e5dd7070Spatrick Builder.getInt32(NumValueSites[ValueKind]++)
1006e5dd7070Spatrick };
1007e5dd7070Spatrick Builder.CreateCall(
1008e5dd7070Spatrick CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
1009e5dd7070Spatrick Builder.restoreIP(BuilderInsertPoint);
1010e5dd7070Spatrick return;
1011e5dd7070Spatrick }
1012e5dd7070Spatrick
1013e5dd7070Spatrick llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1014e5dd7070Spatrick if (PGOReader && haveRegionCounts()) {
1015e5dd7070Spatrick // We record the top most called three functions at each call site.
1016e5dd7070Spatrick // Profile metadata contains "VP" string identifying this metadata
1017e5dd7070Spatrick // as value profiling data, then a uint32_t value for the value profiling
1018e5dd7070Spatrick // kind, a uint64_t value for the total number of times the call is
1019e5dd7070Spatrick // executed, followed by the function hash and execution count (uint64_t)
1020e5dd7070Spatrick // pairs for each function.
1021e5dd7070Spatrick if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1022e5dd7070Spatrick return;
1023e5dd7070Spatrick
1024e5dd7070Spatrick llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
1025e5dd7070Spatrick (llvm::InstrProfValueKind)ValueKind,
1026e5dd7070Spatrick NumValueSites[ValueKind]);
1027e5dd7070Spatrick
1028e5dd7070Spatrick NumValueSites[ValueKind]++;
1029e5dd7070Spatrick }
1030e5dd7070Spatrick }
1031e5dd7070Spatrick
loadRegionCounts(llvm::IndexedInstrProfReader * PGOReader,bool IsInMainFile)1032e5dd7070Spatrick void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1033e5dd7070Spatrick bool IsInMainFile) {
1034e5dd7070Spatrick CGM.getPGOStats().addVisited(IsInMainFile);
1035e5dd7070Spatrick RegionCounts.clear();
1036e5dd7070Spatrick llvm::Expected<llvm::InstrProfRecord> RecordExpected =
1037e5dd7070Spatrick PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1038e5dd7070Spatrick if (auto E = RecordExpected.takeError()) {
1039e5dd7070Spatrick auto IPE = llvm::InstrProfError::take(std::move(E));
1040e5dd7070Spatrick if (IPE == llvm::instrprof_error::unknown_function)
1041e5dd7070Spatrick CGM.getPGOStats().addMissing(IsInMainFile);
1042e5dd7070Spatrick else if (IPE == llvm::instrprof_error::hash_mismatch)
1043e5dd7070Spatrick CGM.getPGOStats().addMismatched(IsInMainFile);
1044e5dd7070Spatrick else if (IPE == llvm::instrprof_error::malformed)
1045e5dd7070Spatrick // TODO: Consider a more specific warning for this case.
1046e5dd7070Spatrick CGM.getPGOStats().addMismatched(IsInMainFile);
1047e5dd7070Spatrick return;
1048e5dd7070Spatrick }
1049e5dd7070Spatrick ProfRecord =
1050e5dd7070Spatrick std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1051e5dd7070Spatrick RegionCounts = ProfRecord->Counts;
1052e5dd7070Spatrick }
1053e5dd7070Spatrick
1054e5dd7070Spatrick /// Calculate what to divide by to scale weights.
1055e5dd7070Spatrick ///
1056e5dd7070Spatrick /// Given the maximum weight, calculate a divisor that will scale all the
1057e5dd7070Spatrick /// weights to strictly less than UINT32_MAX.
calculateWeightScale(uint64_t MaxWeight)1058e5dd7070Spatrick static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1059e5dd7070Spatrick return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1060e5dd7070Spatrick }
1061e5dd7070Spatrick
1062e5dd7070Spatrick /// Scale an individual branch weight (and add 1).
1063e5dd7070Spatrick ///
1064e5dd7070Spatrick /// Scale a 64-bit weight down to 32-bits using \c Scale.
1065e5dd7070Spatrick ///
1066e5dd7070Spatrick /// According to Laplace's Rule of Succession, it is better to compute the
1067e5dd7070Spatrick /// weight based on the count plus 1, so universally add 1 to the value.
1068e5dd7070Spatrick ///
1069e5dd7070Spatrick /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1070e5dd7070Spatrick /// greater than \c Weight.
scaleBranchWeight(uint64_t Weight,uint64_t Scale)1071e5dd7070Spatrick static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1072e5dd7070Spatrick assert(Scale && "scale by 0?");
1073e5dd7070Spatrick uint64_t Scaled = Weight / Scale + 1;
1074e5dd7070Spatrick assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1075e5dd7070Spatrick return Scaled;
1076e5dd7070Spatrick }
1077e5dd7070Spatrick
createProfileWeights(uint64_t TrueCount,uint64_t FalseCount) const1078e5dd7070Spatrick llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1079a9ac8606Spatrick uint64_t FalseCount) const {
1080e5dd7070Spatrick // Check for empty weights.
1081e5dd7070Spatrick if (!TrueCount && !FalseCount)
1082e5dd7070Spatrick return nullptr;
1083e5dd7070Spatrick
1084e5dd7070Spatrick // Calculate how to scale down to 32-bits.
1085e5dd7070Spatrick uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1086e5dd7070Spatrick
1087e5dd7070Spatrick llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1088e5dd7070Spatrick return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1089e5dd7070Spatrick scaleBranchWeight(FalseCount, Scale));
1090e5dd7070Spatrick }
1091e5dd7070Spatrick
1092e5dd7070Spatrick llvm::MDNode *
createProfileWeights(ArrayRef<uint64_t> Weights) const1093a9ac8606Spatrick CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1094e5dd7070Spatrick // We need at least two elements to create meaningful weights.
1095e5dd7070Spatrick if (Weights.size() < 2)
1096e5dd7070Spatrick return nullptr;
1097e5dd7070Spatrick
1098e5dd7070Spatrick // Check for empty weights.
1099e5dd7070Spatrick uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1100e5dd7070Spatrick if (MaxWeight == 0)
1101e5dd7070Spatrick return nullptr;
1102e5dd7070Spatrick
1103e5dd7070Spatrick // Calculate how to scale down to 32-bits.
1104e5dd7070Spatrick uint64_t Scale = calculateWeightScale(MaxWeight);
1105e5dd7070Spatrick
1106e5dd7070Spatrick SmallVector<uint32_t, 16> ScaledWeights;
1107e5dd7070Spatrick ScaledWeights.reserve(Weights.size());
1108e5dd7070Spatrick for (uint64_t W : Weights)
1109e5dd7070Spatrick ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1110e5dd7070Spatrick
1111e5dd7070Spatrick llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1112e5dd7070Spatrick return MDHelper.createBranchWeights(ScaledWeights);
1113e5dd7070Spatrick }
1114e5dd7070Spatrick
1115a9ac8606Spatrick llvm::MDNode *
createProfileWeightsForLoop(const Stmt * Cond,uint64_t LoopCount) const1116a9ac8606Spatrick CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1117a9ac8606Spatrick uint64_t LoopCount) const {
1118e5dd7070Spatrick if (!PGO.haveRegionCounts())
1119e5dd7070Spatrick return nullptr;
1120*12c85518Srobert std::optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1121ec727ea7Spatrick if (!CondCount || *CondCount == 0)
1122e5dd7070Spatrick return nullptr;
1123e5dd7070Spatrick return createProfileWeights(LoopCount,
1124e5dd7070Spatrick std::max(*CondCount, LoopCount) - LoopCount);
1125e5dd7070Spatrick }
1126