xref: /openbsd-src/gnu/llvm/llvm/lib/Transforms/Instrumentation/CFGMST.h (revision 73471bf04ceb096474c7f0fa83b1b65c70a787a1)
109467b48Spatrick //===-- CFGMST.h - Minimum Spanning Tree for CFG ----------------*- C++ -*-===//
209467b48Spatrick //
309467b48Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
409467b48Spatrick // See https://llvm.org/LICENSE.txt for license information.
509467b48Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
609467b48Spatrick //
709467b48Spatrick //===----------------------------------------------------------------------===//
809467b48Spatrick //
909467b48Spatrick // This file implements a Union-find algorithm to compute Minimum Spanning Tree
1009467b48Spatrick // for a given CFG.
1109467b48Spatrick //
1209467b48Spatrick //===----------------------------------------------------------------------===//
1309467b48Spatrick 
1409467b48Spatrick #ifndef LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H
1509467b48Spatrick #define LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H
1609467b48Spatrick 
1709467b48Spatrick #include "llvm/ADT/DenseMap.h"
1809467b48Spatrick #include "llvm/ADT/STLExtras.h"
1909467b48Spatrick #include "llvm/Analysis/BlockFrequencyInfo.h"
2009467b48Spatrick #include "llvm/Analysis/BranchProbabilityInfo.h"
2109467b48Spatrick #include "llvm/Analysis/CFG.h"
2209467b48Spatrick #include "llvm/Support/BranchProbability.h"
2309467b48Spatrick #include "llvm/Support/Debug.h"
2409467b48Spatrick #include "llvm/Support/raw_ostream.h"
2509467b48Spatrick #include "llvm/Transforms/Utils/BasicBlockUtils.h"
2609467b48Spatrick #include <utility>
2709467b48Spatrick #include <vector>
2809467b48Spatrick 
2909467b48Spatrick #define DEBUG_TYPE "cfgmst"
3009467b48Spatrick 
31097a140dSpatrick using namespace llvm;
32097a140dSpatrick 
3309467b48Spatrick namespace llvm {
3409467b48Spatrick 
3509467b48Spatrick /// An union-find based Minimum Spanning Tree for CFG
3609467b48Spatrick ///
3709467b48Spatrick /// Implements a Union-find algorithm to compute Minimum Spanning Tree
3809467b48Spatrick /// for a given CFG.
3909467b48Spatrick template <class Edge, class BBInfo> class CFGMST {
4009467b48Spatrick public:
4109467b48Spatrick   Function &F;
4209467b48Spatrick 
4309467b48Spatrick   // Store all the edges in CFG. It may contain some stale edges
4409467b48Spatrick   // when Removed is set.
4509467b48Spatrick   std::vector<std::unique_ptr<Edge>> AllEdges;
4609467b48Spatrick 
4709467b48Spatrick   // This map records the auxiliary information for each BB.
4809467b48Spatrick   DenseMap<const BasicBlock *, std::unique_ptr<BBInfo>> BBInfos;
4909467b48Spatrick 
5009467b48Spatrick   // Whehter the function has an exit block with no successors.
5109467b48Spatrick   // (For function with an infinite loop, this block may be absent)
5209467b48Spatrick   bool ExitBlockFound = false;
5309467b48Spatrick 
5409467b48Spatrick   // Find the root group of the G and compress the path from G to the root.
findAndCompressGroup(BBInfo * G)5509467b48Spatrick   BBInfo *findAndCompressGroup(BBInfo *G) {
5609467b48Spatrick     if (G->Group != G)
5709467b48Spatrick       G->Group = findAndCompressGroup(static_cast<BBInfo *>(G->Group));
5809467b48Spatrick     return static_cast<BBInfo *>(G->Group);
5909467b48Spatrick   }
6009467b48Spatrick 
6109467b48Spatrick   // Union BB1 and BB2 into the same group and return true.
6209467b48Spatrick   // Returns false if BB1 and BB2 are already in the same group.
unionGroups(const BasicBlock * BB1,const BasicBlock * BB2)6309467b48Spatrick   bool unionGroups(const BasicBlock *BB1, const BasicBlock *BB2) {
6409467b48Spatrick     BBInfo *BB1G = findAndCompressGroup(&getBBInfo(BB1));
6509467b48Spatrick     BBInfo *BB2G = findAndCompressGroup(&getBBInfo(BB2));
6609467b48Spatrick 
6709467b48Spatrick     if (BB1G == BB2G)
6809467b48Spatrick       return false;
6909467b48Spatrick 
7009467b48Spatrick     // Make the smaller rank tree a direct child or the root of high rank tree.
7109467b48Spatrick     if (BB1G->Rank < BB2G->Rank)
7209467b48Spatrick       BB1G->Group = BB2G;
7309467b48Spatrick     else {
7409467b48Spatrick       BB2G->Group = BB1G;
7509467b48Spatrick       // If the ranks are the same, increment root of one tree by one.
7609467b48Spatrick       if (BB1G->Rank == BB2G->Rank)
7709467b48Spatrick         BB1G->Rank++;
7809467b48Spatrick     }
7909467b48Spatrick     return true;
8009467b48Spatrick   }
8109467b48Spatrick 
8209467b48Spatrick   // Give BB, return the auxiliary information.
getBBInfo(const BasicBlock * BB)8309467b48Spatrick   BBInfo &getBBInfo(const BasicBlock *BB) const {
8409467b48Spatrick     auto It = BBInfos.find(BB);
8509467b48Spatrick     assert(It->second.get() != nullptr);
8609467b48Spatrick     return *It->second.get();
8709467b48Spatrick   }
8809467b48Spatrick 
8909467b48Spatrick   // Give BB, return the auxiliary information if it's available.
findBBInfo(const BasicBlock * BB)9009467b48Spatrick   BBInfo *findBBInfo(const BasicBlock *BB) const {
9109467b48Spatrick     auto It = BBInfos.find(BB);
9209467b48Spatrick     if (It == BBInfos.end())
9309467b48Spatrick       return nullptr;
9409467b48Spatrick     return It->second.get();
9509467b48Spatrick   }
9609467b48Spatrick 
9709467b48Spatrick   // Traverse the CFG using a stack. Find all the edges and assign the weight.
9809467b48Spatrick   // Edges with large weight will be put into MST first so they are less likely
9909467b48Spatrick   // to be instrumented.
buildEdges()10009467b48Spatrick   void buildEdges() {
10109467b48Spatrick     LLVM_DEBUG(dbgs() << "Build Edge on " << F.getName() << "\n");
10209467b48Spatrick 
10309467b48Spatrick     const BasicBlock *Entry = &(F.getEntryBlock());
10409467b48Spatrick     uint64_t EntryWeight = (BFI != nullptr ? BFI->getEntryFreq() : 2);
105097a140dSpatrick     // If we want to instrument the entry count, lower the weight to 0.
106*73471bf0Spatrick     if (InstrumentFuncEntry)
107097a140dSpatrick       EntryWeight = 0;
10809467b48Spatrick     Edge *EntryIncoming = nullptr, *EntryOutgoing = nullptr,
10909467b48Spatrick          *ExitOutgoing = nullptr, *ExitIncoming = nullptr;
11009467b48Spatrick     uint64_t MaxEntryOutWeight = 0, MaxExitOutWeight = 0, MaxExitInWeight = 0;
11109467b48Spatrick 
11209467b48Spatrick     // Add a fake edge to the entry.
11309467b48Spatrick     EntryIncoming = &addEdge(nullptr, Entry, EntryWeight);
11409467b48Spatrick     LLVM_DEBUG(dbgs() << "  Edge: from fake node to " << Entry->getName()
11509467b48Spatrick                       << " w = " << EntryWeight << "\n");
11609467b48Spatrick 
11709467b48Spatrick     // Special handling for single BB functions.
11809467b48Spatrick     if (succ_empty(Entry)) {
11909467b48Spatrick       addEdge(Entry, nullptr, EntryWeight);
12009467b48Spatrick       return;
12109467b48Spatrick     }
12209467b48Spatrick 
12309467b48Spatrick     static const uint32_t CriticalEdgeMultiplier = 1000;
12409467b48Spatrick 
125*73471bf0Spatrick     for (BasicBlock &BB : F) {
126*73471bf0Spatrick       Instruction *TI = BB.getTerminator();
12709467b48Spatrick       uint64_t BBWeight =
128*73471bf0Spatrick           (BFI != nullptr ? BFI->getBlockFreq(&BB).getFrequency() : 2);
12909467b48Spatrick       uint64_t Weight = 2;
13009467b48Spatrick       if (int successors = TI->getNumSuccessors()) {
13109467b48Spatrick         for (int i = 0; i != successors; ++i) {
13209467b48Spatrick           BasicBlock *TargetBB = TI->getSuccessor(i);
13309467b48Spatrick           bool Critical = isCriticalEdge(TI, i);
13409467b48Spatrick           uint64_t scaleFactor = BBWeight;
13509467b48Spatrick           if (Critical) {
13609467b48Spatrick             if (scaleFactor < UINT64_MAX / CriticalEdgeMultiplier)
13709467b48Spatrick               scaleFactor *= CriticalEdgeMultiplier;
13809467b48Spatrick             else
13909467b48Spatrick               scaleFactor = UINT64_MAX;
14009467b48Spatrick           }
14109467b48Spatrick           if (BPI != nullptr)
142*73471bf0Spatrick             Weight = BPI->getEdgeProbability(&BB, TargetBB).scale(scaleFactor);
143097a140dSpatrick           if (Weight == 0)
144097a140dSpatrick             Weight++;
145*73471bf0Spatrick           auto *E = &addEdge(&BB, TargetBB, Weight);
14609467b48Spatrick           E->IsCritical = Critical;
147*73471bf0Spatrick           LLVM_DEBUG(dbgs() << "  Edge: from " << BB.getName() << " to "
14809467b48Spatrick                             << TargetBB->getName() << "  w=" << Weight << "\n");
14909467b48Spatrick 
15009467b48Spatrick           // Keep track of entry/exit edges:
151*73471bf0Spatrick           if (&BB == Entry) {
15209467b48Spatrick             if (Weight > MaxEntryOutWeight) {
15309467b48Spatrick               MaxEntryOutWeight = Weight;
15409467b48Spatrick               EntryOutgoing = E;
15509467b48Spatrick             }
15609467b48Spatrick           }
15709467b48Spatrick 
15809467b48Spatrick           auto *TargetTI = TargetBB->getTerminator();
15909467b48Spatrick           if (TargetTI && !TargetTI->getNumSuccessors()) {
16009467b48Spatrick             if (Weight > MaxExitInWeight) {
16109467b48Spatrick               MaxExitInWeight = Weight;
16209467b48Spatrick               ExitIncoming = E;
16309467b48Spatrick             }
16409467b48Spatrick           }
16509467b48Spatrick         }
16609467b48Spatrick       } else {
16709467b48Spatrick         ExitBlockFound = true;
168*73471bf0Spatrick         Edge *ExitO = &addEdge(&BB, nullptr, BBWeight);
16909467b48Spatrick         if (BBWeight > MaxExitOutWeight) {
17009467b48Spatrick           MaxExitOutWeight = BBWeight;
17109467b48Spatrick           ExitOutgoing = ExitO;
17209467b48Spatrick         }
173*73471bf0Spatrick         LLVM_DEBUG(dbgs() << "  Edge: from " << BB.getName() << " to fake exit"
17409467b48Spatrick                           << " w = " << BBWeight << "\n");
17509467b48Spatrick       }
17609467b48Spatrick     }
17709467b48Spatrick 
17809467b48Spatrick     // Entry/exit edge adjustment heurisitic:
17909467b48Spatrick     // prefer instrumenting entry edge over exit edge
18009467b48Spatrick     // if possible. Those exit edges may never have a chance to be
18109467b48Spatrick     // executed (for instance the program is an event handling loop)
18209467b48Spatrick     // before the profile is asynchronously dumped.
18309467b48Spatrick     //
18409467b48Spatrick     // If EntryIncoming and ExitOutgoing has similar weight, make sure
18509467b48Spatrick     // ExitOutging is selected as the min-edge. Similarly, if EntryOutgoing
18609467b48Spatrick     // and ExitIncoming has similar weight, make sure ExitIncoming becomes
18709467b48Spatrick     // the min-edge.
18809467b48Spatrick     uint64_t EntryInWeight = EntryWeight;
18909467b48Spatrick 
19009467b48Spatrick     if (EntryInWeight >= MaxExitOutWeight &&
19109467b48Spatrick         EntryInWeight * 2 < MaxExitOutWeight * 3) {
19209467b48Spatrick       EntryIncoming->Weight = MaxExitOutWeight;
19309467b48Spatrick       ExitOutgoing->Weight = EntryInWeight + 1;
19409467b48Spatrick     }
19509467b48Spatrick 
19609467b48Spatrick     if (MaxEntryOutWeight >= MaxExitInWeight &&
19709467b48Spatrick         MaxEntryOutWeight * 2 < MaxExitInWeight * 3) {
19809467b48Spatrick       EntryOutgoing->Weight = MaxExitInWeight;
19909467b48Spatrick       ExitIncoming->Weight = MaxEntryOutWeight + 1;
20009467b48Spatrick     }
20109467b48Spatrick   }
20209467b48Spatrick 
20309467b48Spatrick   // Sort CFG edges based on its weight.
sortEdgesByWeight()20409467b48Spatrick   void sortEdgesByWeight() {
20509467b48Spatrick     llvm::stable_sort(AllEdges, [](const std::unique_ptr<Edge> &Edge1,
20609467b48Spatrick                                    const std::unique_ptr<Edge> &Edge2) {
20709467b48Spatrick       return Edge1->Weight > Edge2->Weight;
20809467b48Spatrick     });
20909467b48Spatrick   }
21009467b48Spatrick 
21109467b48Spatrick   // Traverse all the edges and compute the Minimum Weight Spanning Tree
21209467b48Spatrick   // using union-find algorithm.
computeMinimumSpanningTree()21309467b48Spatrick   void computeMinimumSpanningTree() {
21409467b48Spatrick     // First, put all the critical edge with landing-pad as the Dest to MST.
21509467b48Spatrick     // This works around the insufficient support of critical edges split
21609467b48Spatrick     // when destination BB is a landing pad.
21709467b48Spatrick     for (auto &Ei : AllEdges) {
21809467b48Spatrick       if (Ei->Removed)
21909467b48Spatrick         continue;
22009467b48Spatrick       if (Ei->IsCritical) {
22109467b48Spatrick         if (Ei->DestBB && Ei->DestBB->isLandingPad()) {
22209467b48Spatrick           if (unionGroups(Ei->SrcBB, Ei->DestBB))
22309467b48Spatrick             Ei->InMST = true;
22409467b48Spatrick         }
22509467b48Spatrick       }
22609467b48Spatrick     }
22709467b48Spatrick 
22809467b48Spatrick     for (auto &Ei : AllEdges) {
22909467b48Spatrick       if (Ei->Removed)
23009467b48Spatrick         continue;
23109467b48Spatrick       // If we detect infinite loops, force
23209467b48Spatrick       // instrumenting the entry edge:
23309467b48Spatrick       if (!ExitBlockFound && Ei->SrcBB == nullptr)
23409467b48Spatrick         continue;
23509467b48Spatrick       if (unionGroups(Ei->SrcBB, Ei->DestBB))
23609467b48Spatrick         Ei->InMST = true;
23709467b48Spatrick     }
23809467b48Spatrick   }
23909467b48Spatrick 
24009467b48Spatrick   // Dump the Debug information about the instrumentation.
dumpEdges(raw_ostream & OS,const Twine & Message)24109467b48Spatrick   void dumpEdges(raw_ostream &OS, const Twine &Message) const {
24209467b48Spatrick     if (!Message.str().empty())
24309467b48Spatrick       OS << Message << "\n";
24409467b48Spatrick     OS << "  Number of Basic Blocks: " << BBInfos.size() << "\n";
24509467b48Spatrick     for (auto &BI : BBInfos) {
24609467b48Spatrick       const BasicBlock *BB = BI.first;
24709467b48Spatrick       OS << "  BB: " << (BB == nullptr ? "FakeNode" : BB->getName()) << "  "
24809467b48Spatrick          << BI.second->infoString() << "\n";
24909467b48Spatrick     }
25009467b48Spatrick 
25109467b48Spatrick     OS << "  Number of Edges: " << AllEdges.size()
25209467b48Spatrick        << " (*: Instrument, C: CriticalEdge, -: Removed)\n";
25309467b48Spatrick     uint32_t Count = 0;
25409467b48Spatrick     for (auto &EI : AllEdges)
25509467b48Spatrick       OS << "  Edge " << Count++ << ": " << getBBInfo(EI->SrcBB).Index << "-->"
25609467b48Spatrick          << getBBInfo(EI->DestBB).Index << EI->infoString() << "\n";
25709467b48Spatrick   }
25809467b48Spatrick 
25909467b48Spatrick   // Add an edge to AllEdges with weight W.
addEdge(const BasicBlock * Src,const BasicBlock * Dest,uint64_t W)26009467b48Spatrick   Edge &addEdge(const BasicBlock *Src, const BasicBlock *Dest, uint64_t W) {
26109467b48Spatrick     uint32_t Index = BBInfos.size();
26209467b48Spatrick     auto Iter = BBInfos.end();
26309467b48Spatrick     bool Inserted;
26409467b48Spatrick     std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Src, nullptr));
26509467b48Spatrick     if (Inserted) {
26609467b48Spatrick       // Newly inserted, update the real info.
26709467b48Spatrick       Iter->second = std::move(std::make_unique<BBInfo>(Index));
26809467b48Spatrick       Index++;
26909467b48Spatrick     }
27009467b48Spatrick     std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Dest, nullptr));
27109467b48Spatrick     if (Inserted)
27209467b48Spatrick       // Newly inserted, update the real info.
27309467b48Spatrick       Iter->second = std::move(std::make_unique<BBInfo>(Index));
27409467b48Spatrick     AllEdges.emplace_back(new Edge(Src, Dest, W));
27509467b48Spatrick     return *AllEdges.back();
27609467b48Spatrick   }
27709467b48Spatrick 
27809467b48Spatrick   BranchProbabilityInfo *BPI;
27909467b48Spatrick   BlockFrequencyInfo *BFI;
28009467b48Spatrick 
281*73471bf0Spatrick   // If function entry will be always instrumented.
282*73471bf0Spatrick   bool InstrumentFuncEntry;
283*73471bf0Spatrick 
28409467b48Spatrick public:
285*73471bf0Spatrick   CFGMST(Function &Func, bool InstrumentFuncEntry_,
286*73471bf0Spatrick          BranchProbabilityInfo *BPI_ = nullptr,
28709467b48Spatrick          BlockFrequencyInfo *BFI_ = nullptr)
F(Func)288*73471bf0Spatrick       : F(Func), BPI(BPI_), BFI(BFI_),
289*73471bf0Spatrick         InstrumentFuncEntry(InstrumentFuncEntry_) {
29009467b48Spatrick     buildEdges();
29109467b48Spatrick     sortEdgesByWeight();
29209467b48Spatrick     computeMinimumSpanningTree();
293*73471bf0Spatrick     if (AllEdges.size() > 1 && InstrumentFuncEntry)
294097a140dSpatrick       std::iter_swap(std::move(AllEdges.begin()),
295097a140dSpatrick                      std::move(AllEdges.begin() + AllEdges.size() - 1));
29609467b48Spatrick   }
29709467b48Spatrick };
29809467b48Spatrick 
29909467b48Spatrick } // end namespace llvm
30009467b48Spatrick 
30109467b48Spatrick #undef DEBUG_TYPE // "cfgmst"
30209467b48Spatrick 
30309467b48Spatrick #endif // LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H
304