1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestOps.h" 11 #include "TestTypes.h" 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/BuiltinAttributes.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/Visitors.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "mlir/Transforms/FoldUtils.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 #include "mlir/Transforms/WalkPatternRewriteDriver.h" 24 #include "llvm/ADT/ScopeExit.h" 25 #include <cstdint> 26 27 using namespace mlir; 28 using namespace test; 29 30 // Native function for testing NativeCodeCall 31 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 32 return choice.getValue() ? input1 : input2; 33 } 34 35 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 36 rewriter.create<OpI>(loc, input); 37 } 38 39 static void handleNoResultOp(PatternRewriter &rewriter, 40 OpSymbolBindingNoResult op) { 41 // Turn the no result op to a one-result op. 42 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), 43 op.getOperand()); 44 } 45 46 static bool getFirstI32Result(Operation *op, Value &value) { 47 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 48 return false; 49 value = op->getResult(0); 50 return true; 51 } 52 53 static Value bindNativeCodeCallResult(Value value) { return value; } 54 55 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, 56 Value input2) { 57 return SmallVector<Value, 2>({input2, input1}); 58 } 59 60 // Test that natives calls are only called once during rewrites. 61 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 62 // This let us check the number of times OpM_Test was called by inspecting 63 // the returned value in the MLIR output. 64 static int64_t opMIncreasingValue = 314159265; 65 static Attribute opMTest(PatternRewriter &rewriter, Value val) { 66 int64_t i = opMIncreasingValue++; 67 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 68 } 69 70 namespace { 71 #include "TestPatterns.inc" 72 } // namespace 73 74 //===----------------------------------------------------------------------===// 75 // Test Reduce Pattern Interface 76 //===----------------------------------------------------------------------===// 77 78 void test::populateTestReductionPatterns(RewritePatternSet &patterns) { 79 populateWithGenerated(patterns); 80 } 81 82 //===----------------------------------------------------------------------===// 83 // Canonicalizer Driver. 84 //===----------------------------------------------------------------------===// 85 86 namespace { 87 struct FoldingPattern : public RewritePattern { 88 public: 89 FoldingPattern(MLIRContext *context) 90 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 91 /*benefit=*/1, context) {} 92 93 LogicalResult matchAndRewrite(Operation *op, 94 PatternRewriter &rewriter) const override { 95 // Exercise createOrFold API for a single-result operation that is folded 96 // upon construction. The operation being created has an in-place folder, 97 // and it should be still present in the output. Furthermore, the folder 98 // should not crash when attempting to recover the (unchanged) operation 99 // result. 100 Value result = rewriter.createOrFold<TestOpInPlaceFold>( 101 op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0)); 102 assert(result); 103 rewriter.replaceOp(op, result); 104 return success(); 105 } 106 }; 107 108 /// This pattern creates a foldable operation at the entry point of the block. 109 /// This tests the situation where the operation folder will need to replace an 110 /// operation with a previously created constant that does not initially 111 /// dominate the operation to replace. 112 struct FolderInsertBeforePreviouslyFoldedConstantPattern 113 : public OpRewritePattern<TestCastOp> { 114 public: 115 using OpRewritePattern<TestCastOp>::OpRewritePattern; 116 117 LogicalResult matchAndRewrite(TestCastOp op, 118 PatternRewriter &rewriter) const override { 119 if (!op->hasAttr("test_fold_before_previously_folded_op")) 120 return failure(); 121 rewriter.setInsertionPointToStart(op->getBlock()); 122 123 auto constOp = rewriter.create<arith::ConstantOp>( 124 op.getLoc(), rewriter.getBoolAttr(true)); 125 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), 126 Value(constOp)); 127 return success(); 128 } 129 }; 130 131 /// This pattern matches test.op_commutative2 with the first operand being 132 /// another test.op_commutative2 with a constant on the right side and fold it 133 /// away by propagating it as its result. This is intend to check that patterns 134 /// are applied after the commutative property moves constant to the right. 135 struct FolderCommutativeOp2WithConstant 136 : public OpRewritePattern<TestCommutative2Op> { 137 public: 138 using OpRewritePattern<TestCommutative2Op>::OpRewritePattern; 139 140 LogicalResult matchAndRewrite(TestCommutative2Op op, 141 PatternRewriter &rewriter) const override { 142 auto operand = 143 dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp()); 144 if (!operand) 145 return failure(); 146 Attribute constInput; 147 if (!matchPattern(operand->getOperand(1), m_Constant(&constInput))) 148 return failure(); 149 rewriter.replaceOp(op, operand->getOperand(1)); 150 return success(); 151 } 152 }; 153 154 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer 155 /// attribute with value smaller than MaxVal, it increments the value by 1. 156 template <int MaxVal> 157 struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> { 158 using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern; 159 160 LogicalResult matchAndRewrite(AnyAttrOfOp op, 161 PatternRewriter &rewriter) const override { 162 auto intAttr = dyn_cast<IntegerAttr>(op.getAttr()); 163 if (!intAttr) 164 return failure(); 165 int64_t val = intAttr.getInt(); 166 if (val >= MaxVal) 167 return failure(); 168 rewriter.modifyOpInPlace( 169 op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); }); 170 return success(); 171 } 172 }; 173 174 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op". 175 struct MakeOpEligible : public RewritePattern { 176 MakeOpEligible(MLIRContext *context) 177 : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {} 178 179 LogicalResult matchAndRewrite(Operation *op, 180 PatternRewriter &rewriter) const override { 181 if (op->hasAttr("eligible")) 182 return failure(); 183 rewriter.modifyOpInPlace( 184 op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); }); 185 return success(); 186 } 187 }; 188 189 /// This pattern hoists eligible ops out of a "test.one_region_op". 190 struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> { 191 using OpRewritePattern<test::OneRegionOp>::OpRewritePattern; 192 193 LogicalResult matchAndRewrite(test::OneRegionOp op, 194 PatternRewriter &rewriter) const override { 195 Operation *terminator = op.getRegion().front().getTerminator(); 196 Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp(); 197 if (toBeHoisted->getParentOp() != op) 198 return failure(); 199 if (!toBeHoisted->hasAttr("eligible")) 200 return failure(); 201 rewriter.moveOpBefore(toBeHoisted, op); 202 return success(); 203 } 204 }; 205 206 /// This pattern moves "test.move_before_parent_op" before the parent op. 207 struct MoveBeforeParentOp : public RewritePattern { 208 MoveBeforeParentOp(MLIRContext *context) 209 : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {} 210 211 LogicalResult matchAndRewrite(Operation *op, 212 PatternRewriter &rewriter) const override { 213 // Do not hoist past functions. 214 if (isa<FunctionOpInterface>(op->getParentOp())) 215 return failure(); 216 rewriter.moveOpBefore(op, op->getParentOp()); 217 return success(); 218 } 219 }; 220 221 /// This pattern moves "test.move_after_parent_op" after the parent op. 222 struct MoveAfterParentOp : public RewritePattern { 223 MoveAfterParentOp(MLIRContext *context) 224 : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {} 225 226 LogicalResult matchAndRewrite(Operation *op, 227 PatternRewriter &rewriter) const override { 228 // Do not hoist past functions. 229 if (isa<FunctionOpInterface>(op->getParentOp())) 230 return failure(); 231 232 int64_t moveForwardBy = 0; 233 if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance")) 234 moveForwardBy = advanceBy.getInt(); 235 236 Operation *moveAfter = op->getParentOp(); 237 for (int64_t i = 0; i < moveForwardBy; ++i) 238 moveAfter = moveAfter->getNextNode(); 239 240 rewriter.moveOpAfter(op, moveAfter); 241 return success(); 242 } 243 }; 244 245 /// This pattern inlines blocks that are nested in 246 /// "test.inline_blocks_into_parent" into the parent block. 247 struct InlineBlocksIntoParent : public RewritePattern { 248 InlineBlocksIntoParent(MLIRContext *context) 249 : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1, 250 context) {} 251 252 LogicalResult matchAndRewrite(Operation *op, 253 PatternRewriter &rewriter) const override { 254 bool changed = false; 255 for (Region &r : op->getRegions()) { 256 while (!r.empty()) { 257 rewriter.inlineBlockBefore(&r.front(), op); 258 changed = true; 259 } 260 } 261 return success(changed); 262 } 263 }; 264 265 /// This pattern splits blocks at "test.split_block_here" and replaces the op 266 /// with a new op (to prevent an infinite loop of block splitting). 267 struct SplitBlockHere : public RewritePattern { 268 SplitBlockHere(MLIRContext *context) 269 : RewritePattern("test.split_block_here", /*benefit=*/1, context) {} 270 271 LogicalResult matchAndRewrite(Operation *op, 272 PatternRewriter &rewriter) const override { 273 rewriter.splitBlock(op->getBlock(), op->getIterator()); 274 Operation *newOp = rewriter.create( 275 op->getLoc(), 276 OperationName("test.new_op", op->getContext()).getIdentifier(), 277 op->getOperands(), op->getResultTypes()); 278 rewriter.replaceOp(op, newOp); 279 return success(); 280 } 281 }; 282 283 /// This pattern clones "test.clone_me" ops. 284 struct CloneOp : public RewritePattern { 285 CloneOp(MLIRContext *context) 286 : RewritePattern("test.clone_me", /*benefit=*/1, context) {} 287 288 LogicalResult matchAndRewrite(Operation *op, 289 PatternRewriter &rewriter) const override { 290 // Do not clone already cloned ops to avoid going into an infinite loop. 291 if (op->hasAttr("was_cloned")) 292 return failure(); 293 Operation *cloned = rewriter.clone(*op); 294 cloned->setAttr("was_cloned", rewriter.getUnitAttr()); 295 return success(); 296 } 297 }; 298 299 /// This pattern clones regions of "test.clone_region_before" ops before the 300 /// parent block. 301 struct CloneRegionBeforeOp : public RewritePattern { 302 CloneRegionBeforeOp(MLIRContext *context) 303 : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {} 304 305 LogicalResult matchAndRewrite(Operation *op, 306 PatternRewriter &rewriter) const override { 307 // Do not clone already cloned ops to avoid going into an infinite loop. 308 if (op->hasAttr("was_cloned")) 309 return failure(); 310 for (Region &r : op->getRegions()) 311 rewriter.cloneRegionBefore(r, op->getBlock()); 312 op->setAttr("was_cloned", rewriter.getUnitAttr()); 313 return success(); 314 } 315 }; 316 317 /// Replace an operation may introduce the re-visiting of its users. 318 class ReplaceWithNewOp : public RewritePattern { 319 public: 320 ReplaceWithNewOp(MLIRContext *context) 321 : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {} 322 323 LogicalResult matchAndRewrite(Operation *op, 324 PatternRewriter &rewriter) const override { 325 Operation *newOp; 326 if (op->hasAttr("create_erase_op")) { 327 newOp = rewriter.create( 328 op->getLoc(), 329 OperationName("test.erase_op", op->getContext()).getIdentifier(), 330 ValueRange(), TypeRange()); 331 } else { 332 newOp = rewriter.create( 333 op->getLoc(), 334 OperationName("test.new_op", op->getContext()).getIdentifier(), 335 op->getOperands(), op->getResultTypes()); 336 } 337 // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp". 338 // A "notifyOperationReplaced" callback is triggered in either case. 339 rewriter.replaceAllOpUsesWith(op, newOp->getResults()); 340 rewriter.eraseOp(op); 341 return success(); 342 } 343 }; 344 345 /// Erases the first child block of the matched "test.erase_first_block" 346 /// operation. 347 class EraseFirstBlock : public RewritePattern { 348 public: 349 EraseFirstBlock(MLIRContext *context) 350 : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {} 351 352 LogicalResult matchAndRewrite(Operation *op, 353 PatternRewriter &rewriter) const override { 354 llvm::errs() << "Num regions: " << op->getNumRegions() << "\n"; 355 for (Region &r : op->getRegions()) { 356 for (Block &b : r.getBlocks()) { 357 rewriter.eraseBlock(&b); 358 llvm::errs() << "Erasing block: " << b << "\n"; 359 return success(); 360 } 361 } 362 363 return failure(); 364 } 365 }; 366 367 struct TestGreedyPatternDriver 368 : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> { 369 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver) 370 371 TestGreedyPatternDriver() = default; 372 TestGreedyPatternDriver(const TestGreedyPatternDriver &other) 373 : PassWrapper(other) {} 374 375 StringRef getArgument() const final { return "test-greedy-patterns"; } 376 StringRef getDescription() const final { return "Run test dialect patterns"; } 377 void runOnOperation() override { 378 mlir::RewritePatternSet patterns(&getContext()); 379 populateWithGenerated(patterns); 380 381 // Verify named pattern is generated with expected name. 382 patterns.add<FoldingPattern, TestNamedPatternRule, 383 FolderInsertBeforePreviouslyFoldedConstantPattern, 384 FolderCommutativeOp2WithConstant, HoistEligibleOps, 385 MakeOpEligible>(&getContext()); 386 387 // Additional patterns for testing the GreedyPatternRewriteDriver. 388 patterns.insert<IncrementIntAttribute<3>>(&getContext()); 389 390 GreedyRewriteConfig config; 391 config.useTopDownTraversal = this->useTopDownTraversal; 392 config.maxIterations = this->maxIterations; 393 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), 394 config); 395 } 396 397 Option<bool> useTopDownTraversal{ 398 *this, "top-down", 399 llvm::cl::desc("Seed the worklist in general top-down order"), 400 llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)}; 401 Option<int> maxIterations{ 402 *this, "max-iterations", 403 llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"), 404 llvm::cl::init(GreedyRewriteConfig().maxIterations)}; 405 }; 406 407 struct DumpNotifications : public RewriterBase::Listener { 408 void notifyBlockInserted(Block *block, Region *previous, 409 Region::iterator previousIt) override { 410 llvm::outs() << "notifyBlockInserted"; 411 if (block->getParentOp()) { 412 llvm::outs() << " into " << block->getParentOp()->getName() << ": "; 413 } else { 414 llvm::outs() << " into unknown op: "; 415 } 416 if (previous == nullptr) { 417 llvm::outs() << "was unlinked\n"; 418 } else { 419 llvm::outs() << "was linked\n"; 420 } 421 } 422 void notifyOperationInserted(Operation *op, 423 OpBuilder::InsertPoint previous) override { 424 llvm::outs() << "notifyOperationInserted: " << op->getName(); 425 if (!previous.isSet()) { 426 llvm::outs() << ", was unlinked\n"; 427 } else { 428 if (!previous.getPoint().getNodePtr()) { 429 llvm::outs() << ", was linked, exact position unknown\n"; 430 } else if (previous.getPoint() == previous.getBlock()->end()) { 431 llvm::outs() << ", was last in block\n"; 432 } else { 433 llvm::outs() << ", previous = " << previous.getPoint()->getName() 434 << "\n"; 435 } 436 } 437 } 438 void notifyBlockErased(Block *block) override { 439 llvm::outs() << "notifyBlockErased\n"; 440 } 441 void notifyOperationErased(Operation *op) override { 442 llvm::outs() << "notifyOperationErased: " << op->getName() << "\n"; 443 } 444 void notifyOperationModified(Operation *op) override { 445 llvm::outs() << "notifyOperationModified: " << op->getName() << "\n"; 446 } 447 void notifyOperationReplaced(Operation *op, ValueRange values) override { 448 llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n"; 449 } 450 }; 451 452 struct TestStrictPatternDriver 453 : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> { 454 public: 455 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver) 456 457 TestStrictPatternDriver() = default; 458 TestStrictPatternDriver(const TestStrictPatternDriver &other) 459 : PassWrapper(other) { 460 strictMode = other.strictMode; 461 } 462 463 StringRef getArgument() const final { return "test-strict-pattern-driver"; } 464 StringRef getDescription() const final { 465 return "Test strict mode of pattern driver"; 466 } 467 468 void runOnOperation() override { 469 MLIRContext *ctx = &getContext(); 470 mlir::RewritePatternSet patterns(ctx); 471 patterns.add< 472 // clang-format off 473 ChangeBlockOp, 474 CloneOp, 475 CloneRegionBeforeOp, 476 EraseOp, 477 ImplicitChangeOp, 478 InlineBlocksIntoParent, 479 InsertSameOp, 480 MoveBeforeParentOp, 481 ReplaceWithNewOp, 482 SplitBlockHere 483 // clang-format on 484 >(ctx); 485 SmallVector<Operation *> ops; 486 getOperation()->walk([&](Operation *op) { 487 StringRef opName = op->getName().getStringRef(); 488 if (opName == "test.insert_same_op" || opName == "test.change_block_op" || 489 opName == "test.replace_with_new_op" || opName == "test.erase_op" || 490 opName == "test.move_before_parent_op" || 491 opName == "test.inline_blocks_into_parent" || 492 opName == "test.split_block_here" || opName == "test.clone_me" || 493 opName == "test.clone_region_before") { 494 ops.push_back(op); 495 } 496 }); 497 498 DumpNotifications dumpNotifications; 499 GreedyRewriteConfig config; 500 config.listener = &dumpNotifications; 501 if (strictMode == "AnyOp") { 502 config.strictMode = GreedyRewriteStrictness::AnyOp; 503 } else if (strictMode == "ExistingAndNewOps") { 504 config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; 505 } else if (strictMode == "ExistingOps") { 506 config.strictMode = GreedyRewriteStrictness::ExistingOps; 507 } else { 508 llvm_unreachable("invalid strictness option"); 509 } 510 511 // Check if these transformations introduce visiting of operations that 512 // are not in the `ops` set (The new created ops are valid). An invalid 513 // operation will trigger the assertion while processing. 514 bool changed = false; 515 bool allErased = false; 516 (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config, 517 &changed, &allErased); 518 Builder b(ctx); 519 getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed)); 520 getOperation()->setAttr("pattern_driver_all_erased", 521 b.getBoolAttr(allErased)); 522 } 523 524 Option<std::string> strictMode{ 525 *this, "strictness", 526 llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"), 527 llvm::cl::init("AnyOp")}; 528 529 private: 530 // New inserted operation is valid for further transformation. 531 class InsertSameOp : public RewritePattern { 532 public: 533 InsertSameOp(MLIRContext *context) 534 : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {} 535 536 LogicalResult matchAndRewrite(Operation *op, 537 PatternRewriter &rewriter) const override { 538 if (op->hasAttr("skip")) 539 return failure(); 540 541 Operation *newOp = 542 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 543 op->getOperands(), op->getResultTypes()); 544 rewriter.modifyOpInPlace( 545 op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); }); 546 newOp->setAttr("skip", rewriter.getBoolAttr(true)); 547 548 return success(); 549 } 550 }; 551 552 // Remove an operation may introduce the re-visiting of its operands. 553 class EraseOp : public RewritePattern { 554 public: 555 EraseOp(MLIRContext *context) 556 : RewritePattern("test.erase_op", /*benefit=*/1, context) {} 557 LogicalResult matchAndRewrite(Operation *op, 558 PatternRewriter &rewriter) const override { 559 rewriter.eraseOp(op); 560 return success(); 561 } 562 }; 563 564 // The following two patterns test RewriterBase::replaceAllUsesWith. 565 // 566 // That function replaces all usages of a Block (or a Value) with another one 567 // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver 568 // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its 569 // worklist: when an op is modified, it is added to the worklist. The two 570 // patterns below make the tracking observable: ChangeBlockOp replaces all 571 // usages of a block and that pattern is applied because the corresponding ops 572 // are put on the initial worklist (see above). ImplicitChangeOp does an 573 // unrelated change but ops of the corresponding type are *not* on the initial 574 // worklist, so the effect of the second pattern is only visible if the 575 // tracking and subsequent adding to the worklist actually works. 576 577 // Replace all usages of the first successor with the second successor. 578 class ChangeBlockOp : public RewritePattern { 579 public: 580 ChangeBlockOp(MLIRContext *context) 581 : RewritePattern("test.change_block_op", /*benefit=*/1, context) {} 582 LogicalResult matchAndRewrite(Operation *op, 583 PatternRewriter &rewriter) const override { 584 if (op->getNumSuccessors() < 2) 585 return failure(); 586 Block *firstSuccessor = op->getSuccessor(0); 587 Block *secondSuccessor = op->getSuccessor(1); 588 if (firstSuccessor == secondSuccessor) 589 return failure(); 590 // This is the function being tested: 591 rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor); 592 // Using the following line instead would make the test fail: 593 // firstSuccessor->replaceAllUsesWith(secondSuccessor); 594 return success(); 595 } 596 }; 597 598 // Changes the successor to the parent block. 599 class ImplicitChangeOp : public RewritePattern { 600 public: 601 ImplicitChangeOp(MLIRContext *context) 602 : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {} 603 LogicalResult matchAndRewrite(Operation *op, 604 PatternRewriter &rewriter) const override { 605 if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock()) 606 return failure(); 607 rewriter.modifyOpInPlace(op, 608 [&]() { op->setSuccessor(op->getBlock(), 0); }); 609 return success(); 610 } 611 }; 612 }; 613 614 struct TestWalkPatternDriver final 615 : PassWrapper<TestWalkPatternDriver, OperationPass<>> { 616 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver) 617 618 TestWalkPatternDriver() = default; 619 TestWalkPatternDriver(const TestWalkPatternDriver &other) 620 : PassWrapper(other) {} 621 622 StringRef getArgument() const override { 623 return "test-walk-pattern-rewrite-driver"; 624 } 625 StringRef getDescription() const override { 626 return "Run test walk pattern rewrite driver"; 627 } 628 void runOnOperation() override { 629 mlir::RewritePatternSet patterns(&getContext()); 630 631 // Patterns for testing the WalkPatternRewriteDriver. 632 patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp, 633 MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>( 634 &getContext()); 635 636 DumpNotifications dumpListener; 637 walkAndApplyPatterns(getOperation(), std::move(patterns), 638 dumpNotifications ? &dumpListener : nullptr); 639 } 640 641 Option<bool> dumpNotifications{ 642 *this, "dump-notifications", 643 llvm::cl::desc("Print rewrite listener notifications"), 644 llvm::cl::init(false)}; 645 }; 646 647 } // namespace 648 649 //===----------------------------------------------------------------------===// 650 // ReturnType Driver. 651 //===----------------------------------------------------------------------===// 652 653 namespace { 654 // Generate ops for each instance where the type can be successfully inferred. 655 template <typename OpTy> 656 static void invokeCreateWithInferredReturnType(Operation *op) { 657 auto *context = op->getContext(); 658 auto fop = op->getParentOfType<func::FuncOp>(); 659 auto location = UnknownLoc::get(context); 660 OpBuilder b(op); 661 b.setInsertionPointAfter(op); 662 663 // Use permutations of 2 args as operands. 664 assert(fop.getNumArguments() >= 2); 665 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 666 for (int j = 0; j < e; ++j) { 667 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 668 SmallVector<Type, 2> inferredReturnTypes; 669 if (succeeded(OpTy::inferReturnTypes( 670 context, std::nullopt, values, op->getDiscardableAttrDictionary(), 671 op->getPropertiesStorage(), op->getRegions(), 672 inferredReturnTypes))) { 673 OperationState state(location, OpTy::getOperationName()); 674 // TODO: Expand to regions. 675 OpTy::build(b, state, values, op->getAttrs()); 676 (void)b.create(state); 677 } 678 } 679 } 680 } 681 682 static void reifyReturnShape(Operation *op) { 683 OpBuilder b(op); 684 685 // Use permutations of 2 args as operands. 686 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 687 SmallVector<Value, 2> shapes; 688 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 689 !llvm::hasSingleElement(shapes)) 690 return; 691 for (const auto &it : llvm::enumerate(shapes)) { 692 op->emitRemark() << "value " << it.index() << ": " 693 << it.value().getDefiningOp(); 694 } 695 } 696 697 struct TestReturnTypeDriver 698 : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> { 699 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver) 700 701 void getDependentDialects(DialectRegistry ®istry) const override { 702 registry.insert<tensor::TensorDialect>(); 703 } 704 StringRef getArgument() const final { return "test-return-type"; } 705 StringRef getDescription() const final { return "Run return type functions"; } 706 707 void runOnOperation() override { 708 if (getOperation().getName() == "testCreateFunctions") { 709 std::vector<Operation *> ops; 710 // Collect ops to avoid triggering on inserted ops. 711 for (auto &op : getOperation().getBody().front()) 712 ops.push_back(&op); 713 // Generate test patterns for each, but skip terminator. 714 for (auto *op : llvm::ArrayRef(ops).drop_back()) { 715 // Test create method of each of the Op classes below. The resultant 716 // output would be in reverse order underneath `op` from which 717 // the attributes and regions are used. 718 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 719 invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>( 720 op); 721 invokeCreateWithInferredReturnType< 722 OpWithShapedTypeInferTypeInterfaceOp>(op); 723 }; 724 return; 725 } 726 if (getOperation().getName() == "testReifyFunctions") { 727 std::vector<Operation *> ops; 728 // Collect ops to avoid triggering on inserted ops. 729 for (auto &op : getOperation().getBody().front()) 730 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 731 ops.push_back(&op); 732 // Generate test patterns for each, but skip terminator. 733 for (auto *op : ops) 734 reifyReturnShape(op); 735 } 736 } 737 }; 738 } // namespace 739 740 namespace { 741 struct TestDerivedAttributeDriver 742 : public PassWrapper<TestDerivedAttributeDriver, 743 OperationPass<func::FuncOp>> { 744 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver) 745 746 StringRef getArgument() const final { return "test-derived-attr"; } 747 StringRef getDescription() const final { 748 return "Run test derived attributes"; 749 } 750 void runOnOperation() override; 751 }; 752 } // namespace 753 754 void TestDerivedAttributeDriver::runOnOperation() { 755 getOperation().walk([](DerivedAttributeOpInterface dOp) { 756 auto dAttr = dOp.materializeDerivedAttributes(); 757 if (!dAttr) 758 return; 759 for (auto d : dAttr) 760 dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue(); 761 }); 762 } 763 764 //===----------------------------------------------------------------------===// 765 // Legalization Driver. 766 //===----------------------------------------------------------------------===// 767 768 namespace { 769 //===----------------------------------------------------------------------===// 770 // Region-Block Rewrite Testing 771 772 /// This pattern applies a signature conversion to a block inside a detached 773 /// region. 774 struct TestDetachedSignatureConversion : public ConversionPattern { 775 TestDetachedSignatureConversion(MLIRContext *ctx) 776 : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1, 777 ctx) {} 778 779 LogicalResult 780 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 781 ConversionPatternRewriter &rewriter) const final { 782 if (op->getNumRegions() != 1) 783 return failure(); 784 OperationState state(op->getLoc(), "test.legal_op_with_region", operands, 785 op->getResultTypes(), {}, BlockRange()); 786 Region *newRegion = state.addRegion(); 787 rewriter.inlineRegionBefore(op->getRegion(0), *newRegion, 788 newRegion->begin()); 789 TypeConverter::SignatureConversion result(newRegion->getNumArguments()); 790 for (unsigned i = 0, e = newRegion->getNumArguments(); i < e; ++i) 791 result.addInputs(i, rewriter.getF64Type()); 792 rewriter.applySignatureConversion(&newRegion->front(), result); 793 Operation *newOp = rewriter.create(state); 794 rewriter.replaceOp(op, newOp->getResults()); 795 return success(); 796 } 797 }; 798 799 /// This pattern is a simple pattern that inlines the first region of a given 800 /// operation into the parent region. 801 struct TestRegionRewriteBlockMovement : public ConversionPattern { 802 TestRegionRewriteBlockMovement(MLIRContext *ctx) 803 : ConversionPattern("test.region", 1, ctx) {} 804 805 LogicalResult 806 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 807 ConversionPatternRewriter &rewriter) const final { 808 // Inline this region into the parent region. 809 auto &parentRegion = *op->getParentRegion(); 810 auto &opRegion = op->getRegion(0); 811 if (op->getDiscardableAttr("legalizer.should_clone")) 812 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 813 else 814 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 815 816 if (op->getDiscardableAttr("legalizer.erase_old_blocks")) { 817 while (!opRegion.empty()) 818 rewriter.eraseBlock(&opRegion.front()); 819 } 820 821 // Drop this operation. 822 rewriter.eraseOp(op); 823 return success(); 824 } 825 }; 826 /// This pattern is a simple pattern that generates a region containing an 827 /// illegal operation. 828 struct TestRegionRewriteUndo : public RewritePattern { 829 TestRegionRewriteUndo(MLIRContext *ctx) 830 : RewritePattern("test.region_builder", 1, ctx) {} 831 832 LogicalResult matchAndRewrite(Operation *op, 833 PatternRewriter &rewriter) const final { 834 // Create the region operation with an entry block containing arguments. 835 OperationState newRegion(op->getLoc(), "test.region"); 836 newRegion.addRegion(); 837 auto *regionOp = rewriter.create(newRegion); 838 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 839 entryBlock->addArgument(rewriter.getIntegerType(64), 840 rewriter.getUnknownLoc()); 841 842 // Add an explicitly illegal operation to ensure the conversion fails. 843 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 844 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 845 846 // Drop this operation. 847 rewriter.eraseOp(op); 848 return success(); 849 } 850 }; 851 /// A simple pattern that creates a block at the end of the parent region of the 852 /// matched operation. 853 struct TestCreateBlock : public RewritePattern { 854 TestCreateBlock(MLIRContext *ctx) 855 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 856 857 LogicalResult matchAndRewrite(Operation *op, 858 PatternRewriter &rewriter) const final { 859 Region ®ion = *op->getParentRegion(); 860 Type i32Type = rewriter.getIntegerType(32); 861 Location loc = op->getLoc(); 862 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 863 rewriter.create<TerminatorOp>(loc); 864 rewriter.eraseOp(op); 865 return success(); 866 } 867 }; 868 869 /// A simple pattern that creates a block containing an invalid operation in 870 /// order to trigger the block creation undo mechanism. 871 struct TestCreateIllegalBlock : public RewritePattern { 872 TestCreateIllegalBlock(MLIRContext *ctx) 873 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 874 875 LogicalResult matchAndRewrite(Operation *op, 876 PatternRewriter &rewriter) const final { 877 Region ®ion = *op->getParentRegion(); 878 Type i32Type = rewriter.getIntegerType(32); 879 Location loc = op->getLoc(); 880 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 881 // Create an illegal op to ensure the conversion fails. 882 rewriter.create<ILLegalOpF>(loc, i32Type); 883 rewriter.create<TerminatorOp>(loc); 884 rewriter.eraseOp(op); 885 return success(); 886 } 887 }; 888 889 /// A simple pattern that tests the undo mechanism when replacing the uses of a 890 /// block argument. 891 struct TestUndoBlockArgReplace : public ConversionPattern { 892 TestUndoBlockArgReplace(MLIRContext *ctx) 893 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 894 895 LogicalResult 896 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 897 ConversionPatternRewriter &rewriter) const final { 898 auto illegalOp = 899 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 900 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 901 illegalOp->getResult(0)); 902 rewriter.modifyOpInPlace(op, [] {}); 903 return success(); 904 } 905 }; 906 907 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion. 908 /// This is to test the rollback logic. 909 struct TestUndoMoveOpBefore : public ConversionPattern { 910 TestUndoMoveOpBefore(MLIRContext *ctx) 911 : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {} 912 913 LogicalResult 914 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 915 ConversionPatternRewriter &rewriter) const override { 916 rewriter.moveOpBefore(op, op->getParentOp()); 917 // Replace with an illegal op to ensure the conversion fails. 918 rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type()); 919 return success(); 920 } 921 }; 922 923 /// A rewrite pattern that tests the undo mechanism when erasing a block. 924 struct TestUndoBlockErase : public ConversionPattern { 925 TestUndoBlockErase(MLIRContext *ctx) 926 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 927 928 LogicalResult 929 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 930 ConversionPatternRewriter &rewriter) const final { 931 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 932 rewriter.setInsertionPointToStart(secondBlock); 933 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 934 rewriter.eraseBlock(secondBlock); 935 rewriter.modifyOpInPlace(op, [] {}); 936 return success(); 937 } 938 }; 939 940 /// A pattern that modifies a property in-place, but keeps the op illegal. 941 struct TestUndoPropertiesModification : public ConversionPattern { 942 TestUndoPropertiesModification(MLIRContext *ctx) 943 : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {} 944 LogicalResult 945 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 946 ConversionPatternRewriter &rewriter) const final { 947 if (!op->hasAttr("modify_inplace")) 948 return failure(); 949 rewriter.modifyOpInPlace( 950 op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); }); 951 return success(); 952 } 953 }; 954 955 //===----------------------------------------------------------------------===// 956 // Type-Conversion Rewrite Testing 957 958 /// This patterns erases a region operation that has had a type conversion. 959 struct TestDropOpSignatureConversion : public ConversionPattern { 960 TestDropOpSignatureConversion(MLIRContext *ctx, 961 const TypeConverter &converter) 962 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 963 LogicalResult 964 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 965 ConversionPatternRewriter &rewriter) const override { 966 Region ®ion = op->getRegion(0); 967 Block *entry = ®ion.front(); 968 969 // Convert the original entry arguments. 970 const TypeConverter &converter = *getTypeConverter(); 971 TypeConverter::SignatureConversion result(entry->getNumArguments()); 972 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 973 result)) || 974 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 975 return failure(); 976 977 // Convert the region signature and just drop the operation. 978 rewriter.eraseOp(op); 979 return success(); 980 } 981 }; 982 /// This pattern simply updates the operands of the given operation. 983 struct TestPassthroughInvalidOp : public ConversionPattern { 984 TestPassthroughInvalidOp(MLIRContext *ctx) 985 : ConversionPattern("test.invalid", 1, ctx) {} 986 LogicalResult 987 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 988 ConversionPatternRewriter &rewriter) const final { 989 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, operands, 990 std::nullopt); 991 return success(); 992 } 993 }; 994 /// Replace with valid op, but simply drop the operands. This is used in a 995 /// regression where we used to generate circular unrealized_conversion_cast 996 /// ops. 997 struct TestDropAndReplaceInvalidOp : public ConversionPattern { 998 TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter) 999 : ConversionPattern(converter, 1000 "test.drop_operands_and_replace_with_valid", 1, ctx) { 1001 } 1002 LogicalResult 1003 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1004 ConversionPatternRewriter &rewriter) const final { 1005 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, ValueRange(), 1006 std::nullopt); 1007 return success(); 1008 } 1009 }; 1010 /// This pattern handles the case of a split return value. 1011 struct TestSplitReturnType : public ConversionPattern { 1012 TestSplitReturnType(MLIRContext *ctx) 1013 : ConversionPattern("test.return", 1, ctx) {} 1014 LogicalResult 1015 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1016 ConversionPatternRewriter &rewriter) const final { 1017 // Check for a return of F32. 1018 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 1019 return failure(); 1020 1021 // Check if the first operation is a cast operation, if it is we use the 1022 // results directly. 1023 auto *defOp = operands[0].getDefiningOp(); 1024 if (auto packerOp = 1025 llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) { 1026 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 1027 return success(); 1028 } 1029 1030 // Otherwise, fail to match. 1031 return failure(); 1032 } 1033 }; 1034 1035 //===----------------------------------------------------------------------===// 1036 // Multi-Level Type-Conversion Rewrite Testing 1037 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 1038 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 1039 : ConversionPattern("test.type_producer", 1, ctx) {} 1040 LogicalResult 1041 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1042 ConversionPatternRewriter &rewriter) const final { 1043 // If the type is I32, change the type to F32. 1044 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 1045 return failure(); 1046 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 1047 return success(); 1048 } 1049 }; 1050 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 1051 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 1052 : ConversionPattern("test.type_producer", 1, ctx) {} 1053 LogicalResult 1054 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1055 ConversionPatternRewriter &rewriter) const final { 1056 // If the type is F32, change the type to F64. 1057 if (!Type(*op->result_type_begin()).isF32()) 1058 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 1059 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 1060 return success(); 1061 } 1062 }; 1063 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 1064 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 1065 : ConversionPattern("test.type_producer", 10, ctx) {} 1066 LogicalResult 1067 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1068 ConversionPatternRewriter &rewriter) const final { 1069 // Always convert to B16, even though it is not a legal type. This tests 1070 // that values are unmapped correctly. 1071 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 1072 return success(); 1073 } 1074 }; 1075 struct TestUpdateConsumerType : public ConversionPattern { 1076 TestUpdateConsumerType(MLIRContext *ctx) 1077 : ConversionPattern("test.type_consumer", 1, ctx) {} 1078 LogicalResult 1079 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1080 ConversionPatternRewriter &rewriter) const final { 1081 // Verify that the incoming operand has been successfully remapped to F64. 1082 if (!operands[0].getType().isF64()) 1083 return failure(); 1084 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 1085 return success(); 1086 } 1087 }; 1088 1089 //===----------------------------------------------------------------------===// 1090 // Non-Root Replacement Rewrite Testing 1091 /// This pattern generates an invalid operation, but replaces it before the 1092 /// pattern is finished. This checks that we don't need to legalize the 1093 /// temporary op. 1094 struct TestNonRootReplacement : public RewritePattern { 1095 TestNonRootReplacement(MLIRContext *ctx) 1096 : RewritePattern("test.replace_non_root", 1, ctx) {} 1097 1098 LogicalResult matchAndRewrite(Operation *op, 1099 PatternRewriter &rewriter) const final { 1100 auto resultType = *op->result_type_begin(); 1101 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 1102 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 1103 1104 rewriter.replaceOp(illegalOp, legalOp); 1105 rewriter.replaceOp(op, illegalOp); 1106 return success(); 1107 } 1108 }; 1109 1110 //===----------------------------------------------------------------------===// 1111 // Recursive Rewrite Testing 1112 /// This pattern is applied to the same operation multiple times, but has a 1113 /// bounded recursion. 1114 struct TestBoundedRecursiveRewrite 1115 : public OpRewritePattern<TestRecursiveRewriteOp> { 1116 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 1117 1118 void initialize() { 1119 // The conversion target handles bounding the recursion of this pattern. 1120 setHasBoundedRewriteRecursion(); 1121 } 1122 1123 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 1124 PatternRewriter &rewriter) const final { 1125 // Decrement the depth of the op in-place. 1126 rewriter.modifyOpInPlace(op, [&] { 1127 op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1)); 1128 }); 1129 return success(); 1130 } 1131 }; 1132 1133 struct TestNestedOpCreationUndoRewrite 1134 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 1135 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 1136 1137 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 1138 PatternRewriter &rewriter) const final { 1139 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 1140 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 1141 return success(); 1142 }; 1143 }; 1144 1145 // This pattern matches `test.blackhole` and delete this op and its producer. 1146 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 1147 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 1148 1149 LogicalResult matchAndRewrite(BlackHoleOp op, 1150 PatternRewriter &rewriter) const final { 1151 Operation *producer = op.getOperand().getDefiningOp(); 1152 // Always erase the user before the producer, the framework should handle 1153 // this correctly. 1154 rewriter.eraseOp(op); 1155 rewriter.eraseOp(producer); 1156 return success(); 1157 }; 1158 }; 1159 1160 // This pattern replaces explicitly illegal op with explicitly legal op, 1161 // but in addition creates unregistered operation. 1162 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { 1163 using OpRewritePattern<ILLegalOpG>::OpRewritePattern; 1164 1165 LogicalResult matchAndRewrite(ILLegalOpG op, 1166 PatternRewriter &rewriter) const final { 1167 IntegerAttr attr = rewriter.getI32IntegerAttr(0); 1168 Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr); 1169 rewriter.replaceOpWithNewOp<LegalOpC>(op, val); 1170 return success(); 1171 }; 1172 }; 1173 1174 class TestEraseOp : public ConversionPattern { 1175 public: 1176 TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op", 1, ctx) {} 1177 LogicalResult 1178 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1179 ConversionPatternRewriter &rewriter) const final { 1180 // Erase op without replacements. 1181 rewriter.eraseOp(op); 1182 return success(); 1183 } 1184 }; 1185 1186 } // namespace 1187 1188 namespace { 1189 struct TestTypeConverter : public TypeConverter { 1190 using TypeConverter::TypeConverter; 1191 TestTypeConverter() { 1192 addConversion(convertType); 1193 addArgumentMaterialization(materializeCast); 1194 addSourceMaterialization(materializeCast); 1195 } 1196 1197 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 1198 // Drop I16 types. 1199 if (t.isSignlessInteger(16)) 1200 return success(); 1201 1202 // Convert I64 to F64. 1203 if (t.isSignlessInteger(64)) { 1204 results.push_back(FloatType::getF64(t.getContext())); 1205 return success(); 1206 } 1207 1208 // Convert I42 to I43. 1209 if (t.isInteger(42)) { 1210 results.push_back(IntegerType::get(t.getContext(), 43)); 1211 return success(); 1212 } 1213 1214 // Split F32 into F16,F16. 1215 if (t.isF32()) { 1216 results.assign(2, FloatType::getF16(t.getContext())); 1217 return success(); 1218 } 1219 1220 // Otherwise, convert the type directly. 1221 results.push_back(t); 1222 return success(); 1223 } 1224 1225 /// Hook for materializing a conversion. This is necessary because we generate 1226 /// 1->N type mappings. 1227 static Value materializeCast(OpBuilder &builder, Type resultType, 1228 ValueRange inputs, Location loc) { 1229 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1230 } 1231 }; 1232 1233 struct TestLegalizePatternDriver 1234 : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> { 1235 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver) 1236 1237 StringRef getArgument() const final { return "test-legalize-patterns"; } 1238 StringRef getDescription() const final { 1239 return "Run test dialect legalization patterns"; 1240 } 1241 /// The mode of conversion to use with the driver. 1242 enum class ConversionMode { Analysis, Full, Partial }; 1243 1244 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 1245 1246 void getDependentDialects(DialectRegistry ®istry) const override { 1247 registry.insert<func::FuncDialect, test::TestDialect>(); 1248 } 1249 1250 void runOnOperation() override { 1251 TestTypeConverter converter; 1252 mlir::RewritePatternSet patterns(&getContext()); 1253 populateWithGenerated(patterns); 1254 patterns.add< 1255 TestRegionRewriteBlockMovement, TestDetachedSignatureConversion, 1256 TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock, 1257 TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp, 1258 TestSplitReturnType, TestChangeProducerTypeI32ToF32, 1259 TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid, 1260 TestUpdateConsumerType, TestNonRootReplacement, 1261 TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite, 1262 TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore, 1263 TestUndoPropertiesModification, TestEraseOp>(&getContext()); 1264 patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>( 1265 &getContext(), converter); 1266 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1267 converter); 1268 mlir::populateCallOpTypeConversionPattern(patterns, converter); 1269 1270 // Define the conversion target used for the test. 1271 ConversionTarget target(getContext()); 1272 target.addLegalOp<ModuleOp>(); 1273 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, 1274 TerminatorOp, OneRegionOp>(); 1275 target.addLegalOp( 1276 OperationName("test.legal_op_with_region", &getContext())); 1277 target 1278 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 1279 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 1280 // Don't allow F32 operands. 1281 return llvm::none_of(op.getOperandTypes(), 1282 [](Type type) { return type.isF32(); }); 1283 }); 1284 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1285 return converter.isSignatureLegal(op.getFunctionType()) && 1286 converter.isLegal(&op.getBody()); 1287 }); 1288 target.addDynamicallyLegalOp<func::CallOp>( 1289 [&](func::CallOp op) { return converter.isLegal(op); }); 1290 1291 // TestCreateUnregisteredOp creates `arith.constant` operation, 1292 // which was not added to target intentionally to test 1293 // correct error code from conversion driver. 1294 target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); 1295 1296 // Expect the type_producer/type_consumer operations to only operate on f64. 1297 target.addDynamicallyLegalOp<TestTypeProducerOp>( 1298 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1299 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1300 return op.getOperand().getType().isF64(); 1301 }); 1302 1303 // Check support for marking certain operations as recursively legal. 1304 target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) { 1305 return static_cast<bool>( 1306 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 1307 }); 1308 1309 // Mark the bound recursion operation as dynamically legal. 1310 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 1311 [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); 1312 1313 // Create a dynamically legal rule that can only be legalized by folding it. 1314 target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>( 1315 [](TestOpInPlaceSelfFold op) { return op.getFolded(); }); 1316 1317 // Handle a partial conversion. 1318 if (mode == ConversionMode::Partial) { 1319 DenseSet<Operation *> unlegalizedOps; 1320 ConversionConfig config; 1321 DumpNotifications dumpNotifications; 1322 config.listener = &dumpNotifications; 1323 config.unlegalizedOps = &unlegalizedOps; 1324 if (failed(applyPartialConversion(getOperation(), target, 1325 std::move(patterns), config))) { 1326 getOperation()->emitRemark() << "applyPartialConversion failed"; 1327 } 1328 // Emit remarks for each legalizable operation. 1329 for (auto *op : unlegalizedOps) 1330 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1331 return; 1332 } 1333 1334 // Handle a full conversion. 1335 if (mode == ConversionMode::Full) { 1336 // Check support for marking unknown operations as dynamically legal. 1337 target.markUnknownOpDynamicallyLegal([](Operation *op) { 1338 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 1339 }); 1340 1341 ConversionConfig config; 1342 DumpNotifications dumpNotifications; 1343 config.listener = &dumpNotifications; 1344 if (failed(applyFullConversion(getOperation(), target, 1345 std::move(patterns), config))) { 1346 getOperation()->emitRemark() << "applyFullConversion failed"; 1347 } 1348 return; 1349 } 1350 1351 // Otherwise, handle an analysis conversion. 1352 assert(mode == ConversionMode::Analysis); 1353 1354 // Analyze the convertible operations. 1355 DenseSet<Operation *> legalizedOps; 1356 ConversionConfig config; 1357 config.legalizableOps = &legalizedOps; 1358 if (failed(applyAnalysisConversion(getOperation(), target, 1359 std::move(patterns), config))) 1360 return signalPassFailure(); 1361 1362 // Emit remarks for each legalizable operation. 1363 for (auto *op : legalizedOps) 1364 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 1365 } 1366 1367 /// The mode of conversion to use. 1368 ConversionMode mode; 1369 }; 1370 } // namespace 1371 1372 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 1373 legalizerConversionMode( 1374 "test-legalize-mode", 1375 llvm::cl::desc("The legalization mode to use with the test driver"), 1376 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 1377 llvm::cl::values( 1378 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 1379 "analysis", "Perform an analysis conversion"), 1380 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 1381 "Perform a full conversion"), 1382 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 1383 "partial", "Perform a partial conversion"))); 1384 1385 //===----------------------------------------------------------------------===// 1386 // ConversionPatternRewriter::getRemappedValue testing. This method is used 1387 // to get the remapped value of an original value that was replaced using 1388 // ConversionPatternRewriter. 1389 namespace { 1390 struct TestRemapValueTypeConverter : public TypeConverter { 1391 using TypeConverter::TypeConverter; 1392 1393 TestRemapValueTypeConverter() { 1394 addConversion( 1395 [](Float32Type type) { return Float64Type::get(type.getContext()); }); 1396 addConversion([](Type type) { return type; }); 1397 } 1398 }; 1399 1400 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 1401 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 1402 /// operand twice. 1403 /// 1404 /// Example: 1405 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 1406 /// is replaced with: 1407 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 1408 struct OneVResOneVOperandOp1Converter 1409 : public OpConversionPattern<OneVResOneVOperandOp1> { 1410 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 1411 1412 LogicalResult 1413 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, 1414 ConversionPatternRewriter &rewriter) const override { 1415 auto origOps = op.getOperands(); 1416 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 1417 "One operand expected"); 1418 Value origOp = *origOps.begin(); 1419 SmallVector<Value, 2> remappedOperands; 1420 // Replicate the remapped original operand twice. Note that we don't used 1421 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 1422 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1423 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1424 1425 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 1426 remappedOperands); 1427 return success(); 1428 } 1429 }; 1430 1431 /// A rewriter pattern that tests that blocks can be merged. 1432 struct TestRemapValueInRegion 1433 : public OpConversionPattern<TestRemappedValueRegionOp> { 1434 using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern; 1435 1436 LogicalResult 1437 matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor, 1438 ConversionPatternRewriter &rewriter) const final { 1439 Block &block = op.getBody().front(); 1440 Operation *terminator = block.getTerminator(); 1441 1442 // Merge the block into the parent region. 1443 Block *parentBlock = op->getBlock(); 1444 Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator()); 1445 rewriter.mergeBlocks(&block, parentBlock, ValueRange()); 1446 rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange()); 1447 1448 // Replace the results of this operation with the remapped terminator 1449 // values. 1450 SmallVector<Value> terminatorOperands; 1451 if (failed(rewriter.getRemappedValues(terminator->getOperands(), 1452 terminatorOperands))) 1453 return failure(); 1454 1455 rewriter.eraseOp(terminator); 1456 rewriter.replaceOp(op, terminatorOperands); 1457 return success(); 1458 } 1459 }; 1460 1461 struct TestRemappedValue 1462 : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> { 1463 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue) 1464 1465 StringRef getArgument() const final { return "test-remapped-value"; } 1466 StringRef getDescription() const final { 1467 return "Test public remapped value mechanism in ConversionPatternRewriter"; 1468 } 1469 void runOnOperation() override { 1470 TestRemapValueTypeConverter typeConverter; 1471 1472 mlir::RewritePatternSet patterns(&getContext()); 1473 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 1474 patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>( 1475 &getContext()); 1476 patterns.add<TestRemapValueInRegion>(typeConverter, &getContext()); 1477 1478 mlir::ConversionTarget target(getContext()); 1479 target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>(); 1480 1481 // Expect the type_producer/type_consumer operations to only operate on f64. 1482 target.addDynamicallyLegalOp<TestTypeProducerOp>( 1483 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1484 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1485 return op.getOperand().getType().isF64(); 1486 }); 1487 1488 // We make OneVResOneVOperandOp1 legal only when it has more that one 1489 // operand. This will trigger the conversion that will replace one-operand 1490 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 1491 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 1492 [](Operation *op) { return op->getNumOperands() > 1; }); 1493 1494 if (failed(mlir::applyFullConversion(getOperation(), target, 1495 std::move(patterns)))) { 1496 signalPassFailure(); 1497 } 1498 } 1499 }; 1500 } // namespace 1501 1502 //===----------------------------------------------------------------------===// 1503 // Test patterns without a specific root operation kind 1504 //===----------------------------------------------------------------------===// 1505 1506 namespace { 1507 /// This pattern matches and removes any operation in the test dialect. 1508 struct RemoveTestDialectOps : public RewritePattern { 1509 RemoveTestDialectOps(MLIRContext *context) 1510 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 1511 1512 LogicalResult matchAndRewrite(Operation *op, 1513 PatternRewriter &rewriter) const override { 1514 if (!isa<TestDialect>(op->getDialect())) 1515 return failure(); 1516 rewriter.eraseOp(op); 1517 return success(); 1518 } 1519 }; 1520 1521 struct TestUnknownRootOpDriver 1522 : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> { 1523 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver) 1524 1525 StringRef getArgument() const final { 1526 return "test-legalize-unknown-root-patterns"; 1527 } 1528 StringRef getDescription() const final { 1529 return "Test public remapped value mechanism in ConversionPatternRewriter"; 1530 } 1531 void runOnOperation() override { 1532 mlir::RewritePatternSet patterns(&getContext()); 1533 patterns.add<RemoveTestDialectOps>(&getContext()); 1534 1535 mlir::ConversionTarget target(getContext()); 1536 target.addIllegalDialect<TestDialect>(); 1537 if (failed(applyPartialConversion(getOperation(), target, 1538 std::move(patterns)))) 1539 signalPassFailure(); 1540 } 1541 }; 1542 } // namespace 1543 1544 //===----------------------------------------------------------------------===// 1545 // Test patterns that uses operations and types defined at runtime 1546 //===----------------------------------------------------------------------===// 1547 1548 namespace { 1549 /// This pattern matches dynamic operations 'test.one_operand_two_results' and 1550 /// replace them with dynamic operations 'test.generic_dynamic_op'. 1551 struct RewriteDynamicOp : public RewritePattern { 1552 RewriteDynamicOp(MLIRContext *context) 1553 : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1, 1554 context) {} 1555 1556 LogicalResult matchAndRewrite(Operation *op, 1557 PatternRewriter &rewriter) const override { 1558 assert(op->getName().getStringRef() == 1559 "test.dynamic_one_operand_two_results" && 1560 "rewrite pattern should only match operations with the right name"); 1561 1562 OperationState state(op->getLoc(), "test.dynamic_generic", 1563 op->getOperands(), op->getResultTypes(), 1564 op->getAttrs()); 1565 auto *newOp = rewriter.create(state); 1566 rewriter.replaceOp(op, newOp->getResults()); 1567 return success(); 1568 } 1569 }; 1570 1571 struct TestRewriteDynamicOpDriver 1572 : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> { 1573 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver) 1574 1575 void getDependentDialects(DialectRegistry ®istry) const override { 1576 registry.insert<TestDialect>(); 1577 } 1578 StringRef getArgument() const final { return "test-rewrite-dynamic-op"; } 1579 StringRef getDescription() const final { 1580 return "Test rewritting on dynamic operations"; 1581 } 1582 void runOnOperation() override { 1583 RewritePatternSet patterns(&getContext()); 1584 patterns.add<RewriteDynamicOp>(&getContext()); 1585 1586 ConversionTarget target(getContext()); 1587 target.addIllegalOp( 1588 OperationName("test.dynamic_one_operand_two_results", &getContext())); 1589 target.addLegalOp(OperationName("test.dynamic_generic", &getContext())); 1590 if (failed(applyPartialConversion(getOperation(), target, 1591 std::move(patterns)))) 1592 signalPassFailure(); 1593 } 1594 }; 1595 } // end anonymous namespace 1596 1597 //===----------------------------------------------------------------------===// 1598 // Test type conversions 1599 //===----------------------------------------------------------------------===// 1600 1601 namespace { 1602 struct TestTypeConversionProducer 1603 : public OpConversionPattern<TestTypeProducerOp> { 1604 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 1605 LogicalResult 1606 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, 1607 ConversionPatternRewriter &rewriter) const final { 1608 Type resultType = op.getType(); 1609 Type convertedType = getTypeConverter() 1610 ? getTypeConverter()->convertType(resultType) 1611 : resultType; 1612 if (isa<FloatType>(resultType)) 1613 resultType = rewriter.getF64Type(); 1614 else if (resultType.isInteger(16)) 1615 resultType = rewriter.getIntegerType(64); 1616 else if (isa<test::TestRecursiveType>(resultType) && 1617 convertedType != resultType) 1618 resultType = convertedType; 1619 else 1620 return failure(); 1621 1622 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 1623 return success(); 1624 } 1625 }; 1626 1627 /// Call signature conversion and then fail the rewrite to trigger the undo 1628 /// mechanism. 1629 struct TestSignatureConversionUndo 1630 : public OpConversionPattern<TestSignatureConversionUndoOp> { 1631 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 1632 1633 LogicalResult 1634 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, 1635 ConversionPatternRewriter &rewriter) const final { 1636 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 1637 return failure(); 1638 } 1639 }; 1640 1641 /// Call signature conversion without providing a type converter to handle 1642 /// materializations. 1643 struct TestTestSignatureConversionNoConverter 1644 : public OpConversionPattern<TestSignatureConversionNoConverterOp> { 1645 TestTestSignatureConversionNoConverter(const TypeConverter &converter, 1646 MLIRContext *context) 1647 : OpConversionPattern<TestSignatureConversionNoConverterOp>(context), 1648 converter(converter) {} 1649 1650 LogicalResult 1651 matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor, 1652 ConversionPatternRewriter &rewriter) const final { 1653 Region ®ion = op->getRegion(0); 1654 Block *entry = ®ion.front(); 1655 1656 // Convert the original entry arguments. 1657 TypeConverter::SignatureConversion result(entry->getNumArguments()); 1658 if (failed( 1659 converter.convertSignatureArgs(entry->getArgumentTypes(), result))) 1660 return failure(); 1661 rewriter.modifyOpInPlace(op, [&] { 1662 rewriter.applySignatureConversion(®ion.front(), result); 1663 }); 1664 return success(); 1665 } 1666 1667 const TypeConverter &converter; 1668 }; 1669 1670 /// Just forward the operands to the root op. This is essentially a no-op 1671 /// pattern that is used to trigger target materialization. 1672 struct TestTypeConsumerForward 1673 : public OpConversionPattern<TestTypeConsumerOp> { 1674 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 1675 1676 LogicalResult 1677 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, 1678 ConversionPatternRewriter &rewriter) const final { 1679 rewriter.modifyOpInPlace(op, 1680 [&] { op->setOperands(adaptor.getOperands()); }); 1681 return success(); 1682 } 1683 }; 1684 1685 struct TestTypeConversionAnotherProducer 1686 : public OpRewritePattern<TestAnotherTypeProducerOp> { 1687 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 1688 1689 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 1690 PatternRewriter &rewriter) const final { 1691 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 1692 return success(); 1693 } 1694 }; 1695 1696 struct TestReplaceWithLegalOp : public ConversionPattern { 1697 TestReplaceWithLegalOp(MLIRContext *ctx) 1698 : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {} 1699 LogicalResult 1700 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1701 ConversionPatternRewriter &rewriter) const final { 1702 rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]); 1703 return success(); 1704 } 1705 }; 1706 1707 struct TestTypeConversionDriver 1708 : public PassWrapper<TestTypeConversionDriver, OperationPass<>> { 1709 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver) 1710 1711 void getDependentDialects(DialectRegistry ®istry) const override { 1712 registry.insert<TestDialect>(); 1713 } 1714 StringRef getArgument() const final { 1715 return "test-legalize-type-conversion"; 1716 } 1717 StringRef getDescription() const final { 1718 return "Test various type conversion functionalities in DialectConversion"; 1719 } 1720 1721 void runOnOperation() override { 1722 // Initialize the type converter. 1723 SmallVector<Type, 2> conversionCallStack; 1724 TypeConverter converter; 1725 1726 /// Add the legal set of type conversions. 1727 converter.addConversion([](Type type) -> Type { 1728 // Treat F64 as legal. 1729 if (type.isF64()) 1730 return type; 1731 // Allow converting BF16/F16/F32 to F64. 1732 if (type.isBF16() || type.isF16() || type.isF32()) 1733 return FloatType::getF64(type.getContext()); 1734 // Otherwise, the type is illegal. 1735 return nullptr; 1736 }); 1737 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 1738 // Drop all integer types. 1739 return success(); 1740 }); 1741 converter.addConversion( 1742 // Convert a recursive self-referring type into a non-self-referring 1743 // type named "outer_converted_type" that contains a SimpleAType. 1744 [&](test::TestRecursiveType type, 1745 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { 1746 // If the type is already converted, return it to indicate that it is 1747 // legal. 1748 if (type.getName() == "outer_converted_type") { 1749 results.push_back(type); 1750 return success(); 1751 } 1752 1753 conversionCallStack.push_back(type); 1754 auto popConversionCallStack = llvm::make_scope_exit( 1755 [&conversionCallStack]() { conversionCallStack.pop_back(); }); 1756 1757 // If the type is on the call stack more than once (it is there at 1758 // least once because of the _current_ call, which is always the last 1759 // element on the stack), we've hit the recursive case. Just return 1760 // SimpleAType here to create a non-recursive type as a result. 1761 if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(), 1762 type)) { 1763 results.push_back(test::SimpleAType::get(type.getContext())); 1764 return success(); 1765 } 1766 1767 // Convert the body recursively. 1768 auto result = test::TestRecursiveType::get(type.getContext(), 1769 "outer_converted_type"); 1770 if (failed(result.setBody(converter.convertType(type.getBody())))) 1771 return failure(); 1772 results.push_back(result); 1773 return success(); 1774 }); 1775 1776 /// Add the legal set of type materializations. 1777 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 1778 ValueRange inputs, 1779 Location loc) -> Value { 1780 // Allow casting from F64 back to F32. 1781 if (!resultType.isF16() && inputs.size() == 1 && 1782 inputs[0].getType().isF64()) 1783 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1784 // Allow producing an i32 or i64 from nothing. 1785 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 1786 inputs.empty()) 1787 return builder.create<TestTypeProducerOp>(loc, resultType); 1788 // Allow producing an i64 from an integer. 1789 if (isa<IntegerType>(resultType) && inputs.size() == 1 && 1790 isa<IntegerType>(inputs[0].getType())) 1791 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1792 // Otherwise, fail. 1793 return nullptr; 1794 }); 1795 1796 // Initialize the conversion target. 1797 mlir::ConversionTarget target(getContext()); 1798 target.addLegalOp<LegalOpD>(); 1799 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 1800 auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType()); 1801 return op.getType().isF64() || op.getType().isInteger(64) || 1802 (recursiveType && 1803 recursiveType.getName() == "outer_converted_type"); 1804 }); 1805 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1806 return converter.isSignatureLegal(op.getFunctionType()) && 1807 converter.isLegal(&op.getBody()); 1808 }); 1809 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 1810 // Allow casts from F64 to F32. 1811 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 1812 }); 1813 target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>( 1814 [&](TestSignatureConversionNoConverterOp op) { 1815 return converter.isLegal(op.getRegion().front().getArgumentTypes()); 1816 }); 1817 1818 // Initialize the set of rewrite patterns. 1819 RewritePatternSet patterns(&getContext()); 1820 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 1821 TestSignatureConversionUndo, 1822 TestTestSignatureConversionNoConverter>(converter, 1823 &getContext()); 1824 patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>( 1825 &getContext()); 1826 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1827 converter); 1828 1829 if (failed(applyPartialConversion(getOperation(), target, 1830 std::move(patterns)))) 1831 signalPassFailure(); 1832 } 1833 }; 1834 } // namespace 1835 1836 //===----------------------------------------------------------------------===// 1837 // Test Target Materialization With No Uses 1838 //===----------------------------------------------------------------------===// 1839 1840 namespace { 1841 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> { 1842 using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern; 1843 1844 LogicalResult 1845 matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor, 1846 ConversionPatternRewriter &rewriter) const final { 1847 rewriter.replaceOp(op, adaptor.getOperands()); 1848 return success(); 1849 } 1850 }; 1851 1852 struct TestTargetMaterializationWithNoUses 1853 : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> { 1854 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 1855 TestTargetMaterializationWithNoUses) 1856 1857 StringRef getArgument() const final { 1858 return "test-target-materialization-with-no-uses"; 1859 } 1860 StringRef getDescription() const final { 1861 return "Test a special case of target materialization in DialectConversion"; 1862 } 1863 1864 void runOnOperation() override { 1865 TypeConverter converter; 1866 converter.addConversion([](Type t) { return t; }); 1867 converter.addConversion([](IntegerType intTy) -> Type { 1868 if (intTy.getWidth() == 16) 1869 return IntegerType::get(intTy.getContext(), 64); 1870 return intTy; 1871 }); 1872 converter.addTargetMaterialization( 1873 [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { 1874 return builder.create<TestCastOp>(loc, type, inputs).getResult(); 1875 }); 1876 1877 ConversionTarget target(getContext()); 1878 target.addIllegalOp<TestTypeChangerOp>(); 1879 1880 RewritePatternSet patterns(&getContext()); 1881 patterns.add<ForwardOperandPattern>(converter, &getContext()); 1882 1883 if (failed(applyPartialConversion(getOperation(), target, 1884 std::move(patterns)))) 1885 signalPassFailure(); 1886 } 1887 }; 1888 } // namespace 1889 1890 //===----------------------------------------------------------------------===// 1891 // Test Block Merging 1892 //===----------------------------------------------------------------------===// 1893 1894 namespace { 1895 /// A rewriter pattern that tests that blocks can be merged. 1896 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 1897 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 1898 1899 LogicalResult 1900 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, 1901 ConversionPatternRewriter &rewriter) const final { 1902 Block &firstBlock = op.getBody().front(); 1903 Operation *branchOp = firstBlock.getTerminator(); 1904 Block *secondBlock = &*(std::next(op.getBody().begin())); 1905 auto succOperands = branchOp->getOperands(); 1906 SmallVector<Value, 2> replacements(succOperands); 1907 rewriter.eraseOp(branchOp); 1908 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1909 rewriter.modifyOpInPlace(op, [] {}); 1910 return success(); 1911 } 1912 }; 1913 1914 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 1915 struct TestUndoBlocksMerge : public ConversionPattern { 1916 TestUndoBlocksMerge(MLIRContext *ctx) 1917 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 1918 LogicalResult 1919 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1920 ConversionPatternRewriter &rewriter) const final { 1921 Block &firstBlock = op->getRegion(0).front(); 1922 Operation *branchOp = firstBlock.getTerminator(); 1923 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 1924 rewriter.setInsertionPointToStart(secondBlock); 1925 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 1926 auto succOperands = branchOp->getOperands(); 1927 SmallVector<Value, 2> replacements(succOperands); 1928 rewriter.eraseOp(branchOp); 1929 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1930 rewriter.modifyOpInPlace(op, [] {}); 1931 return success(); 1932 } 1933 }; 1934 1935 /// A rewrite mechanism to inline the body of the op into its parent, when both 1936 /// ops can have a single block. 1937 struct TestMergeSingleBlockOps 1938 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 1939 using OpConversionPattern< 1940 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 1941 1942 LogicalResult 1943 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, 1944 ConversionPatternRewriter &rewriter) const final { 1945 SingleBlockImplicitTerminatorOp parentOp = 1946 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1947 if (!parentOp) 1948 return failure(); 1949 Block &innerBlock = op.getRegion().front(); 1950 TerminatorOp innerTerminator = 1951 cast<TerminatorOp>(innerBlock.getTerminator()); 1952 rewriter.inlineBlockBefore(&innerBlock, op); 1953 rewriter.eraseOp(innerTerminator); 1954 rewriter.eraseOp(op); 1955 return success(); 1956 } 1957 }; 1958 1959 struct TestMergeBlocksPatternDriver 1960 : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> { 1961 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver) 1962 1963 StringRef getArgument() const final { return "test-merge-blocks"; } 1964 StringRef getDescription() const final { 1965 return "Test Merging operation in ConversionPatternRewriter"; 1966 } 1967 void runOnOperation() override { 1968 MLIRContext *context = &getContext(); 1969 mlir::RewritePatternSet patterns(context); 1970 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 1971 context); 1972 ConversionTarget target(*context); 1973 target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 1974 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 1975 target.addIllegalOp<ILLegalOpF>(); 1976 1977 /// Expect the op to have a single block after legalization. 1978 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 1979 [&](TestMergeBlocksOp op) -> bool { 1980 return llvm::hasSingleElement(op.getBody()); 1981 }); 1982 1983 /// Only allow `test.br` within test.merge_blocks op. 1984 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 1985 return op->getParentOfType<TestMergeBlocksOp>(); 1986 }); 1987 1988 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 1989 /// inlined. 1990 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 1991 [&](SingleBlockImplicitTerminatorOp op) -> bool { 1992 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1993 }); 1994 1995 DenseSet<Operation *> unlegalizedOps; 1996 ConversionConfig config; 1997 config.unlegalizedOps = &unlegalizedOps; 1998 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 1999 config); 2000 for (auto *op : unlegalizedOps) 2001 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 2002 } 2003 }; 2004 } // namespace 2005 2006 //===----------------------------------------------------------------------===// 2007 // Test Selective Replacement 2008 //===----------------------------------------------------------------------===// 2009 2010 namespace { 2011 /// A rewrite mechanism to inline the body of the op into its parent, when both 2012 /// ops can have a single block. 2013 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 2014 using OpRewritePattern<TestCastOp>::OpRewritePattern; 2015 2016 LogicalResult matchAndRewrite(TestCastOp op, 2017 PatternRewriter &rewriter) const final { 2018 if (op.getNumOperands() != 2) 2019 return failure(); 2020 OperandRange operands = op.getOperands(); 2021 2022 // Replace non-terminator uses with the first operand. 2023 rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) { 2024 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 2025 }); 2026 // Replace everything else with the second operand if the operation isn't 2027 // dead. 2028 rewriter.replaceOp(op, op.getOperand(1)); 2029 return success(); 2030 } 2031 }; 2032 2033 struct TestSelectiveReplacementPatternDriver 2034 : public PassWrapper<TestSelectiveReplacementPatternDriver, 2035 OperationPass<>> { 2036 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 2037 TestSelectiveReplacementPatternDriver) 2038 2039 StringRef getArgument() const final { 2040 return "test-pattern-selective-replacement"; 2041 } 2042 StringRef getDescription() const final { 2043 return "Test selective replacement in the PatternRewriter"; 2044 } 2045 void runOnOperation() override { 2046 MLIRContext *context = &getContext(); 2047 mlir::RewritePatternSet patterns(context); 2048 patterns.add<TestSelectiveOpReplacementPattern>(context); 2049 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 2050 } 2051 }; 2052 } // namespace 2053 2054 //===----------------------------------------------------------------------===// 2055 // PassRegistration 2056 //===----------------------------------------------------------------------===// 2057 2058 namespace mlir { 2059 namespace test { 2060 void registerPatternsTestPass() { 2061 PassRegistration<TestReturnTypeDriver>(); 2062 2063 PassRegistration<TestDerivedAttributeDriver>(); 2064 2065 PassRegistration<TestGreedyPatternDriver>(); 2066 PassRegistration<TestStrictPatternDriver>(); 2067 PassRegistration<TestWalkPatternDriver>(); 2068 2069 PassRegistration<TestLegalizePatternDriver>([] { 2070 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 2071 }); 2072 2073 PassRegistration<TestRemappedValue>(); 2074 2075 PassRegistration<TestUnknownRootOpDriver>(); 2076 2077 PassRegistration<TestTypeConversionDriver>(); 2078 PassRegistration<TestTargetMaterializationWithNoUses>(); 2079 2080 PassRegistration<TestRewriteDynamicOpDriver>(); 2081 2082 PassRegistration<TestMergeBlocksPatternDriver>(); 2083 PassRegistration<TestSelectiveReplacementPatternDriver>(); 2084 } 2085 } // namespace test 2086 } // namespace mlir 2087