1 //===- bolt/Core/CallGraph.h ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef BOLT_PASSES_CALLGRAPH_H
10 #define BOLT_PASSES_CALLGRAPH_H
11
12 #include "llvm/Support/FileSystem.h"
13 #include "llvm/Support/raw_ostream.h"
14 #include <cassert>
15 #include <cstdint>
16 #include <unordered_set>
17 #include <vector>
18
19 namespace llvm {
20 namespace bolt {
21
22 // TODO: find better place for this
hashCombine(const int64_t Seed,const int64_t Val)23 inline int64_t hashCombine(const int64_t Seed, const int64_t Val) {
24 std::hash<int64_t> Hasher;
25 return Seed ^ (Hasher(Val) + 0x9e3779b9 + (Seed << 6) + (Seed >> 2));
26 }
27
28 /// A call graph class.
29 class CallGraph {
30 public:
31 using NodeId = size_t;
32 static constexpr NodeId InvalidId = -1;
33
34 template <typename T> class iterator_range {
35 T Begin;
36 T End;
37
38 public:
39 template <typename Container>
iterator_range(Container && c)40 iterator_range(Container &&c) : Begin(c.begin()), End(c.end()) {}
iterator_range(T Begin,T End)41 iterator_range(T Begin, T End)
42 : Begin(std::move(Begin)), End(std::move(End)) {}
43
begin()44 T begin() const { return Begin; }
end()45 T end() const { return End; }
46 };
47
48 class Arc {
49 public:
50 struct Hash {
51 int64_t operator()(const Arc &Arc) const;
52 };
53
Src(S)54 Arc(NodeId S, NodeId D, double W = 0) : Src(S), Dst(D), Weight(W) {}
55 Arc(const Arc &) = delete;
56
57 friend bool operator==(const Arc &Lhs, const Arc &Rhs) {
58 return Lhs.Src == Rhs.Src && Lhs.Dst == Rhs.Dst;
59 }
60
src()61 NodeId src() const { return Src; }
dst()62 NodeId dst() const { return Dst; }
weight()63 double weight() const { return Weight; }
avgCallOffset()64 double avgCallOffset() const { return AvgCallOffset; }
normalizedWeight()65 double normalizedWeight() const { return NormalizedWeight; }
66
67 private:
68 friend class CallGraph;
69 NodeId Src{InvalidId};
70 NodeId Dst{InvalidId};
71 mutable double Weight{0};
72 mutable double NormalizedWeight{0};
73 mutable double AvgCallOffset{0};
74 };
75
76 using ArcsType = std::unordered_set<Arc, Arc::Hash>;
77 using ArcIterator = ArcsType::iterator;
78 using ArcConstIterator = ArcsType::const_iterator;
79
80 class Node {
81 public:
82 explicit Node(uint32_t Size, uint64_t Samples = 0)
Size(Size)83 : Size(Size), Samples(Samples) {}
84
size()85 uint32_t size() const { return Size; }
samples()86 uint64_t samples() const { return Samples; }
87
successors()88 const std::vector<NodeId> &successors() const { return Succs; }
predecessors()89 const std::vector<NodeId> &predecessors() const { return Preds; }
90
91 private:
92 friend class CallGraph;
93 uint32_t Size;
94 uint64_t Samples;
95
96 // preds and succs contain no duplicate elements and self arcs are not
97 // allowed
98 std::vector<NodeId> Preds;
99 std::vector<NodeId> Succs;
100 };
101
numNodes()102 size_t numNodes() const { return Nodes.size(); }
numArcs()103 size_t numArcs() const { return Arcs.size(); }
getNode(const NodeId Id)104 const Node &getNode(const NodeId Id) const {
105 assert(Id < Nodes.size());
106 return Nodes[Id];
107 }
size(const NodeId Id)108 uint32_t size(const NodeId Id) const {
109 assert(Id < Nodes.size());
110 return Nodes[Id].Size;
111 }
samples(const NodeId Id)112 uint64_t samples(const NodeId Id) const {
113 assert(Id < Nodes.size());
114 return Nodes[Id].Samples;
115 }
successors(const NodeId Id)116 const std::vector<NodeId> &successors(const NodeId Id) const {
117 assert(Id < Nodes.size());
118 return Nodes[Id].Succs;
119 }
predecessors(const NodeId Id)120 const std::vector<NodeId> &predecessors(const NodeId Id) const {
121 assert(Id < Nodes.size());
122 return Nodes[Id].Preds;
123 }
124 NodeId addNode(uint32_t Size, uint64_t Samples = 0);
125 const Arc &incArcWeight(NodeId Src, NodeId Dst, double W = 1.0,
126 double Offset = 0.0);
findArc(NodeId Src,NodeId Dst)127 ArcIterator findArc(NodeId Src, NodeId Dst) {
128 return Arcs.find(Arc(Src, Dst));
129 }
findArc(NodeId Src,NodeId Dst)130 ArcConstIterator findArc(NodeId Src, NodeId Dst) const {
131 return Arcs.find(Arc(Src, Dst));
132 }
arcs()133 iterator_range<ArcConstIterator> arcs() const {
134 return iterator_range<ArcConstIterator>(Arcs.begin(), Arcs.end());
135 }
nodes()136 iterator_range<std::vector<Node>::const_iterator> nodes() const {
137 return iterator_range<std::vector<Node>::const_iterator>(Nodes.begin(),
138 Nodes.end());
139 }
140
density()141 double density() const {
142 return double(Arcs.size()) / (Nodes.size() * Nodes.size());
143 }
144
145 // Initialize NormalizedWeight field for every arc
146 void normalizeArcWeights();
147 // Make sure that the sum of incoming arc weights is at least the number of
148 // samples for every node
149 void adjustArcWeights();
150
151 template <typename L> void printDot(StringRef FileName, L getLabel) const;
152
153 private:
setSamples(const NodeId Id,uint64_t Samples)154 void setSamples(const NodeId Id, uint64_t Samples) {
155 assert(Id < Nodes.size());
156 Nodes[Id].Samples = Samples;
157 }
158
159 std::vector<Node> Nodes;
160 ArcsType Arcs;
161 };
162
163 template <class L>
printDot(StringRef FileName,L GetLabel)164 void CallGraph::printDot(StringRef FileName, L GetLabel) const {
165 std::error_code EC;
166 raw_fd_ostream OS(FileName, EC, sys::fs::OF_None);
167 if (EC)
168 return;
169
170 OS << "digraph g {\n";
171 for (NodeId F = 0; F < Nodes.size(); F++) {
172 if (Nodes[F].samples() == 0)
173 continue;
174 OS << "f" << F << " [label=\"" << GetLabel(F)
175 << "\\nsamples=" << Nodes[F].samples() << "\\nsize=" << Nodes[F].size()
176 << "\"];\n";
177 }
178 for (NodeId F = 0; F < Nodes.size(); F++) {
179 if (Nodes[F].samples() == 0)
180 continue;
181 for (NodeId Dst : Nodes[F].successors()) {
182 ArcConstIterator Arc = findArc(F, Dst);
183 OS << "f" << F << " -> f" << Dst
184 << " [label=\"normWgt=" << format("%.3lf", Arc->normalizedWeight())
185 << ",weight=" << format("%.0lf", Arc->weight())
186 << ",callOffset=" << format("%.1lf", Arc->avgCallOffset()) << "\"];\n";
187 }
188 }
189 OS << "}\n";
190 }
191
192 } // namespace bolt
193 } // namespace llvm
194
195 #endif
196