1 //===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains interfaces and analyses for defining a nested callgraph. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/CallGraph.h" 14 #include "mlir/IR/Operation.h" 15 #include "mlir/IR/SymbolTable.h" 16 #include "mlir/Interfaces/CallInterfaces.h" 17 #include "mlir/Support/LLVM.h" 18 #include "llvm/ADT/SCCIterator.h" 19 #include "llvm/ADT/STLExtras.h" 20 #include "llvm/ADT/iterator_range.h" 21 #include "llvm/Support/raw_ostream.h" 22 #include <cassert> 23 #include <memory> 24 25 using namespace mlir; 26 27 //===----------------------------------------------------------------------===// 28 // CallGraphNode 29 //===----------------------------------------------------------------------===// 30 31 /// Returns true if this node refers to the indirect/external node. 32 bool CallGraphNode::isExternal() const { return !callableRegion; } 33 34 /// Return the callable region this node represents. This can only be called 35 /// on non-external nodes. 36 Region *CallGraphNode::getCallableRegion() const { 37 assert(!isExternal() && "the external node has no callable region"); 38 return callableRegion; 39 } 40 41 /// Adds an reference edge to the given node. This is only valid on the 42 /// external node. 43 void CallGraphNode::addAbstractEdge(CallGraphNode *node) { 44 assert(isExternal() && "abstract edges are only valid on external nodes"); 45 addEdge(node, Edge::Kind::Abstract); 46 } 47 48 /// Add an outgoing call edge from this node. 49 void CallGraphNode::addCallEdge(CallGraphNode *node) { 50 addEdge(node, Edge::Kind::Call); 51 } 52 53 /// Adds a reference edge to the given child node. 54 void CallGraphNode::addChildEdge(CallGraphNode *child) { 55 addEdge(child, Edge::Kind::Child); 56 } 57 58 /// Returns true if this node has any child edges. 59 bool CallGraphNode::hasChildren() const { 60 return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); }); 61 } 62 63 /// Add an edge to 'node' with the given kind. 64 void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) { 65 edges.insert({node, kind}); 66 } 67 68 //===----------------------------------------------------------------------===// 69 // CallGraph 70 //===----------------------------------------------------------------------===// 71 72 /// Recursively compute the callgraph edges for the given operation. Computed 73 /// edges are placed into the given callgraph object. 74 static void computeCallGraph(Operation *op, CallGraph &cg, 75 SymbolTableCollection &symbolTable, 76 CallGraphNode *parentNode, bool resolveCalls) { 77 if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) { 78 // If there is no parent node, we ignore this operation. Even if this 79 // operation was a call, there would be no callgraph node to attribute it 80 // to. 81 if (resolveCalls && parentNode) 82 parentNode->addCallEdge(cg.resolveCallable(call, symbolTable)); 83 return; 84 } 85 86 // Compute the callgraph nodes and edges for each of the nested operations. 87 if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) { 88 if (auto *callableRegion = callable.getCallableRegion()) 89 parentNode = cg.getOrAddNode(callableRegion, parentNode); 90 else 91 return; 92 } 93 94 for (Region ®ion : op->getRegions()) 95 for (Operation &nested : region.getOps()) 96 computeCallGraph(&nested, cg, symbolTable, parentNode, resolveCalls); 97 } 98 99 CallGraph::CallGraph(Operation *op) 100 : externalCallerNode(/*callableRegion=*/nullptr), 101 unknownCalleeNode(/*callableRegion=*/nullptr) { 102 // Make two passes over the graph, one to compute the callables and one to 103 // resolve the calls. We split these up as we may have nested callable objects 104 // that need to be reserved before the calls. 105 SymbolTableCollection symbolTable; 106 computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr, 107 /*resolveCalls=*/false); 108 computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr, 109 /*resolveCalls=*/true); 110 } 111 112 /// Get or add a call graph node for the given region. 113 CallGraphNode *CallGraph::getOrAddNode(Region *region, 114 CallGraphNode *parentNode) { 115 assert(region && isa<CallableOpInterface>(region->getParentOp()) && 116 "expected parent operation to be callable"); 117 std::unique_ptr<CallGraphNode> &node = nodes[region]; 118 if (!node) { 119 node.reset(new CallGraphNode(region)); 120 121 // Add this node to the given parent node if necessary. 122 if (parentNode) { 123 parentNode->addChildEdge(node.get()); 124 } else { 125 // Otherwise, connect all callable nodes to the external node, this allows 126 // for conservatively including all callable nodes within the graph. 127 // FIXME This isn't correct, this is only necessary for callable nodes 128 // that *could* be called from external sources. This requires extending 129 // the interface for callables to check if they may be referenced 130 // externally. 131 externalCallerNode.addAbstractEdge(node.get()); 132 } 133 } 134 return node.get(); 135 } 136 137 /// Lookup a call graph node for the given region, or nullptr if none is 138 /// registered. 139 CallGraphNode *CallGraph::lookupNode(Region *region) const { 140 const auto *it = nodes.find(region); 141 return it == nodes.end() ? nullptr : it->second.get(); 142 } 143 144 /// Resolve the callable for given callee to a node in the callgraph, or the 145 /// unknown callee node if a valid node was not resolved. 146 CallGraphNode * 147 CallGraph::resolveCallable(CallOpInterface call, 148 SymbolTableCollection &symbolTable) const { 149 Operation *callable = call.resolveCallableInTable(&symbolTable); 150 if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable)) 151 if (auto *node = lookupNode(callableOp.getCallableRegion())) 152 return node; 153 154 return getUnknownCalleeNode(); 155 } 156 157 /// Erase the given node from the callgraph. 158 void CallGraph::eraseNode(CallGraphNode *node) { 159 // Erase any children of this node first. 160 if (node->hasChildren()) { 161 for (const CallGraphNode::Edge &edge : llvm::make_early_inc_range(*node)) 162 if (edge.isChild()) 163 eraseNode(edge.getTarget()); 164 } 165 // Erase any edges to this node from any other nodes. 166 for (auto &it : nodes) { 167 it.second->edges.remove_if([node](const CallGraphNode::Edge &edge) { 168 return edge.getTarget() == node; 169 }); 170 } 171 nodes.erase(node->getCallableRegion()); 172 } 173 174 //===----------------------------------------------------------------------===// 175 // Printing 176 177 /// Dump the graph in a human readable format. 178 void CallGraph::dump() const { print(llvm::errs()); } 179 void CallGraph::print(raw_ostream &os) const { 180 os << "// ---- CallGraph ----\n"; 181 182 // Functor used to output the name for the given node. 183 auto emitNodeName = [&](const CallGraphNode *node) { 184 if (node == getExternalCallerNode()) { 185 os << "<External-Caller-Node>"; 186 return; 187 } 188 if (node == getUnknownCalleeNode()) { 189 os << "<Unknown-Callee-Node>"; 190 return; 191 } 192 193 auto *callableRegion = node->getCallableRegion(); 194 auto *parentOp = callableRegion->getParentOp(); 195 os << "'" << callableRegion->getParentOp()->getName() << "' - Region #" 196 << callableRegion->getRegionNumber(); 197 auto attrs = parentOp->getAttrDictionary(); 198 if (!attrs.empty()) 199 os << " : " << attrs; 200 }; 201 202 for (auto &nodeIt : nodes) { 203 const CallGraphNode *node = nodeIt.second.get(); 204 205 // Dump the header for this node. 206 os << "// - Node : "; 207 emitNodeName(node); 208 os << "\n"; 209 210 // Emit each of the edges. 211 for (auto &edge : *node) { 212 os << "// -- "; 213 if (edge.isCall()) 214 os << "Call"; 215 else if (edge.isChild()) 216 os << "Child"; 217 218 os << "-Edge : "; 219 emitNodeName(edge.getTarget()); 220 os << "\n"; 221 } 222 os << "//\n"; 223 } 224 225 os << "// -- SCCs --\n"; 226 227 for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) { 228 os << "// - SCC : \n"; 229 for (auto &node : scc) { 230 os << "// -- Node :"; 231 emitNodeName(node); 232 os << "\n"; 233 } 234 os << "\n"; 235 } 236 237 os << "// -------------------\n"; 238 } 239