xref: /llvm-project/llvm/lib/CodeGen/MIRSampleProfile.cpp (revision b7d9322b4963e620dfd12246816e6f7b2da5fd88)
1 //===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides the implementation of the MIRSampleProfile loader, mainly
10 // for flow sensitive SampleFDO.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/CodeGen/MIRSampleProfile.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/DenseSet.h"
17 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
18 #include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
19 #include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
20 #include "llvm/CodeGen/MachineDominators.h"
21 #include "llvm/CodeGen/MachineInstr.h"
22 #include "llvm/CodeGen/MachineLoopInfo.h"
23 #include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h"
24 #include "llvm/CodeGen/MachinePostDominators.h"
25 #include "llvm/CodeGen/Passes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/PseudoProbe.h"
28 #include "llvm/InitializePasses.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/VirtualFileSystem.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h"
34 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h"
35 #include <optional>
36 
37 using namespace llvm;
38 using namespace sampleprof;
39 using namespace llvm::sampleprofutil;
40 using ProfileCount = Function::ProfileCount;
41 
42 #define DEBUG_TYPE "fs-profile-loader"
43 
44 static cl::opt<bool> ShowFSBranchProb(
45     "show-fs-branchprob", cl::Hidden, cl::init(false),
46     cl::desc("Print setting flow sensitive branch probabilities"));
47 static cl::opt<unsigned> FSProfileDebugProbDiffThreshold(
48     "fs-profile-debug-prob-diff-threshold", cl::init(10),
49     cl::desc("Only show debug message if the branch probility is greater than "
50              "this value (in percentage)."));
51 
52 static cl::opt<unsigned> FSProfileDebugBWThreshold(
53     "fs-profile-debug-bw-threshold", cl::init(10000),
54     cl::desc("Only show debug message if the source branch weight is greater "
55              " than this value."));
56 
57 static cl::opt<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden,
58                                    cl::init(false),
59                                    cl::desc("View BFI before MIR loader"));
60 static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden,
61                                   cl::init(false),
62                                   cl::desc("View BFI after MIR loader"));
63 
64 extern cl::opt<bool> ImprovedFSDiscriminator;
65 char MIRProfileLoaderPass::ID = 0;
66 
67 INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE,
68                       "Load MIR Sample Profile",
69                       /* cfg = */ false, /* is_analysis = */ false)
70 INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo)
71 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
72 INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTree)
73 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
74 INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass)
75 INITIALIZE_PASS_END(MIRProfileLoaderPass, DEBUG_TYPE, "Load MIR Sample Profile",
76                     /* cfg = */ false, /* is_analysis = */ false)
77 
78 char &llvm::MIRProfileLoaderPassID = MIRProfileLoaderPass::ID;
79 
80 FunctionPass *
81 llvm::createMIRProfileLoaderPass(std::string File, std::string RemappingFile,
82                                  FSDiscriminatorPass P,
83                                  IntrusiveRefCntPtr<vfs::FileSystem> FS) {
84   return new MIRProfileLoaderPass(File, RemappingFile, P, std::move(FS));
85 }
86 
87 namespace llvm {
88 
89 // Internal option used to control BFI display only after MBP pass.
90 // Defined in CodeGen/MachineBlockFrequencyInfo.cpp:
91 // -view-block-layout-with-bfi={none | fraction | integer | count}
92 extern cl::opt<GVDAGType> ViewBlockLayoutWithBFI;
93 
94 // Command line option to specify the name of the function for CFG dump
95 // Defined in Analysis/BlockFrequencyInfo.cpp:  -view-bfi-func-name=
96 extern cl::opt<std::string> ViewBlockFreqFuncName;
97 
98 std::optional<PseudoProbe> extractProbe(const MachineInstr &MI) {
99   if (MI.isPseudoProbe()) {
100     PseudoProbe Probe;
101     Probe.Id = MI.getOperand(1).getImm();
102     Probe.Type = MI.getOperand(2).getImm();
103     Probe.Attr = MI.getOperand(3).getImm();
104     Probe.Factor = 1;
105     DILocation *DebugLoc = MI.getDebugLoc();
106     Probe.Discriminator = DebugLoc ? DebugLoc->getDiscriminator() : 0;
107     return Probe;
108   }
109 
110   // Ignore callsite probes since they do not have FS discriminators.
111   return std::nullopt;
112 }
113 
114 namespace afdo_detail {
115 template <> struct IRTraits<MachineBasicBlock> {
116   using InstructionT = MachineInstr;
117   using BasicBlockT = MachineBasicBlock;
118   using FunctionT = MachineFunction;
119   using BlockFrequencyInfoT = MachineBlockFrequencyInfo;
120   using LoopT = MachineLoop;
121   using LoopInfoPtrT = MachineLoopInfo *;
122   using DominatorTreePtrT = MachineDominatorTree *;
123   using PostDominatorTreePtrT = MachinePostDominatorTree *;
124   using PostDominatorTreeT = MachinePostDominatorTree;
125   using OptRemarkEmitterT = MachineOptimizationRemarkEmitter;
126   using OptRemarkAnalysisT = MachineOptimizationRemarkAnalysis;
127   using PredRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>;
128   using SuccRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>;
129   static Function &getFunction(MachineFunction &F) { return F.getFunction(); }
130   static const MachineBasicBlock *getEntryBB(const MachineFunction *F) {
131     return GraphTraits<const MachineFunction *>::getEntryNode(F);
132   }
133   static PredRangeT getPredecessors(MachineBasicBlock *BB) {
134     return BB->predecessors();
135   }
136   static SuccRangeT getSuccessors(MachineBasicBlock *BB) {
137     return BB->successors();
138   }
139 };
140 } // namespace afdo_detail
141 
142 class MIRProfileLoader final
143     : public SampleProfileLoaderBaseImpl<MachineBasicBlock> {
144 public:
145   void setInitVals(MachineDominatorTree *MDT, MachinePostDominatorTree *MPDT,
146                    MachineLoopInfo *MLI, MachineBlockFrequencyInfo *MBFI,
147                    MachineOptimizationRemarkEmitter *MORE) {
148     DT = MDT;
149     PDT = MPDT;
150     LI = MLI;
151     BFI = MBFI;
152     ORE = MORE;
153   }
154   void setFSPass(FSDiscriminatorPass Pass) {
155     P = Pass;
156     LowBit = getFSPassBitBegin(P);
157     HighBit = getFSPassBitEnd(P);
158     assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
159   }
160 
161   MIRProfileLoader(StringRef Name, StringRef RemapName,
162                    IntrusiveRefCntPtr<vfs::FileSystem> FS)
163       : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName),
164                                     std::move(FS)) {}
165 
166   void setBranchProbs(MachineFunction &F);
167   bool runOnFunction(MachineFunction &F);
168   bool doInitialization(Module &M);
169   bool isValid() const { return ProfileIsValid; }
170 
171 protected:
172   friend class SampleCoverageTracker;
173 
174   /// Hold the information of the basic block frequency.
175   MachineBlockFrequencyInfo *BFI;
176 
177   /// PassNum is the sequence number this pass is called, start from 1.
178   FSDiscriminatorPass P;
179 
180   // LowBit in the FS discriminator used by this instance. Note the number is
181   // 0-based. Base discrimnator use bit 0 to bit 11.
182   unsigned LowBit;
183   // HighwBit in the FS discriminator used by this instance. Note the number
184   // is 0-based.
185   unsigned HighBit;
186 
187   bool ProfileIsValid = true;
188   ErrorOr<uint64_t> getInstWeight(const MachineInstr &MI) override {
189     if (FunctionSamples::ProfileIsProbeBased)
190       return getProbeWeight(MI);
191     if (ImprovedFSDiscriminator && MI.isMetaInstruction())
192       return std::error_code();
193     return getInstWeightImpl(MI);
194   }
195 };
196 
197 template <>
198 void SampleProfileLoaderBaseImpl<
199     MachineBasicBlock>::computeDominanceAndLoopInfo(MachineFunction &F) {}
200 
201 void MIRProfileLoader::setBranchProbs(MachineFunction &F) {
202   LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n");
203   for (auto &BI : F) {
204     MachineBasicBlock *BB = &BI;
205     if (BB->succ_size() < 2)
206       continue;
207     const MachineBasicBlock *EC = EquivalenceClass[BB];
208     uint64_t BBWeight = BlockWeights[EC];
209     uint64_t SumEdgeWeight = 0;
210     for (MachineBasicBlock *Succ : BB->successors()) {
211       Edge E = std::make_pair(BB, Succ);
212       SumEdgeWeight += EdgeWeights[E];
213     }
214 
215     if (BBWeight != SumEdgeWeight) {
216       LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight="
217                         << BBWeight << " SumEdgeWeight= " << SumEdgeWeight
218                         << "\n");
219       BBWeight = SumEdgeWeight;
220     }
221     if (BBWeight == 0) {
222       LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n");
223       continue;
224     }
225 
226 #ifndef NDEBUG
227     uint64_t BBWeightOrig = BBWeight;
228 #endif
229     uint32_t MaxWeight = std::numeric_limits<uint32_t>::max();
230     uint32_t Factor = 1;
231     if (BBWeight > MaxWeight) {
232       Factor = BBWeight / MaxWeight + 1;
233       BBWeight /= Factor;
234       LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor << "\n");
235     }
236 
237     for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(),
238                                           SE = BB->succ_end();
239          SI != SE; ++SI) {
240       MachineBasicBlock *Succ = *SI;
241       Edge E = std::make_pair(BB, Succ);
242       uint64_t EdgeWeight = EdgeWeights[E];
243       EdgeWeight /= Factor;
244 
245       assert(BBWeight >= EdgeWeight &&
246              "BBweight is larger than EdgeWeight -- should not happen.\n");
247 
248       BranchProbability OldProb = BFI->getMBPI()->getEdgeProbability(BB, SI);
249       BranchProbability NewProb(EdgeWeight, BBWeight);
250       if (OldProb == NewProb)
251         continue;
252       BB->setSuccProbability(SI, NewProb);
253 #ifndef NDEBUG
254       if (!ShowFSBranchProb)
255         continue;
256       bool Show = false;
257       BranchProbability Diff;
258       if (OldProb > NewProb)
259         Diff = OldProb - NewProb;
260       else
261         Diff = NewProb - OldProb;
262       Show = (Diff >= BranchProbability(FSProfileDebugProbDiffThreshold, 100));
263       Show &= (BBWeightOrig >= FSProfileDebugBWThreshold);
264 
265       auto DIL = BB->findBranchDebugLoc();
266       auto SuccDIL = Succ->findBranchDebugLoc();
267       if (Show) {
268         dbgs() << "Set branch fs prob: MBB (" << BB->getNumber() << " -> "
269                << Succ->getNumber() << "): ";
270         if (DIL)
271           dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":"
272                  << DIL->getColumn();
273         if (SuccDIL)
274           dbgs() << "-->" << SuccDIL->getFilename() << ":" << SuccDIL->getLine()
275                  << ":" << SuccDIL->getColumn();
276         dbgs() << " W=" << BBWeightOrig << "  " << OldProb << " --> " << NewProb
277                << "\n";
278       }
279 #endif
280     }
281   }
282 }
283 
284 bool MIRProfileLoader::doInitialization(Module &M) {
285   auto &Ctx = M.getContext();
286 
287   auto ReaderOrErr = sampleprof::SampleProfileReader::create(
288       Filename, Ctx, *FS, P, RemappingFilename);
289   if (std::error_code EC = ReaderOrErr.getError()) {
290     std::string Msg = "Could not open profile: " + EC.message();
291     Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg));
292     return false;
293   }
294 
295   Reader = std::move(ReaderOrErr.get());
296   Reader->setModule(&M);
297   ProfileIsValid = (Reader->read() == sampleprof_error::success);
298 
299   // Load pseudo probe descriptors for probe-based function samples.
300   if (Reader->profileIsProbeBased()) {
301     ProbeManager = std::make_unique<PseudoProbeManager>(M);
302     if (!ProbeManager->moduleIsProbed(M)) {
303       return false;
304     }
305   }
306 
307   return true;
308 }
309 
310 bool MIRProfileLoader::runOnFunction(MachineFunction &MF) {
311   Function &Func = MF.getFunction();
312   clearFunctionData(false);
313   Samples = Reader->getSamplesFor(Func);
314   if (!Samples || Samples->empty())
315     return false;
316 
317   if (FunctionSamples::ProfileIsProbeBased) {
318     if (!ProbeManager->profileIsValid(MF.getFunction(), *Samples))
319       return false;
320   } else {
321     if (getFunctionLoc(MF) == 0)
322       return false;
323   }
324 
325   DenseSet<GlobalValue::GUID> InlinedGUIDs;
326   bool Changed = computeAndPropagateWeights(MF, InlinedGUIDs);
327 
328   // Set the new BPI, BFI.
329   setBranchProbs(MF);
330 
331   return Changed;
332 }
333 
334 } // namespace llvm
335 
336 MIRProfileLoaderPass::MIRProfileLoaderPass(
337     std::string FileName, std::string RemappingFileName, FSDiscriminatorPass P,
338     IntrusiveRefCntPtr<vfs::FileSystem> FS)
339     : MachineFunctionPass(ID), ProfileFileName(FileName), P(P) {
340   LowBit = getFSPassBitBegin(P);
341   HighBit = getFSPassBitEnd(P);
342 
343   auto VFS = FS ? std::move(FS) : vfs::getRealFileSystem();
344   MIRSampleLoader = std::make_unique<MIRProfileLoader>(
345       FileName, RemappingFileName, std::move(VFS));
346   assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
347 }
348 
349 bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) {
350   if (!MIRSampleLoader->isValid())
351     return false;
352 
353   LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: "
354                     << MF.getFunction().getName() << "\n");
355   MBFI = &getAnalysis<MachineBlockFrequencyInfo>();
356   MIRSampleLoader->setInitVals(
357       &getAnalysis<MachineDominatorTree>(),
358       &getAnalysis<MachinePostDominatorTree>(), &getAnalysis<MachineLoopInfo>(),
359       MBFI, &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE());
360 
361   MF.RenumberBlocks();
362   if (ViewBFIBefore && ViewBlockLayoutWithBFI != GVDT_None &&
363       (ViewBlockFreqFuncName.empty() ||
364        MF.getFunction().getName().equals(ViewBlockFreqFuncName))) {
365     MBFI->view("MIR_Prof_loader_b." + MF.getName(), false);
366   }
367 
368   bool Changed = MIRSampleLoader->runOnFunction(MF);
369   if (Changed)
370     MBFI->calculate(MF, *MBFI->getMBPI(), *&getAnalysis<MachineLoopInfo>());
371 
372   if (ViewBFIAfter && ViewBlockLayoutWithBFI != GVDT_None &&
373       (ViewBlockFreqFuncName.empty() ||
374        MF.getFunction().getName().equals(ViewBlockFreqFuncName))) {
375     MBFI->view("MIR_prof_loader_a." + MF.getName(), false);
376   }
377 
378   return Changed;
379 }
380 
381 bool MIRProfileLoaderPass::doInitialization(Module &M) {
382   LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M.getName()
383                     << "\n");
384 
385   MIRSampleLoader->setFSPass(P);
386   return MIRSampleLoader->doInitialization(M);
387 }
388 
389 void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const {
390   AU.setPreservesAll();
391   AU.addRequired<MachineBlockFrequencyInfo>();
392   AU.addRequired<MachineDominatorTree>();
393   AU.addRequired<MachinePostDominatorTree>();
394   AU.addRequiredTransitive<MachineLoopInfo>();
395   AU.addRequired<MachineOptimizationRemarkEmitterPass>();
396   MachineFunctionPass::getAnalysisUsage(AU);
397 }
398