18965011fSRiver Riddle //===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===// 28965011fSRiver Riddle // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 68965011fSRiver Riddle // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 88965011fSRiver Riddle // 98965011fSRiver Riddle // This file contains interfaces and analyses for defining a nested callgraph. 108965011fSRiver Riddle // 118965011fSRiver Riddle //===----------------------------------------------------------------------===// 128965011fSRiver Riddle 138cb405a8SRiver Riddle #include "mlir/Analysis/CallGraph.h" 148cb405a8SRiver Riddle #include "mlir/IR/Operation.h" 158cb405a8SRiver Riddle #include "mlir/IR/SymbolTable.h" 167ce1e7abSRiver Riddle #include "mlir/Interfaces/CallInterfaces.h" 179dd06a12SMehdi Amini #include "mlir/Support/LLVM.h" 188cb405a8SRiver Riddle #include "llvm/ADT/SCCIterator.h" 199dd06a12SMehdi Amini #include "llvm/ADT/STLExtras.h" 209dd06a12SMehdi Amini #include "llvm/ADT/iterator_range.h" 218cb405a8SRiver Riddle #include "llvm/Support/raw_ostream.h" 229dd06a12SMehdi Amini #include <cassert> 239dd06a12SMehdi Amini #include <memory> 248965011fSRiver Riddle 258965011fSRiver Riddle using namespace mlir; 268965011fSRiver Riddle 278965011fSRiver Riddle //===----------------------------------------------------------------------===// 288cb405a8SRiver Riddle // CallGraphNode 298cb405a8SRiver Riddle //===----------------------------------------------------------------------===// 308cb405a8SRiver Riddle 31deb99610SKamlesh Kumar /// Returns true if this node refers to the indirect/external node. 328cb405a8SRiver Riddle bool CallGraphNode::isExternal() const { return !callableRegion; } 338cb405a8SRiver Riddle 348cb405a8SRiver Riddle /// Return the callable region this node represents. This can only be called 358cb405a8SRiver Riddle /// on non-external nodes. 368cb405a8SRiver Riddle Region *CallGraphNode::getCallableRegion() const { 378cb405a8SRiver Riddle assert(!isExternal() && "the external node has no callable region"); 388cb405a8SRiver Riddle return callableRegion; 398cb405a8SRiver Riddle } 408cb405a8SRiver Riddle 418cb405a8SRiver Riddle /// Adds an reference edge to the given node. This is only valid on the 428cb405a8SRiver Riddle /// external node. 438cb405a8SRiver Riddle void CallGraphNode::addAbstractEdge(CallGraphNode *node) { 448cb405a8SRiver Riddle assert(isExternal() && "abstract edges are only valid on external nodes"); 458cb405a8SRiver Riddle addEdge(node, Edge::Kind::Abstract); 468cb405a8SRiver Riddle } 478cb405a8SRiver Riddle 488cb405a8SRiver Riddle /// Add an outgoing call edge from this node. 498cb405a8SRiver Riddle void CallGraphNode::addCallEdge(CallGraphNode *node) { 508cb405a8SRiver Riddle addEdge(node, Edge::Kind::Call); 518cb405a8SRiver Riddle } 528cb405a8SRiver Riddle 538cb405a8SRiver Riddle /// Adds a reference edge to the given child node. 548cb405a8SRiver Riddle void CallGraphNode::addChildEdge(CallGraphNode *child) { 558cb405a8SRiver Riddle addEdge(child, Edge::Kind::Child); 568cb405a8SRiver Riddle } 578cb405a8SRiver Riddle 586b1cc3c6SRiver Riddle /// Returns true if this node has any child edges. 596b1cc3c6SRiver Riddle bool CallGraphNode::hasChildren() const { 606b1cc3c6SRiver Riddle return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); }); 616b1cc3c6SRiver Riddle } 626b1cc3c6SRiver Riddle 638cb405a8SRiver Riddle /// Add an edge to 'node' with the given kind. 648cb405a8SRiver Riddle void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) { 658cb405a8SRiver Riddle edges.insert({node, kind}); 668cb405a8SRiver Riddle } 678cb405a8SRiver Riddle 688cb405a8SRiver Riddle //===----------------------------------------------------------------------===// 698cb405a8SRiver Riddle // CallGraph 708cb405a8SRiver Riddle //===----------------------------------------------------------------------===// 718cb405a8SRiver Riddle 726b1cc3c6SRiver Riddle /// Recursively compute the callgraph edges for the given operation. Computed 736b1cc3c6SRiver Riddle /// edges are placed into the given callgraph object. 746b1cc3c6SRiver Riddle static void computeCallGraph(Operation *op, CallGraph &cg, 75a5ea6045SRiver Riddle SymbolTableCollection &symbolTable, 76c7748404SRiver Riddle CallGraphNode *parentNode, bool resolveCalls) { 77c7748404SRiver Riddle if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) { 788cb405a8SRiver Riddle // If there is no parent node, we ignore this operation. Even if this 79c7748404SRiver Riddle // operation was a call, there would be no callgraph node to attribute it 80c7748404SRiver Riddle // to. 815c159b91SRiver Riddle if (resolveCalls && parentNode) 82a5ea6045SRiver Riddle parentNode->addCallEdge(cg.resolveCallable(call, symbolTable)); 83c7748404SRiver Riddle return; 84c7748404SRiver Riddle } 85c7748404SRiver Riddle 86c7748404SRiver Riddle // Compute the callgraph nodes and edges for each of the nested operations. 87c7748404SRiver Riddle if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) { 88c7748404SRiver Riddle if (auto *callableRegion = callable.getCallableRegion()) 89c7748404SRiver Riddle parentNode = cg.getOrAddNode(callableRegion, parentNode); 90c7748404SRiver Riddle else 91c7748404SRiver Riddle return; 92c7748404SRiver Riddle } 93c7748404SRiver Riddle 94c7748404SRiver Riddle for (Region ®ion : op->getRegions()) 951e4faf23SRiver Riddle for (Operation &nested : region.getOps()) 96a5ea6045SRiver Riddle computeCallGraph(&nested, cg, symbolTable, parentNode, resolveCalls); 978cb405a8SRiver Riddle } 988cb405a8SRiver Riddle 9982f86b86SMarkus Böck CallGraph::CallGraph(Operation *op) 10082f86b86SMarkus Böck : externalCallerNode(/*callableRegion=*/nullptr), 10182f86b86SMarkus Böck unknownCalleeNode(/*callableRegion=*/nullptr) { 102c7748404SRiver Riddle // Make two passes over the graph, one to compute the callables and one to 103c7748404SRiver Riddle // resolve the calls. We split these up as we may have nested callable objects 104c7748404SRiver Riddle // that need to be reserved before the calls. 105a5ea6045SRiver Riddle SymbolTableCollection symbolTable; 106a5ea6045SRiver Riddle computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr, 107a5ea6045SRiver Riddle /*resolveCalls=*/false); 108a5ea6045SRiver Riddle computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr, 109a5ea6045SRiver Riddle /*resolveCalls=*/true); 1108cb405a8SRiver Riddle } 1118cb405a8SRiver Riddle 1128cb405a8SRiver Riddle /// Get or add a call graph node for the given region. 1138cb405a8SRiver Riddle CallGraphNode *CallGraph::getOrAddNode(Region *region, 1148cb405a8SRiver Riddle CallGraphNode *parentNode) { 1158cb405a8SRiver Riddle assert(region && isa<CallableOpInterface>(region->getParentOp()) && 1168cb405a8SRiver Riddle "expected parent operation to be callable"); 1178cb405a8SRiver Riddle std::unique_ptr<CallGraphNode> &node = nodes[region]; 1188cb405a8SRiver Riddle if (!node) { 1198cb405a8SRiver Riddle node.reset(new CallGraphNode(region)); 1208cb405a8SRiver Riddle 1218cb405a8SRiver Riddle // Add this node to the given parent node if necessary. 122a5ea6045SRiver Riddle if (parentNode) { 1238cb405a8SRiver Riddle parentNode->addChildEdge(node.get()); 124a5ea6045SRiver Riddle } else { 1258cb405a8SRiver Riddle // Otherwise, connect all callable nodes to the external node, this allows 1268cb405a8SRiver Riddle // for conservatively including all callable nodes within the graph. 127a5ea6045SRiver Riddle // FIXME This isn't correct, this is only necessary for callable nodes 128a5ea6045SRiver Riddle // that *could* be called from external sources. This requires extending 129a5ea6045SRiver Riddle // the interface for callables to check if they may be referenced 130a5ea6045SRiver Riddle // externally. 13182f86b86SMarkus Böck externalCallerNode.addAbstractEdge(node.get()); 1328cb405a8SRiver Riddle } 133a5ea6045SRiver Riddle } 1348cb405a8SRiver Riddle return node.get(); 1358cb405a8SRiver Riddle } 1368cb405a8SRiver Riddle 1378cb405a8SRiver Riddle /// Lookup a call graph node for the given region, or nullptr if none is 1388cb405a8SRiver Riddle /// registered. 1398cb405a8SRiver Riddle CallGraphNode *CallGraph::lookupNode(Region *region) const { 140a2e57209SMehdi Amini const auto *it = nodes.find(region); 1418cb405a8SRiver Riddle return it == nodes.end() ? nullptr : it->second.get(); 1428cb405a8SRiver Riddle } 1438cb405a8SRiver Riddle 1448cb405a8SRiver Riddle /// Resolve the callable for given callee to a node in the callgraph, or the 14582f86b86SMarkus Böck /// unknown callee node if a valid node was not resolved. 146a5ea6045SRiver Riddle CallGraphNode * 147a5ea6045SRiver Riddle CallGraph::resolveCallable(CallOpInterface call, 148a5ea6045SRiver Riddle SymbolTableCollection &symbolTable) const { 149*d1cad229SHenrich Lauko Operation *callable = call.resolveCallableInTable(&symbolTable); 1505c159b91SRiver Riddle if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable)) 151c7748404SRiver Riddle if (auto *node = lookupNode(callableOp.getCallableRegion())) 1528cb405a8SRiver Riddle return node; 1538cb405a8SRiver Riddle 15482f86b86SMarkus Böck return getUnknownCalleeNode(); 1558cb405a8SRiver Riddle } 1568cb405a8SRiver Riddle 1574be504a9SRiver Riddle /// Erase the given node from the callgraph. 1584be504a9SRiver Riddle void CallGraph::eraseNode(CallGraphNode *node) { 1594be504a9SRiver Riddle // Erase any children of this node first. 1604be504a9SRiver Riddle if (node->hasChildren()) { 1614be504a9SRiver Riddle for (const CallGraphNode::Edge &edge : llvm::make_early_inc_range(*node)) 1624be504a9SRiver Riddle if (edge.isChild()) 1634be504a9SRiver Riddle eraseNode(edge.getTarget()); 1644be504a9SRiver Riddle } 1654be504a9SRiver Riddle // Erase any edges to this node from any other nodes. 1664be504a9SRiver Riddle for (auto &it : nodes) { 1674be504a9SRiver Riddle it.second->edges.remove_if([node](const CallGraphNode::Edge &edge) { 1684be504a9SRiver Riddle return edge.getTarget() == node; 1694be504a9SRiver Riddle }); 1704be504a9SRiver Riddle } 1714be504a9SRiver Riddle nodes.erase(node->getCallableRegion()); 1724be504a9SRiver Riddle } 1734be504a9SRiver Riddle 1746b1cc3c6SRiver Riddle //===----------------------------------------------------------------------===// 1756b1cc3c6SRiver Riddle // Printing 1766b1cc3c6SRiver Riddle 177e5026165SAlexander Belyaev /// Dump the graph in a human readable format. 1788cb405a8SRiver Riddle void CallGraph::dump() const { print(llvm::errs()); } 1798cb405a8SRiver Riddle void CallGraph::print(raw_ostream &os) const { 1808cb405a8SRiver Riddle os << "// ---- CallGraph ----\n"; 1818cb405a8SRiver Riddle 1828cb405a8SRiver Riddle // Functor used to output the name for the given node. 1838cb405a8SRiver Riddle auto emitNodeName = [&](const CallGraphNode *node) { 18482f86b86SMarkus Böck if (node == getExternalCallerNode()) { 18582f86b86SMarkus Böck os << "<External-Caller-Node>"; 18682f86b86SMarkus Böck return; 18782f86b86SMarkus Böck } 18882f86b86SMarkus Böck if (node == getUnknownCalleeNode()) { 18982f86b86SMarkus Böck os << "<Unknown-Callee-Node>"; 1908cb405a8SRiver Riddle return; 1918cb405a8SRiver Riddle } 1928cb405a8SRiver Riddle 1938cb405a8SRiver Riddle auto *callableRegion = node->getCallableRegion(); 1948cb405a8SRiver Riddle auto *parentOp = callableRegion->getParentOp(); 1958cb405a8SRiver Riddle os << "'" << callableRegion->getParentOp()->getName() << "' - Region #" 1968cb405a8SRiver Riddle << callableRegion->getRegionNumber(); 1975eae715aSJacques Pienaar auto attrs = parentOp->getAttrDictionary(); 1985eae715aSJacques Pienaar if (!attrs.empty()) 1998cb405a8SRiver Riddle os << " : " << attrs; 2008cb405a8SRiver Riddle }; 2018cb405a8SRiver Riddle 2028cb405a8SRiver Riddle for (auto &nodeIt : nodes) { 2038cb405a8SRiver Riddle const CallGraphNode *node = nodeIt.second.get(); 2048cb405a8SRiver Riddle 2058cb405a8SRiver Riddle // Dump the header for this node. 2068cb405a8SRiver Riddle os << "// - Node : "; 2078cb405a8SRiver Riddle emitNodeName(node); 2088cb405a8SRiver Riddle os << "\n"; 2098cb405a8SRiver Riddle 2108cb405a8SRiver Riddle // Emit each of the edges. 2118cb405a8SRiver Riddle for (auto &edge : *node) { 2128cb405a8SRiver Riddle os << "// -- "; 2138cb405a8SRiver Riddle if (edge.isCall()) 2148cb405a8SRiver Riddle os << "Call"; 2158cb405a8SRiver Riddle else if (edge.isChild()) 2168cb405a8SRiver Riddle os << "Child"; 2178cb405a8SRiver Riddle 2188cb405a8SRiver Riddle os << "-Edge : "; 2198cb405a8SRiver Riddle emitNodeName(edge.getTarget()); 2208cb405a8SRiver Riddle os << "\n"; 2218cb405a8SRiver Riddle } 2228cb405a8SRiver Riddle os << "//\n"; 2238cb405a8SRiver Riddle } 2248cb405a8SRiver Riddle 2258cb405a8SRiver Riddle os << "// -- SCCs --\n"; 2268cb405a8SRiver Riddle 2278cb405a8SRiver Riddle for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) { 2288cb405a8SRiver Riddle os << "// - SCC : \n"; 2298cb405a8SRiver Riddle for (auto &node : scc) { 2308cb405a8SRiver Riddle os << "// -- Node :"; 2318cb405a8SRiver Riddle emitNodeName(node); 2328cb405a8SRiver Riddle os << "\n"; 2338cb405a8SRiver Riddle } 2348cb405a8SRiver Riddle os << "\n"; 2358cb405a8SRiver Riddle } 2368cb405a8SRiver Riddle 2378cb405a8SRiver Riddle os << "// -------------------\n"; 2388cb405a8SRiver Riddle } 239