xref: /llvm-project/mlir/lib/Analysis/CallGraph.cpp (revision d1cad2290c10712ea27509081f50769ed597ee0f)
1 //===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains interfaces and analyses for defining a nested callgraph.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/CallGraph.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/IR/SymbolTable.h"
16 #include "mlir/Interfaces/CallInterfaces.h"
17 #include "mlir/Support/LLVM.h"
18 #include "llvm/ADT/SCCIterator.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Support/raw_ostream.h"
22 #include <cassert>
23 #include <memory>
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // CallGraphNode
29 //===----------------------------------------------------------------------===//
30 
31 /// Returns true if this node refers to the indirect/external node.
32 bool CallGraphNode::isExternal() const { return !callableRegion; }
33 
34 /// Return the callable region this node represents. This can only be called
35 /// on non-external nodes.
36 Region *CallGraphNode::getCallableRegion() const {
37   assert(!isExternal() && "the external node has no callable region");
38   return callableRegion;
39 }
40 
41 /// Adds an reference edge to the given node. This is only valid on the
42 /// external node.
43 void CallGraphNode::addAbstractEdge(CallGraphNode *node) {
44   assert(isExternal() && "abstract edges are only valid on external nodes");
45   addEdge(node, Edge::Kind::Abstract);
46 }
47 
48 /// Add an outgoing call edge from this node.
49 void CallGraphNode::addCallEdge(CallGraphNode *node) {
50   addEdge(node, Edge::Kind::Call);
51 }
52 
53 /// Adds a reference edge to the given child node.
54 void CallGraphNode::addChildEdge(CallGraphNode *child) {
55   addEdge(child, Edge::Kind::Child);
56 }
57 
58 /// Returns true if this node has any child edges.
59 bool CallGraphNode::hasChildren() const {
60   return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); });
61 }
62 
63 /// Add an edge to 'node' with the given kind.
64 void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
65   edges.insert({node, kind});
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // CallGraph
70 //===----------------------------------------------------------------------===//
71 
72 /// Recursively compute the callgraph edges for the given operation. Computed
73 /// edges are placed into the given callgraph object.
74 static void computeCallGraph(Operation *op, CallGraph &cg,
75                              SymbolTableCollection &symbolTable,
76                              CallGraphNode *parentNode, bool resolveCalls) {
77   if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
78     // If there is no parent node, we ignore this operation. Even if this
79     // operation was a call, there would be no callgraph node to attribute it
80     // to.
81     if (resolveCalls && parentNode)
82       parentNode->addCallEdge(cg.resolveCallable(call, symbolTable));
83     return;
84   }
85 
86   // Compute the callgraph nodes and edges for each of the nested operations.
87   if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) {
88     if (auto *callableRegion = callable.getCallableRegion())
89       parentNode = cg.getOrAddNode(callableRegion, parentNode);
90     else
91       return;
92   }
93 
94   for (Region &region : op->getRegions())
95     for (Operation &nested : region.getOps())
96       computeCallGraph(&nested, cg, symbolTable, parentNode, resolveCalls);
97 }
98 
99 CallGraph::CallGraph(Operation *op)
100     : externalCallerNode(/*callableRegion=*/nullptr),
101       unknownCalleeNode(/*callableRegion=*/nullptr) {
102   // Make two passes over the graph, one to compute the callables and one to
103   // resolve the calls. We split these up as we may have nested callable objects
104   // that need to be reserved before the calls.
105   SymbolTableCollection symbolTable;
106   computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
107                    /*resolveCalls=*/false);
108   computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
109                    /*resolveCalls=*/true);
110 }
111 
112 /// Get or add a call graph node for the given region.
113 CallGraphNode *CallGraph::getOrAddNode(Region *region,
114                                        CallGraphNode *parentNode) {
115   assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
116          "expected parent operation to be callable");
117   std::unique_ptr<CallGraphNode> &node = nodes[region];
118   if (!node) {
119     node.reset(new CallGraphNode(region));
120 
121     // Add this node to the given parent node if necessary.
122     if (parentNode) {
123       parentNode->addChildEdge(node.get());
124     } else {
125       // Otherwise, connect all callable nodes to the external node, this allows
126       // for conservatively including all callable nodes within the graph.
127       // FIXME This isn't correct, this is only necessary for callable nodes
128       // that *could* be called from external sources. This requires extending
129       // the interface for callables to check if they may be referenced
130       // externally.
131       externalCallerNode.addAbstractEdge(node.get());
132     }
133   }
134   return node.get();
135 }
136 
137 /// Lookup a call graph node for the given region, or nullptr if none is
138 /// registered.
139 CallGraphNode *CallGraph::lookupNode(Region *region) const {
140   const auto *it = nodes.find(region);
141   return it == nodes.end() ? nullptr : it->second.get();
142 }
143 
144 /// Resolve the callable for given callee to a node in the callgraph, or the
145 /// unknown callee node if a valid node was not resolved.
146 CallGraphNode *
147 CallGraph::resolveCallable(CallOpInterface call,
148                            SymbolTableCollection &symbolTable) const {
149   Operation *callable = call.resolveCallableInTable(&symbolTable);
150   if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
151     if (auto *node = lookupNode(callableOp.getCallableRegion()))
152       return node;
153 
154   return getUnknownCalleeNode();
155 }
156 
157 /// Erase the given node from the callgraph.
158 void CallGraph::eraseNode(CallGraphNode *node) {
159   // Erase any children of this node first.
160   if (node->hasChildren()) {
161     for (const CallGraphNode::Edge &edge : llvm::make_early_inc_range(*node))
162       if (edge.isChild())
163         eraseNode(edge.getTarget());
164   }
165   // Erase any edges to this node from any other nodes.
166   for (auto &it : nodes) {
167     it.second->edges.remove_if([node](const CallGraphNode::Edge &edge) {
168       return edge.getTarget() == node;
169     });
170   }
171   nodes.erase(node->getCallableRegion());
172 }
173 
174 //===----------------------------------------------------------------------===//
175 // Printing
176 
177 /// Dump the graph in a human readable format.
178 void CallGraph::dump() const { print(llvm::errs()); }
179 void CallGraph::print(raw_ostream &os) const {
180   os << "// ---- CallGraph ----\n";
181 
182   // Functor used to output the name for the given node.
183   auto emitNodeName = [&](const CallGraphNode *node) {
184     if (node == getExternalCallerNode()) {
185       os << "<External-Caller-Node>";
186       return;
187     }
188     if (node == getUnknownCalleeNode()) {
189       os << "<Unknown-Callee-Node>";
190       return;
191     }
192 
193     auto *callableRegion = node->getCallableRegion();
194     auto *parentOp = callableRegion->getParentOp();
195     os << "'" << callableRegion->getParentOp()->getName() << "' - Region #"
196        << callableRegion->getRegionNumber();
197     auto attrs = parentOp->getAttrDictionary();
198     if (!attrs.empty())
199       os << " : " << attrs;
200   };
201 
202   for (auto &nodeIt : nodes) {
203     const CallGraphNode *node = nodeIt.second.get();
204 
205     // Dump the header for this node.
206     os << "// - Node : ";
207     emitNodeName(node);
208     os << "\n";
209 
210     // Emit each of the edges.
211     for (auto &edge : *node) {
212       os << "// -- ";
213       if (edge.isCall())
214         os << "Call";
215       else if (edge.isChild())
216         os << "Child";
217 
218       os << "-Edge : ";
219       emitNodeName(edge.getTarget());
220       os << "\n";
221     }
222     os << "//\n";
223   }
224 
225   os << "// -- SCCs --\n";
226 
227   for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) {
228     os << "// - SCC : \n";
229     for (auto &node : scc) {
230       os << "// -- Node :";
231       emitNodeName(node);
232       os << "\n";
233     }
234     os << "\n";
235   }
236 
237   os << "// -------------------\n";
238 }
239