1 //===-- CFGMST.h - Minimum Spanning Tree for CFG ----------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a Union-find algorithm to compute Minimum Spanning Tree 10 // for a given CFG. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H 15 #define LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H 16 17 #include "llvm/ADT/DenseMap.h" 18 #include "llvm/ADT/STLExtras.h" 19 #include "llvm/Analysis/BlockFrequencyInfo.h" 20 #include "llvm/Analysis/BranchProbabilityInfo.h" 21 #include "llvm/Analysis/CFG.h" 22 #include "llvm/Support/BranchProbability.h" 23 #include "llvm/Support/CommandLine.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/raw_ostream.h" 26 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 27 #include <utility> 28 #include <vector> 29 30 #define DEBUG_TYPE "cfgmst" 31 32 using namespace llvm; 33 static cl::opt<bool> PGOInstrumentEntry( 34 "pgo-instrument-entry", cl::init(false), cl::Hidden, 35 cl::desc("Force to instrument function entry basicblock.")); 36 37 namespace llvm { 38 39 /// An union-find based Minimum Spanning Tree for CFG 40 /// 41 /// Implements a Union-find algorithm to compute Minimum Spanning Tree 42 /// for a given CFG. 43 template <class Edge, class BBInfo> class CFGMST { 44 public: 45 Function &F; 46 47 // Store all the edges in CFG. It may contain some stale edges 48 // when Removed is set. 49 std::vector<std::unique_ptr<Edge>> AllEdges; 50 51 // This map records the auxiliary information for each BB. 52 DenseMap<const BasicBlock *, std::unique_ptr<BBInfo>> BBInfos; 53 54 // Whehter the function has an exit block with no successors. 55 // (For function with an infinite loop, this block may be absent) 56 bool ExitBlockFound = false; 57 58 // Find the root group of the G and compress the path from G to the root. 59 BBInfo *findAndCompressGroup(BBInfo *G) { 60 if (G->Group != G) 61 G->Group = findAndCompressGroup(static_cast<BBInfo *>(G->Group)); 62 return static_cast<BBInfo *>(G->Group); 63 } 64 65 // Union BB1 and BB2 into the same group and return true. 66 // Returns false if BB1 and BB2 are already in the same group. 67 bool unionGroups(const BasicBlock *BB1, const BasicBlock *BB2) { 68 BBInfo *BB1G = findAndCompressGroup(&getBBInfo(BB1)); 69 BBInfo *BB2G = findAndCompressGroup(&getBBInfo(BB2)); 70 71 if (BB1G == BB2G) 72 return false; 73 74 // Make the smaller rank tree a direct child or the root of high rank tree. 75 if (BB1G->Rank < BB2G->Rank) 76 BB1G->Group = BB2G; 77 else { 78 BB2G->Group = BB1G; 79 // If the ranks are the same, increment root of one tree by one. 80 if (BB1G->Rank == BB2G->Rank) 81 BB1G->Rank++; 82 } 83 return true; 84 } 85 86 // Give BB, return the auxiliary information. 87 BBInfo &getBBInfo(const BasicBlock *BB) const { 88 auto It = BBInfos.find(BB); 89 assert(It->second.get() != nullptr); 90 return *It->second.get(); 91 } 92 93 // Give BB, return the auxiliary information if it's available. 94 BBInfo *findBBInfo(const BasicBlock *BB) const { 95 auto It = BBInfos.find(BB); 96 if (It == BBInfos.end()) 97 return nullptr; 98 return It->second.get(); 99 } 100 101 // Traverse the CFG using a stack. Find all the edges and assign the weight. 102 // Edges with large weight will be put into MST first so they are less likely 103 // to be instrumented. 104 void buildEdges() { 105 LLVM_DEBUG(dbgs() << "Build Edge on " << F.getName() << "\n"); 106 107 const BasicBlock *Entry = &(F.getEntryBlock()); 108 uint64_t EntryWeight = (BFI != nullptr ? BFI->getEntryFreq() : 2); 109 // If we want to instrument the entry count, lower the weight to 0. 110 if (PGOInstrumentEntry) 111 EntryWeight = 0; 112 Edge *EntryIncoming = nullptr, *EntryOutgoing = nullptr, 113 *ExitOutgoing = nullptr, *ExitIncoming = nullptr; 114 uint64_t MaxEntryOutWeight = 0, MaxExitOutWeight = 0, MaxExitInWeight = 0; 115 116 // Add a fake edge to the entry. 117 EntryIncoming = &addEdge(nullptr, Entry, EntryWeight); 118 LLVM_DEBUG(dbgs() << " Edge: from fake node to " << Entry->getName() 119 << " w = " << EntryWeight << "\n"); 120 121 // Special handling for single BB functions. 122 if (succ_empty(Entry)) { 123 addEdge(Entry, nullptr, EntryWeight); 124 return; 125 } 126 127 static const uint32_t CriticalEdgeMultiplier = 1000; 128 129 for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { 130 Instruction *TI = BB->getTerminator(); 131 uint64_t BBWeight = 132 (BFI != nullptr ? BFI->getBlockFreq(&*BB).getFrequency() : 2); 133 uint64_t Weight = 2; 134 if (int successors = TI->getNumSuccessors()) { 135 for (int i = 0; i != successors; ++i) { 136 BasicBlock *TargetBB = TI->getSuccessor(i); 137 bool Critical = isCriticalEdge(TI, i); 138 uint64_t scaleFactor = BBWeight; 139 if (Critical) { 140 if (scaleFactor < UINT64_MAX / CriticalEdgeMultiplier) 141 scaleFactor *= CriticalEdgeMultiplier; 142 else 143 scaleFactor = UINT64_MAX; 144 } 145 if (BPI != nullptr) 146 Weight = BPI->getEdgeProbability(&*BB, TargetBB).scale(scaleFactor); 147 if (Weight == 0) 148 Weight++; 149 auto *E = &addEdge(&*BB, TargetBB, Weight); 150 E->IsCritical = Critical; 151 LLVM_DEBUG(dbgs() << " Edge: from " << BB->getName() << " to " 152 << TargetBB->getName() << " w=" << Weight << "\n"); 153 154 // Keep track of entry/exit edges: 155 if (&*BB == Entry) { 156 if (Weight > MaxEntryOutWeight) { 157 MaxEntryOutWeight = Weight; 158 EntryOutgoing = E; 159 } 160 } 161 162 auto *TargetTI = TargetBB->getTerminator(); 163 if (TargetTI && !TargetTI->getNumSuccessors()) { 164 if (Weight > MaxExitInWeight) { 165 MaxExitInWeight = Weight; 166 ExitIncoming = E; 167 } 168 } 169 } 170 } else { 171 ExitBlockFound = true; 172 Edge *ExitO = &addEdge(&*BB, nullptr, BBWeight); 173 if (BBWeight > MaxExitOutWeight) { 174 MaxExitOutWeight = BBWeight; 175 ExitOutgoing = ExitO; 176 } 177 LLVM_DEBUG(dbgs() << " Edge: from " << BB->getName() << " to fake exit" 178 << " w = " << BBWeight << "\n"); 179 } 180 } 181 182 // Entry/exit edge adjustment heurisitic: 183 // prefer instrumenting entry edge over exit edge 184 // if possible. Those exit edges may never have a chance to be 185 // executed (for instance the program is an event handling loop) 186 // before the profile is asynchronously dumped. 187 // 188 // If EntryIncoming and ExitOutgoing has similar weight, make sure 189 // ExitOutging is selected as the min-edge. Similarly, if EntryOutgoing 190 // and ExitIncoming has similar weight, make sure ExitIncoming becomes 191 // the min-edge. 192 uint64_t EntryInWeight = EntryWeight; 193 194 if (EntryInWeight >= MaxExitOutWeight && 195 EntryInWeight * 2 < MaxExitOutWeight * 3) { 196 EntryIncoming->Weight = MaxExitOutWeight; 197 ExitOutgoing->Weight = EntryInWeight + 1; 198 } 199 200 if (MaxEntryOutWeight >= MaxExitInWeight && 201 MaxEntryOutWeight * 2 < MaxExitInWeight * 3) { 202 EntryOutgoing->Weight = MaxExitInWeight; 203 ExitIncoming->Weight = MaxEntryOutWeight + 1; 204 } 205 } 206 207 // Sort CFG edges based on its weight. 208 void sortEdgesByWeight() { 209 llvm::stable_sort(AllEdges, [](const std::unique_ptr<Edge> &Edge1, 210 const std::unique_ptr<Edge> &Edge2) { 211 return Edge1->Weight > Edge2->Weight; 212 }); 213 } 214 215 // Traverse all the edges and compute the Minimum Weight Spanning Tree 216 // using union-find algorithm. 217 void computeMinimumSpanningTree() { 218 // First, put all the critical edge with landing-pad as the Dest to MST. 219 // This works around the insufficient support of critical edges split 220 // when destination BB is a landing pad. 221 for (auto &Ei : AllEdges) { 222 if (Ei->Removed) 223 continue; 224 if (Ei->IsCritical) { 225 if (Ei->DestBB && Ei->DestBB->isLandingPad()) { 226 if (unionGroups(Ei->SrcBB, Ei->DestBB)) 227 Ei->InMST = true; 228 } 229 } 230 } 231 232 for (auto &Ei : AllEdges) { 233 if (Ei->Removed) 234 continue; 235 // If we detect infinite loops, force 236 // instrumenting the entry edge: 237 if (!ExitBlockFound && Ei->SrcBB == nullptr) 238 continue; 239 if (unionGroups(Ei->SrcBB, Ei->DestBB)) 240 Ei->InMST = true; 241 } 242 } 243 244 // Dump the Debug information about the instrumentation. 245 void dumpEdges(raw_ostream &OS, const Twine &Message) const { 246 if (!Message.str().empty()) 247 OS << Message << "\n"; 248 OS << " Number of Basic Blocks: " << BBInfos.size() << "\n"; 249 for (auto &BI : BBInfos) { 250 const BasicBlock *BB = BI.first; 251 OS << " BB: " << (BB == nullptr ? "FakeNode" : BB->getName()) << " " 252 << BI.second->infoString() << "\n"; 253 } 254 255 OS << " Number of Edges: " << AllEdges.size() 256 << " (*: Instrument, C: CriticalEdge, -: Removed)\n"; 257 uint32_t Count = 0; 258 for (auto &EI : AllEdges) 259 OS << " Edge " << Count++ << ": " << getBBInfo(EI->SrcBB).Index << "-->" 260 << getBBInfo(EI->DestBB).Index << EI->infoString() << "\n"; 261 } 262 263 // Add an edge to AllEdges with weight W. 264 Edge &addEdge(const BasicBlock *Src, const BasicBlock *Dest, uint64_t W) { 265 uint32_t Index = BBInfos.size(); 266 auto Iter = BBInfos.end(); 267 bool Inserted; 268 std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Src, nullptr)); 269 if (Inserted) { 270 // Newly inserted, update the real info. 271 Iter->second = std::move(std::make_unique<BBInfo>(Index)); 272 Index++; 273 } 274 std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Dest, nullptr)); 275 if (Inserted) 276 // Newly inserted, update the real info. 277 Iter->second = std::move(std::make_unique<BBInfo>(Index)); 278 AllEdges.emplace_back(new Edge(Src, Dest, W)); 279 return *AllEdges.back(); 280 } 281 282 BranchProbabilityInfo *BPI; 283 BlockFrequencyInfo *BFI; 284 285 public: 286 CFGMST(Function &Func, BranchProbabilityInfo *BPI_ = nullptr, 287 BlockFrequencyInfo *BFI_ = nullptr) 288 : F(Func), BPI(BPI_), BFI(BFI_) { 289 buildEdges(); 290 sortEdgesByWeight(); 291 computeMinimumSpanningTree(); 292 if (PGOInstrumentEntry && (AllEdges.size() > 1)) 293 std::iter_swap(std::move(AllEdges.begin()), 294 std::move(AllEdges.begin() + AllEdges.size() - 1)); 295 } 296 }; 297 298 } // end namespace llvm 299 300 #undef DEBUG_TYPE // "cfgmst" 301 302 #endif // LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H 303