xref: /llvm-project/llvm/lib/CodeGen/MIRSampleProfile.cpp (revision 516e301752560311d2cd8c2b549493eb0f98d01b)
1 //===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides the implementation of the MIRSampleProfile loader, mainly
10 // for flow sensitive SampleFDO.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/CodeGen/MIRSampleProfile.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/DenseSet.h"
17 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
18 #include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
19 #include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
20 #include "llvm/CodeGen/MachineDominators.h"
21 #include "llvm/CodeGen/MachineLoopInfo.h"
22 #include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h"
23 #include "llvm/CodeGen/MachinePostDominators.h"
24 #include "llvm/CodeGen/Passes.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/VirtualFileSystem.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h"
32 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h"
33 
34 using namespace llvm;
35 using namespace sampleprof;
36 using namespace llvm::sampleprofutil;
37 using ProfileCount = Function::ProfileCount;
38 
39 #define DEBUG_TYPE "fs-profile-loader"
40 
41 static cl::opt<bool> ShowFSBranchProb(
42     "show-fs-branchprob", cl::Hidden, cl::init(false),
43     cl::desc("Print setting flow sensitive branch probabilities"));
44 static cl::opt<unsigned> FSProfileDebugProbDiffThreshold(
45     "fs-profile-debug-prob-diff-threshold", cl::init(10),
46     cl::desc("Only show debug message if the branch probility is greater than "
47              "this value (in percentage)."));
48 
49 static cl::opt<unsigned> FSProfileDebugBWThreshold(
50     "fs-profile-debug-bw-threshold", cl::init(10000),
51     cl::desc("Only show debug message if the source branch weight is greater "
52              " than this value."));
53 
54 static cl::opt<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden,
55                                    cl::init(false),
56                                    cl::desc("View BFI before MIR loader"));
57 static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden,
58                                   cl::init(false),
59                                   cl::desc("View BFI after MIR loader"));
60 
61 char MIRProfileLoaderPass::ID = 0;
62 
63 INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE,
64                       "Load MIR Sample Profile",
65                       /* cfg = */ false, /* is_analysis = */ false)
66 INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo)
67 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
68 INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTree)
69 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
70 INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass)
71 INITIALIZE_PASS_END(MIRProfileLoaderPass, DEBUG_TYPE, "Load MIR Sample Profile",
72                     /* cfg = */ false, /* is_analysis = */ false)
73 
74 char &llvm::MIRProfileLoaderPassID = MIRProfileLoaderPass::ID;
75 
76 FunctionPass *
77 llvm::createMIRProfileLoaderPass(std::string File, std::string RemappingFile,
78                                  FSDiscriminatorPass P,
79                                  IntrusiveRefCntPtr<vfs::FileSystem> FS) {
80   return new MIRProfileLoaderPass(File, RemappingFile, P, std::move(FS));
81 }
82 
83 namespace llvm {
84 
85 // Internal option used to control BFI display only after MBP pass.
86 // Defined in CodeGen/MachineBlockFrequencyInfo.cpp:
87 // -view-block-layout-with-bfi={none | fraction | integer | count}
88 extern cl::opt<GVDAGType> ViewBlockLayoutWithBFI;
89 
90 // Command line option to specify the name of the function for CFG dump
91 // Defined in Analysis/BlockFrequencyInfo.cpp:  -view-bfi-func-name=
92 extern cl::opt<std::string> ViewBlockFreqFuncName;
93 
94 namespace afdo_detail {
95 template <> struct IRTraits<MachineBasicBlock> {
96   using InstructionT = MachineInstr;
97   using BasicBlockT = MachineBasicBlock;
98   using FunctionT = MachineFunction;
99   using BlockFrequencyInfoT = MachineBlockFrequencyInfo;
100   using LoopT = MachineLoop;
101   using LoopInfoPtrT = MachineLoopInfo *;
102   using DominatorTreePtrT = MachineDominatorTree *;
103   using PostDominatorTreePtrT = MachinePostDominatorTree *;
104   using PostDominatorTreeT = MachinePostDominatorTree;
105   using OptRemarkEmitterT = MachineOptimizationRemarkEmitter;
106   using OptRemarkAnalysisT = MachineOptimizationRemarkAnalysis;
107   using PredRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>;
108   using SuccRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>;
109   static Function &getFunction(MachineFunction &F) { return F.getFunction(); }
110   static const MachineBasicBlock *getEntryBB(const MachineFunction *F) {
111     return GraphTraits<const MachineFunction *>::getEntryNode(F);
112   }
113   static PredRangeT getPredecessors(MachineBasicBlock *BB) {
114     return BB->predecessors();
115   }
116   static SuccRangeT getSuccessors(MachineBasicBlock *BB) {
117     return BB->successors();
118   }
119 };
120 } // namespace afdo_detail
121 
122 class MIRProfileLoader final
123     : public SampleProfileLoaderBaseImpl<MachineBasicBlock> {
124 public:
125   void setInitVals(MachineDominatorTree *MDT, MachinePostDominatorTree *MPDT,
126                    MachineLoopInfo *MLI, MachineBlockFrequencyInfo *MBFI,
127                    MachineOptimizationRemarkEmitter *MORE) {
128     DT = MDT;
129     PDT = MPDT;
130     LI = MLI;
131     BFI = MBFI;
132     ORE = MORE;
133   }
134   void setFSPass(FSDiscriminatorPass Pass) {
135     P = Pass;
136     LowBit = getFSPassBitBegin(P);
137     HighBit = getFSPassBitEnd(P);
138     assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
139   }
140 
141   MIRProfileLoader(StringRef Name, StringRef RemapName,
142                    IntrusiveRefCntPtr<vfs::FileSystem> FS)
143       : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName),
144                                     std::move(FS)) {}
145 
146   void setBranchProbs(MachineFunction &F);
147   bool runOnFunction(MachineFunction &F);
148   bool doInitialization(Module &M);
149   bool isValid() const { return ProfileIsValid; }
150 
151 protected:
152   friend class SampleCoverageTracker;
153 
154   /// Hold the information of the basic block frequency.
155   MachineBlockFrequencyInfo *BFI;
156 
157   /// PassNum is the sequence number this pass is called, start from 1.
158   FSDiscriminatorPass P;
159 
160   // LowBit in the FS discriminator used by this instance. Note the number is
161   // 0-based. Base discrimnator use bit 0 to bit 11.
162   unsigned LowBit;
163   // HighwBit in the FS discriminator used by this instance. Note the number
164   // is 0-based.
165   unsigned HighBit;
166 
167   bool ProfileIsValid = true;
168 };
169 
170 template <>
171 void SampleProfileLoaderBaseImpl<
172     MachineBasicBlock>::computeDominanceAndLoopInfo(MachineFunction &F) {}
173 
174 void MIRProfileLoader::setBranchProbs(MachineFunction &F) {
175   LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n");
176   for (auto &BI : F) {
177     MachineBasicBlock *BB = &BI;
178     if (BB->succ_size() < 2)
179       continue;
180     const MachineBasicBlock *EC = EquivalenceClass[BB];
181     uint64_t BBWeight = BlockWeights[EC];
182     uint64_t SumEdgeWeight = 0;
183     for (MachineBasicBlock *Succ : BB->successors()) {
184       Edge E = std::make_pair(BB, Succ);
185       SumEdgeWeight += EdgeWeights[E];
186     }
187 
188     if (BBWeight != SumEdgeWeight) {
189       LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight="
190                         << BBWeight << " SumEdgeWeight= " << SumEdgeWeight
191                         << "\n");
192       BBWeight = SumEdgeWeight;
193     }
194     if (BBWeight == 0) {
195       LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n");
196       continue;
197     }
198 
199 #ifndef NDEBUG
200     uint64_t BBWeightOrig = BBWeight;
201 #endif
202     uint32_t MaxWeight = std::numeric_limits<uint32_t>::max();
203     uint32_t Factor = 1;
204     if (BBWeight > MaxWeight) {
205       Factor = BBWeight / MaxWeight + 1;
206       BBWeight /= Factor;
207       LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor << "\n");
208     }
209 
210     for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(),
211                                           SE = BB->succ_end();
212          SI != SE; ++SI) {
213       MachineBasicBlock *Succ = *SI;
214       Edge E = std::make_pair(BB, Succ);
215       uint64_t EdgeWeight = EdgeWeights[E];
216       EdgeWeight /= Factor;
217 
218       assert(BBWeight >= EdgeWeight &&
219              "BBweight is larger than EdgeWeight -- should not happen.\n");
220 
221       BranchProbability OldProb = BFI->getMBPI()->getEdgeProbability(BB, SI);
222       BranchProbability NewProb(EdgeWeight, BBWeight);
223       if (OldProb == NewProb)
224         continue;
225       BB->setSuccProbability(SI, NewProb);
226 #ifndef NDEBUG
227       if (!ShowFSBranchProb)
228         continue;
229       bool Show = false;
230       BranchProbability Diff;
231       if (OldProb > NewProb)
232         Diff = OldProb - NewProb;
233       else
234         Diff = NewProb - OldProb;
235       Show = (Diff >= BranchProbability(FSProfileDebugProbDiffThreshold, 100));
236       Show &= (BBWeightOrig >= FSProfileDebugBWThreshold);
237 
238       auto DIL = BB->findBranchDebugLoc();
239       auto SuccDIL = Succ->findBranchDebugLoc();
240       if (Show) {
241         dbgs() << "Set branch fs prob: MBB (" << BB->getNumber() << " -> "
242                << Succ->getNumber() << "): ";
243         if (DIL)
244           dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":"
245                  << DIL->getColumn();
246         if (SuccDIL)
247           dbgs() << "-->" << SuccDIL->getFilename() << ":" << SuccDIL->getLine()
248                  << ":" << SuccDIL->getColumn();
249         dbgs() << " W=" << BBWeightOrig << "  " << OldProb << " --> " << NewProb
250                << "\n";
251       }
252 #endif
253     }
254   }
255 }
256 
257 bool MIRProfileLoader::doInitialization(Module &M) {
258   auto &Ctx = M.getContext();
259 
260   auto ReaderOrErr = sampleprof::SampleProfileReader::create(
261       Filename, Ctx, *FS, P, RemappingFilename);
262   if (std::error_code EC = ReaderOrErr.getError()) {
263     std::string Msg = "Could not open profile: " + EC.message();
264     Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg));
265     return false;
266   }
267 
268   Reader = std::move(ReaderOrErr.get());
269   Reader->setModule(&M);
270   ProfileIsValid = (Reader->read() == sampleprof_error::success);
271   Reader->getSummary();
272 
273   return true;
274 }
275 
276 bool MIRProfileLoader::runOnFunction(MachineFunction &MF) {
277   Function &Func = MF.getFunction();
278   clearFunctionData(false);
279   Samples = Reader->getSamplesFor(Func);
280   if (!Samples || Samples->empty())
281     return false;
282 
283   if (getFunctionLoc(MF) == 0)
284     return false;
285 
286   DenseSet<GlobalValue::GUID> InlinedGUIDs;
287   bool Changed = computeAndPropagateWeights(MF, InlinedGUIDs);
288 
289   // Set the new BPI, BFI.
290   setBranchProbs(MF);
291 
292   return Changed;
293 }
294 
295 } // namespace llvm
296 
297 MIRProfileLoaderPass::MIRProfileLoaderPass(
298     std::string FileName, std::string RemappingFileName, FSDiscriminatorPass P,
299     IntrusiveRefCntPtr<vfs::FileSystem> FS)
300     : MachineFunctionPass(ID), ProfileFileName(FileName), P(P) {
301   LowBit = getFSPassBitBegin(P);
302   HighBit = getFSPassBitEnd(P);
303 
304   auto VFS = FS ? std::move(FS) : vfs::getRealFileSystem();
305   MIRSampleLoader = std::make_unique<MIRProfileLoader>(
306       FileName, RemappingFileName, std::move(VFS));
307   assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
308 }
309 
310 bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) {
311   if (!MIRSampleLoader->isValid())
312     return false;
313 
314   LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: "
315                     << MF.getFunction().getName() << "\n");
316   MBFI = &getAnalysis<MachineBlockFrequencyInfo>();
317   MIRSampleLoader->setInitVals(
318       &getAnalysis<MachineDominatorTree>(),
319       &getAnalysis<MachinePostDominatorTree>(), &getAnalysis<MachineLoopInfo>(),
320       MBFI, &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE());
321 
322   MF.RenumberBlocks();
323   if (ViewBFIBefore && ViewBlockLayoutWithBFI != GVDT_None &&
324       (ViewBlockFreqFuncName.empty() ||
325        MF.getFunction().getName().equals(ViewBlockFreqFuncName))) {
326     MBFI->view("MIR_Prof_loader_b." + MF.getName(), false);
327   }
328 
329   bool Changed = MIRSampleLoader->runOnFunction(MF);
330   if (Changed)
331     MBFI->calculate(MF, *MBFI->getMBPI(), *&getAnalysis<MachineLoopInfo>());
332 
333   if (ViewBFIAfter && ViewBlockLayoutWithBFI != GVDT_None &&
334       (ViewBlockFreqFuncName.empty() ||
335        MF.getFunction().getName().equals(ViewBlockFreqFuncName))) {
336     MBFI->view("MIR_prof_loader_a." + MF.getName(), false);
337   }
338 
339   return Changed;
340 }
341 
342 bool MIRProfileLoaderPass::doInitialization(Module &M) {
343   LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M.getName()
344                     << "\n");
345 
346   MIRSampleLoader->setFSPass(P);
347   return MIRSampleLoader->doInitialization(M);
348 }
349 
350 void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const {
351   AU.setPreservesAll();
352   AU.addRequired<MachineBlockFrequencyInfo>();
353   AU.addRequired<MachineDominatorTree>();
354   AU.addRequired<MachinePostDominatorTree>();
355   AU.addRequiredTransitive<MachineLoopInfo>();
356   AU.addRequired<MachineOptimizationRemarkEmitterPass>();
357   MachineFunctionPass::getAnalysisUsage(AU);
358 }
359