xref: /llvm-project/mlir/lib/Transforms/Utils/InliningUtils.cpp (revision b39c5cb6977f35ad727d86b2dd6232099734ffd3)
10ba00878SRiver Riddle //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
20ba00878SRiver Riddle //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60ba00878SRiver Riddle //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
80ba00878SRiver Riddle //
90ba00878SRiver Riddle // This file implements miscellaneous inlining utilities.
100ba00878SRiver Riddle //
110ba00878SRiver Riddle //===----------------------------------------------------------------------===//
120ba00878SRiver Riddle 
130ba00878SRiver Riddle #include "mlir/Transforms/InliningUtils.h"
140ba00878SRiver Riddle 
155830f71aSRiver Riddle #include "mlir/IR/Builders.h"
164d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
170ba00878SRiver Riddle #include "mlir/IR/Operation.h"
1836550692SRiver Riddle #include "mlir/Interfaces/CallInterfaces.h"
190ba00878SRiver Riddle #include "llvm/ADT/MapVector.h"
20553f794bSSean Silva #include "llvm/Support/Debug.h"
210ba00878SRiver Riddle #include "llvm/Support/raw_ostream.h"
22a1fe1f5fSKazu Hirata #include <optional>
230ba00878SRiver Riddle 
240ba00878SRiver Riddle #define DEBUG_TYPE "inlining"
250ba00878SRiver Riddle 
260ba00878SRiver Riddle using namespace mlir;
270ba00878SRiver Riddle 
286f092e50SChristian Ulmann /// Remap all locations reachable from the inlined blocks with CallSiteLoc
296f092e50SChristian Ulmann /// locations with the provided caller location.
300ba00878SRiver Riddle static void
314562e389SRiver Riddle remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,
320ba00878SRiver Riddle                       Location callerLoc) {
336f092e50SChristian Ulmann   DenseMap<Location, LocationAttr> mappedLocations;
346f092e50SChristian Ulmann   auto remapLoc = [&](Location loc) {
356f092e50SChristian Ulmann     auto [it, inserted] = mappedLocations.try_emplace(loc);
366f092e50SChristian Ulmann     // Only query the attribute uniquer once per callsite attribute.
376f092e50SChristian Ulmann     if (inserted) {
386f092e50SChristian Ulmann       auto newLoc = CallSiteLoc::get(loc, callerLoc);
396f092e50SChristian Ulmann       it->getSecond() = newLoc;
400ba00878SRiver Riddle     }
416f092e50SChristian Ulmann     return it->second;
420ba00878SRiver Riddle   };
436f092e50SChristian Ulmann 
446f092e50SChristian Ulmann   AttrTypeReplacer attrReplacer;
456f092e50SChristian Ulmann   attrReplacer.addReplacement(
466f092e50SChristian Ulmann       [&](LocationAttr loc) -> std::pair<LocationAttr, WalkResult> {
476f092e50SChristian Ulmann         return {remapLoc(loc), WalkResult::skip()};
486f092e50SChristian Ulmann       });
496f092e50SChristian Ulmann 
506f092e50SChristian Ulmann   for (Block &block : inlinedBlocks) {
516f092e50SChristian Ulmann     for (BlockArgument &arg : block.getArguments())
526f092e50SChristian Ulmann       if (LocationAttr newLoc = remapLoc(arg.getLoc()))
536f092e50SChristian Ulmann         arg.setLoc(newLoc);
546f092e50SChristian Ulmann 
556f092e50SChristian Ulmann     for (Operation &op : block)
566f092e50SChristian Ulmann       attrReplacer.recursivelyReplaceElementsIn(&op, /*replaceAttrs=*/false,
576f092e50SChristian Ulmann                                                 /*replaceLocs=*/true);
586f092e50SChristian Ulmann   }
590ba00878SRiver Riddle }
600ba00878SRiver Riddle 
614562e389SRiver Riddle static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,
624d67b278SJeff Niu                                  IRMapping &mapper) {
630ba00878SRiver Riddle   auto remapOperands = [&](Operation *op) {
640ba00878SRiver Riddle     for (auto &operand : op->getOpOperands())
6535807bc4SRiver Riddle       if (auto mappedOp = mapper.lookupOrNull(operand.get()))
660ba00878SRiver Riddle         operand.set(mappedOp);
670ba00878SRiver Riddle   };
680ba00878SRiver Riddle   for (auto &block : inlinedBlocks)
690ba00878SRiver Riddle     block.walk(remapOperands);
700ba00878SRiver Riddle }
710ba00878SRiver Riddle 
720ba00878SRiver Riddle //===----------------------------------------------------------------------===//
730ba00878SRiver Riddle // InlinerInterface
740ba00878SRiver Riddle //===----------------------------------------------------------------------===//
750ba00878SRiver Riddle 
76fa417479SRiver Riddle bool InlinerInterface::isLegalToInline(Operation *call, Operation *callable,
77fa417479SRiver Riddle                                        bool wouldBeCloned) const {
78fa417479SRiver Riddle   if (auto *handler = getInterfaceFor(call))
79fa417479SRiver Riddle     return handler->isLegalToInline(call, callable, wouldBeCloned);
80fa417479SRiver Riddle   return false;
81501fda01SRiver Riddle }
82501fda01SRiver Riddle 
834d67b278SJeff Niu bool InlinerInterface::isLegalToInline(Region *dest, Region *src,
844d67b278SJeff Niu                                        bool wouldBeCloned,
854d67b278SJeff Niu                                        IRMapping &valueMapping) const {
86fa417479SRiver Riddle   if (auto *handler = getInterfaceFor(dest->getParentOp()))
87fa417479SRiver Riddle     return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping);
88fa417479SRiver Riddle   return false;
890ba00878SRiver Riddle }
900ba00878SRiver Riddle 
914d67b278SJeff Niu bool InlinerInterface::isLegalToInline(Operation *op, Region *dest,
924d67b278SJeff Niu                                        bool wouldBeCloned,
934d67b278SJeff Niu                                        IRMapping &valueMapping) const {
94fa417479SRiver Riddle   if (auto *handler = getInterfaceFor(op))
95fa417479SRiver Riddle     return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping);
96fa417479SRiver Riddle   return false;
970ba00878SRiver Riddle }
980ba00878SRiver Riddle 
990ba00878SRiver Riddle bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const {
1000ba00878SRiver Riddle   auto *handler = getInterfaceFor(op);
1010ba00878SRiver Riddle   return handler ? handler->shouldAnalyzeRecursively(op) : true;
1020ba00878SRiver Riddle }
1030ba00878SRiver Riddle 
1040ba00878SRiver Riddle /// Handle the given inlined terminator by replacing it with a new operation
1050ba00878SRiver Riddle /// as necessary.
1060ba00878SRiver Riddle void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const {
1070ba00878SRiver Riddle   auto *handler = getInterfaceFor(op);
1080ba00878SRiver Riddle   assert(handler && "expected valid dialect handler");
1090ba00878SRiver Riddle   handler->handleTerminator(op, newDest);
1100ba00878SRiver Riddle }
1110ba00878SRiver Riddle 
1120ba00878SRiver Riddle /// Handle the given inlined terminator by replacing it with a new operation
1130ba00878SRiver Riddle /// as necessary.
1140ba00878SRiver Riddle void InlinerInterface::handleTerminator(Operation *op,
11526a0b277SMehdi Amini                                         ValueRange valuesToRepl) const {
1160ba00878SRiver Riddle   auto *handler = getInterfaceFor(op);
1170ba00878SRiver Riddle   assert(handler && "expected valid dialect handler");
1180ba00878SRiver Riddle   handler->handleTerminator(op, valuesToRepl);
1190ba00878SRiver Riddle }
1200ba00878SRiver Riddle 
121*b39c5cb6SWilliam Moses /// Returns true if the inliner can assume a fast path of not creating a
122*b39c5cb6SWilliam Moses /// new block, if there is only one block.
123*b39c5cb6SWilliam Moses bool InlinerInterface::allowSingleBlockOptimization(
124*b39c5cb6SWilliam Moses     iterator_range<Region::iterator> inlinedBlocks) const {
125*b39c5cb6SWilliam Moses   if (inlinedBlocks.empty()) {
126*b39c5cb6SWilliam Moses     return true;
127*b39c5cb6SWilliam Moses   }
128*b39c5cb6SWilliam Moses   auto *handler = getInterfaceFor(inlinedBlocks.begin()->getParentOp());
129*b39c5cb6SWilliam Moses   assert(handler && "expected valid dialect handler");
130*b39c5cb6SWilliam Moses   return handler->allowSingleBlockOptimization(inlinedBlocks);
131*b39c5cb6SWilliam Moses }
132*b39c5cb6SWilliam Moses 
133f809eb4dSTobias Gysi Value InlinerInterface::handleArgument(OpBuilder &builder, Operation *call,
134f809eb4dSTobias Gysi                                        Operation *callable, Value argument,
135f809eb4dSTobias Gysi                                        DictionaryAttr argumentAttrs) const {
136f809eb4dSTobias Gysi   auto *handler = getInterfaceFor(callable);
137f809eb4dSTobias Gysi   assert(handler && "expected valid dialect handler");
1380fb4ac55STobias Gysi   return handler->handleArgument(builder, call, callable, argument,
139f809eb4dSTobias Gysi                                  argumentAttrs);
140f809eb4dSTobias Gysi }
141f809eb4dSTobias Gysi 
142f809eb4dSTobias Gysi Value InlinerInterface::handleResult(OpBuilder &builder, Operation *call,
143f809eb4dSTobias Gysi                                      Operation *callable, Value result,
144f809eb4dSTobias Gysi                                      DictionaryAttr resultAttrs) const {
145f809eb4dSTobias Gysi   auto *handler = getInterfaceFor(callable);
146f809eb4dSTobias Gysi   assert(handler && "expected valid dialect handler");
1470fb4ac55STobias Gysi   return handler->handleResult(builder, call, callable, result, resultAttrs);
148f809eb4dSTobias Gysi }
149f809eb4dSTobias Gysi 
1500e760a08SJacques Pienaar void InlinerInterface::processInlinedCallBlocks(
1510e760a08SJacques Pienaar     Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
1520e760a08SJacques Pienaar   auto *handler = getInterfaceFor(call);
1530e760a08SJacques Pienaar   assert(handler && "expected valid dialect handler");
1540e760a08SJacques Pienaar   handler->processInlinedCallBlocks(call, inlinedBlocks);
1550e760a08SJacques Pienaar }
1560e760a08SJacques Pienaar 
1570ba00878SRiver Riddle /// Utility to check that all of the operations within 'src' can be inlined.
1580ba00878SRiver Riddle static bool isLegalToInline(InlinerInterface &interface, Region *src,
159fa417479SRiver Riddle                             Region *insertRegion, bool shouldCloneInlinedRegion,
1604d67b278SJeff Niu                             IRMapping &valueMapping) {
1610ba00878SRiver Riddle   for (auto &block : *src) {
1620ba00878SRiver Riddle     for (auto &op : block) {
1630ba00878SRiver Riddle       // Check this operation.
164fa417479SRiver Riddle       if (!interface.isLegalToInline(&op, insertRegion,
165fa417479SRiver Riddle                                      shouldCloneInlinedRegion, valueMapping)) {
166553f794bSSean Silva         LLVM_DEBUG({
167553f794bSSean Silva           llvm::dbgs() << "* Illegal to inline because of op: ";
168553f794bSSean Silva           op.dump();
169553f794bSSean Silva         });
1700ba00878SRiver Riddle         return false;
171553f794bSSean Silva       }
1720ba00878SRiver Riddle       // Check any nested regions.
1730ba00878SRiver Riddle       if (interface.shouldAnalyzeRecursively(&op) &&
1740ba00878SRiver Riddle           llvm::any_of(op.getRegions(), [&](Region &region) {
1750ba00878SRiver Riddle             return !isLegalToInline(interface, &region, insertRegion,
176fa417479SRiver Riddle                                     shouldCloneInlinedRegion, valueMapping);
1770ba00878SRiver Riddle           }))
1780ba00878SRiver Riddle         return false;
1790ba00878SRiver Riddle     }
1800ba00878SRiver Riddle   }
1810ba00878SRiver Riddle   return true;
1820ba00878SRiver Riddle }
1830ba00878SRiver Riddle 
1840ba00878SRiver Riddle //===----------------------------------------------------------------------===//
1850ba00878SRiver Riddle // Inline Methods
1860ba00878SRiver Riddle //===----------------------------------------------------------------------===//
1870ba00878SRiver Riddle 
188f809eb4dSTobias Gysi static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder,
189f809eb4dSTobias Gysi                                CallOpInterface call,
190f809eb4dSTobias Gysi                                CallableOpInterface callable,
191f809eb4dSTobias Gysi                                IRMapping &mapper) {
192f809eb4dSTobias Gysi   // Unpack the argument attributes if there are any.
193f809eb4dSTobias Gysi   SmallVector<DictionaryAttr> argAttrs(
194f809eb4dSTobias Gysi       callable.getCallableRegion()->getNumArguments(),
195f809eb4dSTobias Gysi       builder.getDictionaryAttr({}));
19634a35a8bSMartin Erhart   if (ArrayAttr arrayAttr = callable.getArgAttrsAttr()) {
197f809eb4dSTobias Gysi     assert(arrayAttr.size() == argAttrs.size());
198f809eb4dSTobias Gysi     for (auto [idx, attr] : llvm::enumerate(arrayAttr))
199f809eb4dSTobias Gysi       argAttrs[idx] = cast<DictionaryAttr>(attr);
200f809eb4dSTobias Gysi   }
201f809eb4dSTobias Gysi 
202f809eb4dSTobias Gysi   // Run the argument attribute handler for the given argument and attribute.
203f809eb4dSTobias Gysi   for (auto [blockArg, argAttr] :
204f809eb4dSTobias Gysi        llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) {
2050fb4ac55STobias Gysi     Value newArgument = interface.handleArgument(
2060fb4ac55STobias Gysi         builder, call, callable, mapper.lookup(blockArg), argAttr);
2070fb4ac55STobias Gysi     assert(newArgument.getType() == mapper.lookup(blockArg).getType() &&
2080fb4ac55STobias Gysi            "expected the argument type to not change");
209f809eb4dSTobias Gysi 
210f809eb4dSTobias Gysi     // Update the mapping to point the new argument returned by the handler.
211f809eb4dSTobias Gysi     mapper.map(blockArg, newArgument);
212f809eb4dSTobias Gysi   }
213f809eb4dSTobias Gysi }
214f809eb4dSTobias Gysi 
215f809eb4dSTobias Gysi static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder,
216f809eb4dSTobias Gysi                              CallOpInterface call, CallableOpInterface callable,
217f809eb4dSTobias Gysi                              ValueRange results) {
218f809eb4dSTobias Gysi   // Unpack the result attributes if there are any.
219f809eb4dSTobias Gysi   SmallVector<DictionaryAttr> resAttrs(results.size(),
220f809eb4dSTobias Gysi                                        builder.getDictionaryAttr({}));
22134a35a8bSMartin Erhart   if (ArrayAttr arrayAttr = callable.getResAttrsAttr()) {
222f809eb4dSTobias Gysi     assert(arrayAttr.size() == resAttrs.size());
223f809eb4dSTobias Gysi     for (auto [idx, attr] : llvm::enumerate(arrayAttr))
224f809eb4dSTobias Gysi       resAttrs[idx] = cast<DictionaryAttr>(attr);
225f809eb4dSTobias Gysi   }
226f809eb4dSTobias Gysi 
227f809eb4dSTobias Gysi   // Run the result attribute handler for the given result and attribute.
228f809eb4dSTobias Gysi   SmallVector<DictionaryAttr> resultAttributes;
229f809eb4dSTobias Gysi   for (auto [result, resAttr] : llvm::zip(results, resAttrs)) {
230f809eb4dSTobias Gysi     // Store the original result users before running the handler.
231f809eb4dSTobias Gysi     DenseSet<Operation *> resultUsers;
232f809eb4dSTobias Gysi     for (Operation *user : result.getUsers())
233f809eb4dSTobias Gysi       resultUsers.insert(user);
234f809eb4dSTobias Gysi 
2350fb4ac55STobias Gysi     Value newResult =
2360fb4ac55STobias Gysi         interface.handleResult(builder, call, callable, result, resAttr);
2370fb4ac55STobias Gysi     assert(newResult.getType() == result.getType() &&
2380fb4ac55STobias Gysi            "expected the result type to not change");
239f809eb4dSTobias Gysi 
240f809eb4dSTobias Gysi     // Replace the result uses except for the ones introduce by the handler.
241f809eb4dSTobias Gysi     result.replaceUsesWithIf(newResult, [&](OpOperand &operand) {
242f809eb4dSTobias Gysi       return resultUsers.count(operand.getOwner());
243f809eb4dSTobias Gysi     });
244f809eb4dSTobias Gysi   }
245f809eb4dSTobias Gysi }
246f809eb4dSTobias Gysi 
2470e760a08SJacques Pienaar static LogicalResult
248da12d88bSRiver Riddle inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
2494d67b278SJeff Niu                  Block::iterator inlinePoint, IRMapping &mapper,
2500e760a08SJacques Pienaar                  ValueRange resultsToReplace, TypeRange regionResultTypes,
2510a81ace0SKazu Hirata                  std::optional<Location> inlineLoc,
2526089d612SRahul Kayaith                  bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
25322219cfcSSean Silva   assert(resultsToReplace.size() == regionResultTypes.size());
2540ba00878SRiver Riddle   // We expect the region to have at least one block.
2550ba00878SRiver Riddle   if (src->empty())
2560ba00878SRiver Riddle     return failure();
2570ba00878SRiver Riddle 
2580ba00878SRiver Riddle   // Check that all of the region arguments have been mapped.
2590ba00878SRiver Riddle   auto *srcEntryBlock = &src->front();
2600ba00878SRiver Riddle   if (llvm::any_of(srcEntryBlock->getArguments(),
261e62a6956SRiver Riddle                    [&](BlockArgument arg) { return !mapper.contains(arg); }))
2620ba00878SRiver Riddle     return failure();
2630ba00878SRiver Riddle 
2640ba00878SRiver Riddle   // Check that the operations within the source region are valid to inline.
265da12d88bSRiver Riddle   Region *insertRegion = inlineBlock->getParent();
266fa417479SRiver Riddle   if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion,
267fa417479SRiver Riddle                                  mapper) ||
268fa417479SRiver Riddle       !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion,
269fa417479SRiver Riddle                        mapper))
2700ba00878SRiver Riddle     return failure();
2710ba00878SRiver Riddle 
272f809eb4dSTobias Gysi   // Run the argument attribute handler before inlining the callable region.
273f809eb4dSTobias Gysi   OpBuilder builder(inlineBlock, inlinePoint);
274f809eb4dSTobias Gysi   auto callable = dyn_cast<CallableOpInterface>(src->getParentOp());
275f809eb4dSTobias Gysi   if (call && callable)
276f809eb4dSTobias Gysi     handleArgumentImpl(interface, builder, call, callable, mapper);
277f809eb4dSTobias Gysi 
2780ba00878SRiver Riddle   // Check to see if the region is being cloned, or moved inline. In either
2790ba00878SRiver Riddle   // case, move the new blocks after the 'insertBlock' to improve IR
2800ba00878SRiver Riddle   // readability.
281da12d88bSRiver Riddle   Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint);
2820ba00878SRiver Riddle   if (shouldCloneInlinedRegion)
2830ba00878SRiver Riddle     src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
2840ba00878SRiver Riddle   else
2850ba00878SRiver Riddle     insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
2860ba00878SRiver Riddle                                      src->getBlocks(), src->begin(),
2870ba00878SRiver Riddle                                      src->end());
2880ba00878SRiver Riddle 
2890ba00878SRiver Riddle   // Get the range of newly inserted blocks.
290da12d88bSRiver Riddle   auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
2910ba00878SRiver Riddle                                     postInsertBlock->getIterator());
2920ba00878SRiver Riddle   Block *firstNewBlock = &*newBlocks.begin();
2930ba00878SRiver Riddle 
2940ba00878SRiver Riddle   // Remap the locations of the inlined operations if a valid source location
2950ba00878SRiver Riddle   // was provided.
29668f58812STres Popp   if (inlineLoc && !llvm::isa<UnknownLoc>(*inlineLoc))
2970ba00878SRiver Riddle     remapInlinedLocations(newBlocks, *inlineLoc);
2980ba00878SRiver Riddle 
2990ba00878SRiver Riddle   // If the blocks were moved in-place, make sure to remap any necessary
3000ba00878SRiver Riddle   // operands.
3010ba00878SRiver Riddle   if (!shouldCloneInlinedRegion)
3020ba00878SRiver Riddle     remapInlinedOperands(newBlocks, mapper);
3030ba00878SRiver Riddle 
304a20d96e4SRiver Riddle   // Process the newly inlined blocks.
3050e760a08SJacques Pienaar   if (call)
3060e760a08SJacques Pienaar     interface.processInlinedCallBlocks(call, newBlocks);
307a20d96e4SRiver Riddle   interface.processInlinedBlocks(newBlocks);
308a20d96e4SRiver Riddle 
309*b39c5cb6SWilliam Moses   bool singleBlockFastPath = interface.allowSingleBlockOptimization(newBlocks);
310*b39c5cb6SWilliam Moses 
3110ba00878SRiver Riddle   // Handle the case where only a single block was inlined.
312*b39c5cb6SWilliam Moses   if (singleBlockFastPath && std::next(newBlocks.begin()) == newBlocks.end()) {
313f809eb4dSTobias Gysi     // Run the result attribute handler on the terminator operands.
314f809eb4dSTobias Gysi     Operation *firstBlockTerminator = firstNewBlock->getTerminator();
315f809eb4dSTobias Gysi     builder.setInsertionPoint(firstBlockTerminator);
316f809eb4dSTobias Gysi     if (call && callable)
317f809eb4dSTobias Gysi       handleResultImpl(interface, builder, call, callable,
318f809eb4dSTobias Gysi                        firstBlockTerminator->getOperands());
319f809eb4dSTobias Gysi 
3200ba00878SRiver Riddle     // Have the interface handle the terminator of this block.
32126a0b277SMehdi Amini     interface.handleTerminator(firstBlockTerminator, resultsToReplace);
3220ba00878SRiver Riddle     firstBlockTerminator->erase();
3230ba00878SRiver Riddle 
3240ba00878SRiver Riddle     // Merge the post insert block into the cloned entry block.
3250ba00878SRiver Riddle     firstNewBlock->getOperations().splice(firstNewBlock->end(),
3260ba00878SRiver Riddle                                           postInsertBlock->getOperations());
3270ba00878SRiver Riddle     postInsertBlock->erase();
3280ba00878SRiver Riddle   } else {
3290ba00878SRiver Riddle     // Otherwise, there were multiple blocks inlined. Add arguments to the post
3300ba00878SRiver Riddle     // insertion block to represent the results to replace.
331e4853be2SMehdi Amini     for (const auto &resultToRepl : llvm::enumerate(resultsToReplace)) {
332e084679fSRiver Riddle       resultToRepl.value().replaceAllUsesWith(
333e084679fSRiver Riddle           postInsertBlock->addArgument(regionResultTypes[resultToRepl.index()],
334e084679fSRiver Riddle                                        resultToRepl.value().getLoc()));
3350ba00878SRiver Riddle     }
3360ba00878SRiver Riddle 
337f809eb4dSTobias Gysi     // Run the result attribute handler on the post insertion block arguments.
338f809eb4dSTobias Gysi     builder.setInsertionPointToStart(postInsertBlock);
339f809eb4dSTobias Gysi     if (call && callable)
340f809eb4dSTobias Gysi       handleResultImpl(interface, builder, call, callable,
341f809eb4dSTobias Gysi                        postInsertBlock->getArguments());
342f809eb4dSTobias Gysi 
3430ba00878SRiver Riddle     /// Handle the terminators for each of the new blocks.
3440ba00878SRiver Riddle     for (auto &newBlock : newBlocks)
3450ba00878SRiver Riddle       interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);
3460ba00878SRiver Riddle   }
3470ba00878SRiver Riddle 
3480ba00878SRiver Riddle   // Splice the instructions of the inlined entry block into the insert block.
349da12d88bSRiver Riddle   inlineBlock->getOperations().splice(inlineBlock->end(),
3500ba00878SRiver Riddle                                       firstNewBlock->getOperations());
3510ba00878SRiver Riddle   firstNewBlock->erase();
3520ba00878SRiver Riddle   return success();
3530ba00878SRiver Riddle }
3540ba00878SRiver Riddle 
3550e760a08SJacques Pienaar static LogicalResult
356da12d88bSRiver Riddle inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
357da12d88bSRiver Riddle                  Block::iterator inlinePoint, ValueRange inlinedOperands,
3580a81ace0SKazu Hirata                  ValueRange resultsToReplace, std::optional<Location> inlineLoc,
3596089d612SRahul Kayaith                  bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
3600ba00878SRiver Riddle   // We expect the region to have at least one block.
3610ba00878SRiver Riddle   if (src->empty())
3620ba00878SRiver Riddle     return failure();
3630ba00878SRiver Riddle 
3640ba00878SRiver Riddle   auto *entryBlock = &src->front();
3650ba00878SRiver Riddle   if (inlinedOperands.size() != entryBlock->getNumArguments())
3660ba00878SRiver Riddle     return failure();
3670ba00878SRiver Riddle 
3680ba00878SRiver Riddle   // Map the provided call operands to the arguments of the region.
3694d67b278SJeff Niu   IRMapping mapper;
3700ba00878SRiver Riddle   for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
3710ba00878SRiver Riddle     // Verify that the types of the provided values match the function argument
3720ba00878SRiver Riddle     // types.
373e62a6956SRiver Riddle     BlockArgument regionArg = entryBlock->getArgument(i);
3742bdf33ccSRiver Riddle     if (inlinedOperands[i].getType() != regionArg.getType())
3750ba00878SRiver Riddle       return failure();
3760ba00878SRiver Riddle     mapper.map(regionArg, inlinedOperands[i]);
3770ba00878SRiver Riddle   }
3780ba00878SRiver Riddle 
3790ba00878SRiver Riddle   // Call into the main region inliner function.
380da12d88bSRiver Riddle   return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
381da12d88bSRiver Riddle                           resultsToReplace, resultsToReplace.getTypes(),
382da12d88bSRiver Riddle                           inlineLoc, shouldCloneInlinedRegion, call);
3830e760a08SJacques Pienaar }
3840e760a08SJacques Pienaar 
3850e760a08SJacques Pienaar LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
3864d67b278SJeff Niu                                  Operation *inlinePoint, IRMapping &mapper,
3870e760a08SJacques Pienaar                                  ValueRange resultsToReplace,
3880e760a08SJacques Pienaar                                  TypeRange regionResultTypes,
3890a81ace0SKazu Hirata                                  std::optional<Location> inlineLoc,
3900e760a08SJacques Pienaar                                  bool shouldCloneInlinedRegion) {
391da12d88bSRiver Riddle   return inlineRegion(interface, src, inlinePoint->getBlock(),
392da12d88bSRiver Riddle                       ++inlinePoint->getIterator(), mapper, resultsToReplace,
393da12d88bSRiver Riddle                       regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
394da12d88bSRiver Riddle }
3954d67b278SJeff Niu LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
3964d67b278SJeff Niu                                  Block *inlineBlock,
3974d67b278SJeff Niu                                  Block::iterator inlinePoint, IRMapping &mapper,
3984d67b278SJeff Niu                                  ValueRange resultsToReplace,
3994d67b278SJeff Niu                                  TypeRange regionResultTypes,
4000a81ace0SKazu Hirata                                  std::optional<Location> inlineLoc,
401da12d88bSRiver Riddle                                  bool shouldCloneInlinedRegion) {
402da12d88bSRiver Riddle   return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
403da12d88bSRiver Riddle                           resultsToReplace, regionResultTypes, inlineLoc,
404da12d88bSRiver Riddle                           shouldCloneInlinedRegion);
4050e760a08SJacques Pienaar }
4060e760a08SJacques Pienaar 
4070e760a08SJacques Pienaar LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
4080e760a08SJacques Pienaar                                  Operation *inlinePoint,
4090e760a08SJacques Pienaar                                  ValueRange inlinedOperands,
4100e760a08SJacques Pienaar                                  ValueRange resultsToReplace,
4110a81ace0SKazu Hirata                                  std::optional<Location> inlineLoc,
4120e760a08SJacques Pienaar                                  bool shouldCloneInlinedRegion) {
413da12d88bSRiver Riddle   return inlineRegion(interface, src, inlinePoint->getBlock(),
414da12d88bSRiver Riddle                       ++inlinePoint->getIterator(), inlinedOperands,
415da12d88bSRiver Riddle                       resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
416da12d88bSRiver Riddle }
4170a81ace0SKazu Hirata LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
4180a81ace0SKazu Hirata                                  Block *inlineBlock,
4190a81ace0SKazu Hirata                                  Block::iterator inlinePoint,
4200a81ace0SKazu Hirata                                  ValueRange inlinedOperands,
4210a81ace0SKazu Hirata                                  ValueRange resultsToReplace,
4220a81ace0SKazu Hirata                                  std::optional<Location> inlineLoc,
423da12d88bSRiver Riddle                                  bool shouldCloneInlinedRegion) {
424da12d88bSRiver Riddle   return inlineRegionImpl(interface, src, inlineBlock, inlinePoint,
425da12d88bSRiver Riddle                           inlinedOperands, resultsToReplace, inlineLoc,
426da12d88bSRiver Riddle                           shouldCloneInlinedRegion);
4270ba00878SRiver Riddle }
4280ba00878SRiver Riddle 
4295830f71aSRiver Riddle /// Utility function used to generate a cast operation from the given interface,
4305830f71aSRiver Riddle /// or return nullptr if a cast could not be generated.
431e62a6956SRiver Riddle static Value materializeConversion(const DialectInlinerInterface *interface,
4325830f71aSRiver Riddle                                    SmallVectorImpl<Operation *> &castOps,
433e62a6956SRiver Riddle                                    OpBuilder &castBuilder, Value arg, Type type,
434e62a6956SRiver Riddle                                    Location conversionLoc) {
4355830f71aSRiver Riddle   if (!interface)
4365830f71aSRiver Riddle     return nullptr;
4375830f71aSRiver Riddle 
4385830f71aSRiver Riddle   // Check to see if the interface for the call can materialize a conversion.
4395830f71aSRiver Riddle   Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
4405830f71aSRiver Riddle                                                            type, conversionLoc);
4415830f71aSRiver Riddle   if (!castOp)
4425830f71aSRiver Riddle     return nullptr;
4435830f71aSRiver Riddle   castOps.push_back(castOp);
4445830f71aSRiver Riddle 
4455830f71aSRiver Riddle   // Ensure that the generated cast is correct.
4465830f71aSRiver Riddle   assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
4475830f71aSRiver Riddle          castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
4485830f71aSRiver Riddle   return castOp->getResult(0);
4495830f71aSRiver Riddle }
4505830f71aSRiver Riddle 
4515830f71aSRiver Riddle /// This function inlines a given region, 'src', of a callable operation,
4525830f71aSRiver Riddle /// 'callable', into the location defined by the given call operation. This
4535830f71aSRiver Riddle /// function returns failure if inlining is not possible, success otherwise. On
4545830f71aSRiver Riddle /// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
4555830f71aSRiver Riddle /// corresponds to whether the source region should be cloned into the 'call' or
4565830f71aSRiver Riddle /// spliced directly.
4575830f71aSRiver Riddle LogicalResult mlir::inlineCall(InlinerInterface &interface,
4585830f71aSRiver Riddle                                CallOpInterface call,
4595830f71aSRiver Riddle                                CallableOpInterface callable, Region *src,
4605830f71aSRiver Riddle                                bool shouldCloneInlinedRegion) {
4615830f71aSRiver Riddle   // We expect the region to have at least one block.
4625830f71aSRiver Riddle   if (src->empty())
4635830f71aSRiver Riddle     return failure();
4645830f71aSRiver Riddle   auto *entryBlock = &src->front();
46534a35a8bSMartin Erhart   ArrayRef<Type> callableResultTypes = callable.getResultTypes();
4665830f71aSRiver Riddle 
4675830f71aSRiver Riddle   // Make sure that the number of arguments and results matchup between the call
4685830f71aSRiver Riddle   // and the region.
469e62a6956SRiver Riddle   SmallVector<Value, 8> callOperands(call.getArgOperands());
470c4a04059SChristian Sigg   SmallVector<Value, 8> callResults(call->getResults());
4715830f71aSRiver Riddle   if (callOperands.size() != entryBlock->getNumArguments() ||
4725830f71aSRiver Riddle       callResults.size() != callableResultTypes.size())
4730ba00878SRiver Riddle     return failure();
4740ba00878SRiver Riddle 
4755830f71aSRiver Riddle   // A set of cast operations generated to matchup the signature of the region
4765830f71aSRiver Riddle   // with the signature of the call.
4775830f71aSRiver Riddle   SmallVector<Operation *, 4> castOps;
4785830f71aSRiver Riddle   castOps.reserve(callOperands.size() + callResults.size());
4790ba00878SRiver Riddle 
4805830f71aSRiver Riddle   // Functor used to cleanup generated state on failure.
4815830f71aSRiver Riddle   auto cleanupState = [&] {
4825830f71aSRiver Riddle     for (auto *op : castOps) {
4832bdf33ccSRiver Riddle       op->getResult(0).replaceAllUsesWith(op->getOperand(0));
4845830f71aSRiver Riddle       op->erase();
4855830f71aSRiver Riddle     }
4860ba00878SRiver Riddle     return failure();
4875830f71aSRiver Riddle   };
4880ba00878SRiver Riddle 
4895830f71aSRiver Riddle   // Builder used for any conversion operations that need to be materialized.
4905830f71aSRiver Riddle   OpBuilder castBuilder(call);
4915830f71aSRiver Riddle   Location castLoc = call.getLoc();
4920bf4a82aSChristian Sigg   const auto *callInterface = interface.getInterfaceFor(call->getDialect());
4935830f71aSRiver Riddle 
4945830f71aSRiver Riddle   // Map the provided call operands to the arguments of the region.
4954d67b278SJeff Niu   IRMapping mapper;
4965830f71aSRiver Riddle   for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
497e62a6956SRiver Riddle     BlockArgument regionArg = entryBlock->getArgument(i);
498e62a6956SRiver Riddle     Value operand = callOperands[i];
4995830f71aSRiver Riddle 
5005830f71aSRiver Riddle     // If the call operand doesn't match the expected region argument, try to
5015830f71aSRiver Riddle     // generate a cast.
5022bdf33ccSRiver Riddle     Type regionArgType = regionArg.getType();
5032bdf33ccSRiver Riddle     if (operand.getType() != regionArgType) {
5045830f71aSRiver Riddle       if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
5055830f71aSRiver Riddle                                             operand, regionArgType, castLoc)))
5065830f71aSRiver Riddle         return cleanupState();
5075830f71aSRiver Riddle     }
5085830f71aSRiver Riddle     mapper.map(regionArg, operand);
5095830f71aSRiver Riddle   }
5105830f71aSRiver Riddle 
511706d992cSRahul Joshi   // Ensure that the resultant values of the call match the callable.
5125830f71aSRiver Riddle   castBuilder.setInsertionPointAfter(call);
5135830f71aSRiver Riddle   for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
514e62a6956SRiver Riddle     Value callResult = callResults[i];
5152bdf33ccSRiver Riddle     if (callResult.getType() == callableResultTypes[i])
5165830f71aSRiver Riddle       continue;
5175830f71aSRiver Riddle 
5185830f71aSRiver Riddle     // Generate a conversion that will produce the original type, so that the IR
5195830f71aSRiver Riddle     // is still valid after the original call gets replaced.
520e62a6956SRiver Riddle     Value castResult =
5215830f71aSRiver Riddle         materializeConversion(callInterface, castOps, castBuilder, callResult,
5222bdf33ccSRiver Riddle                               callResult.getType(), castLoc);
5235830f71aSRiver Riddle     if (!castResult)
5245830f71aSRiver Riddle       return cleanupState();
5252bdf33ccSRiver Riddle     callResult.replaceAllUsesWith(castResult);
5262bdf33ccSRiver Riddle     castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult);
5275830f71aSRiver Riddle   }
5285830f71aSRiver Riddle 
529501fda01SRiver Riddle   // Check that it is legal to inline the callable into the call.
530fa417479SRiver Riddle   if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion))
531501fda01SRiver Riddle     return cleanupState();
532501fda01SRiver Riddle 
5335830f71aSRiver Riddle   // Attempt to inline the call.
534da12d88bSRiver Riddle   if (failed(inlineRegionImpl(interface, src, call->getBlock(),
535da12d88bSRiver Riddle                               ++call->getIterator(), mapper, callResults,
53622219cfcSSean Silva                               callableResultTypes, call.getLoc(),
5370e760a08SJacques Pienaar                               shouldCloneInlinedRegion, call)))
5385830f71aSRiver Riddle     return cleanupState();
5395830f71aSRiver Riddle   return success();
5400ba00878SRiver Riddle }
541