xref: /llvm-project/mlir/lib/Transforms/Utils/Inliner.cpp (revision 83df39c649fe1b1dd556d8f2160999c65ce497eb)
12542d345SSlava Zakharin //===- Inliner.cpp ---- SCC-based inliner ---------------------------------===//
22542d345SSlava Zakharin //
32542d345SSlava Zakharin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42542d345SSlava Zakharin // See https://llvm.org/LICENSE.txt for license information.
52542d345SSlava Zakharin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62542d345SSlava Zakharin //
72542d345SSlava Zakharin //===----------------------------------------------------------------------===//
82542d345SSlava Zakharin //
92542d345SSlava Zakharin // This file implements Inliner that uses a basic inlining
102542d345SSlava Zakharin // algorithm that operates bottom up over the Strongly Connect Components(SCCs)
112542d345SSlava Zakharin // of the CallGraph. This enables a more incremental propagation of inlining
122542d345SSlava Zakharin // decisions from the leafs to the roots of the callgraph.
132542d345SSlava Zakharin //
142542d345SSlava Zakharin //===----------------------------------------------------------------------===//
152542d345SSlava Zakharin 
162542d345SSlava Zakharin #include "mlir/Transforms/Inliner.h"
172542d345SSlava Zakharin #include "mlir/IR/Threading.h"
182542d345SSlava Zakharin #include "mlir/Interfaces/CallInterfaces.h"
192542d345SSlava Zakharin #include "mlir/Interfaces/SideEffectInterfaces.h"
202542d345SSlava Zakharin #include "mlir/Pass/Pass.h"
212542d345SSlava Zakharin #include "mlir/Support/DebugStringHelper.h"
222542d345SSlava Zakharin #include "mlir/Transforms/InliningUtils.h"
232542d345SSlava Zakharin #include "llvm/ADT/SCCIterator.h"
24ad231272SCongcong Cai #include "llvm/ADT/STLExtras.h"
252542d345SSlava Zakharin #include "llvm/ADT/SmallPtrSet.h"
262542d345SSlava Zakharin #include "llvm/Support/Debug.h"
272542d345SSlava Zakharin 
282542d345SSlava Zakharin #define DEBUG_TYPE "inlining"
292542d345SSlava Zakharin 
302542d345SSlava Zakharin using namespace mlir;
312542d345SSlava Zakharin 
322542d345SSlava Zakharin using ResolvedCall = Inliner::ResolvedCall;
332542d345SSlava Zakharin 
342542d345SSlava Zakharin //===----------------------------------------------------------------------===//
352542d345SSlava Zakharin // Symbol Use Tracking
362542d345SSlava Zakharin //===----------------------------------------------------------------------===//
372542d345SSlava Zakharin 
382542d345SSlava Zakharin /// Walk all of the used symbol callgraph nodes referenced with the given op.
392542d345SSlava Zakharin static void walkReferencedSymbolNodes(
402542d345SSlava Zakharin     Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
412542d345SSlava Zakharin     DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
422542d345SSlava Zakharin     function_ref<void(CallGraphNode *, Operation *)> callback) {
432542d345SSlava Zakharin   auto symbolUses = SymbolTable::getSymbolUses(op);
442542d345SSlava Zakharin   assert(symbolUses && "expected uses to be valid");
452542d345SSlava Zakharin 
462542d345SSlava Zakharin   Operation *symbolTableOp = op->getParentOp();
472542d345SSlava Zakharin   for (const SymbolTable::SymbolUse &use : *symbolUses) {
482542d345SSlava Zakharin     auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr});
492542d345SSlava Zakharin     CallGraphNode *&node = refIt.first->second;
502542d345SSlava Zakharin 
512542d345SSlava Zakharin     // If this is the first instance of this reference, try to resolve a
522542d345SSlava Zakharin     // callgraph node for it.
532542d345SSlava Zakharin     if (refIt.second) {
542542d345SSlava Zakharin       auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
552542d345SSlava Zakharin                                                            use.getSymbolRef());
562542d345SSlava Zakharin       auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
572542d345SSlava Zakharin       if (!callableOp)
582542d345SSlava Zakharin         continue;
592542d345SSlava Zakharin       node = cg.lookupNode(callableOp.getCallableRegion());
602542d345SSlava Zakharin     }
612542d345SSlava Zakharin     if (node)
622542d345SSlava Zakharin       callback(node, use.getUser());
632542d345SSlava Zakharin   }
642542d345SSlava Zakharin }
652542d345SSlava Zakharin 
662542d345SSlava Zakharin //===----------------------------------------------------------------------===//
672542d345SSlava Zakharin // CGUseList
682542d345SSlava Zakharin 
692542d345SSlava Zakharin namespace {
702542d345SSlava Zakharin /// This struct tracks the uses of callgraph nodes that can be dropped when
712542d345SSlava Zakharin /// use_empty. It directly tracks and manages a use-list for all of the
722542d345SSlava Zakharin /// call-graph nodes. This is necessary because many callgraph nodes are
732542d345SSlava Zakharin /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
742542d345SSlava Zakharin /// class.
752542d345SSlava Zakharin struct CGUseList {
762542d345SSlava Zakharin   /// This struct tracks the uses of callgraph nodes within a specific
772542d345SSlava Zakharin   /// operation.
782542d345SSlava Zakharin   struct CGUser {
792542d345SSlava Zakharin     /// Any nodes referenced in the top-level attribute list of this user. We
802542d345SSlava Zakharin     /// use a set here because the number of references does not matter.
812542d345SSlava Zakharin     DenseSet<CallGraphNode *> topLevelUses;
822542d345SSlava Zakharin 
832542d345SSlava Zakharin     /// Uses of nodes referenced by nested operations.
842542d345SSlava Zakharin     DenseMap<CallGraphNode *, int> innerUses;
852542d345SSlava Zakharin   };
862542d345SSlava Zakharin 
872542d345SSlava Zakharin   CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
882542d345SSlava Zakharin 
892542d345SSlava Zakharin   /// Drop uses of nodes referred to by the given call operation that resides
902542d345SSlava Zakharin   /// within 'userNode'.
912542d345SSlava Zakharin   void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
922542d345SSlava Zakharin 
932542d345SSlava Zakharin   /// Remove the given node from the use list.
942542d345SSlava Zakharin   void eraseNode(CallGraphNode *node);
952542d345SSlava Zakharin 
962542d345SSlava Zakharin   /// Returns true if the given callgraph node has no uses and can be pruned.
972542d345SSlava Zakharin   bool isDead(CallGraphNode *node) const;
982542d345SSlava Zakharin 
992542d345SSlava Zakharin   /// Returns true if the given callgraph node has a single use and can be
1002542d345SSlava Zakharin   /// discarded.
1012542d345SSlava Zakharin   bool hasOneUseAndDiscardable(CallGraphNode *node) const;
1022542d345SSlava Zakharin 
1032542d345SSlava Zakharin   /// Recompute the uses held by the given callgraph node.
1042542d345SSlava Zakharin   void recomputeUses(CallGraphNode *node, CallGraph &cg);
1052542d345SSlava Zakharin 
1062542d345SSlava Zakharin   /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
1072542d345SSlava Zakharin   /// of 'lhs' into 'rhs'.
1082542d345SSlava Zakharin   void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
1092542d345SSlava Zakharin 
1102542d345SSlava Zakharin private:
1112542d345SSlava Zakharin   /// Decrement the uses of discardable nodes referenced by the given user.
1122542d345SSlava Zakharin   void decrementDiscardableUses(CGUser &uses);
1132542d345SSlava Zakharin 
1142542d345SSlava Zakharin   /// A mapping between a discardable callgraph node (that is a symbol) and the
1152542d345SSlava Zakharin   /// number of uses for this node.
1162542d345SSlava Zakharin   DenseMap<CallGraphNode *, int> discardableSymNodeUses;
1172542d345SSlava Zakharin 
1182542d345SSlava Zakharin   /// A mapping between a callgraph node and the symbol callgraph nodes that it
1192542d345SSlava Zakharin   /// uses.
1202542d345SSlava Zakharin   DenseMap<CallGraphNode *, CGUser> nodeUses;
1212542d345SSlava Zakharin 
1222542d345SSlava Zakharin   /// A symbol table to use when resolving call lookups.
1232542d345SSlava Zakharin   SymbolTableCollection &symbolTable;
1242542d345SSlava Zakharin };
1252542d345SSlava Zakharin } // namespace
1262542d345SSlava Zakharin 
1272542d345SSlava Zakharin CGUseList::CGUseList(Operation *op, CallGraph &cg,
1282542d345SSlava Zakharin                      SymbolTableCollection &symbolTable)
1292542d345SSlava Zakharin     : symbolTable(symbolTable) {
1302542d345SSlava Zakharin   /// A set of callgraph nodes that are always known to be live during inlining.
1312542d345SSlava Zakharin   DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
1322542d345SSlava Zakharin 
1332542d345SSlava Zakharin   // Walk each of the symbol tables looking for discardable callgraph nodes.
1342542d345SSlava Zakharin   auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
1352542d345SSlava Zakharin     for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
1362542d345SSlava Zakharin       // If this is a callgraph operation, check to see if it is discardable.
1372542d345SSlava Zakharin       if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
1382542d345SSlava Zakharin         if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
1392542d345SSlava Zakharin           SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
1402542d345SSlava Zakharin           if (symbol && (allUsesVisible || symbol.isPrivate()) &&
1412542d345SSlava Zakharin               symbol.canDiscardOnUseEmpty()) {
1422542d345SSlava Zakharin             discardableSymNodeUses.try_emplace(node, 0);
1432542d345SSlava Zakharin           }
1442542d345SSlava Zakharin           continue;
1452542d345SSlava Zakharin         }
1462542d345SSlava Zakharin       }
1472542d345SSlava Zakharin       // Otherwise, check for any referenced nodes. These will be always-live.
1482542d345SSlava Zakharin       walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
1492542d345SSlava Zakharin                                 [](CallGraphNode *, Operation *) {});
1502542d345SSlava Zakharin     }
1512542d345SSlava Zakharin   };
1522542d345SSlava Zakharin   SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
1532542d345SSlava Zakharin                                 walkFn);
1542542d345SSlava Zakharin 
1552542d345SSlava Zakharin   // Drop the use information for any discardable nodes that are always live.
1562542d345SSlava Zakharin   for (auto &it : alwaysLiveNodes)
1572542d345SSlava Zakharin     discardableSymNodeUses.erase(it.second);
1582542d345SSlava Zakharin 
1592542d345SSlava Zakharin   // Compute the uses for each of the callable nodes in the graph.
1602542d345SSlava Zakharin   for (CallGraphNode *node : cg)
1612542d345SSlava Zakharin     recomputeUses(node, cg);
1622542d345SSlava Zakharin }
1632542d345SSlava Zakharin 
1642542d345SSlava Zakharin void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
1652542d345SSlava Zakharin                              CallGraph &cg) {
1662542d345SSlava Zakharin   auto &userRefs = nodeUses[userNode].innerUses;
1672542d345SSlava Zakharin   auto walkFn = [&](CallGraphNode *node, Operation *user) {
1682542d345SSlava Zakharin     auto parentIt = userRefs.find(node);
1692542d345SSlava Zakharin     if (parentIt == userRefs.end())
1702542d345SSlava Zakharin       return;
1712542d345SSlava Zakharin     --parentIt->second;
1722542d345SSlava Zakharin     --discardableSymNodeUses[node];
1732542d345SSlava Zakharin   };
1742542d345SSlava Zakharin   DenseMap<Attribute, CallGraphNode *> resolvedRefs;
1752542d345SSlava Zakharin   walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
1762542d345SSlava Zakharin }
1772542d345SSlava Zakharin 
1782542d345SSlava Zakharin void CGUseList::eraseNode(CallGraphNode *node) {
1792542d345SSlava Zakharin   // Drop all child nodes.
1802542d345SSlava Zakharin   for (auto &edge : *node)
1812542d345SSlava Zakharin     if (edge.isChild())
1822542d345SSlava Zakharin       eraseNode(edge.getTarget());
1832542d345SSlava Zakharin 
1842542d345SSlava Zakharin   // Drop the uses held by this node and erase it.
1852542d345SSlava Zakharin   auto useIt = nodeUses.find(node);
1862542d345SSlava Zakharin   assert(useIt != nodeUses.end() && "expected node to be valid");
1872542d345SSlava Zakharin   decrementDiscardableUses(useIt->getSecond());
1882542d345SSlava Zakharin   nodeUses.erase(useIt);
1892542d345SSlava Zakharin   discardableSymNodeUses.erase(node);
1902542d345SSlava Zakharin }
1912542d345SSlava Zakharin 
1922542d345SSlava Zakharin bool CGUseList::isDead(CallGraphNode *node) const {
1932542d345SSlava Zakharin   // If the parent operation isn't a symbol, simply check normal SSA deadness.
1942542d345SSlava Zakharin   Operation *nodeOp = node->getCallableRegion()->getParentOp();
1952542d345SSlava Zakharin   if (!isa<SymbolOpInterface>(nodeOp))
1962542d345SSlava Zakharin     return isMemoryEffectFree(nodeOp) && nodeOp->use_empty();
1972542d345SSlava Zakharin 
1982542d345SSlava Zakharin   // Otherwise, check the number of symbol uses.
1992542d345SSlava Zakharin   auto symbolIt = discardableSymNodeUses.find(node);
2002542d345SSlava Zakharin   return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
2012542d345SSlava Zakharin }
2022542d345SSlava Zakharin 
2032542d345SSlava Zakharin bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
2042542d345SSlava Zakharin   // If this isn't a symbol node, check for side-effects and SSA use count.
2052542d345SSlava Zakharin   Operation *nodeOp = node->getCallableRegion()->getParentOp();
2062542d345SSlava Zakharin   if (!isa<SymbolOpInterface>(nodeOp))
2072542d345SSlava Zakharin     return isMemoryEffectFree(nodeOp) && nodeOp->hasOneUse();
2082542d345SSlava Zakharin 
2092542d345SSlava Zakharin   // Otherwise, check the number of symbol uses.
2102542d345SSlava Zakharin   auto symbolIt = discardableSymNodeUses.find(node);
2112542d345SSlava Zakharin   return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
2122542d345SSlava Zakharin }
2132542d345SSlava Zakharin 
2142542d345SSlava Zakharin void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
2152542d345SSlava Zakharin   Operation *parentOp = node->getCallableRegion()->getParentOp();
2162542d345SSlava Zakharin   CGUser &uses = nodeUses[node];
2172542d345SSlava Zakharin   decrementDiscardableUses(uses);
2182542d345SSlava Zakharin 
2192542d345SSlava Zakharin   // Collect the new discardable uses within this node.
2202542d345SSlava Zakharin   uses = CGUser();
2212542d345SSlava Zakharin   DenseMap<Attribute, CallGraphNode *> resolvedRefs;
2222542d345SSlava Zakharin   auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
2232542d345SSlava Zakharin     auto discardSymIt = discardableSymNodeUses.find(refNode);
2242542d345SSlava Zakharin     if (discardSymIt == discardableSymNodeUses.end())
2252542d345SSlava Zakharin       return;
2262542d345SSlava Zakharin 
2272542d345SSlava Zakharin     if (user != parentOp)
2282542d345SSlava Zakharin       ++uses.innerUses[refNode];
2292542d345SSlava Zakharin     else if (!uses.topLevelUses.insert(refNode).second)
2302542d345SSlava Zakharin       return;
2312542d345SSlava Zakharin     ++discardSymIt->second;
2322542d345SSlava Zakharin   };
2332542d345SSlava Zakharin   walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
2342542d345SSlava Zakharin }
2352542d345SSlava Zakharin 
2362542d345SSlava Zakharin void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
2372542d345SSlava Zakharin   auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
2382542d345SSlava Zakharin   for (auto &useIt : lhsUses.innerUses) {
2392542d345SSlava Zakharin     rhsUses.innerUses[useIt.first] += useIt.second;
2402542d345SSlava Zakharin     discardableSymNodeUses[useIt.first] += useIt.second;
2412542d345SSlava Zakharin   }
2422542d345SSlava Zakharin }
2432542d345SSlava Zakharin 
2442542d345SSlava Zakharin void CGUseList::decrementDiscardableUses(CGUser &uses) {
2452542d345SSlava Zakharin   for (CallGraphNode *node : uses.topLevelUses)
2462542d345SSlava Zakharin     --discardableSymNodeUses[node];
2472542d345SSlava Zakharin   for (auto &it : uses.innerUses)
2482542d345SSlava Zakharin     discardableSymNodeUses[it.first] -= it.second;
2492542d345SSlava Zakharin }
2502542d345SSlava Zakharin 
2512542d345SSlava Zakharin //===----------------------------------------------------------------------===//
2522542d345SSlava Zakharin // CallGraph traversal
2532542d345SSlava Zakharin //===----------------------------------------------------------------------===//
2542542d345SSlava Zakharin 
2552542d345SSlava Zakharin namespace {
2562542d345SSlava Zakharin /// This class represents a specific callgraph SCC.
2572542d345SSlava Zakharin class CallGraphSCC {
2582542d345SSlava Zakharin public:
2592542d345SSlava Zakharin   CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
2602542d345SSlava Zakharin       : parentIterator(parentIterator) {}
2612542d345SSlava Zakharin   /// Return a range over the nodes within this SCC.
2622542d345SSlava Zakharin   std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
2632542d345SSlava Zakharin   std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
2642542d345SSlava Zakharin 
2652542d345SSlava Zakharin   /// Reset the nodes of this SCC with those provided.
2662542d345SSlava Zakharin   void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
2672542d345SSlava Zakharin 
2682542d345SSlava Zakharin   /// Remove the given node from this SCC.
2692542d345SSlava Zakharin   void remove(CallGraphNode *node) {
2702542d345SSlava Zakharin     auto it = llvm::find(nodes, node);
2712542d345SSlava Zakharin     if (it != nodes.end()) {
2722542d345SSlava Zakharin       nodes.erase(it);
2732542d345SSlava Zakharin       parentIterator.ReplaceNode(node, nullptr);
2742542d345SSlava Zakharin     }
2752542d345SSlava Zakharin   }
2762542d345SSlava Zakharin 
2772542d345SSlava Zakharin private:
2782542d345SSlava Zakharin   std::vector<CallGraphNode *> nodes;
2792542d345SSlava Zakharin   llvm::scc_iterator<const CallGraph *> &parentIterator;
2802542d345SSlava Zakharin };
2812542d345SSlava Zakharin } // namespace
2822542d345SSlava Zakharin 
2832542d345SSlava Zakharin /// Run a given transformation over the SCCs of the callgraph in a bottom up
2842542d345SSlava Zakharin /// traversal.
2852542d345SSlava Zakharin static LogicalResult runTransformOnCGSCCs(
2862542d345SSlava Zakharin     const CallGraph &cg,
2872542d345SSlava Zakharin     function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
2882542d345SSlava Zakharin   llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
2892542d345SSlava Zakharin   CallGraphSCC currentSCC(cgi);
2902542d345SSlava Zakharin   while (!cgi.isAtEnd()) {
2912542d345SSlava Zakharin     // Copy the current SCC and increment so that the transformer can modify the
2922542d345SSlava Zakharin     // SCC without invalidating our iterator.
2932542d345SSlava Zakharin     currentSCC.reset(*cgi);
2942542d345SSlava Zakharin     ++cgi;
2952542d345SSlava Zakharin     if (failed(sccTransformer(currentSCC)))
2962542d345SSlava Zakharin       return failure();
2972542d345SSlava Zakharin   }
2982542d345SSlava Zakharin   return success();
2992542d345SSlava Zakharin }
3002542d345SSlava Zakharin 
3012542d345SSlava Zakharin /// Collect all of the callable operations within the given range of blocks. If
3022542d345SSlava Zakharin /// `traverseNestedCGNodes` is true, this will also collect call operations
3032542d345SSlava Zakharin /// inside of nested callgraph nodes.
3042542d345SSlava Zakharin static void collectCallOps(iterator_range<Region::iterator> blocks,
3052542d345SSlava Zakharin                            CallGraphNode *sourceNode, CallGraph &cg,
3062542d345SSlava Zakharin                            SymbolTableCollection &symbolTable,
3072542d345SSlava Zakharin                            SmallVectorImpl<ResolvedCall> &calls,
3082542d345SSlava Zakharin                            bool traverseNestedCGNodes) {
3092542d345SSlava Zakharin   SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist;
3102542d345SSlava Zakharin   auto addToWorklist = [&](CallGraphNode *node,
3112542d345SSlava Zakharin                            iterator_range<Region::iterator> blocks) {
3122542d345SSlava Zakharin     for (Block &block : blocks)
3132542d345SSlava Zakharin       worklist.emplace_back(&block, node);
3142542d345SSlava Zakharin   };
3152542d345SSlava Zakharin 
3162542d345SSlava Zakharin   addToWorklist(sourceNode, blocks);
3172542d345SSlava Zakharin   while (!worklist.empty()) {
3182542d345SSlava Zakharin     Block *block;
3192542d345SSlava Zakharin     std::tie(block, sourceNode) = worklist.pop_back_val();
3202542d345SSlava Zakharin 
3212542d345SSlava Zakharin     for (Operation &op : *block) {
3222542d345SSlava Zakharin       if (auto call = dyn_cast<CallOpInterface>(op)) {
3232542d345SSlava Zakharin         // TODO: Support inlining nested call references.
3242542d345SSlava Zakharin         CallInterfaceCallable callable = call.getCallableForCallee();
3252542d345SSlava Zakharin         if (SymbolRefAttr symRef = dyn_cast<SymbolRefAttr>(callable)) {
3262542d345SSlava Zakharin           if (!isa<FlatSymbolRefAttr>(symRef))
3272542d345SSlava Zakharin             continue;
3282542d345SSlava Zakharin         }
3292542d345SSlava Zakharin 
3302542d345SSlava Zakharin         CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
3312542d345SSlava Zakharin         if (!targetNode->isExternal())
3322542d345SSlava Zakharin           calls.emplace_back(call, sourceNode, targetNode);
3332542d345SSlava Zakharin         continue;
3342542d345SSlava Zakharin       }
3352542d345SSlava Zakharin 
3362542d345SSlava Zakharin       // If this is not a call, traverse the nested regions. If
3372542d345SSlava Zakharin       // `traverseNestedCGNodes` is false, then don't traverse nested call graph
3382542d345SSlava Zakharin       // regions.
3392542d345SSlava Zakharin       for (auto &nestedRegion : op.getRegions()) {
3402542d345SSlava Zakharin         CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion);
3412542d345SSlava Zakharin         if (traverseNestedCGNodes || !nestedNode)
3422542d345SSlava Zakharin           addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
3432542d345SSlava Zakharin       }
3442542d345SSlava Zakharin     }
3452542d345SSlava Zakharin   }
3462542d345SSlava Zakharin }
3472542d345SSlava Zakharin 
3482542d345SSlava Zakharin //===----------------------------------------------------------------------===//
3492542d345SSlava Zakharin // InlinerInterfaceImpl
3502542d345SSlava Zakharin //===----------------------------------------------------------------------===//
3512542d345SSlava Zakharin 
3522542d345SSlava Zakharin #ifndef NDEBUG
3532542d345SSlava Zakharin static std::string getNodeName(CallOpInterface op) {
3542542d345SSlava Zakharin   if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
3552542d345SSlava Zakharin     return debugString(op);
3562542d345SSlava Zakharin   return "_unnamed_callee_";
3572542d345SSlava Zakharin }
3582542d345SSlava Zakharin #endif
3592542d345SSlava Zakharin 
3602542d345SSlava Zakharin /// Return true if the specified `inlineHistoryID`  indicates an inline history
3612542d345SSlava Zakharin /// that already includes `node`.
3622542d345SSlava Zakharin static bool inlineHistoryIncludes(
3632542d345SSlava Zakharin     CallGraphNode *node, std::optional<size_t> inlineHistoryID,
3642542d345SSlava Zakharin     MutableArrayRef<std::pair<CallGraphNode *, std::optional<size_t>>>
3652542d345SSlava Zakharin         inlineHistory) {
3662542d345SSlava Zakharin   while (inlineHistoryID.has_value()) {
3672542d345SSlava Zakharin     assert(*inlineHistoryID < inlineHistory.size() &&
3682542d345SSlava Zakharin            "Invalid inline history ID");
3692542d345SSlava Zakharin     if (inlineHistory[*inlineHistoryID].first == node)
3702542d345SSlava Zakharin       return true;
3712542d345SSlava Zakharin     inlineHistoryID = inlineHistory[*inlineHistoryID].second;
3722542d345SSlava Zakharin   }
3732542d345SSlava Zakharin   return false;
3742542d345SSlava Zakharin }
3752542d345SSlava Zakharin 
3762542d345SSlava Zakharin namespace {
3772542d345SSlava Zakharin /// This class provides a specialization of the main inlining interface.
3782542d345SSlava Zakharin struct InlinerInterfaceImpl : public InlinerInterface {
3792542d345SSlava Zakharin   InlinerInterfaceImpl(MLIRContext *context, CallGraph &cg,
3802542d345SSlava Zakharin                        SymbolTableCollection &symbolTable)
3812542d345SSlava Zakharin       : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
3822542d345SSlava Zakharin 
3832542d345SSlava Zakharin   /// Process a set of blocks that have been inlined. This callback is invoked
3842542d345SSlava Zakharin   /// *before* inlined terminator operations have been processed.
3852542d345SSlava Zakharin   void
3862542d345SSlava Zakharin   processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
3872542d345SSlava Zakharin     // Find the closest callgraph node from the first block.
3882542d345SSlava Zakharin     CallGraphNode *node;
3892542d345SSlava Zakharin     Region *region = inlinedBlocks.begin()->getParent();
3902542d345SSlava Zakharin     while (!(node = cg.lookupNode(region))) {
3912542d345SSlava Zakharin       region = region->getParentRegion();
3922542d345SSlava Zakharin       assert(region && "expected valid parent node");
3932542d345SSlava Zakharin     }
3942542d345SSlava Zakharin 
3952542d345SSlava Zakharin     collectCallOps(inlinedBlocks, node, cg, symbolTable, calls,
3962542d345SSlava Zakharin                    /*traverseNestedCGNodes=*/true);
3972542d345SSlava Zakharin   }
3982542d345SSlava Zakharin 
3992542d345SSlava Zakharin   /// Mark the given callgraph node for deletion.
4002542d345SSlava Zakharin   void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
4012542d345SSlava Zakharin 
4022542d345SSlava Zakharin   /// This method properly disposes of callables that became dead during
4032542d345SSlava Zakharin   /// inlining. This should not be called while iterating over the SCCs.
4042542d345SSlava Zakharin   void eraseDeadCallables() {
4052542d345SSlava Zakharin     for (CallGraphNode *node : deadNodes)
4062542d345SSlava Zakharin       node->getCallableRegion()->getParentOp()->erase();
4072542d345SSlava Zakharin   }
4082542d345SSlava Zakharin 
4092542d345SSlava Zakharin   /// The set of callables known to be dead.
4102542d345SSlava Zakharin   SmallPtrSet<CallGraphNode *, 8> deadNodes;
4112542d345SSlava Zakharin 
4122542d345SSlava Zakharin   /// The current set of call instructions to consider for inlining.
4132542d345SSlava Zakharin   SmallVector<ResolvedCall, 8> calls;
4142542d345SSlava Zakharin 
4152542d345SSlava Zakharin   /// The callgraph being operated on.
4162542d345SSlava Zakharin   CallGraph &cg;
4172542d345SSlava Zakharin 
4182542d345SSlava Zakharin   /// A symbol table to use when resolving call lookups.
4192542d345SSlava Zakharin   SymbolTableCollection &symbolTable;
4202542d345SSlava Zakharin };
4212542d345SSlava Zakharin } // namespace
4222542d345SSlava Zakharin 
4232542d345SSlava Zakharin namespace mlir {
4242542d345SSlava Zakharin 
4252542d345SSlava Zakharin class Inliner::Impl {
4262542d345SSlava Zakharin public:
4272542d345SSlava Zakharin   Impl(Inliner &inliner) : inliner(inliner) {}
4282542d345SSlava Zakharin 
4292542d345SSlava Zakharin   /// Attempt to inline calls within the given scc, and run simplifications,
4302542d345SSlava Zakharin   /// until a fixed point is reached. This allows for the inlining of newly
4312542d345SSlava Zakharin   /// devirtualized calls. Returns failure if there was a fatal error during
4322542d345SSlava Zakharin   /// inlining.
4332542d345SSlava Zakharin   LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface,
4342542d345SSlava Zakharin                           CGUseList &useList, CallGraphSCC &currentSCC,
4352542d345SSlava Zakharin                           MLIRContext *context);
4362542d345SSlava Zakharin 
4372542d345SSlava Zakharin private:
4382542d345SSlava Zakharin   /// Optimize the nodes within the given SCC with one of the held optimization
4392542d345SSlava Zakharin   /// pass pipelines. Returns failure if an error occurred during the
4402542d345SSlava Zakharin   /// optimization of the SCC, success otherwise.
4412542d345SSlava Zakharin   LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
4422542d345SSlava Zakharin                             CallGraphSCC &currentSCC, MLIRContext *context);
4432542d345SSlava Zakharin 
4442542d345SSlava Zakharin   /// Optimize the nodes within the given SCC in parallel. Returns failure if an
4452542d345SSlava Zakharin   /// error occurred during the optimization of the SCC, success otherwise.
4462542d345SSlava Zakharin   LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
4472542d345SSlava Zakharin                                  MLIRContext *context);
4482542d345SSlava Zakharin 
4492542d345SSlava Zakharin   /// Optimize the given callable node with one of the pass managers provided
4502542d345SSlava Zakharin   /// with `pipelines`, or the generic pre-inline pipeline. Returns failure if
4512542d345SSlava Zakharin   /// an error occurred during the optimization of the callable, success
4522542d345SSlava Zakharin   /// otherwise.
4532542d345SSlava Zakharin   LogicalResult optimizeCallable(CallGraphNode *node,
4542542d345SSlava Zakharin                                  llvm::StringMap<OpPassManager> &pipelines);
4552542d345SSlava Zakharin 
4562542d345SSlava Zakharin   /// Attempt to inline calls within the given scc. This function returns
4572542d345SSlava Zakharin   /// success if any calls were inlined, failure otherwise.
4582542d345SSlava Zakharin   LogicalResult inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
4592542d345SSlava Zakharin                                  CGUseList &useList, CallGraphSCC &currentSCC);
4602542d345SSlava Zakharin 
4612542d345SSlava Zakharin   /// Returns true if the given call should be inlined.
4622542d345SSlava Zakharin   bool shouldInline(ResolvedCall &resolvedCall);
4632542d345SSlava Zakharin 
4642542d345SSlava Zakharin private:
4652542d345SSlava Zakharin   Inliner &inliner;
4662542d345SSlava Zakharin   llvm::SmallVector<llvm::StringMap<OpPassManager>> pipelines;
4672542d345SSlava Zakharin };
4682542d345SSlava Zakharin 
4692542d345SSlava Zakharin LogicalResult Inliner::Impl::inlineSCC(InlinerInterfaceImpl &inlinerIface,
4702542d345SSlava Zakharin                                        CGUseList &useList,
4712542d345SSlava Zakharin                                        CallGraphSCC &currentSCC,
4722542d345SSlava Zakharin                                        MLIRContext *context) {
4732542d345SSlava Zakharin   // Continuously simplify and inline until we either reach a fixed point, or
4742542d345SSlava Zakharin   // hit the maximum iteration count. Simplifying early helps to refine the cost
4752542d345SSlava Zakharin   // model, and in future iterations may devirtualize new calls.
4762542d345SSlava Zakharin   unsigned iterationCount = 0;
4772542d345SSlava Zakharin   do {
4782542d345SSlava Zakharin     if (failed(optimizeSCC(inlinerIface.cg, useList, currentSCC, context)))
4792542d345SSlava Zakharin       return failure();
4802542d345SSlava Zakharin     if (failed(inlineCallsInSCC(inlinerIface, useList, currentSCC)))
4812542d345SSlava Zakharin       break;
4822542d345SSlava Zakharin   } while (++iterationCount < inliner.config.getMaxInliningIterations());
4832542d345SSlava Zakharin   return success();
4842542d345SSlava Zakharin }
4852542d345SSlava Zakharin 
4862542d345SSlava Zakharin LogicalResult Inliner::Impl::optimizeSCC(CallGraph &cg, CGUseList &useList,
4872542d345SSlava Zakharin                                          CallGraphSCC &currentSCC,
4882542d345SSlava Zakharin                                          MLIRContext *context) {
4892542d345SSlava Zakharin   // Collect the sets of nodes to simplify.
4902542d345SSlava Zakharin   SmallVector<CallGraphNode *, 4> nodesToVisit;
4912542d345SSlava Zakharin   for (auto *node : currentSCC) {
4922542d345SSlava Zakharin     if (node->isExternal())
4932542d345SSlava Zakharin       continue;
4942542d345SSlava Zakharin 
4952542d345SSlava Zakharin     // Don't simplify nodes with children. Nodes with children require special
4962542d345SSlava Zakharin     // handling as we may remove the node during simplification. In the future,
4972542d345SSlava Zakharin     // we should be able to handle this case with proper node deletion tracking.
4982542d345SSlava Zakharin     if (node->hasChildren())
4992542d345SSlava Zakharin       continue;
5002542d345SSlava Zakharin 
5012542d345SSlava Zakharin     // We also won't apply simplifications to nodes that can't have passes
5022542d345SSlava Zakharin     // scheduled on them.
5032542d345SSlava Zakharin     auto *region = node->getCallableRegion();
5042542d345SSlava Zakharin     if (!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
5052542d345SSlava Zakharin       continue;
5062542d345SSlava Zakharin     nodesToVisit.push_back(node);
5072542d345SSlava Zakharin   }
5082542d345SSlava Zakharin   if (nodesToVisit.empty())
5092542d345SSlava Zakharin     return success();
5102542d345SSlava Zakharin 
5112542d345SSlava Zakharin   // Optimize each of the nodes within the SCC in parallel.
5122542d345SSlava Zakharin   if (failed(optimizeSCCAsync(nodesToVisit, context)))
5132542d345SSlava Zakharin     return failure();
5142542d345SSlava Zakharin 
5152542d345SSlava Zakharin   // Recompute the uses held by each of the nodes.
5162542d345SSlava Zakharin   for (CallGraphNode *node : nodesToVisit)
5172542d345SSlava Zakharin     useList.recomputeUses(node, cg);
5182542d345SSlava Zakharin   return success();
5192542d345SSlava Zakharin }
5202542d345SSlava Zakharin 
5212542d345SSlava Zakharin LogicalResult
5222542d345SSlava Zakharin Inliner::Impl::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
5232542d345SSlava Zakharin                                 MLIRContext *ctx) {
5242542d345SSlava Zakharin   // We must maintain a fixed pool of pass managers which is at least as large
5252542d345SSlava Zakharin   // as the maximum parallelism of the failableParallelForEach below.
5262542d345SSlava Zakharin   // Note: The number of pass managers here needs to remain constant
5272542d345SSlava Zakharin   // to prevent issues with pass instrumentations that rely on having the same
5282542d345SSlava Zakharin   // pass manager for the main thread.
5292542d345SSlava Zakharin   size_t numThreads = ctx->getNumThreads();
5302542d345SSlava Zakharin   const auto &opPipelines = inliner.config.getOpPipelines();
5312542d345SSlava Zakharin   if (pipelines.size() < numThreads) {
5322542d345SSlava Zakharin     pipelines.reserve(numThreads);
5332542d345SSlava Zakharin     pipelines.resize(numThreads, opPipelines);
5342542d345SSlava Zakharin   }
5352542d345SSlava Zakharin 
5362542d345SSlava Zakharin   // Ensure an analysis manager has been constructed for each of the nodes.
5372542d345SSlava Zakharin   // This prevents thread races when running the nested pipelines.
5382542d345SSlava Zakharin   for (CallGraphNode *node : nodesToVisit)
5392542d345SSlava Zakharin     inliner.am.nest(node->getCallableRegion()->getParentOp());
5402542d345SSlava Zakharin 
5412542d345SSlava Zakharin   // An atomic failure variable for the async executors.
5422542d345SSlava Zakharin   std::vector<std::atomic<bool>> activePMs(pipelines.size());
5432542d345SSlava Zakharin   std::fill(activePMs.begin(), activePMs.end(), false);
5442542d345SSlava Zakharin   return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) {
5452542d345SSlava Zakharin     // Find a pass manager for this operation.
5462542d345SSlava Zakharin     auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
5472542d345SSlava Zakharin       bool expectedInactive = false;
5482542d345SSlava Zakharin       return isActive.compare_exchange_strong(expectedInactive, true);
5492542d345SSlava Zakharin     });
5502542d345SSlava Zakharin     assert(it != activePMs.end() &&
5512542d345SSlava Zakharin            "could not find inactive pass manager for thread");
5522542d345SSlava Zakharin     unsigned pmIndex = it - activePMs.begin();
5532542d345SSlava Zakharin 
5542542d345SSlava Zakharin     // Optimize this callable node.
5552542d345SSlava Zakharin     LogicalResult result = optimizeCallable(node, pipelines[pmIndex]);
5562542d345SSlava Zakharin 
5572542d345SSlava Zakharin     // Reset the active bit for this pass manager.
5582542d345SSlava Zakharin     activePMs[pmIndex].store(false);
5592542d345SSlava Zakharin     return result;
5602542d345SSlava Zakharin   });
5612542d345SSlava Zakharin }
5622542d345SSlava Zakharin 
5632542d345SSlava Zakharin LogicalResult
5642542d345SSlava Zakharin Inliner::Impl::optimizeCallable(CallGraphNode *node,
5652542d345SSlava Zakharin                                 llvm::StringMap<OpPassManager> &pipelines) {
5662542d345SSlava Zakharin   Operation *callable = node->getCallableRegion()->getParentOp();
5672542d345SSlava Zakharin   StringRef opName = callable->getName().getStringRef();
5682542d345SSlava Zakharin   auto pipelineIt = pipelines.find(opName);
5692542d345SSlava Zakharin   const auto &defaultPipeline = inliner.config.getDefaultPipeline();
5702542d345SSlava Zakharin   if (pipelineIt == pipelines.end()) {
5712542d345SSlava Zakharin     // If a pipeline didn't exist, use the generic pipeline if possible.
5722542d345SSlava Zakharin     if (!defaultPipeline)
5732542d345SSlava Zakharin       return success();
5742542d345SSlava Zakharin 
5752542d345SSlava Zakharin     OpPassManager defaultPM(opName);
5762542d345SSlava Zakharin     defaultPipeline(defaultPM);
5772542d345SSlava Zakharin     pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
5782542d345SSlava Zakharin   }
5792542d345SSlava Zakharin   return inliner.runPipelineHelper(inliner.pass, pipelineIt->second, callable);
5802542d345SSlava Zakharin }
5812542d345SSlava Zakharin 
5822542d345SSlava Zakharin /// Attempt to inline calls within the given scc. This function returns
5832542d345SSlava Zakharin /// success if any calls were inlined, failure otherwise.
5842542d345SSlava Zakharin LogicalResult
5852542d345SSlava Zakharin Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
5862542d345SSlava Zakharin                                 CGUseList &useList, CallGraphSCC &currentSCC) {
5872542d345SSlava Zakharin   CallGraph &cg = inlinerIface.cg;
5882542d345SSlava Zakharin   auto &calls = inlinerIface.calls;
5892542d345SSlava Zakharin 
5902542d345SSlava Zakharin   // A set of dead nodes to remove after inlining.
5912542d345SSlava Zakharin   llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
5922542d345SSlava Zakharin 
5932542d345SSlava Zakharin   // Collect all of the direct calls within the nodes of the current SCC. We
5942542d345SSlava Zakharin   // don't traverse nested callgraph nodes, because they are handled separately
5952542d345SSlava Zakharin   // likely within a different SCC.
5962542d345SSlava Zakharin   for (CallGraphNode *node : currentSCC) {
5972542d345SSlava Zakharin     if (node->isExternal())
5982542d345SSlava Zakharin       continue;
5992542d345SSlava Zakharin 
6002542d345SSlava Zakharin     // Don't collect calls if the node is already dead.
6012542d345SSlava Zakharin     if (useList.isDead(node)) {
6022542d345SSlava Zakharin       deadNodes.insert(node);
6032542d345SSlava Zakharin     } else {
6042542d345SSlava Zakharin       collectCallOps(*node->getCallableRegion(), node, cg,
6052542d345SSlava Zakharin                      inlinerIface.symbolTable, calls,
6062542d345SSlava Zakharin                      /*traverseNestedCGNodes=*/false);
6072542d345SSlava Zakharin     }
6082542d345SSlava Zakharin   }
6092542d345SSlava Zakharin 
6102542d345SSlava Zakharin   // When inlining a callee produces new call sites, we want to keep track of
6112542d345SSlava Zakharin   // the fact that they were inlined from the callee. This allows us to avoid
6122542d345SSlava Zakharin   // infinite inlining.
6132542d345SSlava Zakharin   using InlineHistoryT = std::optional<size_t>;
6142542d345SSlava Zakharin   SmallVector<std::pair<CallGraphNode *, InlineHistoryT>, 8> inlineHistory;
6152542d345SSlava Zakharin   std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
6162542d345SSlava Zakharin 
6172542d345SSlava Zakharin   LLVM_DEBUG({
6182542d345SSlava Zakharin     llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
6192542d345SSlava Zakharin     for (unsigned i = 0, e = calls.size(); i < e; ++i)
6202542d345SSlava Zakharin       llvm::dbgs() << "  " << i << ". " << calls[i].call << ",\n";
6212542d345SSlava Zakharin     llvm::dbgs() << "}\n";
6222542d345SSlava Zakharin   });
6232542d345SSlava Zakharin 
6242542d345SSlava Zakharin   // Try to inline each of the call operations. Don't cache the end iterator
6252542d345SSlava Zakharin   // here as more calls may be added during inlining.
6262542d345SSlava Zakharin   bool inlinedAnyCalls = false;
6272542d345SSlava Zakharin   for (unsigned i = 0; i < calls.size(); ++i) {
6282542d345SSlava Zakharin     if (deadNodes.contains(calls[i].sourceNode))
6292542d345SSlava Zakharin       continue;
6302542d345SSlava Zakharin     ResolvedCall it = calls[i];
6312542d345SSlava Zakharin 
6322542d345SSlava Zakharin     InlineHistoryT inlineHistoryID = callHistory[i];
6332542d345SSlava Zakharin     bool inHistory =
6342542d345SSlava Zakharin         inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory);
6352542d345SSlava Zakharin     bool doInline = !inHistory && shouldInline(it);
6362542d345SSlava Zakharin     CallOpInterface call = it.call;
6372542d345SSlava Zakharin     LLVM_DEBUG({
6382542d345SSlava Zakharin       if (doInline)
6392542d345SSlava Zakharin         llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
6402542d345SSlava Zakharin       else
6412542d345SSlava Zakharin         llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
6422542d345SSlava Zakharin     });
6432542d345SSlava Zakharin     if (!doInline)
6442542d345SSlava Zakharin       continue;
6452542d345SSlava Zakharin 
6462542d345SSlava Zakharin     unsigned prevSize = calls.size();
6472542d345SSlava Zakharin     Region *targetRegion = it.targetNode->getCallableRegion();
6482542d345SSlava Zakharin 
6492542d345SSlava Zakharin     // If this is the last call to the target node and the node is discardable,
6502542d345SSlava Zakharin     // then inline it in-place and delete the node if successful.
6512542d345SSlava Zakharin     bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
6522542d345SSlava Zakharin 
6532542d345SSlava Zakharin     LogicalResult inlineResult =
6542542d345SSlava Zakharin         inlineCall(inlinerIface, call,
6552542d345SSlava Zakharin                    cast<CallableOpInterface>(targetRegion->getParentOp()),
6562542d345SSlava Zakharin                    targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
6572542d345SSlava Zakharin     if (failed(inlineResult)) {
6582542d345SSlava Zakharin       LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
6592542d345SSlava Zakharin       continue;
6602542d345SSlava Zakharin     }
6612542d345SSlava Zakharin     inlinedAnyCalls = true;
6622542d345SSlava Zakharin 
6632542d345SSlava Zakharin     // Create a inline history entry for this inlined call, so that we remember
6642542d345SSlava Zakharin     // that new callsites came about due to inlining Callee.
6652542d345SSlava Zakharin     InlineHistoryT newInlineHistoryID{inlineHistory.size()};
6662542d345SSlava Zakharin     inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID));
6672542d345SSlava Zakharin 
6682542d345SSlava Zakharin     auto historyToString = [](InlineHistoryT h) {
6692542d345SSlava Zakharin       return h.has_value() ? std::to_string(*h) : "root";
6702542d345SSlava Zakharin     };
6712542d345SSlava Zakharin     (void)historyToString;
6722542d345SSlava Zakharin     LLVM_DEBUG(llvm::dbgs()
6732542d345SSlava Zakharin                << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
6742542d345SSlava Zakharin                << getNodeName(call) << ", " << historyToString(inlineHistoryID)
6752542d345SSlava Zakharin                << "]\n");
6762542d345SSlava Zakharin 
6772542d345SSlava Zakharin     for (unsigned k = prevSize; k != calls.size(); ++k) {
6782542d345SSlava Zakharin       callHistory.push_back(newInlineHistoryID);
6792542d345SSlava Zakharin       LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
6802542d345SSlava Zakharin                               << "}\n   with historyID = " << newInlineHistoryID
6812542d345SSlava Zakharin                               << ", added due to inlining of\n  call {" << call
6822542d345SSlava Zakharin                               << "}\n with historyID = "
6832542d345SSlava Zakharin                               << historyToString(inlineHistoryID) << "\n");
6842542d345SSlava Zakharin     }
6852542d345SSlava Zakharin 
6862542d345SSlava Zakharin     // If the inlining was successful, Merge the new uses into the source node.
6872542d345SSlava Zakharin     useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
6882542d345SSlava Zakharin     useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
6892542d345SSlava Zakharin 
6902542d345SSlava Zakharin     // then erase the call.
6912542d345SSlava Zakharin     call.erase();
6922542d345SSlava Zakharin 
6932542d345SSlava Zakharin     // If we inlined in place, mark the node for deletion.
6942542d345SSlava Zakharin     if (inlineInPlace) {
6952542d345SSlava Zakharin       useList.eraseNode(it.targetNode);
6962542d345SSlava Zakharin       deadNodes.insert(it.targetNode);
6972542d345SSlava Zakharin     }
6982542d345SSlava Zakharin   }
6992542d345SSlava Zakharin 
7002542d345SSlava Zakharin   for (CallGraphNode *node : deadNodes) {
7012542d345SSlava Zakharin     currentSCC.remove(node);
7022542d345SSlava Zakharin     inlinerIface.markForDeletion(node);
7032542d345SSlava Zakharin   }
7042542d345SSlava Zakharin   calls.clear();
7052542d345SSlava Zakharin   return success(inlinedAnyCalls);
7062542d345SSlava Zakharin }
7072542d345SSlava Zakharin 
7082542d345SSlava Zakharin /// Returns true if the given call should be inlined.
7092542d345SSlava Zakharin bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) {
7102542d345SSlava Zakharin   // Don't allow inlining terminator calls. We currently don't support this
7112542d345SSlava Zakharin   // case.
7122542d345SSlava Zakharin   if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
7132542d345SSlava Zakharin     return false;
7142542d345SSlava Zakharin 
715ad231272SCongcong Cai   // Don't allow inlining if the target is a self-recursive function.
716*83df39c6Sjunfengd-nv   // Don't allow inlining if the call graph is like A->B->A.
717ad231272SCongcong Cai   if (llvm::count_if(*resolvedCall.targetNode,
718ad231272SCongcong Cai                      [&](CallGraphNode::Edge const &edge) -> bool {
719*83df39c6Sjunfengd-nv                        return edge.getTarget() == resolvedCall.targetNode ||
720*83df39c6Sjunfengd-nv                               edge.getTarget() == resolvedCall.sourceNode;
721ad231272SCongcong Cai                      }) > 0)
722ad231272SCongcong Cai     return false;
723ad231272SCongcong Cai 
7242542d345SSlava Zakharin   // Don't allow inlining if the target is an ancestor of the call. This
7252542d345SSlava Zakharin   // prevents inlining recursively.
7262542d345SSlava Zakharin   Region *callableRegion = resolvedCall.targetNode->getCallableRegion();
7272542d345SSlava Zakharin   if (callableRegion->isAncestor(resolvedCall.call->getParentRegion()))
7282542d345SSlava Zakharin     return false;
7292542d345SSlava Zakharin 
7302542d345SSlava Zakharin   // Don't allow inlining if the callee has multiple blocks (unstructured
7312542d345SSlava Zakharin   // control flow) but we cannot be sure that the caller region supports that.
7322542d345SSlava Zakharin   bool calleeHasMultipleBlocks =
7332542d345SSlava Zakharin       llvm::hasNItemsOrMore(*callableRegion, /*N=*/2);
7342542d345SSlava Zakharin   // If both parent ops have the same type, it is safe to inline. Otherwise,
7352542d345SSlava Zakharin   // decide based on whether the op has the SingleBlock trait or not.
7362542d345SSlava Zakharin   // Note: This check does currently not account for SizedRegion/MaxSizedRegion.
7372542d345SSlava Zakharin   auto callerRegionSupportsMultipleBlocks = [&]() {
7382542d345SSlava Zakharin     return callableRegion->getParentOp()->getName() ==
7392542d345SSlava Zakharin                resolvedCall.call->getParentOp()->getName() ||
7402542d345SSlava Zakharin            !resolvedCall.call->getParentOp()
7412542d345SSlava Zakharin                 ->mightHaveTrait<OpTrait::SingleBlock>();
7422542d345SSlava Zakharin   };
7432542d345SSlava Zakharin   if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
7442542d345SSlava Zakharin     return false;
7452542d345SSlava Zakharin 
746732f5368SSlava Zakharin   if (!inliner.isProfitableToInline(resolvedCall))
747732f5368SSlava Zakharin     return false;
748732f5368SSlava Zakharin 
7492542d345SSlava Zakharin   // Otherwise, inline.
7502542d345SSlava Zakharin   return true;
7512542d345SSlava Zakharin }
7522542d345SSlava Zakharin 
7532542d345SSlava Zakharin LogicalResult Inliner::doInlining() {
7542542d345SSlava Zakharin   Impl impl(*this);
7552542d345SSlava Zakharin   auto *context = op->getContext();
7562542d345SSlava Zakharin   // Run the inline transform in post-order over the SCCs in the callgraph.
7572542d345SSlava Zakharin   SymbolTableCollection symbolTable;
7582542d345SSlava Zakharin   // FIXME: some clean-up can be done for the arguments
7592542d345SSlava Zakharin   // of the Impl's methods, if the inlinerIface and useList
7602542d345SSlava Zakharin   // become the states of the Impl.
7612542d345SSlava Zakharin   InlinerInterfaceImpl inlinerIface(context, cg, symbolTable);
7622542d345SSlava Zakharin   CGUseList useList(op, cg, symbolTable);
7632542d345SSlava Zakharin   LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
7642542d345SSlava Zakharin     return impl.inlineSCC(inlinerIface, useList, scc, context);
7652542d345SSlava Zakharin   });
7662542d345SSlava Zakharin   if (failed(result))
7672542d345SSlava Zakharin     return result;
7682542d345SSlava Zakharin 
7692542d345SSlava Zakharin   // After inlining, make sure to erase any callables proven to be dead.
7702542d345SSlava Zakharin   inlinerIface.eraseDeadCallables();
7712542d345SSlava Zakharin   return success();
7722542d345SSlava Zakharin }
7732542d345SSlava Zakharin } // namespace mlir
774