1 //===- CallPrinter.cpp - DOT printer for call graph -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines '-dot-callgraph', which emit a callgraph.<fnname>.dot
10 // containing the call graph of a module.
11 //
12 // There is also a pass available to directly call dotty ('-view-callgraph').
13 //
14 //===----------------------------------------------------------------------===//
15
16 #include "llvm/Analysis/CallPrinter.h"
17 #include "llvm/Analysis/BlockFrequencyInfo.h"
18 #include "llvm/Analysis/BranchProbabilityInfo.h"
19 #include "llvm/Analysis/CallGraph.h"
20 #include "llvm/Analysis/DOTGraphTraitsPass.h"
21 #include "llvm/Analysis/HeatUtils.h"
22 #include "llvm/Support/CommandLine.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/SmallSet.h"
26
27 using namespace llvm;
28
29 // This option shows static (relative) call counts.
30 // FIXME:
31 // Need to show real counts when profile data is available
32 static cl::opt<bool> ShowHeatColors("callgraph-heat-colors", cl::init(false),
33 cl::Hidden,
34 cl::desc("Show heat colors in call-graph"));
35
36 static cl::opt<bool>
37 ShowEdgeWeight("callgraph-show-weights", cl::init(false), cl::Hidden,
38 cl::desc("Show edges labeled with weights"));
39
40 static cl::opt<bool>
41 CallMultiGraph("callgraph-multigraph", cl::init(false), cl::Hidden,
42 cl::desc("Show call-multigraph (do not remove parallel edges)"));
43
44 static cl::opt<std::string> CallGraphDotFilenamePrefix(
45 "callgraph-dot-filename-prefix", cl::Hidden,
46 cl::desc("The prefix used for the CallGraph dot file names."));
47
48 namespace llvm {
49
50 class CallGraphDOTInfo {
51 private:
52 Module *M;
53 CallGraph *CG;
54 DenseMap<const Function *, uint64_t> Freq;
55 uint64_t MaxFreq;
56
57 public:
58 std::function<BlockFrequencyInfo *(Function &)> LookupBFI;
59
CallGraphDOTInfo(Module * M,CallGraph * CG,function_ref<BlockFrequencyInfo * (Function &)> LookupBFI)60 CallGraphDOTInfo(Module *M, CallGraph *CG,
61 function_ref<BlockFrequencyInfo *(Function &)> LookupBFI)
62 : M(M), CG(CG), LookupBFI(LookupBFI) {
63 MaxFreq = 0;
64
65 for (Function &F : M->getFunctionList()) {
66 uint64_t localSumFreq = 0;
67 SmallSet<Function *, 16> Callers;
68 for (User *U : F.users())
69 if (isa<CallInst>(U))
70 Callers.insert(cast<Instruction>(U)->getFunction());
71 for (Function *Caller : Callers)
72 localSumFreq += getNumOfCalls(*Caller, F);
73 if (localSumFreq >= MaxFreq)
74 MaxFreq = localSumFreq;
75 Freq[&F] = localSumFreq;
76 }
77 if (!CallMultiGraph)
78 removeParallelEdges();
79 }
80
getModule() const81 Module *getModule() const { return M; }
82
getCallGraph() const83 CallGraph *getCallGraph() const { return CG; }
84
getFreq(const Function * F)85 uint64_t getFreq(const Function *F) { return Freq[F]; }
86
getMaxFreq()87 uint64_t getMaxFreq() { return MaxFreq; }
88
89 private:
removeParallelEdges()90 void removeParallelEdges() {
91 for (auto &I : (*CG)) {
92 CallGraphNode *Node = I.second.get();
93
94 bool FoundParallelEdge = true;
95 while (FoundParallelEdge) {
96 SmallSet<Function *, 16> Visited;
97 FoundParallelEdge = false;
98 for (auto CI = Node->begin(), CE = Node->end(); CI != CE; CI++) {
99 if (!(Visited.insert(CI->second->getFunction())).second) {
100 FoundParallelEdge = true;
101 Node->removeCallEdge(CI);
102 break;
103 }
104 }
105 }
106 }
107 }
108 };
109
110 template <>
111 struct GraphTraits<CallGraphDOTInfo *>
112 : public GraphTraits<const CallGraphNode *> {
getEntryNodellvm::GraphTraits113 static NodeRef getEntryNode(CallGraphDOTInfo *CGInfo) {
114 // Start at the external node!
115 return CGInfo->getCallGraph()->getExternalCallingNode();
116 }
117
118 typedef std::pair<const Function *const, std::unique_ptr<CallGraphNode>>
119 PairTy;
CGGetValuePtrllvm::GraphTraits120 static const CallGraphNode *CGGetValuePtr(const PairTy &P) {
121 return P.second.get();
122 }
123
124 // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
125 typedef mapped_iterator<CallGraph::const_iterator, decltype(&CGGetValuePtr)>
126 nodes_iterator;
127
nodes_beginllvm::GraphTraits128 static nodes_iterator nodes_begin(CallGraphDOTInfo *CGInfo) {
129 return nodes_iterator(CGInfo->getCallGraph()->begin(), &CGGetValuePtr);
130 }
nodes_endllvm::GraphTraits131 static nodes_iterator nodes_end(CallGraphDOTInfo *CGInfo) {
132 return nodes_iterator(CGInfo->getCallGraph()->end(), &CGGetValuePtr);
133 }
134 };
135
136 template <>
137 struct DOTGraphTraits<CallGraphDOTInfo *> : public DefaultDOTGraphTraits {
138
DOTGraphTraitsllvm::DOTGraphTraits139 DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {}
140
getGraphNamellvm::DOTGraphTraits141 static std::string getGraphName(CallGraphDOTInfo *CGInfo) {
142 return "Call graph: " +
143 std::string(CGInfo->getModule()->getModuleIdentifier());
144 }
145
isNodeHiddenllvm::DOTGraphTraits146 static bool isNodeHidden(const CallGraphNode *Node,
147 const CallGraphDOTInfo *CGInfo) {
148 if (CallMultiGraph || Node->getFunction())
149 return false;
150 return true;
151 }
152
getNodeLabelllvm::DOTGraphTraits153 std::string getNodeLabel(const CallGraphNode *Node,
154 CallGraphDOTInfo *CGInfo) {
155 if (Node == CGInfo->getCallGraph()->getExternalCallingNode())
156 return "external caller";
157 if (Node == CGInfo->getCallGraph()->getCallsExternalNode())
158 return "external callee";
159
160 if (Function *Func = Node->getFunction())
161 return std::string(Func->getName());
162 return "external node";
163 }
CGGetValuePtrllvm::DOTGraphTraits164 static const CallGraphNode *CGGetValuePtr(CallGraphNode::CallRecord P) {
165 return P.second;
166 }
167
168 // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
169 typedef mapped_iterator<CallGraphNode::const_iterator,
170 decltype(&CGGetValuePtr)>
171 nodes_iterator;
172
getEdgeAttributesllvm::DOTGraphTraits173 std::string getEdgeAttributes(const CallGraphNode *Node, nodes_iterator I,
174 CallGraphDOTInfo *CGInfo) {
175 if (!ShowEdgeWeight)
176 return "";
177
178 Function *Caller = Node->getFunction();
179 if (Caller == nullptr || Caller->isDeclaration())
180 return "";
181
182 Function *Callee = (*I)->getFunction();
183 if (Callee == nullptr)
184 return "";
185
186 uint64_t Counter = getNumOfCalls(*Caller, *Callee);
187 double Width =
188 1 + 2 * (double(Counter) / CGInfo->getMaxFreq());
189 std::string Attrs = "label=\"" + std::to_string(Counter) +
190 "\" penwidth=" + std::to_string(Width);
191 return Attrs;
192 }
193
getNodeAttributesllvm::DOTGraphTraits194 std::string getNodeAttributes(const CallGraphNode *Node,
195 CallGraphDOTInfo *CGInfo) {
196 Function *F = Node->getFunction();
197 if (F == nullptr)
198 return "";
199 std::string attrs;
200 if (ShowHeatColors) {
201 uint64_t freq = CGInfo->getFreq(F);
202 std::string color = getHeatColor(freq, CGInfo->getMaxFreq());
203 std::string edgeColor = (freq <= (CGInfo->getMaxFreq() / 2))
204 ? getHeatColor(0)
205 : getHeatColor(1);
206 attrs = "color=\"" + edgeColor + "ff\", style=filled, fillcolor=\"" +
207 color + "80\"";
208 }
209 return attrs;
210 }
211 };
212
213 } // end llvm namespace
214
215 namespace {
216 // Viewer
217 class CallGraphViewer : public ModulePass {
218 public:
219 static char ID;
CallGraphViewer()220 CallGraphViewer() : ModulePass(ID) {}
221
222 void getAnalysisUsage(AnalysisUsage &AU) const override;
223 bool runOnModule(Module &M) override;
224 };
225
getAnalysisUsage(AnalysisUsage & AU) const226 void CallGraphViewer::getAnalysisUsage(AnalysisUsage &AU) const {
227 ModulePass::getAnalysisUsage(AU);
228 AU.addRequired<BlockFrequencyInfoWrapperPass>();
229 AU.setPreservesAll();
230 }
231
runOnModule(Module & M)232 bool CallGraphViewer::runOnModule(Module &M) {
233 auto LookupBFI = [this](Function &F) {
234 return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
235 };
236
237 CallGraph CG(M);
238 CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI);
239
240 std::string Title =
241 DOTGraphTraits<CallGraphDOTInfo *>::getGraphName(&CFGInfo);
242 ViewGraph(&CFGInfo, "callgraph", true, Title);
243
244 return false;
245 }
246
247 // DOT Printer
248
249 class CallGraphDOTPrinter : public ModulePass {
250 public:
251 static char ID;
CallGraphDOTPrinter()252 CallGraphDOTPrinter() : ModulePass(ID) {}
253
254 void getAnalysisUsage(AnalysisUsage &AU) const override;
255 bool runOnModule(Module &M) override;
256 };
257
getAnalysisUsage(AnalysisUsage & AU) const258 void CallGraphDOTPrinter::getAnalysisUsage(AnalysisUsage &AU) const {
259 ModulePass::getAnalysisUsage(AU);
260 AU.addRequired<BlockFrequencyInfoWrapperPass>();
261 AU.setPreservesAll();
262 }
263
runOnModule(Module & M)264 bool CallGraphDOTPrinter::runOnModule(Module &M) {
265 auto LookupBFI = [this](Function &F) {
266 return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
267 };
268
269 std::string Filename;
270 if (!CallGraphDotFilenamePrefix.empty())
271 Filename = (CallGraphDotFilenamePrefix + ".callgraph.dot");
272 else
273 Filename = (std::string(M.getModuleIdentifier()) + ".callgraph.dot");
274 errs() << "Writing '" << Filename << "'...";
275
276 std::error_code EC;
277 raw_fd_ostream File(Filename, EC, sys::fs::OF_Text);
278
279 CallGraph CG(M);
280 CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI);
281
282 if (!EC)
283 WriteGraph(File, &CFGInfo);
284 else
285 errs() << " error opening file for writing!";
286 errs() << "\n";
287
288 return false;
289 }
290
291 } // end anonymous namespace
292
293 char CallGraphViewer::ID = 0;
294 INITIALIZE_PASS(CallGraphViewer, "view-callgraph", "View call graph", false,
295 false)
296
297 char CallGraphDOTPrinter::ID = 0;
298 INITIALIZE_PASS(CallGraphDOTPrinter, "dot-callgraph",
299 "Print call graph to 'dot' file", false, false)
300
301 // Create methods available outside of this file, to use them
302 // "include/llvm/LinkAllPasses.h". Otherwise the pass would be deleted by
303 // the link time optimization.
304
createCallGraphViewerPass()305 ModulePass *llvm::createCallGraphViewerPass() { return new CallGraphViewer(); }
306
createCallGraphDOTPrinterPass()307 ModulePass *llvm::createCallGraphDOTPrinterPass() {
308 return new CallGraphDOTPrinter();
309 }
310