1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/DialectConversion.h" 10 #include "mlir/Config/mlir-config.h" 11 #include "mlir/IR/Block.h" 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/BuiltinOps.h" 14 #include "mlir/IR/Dominance.h" 15 #include "mlir/IR/IRMapping.h" 16 #include "mlir/IR/Iterators.h" 17 #include "mlir/Interfaces/FunctionInterfaces.h" 18 #include "mlir/Rewrite/PatternApplicator.h" 19 #include "llvm/ADT/ScopeExit.h" 20 #include "llvm/ADT/SetVector.h" 21 #include "llvm/ADT/SmallPtrSet.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/FormatVariadic.h" 24 #include "llvm/Support/SaveAndRestore.h" 25 #include "llvm/Support/ScopedPrinter.h" 26 #include <optional> 27 28 using namespace mlir; 29 using namespace mlir::detail; 30 31 #define DEBUG_TYPE "dialect-conversion" 32 33 /// A utility function to log a successful result for the given reason. 34 template <typename... Args> 35 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { 36 LLVM_DEBUG({ 37 os.unindent(); 38 os.startLine() << "} -> SUCCESS"; 39 if (!fmt.empty()) 40 os.getOStream() << " : " 41 << llvm::formatv(fmt.data(), std::forward<Args>(args)...); 42 os.getOStream() << "\n"; 43 }); 44 } 45 46 /// A utility function to log a failure result for the given reason. 47 template <typename... Args> 48 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { 49 LLVM_DEBUG({ 50 os.unindent(); 51 os.startLine() << "} -> FAILURE : " 52 << llvm::formatv(fmt.data(), std::forward<Args>(args)...) 53 << "\n"; 54 }); 55 } 56 57 /// Helper function that computes an insertion point where the given value is 58 /// defined and can be used without a dominance violation. 59 static OpBuilder::InsertPoint computeInsertPoint(Value value) { 60 Block *insertBlock = value.getParentBlock(); 61 Block::iterator insertPt = insertBlock->begin(); 62 if (OpResult inputRes = dyn_cast<OpResult>(value)) 63 insertPt = ++inputRes.getOwner()->getIterator(); 64 return OpBuilder::InsertPoint(insertBlock, insertPt); 65 } 66 67 /// Helper function that computes an insertion point where the given values are 68 /// defined and can be used without a dominance violation. 69 static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) { 70 assert(!vals.empty() && "expected at least one value"); 71 DominanceInfo domInfo; 72 OpBuilder::InsertPoint pt = computeInsertPoint(vals.front()); 73 for (Value v : vals.drop_front()) { 74 // Choose the "later" insertion point. 75 OpBuilder::InsertPoint nextPt = computeInsertPoint(v); 76 if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(), 77 nextPt.getPoint())) { 78 // pt is before nextPt => choose nextPt. 79 pt = nextPt; 80 } else { 81 #ifndef NDEBUG 82 // nextPt should be before pt => choose pt. 83 // If pt, nextPt are no dominance relationship, then there is no valid 84 // insertion point at which all given values are defined. 85 bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(), 86 pt.getBlock(), pt.getPoint()); 87 assert(dom && "unable to find valid insertion point"); 88 #endif // NDEBUG 89 } 90 } 91 return pt; 92 } 93 94 //===----------------------------------------------------------------------===// 95 // ConversionValueMapping 96 //===----------------------------------------------------------------------===// 97 98 /// A vector of SSA values, optimized for the most common case of a single 99 /// value. 100 using ValueVector = SmallVector<Value, 1>; 101 102 namespace { 103 104 /// Helper class to make it possible to use `ValueVector` as a key in DenseMap. 105 struct ValueVectorMapInfo { 106 static ValueVector getEmptyKey() { return ValueVector{Value()}; } 107 static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; } 108 static ::llvm::hash_code getHashValue(const ValueVector &val) { 109 return ::llvm::hash_combine_range(val.begin(), val.end()); 110 } 111 static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) { 112 return LHS == RHS; 113 } 114 }; 115 116 /// This class wraps a IRMapping to provide recursive lookup 117 /// functionality, i.e. we will traverse if the mapped value also has a mapping. 118 struct ConversionValueMapping { 119 /// Return "true" if an SSA value is mapped to the given value. May return 120 /// false positives. 121 bool isMappedTo(Value value) const { return mappedTo.contains(value); } 122 123 /// Lookup the most recently mapped values with the desired types in the 124 /// mapping. 125 /// 126 /// Special cases: 127 /// - If the desired type range is empty, simply return the most recently 128 /// mapped values. 129 /// - If there is no mapping to the desired types, also return the most 130 /// recently mapped values. 131 /// - If there is no mapping for the given values at all, return the given 132 /// value. 133 ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const; 134 135 /// Lookup the given value within the map, or return an empty vector if the 136 /// value is not mapped. If it is mapped, this follows the same behavior 137 /// as `lookupOrDefault`. 138 ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const; 139 140 template <typename T> 141 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {}; 142 143 /// Map a value vector to the one provided. 144 template <typename OldVal, typename NewVal> 145 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value> 146 map(OldVal &&oldVal, NewVal &&newVal) { 147 LLVM_DEBUG({ 148 ValueVector next(newVal); 149 while (true) { 150 assert(next != oldVal && "inserting cyclic mapping"); 151 auto it = mapping.find(next); 152 if (it == mapping.end()) 153 break; 154 next = it->second; 155 } 156 }); 157 for (Value v : newVal) 158 mappedTo.insert(v); 159 160 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal); 161 } 162 163 /// Map a value vector or single value to the one provided. 164 template <typename OldVal, typename NewVal> 165 std::enable_if_t<!IsValueVector<OldVal>::value || 166 !IsValueVector<NewVal>::value> 167 map(OldVal &&oldVal, NewVal &&newVal) { 168 if constexpr (IsValueVector<OldVal>{}) { 169 map(std::forward<OldVal>(oldVal), ValueVector{newVal}); 170 } else if constexpr (IsValueVector<NewVal>{}) { 171 map(ValueVector{oldVal}, std::forward<NewVal>(newVal)); 172 } else { 173 map(ValueVector{oldVal}, ValueVector{newVal}); 174 } 175 } 176 177 /// Drop the last mapping for the given values. 178 void erase(const ValueVector &value) { mapping.erase(value); } 179 180 private: 181 /// Current value mappings. 182 DenseMap<ValueVector, ValueVector, ValueVectorMapInfo> mapping; 183 184 /// All SSA values that are mapped to. May contain false positives. 185 DenseSet<Value> mappedTo; 186 }; 187 } // namespace 188 189 ValueVector 190 ConversionValueMapping::lookupOrDefault(Value from, 191 TypeRange desiredTypes) const { 192 // Try to find the deepest values that have the desired types. If there is no 193 // such mapping, simply return the deepest values. 194 ValueVector desiredValue; 195 ValueVector current{from}; 196 do { 197 // Store the current value if the types match. 198 if (TypeRange(ValueRange(current)) == desiredTypes) 199 desiredValue = current; 200 201 // If possible, Replace each value with (one or multiple) mapped values. 202 ValueVector next; 203 for (Value v : current) { 204 auto it = mapping.find({v}); 205 if (it != mapping.end()) { 206 llvm::append_range(next, it->second); 207 } else { 208 next.push_back(v); 209 } 210 } 211 if (next != current) { 212 // If at least one value was replaced, continue the lookup from there. 213 current = std::move(next); 214 continue; 215 } 216 217 // Otherwise: Check if there is a mapping for the entire vector. Such 218 // mappings are materializations. (N:M mapping are not supported for value 219 // replacements.) 220 // 221 // Note: From a correctness point of view, materializations do not have to 222 // be stored (and looked up) in the mapping. But for performance reasons, 223 // we choose to reuse existing IR (when possible) instead of creating it 224 // multiple times. 225 auto it = mapping.find(current); 226 if (it == mapping.end()) { 227 // No mapping found: The lookup stops here. 228 break; 229 } 230 current = it->second; 231 } while (true); 232 233 // If the desired values were found use them, otherwise default to the leaf 234 // values. 235 // Note: If `desiredTypes` is empty, this function always returns `current`. 236 return !desiredValue.empty() ? std::move(desiredValue) : std::move(current); 237 } 238 239 ValueVector ConversionValueMapping::lookupOrNull(Value from, 240 TypeRange desiredTypes) const { 241 ValueVector result = lookupOrDefault(from, desiredTypes); 242 if (result == ValueVector{from} || 243 (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes)) 244 return {}; 245 return result; 246 } 247 248 //===----------------------------------------------------------------------===// 249 // Rewriter and Translation State 250 //===----------------------------------------------------------------------===// 251 namespace { 252 /// This class contains a snapshot of the current conversion rewriter state. 253 /// This is useful when saving and undoing a set of rewrites. 254 struct RewriterState { 255 RewriterState(unsigned numRewrites, unsigned numIgnoredOperations, 256 unsigned numReplacedOps) 257 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations), 258 numReplacedOps(numReplacedOps) {} 259 260 /// The current number of rewrites performed. 261 unsigned numRewrites; 262 263 /// The current number of ignored operations. 264 unsigned numIgnoredOperations; 265 266 /// The current number of replaced ops that are scheduled for erasure. 267 unsigned numReplacedOps; 268 }; 269 270 //===----------------------------------------------------------------------===// 271 // IR rewrites 272 //===----------------------------------------------------------------------===// 273 274 /// An IR rewrite that can be committed (upon success) or rolled back (upon 275 /// failure). 276 /// 277 /// The dialect conversion keeps track of IR modifications (requested by the 278 /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites 279 /// are directly applied to the IR as the rewriter API is used, some are applied 280 /// partially, and some are delayed until the `IRRewrite` objects are committed. 281 class IRRewrite { 282 public: 283 /// The kind of the rewrite. Rewrites can be undone if the conversion fails. 284 /// Enum values are ordered, so that they can be used in `classof`: first all 285 /// block rewrites, then all operation rewrites. 286 enum class Kind { 287 // Block rewrites 288 CreateBlock, 289 EraseBlock, 290 InlineBlock, 291 MoveBlock, 292 BlockTypeConversion, 293 ReplaceBlockArg, 294 // Operation rewrites 295 MoveOperation, 296 ModifyOperation, 297 ReplaceOperation, 298 CreateOperation, 299 UnresolvedMaterialization 300 }; 301 302 virtual ~IRRewrite() = default; 303 304 /// Roll back the rewrite. Operations may be erased during rollback. 305 virtual void rollback() = 0; 306 307 /// Commit the rewrite. At this point, it is certain that the dialect 308 /// conversion will succeed. All IR modifications, except for operation/block 309 /// erasure, must be performed through the given rewriter. 310 /// 311 /// Instead of erasing operations/blocks, they should merely be unlinked 312 /// commit phase and finally be erased during the cleanup phase. This is 313 /// because internal dialect conversion state (such as `mapping`) may still 314 /// be using them. 315 /// 316 /// Any IR modification that was already performed before the commit phase 317 /// (e.g., insertion of an op) must be communicated to the listener that may 318 /// be attached to the given rewriter. 319 virtual void commit(RewriterBase &rewriter) {} 320 321 /// Cleanup operations/blocks. Cleanup is called after commit. 322 virtual void cleanup(RewriterBase &rewriter) {} 323 324 Kind getKind() const { return kind; } 325 326 static bool classof(const IRRewrite *rewrite) { return true; } 327 328 protected: 329 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl) 330 : kind(kind), rewriterImpl(rewriterImpl) {} 331 332 const ConversionConfig &getConfig() const; 333 334 const Kind kind; 335 ConversionPatternRewriterImpl &rewriterImpl; 336 }; 337 338 /// A block rewrite. 339 class BlockRewrite : public IRRewrite { 340 public: 341 /// Return the block that this rewrite operates on. 342 Block *getBlock() const { return block; } 343 344 static bool classof(const IRRewrite *rewrite) { 345 return rewrite->getKind() >= Kind::CreateBlock && 346 rewrite->getKind() <= Kind::ReplaceBlockArg; 347 } 348 349 protected: 350 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl, 351 Block *block) 352 : IRRewrite(kind, rewriterImpl), block(block) {} 353 354 // The block that this rewrite operates on. 355 Block *block; 356 }; 357 358 /// Creation of a block. Block creations are immediately reflected in the IR. 359 /// There is no extra work to commit the rewrite. During rollback, the newly 360 /// created block is erased. 361 class CreateBlockRewrite : public BlockRewrite { 362 public: 363 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block) 364 : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {} 365 366 static bool classof(const IRRewrite *rewrite) { 367 return rewrite->getKind() == Kind::CreateBlock; 368 } 369 370 void commit(RewriterBase &rewriter) override { 371 // The block was already created and inserted. Just inform the listener. 372 if (auto *listener = rewriter.getListener()) 373 listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{}); 374 } 375 376 void rollback() override { 377 // Unlink all of the operations within this block, they will be deleted 378 // separately. 379 auto &blockOps = block->getOperations(); 380 while (!blockOps.empty()) 381 blockOps.remove(blockOps.begin()); 382 block->dropAllUses(); 383 if (block->getParent()) 384 block->erase(); 385 else 386 delete block; 387 } 388 }; 389 390 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased 391 /// blocks are immediately unlinked, but only erased during cleanup. This makes 392 /// it easier to rollback a block erasure: the block is simply inserted into its 393 /// original location. 394 class EraseBlockRewrite : public BlockRewrite { 395 public: 396 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block) 397 : BlockRewrite(Kind::EraseBlock, rewriterImpl, block), 398 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {} 399 400 static bool classof(const IRRewrite *rewrite) { 401 return rewrite->getKind() == Kind::EraseBlock; 402 } 403 404 ~EraseBlockRewrite() override { 405 assert(!block && 406 "rewrite was neither rolled back nor committed/cleaned up"); 407 } 408 409 void rollback() override { 410 // The block (owned by this rewrite) was not actually erased yet. It was 411 // just unlinked. Put it back into its original position. 412 assert(block && "expected block"); 413 auto &blockList = region->getBlocks(); 414 Region::iterator before = insertBeforeBlock 415 ? Region::iterator(insertBeforeBlock) 416 : blockList.end(); 417 blockList.insert(before, block); 418 block = nullptr; 419 } 420 421 void commit(RewriterBase &rewriter) override { 422 // Erase the block. 423 assert(block && "expected block"); 424 assert(block->empty() && "expected empty block"); 425 426 // Notify the listener that the block is about to be erased. 427 if (auto *listener = 428 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener())) 429 listener->notifyBlockErased(block); 430 } 431 432 void cleanup(RewriterBase &rewriter) override { 433 // Erase the block. 434 block->dropAllDefinedValueUses(); 435 delete block; 436 block = nullptr; 437 } 438 439 private: 440 // The region in which this block was previously contained. 441 Region *region; 442 443 // The original successor of this block before it was unlinked. "nullptr" if 444 // this block was the only block in the region. 445 Block *insertBeforeBlock; 446 }; 447 448 /// Inlining of a block. This rewrite is immediately reflected in the IR. 449 /// Note: This rewrite represents only the inlining of the operations. The 450 /// erasure of the inlined block is a separate rewrite. 451 class InlineBlockRewrite : public BlockRewrite { 452 public: 453 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, 454 Block *sourceBlock, Block::iterator before) 455 : BlockRewrite(Kind::InlineBlock, rewriterImpl, block), 456 sourceBlock(sourceBlock), 457 firstInlinedInst(sourceBlock->empty() ? nullptr 458 : &sourceBlock->front()), 459 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) { 460 // If a listener is attached to the dialect conversion, ops must be moved 461 // one-by-one. When they are moved in bulk, notifications cannot be sent 462 // because the ops that used to be in the source block at the time of the 463 // inlining (before the "commit" phase) are unknown at the time when 464 // notifications are sent (which is during the "commit" phase). 465 assert(!getConfig().listener && 466 "InlineBlockRewrite not supported if listener is attached"); 467 } 468 469 static bool classof(const IRRewrite *rewrite) { 470 return rewrite->getKind() == Kind::InlineBlock; 471 } 472 473 void rollback() override { 474 // Put the operations from the destination block (owned by the rewrite) 475 // back into the source block. 476 if (firstInlinedInst) { 477 assert(lastInlinedInst && "expected operation"); 478 sourceBlock->getOperations().splice(sourceBlock->begin(), 479 block->getOperations(), 480 Block::iterator(firstInlinedInst), 481 ++Block::iterator(lastInlinedInst)); 482 } 483 } 484 485 private: 486 // The block that originally contained the operations. 487 Block *sourceBlock; 488 489 // The first inlined operation. 490 Operation *firstInlinedInst; 491 492 // The last inlined operation. 493 Operation *lastInlinedInst; 494 }; 495 496 /// Moving of a block. This rewrite is immediately reflected in the IR. 497 class MoveBlockRewrite : public BlockRewrite { 498 public: 499 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, 500 Region *region, Block *insertBeforeBlock) 501 : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region), 502 insertBeforeBlock(insertBeforeBlock) {} 503 504 static bool classof(const IRRewrite *rewrite) { 505 return rewrite->getKind() == Kind::MoveBlock; 506 } 507 508 void commit(RewriterBase &rewriter) override { 509 // The block was already moved. Just inform the listener. 510 if (auto *listener = rewriter.getListener()) { 511 // Note: `previousIt` cannot be passed because this is a delayed 512 // notification and iterators into past IR state cannot be represented. 513 listener->notifyBlockInserted(block, /*previous=*/region, 514 /*previousIt=*/{}); 515 } 516 } 517 518 void rollback() override { 519 // Move the block back to its original position. 520 Region::iterator before = 521 insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end(); 522 region->getBlocks().splice(before, block->getParent()->getBlocks(), block); 523 } 524 525 private: 526 // The region in which this block was previously contained. 527 Region *region; 528 529 // The original successor of this block before it was moved. "nullptr" if 530 // this block was the only block in the region. 531 Block *insertBeforeBlock; 532 }; 533 534 /// Block type conversion. This rewrite is partially reflected in the IR. 535 class BlockTypeConversionRewrite : public BlockRewrite { 536 public: 537 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl, 538 Block *origBlock, Block *newBlock) 539 : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock), 540 newBlock(newBlock) {} 541 542 static bool classof(const IRRewrite *rewrite) { 543 return rewrite->getKind() == Kind::BlockTypeConversion; 544 } 545 546 Block *getOrigBlock() const { return block; } 547 548 Block *getNewBlock() const { return newBlock; } 549 550 void commit(RewriterBase &rewriter) override; 551 552 void rollback() override; 553 554 private: 555 /// The new block that was created as part of this signature conversion. 556 Block *newBlock; 557 }; 558 559 /// Replacing a block argument. This rewrite is not immediately reflected in the 560 /// IR. An internal IR mapping is updated, but the actual replacement is delayed 561 /// until the rewrite is committed. 562 class ReplaceBlockArgRewrite : public BlockRewrite { 563 public: 564 ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl, 565 Block *block, BlockArgument arg, 566 const TypeConverter *converter) 567 : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg), 568 converter(converter) {} 569 570 static bool classof(const IRRewrite *rewrite) { 571 return rewrite->getKind() == Kind::ReplaceBlockArg; 572 } 573 574 void commit(RewriterBase &rewriter) override; 575 576 void rollback() override; 577 578 private: 579 BlockArgument arg; 580 581 /// The current type converter when the block argument was replaced. 582 const TypeConverter *converter; 583 }; 584 585 /// An operation rewrite. 586 class OperationRewrite : public IRRewrite { 587 public: 588 /// Return the operation that this rewrite operates on. 589 Operation *getOperation() const { return op; } 590 591 static bool classof(const IRRewrite *rewrite) { 592 return rewrite->getKind() >= Kind::MoveOperation && 593 rewrite->getKind() <= Kind::UnresolvedMaterialization; 594 } 595 596 protected: 597 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl, 598 Operation *op) 599 : IRRewrite(kind, rewriterImpl), op(op) {} 600 601 // The operation that this rewrite operates on. 602 Operation *op; 603 }; 604 605 /// Moving of an operation. This rewrite is immediately reflected in the IR. 606 class MoveOperationRewrite : public OperationRewrite { 607 public: 608 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, 609 Operation *op, Block *block, Operation *insertBeforeOp) 610 : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block), 611 insertBeforeOp(insertBeforeOp) {} 612 613 static bool classof(const IRRewrite *rewrite) { 614 return rewrite->getKind() == Kind::MoveOperation; 615 } 616 617 void commit(RewriterBase &rewriter) override { 618 // The operation was already moved. Just inform the listener. 619 if (auto *listener = rewriter.getListener()) { 620 // Note: `previousIt` cannot be passed because this is a delayed 621 // notification and iterators into past IR state cannot be represented. 622 listener->notifyOperationInserted( 623 op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block, 624 /*insertPt=*/{})); 625 } 626 } 627 628 void rollback() override { 629 // Move the operation back to its original position. 630 Block::iterator before = 631 insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end(); 632 block->getOperations().splice(before, op->getBlock()->getOperations(), op); 633 } 634 635 private: 636 // The block in which this operation was previously contained. 637 Block *block; 638 639 // The original successor of this operation before it was moved. "nullptr" 640 // if this operation was the only operation in the region. 641 Operation *insertBeforeOp; 642 }; 643 644 /// In-place modification of an op. This rewrite is immediately reflected in 645 /// the IR. The previous state of the operation is stored in this object. 646 class ModifyOperationRewrite : public OperationRewrite { 647 public: 648 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, 649 Operation *op) 650 : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op), 651 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()), 652 operands(op->operand_begin(), op->operand_end()), 653 successors(op->successor_begin(), op->successor_end()) { 654 if (OpaqueProperties prop = op->getPropertiesStorage()) { 655 // Make a copy of the properties. 656 propertiesStorage = operator new(op->getPropertiesStorageSize()); 657 OpaqueProperties propCopy(propertiesStorage); 658 name.initOpProperties(propCopy, /*init=*/prop); 659 } 660 } 661 662 static bool classof(const IRRewrite *rewrite) { 663 return rewrite->getKind() == Kind::ModifyOperation; 664 } 665 666 ~ModifyOperationRewrite() override { 667 assert(!propertiesStorage && 668 "rewrite was neither committed nor rolled back"); 669 } 670 671 void commit(RewriterBase &rewriter) override { 672 // Notify the listener that the operation was modified in-place. 673 if (auto *listener = 674 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener())) 675 listener->notifyOperationModified(op); 676 677 if (propertiesStorage) { 678 OpaqueProperties propCopy(propertiesStorage); 679 // Note: The operation may have been erased in the mean time, so 680 // OperationName must be stored in this object. 681 name.destroyOpProperties(propCopy); 682 operator delete(propertiesStorage); 683 propertiesStorage = nullptr; 684 } 685 } 686 687 void rollback() override { 688 op->setLoc(loc); 689 op->setAttrs(attrs); 690 op->setOperands(operands); 691 for (const auto &it : llvm::enumerate(successors)) 692 op->setSuccessor(it.value(), it.index()); 693 if (propertiesStorage) { 694 OpaqueProperties propCopy(propertiesStorage); 695 op->copyProperties(propCopy); 696 name.destroyOpProperties(propCopy); 697 operator delete(propertiesStorage); 698 propertiesStorage = nullptr; 699 } 700 } 701 702 private: 703 OperationName name; 704 LocationAttr loc; 705 DictionaryAttr attrs; 706 SmallVector<Value, 8> operands; 707 SmallVector<Block *, 2> successors; 708 void *propertiesStorage = nullptr; 709 }; 710 711 /// Replacing an operation. Erasing an operation is treated as a special case 712 /// with "null" replacements. This rewrite is not immediately reflected in the 713 /// IR. An internal IR mapping is updated, but values are not replaced and the 714 /// original op is not erased until the rewrite is committed. 715 class ReplaceOperationRewrite : public OperationRewrite { 716 public: 717 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, 718 Operation *op, const TypeConverter *converter) 719 : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op), 720 converter(converter) {} 721 722 static bool classof(const IRRewrite *rewrite) { 723 return rewrite->getKind() == Kind::ReplaceOperation; 724 } 725 726 void commit(RewriterBase &rewriter) override; 727 728 void rollback() override; 729 730 void cleanup(RewriterBase &rewriter) override; 731 732 private: 733 /// An optional type converter that can be used to materialize conversions 734 /// between the new and old values if necessary. 735 const TypeConverter *converter; 736 }; 737 738 class CreateOperationRewrite : public OperationRewrite { 739 public: 740 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, 741 Operation *op) 742 : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {} 743 744 static bool classof(const IRRewrite *rewrite) { 745 return rewrite->getKind() == Kind::CreateOperation; 746 } 747 748 void commit(RewriterBase &rewriter) override { 749 // The operation was already created and inserted. Just inform the listener. 750 if (auto *listener = rewriter.getListener()) 751 listener->notifyOperationInserted(op, /*previous=*/{}); 752 } 753 754 void rollback() override; 755 }; 756 757 /// The type of materialization. 758 enum MaterializationKind { 759 /// This materialization materializes a conversion from an illegal type to a 760 /// legal one. 761 Target, 762 763 /// This materialization materializes a conversion from a legal type back to 764 /// an illegal one. 765 Source 766 }; 767 768 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast" 769 /// op. Unresolved materializations are erased at the end of the dialect 770 /// conversion. 771 class UnresolvedMaterializationRewrite : public OperationRewrite { 772 public: 773 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl, 774 UnrealizedConversionCastOp op, 775 const TypeConverter *converter, 776 MaterializationKind kind, Type originalType, 777 ValueVector mappedValues); 778 779 static bool classof(const IRRewrite *rewrite) { 780 return rewrite->getKind() == Kind::UnresolvedMaterialization; 781 } 782 783 void rollback() override; 784 785 UnrealizedConversionCastOp getOperation() const { 786 return cast<UnrealizedConversionCastOp>(op); 787 } 788 789 /// Return the type converter of this materialization (which may be null). 790 const TypeConverter *getConverter() const { 791 return converterAndKind.getPointer(); 792 } 793 794 /// Return the kind of this materialization. 795 MaterializationKind getMaterializationKind() const { 796 return converterAndKind.getInt(); 797 } 798 799 /// Return the original type of the SSA value. 800 Type getOriginalType() const { return originalType; } 801 802 private: 803 /// The corresponding type converter to use when resolving this 804 /// materialization, and the kind of this materialization. 805 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind> 806 converterAndKind; 807 808 /// The original type of the SSA value. Only used for target 809 /// materializations. 810 Type originalType; 811 812 /// The values in the conversion value mapping that are being replaced by the 813 /// results of this unresolved materialization. 814 ValueVector mappedValues; 815 }; 816 } // namespace 817 818 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 819 /// Return "true" if there is an operation rewrite that matches the specified 820 /// rewrite type and operation among the given rewrites. 821 template <typename RewriteTy, typename R> 822 static bool hasRewrite(R &&rewrites, Operation *op) { 823 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) { 824 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get()); 825 return rewriteTy && rewriteTy->getOperation() == op; 826 }); 827 } 828 829 /// Return "true" if there is a block rewrite that matches the specified 830 /// rewrite type and block among the given rewrites. 831 template <typename RewriteTy, typename R> 832 static bool hasRewrite(R &&rewrites, Block *block) { 833 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) { 834 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get()); 835 return rewriteTy && rewriteTy->getBlock() == block; 836 }); 837 } 838 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 839 840 //===----------------------------------------------------------------------===// 841 // ConversionPatternRewriterImpl 842 //===----------------------------------------------------------------------===// 843 namespace mlir { 844 namespace detail { 845 struct ConversionPatternRewriterImpl : public RewriterBase::Listener { 846 explicit ConversionPatternRewriterImpl(MLIRContext *ctx, 847 const ConversionConfig &config) 848 : context(ctx), eraseRewriter(ctx), config(config) {} 849 850 //===--------------------------------------------------------------------===// 851 // State Management 852 //===--------------------------------------------------------------------===// 853 854 /// Return the current state of the rewriter. 855 RewriterState getCurrentState(); 856 857 /// Apply all requested operation rewrites. This method is invoked when the 858 /// conversion process succeeds. 859 void applyRewrites(); 860 861 /// Reset the state of the rewriter to a previously saved point. 862 void resetState(RewriterState state); 863 864 /// Append a rewrite. Rewrites are committed upon success and rolled back upon 865 /// failure. 866 template <typename RewriteTy, typename... Args> 867 void appendRewrite(Args &&...args) { 868 rewrites.push_back( 869 std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...)); 870 } 871 872 /// Undo the rewrites (motions, splits) one by one in reverse order until 873 /// "numRewritesToKeep" rewrites remains. 874 void undoRewrites(unsigned numRewritesToKeep = 0); 875 876 /// Remap the given values to those with potentially different types. Returns 877 /// success if the values could be remapped, failure otherwise. `valueDiagTag` 878 /// is the tag used when describing a value within a diagnostic, e.g. 879 /// "operand". 880 LogicalResult remapValues(StringRef valueDiagTag, 881 std::optional<Location> inputLoc, 882 PatternRewriter &rewriter, ValueRange values, 883 SmallVector<ValueVector> &remapped); 884 885 /// Return "true" if the given operation is ignored, and does not need to be 886 /// converted. 887 bool isOpIgnored(Operation *op) const; 888 889 /// Return "true" if the given operation was replaced or erased. 890 bool wasOpReplaced(Operation *op) const; 891 892 //===--------------------------------------------------------------------===// 893 // Type Conversion 894 //===--------------------------------------------------------------------===// 895 896 /// Convert the types of block arguments within the given region. 897 FailureOr<Block *> 898 convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, 899 const TypeConverter &converter, 900 TypeConverter::SignatureConversion *entryConversion); 901 902 /// Apply the given signature conversion on the given block. The new block 903 /// containing the updated signature is returned. If no conversions were 904 /// necessary, e.g. if the block has no arguments, `block` is returned. 905 /// `converter` is used to generate any necessary cast operations that 906 /// translate between the origin argument types and those specified in the 907 /// signature conversion. 908 Block *applySignatureConversion( 909 ConversionPatternRewriter &rewriter, Block *block, 910 const TypeConverter *converter, 911 TypeConverter::SignatureConversion &signatureConversion); 912 913 //===--------------------------------------------------------------------===// 914 // Materializations 915 //===--------------------------------------------------------------------===// 916 917 /// Build an unresolved materialization operation given a range of output 918 /// types and a list of input operands. Returns the inputs if they their 919 /// types match the output types. 920 /// 921 /// If a cast op was built, it can optionally be returned with the `castOp` 922 /// output argument. 923 /// 924 /// If `valuesToMap` is set to a non-null Value, then that value is mapped to 925 /// the results of the unresolved materialization in the conversion value 926 /// mapping. 927 ValueRange buildUnresolvedMaterialization( 928 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, 929 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, 930 Type originalType, const TypeConverter *converter, 931 UnrealizedConversionCastOp *castOp = nullptr); 932 933 /// Find a replacement value for the given SSA value in the conversion value 934 /// mapping. The replacement value must have the same type as the given SSA 935 /// value. If there is no replacement value with the correct type, find the 936 /// latest replacement value (regardless of the type) and build a source 937 /// materialization. 938 Value findOrBuildReplacementValue(Value value, 939 const TypeConverter *converter); 940 941 //===--------------------------------------------------------------------===// 942 // Rewriter Notification Hooks 943 //===--------------------------------------------------------------------===// 944 945 //// Notifies that an op was inserted. 946 void notifyOperationInserted(Operation *op, 947 OpBuilder::InsertPoint previous) override; 948 949 /// Notifies that an op is about to be replaced with the given values. 950 void notifyOpReplaced(Operation *op, ArrayRef<ValueRange> newValues); 951 952 /// Notifies that a block is about to be erased. 953 void notifyBlockIsBeingErased(Block *block); 954 955 /// Notifies that a block was inserted. 956 void notifyBlockInserted(Block *block, Region *previous, 957 Region::iterator previousIt) override; 958 959 /// Notifies that a block is being inlined into another block. 960 void notifyBlockBeingInlined(Block *block, Block *srcBlock, 961 Block::iterator before); 962 963 /// Notifies that a pattern match failed for the given reason. 964 void 965 notifyMatchFailure(Location loc, 966 function_ref<void(Diagnostic &)> reasonCallback) override; 967 968 //===--------------------------------------------------------------------===// 969 // IR Erasure 970 //===--------------------------------------------------------------------===// 971 972 /// A rewriter that keeps track of erased ops and blocks. It ensures that no 973 /// operation or block is erased multiple times. This rewriter assumes that 974 /// no new IR is created between calls to `eraseOp`/`eraseBlock`. 975 struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener { 976 public: 977 SingleEraseRewriter(MLIRContext *context) 978 : RewriterBase(context, /*listener=*/this) {} 979 980 /// Erase the given op (unless it was already erased). 981 void eraseOp(Operation *op) override { 982 if (wasErased(op)) 983 return; 984 op->dropAllUses(); 985 RewriterBase::eraseOp(op); 986 } 987 988 /// Erase the given block (unless it was already erased). 989 void eraseBlock(Block *block) override { 990 if (wasErased(block)) 991 return; 992 assert(block->empty() && "expected empty block"); 993 block->dropAllDefinedValueUses(); 994 RewriterBase::eraseBlock(block); 995 } 996 997 bool wasErased(void *ptr) const { return erased.contains(ptr); } 998 999 void notifyOperationErased(Operation *op) override { erased.insert(op); } 1000 1001 void notifyBlockErased(Block *block) override { erased.insert(block); } 1002 1003 private: 1004 /// Pointers to all erased operations and blocks. 1005 DenseSet<void *> erased; 1006 }; 1007 1008 //===--------------------------------------------------------------------===// 1009 // State 1010 //===--------------------------------------------------------------------===// 1011 1012 /// MLIR context. 1013 MLIRContext *context; 1014 1015 /// A rewriter that keeps track of ops/block that were already erased and 1016 /// skips duplicate op/block erasures. This rewriter is used during the 1017 /// "cleanup" phase. 1018 SingleEraseRewriter eraseRewriter; 1019 1020 // Mapping between replaced values that differ in type. This happens when 1021 // replacing a value with one of a different type. 1022 ConversionValueMapping mapping; 1023 1024 /// Ordered list of block operations (creations, splits, motions). 1025 SmallVector<std::unique_ptr<IRRewrite>> rewrites; 1026 1027 /// A set of operations that should no longer be considered for legalization. 1028 /// E.g., ops that are recursively legal. Ops that were replaced/erased are 1029 /// tracked separately. 1030 SetVector<Operation *> ignoredOps; 1031 1032 /// A set of operations that were replaced/erased. Such ops are not erased 1033 /// immediately but only when the dialect conversion succeeds. In the mean 1034 /// time, they should no longer be considered for legalization and any attempt 1035 /// to modify/access them is invalid rewriter API usage. 1036 SetVector<Operation *> replacedOps; 1037 1038 /// A mapping of all unresolved materializations (UnrealizedConversionCastOp) 1039 /// to the corresponding rewrite objects. 1040 DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *> 1041 unresolvedMaterializations; 1042 1043 /// The current type converter, or nullptr if no type converter is currently 1044 /// active. 1045 const TypeConverter *currentTypeConverter = nullptr; 1046 1047 /// A mapping of regions to type converters that should be used when 1048 /// converting the arguments of blocks within that region. 1049 DenseMap<Region *, const TypeConverter *> regionToConverter; 1050 1051 /// Dialect conversion configuration. 1052 const ConversionConfig &config; 1053 1054 #ifndef NDEBUG 1055 /// A set of operations that have pending updates. This tracking isn't 1056 /// strictly necessary, and is thus only active during debug builds for extra 1057 /// verification. 1058 SmallPtrSet<Operation *, 1> pendingRootUpdates; 1059 1060 /// A logger used to emit diagnostics during the conversion process. 1061 llvm::ScopedPrinter logger{llvm::dbgs()}; 1062 #endif 1063 }; 1064 } // namespace detail 1065 } // namespace mlir 1066 1067 const ConversionConfig &IRRewrite::getConfig() const { 1068 return rewriterImpl.config; 1069 } 1070 1071 void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) { 1072 // Inform the listener about all IR modifications that have already taken 1073 // place: References to the original block have been replaced with the new 1074 // block. 1075 if (auto *listener = 1076 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener())) 1077 for (Operation *op : getNewBlock()->getUsers()) 1078 listener->notifyOperationModified(op); 1079 } 1080 1081 void BlockTypeConversionRewrite::rollback() { 1082 getNewBlock()->replaceAllUsesWith(getOrigBlock()); 1083 } 1084 1085 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { 1086 Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter); 1087 if (!repl) 1088 return; 1089 1090 if (isa<BlockArgument>(repl)) { 1091 rewriter.replaceAllUsesWith(arg, repl); 1092 return; 1093 } 1094 1095 // If the replacement value is an operation, we check to make sure that we 1096 // don't replace uses that are within the parent operation of the 1097 // replacement value. 1098 Operation *replOp = cast<OpResult>(repl).getOwner(); 1099 Block *replBlock = replOp->getBlock(); 1100 rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) { 1101 Operation *user = operand.getOwner(); 1102 return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); 1103 }); 1104 } 1105 1106 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); } 1107 1108 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { 1109 auto *listener = 1110 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()); 1111 1112 // Compute replacement values. 1113 SmallVector<Value> replacements = 1114 llvm::map_to_vector(op->getResults(), [&](OpResult result) { 1115 return rewriterImpl.findOrBuildReplacementValue(result, converter); 1116 }); 1117 1118 // Notify the listener that the operation is about to be replaced. 1119 if (listener) 1120 listener->notifyOperationReplaced(op, replacements); 1121 1122 // Replace all uses with the new values. 1123 for (auto [result, newValue] : 1124 llvm::zip_equal(op->getResults(), replacements)) 1125 if (newValue) 1126 rewriter.replaceAllUsesWith(result, newValue); 1127 1128 // The original op will be erased, so remove it from the set of unlegalized 1129 // ops. 1130 if (getConfig().unlegalizedOps) 1131 getConfig().unlegalizedOps->erase(op); 1132 1133 // Notify the listener that the operation (and its nested operations) was 1134 // erased. 1135 if (listener) { 1136 op->walk<WalkOrder::PostOrder>( 1137 [&](Operation *op) { listener->notifyOperationErased(op); }); 1138 } 1139 1140 // Do not erase the operation yet. It may still be referenced in `mapping`. 1141 // Just unlink it for now and erase it during cleanup. 1142 op->getBlock()->getOperations().remove(op); 1143 } 1144 1145 void ReplaceOperationRewrite::rollback() { 1146 for (auto result : op->getResults()) 1147 rewriterImpl.mapping.erase({result}); 1148 } 1149 1150 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) { 1151 rewriter.eraseOp(op); 1152 } 1153 1154 void CreateOperationRewrite::rollback() { 1155 for (Region ®ion : op->getRegions()) { 1156 while (!region.getBlocks().empty()) 1157 region.getBlocks().remove(region.getBlocks().begin()); 1158 } 1159 op->dropAllUses(); 1160 op->erase(); 1161 } 1162 1163 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite( 1164 ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op, 1165 const TypeConverter *converter, MaterializationKind kind, Type originalType, 1166 ValueVector mappedValues) 1167 : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), 1168 converterAndKind(converter, kind), originalType(originalType), 1169 mappedValues(std::move(mappedValues)) { 1170 assert((!originalType || kind == MaterializationKind::Target) && 1171 "original type is valid only for target materializations"); 1172 rewriterImpl.unresolvedMaterializations[op] = this; 1173 } 1174 1175 void UnresolvedMaterializationRewrite::rollback() { 1176 if (!mappedValues.empty()) 1177 rewriterImpl.mapping.erase(mappedValues); 1178 rewriterImpl.unresolvedMaterializations.erase(getOperation()); 1179 op->erase(); 1180 } 1181 1182 void ConversionPatternRewriterImpl::applyRewrites() { 1183 // Commit all rewrites. 1184 IRRewriter rewriter(context, config.listener); 1185 // Note: New rewrites may be added during the "commit" phase and the 1186 // `rewrites` vector may reallocate. 1187 for (size_t i = 0; i < rewrites.size(); ++i) 1188 rewrites[i]->commit(rewriter); 1189 1190 // Clean up all rewrites. 1191 for (auto &rewrite : rewrites) 1192 rewrite->cleanup(eraseRewriter); 1193 } 1194 1195 //===----------------------------------------------------------------------===// 1196 // State Management 1197 1198 RewriterState ConversionPatternRewriterImpl::getCurrentState() { 1199 return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size()); 1200 } 1201 1202 void ConversionPatternRewriterImpl::resetState(RewriterState state) { 1203 // Undo any rewrites. 1204 undoRewrites(state.numRewrites); 1205 1206 // Pop all of the recorded ignored operations that are no longer valid. 1207 while (ignoredOps.size() != state.numIgnoredOperations) 1208 ignoredOps.pop_back(); 1209 1210 while (replacedOps.size() != state.numReplacedOps) 1211 replacedOps.pop_back(); 1212 } 1213 1214 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) { 1215 for (auto &rewrite : 1216 llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) 1217 rewrite->rollback(); 1218 rewrites.resize(numRewritesToKeep); 1219 } 1220 1221 LogicalResult ConversionPatternRewriterImpl::remapValues( 1222 StringRef valueDiagTag, std::optional<Location> inputLoc, 1223 PatternRewriter &rewriter, ValueRange values, 1224 SmallVector<ValueVector> &remapped) { 1225 remapped.reserve(llvm::size(values)); 1226 1227 for (const auto &it : llvm::enumerate(values)) { 1228 Value operand = it.value(); 1229 Type origType = operand.getType(); 1230 Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); 1231 1232 if (!currentTypeConverter) { 1233 // The current pattern does not have a type converter. I.e., it does not 1234 // distinguish between legal and illegal types. For each operand, simply 1235 // pass through the most recently mapped values. 1236 remapped.push_back(mapping.lookupOrDefault(operand)); 1237 continue; 1238 } 1239 1240 // If there is no legal conversion, fail to match this pattern. 1241 SmallVector<Type, 1> legalTypes; 1242 if (failed(currentTypeConverter->convertType(origType, legalTypes))) { 1243 notifyMatchFailure(operandLoc, [=](Diagnostic &diag) { 1244 diag << "unable to convert type for " << valueDiagTag << " #" 1245 << it.index() << ", type was " << origType; 1246 }); 1247 return failure(); 1248 } 1249 // If a type is converted to 0 types, there is nothing to do. 1250 if (legalTypes.empty()) { 1251 remapped.push_back({}); 1252 continue; 1253 } 1254 1255 ValueVector repl = mapping.lookupOrDefault(operand, legalTypes); 1256 if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) { 1257 // Mapped values have the correct type or there is an existing 1258 // materialization. Or the operand is not mapped at all and has the 1259 // correct type. 1260 remapped.push_back(std::move(repl)); 1261 continue; 1262 } 1263 1264 // Create a materialization for the most recently mapped values. 1265 repl = mapping.lookupOrDefault(operand); 1266 ValueRange castValues = buildUnresolvedMaterialization( 1267 MaterializationKind::Target, computeInsertPoint(repl), operandLoc, 1268 /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes, 1269 /*originalType=*/origType, currentTypeConverter); 1270 remapped.push_back(castValues); 1271 } 1272 return success(); 1273 } 1274 1275 bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { 1276 // Check to see if this operation is ignored or was replaced. 1277 return replacedOps.count(op) || ignoredOps.count(op); 1278 } 1279 1280 bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const { 1281 // Check to see if this operation was replaced. 1282 return replacedOps.count(op); 1283 } 1284 1285 //===----------------------------------------------------------------------===// 1286 // Type Conversion 1287 1288 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes( 1289 ConversionPatternRewriter &rewriter, Region *region, 1290 const TypeConverter &converter, 1291 TypeConverter::SignatureConversion *entryConversion) { 1292 regionToConverter[region] = &converter; 1293 if (region->empty()) 1294 return nullptr; 1295 1296 // Convert the arguments of each non-entry block within the region. 1297 for (Block &block : 1298 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) { 1299 // Compute the signature for the block with the provided converter. 1300 std::optional<TypeConverter::SignatureConversion> conversion = 1301 converter.convertBlockSignature(&block); 1302 if (!conversion) 1303 return failure(); 1304 // Convert the block with the computed signature. 1305 applySignatureConversion(rewriter, &block, &converter, *conversion); 1306 } 1307 1308 // Convert the entry block. If an entry signature conversion was provided, 1309 // use that one. Otherwise, compute the signature with the type converter. 1310 if (entryConversion) 1311 return applySignatureConversion(rewriter, ®ion->front(), &converter, 1312 *entryConversion); 1313 std::optional<TypeConverter::SignatureConversion> conversion = 1314 converter.convertBlockSignature(®ion->front()); 1315 if (!conversion) 1316 return failure(); 1317 return applySignatureConversion(rewriter, ®ion->front(), &converter, 1318 *conversion); 1319 } 1320 1321 Block *ConversionPatternRewriterImpl::applySignatureConversion( 1322 ConversionPatternRewriter &rewriter, Block *block, 1323 const TypeConverter *converter, 1324 TypeConverter::SignatureConversion &signatureConversion) { 1325 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 1326 // A block cannot be converted multiple times. 1327 if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block)) 1328 llvm::report_fatal_error("block was already converted"); 1329 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 1330 1331 OpBuilder::InsertionGuard g(rewriter); 1332 1333 // If no arguments are being changed or added, there is nothing to do. 1334 unsigned origArgCount = block->getNumArguments(); 1335 auto convertedTypes = signatureConversion.getConvertedTypes(); 1336 if (llvm::equal(block->getArgumentTypes(), convertedTypes)) 1337 return block; 1338 1339 // Compute the locations of all block arguments in the new block. 1340 SmallVector<Location> newLocs(convertedTypes.size(), 1341 rewriter.getUnknownLoc()); 1342 for (unsigned i = 0; i < origArgCount; ++i) { 1343 auto inputMap = signatureConversion.getInputMapping(i); 1344 if (!inputMap || inputMap->replacementValue) 1345 continue; 1346 Location origLoc = block->getArgument(i).getLoc(); 1347 for (unsigned j = 0; j < inputMap->size; ++j) 1348 newLocs[inputMap->inputNo + j] = origLoc; 1349 } 1350 1351 // Insert a new block with the converted block argument types and move all ops 1352 // from the old block to the new block. 1353 Block *newBlock = 1354 rewriter.createBlock(block->getParent(), std::next(block->getIterator()), 1355 convertedTypes, newLocs); 1356 1357 // If a listener is attached to the dialect conversion, ops cannot be moved 1358 // to the destination block in bulk ("fast path"). This is because at the time 1359 // the notifications are sent, it is unknown which ops were moved. Instead, 1360 // ops should be moved one-by-one ("slow path"), so that a separate 1361 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is 1362 // a bit more efficient, so we try to do that when possible. 1363 bool fastPath = !config.listener; 1364 if (fastPath) { 1365 appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end()); 1366 newBlock->getOperations().splice(newBlock->end(), block->getOperations()); 1367 } else { 1368 while (!block->empty()) 1369 rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end()); 1370 } 1371 1372 // Replace all uses of the old block with the new block. 1373 block->replaceAllUsesWith(newBlock); 1374 1375 for (unsigned i = 0; i != origArgCount; ++i) { 1376 BlockArgument origArg = block->getArgument(i); 1377 Type origArgType = origArg.getType(); 1378 1379 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap = 1380 signatureConversion.getInputMapping(i); 1381 if (!inputMap) { 1382 // This block argument was dropped and no replacement value was provided. 1383 // Materialize a replacement value "out of thin air". 1384 buildUnresolvedMaterialization( 1385 MaterializationKind::Source, 1386 OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), 1387 /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(), 1388 /*outputType=*/origArgType, /*originalType=*/Type(), converter); 1389 appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter); 1390 continue; 1391 } 1392 1393 if (Value repl = inputMap->replacementValue) { 1394 // This block argument was dropped and a replacement value was provided. 1395 assert(inputMap->size == 0 && 1396 "invalid to provide a replacement value when the argument isn't " 1397 "dropped"); 1398 mapping.map(origArg, repl); 1399 appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter); 1400 continue; 1401 } 1402 1403 // This is a 1->1+ mapping. 1404 auto replArgs = 1405 newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); 1406 ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs); 1407 mapping.map(origArg, std::move(replArgVals)); 1408 appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter); 1409 } 1410 1411 appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock); 1412 1413 // Erase the old block. (It is just unlinked for now and will be erased during 1414 // cleanup.) 1415 rewriter.eraseBlock(block); 1416 1417 return newBlock; 1418 } 1419 1420 //===----------------------------------------------------------------------===// 1421 // Materializations 1422 //===----------------------------------------------------------------------===// 1423 1424 /// Build an unresolved materialization operation given an output type and set 1425 /// of input operands. 1426 ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( 1427 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, 1428 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, 1429 Type originalType, const TypeConverter *converter, 1430 UnrealizedConversionCastOp *castOp) { 1431 assert((!originalType || kind == MaterializationKind::Target) && 1432 "original type is valid only for target materializations"); 1433 assert(TypeRange(inputs) != outputTypes && 1434 "materialization is not necessary"); 1435 1436 // Create an unresolved materialization. We use a new OpBuilder to avoid 1437 // tracking the materialization like we do for other operations. 1438 OpBuilder builder(outputTypes.front().getContext()); 1439 builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); 1440 auto convertOp = 1441 builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs); 1442 if (!valuesToMap.empty()) 1443 mapping.map(valuesToMap, convertOp.getResults()); 1444 if (castOp) 1445 *castOp = convertOp; 1446 appendRewrite<UnresolvedMaterializationRewrite>( 1447 convertOp, converter, kind, originalType, std::move(valuesToMap)); 1448 return convertOp.getResults(); 1449 } 1450 1451 Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( 1452 Value value, const TypeConverter *converter) { 1453 // Try to find a replacement value with the same type in the conversion value 1454 // mapping. This includes cached materializations. We try to reuse those 1455 // instead of generating duplicate IR. 1456 ValueVector repl = mapping.lookupOrNull(value, value.getType()); 1457 if (!repl.empty()) 1458 return repl.front(); 1459 1460 // Check if the value is dead. No replacement value is needed in that case. 1461 // This is an approximate check that may have false negatives but does not 1462 // require computing and traversing an inverse mapping. (We may end up 1463 // building source materializations that are never used and that fold away.) 1464 if (llvm::all_of(value.getUsers(), 1465 [&](Operation *op) { return replacedOps.contains(op); }) && 1466 !mapping.isMappedTo(value)) 1467 return Value(); 1468 1469 // No replacement value was found. Get the latest replacement value 1470 // (regardless of the type) and build a source materialization to the 1471 // original type. 1472 repl = mapping.lookupOrNull(value); 1473 if (repl.empty()) { 1474 // No replacement value is registered in the mapping. This means that the 1475 // value is dropped and no longer needed. (If the value were still needed, 1476 // a source materialization producing a replacement value "out of thin air" 1477 // would have already been created during `replaceOp` or 1478 // `applySignatureConversion`.) 1479 return Value(); 1480 } 1481 1482 // Note: `computeInsertPoint` computes the "earliest" insertion point at 1483 // which all values in `repl` are defined. It is important to emit the 1484 // materialization at that location because the same materialization may be 1485 // reused in a different context. (That's because materializations are cached 1486 // in the conversion value mapping.) The insertion point of the 1487 // materialization must be valid for all future users that may be created 1488 // later in the conversion process. 1489 Value castValue = 1490 buildUnresolvedMaterialization(MaterializationKind::Source, 1491 computeInsertPoint(repl), value.getLoc(), 1492 /*valuesToMap=*/repl, /*inputs=*/repl, 1493 /*outputType=*/value.getType(), 1494 /*originalType=*/Type(), converter) 1495 .front(); 1496 return castValue; 1497 } 1498 1499 //===----------------------------------------------------------------------===// 1500 // Rewriter Notification Hooks 1501 1502 void ConversionPatternRewriterImpl::notifyOperationInserted( 1503 Operation *op, OpBuilder::InsertPoint previous) { 1504 LLVM_DEBUG({ 1505 logger.startLine() << "** Insert : '" << op->getName() << "'(" << op 1506 << ")\n"; 1507 }); 1508 assert(!wasOpReplaced(op->getParentOp()) && 1509 "attempting to insert into a block within a replaced/erased op"); 1510 1511 if (!previous.isSet()) { 1512 // This is a newly created op. 1513 appendRewrite<CreateOperationRewrite>(op); 1514 return; 1515 } 1516 Operation *prevOp = previous.getPoint() == previous.getBlock()->end() 1517 ? nullptr 1518 : &*previous.getPoint(); 1519 appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp); 1520 } 1521 1522 void ConversionPatternRewriterImpl::notifyOpReplaced( 1523 Operation *op, ArrayRef<ValueRange> newValues) { 1524 assert(newValues.size() == op->getNumResults()); 1525 assert(!ignoredOps.contains(op) && "operation was already replaced"); 1526 1527 // Check if replaced op is an unresolved materialization, i.e., an 1528 // unrealized_conversion_cast op that was created by the conversion driver. 1529 bool isUnresolvedMaterialization = false; 1530 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) 1531 if (unresolvedMaterializations.contains(castOp)) 1532 isUnresolvedMaterialization = true; 1533 1534 // Create mappings for each of the new result values. 1535 for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults())) { 1536 if (repl.empty()) { 1537 // This result was dropped and no replacement value was provided. 1538 if (isUnresolvedMaterialization) { 1539 // Do not create another materializations if we are erasing a 1540 // materialization. 1541 continue; 1542 } 1543 1544 // Materialize a replacement value "out of thin air". 1545 buildUnresolvedMaterialization( 1546 MaterializationKind::Source, computeInsertPoint(result), 1547 result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(), 1548 /*outputType=*/result.getType(), /*originalType=*/Type(), 1549 currentTypeConverter); 1550 continue; 1551 } else { 1552 // Make sure that the user does not mess with unresolved materializations 1553 // that were inserted by the conversion driver. We keep track of these 1554 // ops in internal data structures. Erasing them must be allowed because 1555 // this can happen when the user is erasing an entire block (including 1556 // its body). But replacing them with another value should be forbidden 1557 // to avoid problems with the `mapping`. 1558 assert(!isUnresolvedMaterialization && 1559 "attempting to replace an unresolved materialization"); 1560 } 1561 1562 // Remap result to replacement value. 1563 if (repl.empty()) 1564 continue; 1565 mapping.map(result, repl); 1566 } 1567 1568 appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter); 1569 // Mark this operation and all nested ops as replaced. 1570 op->walk([&](Operation *op) { replacedOps.insert(op); }); 1571 } 1572 1573 void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) { 1574 appendRewrite<EraseBlockRewrite>(block); 1575 } 1576 1577 void ConversionPatternRewriterImpl::notifyBlockInserted( 1578 Block *block, Region *previous, Region::iterator previousIt) { 1579 assert(!wasOpReplaced(block->getParentOp()) && 1580 "attempting to insert into a region within a replaced/erased op"); 1581 LLVM_DEBUG( 1582 { 1583 Operation *parent = block->getParentOp(); 1584 if (parent) { 1585 logger.startLine() << "** Insert Block into : '" << parent->getName() 1586 << "'(" << parent << ")\n"; 1587 } else { 1588 logger.startLine() 1589 << "** Insert Block into detached Region (nullptr parent op)'"; 1590 } 1591 }); 1592 1593 if (!previous) { 1594 // This is a newly created block. 1595 appendRewrite<CreateBlockRewrite>(block); 1596 return; 1597 } 1598 Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt; 1599 appendRewrite<MoveBlockRewrite>(block, previous, prevBlock); 1600 } 1601 1602 void ConversionPatternRewriterImpl::notifyBlockBeingInlined( 1603 Block *block, Block *srcBlock, Block::iterator before) { 1604 appendRewrite<InlineBlockRewrite>(block, srcBlock, before); 1605 } 1606 1607 void ConversionPatternRewriterImpl::notifyMatchFailure( 1608 Location loc, function_ref<void(Diagnostic &)> reasonCallback) { 1609 LLVM_DEBUG({ 1610 Diagnostic diag(loc, DiagnosticSeverity::Remark); 1611 reasonCallback(diag); 1612 logger.startLine() << "** Failure : " << diag.str() << "\n"; 1613 if (config.notifyCallback) 1614 config.notifyCallback(diag); 1615 }); 1616 } 1617 1618 //===----------------------------------------------------------------------===// 1619 // ConversionPatternRewriter 1620 //===----------------------------------------------------------------------===// 1621 1622 ConversionPatternRewriter::ConversionPatternRewriter( 1623 MLIRContext *ctx, const ConversionConfig &config) 1624 : PatternRewriter(ctx), 1625 impl(new detail::ConversionPatternRewriterImpl(ctx, config)) { 1626 setListener(impl.get()); 1627 } 1628 1629 ConversionPatternRewriter::~ConversionPatternRewriter() = default; 1630 1631 void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) { 1632 assert(op && newOp && "expected non-null op"); 1633 replaceOp(op, newOp->getResults()); 1634 } 1635 1636 void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { 1637 assert(op->getNumResults() == newValues.size() && 1638 "incorrect # of replacement values"); 1639 LLVM_DEBUG({ 1640 impl->logger.startLine() 1641 << "** Replace : '" << op->getName() << "'(" << op << ")\n"; 1642 }); 1643 SmallVector<ValueRange> newVals; 1644 for (size_t i = 0; i < newValues.size(); ++i) { 1645 if (newValues[i]) { 1646 newVals.push_back(newValues.slice(i, 1)); 1647 } else { 1648 newVals.push_back(ValueRange()); 1649 } 1650 } 1651 impl->notifyOpReplaced(op, newVals); 1652 } 1653 1654 void ConversionPatternRewriter::replaceOpWithMultiple( 1655 Operation *op, ArrayRef<ValueRange> newValues) { 1656 assert(op->getNumResults() == newValues.size() && 1657 "incorrect # of replacement values"); 1658 LLVM_DEBUG({ 1659 impl->logger.startLine() 1660 << "** Replace : '" << op->getName() << "'(" << op << ")\n"; 1661 }); 1662 impl->notifyOpReplaced(op, newValues); 1663 } 1664 1665 void ConversionPatternRewriter::eraseOp(Operation *op) { 1666 LLVM_DEBUG({ 1667 impl->logger.startLine() 1668 << "** Erase : '" << op->getName() << "'(" << op << ")\n"; 1669 }); 1670 SmallVector<ValueRange> nullRepls(op->getNumResults(), {}); 1671 impl->notifyOpReplaced(op, nullRepls); 1672 } 1673 1674 void ConversionPatternRewriter::eraseBlock(Block *block) { 1675 assert(!impl->wasOpReplaced(block->getParentOp()) && 1676 "attempting to erase a block within a replaced/erased op"); 1677 1678 // Mark all ops for erasure. 1679 for (Operation &op : *block) 1680 eraseOp(&op); 1681 1682 // Unlink the block from its parent region. The block is kept in the rewrite 1683 // object and will be actually destroyed when rewrites are applied. This 1684 // allows us to keep the operations in the block live and undo the removal by 1685 // re-inserting the block. 1686 impl->notifyBlockIsBeingErased(block); 1687 block->getParent()->getBlocks().remove(block); 1688 } 1689 1690 Block *ConversionPatternRewriter::applySignatureConversion( 1691 Block *block, TypeConverter::SignatureConversion &conversion, 1692 const TypeConverter *converter) { 1693 assert(!impl->wasOpReplaced(block->getParentOp()) && 1694 "attempting to apply a signature conversion to a block within a " 1695 "replaced/erased op"); 1696 return impl->applySignatureConversion(*this, block, converter, conversion); 1697 } 1698 1699 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes( 1700 Region *region, const TypeConverter &converter, 1701 TypeConverter::SignatureConversion *entryConversion) { 1702 assert(!impl->wasOpReplaced(region->getParentOp()) && 1703 "attempting to apply a signature conversion to a block within a " 1704 "replaced/erased op"); 1705 return impl->convertRegionTypes(*this, region, converter, entryConversion); 1706 } 1707 1708 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, 1709 Value to) { 1710 LLVM_DEBUG({ 1711 Operation *parentOp = from.getOwner()->getParentOp(); 1712 impl->logger.startLine() << "** Replace Argument : '" << from 1713 << "'(in region of '" << parentOp->getName() 1714 << "'(" << from.getOwner()->getParentOp() << ")\n"; 1715 }); 1716 impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, 1717 impl->currentTypeConverter); 1718 impl->mapping.map(impl->mapping.lookupOrDefault(from), to); 1719 } 1720 1721 Value ConversionPatternRewriter::getRemappedValue(Value key) { 1722 SmallVector<ValueVector> remappedValues; 1723 if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key, 1724 remappedValues))) 1725 return nullptr; 1726 assert(remappedValues.front().size() == 1 && "1:N conversion not supported"); 1727 return remappedValues.front().front(); 1728 } 1729 1730 LogicalResult 1731 ConversionPatternRewriter::getRemappedValues(ValueRange keys, 1732 SmallVectorImpl<Value> &results) { 1733 if (keys.empty()) 1734 return success(); 1735 SmallVector<ValueVector> remapped; 1736 if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys, 1737 remapped))) 1738 return failure(); 1739 for (const auto &values : remapped) { 1740 assert(values.size() == 1 && "1:N conversion not supported"); 1741 results.push_back(values.front()); 1742 } 1743 return success(); 1744 } 1745 1746 void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, 1747 Block::iterator before, 1748 ValueRange argValues) { 1749 #ifndef NDEBUG 1750 assert(argValues.size() == source->getNumArguments() && 1751 "incorrect # of argument replacement values"); 1752 assert(!impl->wasOpReplaced(source->getParentOp()) && 1753 "attempting to inline a block from a replaced/erased op"); 1754 assert(!impl->wasOpReplaced(dest->getParentOp()) && 1755 "attempting to inline a block into a replaced/erased op"); 1756 auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); }; 1757 // The source block will be deleted, so it should not have any users (i.e., 1758 // there should be no predecessors). 1759 assert(llvm::all_of(source->getUsers(), opIgnored) && 1760 "expected 'source' to have no predecessors"); 1761 #endif // NDEBUG 1762 1763 // If a listener is attached to the dialect conversion, ops cannot be moved 1764 // to the destination block in bulk ("fast path"). This is because at the time 1765 // the notifications are sent, it is unknown which ops were moved. Instead, 1766 // ops should be moved one-by-one ("slow path"), so that a separate 1767 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is 1768 // a bit more efficient, so we try to do that when possible. 1769 bool fastPath = !impl->config.listener; 1770 1771 if (fastPath) 1772 impl->notifyBlockBeingInlined(dest, source, before); 1773 1774 // Replace all uses of block arguments. 1775 for (auto it : llvm::zip(source->getArguments(), argValues)) 1776 replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it)); 1777 1778 if (fastPath) { 1779 // Move all ops at once. 1780 dest->getOperations().splice(before, source->getOperations()); 1781 } else { 1782 // Move op by op. 1783 while (!source->empty()) 1784 moveOpBefore(&source->front(), dest, before); 1785 } 1786 1787 // Erase the source block. 1788 eraseBlock(source); 1789 } 1790 1791 void ConversionPatternRewriter::startOpModification(Operation *op) { 1792 assert(!impl->wasOpReplaced(op) && 1793 "attempting to modify a replaced/erased op"); 1794 #ifndef NDEBUG 1795 impl->pendingRootUpdates.insert(op); 1796 #endif 1797 impl->appendRewrite<ModifyOperationRewrite>(op); 1798 } 1799 1800 void ConversionPatternRewriter::finalizeOpModification(Operation *op) { 1801 assert(!impl->wasOpReplaced(op) && 1802 "attempting to modify a replaced/erased op"); 1803 PatternRewriter::finalizeOpModification(op); 1804 // There is nothing to do here, we only need to track the operation at the 1805 // start of the update. 1806 #ifndef NDEBUG 1807 assert(impl->pendingRootUpdates.erase(op) && 1808 "operation did not have a pending in-place update"); 1809 #endif 1810 } 1811 1812 void ConversionPatternRewriter::cancelOpModification(Operation *op) { 1813 #ifndef NDEBUG 1814 assert(impl->pendingRootUpdates.erase(op) && 1815 "operation did not have a pending in-place update"); 1816 #endif 1817 // Erase the last update for this operation. 1818 auto it = llvm::find_if( 1819 llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) { 1820 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get()); 1821 return modifyRewrite && modifyRewrite->getOperation() == op; 1822 }); 1823 assert(it != impl->rewrites.rend() && "no root update started on op"); 1824 (*it)->rollback(); 1825 int updateIdx = std::prev(impl->rewrites.rend()) - it; 1826 impl->rewrites.erase(impl->rewrites.begin() + updateIdx); 1827 } 1828 1829 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { 1830 return *impl; 1831 } 1832 1833 //===----------------------------------------------------------------------===// 1834 // ConversionPattern 1835 //===----------------------------------------------------------------------===// 1836 1837 SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands( 1838 ArrayRef<ValueRange> operands) const { 1839 SmallVector<Value> oneToOneOperands; 1840 oneToOneOperands.reserve(operands.size()); 1841 for (ValueRange operand : operands) { 1842 if (operand.size() != 1) 1843 llvm::report_fatal_error("pattern '" + getDebugName() + 1844 "' does not support 1:N conversion"); 1845 oneToOneOperands.push_back(operand.front()); 1846 } 1847 return oneToOneOperands; 1848 } 1849 1850 LogicalResult 1851 ConversionPattern::matchAndRewrite(Operation *op, 1852 PatternRewriter &rewriter) const { 1853 auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter); 1854 auto &rewriterImpl = dialectRewriter.getImpl(); 1855 1856 // Track the current conversion pattern type converter in the rewriter. 1857 llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter, 1858 getTypeConverter()); 1859 1860 // Remap the operands of the operation. 1861 SmallVector<ValueVector> remapped; 1862 if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter, 1863 op->getOperands(), remapped))) { 1864 return failure(); 1865 } 1866 SmallVector<ValueRange> remappedAsRange = 1867 llvm::to_vector_of<ValueRange>(remapped); 1868 return matchAndRewrite(op, remappedAsRange, dialectRewriter); 1869 } 1870 1871 //===----------------------------------------------------------------------===// 1872 // OperationLegalizer 1873 //===----------------------------------------------------------------------===// 1874 1875 namespace { 1876 /// A set of rewrite patterns that can be used to legalize a given operation. 1877 using LegalizationPatterns = SmallVector<const Pattern *, 1>; 1878 1879 /// This class defines a recursive operation legalizer. 1880 class OperationLegalizer { 1881 public: 1882 using LegalizationAction = ConversionTarget::LegalizationAction; 1883 1884 OperationLegalizer(const ConversionTarget &targetInfo, 1885 const FrozenRewritePatternSet &patterns, 1886 const ConversionConfig &config); 1887 1888 /// Returns true if the given operation is known to be illegal on the target. 1889 bool isIllegal(Operation *op) const; 1890 1891 /// Attempt to legalize the given operation. Returns success if the operation 1892 /// was legalized, failure otherwise. 1893 LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter); 1894 1895 /// Returns the conversion target in use by the legalizer. 1896 const ConversionTarget &getTarget() { return target; } 1897 1898 private: 1899 /// Attempt to legalize the given operation by folding it. 1900 LogicalResult legalizeWithFold(Operation *op, 1901 ConversionPatternRewriter &rewriter); 1902 1903 /// Attempt to legalize the given operation by applying a pattern. Returns 1904 /// success if the operation was legalized, failure otherwise. 1905 LogicalResult legalizeWithPattern(Operation *op, 1906 ConversionPatternRewriter &rewriter); 1907 1908 /// Return true if the given pattern may be applied to the given operation, 1909 /// false otherwise. 1910 bool canApplyPattern(Operation *op, const Pattern &pattern, 1911 ConversionPatternRewriter &rewriter); 1912 1913 /// Legalize the resultant IR after successfully applying the given pattern. 1914 LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern, 1915 ConversionPatternRewriter &rewriter, 1916 RewriterState &curState); 1917 1918 /// Legalizes the actions registered during the execution of a pattern. 1919 LogicalResult 1920 legalizePatternBlockRewrites(Operation *op, 1921 ConversionPatternRewriter &rewriter, 1922 ConversionPatternRewriterImpl &impl, 1923 RewriterState &state, RewriterState &newState); 1924 LogicalResult legalizePatternCreatedOperations( 1925 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, 1926 RewriterState &state, RewriterState &newState); 1927 LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter, 1928 ConversionPatternRewriterImpl &impl, 1929 RewriterState &state, 1930 RewriterState &newState); 1931 1932 //===--------------------------------------------------------------------===// 1933 // Cost Model 1934 //===--------------------------------------------------------------------===// 1935 1936 /// Build an optimistic legalization graph given the provided patterns. This 1937 /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with 1938 /// patterns for operations that are not directly legal, but may be 1939 /// transitively legal for the current target given the provided patterns. 1940 void buildLegalizationGraph( 1941 LegalizationPatterns &anyOpLegalizerPatterns, 1942 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); 1943 1944 /// Compute the benefit of each node within the computed legalization graph. 1945 /// This orders the patterns within 'legalizerPatterns' based upon two 1946 /// criteria: 1947 /// 1) Prefer patterns that have the lowest legalization depth, i.e. 1948 /// represent the more direct mapping to the target. 1949 /// 2) When comparing patterns with the same legalization depth, prefer the 1950 /// pattern with the highest PatternBenefit. This allows for users to 1951 /// prefer specific legalizations over others. 1952 void computeLegalizationGraphBenefit( 1953 LegalizationPatterns &anyOpLegalizerPatterns, 1954 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); 1955 1956 /// Compute the legalization depth when legalizing an operation of the given 1957 /// type. 1958 unsigned computeOpLegalizationDepth( 1959 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth, 1960 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); 1961 1962 /// Apply the conversion cost model to the given set of patterns, and return 1963 /// the smallest legalization depth of any of the patterns. See 1964 /// `computeLegalizationGraphBenefit` for the breakdown of the cost model. 1965 unsigned applyCostModelToPatterns( 1966 LegalizationPatterns &patterns, 1967 DenseMap<OperationName, unsigned> &minOpPatternDepth, 1968 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); 1969 1970 /// The current set of patterns that have been applied. 1971 SmallPtrSet<const Pattern *, 8> appliedPatterns; 1972 1973 /// The legalization information provided by the target. 1974 const ConversionTarget ⌖ 1975 1976 /// The pattern applicator to use for conversions. 1977 PatternApplicator applicator; 1978 1979 /// Dialect conversion configuration. 1980 const ConversionConfig &config; 1981 }; 1982 } // namespace 1983 1984 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo, 1985 const FrozenRewritePatternSet &patterns, 1986 const ConversionConfig &config) 1987 : target(targetInfo), applicator(patterns), config(config) { 1988 // The set of patterns that can be applied to illegal operations to transform 1989 // them into legal ones. 1990 DenseMap<OperationName, LegalizationPatterns> legalizerPatterns; 1991 LegalizationPatterns anyOpLegalizerPatterns; 1992 1993 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns); 1994 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns); 1995 } 1996 1997 bool OperationLegalizer::isIllegal(Operation *op) const { 1998 return target.isIllegal(op); 1999 } 2000 2001 LogicalResult 2002 OperationLegalizer::legalize(Operation *op, 2003 ConversionPatternRewriter &rewriter) { 2004 #ifndef NDEBUG 2005 const char *logLineComment = 2006 "//===-------------------------------------------===//\n"; 2007 2008 auto &logger = rewriter.getImpl().logger; 2009 #endif 2010 LLVM_DEBUG({ 2011 logger.getOStream() << "\n"; 2012 logger.startLine() << logLineComment; 2013 logger.startLine() << "Legalizing operation : '" << op->getName() << "'(" 2014 << op << ") {\n"; 2015 logger.indent(); 2016 2017 // If the operation has no regions, just print it here. 2018 if (op->getNumRegions() == 0) { 2019 op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm()); 2020 logger.getOStream() << "\n\n"; 2021 } 2022 }); 2023 2024 // Check if this operation is legal on the target. 2025 if (auto legalityInfo = target.isLegal(op)) { 2026 LLVM_DEBUG({ 2027 logSuccess( 2028 logger, "operation marked legal by the target{0}", 2029 legalityInfo->isRecursivelyLegal 2030 ? "; NOTE: operation is recursively legal; skipping internals" 2031 : ""); 2032 logger.startLine() << logLineComment; 2033 }); 2034 2035 // If this operation is recursively legal, mark its children as ignored so 2036 // that we don't consider them for legalization. 2037 if (legalityInfo->isRecursivelyLegal) { 2038 op->walk([&](Operation *nested) { 2039 if (op != nested) 2040 rewriter.getImpl().ignoredOps.insert(nested); 2041 }); 2042 } 2043 2044 return success(); 2045 } 2046 2047 // Check to see if the operation is ignored and doesn't need to be converted. 2048 if (rewriter.getImpl().isOpIgnored(op)) { 2049 LLVM_DEBUG({ 2050 logSuccess(logger, "operation marked 'ignored' during conversion"); 2051 logger.startLine() << logLineComment; 2052 }); 2053 return success(); 2054 } 2055 2056 // If the operation isn't legal, try to fold it in-place. 2057 // TODO: Should we always try to do this, even if the op is 2058 // already legal? 2059 if (succeeded(legalizeWithFold(op, rewriter))) { 2060 LLVM_DEBUG({ 2061 logSuccess(logger, "operation was folded"); 2062 logger.startLine() << logLineComment; 2063 }); 2064 return success(); 2065 } 2066 2067 // Otherwise, we need to apply a legalization pattern to this operation. 2068 if (succeeded(legalizeWithPattern(op, rewriter))) { 2069 LLVM_DEBUG({ 2070 logSuccess(logger, ""); 2071 logger.startLine() << logLineComment; 2072 }); 2073 return success(); 2074 } 2075 2076 LLVM_DEBUG({ 2077 logFailure(logger, "no matched legalization pattern"); 2078 logger.startLine() << logLineComment; 2079 }); 2080 return failure(); 2081 } 2082 2083 LogicalResult 2084 OperationLegalizer::legalizeWithFold(Operation *op, 2085 ConversionPatternRewriter &rewriter) { 2086 auto &rewriterImpl = rewriter.getImpl(); 2087 RewriterState curState = rewriterImpl.getCurrentState(); 2088 2089 LLVM_DEBUG({ 2090 rewriterImpl.logger.startLine() << "* Fold {\n"; 2091 rewriterImpl.logger.indent(); 2092 }); 2093 2094 // Try to fold the operation. 2095 SmallVector<Value, 2> replacementValues; 2096 rewriter.setInsertionPoint(op); 2097 if (failed(rewriter.tryFold(op, replacementValues))) { 2098 LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold")); 2099 return failure(); 2100 } 2101 // An empty list of replacement values indicates that the fold was in-place. 2102 // As the operation changed, a new legalization needs to be attempted. 2103 if (replacementValues.empty()) 2104 return legalize(op, rewriter); 2105 2106 // Insert a replacement for 'op' with the folded replacement values. 2107 rewriter.replaceOp(op, replacementValues); 2108 2109 // Recursively legalize any new constant operations. 2110 for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size(); 2111 i != e; ++i) { 2112 auto *createOp = 2113 dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get()); 2114 if (!createOp) 2115 continue; 2116 if (failed(legalize(createOp->getOperation(), rewriter))) { 2117 LLVM_DEBUG(logFailure(rewriterImpl.logger, 2118 "failed to legalize generated constant '{0}'", 2119 createOp->getOperation()->getName())); 2120 rewriterImpl.resetState(curState); 2121 return failure(); 2122 } 2123 } 2124 2125 LLVM_DEBUG(logSuccess(rewriterImpl.logger, "")); 2126 return success(); 2127 } 2128 2129 LogicalResult 2130 OperationLegalizer::legalizeWithPattern(Operation *op, 2131 ConversionPatternRewriter &rewriter) { 2132 auto &rewriterImpl = rewriter.getImpl(); 2133 2134 // Functor that returns if the given pattern may be applied. 2135 auto canApply = [&](const Pattern &pattern) { 2136 bool canApply = canApplyPattern(op, pattern, rewriter); 2137 if (canApply && config.listener) 2138 config.listener->notifyPatternBegin(pattern, op); 2139 return canApply; 2140 }; 2141 2142 // Functor that cleans up the rewriter state after a pattern failed to match. 2143 RewriterState curState = rewriterImpl.getCurrentState(); 2144 auto onFailure = [&](const Pattern &pattern) { 2145 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); 2146 LLVM_DEBUG({ 2147 logFailure(rewriterImpl.logger, "pattern failed to match"); 2148 if (rewriterImpl.config.notifyCallback) { 2149 Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark); 2150 diag << "Failed to apply pattern \"" << pattern.getDebugName() 2151 << "\" on op:\n" 2152 << *op; 2153 rewriterImpl.config.notifyCallback(diag); 2154 } 2155 }); 2156 if (config.listener) 2157 config.listener->notifyPatternEnd(pattern, failure()); 2158 rewriterImpl.resetState(curState); 2159 appliedPatterns.erase(&pattern); 2160 }; 2161 2162 // Functor that performs additional legalization when a pattern is 2163 // successfully applied. 2164 auto onSuccess = [&](const Pattern &pattern) { 2165 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); 2166 auto result = legalizePatternResult(op, pattern, rewriter, curState); 2167 appliedPatterns.erase(&pattern); 2168 if (failed(result)) 2169 rewriterImpl.resetState(curState); 2170 if (config.listener) 2171 config.listener->notifyPatternEnd(pattern, result); 2172 return result; 2173 }; 2174 2175 // Try to match and rewrite a pattern on this operation. 2176 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure, 2177 onSuccess); 2178 } 2179 2180 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern, 2181 ConversionPatternRewriter &rewriter) { 2182 LLVM_DEBUG({ 2183 auto &os = rewriter.getImpl().logger; 2184 os.getOStream() << "\n"; 2185 os.startLine() << "* Pattern : '" << op->getName() << " -> ("; 2186 llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream()); 2187 os.getOStream() << ")' {\n"; 2188 os.indent(); 2189 }); 2190 2191 // Ensure that we don't cycle by not allowing the same pattern to be 2192 // applied twice in the same recursion stack if it is not known to be safe. 2193 if (!pattern.hasBoundedRewriteRecursion() && 2194 !appliedPatterns.insert(&pattern).second) { 2195 LLVM_DEBUG( 2196 logFailure(rewriter.getImpl().logger, "pattern was already applied")); 2197 return false; 2198 } 2199 return true; 2200 } 2201 2202 LogicalResult 2203 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, 2204 ConversionPatternRewriter &rewriter, 2205 RewriterState &curState) { 2206 auto &impl = rewriter.getImpl(); 2207 assert(impl.pendingRootUpdates.empty() && "dangling root updates"); 2208 2209 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 2210 // Check that the root was either replaced or updated in place. 2211 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites); 2212 auto replacedRoot = [&] { 2213 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op); 2214 }; 2215 auto updatedRootInPlace = [&] { 2216 return hasRewrite<ModifyOperationRewrite>(newRewrites, op); 2217 }; 2218 if (!replacedRoot() && !updatedRootInPlace()) 2219 llvm::report_fatal_error("expected pattern to replace the root operation"); 2220 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 2221 2222 // Legalize each of the actions registered during application. 2223 RewriterState newState = impl.getCurrentState(); 2224 if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState, 2225 newState)) || 2226 failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) || 2227 failed(legalizePatternCreatedOperations(rewriter, impl, curState, 2228 newState))) { 2229 return failure(); 2230 } 2231 2232 LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully")); 2233 return success(); 2234 } 2235 2236 LogicalResult OperationLegalizer::legalizePatternBlockRewrites( 2237 Operation *op, ConversionPatternRewriter &rewriter, 2238 ConversionPatternRewriterImpl &impl, RewriterState &state, 2239 RewriterState &newState) { 2240 SmallPtrSet<Operation *, 16> operationsToIgnore; 2241 2242 // If the pattern moved or created any blocks, make sure the types of block 2243 // arguments get legalized. 2244 for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { 2245 BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get()); 2246 if (!rewrite) 2247 continue; 2248 Block *block = rewrite->getBlock(); 2249 if (isa<BlockTypeConversionRewrite, EraseBlockRewrite, 2250 ReplaceBlockArgRewrite>(rewrite)) 2251 continue; 2252 // Only check blocks outside of the current operation. 2253 Operation *parentOp = block->getParentOp(); 2254 if (!parentOp || parentOp == op || block->getNumArguments() == 0) 2255 continue; 2256 2257 // If the region of the block has a type converter, try to convert the block 2258 // directly. 2259 if (auto *converter = impl.regionToConverter.lookup(block->getParent())) { 2260 std::optional<TypeConverter::SignatureConversion> conversion = 2261 converter->convertBlockSignature(block); 2262 if (!conversion) { 2263 LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved " 2264 "block")); 2265 return failure(); 2266 } 2267 impl.applySignatureConversion(rewriter, block, converter, *conversion); 2268 continue; 2269 } 2270 2271 // Otherwise, check that this operation isn't one generated by this pattern. 2272 // This is because we will attempt to legalize the parent operation, and 2273 // blocks in regions created by this pattern will already be legalized later 2274 // on. If we haven't built the set yet, build it now. 2275 if (operationsToIgnore.empty()) { 2276 for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e; 2277 ++i) { 2278 auto *createOp = 2279 dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get()); 2280 if (!createOp) 2281 continue; 2282 operationsToIgnore.insert(createOp->getOperation()); 2283 } 2284 } 2285 2286 // If this operation should be considered for re-legalization, try it. 2287 if (operationsToIgnore.insert(parentOp).second && 2288 failed(legalize(parentOp, rewriter))) { 2289 LLVM_DEBUG(logFailure(impl.logger, 2290 "operation '{0}'({1}) became illegal after rewrite", 2291 parentOp->getName(), parentOp)); 2292 return failure(); 2293 } 2294 } 2295 return success(); 2296 } 2297 2298 LogicalResult OperationLegalizer::legalizePatternCreatedOperations( 2299 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, 2300 RewriterState &state, RewriterState &newState) { 2301 for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { 2302 auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get()); 2303 if (!createOp) 2304 continue; 2305 Operation *op = createOp->getOperation(); 2306 if (failed(legalize(op, rewriter))) { 2307 LLVM_DEBUG(logFailure(impl.logger, 2308 "failed to legalize generated operation '{0}'({1})", 2309 op->getName(), op)); 2310 return failure(); 2311 } 2312 } 2313 return success(); 2314 } 2315 2316 LogicalResult OperationLegalizer::legalizePatternRootUpdates( 2317 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, 2318 RewriterState &state, RewriterState &newState) { 2319 for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { 2320 auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get()); 2321 if (!rewrite) 2322 continue; 2323 Operation *op = rewrite->getOperation(); 2324 if (failed(legalize(op, rewriter))) { 2325 LLVM_DEBUG(logFailure( 2326 impl.logger, "failed to legalize operation updated in-place '{0}'", 2327 op->getName())); 2328 return failure(); 2329 } 2330 } 2331 return success(); 2332 } 2333 2334 //===----------------------------------------------------------------------===// 2335 // Cost Model 2336 2337 void OperationLegalizer::buildLegalizationGraph( 2338 LegalizationPatterns &anyOpLegalizerPatterns, 2339 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { 2340 // A mapping between an operation and a set of operations that can be used to 2341 // generate it. 2342 DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps; 2343 // A mapping between an operation and any currently invalid patterns it has. 2344 DenseMap<OperationName, SmallPtrSet<const Pattern *, 2>> invalidPatterns; 2345 // A worklist of patterns to consider for legality. 2346 SetVector<const Pattern *> patternWorklist; 2347 2348 // Build the mapping from operations to the parent ops that may generate them. 2349 applicator.walkAllPatterns([&](const Pattern &pattern) { 2350 std::optional<OperationName> root = pattern.getRootKind(); 2351 2352 // If the pattern has no specific root, we can't analyze the relationship 2353 // between the root op and generated operations. Given that, add all such 2354 // patterns to the legalization set. 2355 if (!root) { 2356 anyOpLegalizerPatterns.push_back(&pattern); 2357 return; 2358 } 2359 2360 // Skip operations that are always known to be legal. 2361 if (target.getOpAction(*root) == LegalizationAction::Legal) 2362 return; 2363 2364 // Add this pattern to the invalid set for the root op and record this root 2365 // as a parent for any generated operations. 2366 invalidPatterns[*root].insert(&pattern); 2367 for (auto op : pattern.getGeneratedOps()) 2368 parentOps[op].insert(*root); 2369 2370 // Add this pattern to the worklist. 2371 patternWorklist.insert(&pattern); 2372 }); 2373 2374 // If there are any patterns that don't have a specific root kind, we can't 2375 // make direct assumptions about what operations will never be legalized. 2376 // Note: Technically we could, but it would require an analysis that may 2377 // recurse into itself. It would be better to perform this kind of filtering 2378 // at a higher level than here anyways. 2379 if (!anyOpLegalizerPatterns.empty()) { 2380 for (const Pattern *pattern : patternWorklist) 2381 legalizerPatterns[*pattern->getRootKind()].push_back(pattern); 2382 return; 2383 } 2384 2385 while (!patternWorklist.empty()) { 2386 auto *pattern = patternWorklist.pop_back_val(); 2387 2388 // Check to see if any of the generated operations are invalid. 2389 if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) { 2390 std::optional<LegalizationAction> action = target.getOpAction(op); 2391 return !legalizerPatterns.count(op) && 2392 (!action || action == LegalizationAction::Illegal); 2393 })) 2394 continue; 2395 2396 // Otherwise, if all of the generated operation are valid, this op is now 2397 // legal so add all of the child patterns to the worklist. 2398 legalizerPatterns[*pattern->getRootKind()].push_back(pattern); 2399 invalidPatterns[*pattern->getRootKind()].erase(pattern); 2400 2401 // Add any invalid patterns of the parent operations to see if they have now 2402 // become legal. 2403 for (auto op : parentOps[*pattern->getRootKind()]) 2404 patternWorklist.set_union(invalidPatterns[op]); 2405 } 2406 } 2407 2408 void OperationLegalizer::computeLegalizationGraphBenefit( 2409 LegalizationPatterns &anyOpLegalizerPatterns, 2410 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { 2411 // The smallest pattern depth, when legalizing an operation. 2412 DenseMap<OperationName, unsigned> minOpPatternDepth; 2413 2414 // For each operation that is transitively legal, compute a cost for it. 2415 for (auto &opIt : legalizerPatterns) 2416 if (!minOpPatternDepth.count(opIt.first)) 2417 computeOpLegalizationDepth(opIt.first, minOpPatternDepth, 2418 legalizerPatterns); 2419 2420 // Apply the cost model to the patterns that can match any operation. Those 2421 // with a specific operation type are already resolved when computing the op 2422 // legalization depth. 2423 if (!anyOpLegalizerPatterns.empty()) 2424 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth, 2425 legalizerPatterns); 2426 2427 // Apply a cost model to the pattern applicator. We order patterns first by 2428 // depth then benefit. `legalizerPatterns` contains per-op patterns by 2429 // decreasing benefit. 2430 applicator.applyCostModel([&](const Pattern &pattern) { 2431 ArrayRef<const Pattern *> orderedPatternList; 2432 if (std::optional<OperationName> rootName = pattern.getRootKind()) 2433 orderedPatternList = legalizerPatterns[*rootName]; 2434 else 2435 orderedPatternList = anyOpLegalizerPatterns; 2436 2437 // If the pattern is not found, then it was removed and cannot be matched. 2438 auto *it = llvm::find(orderedPatternList, &pattern); 2439 if (it == orderedPatternList.end()) 2440 return PatternBenefit::impossibleToMatch(); 2441 2442 // Patterns found earlier in the list have higher benefit. 2443 return PatternBenefit(std::distance(it, orderedPatternList.end())); 2444 }); 2445 } 2446 2447 unsigned OperationLegalizer::computeOpLegalizationDepth( 2448 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth, 2449 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { 2450 // Check for existing depth. 2451 auto depthIt = minOpPatternDepth.find(op); 2452 if (depthIt != minOpPatternDepth.end()) 2453 return depthIt->second; 2454 2455 // If a mapping for this operation does not exist, then this operation 2456 // is always legal. Return 0 as the depth for a directly legal operation. 2457 auto opPatternsIt = legalizerPatterns.find(op); 2458 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty()) 2459 return 0u; 2460 2461 // Record this initial depth in case we encounter this op again when 2462 // recursively computing the depth. 2463 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max()); 2464 2465 // Apply the cost model to the operation patterns, and update the minimum 2466 // depth. 2467 unsigned minDepth = applyCostModelToPatterns( 2468 opPatternsIt->second, minOpPatternDepth, legalizerPatterns); 2469 minOpPatternDepth[op] = minDepth; 2470 return minDepth; 2471 } 2472 2473 unsigned OperationLegalizer::applyCostModelToPatterns( 2474 LegalizationPatterns &patterns, 2475 DenseMap<OperationName, unsigned> &minOpPatternDepth, 2476 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { 2477 unsigned minDepth = std::numeric_limits<unsigned>::max(); 2478 2479 // Compute the depth for each pattern within the set. 2480 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth; 2481 patternsByDepth.reserve(patterns.size()); 2482 for (const Pattern *pattern : patterns) { 2483 unsigned depth = 1; 2484 for (auto generatedOp : pattern->getGeneratedOps()) { 2485 unsigned generatedOpDepth = computeOpLegalizationDepth( 2486 generatedOp, minOpPatternDepth, legalizerPatterns); 2487 depth = std::max(depth, generatedOpDepth + 1); 2488 } 2489 patternsByDepth.emplace_back(pattern, depth); 2490 2491 // Update the minimum depth of the pattern list. 2492 minDepth = std::min(minDepth, depth); 2493 } 2494 2495 // If the operation only has one legalization pattern, there is no need to 2496 // sort them. 2497 if (patternsByDepth.size() == 1) 2498 return minDepth; 2499 2500 // Sort the patterns by those likely to be the most beneficial. 2501 std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(), 2502 [](const std::pair<const Pattern *, unsigned> &lhs, 2503 const std::pair<const Pattern *, unsigned> &rhs) { 2504 // First sort by the smaller pattern legalization 2505 // depth. 2506 if (lhs.second != rhs.second) 2507 return lhs.second < rhs.second; 2508 2509 // Then sort by the larger pattern benefit. 2510 auto lhsBenefit = lhs.first->getBenefit(); 2511 auto rhsBenefit = rhs.first->getBenefit(); 2512 return lhsBenefit > rhsBenefit; 2513 }); 2514 2515 // Update the legalization pattern to use the new sorted list. 2516 patterns.clear(); 2517 for (auto &patternIt : patternsByDepth) 2518 patterns.push_back(patternIt.first); 2519 return minDepth; 2520 } 2521 2522 //===----------------------------------------------------------------------===// 2523 // OperationConverter 2524 //===----------------------------------------------------------------------===// 2525 namespace { 2526 enum OpConversionMode { 2527 /// In this mode, the conversion will ignore failed conversions to allow 2528 /// illegal operations to co-exist in the IR. 2529 Partial, 2530 2531 /// In this mode, all operations must be legal for the given target for the 2532 /// conversion to succeed. 2533 Full, 2534 2535 /// In this mode, operations are analyzed for legality. No actual rewrites are 2536 /// applied to the operations on success. 2537 Analysis, 2538 }; 2539 } // namespace 2540 2541 namespace mlir { 2542 // This class converts operations to a given conversion target via a set of 2543 // rewrite patterns. The conversion behaves differently depending on the 2544 // conversion mode. 2545 struct OperationConverter { 2546 explicit OperationConverter(const ConversionTarget &target, 2547 const FrozenRewritePatternSet &patterns, 2548 const ConversionConfig &config, 2549 OpConversionMode mode) 2550 : config(config), opLegalizer(target, patterns, this->config), 2551 mode(mode) {} 2552 2553 /// Converts the given operations to the conversion target. 2554 LogicalResult convertOperations(ArrayRef<Operation *> ops); 2555 2556 private: 2557 /// Converts an operation with the given rewriter. 2558 LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op); 2559 2560 /// Dialect conversion configuration. 2561 ConversionConfig config; 2562 2563 /// The legalizer to use when converting operations. 2564 OperationLegalizer opLegalizer; 2565 2566 /// The conversion mode to use when legalizing operations. 2567 OpConversionMode mode; 2568 }; 2569 } // namespace mlir 2570 2571 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, 2572 Operation *op) { 2573 // Legalize the given operation. 2574 if (failed(opLegalizer.legalize(op, rewriter))) { 2575 // Handle the case of a failed conversion for each of the different modes. 2576 // Full conversions expect all operations to be converted. 2577 if (mode == OpConversionMode::Full) 2578 return op->emitError() 2579 << "failed to legalize operation '" << op->getName() << "'"; 2580 // Partial conversions allow conversions to fail iff the operation was not 2581 // explicitly marked as illegal. If the user provided a `unlegalizedOps` 2582 // set, non-legalizable ops are added to that set. 2583 if (mode == OpConversionMode::Partial) { 2584 if (opLegalizer.isIllegal(op)) 2585 return op->emitError() 2586 << "failed to legalize operation '" << op->getName() 2587 << "' that was explicitly marked illegal"; 2588 if (config.unlegalizedOps) 2589 config.unlegalizedOps->insert(op); 2590 } 2591 } else if (mode == OpConversionMode::Analysis) { 2592 // Analysis conversions don't fail if any operations fail to legalize, 2593 // they are only interested in the operations that were successfully 2594 // legalized. 2595 if (config.legalizableOps) 2596 config.legalizableOps->insert(op); 2597 } 2598 return success(); 2599 } 2600 2601 static LogicalResult 2602 legalizeUnresolvedMaterialization(RewriterBase &rewriter, 2603 UnresolvedMaterializationRewrite *rewrite) { 2604 UnrealizedConversionCastOp op = rewrite->getOperation(); 2605 assert(!op.use_empty() && 2606 "expected that dead materializations have already been DCE'd"); 2607 Operation::operand_range inputOperands = op.getOperands(); 2608 2609 // Try to materialize the conversion. 2610 if (const TypeConverter *converter = rewrite->getConverter()) { 2611 rewriter.setInsertionPoint(op); 2612 SmallVector<Value> newMaterialization; 2613 switch (rewrite->getMaterializationKind()) { 2614 case MaterializationKind::Target: 2615 newMaterialization = converter->materializeTargetConversion( 2616 rewriter, op->getLoc(), op.getResultTypes(), inputOperands, 2617 rewrite->getOriginalType()); 2618 break; 2619 case MaterializationKind::Source: 2620 assert(op->getNumResults() == 1 && "expected single result"); 2621 Value sourceMat = converter->materializeSourceConversion( 2622 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands); 2623 if (sourceMat) 2624 newMaterialization.push_back(sourceMat); 2625 break; 2626 } 2627 if (!newMaterialization.empty()) { 2628 #ifndef NDEBUG 2629 ValueRange newMaterializationRange(newMaterialization); 2630 assert(TypeRange(newMaterializationRange) == op.getResultTypes() && 2631 "materialization callback produced value of incorrect type"); 2632 #endif // NDEBUG 2633 rewriter.replaceOp(op, newMaterialization); 2634 return success(); 2635 } 2636 } 2637 2638 InFlightDiagnostic diag = op->emitError() 2639 << "failed to legalize unresolved materialization " 2640 "from (" 2641 << inputOperands.getTypes() << ") to (" 2642 << op.getResultTypes() 2643 << ") that remained live after conversion"; 2644 diag.attachNote(op->getUsers().begin()->getLoc()) 2645 << "see existing live user here: " << *op->getUsers().begin(); 2646 return failure(); 2647 } 2648 2649 LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { 2650 if (ops.empty()) 2651 return success(); 2652 const ConversionTarget &target = opLegalizer.getTarget(); 2653 2654 // Compute the set of operations and blocks to convert. 2655 SmallVector<Operation *> toConvert; 2656 for (auto *op : ops) { 2657 op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>( 2658 [&](Operation *op) { 2659 toConvert.push_back(op); 2660 // Don't check this operation's children for conversion if the 2661 // operation is recursively legal. 2662 auto legalityInfo = target.isLegal(op); 2663 if (legalityInfo && legalityInfo->isRecursivelyLegal) 2664 return WalkResult::skip(); 2665 return WalkResult::advance(); 2666 }); 2667 } 2668 2669 // Convert each operation and discard rewrites on failure. 2670 ConversionPatternRewriter rewriter(ops.front()->getContext(), config); 2671 ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); 2672 2673 for (auto *op : toConvert) 2674 if (failed(convert(rewriter, op))) 2675 return rewriterImpl.undoRewrites(), failure(); 2676 2677 // After a successful conversion, apply rewrites. 2678 rewriterImpl.applyRewrites(); 2679 2680 // Gather all unresolved materializations. 2681 SmallVector<UnrealizedConversionCastOp> allCastOps; 2682 const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *> 2683 &materializations = rewriterImpl.unresolvedMaterializations; 2684 for (auto it : materializations) { 2685 if (rewriterImpl.eraseRewriter.wasErased(it.first)) 2686 continue; 2687 allCastOps.push_back(it.first); 2688 } 2689 2690 // Reconcile all UnrealizedConversionCastOps that were inserted by the 2691 // dialect conversion frameworks. (Not the one that were inserted by 2692 // patterns.) 2693 SmallVector<UnrealizedConversionCastOp> remainingCastOps; 2694 reconcileUnrealizedCasts(allCastOps, &remainingCastOps); 2695 2696 // Try to legalize all unresolved materializations. 2697 if (config.buildMaterializations) { 2698 IRRewriter rewriter(rewriterImpl.context, config.listener); 2699 for (UnrealizedConversionCastOp castOp : remainingCastOps) { 2700 auto it = materializations.find(castOp); 2701 assert(it != materializations.end() && "inconsistent state"); 2702 if (failed(legalizeUnresolvedMaterialization(rewriter, it->second))) 2703 return failure(); 2704 } 2705 } 2706 2707 return success(); 2708 } 2709 2710 //===----------------------------------------------------------------------===// 2711 // Reconcile Unrealized Casts 2712 //===----------------------------------------------------------------------===// 2713 2714 void mlir::reconcileUnrealizedCasts( 2715 ArrayRef<UnrealizedConversionCastOp> castOps, 2716 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) { 2717 SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(), 2718 castOps.end()); 2719 // This set is maintained only if `remainingCastOps` is provided. 2720 DenseSet<Operation *> erasedOps; 2721 2722 // Helper function that adds all operands to the worklist that are an 2723 // unrealized_conversion_cast op result. 2724 auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) { 2725 for (Value v : castOp.getInputs()) 2726 if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>()) 2727 worklist.insert(inputCastOp); 2728 }; 2729 2730 // Helper function that return the unrealized_conversion_cast op that 2731 // defines all inputs of the given op (in the same order). Return "nullptr" 2732 // if there is no such op. 2733 auto getInputCast = 2734 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp { 2735 if (castOp.getInputs().empty()) 2736 return {}; 2737 auto inputCastOp = 2738 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>(); 2739 if (!inputCastOp) 2740 return {}; 2741 if (inputCastOp.getOutputs() != castOp.getInputs()) 2742 return {}; 2743 return inputCastOp; 2744 }; 2745 2746 // Process ops in the worklist bottom-to-top. 2747 while (!worklist.empty()) { 2748 UnrealizedConversionCastOp castOp = worklist.pop_back_val(); 2749 if (castOp->use_empty()) { 2750 // DCE: If the op has no users, erase it. Add the operands to the 2751 // worklist to find additional DCE opportunities. 2752 enqueueOperands(castOp); 2753 if (remainingCastOps) 2754 erasedOps.insert(castOp.getOperation()); 2755 castOp->erase(); 2756 continue; 2757 } 2758 2759 // Traverse the chain of input cast ops to see if an op with the same 2760 // input types can be found. 2761 UnrealizedConversionCastOp nextCast = castOp; 2762 while (nextCast) { 2763 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) { 2764 // Found a cast where the input types match the output types of the 2765 // matched op. We can directly use those inputs and the matched op can 2766 // be removed. 2767 enqueueOperands(castOp); 2768 castOp.replaceAllUsesWith(nextCast.getInputs()); 2769 if (remainingCastOps) 2770 erasedOps.insert(castOp.getOperation()); 2771 castOp->erase(); 2772 break; 2773 } 2774 nextCast = getInputCast(nextCast); 2775 } 2776 } 2777 2778 if (remainingCastOps) 2779 for (UnrealizedConversionCastOp op : castOps) 2780 if (!erasedOps.contains(op.getOperation())) 2781 remainingCastOps->push_back(op); 2782 } 2783 2784 //===----------------------------------------------------------------------===// 2785 // Type Conversion 2786 //===----------------------------------------------------------------------===// 2787 2788 void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo, 2789 ArrayRef<Type> types) { 2790 assert(!types.empty() && "expected valid types"); 2791 remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size()); 2792 addInputs(types); 2793 } 2794 2795 void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) { 2796 assert(!types.empty() && 2797 "1->0 type remappings don't need to be added explicitly"); 2798 argTypes.append(types.begin(), types.end()); 2799 } 2800 2801 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, 2802 unsigned newInputNo, 2803 unsigned newInputCount) { 2804 assert(!remappedInputs[origInputNo] && "input has already been remapped"); 2805 assert(newInputCount != 0 && "expected valid input count"); 2806 remappedInputs[origInputNo] = 2807 InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr}; 2808 } 2809 2810 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, 2811 Value replacementValue) { 2812 assert(!remappedInputs[origInputNo] && "input has already been remapped"); 2813 remappedInputs[origInputNo] = 2814 InputMapping{origInputNo, /*size=*/0, replacementValue}; 2815 } 2816 2817 LogicalResult TypeConverter::convertType(Type t, 2818 SmallVectorImpl<Type> &results) const { 2819 assert(t && "expected non-null type"); 2820 2821 { 2822 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex, 2823 std::defer_lock); 2824 if (t.getContext()->isMultithreadingEnabled()) 2825 cacheReadLock.lock(); 2826 auto existingIt = cachedDirectConversions.find(t); 2827 if (existingIt != cachedDirectConversions.end()) { 2828 if (existingIt->second) 2829 results.push_back(existingIt->second); 2830 return success(existingIt->second != nullptr); 2831 } 2832 auto multiIt = cachedMultiConversions.find(t); 2833 if (multiIt != cachedMultiConversions.end()) { 2834 results.append(multiIt->second.begin(), multiIt->second.end()); 2835 return success(); 2836 } 2837 } 2838 // Walk the added converters in reverse order to apply the most recently 2839 // registered first. 2840 size_t currentCount = results.size(); 2841 2842 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex, 2843 std::defer_lock); 2844 2845 for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) { 2846 if (std::optional<LogicalResult> result = converter(t, results)) { 2847 if (t.getContext()->isMultithreadingEnabled()) 2848 cacheWriteLock.lock(); 2849 if (!succeeded(*result)) { 2850 cachedDirectConversions.try_emplace(t, nullptr); 2851 return failure(); 2852 } 2853 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount); 2854 if (newTypes.size() == 1) 2855 cachedDirectConversions.try_emplace(t, newTypes.front()); 2856 else 2857 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes)); 2858 return success(); 2859 } 2860 } 2861 return failure(); 2862 } 2863 2864 Type TypeConverter::convertType(Type t) const { 2865 // Use the multi-type result version to convert the type. 2866 SmallVector<Type, 1> results; 2867 if (failed(convertType(t, results))) 2868 return nullptr; 2869 2870 // Check to ensure that only one type was produced. 2871 return results.size() == 1 ? results.front() : nullptr; 2872 } 2873 2874 LogicalResult 2875 TypeConverter::convertTypes(TypeRange types, 2876 SmallVectorImpl<Type> &results) const { 2877 for (Type type : types) 2878 if (failed(convertType(type, results))) 2879 return failure(); 2880 return success(); 2881 } 2882 2883 bool TypeConverter::isLegal(Type type) const { 2884 return convertType(type) == type; 2885 } 2886 bool TypeConverter::isLegal(Operation *op) const { 2887 return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes()); 2888 } 2889 2890 bool TypeConverter::isLegal(Region *region) const { 2891 return llvm::all_of(*region, [this](Block &block) { 2892 return isLegal(block.getArgumentTypes()); 2893 }); 2894 } 2895 2896 bool TypeConverter::isSignatureLegal(FunctionType ty) const { 2897 return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults())); 2898 } 2899 2900 LogicalResult 2901 TypeConverter::convertSignatureArg(unsigned inputNo, Type type, 2902 SignatureConversion &result) const { 2903 // Try to convert the given input type. 2904 SmallVector<Type, 1> convertedTypes; 2905 if (failed(convertType(type, convertedTypes))) 2906 return failure(); 2907 2908 // If this argument is being dropped, there is nothing left to do. 2909 if (convertedTypes.empty()) 2910 return success(); 2911 2912 // Otherwise, add the new inputs. 2913 result.addInputs(inputNo, convertedTypes); 2914 return success(); 2915 } 2916 LogicalResult 2917 TypeConverter::convertSignatureArgs(TypeRange types, 2918 SignatureConversion &result, 2919 unsigned origInputOffset) const { 2920 for (unsigned i = 0, e = types.size(); i != e; ++i) 2921 if (failed(convertSignatureArg(origInputOffset + i, types[i], result))) 2922 return failure(); 2923 return success(); 2924 } 2925 2926 Value TypeConverter::materializeArgumentConversion(OpBuilder &builder, 2927 Location loc, 2928 Type resultType, 2929 ValueRange inputs) const { 2930 for (const MaterializationCallbackFn &fn : 2931 llvm::reverse(argumentMaterializations)) 2932 if (Value result = fn(builder, resultType, inputs, loc)) 2933 return result; 2934 return nullptr; 2935 } 2936 2937 Value TypeConverter::materializeSourceConversion(OpBuilder &builder, 2938 Location loc, Type resultType, 2939 ValueRange inputs) const { 2940 for (const MaterializationCallbackFn &fn : 2941 llvm::reverse(sourceMaterializations)) 2942 if (Value result = fn(builder, resultType, inputs, loc)) 2943 return result; 2944 return nullptr; 2945 } 2946 2947 Value TypeConverter::materializeTargetConversion(OpBuilder &builder, 2948 Location loc, Type resultType, 2949 ValueRange inputs, 2950 Type originalType) const { 2951 SmallVector<Value> result = materializeTargetConversion( 2952 builder, loc, TypeRange(resultType), inputs, originalType); 2953 if (result.empty()) 2954 return nullptr; 2955 assert(result.size() == 1 && "expected single result"); 2956 return result.front(); 2957 } 2958 2959 SmallVector<Value> TypeConverter::materializeTargetConversion( 2960 OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs, 2961 Type originalType) const { 2962 for (const TargetMaterializationCallbackFn &fn : 2963 llvm::reverse(targetMaterializations)) { 2964 SmallVector<Value> result = 2965 fn(builder, resultTypes, inputs, loc, originalType); 2966 if (result.empty()) 2967 continue; 2968 assert(TypeRange(ValueRange(result)) == resultTypes && 2969 "callback produced incorrect number of values or values with " 2970 "incorrect types"); 2971 return result; 2972 } 2973 return {}; 2974 } 2975 2976 std::optional<TypeConverter::SignatureConversion> 2977 TypeConverter::convertBlockSignature(Block *block) const { 2978 SignatureConversion conversion(block->getNumArguments()); 2979 if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion))) 2980 return std::nullopt; 2981 return conversion; 2982 } 2983 2984 //===----------------------------------------------------------------------===// 2985 // Type attribute conversion 2986 //===----------------------------------------------------------------------===// 2987 TypeConverter::AttributeConversionResult 2988 TypeConverter::AttributeConversionResult::result(Attribute attr) { 2989 return AttributeConversionResult(attr, resultTag); 2990 } 2991 2992 TypeConverter::AttributeConversionResult 2993 TypeConverter::AttributeConversionResult::na() { 2994 return AttributeConversionResult(nullptr, naTag); 2995 } 2996 2997 TypeConverter::AttributeConversionResult 2998 TypeConverter::AttributeConversionResult::abort() { 2999 return AttributeConversionResult(nullptr, abortTag); 3000 } 3001 3002 bool TypeConverter::AttributeConversionResult::hasResult() const { 3003 return impl.getInt() == resultTag; 3004 } 3005 3006 bool TypeConverter::AttributeConversionResult::isNa() const { 3007 return impl.getInt() == naTag; 3008 } 3009 3010 bool TypeConverter::AttributeConversionResult::isAbort() const { 3011 return impl.getInt() == abortTag; 3012 } 3013 3014 Attribute TypeConverter::AttributeConversionResult::getResult() const { 3015 assert(hasResult() && "Cannot get result from N/A or abort"); 3016 return impl.getPointer(); 3017 } 3018 3019 std::optional<Attribute> 3020 TypeConverter::convertTypeAttribute(Type type, Attribute attr) const { 3021 for (const TypeAttributeConversionCallbackFn &fn : 3022 llvm::reverse(typeAttributeConversions)) { 3023 AttributeConversionResult res = fn(type, attr); 3024 if (res.hasResult()) 3025 return res.getResult(); 3026 if (res.isAbort()) 3027 return std::nullopt; 3028 } 3029 return std::nullopt; 3030 } 3031 3032 //===----------------------------------------------------------------------===// 3033 // FunctionOpInterfaceSignatureConversion 3034 //===----------------------------------------------------------------------===// 3035 3036 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, 3037 const TypeConverter &typeConverter, 3038 ConversionPatternRewriter &rewriter) { 3039 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType()); 3040 if (!type) 3041 return failure(); 3042 3043 // Convert the original function types. 3044 TypeConverter::SignatureConversion result(type.getNumInputs()); 3045 SmallVector<Type, 1> newResults; 3046 if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) || 3047 failed(typeConverter.convertTypes(type.getResults(), newResults)) || 3048 failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(), 3049 typeConverter, &result))) 3050 return failure(); 3051 3052 // Update the function signature in-place. 3053 auto newType = FunctionType::get(rewriter.getContext(), 3054 result.getConvertedTypes(), newResults); 3055 3056 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); }); 3057 3058 return success(); 3059 } 3060 3061 /// Create a default conversion pattern that rewrites the type signature of a 3062 /// FunctionOpInterface op. This only supports ops which use FunctionType to 3063 /// represent their type. 3064 namespace { 3065 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern { 3066 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName, 3067 MLIRContext *ctx, 3068 const TypeConverter &converter) 3069 : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {} 3070 3071 LogicalResult 3072 matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/, 3073 ConversionPatternRewriter &rewriter) const override { 3074 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op); 3075 return convertFuncOpTypes(funcOp, *typeConverter, rewriter); 3076 } 3077 }; 3078 3079 struct AnyFunctionOpInterfaceSignatureConversion 3080 : public OpInterfaceConversionPattern<FunctionOpInterface> { 3081 using OpInterfaceConversionPattern::OpInterfaceConversionPattern; 3082 3083 LogicalResult 3084 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/, 3085 ConversionPatternRewriter &rewriter) const override { 3086 return convertFuncOpTypes(funcOp, *typeConverter, rewriter); 3087 } 3088 }; 3089 } // namespace 3090 3091 FailureOr<Operation *> 3092 mlir::convertOpResultTypes(Operation *op, ValueRange operands, 3093 const TypeConverter &converter, 3094 ConversionPatternRewriter &rewriter) { 3095 assert(op && "Invalid op"); 3096 Location loc = op->getLoc(); 3097 if (converter.isLegal(op)) 3098 return rewriter.notifyMatchFailure(loc, "op already legal"); 3099 3100 OperationState newOp(loc, op->getName()); 3101 newOp.addOperands(operands); 3102 3103 SmallVector<Type> newResultTypes; 3104 if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes))) 3105 return rewriter.notifyMatchFailure(loc, "couldn't convert return types"); 3106 3107 newOp.addTypes(newResultTypes); 3108 newOp.addAttributes(op->getAttrs()); 3109 return rewriter.create(newOp); 3110 } 3111 3112 void mlir::populateFunctionOpInterfaceTypeConversionPattern( 3113 StringRef functionLikeOpName, RewritePatternSet &patterns, 3114 const TypeConverter &converter) { 3115 patterns.add<FunctionOpInterfaceSignatureConversion>( 3116 functionLikeOpName, patterns.getContext(), converter); 3117 } 3118 3119 void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern( 3120 RewritePatternSet &patterns, const TypeConverter &converter) { 3121 patterns.add<AnyFunctionOpInterfaceSignatureConversion>( 3122 converter, patterns.getContext()); 3123 } 3124 3125 //===----------------------------------------------------------------------===// 3126 // ConversionTarget 3127 //===----------------------------------------------------------------------===// 3128 3129 void ConversionTarget::setOpAction(OperationName op, 3130 LegalizationAction action) { 3131 legalOperations[op].action = action; 3132 } 3133 3134 void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames, 3135 LegalizationAction action) { 3136 for (StringRef dialect : dialectNames) 3137 legalDialects[dialect] = action; 3138 } 3139 3140 auto ConversionTarget::getOpAction(OperationName op) const 3141 -> std::optional<LegalizationAction> { 3142 std::optional<LegalizationInfo> info = getOpInfo(op); 3143 return info ? info->action : std::optional<LegalizationAction>(); 3144 } 3145 3146 auto ConversionTarget::isLegal(Operation *op) const 3147 -> std::optional<LegalOpDetails> { 3148 std::optional<LegalizationInfo> info = getOpInfo(op->getName()); 3149 if (!info) 3150 return std::nullopt; 3151 3152 // Returns true if this operation instance is known to be legal. 3153 auto isOpLegal = [&] { 3154 // Handle dynamic legality either with the provided legality function. 3155 if (info->action == LegalizationAction::Dynamic) { 3156 std::optional<bool> result = info->legalityFn(op); 3157 if (result) 3158 return *result; 3159 } 3160 3161 // Otherwise, the operation is only legal if it was marked 'Legal'. 3162 return info->action == LegalizationAction::Legal; 3163 }; 3164 if (!isOpLegal()) 3165 return std::nullopt; 3166 3167 // This operation is legal, compute any additional legality information. 3168 LegalOpDetails legalityDetails; 3169 if (info->isRecursivelyLegal) { 3170 auto legalityFnIt = opRecursiveLegalityFns.find(op->getName()); 3171 if (legalityFnIt != opRecursiveLegalityFns.end()) { 3172 legalityDetails.isRecursivelyLegal = 3173 legalityFnIt->second(op).value_or(true); 3174 } else { 3175 legalityDetails.isRecursivelyLegal = true; 3176 } 3177 } 3178 return legalityDetails; 3179 } 3180 3181 bool ConversionTarget::isIllegal(Operation *op) const { 3182 std::optional<LegalizationInfo> info = getOpInfo(op->getName()); 3183 if (!info) 3184 return false; 3185 3186 if (info->action == LegalizationAction::Dynamic) { 3187 std::optional<bool> result = info->legalityFn(op); 3188 if (!result) 3189 return false; 3190 3191 return !(*result); 3192 } 3193 3194 return info->action == LegalizationAction::Illegal; 3195 } 3196 3197 static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks( 3198 ConversionTarget::DynamicLegalityCallbackFn oldCallback, 3199 ConversionTarget::DynamicLegalityCallbackFn newCallback) { 3200 if (!oldCallback) 3201 return newCallback; 3202 3203 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)]( 3204 Operation *op) -> std::optional<bool> { 3205 if (std::optional<bool> result = newCl(op)) 3206 return *result; 3207 3208 return oldCl(op); 3209 }; 3210 return chain; 3211 } 3212 3213 void ConversionTarget::setLegalityCallback( 3214 OperationName name, const DynamicLegalityCallbackFn &callback) { 3215 assert(callback && "expected valid legality callback"); 3216 auto *infoIt = legalOperations.find(name); 3217 assert(infoIt != legalOperations.end() && 3218 infoIt->second.action == LegalizationAction::Dynamic && 3219 "expected operation to already be marked as dynamically legal"); 3220 infoIt->second.legalityFn = 3221 composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback); 3222 } 3223 3224 void ConversionTarget::markOpRecursivelyLegal( 3225 OperationName name, const DynamicLegalityCallbackFn &callback) { 3226 auto *infoIt = legalOperations.find(name); 3227 assert(infoIt != legalOperations.end() && 3228 infoIt->second.action != LegalizationAction::Illegal && 3229 "expected operation to already be marked as legal"); 3230 infoIt->second.isRecursivelyLegal = true; 3231 if (callback) 3232 opRecursiveLegalityFns[name] = composeLegalityCallbacks( 3233 std::move(opRecursiveLegalityFns[name]), callback); 3234 else 3235 opRecursiveLegalityFns.erase(name); 3236 } 3237 3238 void ConversionTarget::setLegalityCallback( 3239 ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) { 3240 assert(callback && "expected valid legality callback"); 3241 for (StringRef dialect : dialects) 3242 dialectLegalityFns[dialect] = composeLegalityCallbacks( 3243 std::move(dialectLegalityFns[dialect]), callback); 3244 } 3245 3246 void ConversionTarget::setLegalityCallback( 3247 const DynamicLegalityCallbackFn &callback) { 3248 assert(callback && "expected valid legality callback"); 3249 unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback); 3250 } 3251 3252 auto ConversionTarget::getOpInfo(OperationName op) const 3253 -> std::optional<LegalizationInfo> { 3254 // Check for info for this specific operation. 3255 const auto *it = legalOperations.find(op); 3256 if (it != legalOperations.end()) 3257 return it->second; 3258 // Check for info for the parent dialect. 3259 auto dialectIt = legalDialects.find(op.getDialectNamespace()); 3260 if (dialectIt != legalDialects.end()) { 3261 DynamicLegalityCallbackFn callback; 3262 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace()); 3263 if (dialectFn != dialectLegalityFns.end()) 3264 callback = dialectFn->second; 3265 return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false, 3266 callback}; 3267 } 3268 // Otherwise, check if we mark unknown operations as dynamic. 3269 if (unknownLegalityFn) 3270 return LegalizationInfo{LegalizationAction::Dynamic, 3271 /*isRecursivelyLegal=*/false, unknownLegalityFn}; 3272 return std::nullopt; 3273 } 3274 3275 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH 3276 //===----------------------------------------------------------------------===// 3277 // PDL Configuration 3278 //===----------------------------------------------------------------------===// 3279 3280 void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) { 3281 auto &rewriterImpl = 3282 static_cast<ConversionPatternRewriter &>(rewriter).getImpl(); 3283 rewriterImpl.currentTypeConverter = getTypeConverter(); 3284 } 3285 3286 void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) { 3287 auto &rewriterImpl = 3288 static_cast<ConversionPatternRewriter &>(rewriter).getImpl(); 3289 rewriterImpl.currentTypeConverter = nullptr; 3290 } 3291 3292 /// Remap the given value using the rewriter and the type converter in the 3293 /// provided config. 3294 static FailureOr<SmallVector<Value>> 3295 pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) { 3296 SmallVector<Value> mappedValues; 3297 if (failed(rewriter.getRemappedValues(values, mappedValues))) 3298 return failure(); 3299 return std::move(mappedValues); 3300 } 3301 3302 void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) { 3303 patterns.getPDLPatterns().registerRewriteFunction( 3304 "convertValue", 3305 [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> { 3306 auto results = pdllConvertValues( 3307 static_cast<ConversionPatternRewriter &>(rewriter), value); 3308 if (failed(results)) 3309 return failure(); 3310 return results->front(); 3311 }); 3312 patterns.getPDLPatterns().registerRewriteFunction( 3313 "convertValues", [](PatternRewriter &rewriter, ValueRange values) { 3314 return pdllConvertValues( 3315 static_cast<ConversionPatternRewriter &>(rewriter), values); 3316 }); 3317 patterns.getPDLPatterns().registerRewriteFunction( 3318 "convertType", 3319 [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> { 3320 auto &rewriterImpl = 3321 static_cast<ConversionPatternRewriter &>(rewriter).getImpl(); 3322 if (const TypeConverter *converter = 3323 rewriterImpl.currentTypeConverter) { 3324 if (Type newType = converter->convertType(type)) 3325 return newType; 3326 return failure(); 3327 } 3328 return type; 3329 }); 3330 patterns.getPDLPatterns().registerRewriteFunction( 3331 "convertTypes", 3332 [](PatternRewriter &rewriter, 3333 TypeRange types) -> FailureOr<SmallVector<Type>> { 3334 auto &rewriterImpl = 3335 static_cast<ConversionPatternRewriter &>(rewriter).getImpl(); 3336 const TypeConverter *converter = rewriterImpl.currentTypeConverter; 3337 if (!converter) 3338 return SmallVector<Type>(types); 3339 3340 SmallVector<Type> remappedTypes; 3341 if (failed(converter->convertTypes(types, remappedTypes))) 3342 return failure(); 3343 return std::move(remappedTypes); 3344 }); 3345 } 3346 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH 3347 3348 //===----------------------------------------------------------------------===// 3349 // Op Conversion Entry Points 3350 //===----------------------------------------------------------------------===// 3351 3352 //===----------------------------------------------------------------------===// 3353 // Partial Conversion 3354 3355 LogicalResult mlir::applyPartialConversion( 3356 ArrayRef<Operation *> ops, const ConversionTarget &target, 3357 const FrozenRewritePatternSet &patterns, ConversionConfig config) { 3358 OperationConverter opConverter(target, patterns, config, 3359 OpConversionMode::Partial); 3360 return opConverter.convertOperations(ops); 3361 } 3362 LogicalResult 3363 mlir::applyPartialConversion(Operation *op, const ConversionTarget &target, 3364 const FrozenRewritePatternSet &patterns, 3365 ConversionConfig config) { 3366 return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config); 3367 } 3368 3369 //===----------------------------------------------------------------------===// 3370 // Full Conversion 3371 3372 LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops, 3373 const ConversionTarget &target, 3374 const FrozenRewritePatternSet &patterns, 3375 ConversionConfig config) { 3376 OperationConverter opConverter(target, patterns, config, 3377 OpConversionMode::Full); 3378 return opConverter.convertOperations(ops); 3379 } 3380 LogicalResult mlir::applyFullConversion(Operation *op, 3381 const ConversionTarget &target, 3382 const FrozenRewritePatternSet &patterns, 3383 ConversionConfig config) { 3384 return applyFullConversion(llvm::ArrayRef(op), target, patterns, config); 3385 } 3386 3387 //===----------------------------------------------------------------------===// 3388 // Analysis Conversion 3389 3390 /// Find a common IsolatedFromAbove ancestor of the given ops. If at least one 3391 /// op is a top-level module op (which is expected to be isolated from above), 3392 /// return that op. 3393 static Operation *findCommonAncestor(ArrayRef<Operation *> ops) { 3394 // Check if there is a top-level operation within `ops`. If so, return that 3395 // op. 3396 for (Operation *op : ops) { 3397 if (!op->getParentOp()) { 3398 #ifndef NDEBUG 3399 assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() && 3400 "expected top-level op to be isolated from above"); 3401 for (Operation *other : ops) 3402 assert(op->isAncestor(other) && 3403 "expected ops to have a common ancestor"); 3404 #endif // NDEBUG 3405 return op; 3406 } 3407 } 3408 3409 // No top-level op. Find a common ancestor. 3410 Operation *commonAncestor = 3411 ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(); 3412 for (Operation *op : ops.drop_front()) { 3413 while (!commonAncestor->isProperAncestor(op)) { 3414 commonAncestor = 3415 commonAncestor->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(); 3416 assert(commonAncestor && 3417 "expected to find a common isolated from above ancestor"); 3418 } 3419 } 3420 3421 return commonAncestor; 3422 } 3423 3424 LogicalResult mlir::applyAnalysisConversion( 3425 ArrayRef<Operation *> ops, ConversionTarget &target, 3426 const FrozenRewritePatternSet &patterns, ConversionConfig config) { 3427 #ifndef NDEBUG 3428 if (config.legalizableOps) 3429 assert(config.legalizableOps->empty() && "expected empty set"); 3430 #endif // NDEBUG 3431 3432 // Clone closted common ancestor that is isolated from above. 3433 Operation *commonAncestor = findCommonAncestor(ops); 3434 IRMapping mapping; 3435 Operation *clonedAncestor = commonAncestor->clone(mapping); 3436 // Compute inverse IR mapping. 3437 DenseMap<Operation *, Operation *> inverseOperationMap; 3438 for (auto &it : mapping.getOperationMap()) 3439 inverseOperationMap[it.second] = it.first; 3440 3441 // Convert the cloned operations. The original IR will remain unchanged. 3442 SmallVector<Operation *> opsToConvert = llvm::map_to_vector( 3443 ops, [&](Operation *op) { return mapping.lookup(op); }); 3444 OperationConverter opConverter(target, patterns, config, 3445 OpConversionMode::Analysis); 3446 LogicalResult status = opConverter.convertOperations(opsToConvert); 3447 3448 // Remap `legalizableOps`, so that they point to the original ops and not the 3449 // cloned ops. 3450 if (config.legalizableOps) { 3451 DenseSet<Operation *> originalLegalizableOps; 3452 for (Operation *op : *config.legalizableOps) 3453 originalLegalizableOps.insert(inverseOperationMap[op]); 3454 *config.legalizableOps = std::move(originalLegalizableOps); 3455 } 3456 3457 // Erase the cloned IR. 3458 clonedAncestor->erase(); 3459 return status; 3460 } 3461 3462 LogicalResult 3463 mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target, 3464 const FrozenRewritePatternSet &patterns, 3465 ConversionConfig config) { 3466 return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config); 3467 } 3468