1 //===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file provides the implementation of the MIRSampleProfile loader, mainly 10 // for flow sensitive SampleFDO. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/CodeGen/MIRSampleProfile.h" 15 #include "llvm/ADT/DenseMap.h" 16 #include "llvm/ADT/DenseSet.h" 17 #include "llvm/Analysis/BlockFrequencyInfoImpl.h" 18 #include "llvm/CodeGen/MachineBlockFrequencyInfo.h" 19 #include "llvm/CodeGen/MachineBranchProbabilityInfo.h" 20 #include "llvm/CodeGen/MachineDominators.h" 21 #include "llvm/CodeGen/MachineInstr.h" 22 #include "llvm/CodeGen/MachineLoopInfo.h" 23 #include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h" 24 #include "llvm/CodeGen/MachinePostDominators.h" 25 #include "llvm/CodeGen/Passes.h" 26 #include "llvm/IR/Function.h" 27 #include "llvm/IR/PseudoProbe.h" 28 #include "llvm/InitializePasses.h" 29 #include "llvm/Support/CommandLine.h" 30 #include "llvm/Support/Debug.h" 31 #include "llvm/Support/VirtualFileSystem.h" 32 #include "llvm/Support/raw_ostream.h" 33 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h" 34 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h" 35 #include <optional> 36 37 using namespace llvm; 38 using namespace sampleprof; 39 using namespace llvm::sampleprofutil; 40 using ProfileCount = Function::ProfileCount; 41 42 #define DEBUG_TYPE "fs-profile-loader" 43 44 static cl::opt<bool> ShowFSBranchProb( 45 "show-fs-branchprob", cl::Hidden, cl::init(false), 46 cl::desc("Print setting flow sensitive branch probabilities")); 47 static cl::opt<unsigned> FSProfileDebugProbDiffThreshold( 48 "fs-profile-debug-prob-diff-threshold", cl::init(10), 49 cl::desc( 50 "Only show debug message if the branch probability is greater than " 51 "this value (in percentage).")); 52 53 static cl::opt<unsigned> FSProfileDebugBWThreshold( 54 "fs-profile-debug-bw-threshold", cl::init(10000), 55 cl::desc("Only show debug message if the source branch weight is greater " 56 " than this value.")); 57 58 static cl::opt<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden, 59 cl::init(false), 60 cl::desc("View BFI before MIR loader")); 61 static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden, 62 cl::init(false), 63 cl::desc("View BFI after MIR loader")); 64 65 namespace llvm { 66 extern cl::opt<bool> ImprovedFSDiscriminator; 67 } 68 char MIRProfileLoaderPass::ID = 0; 69 70 INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE, 71 "Load MIR Sample Profile", 72 /* cfg = */ false, /* is_analysis = */ false) 73 INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfoWrapperPass) 74 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass) 75 INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass) 76 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass) 77 INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass) 78 INITIALIZE_PASS_END(MIRProfileLoaderPass, DEBUG_TYPE, "Load MIR Sample Profile", 79 /* cfg = */ false, /* is_analysis = */ false) 80 81 char &llvm::MIRProfileLoaderPassID = MIRProfileLoaderPass::ID; 82 83 FunctionPass * 84 llvm::createMIRProfileLoaderPass(std::string File, std::string RemappingFile, 85 FSDiscriminatorPass P, 86 IntrusiveRefCntPtr<vfs::FileSystem> FS) { 87 return new MIRProfileLoaderPass(File, RemappingFile, P, std::move(FS)); 88 } 89 90 namespace llvm { 91 92 // Internal option used to control BFI display only after MBP pass. 93 // Defined in CodeGen/MachineBlockFrequencyInfo.cpp: 94 // -view-block-layout-with-bfi={none | fraction | integer | count} 95 extern cl::opt<GVDAGType> ViewBlockLayoutWithBFI; 96 97 // Command line option to specify the name of the function for CFG dump 98 // Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name= 99 extern cl::opt<std::string> ViewBlockFreqFuncName; 100 101 std::optional<PseudoProbe> extractProbe(const MachineInstr &MI) { 102 if (MI.isPseudoProbe()) { 103 PseudoProbe Probe; 104 Probe.Id = MI.getOperand(1).getImm(); 105 Probe.Type = MI.getOperand(2).getImm(); 106 Probe.Attr = MI.getOperand(3).getImm(); 107 Probe.Factor = 1; 108 DILocation *DebugLoc = MI.getDebugLoc(); 109 Probe.Discriminator = DebugLoc ? DebugLoc->getDiscriminator() : 0; 110 return Probe; 111 } 112 113 // Ignore callsite probes since they do not have FS discriminators. 114 return std::nullopt; 115 } 116 117 namespace afdo_detail { 118 template <> struct IRTraits<MachineBasicBlock> { 119 using InstructionT = MachineInstr; 120 using BasicBlockT = MachineBasicBlock; 121 using FunctionT = MachineFunction; 122 using BlockFrequencyInfoT = MachineBlockFrequencyInfo; 123 using LoopT = MachineLoop; 124 using LoopInfoPtrT = MachineLoopInfo *; 125 using DominatorTreePtrT = MachineDominatorTree *; 126 using PostDominatorTreePtrT = MachinePostDominatorTree *; 127 using PostDominatorTreeT = MachinePostDominatorTree; 128 using OptRemarkEmitterT = MachineOptimizationRemarkEmitter; 129 using OptRemarkAnalysisT = MachineOptimizationRemarkAnalysis; 130 using PredRangeT = 131 iterator_range<SmallVectorImpl<MachineBasicBlock *>::iterator>; 132 using SuccRangeT = 133 iterator_range<SmallVectorImpl<MachineBasicBlock *>::iterator>; 134 static Function &getFunction(MachineFunction &F) { return F.getFunction(); } 135 static const MachineBasicBlock *getEntryBB(const MachineFunction *F) { 136 return GraphTraits<const MachineFunction *>::getEntryNode(F); 137 } 138 static PredRangeT getPredecessors(MachineBasicBlock *BB) { 139 return BB->predecessors(); 140 } 141 static SuccRangeT getSuccessors(MachineBasicBlock *BB) { 142 return BB->successors(); 143 } 144 }; 145 } // namespace afdo_detail 146 147 class MIRProfileLoader final 148 : public SampleProfileLoaderBaseImpl<MachineFunction> { 149 public: 150 void setInitVals(MachineDominatorTree *MDT, MachinePostDominatorTree *MPDT, 151 MachineLoopInfo *MLI, MachineBlockFrequencyInfo *MBFI, 152 MachineOptimizationRemarkEmitter *MORE) { 153 DT = MDT; 154 PDT = MPDT; 155 LI = MLI; 156 BFI = MBFI; 157 ORE = MORE; 158 } 159 void setFSPass(FSDiscriminatorPass Pass) { 160 P = Pass; 161 LowBit = getFSPassBitBegin(P); 162 HighBit = getFSPassBitEnd(P); 163 assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit"); 164 } 165 166 MIRProfileLoader(StringRef Name, StringRef RemapName, 167 IntrusiveRefCntPtr<vfs::FileSystem> FS) 168 : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName), 169 std::move(FS)) {} 170 171 void setBranchProbs(MachineFunction &F); 172 bool runOnFunction(MachineFunction &F); 173 bool doInitialization(Module &M); 174 bool isValid() const { return ProfileIsValid; } 175 176 protected: 177 friend class SampleCoverageTracker; 178 179 /// Hold the information of the basic block frequency. 180 MachineBlockFrequencyInfo *BFI; 181 182 /// PassNum is the sequence number this pass is called, start from 1. 183 FSDiscriminatorPass P; 184 185 // LowBit in the FS discriminator used by this instance. Note the number is 186 // 0-based. Base discrimnator use bit 0 to bit 11. 187 unsigned LowBit; 188 // HighwBit in the FS discriminator used by this instance. Note the number 189 // is 0-based. 190 unsigned HighBit; 191 192 bool ProfileIsValid = true; 193 ErrorOr<uint64_t> getInstWeight(const MachineInstr &MI) override { 194 if (FunctionSamples::ProfileIsProbeBased) 195 return getProbeWeight(MI); 196 if (ImprovedFSDiscriminator && MI.isMetaInstruction()) 197 return std::error_code(); 198 return getInstWeightImpl(MI); 199 } 200 }; 201 202 template <> 203 void SampleProfileLoaderBaseImpl<MachineFunction>::computeDominanceAndLoopInfo( 204 MachineFunction &F) {} 205 206 void MIRProfileLoader::setBranchProbs(MachineFunction &F) { 207 LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n"); 208 for (auto &BI : F) { 209 MachineBasicBlock *BB = &BI; 210 if (BB->succ_size() < 2) 211 continue; 212 const MachineBasicBlock *EC = EquivalenceClass[BB]; 213 uint64_t BBWeight = BlockWeights[EC]; 214 uint64_t SumEdgeWeight = 0; 215 for (MachineBasicBlock *Succ : BB->successors()) { 216 Edge E = std::make_pair(BB, Succ); 217 SumEdgeWeight += EdgeWeights[E]; 218 } 219 220 if (BBWeight != SumEdgeWeight) { 221 LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight=" 222 << BBWeight << " SumEdgeWeight= " << SumEdgeWeight 223 << "\n"); 224 BBWeight = SumEdgeWeight; 225 } 226 if (BBWeight == 0) { 227 LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n"); 228 continue; 229 } 230 231 #ifndef NDEBUG 232 uint64_t BBWeightOrig = BBWeight; 233 #endif 234 uint32_t MaxWeight = std::numeric_limits<uint32_t>::max(); 235 uint32_t Factor = 1; 236 if (BBWeight > MaxWeight) { 237 Factor = BBWeight / MaxWeight + 1; 238 BBWeight /= Factor; 239 LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor << "\n"); 240 } 241 242 for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(), 243 SE = BB->succ_end(); 244 SI != SE; ++SI) { 245 MachineBasicBlock *Succ = *SI; 246 Edge E = std::make_pair(BB, Succ); 247 uint64_t EdgeWeight = EdgeWeights[E]; 248 EdgeWeight /= Factor; 249 250 assert(BBWeight >= EdgeWeight && 251 "BBweight is larger than EdgeWeight -- should not happen.\n"); 252 253 BranchProbability OldProb = BFI->getMBPI()->getEdgeProbability(BB, SI); 254 BranchProbability NewProb(EdgeWeight, BBWeight); 255 if (OldProb == NewProb) 256 continue; 257 BB->setSuccProbability(SI, NewProb); 258 #ifndef NDEBUG 259 if (!ShowFSBranchProb) 260 continue; 261 bool Show = false; 262 BranchProbability Diff; 263 if (OldProb > NewProb) 264 Diff = OldProb - NewProb; 265 else 266 Diff = NewProb - OldProb; 267 Show = (Diff >= BranchProbability(FSProfileDebugProbDiffThreshold, 100)); 268 Show &= (BBWeightOrig >= FSProfileDebugBWThreshold); 269 270 auto DIL = BB->findBranchDebugLoc(); 271 auto SuccDIL = Succ->findBranchDebugLoc(); 272 if (Show) { 273 dbgs() << "Set branch fs prob: MBB (" << BB->getNumber() << " -> " 274 << Succ->getNumber() << "): "; 275 if (DIL) 276 dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":" 277 << DIL->getColumn(); 278 if (SuccDIL) 279 dbgs() << "-->" << SuccDIL->getFilename() << ":" << SuccDIL->getLine() 280 << ":" << SuccDIL->getColumn(); 281 dbgs() << " W=" << BBWeightOrig << " " << OldProb << " --> " << NewProb 282 << "\n"; 283 } 284 #endif 285 } 286 } 287 } 288 289 bool MIRProfileLoader::doInitialization(Module &M) { 290 auto &Ctx = M.getContext(); 291 292 auto ReaderOrErr = sampleprof::SampleProfileReader::create( 293 Filename, Ctx, *FS, P, RemappingFilename); 294 if (std::error_code EC = ReaderOrErr.getError()) { 295 std::string Msg = "Could not open profile: " + EC.message(); 296 Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg)); 297 return false; 298 } 299 300 Reader = std::move(ReaderOrErr.get()); 301 Reader->setModule(&M); 302 ProfileIsValid = (Reader->read() == sampleprof_error::success); 303 304 // Load pseudo probe descriptors for probe-based function samples. 305 if (Reader->profileIsProbeBased()) { 306 ProbeManager = std::make_unique<PseudoProbeManager>(M); 307 if (!ProbeManager->moduleIsProbed(M)) { 308 return false; 309 } 310 } 311 312 return true; 313 } 314 315 bool MIRProfileLoader::runOnFunction(MachineFunction &MF) { 316 // Do not load non-FS profiles. A line or probe can get a zero-valued 317 // discriminator at certain pass which could result in accidentally loading 318 // the corresponding base counter in the non-FS profile, while a non-zero 319 // discriminator would end up getting zero samples. This could in turn undo 320 // the sample distribution effort done by previous BFI maintenance and the 321 // probe distribution factor work for pseudo probes. 322 if (!Reader->profileIsFS()) 323 return false; 324 325 Function &Func = MF.getFunction(); 326 clearFunctionData(false); 327 Samples = Reader->getSamplesFor(Func); 328 if (!Samples || Samples->empty()) 329 return false; 330 331 if (FunctionSamples::ProfileIsProbeBased) { 332 if (!ProbeManager->profileIsValid(MF.getFunction(), *Samples)) 333 return false; 334 } else { 335 if (getFunctionLoc(MF) == 0) 336 return false; 337 } 338 339 DenseSet<GlobalValue::GUID> InlinedGUIDs; 340 bool Changed = computeAndPropagateWeights(MF, InlinedGUIDs); 341 342 // Set the new BPI, BFI. 343 setBranchProbs(MF); 344 345 return Changed; 346 } 347 348 } // namespace llvm 349 350 MIRProfileLoaderPass::MIRProfileLoaderPass( 351 std::string FileName, std::string RemappingFileName, FSDiscriminatorPass P, 352 IntrusiveRefCntPtr<vfs::FileSystem> FS) 353 : MachineFunctionPass(ID), ProfileFileName(FileName), P(P) { 354 LowBit = getFSPassBitBegin(P); 355 HighBit = getFSPassBitEnd(P); 356 357 auto VFS = FS ? std::move(FS) : vfs::getRealFileSystem(); 358 MIRSampleLoader = std::make_unique<MIRProfileLoader>( 359 FileName, RemappingFileName, std::move(VFS)); 360 assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit"); 361 } 362 363 bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) { 364 if (!MIRSampleLoader->isValid()) 365 return false; 366 367 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: " 368 << MF.getFunction().getName() << "\n"); 369 MBFI = &getAnalysis<MachineBlockFrequencyInfoWrapperPass>().getMBFI(); 370 auto *MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree(); 371 auto *MPDT = 372 &getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree(); 373 374 MF.RenumberBlocks(); 375 MDT->updateBlockNumbers(); 376 MPDT->updateBlockNumbers(); 377 378 MIRSampleLoader->setInitVals( 379 MDT, MPDT, &getAnalysis<MachineLoopInfoWrapperPass>().getLI(), MBFI, 380 &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE()); 381 382 if (ViewBFIBefore && ViewBlockLayoutWithBFI != GVDT_None && 383 (ViewBlockFreqFuncName.empty() || 384 MF.getFunction().getName() == ViewBlockFreqFuncName)) { 385 MBFI->view("MIR_Prof_loader_b." + MF.getName(), false); 386 } 387 388 bool Changed = MIRSampleLoader->runOnFunction(MF); 389 if (Changed) 390 MBFI->calculate(MF, *MBFI->getMBPI(), 391 *&getAnalysis<MachineLoopInfoWrapperPass>().getLI()); 392 393 if (ViewBFIAfter && ViewBlockLayoutWithBFI != GVDT_None && 394 (ViewBlockFreqFuncName.empty() || 395 MF.getFunction().getName() == ViewBlockFreqFuncName)) { 396 MBFI->view("MIR_prof_loader_a." + MF.getName(), false); 397 } 398 399 return Changed; 400 } 401 402 bool MIRProfileLoaderPass::doInitialization(Module &M) { 403 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M.getName() 404 << "\n"); 405 406 MIRSampleLoader->setFSPass(P); 407 return MIRSampleLoader->doInitialization(M); 408 } 409 410 void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const { 411 AU.setPreservesAll(); 412 AU.addRequired<MachineBlockFrequencyInfoWrapperPass>(); 413 AU.addRequired<MachineDominatorTreeWrapperPass>(); 414 AU.addRequired<MachinePostDominatorTreeWrapperPass>(); 415 AU.addRequiredTransitive<MachineLoopInfoWrapperPass>(); 416 AU.addRequired<MachineOptimizationRemarkEmitterPass>(); 417 MachineFunctionPass::getAnalysisUsage(AU); 418 } 419