10ba00878SRiver Riddle //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===// 20ba00878SRiver Riddle // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 60ba00878SRiver Riddle // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 80ba00878SRiver Riddle // 90ba00878SRiver Riddle // This file implements miscellaneous inlining utilities. 100ba00878SRiver Riddle // 110ba00878SRiver Riddle //===----------------------------------------------------------------------===// 120ba00878SRiver Riddle 130ba00878SRiver Riddle #include "mlir/Transforms/InliningUtils.h" 140ba00878SRiver Riddle 155830f71aSRiver Riddle #include "mlir/IR/Builders.h" 164d67b278SJeff Niu #include "mlir/IR/IRMapping.h" 170ba00878SRiver Riddle #include "mlir/IR/Operation.h" 1836550692SRiver Riddle #include "mlir/Interfaces/CallInterfaces.h" 190ba00878SRiver Riddle #include "llvm/ADT/MapVector.h" 20553f794bSSean Silva #include "llvm/Support/Debug.h" 210ba00878SRiver Riddle #include "llvm/Support/raw_ostream.h" 22a1fe1f5fSKazu Hirata #include <optional> 230ba00878SRiver Riddle 240ba00878SRiver Riddle #define DEBUG_TYPE "inlining" 250ba00878SRiver Riddle 260ba00878SRiver Riddle using namespace mlir; 270ba00878SRiver Riddle 286f092e50SChristian Ulmann /// Remap all locations reachable from the inlined blocks with CallSiteLoc 296f092e50SChristian Ulmann /// locations with the provided caller location. 300ba00878SRiver Riddle static void 314562e389SRiver Riddle remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks, 320ba00878SRiver Riddle Location callerLoc) { 336f092e50SChristian Ulmann DenseMap<Location, LocationAttr> mappedLocations; 346f092e50SChristian Ulmann auto remapLoc = [&](Location loc) { 356f092e50SChristian Ulmann auto [it, inserted] = mappedLocations.try_emplace(loc); 366f092e50SChristian Ulmann // Only query the attribute uniquer once per callsite attribute. 376f092e50SChristian Ulmann if (inserted) { 386f092e50SChristian Ulmann auto newLoc = CallSiteLoc::get(loc, callerLoc); 396f092e50SChristian Ulmann it->getSecond() = newLoc; 400ba00878SRiver Riddle } 416f092e50SChristian Ulmann return it->second; 420ba00878SRiver Riddle }; 436f092e50SChristian Ulmann 446f092e50SChristian Ulmann AttrTypeReplacer attrReplacer; 456f092e50SChristian Ulmann attrReplacer.addReplacement( 466f092e50SChristian Ulmann [&](LocationAttr loc) -> std::pair<LocationAttr, WalkResult> { 476f092e50SChristian Ulmann return {remapLoc(loc), WalkResult::skip()}; 486f092e50SChristian Ulmann }); 496f092e50SChristian Ulmann 506f092e50SChristian Ulmann for (Block &block : inlinedBlocks) { 516f092e50SChristian Ulmann for (BlockArgument &arg : block.getArguments()) 526f092e50SChristian Ulmann if (LocationAttr newLoc = remapLoc(arg.getLoc())) 536f092e50SChristian Ulmann arg.setLoc(newLoc); 546f092e50SChristian Ulmann 556f092e50SChristian Ulmann for (Operation &op : block) 566f092e50SChristian Ulmann attrReplacer.recursivelyReplaceElementsIn(&op, /*replaceAttrs=*/false, 576f092e50SChristian Ulmann /*replaceLocs=*/true); 586f092e50SChristian Ulmann } 590ba00878SRiver Riddle } 600ba00878SRiver Riddle 614562e389SRiver Riddle static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks, 624d67b278SJeff Niu IRMapping &mapper) { 630ba00878SRiver Riddle auto remapOperands = [&](Operation *op) { 640ba00878SRiver Riddle for (auto &operand : op->getOpOperands()) 6535807bc4SRiver Riddle if (auto mappedOp = mapper.lookupOrNull(operand.get())) 660ba00878SRiver Riddle operand.set(mappedOp); 670ba00878SRiver Riddle }; 680ba00878SRiver Riddle for (auto &block : inlinedBlocks) 690ba00878SRiver Riddle block.walk(remapOperands); 700ba00878SRiver Riddle } 710ba00878SRiver Riddle 720ba00878SRiver Riddle //===----------------------------------------------------------------------===// 730ba00878SRiver Riddle // InlinerInterface 740ba00878SRiver Riddle //===----------------------------------------------------------------------===// 750ba00878SRiver Riddle 76fa417479SRiver Riddle bool InlinerInterface::isLegalToInline(Operation *call, Operation *callable, 77fa417479SRiver Riddle bool wouldBeCloned) const { 78fa417479SRiver Riddle if (auto *handler = getInterfaceFor(call)) 79fa417479SRiver Riddle return handler->isLegalToInline(call, callable, wouldBeCloned); 80fa417479SRiver Riddle return false; 81501fda01SRiver Riddle } 82501fda01SRiver Riddle 834d67b278SJeff Niu bool InlinerInterface::isLegalToInline(Region *dest, Region *src, 844d67b278SJeff Niu bool wouldBeCloned, 854d67b278SJeff Niu IRMapping &valueMapping) const { 86fa417479SRiver Riddle if (auto *handler = getInterfaceFor(dest->getParentOp())) 87fa417479SRiver Riddle return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping); 88fa417479SRiver Riddle return false; 890ba00878SRiver Riddle } 900ba00878SRiver Riddle 914d67b278SJeff Niu bool InlinerInterface::isLegalToInline(Operation *op, Region *dest, 924d67b278SJeff Niu bool wouldBeCloned, 934d67b278SJeff Niu IRMapping &valueMapping) const { 94fa417479SRiver Riddle if (auto *handler = getInterfaceFor(op)) 95fa417479SRiver Riddle return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping); 96fa417479SRiver Riddle return false; 970ba00878SRiver Riddle } 980ba00878SRiver Riddle 990ba00878SRiver Riddle bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const { 1000ba00878SRiver Riddle auto *handler = getInterfaceFor(op); 1010ba00878SRiver Riddle return handler ? handler->shouldAnalyzeRecursively(op) : true; 1020ba00878SRiver Riddle } 1030ba00878SRiver Riddle 1040ba00878SRiver Riddle /// Handle the given inlined terminator by replacing it with a new operation 1050ba00878SRiver Riddle /// as necessary. 1060ba00878SRiver Riddle void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const { 1070ba00878SRiver Riddle auto *handler = getInterfaceFor(op); 1080ba00878SRiver Riddle assert(handler && "expected valid dialect handler"); 1090ba00878SRiver Riddle handler->handleTerminator(op, newDest); 1100ba00878SRiver Riddle } 1110ba00878SRiver Riddle 1120ba00878SRiver Riddle /// Handle the given inlined terminator by replacing it with a new operation 1130ba00878SRiver Riddle /// as necessary. 1140ba00878SRiver Riddle void InlinerInterface::handleTerminator(Operation *op, 11526a0b277SMehdi Amini ValueRange valuesToRepl) const { 1160ba00878SRiver Riddle auto *handler = getInterfaceFor(op); 1170ba00878SRiver Riddle assert(handler && "expected valid dialect handler"); 1180ba00878SRiver Riddle handler->handleTerminator(op, valuesToRepl); 1190ba00878SRiver Riddle } 1200ba00878SRiver Riddle 121*b39c5cb6SWilliam Moses /// Returns true if the inliner can assume a fast path of not creating a 122*b39c5cb6SWilliam Moses /// new block, if there is only one block. 123*b39c5cb6SWilliam Moses bool InlinerInterface::allowSingleBlockOptimization( 124*b39c5cb6SWilliam Moses iterator_range<Region::iterator> inlinedBlocks) const { 125*b39c5cb6SWilliam Moses if (inlinedBlocks.empty()) { 126*b39c5cb6SWilliam Moses return true; 127*b39c5cb6SWilliam Moses } 128*b39c5cb6SWilliam Moses auto *handler = getInterfaceFor(inlinedBlocks.begin()->getParentOp()); 129*b39c5cb6SWilliam Moses assert(handler && "expected valid dialect handler"); 130*b39c5cb6SWilliam Moses return handler->allowSingleBlockOptimization(inlinedBlocks); 131*b39c5cb6SWilliam Moses } 132*b39c5cb6SWilliam Moses 133f809eb4dSTobias Gysi Value InlinerInterface::handleArgument(OpBuilder &builder, Operation *call, 134f809eb4dSTobias Gysi Operation *callable, Value argument, 135f809eb4dSTobias Gysi DictionaryAttr argumentAttrs) const { 136f809eb4dSTobias Gysi auto *handler = getInterfaceFor(callable); 137f809eb4dSTobias Gysi assert(handler && "expected valid dialect handler"); 1380fb4ac55STobias Gysi return handler->handleArgument(builder, call, callable, argument, 139f809eb4dSTobias Gysi argumentAttrs); 140f809eb4dSTobias Gysi } 141f809eb4dSTobias Gysi 142f809eb4dSTobias Gysi Value InlinerInterface::handleResult(OpBuilder &builder, Operation *call, 143f809eb4dSTobias Gysi Operation *callable, Value result, 144f809eb4dSTobias Gysi DictionaryAttr resultAttrs) const { 145f809eb4dSTobias Gysi auto *handler = getInterfaceFor(callable); 146f809eb4dSTobias Gysi assert(handler && "expected valid dialect handler"); 1470fb4ac55STobias Gysi return handler->handleResult(builder, call, callable, result, resultAttrs); 148f809eb4dSTobias Gysi } 149f809eb4dSTobias Gysi 1500e760a08SJacques Pienaar void InlinerInterface::processInlinedCallBlocks( 1510e760a08SJacques Pienaar Operation *call, iterator_range<Region::iterator> inlinedBlocks) const { 1520e760a08SJacques Pienaar auto *handler = getInterfaceFor(call); 1530e760a08SJacques Pienaar assert(handler && "expected valid dialect handler"); 1540e760a08SJacques Pienaar handler->processInlinedCallBlocks(call, inlinedBlocks); 1550e760a08SJacques Pienaar } 1560e760a08SJacques Pienaar 1570ba00878SRiver Riddle /// Utility to check that all of the operations within 'src' can be inlined. 1580ba00878SRiver Riddle static bool isLegalToInline(InlinerInterface &interface, Region *src, 159fa417479SRiver Riddle Region *insertRegion, bool shouldCloneInlinedRegion, 1604d67b278SJeff Niu IRMapping &valueMapping) { 1610ba00878SRiver Riddle for (auto &block : *src) { 1620ba00878SRiver Riddle for (auto &op : block) { 1630ba00878SRiver Riddle // Check this operation. 164fa417479SRiver Riddle if (!interface.isLegalToInline(&op, insertRegion, 165fa417479SRiver Riddle shouldCloneInlinedRegion, valueMapping)) { 166553f794bSSean Silva LLVM_DEBUG({ 167553f794bSSean Silva llvm::dbgs() << "* Illegal to inline because of op: "; 168553f794bSSean Silva op.dump(); 169553f794bSSean Silva }); 1700ba00878SRiver Riddle return false; 171553f794bSSean Silva } 1720ba00878SRiver Riddle // Check any nested regions. 1730ba00878SRiver Riddle if (interface.shouldAnalyzeRecursively(&op) && 1740ba00878SRiver Riddle llvm::any_of(op.getRegions(), [&](Region ®ion) { 1750ba00878SRiver Riddle return !isLegalToInline(interface, ®ion, insertRegion, 176fa417479SRiver Riddle shouldCloneInlinedRegion, valueMapping); 1770ba00878SRiver Riddle })) 1780ba00878SRiver Riddle return false; 1790ba00878SRiver Riddle } 1800ba00878SRiver Riddle } 1810ba00878SRiver Riddle return true; 1820ba00878SRiver Riddle } 1830ba00878SRiver Riddle 1840ba00878SRiver Riddle //===----------------------------------------------------------------------===// 1850ba00878SRiver Riddle // Inline Methods 1860ba00878SRiver Riddle //===----------------------------------------------------------------------===// 1870ba00878SRiver Riddle 188f809eb4dSTobias Gysi static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder, 189f809eb4dSTobias Gysi CallOpInterface call, 190f809eb4dSTobias Gysi CallableOpInterface callable, 191f809eb4dSTobias Gysi IRMapping &mapper) { 192f809eb4dSTobias Gysi // Unpack the argument attributes if there are any. 193f809eb4dSTobias Gysi SmallVector<DictionaryAttr> argAttrs( 194f809eb4dSTobias Gysi callable.getCallableRegion()->getNumArguments(), 195f809eb4dSTobias Gysi builder.getDictionaryAttr({})); 19634a35a8bSMartin Erhart if (ArrayAttr arrayAttr = callable.getArgAttrsAttr()) { 197f809eb4dSTobias Gysi assert(arrayAttr.size() == argAttrs.size()); 198f809eb4dSTobias Gysi for (auto [idx, attr] : llvm::enumerate(arrayAttr)) 199f809eb4dSTobias Gysi argAttrs[idx] = cast<DictionaryAttr>(attr); 200f809eb4dSTobias Gysi } 201f809eb4dSTobias Gysi 202f809eb4dSTobias Gysi // Run the argument attribute handler for the given argument and attribute. 203f809eb4dSTobias Gysi for (auto [blockArg, argAttr] : 204f809eb4dSTobias Gysi llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) { 2050fb4ac55STobias Gysi Value newArgument = interface.handleArgument( 2060fb4ac55STobias Gysi builder, call, callable, mapper.lookup(blockArg), argAttr); 2070fb4ac55STobias Gysi assert(newArgument.getType() == mapper.lookup(blockArg).getType() && 2080fb4ac55STobias Gysi "expected the argument type to not change"); 209f809eb4dSTobias Gysi 210f809eb4dSTobias Gysi // Update the mapping to point the new argument returned by the handler. 211f809eb4dSTobias Gysi mapper.map(blockArg, newArgument); 212f809eb4dSTobias Gysi } 213f809eb4dSTobias Gysi } 214f809eb4dSTobias Gysi 215f809eb4dSTobias Gysi static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder, 216f809eb4dSTobias Gysi CallOpInterface call, CallableOpInterface callable, 217f809eb4dSTobias Gysi ValueRange results) { 218f809eb4dSTobias Gysi // Unpack the result attributes if there are any. 219f809eb4dSTobias Gysi SmallVector<DictionaryAttr> resAttrs(results.size(), 220f809eb4dSTobias Gysi builder.getDictionaryAttr({})); 22134a35a8bSMartin Erhart if (ArrayAttr arrayAttr = callable.getResAttrsAttr()) { 222f809eb4dSTobias Gysi assert(arrayAttr.size() == resAttrs.size()); 223f809eb4dSTobias Gysi for (auto [idx, attr] : llvm::enumerate(arrayAttr)) 224f809eb4dSTobias Gysi resAttrs[idx] = cast<DictionaryAttr>(attr); 225f809eb4dSTobias Gysi } 226f809eb4dSTobias Gysi 227f809eb4dSTobias Gysi // Run the result attribute handler for the given result and attribute. 228f809eb4dSTobias Gysi SmallVector<DictionaryAttr> resultAttributes; 229f809eb4dSTobias Gysi for (auto [result, resAttr] : llvm::zip(results, resAttrs)) { 230f809eb4dSTobias Gysi // Store the original result users before running the handler. 231f809eb4dSTobias Gysi DenseSet<Operation *> resultUsers; 232f809eb4dSTobias Gysi for (Operation *user : result.getUsers()) 233f809eb4dSTobias Gysi resultUsers.insert(user); 234f809eb4dSTobias Gysi 2350fb4ac55STobias Gysi Value newResult = 2360fb4ac55STobias Gysi interface.handleResult(builder, call, callable, result, resAttr); 2370fb4ac55STobias Gysi assert(newResult.getType() == result.getType() && 2380fb4ac55STobias Gysi "expected the result type to not change"); 239f809eb4dSTobias Gysi 240f809eb4dSTobias Gysi // Replace the result uses except for the ones introduce by the handler. 241f809eb4dSTobias Gysi result.replaceUsesWithIf(newResult, [&](OpOperand &operand) { 242f809eb4dSTobias Gysi return resultUsers.count(operand.getOwner()); 243f809eb4dSTobias Gysi }); 244f809eb4dSTobias Gysi } 245f809eb4dSTobias Gysi } 246f809eb4dSTobias Gysi 2470e760a08SJacques Pienaar static LogicalResult 248da12d88bSRiver Riddle inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, 2494d67b278SJeff Niu Block::iterator inlinePoint, IRMapping &mapper, 2500e760a08SJacques Pienaar ValueRange resultsToReplace, TypeRange regionResultTypes, 2510a81ace0SKazu Hirata std::optional<Location> inlineLoc, 2526089d612SRahul Kayaith bool shouldCloneInlinedRegion, CallOpInterface call = {}) { 25322219cfcSSean Silva assert(resultsToReplace.size() == regionResultTypes.size()); 2540ba00878SRiver Riddle // We expect the region to have at least one block. 2550ba00878SRiver Riddle if (src->empty()) 2560ba00878SRiver Riddle return failure(); 2570ba00878SRiver Riddle 2580ba00878SRiver Riddle // Check that all of the region arguments have been mapped. 2590ba00878SRiver Riddle auto *srcEntryBlock = &src->front(); 2600ba00878SRiver Riddle if (llvm::any_of(srcEntryBlock->getArguments(), 261e62a6956SRiver Riddle [&](BlockArgument arg) { return !mapper.contains(arg); })) 2620ba00878SRiver Riddle return failure(); 2630ba00878SRiver Riddle 2640ba00878SRiver Riddle // Check that the operations within the source region are valid to inline. 265da12d88bSRiver Riddle Region *insertRegion = inlineBlock->getParent(); 266fa417479SRiver Riddle if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion, 267fa417479SRiver Riddle mapper) || 268fa417479SRiver Riddle !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion, 269fa417479SRiver Riddle mapper)) 2700ba00878SRiver Riddle return failure(); 2710ba00878SRiver Riddle 272f809eb4dSTobias Gysi // Run the argument attribute handler before inlining the callable region. 273f809eb4dSTobias Gysi OpBuilder builder(inlineBlock, inlinePoint); 274f809eb4dSTobias Gysi auto callable = dyn_cast<CallableOpInterface>(src->getParentOp()); 275f809eb4dSTobias Gysi if (call && callable) 276f809eb4dSTobias Gysi handleArgumentImpl(interface, builder, call, callable, mapper); 277f809eb4dSTobias Gysi 2780ba00878SRiver Riddle // Check to see if the region is being cloned, or moved inline. In either 2790ba00878SRiver Riddle // case, move the new blocks after the 'insertBlock' to improve IR 2800ba00878SRiver Riddle // readability. 281da12d88bSRiver Riddle Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint); 2820ba00878SRiver Riddle if (shouldCloneInlinedRegion) 2830ba00878SRiver Riddle src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper); 2840ba00878SRiver Riddle else 2850ba00878SRiver Riddle insertRegion->getBlocks().splice(postInsertBlock->getIterator(), 2860ba00878SRiver Riddle src->getBlocks(), src->begin(), 2870ba00878SRiver Riddle src->end()); 2880ba00878SRiver Riddle 2890ba00878SRiver Riddle // Get the range of newly inserted blocks. 290da12d88bSRiver Riddle auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()), 2910ba00878SRiver Riddle postInsertBlock->getIterator()); 2920ba00878SRiver Riddle Block *firstNewBlock = &*newBlocks.begin(); 2930ba00878SRiver Riddle 2940ba00878SRiver Riddle // Remap the locations of the inlined operations if a valid source location 2950ba00878SRiver Riddle // was provided. 29668f58812STres Popp if (inlineLoc && !llvm::isa<UnknownLoc>(*inlineLoc)) 2970ba00878SRiver Riddle remapInlinedLocations(newBlocks, *inlineLoc); 2980ba00878SRiver Riddle 2990ba00878SRiver Riddle // If the blocks were moved in-place, make sure to remap any necessary 3000ba00878SRiver Riddle // operands. 3010ba00878SRiver Riddle if (!shouldCloneInlinedRegion) 3020ba00878SRiver Riddle remapInlinedOperands(newBlocks, mapper); 3030ba00878SRiver Riddle 304a20d96e4SRiver Riddle // Process the newly inlined blocks. 3050e760a08SJacques Pienaar if (call) 3060e760a08SJacques Pienaar interface.processInlinedCallBlocks(call, newBlocks); 307a20d96e4SRiver Riddle interface.processInlinedBlocks(newBlocks); 308a20d96e4SRiver Riddle 309*b39c5cb6SWilliam Moses bool singleBlockFastPath = interface.allowSingleBlockOptimization(newBlocks); 310*b39c5cb6SWilliam Moses 3110ba00878SRiver Riddle // Handle the case where only a single block was inlined. 312*b39c5cb6SWilliam Moses if (singleBlockFastPath && std::next(newBlocks.begin()) == newBlocks.end()) { 313f809eb4dSTobias Gysi // Run the result attribute handler on the terminator operands. 314f809eb4dSTobias Gysi Operation *firstBlockTerminator = firstNewBlock->getTerminator(); 315f809eb4dSTobias Gysi builder.setInsertionPoint(firstBlockTerminator); 316f809eb4dSTobias Gysi if (call && callable) 317f809eb4dSTobias Gysi handleResultImpl(interface, builder, call, callable, 318f809eb4dSTobias Gysi firstBlockTerminator->getOperands()); 319f809eb4dSTobias Gysi 3200ba00878SRiver Riddle // Have the interface handle the terminator of this block. 32126a0b277SMehdi Amini interface.handleTerminator(firstBlockTerminator, resultsToReplace); 3220ba00878SRiver Riddle firstBlockTerminator->erase(); 3230ba00878SRiver Riddle 3240ba00878SRiver Riddle // Merge the post insert block into the cloned entry block. 3250ba00878SRiver Riddle firstNewBlock->getOperations().splice(firstNewBlock->end(), 3260ba00878SRiver Riddle postInsertBlock->getOperations()); 3270ba00878SRiver Riddle postInsertBlock->erase(); 3280ba00878SRiver Riddle } else { 3290ba00878SRiver Riddle // Otherwise, there were multiple blocks inlined. Add arguments to the post 3300ba00878SRiver Riddle // insertion block to represent the results to replace. 331e4853be2SMehdi Amini for (const auto &resultToRepl : llvm::enumerate(resultsToReplace)) { 332e084679fSRiver Riddle resultToRepl.value().replaceAllUsesWith( 333e084679fSRiver Riddle postInsertBlock->addArgument(regionResultTypes[resultToRepl.index()], 334e084679fSRiver Riddle resultToRepl.value().getLoc())); 3350ba00878SRiver Riddle } 3360ba00878SRiver Riddle 337f809eb4dSTobias Gysi // Run the result attribute handler on the post insertion block arguments. 338f809eb4dSTobias Gysi builder.setInsertionPointToStart(postInsertBlock); 339f809eb4dSTobias Gysi if (call && callable) 340f809eb4dSTobias Gysi handleResultImpl(interface, builder, call, callable, 341f809eb4dSTobias Gysi postInsertBlock->getArguments()); 342f809eb4dSTobias Gysi 3430ba00878SRiver Riddle /// Handle the terminators for each of the new blocks. 3440ba00878SRiver Riddle for (auto &newBlock : newBlocks) 3450ba00878SRiver Riddle interface.handleTerminator(newBlock.getTerminator(), postInsertBlock); 3460ba00878SRiver Riddle } 3470ba00878SRiver Riddle 3480ba00878SRiver Riddle // Splice the instructions of the inlined entry block into the insert block. 349da12d88bSRiver Riddle inlineBlock->getOperations().splice(inlineBlock->end(), 3500ba00878SRiver Riddle firstNewBlock->getOperations()); 3510ba00878SRiver Riddle firstNewBlock->erase(); 3520ba00878SRiver Riddle return success(); 3530ba00878SRiver Riddle } 3540ba00878SRiver Riddle 3550e760a08SJacques Pienaar static LogicalResult 356da12d88bSRiver Riddle inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, 357da12d88bSRiver Riddle Block::iterator inlinePoint, ValueRange inlinedOperands, 3580a81ace0SKazu Hirata ValueRange resultsToReplace, std::optional<Location> inlineLoc, 3596089d612SRahul Kayaith bool shouldCloneInlinedRegion, CallOpInterface call = {}) { 3600ba00878SRiver Riddle // We expect the region to have at least one block. 3610ba00878SRiver Riddle if (src->empty()) 3620ba00878SRiver Riddle return failure(); 3630ba00878SRiver Riddle 3640ba00878SRiver Riddle auto *entryBlock = &src->front(); 3650ba00878SRiver Riddle if (inlinedOperands.size() != entryBlock->getNumArguments()) 3660ba00878SRiver Riddle return failure(); 3670ba00878SRiver Riddle 3680ba00878SRiver Riddle // Map the provided call operands to the arguments of the region. 3694d67b278SJeff Niu IRMapping mapper; 3700ba00878SRiver Riddle for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) { 3710ba00878SRiver Riddle // Verify that the types of the provided values match the function argument 3720ba00878SRiver Riddle // types. 373e62a6956SRiver Riddle BlockArgument regionArg = entryBlock->getArgument(i); 3742bdf33ccSRiver Riddle if (inlinedOperands[i].getType() != regionArg.getType()) 3750ba00878SRiver Riddle return failure(); 3760ba00878SRiver Riddle mapper.map(regionArg, inlinedOperands[i]); 3770ba00878SRiver Riddle } 3780ba00878SRiver Riddle 3790ba00878SRiver Riddle // Call into the main region inliner function. 380da12d88bSRiver Riddle return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper, 381da12d88bSRiver Riddle resultsToReplace, resultsToReplace.getTypes(), 382da12d88bSRiver Riddle inlineLoc, shouldCloneInlinedRegion, call); 3830e760a08SJacques Pienaar } 3840e760a08SJacques Pienaar 3850e760a08SJacques Pienaar LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 3864d67b278SJeff Niu Operation *inlinePoint, IRMapping &mapper, 3870e760a08SJacques Pienaar ValueRange resultsToReplace, 3880e760a08SJacques Pienaar TypeRange regionResultTypes, 3890a81ace0SKazu Hirata std::optional<Location> inlineLoc, 3900e760a08SJacques Pienaar bool shouldCloneInlinedRegion) { 391da12d88bSRiver Riddle return inlineRegion(interface, src, inlinePoint->getBlock(), 392da12d88bSRiver Riddle ++inlinePoint->getIterator(), mapper, resultsToReplace, 393da12d88bSRiver Riddle regionResultTypes, inlineLoc, shouldCloneInlinedRegion); 394da12d88bSRiver Riddle } 3954d67b278SJeff Niu LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 3964d67b278SJeff Niu Block *inlineBlock, 3974d67b278SJeff Niu Block::iterator inlinePoint, IRMapping &mapper, 3984d67b278SJeff Niu ValueRange resultsToReplace, 3994d67b278SJeff Niu TypeRange regionResultTypes, 4000a81ace0SKazu Hirata std::optional<Location> inlineLoc, 401da12d88bSRiver Riddle bool shouldCloneInlinedRegion) { 402da12d88bSRiver Riddle return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper, 403da12d88bSRiver Riddle resultsToReplace, regionResultTypes, inlineLoc, 404da12d88bSRiver Riddle shouldCloneInlinedRegion); 4050e760a08SJacques Pienaar } 4060e760a08SJacques Pienaar 4070e760a08SJacques Pienaar LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 4080e760a08SJacques Pienaar Operation *inlinePoint, 4090e760a08SJacques Pienaar ValueRange inlinedOperands, 4100e760a08SJacques Pienaar ValueRange resultsToReplace, 4110a81ace0SKazu Hirata std::optional<Location> inlineLoc, 4120e760a08SJacques Pienaar bool shouldCloneInlinedRegion) { 413da12d88bSRiver Riddle return inlineRegion(interface, src, inlinePoint->getBlock(), 414da12d88bSRiver Riddle ++inlinePoint->getIterator(), inlinedOperands, 415da12d88bSRiver Riddle resultsToReplace, inlineLoc, shouldCloneInlinedRegion); 416da12d88bSRiver Riddle } 4170a81ace0SKazu Hirata LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 4180a81ace0SKazu Hirata Block *inlineBlock, 4190a81ace0SKazu Hirata Block::iterator inlinePoint, 4200a81ace0SKazu Hirata ValueRange inlinedOperands, 4210a81ace0SKazu Hirata ValueRange resultsToReplace, 4220a81ace0SKazu Hirata std::optional<Location> inlineLoc, 423da12d88bSRiver Riddle bool shouldCloneInlinedRegion) { 424da12d88bSRiver Riddle return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, 425da12d88bSRiver Riddle inlinedOperands, resultsToReplace, inlineLoc, 426da12d88bSRiver Riddle shouldCloneInlinedRegion); 4270ba00878SRiver Riddle } 4280ba00878SRiver Riddle 4295830f71aSRiver Riddle /// Utility function used to generate a cast operation from the given interface, 4305830f71aSRiver Riddle /// or return nullptr if a cast could not be generated. 431e62a6956SRiver Riddle static Value materializeConversion(const DialectInlinerInterface *interface, 4325830f71aSRiver Riddle SmallVectorImpl<Operation *> &castOps, 433e62a6956SRiver Riddle OpBuilder &castBuilder, Value arg, Type type, 434e62a6956SRiver Riddle Location conversionLoc) { 4355830f71aSRiver Riddle if (!interface) 4365830f71aSRiver Riddle return nullptr; 4375830f71aSRiver Riddle 4385830f71aSRiver Riddle // Check to see if the interface for the call can materialize a conversion. 4395830f71aSRiver Riddle Operation *castOp = interface->materializeCallConversion(castBuilder, arg, 4405830f71aSRiver Riddle type, conversionLoc); 4415830f71aSRiver Riddle if (!castOp) 4425830f71aSRiver Riddle return nullptr; 4435830f71aSRiver Riddle castOps.push_back(castOp); 4445830f71aSRiver Riddle 4455830f71aSRiver Riddle // Ensure that the generated cast is correct. 4465830f71aSRiver Riddle assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg && 4475830f71aSRiver Riddle castOp->getNumResults() == 1 && *castOp->result_type_begin() == type); 4485830f71aSRiver Riddle return castOp->getResult(0); 4495830f71aSRiver Riddle } 4505830f71aSRiver Riddle 4515830f71aSRiver Riddle /// This function inlines a given region, 'src', of a callable operation, 4525830f71aSRiver Riddle /// 'callable', into the location defined by the given call operation. This 4535830f71aSRiver Riddle /// function returns failure if inlining is not possible, success otherwise. On 4545830f71aSRiver Riddle /// failure, no changes are made to the module. 'shouldCloneInlinedRegion' 4555830f71aSRiver Riddle /// corresponds to whether the source region should be cloned into the 'call' or 4565830f71aSRiver Riddle /// spliced directly. 4575830f71aSRiver Riddle LogicalResult mlir::inlineCall(InlinerInterface &interface, 4585830f71aSRiver Riddle CallOpInterface call, 4595830f71aSRiver Riddle CallableOpInterface callable, Region *src, 4605830f71aSRiver Riddle bool shouldCloneInlinedRegion) { 4615830f71aSRiver Riddle // We expect the region to have at least one block. 4625830f71aSRiver Riddle if (src->empty()) 4635830f71aSRiver Riddle return failure(); 4645830f71aSRiver Riddle auto *entryBlock = &src->front(); 46534a35a8bSMartin Erhart ArrayRef<Type> callableResultTypes = callable.getResultTypes(); 4665830f71aSRiver Riddle 4675830f71aSRiver Riddle // Make sure that the number of arguments and results matchup between the call 4685830f71aSRiver Riddle // and the region. 469e62a6956SRiver Riddle SmallVector<Value, 8> callOperands(call.getArgOperands()); 470c4a04059SChristian Sigg SmallVector<Value, 8> callResults(call->getResults()); 4715830f71aSRiver Riddle if (callOperands.size() != entryBlock->getNumArguments() || 4725830f71aSRiver Riddle callResults.size() != callableResultTypes.size()) 4730ba00878SRiver Riddle return failure(); 4740ba00878SRiver Riddle 4755830f71aSRiver Riddle // A set of cast operations generated to matchup the signature of the region 4765830f71aSRiver Riddle // with the signature of the call. 4775830f71aSRiver Riddle SmallVector<Operation *, 4> castOps; 4785830f71aSRiver Riddle castOps.reserve(callOperands.size() + callResults.size()); 4790ba00878SRiver Riddle 4805830f71aSRiver Riddle // Functor used to cleanup generated state on failure. 4815830f71aSRiver Riddle auto cleanupState = [&] { 4825830f71aSRiver Riddle for (auto *op : castOps) { 4832bdf33ccSRiver Riddle op->getResult(0).replaceAllUsesWith(op->getOperand(0)); 4845830f71aSRiver Riddle op->erase(); 4855830f71aSRiver Riddle } 4860ba00878SRiver Riddle return failure(); 4875830f71aSRiver Riddle }; 4880ba00878SRiver Riddle 4895830f71aSRiver Riddle // Builder used for any conversion operations that need to be materialized. 4905830f71aSRiver Riddle OpBuilder castBuilder(call); 4915830f71aSRiver Riddle Location castLoc = call.getLoc(); 4920bf4a82aSChristian Sigg const auto *callInterface = interface.getInterfaceFor(call->getDialect()); 4935830f71aSRiver Riddle 4945830f71aSRiver Riddle // Map the provided call operands to the arguments of the region. 4954d67b278SJeff Niu IRMapping mapper; 4965830f71aSRiver Riddle for (unsigned i = 0, e = callOperands.size(); i != e; ++i) { 497e62a6956SRiver Riddle BlockArgument regionArg = entryBlock->getArgument(i); 498e62a6956SRiver Riddle Value operand = callOperands[i]; 4995830f71aSRiver Riddle 5005830f71aSRiver Riddle // If the call operand doesn't match the expected region argument, try to 5015830f71aSRiver Riddle // generate a cast. 5022bdf33ccSRiver Riddle Type regionArgType = regionArg.getType(); 5032bdf33ccSRiver Riddle if (operand.getType() != regionArgType) { 5045830f71aSRiver Riddle if (!(operand = materializeConversion(callInterface, castOps, castBuilder, 5055830f71aSRiver Riddle operand, regionArgType, castLoc))) 5065830f71aSRiver Riddle return cleanupState(); 5075830f71aSRiver Riddle } 5085830f71aSRiver Riddle mapper.map(regionArg, operand); 5095830f71aSRiver Riddle } 5105830f71aSRiver Riddle 511706d992cSRahul Joshi // Ensure that the resultant values of the call match the callable. 5125830f71aSRiver Riddle castBuilder.setInsertionPointAfter(call); 5135830f71aSRiver Riddle for (unsigned i = 0, e = callResults.size(); i != e; ++i) { 514e62a6956SRiver Riddle Value callResult = callResults[i]; 5152bdf33ccSRiver Riddle if (callResult.getType() == callableResultTypes[i]) 5165830f71aSRiver Riddle continue; 5175830f71aSRiver Riddle 5185830f71aSRiver Riddle // Generate a conversion that will produce the original type, so that the IR 5195830f71aSRiver Riddle // is still valid after the original call gets replaced. 520e62a6956SRiver Riddle Value castResult = 5215830f71aSRiver Riddle materializeConversion(callInterface, castOps, castBuilder, callResult, 5222bdf33ccSRiver Riddle callResult.getType(), castLoc); 5235830f71aSRiver Riddle if (!castResult) 5245830f71aSRiver Riddle return cleanupState(); 5252bdf33ccSRiver Riddle callResult.replaceAllUsesWith(castResult); 5262bdf33ccSRiver Riddle castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult); 5275830f71aSRiver Riddle } 5285830f71aSRiver Riddle 529501fda01SRiver Riddle // Check that it is legal to inline the callable into the call. 530fa417479SRiver Riddle if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion)) 531501fda01SRiver Riddle return cleanupState(); 532501fda01SRiver Riddle 5335830f71aSRiver Riddle // Attempt to inline the call. 534da12d88bSRiver Riddle if (failed(inlineRegionImpl(interface, src, call->getBlock(), 535da12d88bSRiver Riddle ++call->getIterator(), mapper, callResults, 53622219cfcSSean Silva callableResultTypes, call.getLoc(), 5370e760a08SJacques Pienaar shouldCloneInlinedRegion, call))) 5385830f71aSRiver Riddle return cleanupState(); 5395830f71aSRiver Riddle return success(); 5400ba00878SRiver Riddle } 541