xref: /llvm-project/llvm/lib/Transforms/Instrumentation/PGOCtxProfFlattening.cpp (revision 783bac7ffb8f0d58d7381d90fcaa082eb0be1c1d)
1 //===- PGOCtxProfFlattening.cpp - Contextual Instr. Flattening ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Flattens the contextual profile and lowers it to MD_prof.
10 // This should happen after all IPO (which is assumed to have maintained the
11 // contextual profile) happened. Flattening consists of summing the values at
12 // the same index of the counters belonging to all the contexts of a function.
13 // The lowering consists of materializing the counter values to function
14 // entrypoint counts and branch probabilities.
15 //
16 // This pass also removes contextual instrumentation, which has been kept around
17 // to facilitate its functionality.
18 //
19 //===----------------------------------------------------------------------===//
20 
21 #include "llvm/Transforms/Instrumentation/PGOCtxProfFlattening.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/ScopeExit.h"
24 #include "llvm/Analysis/CtxProfAnalysis.h"
25 #include "llvm/Analysis/ProfileSummaryInfo.h"
26 #include "llvm/IR/Analysis.h"
27 #include "llvm/IR/CFG.h"
28 #include "llvm/IR/Dominators.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/IntrinsicInst.h"
31 #include "llvm/IR/Module.h"
32 #include "llvm/IR/PassManager.h"
33 #include "llvm/IR/ProfileSummary.h"
34 #include "llvm/ProfileData/ProfileCommon.h"
35 #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
36 #include "llvm/Transforms/Scalar/DCE.h"
37 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
38 #include <deque>
39 
40 using namespace llvm;
41 
42 namespace {
43 
44 class ProfileAnnotator final {
45   class BBInfo;
46   struct EdgeInfo {
47     BBInfo *const Src;
48     BBInfo *const Dest;
49     std::optional<uint64_t> Count;
50 
51     explicit EdgeInfo(BBInfo &Src, BBInfo &Dest) : Src(&Src), Dest(&Dest) {}
52   };
53 
54   class BBInfo {
55     std::optional<uint64_t> Count;
56     // OutEdges is dimensioned to match the number of terminator operands.
57     // Entries in the vector match the index in the terminator operand list. In
58     // some cases - see `shouldExcludeEdge` and its implementation - an entry
59     // will be nullptr.
60     // InEdges doesn't have the above constraint.
61     SmallVector<EdgeInfo *> OutEdges;
62     SmallVector<EdgeInfo *> InEdges;
63     size_t UnknownCountOutEdges = 0;
64     size_t UnknownCountInEdges = 0;
65 
66     // Pass AssumeAllKnown when we try to propagate counts from edges to BBs -
67     // because all the edge counters must be known.
68     // Return std::nullopt if there were no edges to sum. The user can decide
69     // how to interpret that.
70     std::optional<uint64_t> getEdgeSum(const SmallVector<EdgeInfo *> &Edges,
71                                        bool AssumeAllKnown) const {
72       std::optional<uint64_t> Sum;
73       for (const auto *E : Edges) {
74         // `Edges` may be `OutEdges`, case in which `E` could be nullptr.
75         if (E) {
76           if (!Sum.has_value())
77             Sum = 0;
78           *Sum += (AssumeAllKnown ? *E->Count : E->Count.value_or(0U));
79         }
80       }
81       return Sum;
82     }
83 
84     bool computeCountFrom(const SmallVector<EdgeInfo *> &Edges) {
85       assert(!Count.has_value());
86       Count = getEdgeSum(Edges, true);
87       return Count.has_value();
88     }
89 
90     void setSingleUnknownEdgeCount(SmallVector<EdgeInfo *> &Edges) {
91       uint64_t KnownSum = getEdgeSum(Edges, false).value_or(0U);
92       uint64_t EdgeVal = *Count > KnownSum ? *Count - KnownSum : 0U;
93       EdgeInfo *E = nullptr;
94       for (auto *I : Edges)
95         if (I && !I->Count.has_value()) {
96           E = I;
97 #ifdef NDEBUG
98           break;
99 #else
100           assert((!E || E == I) &&
101                  "Expected exactly one edge to have an unknown count, "
102                  "found a second one");
103           continue;
104 #endif
105         }
106       assert(E && "Expected exactly one edge to have an unknown count");
107       assert(!E->Count.has_value());
108       E->Count = EdgeVal;
109       assert(E->Src->UnknownCountOutEdges > 0);
110       assert(E->Dest->UnknownCountInEdges > 0);
111       --E->Src->UnknownCountOutEdges;
112       --E->Dest->UnknownCountInEdges;
113     }
114 
115   public:
116     BBInfo(size_t NumInEdges, size_t NumOutEdges, std::optional<uint64_t> Count)
117         : Count(Count) {
118       // For in edges, we just want to pre-allocate enough space, since we know
119       // it at this stage. For out edges, we will insert edges at the indices
120       // corresponding to positions in this BB's terminator instruction, so we
121       // construct a default (nullptr values)-initialized vector. A nullptr edge
122       // corresponds to those that are excluded (see shouldExcludeEdge).
123       InEdges.reserve(NumInEdges);
124       OutEdges.resize(NumOutEdges);
125     }
126 
127     bool tryTakeCountFromKnownOutEdges(const BasicBlock &BB) {
128       if (!UnknownCountOutEdges) {
129         return computeCountFrom(OutEdges);
130       }
131       return false;
132     }
133 
134     bool tryTakeCountFromKnownInEdges(const BasicBlock &BB) {
135       if (!UnknownCountInEdges) {
136         return computeCountFrom(InEdges);
137       }
138       return false;
139     }
140 
141     void addInEdge(EdgeInfo &Info) {
142       InEdges.push_back(&Info);
143       ++UnknownCountInEdges;
144     }
145 
146     // For the out edges, we care about the position we place them in, which is
147     // the position in terminator instruction's list (at construction). Later,
148     // we build branch_weights metadata with edge frequency values matching
149     // these positions.
150     void addOutEdge(size_t Index, EdgeInfo &Info) {
151       OutEdges[Index] = &Info;
152       ++UnknownCountOutEdges;
153     }
154 
155     bool hasCount() const { return Count.has_value(); }
156 
157     uint64_t getCount() const { return *Count; }
158 
159     bool trySetSingleUnknownInEdgeCount() {
160       if (UnknownCountInEdges == 1) {
161         setSingleUnknownEdgeCount(InEdges);
162         return true;
163       }
164       return false;
165     }
166 
167     bool trySetSingleUnknownOutEdgeCount() {
168       if (UnknownCountOutEdges == 1) {
169         setSingleUnknownEdgeCount(OutEdges);
170         return true;
171       }
172       return false;
173     }
174     size_t getNumOutEdges() const { return OutEdges.size(); }
175 
176     uint64_t getEdgeCount(size_t Index) const {
177       if (auto *E = OutEdges[Index])
178         return *E->Count;
179       return 0U;
180     }
181   };
182 
183   Function &F;
184   const SmallVectorImpl<uint64_t> &Counters;
185   // To be accessed through getBBInfo() after construction.
186   std::map<const BasicBlock *, BBInfo> BBInfos;
187   std::vector<EdgeInfo> EdgeInfos;
188   InstrProfSummaryBuilder &PB;
189 
190   // This is an adaptation of PGOUseFunc::populateCounters.
191   // FIXME(mtrofin): look into factoring the code to share one implementation.
192   void propagateCounterValues(const SmallVectorImpl<uint64_t> &Counters) {
193     bool KeepGoing = true;
194     while (KeepGoing) {
195       KeepGoing = false;
196       for (const auto &BB : F) {
197         auto &Info = getBBInfo(BB);
198         if (!Info.hasCount())
199           KeepGoing |= Info.tryTakeCountFromKnownOutEdges(BB) ||
200                        Info.tryTakeCountFromKnownInEdges(BB);
201         if (Info.hasCount()) {
202           KeepGoing |= Info.trySetSingleUnknownOutEdgeCount();
203           KeepGoing |= Info.trySetSingleUnknownInEdgeCount();
204         }
205       }
206     }
207   }
208   // The only criteria for exclusion is faux suspend -> exit edges in presplit
209   // coroutines. The API serves for readability, currently.
210   bool shouldExcludeEdge(const BasicBlock &Src, const BasicBlock &Dest) const {
211     return llvm::isPresplitCoroSuspendExitEdge(Src, Dest);
212   }
213 
214   BBInfo &getBBInfo(const BasicBlock &BB) { return BBInfos.find(&BB)->second; }
215 
216   const BBInfo &getBBInfo(const BasicBlock &BB) const {
217     return BBInfos.find(&BB)->second;
218   }
219 
220   // validation function after we propagate the counters: all BBs and edges'
221   // counters must have a value.
222   bool allCountersAreAssigned() const {
223     for (const auto &BBInfo : BBInfos)
224       if (!BBInfo.second.hasCount())
225         return false;
226     for (const auto &EdgeInfo : EdgeInfos)
227       if (!EdgeInfo.Count.has_value())
228         return false;
229     return true;
230   }
231 
232   /// Check that all paths from the entry basic block that use edges with
233   /// non-zero counts arrive at a basic block with no successors (i.e. "exit")
234   bool allTakenPathsExit() const {
235     std::deque<const BasicBlock *> Worklist;
236     DenseSet<const BasicBlock *> Visited;
237     Worklist.push_back(&F.getEntryBlock());
238     bool HitExit = false;
239     while (!Worklist.empty()) {
240       const auto *BB = Worklist.front();
241       Worklist.pop_front();
242       if (!Visited.insert(BB).second)
243         continue;
244       if (succ_size(BB) == 0) {
245         if (isa<UnreachableInst>(BB->getTerminator()))
246           return false;
247         HitExit = true;
248         continue;
249       }
250       if (succ_size(BB) == 1) {
251         Worklist.push_back(BB->getUniqueSuccessor());
252         continue;
253       }
254       const auto &BBInfo = getBBInfo(*BB);
255       bool HasAWayOut = false;
256       for (auto I = 0U; I < BB->getTerminator()->getNumSuccessors(); ++I) {
257         const auto *Succ = BB->getTerminator()->getSuccessor(I);
258         if (!shouldExcludeEdge(*BB, *Succ)) {
259           if (BBInfo.getEdgeCount(I) > 0) {
260             HasAWayOut = true;
261             Worklist.push_back(Succ);
262           }
263         }
264       }
265       if (!HasAWayOut)
266         return false;
267     }
268     return HitExit;
269   }
270 
271   bool allNonColdSelectsHaveProfile() const {
272     for (const auto &BB : F) {
273       if (getBBInfo(BB).getCount() > 0) {
274         for (const auto &I : BB) {
275           if (const auto *SI = dyn_cast<SelectInst>(&I)) {
276             if (!SI->getMetadata(LLVMContext::MD_prof)) {
277               return false;
278             }
279           }
280         }
281       }
282     }
283     return true;
284   }
285 
286 public:
287   ProfileAnnotator(Function &F, const SmallVectorImpl<uint64_t> &Counters,
288                    InstrProfSummaryBuilder &PB)
289       : F(F), Counters(Counters), PB(PB) {
290     assert(!F.isDeclaration());
291     assert(!Counters.empty());
292     size_t NrEdges = 0;
293     for (const auto &BB : F) {
294       std::optional<uint64_t> Count;
295       if (auto *Ins = CtxProfAnalysis::getBBInstrumentation(
296               const_cast<BasicBlock &>(BB))) {
297         auto Index = Ins->getIndex()->getZExtValue();
298         assert(Index < Counters.size() &&
299                "The index must be inside the counters vector by construction - "
300                "tripping this assertion indicates a bug in how the contextual "
301                "profile is managed by IPO transforms");
302         (void)Index;
303         Count = Counters[Ins->getIndex()->getZExtValue()];
304       } else if (isa<UnreachableInst>(BB.getTerminator())) {
305         // The program presumably didn't crash.
306         Count = 0;
307       }
308       auto [It, Ins] =
309           BBInfos.insert({&BB, {pred_size(&BB), succ_size(&BB), Count}});
310       (void)Ins;
311       assert(Ins && "We iterate through the function's BBs, no reason to "
312                     "insert one more than once");
313       NrEdges += llvm::count_if(successors(&BB), [&](const auto *Succ) {
314         return !shouldExcludeEdge(BB, *Succ);
315       });
316     }
317     // Pre-allocate the vector, we want references to its contents to be stable.
318     EdgeInfos.reserve(NrEdges);
319     for (const auto &BB : F) {
320       auto &Info = getBBInfo(BB);
321       for (auto I = 0U; I < BB.getTerminator()->getNumSuccessors(); ++I) {
322         const auto *Succ = BB.getTerminator()->getSuccessor(I);
323         if (!shouldExcludeEdge(BB, *Succ)) {
324           auto &EI = EdgeInfos.emplace_back(getBBInfo(BB), getBBInfo(*Succ));
325           Info.addOutEdge(I, EI);
326           getBBInfo(*Succ).addInEdge(EI);
327         }
328       }
329     }
330     assert(EdgeInfos.capacity() == NrEdges &&
331            "The capacity of EdgeInfos should have stayed unchanged it was "
332            "populated, because we need pointers to its contents to be stable");
333   }
334 
335   void setProfileForSelectInstructions(BasicBlock &BB, const BBInfo &BBInfo) {
336     if (BBInfo.getCount() == 0)
337       return;
338 
339     for (auto &I : BB) {
340       if (auto *SI = dyn_cast<SelectInst>(&I)) {
341         if (auto *Step = CtxProfAnalysis::getSelectInstrumentation(*SI)) {
342           auto Index = Step->getIndex()->getZExtValue();
343           assert(Index < Counters.size() &&
344                  "The index of the step instruction must be inside the "
345                  "counters vector by "
346                  "construction - tripping this assertion indicates a bug in "
347                  "how the contextual profile is managed by IPO transforms");
348           auto TotalCount = BBInfo.getCount();
349           auto TrueCount = Counters[Index];
350           auto FalseCount =
351               (TotalCount > TrueCount ? TotalCount - TrueCount : 0U);
352           setProfMetadata(F.getParent(), SI, {TrueCount, FalseCount},
353                           std::max(TrueCount, FalseCount));
354           PB.addInternalCount(TrueCount);
355           PB.addInternalCount(FalseCount);
356         }
357       }
358     }
359   }
360 
361   /// Assign branch weights and function entry count. Also update the PSI
362   /// builder.
363   void assignProfileData() {
364     assert(!Counters.empty());
365     propagateCounterValues(Counters);
366     F.setEntryCount(Counters[0]);
367     PB.addEntryCount(Counters[0]);
368 
369     for (auto &BB : F) {
370       const auto &BBInfo = getBBInfo(BB);
371       setProfileForSelectInstructions(BB, BBInfo);
372       if (succ_size(&BB) < 2)
373         continue;
374       auto *Term = BB.getTerminator();
375       SmallVector<uint64_t, 2> EdgeCounts(Term->getNumSuccessors(), 0);
376       uint64_t MaxCount = 0;
377 
378       for (unsigned SuccIdx = 0, Size = BBInfo.getNumOutEdges(); SuccIdx < Size;
379            ++SuccIdx) {
380         uint64_t EdgeCount = BBInfo.getEdgeCount(SuccIdx);
381         if (EdgeCount > MaxCount)
382           MaxCount = EdgeCount;
383         EdgeCounts[SuccIdx] = EdgeCount;
384         PB.addInternalCount(EdgeCount);
385       }
386 
387       if (MaxCount != 0)
388         setProfMetadata(F.getParent(), Term, EdgeCounts, MaxCount);
389     }
390     assert(allCountersAreAssigned() &&
391            "[ctx-prof] Expected all counters have been assigned.");
392     assert(allTakenPathsExit() &&
393            "[ctx-prof] Encountered a BB with more than one successor, where "
394            "all outgoing edges have a 0 count. This occurs in non-exiting "
395            "functions (message pumps, usually) which are not supported in the "
396            "contextual profiling case");
397     assert(allNonColdSelectsHaveProfile() &&
398            "[ctx-prof] All non-cold select instructions were expected to have "
399            "a profile.");
400   }
401 };
402 
403 [[maybe_unused]] bool areAllBBsReachable(const Function &F,
404                                          FunctionAnalysisManager &FAM) {
405   auto &DT = FAM.getResult<DominatorTreeAnalysis>(const_cast<Function &>(F));
406   return llvm::all_of(
407       F, [&](const BasicBlock &BB) { return DT.isReachableFromEntry(&BB); });
408 }
409 
410 void clearColdFunctionProfile(Function &F) {
411   for (auto &BB : F)
412     BB.getTerminator()->setMetadata(LLVMContext::MD_prof, nullptr);
413   F.setEntryCount(0U);
414 }
415 
416 void removeInstrumentation(Function &F) {
417   for (auto &BB : F)
418     for (auto &I : llvm::make_early_inc_range(BB))
419       if (isa<InstrProfCntrInstBase>(I))
420         I.eraseFromParent();
421 }
422 
423 } // namespace
424 
425 PreservedAnalyses PGOCtxProfFlatteningPass::run(Module &M,
426                                                 ModuleAnalysisManager &MAM) {
427   // Ensure in all cases the instrumentation is removed: if this module had no
428   // roots, the contextual profile would evaluate to false, but there would
429   // still be instrumentation.
430   // Note: in such cases we leave as-is any other profile info (if present -
431   // e.g. synthetic weights, etc) because it wouldn't interfere with the
432   // contextual - based one (which would be in other modules)
433   auto OnExit = llvm::make_scope_exit([&]() {
434     for (auto &F : M)
435       removeInstrumentation(F);
436   });
437   auto &CtxProf = MAM.getResult<CtxProfAnalysis>(M);
438   if (!CtxProf)
439     return PreservedAnalyses::none();
440 
441   const auto FlattenedProfile = CtxProf.flatten();
442 
443   InstrProfSummaryBuilder PB(ProfileSummaryBuilder::DefaultCutoffs);
444   for (auto &F : M) {
445     if (F.isDeclaration())
446       continue;
447 
448     assert(areAllBBsReachable(
449                F, MAM.getResult<FunctionAnalysisManagerModuleProxy>(M)
450                       .getManager()) &&
451            "Function has unreacheable basic blocks. The expectation was that "
452            "DCE was run before.");
453 
454     auto It = FlattenedProfile.find(AssignGUIDPass::getGUID(F));
455     // If this function didn't appear in the contextual profile, it's cold.
456     if (It == FlattenedProfile.end())
457       clearColdFunctionProfile(F);
458     else {
459       ProfileAnnotator S(F, It->second, PB);
460       S.assignProfileData();
461     }
462   }
463 
464   auto &PSI = MAM.getResult<ProfileSummaryAnalysis>(M);
465 
466   M.setProfileSummary(PB.getSummary()->getMD(M.getContext()),
467                       ProfileSummary::Kind::PSK_Instr);
468   PSI.refresh();
469   return PreservedAnalyses::none();
470 }
471