1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements miscellaneous inlining utilities. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Transforms/InliningUtils.h" 14 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/IRMapping.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/Interfaces/CallInterfaces.h" 19 #include "llvm/ADT/MapVector.h" 20 #include "llvm/Support/Debug.h" 21 #include "llvm/Support/raw_ostream.h" 22 #include <optional> 23 24 #define DEBUG_TYPE "inlining" 25 26 using namespace mlir; 27 28 /// Remap all locations reachable from the inlined blocks with CallSiteLoc 29 /// locations with the provided caller location. 30 static void 31 remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks, 32 Location callerLoc) { 33 DenseMap<Location, LocationAttr> mappedLocations; 34 auto remapLoc = [&](Location loc) { 35 auto [it, inserted] = mappedLocations.try_emplace(loc); 36 // Only query the attribute uniquer once per callsite attribute. 37 if (inserted) { 38 auto newLoc = CallSiteLoc::get(loc, callerLoc); 39 it->getSecond() = newLoc; 40 } 41 return it->second; 42 }; 43 44 AttrTypeReplacer attrReplacer; 45 attrReplacer.addReplacement( 46 [&](LocationAttr loc) -> std::pair<LocationAttr, WalkResult> { 47 return {remapLoc(loc), WalkResult::skip()}; 48 }); 49 50 for (Block &block : inlinedBlocks) { 51 for (BlockArgument &arg : block.getArguments()) 52 if (LocationAttr newLoc = remapLoc(arg.getLoc())) 53 arg.setLoc(newLoc); 54 55 for (Operation &op : block) 56 attrReplacer.recursivelyReplaceElementsIn(&op, /*replaceAttrs=*/false, 57 /*replaceLocs=*/true); 58 } 59 } 60 61 static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks, 62 IRMapping &mapper) { 63 auto remapOperands = [&](Operation *op) { 64 for (auto &operand : op->getOpOperands()) 65 if (auto mappedOp = mapper.lookupOrNull(operand.get())) 66 operand.set(mappedOp); 67 }; 68 for (auto &block : inlinedBlocks) 69 block.walk(remapOperands); 70 } 71 72 //===----------------------------------------------------------------------===// 73 // InlinerInterface 74 //===----------------------------------------------------------------------===// 75 76 bool InlinerInterface::isLegalToInline(Operation *call, Operation *callable, 77 bool wouldBeCloned) const { 78 if (auto *handler = getInterfaceFor(call)) 79 return handler->isLegalToInline(call, callable, wouldBeCloned); 80 return false; 81 } 82 83 bool InlinerInterface::isLegalToInline(Region *dest, Region *src, 84 bool wouldBeCloned, 85 IRMapping &valueMapping) const { 86 if (auto *handler = getInterfaceFor(dest->getParentOp())) 87 return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping); 88 return false; 89 } 90 91 bool InlinerInterface::isLegalToInline(Operation *op, Region *dest, 92 bool wouldBeCloned, 93 IRMapping &valueMapping) const { 94 if (auto *handler = getInterfaceFor(op)) 95 return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping); 96 return false; 97 } 98 99 bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const { 100 auto *handler = getInterfaceFor(op); 101 return handler ? handler->shouldAnalyzeRecursively(op) : true; 102 } 103 104 /// Handle the given inlined terminator by replacing it with a new operation 105 /// as necessary. 106 void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const { 107 auto *handler = getInterfaceFor(op); 108 assert(handler && "expected valid dialect handler"); 109 handler->handleTerminator(op, newDest); 110 } 111 112 /// Handle the given inlined terminator by replacing it with a new operation 113 /// as necessary. 114 void InlinerInterface::handleTerminator(Operation *op, 115 ValueRange valuesToRepl) const { 116 auto *handler = getInterfaceFor(op); 117 assert(handler && "expected valid dialect handler"); 118 handler->handleTerminator(op, valuesToRepl); 119 } 120 121 /// Returns true if the inliner can assume a fast path of not creating a 122 /// new block, if there is only one block. 123 bool InlinerInterface::allowSingleBlockOptimization( 124 iterator_range<Region::iterator> inlinedBlocks) const { 125 if (inlinedBlocks.empty()) { 126 return true; 127 } 128 auto *handler = getInterfaceFor(inlinedBlocks.begin()->getParentOp()); 129 assert(handler && "expected valid dialect handler"); 130 return handler->allowSingleBlockOptimization(inlinedBlocks); 131 } 132 133 Value InlinerInterface::handleArgument(OpBuilder &builder, Operation *call, 134 Operation *callable, Value argument, 135 DictionaryAttr argumentAttrs) const { 136 auto *handler = getInterfaceFor(callable); 137 assert(handler && "expected valid dialect handler"); 138 return handler->handleArgument(builder, call, callable, argument, 139 argumentAttrs); 140 } 141 142 Value InlinerInterface::handleResult(OpBuilder &builder, Operation *call, 143 Operation *callable, Value result, 144 DictionaryAttr resultAttrs) const { 145 auto *handler = getInterfaceFor(callable); 146 assert(handler && "expected valid dialect handler"); 147 return handler->handleResult(builder, call, callable, result, resultAttrs); 148 } 149 150 void InlinerInterface::processInlinedCallBlocks( 151 Operation *call, iterator_range<Region::iterator> inlinedBlocks) const { 152 auto *handler = getInterfaceFor(call); 153 assert(handler && "expected valid dialect handler"); 154 handler->processInlinedCallBlocks(call, inlinedBlocks); 155 } 156 157 /// Utility to check that all of the operations within 'src' can be inlined. 158 static bool isLegalToInline(InlinerInterface &interface, Region *src, 159 Region *insertRegion, bool shouldCloneInlinedRegion, 160 IRMapping &valueMapping) { 161 for (auto &block : *src) { 162 for (auto &op : block) { 163 // Check this operation. 164 if (!interface.isLegalToInline(&op, insertRegion, 165 shouldCloneInlinedRegion, valueMapping)) { 166 LLVM_DEBUG({ 167 llvm::dbgs() << "* Illegal to inline because of op: "; 168 op.dump(); 169 }); 170 return false; 171 } 172 // Check any nested regions. 173 if (interface.shouldAnalyzeRecursively(&op) && 174 llvm::any_of(op.getRegions(), [&](Region ®ion) { 175 return !isLegalToInline(interface, ®ion, insertRegion, 176 shouldCloneInlinedRegion, valueMapping); 177 })) 178 return false; 179 } 180 } 181 return true; 182 } 183 184 //===----------------------------------------------------------------------===// 185 // Inline Methods 186 //===----------------------------------------------------------------------===// 187 188 static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder, 189 CallOpInterface call, 190 CallableOpInterface callable, 191 IRMapping &mapper) { 192 // Unpack the argument attributes if there are any. 193 SmallVector<DictionaryAttr> argAttrs( 194 callable.getCallableRegion()->getNumArguments(), 195 builder.getDictionaryAttr({})); 196 if (ArrayAttr arrayAttr = callable.getArgAttrsAttr()) { 197 assert(arrayAttr.size() == argAttrs.size()); 198 for (auto [idx, attr] : llvm::enumerate(arrayAttr)) 199 argAttrs[idx] = cast<DictionaryAttr>(attr); 200 } 201 202 // Run the argument attribute handler for the given argument and attribute. 203 for (auto [blockArg, argAttr] : 204 llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) { 205 Value newArgument = interface.handleArgument( 206 builder, call, callable, mapper.lookup(blockArg), argAttr); 207 assert(newArgument.getType() == mapper.lookup(blockArg).getType() && 208 "expected the argument type to not change"); 209 210 // Update the mapping to point the new argument returned by the handler. 211 mapper.map(blockArg, newArgument); 212 } 213 } 214 215 static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder, 216 CallOpInterface call, CallableOpInterface callable, 217 ValueRange results) { 218 // Unpack the result attributes if there are any. 219 SmallVector<DictionaryAttr> resAttrs(results.size(), 220 builder.getDictionaryAttr({})); 221 if (ArrayAttr arrayAttr = callable.getResAttrsAttr()) { 222 assert(arrayAttr.size() == resAttrs.size()); 223 for (auto [idx, attr] : llvm::enumerate(arrayAttr)) 224 resAttrs[idx] = cast<DictionaryAttr>(attr); 225 } 226 227 // Run the result attribute handler for the given result and attribute. 228 SmallVector<DictionaryAttr> resultAttributes; 229 for (auto [result, resAttr] : llvm::zip(results, resAttrs)) { 230 // Store the original result users before running the handler. 231 DenseSet<Operation *> resultUsers; 232 for (Operation *user : result.getUsers()) 233 resultUsers.insert(user); 234 235 Value newResult = 236 interface.handleResult(builder, call, callable, result, resAttr); 237 assert(newResult.getType() == result.getType() && 238 "expected the result type to not change"); 239 240 // Replace the result uses except for the ones introduce by the handler. 241 result.replaceUsesWithIf(newResult, [&](OpOperand &operand) { 242 return resultUsers.count(operand.getOwner()); 243 }); 244 } 245 } 246 247 static LogicalResult 248 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, 249 Block::iterator inlinePoint, IRMapping &mapper, 250 ValueRange resultsToReplace, TypeRange regionResultTypes, 251 std::optional<Location> inlineLoc, 252 bool shouldCloneInlinedRegion, CallOpInterface call = {}) { 253 assert(resultsToReplace.size() == regionResultTypes.size()); 254 // We expect the region to have at least one block. 255 if (src->empty()) 256 return failure(); 257 258 // Check that all of the region arguments have been mapped. 259 auto *srcEntryBlock = &src->front(); 260 if (llvm::any_of(srcEntryBlock->getArguments(), 261 [&](BlockArgument arg) { return !mapper.contains(arg); })) 262 return failure(); 263 264 // Check that the operations within the source region are valid to inline. 265 Region *insertRegion = inlineBlock->getParent(); 266 if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion, 267 mapper) || 268 !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion, 269 mapper)) 270 return failure(); 271 272 // Run the argument attribute handler before inlining the callable region. 273 OpBuilder builder(inlineBlock, inlinePoint); 274 auto callable = dyn_cast<CallableOpInterface>(src->getParentOp()); 275 if (call && callable) 276 handleArgumentImpl(interface, builder, call, callable, mapper); 277 278 // Check to see if the region is being cloned, or moved inline. In either 279 // case, move the new blocks after the 'insertBlock' to improve IR 280 // readability. 281 Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint); 282 if (shouldCloneInlinedRegion) 283 src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper); 284 else 285 insertRegion->getBlocks().splice(postInsertBlock->getIterator(), 286 src->getBlocks(), src->begin(), 287 src->end()); 288 289 // Get the range of newly inserted blocks. 290 auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()), 291 postInsertBlock->getIterator()); 292 Block *firstNewBlock = &*newBlocks.begin(); 293 294 // Remap the locations of the inlined operations if a valid source location 295 // was provided. 296 if (inlineLoc && !llvm::isa<UnknownLoc>(*inlineLoc)) 297 remapInlinedLocations(newBlocks, *inlineLoc); 298 299 // If the blocks were moved in-place, make sure to remap any necessary 300 // operands. 301 if (!shouldCloneInlinedRegion) 302 remapInlinedOperands(newBlocks, mapper); 303 304 // Process the newly inlined blocks. 305 if (call) 306 interface.processInlinedCallBlocks(call, newBlocks); 307 interface.processInlinedBlocks(newBlocks); 308 309 bool singleBlockFastPath = interface.allowSingleBlockOptimization(newBlocks); 310 311 // Handle the case where only a single block was inlined. 312 if (singleBlockFastPath && std::next(newBlocks.begin()) == newBlocks.end()) { 313 // Run the result attribute handler on the terminator operands. 314 Operation *firstBlockTerminator = firstNewBlock->getTerminator(); 315 builder.setInsertionPoint(firstBlockTerminator); 316 if (call && callable) 317 handleResultImpl(interface, builder, call, callable, 318 firstBlockTerminator->getOperands()); 319 320 // Have the interface handle the terminator of this block. 321 interface.handleTerminator(firstBlockTerminator, resultsToReplace); 322 firstBlockTerminator->erase(); 323 324 // Merge the post insert block into the cloned entry block. 325 firstNewBlock->getOperations().splice(firstNewBlock->end(), 326 postInsertBlock->getOperations()); 327 postInsertBlock->erase(); 328 } else { 329 // Otherwise, there were multiple blocks inlined. Add arguments to the post 330 // insertion block to represent the results to replace. 331 for (const auto &resultToRepl : llvm::enumerate(resultsToReplace)) { 332 resultToRepl.value().replaceAllUsesWith( 333 postInsertBlock->addArgument(regionResultTypes[resultToRepl.index()], 334 resultToRepl.value().getLoc())); 335 } 336 337 // Run the result attribute handler on the post insertion block arguments. 338 builder.setInsertionPointToStart(postInsertBlock); 339 if (call && callable) 340 handleResultImpl(interface, builder, call, callable, 341 postInsertBlock->getArguments()); 342 343 /// Handle the terminators for each of the new blocks. 344 for (auto &newBlock : newBlocks) 345 interface.handleTerminator(newBlock.getTerminator(), postInsertBlock); 346 } 347 348 // Splice the instructions of the inlined entry block into the insert block. 349 inlineBlock->getOperations().splice(inlineBlock->end(), 350 firstNewBlock->getOperations()); 351 firstNewBlock->erase(); 352 return success(); 353 } 354 355 static LogicalResult 356 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, 357 Block::iterator inlinePoint, ValueRange inlinedOperands, 358 ValueRange resultsToReplace, std::optional<Location> inlineLoc, 359 bool shouldCloneInlinedRegion, CallOpInterface call = {}) { 360 // We expect the region to have at least one block. 361 if (src->empty()) 362 return failure(); 363 364 auto *entryBlock = &src->front(); 365 if (inlinedOperands.size() != entryBlock->getNumArguments()) 366 return failure(); 367 368 // Map the provided call operands to the arguments of the region. 369 IRMapping mapper; 370 for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) { 371 // Verify that the types of the provided values match the function argument 372 // types. 373 BlockArgument regionArg = entryBlock->getArgument(i); 374 if (inlinedOperands[i].getType() != regionArg.getType()) 375 return failure(); 376 mapper.map(regionArg, inlinedOperands[i]); 377 } 378 379 // Call into the main region inliner function. 380 return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper, 381 resultsToReplace, resultsToReplace.getTypes(), 382 inlineLoc, shouldCloneInlinedRegion, call); 383 } 384 385 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 386 Operation *inlinePoint, IRMapping &mapper, 387 ValueRange resultsToReplace, 388 TypeRange regionResultTypes, 389 std::optional<Location> inlineLoc, 390 bool shouldCloneInlinedRegion) { 391 return inlineRegion(interface, src, inlinePoint->getBlock(), 392 ++inlinePoint->getIterator(), mapper, resultsToReplace, 393 regionResultTypes, inlineLoc, shouldCloneInlinedRegion); 394 } 395 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 396 Block *inlineBlock, 397 Block::iterator inlinePoint, IRMapping &mapper, 398 ValueRange resultsToReplace, 399 TypeRange regionResultTypes, 400 std::optional<Location> inlineLoc, 401 bool shouldCloneInlinedRegion) { 402 return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper, 403 resultsToReplace, regionResultTypes, inlineLoc, 404 shouldCloneInlinedRegion); 405 } 406 407 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 408 Operation *inlinePoint, 409 ValueRange inlinedOperands, 410 ValueRange resultsToReplace, 411 std::optional<Location> inlineLoc, 412 bool shouldCloneInlinedRegion) { 413 return inlineRegion(interface, src, inlinePoint->getBlock(), 414 ++inlinePoint->getIterator(), inlinedOperands, 415 resultsToReplace, inlineLoc, shouldCloneInlinedRegion); 416 } 417 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 418 Block *inlineBlock, 419 Block::iterator inlinePoint, 420 ValueRange inlinedOperands, 421 ValueRange resultsToReplace, 422 std::optional<Location> inlineLoc, 423 bool shouldCloneInlinedRegion) { 424 return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, 425 inlinedOperands, resultsToReplace, inlineLoc, 426 shouldCloneInlinedRegion); 427 } 428 429 /// Utility function used to generate a cast operation from the given interface, 430 /// or return nullptr if a cast could not be generated. 431 static Value materializeConversion(const DialectInlinerInterface *interface, 432 SmallVectorImpl<Operation *> &castOps, 433 OpBuilder &castBuilder, Value arg, Type type, 434 Location conversionLoc) { 435 if (!interface) 436 return nullptr; 437 438 // Check to see if the interface for the call can materialize a conversion. 439 Operation *castOp = interface->materializeCallConversion(castBuilder, arg, 440 type, conversionLoc); 441 if (!castOp) 442 return nullptr; 443 castOps.push_back(castOp); 444 445 // Ensure that the generated cast is correct. 446 assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg && 447 castOp->getNumResults() == 1 && *castOp->result_type_begin() == type); 448 return castOp->getResult(0); 449 } 450 451 /// This function inlines a given region, 'src', of a callable operation, 452 /// 'callable', into the location defined by the given call operation. This 453 /// function returns failure if inlining is not possible, success otherwise. On 454 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion' 455 /// corresponds to whether the source region should be cloned into the 'call' or 456 /// spliced directly. 457 LogicalResult mlir::inlineCall(InlinerInterface &interface, 458 CallOpInterface call, 459 CallableOpInterface callable, Region *src, 460 bool shouldCloneInlinedRegion) { 461 // We expect the region to have at least one block. 462 if (src->empty()) 463 return failure(); 464 auto *entryBlock = &src->front(); 465 ArrayRef<Type> callableResultTypes = callable.getResultTypes(); 466 467 // Make sure that the number of arguments and results matchup between the call 468 // and the region. 469 SmallVector<Value, 8> callOperands(call.getArgOperands()); 470 SmallVector<Value, 8> callResults(call->getResults()); 471 if (callOperands.size() != entryBlock->getNumArguments() || 472 callResults.size() != callableResultTypes.size()) 473 return failure(); 474 475 // A set of cast operations generated to matchup the signature of the region 476 // with the signature of the call. 477 SmallVector<Operation *, 4> castOps; 478 castOps.reserve(callOperands.size() + callResults.size()); 479 480 // Functor used to cleanup generated state on failure. 481 auto cleanupState = [&] { 482 for (auto *op : castOps) { 483 op->getResult(0).replaceAllUsesWith(op->getOperand(0)); 484 op->erase(); 485 } 486 return failure(); 487 }; 488 489 // Builder used for any conversion operations that need to be materialized. 490 OpBuilder castBuilder(call); 491 Location castLoc = call.getLoc(); 492 const auto *callInterface = interface.getInterfaceFor(call->getDialect()); 493 494 // Map the provided call operands to the arguments of the region. 495 IRMapping mapper; 496 for (unsigned i = 0, e = callOperands.size(); i != e; ++i) { 497 BlockArgument regionArg = entryBlock->getArgument(i); 498 Value operand = callOperands[i]; 499 500 // If the call operand doesn't match the expected region argument, try to 501 // generate a cast. 502 Type regionArgType = regionArg.getType(); 503 if (operand.getType() != regionArgType) { 504 if (!(operand = materializeConversion(callInterface, castOps, castBuilder, 505 operand, regionArgType, castLoc))) 506 return cleanupState(); 507 } 508 mapper.map(regionArg, operand); 509 } 510 511 // Ensure that the resultant values of the call match the callable. 512 castBuilder.setInsertionPointAfter(call); 513 for (unsigned i = 0, e = callResults.size(); i != e; ++i) { 514 Value callResult = callResults[i]; 515 if (callResult.getType() == callableResultTypes[i]) 516 continue; 517 518 // Generate a conversion that will produce the original type, so that the IR 519 // is still valid after the original call gets replaced. 520 Value castResult = 521 materializeConversion(callInterface, castOps, castBuilder, callResult, 522 callResult.getType(), castLoc); 523 if (!castResult) 524 return cleanupState(); 525 callResult.replaceAllUsesWith(castResult); 526 castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult); 527 } 528 529 // Check that it is legal to inline the callable into the call. 530 if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion)) 531 return cleanupState(); 532 533 // Attempt to inline the call. 534 if (failed(inlineRegionImpl(interface, src, call->getBlock(), 535 ++call->getIterator(), mapper, callResults, 536 callableResultTypes, call.getLoc(), 537 shouldCloneInlinedRegion, call))) 538 return cleanupState(); 539 return success(); 540 } 541