1 //===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file provides the implementation of the MIRSampleProfile loader, mainly 10 // for flow sensitive SampleFDO. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/CodeGen/MIRSampleProfile.h" 15 #include "llvm/ADT/DenseMap.h" 16 #include "llvm/ADT/DenseSet.h" 17 #include "llvm/Analysis/BlockFrequencyInfoImpl.h" 18 #include "llvm/CodeGen/MachineBlockFrequencyInfo.h" 19 #include "llvm/CodeGen/MachineBranchProbabilityInfo.h" 20 #include "llvm/CodeGen/MachineDominators.h" 21 #include "llvm/CodeGen/MachineLoopInfo.h" 22 #include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h" 23 #include "llvm/CodeGen/MachinePostDominators.h" 24 #include "llvm/CodeGen/Passes.h" 25 #include "llvm/IR/Function.h" 26 #include "llvm/InitializePasses.h" 27 #include "llvm/Support/CommandLine.h" 28 #include "llvm/Support/Debug.h" 29 #include "llvm/Support/VirtualFileSystem.h" 30 #include "llvm/Support/raw_ostream.h" 31 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h" 32 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h" 33 34 using namespace llvm; 35 using namespace sampleprof; 36 using namespace llvm::sampleprofutil; 37 using ProfileCount = Function::ProfileCount; 38 39 #define DEBUG_TYPE "fs-profile-loader" 40 41 static cl::opt<bool> ShowFSBranchProb( 42 "show-fs-branchprob", cl::Hidden, cl::init(false), 43 cl::desc("Print setting flow sensitive branch probabilities")); 44 static cl::opt<unsigned> FSProfileDebugProbDiffThreshold( 45 "fs-profile-debug-prob-diff-threshold", cl::init(10), 46 cl::desc("Only show debug message if the branch probility is greater than " 47 "this value (in percentage).")); 48 49 static cl::opt<unsigned> FSProfileDebugBWThreshold( 50 "fs-profile-debug-bw-threshold", cl::init(10000), 51 cl::desc("Only show debug message if the source branch weight is greater " 52 " than this value.")); 53 54 static cl::opt<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden, 55 cl::init(false), 56 cl::desc("View BFI before MIR loader")); 57 static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden, 58 cl::init(false), 59 cl::desc("View BFI after MIR loader")); 60 61 char MIRProfileLoaderPass::ID = 0; 62 63 INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE, 64 "Load MIR Sample Profile", 65 /* cfg = */ false, /* is_analysis = */ false) 66 INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo) 67 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree) 68 INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTree) 69 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo) 70 INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass) 71 INITIALIZE_PASS_END(MIRProfileLoaderPass, DEBUG_TYPE, "Load MIR Sample Profile", 72 /* cfg = */ false, /* is_analysis = */ false) 73 74 char &llvm::MIRProfileLoaderPassID = MIRProfileLoaderPass::ID; 75 76 FunctionPass * 77 llvm::createMIRProfileLoaderPass(std::string File, std::string RemappingFile, 78 FSDiscriminatorPass P, 79 IntrusiveRefCntPtr<vfs::FileSystem> FS) { 80 return new MIRProfileLoaderPass(File, RemappingFile, P, std::move(FS)); 81 } 82 83 namespace llvm { 84 85 // Internal option used to control BFI display only after MBP pass. 86 // Defined in CodeGen/MachineBlockFrequencyInfo.cpp: 87 // -view-block-layout-with-bfi={none | fraction | integer | count} 88 extern cl::opt<GVDAGType> ViewBlockLayoutWithBFI; 89 90 // Command line option to specify the name of the function for CFG dump 91 // Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name= 92 extern cl::opt<std::string> ViewBlockFreqFuncName; 93 94 namespace afdo_detail { 95 template <> struct IRTraits<MachineBasicBlock> { 96 using InstructionT = MachineInstr; 97 using BasicBlockT = MachineBasicBlock; 98 using FunctionT = MachineFunction; 99 using BlockFrequencyInfoT = MachineBlockFrequencyInfo; 100 using LoopT = MachineLoop; 101 using LoopInfoPtrT = MachineLoopInfo *; 102 using DominatorTreePtrT = MachineDominatorTree *; 103 using PostDominatorTreePtrT = MachinePostDominatorTree *; 104 using PostDominatorTreeT = MachinePostDominatorTree; 105 using OptRemarkEmitterT = MachineOptimizationRemarkEmitter; 106 using OptRemarkAnalysisT = MachineOptimizationRemarkAnalysis; 107 using PredRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>; 108 using SuccRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>; 109 static Function &getFunction(MachineFunction &F) { return F.getFunction(); } 110 static const MachineBasicBlock *getEntryBB(const MachineFunction *F) { 111 return GraphTraits<const MachineFunction *>::getEntryNode(F); 112 } 113 static PredRangeT getPredecessors(MachineBasicBlock *BB) { 114 return BB->predecessors(); 115 } 116 static SuccRangeT getSuccessors(MachineBasicBlock *BB) { 117 return BB->successors(); 118 } 119 }; 120 } // namespace afdo_detail 121 122 class MIRProfileLoader final 123 : public SampleProfileLoaderBaseImpl<MachineBasicBlock> { 124 public: 125 void setInitVals(MachineDominatorTree *MDT, MachinePostDominatorTree *MPDT, 126 MachineLoopInfo *MLI, MachineBlockFrequencyInfo *MBFI, 127 MachineOptimizationRemarkEmitter *MORE) { 128 DT = MDT; 129 PDT = MPDT; 130 LI = MLI; 131 BFI = MBFI; 132 ORE = MORE; 133 } 134 void setFSPass(FSDiscriminatorPass Pass) { 135 P = Pass; 136 LowBit = getFSPassBitBegin(P); 137 HighBit = getFSPassBitEnd(P); 138 assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit"); 139 } 140 141 MIRProfileLoader(StringRef Name, StringRef RemapName, 142 IntrusiveRefCntPtr<vfs::FileSystem> FS) 143 : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName), 144 std::move(FS)) {} 145 146 void setBranchProbs(MachineFunction &F); 147 bool runOnFunction(MachineFunction &F); 148 bool doInitialization(Module &M); 149 bool isValid() const { return ProfileIsValid; } 150 151 protected: 152 friend class SampleCoverageTracker; 153 154 /// Hold the information of the basic block frequency. 155 MachineBlockFrequencyInfo *BFI; 156 157 /// PassNum is the sequence number this pass is called, start from 1. 158 FSDiscriminatorPass P; 159 160 // LowBit in the FS discriminator used by this instance. Note the number is 161 // 0-based. Base discrimnator use bit 0 to bit 11. 162 unsigned LowBit; 163 // HighwBit in the FS discriminator used by this instance. Note the number 164 // is 0-based. 165 unsigned HighBit; 166 167 bool ProfileIsValid = true; 168 }; 169 170 template <> 171 void SampleProfileLoaderBaseImpl< 172 MachineBasicBlock>::computeDominanceAndLoopInfo(MachineFunction &F) {} 173 174 void MIRProfileLoader::setBranchProbs(MachineFunction &F) { 175 LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n"); 176 for (auto &BI : F) { 177 MachineBasicBlock *BB = &BI; 178 if (BB->succ_size() < 2) 179 continue; 180 const MachineBasicBlock *EC = EquivalenceClass[BB]; 181 uint64_t BBWeight = BlockWeights[EC]; 182 uint64_t SumEdgeWeight = 0; 183 for (MachineBasicBlock *Succ : BB->successors()) { 184 Edge E = std::make_pair(BB, Succ); 185 SumEdgeWeight += EdgeWeights[E]; 186 } 187 188 if (BBWeight != SumEdgeWeight) { 189 LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight=" 190 << BBWeight << " SumEdgeWeight= " << SumEdgeWeight 191 << "\n"); 192 BBWeight = SumEdgeWeight; 193 } 194 if (BBWeight == 0) { 195 LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n"); 196 continue; 197 } 198 199 #ifndef NDEBUG 200 uint64_t BBWeightOrig = BBWeight; 201 #endif 202 uint32_t MaxWeight = std::numeric_limits<uint32_t>::max(); 203 uint32_t Factor = 1; 204 if (BBWeight > MaxWeight) { 205 Factor = BBWeight / MaxWeight + 1; 206 BBWeight /= Factor; 207 LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor << "\n"); 208 } 209 210 for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(), 211 SE = BB->succ_end(); 212 SI != SE; ++SI) { 213 MachineBasicBlock *Succ = *SI; 214 Edge E = std::make_pair(BB, Succ); 215 uint64_t EdgeWeight = EdgeWeights[E]; 216 EdgeWeight /= Factor; 217 218 assert(BBWeight >= EdgeWeight && 219 "BBweight is larger than EdgeWeight -- should not happen.\n"); 220 221 BranchProbability OldProb = BFI->getMBPI()->getEdgeProbability(BB, SI); 222 BranchProbability NewProb(EdgeWeight, BBWeight); 223 if (OldProb == NewProb) 224 continue; 225 BB->setSuccProbability(SI, NewProb); 226 #ifndef NDEBUG 227 if (!ShowFSBranchProb) 228 continue; 229 bool Show = false; 230 BranchProbability Diff; 231 if (OldProb > NewProb) 232 Diff = OldProb - NewProb; 233 else 234 Diff = NewProb - OldProb; 235 Show = (Diff >= BranchProbability(FSProfileDebugProbDiffThreshold, 100)); 236 Show &= (BBWeightOrig >= FSProfileDebugBWThreshold); 237 238 auto DIL = BB->findBranchDebugLoc(); 239 auto SuccDIL = Succ->findBranchDebugLoc(); 240 if (Show) { 241 dbgs() << "Set branch fs prob: MBB (" << BB->getNumber() << " -> " 242 << Succ->getNumber() << "): "; 243 if (DIL) 244 dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":" 245 << DIL->getColumn(); 246 if (SuccDIL) 247 dbgs() << "-->" << SuccDIL->getFilename() << ":" << SuccDIL->getLine() 248 << ":" << SuccDIL->getColumn(); 249 dbgs() << " W=" << BBWeightOrig << " " << OldProb << " --> " << NewProb 250 << "\n"; 251 } 252 #endif 253 } 254 } 255 } 256 257 bool MIRProfileLoader::doInitialization(Module &M) { 258 auto &Ctx = M.getContext(); 259 260 auto ReaderOrErr = sampleprof::SampleProfileReader::create( 261 Filename, Ctx, *FS, P, RemappingFilename); 262 if (std::error_code EC = ReaderOrErr.getError()) { 263 std::string Msg = "Could not open profile: " + EC.message(); 264 Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg)); 265 return false; 266 } 267 268 Reader = std::move(ReaderOrErr.get()); 269 Reader->setModule(&M); 270 ProfileIsValid = (Reader->read() == sampleprof_error::success); 271 Reader->getSummary(); 272 273 return true; 274 } 275 276 bool MIRProfileLoader::runOnFunction(MachineFunction &MF) { 277 Function &Func = MF.getFunction(); 278 clearFunctionData(false); 279 Samples = Reader->getSamplesFor(Func); 280 if (!Samples || Samples->empty()) 281 return false; 282 283 if (getFunctionLoc(MF) == 0) 284 return false; 285 286 DenseSet<GlobalValue::GUID> InlinedGUIDs; 287 bool Changed = computeAndPropagateWeights(MF, InlinedGUIDs); 288 289 // Set the new BPI, BFI. 290 setBranchProbs(MF); 291 292 return Changed; 293 } 294 295 } // namespace llvm 296 297 MIRProfileLoaderPass::MIRProfileLoaderPass( 298 std::string FileName, std::string RemappingFileName, FSDiscriminatorPass P, 299 IntrusiveRefCntPtr<vfs::FileSystem> FS) 300 : MachineFunctionPass(ID), ProfileFileName(FileName), P(P) { 301 LowBit = getFSPassBitBegin(P); 302 HighBit = getFSPassBitEnd(P); 303 304 auto VFS = FS ? std::move(FS) : vfs::getRealFileSystem(); 305 MIRSampleLoader = std::make_unique<MIRProfileLoader>( 306 FileName, RemappingFileName, std::move(VFS)); 307 assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit"); 308 } 309 310 bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) { 311 if (!MIRSampleLoader->isValid()) 312 return false; 313 314 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: " 315 << MF.getFunction().getName() << "\n"); 316 MBFI = &getAnalysis<MachineBlockFrequencyInfo>(); 317 MIRSampleLoader->setInitVals( 318 &getAnalysis<MachineDominatorTree>(), 319 &getAnalysis<MachinePostDominatorTree>(), &getAnalysis<MachineLoopInfo>(), 320 MBFI, &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE()); 321 322 MF.RenumberBlocks(); 323 if (ViewBFIBefore && ViewBlockLayoutWithBFI != GVDT_None && 324 (ViewBlockFreqFuncName.empty() || 325 MF.getFunction().getName().equals(ViewBlockFreqFuncName))) { 326 MBFI->view("MIR_Prof_loader_b." + MF.getName(), false); 327 } 328 329 bool Changed = MIRSampleLoader->runOnFunction(MF); 330 if (Changed) 331 MBFI->calculate(MF, *MBFI->getMBPI(), *&getAnalysis<MachineLoopInfo>()); 332 333 if (ViewBFIAfter && ViewBlockLayoutWithBFI != GVDT_None && 334 (ViewBlockFreqFuncName.empty() || 335 MF.getFunction().getName().equals(ViewBlockFreqFuncName))) { 336 MBFI->view("MIR_prof_loader_a." + MF.getName(), false); 337 } 338 339 return Changed; 340 } 341 342 bool MIRProfileLoaderPass::doInitialization(Module &M) { 343 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M.getName() 344 << "\n"); 345 346 MIRSampleLoader->setFSPass(P); 347 return MIRSampleLoader->doInitialization(M); 348 } 349 350 void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const { 351 AU.setPreservesAll(); 352 AU.addRequired<MachineBlockFrequencyInfo>(); 353 AU.addRequired<MachineDominatorTree>(); 354 AU.addRequired<MachinePostDominatorTree>(); 355 AU.addRequiredTransitive<MachineLoopInfo>(); 356 AU.addRequired<MachineOptimizationRemarkEmitterPass>(); 357 MachineFunctionPass::getAnalysisUsage(AU); 358 } 359