xref: /llvm-project/mlir/lib/Transforms/Utils/Inliner.cpp (revision 83df39c649fe1b1dd556d8f2160999c65ce497eb)
1 //===- Inliner.cpp ---- SCC-based inliner ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Inliner that uses a basic inlining
10 // algorithm that operates bottom up over the Strongly Connect Components(SCCs)
11 // of the CallGraph. This enables a more incremental propagation of inlining
12 // decisions from the leafs to the roots of the callgraph.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "mlir/Transforms/Inliner.h"
17 #include "mlir/IR/Threading.h"
18 #include "mlir/Interfaces/CallInterfaces.h"
19 #include "mlir/Interfaces/SideEffectInterfaces.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Support/DebugStringHelper.h"
22 #include "mlir/Transforms/InliningUtils.h"
23 #include "llvm/ADT/SCCIterator.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/Support/Debug.h"
27 
28 #define DEBUG_TYPE "inlining"
29 
30 using namespace mlir;
31 
32 using ResolvedCall = Inliner::ResolvedCall;
33 
34 //===----------------------------------------------------------------------===//
35 // Symbol Use Tracking
36 //===----------------------------------------------------------------------===//
37 
38 /// Walk all of the used symbol callgraph nodes referenced with the given op.
39 static void walkReferencedSymbolNodes(
40     Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
41     DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
42     function_ref<void(CallGraphNode *, Operation *)> callback) {
43   auto symbolUses = SymbolTable::getSymbolUses(op);
44   assert(symbolUses && "expected uses to be valid");
45 
46   Operation *symbolTableOp = op->getParentOp();
47   for (const SymbolTable::SymbolUse &use : *symbolUses) {
48     auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr});
49     CallGraphNode *&node = refIt.first->second;
50 
51     // If this is the first instance of this reference, try to resolve a
52     // callgraph node for it.
53     if (refIt.second) {
54       auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
55                                                            use.getSymbolRef());
56       auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
57       if (!callableOp)
58         continue;
59       node = cg.lookupNode(callableOp.getCallableRegion());
60     }
61     if (node)
62       callback(node, use.getUser());
63   }
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // CGUseList
68 
69 namespace {
70 /// This struct tracks the uses of callgraph nodes that can be dropped when
71 /// use_empty. It directly tracks and manages a use-list for all of the
72 /// call-graph nodes. This is necessary because many callgraph nodes are
73 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
74 /// class.
75 struct CGUseList {
76   /// This struct tracks the uses of callgraph nodes within a specific
77   /// operation.
78   struct CGUser {
79     /// Any nodes referenced in the top-level attribute list of this user. We
80     /// use a set here because the number of references does not matter.
81     DenseSet<CallGraphNode *> topLevelUses;
82 
83     /// Uses of nodes referenced by nested operations.
84     DenseMap<CallGraphNode *, int> innerUses;
85   };
86 
87   CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
88 
89   /// Drop uses of nodes referred to by the given call operation that resides
90   /// within 'userNode'.
91   void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
92 
93   /// Remove the given node from the use list.
94   void eraseNode(CallGraphNode *node);
95 
96   /// Returns true if the given callgraph node has no uses and can be pruned.
97   bool isDead(CallGraphNode *node) const;
98 
99   /// Returns true if the given callgraph node has a single use and can be
100   /// discarded.
101   bool hasOneUseAndDiscardable(CallGraphNode *node) const;
102 
103   /// Recompute the uses held by the given callgraph node.
104   void recomputeUses(CallGraphNode *node, CallGraph &cg);
105 
106   /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
107   /// of 'lhs' into 'rhs'.
108   void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
109 
110 private:
111   /// Decrement the uses of discardable nodes referenced by the given user.
112   void decrementDiscardableUses(CGUser &uses);
113 
114   /// A mapping between a discardable callgraph node (that is a symbol) and the
115   /// number of uses for this node.
116   DenseMap<CallGraphNode *, int> discardableSymNodeUses;
117 
118   /// A mapping between a callgraph node and the symbol callgraph nodes that it
119   /// uses.
120   DenseMap<CallGraphNode *, CGUser> nodeUses;
121 
122   /// A symbol table to use when resolving call lookups.
123   SymbolTableCollection &symbolTable;
124 };
125 } // namespace
126 
127 CGUseList::CGUseList(Operation *op, CallGraph &cg,
128                      SymbolTableCollection &symbolTable)
129     : symbolTable(symbolTable) {
130   /// A set of callgraph nodes that are always known to be live during inlining.
131   DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
132 
133   // Walk each of the symbol tables looking for discardable callgraph nodes.
134   auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
135     for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
136       // If this is a callgraph operation, check to see if it is discardable.
137       if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
138         if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
139           SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
140           if (symbol && (allUsesVisible || symbol.isPrivate()) &&
141               symbol.canDiscardOnUseEmpty()) {
142             discardableSymNodeUses.try_emplace(node, 0);
143           }
144           continue;
145         }
146       }
147       // Otherwise, check for any referenced nodes. These will be always-live.
148       walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
149                                 [](CallGraphNode *, Operation *) {});
150     }
151   };
152   SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
153                                 walkFn);
154 
155   // Drop the use information for any discardable nodes that are always live.
156   for (auto &it : alwaysLiveNodes)
157     discardableSymNodeUses.erase(it.second);
158 
159   // Compute the uses for each of the callable nodes in the graph.
160   for (CallGraphNode *node : cg)
161     recomputeUses(node, cg);
162 }
163 
164 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
165                              CallGraph &cg) {
166   auto &userRefs = nodeUses[userNode].innerUses;
167   auto walkFn = [&](CallGraphNode *node, Operation *user) {
168     auto parentIt = userRefs.find(node);
169     if (parentIt == userRefs.end())
170       return;
171     --parentIt->second;
172     --discardableSymNodeUses[node];
173   };
174   DenseMap<Attribute, CallGraphNode *> resolvedRefs;
175   walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
176 }
177 
178 void CGUseList::eraseNode(CallGraphNode *node) {
179   // Drop all child nodes.
180   for (auto &edge : *node)
181     if (edge.isChild())
182       eraseNode(edge.getTarget());
183 
184   // Drop the uses held by this node and erase it.
185   auto useIt = nodeUses.find(node);
186   assert(useIt != nodeUses.end() && "expected node to be valid");
187   decrementDiscardableUses(useIt->getSecond());
188   nodeUses.erase(useIt);
189   discardableSymNodeUses.erase(node);
190 }
191 
192 bool CGUseList::isDead(CallGraphNode *node) const {
193   // If the parent operation isn't a symbol, simply check normal SSA deadness.
194   Operation *nodeOp = node->getCallableRegion()->getParentOp();
195   if (!isa<SymbolOpInterface>(nodeOp))
196     return isMemoryEffectFree(nodeOp) && nodeOp->use_empty();
197 
198   // Otherwise, check the number of symbol uses.
199   auto symbolIt = discardableSymNodeUses.find(node);
200   return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
201 }
202 
203 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
204   // If this isn't a symbol node, check for side-effects and SSA use count.
205   Operation *nodeOp = node->getCallableRegion()->getParentOp();
206   if (!isa<SymbolOpInterface>(nodeOp))
207     return isMemoryEffectFree(nodeOp) && nodeOp->hasOneUse();
208 
209   // Otherwise, check the number of symbol uses.
210   auto symbolIt = discardableSymNodeUses.find(node);
211   return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
212 }
213 
214 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
215   Operation *parentOp = node->getCallableRegion()->getParentOp();
216   CGUser &uses = nodeUses[node];
217   decrementDiscardableUses(uses);
218 
219   // Collect the new discardable uses within this node.
220   uses = CGUser();
221   DenseMap<Attribute, CallGraphNode *> resolvedRefs;
222   auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
223     auto discardSymIt = discardableSymNodeUses.find(refNode);
224     if (discardSymIt == discardableSymNodeUses.end())
225       return;
226 
227     if (user != parentOp)
228       ++uses.innerUses[refNode];
229     else if (!uses.topLevelUses.insert(refNode).second)
230       return;
231     ++discardSymIt->second;
232   };
233   walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
234 }
235 
236 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
237   auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
238   for (auto &useIt : lhsUses.innerUses) {
239     rhsUses.innerUses[useIt.first] += useIt.second;
240     discardableSymNodeUses[useIt.first] += useIt.second;
241   }
242 }
243 
244 void CGUseList::decrementDiscardableUses(CGUser &uses) {
245   for (CallGraphNode *node : uses.topLevelUses)
246     --discardableSymNodeUses[node];
247   for (auto &it : uses.innerUses)
248     discardableSymNodeUses[it.first] -= it.second;
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // CallGraph traversal
253 //===----------------------------------------------------------------------===//
254 
255 namespace {
256 /// This class represents a specific callgraph SCC.
257 class CallGraphSCC {
258 public:
259   CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
260       : parentIterator(parentIterator) {}
261   /// Return a range over the nodes within this SCC.
262   std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
263   std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
264 
265   /// Reset the nodes of this SCC with those provided.
266   void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
267 
268   /// Remove the given node from this SCC.
269   void remove(CallGraphNode *node) {
270     auto it = llvm::find(nodes, node);
271     if (it != nodes.end()) {
272       nodes.erase(it);
273       parentIterator.ReplaceNode(node, nullptr);
274     }
275   }
276 
277 private:
278   std::vector<CallGraphNode *> nodes;
279   llvm::scc_iterator<const CallGraph *> &parentIterator;
280 };
281 } // namespace
282 
283 /// Run a given transformation over the SCCs of the callgraph in a bottom up
284 /// traversal.
285 static LogicalResult runTransformOnCGSCCs(
286     const CallGraph &cg,
287     function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
288   llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
289   CallGraphSCC currentSCC(cgi);
290   while (!cgi.isAtEnd()) {
291     // Copy the current SCC and increment so that the transformer can modify the
292     // SCC without invalidating our iterator.
293     currentSCC.reset(*cgi);
294     ++cgi;
295     if (failed(sccTransformer(currentSCC)))
296       return failure();
297   }
298   return success();
299 }
300 
301 /// Collect all of the callable operations within the given range of blocks. If
302 /// `traverseNestedCGNodes` is true, this will also collect call operations
303 /// inside of nested callgraph nodes.
304 static void collectCallOps(iterator_range<Region::iterator> blocks,
305                            CallGraphNode *sourceNode, CallGraph &cg,
306                            SymbolTableCollection &symbolTable,
307                            SmallVectorImpl<ResolvedCall> &calls,
308                            bool traverseNestedCGNodes) {
309   SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist;
310   auto addToWorklist = [&](CallGraphNode *node,
311                            iterator_range<Region::iterator> blocks) {
312     for (Block &block : blocks)
313       worklist.emplace_back(&block, node);
314   };
315 
316   addToWorklist(sourceNode, blocks);
317   while (!worklist.empty()) {
318     Block *block;
319     std::tie(block, sourceNode) = worklist.pop_back_val();
320 
321     for (Operation &op : *block) {
322       if (auto call = dyn_cast<CallOpInterface>(op)) {
323         // TODO: Support inlining nested call references.
324         CallInterfaceCallable callable = call.getCallableForCallee();
325         if (SymbolRefAttr symRef = dyn_cast<SymbolRefAttr>(callable)) {
326           if (!isa<FlatSymbolRefAttr>(symRef))
327             continue;
328         }
329 
330         CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
331         if (!targetNode->isExternal())
332           calls.emplace_back(call, sourceNode, targetNode);
333         continue;
334       }
335 
336       // If this is not a call, traverse the nested regions. If
337       // `traverseNestedCGNodes` is false, then don't traverse nested call graph
338       // regions.
339       for (auto &nestedRegion : op.getRegions()) {
340         CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion);
341         if (traverseNestedCGNodes || !nestedNode)
342           addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
343       }
344     }
345   }
346 }
347 
348 //===----------------------------------------------------------------------===//
349 // InlinerInterfaceImpl
350 //===----------------------------------------------------------------------===//
351 
352 #ifndef NDEBUG
353 static std::string getNodeName(CallOpInterface op) {
354   if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
355     return debugString(op);
356   return "_unnamed_callee_";
357 }
358 #endif
359 
360 /// Return true if the specified `inlineHistoryID`  indicates an inline history
361 /// that already includes `node`.
362 static bool inlineHistoryIncludes(
363     CallGraphNode *node, std::optional<size_t> inlineHistoryID,
364     MutableArrayRef<std::pair<CallGraphNode *, std::optional<size_t>>>
365         inlineHistory) {
366   while (inlineHistoryID.has_value()) {
367     assert(*inlineHistoryID < inlineHistory.size() &&
368            "Invalid inline history ID");
369     if (inlineHistory[*inlineHistoryID].first == node)
370       return true;
371     inlineHistoryID = inlineHistory[*inlineHistoryID].second;
372   }
373   return false;
374 }
375 
376 namespace {
377 /// This class provides a specialization of the main inlining interface.
378 struct InlinerInterfaceImpl : public InlinerInterface {
379   InlinerInterfaceImpl(MLIRContext *context, CallGraph &cg,
380                        SymbolTableCollection &symbolTable)
381       : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
382 
383   /// Process a set of blocks that have been inlined. This callback is invoked
384   /// *before* inlined terminator operations have been processed.
385   void
386   processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
387     // Find the closest callgraph node from the first block.
388     CallGraphNode *node;
389     Region *region = inlinedBlocks.begin()->getParent();
390     while (!(node = cg.lookupNode(region))) {
391       region = region->getParentRegion();
392       assert(region && "expected valid parent node");
393     }
394 
395     collectCallOps(inlinedBlocks, node, cg, symbolTable, calls,
396                    /*traverseNestedCGNodes=*/true);
397   }
398 
399   /// Mark the given callgraph node for deletion.
400   void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
401 
402   /// This method properly disposes of callables that became dead during
403   /// inlining. This should not be called while iterating over the SCCs.
404   void eraseDeadCallables() {
405     for (CallGraphNode *node : deadNodes)
406       node->getCallableRegion()->getParentOp()->erase();
407   }
408 
409   /// The set of callables known to be dead.
410   SmallPtrSet<CallGraphNode *, 8> deadNodes;
411 
412   /// The current set of call instructions to consider for inlining.
413   SmallVector<ResolvedCall, 8> calls;
414 
415   /// The callgraph being operated on.
416   CallGraph &cg;
417 
418   /// A symbol table to use when resolving call lookups.
419   SymbolTableCollection &symbolTable;
420 };
421 } // namespace
422 
423 namespace mlir {
424 
425 class Inliner::Impl {
426 public:
427   Impl(Inliner &inliner) : inliner(inliner) {}
428 
429   /// Attempt to inline calls within the given scc, and run simplifications,
430   /// until a fixed point is reached. This allows for the inlining of newly
431   /// devirtualized calls. Returns failure if there was a fatal error during
432   /// inlining.
433   LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface,
434                           CGUseList &useList, CallGraphSCC &currentSCC,
435                           MLIRContext *context);
436 
437 private:
438   /// Optimize the nodes within the given SCC with one of the held optimization
439   /// pass pipelines. Returns failure if an error occurred during the
440   /// optimization of the SCC, success otherwise.
441   LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
442                             CallGraphSCC &currentSCC, MLIRContext *context);
443 
444   /// Optimize the nodes within the given SCC in parallel. Returns failure if an
445   /// error occurred during the optimization of the SCC, success otherwise.
446   LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
447                                  MLIRContext *context);
448 
449   /// Optimize the given callable node with one of the pass managers provided
450   /// with `pipelines`, or the generic pre-inline pipeline. Returns failure if
451   /// an error occurred during the optimization of the callable, success
452   /// otherwise.
453   LogicalResult optimizeCallable(CallGraphNode *node,
454                                  llvm::StringMap<OpPassManager> &pipelines);
455 
456   /// Attempt to inline calls within the given scc. This function returns
457   /// success if any calls were inlined, failure otherwise.
458   LogicalResult inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
459                                  CGUseList &useList, CallGraphSCC &currentSCC);
460 
461   /// Returns true if the given call should be inlined.
462   bool shouldInline(ResolvedCall &resolvedCall);
463 
464 private:
465   Inliner &inliner;
466   llvm::SmallVector<llvm::StringMap<OpPassManager>> pipelines;
467 };
468 
469 LogicalResult Inliner::Impl::inlineSCC(InlinerInterfaceImpl &inlinerIface,
470                                        CGUseList &useList,
471                                        CallGraphSCC &currentSCC,
472                                        MLIRContext *context) {
473   // Continuously simplify and inline until we either reach a fixed point, or
474   // hit the maximum iteration count. Simplifying early helps to refine the cost
475   // model, and in future iterations may devirtualize new calls.
476   unsigned iterationCount = 0;
477   do {
478     if (failed(optimizeSCC(inlinerIface.cg, useList, currentSCC, context)))
479       return failure();
480     if (failed(inlineCallsInSCC(inlinerIface, useList, currentSCC)))
481       break;
482   } while (++iterationCount < inliner.config.getMaxInliningIterations());
483   return success();
484 }
485 
486 LogicalResult Inliner::Impl::optimizeSCC(CallGraph &cg, CGUseList &useList,
487                                          CallGraphSCC &currentSCC,
488                                          MLIRContext *context) {
489   // Collect the sets of nodes to simplify.
490   SmallVector<CallGraphNode *, 4> nodesToVisit;
491   for (auto *node : currentSCC) {
492     if (node->isExternal())
493       continue;
494 
495     // Don't simplify nodes with children. Nodes with children require special
496     // handling as we may remove the node during simplification. In the future,
497     // we should be able to handle this case with proper node deletion tracking.
498     if (node->hasChildren())
499       continue;
500 
501     // We also won't apply simplifications to nodes that can't have passes
502     // scheduled on them.
503     auto *region = node->getCallableRegion();
504     if (!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
505       continue;
506     nodesToVisit.push_back(node);
507   }
508   if (nodesToVisit.empty())
509     return success();
510 
511   // Optimize each of the nodes within the SCC in parallel.
512   if (failed(optimizeSCCAsync(nodesToVisit, context)))
513     return failure();
514 
515   // Recompute the uses held by each of the nodes.
516   for (CallGraphNode *node : nodesToVisit)
517     useList.recomputeUses(node, cg);
518   return success();
519 }
520 
521 LogicalResult
522 Inliner::Impl::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
523                                 MLIRContext *ctx) {
524   // We must maintain a fixed pool of pass managers which is at least as large
525   // as the maximum parallelism of the failableParallelForEach below.
526   // Note: The number of pass managers here needs to remain constant
527   // to prevent issues with pass instrumentations that rely on having the same
528   // pass manager for the main thread.
529   size_t numThreads = ctx->getNumThreads();
530   const auto &opPipelines = inliner.config.getOpPipelines();
531   if (pipelines.size() < numThreads) {
532     pipelines.reserve(numThreads);
533     pipelines.resize(numThreads, opPipelines);
534   }
535 
536   // Ensure an analysis manager has been constructed for each of the nodes.
537   // This prevents thread races when running the nested pipelines.
538   for (CallGraphNode *node : nodesToVisit)
539     inliner.am.nest(node->getCallableRegion()->getParentOp());
540 
541   // An atomic failure variable for the async executors.
542   std::vector<std::atomic<bool>> activePMs(pipelines.size());
543   std::fill(activePMs.begin(), activePMs.end(), false);
544   return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) {
545     // Find a pass manager for this operation.
546     auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
547       bool expectedInactive = false;
548       return isActive.compare_exchange_strong(expectedInactive, true);
549     });
550     assert(it != activePMs.end() &&
551            "could not find inactive pass manager for thread");
552     unsigned pmIndex = it - activePMs.begin();
553 
554     // Optimize this callable node.
555     LogicalResult result = optimizeCallable(node, pipelines[pmIndex]);
556 
557     // Reset the active bit for this pass manager.
558     activePMs[pmIndex].store(false);
559     return result;
560   });
561 }
562 
563 LogicalResult
564 Inliner::Impl::optimizeCallable(CallGraphNode *node,
565                                 llvm::StringMap<OpPassManager> &pipelines) {
566   Operation *callable = node->getCallableRegion()->getParentOp();
567   StringRef opName = callable->getName().getStringRef();
568   auto pipelineIt = pipelines.find(opName);
569   const auto &defaultPipeline = inliner.config.getDefaultPipeline();
570   if (pipelineIt == pipelines.end()) {
571     // If a pipeline didn't exist, use the generic pipeline if possible.
572     if (!defaultPipeline)
573       return success();
574 
575     OpPassManager defaultPM(opName);
576     defaultPipeline(defaultPM);
577     pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
578   }
579   return inliner.runPipelineHelper(inliner.pass, pipelineIt->second, callable);
580 }
581 
582 /// Attempt to inline calls within the given scc. This function returns
583 /// success if any calls were inlined, failure otherwise.
584 LogicalResult
585 Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
586                                 CGUseList &useList, CallGraphSCC &currentSCC) {
587   CallGraph &cg = inlinerIface.cg;
588   auto &calls = inlinerIface.calls;
589 
590   // A set of dead nodes to remove after inlining.
591   llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
592 
593   // Collect all of the direct calls within the nodes of the current SCC. We
594   // don't traverse nested callgraph nodes, because they are handled separately
595   // likely within a different SCC.
596   for (CallGraphNode *node : currentSCC) {
597     if (node->isExternal())
598       continue;
599 
600     // Don't collect calls if the node is already dead.
601     if (useList.isDead(node)) {
602       deadNodes.insert(node);
603     } else {
604       collectCallOps(*node->getCallableRegion(), node, cg,
605                      inlinerIface.symbolTable, calls,
606                      /*traverseNestedCGNodes=*/false);
607     }
608   }
609 
610   // When inlining a callee produces new call sites, we want to keep track of
611   // the fact that they were inlined from the callee. This allows us to avoid
612   // infinite inlining.
613   using InlineHistoryT = std::optional<size_t>;
614   SmallVector<std::pair<CallGraphNode *, InlineHistoryT>, 8> inlineHistory;
615   std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
616 
617   LLVM_DEBUG({
618     llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
619     for (unsigned i = 0, e = calls.size(); i < e; ++i)
620       llvm::dbgs() << "  " << i << ". " << calls[i].call << ",\n";
621     llvm::dbgs() << "}\n";
622   });
623 
624   // Try to inline each of the call operations. Don't cache the end iterator
625   // here as more calls may be added during inlining.
626   bool inlinedAnyCalls = false;
627   for (unsigned i = 0; i < calls.size(); ++i) {
628     if (deadNodes.contains(calls[i].sourceNode))
629       continue;
630     ResolvedCall it = calls[i];
631 
632     InlineHistoryT inlineHistoryID = callHistory[i];
633     bool inHistory =
634         inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory);
635     bool doInline = !inHistory && shouldInline(it);
636     CallOpInterface call = it.call;
637     LLVM_DEBUG({
638       if (doInline)
639         llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
640       else
641         llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
642     });
643     if (!doInline)
644       continue;
645 
646     unsigned prevSize = calls.size();
647     Region *targetRegion = it.targetNode->getCallableRegion();
648 
649     // If this is the last call to the target node and the node is discardable,
650     // then inline it in-place and delete the node if successful.
651     bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
652 
653     LogicalResult inlineResult =
654         inlineCall(inlinerIface, call,
655                    cast<CallableOpInterface>(targetRegion->getParentOp()),
656                    targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
657     if (failed(inlineResult)) {
658       LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
659       continue;
660     }
661     inlinedAnyCalls = true;
662 
663     // Create a inline history entry for this inlined call, so that we remember
664     // that new callsites came about due to inlining Callee.
665     InlineHistoryT newInlineHistoryID{inlineHistory.size()};
666     inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID));
667 
668     auto historyToString = [](InlineHistoryT h) {
669       return h.has_value() ? std::to_string(*h) : "root";
670     };
671     (void)historyToString;
672     LLVM_DEBUG(llvm::dbgs()
673                << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
674                << getNodeName(call) << ", " << historyToString(inlineHistoryID)
675                << "]\n");
676 
677     for (unsigned k = prevSize; k != calls.size(); ++k) {
678       callHistory.push_back(newInlineHistoryID);
679       LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
680                               << "}\n   with historyID = " << newInlineHistoryID
681                               << ", added due to inlining of\n  call {" << call
682                               << "}\n with historyID = "
683                               << historyToString(inlineHistoryID) << "\n");
684     }
685 
686     // If the inlining was successful, Merge the new uses into the source node.
687     useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
688     useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
689 
690     // then erase the call.
691     call.erase();
692 
693     // If we inlined in place, mark the node for deletion.
694     if (inlineInPlace) {
695       useList.eraseNode(it.targetNode);
696       deadNodes.insert(it.targetNode);
697     }
698   }
699 
700   for (CallGraphNode *node : deadNodes) {
701     currentSCC.remove(node);
702     inlinerIface.markForDeletion(node);
703   }
704   calls.clear();
705   return success(inlinedAnyCalls);
706 }
707 
708 /// Returns true if the given call should be inlined.
709 bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) {
710   // Don't allow inlining terminator calls. We currently don't support this
711   // case.
712   if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
713     return false;
714 
715   // Don't allow inlining if the target is a self-recursive function.
716   // Don't allow inlining if the call graph is like A->B->A.
717   if (llvm::count_if(*resolvedCall.targetNode,
718                      [&](CallGraphNode::Edge const &edge) -> bool {
719                        return edge.getTarget() == resolvedCall.targetNode ||
720                               edge.getTarget() == resolvedCall.sourceNode;
721                      }) > 0)
722     return false;
723 
724   // Don't allow inlining if the target is an ancestor of the call. This
725   // prevents inlining recursively.
726   Region *callableRegion = resolvedCall.targetNode->getCallableRegion();
727   if (callableRegion->isAncestor(resolvedCall.call->getParentRegion()))
728     return false;
729 
730   // Don't allow inlining if the callee has multiple blocks (unstructured
731   // control flow) but we cannot be sure that the caller region supports that.
732   bool calleeHasMultipleBlocks =
733       llvm::hasNItemsOrMore(*callableRegion, /*N=*/2);
734   // If both parent ops have the same type, it is safe to inline. Otherwise,
735   // decide based on whether the op has the SingleBlock trait or not.
736   // Note: This check does currently not account for SizedRegion/MaxSizedRegion.
737   auto callerRegionSupportsMultipleBlocks = [&]() {
738     return callableRegion->getParentOp()->getName() ==
739                resolvedCall.call->getParentOp()->getName() ||
740            !resolvedCall.call->getParentOp()
741                 ->mightHaveTrait<OpTrait::SingleBlock>();
742   };
743   if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
744     return false;
745 
746   if (!inliner.isProfitableToInline(resolvedCall))
747     return false;
748 
749   // Otherwise, inline.
750   return true;
751 }
752 
753 LogicalResult Inliner::doInlining() {
754   Impl impl(*this);
755   auto *context = op->getContext();
756   // Run the inline transform in post-order over the SCCs in the callgraph.
757   SymbolTableCollection symbolTable;
758   // FIXME: some clean-up can be done for the arguments
759   // of the Impl's methods, if the inlinerIface and useList
760   // become the states of the Impl.
761   InlinerInterfaceImpl inlinerIface(context, cg, symbolTable);
762   CGUseList useList(op, cg, symbolTable);
763   LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
764     return impl.inlineSCC(inlinerIface, useList, scc, context);
765   });
766   if (failed(result))
767     return result;
768 
769   // After inlining, make sure to erase any callables proven to be dead.
770   inlinerIface.eraseDeadCallables();
771   return success();
772 }
773 } // namespace mlir
774