1 //===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file provides the implementation of the MIRSampleProfile loader, mainly 10 // for flow sensitive SampleFDO. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/CodeGen/MIRSampleProfile.h" 15 #include "llvm/ADT/DenseMap.h" 16 #include "llvm/ADT/DenseSet.h" 17 #include "llvm/Analysis/BlockFrequencyInfoImpl.h" 18 #include "llvm/IR/Function.h" 19 #include "llvm/Support/CommandLine.h" 20 #include "llvm/Support/Debug.h" 21 #include "llvm/Support/raw_ostream.h" 22 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h" 23 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h" 24 25 using namespace llvm; 26 using namespace sampleprof; 27 using namespace llvm::sampleprofutil; 28 using ProfileCount = Function::ProfileCount; 29 30 #define DEBUG_TYPE "fs-profile-loader" 31 32 static cl::opt<bool> ShowFSBranchProb( 33 "show-fs-branchprob", cl::Hidden, cl::init(false), 34 cl::desc("Print setting flow sensitive branch probabilities")); 35 static cl::opt<unsigned> FSProfileDebugProbDiffThreshold( 36 "fs-profile-debug-prob-diff-threshold", cl::init(10), 37 cl::desc("Only show debug message if the branch probility is greater than " 38 "this value (in percentage).")); 39 40 static cl::opt<unsigned> FSProfileDebugBWThreshold( 41 "fs-profile-debug-bw-threshold", cl::init(10000), 42 cl::desc("Only show debug message if the source branch weight is greater " 43 " than this value.")); 44 45 static cl::opt<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden, 46 cl::init(false), 47 cl::desc("View BFI before MIR loader")); 48 static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden, 49 cl::init(false), 50 cl::desc("View BFI after MIR loader")); 51 52 char MIRProfileLoaderPass::ID = 0; 53 54 INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE, 55 "Load MIR Sample Profile", 56 /* cfg = */ false, /* is_analysis = */ false) 57 INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo) 58 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree) 59 INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTree) 60 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo) 61 INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass) 62 INITIALIZE_PASS_END(MIRProfileLoaderPass, DEBUG_TYPE, "Load MIR Sample Profile", 63 /* cfg = */ false, /* is_analysis = */ false) 64 65 char &llvm::MIRProfileLoaderPassID = MIRProfileLoaderPass::ID; 66 67 FunctionPass *llvm::createMIRProfileLoaderPass(std::string File, 68 std::string RemappingFile, 69 FSDiscriminatorPass P) { 70 return new MIRProfileLoaderPass(File, RemappingFile, P); 71 } 72 73 namespace llvm { 74 75 // Internal option used to control BFI display only after MBP pass. 76 // Defined in CodeGen/MachineBlockFrequencyInfo.cpp: 77 // -view-block-layout-with-bfi={none | fraction | integer | count} 78 extern cl::opt<GVDAGType> ViewBlockLayoutWithBFI; 79 80 // Command line option to specify the name of the function for CFG dump 81 // Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name= 82 extern cl::opt<std::string> ViewBlockFreqFuncName; 83 84 namespace afdo_detail { 85 template <> struct IRTraits<MachineBasicBlock> { 86 using InstructionT = MachineInstr; 87 using BasicBlockT = MachineBasicBlock; 88 using FunctionT = MachineFunction; 89 using BlockFrequencyInfoT = MachineBlockFrequencyInfo; 90 using LoopT = MachineLoop; 91 using LoopInfoPtrT = MachineLoopInfo *; 92 using DominatorTreePtrT = MachineDominatorTree *; 93 using PostDominatorTreePtrT = MachinePostDominatorTree *; 94 using PostDominatorTreeT = MachinePostDominatorTree; 95 using OptRemarkEmitterT = MachineOptimizationRemarkEmitter; 96 using OptRemarkAnalysisT = MachineOptimizationRemarkAnalysis; 97 using PredRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>; 98 using SuccRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>; 99 static Function &getFunction(MachineFunction &F) { return F.getFunction(); } 100 static const MachineBasicBlock *getEntryBB(const MachineFunction *F) { 101 return GraphTraits<const MachineFunction *>::getEntryNode(F); 102 } 103 static PredRangeT getPredecessors(MachineBasicBlock *BB) { 104 return BB->predecessors(); 105 } 106 static SuccRangeT getSuccessors(MachineBasicBlock *BB) { 107 return BB->successors(); 108 } 109 }; 110 } // namespace afdo_detail 111 112 class MIRProfileLoader final 113 : public SampleProfileLoaderBaseImpl<MachineBasicBlock> { 114 public: 115 void setInitVals(MachineDominatorTree *MDT, MachinePostDominatorTree *MPDT, 116 MachineLoopInfo *MLI, MachineBlockFrequencyInfo *MBFI, 117 MachineOptimizationRemarkEmitter *MORE) { 118 DT = MDT; 119 PDT = MPDT; 120 LI = MLI; 121 BFI = MBFI; 122 ORE = MORE; 123 } 124 void setFSPass(FSDiscriminatorPass Pass) { 125 P = Pass; 126 LowBit = getFSPassBitBegin(P); 127 HighBit = getFSPassBitEnd(P); 128 assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit"); 129 } 130 131 MIRProfileLoader(StringRef Name, StringRef RemapName) 132 : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName)) { 133 } 134 135 void setBranchProbs(MachineFunction &F); 136 bool runOnFunction(MachineFunction &F); 137 bool doInitialization(Module &M); 138 bool isValid() const { return ProfileIsValid; } 139 140 protected: 141 friend class SampleCoverageTracker; 142 143 /// Hold the information of the basic block frequency. 144 MachineBlockFrequencyInfo *BFI; 145 146 /// PassNum is the sequence number this pass is called, start from 1. 147 FSDiscriminatorPass P; 148 149 // LowBit in the FS discriminator used by this instance. Note the number is 150 // 0-based. Base discrimnator use bit 0 to bit 11. 151 unsigned LowBit; 152 // HighwBit in the FS discriminator used by this instance. Note the number 153 // is 0-based. 154 unsigned HighBit; 155 156 bool ProfileIsValid = true; 157 }; 158 159 template <> 160 void SampleProfileLoaderBaseImpl< 161 MachineBasicBlock>::computeDominanceAndLoopInfo(MachineFunction &F) {} 162 163 void MIRProfileLoader::setBranchProbs(MachineFunction &F) { 164 LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n"); 165 for (auto &BI : F) { 166 MachineBasicBlock *BB = &BI; 167 if (BB->succ_size() < 2) 168 continue; 169 const MachineBasicBlock *EC = EquivalenceClass[BB]; 170 uint64_t BBWeight = BlockWeights[EC]; 171 uint64_t SumEdgeWeight = 0; 172 for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(), 173 SE = BB->succ_end(); 174 SI != SE; ++SI) { 175 MachineBasicBlock *Succ = *SI; 176 Edge E = std::make_pair(BB, Succ); 177 SumEdgeWeight += EdgeWeights[E]; 178 } 179 180 if (BBWeight != SumEdgeWeight) { 181 LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight=" 182 << BBWeight << " SumEdgeWeight= " << SumEdgeWeight 183 << "\n"); 184 BBWeight = SumEdgeWeight; 185 } 186 if (BBWeight == 0) { 187 LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n"); 188 continue; 189 } 190 191 #ifndef NDEBUG 192 uint64_t BBWeightOrig = BBWeight; 193 #endif 194 uint32_t MaxWeight = std::numeric_limits<uint32_t>::max(); 195 uint32_t Factor = 1; 196 if (BBWeight > MaxWeight) { 197 Factor = BBWeight / MaxWeight + 1; 198 BBWeight /= Factor; 199 LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor << "\n"); 200 } 201 202 for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(), 203 SE = BB->succ_end(); 204 SI != SE; ++SI) { 205 MachineBasicBlock *Succ = *SI; 206 Edge E = std::make_pair(BB, Succ); 207 uint64_t EdgeWeight = EdgeWeights[E]; 208 EdgeWeight /= Factor; 209 210 assert(BBWeight >= EdgeWeight && 211 "BBweight is larger than EdgeWeight -- should not happen.\n"); 212 213 BranchProbability OldProb = BFI->getMBPI()->getEdgeProbability(BB, SI); 214 BranchProbability NewProb(EdgeWeight, BBWeight); 215 if (OldProb == NewProb) 216 continue; 217 BB->setSuccProbability(SI, NewProb); 218 #ifndef NDEBUG 219 if (!ShowFSBranchProb) 220 continue; 221 bool Show = false; 222 BranchProbability Diff; 223 if (OldProb > NewProb) 224 Diff = OldProb - NewProb; 225 else 226 Diff = NewProb - OldProb; 227 Show = (Diff >= BranchProbability(FSProfileDebugProbDiffThreshold, 100)); 228 Show &= (BBWeightOrig >= FSProfileDebugBWThreshold); 229 230 auto DIL = BB->findBranchDebugLoc(); 231 auto SuccDIL = Succ->findBranchDebugLoc(); 232 if (Show) { 233 dbgs() << "Set branch fs prob: MBB (" << BB->getNumber() << " -> " 234 << Succ->getNumber() << "): "; 235 if (DIL) 236 dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":" 237 << DIL->getColumn(); 238 if (SuccDIL) 239 dbgs() << "-->" << SuccDIL->getFilename() << ":" << SuccDIL->getLine() 240 << ":" << SuccDIL->getColumn(); 241 dbgs() << " W=" << BBWeightOrig << " " << OldProb << " --> " << NewProb 242 << "\n"; 243 } 244 #endif 245 } 246 } 247 } 248 249 bool MIRProfileLoader::doInitialization(Module &M) { 250 auto &Ctx = M.getContext(); 251 252 auto ReaderOrErr = sampleprof::SampleProfileReader::create(Filename, Ctx, P, 253 RemappingFilename); 254 if (std::error_code EC = ReaderOrErr.getError()) { 255 std::string Msg = "Could not open profile: " + EC.message(); 256 Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg)); 257 return false; 258 } 259 260 Reader = std::move(ReaderOrErr.get()); 261 Reader->setModule(&M); 262 ProfileIsValid = (Reader->read() == sampleprof_error::success); 263 Reader->getSummary(); 264 265 return true; 266 } 267 268 bool MIRProfileLoader::runOnFunction(MachineFunction &MF) { 269 Function &Func = MF.getFunction(); 270 clearFunctionData(false); 271 Samples = Reader->getSamplesFor(Func); 272 if (!Samples || Samples->empty()) 273 return false; 274 275 if (getFunctionLoc(MF) == 0) 276 return false; 277 278 DenseSet<GlobalValue::GUID> InlinedGUIDs; 279 bool Changed = computeAndPropagateWeights(MF, InlinedGUIDs); 280 281 // Set the new BPI, BFI. 282 setBranchProbs(MF); 283 284 return Changed; 285 } 286 287 } // namespace llvm 288 289 MIRProfileLoaderPass::MIRProfileLoaderPass(std::string FileName, 290 std::string RemappingFileName, 291 FSDiscriminatorPass P) 292 : MachineFunctionPass(ID), ProfileFileName(FileName), P(P), 293 MIRSampleLoader( 294 std::make_unique<MIRProfileLoader>(FileName, RemappingFileName)) { 295 LowBit = getFSPassBitBegin(P); 296 HighBit = getFSPassBitEnd(P); 297 assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit"); 298 } 299 300 bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) { 301 if (!MIRSampleLoader->isValid()) 302 return false; 303 304 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: " 305 << MF.getFunction().getName() << "\n"); 306 MBFI = &getAnalysis<MachineBlockFrequencyInfo>(); 307 MIRSampleLoader->setInitVals( 308 &getAnalysis<MachineDominatorTree>(), 309 &getAnalysis<MachinePostDominatorTree>(), &getAnalysis<MachineLoopInfo>(), 310 MBFI, &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE()); 311 312 MF.RenumberBlocks(); 313 if (ViewBFIBefore && ViewBlockLayoutWithBFI != GVDT_None && 314 (ViewBlockFreqFuncName.empty() || 315 MF.getFunction().getName().equals(ViewBlockFreqFuncName))) { 316 MBFI->view("MIR_Prof_loader_b." + MF.getName(), false); 317 } 318 319 bool Changed = MIRSampleLoader->runOnFunction(MF); 320 321 if (ViewBFIAfter && ViewBlockLayoutWithBFI != GVDT_None && 322 (ViewBlockFreqFuncName.empty() || 323 MF.getFunction().getName().equals(ViewBlockFreqFuncName))) { 324 MBFI->view("MIR_prof_loader_a." + MF.getName(), false); 325 } 326 327 return Changed; 328 } 329 330 bool MIRProfileLoaderPass::doInitialization(Module &M) { 331 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M.getName() 332 << "\n"); 333 334 MIRSampleLoader->setFSPass(P); 335 return MIRSampleLoader->doInitialization(M); 336 } 337 338 void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const { 339 AU.setPreservesAll(); 340 AU.addRequired<MachineBlockFrequencyInfo>(); 341 AU.addRequired<MachineDominatorTree>(); 342 AU.addRequired<MachinePostDominatorTree>(); 343 AU.addRequiredTransitive<MachineLoopInfo>(); 344 AU.addRequired<MachineOptimizationRemarkEmitterPass>(); 345 MachineFunctionPass::getAnalysisUsage(AU); 346 } 347