xref: /llvm-project/mlir/lib/Analysis/CallGraph.cpp (revision d1cad2290c10712ea27509081f50769ed597ee0f)
18965011fSRiver Riddle //===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===//
28965011fSRiver Riddle //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68965011fSRiver Riddle //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
88965011fSRiver Riddle //
98965011fSRiver Riddle // This file contains interfaces and analyses for defining a nested callgraph.
108965011fSRiver Riddle //
118965011fSRiver Riddle //===----------------------------------------------------------------------===//
128965011fSRiver Riddle 
138cb405a8SRiver Riddle #include "mlir/Analysis/CallGraph.h"
148cb405a8SRiver Riddle #include "mlir/IR/Operation.h"
158cb405a8SRiver Riddle #include "mlir/IR/SymbolTable.h"
167ce1e7abSRiver Riddle #include "mlir/Interfaces/CallInterfaces.h"
179dd06a12SMehdi Amini #include "mlir/Support/LLVM.h"
188cb405a8SRiver Riddle #include "llvm/ADT/SCCIterator.h"
199dd06a12SMehdi Amini #include "llvm/ADT/STLExtras.h"
209dd06a12SMehdi Amini #include "llvm/ADT/iterator_range.h"
218cb405a8SRiver Riddle #include "llvm/Support/raw_ostream.h"
229dd06a12SMehdi Amini #include <cassert>
239dd06a12SMehdi Amini #include <memory>
248965011fSRiver Riddle 
258965011fSRiver Riddle using namespace mlir;
268965011fSRiver Riddle 
278965011fSRiver Riddle //===----------------------------------------------------------------------===//
288cb405a8SRiver Riddle // CallGraphNode
298cb405a8SRiver Riddle //===----------------------------------------------------------------------===//
308cb405a8SRiver Riddle 
31deb99610SKamlesh Kumar /// Returns true if this node refers to the indirect/external node.
328cb405a8SRiver Riddle bool CallGraphNode::isExternal() const { return !callableRegion; }
338cb405a8SRiver Riddle 
348cb405a8SRiver Riddle /// Return the callable region this node represents. This can only be called
358cb405a8SRiver Riddle /// on non-external nodes.
368cb405a8SRiver Riddle Region *CallGraphNode::getCallableRegion() const {
378cb405a8SRiver Riddle   assert(!isExternal() && "the external node has no callable region");
388cb405a8SRiver Riddle   return callableRegion;
398cb405a8SRiver Riddle }
408cb405a8SRiver Riddle 
418cb405a8SRiver Riddle /// Adds an reference edge to the given node. This is only valid on the
428cb405a8SRiver Riddle /// external node.
438cb405a8SRiver Riddle void CallGraphNode::addAbstractEdge(CallGraphNode *node) {
448cb405a8SRiver Riddle   assert(isExternal() && "abstract edges are only valid on external nodes");
458cb405a8SRiver Riddle   addEdge(node, Edge::Kind::Abstract);
468cb405a8SRiver Riddle }
478cb405a8SRiver Riddle 
488cb405a8SRiver Riddle /// Add an outgoing call edge from this node.
498cb405a8SRiver Riddle void CallGraphNode::addCallEdge(CallGraphNode *node) {
508cb405a8SRiver Riddle   addEdge(node, Edge::Kind::Call);
518cb405a8SRiver Riddle }
528cb405a8SRiver Riddle 
538cb405a8SRiver Riddle /// Adds a reference edge to the given child node.
548cb405a8SRiver Riddle void CallGraphNode::addChildEdge(CallGraphNode *child) {
558cb405a8SRiver Riddle   addEdge(child, Edge::Kind::Child);
568cb405a8SRiver Riddle }
578cb405a8SRiver Riddle 
586b1cc3c6SRiver Riddle /// Returns true if this node has any child edges.
596b1cc3c6SRiver Riddle bool CallGraphNode::hasChildren() const {
606b1cc3c6SRiver Riddle   return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); });
616b1cc3c6SRiver Riddle }
626b1cc3c6SRiver Riddle 
638cb405a8SRiver Riddle /// Add an edge to 'node' with the given kind.
648cb405a8SRiver Riddle void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
658cb405a8SRiver Riddle   edges.insert({node, kind});
668cb405a8SRiver Riddle }
678cb405a8SRiver Riddle 
688cb405a8SRiver Riddle //===----------------------------------------------------------------------===//
698cb405a8SRiver Riddle // CallGraph
708cb405a8SRiver Riddle //===----------------------------------------------------------------------===//
718cb405a8SRiver Riddle 
726b1cc3c6SRiver Riddle /// Recursively compute the callgraph edges for the given operation. Computed
736b1cc3c6SRiver Riddle /// edges are placed into the given callgraph object.
746b1cc3c6SRiver Riddle static void computeCallGraph(Operation *op, CallGraph &cg,
75a5ea6045SRiver Riddle                              SymbolTableCollection &symbolTable,
76c7748404SRiver Riddle                              CallGraphNode *parentNode, bool resolveCalls) {
77c7748404SRiver Riddle   if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
788cb405a8SRiver Riddle     // If there is no parent node, we ignore this operation. Even if this
79c7748404SRiver Riddle     // operation was a call, there would be no callgraph node to attribute it
80c7748404SRiver Riddle     // to.
815c159b91SRiver Riddle     if (resolveCalls && parentNode)
82a5ea6045SRiver Riddle       parentNode->addCallEdge(cg.resolveCallable(call, symbolTable));
83c7748404SRiver Riddle     return;
84c7748404SRiver Riddle   }
85c7748404SRiver Riddle 
86c7748404SRiver Riddle   // Compute the callgraph nodes and edges for each of the nested operations.
87c7748404SRiver Riddle   if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) {
88c7748404SRiver Riddle     if (auto *callableRegion = callable.getCallableRegion())
89c7748404SRiver Riddle       parentNode = cg.getOrAddNode(callableRegion, parentNode);
90c7748404SRiver Riddle     else
91c7748404SRiver Riddle       return;
92c7748404SRiver Riddle   }
93c7748404SRiver Riddle 
94c7748404SRiver Riddle   for (Region &region : op->getRegions())
951e4faf23SRiver Riddle     for (Operation &nested : region.getOps())
96a5ea6045SRiver Riddle       computeCallGraph(&nested, cg, symbolTable, parentNode, resolveCalls);
978cb405a8SRiver Riddle }
988cb405a8SRiver Riddle 
9982f86b86SMarkus Böck CallGraph::CallGraph(Operation *op)
10082f86b86SMarkus Böck     : externalCallerNode(/*callableRegion=*/nullptr),
10182f86b86SMarkus Böck       unknownCalleeNode(/*callableRegion=*/nullptr) {
102c7748404SRiver Riddle   // Make two passes over the graph, one to compute the callables and one to
103c7748404SRiver Riddle   // resolve the calls. We split these up as we may have nested callable objects
104c7748404SRiver Riddle   // that need to be reserved before the calls.
105a5ea6045SRiver Riddle   SymbolTableCollection symbolTable;
106a5ea6045SRiver Riddle   computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
107a5ea6045SRiver Riddle                    /*resolveCalls=*/false);
108a5ea6045SRiver Riddle   computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
109a5ea6045SRiver Riddle                    /*resolveCalls=*/true);
1108cb405a8SRiver Riddle }
1118cb405a8SRiver Riddle 
1128cb405a8SRiver Riddle /// Get or add a call graph node for the given region.
1138cb405a8SRiver Riddle CallGraphNode *CallGraph::getOrAddNode(Region *region,
1148cb405a8SRiver Riddle                                        CallGraphNode *parentNode) {
1158cb405a8SRiver Riddle   assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
1168cb405a8SRiver Riddle          "expected parent operation to be callable");
1178cb405a8SRiver Riddle   std::unique_ptr<CallGraphNode> &node = nodes[region];
1188cb405a8SRiver Riddle   if (!node) {
1198cb405a8SRiver Riddle     node.reset(new CallGraphNode(region));
1208cb405a8SRiver Riddle 
1218cb405a8SRiver Riddle     // Add this node to the given parent node if necessary.
122a5ea6045SRiver Riddle     if (parentNode) {
1238cb405a8SRiver Riddle       parentNode->addChildEdge(node.get());
124a5ea6045SRiver Riddle     } else {
1258cb405a8SRiver Riddle       // Otherwise, connect all callable nodes to the external node, this allows
1268cb405a8SRiver Riddle       // for conservatively including all callable nodes within the graph.
127a5ea6045SRiver Riddle       // FIXME This isn't correct, this is only necessary for callable nodes
128a5ea6045SRiver Riddle       // that *could* be called from external sources. This requires extending
129a5ea6045SRiver Riddle       // the interface for callables to check if they may be referenced
130a5ea6045SRiver Riddle       // externally.
13182f86b86SMarkus Böck       externalCallerNode.addAbstractEdge(node.get());
1328cb405a8SRiver Riddle     }
133a5ea6045SRiver Riddle   }
1348cb405a8SRiver Riddle   return node.get();
1358cb405a8SRiver Riddle }
1368cb405a8SRiver Riddle 
1378cb405a8SRiver Riddle /// Lookup a call graph node for the given region, or nullptr if none is
1388cb405a8SRiver Riddle /// registered.
1398cb405a8SRiver Riddle CallGraphNode *CallGraph::lookupNode(Region *region) const {
140a2e57209SMehdi Amini   const auto *it = nodes.find(region);
1418cb405a8SRiver Riddle   return it == nodes.end() ? nullptr : it->second.get();
1428cb405a8SRiver Riddle }
1438cb405a8SRiver Riddle 
1448cb405a8SRiver Riddle /// Resolve the callable for given callee to a node in the callgraph, or the
14582f86b86SMarkus Böck /// unknown callee node if a valid node was not resolved.
146a5ea6045SRiver Riddle CallGraphNode *
147a5ea6045SRiver Riddle CallGraph::resolveCallable(CallOpInterface call,
148a5ea6045SRiver Riddle                            SymbolTableCollection &symbolTable) const {
149*d1cad229SHenrich Lauko   Operation *callable = call.resolveCallableInTable(&symbolTable);
1505c159b91SRiver Riddle   if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
151c7748404SRiver Riddle     if (auto *node = lookupNode(callableOp.getCallableRegion()))
1528cb405a8SRiver Riddle       return node;
1538cb405a8SRiver Riddle 
15482f86b86SMarkus Böck   return getUnknownCalleeNode();
1558cb405a8SRiver Riddle }
1568cb405a8SRiver Riddle 
1574be504a9SRiver Riddle /// Erase the given node from the callgraph.
1584be504a9SRiver Riddle void CallGraph::eraseNode(CallGraphNode *node) {
1594be504a9SRiver Riddle   // Erase any children of this node first.
1604be504a9SRiver Riddle   if (node->hasChildren()) {
1614be504a9SRiver Riddle     for (const CallGraphNode::Edge &edge : llvm::make_early_inc_range(*node))
1624be504a9SRiver Riddle       if (edge.isChild())
1634be504a9SRiver Riddle         eraseNode(edge.getTarget());
1644be504a9SRiver Riddle   }
1654be504a9SRiver Riddle   // Erase any edges to this node from any other nodes.
1664be504a9SRiver Riddle   for (auto &it : nodes) {
1674be504a9SRiver Riddle     it.second->edges.remove_if([node](const CallGraphNode::Edge &edge) {
1684be504a9SRiver Riddle       return edge.getTarget() == node;
1694be504a9SRiver Riddle     });
1704be504a9SRiver Riddle   }
1714be504a9SRiver Riddle   nodes.erase(node->getCallableRegion());
1724be504a9SRiver Riddle }
1734be504a9SRiver Riddle 
1746b1cc3c6SRiver Riddle //===----------------------------------------------------------------------===//
1756b1cc3c6SRiver Riddle // Printing
1766b1cc3c6SRiver Riddle 
177e5026165SAlexander Belyaev /// Dump the graph in a human readable format.
1788cb405a8SRiver Riddle void CallGraph::dump() const { print(llvm::errs()); }
1798cb405a8SRiver Riddle void CallGraph::print(raw_ostream &os) const {
1808cb405a8SRiver Riddle   os << "// ---- CallGraph ----\n";
1818cb405a8SRiver Riddle 
1828cb405a8SRiver Riddle   // Functor used to output the name for the given node.
1838cb405a8SRiver Riddle   auto emitNodeName = [&](const CallGraphNode *node) {
18482f86b86SMarkus Böck     if (node == getExternalCallerNode()) {
18582f86b86SMarkus Böck       os << "<External-Caller-Node>";
18682f86b86SMarkus Böck       return;
18782f86b86SMarkus Böck     }
18882f86b86SMarkus Böck     if (node == getUnknownCalleeNode()) {
18982f86b86SMarkus Böck       os << "<Unknown-Callee-Node>";
1908cb405a8SRiver Riddle       return;
1918cb405a8SRiver Riddle     }
1928cb405a8SRiver Riddle 
1938cb405a8SRiver Riddle     auto *callableRegion = node->getCallableRegion();
1948cb405a8SRiver Riddle     auto *parentOp = callableRegion->getParentOp();
1958cb405a8SRiver Riddle     os << "'" << callableRegion->getParentOp()->getName() << "' - Region #"
1968cb405a8SRiver Riddle        << callableRegion->getRegionNumber();
1975eae715aSJacques Pienaar     auto attrs = parentOp->getAttrDictionary();
1985eae715aSJacques Pienaar     if (!attrs.empty())
1998cb405a8SRiver Riddle       os << " : " << attrs;
2008cb405a8SRiver Riddle   };
2018cb405a8SRiver Riddle 
2028cb405a8SRiver Riddle   for (auto &nodeIt : nodes) {
2038cb405a8SRiver Riddle     const CallGraphNode *node = nodeIt.second.get();
2048cb405a8SRiver Riddle 
2058cb405a8SRiver Riddle     // Dump the header for this node.
2068cb405a8SRiver Riddle     os << "// - Node : ";
2078cb405a8SRiver Riddle     emitNodeName(node);
2088cb405a8SRiver Riddle     os << "\n";
2098cb405a8SRiver Riddle 
2108cb405a8SRiver Riddle     // Emit each of the edges.
2118cb405a8SRiver Riddle     for (auto &edge : *node) {
2128cb405a8SRiver Riddle       os << "// -- ";
2138cb405a8SRiver Riddle       if (edge.isCall())
2148cb405a8SRiver Riddle         os << "Call";
2158cb405a8SRiver Riddle       else if (edge.isChild())
2168cb405a8SRiver Riddle         os << "Child";
2178cb405a8SRiver Riddle 
2188cb405a8SRiver Riddle       os << "-Edge : ";
2198cb405a8SRiver Riddle       emitNodeName(edge.getTarget());
2208cb405a8SRiver Riddle       os << "\n";
2218cb405a8SRiver Riddle     }
2228cb405a8SRiver Riddle     os << "//\n";
2238cb405a8SRiver Riddle   }
2248cb405a8SRiver Riddle 
2258cb405a8SRiver Riddle   os << "// -- SCCs --\n";
2268cb405a8SRiver Riddle 
2278cb405a8SRiver Riddle   for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) {
2288cb405a8SRiver Riddle     os << "// - SCC : \n";
2298cb405a8SRiver Riddle     for (auto &node : scc) {
2308cb405a8SRiver Riddle       os << "// -- Node :";
2318cb405a8SRiver Riddle       emitNodeName(node);
2328cb405a8SRiver Riddle       os << "\n";
2338cb405a8SRiver Riddle     }
2348cb405a8SRiver Riddle     os << "\n";
2358cb405a8SRiver Riddle   }
2368cb405a8SRiver Riddle 
2378cb405a8SRiver Riddle   os << "// -------------------\n";
2388cb405a8SRiver Riddle }
239