xref: /llvm-project/llvm/lib/CodeGen/MIRSampleProfile.cpp (revision d871b2e0d09b872c57139ee0e24f966d58b92d33)
1 //===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides the implementation of the MIRSampleProfile loader, mainly
10 // for flow sensitive SampleFDO.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/CodeGen/MIRSampleProfile.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/DenseSet.h"
17 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
18 #include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
19 #include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
20 #include "llvm/CodeGen/MachineDominators.h"
21 #include "llvm/CodeGen/MachineInstr.h"
22 #include "llvm/CodeGen/MachineLoopInfo.h"
23 #include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h"
24 #include "llvm/CodeGen/MachinePostDominators.h"
25 #include "llvm/CodeGen/Passes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/PseudoProbe.h"
28 #include "llvm/InitializePasses.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/VirtualFileSystem.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h"
34 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h"
35 #include <optional>
36 
37 using namespace llvm;
38 using namespace sampleprof;
39 using namespace llvm::sampleprofutil;
40 using ProfileCount = Function::ProfileCount;
41 
42 #define DEBUG_TYPE "fs-profile-loader"
43 
44 static cl::opt<bool> ShowFSBranchProb(
45     "show-fs-branchprob", cl::Hidden, cl::init(false),
46     cl::desc("Print setting flow sensitive branch probabilities"));
47 static cl::opt<unsigned> FSProfileDebugProbDiffThreshold(
48     "fs-profile-debug-prob-diff-threshold", cl::init(10),
49     cl::desc("Only show debug message if the branch probility is greater than "
50              "this value (in percentage)."));
51 
52 static cl::opt<unsigned> FSProfileDebugBWThreshold(
53     "fs-profile-debug-bw-threshold", cl::init(10000),
54     cl::desc("Only show debug message if the source branch weight is greater "
55              " than this value."));
56 
57 static cl::opt<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden,
58                                    cl::init(false),
59                                    cl::desc("View BFI before MIR loader"));
60 static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden,
61                                   cl::init(false),
62                                   cl::desc("View BFI after MIR loader"));
63 
64 namespace llvm {
65 extern cl::opt<bool> ImprovedFSDiscriminator;
66 }
67 char MIRProfileLoaderPass::ID = 0;
68 
69 INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE,
70                       "Load MIR Sample Profile",
71                       /* cfg = */ false, /* is_analysis = */ false)
72 INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfoWrapperPass)
73 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
74 INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass)
75 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass)
76 INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass)
77 INITIALIZE_PASS_END(MIRProfileLoaderPass, DEBUG_TYPE, "Load MIR Sample Profile",
78                     /* cfg = */ false, /* is_analysis = */ false)
79 
80 char &llvm::MIRProfileLoaderPassID = MIRProfileLoaderPass::ID;
81 
82 FunctionPass *
83 llvm::createMIRProfileLoaderPass(std::string File, std::string RemappingFile,
84                                  FSDiscriminatorPass P,
85                                  IntrusiveRefCntPtr<vfs::FileSystem> FS) {
86   return new MIRProfileLoaderPass(File, RemappingFile, P, std::move(FS));
87 }
88 
89 namespace llvm {
90 
91 // Internal option used to control BFI display only after MBP pass.
92 // Defined in CodeGen/MachineBlockFrequencyInfo.cpp:
93 // -view-block-layout-with-bfi={none | fraction | integer | count}
94 extern cl::opt<GVDAGType> ViewBlockLayoutWithBFI;
95 
96 // Command line option to specify the name of the function for CFG dump
97 // Defined in Analysis/BlockFrequencyInfo.cpp:  -view-bfi-func-name=
98 extern cl::opt<std::string> ViewBlockFreqFuncName;
99 
100 std::optional<PseudoProbe> extractProbe(const MachineInstr &MI) {
101   if (MI.isPseudoProbe()) {
102     PseudoProbe Probe;
103     Probe.Id = MI.getOperand(1).getImm();
104     Probe.Type = MI.getOperand(2).getImm();
105     Probe.Attr = MI.getOperand(3).getImm();
106     Probe.Factor = 1;
107     DILocation *DebugLoc = MI.getDebugLoc();
108     Probe.Discriminator = DebugLoc ? DebugLoc->getDiscriminator() : 0;
109     return Probe;
110   }
111 
112   // Ignore callsite probes since they do not have FS discriminators.
113   return std::nullopt;
114 }
115 
116 namespace afdo_detail {
117 template <> struct IRTraits<MachineBasicBlock> {
118   using InstructionT = MachineInstr;
119   using BasicBlockT = MachineBasicBlock;
120   using FunctionT = MachineFunction;
121   using BlockFrequencyInfoT = MachineBlockFrequencyInfo;
122   using LoopT = MachineLoop;
123   using LoopInfoPtrT = MachineLoopInfo *;
124   using DominatorTreePtrT = MachineDominatorTree *;
125   using PostDominatorTreePtrT = MachinePostDominatorTree *;
126   using PostDominatorTreeT = MachinePostDominatorTree;
127   using OptRemarkEmitterT = MachineOptimizationRemarkEmitter;
128   using OptRemarkAnalysisT = MachineOptimizationRemarkAnalysis;
129   using PredRangeT =
130       iterator_range<SmallVectorImpl<MachineBasicBlock *>::iterator>;
131   using SuccRangeT =
132       iterator_range<SmallVectorImpl<MachineBasicBlock *>::iterator>;
133   static Function &getFunction(MachineFunction &F) { return F.getFunction(); }
134   static const MachineBasicBlock *getEntryBB(const MachineFunction *F) {
135     return GraphTraits<const MachineFunction *>::getEntryNode(F);
136   }
137   static PredRangeT getPredecessors(MachineBasicBlock *BB) {
138     return BB->predecessors();
139   }
140   static SuccRangeT getSuccessors(MachineBasicBlock *BB) {
141     return BB->successors();
142   }
143 };
144 } // namespace afdo_detail
145 
146 class MIRProfileLoader final
147     : public SampleProfileLoaderBaseImpl<MachineFunction> {
148 public:
149   void setInitVals(MachineDominatorTree *MDT, MachinePostDominatorTree *MPDT,
150                    MachineLoopInfo *MLI, MachineBlockFrequencyInfo *MBFI,
151                    MachineOptimizationRemarkEmitter *MORE) {
152     DT = MDT;
153     PDT = MPDT;
154     LI = MLI;
155     BFI = MBFI;
156     ORE = MORE;
157   }
158   void setFSPass(FSDiscriminatorPass Pass) {
159     P = Pass;
160     LowBit = getFSPassBitBegin(P);
161     HighBit = getFSPassBitEnd(P);
162     assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
163   }
164 
165   MIRProfileLoader(StringRef Name, StringRef RemapName,
166                    IntrusiveRefCntPtr<vfs::FileSystem> FS)
167       : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName),
168                                     std::move(FS)) {}
169 
170   void setBranchProbs(MachineFunction &F);
171   bool runOnFunction(MachineFunction &F);
172   bool doInitialization(Module &M);
173   bool isValid() const { return ProfileIsValid; }
174 
175 protected:
176   friend class SampleCoverageTracker;
177 
178   /// Hold the information of the basic block frequency.
179   MachineBlockFrequencyInfo *BFI;
180 
181   /// PassNum is the sequence number this pass is called, start from 1.
182   FSDiscriminatorPass P;
183 
184   // LowBit in the FS discriminator used by this instance. Note the number is
185   // 0-based. Base discrimnator use bit 0 to bit 11.
186   unsigned LowBit;
187   // HighwBit in the FS discriminator used by this instance. Note the number
188   // is 0-based.
189   unsigned HighBit;
190 
191   bool ProfileIsValid = true;
192   ErrorOr<uint64_t> getInstWeight(const MachineInstr &MI) override {
193     if (FunctionSamples::ProfileIsProbeBased)
194       return getProbeWeight(MI);
195     if (ImprovedFSDiscriminator && MI.isMetaInstruction())
196       return std::error_code();
197     return getInstWeightImpl(MI);
198   }
199 };
200 
201 template <>
202 void SampleProfileLoaderBaseImpl<MachineFunction>::computeDominanceAndLoopInfo(
203     MachineFunction &F) {}
204 
205 void MIRProfileLoader::setBranchProbs(MachineFunction &F) {
206   LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n");
207   for (auto &BI : F) {
208     MachineBasicBlock *BB = &BI;
209     if (BB->succ_size() < 2)
210       continue;
211     const MachineBasicBlock *EC = EquivalenceClass[BB];
212     uint64_t BBWeight = BlockWeights[EC];
213     uint64_t SumEdgeWeight = 0;
214     for (MachineBasicBlock *Succ : BB->successors()) {
215       Edge E = std::make_pair(BB, Succ);
216       SumEdgeWeight += EdgeWeights[E];
217     }
218 
219     if (BBWeight != SumEdgeWeight) {
220       LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight="
221                         << BBWeight << " SumEdgeWeight= " << SumEdgeWeight
222                         << "\n");
223       BBWeight = SumEdgeWeight;
224     }
225     if (BBWeight == 0) {
226       LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n");
227       continue;
228     }
229 
230 #ifndef NDEBUG
231     uint64_t BBWeightOrig = BBWeight;
232 #endif
233     uint32_t MaxWeight = std::numeric_limits<uint32_t>::max();
234     uint32_t Factor = 1;
235     if (BBWeight > MaxWeight) {
236       Factor = BBWeight / MaxWeight + 1;
237       BBWeight /= Factor;
238       LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor << "\n");
239     }
240 
241     for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(),
242                                           SE = BB->succ_end();
243          SI != SE; ++SI) {
244       MachineBasicBlock *Succ = *SI;
245       Edge E = std::make_pair(BB, Succ);
246       uint64_t EdgeWeight = EdgeWeights[E];
247       EdgeWeight /= Factor;
248 
249       assert(BBWeight >= EdgeWeight &&
250              "BBweight is larger than EdgeWeight -- should not happen.\n");
251 
252       BranchProbability OldProb = BFI->getMBPI()->getEdgeProbability(BB, SI);
253       BranchProbability NewProb(EdgeWeight, BBWeight);
254       if (OldProb == NewProb)
255         continue;
256       BB->setSuccProbability(SI, NewProb);
257 #ifndef NDEBUG
258       if (!ShowFSBranchProb)
259         continue;
260       bool Show = false;
261       BranchProbability Diff;
262       if (OldProb > NewProb)
263         Diff = OldProb - NewProb;
264       else
265         Diff = NewProb - OldProb;
266       Show = (Diff >= BranchProbability(FSProfileDebugProbDiffThreshold, 100));
267       Show &= (BBWeightOrig >= FSProfileDebugBWThreshold);
268 
269       auto DIL = BB->findBranchDebugLoc();
270       auto SuccDIL = Succ->findBranchDebugLoc();
271       if (Show) {
272         dbgs() << "Set branch fs prob: MBB (" << BB->getNumber() << " -> "
273                << Succ->getNumber() << "): ";
274         if (DIL)
275           dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":"
276                  << DIL->getColumn();
277         if (SuccDIL)
278           dbgs() << "-->" << SuccDIL->getFilename() << ":" << SuccDIL->getLine()
279                  << ":" << SuccDIL->getColumn();
280         dbgs() << " W=" << BBWeightOrig << "  " << OldProb << " --> " << NewProb
281                << "\n";
282       }
283 #endif
284     }
285   }
286 }
287 
288 bool MIRProfileLoader::doInitialization(Module &M) {
289   auto &Ctx = M.getContext();
290 
291   auto ReaderOrErr = sampleprof::SampleProfileReader::create(
292       Filename, Ctx, *FS, P, RemappingFilename);
293   if (std::error_code EC = ReaderOrErr.getError()) {
294     std::string Msg = "Could not open profile: " + EC.message();
295     Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg));
296     return false;
297   }
298 
299   Reader = std::move(ReaderOrErr.get());
300   Reader->setModule(&M);
301   ProfileIsValid = (Reader->read() == sampleprof_error::success);
302 
303   // Load pseudo probe descriptors for probe-based function samples.
304   if (Reader->profileIsProbeBased()) {
305     ProbeManager = std::make_unique<PseudoProbeManager>(M);
306     if (!ProbeManager->moduleIsProbed(M)) {
307       return false;
308     }
309   }
310 
311   return true;
312 }
313 
314 bool MIRProfileLoader::runOnFunction(MachineFunction &MF) {
315   // Do not load non-FS profiles. A line or probe can get a zero-valued
316   // discriminator at certain pass which could result in accidentally loading
317   // the corresponding base counter in the non-FS profile, while a non-zero
318   // discriminator would end up getting zero samples. This could in turn undo
319   // the sample distribution effort done by previous BFI maintenance and the
320   // probe distribution factor work for pseudo probes.
321   if (!Reader->profileIsFS())
322     return false;
323 
324   Function &Func = MF.getFunction();
325   clearFunctionData(false);
326   Samples = Reader->getSamplesFor(Func);
327   if (!Samples || Samples->empty())
328     return false;
329 
330   if (FunctionSamples::ProfileIsProbeBased) {
331     if (!ProbeManager->profileIsValid(MF.getFunction(), *Samples))
332       return false;
333   } else {
334     if (getFunctionLoc(MF) == 0)
335       return false;
336   }
337 
338   DenseSet<GlobalValue::GUID> InlinedGUIDs;
339   bool Changed = computeAndPropagateWeights(MF, InlinedGUIDs);
340 
341   // Set the new BPI, BFI.
342   setBranchProbs(MF);
343 
344   return Changed;
345 }
346 
347 } // namespace llvm
348 
349 MIRProfileLoaderPass::MIRProfileLoaderPass(
350     std::string FileName, std::string RemappingFileName, FSDiscriminatorPass P,
351     IntrusiveRefCntPtr<vfs::FileSystem> FS)
352     : MachineFunctionPass(ID), ProfileFileName(FileName), P(P) {
353   LowBit = getFSPassBitBegin(P);
354   HighBit = getFSPassBitEnd(P);
355 
356   auto VFS = FS ? std::move(FS) : vfs::getRealFileSystem();
357   MIRSampleLoader = std::make_unique<MIRProfileLoader>(
358       FileName, RemappingFileName, std::move(VFS));
359   assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
360 }
361 
362 bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) {
363   if (!MIRSampleLoader->isValid())
364     return false;
365 
366   LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: "
367                     << MF.getFunction().getName() << "\n");
368   MBFI = &getAnalysis<MachineBlockFrequencyInfoWrapperPass>().getMBFI();
369   auto *MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
370   auto *MPDT =
371       &getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree();
372 
373   MF.RenumberBlocks();
374   MDT->updateBlockNumbers();
375   MPDT->updateBlockNumbers();
376 
377   MIRSampleLoader->setInitVals(
378       MDT, MPDT, &getAnalysis<MachineLoopInfoWrapperPass>().getLI(), MBFI,
379       &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE());
380 
381   if (ViewBFIBefore && ViewBlockLayoutWithBFI != GVDT_None &&
382       (ViewBlockFreqFuncName.empty() ||
383        MF.getFunction().getName() == ViewBlockFreqFuncName)) {
384     MBFI->view("MIR_Prof_loader_b." + MF.getName(), false);
385   }
386 
387   bool Changed = MIRSampleLoader->runOnFunction(MF);
388   if (Changed)
389     MBFI->calculate(MF, *MBFI->getMBPI(),
390                     *&getAnalysis<MachineLoopInfoWrapperPass>().getLI());
391 
392   if (ViewBFIAfter && ViewBlockLayoutWithBFI != GVDT_None &&
393       (ViewBlockFreqFuncName.empty() ||
394        MF.getFunction().getName() == ViewBlockFreqFuncName)) {
395     MBFI->view("MIR_prof_loader_a." + MF.getName(), false);
396   }
397 
398   return Changed;
399 }
400 
401 bool MIRProfileLoaderPass::doInitialization(Module &M) {
402   LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M.getName()
403                     << "\n");
404 
405   MIRSampleLoader->setFSPass(P);
406   return MIRSampleLoader->doInitialization(M);
407 }
408 
409 void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const {
410   AU.setPreservesAll();
411   AU.addRequired<MachineBlockFrequencyInfoWrapperPass>();
412   AU.addRequired<MachineDominatorTreeWrapperPass>();
413   AU.addRequired<MachinePostDominatorTreeWrapperPass>();
414   AU.addRequiredTransitive<MachineLoopInfoWrapperPass>();
415   AU.addRequired<MachineOptimizationRemarkEmitterPass>();
416   MachineFunctionPass::getAnalysisUsage(AU);
417 }
418