1 //===- Inliner.cpp ---- SCC-based inliner ---------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements Inliner that uses a basic inlining 10 // algorithm that operates bottom up over the Strongly Connect Components(SCCs) 11 // of the CallGraph. This enables a more incremental propagation of inlining 12 // decisions from the leafs to the roots of the callgraph. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "mlir/Transforms/Inliner.h" 17 #include "mlir/IR/Threading.h" 18 #include "mlir/Interfaces/CallInterfaces.h" 19 #include "mlir/Interfaces/SideEffectInterfaces.h" 20 #include "mlir/Pass/Pass.h" 21 #include "mlir/Support/DebugStringHelper.h" 22 #include "mlir/Transforms/InliningUtils.h" 23 #include "llvm/ADT/SCCIterator.h" 24 #include "llvm/ADT/STLExtras.h" 25 #include "llvm/ADT/SmallPtrSet.h" 26 #include "llvm/Support/Debug.h" 27 28 #define DEBUG_TYPE "inlining" 29 30 using namespace mlir; 31 32 using ResolvedCall = Inliner::ResolvedCall; 33 34 //===----------------------------------------------------------------------===// 35 // Symbol Use Tracking 36 //===----------------------------------------------------------------------===// 37 38 /// Walk all of the used symbol callgraph nodes referenced with the given op. 39 static void walkReferencedSymbolNodes( 40 Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, 41 DenseMap<Attribute, CallGraphNode *> &resolvedRefs, 42 function_ref<void(CallGraphNode *, Operation *)> callback) { 43 auto symbolUses = SymbolTable::getSymbolUses(op); 44 assert(symbolUses && "expected uses to be valid"); 45 46 Operation *symbolTableOp = op->getParentOp(); 47 for (const SymbolTable::SymbolUse &use : *symbolUses) { 48 auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr}); 49 CallGraphNode *&node = refIt.first->second; 50 51 // If this is the first instance of this reference, try to resolve a 52 // callgraph node for it. 53 if (refIt.second) { 54 auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp, 55 use.getSymbolRef()); 56 auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp); 57 if (!callableOp) 58 continue; 59 node = cg.lookupNode(callableOp.getCallableRegion()); 60 } 61 if (node) 62 callback(node, use.getUser()); 63 } 64 } 65 66 //===----------------------------------------------------------------------===// 67 // CGUseList 68 69 namespace { 70 /// This struct tracks the uses of callgraph nodes that can be dropped when 71 /// use_empty. It directly tracks and manages a use-list for all of the 72 /// call-graph nodes. This is necessary because many callgraph nodes are 73 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use` 74 /// class. 75 struct CGUseList { 76 /// This struct tracks the uses of callgraph nodes within a specific 77 /// operation. 78 struct CGUser { 79 /// Any nodes referenced in the top-level attribute list of this user. We 80 /// use a set here because the number of references does not matter. 81 DenseSet<CallGraphNode *> topLevelUses; 82 83 /// Uses of nodes referenced by nested operations. 84 DenseMap<CallGraphNode *, int> innerUses; 85 }; 86 87 CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable); 88 89 /// Drop uses of nodes referred to by the given call operation that resides 90 /// within 'userNode'. 91 void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg); 92 93 /// Remove the given node from the use list. 94 void eraseNode(CallGraphNode *node); 95 96 /// Returns true if the given callgraph node has no uses and can be pruned. 97 bool isDead(CallGraphNode *node) const; 98 99 /// Returns true if the given callgraph node has a single use and can be 100 /// discarded. 101 bool hasOneUseAndDiscardable(CallGraphNode *node) const; 102 103 /// Recompute the uses held by the given callgraph node. 104 void recomputeUses(CallGraphNode *node, CallGraph &cg); 105 106 /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy 107 /// of 'lhs' into 'rhs'. 108 void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs); 109 110 private: 111 /// Decrement the uses of discardable nodes referenced by the given user. 112 void decrementDiscardableUses(CGUser &uses); 113 114 /// A mapping between a discardable callgraph node (that is a symbol) and the 115 /// number of uses for this node. 116 DenseMap<CallGraphNode *, int> discardableSymNodeUses; 117 118 /// A mapping between a callgraph node and the symbol callgraph nodes that it 119 /// uses. 120 DenseMap<CallGraphNode *, CGUser> nodeUses; 121 122 /// A symbol table to use when resolving call lookups. 123 SymbolTableCollection &symbolTable; 124 }; 125 } // namespace 126 127 CGUseList::CGUseList(Operation *op, CallGraph &cg, 128 SymbolTableCollection &symbolTable) 129 : symbolTable(symbolTable) { 130 /// A set of callgraph nodes that are always known to be live during inlining. 131 DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes; 132 133 // Walk each of the symbol tables looking for discardable callgraph nodes. 134 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) { 135 for (Operation &op : symbolTableOp->getRegion(0).getOps()) { 136 // If this is a callgraph operation, check to see if it is discardable. 137 if (auto callable = dyn_cast<CallableOpInterface>(&op)) { 138 if (auto *node = cg.lookupNode(callable.getCallableRegion())) { 139 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op); 140 if (symbol && (allUsesVisible || symbol.isPrivate()) && 141 symbol.canDiscardOnUseEmpty()) { 142 discardableSymNodeUses.try_emplace(node, 0); 143 } 144 continue; 145 } 146 } 147 // Otherwise, check for any referenced nodes. These will be always-live. 148 walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes, 149 [](CallGraphNode *, Operation *) {}); 150 } 151 }; 152 SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), 153 walkFn); 154 155 // Drop the use information for any discardable nodes that are always live. 156 for (auto &it : alwaysLiveNodes) 157 discardableSymNodeUses.erase(it.second); 158 159 // Compute the uses for each of the callable nodes in the graph. 160 for (CallGraphNode *node : cg) 161 recomputeUses(node, cg); 162 } 163 164 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp, 165 CallGraph &cg) { 166 auto &userRefs = nodeUses[userNode].innerUses; 167 auto walkFn = [&](CallGraphNode *node, Operation *user) { 168 auto parentIt = userRefs.find(node); 169 if (parentIt == userRefs.end()) 170 return; 171 --parentIt->second; 172 --discardableSymNodeUses[node]; 173 }; 174 DenseMap<Attribute, CallGraphNode *> resolvedRefs; 175 walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn); 176 } 177 178 void CGUseList::eraseNode(CallGraphNode *node) { 179 // Drop all child nodes. 180 for (auto &edge : *node) 181 if (edge.isChild()) 182 eraseNode(edge.getTarget()); 183 184 // Drop the uses held by this node and erase it. 185 auto useIt = nodeUses.find(node); 186 assert(useIt != nodeUses.end() && "expected node to be valid"); 187 decrementDiscardableUses(useIt->getSecond()); 188 nodeUses.erase(useIt); 189 discardableSymNodeUses.erase(node); 190 } 191 192 bool CGUseList::isDead(CallGraphNode *node) const { 193 // If the parent operation isn't a symbol, simply check normal SSA deadness. 194 Operation *nodeOp = node->getCallableRegion()->getParentOp(); 195 if (!isa<SymbolOpInterface>(nodeOp)) 196 return isMemoryEffectFree(nodeOp) && nodeOp->use_empty(); 197 198 // Otherwise, check the number of symbol uses. 199 auto symbolIt = discardableSymNodeUses.find(node); 200 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0; 201 } 202 203 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const { 204 // If this isn't a symbol node, check for side-effects and SSA use count. 205 Operation *nodeOp = node->getCallableRegion()->getParentOp(); 206 if (!isa<SymbolOpInterface>(nodeOp)) 207 return isMemoryEffectFree(nodeOp) && nodeOp->hasOneUse(); 208 209 // Otherwise, check the number of symbol uses. 210 auto symbolIt = discardableSymNodeUses.find(node); 211 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1; 212 } 213 214 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) { 215 Operation *parentOp = node->getCallableRegion()->getParentOp(); 216 CGUser &uses = nodeUses[node]; 217 decrementDiscardableUses(uses); 218 219 // Collect the new discardable uses within this node. 220 uses = CGUser(); 221 DenseMap<Attribute, CallGraphNode *> resolvedRefs; 222 auto walkFn = [&](CallGraphNode *refNode, Operation *user) { 223 auto discardSymIt = discardableSymNodeUses.find(refNode); 224 if (discardSymIt == discardableSymNodeUses.end()) 225 return; 226 227 if (user != parentOp) 228 ++uses.innerUses[refNode]; 229 else if (!uses.topLevelUses.insert(refNode).second) 230 return; 231 ++discardSymIt->second; 232 }; 233 walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn); 234 } 235 236 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) { 237 auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs]; 238 for (auto &useIt : lhsUses.innerUses) { 239 rhsUses.innerUses[useIt.first] += useIt.second; 240 discardableSymNodeUses[useIt.first] += useIt.second; 241 } 242 } 243 244 void CGUseList::decrementDiscardableUses(CGUser &uses) { 245 for (CallGraphNode *node : uses.topLevelUses) 246 --discardableSymNodeUses[node]; 247 for (auto &it : uses.innerUses) 248 discardableSymNodeUses[it.first] -= it.second; 249 } 250 251 //===----------------------------------------------------------------------===// 252 // CallGraph traversal 253 //===----------------------------------------------------------------------===// 254 255 namespace { 256 /// This class represents a specific callgraph SCC. 257 class CallGraphSCC { 258 public: 259 CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator) 260 : parentIterator(parentIterator) {} 261 /// Return a range over the nodes within this SCC. 262 std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); } 263 std::vector<CallGraphNode *>::iterator end() { return nodes.end(); } 264 265 /// Reset the nodes of this SCC with those provided. 266 void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; } 267 268 /// Remove the given node from this SCC. 269 void remove(CallGraphNode *node) { 270 auto it = llvm::find(nodes, node); 271 if (it != nodes.end()) { 272 nodes.erase(it); 273 parentIterator.ReplaceNode(node, nullptr); 274 } 275 } 276 277 private: 278 std::vector<CallGraphNode *> nodes; 279 llvm::scc_iterator<const CallGraph *> &parentIterator; 280 }; 281 } // namespace 282 283 /// Run a given transformation over the SCCs of the callgraph in a bottom up 284 /// traversal. 285 static LogicalResult runTransformOnCGSCCs( 286 const CallGraph &cg, 287 function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) { 288 llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg); 289 CallGraphSCC currentSCC(cgi); 290 while (!cgi.isAtEnd()) { 291 // Copy the current SCC and increment so that the transformer can modify the 292 // SCC without invalidating our iterator. 293 currentSCC.reset(*cgi); 294 ++cgi; 295 if (failed(sccTransformer(currentSCC))) 296 return failure(); 297 } 298 return success(); 299 } 300 301 /// Collect all of the callable operations within the given range of blocks. If 302 /// `traverseNestedCGNodes` is true, this will also collect call operations 303 /// inside of nested callgraph nodes. 304 static void collectCallOps(iterator_range<Region::iterator> blocks, 305 CallGraphNode *sourceNode, CallGraph &cg, 306 SymbolTableCollection &symbolTable, 307 SmallVectorImpl<ResolvedCall> &calls, 308 bool traverseNestedCGNodes) { 309 SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist; 310 auto addToWorklist = [&](CallGraphNode *node, 311 iterator_range<Region::iterator> blocks) { 312 for (Block &block : blocks) 313 worklist.emplace_back(&block, node); 314 }; 315 316 addToWorklist(sourceNode, blocks); 317 while (!worklist.empty()) { 318 Block *block; 319 std::tie(block, sourceNode) = worklist.pop_back_val(); 320 321 for (Operation &op : *block) { 322 if (auto call = dyn_cast<CallOpInterface>(op)) { 323 // TODO: Support inlining nested call references. 324 CallInterfaceCallable callable = call.getCallableForCallee(); 325 if (SymbolRefAttr symRef = dyn_cast<SymbolRefAttr>(callable)) { 326 if (!isa<FlatSymbolRefAttr>(symRef)) 327 continue; 328 } 329 330 CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable); 331 if (!targetNode->isExternal()) 332 calls.emplace_back(call, sourceNode, targetNode); 333 continue; 334 } 335 336 // If this is not a call, traverse the nested regions. If 337 // `traverseNestedCGNodes` is false, then don't traverse nested call graph 338 // regions. 339 for (auto &nestedRegion : op.getRegions()) { 340 CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion); 341 if (traverseNestedCGNodes || !nestedNode) 342 addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion); 343 } 344 } 345 } 346 } 347 348 //===----------------------------------------------------------------------===// 349 // InlinerInterfaceImpl 350 //===----------------------------------------------------------------------===// 351 352 #ifndef NDEBUG 353 static std::string getNodeName(CallOpInterface op) { 354 if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee())) 355 return debugString(op); 356 return "_unnamed_callee_"; 357 } 358 #endif 359 360 /// Return true if the specified `inlineHistoryID` indicates an inline history 361 /// that already includes `node`. 362 static bool inlineHistoryIncludes( 363 CallGraphNode *node, std::optional<size_t> inlineHistoryID, 364 MutableArrayRef<std::pair<CallGraphNode *, std::optional<size_t>>> 365 inlineHistory) { 366 while (inlineHistoryID.has_value()) { 367 assert(*inlineHistoryID < inlineHistory.size() && 368 "Invalid inline history ID"); 369 if (inlineHistory[*inlineHistoryID].first == node) 370 return true; 371 inlineHistoryID = inlineHistory[*inlineHistoryID].second; 372 } 373 return false; 374 } 375 376 namespace { 377 /// This class provides a specialization of the main inlining interface. 378 struct InlinerInterfaceImpl : public InlinerInterface { 379 InlinerInterfaceImpl(MLIRContext *context, CallGraph &cg, 380 SymbolTableCollection &symbolTable) 381 : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {} 382 383 /// Process a set of blocks that have been inlined. This callback is invoked 384 /// *before* inlined terminator operations have been processed. 385 void 386 processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final { 387 // Find the closest callgraph node from the first block. 388 CallGraphNode *node; 389 Region *region = inlinedBlocks.begin()->getParent(); 390 while (!(node = cg.lookupNode(region))) { 391 region = region->getParentRegion(); 392 assert(region && "expected valid parent node"); 393 } 394 395 collectCallOps(inlinedBlocks, node, cg, symbolTable, calls, 396 /*traverseNestedCGNodes=*/true); 397 } 398 399 /// Mark the given callgraph node for deletion. 400 void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); } 401 402 /// This method properly disposes of callables that became dead during 403 /// inlining. This should not be called while iterating over the SCCs. 404 void eraseDeadCallables() { 405 for (CallGraphNode *node : deadNodes) 406 node->getCallableRegion()->getParentOp()->erase(); 407 } 408 409 /// The set of callables known to be dead. 410 SmallPtrSet<CallGraphNode *, 8> deadNodes; 411 412 /// The current set of call instructions to consider for inlining. 413 SmallVector<ResolvedCall, 8> calls; 414 415 /// The callgraph being operated on. 416 CallGraph &cg; 417 418 /// A symbol table to use when resolving call lookups. 419 SymbolTableCollection &symbolTable; 420 }; 421 } // namespace 422 423 namespace mlir { 424 425 class Inliner::Impl { 426 public: 427 Impl(Inliner &inliner) : inliner(inliner) {} 428 429 /// Attempt to inline calls within the given scc, and run simplifications, 430 /// until a fixed point is reached. This allows for the inlining of newly 431 /// devirtualized calls. Returns failure if there was a fatal error during 432 /// inlining. 433 LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface, 434 CGUseList &useList, CallGraphSCC ¤tSCC, 435 MLIRContext *context); 436 437 private: 438 /// Optimize the nodes within the given SCC with one of the held optimization 439 /// pass pipelines. Returns failure if an error occurred during the 440 /// optimization of the SCC, success otherwise. 441 LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList, 442 CallGraphSCC ¤tSCC, MLIRContext *context); 443 444 /// Optimize the nodes within the given SCC in parallel. Returns failure if an 445 /// error occurred during the optimization of the SCC, success otherwise. 446 LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit, 447 MLIRContext *context); 448 449 /// Optimize the given callable node with one of the pass managers provided 450 /// with `pipelines`, or the generic pre-inline pipeline. Returns failure if 451 /// an error occurred during the optimization of the callable, success 452 /// otherwise. 453 LogicalResult optimizeCallable(CallGraphNode *node, 454 llvm::StringMap<OpPassManager> &pipelines); 455 456 /// Attempt to inline calls within the given scc. This function returns 457 /// success if any calls were inlined, failure otherwise. 458 LogicalResult inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, 459 CGUseList &useList, CallGraphSCC ¤tSCC); 460 461 /// Returns true if the given call should be inlined. 462 bool shouldInline(ResolvedCall &resolvedCall); 463 464 private: 465 Inliner &inliner; 466 llvm::SmallVector<llvm::StringMap<OpPassManager>> pipelines; 467 }; 468 469 LogicalResult Inliner::Impl::inlineSCC(InlinerInterfaceImpl &inlinerIface, 470 CGUseList &useList, 471 CallGraphSCC ¤tSCC, 472 MLIRContext *context) { 473 // Continuously simplify and inline until we either reach a fixed point, or 474 // hit the maximum iteration count. Simplifying early helps to refine the cost 475 // model, and in future iterations may devirtualize new calls. 476 unsigned iterationCount = 0; 477 do { 478 if (failed(optimizeSCC(inlinerIface.cg, useList, currentSCC, context))) 479 return failure(); 480 if (failed(inlineCallsInSCC(inlinerIface, useList, currentSCC))) 481 break; 482 } while (++iterationCount < inliner.config.getMaxInliningIterations()); 483 return success(); 484 } 485 486 LogicalResult Inliner::Impl::optimizeSCC(CallGraph &cg, CGUseList &useList, 487 CallGraphSCC ¤tSCC, 488 MLIRContext *context) { 489 // Collect the sets of nodes to simplify. 490 SmallVector<CallGraphNode *, 4> nodesToVisit; 491 for (auto *node : currentSCC) { 492 if (node->isExternal()) 493 continue; 494 495 // Don't simplify nodes with children. Nodes with children require special 496 // handling as we may remove the node during simplification. In the future, 497 // we should be able to handle this case with proper node deletion tracking. 498 if (node->hasChildren()) 499 continue; 500 501 // We also won't apply simplifications to nodes that can't have passes 502 // scheduled on them. 503 auto *region = node->getCallableRegion(); 504 if (!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()) 505 continue; 506 nodesToVisit.push_back(node); 507 } 508 if (nodesToVisit.empty()) 509 return success(); 510 511 // Optimize each of the nodes within the SCC in parallel. 512 if (failed(optimizeSCCAsync(nodesToVisit, context))) 513 return failure(); 514 515 // Recompute the uses held by each of the nodes. 516 for (CallGraphNode *node : nodesToVisit) 517 useList.recomputeUses(node, cg); 518 return success(); 519 } 520 521 LogicalResult 522 Inliner::Impl::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit, 523 MLIRContext *ctx) { 524 // We must maintain a fixed pool of pass managers which is at least as large 525 // as the maximum parallelism of the failableParallelForEach below. 526 // Note: The number of pass managers here needs to remain constant 527 // to prevent issues with pass instrumentations that rely on having the same 528 // pass manager for the main thread. 529 size_t numThreads = ctx->getNumThreads(); 530 const auto &opPipelines = inliner.config.getOpPipelines(); 531 if (pipelines.size() < numThreads) { 532 pipelines.reserve(numThreads); 533 pipelines.resize(numThreads, opPipelines); 534 } 535 536 // Ensure an analysis manager has been constructed for each of the nodes. 537 // This prevents thread races when running the nested pipelines. 538 for (CallGraphNode *node : nodesToVisit) 539 inliner.am.nest(node->getCallableRegion()->getParentOp()); 540 541 // An atomic failure variable for the async executors. 542 std::vector<std::atomic<bool>> activePMs(pipelines.size()); 543 std::fill(activePMs.begin(), activePMs.end(), false); 544 return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) { 545 // Find a pass manager for this operation. 546 auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) { 547 bool expectedInactive = false; 548 return isActive.compare_exchange_strong(expectedInactive, true); 549 }); 550 assert(it != activePMs.end() && 551 "could not find inactive pass manager for thread"); 552 unsigned pmIndex = it - activePMs.begin(); 553 554 // Optimize this callable node. 555 LogicalResult result = optimizeCallable(node, pipelines[pmIndex]); 556 557 // Reset the active bit for this pass manager. 558 activePMs[pmIndex].store(false); 559 return result; 560 }); 561 } 562 563 LogicalResult 564 Inliner::Impl::optimizeCallable(CallGraphNode *node, 565 llvm::StringMap<OpPassManager> &pipelines) { 566 Operation *callable = node->getCallableRegion()->getParentOp(); 567 StringRef opName = callable->getName().getStringRef(); 568 auto pipelineIt = pipelines.find(opName); 569 const auto &defaultPipeline = inliner.config.getDefaultPipeline(); 570 if (pipelineIt == pipelines.end()) { 571 // If a pipeline didn't exist, use the generic pipeline if possible. 572 if (!defaultPipeline) 573 return success(); 574 575 OpPassManager defaultPM(opName); 576 defaultPipeline(defaultPM); 577 pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first; 578 } 579 return inliner.runPipelineHelper(inliner.pass, pipelineIt->second, callable); 580 } 581 582 /// Attempt to inline calls within the given scc. This function returns 583 /// success if any calls were inlined, failure otherwise. 584 LogicalResult 585 Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, 586 CGUseList &useList, CallGraphSCC ¤tSCC) { 587 CallGraph &cg = inlinerIface.cg; 588 auto &calls = inlinerIface.calls; 589 590 // A set of dead nodes to remove after inlining. 591 llvm::SmallSetVector<CallGraphNode *, 1> deadNodes; 592 593 // Collect all of the direct calls within the nodes of the current SCC. We 594 // don't traverse nested callgraph nodes, because they are handled separately 595 // likely within a different SCC. 596 for (CallGraphNode *node : currentSCC) { 597 if (node->isExternal()) 598 continue; 599 600 // Don't collect calls if the node is already dead. 601 if (useList.isDead(node)) { 602 deadNodes.insert(node); 603 } else { 604 collectCallOps(*node->getCallableRegion(), node, cg, 605 inlinerIface.symbolTable, calls, 606 /*traverseNestedCGNodes=*/false); 607 } 608 } 609 610 // When inlining a callee produces new call sites, we want to keep track of 611 // the fact that they were inlined from the callee. This allows us to avoid 612 // infinite inlining. 613 using InlineHistoryT = std::optional<size_t>; 614 SmallVector<std::pair<CallGraphNode *, InlineHistoryT>, 8> inlineHistory; 615 std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{}); 616 617 LLVM_DEBUG({ 618 llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n"; 619 for (unsigned i = 0, e = calls.size(); i < e; ++i) 620 llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n"; 621 llvm::dbgs() << "}\n"; 622 }); 623 624 // Try to inline each of the call operations. Don't cache the end iterator 625 // here as more calls may be added during inlining. 626 bool inlinedAnyCalls = false; 627 for (unsigned i = 0; i < calls.size(); ++i) { 628 if (deadNodes.contains(calls[i].sourceNode)) 629 continue; 630 ResolvedCall it = calls[i]; 631 632 InlineHistoryT inlineHistoryID = callHistory[i]; 633 bool inHistory = 634 inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory); 635 bool doInline = !inHistory && shouldInline(it); 636 CallOpInterface call = it.call; 637 LLVM_DEBUG({ 638 if (doInline) 639 llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n"; 640 else 641 llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n"; 642 }); 643 if (!doInline) 644 continue; 645 646 unsigned prevSize = calls.size(); 647 Region *targetRegion = it.targetNode->getCallableRegion(); 648 649 // If this is the last call to the target node and the node is discardable, 650 // then inline it in-place and delete the node if successful. 651 bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode); 652 653 LogicalResult inlineResult = 654 inlineCall(inlinerIface, call, 655 cast<CallableOpInterface>(targetRegion->getParentOp()), 656 targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace); 657 if (failed(inlineResult)) { 658 LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n"); 659 continue; 660 } 661 inlinedAnyCalls = true; 662 663 // Create a inline history entry for this inlined call, so that we remember 664 // that new callsites came about due to inlining Callee. 665 InlineHistoryT newInlineHistoryID{inlineHistory.size()}; 666 inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID)); 667 668 auto historyToString = [](InlineHistoryT h) { 669 return h.has_value() ? std::to_string(*h) : "root"; 670 }; 671 (void)historyToString; 672 LLVM_DEBUG(llvm::dbgs() 673 << "* new inlineHistory entry: " << newInlineHistoryID << ". [" 674 << getNodeName(call) << ", " << historyToString(inlineHistoryID) 675 << "]\n"); 676 677 for (unsigned k = prevSize; k != calls.size(); ++k) { 678 callHistory.push_back(newInlineHistoryID); 679 LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call 680 << "}\n with historyID = " << newInlineHistoryID 681 << ", added due to inlining of\n call {" << call 682 << "}\n with historyID = " 683 << historyToString(inlineHistoryID) << "\n"); 684 } 685 686 // If the inlining was successful, Merge the new uses into the source node. 687 useList.dropCallUses(it.sourceNode, call.getOperation(), cg); 688 useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode); 689 690 // then erase the call. 691 call.erase(); 692 693 // If we inlined in place, mark the node for deletion. 694 if (inlineInPlace) { 695 useList.eraseNode(it.targetNode); 696 deadNodes.insert(it.targetNode); 697 } 698 } 699 700 for (CallGraphNode *node : deadNodes) { 701 currentSCC.remove(node); 702 inlinerIface.markForDeletion(node); 703 } 704 calls.clear(); 705 return success(inlinedAnyCalls); 706 } 707 708 /// Returns true if the given call should be inlined. 709 bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) { 710 // Don't allow inlining terminator calls. We currently don't support this 711 // case. 712 if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>()) 713 return false; 714 715 // Don't allow inlining if the target is a self-recursive function. 716 // Don't allow inlining if the call graph is like A->B->A. 717 if (llvm::count_if(*resolvedCall.targetNode, 718 [&](CallGraphNode::Edge const &edge) -> bool { 719 return edge.getTarget() == resolvedCall.targetNode || 720 edge.getTarget() == resolvedCall.sourceNode; 721 }) > 0) 722 return false; 723 724 // Don't allow inlining if the target is an ancestor of the call. This 725 // prevents inlining recursively. 726 Region *callableRegion = resolvedCall.targetNode->getCallableRegion(); 727 if (callableRegion->isAncestor(resolvedCall.call->getParentRegion())) 728 return false; 729 730 // Don't allow inlining if the callee has multiple blocks (unstructured 731 // control flow) but we cannot be sure that the caller region supports that. 732 bool calleeHasMultipleBlocks = 733 llvm::hasNItemsOrMore(*callableRegion, /*N=*/2); 734 // If both parent ops have the same type, it is safe to inline. Otherwise, 735 // decide based on whether the op has the SingleBlock trait or not. 736 // Note: This check does currently not account for SizedRegion/MaxSizedRegion. 737 auto callerRegionSupportsMultipleBlocks = [&]() { 738 return callableRegion->getParentOp()->getName() == 739 resolvedCall.call->getParentOp()->getName() || 740 !resolvedCall.call->getParentOp() 741 ->mightHaveTrait<OpTrait::SingleBlock>(); 742 }; 743 if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks()) 744 return false; 745 746 if (!inliner.isProfitableToInline(resolvedCall)) 747 return false; 748 749 // Otherwise, inline. 750 return true; 751 } 752 753 LogicalResult Inliner::doInlining() { 754 Impl impl(*this); 755 auto *context = op->getContext(); 756 // Run the inline transform in post-order over the SCCs in the callgraph. 757 SymbolTableCollection symbolTable; 758 // FIXME: some clean-up can be done for the arguments 759 // of the Impl's methods, if the inlinerIface and useList 760 // become the states of the Impl. 761 InlinerInterfaceImpl inlinerIface(context, cg, symbolTable); 762 CGUseList useList(op, cg, symbolTable); 763 LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) { 764 return impl.inlineSCC(inlinerIface, useList, scc, context); 765 }); 766 if (failed(result)) 767 return result; 768 769 // After inlining, make sure to erase any callables proven to be dead. 770 inlinerIface.eraseDeadCallables(); 771 return success(); 772 } 773 } // namespace mlir 774