109467b48Spatrick //===-- CFGMST.h - Minimum Spanning Tree for CFG ----------------*- C++ -*-===// 209467b48Spatrick // 309467b48Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 409467b48Spatrick // See https://llvm.org/LICENSE.txt for license information. 509467b48Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 609467b48Spatrick // 709467b48Spatrick //===----------------------------------------------------------------------===// 809467b48Spatrick // 909467b48Spatrick // This file implements a Union-find algorithm to compute Minimum Spanning Tree 1009467b48Spatrick // for a given CFG. 1109467b48Spatrick // 1209467b48Spatrick //===----------------------------------------------------------------------===// 1309467b48Spatrick 1409467b48Spatrick #ifndef LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H 1509467b48Spatrick #define LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H 1609467b48Spatrick 1709467b48Spatrick #include "llvm/ADT/DenseMap.h" 1809467b48Spatrick #include "llvm/ADT/STLExtras.h" 1909467b48Spatrick #include "llvm/Analysis/BlockFrequencyInfo.h" 2009467b48Spatrick #include "llvm/Analysis/BranchProbabilityInfo.h" 2109467b48Spatrick #include "llvm/Analysis/CFG.h" 2209467b48Spatrick #include "llvm/Support/BranchProbability.h" 2309467b48Spatrick #include "llvm/Support/Debug.h" 2409467b48Spatrick #include "llvm/Support/raw_ostream.h" 2509467b48Spatrick #include "llvm/Transforms/Utils/BasicBlockUtils.h" 2609467b48Spatrick #include <utility> 2709467b48Spatrick #include <vector> 2809467b48Spatrick 2909467b48Spatrick #define DEBUG_TYPE "cfgmst" 3009467b48Spatrick 31097a140dSpatrick using namespace llvm; 32097a140dSpatrick 3309467b48Spatrick namespace llvm { 3409467b48Spatrick 3509467b48Spatrick /// An union-find based Minimum Spanning Tree for CFG 3609467b48Spatrick /// 3709467b48Spatrick /// Implements a Union-find algorithm to compute Minimum Spanning Tree 3809467b48Spatrick /// for a given CFG. 3909467b48Spatrick template <class Edge, class BBInfo> class CFGMST { 4009467b48Spatrick public: 4109467b48Spatrick Function &F; 4209467b48Spatrick 4309467b48Spatrick // Store all the edges in CFG. It may contain some stale edges 4409467b48Spatrick // when Removed is set. 4509467b48Spatrick std::vector<std::unique_ptr<Edge>> AllEdges; 4609467b48Spatrick 4709467b48Spatrick // This map records the auxiliary information for each BB. 4809467b48Spatrick DenseMap<const BasicBlock *, std::unique_ptr<BBInfo>> BBInfos; 4909467b48Spatrick 5009467b48Spatrick // Whehter the function has an exit block with no successors. 5109467b48Spatrick // (For function with an infinite loop, this block may be absent) 5209467b48Spatrick bool ExitBlockFound = false; 5309467b48Spatrick 5409467b48Spatrick // Find the root group of the G and compress the path from G to the root. findAndCompressGroup(BBInfo * G)5509467b48Spatrick BBInfo *findAndCompressGroup(BBInfo *G) { 5609467b48Spatrick if (G->Group != G) 5709467b48Spatrick G->Group = findAndCompressGroup(static_cast<BBInfo *>(G->Group)); 5809467b48Spatrick return static_cast<BBInfo *>(G->Group); 5909467b48Spatrick } 6009467b48Spatrick 6109467b48Spatrick // Union BB1 and BB2 into the same group and return true. 6209467b48Spatrick // Returns false if BB1 and BB2 are already in the same group. unionGroups(const BasicBlock * BB1,const BasicBlock * BB2)6309467b48Spatrick bool unionGroups(const BasicBlock *BB1, const BasicBlock *BB2) { 6409467b48Spatrick BBInfo *BB1G = findAndCompressGroup(&getBBInfo(BB1)); 6509467b48Spatrick BBInfo *BB2G = findAndCompressGroup(&getBBInfo(BB2)); 6609467b48Spatrick 6709467b48Spatrick if (BB1G == BB2G) 6809467b48Spatrick return false; 6909467b48Spatrick 7009467b48Spatrick // Make the smaller rank tree a direct child or the root of high rank tree. 7109467b48Spatrick if (BB1G->Rank < BB2G->Rank) 7209467b48Spatrick BB1G->Group = BB2G; 7309467b48Spatrick else { 7409467b48Spatrick BB2G->Group = BB1G; 7509467b48Spatrick // If the ranks are the same, increment root of one tree by one. 7609467b48Spatrick if (BB1G->Rank == BB2G->Rank) 7709467b48Spatrick BB1G->Rank++; 7809467b48Spatrick } 7909467b48Spatrick return true; 8009467b48Spatrick } 8109467b48Spatrick 8209467b48Spatrick // Give BB, return the auxiliary information. getBBInfo(const BasicBlock * BB)8309467b48Spatrick BBInfo &getBBInfo(const BasicBlock *BB) const { 8409467b48Spatrick auto It = BBInfos.find(BB); 8509467b48Spatrick assert(It->second.get() != nullptr); 8609467b48Spatrick return *It->second.get(); 8709467b48Spatrick } 8809467b48Spatrick 8909467b48Spatrick // Give BB, return the auxiliary information if it's available. findBBInfo(const BasicBlock * BB)9009467b48Spatrick BBInfo *findBBInfo(const BasicBlock *BB) const { 9109467b48Spatrick auto It = BBInfos.find(BB); 9209467b48Spatrick if (It == BBInfos.end()) 9309467b48Spatrick return nullptr; 9409467b48Spatrick return It->second.get(); 9509467b48Spatrick } 9609467b48Spatrick 9709467b48Spatrick // Traverse the CFG using a stack. Find all the edges and assign the weight. 9809467b48Spatrick // Edges with large weight will be put into MST first so they are less likely 9909467b48Spatrick // to be instrumented. buildEdges()10009467b48Spatrick void buildEdges() { 10109467b48Spatrick LLVM_DEBUG(dbgs() << "Build Edge on " << F.getName() << "\n"); 10209467b48Spatrick 10309467b48Spatrick const BasicBlock *Entry = &(F.getEntryBlock()); 10409467b48Spatrick uint64_t EntryWeight = (BFI != nullptr ? BFI->getEntryFreq() : 2); 105097a140dSpatrick // If we want to instrument the entry count, lower the weight to 0. 106*73471bf0Spatrick if (InstrumentFuncEntry) 107097a140dSpatrick EntryWeight = 0; 10809467b48Spatrick Edge *EntryIncoming = nullptr, *EntryOutgoing = nullptr, 10909467b48Spatrick *ExitOutgoing = nullptr, *ExitIncoming = nullptr; 11009467b48Spatrick uint64_t MaxEntryOutWeight = 0, MaxExitOutWeight = 0, MaxExitInWeight = 0; 11109467b48Spatrick 11209467b48Spatrick // Add a fake edge to the entry. 11309467b48Spatrick EntryIncoming = &addEdge(nullptr, Entry, EntryWeight); 11409467b48Spatrick LLVM_DEBUG(dbgs() << " Edge: from fake node to " << Entry->getName() 11509467b48Spatrick << " w = " << EntryWeight << "\n"); 11609467b48Spatrick 11709467b48Spatrick // Special handling for single BB functions. 11809467b48Spatrick if (succ_empty(Entry)) { 11909467b48Spatrick addEdge(Entry, nullptr, EntryWeight); 12009467b48Spatrick return; 12109467b48Spatrick } 12209467b48Spatrick 12309467b48Spatrick static const uint32_t CriticalEdgeMultiplier = 1000; 12409467b48Spatrick 125*73471bf0Spatrick for (BasicBlock &BB : F) { 126*73471bf0Spatrick Instruction *TI = BB.getTerminator(); 12709467b48Spatrick uint64_t BBWeight = 128*73471bf0Spatrick (BFI != nullptr ? BFI->getBlockFreq(&BB).getFrequency() : 2); 12909467b48Spatrick uint64_t Weight = 2; 13009467b48Spatrick if (int successors = TI->getNumSuccessors()) { 13109467b48Spatrick for (int i = 0; i != successors; ++i) { 13209467b48Spatrick BasicBlock *TargetBB = TI->getSuccessor(i); 13309467b48Spatrick bool Critical = isCriticalEdge(TI, i); 13409467b48Spatrick uint64_t scaleFactor = BBWeight; 13509467b48Spatrick if (Critical) { 13609467b48Spatrick if (scaleFactor < UINT64_MAX / CriticalEdgeMultiplier) 13709467b48Spatrick scaleFactor *= CriticalEdgeMultiplier; 13809467b48Spatrick else 13909467b48Spatrick scaleFactor = UINT64_MAX; 14009467b48Spatrick } 14109467b48Spatrick if (BPI != nullptr) 142*73471bf0Spatrick Weight = BPI->getEdgeProbability(&BB, TargetBB).scale(scaleFactor); 143097a140dSpatrick if (Weight == 0) 144097a140dSpatrick Weight++; 145*73471bf0Spatrick auto *E = &addEdge(&BB, TargetBB, Weight); 14609467b48Spatrick E->IsCritical = Critical; 147*73471bf0Spatrick LLVM_DEBUG(dbgs() << " Edge: from " << BB.getName() << " to " 14809467b48Spatrick << TargetBB->getName() << " w=" << Weight << "\n"); 14909467b48Spatrick 15009467b48Spatrick // Keep track of entry/exit edges: 151*73471bf0Spatrick if (&BB == Entry) { 15209467b48Spatrick if (Weight > MaxEntryOutWeight) { 15309467b48Spatrick MaxEntryOutWeight = Weight; 15409467b48Spatrick EntryOutgoing = E; 15509467b48Spatrick } 15609467b48Spatrick } 15709467b48Spatrick 15809467b48Spatrick auto *TargetTI = TargetBB->getTerminator(); 15909467b48Spatrick if (TargetTI && !TargetTI->getNumSuccessors()) { 16009467b48Spatrick if (Weight > MaxExitInWeight) { 16109467b48Spatrick MaxExitInWeight = Weight; 16209467b48Spatrick ExitIncoming = E; 16309467b48Spatrick } 16409467b48Spatrick } 16509467b48Spatrick } 16609467b48Spatrick } else { 16709467b48Spatrick ExitBlockFound = true; 168*73471bf0Spatrick Edge *ExitO = &addEdge(&BB, nullptr, BBWeight); 16909467b48Spatrick if (BBWeight > MaxExitOutWeight) { 17009467b48Spatrick MaxExitOutWeight = BBWeight; 17109467b48Spatrick ExitOutgoing = ExitO; 17209467b48Spatrick } 173*73471bf0Spatrick LLVM_DEBUG(dbgs() << " Edge: from " << BB.getName() << " to fake exit" 17409467b48Spatrick << " w = " << BBWeight << "\n"); 17509467b48Spatrick } 17609467b48Spatrick } 17709467b48Spatrick 17809467b48Spatrick // Entry/exit edge adjustment heurisitic: 17909467b48Spatrick // prefer instrumenting entry edge over exit edge 18009467b48Spatrick // if possible. Those exit edges may never have a chance to be 18109467b48Spatrick // executed (for instance the program is an event handling loop) 18209467b48Spatrick // before the profile is asynchronously dumped. 18309467b48Spatrick // 18409467b48Spatrick // If EntryIncoming and ExitOutgoing has similar weight, make sure 18509467b48Spatrick // ExitOutging is selected as the min-edge. Similarly, if EntryOutgoing 18609467b48Spatrick // and ExitIncoming has similar weight, make sure ExitIncoming becomes 18709467b48Spatrick // the min-edge. 18809467b48Spatrick uint64_t EntryInWeight = EntryWeight; 18909467b48Spatrick 19009467b48Spatrick if (EntryInWeight >= MaxExitOutWeight && 19109467b48Spatrick EntryInWeight * 2 < MaxExitOutWeight * 3) { 19209467b48Spatrick EntryIncoming->Weight = MaxExitOutWeight; 19309467b48Spatrick ExitOutgoing->Weight = EntryInWeight + 1; 19409467b48Spatrick } 19509467b48Spatrick 19609467b48Spatrick if (MaxEntryOutWeight >= MaxExitInWeight && 19709467b48Spatrick MaxEntryOutWeight * 2 < MaxExitInWeight * 3) { 19809467b48Spatrick EntryOutgoing->Weight = MaxExitInWeight; 19909467b48Spatrick ExitIncoming->Weight = MaxEntryOutWeight + 1; 20009467b48Spatrick } 20109467b48Spatrick } 20209467b48Spatrick 20309467b48Spatrick // Sort CFG edges based on its weight. sortEdgesByWeight()20409467b48Spatrick void sortEdgesByWeight() { 20509467b48Spatrick llvm::stable_sort(AllEdges, [](const std::unique_ptr<Edge> &Edge1, 20609467b48Spatrick const std::unique_ptr<Edge> &Edge2) { 20709467b48Spatrick return Edge1->Weight > Edge2->Weight; 20809467b48Spatrick }); 20909467b48Spatrick } 21009467b48Spatrick 21109467b48Spatrick // Traverse all the edges and compute the Minimum Weight Spanning Tree 21209467b48Spatrick // using union-find algorithm. computeMinimumSpanningTree()21309467b48Spatrick void computeMinimumSpanningTree() { 21409467b48Spatrick // First, put all the critical edge with landing-pad as the Dest to MST. 21509467b48Spatrick // This works around the insufficient support of critical edges split 21609467b48Spatrick // when destination BB is a landing pad. 21709467b48Spatrick for (auto &Ei : AllEdges) { 21809467b48Spatrick if (Ei->Removed) 21909467b48Spatrick continue; 22009467b48Spatrick if (Ei->IsCritical) { 22109467b48Spatrick if (Ei->DestBB && Ei->DestBB->isLandingPad()) { 22209467b48Spatrick if (unionGroups(Ei->SrcBB, Ei->DestBB)) 22309467b48Spatrick Ei->InMST = true; 22409467b48Spatrick } 22509467b48Spatrick } 22609467b48Spatrick } 22709467b48Spatrick 22809467b48Spatrick for (auto &Ei : AllEdges) { 22909467b48Spatrick if (Ei->Removed) 23009467b48Spatrick continue; 23109467b48Spatrick // If we detect infinite loops, force 23209467b48Spatrick // instrumenting the entry edge: 23309467b48Spatrick if (!ExitBlockFound && Ei->SrcBB == nullptr) 23409467b48Spatrick continue; 23509467b48Spatrick if (unionGroups(Ei->SrcBB, Ei->DestBB)) 23609467b48Spatrick Ei->InMST = true; 23709467b48Spatrick } 23809467b48Spatrick } 23909467b48Spatrick 24009467b48Spatrick // Dump the Debug information about the instrumentation. dumpEdges(raw_ostream & OS,const Twine & Message)24109467b48Spatrick void dumpEdges(raw_ostream &OS, const Twine &Message) const { 24209467b48Spatrick if (!Message.str().empty()) 24309467b48Spatrick OS << Message << "\n"; 24409467b48Spatrick OS << " Number of Basic Blocks: " << BBInfos.size() << "\n"; 24509467b48Spatrick for (auto &BI : BBInfos) { 24609467b48Spatrick const BasicBlock *BB = BI.first; 24709467b48Spatrick OS << " BB: " << (BB == nullptr ? "FakeNode" : BB->getName()) << " " 24809467b48Spatrick << BI.second->infoString() << "\n"; 24909467b48Spatrick } 25009467b48Spatrick 25109467b48Spatrick OS << " Number of Edges: " << AllEdges.size() 25209467b48Spatrick << " (*: Instrument, C: CriticalEdge, -: Removed)\n"; 25309467b48Spatrick uint32_t Count = 0; 25409467b48Spatrick for (auto &EI : AllEdges) 25509467b48Spatrick OS << " Edge " << Count++ << ": " << getBBInfo(EI->SrcBB).Index << "-->" 25609467b48Spatrick << getBBInfo(EI->DestBB).Index << EI->infoString() << "\n"; 25709467b48Spatrick } 25809467b48Spatrick 25909467b48Spatrick // Add an edge to AllEdges with weight W. addEdge(const BasicBlock * Src,const BasicBlock * Dest,uint64_t W)26009467b48Spatrick Edge &addEdge(const BasicBlock *Src, const BasicBlock *Dest, uint64_t W) { 26109467b48Spatrick uint32_t Index = BBInfos.size(); 26209467b48Spatrick auto Iter = BBInfos.end(); 26309467b48Spatrick bool Inserted; 26409467b48Spatrick std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Src, nullptr)); 26509467b48Spatrick if (Inserted) { 26609467b48Spatrick // Newly inserted, update the real info. 26709467b48Spatrick Iter->second = std::move(std::make_unique<BBInfo>(Index)); 26809467b48Spatrick Index++; 26909467b48Spatrick } 27009467b48Spatrick std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Dest, nullptr)); 27109467b48Spatrick if (Inserted) 27209467b48Spatrick // Newly inserted, update the real info. 27309467b48Spatrick Iter->second = std::move(std::make_unique<BBInfo>(Index)); 27409467b48Spatrick AllEdges.emplace_back(new Edge(Src, Dest, W)); 27509467b48Spatrick return *AllEdges.back(); 27609467b48Spatrick } 27709467b48Spatrick 27809467b48Spatrick BranchProbabilityInfo *BPI; 27909467b48Spatrick BlockFrequencyInfo *BFI; 28009467b48Spatrick 281*73471bf0Spatrick // If function entry will be always instrumented. 282*73471bf0Spatrick bool InstrumentFuncEntry; 283*73471bf0Spatrick 28409467b48Spatrick public: 285*73471bf0Spatrick CFGMST(Function &Func, bool InstrumentFuncEntry_, 286*73471bf0Spatrick BranchProbabilityInfo *BPI_ = nullptr, 28709467b48Spatrick BlockFrequencyInfo *BFI_ = nullptr) F(Func)288*73471bf0Spatrick : F(Func), BPI(BPI_), BFI(BFI_), 289*73471bf0Spatrick InstrumentFuncEntry(InstrumentFuncEntry_) { 29009467b48Spatrick buildEdges(); 29109467b48Spatrick sortEdgesByWeight(); 29209467b48Spatrick computeMinimumSpanningTree(); 293*73471bf0Spatrick if (AllEdges.size() > 1 && InstrumentFuncEntry) 294097a140dSpatrick std::iter_swap(std::move(AllEdges.begin()), 295097a140dSpatrick std::move(AllEdges.begin() + AllEdges.size() - 1)); 29609467b48Spatrick } 29709467b48Spatrick }; 29809467b48Spatrick 29909467b48Spatrick } // end namespace llvm 30009467b48Spatrick 30109467b48Spatrick #undef DEBUG_TYPE // "cfgmst" 30209467b48Spatrick 30309467b48Spatrick #endif // LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H 304