1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestOps.h" 11 #include "TestTypes.h" 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/BuiltinAttributes.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/Visitors.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "mlir/Transforms/FoldUtils.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 #include "mlir/Transforms/WalkPatternRewriteDriver.h" 24 #include "llvm/ADT/ScopeExit.h" 25 #include <cstdint> 26 27 using namespace mlir; 28 using namespace test; 29 30 // Native function for testing NativeCodeCall 31 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 32 return choice.getValue() ? input1 : input2; 33 } 34 35 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 36 rewriter.create<OpI>(loc, input); 37 } 38 39 static void handleNoResultOp(PatternRewriter &rewriter, 40 OpSymbolBindingNoResult op) { 41 // Turn the no result op to a one-result op. 42 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), 43 op.getOperand()); 44 } 45 46 static bool getFirstI32Result(Operation *op, Value &value) { 47 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 48 return false; 49 value = op->getResult(0); 50 return true; 51 } 52 53 static Value bindNativeCodeCallResult(Value value) { return value; } 54 55 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, 56 Value input2) { 57 return SmallVector<Value, 2>({input2, input1}); 58 } 59 60 // Test that natives calls are only called once during rewrites. 61 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 62 // This let us check the number of times OpM_Test was called by inspecting 63 // the returned value in the MLIR output. 64 static int64_t opMIncreasingValue = 314159265; 65 static Attribute opMTest(PatternRewriter &rewriter, Value val) { 66 int64_t i = opMIncreasingValue++; 67 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 68 } 69 70 namespace { 71 #include "TestPatterns.inc" 72 } // namespace 73 74 //===----------------------------------------------------------------------===// 75 // Test Reduce Pattern Interface 76 //===----------------------------------------------------------------------===// 77 78 void test::populateTestReductionPatterns(RewritePatternSet &patterns) { 79 populateWithGenerated(patterns); 80 } 81 82 //===----------------------------------------------------------------------===// 83 // Canonicalizer Driver. 84 //===----------------------------------------------------------------------===// 85 86 namespace { 87 struct FoldingPattern : public RewritePattern { 88 public: 89 FoldingPattern(MLIRContext *context) 90 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 91 /*benefit=*/1, context) {} 92 93 LogicalResult matchAndRewrite(Operation *op, 94 PatternRewriter &rewriter) const override { 95 // Exercise createOrFold API for a single-result operation that is folded 96 // upon construction. The operation being created has an in-place folder, 97 // and it should be still present in the output. Furthermore, the folder 98 // should not crash when attempting to recover the (unchanged) operation 99 // result. 100 Value result = rewriter.createOrFold<TestOpInPlaceFold>( 101 op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0)); 102 assert(result); 103 rewriter.replaceOp(op, result); 104 return success(); 105 } 106 }; 107 108 /// This pattern creates a foldable operation at the entry point of the block. 109 /// This tests the situation where the operation folder will need to replace an 110 /// operation with a previously created constant that does not initially 111 /// dominate the operation to replace. 112 struct FolderInsertBeforePreviouslyFoldedConstantPattern 113 : public OpRewritePattern<TestCastOp> { 114 public: 115 using OpRewritePattern<TestCastOp>::OpRewritePattern; 116 117 LogicalResult matchAndRewrite(TestCastOp op, 118 PatternRewriter &rewriter) const override { 119 if (!op->hasAttr("test_fold_before_previously_folded_op")) 120 return failure(); 121 rewriter.setInsertionPointToStart(op->getBlock()); 122 123 auto constOp = rewriter.create<arith::ConstantOp>( 124 op.getLoc(), rewriter.getBoolAttr(true)); 125 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), 126 Value(constOp)); 127 return success(); 128 } 129 }; 130 131 /// This pattern matches test.op_commutative2 with the first operand being 132 /// another test.op_commutative2 with a constant on the right side and fold it 133 /// away by propagating it as its result. This is intend to check that patterns 134 /// are applied after the commutative property moves constant to the right. 135 struct FolderCommutativeOp2WithConstant 136 : public OpRewritePattern<TestCommutative2Op> { 137 public: 138 using OpRewritePattern<TestCommutative2Op>::OpRewritePattern; 139 140 LogicalResult matchAndRewrite(TestCommutative2Op op, 141 PatternRewriter &rewriter) const override { 142 auto operand = 143 dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp()); 144 if (!operand) 145 return failure(); 146 Attribute constInput; 147 if (!matchPattern(operand->getOperand(1), m_Constant(&constInput))) 148 return failure(); 149 rewriter.replaceOp(op, operand->getOperand(1)); 150 return success(); 151 } 152 }; 153 154 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer 155 /// attribute with value smaller than MaxVal, it increments the value by 1. 156 template <int MaxVal> 157 struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> { 158 using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern; 159 160 LogicalResult matchAndRewrite(AnyAttrOfOp op, 161 PatternRewriter &rewriter) const override { 162 auto intAttr = dyn_cast<IntegerAttr>(op.getAttr()); 163 if (!intAttr) 164 return failure(); 165 int64_t val = intAttr.getInt(); 166 if (val >= MaxVal) 167 return failure(); 168 rewriter.modifyOpInPlace( 169 op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); }); 170 return success(); 171 } 172 }; 173 174 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op". 175 struct MakeOpEligible : public RewritePattern { 176 MakeOpEligible(MLIRContext *context) 177 : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {} 178 179 LogicalResult matchAndRewrite(Operation *op, 180 PatternRewriter &rewriter) const override { 181 if (op->hasAttr("eligible")) 182 return failure(); 183 rewriter.modifyOpInPlace( 184 op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); }); 185 return success(); 186 } 187 }; 188 189 /// This pattern hoists eligible ops out of a "test.one_region_op". 190 struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> { 191 using OpRewritePattern<test::OneRegionOp>::OpRewritePattern; 192 193 LogicalResult matchAndRewrite(test::OneRegionOp op, 194 PatternRewriter &rewriter) const override { 195 Operation *terminator = op.getRegion().front().getTerminator(); 196 Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp(); 197 if (toBeHoisted->getParentOp() != op) 198 return failure(); 199 if (!toBeHoisted->hasAttr("eligible")) 200 return failure(); 201 rewriter.moveOpBefore(toBeHoisted, op); 202 return success(); 203 } 204 }; 205 206 /// This pattern moves "test.move_before_parent_op" before the parent op. 207 struct MoveBeforeParentOp : public RewritePattern { 208 MoveBeforeParentOp(MLIRContext *context) 209 : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {} 210 211 LogicalResult matchAndRewrite(Operation *op, 212 PatternRewriter &rewriter) const override { 213 // Do not hoist past functions. 214 if (isa<FunctionOpInterface>(op->getParentOp())) 215 return failure(); 216 rewriter.moveOpBefore(op, op->getParentOp()); 217 return success(); 218 } 219 }; 220 221 /// This pattern moves "test.move_after_parent_op" after the parent op. 222 struct MoveAfterParentOp : public RewritePattern { 223 MoveAfterParentOp(MLIRContext *context) 224 : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {} 225 226 LogicalResult matchAndRewrite(Operation *op, 227 PatternRewriter &rewriter) const override { 228 // Do not hoist past functions. 229 if (isa<FunctionOpInterface>(op->getParentOp())) 230 return failure(); 231 232 int64_t moveForwardBy = 0; 233 if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance")) 234 moveForwardBy = advanceBy.getInt(); 235 236 Operation *moveAfter = op->getParentOp(); 237 for (int64_t i = 0; i < moveForwardBy; ++i) 238 moveAfter = moveAfter->getNextNode(); 239 240 rewriter.moveOpAfter(op, moveAfter); 241 return success(); 242 } 243 }; 244 245 /// This pattern inlines blocks that are nested in 246 /// "test.inline_blocks_into_parent" into the parent block. 247 struct InlineBlocksIntoParent : public RewritePattern { 248 InlineBlocksIntoParent(MLIRContext *context) 249 : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1, 250 context) {} 251 252 LogicalResult matchAndRewrite(Operation *op, 253 PatternRewriter &rewriter) const override { 254 bool changed = false; 255 for (Region &r : op->getRegions()) { 256 while (!r.empty()) { 257 rewriter.inlineBlockBefore(&r.front(), op); 258 changed = true; 259 } 260 } 261 return success(changed); 262 } 263 }; 264 265 /// This pattern splits blocks at "test.split_block_here" and replaces the op 266 /// with a new op (to prevent an infinite loop of block splitting). 267 struct SplitBlockHere : public RewritePattern { 268 SplitBlockHere(MLIRContext *context) 269 : RewritePattern("test.split_block_here", /*benefit=*/1, context) {} 270 271 LogicalResult matchAndRewrite(Operation *op, 272 PatternRewriter &rewriter) const override { 273 rewriter.splitBlock(op->getBlock(), op->getIterator()); 274 Operation *newOp = rewriter.create( 275 op->getLoc(), 276 OperationName("test.new_op", op->getContext()).getIdentifier(), 277 op->getOperands(), op->getResultTypes()); 278 rewriter.replaceOp(op, newOp); 279 return success(); 280 } 281 }; 282 283 /// This pattern clones "test.clone_me" ops. 284 struct CloneOp : public RewritePattern { 285 CloneOp(MLIRContext *context) 286 : RewritePattern("test.clone_me", /*benefit=*/1, context) {} 287 288 LogicalResult matchAndRewrite(Operation *op, 289 PatternRewriter &rewriter) const override { 290 // Do not clone already cloned ops to avoid going into an infinite loop. 291 if (op->hasAttr("was_cloned")) 292 return failure(); 293 Operation *cloned = rewriter.clone(*op); 294 cloned->setAttr("was_cloned", rewriter.getUnitAttr()); 295 return success(); 296 } 297 }; 298 299 /// This pattern clones regions of "test.clone_region_before" ops before the 300 /// parent block. 301 struct CloneRegionBeforeOp : public RewritePattern { 302 CloneRegionBeforeOp(MLIRContext *context) 303 : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {} 304 305 LogicalResult matchAndRewrite(Operation *op, 306 PatternRewriter &rewriter) const override { 307 // Do not clone already cloned ops to avoid going into an infinite loop. 308 if (op->hasAttr("was_cloned")) 309 return failure(); 310 for (Region &r : op->getRegions()) 311 rewriter.cloneRegionBefore(r, op->getBlock()); 312 op->setAttr("was_cloned", rewriter.getUnitAttr()); 313 return success(); 314 } 315 }; 316 317 /// Replace an operation may introduce the re-visiting of its users. 318 class ReplaceWithNewOp : public RewritePattern { 319 public: 320 ReplaceWithNewOp(MLIRContext *context) 321 : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {} 322 323 LogicalResult matchAndRewrite(Operation *op, 324 PatternRewriter &rewriter) const override { 325 Operation *newOp; 326 if (op->hasAttr("create_erase_op")) { 327 newOp = rewriter.create( 328 op->getLoc(), 329 OperationName("test.erase_op", op->getContext()).getIdentifier(), 330 ValueRange(), TypeRange()); 331 } else { 332 newOp = rewriter.create( 333 op->getLoc(), 334 OperationName("test.new_op", op->getContext()).getIdentifier(), 335 op->getOperands(), op->getResultTypes()); 336 } 337 // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp". 338 // A "notifyOperationReplaced" callback is triggered in either case. 339 rewriter.replaceAllOpUsesWith(op, newOp->getResults()); 340 rewriter.eraseOp(op); 341 return success(); 342 } 343 }; 344 345 /// Erases the first child block of the matched "test.erase_first_block" 346 /// operation. 347 class EraseFirstBlock : public RewritePattern { 348 public: 349 EraseFirstBlock(MLIRContext *context) 350 : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {} 351 352 LogicalResult matchAndRewrite(Operation *op, 353 PatternRewriter &rewriter) const override { 354 for (Region &r : op->getRegions()) { 355 for (Block &b : r.getBlocks()) { 356 rewriter.eraseBlock(&b); 357 return success(); 358 } 359 } 360 361 return failure(); 362 } 363 }; 364 365 struct TestGreedyPatternDriver 366 : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> { 367 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver) 368 369 TestGreedyPatternDriver() = default; 370 TestGreedyPatternDriver(const TestGreedyPatternDriver &other) 371 : PassWrapper(other) {} 372 373 StringRef getArgument() const final { return "test-greedy-patterns"; } 374 StringRef getDescription() const final { return "Run test dialect patterns"; } 375 void runOnOperation() override { 376 mlir::RewritePatternSet patterns(&getContext()); 377 populateWithGenerated(patterns); 378 379 // Verify named pattern is generated with expected name. 380 patterns.add<FoldingPattern, TestNamedPatternRule, 381 FolderInsertBeforePreviouslyFoldedConstantPattern, 382 FolderCommutativeOp2WithConstant, HoistEligibleOps, 383 MakeOpEligible>(&getContext()); 384 385 // Additional patterns for testing the GreedyPatternRewriteDriver. 386 patterns.insert<IncrementIntAttribute<3>>(&getContext()); 387 388 GreedyRewriteConfig config; 389 config.useTopDownTraversal = this->useTopDownTraversal; 390 config.maxIterations = this->maxIterations; 391 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), 392 config); 393 } 394 395 Option<bool> useTopDownTraversal{ 396 *this, "top-down", 397 llvm::cl::desc("Seed the worklist in general top-down order"), 398 llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)}; 399 Option<int> maxIterations{ 400 *this, "max-iterations", 401 llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"), 402 llvm::cl::init(GreedyRewriteConfig().maxIterations)}; 403 }; 404 405 struct DumpNotifications : public RewriterBase::Listener { 406 void notifyBlockInserted(Block *block, Region *previous, 407 Region::iterator previousIt) override { 408 llvm::outs() << "notifyBlockInserted"; 409 if (block->getParentOp()) { 410 llvm::outs() << " into " << block->getParentOp()->getName() << ": "; 411 } else { 412 llvm::outs() << " into unknown op: "; 413 } 414 if (previous == nullptr) { 415 llvm::outs() << "was unlinked\n"; 416 } else { 417 llvm::outs() << "was linked\n"; 418 } 419 } 420 void notifyOperationInserted(Operation *op, 421 OpBuilder::InsertPoint previous) override { 422 llvm::outs() << "notifyOperationInserted: " << op->getName(); 423 if (!previous.isSet()) { 424 llvm::outs() << ", was unlinked\n"; 425 } else { 426 if (!previous.getPoint().getNodePtr()) { 427 llvm::outs() << ", was linked, exact position unknown\n"; 428 } else if (previous.getPoint() == previous.getBlock()->end()) { 429 llvm::outs() << ", was last in block\n"; 430 } else { 431 llvm::outs() << ", previous = " << previous.getPoint()->getName() 432 << "\n"; 433 } 434 } 435 } 436 void notifyBlockErased(Block *block) override { 437 llvm::outs() << "notifyBlockErased\n"; 438 } 439 void notifyOperationErased(Operation *op) override { 440 llvm::outs() << "notifyOperationErased: " << op->getName() << "\n"; 441 } 442 void notifyOperationModified(Operation *op) override { 443 llvm::outs() << "notifyOperationModified: " << op->getName() << "\n"; 444 } 445 void notifyOperationReplaced(Operation *op, ValueRange values) override { 446 llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n"; 447 } 448 }; 449 450 struct TestStrictPatternDriver 451 : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> { 452 public: 453 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver) 454 455 TestStrictPatternDriver() = default; 456 TestStrictPatternDriver(const TestStrictPatternDriver &other) 457 : PassWrapper(other) { 458 strictMode = other.strictMode; 459 } 460 461 StringRef getArgument() const final { return "test-strict-pattern-driver"; } 462 StringRef getDescription() const final { 463 return "Test strict mode of pattern driver"; 464 } 465 466 void runOnOperation() override { 467 MLIRContext *ctx = &getContext(); 468 mlir::RewritePatternSet patterns(ctx); 469 patterns.add< 470 // clang-format off 471 ChangeBlockOp, 472 CloneOp, 473 CloneRegionBeforeOp, 474 EraseOp, 475 ImplicitChangeOp, 476 InlineBlocksIntoParent, 477 InsertSameOp, 478 MoveBeforeParentOp, 479 ReplaceWithNewOp, 480 SplitBlockHere 481 // clang-format on 482 >(ctx); 483 SmallVector<Operation *> ops; 484 getOperation()->walk([&](Operation *op) { 485 StringRef opName = op->getName().getStringRef(); 486 if (opName == "test.insert_same_op" || opName == "test.change_block_op" || 487 opName == "test.replace_with_new_op" || opName == "test.erase_op" || 488 opName == "test.move_before_parent_op" || 489 opName == "test.inline_blocks_into_parent" || 490 opName == "test.split_block_here" || opName == "test.clone_me" || 491 opName == "test.clone_region_before") { 492 ops.push_back(op); 493 } 494 }); 495 496 DumpNotifications dumpNotifications; 497 GreedyRewriteConfig config; 498 config.listener = &dumpNotifications; 499 if (strictMode == "AnyOp") { 500 config.strictMode = GreedyRewriteStrictness::AnyOp; 501 } else if (strictMode == "ExistingAndNewOps") { 502 config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; 503 } else if (strictMode == "ExistingOps") { 504 config.strictMode = GreedyRewriteStrictness::ExistingOps; 505 } else { 506 llvm_unreachable("invalid strictness option"); 507 } 508 509 // Check if these transformations introduce visiting of operations that 510 // are not in the `ops` set (The new created ops are valid). An invalid 511 // operation will trigger the assertion while processing. 512 bool changed = false; 513 bool allErased = false; 514 (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config, 515 &changed, &allErased); 516 Builder b(ctx); 517 getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed)); 518 getOperation()->setAttr("pattern_driver_all_erased", 519 b.getBoolAttr(allErased)); 520 } 521 522 Option<std::string> strictMode{ 523 *this, "strictness", 524 llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"), 525 llvm::cl::init("AnyOp")}; 526 527 private: 528 // New inserted operation is valid for further transformation. 529 class InsertSameOp : public RewritePattern { 530 public: 531 InsertSameOp(MLIRContext *context) 532 : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {} 533 534 LogicalResult matchAndRewrite(Operation *op, 535 PatternRewriter &rewriter) const override { 536 if (op->hasAttr("skip")) 537 return failure(); 538 539 Operation *newOp = 540 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 541 op->getOperands(), op->getResultTypes()); 542 rewriter.modifyOpInPlace( 543 op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); }); 544 newOp->setAttr("skip", rewriter.getBoolAttr(true)); 545 546 return success(); 547 } 548 }; 549 550 // Remove an operation may introduce the re-visiting of its operands. 551 class EraseOp : public RewritePattern { 552 public: 553 EraseOp(MLIRContext *context) 554 : RewritePattern("test.erase_op", /*benefit=*/1, context) {} 555 LogicalResult matchAndRewrite(Operation *op, 556 PatternRewriter &rewriter) const override { 557 rewriter.eraseOp(op); 558 return success(); 559 } 560 }; 561 562 // The following two patterns test RewriterBase::replaceAllUsesWith. 563 // 564 // That function replaces all usages of a Block (or a Value) with another one 565 // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver 566 // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its 567 // worklist: when an op is modified, it is added to the worklist. The two 568 // patterns below make the tracking observable: ChangeBlockOp replaces all 569 // usages of a block and that pattern is applied because the corresponding ops 570 // are put on the initial worklist (see above). ImplicitChangeOp does an 571 // unrelated change but ops of the corresponding type are *not* on the initial 572 // worklist, so the effect of the second pattern is only visible if the 573 // tracking and subsequent adding to the worklist actually works. 574 575 // Replace all usages of the first successor with the second successor. 576 class ChangeBlockOp : public RewritePattern { 577 public: 578 ChangeBlockOp(MLIRContext *context) 579 : RewritePattern("test.change_block_op", /*benefit=*/1, context) {} 580 LogicalResult matchAndRewrite(Operation *op, 581 PatternRewriter &rewriter) const override { 582 if (op->getNumSuccessors() < 2) 583 return failure(); 584 Block *firstSuccessor = op->getSuccessor(0); 585 Block *secondSuccessor = op->getSuccessor(1); 586 if (firstSuccessor == secondSuccessor) 587 return failure(); 588 // This is the function being tested: 589 rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor); 590 // Using the following line instead would make the test fail: 591 // firstSuccessor->replaceAllUsesWith(secondSuccessor); 592 return success(); 593 } 594 }; 595 596 // Changes the successor to the parent block. 597 class ImplicitChangeOp : public RewritePattern { 598 public: 599 ImplicitChangeOp(MLIRContext *context) 600 : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {} 601 LogicalResult matchAndRewrite(Operation *op, 602 PatternRewriter &rewriter) const override { 603 if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock()) 604 return failure(); 605 rewriter.modifyOpInPlace(op, 606 [&]() { op->setSuccessor(op->getBlock(), 0); }); 607 return success(); 608 } 609 }; 610 }; 611 612 struct TestWalkPatternDriver final 613 : PassWrapper<TestWalkPatternDriver, OperationPass<>> { 614 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver) 615 616 TestWalkPatternDriver() = default; 617 TestWalkPatternDriver(const TestWalkPatternDriver &other) 618 : PassWrapper(other) {} 619 620 StringRef getArgument() const override { 621 return "test-walk-pattern-rewrite-driver"; 622 } 623 StringRef getDescription() const override { 624 return "Run test walk pattern rewrite driver"; 625 } 626 void runOnOperation() override { 627 mlir::RewritePatternSet patterns(&getContext()); 628 629 // Patterns for testing the WalkPatternRewriteDriver. 630 patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp, 631 MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>( 632 &getContext()); 633 634 DumpNotifications dumpListener; 635 walkAndApplyPatterns(getOperation(), std::move(patterns), 636 dumpNotifications ? &dumpListener : nullptr); 637 } 638 639 Option<bool> dumpNotifications{ 640 *this, "dump-notifications", 641 llvm::cl::desc("Print rewrite listener notifications"), 642 llvm::cl::init(false)}; 643 }; 644 645 } // namespace 646 647 //===----------------------------------------------------------------------===// 648 // ReturnType Driver. 649 //===----------------------------------------------------------------------===// 650 651 namespace { 652 // Generate ops for each instance where the type can be successfully inferred. 653 template <typename OpTy> 654 static void invokeCreateWithInferredReturnType(Operation *op) { 655 auto *context = op->getContext(); 656 auto fop = op->getParentOfType<func::FuncOp>(); 657 auto location = UnknownLoc::get(context); 658 OpBuilder b(op); 659 b.setInsertionPointAfter(op); 660 661 // Use permutations of 2 args as operands. 662 assert(fop.getNumArguments() >= 2); 663 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 664 for (int j = 0; j < e; ++j) { 665 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 666 SmallVector<Type, 2> inferredReturnTypes; 667 if (succeeded(OpTy::inferReturnTypes( 668 context, std::nullopt, values, op->getDiscardableAttrDictionary(), 669 op->getPropertiesStorage(), op->getRegions(), 670 inferredReturnTypes))) { 671 OperationState state(location, OpTy::getOperationName()); 672 // TODO: Expand to regions. 673 OpTy::build(b, state, values, op->getAttrs()); 674 (void)b.create(state); 675 } 676 } 677 } 678 } 679 680 static void reifyReturnShape(Operation *op) { 681 OpBuilder b(op); 682 683 // Use permutations of 2 args as operands. 684 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 685 SmallVector<Value, 2> shapes; 686 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 687 !llvm::hasSingleElement(shapes)) 688 return; 689 for (const auto &it : llvm::enumerate(shapes)) { 690 op->emitRemark() << "value " << it.index() << ": " 691 << it.value().getDefiningOp(); 692 } 693 } 694 695 struct TestReturnTypeDriver 696 : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> { 697 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver) 698 699 void getDependentDialects(DialectRegistry ®istry) const override { 700 registry.insert<tensor::TensorDialect>(); 701 } 702 StringRef getArgument() const final { return "test-return-type"; } 703 StringRef getDescription() const final { return "Run return type functions"; } 704 705 void runOnOperation() override { 706 if (getOperation().getName() == "testCreateFunctions") { 707 std::vector<Operation *> ops; 708 // Collect ops to avoid triggering on inserted ops. 709 for (auto &op : getOperation().getBody().front()) 710 ops.push_back(&op); 711 // Generate test patterns for each, but skip terminator. 712 for (auto *op : llvm::ArrayRef(ops).drop_back()) { 713 // Test create method of each of the Op classes below. The resultant 714 // output would be in reverse order underneath `op` from which 715 // the attributes and regions are used. 716 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 717 invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>( 718 op); 719 invokeCreateWithInferredReturnType< 720 OpWithShapedTypeInferTypeInterfaceOp>(op); 721 }; 722 return; 723 } 724 if (getOperation().getName() == "testReifyFunctions") { 725 std::vector<Operation *> ops; 726 // Collect ops to avoid triggering on inserted ops. 727 for (auto &op : getOperation().getBody().front()) 728 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 729 ops.push_back(&op); 730 // Generate test patterns for each, but skip terminator. 731 for (auto *op : ops) 732 reifyReturnShape(op); 733 } 734 } 735 }; 736 } // namespace 737 738 namespace { 739 struct TestDerivedAttributeDriver 740 : public PassWrapper<TestDerivedAttributeDriver, 741 OperationPass<func::FuncOp>> { 742 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver) 743 744 StringRef getArgument() const final { return "test-derived-attr"; } 745 StringRef getDescription() const final { 746 return "Run test derived attributes"; 747 } 748 void runOnOperation() override; 749 }; 750 } // namespace 751 752 void TestDerivedAttributeDriver::runOnOperation() { 753 getOperation().walk([](DerivedAttributeOpInterface dOp) { 754 auto dAttr = dOp.materializeDerivedAttributes(); 755 if (!dAttr) 756 return; 757 for (auto d : dAttr) 758 dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue(); 759 }); 760 } 761 762 //===----------------------------------------------------------------------===// 763 // Legalization Driver. 764 //===----------------------------------------------------------------------===// 765 766 namespace { 767 //===----------------------------------------------------------------------===// 768 // Region-Block Rewrite Testing 769 770 /// This pattern applies a signature conversion to a block inside a detached 771 /// region. 772 struct TestDetachedSignatureConversion : public ConversionPattern { 773 TestDetachedSignatureConversion(MLIRContext *ctx) 774 : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1, 775 ctx) {} 776 777 LogicalResult 778 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 779 ConversionPatternRewriter &rewriter) const final { 780 if (op->getNumRegions() != 1) 781 return failure(); 782 OperationState state(op->getLoc(), "test.legal_op_with_region", operands, 783 op->getResultTypes(), {}, BlockRange()); 784 Region *newRegion = state.addRegion(); 785 rewriter.inlineRegionBefore(op->getRegion(0), *newRegion, 786 newRegion->begin()); 787 TypeConverter::SignatureConversion result(newRegion->getNumArguments()); 788 for (unsigned i = 0, e = newRegion->getNumArguments(); i < e; ++i) 789 result.addInputs(i, rewriter.getF64Type()); 790 rewriter.applySignatureConversion(&newRegion->front(), result); 791 Operation *newOp = rewriter.create(state); 792 rewriter.replaceOp(op, newOp->getResults()); 793 return success(); 794 } 795 }; 796 797 /// This pattern is a simple pattern that inlines the first region of a given 798 /// operation into the parent region. 799 struct TestRegionRewriteBlockMovement : public ConversionPattern { 800 TestRegionRewriteBlockMovement(MLIRContext *ctx) 801 : ConversionPattern("test.region", 1, ctx) {} 802 803 LogicalResult 804 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 805 ConversionPatternRewriter &rewriter) const final { 806 // Inline this region into the parent region. 807 auto &parentRegion = *op->getParentRegion(); 808 auto &opRegion = op->getRegion(0); 809 if (op->getDiscardableAttr("legalizer.should_clone")) 810 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 811 else 812 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 813 814 if (op->getDiscardableAttr("legalizer.erase_old_blocks")) { 815 while (!opRegion.empty()) 816 rewriter.eraseBlock(&opRegion.front()); 817 } 818 819 // Drop this operation. 820 rewriter.eraseOp(op); 821 return success(); 822 } 823 }; 824 /// This pattern is a simple pattern that generates a region containing an 825 /// illegal operation. 826 struct TestRegionRewriteUndo : public RewritePattern { 827 TestRegionRewriteUndo(MLIRContext *ctx) 828 : RewritePattern("test.region_builder", 1, ctx) {} 829 830 LogicalResult matchAndRewrite(Operation *op, 831 PatternRewriter &rewriter) const final { 832 // Create the region operation with an entry block containing arguments. 833 OperationState newRegion(op->getLoc(), "test.region"); 834 newRegion.addRegion(); 835 auto *regionOp = rewriter.create(newRegion); 836 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 837 entryBlock->addArgument(rewriter.getIntegerType(64), 838 rewriter.getUnknownLoc()); 839 840 // Add an explicitly illegal operation to ensure the conversion fails. 841 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 842 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 843 844 // Drop this operation. 845 rewriter.eraseOp(op); 846 return success(); 847 } 848 }; 849 /// A simple pattern that creates a block at the end of the parent region of the 850 /// matched operation. 851 struct TestCreateBlock : public RewritePattern { 852 TestCreateBlock(MLIRContext *ctx) 853 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 854 855 LogicalResult matchAndRewrite(Operation *op, 856 PatternRewriter &rewriter) const final { 857 Region ®ion = *op->getParentRegion(); 858 Type i32Type = rewriter.getIntegerType(32); 859 Location loc = op->getLoc(); 860 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 861 rewriter.create<TerminatorOp>(loc); 862 rewriter.eraseOp(op); 863 return success(); 864 } 865 }; 866 867 /// A simple pattern that creates a block containing an invalid operation in 868 /// order to trigger the block creation undo mechanism. 869 struct TestCreateIllegalBlock : public RewritePattern { 870 TestCreateIllegalBlock(MLIRContext *ctx) 871 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 872 873 LogicalResult matchAndRewrite(Operation *op, 874 PatternRewriter &rewriter) const final { 875 Region ®ion = *op->getParentRegion(); 876 Type i32Type = rewriter.getIntegerType(32); 877 Location loc = op->getLoc(); 878 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 879 // Create an illegal op to ensure the conversion fails. 880 rewriter.create<ILLegalOpF>(loc, i32Type); 881 rewriter.create<TerminatorOp>(loc); 882 rewriter.eraseOp(op); 883 return success(); 884 } 885 }; 886 887 /// A simple pattern that tests the undo mechanism when replacing the uses of a 888 /// block argument. 889 struct TestUndoBlockArgReplace : public ConversionPattern { 890 TestUndoBlockArgReplace(MLIRContext *ctx) 891 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 892 893 LogicalResult 894 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 895 ConversionPatternRewriter &rewriter) const final { 896 auto illegalOp = 897 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 898 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 899 illegalOp->getResult(0)); 900 rewriter.modifyOpInPlace(op, [] {}); 901 return success(); 902 } 903 }; 904 905 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion. 906 /// This is to test the rollback logic. 907 struct TestUndoMoveOpBefore : public ConversionPattern { 908 TestUndoMoveOpBefore(MLIRContext *ctx) 909 : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {} 910 911 LogicalResult 912 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 913 ConversionPatternRewriter &rewriter) const override { 914 rewriter.moveOpBefore(op, op->getParentOp()); 915 // Replace with an illegal op to ensure the conversion fails. 916 rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type()); 917 return success(); 918 } 919 }; 920 921 /// A rewrite pattern that tests the undo mechanism when erasing a block. 922 struct TestUndoBlockErase : public ConversionPattern { 923 TestUndoBlockErase(MLIRContext *ctx) 924 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 925 926 LogicalResult 927 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 928 ConversionPatternRewriter &rewriter) const final { 929 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 930 rewriter.setInsertionPointToStart(secondBlock); 931 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 932 rewriter.eraseBlock(secondBlock); 933 rewriter.modifyOpInPlace(op, [] {}); 934 return success(); 935 } 936 }; 937 938 /// A pattern that modifies a property in-place, but keeps the op illegal. 939 struct TestUndoPropertiesModification : public ConversionPattern { 940 TestUndoPropertiesModification(MLIRContext *ctx) 941 : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {} 942 LogicalResult 943 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 944 ConversionPatternRewriter &rewriter) const final { 945 if (!op->hasAttr("modify_inplace")) 946 return failure(); 947 rewriter.modifyOpInPlace( 948 op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); }); 949 return success(); 950 } 951 }; 952 953 //===----------------------------------------------------------------------===// 954 // Type-Conversion Rewrite Testing 955 956 /// This patterns erases a region operation that has had a type conversion. 957 struct TestDropOpSignatureConversion : public ConversionPattern { 958 TestDropOpSignatureConversion(MLIRContext *ctx, 959 const TypeConverter &converter) 960 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 961 LogicalResult 962 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 963 ConversionPatternRewriter &rewriter) const override { 964 Region ®ion = op->getRegion(0); 965 Block *entry = ®ion.front(); 966 967 // Convert the original entry arguments. 968 const TypeConverter &converter = *getTypeConverter(); 969 TypeConverter::SignatureConversion result(entry->getNumArguments()); 970 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 971 result)) || 972 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 973 return failure(); 974 975 // Convert the region signature and just drop the operation. 976 rewriter.eraseOp(op); 977 return success(); 978 } 979 }; 980 /// This pattern simply updates the operands of the given operation. 981 struct TestPassthroughInvalidOp : public ConversionPattern { 982 TestPassthroughInvalidOp(MLIRContext *ctx) 983 : ConversionPattern("test.invalid", 1, ctx) {} 984 LogicalResult 985 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 986 ConversionPatternRewriter &rewriter) const final { 987 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, operands, 988 std::nullopt); 989 return success(); 990 } 991 }; 992 /// Replace with valid op, but simply drop the operands. This is used in a 993 /// regression where we used to generate circular unrealized_conversion_cast 994 /// ops. 995 struct TestDropAndReplaceInvalidOp : public ConversionPattern { 996 TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter) 997 : ConversionPattern(converter, 998 "test.drop_operands_and_replace_with_valid", 1, ctx) { 999 } 1000 LogicalResult 1001 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1002 ConversionPatternRewriter &rewriter) const final { 1003 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, ValueRange(), 1004 std::nullopt); 1005 return success(); 1006 } 1007 }; 1008 /// This pattern handles the case of a split return value. 1009 struct TestSplitReturnType : public ConversionPattern { 1010 TestSplitReturnType(MLIRContext *ctx) 1011 : ConversionPattern("test.return", 1, ctx) {} 1012 LogicalResult 1013 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1014 ConversionPatternRewriter &rewriter) const final { 1015 // Check for a return of F32. 1016 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 1017 return failure(); 1018 1019 // Check if the first operation is a cast operation, if it is we use the 1020 // results directly. 1021 auto *defOp = operands[0].getDefiningOp(); 1022 if (auto packerOp = 1023 llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) { 1024 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 1025 return success(); 1026 } 1027 1028 // Otherwise, fail to match. 1029 return failure(); 1030 } 1031 }; 1032 1033 //===----------------------------------------------------------------------===// 1034 // Multi-Level Type-Conversion Rewrite Testing 1035 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 1036 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 1037 : ConversionPattern("test.type_producer", 1, ctx) {} 1038 LogicalResult 1039 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1040 ConversionPatternRewriter &rewriter) const final { 1041 // If the type is I32, change the type to F32. 1042 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 1043 return failure(); 1044 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 1045 return success(); 1046 } 1047 }; 1048 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 1049 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 1050 : ConversionPattern("test.type_producer", 1, ctx) {} 1051 LogicalResult 1052 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1053 ConversionPatternRewriter &rewriter) const final { 1054 // If the type is F32, change the type to F64. 1055 if (!Type(*op->result_type_begin()).isF32()) 1056 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 1057 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 1058 return success(); 1059 } 1060 }; 1061 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 1062 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 1063 : ConversionPattern("test.type_producer", 10, ctx) {} 1064 LogicalResult 1065 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1066 ConversionPatternRewriter &rewriter) const final { 1067 // Always convert to B16, even though it is not a legal type. This tests 1068 // that values are unmapped correctly. 1069 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 1070 return success(); 1071 } 1072 }; 1073 struct TestUpdateConsumerType : public ConversionPattern { 1074 TestUpdateConsumerType(MLIRContext *ctx) 1075 : ConversionPattern("test.type_consumer", 1, ctx) {} 1076 LogicalResult 1077 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1078 ConversionPatternRewriter &rewriter) const final { 1079 // Verify that the incoming operand has been successfully remapped to F64. 1080 if (!operands[0].getType().isF64()) 1081 return failure(); 1082 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 1083 return success(); 1084 } 1085 }; 1086 1087 //===----------------------------------------------------------------------===// 1088 // Non-Root Replacement Rewrite Testing 1089 /// This pattern generates an invalid operation, but replaces it before the 1090 /// pattern is finished. This checks that we don't need to legalize the 1091 /// temporary op. 1092 struct TestNonRootReplacement : public RewritePattern { 1093 TestNonRootReplacement(MLIRContext *ctx) 1094 : RewritePattern("test.replace_non_root", 1, ctx) {} 1095 1096 LogicalResult matchAndRewrite(Operation *op, 1097 PatternRewriter &rewriter) const final { 1098 auto resultType = *op->result_type_begin(); 1099 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 1100 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 1101 1102 rewriter.replaceOp(illegalOp, legalOp); 1103 rewriter.replaceOp(op, illegalOp); 1104 return success(); 1105 } 1106 }; 1107 1108 //===----------------------------------------------------------------------===// 1109 // Recursive Rewrite Testing 1110 /// This pattern is applied to the same operation multiple times, but has a 1111 /// bounded recursion. 1112 struct TestBoundedRecursiveRewrite 1113 : public OpRewritePattern<TestRecursiveRewriteOp> { 1114 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 1115 1116 void initialize() { 1117 // The conversion target handles bounding the recursion of this pattern. 1118 setHasBoundedRewriteRecursion(); 1119 } 1120 1121 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 1122 PatternRewriter &rewriter) const final { 1123 // Decrement the depth of the op in-place. 1124 rewriter.modifyOpInPlace(op, [&] { 1125 op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1)); 1126 }); 1127 return success(); 1128 } 1129 }; 1130 1131 struct TestNestedOpCreationUndoRewrite 1132 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 1133 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 1134 1135 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 1136 PatternRewriter &rewriter) const final { 1137 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 1138 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 1139 return success(); 1140 }; 1141 }; 1142 1143 // This pattern matches `test.blackhole` and delete this op and its producer. 1144 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 1145 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 1146 1147 LogicalResult matchAndRewrite(BlackHoleOp op, 1148 PatternRewriter &rewriter) const final { 1149 Operation *producer = op.getOperand().getDefiningOp(); 1150 // Always erase the user before the producer, the framework should handle 1151 // this correctly. 1152 rewriter.eraseOp(op); 1153 rewriter.eraseOp(producer); 1154 return success(); 1155 }; 1156 }; 1157 1158 // This pattern replaces explicitly illegal op with explicitly legal op, 1159 // but in addition creates unregistered operation. 1160 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { 1161 using OpRewritePattern<ILLegalOpG>::OpRewritePattern; 1162 1163 LogicalResult matchAndRewrite(ILLegalOpG op, 1164 PatternRewriter &rewriter) const final { 1165 IntegerAttr attr = rewriter.getI32IntegerAttr(0); 1166 Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr); 1167 rewriter.replaceOpWithNewOp<LegalOpC>(op, val); 1168 return success(); 1169 }; 1170 }; 1171 1172 class TestEraseOp : public ConversionPattern { 1173 public: 1174 TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op", 1, ctx) {} 1175 LogicalResult 1176 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1177 ConversionPatternRewriter &rewriter) const final { 1178 // Erase op without replacements. 1179 rewriter.eraseOp(op); 1180 return success(); 1181 } 1182 }; 1183 1184 } // namespace 1185 1186 namespace { 1187 struct TestTypeConverter : public TypeConverter { 1188 using TypeConverter::TypeConverter; 1189 TestTypeConverter() { 1190 addConversion(convertType); 1191 addArgumentMaterialization(materializeCast); 1192 addSourceMaterialization(materializeCast); 1193 } 1194 1195 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 1196 // Drop I16 types. 1197 if (t.isSignlessInteger(16)) 1198 return success(); 1199 1200 // Convert I64 to F64. 1201 if (t.isSignlessInteger(64)) { 1202 results.push_back(FloatType::getF64(t.getContext())); 1203 return success(); 1204 } 1205 1206 // Convert I42 to I43. 1207 if (t.isInteger(42)) { 1208 results.push_back(IntegerType::get(t.getContext(), 43)); 1209 return success(); 1210 } 1211 1212 // Split F32 into F16,F16. 1213 if (t.isF32()) { 1214 results.assign(2, FloatType::getF16(t.getContext())); 1215 return success(); 1216 } 1217 1218 // Otherwise, convert the type directly. 1219 results.push_back(t); 1220 return success(); 1221 } 1222 1223 /// Hook for materializing a conversion. This is necessary because we generate 1224 /// 1->N type mappings. 1225 static Value materializeCast(OpBuilder &builder, Type resultType, 1226 ValueRange inputs, Location loc) { 1227 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1228 } 1229 }; 1230 1231 struct TestLegalizePatternDriver 1232 : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> { 1233 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver) 1234 1235 StringRef getArgument() const final { return "test-legalize-patterns"; } 1236 StringRef getDescription() const final { 1237 return "Run test dialect legalization patterns"; 1238 } 1239 /// The mode of conversion to use with the driver. 1240 enum class ConversionMode { Analysis, Full, Partial }; 1241 1242 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 1243 1244 void getDependentDialects(DialectRegistry ®istry) const override { 1245 registry.insert<func::FuncDialect, test::TestDialect>(); 1246 } 1247 1248 void runOnOperation() override { 1249 TestTypeConverter converter; 1250 mlir::RewritePatternSet patterns(&getContext()); 1251 populateWithGenerated(patterns); 1252 patterns.add< 1253 TestRegionRewriteBlockMovement, TestDetachedSignatureConversion, 1254 TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock, 1255 TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp, 1256 TestSplitReturnType, TestChangeProducerTypeI32ToF32, 1257 TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid, 1258 TestUpdateConsumerType, TestNonRootReplacement, 1259 TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite, 1260 TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore, 1261 TestUndoPropertiesModification, TestEraseOp>(&getContext()); 1262 patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>( 1263 &getContext(), converter); 1264 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1265 converter); 1266 mlir::populateCallOpTypeConversionPattern(patterns, converter); 1267 1268 // Define the conversion target used for the test. 1269 ConversionTarget target(getContext()); 1270 target.addLegalOp<ModuleOp>(); 1271 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, 1272 TerminatorOp, OneRegionOp>(); 1273 target.addLegalOp( 1274 OperationName("test.legal_op_with_region", &getContext())); 1275 target 1276 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 1277 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 1278 // Don't allow F32 operands. 1279 return llvm::none_of(op.getOperandTypes(), 1280 [](Type type) { return type.isF32(); }); 1281 }); 1282 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1283 return converter.isSignatureLegal(op.getFunctionType()) && 1284 converter.isLegal(&op.getBody()); 1285 }); 1286 target.addDynamicallyLegalOp<func::CallOp>( 1287 [&](func::CallOp op) { return converter.isLegal(op); }); 1288 1289 // TestCreateUnregisteredOp creates `arith.constant` operation, 1290 // which was not added to target intentionally to test 1291 // correct error code from conversion driver. 1292 target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); 1293 1294 // Expect the type_producer/type_consumer operations to only operate on f64. 1295 target.addDynamicallyLegalOp<TestTypeProducerOp>( 1296 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1297 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1298 return op.getOperand().getType().isF64(); 1299 }); 1300 1301 // Check support for marking certain operations as recursively legal. 1302 target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) { 1303 return static_cast<bool>( 1304 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 1305 }); 1306 1307 // Mark the bound recursion operation as dynamically legal. 1308 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 1309 [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); 1310 1311 // Create a dynamically legal rule that can only be legalized by folding it. 1312 target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>( 1313 [](TestOpInPlaceSelfFold op) { return op.getFolded(); }); 1314 1315 // Handle a partial conversion. 1316 if (mode == ConversionMode::Partial) { 1317 DenseSet<Operation *> unlegalizedOps; 1318 ConversionConfig config; 1319 DumpNotifications dumpNotifications; 1320 config.listener = &dumpNotifications; 1321 config.unlegalizedOps = &unlegalizedOps; 1322 if (failed(applyPartialConversion(getOperation(), target, 1323 std::move(patterns), config))) { 1324 getOperation()->emitRemark() << "applyPartialConversion failed"; 1325 } 1326 // Emit remarks for each legalizable operation. 1327 for (auto *op : unlegalizedOps) 1328 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1329 return; 1330 } 1331 1332 // Handle a full conversion. 1333 if (mode == ConversionMode::Full) { 1334 // Check support for marking unknown operations as dynamically legal. 1335 target.markUnknownOpDynamicallyLegal([](Operation *op) { 1336 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 1337 }); 1338 1339 ConversionConfig config; 1340 DumpNotifications dumpNotifications; 1341 config.listener = &dumpNotifications; 1342 if (failed(applyFullConversion(getOperation(), target, 1343 std::move(patterns), config))) { 1344 getOperation()->emitRemark() << "applyFullConversion failed"; 1345 } 1346 return; 1347 } 1348 1349 // Otherwise, handle an analysis conversion. 1350 assert(mode == ConversionMode::Analysis); 1351 1352 // Analyze the convertible operations. 1353 DenseSet<Operation *> legalizedOps; 1354 ConversionConfig config; 1355 config.legalizableOps = &legalizedOps; 1356 if (failed(applyAnalysisConversion(getOperation(), target, 1357 std::move(patterns), config))) 1358 return signalPassFailure(); 1359 1360 // Emit remarks for each legalizable operation. 1361 for (auto *op : legalizedOps) 1362 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 1363 } 1364 1365 /// The mode of conversion to use. 1366 ConversionMode mode; 1367 }; 1368 } // namespace 1369 1370 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 1371 legalizerConversionMode( 1372 "test-legalize-mode", 1373 llvm::cl::desc("The legalization mode to use with the test driver"), 1374 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 1375 llvm::cl::values( 1376 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 1377 "analysis", "Perform an analysis conversion"), 1378 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 1379 "Perform a full conversion"), 1380 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 1381 "partial", "Perform a partial conversion"))); 1382 1383 //===----------------------------------------------------------------------===// 1384 // ConversionPatternRewriter::getRemappedValue testing. This method is used 1385 // to get the remapped value of an original value that was replaced using 1386 // ConversionPatternRewriter. 1387 namespace { 1388 struct TestRemapValueTypeConverter : public TypeConverter { 1389 using TypeConverter::TypeConverter; 1390 1391 TestRemapValueTypeConverter() { 1392 addConversion( 1393 [](Float32Type type) { return Float64Type::get(type.getContext()); }); 1394 addConversion([](Type type) { return type; }); 1395 } 1396 }; 1397 1398 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 1399 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 1400 /// operand twice. 1401 /// 1402 /// Example: 1403 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 1404 /// is replaced with: 1405 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 1406 struct OneVResOneVOperandOp1Converter 1407 : public OpConversionPattern<OneVResOneVOperandOp1> { 1408 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 1409 1410 LogicalResult 1411 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, 1412 ConversionPatternRewriter &rewriter) const override { 1413 auto origOps = op.getOperands(); 1414 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 1415 "One operand expected"); 1416 Value origOp = *origOps.begin(); 1417 SmallVector<Value, 2> remappedOperands; 1418 // Replicate the remapped original operand twice. Note that we don't used 1419 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 1420 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1421 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1422 1423 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 1424 remappedOperands); 1425 return success(); 1426 } 1427 }; 1428 1429 /// A rewriter pattern that tests that blocks can be merged. 1430 struct TestRemapValueInRegion 1431 : public OpConversionPattern<TestRemappedValueRegionOp> { 1432 using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern; 1433 1434 LogicalResult 1435 matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor, 1436 ConversionPatternRewriter &rewriter) const final { 1437 Block &block = op.getBody().front(); 1438 Operation *terminator = block.getTerminator(); 1439 1440 // Merge the block into the parent region. 1441 Block *parentBlock = op->getBlock(); 1442 Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator()); 1443 rewriter.mergeBlocks(&block, parentBlock, ValueRange()); 1444 rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange()); 1445 1446 // Replace the results of this operation with the remapped terminator 1447 // values. 1448 SmallVector<Value> terminatorOperands; 1449 if (failed(rewriter.getRemappedValues(terminator->getOperands(), 1450 terminatorOperands))) 1451 return failure(); 1452 1453 rewriter.eraseOp(terminator); 1454 rewriter.replaceOp(op, terminatorOperands); 1455 return success(); 1456 } 1457 }; 1458 1459 struct TestRemappedValue 1460 : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> { 1461 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue) 1462 1463 StringRef getArgument() const final { return "test-remapped-value"; } 1464 StringRef getDescription() const final { 1465 return "Test public remapped value mechanism in ConversionPatternRewriter"; 1466 } 1467 void runOnOperation() override { 1468 TestRemapValueTypeConverter typeConverter; 1469 1470 mlir::RewritePatternSet patterns(&getContext()); 1471 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 1472 patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>( 1473 &getContext()); 1474 patterns.add<TestRemapValueInRegion>(typeConverter, &getContext()); 1475 1476 mlir::ConversionTarget target(getContext()); 1477 target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>(); 1478 1479 // Expect the type_producer/type_consumer operations to only operate on f64. 1480 target.addDynamicallyLegalOp<TestTypeProducerOp>( 1481 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1482 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1483 return op.getOperand().getType().isF64(); 1484 }); 1485 1486 // We make OneVResOneVOperandOp1 legal only when it has more that one 1487 // operand. This will trigger the conversion that will replace one-operand 1488 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 1489 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 1490 [](Operation *op) { return op->getNumOperands() > 1; }); 1491 1492 if (failed(mlir::applyFullConversion(getOperation(), target, 1493 std::move(patterns)))) { 1494 signalPassFailure(); 1495 } 1496 } 1497 }; 1498 } // namespace 1499 1500 //===----------------------------------------------------------------------===// 1501 // Test patterns without a specific root operation kind 1502 //===----------------------------------------------------------------------===// 1503 1504 namespace { 1505 /// This pattern matches and removes any operation in the test dialect. 1506 struct RemoveTestDialectOps : public RewritePattern { 1507 RemoveTestDialectOps(MLIRContext *context) 1508 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 1509 1510 LogicalResult matchAndRewrite(Operation *op, 1511 PatternRewriter &rewriter) const override { 1512 if (!isa<TestDialect>(op->getDialect())) 1513 return failure(); 1514 rewriter.eraseOp(op); 1515 return success(); 1516 } 1517 }; 1518 1519 struct TestUnknownRootOpDriver 1520 : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> { 1521 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver) 1522 1523 StringRef getArgument() const final { 1524 return "test-legalize-unknown-root-patterns"; 1525 } 1526 StringRef getDescription() const final { 1527 return "Test public remapped value mechanism in ConversionPatternRewriter"; 1528 } 1529 void runOnOperation() override { 1530 mlir::RewritePatternSet patterns(&getContext()); 1531 patterns.add<RemoveTestDialectOps>(&getContext()); 1532 1533 mlir::ConversionTarget target(getContext()); 1534 target.addIllegalDialect<TestDialect>(); 1535 if (failed(applyPartialConversion(getOperation(), target, 1536 std::move(patterns)))) 1537 signalPassFailure(); 1538 } 1539 }; 1540 } // namespace 1541 1542 //===----------------------------------------------------------------------===// 1543 // Test patterns that uses operations and types defined at runtime 1544 //===----------------------------------------------------------------------===// 1545 1546 namespace { 1547 /// This pattern matches dynamic operations 'test.one_operand_two_results' and 1548 /// replace them with dynamic operations 'test.generic_dynamic_op'. 1549 struct RewriteDynamicOp : public RewritePattern { 1550 RewriteDynamicOp(MLIRContext *context) 1551 : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1, 1552 context) {} 1553 1554 LogicalResult matchAndRewrite(Operation *op, 1555 PatternRewriter &rewriter) const override { 1556 assert(op->getName().getStringRef() == 1557 "test.dynamic_one_operand_two_results" && 1558 "rewrite pattern should only match operations with the right name"); 1559 1560 OperationState state(op->getLoc(), "test.dynamic_generic", 1561 op->getOperands(), op->getResultTypes(), 1562 op->getAttrs()); 1563 auto *newOp = rewriter.create(state); 1564 rewriter.replaceOp(op, newOp->getResults()); 1565 return success(); 1566 } 1567 }; 1568 1569 struct TestRewriteDynamicOpDriver 1570 : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> { 1571 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver) 1572 1573 void getDependentDialects(DialectRegistry ®istry) const override { 1574 registry.insert<TestDialect>(); 1575 } 1576 StringRef getArgument() const final { return "test-rewrite-dynamic-op"; } 1577 StringRef getDescription() const final { 1578 return "Test rewritting on dynamic operations"; 1579 } 1580 void runOnOperation() override { 1581 RewritePatternSet patterns(&getContext()); 1582 patterns.add<RewriteDynamicOp>(&getContext()); 1583 1584 ConversionTarget target(getContext()); 1585 target.addIllegalOp( 1586 OperationName("test.dynamic_one_operand_two_results", &getContext())); 1587 target.addLegalOp(OperationName("test.dynamic_generic", &getContext())); 1588 if (failed(applyPartialConversion(getOperation(), target, 1589 std::move(patterns)))) 1590 signalPassFailure(); 1591 } 1592 }; 1593 } // end anonymous namespace 1594 1595 //===----------------------------------------------------------------------===// 1596 // Test type conversions 1597 //===----------------------------------------------------------------------===// 1598 1599 namespace { 1600 struct TestTypeConversionProducer 1601 : public OpConversionPattern<TestTypeProducerOp> { 1602 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 1603 LogicalResult 1604 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, 1605 ConversionPatternRewriter &rewriter) const final { 1606 Type resultType = op.getType(); 1607 Type convertedType = getTypeConverter() 1608 ? getTypeConverter()->convertType(resultType) 1609 : resultType; 1610 if (isa<FloatType>(resultType)) 1611 resultType = rewriter.getF64Type(); 1612 else if (resultType.isInteger(16)) 1613 resultType = rewriter.getIntegerType(64); 1614 else if (isa<test::TestRecursiveType>(resultType) && 1615 convertedType != resultType) 1616 resultType = convertedType; 1617 else 1618 return failure(); 1619 1620 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 1621 return success(); 1622 } 1623 }; 1624 1625 /// Call signature conversion and then fail the rewrite to trigger the undo 1626 /// mechanism. 1627 struct TestSignatureConversionUndo 1628 : public OpConversionPattern<TestSignatureConversionUndoOp> { 1629 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 1630 1631 LogicalResult 1632 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, 1633 ConversionPatternRewriter &rewriter) const final { 1634 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 1635 return failure(); 1636 } 1637 }; 1638 1639 /// Call signature conversion without providing a type converter to handle 1640 /// materializations. 1641 struct TestTestSignatureConversionNoConverter 1642 : public OpConversionPattern<TestSignatureConversionNoConverterOp> { 1643 TestTestSignatureConversionNoConverter(const TypeConverter &converter, 1644 MLIRContext *context) 1645 : OpConversionPattern<TestSignatureConversionNoConverterOp>(context), 1646 converter(converter) {} 1647 1648 LogicalResult 1649 matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor, 1650 ConversionPatternRewriter &rewriter) const final { 1651 Region ®ion = op->getRegion(0); 1652 Block *entry = ®ion.front(); 1653 1654 // Convert the original entry arguments. 1655 TypeConverter::SignatureConversion result(entry->getNumArguments()); 1656 if (failed( 1657 converter.convertSignatureArgs(entry->getArgumentTypes(), result))) 1658 return failure(); 1659 rewriter.modifyOpInPlace(op, [&] { 1660 rewriter.applySignatureConversion(®ion.front(), result); 1661 }); 1662 return success(); 1663 } 1664 1665 const TypeConverter &converter; 1666 }; 1667 1668 /// Just forward the operands to the root op. This is essentially a no-op 1669 /// pattern that is used to trigger target materialization. 1670 struct TestTypeConsumerForward 1671 : public OpConversionPattern<TestTypeConsumerOp> { 1672 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 1673 1674 LogicalResult 1675 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, 1676 ConversionPatternRewriter &rewriter) const final { 1677 rewriter.modifyOpInPlace(op, 1678 [&] { op->setOperands(adaptor.getOperands()); }); 1679 return success(); 1680 } 1681 }; 1682 1683 struct TestTypeConversionAnotherProducer 1684 : public OpRewritePattern<TestAnotherTypeProducerOp> { 1685 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 1686 1687 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 1688 PatternRewriter &rewriter) const final { 1689 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 1690 return success(); 1691 } 1692 }; 1693 1694 struct TestReplaceWithLegalOp : public ConversionPattern { 1695 TestReplaceWithLegalOp(MLIRContext *ctx) 1696 : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {} 1697 LogicalResult 1698 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1699 ConversionPatternRewriter &rewriter) const final { 1700 rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]); 1701 return success(); 1702 } 1703 }; 1704 1705 struct TestTypeConversionDriver 1706 : public PassWrapper<TestTypeConversionDriver, OperationPass<>> { 1707 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver) 1708 1709 void getDependentDialects(DialectRegistry ®istry) const override { 1710 registry.insert<TestDialect>(); 1711 } 1712 StringRef getArgument() const final { 1713 return "test-legalize-type-conversion"; 1714 } 1715 StringRef getDescription() const final { 1716 return "Test various type conversion functionalities in DialectConversion"; 1717 } 1718 1719 void runOnOperation() override { 1720 // Initialize the type converter. 1721 SmallVector<Type, 2> conversionCallStack; 1722 TypeConverter converter; 1723 1724 /// Add the legal set of type conversions. 1725 converter.addConversion([](Type type) -> Type { 1726 // Treat F64 as legal. 1727 if (type.isF64()) 1728 return type; 1729 // Allow converting BF16/F16/F32 to F64. 1730 if (type.isBF16() || type.isF16() || type.isF32()) 1731 return FloatType::getF64(type.getContext()); 1732 // Otherwise, the type is illegal. 1733 return nullptr; 1734 }); 1735 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 1736 // Drop all integer types. 1737 return success(); 1738 }); 1739 converter.addConversion( 1740 // Convert a recursive self-referring type into a non-self-referring 1741 // type named "outer_converted_type" that contains a SimpleAType. 1742 [&](test::TestRecursiveType type, 1743 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { 1744 // If the type is already converted, return it to indicate that it is 1745 // legal. 1746 if (type.getName() == "outer_converted_type") { 1747 results.push_back(type); 1748 return success(); 1749 } 1750 1751 conversionCallStack.push_back(type); 1752 auto popConversionCallStack = llvm::make_scope_exit( 1753 [&conversionCallStack]() { conversionCallStack.pop_back(); }); 1754 1755 // If the type is on the call stack more than once (it is there at 1756 // least once because of the _current_ call, which is always the last 1757 // element on the stack), we've hit the recursive case. Just return 1758 // SimpleAType here to create a non-recursive type as a result. 1759 if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(), 1760 type)) { 1761 results.push_back(test::SimpleAType::get(type.getContext())); 1762 return success(); 1763 } 1764 1765 // Convert the body recursively. 1766 auto result = test::TestRecursiveType::get(type.getContext(), 1767 "outer_converted_type"); 1768 if (failed(result.setBody(converter.convertType(type.getBody())))) 1769 return failure(); 1770 results.push_back(result); 1771 return success(); 1772 }); 1773 1774 /// Add the legal set of type materializations. 1775 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 1776 ValueRange inputs, 1777 Location loc) -> Value { 1778 // Allow casting from F64 back to F32. 1779 if (!resultType.isF16() && inputs.size() == 1 && 1780 inputs[0].getType().isF64()) 1781 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1782 // Allow producing an i32 or i64 from nothing. 1783 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 1784 inputs.empty()) 1785 return builder.create<TestTypeProducerOp>(loc, resultType); 1786 // Allow producing an i64 from an integer. 1787 if (isa<IntegerType>(resultType) && inputs.size() == 1 && 1788 isa<IntegerType>(inputs[0].getType())) 1789 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1790 // Otherwise, fail. 1791 return nullptr; 1792 }); 1793 1794 // Initialize the conversion target. 1795 mlir::ConversionTarget target(getContext()); 1796 target.addLegalOp<LegalOpD>(); 1797 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 1798 auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType()); 1799 return op.getType().isF64() || op.getType().isInteger(64) || 1800 (recursiveType && 1801 recursiveType.getName() == "outer_converted_type"); 1802 }); 1803 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1804 return converter.isSignatureLegal(op.getFunctionType()) && 1805 converter.isLegal(&op.getBody()); 1806 }); 1807 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 1808 // Allow casts from F64 to F32. 1809 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 1810 }); 1811 target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>( 1812 [&](TestSignatureConversionNoConverterOp op) { 1813 return converter.isLegal(op.getRegion().front().getArgumentTypes()); 1814 }); 1815 1816 // Initialize the set of rewrite patterns. 1817 RewritePatternSet patterns(&getContext()); 1818 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 1819 TestSignatureConversionUndo, 1820 TestTestSignatureConversionNoConverter>(converter, 1821 &getContext()); 1822 patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>( 1823 &getContext()); 1824 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1825 converter); 1826 1827 if (failed(applyPartialConversion(getOperation(), target, 1828 std::move(patterns)))) 1829 signalPassFailure(); 1830 } 1831 }; 1832 } // namespace 1833 1834 //===----------------------------------------------------------------------===// 1835 // Test Target Materialization With No Uses 1836 //===----------------------------------------------------------------------===// 1837 1838 namespace { 1839 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> { 1840 using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern; 1841 1842 LogicalResult 1843 matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor, 1844 ConversionPatternRewriter &rewriter) const final { 1845 rewriter.replaceOp(op, adaptor.getOperands()); 1846 return success(); 1847 } 1848 }; 1849 1850 struct TestTargetMaterializationWithNoUses 1851 : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> { 1852 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 1853 TestTargetMaterializationWithNoUses) 1854 1855 StringRef getArgument() const final { 1856 return "test-target-materialization-with-no-uses"; 1857 } 1858 StringRef getDescription() const final { 1859 return "Test a special case of target materialization in DialectConversion"; 1860 } 1861 1862 void runOnOperation() override { 1863 TypeConverter converter; 1864 converter.addConversion([](Type t) { return t; }); 1865 converter.addConversion([](IntegerType intTy) -> Type { 1866 if (intTy.getWidth() == 16) 1867 return IntegerType::get(intTy.getContext(), 64); 1868 return intTy; 1869 }); 1870 converter.addTargetMaterialization( 1871 [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { 1872 return builder.create<TestCastOp>(loc, type, inputs).getResult(); 1873 }); 1874 1875 ConversionTarget target(getContext()); 1876 target.addIllegalOp<TestTypeChangerOp>(); 1877 1878 RewritePatternSet patterns(&getContext()); 1879 patterns.add<ForwardOperandPattern>(converter, &getContext()); 1880 1881 if (failed(applyPartialConversion(getOperation(), target, 1882 std::move(patterns)))) 1883 signalPassFailure(); 1884 } 1885 }; 1886 } // namespace 1887 1888 //===----------------------------------------------------------------------===// 1889 // Test Block Merging 1890 //===----------------------------------------------------------------------===// 1891 1892 namespace { 1893 /// A rewriter pattern that tests that blocks can be merged. 1894 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 1895 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 1896 1897 LogicalResult 1898 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, 1899 ConversionPatternRewriter &rewriter) const final { 1900 Block &firstBlock = op.getBody().front(); 1901 Operation *branchOp = firstBlock.getTerminator(); 1902 Block *secondBlock = &*(std::next(op.getBody().begin())); 1903 auto succOperands = branchOp->getOperands(); 1904 SmallVector<Value, 2> replacements(succOperands); 1905 rewriter.eraseOp(branchOp); 1906 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1907 rewriter.modifyOpInPlace(op, [] {}); 1908 return success(); 1909 } 1910 }; 1911 1912 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 1913 struct TestUndoBlocksMerge : public ConversionPattern { 1914 TestUndoBlocksMerge(MLIRContext *ctx) 1915 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 1916 LogicalResult 1917 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1918 ConversionPatternRewriter &rewriter) const final { 1919 Block &firstBlock = op->getRegion(0).front(); 1920 Operation *branchOp = firstBlock.getTerminator(); 1921 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 1922 rewriter.setInsertionPointToStart(secondBlock); 1923 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 1924 auto succOperands = branchOp->getOperands(); 1925 SmallVector<Value, 2> replacements(succOperands); 1926 rewriter.eraseOp(branchOp); 1927 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1928 rewriter.modifyOpInPlace(op, [] {}); 1929 return success(); 1930 } 1931 }; 1932 1933 /// A rewrite mechanism to inline the body of the op into its parent, when both 1934 /// ops can have a single block. 1935 struct TestMergeSingleBlockOps 1936 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 1937 using OpConversionPattern< 1938 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 1939 1940 LogicalResult 1941 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, 1942 ConversionPatternRewriter &rewriter) const final { 1943 SingleBlockImplicitTerminatorOp parentOp = 1944 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1945 if (!parentOp) 1946 return failure(); 1947 Block &innerBlock = op.getRegion().front(); 1948 TerminatorOp innerTerminator = 1949 cast<TerminatorOp>(innerBlock.getTerminator()); 1950 rewriter.inlineBlockBefore(&innerBlock, op); 1951 rewriter.eraseOp(innerTerminator); 1952 rewriter.eraseOp(op); 1953 return success(); 1954 } 1955 }; 1956 1957 struct TestMergeBlocksPatternDriver 1958 : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> { 1959 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver) 1960 1961 StringRef getArgument() const final { return "test-merge-blocks"; } 1962 StringRef getDescription() const final { 1963 return "Test Merging operation in ConversionPatternRewriter"; 1964 } 1965 void runOnOperation() override { 1966 MLIRContext *context = &getContext(); 1967 mlir::RewritePatternSet patterns(context); 1968 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 1969 context); 1970 ConversionTarget target(*context); 1971 target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 1972 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 1973 target.addIllegalOp<ILLegalOpF>(); 1974 1975 /// Expect the op to have a single block after legalization. 1976 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 1977 [&](TestMergeBlocksOp op) -> bool { 1978 return llvm::hasSingleElement(op.getBody()); 1979 }); 1980 1981 /// Only allow `test.br` within test.merge_blocks op. 1982 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 1983 return op->getParentOfType<TestMergeBlocksOp>(); 1984 }); 1985 1986 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 1987 /// inlined. 1988 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 1989 [&](SingleBlockImplicitTerminatorOp op) -> bool { 1990 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1991 }); 1992 1993 DenseSet<Operation *> unlegalizedOps; 1994 ConversionConfig config; 1995 config.unlegalizedOps = &unlegalizedOps; 1996 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 1997 config); 1998 for (auto *op : unlegalizedOps) 1999 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 2000 } 2001 }; 2002 } // namespace 2003 2004 //===----------------------------------------------------------------------===// 2005 // Test Selective Replacement 2006 //===----------------------------------------------------------------------===// 2007 2008 namespace { 2009 /// A rewrite mechanism to inline the body of the op into its parent, when both 2010 /// ops can have a single block. 2011 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 2012 using OpRewritePattern<TestCastOp>::OpRewritePattern; 2013 2014 LogicalResult matchAndRewrite(TestCastOp op, 2015 PatternRewriter &rewriter) const final { 2016 if (op.getNumOperands() != 2) 2017 return failure(); 2018 OperandRange operands = op.getOperands(); 2019 2020 // Replace non-terminator uses with the first operand. 2021 rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) { 2022 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 2023 }); 2024 // Replace everything else with the second operand if the operation isn't 2025 // dead. 2026 rewriter.replaceOp(op, op.getOperand(1)); 2027 return success(); 2028 } 2029 }; 2030 2031 struct TestSelectiveReplacementPatternDriver 2032 : public PassWrapper<TestSelectiveReplacementPatternDriver, 2033 OperationPass<>> { 2034 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 2035 TestSelectiveReplacementPatternDriver) 2036 2037 StringRef getArgument() const final { 2038 return "test-pattern-selective-replacement"; 2039 } 2040 StringRef getDescription() const final { 2041 return "Test selective replacement in the PatternRewriter"; 2042 } 2043 void runOnOperation() override { 2044 MLIRContext *context = &getContext(); 2045 mlir::RewritePatternSet patterns(context); 2046 patterns.add<TestSelectiveOpReplacementPattern>(context); 2047 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 2048 } 2049 }; 2050 } // namespace 2051 2052 //===----------------------------------------------------------------------===// 2053 // PassRegistration 2054 //===----------------------------------------------------------------------===// 2055 2056 namespace mlir { 2057 namespace test { 2058 void registerPatternsTestPass() { 2059 PassRegistration<TestReturnTypeDriver>(); 2060 2061 PassRegistration<TestDerivedAttributeDriver>(); 2062 2063 PassRegistration<TestGreedyPatternDriver>(); 2064 PassRegistration<TestStrictPatternDriver>(); 2065 PassRegistration<TestWalkPatternDriver>(); 2066 2067 PassRegistration<TestLegalizePatternDriver>([] { 2068 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 2069 }); 2070 2071 PassRegistration<TestRemappedValue>(); 2072 2073 PassRegistration<TestUnknownRootOpDriver>(); 2074 2075 PassRegistration<TestTypeConversionDriver>(); 2076 PassRegistration<TestTargetMaterializationWithNoUses>(); 2077 2078 PassRegistration<TestRewriteDynamicOpDriver>(); 2079 2080 PassRegistration<TestMergeBlocksPatternDriver>(); 2081 PassRegistration<TestSelectiveReplacementPatternDriver>(); 2082 } 2083 } // namespace test 2084 } // namespace mlir 2085