xref: /llvm-project/llvm/lib/CodeGen/MIRSampleProfile.cpp (revision 67efbd0bf1b2df8a479e09eb2be7db4c3c892f2c)
1 //===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides the implementation of the MIRSampleProfile loader, mainly
10 // for flow sensitive SampleFDO.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/CodeGen/MIRSampleProfile.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/DenseSet.h"
17 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
18 #include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
19 #include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
20 #include "llvm/CodeGen/MachineDominators.h"
21 #include "llvm/CodeGen/MachineInstr.h"
22 #include "llvm/CodeGen/MachineLoopInfo.h"
23 #include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h"
24 #include "llvm/CodeGen/MachinePostDominators.h"
25 #include "llvm/CodeGen/Passes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/PseudoProbe.h"
28 #include "llvm/InitializePasses.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/VirtualFileSystem.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h"
34 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h"
35 #include <optional>
36 
37 using namespace llvm;
38 using namespace sampleprof;
39 using namespace llvm::sampleprofutil;
40 using ProfileCount = Function::ProfileCount;
41 
42 #define DEBUG_TYPE "fs-profile-loader"
43 
44 static cl::opt<bool> ShowFSBranchProb(
45     "show-fs-branchprob", cl::Hidden, cl::init(false),
46     cl::desc("Print setting flow sensitive branch probabilities"));
47 static cl::opt<unsigned> FSProfileDebugProbDiffThreshold(
48     "fs-profile-debug-prob-diff-threshold", cl::init(10),
49     cl::desc(
50         "Only show debug message if the branch probability is greater than "
51         "this value (in percentage)."));
52 
53 static cl::opt<unsigned> FSProfileDebugBWThreshold(
54     "fs-profile-debug-bw-threshold", cl::init(10000),
55     cl::desc("Only show debug message if the source branch weight is greater "
56              " than this value."));
57 
58 static cl::opt<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden,
59                                    cl::init(false),
60                                    cl::desc("View BFI before MIR loader"));
61 static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden,
62                                   cl::init(false),
63                                   cl::desc("View BFI after MIR loader"));
64 
65 namespace llvm {
66 extern cl::opt<bool> ImprovedFSDiscriminator;
67 }
68 char MIRProfileLoaderPass::ID = 0;
69 
70 INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE,
71                       "Load MIR Sample Profile",
72                       /* cfg = */ false, /* is_analysis = */ false)
73 INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfoWrapperPass)
74 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
75 INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass)
76 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass)
77 INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass)
78 INITIALIZE_PASS_END(MIRProfileLoaderPass, DEBUG_TYPE, "Load MIR Sample Profile",
79                     /* cfg = */ false, /* is_analysis = */ false)
80 
81 char &llvm::MIRProfileLoaderPassID = MIRProfileLoaderPass::ID;
82 
83 FunctionPass *
84 llvm::createMIRProfileLoaderPass(std::string File, std::string RemappingFile,
85                                  FSDiscriminatorPass P,
86                                  IntrusiveRefCntPtr<vfs::FileSystem> FS) {
87   return new MIRProfileLoaderPass(File, RemappingFile, P, std::move(FS));
88 }
89 
90 namespace llvm {
91 
92 // Internal option used to control BFI display only after MBP pass.
93 // Defined in CodeGen/MachineBlockFrequencyInfo.cpp:
94 // -view-block-layout-with-bfi={none | fraction | integer | count}
95 extern cl::opt<GVDAGType> ViewBlockLayoutWithBFI;
96 
97 // Command line option to specify the name of the function for CFG dump
98 // Defined in Analysis/BlockFrequencyInfo.cpp:  -view-bfi-func-name=
99 extern cl::opt<std::string> ViewBlockFreqFuncName;
100 
101 std::optional<PseudoProbe> extractProbe(const MachineInstr &MI) {
102   if (MI.isPseudoProbe()) {
103     PseudoProbe Probe;
104     Probe.Id = MI.getOperand(1).getImm();
105     Probe.Type = MI.getOperand(2).getImm();
106     Probe.Attr = MI.getOperand(3).getImm();
107     Probe.Factor = 1;
108     DILocation *DebugLoc = MI.getDebugLoc();
109     Probe.Discriminator = DebugLoc ? DebugLoc->getDiscriminator() : 0;
110     return Probe;
111   }
112 
113   // Ignore callsite probes since they do not have FS discriminators.
114   return std::nullopt;
115 }
116 
117 namespace afdo_detail {
118 template <> struct IRTraits<MachineBasicBlock> {
119   using InstructionT = MachineInstr;
120   using BasicBlockT = MachineBasicBlock;
121   using FunctionT = MachineFunction;
122   using BlockFrequencyInfoT = MachineBlockFrequencyInfo;
123   using LoopT = MachineLoop;
124   using LoopInfoPtrT = MachineLoopInfo *;
125   using DominatorTreePtrT = MachineDominatorTree *;
126   using PostDominatorTreePtrT = MachinePostDominatorTree *;
127   using PostDominatorTreeT = MachinePostDominatorTree;
128   using OptRemarkEmitterT = MachineOptimizationRemarkEmitter;
129   using OptRemarkAnalysisT = MachineOptimizationRemarkAnalysis;
130   using PredRangeT =
131       iterator_range<SmallVectorImpl<MachineBasicBlock *>::iterator>;
132   using SuccRangeT =
133       iterator_range<SmallVectorImpl<MachineBasicBlock *>::iterator>;
134   static Function &getFunction(MachineFunction &F) { return F.getFunction(); }
135   static const MachineBasicBlock *getEntryBB(const MachineFunction *F) {
136     return GraphTraits<const MachineFunction *>::getEntryNode(F);
137   }
138   static PredRangeT getPredecessors(MachineBasicBlock *BB) {
139     return BB->predecessors();
140   }
141   static SuccRangeT getSuccessors(MachineBasicBlock *BB) {
142     return BB->successors();
143   }
144 };
145 } // namespace afdo_detail
146 
147 class MIRProfileLoader final
148     : public SampleProfileLoaderBaseImpl<MachineFunction> {
149 public:
150   void setInitVals(MachineDominatorTree *MDT, MachinePostDominatorTree *MPDT,
151                    MachineLoopInfo *MLI, MachineBlockFrequencyInfo *MBFI,
152                    MachineOptimizationRemarkEmitter *MORE) {
153     DT = MDT;
154     PDT = MPDT;
155     LI = MLI;
156     BFI = MBFI;
157     ORE = MORE;
158   }
159   void setFSPass(FSDiscriminatorPass Pass) {
160     P = Pass;
161     LowBit = getFSPassBitBegin(P);
162     HighBit = getFSPassBitEnd(P);
163     assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
164   }
165 
166   MIRProfileLoader(StringRef Name, StringRef RemapName,
167                    IntrusiveRefCntPtr<vfs::FileSystem> FS)
168       : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName),
169                                     std::move(FS)) {}
170 
171   void setBranchProbs(MachineFunction &F);
172   bool runOnFunction(MachineFunction &F);
173   bool doInitialization(Module &M);
174   bool isValid() const { return ProfileIsValid; }
175 
176 protected:
177   friend class SampleCoverageTracker;
178 
179   /// Hold the information of the basic block frequency.
180   MachineBlockFrequencyInfo *BFI;
181 
182   /// PassNum is the sequence number this pass is called, start from 1.
183   FSDiscriminatorPass P;
184 
185   // LowBit in the FS discriminator used by this instance. Note the number is
186   // 0-based. Base discrimnator use bit 0 to bit 11.
187   unsigned LowBit;
188   // HighwBit in the FS discriminator used by this instance. Note the number
189   // is 0-based.
190   unsigned HighBit;
191 
192   bool ProfileIsValid = true;
193   ErrorOr<uint64_t> getInstWeight(const MachineInstr &MI) override {
194     if (FunctionSamples::ProfileIsProbeBased)
195       return getProbeWeight(MI);
196     if (ImprovedFSDiscriminator && MI.isMetaInstruction())
197       return std::error_code();
198     return getInstWeightImpl(MI);
199   }
200 };
201 
202 template <>
203 void SampleProfileLoaderBaseImpl<MachineFunction>::computeDominanceAndLoopInfo(
204     MachineFunction &F) {}
205 
206 void MIRProfileLoader::setBranchProbs(MachineFunction &F) {
207   LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n");
208   for (auto &BI : F) {
209     MachineBasicBlock *BB = &BI;
210     if (BB->succ_size() < 2)
211       continue;
212     const MachineBasicBlock *EC = EquivalenceClass[BB];
213     uint64_t BBWeight = BlockWeights[EC];
214     uint64_t SumEdgeWeight = 0;
215     for (MachineBasicBlock *Succ : BB->successors()) {
216       Edge E = std::make_pair(BB, Succ);
217       SumEdgeWeight += EdgeWeights[E];
218     }
219 
220     if (BBWeight != SumEdgeWeight) {
221       LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight="
222                         << BBWeight << " SumEdgeWeight= " << SumEdgeWeight
223                         << "\n");
224       BBWeight = SumEdgeWeight;
225     }
226     if (BBWeight == 0) {
227       LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n");
228       continue;
229     }
230 
231 #ifndef NDEBUG
232     uint64_t BBWeightOrig = BBWeight;
233 #endif
234     uint32_t MaxWeight = std::numeric_limits<uint32_t>::max();
235     uint32_t Factor = 1;
236     if (BBWeight > MaxWeight) {
237       Factor = BBWeight / MaxWeight + 1;
238       BBWeight /= Factor;
239       LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor << "\n");
240     }
241 
242     for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(),
243                                           SE = BB->succ_end();
244          SI != SE; ++SI) {
245       MachineBasicBlock *Succ = *SI;
246       Edge E = std::make_pair(BB, Succ);
247       uint64_t EdgeWeight = EdgeWeights[E];
248       EdgeWeight /= Factor;
249 
250       assert(BBWeight >= EdgeWeight &&
251              "BBweight is larger than EdgeWeight -- should not happen.\n");
252 
253       BranchProbability OldProb = BFI->getMBPI()->getEdgeProbability(BB, SI);
254       BranchProbability NewProb(EdgeWeight, BBWeight);
255       if (OldProb == NewProb)
256         continue;
257       BB->setSuccProbability(SI, NewProb);
258 #ifndef NDEBUG
259       if (!ShowFSBranchProb)
260         continue;
261       bool Show = false;
262       BranchProbability Diff;
263       if (OldProb > NewProb)
264         Diff = OldProb - NewProb;
265       else
266         Diff = NewProb - OldProb;
267       Show = (Diff >= BranchProbability(FSProfileDebugProbDiffThreshold, 100));
268       Show &= (BBWeightOrig >= FSProfileDebugBWThreshold);
269 
270       auto DIL = BB->findBranchDebugLoc();
271       auto SuccDIL = Succ->findBranchDebugLoc();
272       if (Show) {
273         dbgs() << "Set branch fs prob: MBB (" << BB->getNumber() << " -> "
274                << Succ->getNumber() << "): ";
275         if (DIL)
276           dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":"
277                  << DIL->getColumn();
278         if (SuccDIL)
279           dbgs() << "-->" << SuccDIL->getFilename() << ":" << SuccDIL->getLine()
280                  << ":" << SuccDIL->getColumn();
281         dbgs() << " W=" << BBWeightOrig << "  " << OldProb << " --> " << NewProb
282                << "\n";
283       }
284 #endif
285     }
286   }
287 }
288 
289 bool MIRProfileLoader::doInitialization(Module &M) {
290   auto &Ctx = M.getContext();
291 
292   auto ReaderOrErr = sampleprof::SampleProfileReader::create(
293       Filename, Ctx, *FS, P, RemappingFilename);
294   if (std::error_code EC = ReaderOrErr.getError()) {
295     std::string Msg = "Could not open profile: " + EC.message();
296     Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg));
297     return false;
298   }
299 
300   Reader = std::move(ReaderOrErr.get());
301   Reader->setModule(&M);
302   ProfileIsValid = (Reader->read() == sampleprof_error::success);
303 
304   // Load pseudo probe descriptors for probe-based function samples.
305   if (Reader->profileIsProbeBased()) {
306     ProbeManager = std::make_unique<PseudoProbeManager>(M);
307     if (!ProbeManager->moduleIsProbed(M)) {
308       return false;
309     }
310   }
311 
312   return true;
313 }
314 
315 bool MIRProfileLoader::runOnFunction(MachineFunction &MF) {
316   // Do not load non-FS profiles. A line or probe can get a zero-valued
317   // discriminator at certain pass which could result in accidentally loading
318   // the corresponding base counter in the non-FS profile, while a non-zero
319   // discriminator would end up getting zero samples. This could in turn undo
320   // the sample distribution effort done by previous BFI maintenance and the
321   // probe distribution factor work for pseudo probes.
322   if (!Reader->profileIsFS())
323     return false;
324 
325   Function &Func = MF.getFunction();
326   clearFunctionData(false);
327   Samples = Reader->getSamplesFor(Func);
328   if (!Samples || Samples->empty())
329     return false;
330 
331   if (FunctionSamples::ProfileIsProbeBased) {
332     if (!ProbeManager->profileIsValid(MF.getFunction(), *Samples))
333       return false;
334   } else {
335     if (getFunctionLoc(MF) == 0)
336       return false;
337   }
338 
339   DenseSet<GlobalValue::GUID> InlinedGUIDs;
340   bool Changed = computeAndPropagateWeights(MF, InlinedGUIDs);
341 
342   // Set the new BPI, BFI.
343   setBranchProbs(MF);
344 
345   return Changed;
346 }
347 
348 } // namespace llvm
349 
350 MIRProfileLoaderPass::MIRProfileLoaderPass(
351     std::string FileName, std::string RemappingFileName, FSDiscriminatorPass P,
352     IntrusiveRefCntPtr<vfs::FileSystem> FS)
353     : MachineFunctionPass(ID), ProfileFileName(FileName), P(P) {
354   LowBit = getFSPassBitBegin(P);
355   HighBit = getFSPassBitEnd(P);
356 
357   auto VFS = FS ? std::move(FS) : vfs::getRealFileSystem();
358   MIRSampleLoader = std::make_unique<MIRProfileLoader>(
359       FileName, RemappingFileName, std::move(VFS));
360   assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
361 }
362 
363 bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) {
364   if (!MIRSampleLoader->isValid())
365     return false;
366 
367   LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: "
368                     << MF.getFunction().getName() << "\n");
369   MBFI = &getAnalysis<MachineBlockFrequencyInfoWrapperPass>().getMBFI();
370   auto *MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
371   auto *MPDT =
372       &getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree();
373 
374   MF.RenumberBlocks();
375   MDT->updateBlockNumbers();
376   MPDT->updateBlockNumbers();
377 
378   MIRSampleLoader->setInitVals(
379       MDT, MPDT, &getAnalysis<MachineLoopInfoWrapperPass>().getLI(), MBFI,
380       &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE());
381 
382   if (ViewBFIBefore && ViewBlockLayoutWithBFI != GVDT_None &&
383       (ViewBlockFreqFuncName.empty() ||
384        MF.getFunction().getName() == ViewBlockFreqFuncName)) {
385     MBFI->view("MIR_Prof_loader_b." + MF.getName(), false);
386   }
387 
388   bool Changed = MIRSampleLoader->runOnFunction(MF);
389   if (Changed)
390     MBFI->calculate(MF, *MBFI->getMBPI(),
391                     *&getAnalysis<MachineLoopInfoWrapperPass>().getLI());
392 
393   if (ViewBFIAfter && ViewBlockLayoutWithBFI != GVDT_None &&
394       (ViewBlockFreqFuncName.empty() ||
395        MF.getFunction().getName() == ViewBlockFreqFuncName)) {
396     MBFI->view("MIR_prof_loader_a." + MF.getName(), false);
397   }
398 
399   return Changed;
400 }
401 
402 bool MIRProfileLoaderPass::doInitialization(Module &M) {
403   LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M.getName()
404                     << "\n");
405 
406   MIRSampleLoader->setFSPass(P);
407   return MIRSampleLoader->doInitialization(M);
408 }
409 
410 void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const {
411   AU.setPreservesAll();
412   AU.addRequired<MachineBlockFrequencyInfoWrapperPass>();
413   AU.addRequired<MachineDominatorTreeWrapperPass>();
414   AU.addRequired<MachinePostDominatorTreeWrapperPass>();
415   AU.addRequiredTransitive<MachineLoopInfoWrapperPass>();
416   AU.addRequired<MachineOptimizationRemarkEmitterPass>();
417   MachineFunctionPass::getAnalysisUsage(AU);
418 }
419