xref: /llvm-project/mlir/lib/Transforms/Utils/InliningUtils.cpp (revision b39c5cb6977f35ad727d86b2dd6232099734ffd3)
1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous inlining utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/InliningUtils.h"
14 
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/IRMapping.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Interfaces/CallInterfaces.h"
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
22 #include <optional>
23 
24 #define DEBUG_TYPE "inlining"
25 
26 using namespace mlir;
27 
28 /// Remap all locations reachable from the inlined blocks with CallSiteLoc
29 /// locations with the provided caller location.
30 static void
31 remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,
32                       Location callerLoc) {
33   DenseMap<Location, LocationAttr> mappedLocations;
34   auto remapLoc = [&](Location loc) {
35     auto [it, inserted] = mappedLocations.try_emplace(loc);
36     // Only query the attribute uniquer once per callsite attribute.
37     if (inserted) {
38       auto newLoc = CallSiteLoc::get(loc, callerLoc);
39       it->getSecond() = newLoc;
40     }
41     return it->second;
42   };
43 
44   AttrTypeReplacer attrReplacer;
45   attrReplacer.addReplacement(
46       [&](LocationAttr loc) -> std::pair<LocationAttr, WalkResult> {
47         return {remapLoc(loc), WalkResult::skip()};
48       });
49 
50   for (Block &block : inlinedBlocks) {
51     for (BlockArgument &arg : block.getArguments())
52       if (LocationAttr newLoc = remapLoc(arg.getLoc()))
53         arg.setLoc(newLoc);
54 
55     for (Operation &op : block)
56       attrReplacer.recursivelyReplaceElementsIn(&op, /*replaceAttrs=*/false,
57                                                 /*replaceLocs=*/true);
58   }
59 }
60 
61 static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,
62                                  IRMapping &mapper) {
63   auto remapOperands = [&](Operation *op) {
64     for (auto &operand : op->getOpOperands())
65       if (auto mappedOp = mapper.lookupOrNull(operand.get()))
66         operand.set(mappedOp);
67   };
68   for (auto &block : inlinedBlocks)
69     block.walk(remapOperands);
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // InlinerInterface
74 //===----------------------------------------------------------------------===//
75 
76 bool InlinerInterface::isLegalToInline(Operation *call, Operation *callable,
77                                        bool wouldBeCloned) const {
78   if (auto *handler = getInterfaceFor(call))
79     return handler->isLegalToInline(call, callable, wouldBeCloned);
80   return false;
81 }
82 
83 bool InlinerInterface::isLegalToInline(Region *dest, Region *src,
84                                        bool wouldBeCloned,
85                                        IRMapping &valueMapping) const {
86   if (auto *handler = getInterfaceFor(dest->getParentOp()))
87     return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping);
88   return false;
89 }
90 
91 bool InlinerInterface::isLegalToInline(Operation *op, Region *dest,
92                                        bool wouldBeCloned,
93                                        IRMapping &valueMapping) const {
94   if (auto *handler = getInterfaceFor(op))
95     return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping);
96   return false;
97 }
98 
99 bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const {
100   auto *handler = getInterfaceFor(op);
101   return handler ? handler->shouldAnalyzeRecursively(op) : true;
102 }
103 
104 /// Handle the given inlined terminator by replacing it with a new operation
105 /// as necessary.
106 void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const {
107   auto *handler = getInterfaceFor(op);
108   assert(handler && "expected valid dialect handler");
109   handler->handleTerminator(op, newDest);
110 }
111 
112 /// Handle the given inlined terminator by replacing it with a new operation
113 /// as necessary.
114 void InlinerInterface::handleTerminator(Operation *op,
115                                         ValueRange valuesToRepl) const {
116   auto *handler = getInterfaceFor(op);
117   assert(handler && "expected valid dialect handler");
118   handler->handleTerminator(op, valuesToRepl);
119 }
120 
121 /// Returns true if the inliner can assume a fast path of not creating a
122 /// new block, if there is only one block.
123 bool InlinerInterface::allowSingleBlockOptimization(
124     iterator_range<Region::iterator> inlinedBlocks) const {
125   if (inlinedBlocks.empty()) {
126     return true;
127   }
128   auto *handler = getInterfaceFor(inlinedBlocks.begin()->getParentOp());
129   assert(handler && "expected valid dialect handler");
130   return handler->allowSingleBlockOptimization(inlinedBlocks);
131 }
132 
133 Value InlinerInterface::handleArgument(OpBuilder &builder, Operation *call,
134                                        Operation *callable, Value argument,
135                                        DictionaryAttr argumentAttrs) const {
136   auto *handler = getInterfaceFor(callable);
137   assert(handler && "expected valid dialect handler");
138   return handler->handleArgument(builder, call, callable, argument,
139                                  argumentAttrs);
140 }
141 
142 Value InlinerInterface::handleResult(OpBuilder &builder, Operation *call,
143                                      Operation *callable, Value result,
144                                      DictionaryAttr resultAttrs) const {
145   auto *handler = getInterfaceFor(callable);
146   assert(handler && "expected valid dialect handler");
147   return handler->handleResult(builder, call, callable, result, resultAttrs);
148 }
149 
150 void InlinerInterface::processInlinedCallBlocks(
151     Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
152   auto *handler = getInterfaceFor(call);
153   assert(handler && "expected valid dialect handler");
154   handler->processInlinedCallBlocks(call, inlinedBlocks);
155 }
156 
157 /// Utility to check that all of the operations within 'src' can be inlined.
158 static bool isLegalToInline(InlinerInterface &interface, Region *src,
159                             Region *insertRegion, bool shouldCloneInlinedRegion,
160                             IRMapping &valueMapping) {
161   for (auto &block : *src) {
162     for (auto &op : block) {
163       // Check this operation.
164       if (!interface.isLegalToInline(&op, insertRegion,
165                                      shouldCloneInlinedRegion, valueMapping)) {
166         LLVM_DEBUG({
167           llvm::dbgs() << "* Illegal to inline because of op: ";
168           op.dump();
169         });
170         return false;
171       }
172       // Check any nested regions.
173       if (interface.shouldAnalyzeRecursively(&op) &&
174           llvm::any_of(op.getRegions(), [&](Region &region) {
175             return !isLegalToInline(interface, &region, insertRegion,
176                                     shouldCloneInlinedRegion, valueMapping);
177           }))
178         return false;
179     }
180   }
181   return true;
182 }
183 
184 //===----------------------------------------------------------------------===//
185 // Inline Methods
186 //===----------------------------------------------------------------------===//
187 
188 static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder,
189                                CallOpInterface call,
190                                CallableOpInterface callable,
191                                IRMapping &mapper) {
192   // Unpack the argument attributes if there are any.
193   SmallVector<DictionaryAttr> argAttrs(
194       callable.getCallableRegion()->getNumArguments(),
195       builder.getDictionaryAttr({}));
196   if (ArrayAttr arrayAttr = callable.getArgAttrsAttr()) {
197     assert(arrayAttr.size() == argAttrs.size());
198     for (auto [idx, attr] : llvm::enumerate(arrayAttr))
199       argAttrs[idx] = cast<DictionaryAttr>(attr);
200   }
201 
202   // Run the argument attribute handler for the given argument and attribute.
203   for (auto [blockArg, argAttr] :
204        llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) {
205     Value newArgument = interface.handleArgument(
206         builder, call, callable, mapper.lookup(blockArg), argAttr);
207     assert(newArgument.getType() == mapper.lookup(blockArg).getType() &&
208            "expected the argument type to not change");
209 
210     // Update the mapping to point the new argument returned by the handler.
211     mapper.map(blockArg, newArgument);
212   }
213 }
214 
215 static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder,
216                              CallOpInterface call, CallableOpInterface callable,
217                              ValueRange results) {
218   // Unpack the result attributes if there are any.
219   SmallVector<DictionaryAttr> resAttrs(results.size(),
220                                        builder.getDictionaryAttr({}));
221   if (ArrayAttr arrayAttr = callable.getResAttrsAttr()) {
222     assert(arrayAttr.size() == resAttrs.size());
223     for (auto [idx, attr] : llvm::enumerate(arrayAttr))
224       resAttrs[idx] = cast<DictionaryAttr>(attr);
225   }
226 
227   // Run the result attribute handler for the given result and attribute.
228   SmallVector<DictionaryAttr> resultAttributes;
229   for (auto [result, resAttr] : llvm::zip(results, resAttrs)) {
230     // Store the original result users before running the handler.
231     DenseSet<Operation *> resultUsers;
232     for (Operation *user : result.getUsers())
233       resultUsers.insert(user);
234 
235     Value newResult =
236         interface.handleResult(builder, call, callable, result, resAttr);
237     assert(newResult.getType() == result.getType() &&
238            "expected the result type to not change");
239 
240     // Replace the result uses except for the ones introduce by the handler.
241     result.replaceUsesWithIf(newResult, [&](OpOperand &operand) {
242       return resultUsers.count(operand.getOwner());
243     });
244   }
245 }
246 
247 static LogicalResult
248 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
249                  Block::iterator inlinePoint, IRMapping &mapper,
250                  ValueRange resultsToReplace, TypeRange regionResultTypes,
251                  std::optional<Location> inlineLoc,
252                  bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
253   assert(resultsToReplace.size() == regionResultTypes.size());
254   // We expect the region to have at least one block.
255   if (src->empty())
256     return failure();
257 
258   // Check that all of the region arguments have been mapped.
259   auto *srcEntryBlock = &src->front();
260   if (llvm::any_of(srcEntryBlock->getArguments(),
261                    [&](BlockArgument arg) { return !mapper.contains(arg); }))
262     return failure();
263 
264   // Check that the operations within the source region are valid to inline.
265   Region *insertRegion = inlineBlock->getParent();
266   if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion,
267                                  mapper) ||
268       !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion,
269                        mapper))
270     return failure();
271 
272   // Run the argument attribute handler before inlining the callable region.
273   OpBuilder builder(inlineBlock, inlinePoint);
274   auto callable = dyn_cast<CallableOpInterface>(src->getParentOp());
275   if (call && callable)
276     handleArgumentImpl(interface, builder, call, callable, mapper);
277 
278   // Check to see if the region is being cloned, or moved inline. In either
279   // case, move the new blocks after the 'insertBlock' to improve IR
280   // readability.
281   Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint);
282   if (shouldCloneInlinedRegion)
283     src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
284   else
285     insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
286                                      src->getBlocks(), src->begin(),
287                                      src->end());
288 
289   // Get the range of newly inserted blocks.
290   auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
291                                     postInsertBlock->getIterator());
292   Block *firstNewBlock = &*newBlocks.begin();
293 
294   // Remap the locations of the inlined operations if a valid source location
295   // was provided.
296   if (inlineLoc && !llvm::isa<UnknownLoc>(*inlineLoc))
297     remapInlinedLocations(newBlocks, *inlineLoc);
298 
299   // If the blocks were moved in-place, make sure to remap any necessary
300   // operands.
301   if (!shouldCloneInlinedRegion)
302     remapInlinedOperands(newBlocks, mapper);
303 
304   // Process the newly inlined blocks.
305   if (call)
306     interface.processInlinedCallBlocks(call, newBlocks);
307   interface.processInlinedBlocks(newBlocks);
308 
309   bool singleBlockFastPath = interface.allowSingleBlockOptimization(newBlocks);
310 
311   // Handle the case where only a single block was inlined.
312   if (singleBlockFastPath && std::next(newBlocks.begin()) == newBlocks.end()) {
313     // Run the result attribute handler on the terminator operands.
314     Operation *firstBlockTerminator = firstNewBlock->getTerminator();
315     builder.setInsertionPoint(firstBlockTerminator);
316     if (call && callable)
317       handleResultImpl(interface, builder, call, callable,
318                        firstBlockTerminator->getOperands());
319 
320     // Have the interface handle the terminator of this block.
321     interface.handleTerminator(firstBlockTerminator, resultsToReplace);
322     firstBlockTerminator->erase();
323 
324     // Merge the post insert block into the cloned entry block.
325     firstNewBlock->getOperations().splice(firstNewBlock->end(),
326                                           postInsertBlock->getOperations());
327     postInsertBlock->erase();
328   } else {
329     // Otherwise, there were multiple blocks inlined. Add arguments to the post
330     // insertion block to represent the results to replace.
331     for (const auto &resultToRepl : llvm::enumerate(resultsToReplace)) {
332       resultToRepl.value().replaceAllUsesWith(
333           postInsertBlock->addArgument(regionResultTypes[resultToRepl.index()],
334                                        resultToRepl.value().getLoc()));
335     }
336 
337     // Run the result attribute handler on the post insertion block arguments.
338     builder.setInsertionPointToStart(postInsertBlock);
339     if (call && callable)
340       handleResultImpl(interface, builder, call, callable,
341                        postInsertBlock->getArguments());
342 
343     /// Handle the terminators for each of the new blocks.
344     for (auto &newBlock : newBlocks)
345       interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);
346   }
347 
348   // Splice the instructions of the inlined entry block into the insert block.
349   inlineBlock->getOperations().splice(inlineBlock->end(),
350                                       firstNewBlock->getOperations());
351   firstNewBlock->erase();
352   return success();
353 }
354 
355 static LogicalResult
356 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
357                  Block::iterator inlinePoint, ValueRange inlinedOperands,
358                  ValueRange resultsToReplace, std::optional<Location> inlineLoc,
359                  bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
360   // We expect the region to have at least one block.
361   if (src->empty())
362     return failure();
363 
364   auto *entryBlock = &src->front();
365   if (inlinedOperands.size() != entryBlock->getNumArguments())
366     return failure();
367 
368   // Map the provided call operands to the arguments of the region.
369   IRMapping mapper;
370   for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
371     // Verify that the types of the provided values match the function argument
372     // types.
373     BlockArgument regionArg = entryBlock->getArgument(i);
374     if (inlinedOperands[i].getType() != regionArg.getType())
375       return failure();
376     mapper.map(regionArg, inlinedOperands[i]);
377   }
378 
379   // Call into the main region inliner function.
380   return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
381                           resultsToReplace, resultsToReplace.getTypes(),
382                           inlineLoc, shouldCloneInlinedRegion, call);
383 }
384 
385 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
386                                  Operation *inlinePoint, IRMapping &mapper,
387                                  ValueRange resultsToReplace,
388                                  TypeRange regionResultTypes,
389                                  std::optional<Location> inlineLoc,
390                                  bool shouldCloneInlinedRegion) {
391   return inlineRegion(interface, src, inlinePoint->getBlock(),
392                       ++inlinePoint->getIterator(), mapper, resultsToReplace,
393                       regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
394 }
395 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
396                                  Block *inlineBlock,
397                                  Block::iterator inlinePoint, IRMapping &mapper,
398                                  ValueRange resultsToReplace,
399                                  TypeRange regionResultTypes,
400                                  std::optional<Location> inlineLoc,
401                                  bool shouldCloneInlinedRegion) {
402   return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
403                           resultsToReplace, regionResultTypes, inlineLoc,
404                           shouldCloneInlinedRegion);
405 }
406 
407 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
408                                  Operation *inlinePoint,
409                                  ValueRange inlinedOperands,
410                                  ValueRange resultsToReplace,
411                                  std::optional<Location> inlineLoc,
412                                  bool shouldCloneInlinedRegion) {
413   return inlineRegion(interface, src, inlinePoint->getBlock(),
414                       ++inlinePoint->getIterator(), inlinedOperands,
415                       resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
416 }
417 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
418                                  Block *inlineBlock,
419                                  Block::iterator inlinePoint,
420                                  ValueRange inlinedOperands,
421                                  ValueRange resultsToReplace,
422                                  std::optional<Location> inlineLoc,
423                                  bool shouldCloneInlinedRegion) {
424   return inlineRegionImpl(interface, src, inlineBlock, inlinePoint,
425                           inlinedOperands, resultsToReplace, inlineLoc,
426                           shouldCloneInlinedRegion);
427 }
428 
429 /// Utility function used to generate a cast operation from the given interface,
430 /// or return nullptr if a cast could not be generated.
431 static Value materializeConversion(const DialectInlinerInterface *interface,
432                                    SmallVectorImpl<Operation *> &castOps,
433                                    OpBuilder &castBuilder, Value arg, Type type,
434                                    Location conversionLoc) {
435   if (!interface)
436     return nullptr;
437 
438   // Check to see if the interface for the call can materialize a conversion.
439   Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
440                                                            type, conversionLoc);
441   if (!castOp)
442     return nullptr;
443   castOps.push_back(castOp);
444 
445   // Ensure that the generated cast is correct.
446   assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
447          castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
448   return castOp->getResult(0);
449 }
450 
451 /// This function inlines a given region, 'src', of a callable operation,
452 /// 'callable', into the location defined by the given call operation. This
453 /// function returns failure if inlining is not possible, success otherwise. On
454 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
455 /// corresponds to whether the source region should be cloned into the 'call' or
456 /// spliced directly.
457 LogicalResult mlir::inlineCall(InlinerInterface &interface,
458                                CallOpInterface call,
459                                CallableOpInterface callable, Region *src,
460                                bool shouldCloneInlinedRegion) {
461   // We expect the region to have at least one block.
462   if (src->empty())
463     return failure();
464   auto *entryBlock = &src->front();
465   ArrayRef<Type> callableResultTypes = callable.getResultTypes();
466 
467   // Make sure that the number of arguments and results matchup between the call
468   // and the region.
469   SmallVector<Value, 8> callOperands(call.getArgOperands());
470   SmallVector<Value, 8> callResults(call->getResults());
471   if (callOperands.size() != entryBlock->getNumArguments() ||
472       callResults.size() != callableResultTypes.size())
473     return failure();
474 
475   // A set of cast operations generated to matchup the signature of the region
476   // with the signature of the call.
477   SmallVector<Operation *, 4> castOps;
478   castOps.reserve(callOperands.size() + callResults.size());
479 
480   // Functor used to cleanup generated state on failure.
481   auto cleanupState = [&] {
482     for (auto *op : castOps) {
483       op->getResult(0).replaceAllUsesWith(op->getOperand(0));
484       op->erase();
485     }
486     return failure();
487   };
488 
489   // Builder used for any conversion operations that need to be materialized.
490   OpBuilder castBuilder(call);
491   Location castLoc = call.getLoc();
492   const auto *callInterface = interface.getInterfaceFor(call->getDialect());
493 
494   // Map the provided call operands to the arguments of the region.
495   IRMapping mapper;
496   for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
497     BlockArgument regionArg = entryBlock->getArgument(i);
498     Value operand = callOperands[i];
499 
500     // If the call operand doesn't match the expected region argument, try to
501     // generate a cast.
502     Type regionArgType = regionArg.getType();
503     if (operand.getType() != regionArgType) {
504       if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
505                                             operand, regionArgType, castLoc)))
506         return cleanupState();
507     }
508     mapper.map(regionArg, operand);
509   }
510 
511   // Ensure that the resultant values of the call match the callable.
512   castBuilder.setInsertionPointAfter(call);
513   for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
514     Value callResult = callResults[i];
515     if (callResult.getType() == callableResultTypes[i])
516       continue;
517 
518     // Generate a conversion that will produce the original type, so that the IR
519     // is still valid after the original call gets replaced.
520     Value castResult =
521         materializeConversion(callInterface, castOps, castBuilder, callResult,
522                               callResult.getType(), castLoc);
523     if (!castResult)
524       return cleanupState();
525     callResult.replaceAllUsesWith(castResult);
526     castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult);
527   }
528 
529   // Check that it is legal to inline the callable into the call.
530   if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion))
531     return cleanupState();
532 
533   // Attempt to inline the call.
534   if (failed(inlineRegionImpl(interface, src, call->getBlock(),
535                               ++call->getIterator(), mapper, callResults,
536                               callableResultTypes, call.getLoc(),
537                               shouldCloneInlinedRegion, call)))
538     return cleanupState();
539   return success();
540 }
541