12542d345SSlava Zakharin //===- Inliner.cpp ---- SCC-based inliner ---------------------------------===// 22542d345SSlava Zakharin // 32542d345SSlava Zakharin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 42542d345SSlava Zakharin // See https://llvm.org/LICENSE.txt for license information. 52542d345SSlava Zakharin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 62542d345SSlava Zakharin // 72542d345SSlava Zakharin //===----------------------------------------------------------------------===// 82542d345SSlava Zakharin // 92542d345SSlava Zakharin // This file implements Inliner that uses a basic inlining 102542d345SSlava Zakharin // algorithm that operates bottom up over the Strongly Connect Components(SCCs) 112542d345SSlava Zakharin // of the CallGraph. This enables a more incremental propagation of inlining 122542d345SSlava Zakharin // decisions from the leafs to the roots of the callgraph. 132542d345SSlava Zakharin // 142542d345SSlava Zakharin //===----------------------------------------------------------------------===// 152542d345SSlava Zakharin 162542d345SSlava Zakharin #include "mlir/Transforms/Inliner.h" 172542d345SSlava Zakharin #include "mlir/IR/Threading.h" 182542d345SSlava Zakharin #include "mlir/Interfaces/CallInterfaces.h" 192542d345SSlava Zakharin #include "mlir/Interfaces/SideEffectInterfaces.h" 202542d345SSlava Zakharin #include "mlir/Pass/Pass.h" 212542d345SSlava Zakharin #include "mlir/Support/DebugStringHelper.h" 222542d345SSlava Zakharin #include "mlir/Transforms/InliningUtils.h" 232542d345SSlava Zakharin #include "llvm/ADT/SCCIterator.h" 24ad231272SCongcong Cai #include "llvm/ADT/STLExtras.h" 252542d345SSlava Zakharin #include "llvm/ADT/SmallPtrSet.h" 262542d345SSlava Zakharin #include "llvm/Support/Debug.h" 272542d345SSlava Zakharin 282542d345SSlava Zakharin #define DEBUG_TYPE "inlining" 292542d345SSlava Zakharin 302542d345SSlava Zakharin using namespace mlir; 312542d345SSlava Zakharin 322542d345SSlava Zakharin using ResolvedCall = Inliner::ResolvedCall; 332542d345SSlava Zakharin 342542d345SSlava Zakharin //===----------------------------------------------------------------------===// 352542d345SSlava Zakharin // Symbol Use Tracking 362542d345SSlava Zakharin //===----------------------------------------------------------------------===// 372542d345SSlava Zakharin 382542d345SSlava Zakharin /// Walk all of the used symbol callgraph nodes referenced with the given op. 392542d345SSlava Zakharin static void walkReferencedSymbolNodes( 402542d345SSlava Zakharin Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, 412542d345SSlava Zakharin DenseMap<Attribute, CallGraphNode *> &resolvedRefs, 422542d345SSlava Zakharin function_ref<void(CallGraphNode *, Operation *)> callback) { 432542d345SSlava Zakharin auto symbolUses = SymbolTable::getSymbolUses(op); 442542d345SSlava Zakharin assert(symbolUses && "expected uses to be valid"); 452542d345SSlava Zakharin 462542d345SSlava Zakharin Operation *symbolTableOp = op->getParentOp(); 472542d345SSlava Zakharin for (const SymbolTable::SymbolUse &use : *symbolUses) { 482542d345SSlava Zakharin auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr}); 492542d345SSlava Zakharin CallGraphNode *&node = refIt.first->second; 502542d345SSlava Zakharin 512542d345SSlava Zakharin // If this is the first instance of this reference, try to resolve a 522542d345SSlava Zakharin // callgraph node for it. 532542d345SSlava Zakharin if (refIt.second) { 542542d345SSlava Zakharin auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp, 552542d345SSlava Zakharin use.getSymbolRef()); 562542d345SSlava Zakharin auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp); 572542d345SSlava Zakharin if (!callableOp) 582542d345SSlava Zakharin continue; 592542d345SSlava Zakharin node = cg.lookupNode(callableOp.getCallableRegion()); 602542d345SSlava Zakharin } 612542d345SSlava Zakharin if (node) 622542d345SSlava Zakharin callback(node, use.getUser()); 632542d345SSlava Zakharin } 642542d345SSlava Zakharin } 652542d345SSlava Zakharin 662542d345SSlava Zakharin //===----------------------------------------------------------------------===// 672542d345SSlava Zakharin // CGUseList 682542d345SSlava Zakharin 692542d345SSlava Zakharin namespace { 702542d345SSlava Zakharin /// This struct tracks the uses of callgraph nodes that can be dropped when 712542d345SSlava Zakharin /// use_empty. It directly tracks and manages a use-list for all of the 722542d345SSlava Zakharin /// call-graph nodes. This is necessary because many callgraph nodes are 732542d345SSlava Zakharin /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use` 742542d345SSlava Zakharin /// class. 752542d345SSlava Zakharin struct CGUseList { 762542d345SSlava Zakharin /// This struct tracks the uses of callgraph nodes within a specific 772542d345SSlava Zakharin /// operation. 782542d345SSlava Zakharin struct CGUser { 792542d345SSlava Zakharin /// Any nodes referenced in the top-level attribute list of this user. We 802542d345SSlava Zakharin /// use a set here because the number of references does not matter. 812542d345SSlava Zakharin DenseSet<CallGraphNode *> topLevelUses; 822542d345SSlava Zakharin 832542d345SSlava Zakharin /// Uses of nodes referenced by nested operations. 842542d345SSlava Zakharin DenseMap<CallGraphNode *, int> innerUses; 852542d345SSlava Zakharin }; 862542d345SSlava Zakharin 872542d345SSlava Zakharin CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable); 882542d345SSlava Zakharin 892542d345SSlava Zakharin /// Drop uses of nodes referred to by the given call operation that resides 902542d345SSlava Zakharin /// within 'userNode'. 912542d345SSlava Zakharin void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg); 922542d345SSlava Zakharin 932542d345SSlava Zakharin /// Remove the given node from the use list. 942542d345SSlava Zakharin void eraseNode(CallGraphNode *node); 952542d345SSlava Zakharin 962542d345SSlava Zakharin /// Returns true if the given callgraph node has no uses and can be pruned. 972542d345SSlava Zakharin bool isDead(CallGraphNode *node) const; 982542d345SSlava Zakharin 992542d345SSlava Zakharin /// Returns true if the given callgraph node has a single use and can be 1002542d345SSlava Zakharin /// discarded. 1012542d345SSlava Zakharin bool hasOneUseAndDiscardable(CallGraphNode *node) const; 1022542d345SSlava Zakharin 1032542d345SSlava Zakharin /// Recompute the uses held by the given callgraph node. 1042542d345SSlava Zakharin void recomputeUses(CallGraphNode *node, CallGraph &cg); 1052542d345SSlava Zakharin 1062542d345SSlava Zakharin /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy 1072542d345SSlava Zakharin /// of 'lhs' into 'rhs'. 1082542d345SSlava Zakharin void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs); 1092542d345SSlava Zakharin 1102542d345SSlava Zakharin private: 1112542d345SSlava Zakharin /// Decrement the uses of discardable nodes referenced by the given user. 1122542d345SSlava Zakharin void decrementDiscardableUses(CGUser &uses); 1132542d345SSlava Zakharin 1142542d345SSlava Zakharin /// A mapping between a discardable callgraph node (that is a symbol) and the 1152542d345SSlava Zakharin /// number of uses for this node. 1162542d345SSlava Zakharin DenseMap<CallGraphNode *, int> discardableSymNodeUses; 1172542d345SSlava Zakharin 1182542d345SSlava Zakharin /// A mapping between a callgraph node and the symbol callgraph nodes that it 1192542d345SSlava Zakharin /// uses. 1202542d345SSlava Zakharin DenseMap<CallGraphNode *, CGUser> nodeUses; 1212542d345SSlava Zakharin 1222542d345SSlava Zakharin /// A symbol table to use when resolving call lookups. 1232542d345SSlava Zakharin SymbolTableCollection &symbolTable; 1242542d345SSlava Zakharin }; 1252542d345SSlava Zakharin } // namespace 1262542d345SSlava Zakharin 1272542d345SSlava Zakharin CGUseList::CGUseList(Operation *op, CallGraph &cg, 1282542d345SSlava Zakharin SymbolTableCollection &symbolTable) 1292542d345SSlava Zakharin : symbolTable(symbolTable) { 1302542d345SSlava Zakharin /// A set of callgraph nodes that are always known to be live during inlining. 1312542d345SSlava Zakharin DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes; 1322542d345SSlava Zakharin 1332542d345SSlava Zakharin // Walk each of the symbol tables looking for discardable callgraph nodes. 1342542d345SSlava Zakharin auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) { 1352542d345SSlava Zakharin for (Operation &op : symbolTableOp->getRegion(0).getOps()) { 1362542d345SSlava Zakharin // If this is a callgraph operation, check to see if it is discardable. 1372542d345SSlava Zakharin if (auto callable = dyn_cast<CallableOpInterface>(&op)) { 1382542d345SSlava Zakharin if (auto *node = cg.lookupNode(callable.getCallableRegion())) { 1392542d345SSlava Zakharin SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op); 1402542d345SSlava Zakharin if (symbol && (allUsesVisible || symbol.isPrivate()) && 1412542d345SSlava Zakharin symbol.canDiscardOnUseEmpty()) { 1422542d345SSlava Zakharin discardableSymNodeUses.try_emplace(node, 0); 1432542d345SSlava Zakharin } 1442542d345SSlava Zakharin continue; 1452542d345SSlava Zakharin } 1462542d345SSlava Zakharin } 1472542d345SSlava Zakharin // Otherwise, check for any referenced nodes. These will be always-live. 1482542d345SSlava Zakharin walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes, 1492542d345SSlava Zakharin [](CallGraphNode *, Operation *) {}); 1502542d345SSlava Zakharin } 1512542d345SSlava Zakharin }; 1522542d345SSlava Zakharin SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), 1532542d345SSlava Zakharin walkFn); 1542542d345SSlava Zakharin 1552542d345SSlava Zakharin // Drop the use information for any discardable nodes that are always live. 1562542d345SSlava Zakharin for (auto &it : alwaysLiveNodes) 1572542d345SSlava Zakharin discardableSymNodeUses.erase(it.second); 1582542d345SSlava Zakharin 1592542d345SSlava Zakharin // Compute the uses for each of the callable nodes in the graph. 1602542d345SSlava Zakharin for (CallGraphNode *node : cg) 1612542d345SSlava Zakharin recomputeUses(node, cg); 1622542d345SSlava Zakharin } 1632542d345SSlava Zakharin 1642542d345SSlava Zakharin void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp, 1652542d345SSlava Zakharin CallGraph &cg) { 1662542d345SSlava Zakharin auto &userRefs = nodeUses[userNode].innerUses; 1672542d345SSlava Zakharin auto walkFn = [&](CallGraphNode *node, Operation *user) { 1682542d345SSlava Zakharin auto parentIt = userRefs.find(node); 1692542d345SSlava Zakharin if (parentIt == userRefs.end()) 1702542d345SSlava Zakharin return; 1712542d345SSlava Zakharin --parentIt->second; 1722542d345SSlava Zakharin --discardableSymNodeUses[node]; 1732542d345SSlava Zakharin }; 1742542d345SSlava Zakharin DenseMap<Attribute, CallGraphNode *> resolvedRefs; 1752542d345SSlava Zakharin walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn); 1762542d345SSlava Zakharin } 1772542d345SSlava Zakharin 1782542d345SSlava Zakharin void CGUseList::eraseNode(CallGraphNode *node) { 1792542d345SSlava Zakharin // Drop all child nodes. 1802542d345SSlava Zakharin for (auto &edge : *node) 1812542d345SSlava Zakharin if (edge.isChild()) 1822542d345SSlava Zakharin eraseNode(edge.getTarget()); 1832542d345SSlava Zakharin 1842542d345SSlava Zakharin // Drop the uses held by this node and erase it. 1852542d345SSlava Zakharin auto useIt = nodeUses.find(node); 1862542d345SSlava Zakharin assert(useIt != nodeUses.end() && "expected node to be valid"); 1872542d345SSlava Zakharin decrementDiscardableUses(useIt->getSecond()); 1882542d345SSlava Zakharin nodeUses.erase(useIt); 1892542d345SSlava Zakharin discardableSymNodeUses.erase(node); 1902542d345SSlava Zakharin } 1912542d345SSlava Zakharin 1922542d345SSlava Zakharin bool CGUseList::isDead(CallGraphNode *node) const { 1932542d345SSlava Zakharin // If the parent operation isn't a symbol, simply check normal SSA deadness. 1942542d345SSlava Zakharin Operation *nodeOp = node->getCallableRegion()->getParentOp(); 1952542d345SSlava Zakharin if (!isa<SymbolOpInterface>(nodeOp)) 1962542d345SSlava Zakharin return isMemoryEffectFree(nodeOp) && nodeOp->use_empty(); 1972542d345SSlava Zakharin 1982542d345SSlava Zakharin // Otherwise, check the number of symbol uses. 1992542d345SSlava Zakharin auto symbolIt = discardableSymNodeUses.find(node); 2002542d345SSlava Zakharin return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0; 2012542d345SSlava Zakharin } 2022542d345SSlava Zakharin 2032542d345SSlava Zakharin bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const { 2042542d345SSlava Zakharin // If this isn't a symbol node, check for side-effects and SSA use count. 2052542d345SSlava Zakharin Operation *nodeOp = node->getCallableRegion()->getParentOp(); 2062542d345SSlava Zakharin if (!isa<SymbolOpInterface>(nodeOp)) 2072542d345SSlava Zakharin return isMemoryEffectFree(nodeOp) && nodeOp->hasOneUse(); 2082542d345SSlava Zakharin 2092542d345SSlava Zakharin // Otherwise, check the number of symbol uses. 2102542d345SSlava Zakharin auto symbolIt = discardableSymNodeUses.find(node); 2112542d345SSlava Zakharin return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1; 2122542d345SSlava Zakharin } 2132542d345SSlava Zakharin 2142542d345SSlava Zakharin void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) { 2152542d345SSlava Zakharin Operation *parentOp = node->getCallableRegion()->getParentOp(); 2162542d345SSlava Zakharin CGUser &uses = nodeUses[node]; 2172542d345SSlava Zakharin decrementDiscardableUses(uses); 2182542d345SSlava Zakharin 2192542d345SSlava Zakharin // Collect the new discardable uses within this node. 2202542d345SSlava Zakharin uses = CGUser(); 2212542d345SSlava Zakharin DenseMap<Attribute, CallGraphNode *> resolvedRefs; 2222542d345SSlava Zakharin auto walkFn = [&](CallGraphNode *refNode, Operation *user) { 2232542d345SSlava Zakharin auto discardSymIt = discardableSymNodeUses.find(refNode); 2242542d345SSlava Zakharin if (discardSymIt == discardableSymNodeUses.end()) 2252542d345SSlava Zakharin return; 2262542d345SSlava Zakharin 2272542d345SSlava Zakharin if (user != parentOp) 2282542d345SSlava Zakharin ++uses.innerUses[refNode]; 2292542d345SSlava Zakharin else if (!uses.topLevelUses.insert(refNode).second) 2302542d345SSlava Zakharin return; 2312542d345SSlava Zakharin ++discardSymIt->second; 2322542d345SSlava Zakharin }; 2332542d345SSlava Zakharin walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn); 2342542d345SSlava Zakharin } 2352542d345SSlava Zakharin 2362542d345SSlava Zakharin void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) { 2372542d345SSlava Zakharin auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs]; 2382542d345SSlava Zakharin for (auto &useIt : lhsUses.innerUses) { 2392542d345SSlava Zakharin rhsUses.innerUses[useIt.first] += useIt.second; 2402542d345SSlava Zakharin discardableSymNodeUses[useIt.first] += useIt.second; 2412542d345SSlava Zakharin } 2422542d345SSlava Zakharin } 2432542d345SSlava Zakharin 2442542d345SSlava Zakharin void CGUseList::decrementDiscardableUses(CGUser &uses) { 2452542d345SSlava Zakharin for (CallGraphNode *node : uses.topLevelUses) 2462542d345SSlava Zakharin --discardableSymNodeUses[node]; 2472542d345SSlava Zakharin for (auto &it : uses.innerUses) 2482542d345SSlava Zakharin discardableSymNodeUses[it.first] -= it.second; 2492542d345SSlava Zakharin } 2502542d345SSlava Zakharin 2512542d345SSlava Zakharin //===----------------------------------------------------------------------===// 2522542d345SSlava Zakharin // CallGraph traversal 2532542d345SSlava Zakharin //===----------------------------------------------------------------------===// 2542542d345SSlava Zakharin 2552542d345SSlava Zakharin namespace { 2562542d345SSlava Zakharin /// This class represents a specific callgraph SCC. 2572542d345SSlava Zakharin class CallGraphSCC { 2582542d345SSlava Zakharin public: 2592542d345SSlava Zakharin CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator) 2602542d345SSlava Zakharin : parentIterator(parentIterator) {} 2612542d345SSlava Zakharin /// Return a range over the nodes within this SCC. 2622542d345SSlava Zakharin std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); } 2632542d345SSlava Zakharin std::vector<CallGraphNode *>::iterator end() { return nodes.end(); } 2642542d345SSlava Zakharin 2652542d345SSlava Zakharin /// Reset the nodes of this SCC with those provided. 2662542d345SSlava Zakharin void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; } 2672542d345SSlava Zakharin 2682542d345SSlava Zakharin /// Remove the given node from this SCC. 2692542d345SSlava Zakharin void remove(CallGraphNode *node) { 2702542d345SSlava Zakharin auto it = llvm::find(nodes, node); 2712542d345SSlava Zakharin if (it != nodes.end()) { 2722542d345SSlava Zakharin nodes.erase(it); 2732542d345SSlava Zakharin parentIterator.ReplaceNode(node, nullptr); 2742542d345SSlava Zakharin } 2752542d345SSlava Zakharin } 2762542d345SSlava Zakharin 2772542d345SSlava Zakharin private: 2782542d345SSlava Zakharin std::vector<CallGraphNode *> nodes; 2792542d345SSlava Zakharin llvm::scc_iterator<const CallGraph *> &parentIterator; 2802542d345SSlava Zakharin }; 2812542d345SSlava Zakharin } // namespace 2822542d345SSlava Zakharin 2832542d345SSlava Zakharin /// Run a given transformation over the SCCs of the callgraph in a bottom up 2842542d345SSlava Zakharin /// traversal. 2852542d345SSlava Zakharin static LogicalResult runTransformOnCGSCCs( 2862542d345SSlava Zakharin const CallGraph &cg, 2872542d345SSlava Zakharin function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) { 2882542d345SSlava Zakharin llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg); 2892542d345SSlava Zakharin CallGraphSCC currentSCC(cgi); 2902542d345SSlava Zakharin while (!cgi.isAtEnd()) { 2912542d345SSlava Zakharin // Copy the current SCC and increment so that the transformer can modify the 2922542d345SSlava Zakharin // SCC without invalidating our iterator. 2932542d345SSlava Zakharin currentSCC.reset(*cgi); 2942542d345SSlava Zakharin ++cgi; 2952542d345SSlava Zakharin if (failed(sccTransformer(currentSCC))) 2962542d345SSlava Zakharin return failure(); 2972542d345SSlava Zakharin } 2982542d345SSlava Zakharin return success(); 2992542d345SSlava Zakharin } 3002542d345SSlava Zakharin 3012542d345SSlava Zakharin /// Collect all of the callable operations within the given range of blocks. If 3022542d345SSlava Zakharin /// `traverseNestedCGNodes` is true, this will also collect call operations 3032542d345SSlava Zakharin /// inside of nested callgraph nodes. 3042542d345SSlava Zakharin static void collectCallOps(iterator_range<Region::iterator> blocks, 3052542d345SSlava Zakharin CallGraphNode *sourceNode, CallGraph &cg, 3062542d345SSlava Zakharin SymbolTableCollection &symbolTable, 3072542d345SSlava Zakharin SmallVectorImpl<ResolvedCall> &calls, 3082542d345SSlava Zakharin bool traverseNestedCGNodes) { 3092542d345SSlava Zakharin SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist; 3102542d345SSlava Zakharin auto addToWorklist = [&](CallGraphNode *node, 3112542d345SSlava Zakharin iterator_range<Region::iterator> blocks) { 3122542d345SSlava Zakharin for (Block &block : blocks) 3132542d345SSlava Zakharin worklist.emplace_back(&block, node); 3142542d345SSlava Zakharin }; 3152542d345SSlava Zakharin 3162542d345SSlava Zakharin addToWorklist(sourceNode, blocks); 3172542d345SSlava Zakharin while (!worklist.empty()) { 3182542d345SSlava Zakharin Block *block; 3192542d345SSlava Zakharin std::tie(block, sourceNode) = worklist.pop_back_val(); 3202542d345SSlava Zakharin 3212542d345SSlava Zakharin for (Operation &op : *block) { 3222542d345SSlava Zakharin if (auto call = dyn_cast<CallOpInterface>(op)) { 3232542d345SSlava Zakharin // TODO: Support inlining nested call references. 3242542d345SSlava Zakharin CallInterfaceCallable callable = call.getCallableForCallee(); 3252542d345SSlava Zakharin if (SymbolRefAttr symRef = dyn_cast<SymbolRefAttr>(callable)) { 3262542d345SSlava Zakharin if (!isa<FlatSymbolRefAttr>(symRef)) 3272542d345SSlava Zakharin continue; 3282542d345SSlava Zakharin } 3292542d345SSlava Zakharin 3302542d345SSlava Zakharin CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable); 3312542d345SSlava Zakharin if (!targetNode->isExternal()) 3322542d345SSlava Zakharin calls.emplace_back(call, sourceNode, targetNode); 3332542d345SSlava Zakharin continue; 3342542d345SSlava Zakharin } 3352542d345SSlava Zakharin 3362542d345SSlava Zakharin // If this is not a call, traverse the nested regions. If 3372542d345SSlava Zakharin // `traverseNestedCGNodes` is false, then don't traverse nested call graph 3382542d345SSlava Zakharin // regions. 3392542d345SSlava Zakharin for (auto &nestedRegion : op.getRegions()) { 3402542d345SSlava Zakharin CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion); 3412542d345SSlava Zakharin if (traverseNestedCGNodes || !nestedNode) 3422542d345SSlava Zakharin addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion); 3432542d345SSlava Zakharin } 3442542d345SSlava Zakharin } 3452542d345SSlava Zakharin } 3462542d345SSlava Zakharin } 3472542d345SSlava Zakharin 3482542d345SSlava Zakharin //===----------------------------------------------------------------------===// 3492542d345SSlava Zakharin // InlinerInterfaceImpl 3502542d345SSlava Zakharin //===----------------------------------------------------------------------===// 3512542d345SSlava Zakharin 3522542d345SSlava Zakharin #ifndef NDEBUG 3532542d345SSlava Zakharin static std::string getNodeName(CallOpInterface op) { 3542542d345SSlava Zakharin if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee())) 3552542d345SSlava Zakharin return debugString(op); 3562542d345SSlava Zakharin return "_unnamed_callee_"; 3572542d345SSlava Zakharin } 3582542d345SSlava Zakharin #endif 3592542d345SSlava Zakharin 3602542d345SSlava Zakharin /// Return true if the specified `inlineHistoryID` indicates an inline history 3612542d345SSlava Zakharin /// that already includes `node`. 3622542d345SSlava Zakharin static bool inlineHistoryIncludes( 3632542d345SSlava Zakharin CallGraphNode *node, std::optional<size_t> inlineHistoryID, 3642542d345SSlava Zakharin MutableArrayRef<std::pair<CallGraphNode *, std::optional<size_t>>> 3652542d345SSlava Zakharin inlineHistory) { 3662542d345SSlava Zakharin while (inlineHistoryID.has_value()) { 3672542d345SSlava Zakharin assert(*inlineHistoryID < inlineHistory.size() && 3682542d345SSlava Zakharin "Invalid inline history ID"); 3692542d345SSlava Zakharin if (inlineHistory[*inlineHistoryID].first == node) 3702542d345SSlava Zakharin return true; 3712542d345SSlava Zakharin inlineHistoryID = inlineHistory[*inlineHistoryID].second; 3722542d345SSlava Zakharin } 3732542d345SSlava Zakharin return false; 3742542d345SSlava Zakharin } 3752542d345SSlava Zakharin 3762542d345SSlava Zakharin namespace { 3772542d345SSlava Zakharin /// This class provides a specialization of the main inlining interface. 3782542d345SSlava Zakharin struct InlinerInterfaceImpl : public InlinerInterface { 3792542d345SSlava Zakharin InlinerInterfaceImpl(MLIRContext *context, CallGraph &cg, 3802542d345SSlava Zakharin SymbolTableCollection &symbolTable) 3812542d345SSlava Zakharin : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {} 3822542d345SSlava Zakharin 3832542d345SSlava Zakharin /// Process a set of blocks that have been inlined. This callback is invoked 3842542d345SSlava Zakharin /// *before* inlined terminator operations have been processed. 3852542d345SSlava Zakharin void 3862542d345SSlava Zakharin processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final { 3872542d345SSlava Zakharin // Find the closest callgraph node from the first block. 3882542d345SSlava Zakharin CallGraphNode *node; 3892542d345SSlava Zakharin Region *region = inlinedBlocks.begin()->getParent(); 3902542d345SSlava Zakharin while (!(node = cg.lookupNode(region))) { 3912542d345SSlava Zakharin region = region->getParentRegion(); 3922542d345SSlava Zakharin assert(region && "expected valid parent node"); 3932542d345SSlava Zakharin } 3942542d345SSlava Zakharin 3952542d345SSlava Zakharin collectCallOps(inlinedBlocks, node, cg, symbolTable, calls, 3962542d345SSlava Zakharin /*traverseNestedCGNodes=*/true); 3972542d345SSlava Zakharin } 3982542d345SSlava Zakharin 3992542d345SSlava Zakharin /// Mark the given callgraph node for deletion. 4002542d345SSlava Zakharin void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); } 4012542d345SSlava Zakharin 4022542d345SSlava Zakharin /// This method properly disposes of callables that became dead during 4032542d345SSlava Zakharin /// inlining. This should not be called while iterating over the SCCs. 4042542d345SSlava Zakharin void eraseDeadCallables() { 4052542d345SSlava Zakharin for (CallGraphNode *node : deadNodes) 4062542d345SSlava Zakharin node->getCallableRegion()->getParentOp()->erase(); 4072542d345SSlava Zakharin } 4082542d345SSlava Zakharin 4092542d345SSlava Zakharin /// The set of callables known to be dead. 4102542d345SSlava Zakharin SmallPtrSet<CallGraphNode *, 8> deadNodes; 4112542d345SSlava Zakharin 4122542d345SSlava Zakharin /// The current set of call instructions to consider for inlining. 4132542d345SSlava Zakharin SmallVector<ResolvedCall, 8> calls; 4142542d345SSlava Zakharin 4152542d345SSlava Zakharin /// The callgraph being operated on. 4162542d345SSlava Zakharin CallGraph &cg; 4172542d345SSlava Zakharin 4182542d345SSlava Zakharin /// A symbol table to use when resolving call lookups. 4192542d345SSlava Zakharin SymbolTableCollection &symbolTable; 4202542d345SSlava Zakharin }; 4212542d345SSlava Zakharin } // namespace 4222542d345SSlava Zakharin 4232542d345SSlava Zakharin namespace mlir { 4242542d345SSlava Zakharin 4252542d345SSlava Zakharin class Inliner::Impl { 4262542d345SSlava Zakharin public: 4272542d345SSlava Zakharin Impl(Inliner &inliner) : inliner(inliner) {} 4282542d345SSlava Zakharin 4292542d345SSlava Zakharin /// Attempt to inline calls within the given scc, and run simplifications, 4302542d345SSlava Zakharin /// until a fixed point is reached. This allows for the inlining of newly 4312542d345SSlava Zakharin /// devirtualized calls. Returns failure if there was a fatal error during 4322542d345SSlava Zakharin /// inlining. 4332542d345SSlava Zakharin LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface, 4342542d345SSlava Zakharin CGUseList &useList, CallGraphSCC ¤tSCC, 4352542d345SSlava Zakharin MLIRContext *context); 4362542d345SSlava Zakharin 4372542d345SSlava Zakharin private: 4382542d345SSlava Zakharin /// Optimize the nodes within the given SCC with one of the held optimization 4392542d345SSlava Zakharin /// pass pipelines. Returns failure if an error occurred during the 4402542d345SSlava Zakharin /// optimization of the SCC, success otherwise. 4412542d345SSlava Zakharin LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList, 4422542d345SSlava Zakharin CallGraphSCC ¤tSCC, MLIRContext *context); 4432542d345SSlava Zakharin 4442542d345SSlava Zakharin /// Optimize the nodes within the given SCC in parallel. Returns failure if an 4452542d345SSlava Zakharin /// error occurred during the optimization of the SCC, success otherwise. 4462542d345SSlava Zakharin LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit, 4472542d345SSlava Zakharin MLIRContext *context); 4482542d345SSlava Zakharin 4492542d345SSlava Zakharin /// Optimize the given callable node with one of the pass managers provided 4502542d345SSlava Zakharin /// with `pipelines`, or the generic pre-inline pipeline. Returns failure if 4512542d345SSlava Zakharin /// an error occurred during the optimization of the callable, success 4522542d345SSlava Zakharin /// otherwise. 4532542d345SSlava Zakharin LogicalResult optimizeCallable(CallGraphNode *node, 4542542d345SSlava Zakharin llvm::StringMap<OpPassManager> &pipelines); 4552542d345SSlava Zakharin 4562542d345SSlava Zakharin /// Attempt to inline calls within the given scc. This function returns 4572542d345SSlava Zakharin /// success if any calls were inlined, failure otherwise. 4582542d345SSlava Zakharin LogicalResult inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, 4592542d345SSlava Zakharin CGUseList &useList, CallGraphSCC ¤tSCC); 4602542d345SSlava Zakharin 4612542d345SSlava Zakharin /// Returns true if the given call should be inlined. 4622542d345SSlava Zakharin bool shouldInline(ResolvedCall &resolvedCall); 4632542d345SSlava Zakharin 4642542d345SSlava Zakharin private: 4652542d345SSlava Zakharin Inliner &inliner; 4662542d345SSlava Zakharin llvm::SmallVector<llvm::StringMap<OpPassManager>> pipelines; 4672542d345SSlava Zakharin }; 4682542d345SSlava Zakharin 4692542d345SSlava Zakharin LogicalResult Inliner::Impl::inlineSCC(InlinerInterfaceImpl &inlinerIface, 4702542d345SSlava Zakharin CGUseList &useList, 4712542d345SSlava Zakharin CallGraphSCC ¤tSCC, 4722542d345SSlava Zakharin MLIRContext *context) { 4732542d345SSlava Zakharin // Continuously simplify and inline until we either reach a fixed point, or 4742542d345SSlava Zakharin // hit the maximum iteration count. Simplifying early helps to refine the cost 4752542d345SSlava Zakharin // model, and in future iterations may devirtualize new calls. 4762542d345SSlava Zakharin unsigned iterationCount = 0; 4772542d345SSlava Zakharin do { 4782542d345SSlava Zakharin if (failed(optimizeSCC(inlinerIface.cg, useList, currentSCC, context))) 4792542d345SSlava Zakharin return failure(); 4802542d345SSlava Zakharin if (failed(inlineCallsInSCC(inlinerIface, useList, currentSCC))) 4812542d345SSlava Zakharin break; 4822542d345SSlava Zakharin } while (++iterationCount < inliner.config.getMaxInliningIterations()); 4832542d345SSlava Zakharin return success(); 4842542d345SSlava Zakharin } 4852542d345SSlava Zakharin 4862542d345SSlava Zakharin LogicalResult Inliner::Impl::optimizeSCC(CallGraph &cg, CGUseList &useList, 4872542d345SSlava Zakharin CallGraphSCC ¤tSCC, 4882542d345SSlava Zakharin MLIRContext *context) { 4892542d345SSlava Zakharin // Collect the sets of nodes to simplify. 4902542d345SSlava Zakharin SmallVector<CallGraphNode *, 4> nodesToVisit; 4912542d345SSlava Zakharin for (auto *node : currentSCC) { 4922542d345SSlava Zakharin if (node->isExternal()) 4932542d345SSlava Zakharin continue; 4942542d345SSlava Zakharin 4952542d345SSlava Zakharin // Don't simplify nodes with children. Nodes with children require special 4962542d345SSlava Zakharin // handling as we may remove the node during simplification. In the future, 4972542d345SSlava Zakharin // we should be able to handle this case with proper node deletion tracking. 4982542d345SSlava Zakharin if (node->hasChildren()) 4992542d345SSlava Zakharin continue; 5002542d345SSlava Zakharin 5012542d345SSlava Zakharin // We also won't apply simplifications to nodes that can't have passes 5022542d345SSlava Zakharin // scheduled on them. 5032542d345SSlava Zakharin auto *region = node->getCallableRegion(); 5042542d345SSlava Zakharin if (!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()) 5052542d345SSlava Zakharin continue; 5062542d345SSlava Zakharin nodesToVisit.push_back(node); 5072542d345SSlava Zakharin } 5082542d345SSlava Zakharin if (nodesToVisit.empty()) 5092542d345SSlava Zakharin return success(); 5102542d345SSlava Zakharin 5112542d345SSlava Zakharin // Optimize each of the nodes within the SCC in parallel. 5122542d345SSlava Zakharin if (failed(optimizeSCCAsync(nodesToVisit, context))) 5132542d345SSlava Zakharin return failure(); 5142542d345SSlava Zakharin 5152542d345SSlava Zakharin // Recompute the uses held by each of the nodes. 5162542d345SSlava Zakharin for (CallGraphNode *node : nodesToVisit) 5172542d345SSlava Zakharin useList.recomputeUses(node, cg); 5182542d345SSlava Zakharin return success(); 5192542d345SSlava Zakharin } 5202542d345SSlava Zakharin 5212542d345SSlava Zakharin LogicalResult 5222542d345SSlava Zakharin Inliner::Impl::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit, 5232542d345SSlava Zakharin MLIRContext *ctx) { 5242542d345SSlava Zakharin // We must maintain a fixed pool of pass managers which is at least as large 5252542d345SSlava Zakharin // as the maximum parallelism of the failableParallelForEach below. 5262542d345SSlava Zakharin // Note: The number of pass managers here needs to remain constant 5272542d345SSlava Zakharin // to prevent issues with pass instrumentations that rely on having the same 5282542d345SSlava Zakharin // pass manager for the main thread. 5292542d345SSlava Zakharin size_t numThreads = ctx->getNumThreads(); 5302542d345SSlava Zakharin const auto &opPipelines = inliner.config.getOpPipelines(); 5312542d345SSlava Zakharin if (pipelines.size() < numThreads) { 5322542d345SSlava Zakharin pipelines.reserve(numThreads); 5332542d345SSlava Zakharin pipelines.resize(numThreads, opPipelines); 5342542d345SSlava Zakharin } 5352542d345SSlava Zakharin 5362542d345SSlava Zakharin // Ensure an analysis manager has been constructed for each of the nodes. 5372542d345SSlava Zakharin // This prevents thread races when running the nested pipelines. 5382542d345SSlava Zakharin for (CallGraphNode *node : nodesToVisit) 5392542d345SSlava Zakharin inliner.am.nest(node->getCallableRegion()->getParentOp()); 5402542d345SSlava Zakharin 5412542d345SSlava Zakharin // An atomic failure variable for the async executors. 5422542d345SSlava Zakharin std::vector<std::atomic<bool>> activePMs(pipelines.size()); 5432542d345SSlava Zakharin std::fill(activePMs.begin(), activePMs.end(), false); 5442542d345SSlava Zakharin return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) { 5452542d345SSlava Zakharin // Find a pass manager for this operation. 5462542d345SSlava Zakharin auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) { 5472542d345SSlava Zakharin bool expectedInactive = false; 5482542d345SSlava Zakharin return isActive.compare_exchange_strong(expectedInactive, true); 5492542d345SSlava Zakharin }); 5502542d345SSlava Zakharin assert(it != activePMs.end() && 5512542d345SSlava Zakharin "could not find inactive pass manager for thread"); 5522542d345SSlava Zakharin unsigned pmIndex = it - activePMs.begin(); 5532542d345SSlava Zakharin 5542542d345SSlava Zakharin // Optimize this callable node. 5552542d345SSlava Zakharin LogicalResult result = optimizeCallable(node, pipelines[pmIndex]); 5562542d345SSlava Zakharin 5572542d345SSlava Zakharin // Reset the active bit for this pass manager. 5582542d345SSlava Zakharin activePMs[pmIndex].store(false); 5592542d345SSlava Zakharin return result; 5602542d345SSlava Zakharin }); 5612542d345SSlava Zakharin } 5622542d345SSlava Zakharin 5632542d345SSlava Zakharin LogicalResult 5642542d345SSlava Zakharin Inliner::Impl::optimizeCallable(CallGraphNode *node, 5652542d345SSlava Zakharin llvm::StringMap<OpPassManager> &pipelines) { 5662542d345SSlava Zakharin Operation *callable = node->getCallableRegion()->getParentOp(); 5672542d345SSlava Zakharin StringRef opName = callable->getName().getStringRef(); 5682542d345SSlava Zakharin auto pipelineIt = pipelines.find(opName); 5692542d345SSlava Zakharin const auto &defaultPipeline = inliner.config.getDefaultPipeline(); 5702542d345SSlava Zakharin if (pipelineIt == pipelines.end()) { 5712542d345SSlava Zakharin // If a pipeline didn't exist, use the generic pipeline if possible. 5722542d345SSlava Zakharin if (!defaultPipeline) 5732542d345SSlava Zakharin return success(); 5742542d345SSlava Zakharin 5752542d345SSlava Zakharin OpPassManager defaultPM(opName); 5762542d345SSlava Zakharin defaultPipeline(defaultPM); 5772542d345SSlava Zakharin pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first; 5782542d345SSlava Zakharin } 5792542d345SSlava Zakharin return inliner.runPipelineHelper(inliner.pass, pipelineIt->second, callable); 5802542d345SSlava Zakharin } 5812542d345SSlava Zakharin 5822542d345SSlava Zakharin /// Attempt to inline calls within the given scc. This function returns 5832542d345SSlava Zakharin /// success if any calls were inlined, failure otherwise. 5842542d345SSlava Zakharin LogicalResult 5852542d345SSlava Zakharin Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, 5862542d345SSlava Zakharin CGUseList &useList, CallGraphSCC ¤tSCC) { 5872542d345SSlava Zakharin CallGraph &cg = inlinerIface.cg; 5882542d345SSlava Zakharin auto &calls = inlinerIface.calls; 5892542d345SSlava Zakharin 5902542d345SSlava Zakharin // A set of dead nodes to remove after inlining. 5912542d345SSlava Zakharin llvm::SmallSetVector<CallGraphNode *, 1> deadNodes; 5922542d345SSlava Zakharin 5932542d345SSlava Zakharin // Collect all of the direct calls within the nodes of the current SCC. We 5942542d345SSlava Zakharin // don't traverse nested callgraph nodes, because they are handled separately 5952542d345SSlava Zakharin // likely within a different SCC. 5962542d345SSlava Zakharin for (CallGraphNode *node : currentSCC) { 5972542d345SSlava Zakharin if (node->isExternal()) 5982542d345SSlava Zakharin continue; 5992542d345SSlava Zakharin 6002542d345SSlava Zakharin // Don't collect calls if the node is already dead. 6012542d345SSlava Zakharin if (useList.isDead(node)) { 6022542d345SSlava Zakharin deadNodes.insert(node); 6032542d345SSlava Zakharin } else { 6042542d345SSlava Zakharin collectCallOps(*node->getCallableRegion(), node, cg, 6052542d345SSlava Zakharin inlinerIface.symbolTable, calls, 6062542d345SSlava Zakharin /*traverseNestedCGNodes=*/false); 6072542d345SSlava Zakharin } 6082542d345SSlava Zakharin } 6092542d345SSlava Zakharin 6102542d345SSlava Zakharin // When inlining a callee produces new call sites, we want to keep track of 6112542d345SSlava Zakharin // the fact that they were inlined from the callee. This allows us to avoid 6122542d345SSlava Zakharin // infinite inlining. 6132542d345SSlava Zakharin using InlineHistoryT = std::optional<size_t>; 6142542d345SSlava Zakharin SmallVector<std::pair<CallGraphNode *, InlineHistoryT>, 8> inlineHistory; 6152542d345SSlava Zakharin std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{}); 6162542d345SSlava Zakharin 6172542d345SSlava Zakharin LLVM_DEBUG({ 6182542d345SSlava Zakharin llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n"; 6192542d345SSlava Zakharin for (unsigned i = 0, e = calls.size(); i < e; ++i) 6202542d345SSlava Zakharin llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n"; 6212542d345SSlava Zakharin llvm::dbgs() << "}\n"; 6222542d345SSlava Zakharin }); 6232542d345SSlava Zakharin 6242542d345SSlava Zakharin // Try to inline each of the call operations. Don't cache the end iterator 6252542d345SSlava Zakharin // here as more calls may be added during inlining. 6262542d345SSlava Zakharin bool inlinedAnyCalls = false; 6272542d345SSlava Zakharin for (unsigned i = 0; i < calls.size(); ++i) { 6282542d345SSlava Zakharin if (deadNodes.contains(calls[i].sourceNode)) 6292542d345SSlava Zakharin continue; 6302542d345SSlava Zakharin ResolvedCall it = calls[i]; 6312542d345SSlava Zakharin 6322542d345SSlava Zakharin InlineHistoryT inlineHistoryID = callHistory[i]; 6332542d345SSlava Zakharin bool inHistory = 6342542d345SSlava Zakharin inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory); 6352542d345SSlava Zakharin bool doInline = !inHistory && shouldInline(it); 6362542d345SSlava Zakharin CallOpInterface call = it.call; 6372542d345SSlava Zakharin LLVM_DEBUG({ 6382542d345SSlava Zakharin if (doInline) 6392542d345SSlava Zakharin llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n"; 6402542d345SSlava Zakharin else 6412542d345SSlava Zakharin llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n"; 6422542d345SSlava Zakharin }); 6432542d345SSlava Zakharin if (!doInline) 6442542d345SSlava Zakharin continue; 6452542d345SSlava Zakharin 6462542d345SSlava Zakharin unsigned prevSize = calls.size(); 6472542d345SSlava Zakharin Region *targetRegion = it.targetNode->getCallableRegion(); 6482542d345SSlava Zakharin 6492542d345SSlava Zakharin // If this is the last call to the target node and the node is discardable, 6502542d345SSlava Zakharin // then inline it in-place and delete the node if successful. 6512542d345SSlava Zakharin bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode); 6522542d345SSlava Zakharin 6532542d345SSlava Zakharin LogicalResult inlineResult = 6542542d345SSlava Zakharin inlineCall(inlinerIface, call, 6552542d345SSlava Zakharin cast<CallableOpInterface>(targetRegion->getParentOp()), 6562542d345SSlava Zakharin targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace); 6572542d345SSlava Zakharin if (failed(inlineResult)) { 6582542d345SSlava Zakharin LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n"); 6592542d345SSlava Zakharin continue; 6602542d345SSlava Zakharin } 6612542d345SSlava Zakharin inlinedAnyCalls = true; 6622542d345SSlava Zakharin 6632542d345SSlava Zakharin // Create a inline history entry for this inlined call, so that we remember 6642542d345SSlava Zakharin // that new callsites came about due to inlining Callee. 6652542d345SSlava Zakharin InlineHistoryT newInlineHistoryID{inlineHistory.size()}; 6662542d345SSlava Zakharin inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID)); 6672542d345SSlava Zakharin 6682542d345SSlava Zakharin auto historyToString = [](InlineHistoryT h) { 6692542d345SSlava Zakharin return h.has_value() ? std::to_string(*h) : "root"; 6702542d345SSlava Zakharin }; 6712542d345SSlava Zakharin (void)historyToString; 6722542d345SSlava Zakharin LLVM_DEBUG(llvm::dbgs() 6732542d345SSlava Zakharin << "* new inlineHistory entry: " << newInlineHistoryID << ". [" 6742542d345SSlava Zakharin << getNodeName(call) << ", " << historyToString(inlineHistoryID) 6752542d345SSlava Zakharin << "]\n"); 6762542d345SSlava Zakharin 6772542d345SSlava Zakharin for (unsigned k = prevSize; k != calls.size(); ++k) { 6782542d345SSlava Zakharin callHistory.push_back(newInlineHistoryID); 6792542d345SSlava Zakharin LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call 6802542d345SSlava Zakharin << "}\n with historyID = " << newInlineHistoryID 6812542d345SSlava Zakharin << ", added due to inlining of\n call {" << call 6822542d345SSlava Zakharin << "}\n with historyID = " 6832542d345SSlava Zakharin << historyToString(inlineHistoryID) << "\n"); 6842542d345SSlava Zakharin } 6852542d345SSlava Zakharin 6862542d345SSlava Zakharin // If the inlining was successful, Merge the new uses into the source node. 6872542d345SSlava Zakharin useList.dropCallUses(it.sourceNode, call.getOperation(), cg); 6882542d345SSlava Zakharin useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode); 6892542d345SSlava Zakharin 6902542d345SSlava Zakharin // then erase the call. 6912542d345SSlava Zakharin call.erase(); 6922542d345SSlava Zakharin 6932542d345SSlava Zakharin // If we inlined in place, mark the node for deletion. 6942542d345SSlava Zakharin if (inlineInPlace) { 6952542d345SSlava Zakharin useList.eraseNode(it.targetNode); 6962542d345SSlava Zakharin deadNodes.insert(it.targetNode); 6972542d345SSlava Zakharin } 6982542d345SSlava Zakharin } 6992542d345SSlava Zakharin 7002542d345SSlava Zakharin for (CallGraphNode *node : deadNodes) { 7012542d345SSlava Zakharin currentSCC.remove(node); 7022542d345SSlava Zakharin inlinerIface.markForDeletion(node); 7032542d345SSlava Zakharin } 7042542d345SSlava Zakharin calls.clear(); 7052542d345SSlava Zakharin return success(inlinedAnyCalls); 7062542d345SSlava Zakharin } 7072542d345SSlava Zakharin 7082542d345SSlava Zakharin /// Returns true if the given call should be inlined. 7092542d345SSlava Zakharin bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) { 7102542d345SSlava Zakharin // Don't allow inlining terminator calls. We currently don't support this 7112542d345SSlava Zakharin // case. 7122542d345SSlava Zakharin if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>()) 7132542d345SSlava Zakharin return false; 7142542d345SSlava Zakharin 715ad231272SCongcong Cai // Don't allow inlining if the target is a self-recursive function. 716*83df39c6Sjunfengd-nv // Don't allow inlining if the call graph is like A->B->A. 717ad231272SCongcong Cai if (llvm::count_if(*resolvedCall.targetNode, 718ad231272SCongcong Cai [&](CallGraphNode::Edge const &edge) -> bool { 719*83df39c6Sjunfengd-nv return edge.getTarget() == resolvedCall.targetNode || 720*83df39c6Sjunfengd-nv edge.getTarget() == resolvedCall.sourceNode; 721ad231272SCongcong Cai }) > 0) 722ad231272SCongcong Cai return false; 723ad231272SCongcong Cai 7242542d345SSlava Zakharin // Don't allow inlining if the target is an ancestor of the call. This 7252542d345SSlava Zakharin // prevents inlining recursively. 7262542d345SSlava Zakharin Region *callableRegion = resolvedCall.targetNode->getCallableRegion(); 7272542d345SSlava Zakharin if (callableRegion->isAncestor(resolvedCall.call->getParentRegion())) 7282542d345SSlava Zakharin return false; 7292542d345SSlava Zakharin 7302542d345SSlava Zakharin // Don't allow inlining if the callee has multiple blocks (unstructured 7312542d345SSlava Zakharin // control flow) but we cannot be sure that the caller region supports that. 7322542d345SSlava Zakharin bool calleeHasMultipleBlocks = 7332542d345SSlava Zakharin llvm::hasNItemsOrMore(*callableRegion, /*N=*/2); 7342542d345SSlava Zakharin // If both parent ops have the same type, it is safe to inline. Otherwise, 7352542d345SSlava Zakharin // decide based on whether the op has the SingleBlock trait or not. 7362542d345SSlava Zakharin // Note: This check does currently not account for SizedRegion/MaxSizedRegion. 7372542d345SSlava Zakharin auto callerRegionSupportsMultipleBlocks = [&]() { 7382542d345SSlava Zakharin return callableRegion->getParentOp()->getName() == 7392542d345SSlava Zakharin resolvedCall.call->getParentOp()->getName() || 7402542d345SSlava Zakharin !resolvedCall.call->getParentOp() 7412542d345SSlava Zakharin ->mightHaveTrait<OpTrait::SingleBlock>(); 7422542d345SSlava Zakharin }; 7432542d345SSlava Zakharin if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks()) 7442542d345SSlava Zakharin return false; 7452542d345SSlava Zakharin 746732f5368SSlava Zakharin if (!inliner.isProfitableToInline(resolvedCall)) 747732f5368SSlava Zakharin return false; 748732f5368SSlava Zakharin 7492542d345SSlava Zakharin // Otherwise, inline. 7502542d345SSlava Zakharin return true; 7512542d345SSlava Zakharin } 7522542d345SSlava Zakharin 7532542d345SSlava Zakharin LogicalResult Inliner::doInlining() { 7542542d345SSlava Zakharin Impl impl(*this); 7552542d345SSlava Zakharin auto *context = op->getContext(); 7562542d345SSlava Zakharin // Run the inline transform in post-order over the SCCs in the callgraph. 7572542d345SSlava Zakharin SymbolTableCollection symbolTable; 7582542d345SSlava Zakharin // FIXME: some clean-up can be done for the arguments 7592542d345SSlava Zakharin // of the Impl's methods, if the inlinerIface and useList 7602542d345SSlava Zakharin // become the states of the Impl. 7612542d345SSlava Zakharin InlinerInterfaceImpl inlinerIface(context, cg, symbolTable); 7622542d345SSlava Zakharin CGUseList useList(op, cg, symbolTable); 7632542d345SSlava Zakharin LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) { 7642542d345SSlava Zakharin return impl.inlineSCC(inlinerIface, useList, scc, context); 7652542d345SSlava Zakharin }); 7662542d345SSlava Zakharin if (failed(result)) 7672542d345SSlava Zakharin return result; 7682542d345SSlava Zakharin 7692542d345SSlava Zakharin // After inlining, make sure to erase any callables proven to be dead. 7702542d345SSlava Zakharin inlinerIface.eraseDeadCallables(); 7712542d345SSlava Zakharin return success(); 7722542d345SSlava Zakharin } 7732542d345SSlava Zakharin } // namespace mlir 774