xref: /llvm-project/bolt/include/bolt/Core/CallGraph.h (revision 2430a354bfb9e8c08e0dd5f294012b40afb75ce0)
1 //===- bolt/Core/CallGraph.h ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef BOLT_PASSES_CALLGRAPH_H
10 #define BOLT_PASSES_CALLGRAPH_H
11 
12 #include "llvm/Support/FileSystem.h"
13 #include "llvm/Support/raw_ostream.h"
14 #include <cassert>
15 #include <cstdint>
16 #include <unordered_set>
17 #include <vector>
18 
19 namespace llvm {
20 namespace bolt {
21 
22 // TODO: find better place for this
hashCombine(const int64_t Seed,const int64_t Val)23 inline int64_t hashCombine(const int64_t Seed, const int64_t Val) {
24   std::hash<int64_t> Hasher;
25   return Seed ^ (Hasher(Val) + 0x9e3779b9 + (Seed << 6) + (Seed >> 2));
26 }
27 
28 /// A call graph class.
29 class CallGraph {
30 public:
31   using NodeId = size_t;
32   static constexpr NodeId InvalidId = -1;
33 
34   template <typename T> class iterator_range {
35     T Begin;
36     T End;
37 
38   public:
39     template <typename Container>
iterator_range(Container && c)40     iterator_range(Container &&c) : Begin(c.begin()), End(c.end()) {}
iterator_range(T Begin,T End)41     iterator_range(T Begin, T End)
42         : Begin(std::move(Begin)), End(std::move(End)) {}
43 
begin()44     T begin() const { return Begin; }
end()45     T end() const { return End; }
46   };
47 
48   class Arc {
49   public:
50     struct Hash {
51       int64_t operator()(const Arc &Arc) const;
52     };
53 
Src(S)54     Arc(NodeId S, NodeId D, double W = 0) : Src(S), Dst(D), Weight(W) {}
55     Arc(const Arc &) = delete;
56 
57     friend bool operator==(const Arc &Lhs, const Arc &Rhs) {
58       return Lhs.Src == Rhs.Src && Lhs.Dst == Rhs.Dst;
59     }
60 
src()61     NodeId src() const { return Src; }
dst()62     NodeId dst() const { return Dst; }
weight()63     double weight() const { return Weight; }
avgCallOffset()64     double avgCallOffset() const { return AvgCallOffset; }
normalizedWeight()65     double normalizedWeight() const { return NormalizedWeight; }
66 
67   private:
68     friend class CallGraph;
69     NodeId Src{InvalidId};
70     NodeId Dst{InvalidId};
71     mutable double Weight{0};
72     mutable double NormalizedWeight{0};
73     mutable double AvgCallOffset{0};
74   };
75 
76   using ArcsType = std::unordered_set<Arc, Arc::Hash>;
77   using ArcIterator = ArcsType::iterator;
78   using ArcConstIterator = ArcsType::const_iterator;
79 
80   class Node {
81   public:
82     explicit Node(uint32_t Size, uint64_t Samples = 0)
Size(Size)83         : Size(Size), Samples(Samples) {}
84 
size()85     uint32_t size() const { return Size; }
samples()86     uint64_t samples() const { return Samples; }
87 
successors()88     const std::vector<NodeId> &successors() const { return Succs; }
predecessors()89     const std::vector<NodeId> &predecessors() const { return Preds; }
90 
91   private:
92     friend class CallGraph;
93     uint32_t Size;
94     uint64_t Samples;
95 
96     // preds and succs contain no duplicate elements and self arcs are not
97     // allowed
98     std::vector<NodeId> Preds;
99     std::vector<NodeId> Succs;
100   };
101 
numNodes()102   size_t numNodes() const { return Nodes.size(); }
numArcs()103   size_t numArcs() const { return Arcs.size(); }
getNode(const NodeId Id)104   const Node &getNode(const NodeId Id) const {
105     assert(Id < Nodes.size());
106     return Nodes[Id];
107   }
size(const NodeId Id)108   uint32_t size(const NodeId Id) const {
109     assert(Id < Nodes.size());
110     return Nodes[Id].Size;
111   }
samples(const NodeId Id)112   uint64_t samples(const NodeId Id) const {
113     assert(Id < Nodes.size());
114     return Nodes[Id].Samples;
115   }
successors(const NodeId Id)116   const std::vector<NodeId> &successors(const NodeId Id) const {
117     assert(Id < Nodes.size());
118     return Nodes[Id].Succs;
119   }
predecessors(const NodeId Id)120   const std::vector<NodeId> &predecessors(const NodeId Id) const {
121     assert(Id < Nodes.size());
122     return Nodes[Id].Preds;
123   }
124   NodeId addNode(uint32_t Size, uint64_t Samples = 0);
125   const Arc &incArcWeight(NodeId Src, NodeId Dst, double W = 1.0,
126                           double Offset = 0.0);
findArc(NodeId Src,NodeId Dst)127   ArcIterator findArc(NodeId Src, NodeId Dst) {
128     return Arcs.find(Arc(Src, Dst));
129   }
findArc(NodeId Src,NodeId Dst)130   ArcConstIterator findArc(NodeId Src, NodeId Dst) const {
131     return Arcs.find(Arc(Src, Dst));
132   }
arcs()133   iterator_range<ArcConstIterator> arcs() const {
134     return iterator_range<ArcConstIterator>(Arcs.begin(), Arcs.end());
135   }
nodes()136   iterator_range<std::vector<Node>::const_iterator> nodes() const {
137     return iterator_range<std::vector<Node>::const_iterator>(Nodes.begin(),
138                                                              Nodes.end());
139   }
140 
density()141   double density() const {
142     return double(Arcs.size()) / (Nodes.size() * Nodes.size());
143   }
144 
145   // Initialize NormalizedWeight field for every arc
146   void normalizeArcWeights();
147   // Make sure that the sum of incoming arc weights is at least the number of
148   // samples for every node
149   void adjustArcWeights();
150 
151   template <typename L> void printDot(StringRef FileName, L getLabel) const;
152 
153 private:
setSamples(const NodeId Id,uint64_t Samples)154   void setSamples(const NodeId Id, uint64_t Samples) {
155     assert(Id < Nodes.size());
156     Nodes[Id].Samples = Samples;
157   }
158 
159   std::vector<Node> Nodes;
160   ArcsType Arcs;
161 };
162 
163 template <class L>
printDot(StringRef FileName,L GetLabel)164 void CallGraph::printDot(StringRef FileName, L GetLabel) const {
165   std::error_code EC;
166   raw_fd_ostream OS(FileName, EC, sys::fs::OF_None);
167   if (EC)
168     return;
169 
170   OS << "digraph g {\n";
171   for (NodeId F = 0; F < Nodes.size(); F++) {
172     if (Nodes[F].samples() == 0)
173       continue;
174     OS << "f" << F << " [label=\"" << GetLabel(F)
175        << "\\nsamples=" << Nodes[F].samples() << "\\nsize=" << Nodes[F].size()
176        << "\"];\n";
177   }
178   for (NodeId F = 0; F < Nodes.size(); F++) {
179     if (Nodes[F].samples() == 0)
180       continue;
181     for (NodeId Dst : Nodes[F].successors()) {
182       ArcConstIterator Arc = findArc(F, Dst);
183       OS << "f" << F << " -> f" << Dst
184          << " [label=\"normWgt=" << format("%.3lf", Arc->normalizedWeight())
185          << ",weight=" << format("%.0lf", Arc->weight())
186          << ",callOffset=" << format("%.1lf", Arc->avgCallOffset()) << "\"];\n";
187     }
188   }
189   OS << "}\n";
190 }
191 
192 } // namespace bolt
193 } // namespace llvm
194 
195 #endif
196