1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestOps.h" 11 #include "TestTypes.h" 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/BuiltinAttributes.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/Visitors.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "mlir/Transforms/FoldUtils.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 #include "mlir/Transforms/WalkPatternRewriteDriver.h" 24 #include "llvm/ADT/ScopeExit.h" 25 #include <cstdint> 26 27 using namespace mlir; 28 using namespace test; 29 30 // Native function for testing NativeCodeCall 31 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 32 return choice.getValue() ? input1 : input2; 33 } 34 35 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 36 rewriter.create<OpI>(loc, input); 37 } 38 39 static void handleNoResultOp(PatternRewriter &rewriter, 40 OpSymbolBindingNoResult op) { 41 // Turn the no result op to a one-result op. 42 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), 43 op.getOperand()); 44 } 45 46 static bool getFirstI32Result(Operation *op, Value &value) { 47 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 48 return false; 49 value = op->getResult(0); 50 return true; 51 } 52 53 static Value bindNativeCodeCallResult(Value value) { return value; } 54 55 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, 56 Value input2) { 57 return SmallVector<Value, 2>({input2, input1}); 58 } 59 60 // Test that natives calls are only called once during rewrites. 61 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 62 // This let us check the number of times OpM_Test was called by inspecting 63 // the returned value in the MLIR output. 64 static int64_t opMIncreasingValue = 314159265; 65 static Attribute opMTest(PatternRewriter &rewriter, Value val) { 66 int64_t i = opMIncreasingValue++; 67 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 68 } 69 70 namespace { 71 #include "TestPatterns.inc" 72 } // namespace 73 74 //===----------------------------------------------------------------------===// 75 // Test Reduce Pattern Interface 76 //===----------------------------------------------------------------------===// 77 78 void test::populateTestReductionPatterns(RewritePatternSet &patterns) { 79 populateWithGenerated(patterns); 80 } 81 82 //===----------------------------------------------------------------------===// 83 // Canonicalizer Driver. 84 //===----------------------------------------------------------------------===// 85 86 namespace { 87 struct FoldingPattern : public RewritePattern { 88 public: 89 FoldingPattern(MLIRContext *context) 90 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 91 /*benefit=*/1, context) {} 92 93 LogicalResult matchAndRewrite(Operation *op, 94 PatternRewriter &rewriter) const override { 95 // Exercise createOrFold API for a single-result operation that is folded 96 // upon construction. The operation being created has an in-place folder, 97 // and it should be still present in the output. Furthermore, the folder 98 // should not crash when attempting to recover the (unchanged) operation 99 // result. 100 Value result = rewriter.createOrFold<TestOpInPlaceFold>( 101 op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0)); 102 assert(result); 103 rewriter.replaceOp(op, result); 104 return success(); 105 } 106 }; 107 108 /// This pattern creates a foldable operation at the entry point of the block. 109 /// This tests the situation where the operation folder will need to replace an 110 /// operation with a previously created constant that does not initially 111 /// dominate the operation to replace. 112 struct FolderInsertBeforePreviouslyFoldedConstantPattern 113 : public OpRewritePattern<TestCastOp> { 114 public: 115 using OpRewritePattern<TestCastOp>::OpRewritePattern; 116 117 LogicalResult matchAndRewrite(TestCastOp op, 118 PatternRewriter &rewriter) const override { 119 if (!op->hasAttr("test_fold_before_previously_folded_op")) 120 return failure(); 121 rewriter.setInsertionPointToStart(op->getBlock()); 122 123 auto constOp = rewriter.create<arith::ConstantOp>( 124 op.getLoc(), rewriter.getBoolAttr(true)); 125 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), 126 Value(constOp)); 127 return success(); 128 } 129 }; 130 131 /// This pattern matches test.op_commutative2 with the first operand being 132 /// another test.op_commutative2 with a constant on the right side and fold it 133 /// away by propagating it as its result. This is intend to check that patterns 134 /// are applied after the commutative property moves constant to the right. 135 struct FolderCommutativeOp2WithConstant 136 : public OpRewritePattern<TestCommutative2Op> { 137 public: 138 using OpRewritePattern<TestCommutative2Op>::OpRewritePattern; 139 140 LogicalResult matchAndRewrite(TestCommutative2Op op, 141 PatternRewriter &rewriter) const override { 142 auto operand = 143 dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp()); 144 if (!operand) 145 return failure(); 146 Attribute constInput; 147 if (!matchPattern(operand->getOperand(1), m_Constant(&constInput))) 148 return failure(); 149 rewriter.replaceOp(op, operand->getOperand(1)); 150 return success(); 151 } 152 }; 153 154 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer 155 /// attribute with value smaller than MaxVal, it increments the value by 1. 156 template <int MaxVal> 157 struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> { 158 using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern; 159 160 LogicalResult matchAndRewrite(AnyAttrOfOp op, 161 PatternRewriter &rewriter) const override { 162 auto intAttr = dyn_cast<IntegerAttr>(op.getAttr()); 163 if (!intAttr) 164 return failure(); 165 int64_t val = intAttr.getInt(); 166 if (val >= MaxVal) 167 return failure(); 168 rewriter.modifyOpInPlace( 169 op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); }); 170 return success(); 171 } 172 }; 173 174 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op". 175 struct MakeOpEligible : public RewritePattern { 176 MakeOpEligible(MLIRContext *context) 177 : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {} 178 179 LogicalResult matchAndRewrite(Operation *op, 180 PatternRewriter &rewriter) const override { 181 if (op->hasAttr("eligible")) 182 return failure(); 183 rewriter.modifyOpInPlace( 184 op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); }); 185 return success(); 186 } 187 }; 188 189 /// This pattern hoists eligible ops out of a "test.one_region_op". 190 struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> { 191 using OpRewritePattern<test::OneRegionOp>::OpRewritePattern; 192 193 LogicalResult matchAndRewrite(test::OneRegionOp op, 194 PatternRewriter &rewriter) const override { 195 Operation *terminator = op.getRegion().front().getTerminator(); 196 Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp(); 197 if (toBeHoisted->getParentOp() != op) 198 return failure(); 199 if (!toBeHoisted->hasAttr("eligible")) 200 return failure(); 201 rewriter.moveOpBefore(toBeHoisted, op); 202 return success(); 203 } 204 }; 205 206 /// This pattern moves "test.move_before_parent_op" before the parent op. 207 struct MoveBeforeParentOp : public RewritePattern { 208 MoveBeforeParentOp(MLIRContext *context) 209 : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {} 210 211 LogicalResult matchAndRewrite(Operation *op, 212 PatternRewriter &rewriter) const override { 213 // Do not hoist past functions. 214 if (isa<FunctionOpInterface>(op->getParentOp())) 215 return failure(); 216 rewriter.moveOpBefore(op, op->getParentOp()); 217 return success(); 218 } 219 }; 220 221 /// This pattern moves "test.move_after_parent_op" after the parent op. 222 struct MoveAfterParentOp : public RewritePattern { 223 MoveAfterParentOp(MLIRContext *context) 224 : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {} 225 226 LogicalResult matchAndRewrite(Operation *op, 227 PatternRewriter &rewriter) const override { 228 // Do not hoist past functions. 229 if (isa<FunctionOpInterface>(op->getParentOp())) 230 return failure(); 231 232 int64_t moveForwardBy = 0; 233 if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance")) 234 moveForwardBy = advanceBy.getInt(); 235 236 Operation *moveAfter = op->getParentOp(); 237 for (int64_t i = 0; i < moveForwardBy; ++i) 238 moveAfter = moveAfter->getNextNode(); 239 240 rewriter.moveOpAfter(op, moveAfter); 241 return success(); 242 } 243 }; 244 245 /// This pattern inlines blocks that are nested in 246 /// "test.inline_blocks_into_parent" into the parent block. 247 struct InlineBlocksIntoParent : public RewritePattern { 248 InlineBlocksIntoParent(MLIRContext *context) 249 : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1, 250 context) {} 251 252 LogicalResult matchAndRewrite(Operation *op, 253 PatternRewriter &rewriter) const override { 254 bool changed = false; 255 for (Region &r : op->getRegions()) { 256 while (!r.empty()) { 257 rewriter.inlineBlockBefore(&r.front(), op); 258 changed = true; 259 } 260 } 261 return success(changed); 262 } 263 }; 264 265 /// This pattern splits blocks at "test.split_block_here" and replaces the op 266 /// with a new op (to prevent an infinite loop of block splitting). 267 struct SplitBlockHere : public RewritePattern { 268 SplitBlockHere(MLIRContext *context) 269 : RewritePattern("test.split_block_here", /*benefit=*/1, context) {} 270 271 LogicalResult matchAndRewrite(Operation *op, 272 PatternRewriter &rewriter) const override { 273 rewriter.splitBlock(op->getBlock(), op->getIterator()); 274 Operation *newOp = rewriter.create( 275 op->getLoc(), 276 OperationName("test.new_op", op->getContext()).getIdentifier(), 277 op->getOperands(), op->getResultTypes()); 278 rewriter.replaceOp(op, newOp); 279 return success(); 280 } 281 }; 282 283 /// This pattern clones "test.clone_me" ops. 284 struct CloneOp : public RewritePattern { 285 CloneOp(MLIRContext *context) 286 : RewritePattern("test.clone_me", /*benefit=*/1, context) {} 287 288 LogicalResult matchAndRewrite(Operation *op, 289 PatternRewriter &rewriter) const override { 290 // Do not clone already cloned ops to avoid going into an infinite loop. 291 if (op->hasAttr("was_cloned")) 292 return failure(); 293 Operation *cloned = rewriter.clone(*op); 294 cloned->setAttr("was_cloned", rewriter.getUnitAttr()); 295 return success(); 296 } 297 }; 298 299 /// This pattern clones regions of "test.clone_region_before" ops before the 300 /// parent block. 301 struct CloneRegionBeforeOp : public RewritePattern { 302 CloneRegionBeforeOp(MLIRContext *context) 303 : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {} 304 305 LogicalResult matchAndRewrite(Operation *op, 306 PatternRewriter &rewriter) const override { 307 // Do not clone already cloned ops to avoid going into an infinite loop. 308 if (op->hasAttr("was_cloned")) 309 return failure(); 310 for (Region &r : op->getRegions()) 311 rewriter.cloneRegionBefore(r, op->getBlock()); 312 op->setAttr("was_cloned", rewriter.getUnitAttr()); 313 return success(); 314 } 315 }; 316 317 /// Replace an operation may introduce the re-visiting of its users. 318 class ReplaceWithNewOp : public RewritePattern { 319 public: 320 ReplaceWithNewOp(MLIRContext *context) 321 : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {} 322 323 LogicalResult matchAndRewrite(Operation *op, 324 PatternRewriter &rewriter) const override { 325 Operation *newOp; 326 if (op->hasAttr("create_erase_op")) { 327 newOp = rewriter.create( 328 op->getLoc(), 329 OperationName("test.erase_op", op->getContext()).getIdentifier(), 330 ValueRange(), TypeRange()); 331 } else { 332 newOp = rewriter.create( 333 op->getLoc(), 334 OperationName("test.new_op", op->getContext()).getIdentifier(), 335 op->getOperands(), op->getResultTypes()); 336 } 337 // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp". 338 // A "notifyOperationReplaced" callback is triggered in either case. 339 rewriter.replaceAllOpUsesWith(op, newOp->getResults()); 340 rewriter.eraseOp(op); 341 return success(); 342 } 343 }; 344 345 /// Erases the first child block of the matched "test.erase_first_block" 346 /// operation. 347 class EraseFirstBlock : public RewritePattern { 348 public: 349 EraseFirstBlock(MLIRContext *context) 350 : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {} 351 352 LogicalResult matchAndRewrite(Operation *op, 353 PatternRewriter &rewriter) const override { 354 for (Region &r : op->getRegions()) { 355 for (Block &b : r.getBlocks()) { 356 rewriter.eraseBlock(&b); 357 return success(); 358 } 359 } 360 361 return failure(); 362 } 363 }; 364 365 struct TestGreedyPatternDriver 366 : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> { 367 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver) 368 369 TestGreedyPatternDriver() = default; 370 TestGreedyPatternDriver(const TestGreedyPatternDriver &other) 371 : PassWrapper(other) {} 372 373 StringRef getArgument() const final { return "test-greedy-patterns"; } 374 StringRef getDescription() const final { return "Run test dialect patterns"; } 375 void runOnOperation() override { 376 mlir::RewritePatternSet patterns(&getContext()); 377 populateWithGenerated(patterns); 378 379 // Verify named pattern is generated with expected name. 380 patterns.add<FoldingPattern, TestNamedPatternRule, 381 FolderInsertBeforePreviouslyFoldedConstantPattern, 382 FolderCommutativeOp2WithConstant, HoistEligibleOps, 383 MakeOpEligible>(&getContext()); 384 385 // Additional patterns for testing the GreedyPatternRewriteDriver. 386 patterns.insert<IncrementIntAttribute<3>>(&getContext()); 387 388 GreedyRewriteConfig config; 389 config.useTopDownTraversal = this->useTopDownTraversal; 390 config.maxIterations = this->maxIterations; 391 config.fold = this->fold; 392 config.cseConstants = this->cseConstants; 393 (void)applyPatternsGreedily(getOperation(), std::move(patterns), config); 394 } 395 396 Option<bool> useTopDownTraversal{ 397 *this, "top-down", 398 llvm::cl::desc("Seed the worklist in general top-down order"), 399 llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)}; 400 Option<int> maxIterations{ 401 *this, "max-iterations", 402 llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"), 403 llvm::cl::init(GreedyRewriteConfig().maxIterations)}; 404 Option<bool> fold{*this, "fold", llvm::cl::desc("Whether to fold"), 405 llvm::cl::init(GreedyRewriteConfig().fold)}; 406 Option<bool> cseConstants{*this, "cse-constants", 407 llvm::cl::desc("Whether to CSE constants"), 408 llvm::cl::init(GreedyRewriteConfig().cseConstants)}; 409 }; 410 411 struct DumpNotifications : public RewriterBase::Listener { 412 void notifyBlockInserted(Block *block, Region *previous, 413 Region::iterator previousIt) override { 414 llvm::outs() << "notifyBlockInserted"; 415 if (block->getParentOp()) { 416 llvm::outs() << " into " << block->getParentOp()->getName() << ": "; 417 } else { 418 llvm::outs() << " into unknown op: "; 419 } 420 if (previous == nullptr) { 421 llvm::outs() << "was unlinked\n"; 422 } else { 423 llvm::outs() << "was linked\n"; 424 } 425 } 426 void notifyOperationInserted(Operation *op, 427 OpBuilder::InsertPoint previous) override { 428 llvm::outs() << "notifyOperationInserted: " << op->getName(); 429 if (!previous.isSet()) { 430 llvm::outs() << ", was unlinked\n"; 431 } else { 432 if (!previous.getPoint().getNodePtr()) { 433 llvm::outs() << ", was linked, exact position unknown\n"; 434 } else if (previous.getPoint() == previous.getBlock()->end()) { 435 llvm::outs() << ", was last in block\n"; 436 } else { 437 llvm::outs() << ", previous = " << previous.getPoint()->getName() 438 << "\n"; 439 } 440 } 441 } 442 void notifyBlockErased(Block *block) override { 443 llvm::outs() << "notifyBlockErased\n"; 444 } 445 void notifyOperationErased(Operation *op) override { 446 llvm::outs() << "notifyOperationErased: " << op->getName() << "\n"; 447 } 448 void notifyOperationModified(Operation *op) override { 449 llvm::outs() << "notifyOperationModified: " << op->getName() << "\n"; 450 } 451 void notifyOperationReplaced(Operation *op, ValueRange values) override { 452 llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n"; 453 } 454 }; 455 456 struct TestStrictPatternDriver 457 : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> { 458 public: 459 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver) 460 461 TestStrictPatternDriver() = default; 462 TestStrictPatternDriver(const TestStrictPatternDriver &other) 463 : PassWrapper(other) { 464 strictMode = other.strictMode; 465 } 466 467 StringRef getArgument() const final { return "test-strict-pattern-driver"; } 468 StringRef getDescription() const final { 469 return "Test strict mode of pattern driver"; 470 } 471 472 void runOnOperation() override { 473 MLIRContext *ctx = &getContext(); 474 mlir::RewritePatternSet patterns(ctx); 475 patterns.add< 476 // clang-format off 477 ChangeBlockOp, 478 CloneOp, 479 CloneRegionBeforeOp, 480 EraseOp, 481 ImplicitChangeOp, 482 InlineBlocksIntoParent, 483 InsertSameOp, 484 MoveBeforeParentOp, 485 ReplaceWithNewOp, 486 SplitBlockHere 487 // clang-format on 488 >(ctx); 489 SmallVector<Operation *> ops; 490 getOperation()->walk([&](Operation *op) { 491 StringRef opName = op->getName().getStringRef(); 492 if (opName == "test.insert_same_op" || opName == "test.change_block_op" || 493 opName == "test.replace_with_new_op" || opName == "test.erase_op" || 494 opName == "test.move_before_parent_op" || 495 opName == "test.inline_blocks_into_parent" || 496 opName == "test.split_block_here" || opName == "test.clone_me" || 497 opName == "test.clone_region_before") { 498 ops.push_back(op); 499 } 500 }); 501 502 DumpNotifications dumpNotifications; 503 GreedyRewriteConfig config; 504 config.listener = &dumpNotifications; 505 if (strictMode == "AnyOp") { 506 config.strictMode = GreedyRewriteStrictness::AnyOp; 507 } else if (strictMode == "ExistingAndNewOps") { 508 config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; 509 } else if (strictMode == "ExistingOps") { 510 config.strictMode = GreedyRewriteStrictness::ExistingOps; 511 } else { 512 llvm_unreachable("invalid strictness option"); 513 } 514 515 // Check if these transformations introduce visiting of operations that 516 // are not in the `ops` set (The new created ops are valid). An invalid 517 // operation will trigger the assertion while processing. 518 bool changed = false; 519 bool allErased = false; 520 (void)applyOpPatternsGreedily(ArrayRef(ops), std::move(patterns), config, 521 &changed, &allErased); 522 Builder b(ctx); 523 getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed)); 524 getOperation()->setAttr("pattern_driver_all_erased", 525 b.getBoolAttr(allErased)); 526 } 527 528 Option<std::string> strictMode{ 529 *this, "strictness", 530 llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"), 531 llvm::cl::init("AnyOp")}; 532 533 private: 534 // New inserted operation is valid for further transformation. 535 class InsertSameOp : public RewritePattern { 536 public: 537 InsertSameOp(MLIRContext *context) 538 : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {} 539 540 LogicalResult matchAndRewrite(Operation *op, 541 PatternRewriter &rewriter) const override { 542 if (op->hasAttr("skip")) 543 return failure(); 544 545 Operation *newOp = 546 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 547 op->getOperands(), op->getResultTypes()); 548 rewriter.modifyOpInPlace( 549 op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); }); 550 newOp->setAttr("skip", rewriter.getBoolAttr(true)); 551 552 return success(); 553 } 554 }; 555 556 // Remove an operation may introduce the re-visiting of its operands. 557 class EraseOp : public RewritePattern { 558 public: 559 EraseOp(MLIRContext *context) 560 : RewritePattern("test.erase_op", /*benefit=*/1, context) {} 561 LogicalResult matchAndRewrite(Operation *op, 562 PatternRewriter &rewriter) const override { 563 rewriter.eraseOp(op); 564 return success(); 565 } 566 }; 567 568 // The following two patterns test RewriterBase::replaceAllUsesWith. 569 // 570 // That function replaces all usages of a Block (or a Value) with another one 571 // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver 572 // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its 573 // worklist: when an op is modified, it is added to the worklist. The two 574 // patterns below make the tracking observable: ChangeBlockOp replaces all 575 // usages of a block and that pattern is applied because the corresponding ops 576 // are put on the initial worklist (see above). ImplicitChangeOp does an 577 // unrelated change but ops of the corresponding type are *not* on the initial 578 // worklist, so the effect of the second pattern is only visible if the 579 // tracking and subsequent adding to the worklist actually works. 580 581 // Replace all usages of the first successor with the second successor. 582 class ChangeBlockOp : public RewritePattern { 583 public: 584 ChangeBlockOp(MLIRContext *context) 585 : RewritePattern("test.change_block_op", /*benefit=*/1, context) {} 586 LogicalResult matchAndRewrite(Operation *op, 587 PatternRewriter &rewriter) const override { 588 if (op->getNumSuccessors() < 2) 589 return failure(); 590 Block *firstSuccessor = op->getSuccessor(0); 591 Block *secondSuccessor = op->getSuccessor(1); 592 if (firstSuccessor == secondSuccessor) 593 return failure(); 594 // This is the function being tested: 595 rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor); 596 // Using the following line instead would make the test fail: 597 // firstSuccessor->replaceAllUsesWith(secondSuccessor); 598 return success(); 599 } 600 }; 601 602 // Changes the successor to the parent block. 603 class ImplicitChangeOp : public RewritePattern { 604 public: 605 ImplicitChangeOp(MLIRContext *context) 606 : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {} 607 LogicalResult matchAndRewrite(Operation *op, 608 PatternRewriter &rewriter) const override { 609 if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock()) 610 return failure(); 611 rewriter.modifyOpInPlace(op, 612 [&]() { op->setSuccessor(op->getBlock(), 0); }); 613 return success(); 614 } 615 }; 616 }; 617 618 struct TestWalkPatternDriver final 619 : PassWrapper<TestWalkPatternDriver, OperationPass<>> { 620 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver) 621 622 TestWalkPatternDriver() = default; 623 TestWalkPatternDriver(const TestWalkPatternDriver &other) 624 : PassWrapper(other) {} 625 626 StringRef getArgument() const override { 627 return "test-walk-pattern-rewrite-driver"; 628 } 629 StringRef getDescription() const override { 630 return "Run test walk pattern rewrite driver"; 631 } 632 void runOnOperation() override { 633 mlir::RewritePatternSet patterns(&getContext()); 634 635 // Patterns for testing the WalkPatternRewriteDriver. 636 patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp, 637 MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>( 638 &getContext()); 639 640 DumpNotifications dumpListener; 641 walkAndApplyPatterns(getOperation(), std::move(patterns), 642 dumpNotifications ? &dumpListener : nullptr); 643 } 644 645 Option<bool> dumpNotifications{ 646 *this, "dump-notifications", 647 llvm::cl::desc("Print rewrite listener notifications"), 648 llvm::cl::init(false)}; 649 }; 650 651 } // namespace 652 653 //===----------------------------------------------------------------------===// 654 // ReturnType Driver. 655 //===----------------------------------------------------------------------===// 656 657 namespace { 658 // Generate ops for each instance where the type can be successfully inferred. 659 template <typename OpTy> 660 static void invokeCreateWithInferredReturnType(Operation *op) { 661 auto *context = op->getContext(); 662 auto fop = op->getParentOfType<func::FuncOp>(); 663 auto location = UnknownLoc::get(context); 664 OpBuilder b(op); 665 b.setInsertionPointAfter(op); 666 667 // Use permutations of 2 args as operands. 668 assert(fop.getNumArguments() >= 2); 669 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 670 for (int j = 0; j < e; ++j) { 671 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 672 SmallVector<Type, 2> inferredReturnTypes; 673 if (succeeded(OpTy::inferReturnTypes( 674 context, std::nullopt, values, op->getDiscardableAttrDictionary(), 675 op->getPropertiesStorage(), op->getRegions(), 676 inferredReturnTypes))) { 677 OperationState state(location, OpTy::getOperationName()); 678 // TODO: Expand to regions. 679 OpTy::build(b, state, values, op->getAttrs()); 680 (void)b.create(state); 681 } 682 } 683 } 684 } 685 686 static void reifyReturnShape(Operation *op) { 687 OpBuilder b(op); 688 689 // Use permutations of 2 args as operands. 690 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 691 SmallVector<Value, 2> shapes; 692 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 693 !llvm::hasSingleElement(shapes)) 694 return; 695 for (const auto &it : llvm::enumerate(shapes)) { 696 op->emitRemark() << "value " << it.index() << ": " 697 << it.value().getDefiningOp(); 698 } 699 } 700 701 struct TestReturnTypeDriver 702 : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> { 703 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver) 704 705 void getDependentDialects(DialectRegistry ®istry) const override { 706 registry.insert<tensor::TensorDialect>(); 707 } 708 StringRef getArgument() const final { return "test-return-type"; } 709 StringRef getDescription() const final { return "Run return type functions"; } 710 711 void runOnOperation() override { 712 if (getOperation().getName() == "testCreateFunctions") { 713 std::vector<Operation *> ops; 714 // Collect ops to avoid triggering on inserted ops. 715 for (auto &op : getOperation().getBody().front()) 716 ops.push_back(&op); 717 // Generate test patterns for each, but skip terminator. 718 for (auto *op : llvm::ArrayRef(ops).drop_back()) { 719 // Test create method of each of the Op classes below. The resultant 720 // output would be in reverse order underneath `op` from which 721 // the attributes and regions are used. 722 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 723 invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>( 724 op); 725 invokeCreateWithInferredReturnType< 726 OpWithShapedTypeInferTypeInterfaceOp>(op); 727 }; 728 return; 729 } 730 if (getOperation().getName() == "testReifyFunctions") { 731 std::vector<Operation *> ops; 732 // Collect ops to avoid triggering on inserted ops. 733 for (auto &op : getOperation().getBody().front()) 734 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 735 ops.push_back(&op); 736 // Generate test patterns for each, but skip terminator. 737 for (auto *op : ops) 738 reifyReturnShape(op); 739 } 740 } 741 }; 742 } // namespace 743 744 namespace { 745 struct TestDerivedAttributeDriver 746 : public PassWrapper<TestDerivedAttributeDriver, 747 OperationPass<func::FuncOp>> { 748 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver) 749 750 StringRef getArgument() const final { return "test-derived-attr"; } 751 StringRef getDescription() const final { 752 return "Run test derived attributes"; 753 } 754 void runOnOperation() override; 755 }; 756 } // namespace 757 758 void TestDerivedAttributeDriver::runOnOperation() { 759 getOperation().walk([](DerivedAttributeOpInterface dOp) { 760 auto dAttr = dOp.materializeDerivedAttributes(); 761 if (!dAttr) 762 return; 763 for (auto d : dAttr) 764 dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue(); 765 }); 766 } 767 768 //===----------------------------------------------------------------------===// 769 // Legalization Driver. 770 //===----------------------------------------------------------------------===// 771 772 namespace { 773 //===----------------------------------------------------------------------===// 774 // Region-Block Rewrite Testing 775 776 /// This pattern applies a signature conversion to a block inside a detached 777 /// region. 778 struct TestDetachedSignatureConversion : public ConversionPattern { 779 TestDetachedSignatureConversion(MLIRContext *ctx) 780 : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1, 781 ctx) {} 782 783 LogicalResult 784 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 785 ConversionPatternRewriter &rewriter) const final { 786 if (op->getNumRegions() != 1) 787 return failure(); 788 OperationState state(op->getLoc(), "test.legal_op_with_region", operands, 789 op->getResultTypes(), {}, BlockRange()); 790 Region *newRegion = state.addRegion(); 791 rewriter.inlineRegionBefore(op->getRegion(0), *newRegion, 792 newRegion->begin()); 793 TypeConverter::SignatureConversion result(newRegion->getNumArguments()); 794 for (unsigned i = 0, e = newRegion->getNumArguments(); i < e; ++i) 795 result.addInputs(i, rewriter.getF64Type()); 796 rewriter.applySignatureConversion(&newRegion->front(), result); 797 Operation *newOp = rewriter.create(state); 798 rewriter.replaceOp(op, newOp->getResults()); 799 return success(); 800 } 801 }; 802 803 /// This pattern is a simple pattern that inlines the first region of a given 804 /// operation into the parent region. 805 struct TestRegionRewriteBlockMovement : public ConversionPattern { 806 TestRegionRewriteBlockMovement(MLIRContext *ctx) 807 : ConversionPattern("test.region", 1, ctx) {} 808 809 LogicalResult 810 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 811 ConversionPatternRewriter &rewriter) const final { 812 // Inline this region into the parent region. 813 auto &parentRegion = *op->getParentRegion(); 814 auto &opRegion = op->getRegion(0); 815 if (op->getDiscardableAttr("legalizer.should_clone")) 816 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 817 else 818 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 819 820 if (op->getDiscardableAttr("legalizer.erase_old_blocks")) { 821 while (!opRegion.empty()) 822 rewriter.eraseBlock(&opRegion.front()); 823 } 824 825 // Drop this operation. 826 rewriter.eraseOp(op); 827 return success(); 828 } 829 }; 830 /// This pattern is a simple pattern that generates a region containing an 831 /// illegal operation. 832 struct TestRegionRewriteUndo : public RewritePattern { 833 TestRegionRewriteUndo(MLIRContext *ctx) 834 : RewritePattern("test.region_builder", 1, ctx) {} 835 836 LogicalResult matchAndRewrite(Operation *op, 837 PatternRewriter &rewriter) const final { 838 // Create the region operation with an entry block containing arguments. 839 OperationState newRegion(op->getLoc(), "test.region"); 840 newRegion.addRegion(); 841 auto *regionOp = rewriter.create(newRegion); 842 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 843 entryBlock->addArgument(rewriter.getIntegerType(64), 844 rewriter.getUnknownLoc()); 845 846 // Add an explicitly illegal operation to ensure the conversion fails. 847 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 848 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 849 850 // Drop this operation. 851 rewriter.eraseOp(op); 852 return success(); 853 } 854 }; 855 /// A simple pattern that creates a block at the end of the parent region of the 856 /// matched operation. 857 struct TestCreateBlock : public RewritePattern { 858 TestCreateBlock(MLIRContext *ctx) 859 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 860 861 LogicalResult matchAndRewrite(Operation *op, 862 PatternRewriter &rewriter) const final { 863 Region ®ion = *op->getParentRegion(); 864 Type i32Type = rewriter.getIntegerType(32); 865 Location loc = op->getLoc(); 866 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 867 rewriter.create<TerminatorOp>(loc); 868 rewriter.eraseOp(op); 869 return success(); 870 } 871 }; 872 873 /// A simple pattern that creates a block containing an invalid operation in 874 /// order to trigger the block creation undo mechanism. 875 struct TestCreateIllegalBlock : public RewritePattern { 876 TestCreateIllegalBlock(MLIRContext *ctx) 877 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 878 879 LogicalResult matchAndRewrite(Operation *op, 880 PatternRewriter &rewriter) const final { 881 Region ®ion = *op->getParentRegion(); 882 Type i32Type = rewriter.getIntegerType(32); 883 Location loc = op->getLoc(); 884 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 885 // Create an illegal op to ensure the conversion fails. 886 rewriter.create<ILLegalOpF>(loc, i32Type); 887 rewriter.create<TerminatorOp>(loc); 888 rewriter.eraseOp(op); 889 return success(); 890 } 891 }; 892 893 /// A simple pattern that tests the undo mechanism when replacing the uses of a 894 /// block argument. 895 struct TestUndoBlockArgReplace : public ConversionPattern { 896 TestUndoBlockArgReplace(MLIRContext *ctx) 897 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 898 899 LogicalResult 900 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 901 ConversionPatternRewriter &rewriter) const final { 902 auto illegalOp = 903 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 904 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 905 illegalOp->getResult(0)); 906 rewriter.modifyOpInPlace(op, [] {}); 907 return success(); 908 } 909 }; 910 911 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion. 912 /// This is to test the rollback logic. 913 struct TestUndoMoveOpBefore : public ConversionPattern { 914 TestUndoMoveOpBefore(MLIRContext *ctx) 915 : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {} 916 917 LogicalResult 918 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 919 ConversionPatternRewriter &rewriter) const override { 920 rewriter.moveOpBefore(op, op->getParentOp()); 921 // Replace with an illegal op to ensure the conversion fails. 922 rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type()); 923 return success(); 924 } 925 }; 926 927 /// A rewrite pattern that tests the undo mechanism when erasing a block. 928 struct TestUndoBlockErase : public ConversionPattern { 929 TestUndoBlockErase(MLIRContext *ctx) 930 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 931 932 LogicalResult 933 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 934 ConversionPatternRewriter &rewriter) const final { 935 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 936 rewriter.setInsertionPointToStart(secondBlock); 937 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 938 rewriter.eraseBlock(secondBlock); 939 rewriter.modifyOpInPlace(op, [] {}); 940 return success(); 941 } 942 }; 943 944 /// A pattern that modifies a property in-place, but keeps the op illegal. 945 struct TestUndoPropertiesModification : public ConversionPattern { 946 TestUndoPropertiesModification(MLIRContext *ctx) 947 : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {} 948 LogicalResult 949 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 950 ConversionPatternRewriter &rewriter) const final { 951 if (!op->hasAttr("modify_inplace")) 952 return failure(); 953 rewriter.modifyOpInPlace( 954 op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); }); 955 return success(); 956 } 957 }; 958 959 //===----------------------------------------------------------------------===// 960 // Type-Conversion Rewrite Testing 961 962 /// This patterns erases a region operation that has had a type conversion. 963 struct TestDropOpSignatureConversion : public ConversionPattern { 964 TestDropOpSignatureConversion(MLIRContext *ctx, 965 const TypeConverter &converter) 966 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 967 LogicalResult 968 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 969 ConversionPatternRewriter &rewriter) const override { 970 Region ®ion = op->getRegion(0); 971 Block *entry = ®ion.front(); 972 973 // Convert the original entry arguments. 974 const TypeConverter &converter = *getTypeConverter(); 975 TypeConverter::SignatureConversion result(entry->getNumArguments()); 976 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 977 result)) || 978 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 979 return failure(); 980 981 // Convert the region signature and just drop the operation. 982 rewriter.eraseOp(op); 983 return success(); 984 } 985 }; 986 /// This pattern simply updates the operands of the given operation. 987 struct TestPassthroughInvalidOp : public ConversionPattern { 988 TestPassthroughInvalidOp(MLIRContext *ctx) 989 : ConversionPattern("test.invalid", 1, ctx) {} 990 LogicalResult 991 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 992 ConversionPatternRewriter &rewriter) const final { 993 SmallVector<Value> flattened; 994 for (auto it : llvm::enumerate(operands)) { 995 ValueRange range = it.value(); 996 if (range.size() == 1) { 997 flattened.push_back(range.front()); 998 continue; 999 } 1000 1001 // This is a 1:N replacement. Insert a test.cast op. (That's what the 1002 // argument materialization used to do.) 1003 flattened.push_back( 1004 rewriter 1005 .create<TestCastOp>(op->getLoc(), 1006 op->getOperand(it.index()).getType(), range) 1007 .getResult()); 1008 } 1009 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, flattened, 1010 std::nullopt); 1011 return success(); 1012 } 1013 }; 1014 /// Replace with valid op, but simply drop the operands. This is used in a 1015 /// regression where we used to generate circular unrealized_conversion_cast 1016 /// ops. 1017 struct TestDropAndReplaceInvalidOp : public ConversionPattern { 1018 TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter) 1019 : ConversionPattern(converter, 1020 "test.drop_operands_and_replace_with_valid", 1, ctx) { 1021 } 1022 LogicalResult 1023 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1024 ConversionPatternRewriter &rewriter) const final { 1025 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, ValueRange(), 1026 std::nullopt); 1027 return success(); 1028 } 1029 }; 1030 /// This pattern handles the case of a split return value. 1031 struct TestSplitReturnType : public ConversionPattern { 1032 TestSplitReturnType(MLIRContext *ctx) 1033 : ConversionPattern("test.return", 1, ctx) {} 1034 LogicalResult 1035 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 1036 ConversionPatternRewriter &rewriter) const final { 1037 // Check for a return of F32. 1038 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 1039 return failure(); 1040 rewriter.replaceOpWithNewOp<TestReturnOp>(op, operands[0]); 1041 return success(); 1042 } 1043 }; 1044 1045 //===----------------------------------------------------------------------===// 1046 // Multi-Level Type-Conversion Rewrite Testing 1047 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 1048 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 1049 : ConversionPattern("test.type_producer", 1, ctx) {} 1050 LogicalResult 1051 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1052 ConversionPatternRewriter &rewriter) const final { 1053 // If the type is I32, change the type to F32. 1054 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 1055 return failure(); 1056 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 1057 return success(); 1058 } 1059 }; 1060 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 1061 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 1062 : ConversionPattern("test.type_producer", 1, ctx) {} 1063 LogicalResult 1064 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1065 ConversionPatternRewriter &rewriter) const final { 1066 // If the type is F32, change the type to F64. 1067 if (!Type(*op->result_type_begin()).isF32()) 1068 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 1069 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 1070 return success(); 1071 } 1072 }; 1073 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 1074 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 1075 : ConversionPattern("test.type_producer", 10, ctx) {} 1076 LogicalResult 1077 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1078 ConversionPatternRewriter &rewriter) const final { 1079 // Always convert to B16, even though it is not a legal type. This tests 1080 // that values are unmapped correctly. 1081 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 1082 return success(); 1083 } 1084 }; 1085 struct TestUpdateConsumerType : public ConversionPattern { 1086 TestUpdateConsumerType(MLIRContext *ctx) 1087 : ConversionPattern("test.type_consumer", 1, ctx) {} 1088 LogicalResult 1089 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1090 ConversionPatternRewriter &rewriter) const final { 1091 // Verify that the incoming operand has been successfully remapped to F64. 1092 if (!operands[0].getType().isF64()) 1093 return failure(); 1094 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 1095 return success(); 1096 } 1097 }; 1098 1099 //===----------------------------------------------------------------------===// 1100 // Non-Root Replacement Rewrite Testing 1101 /// This pattern generates an invalid operation, but replaces it before the 1102 /// pattern is finished. This checks that we don't need to legalize the 1103 /// temporary op. 1104 struct TestNonRootReplacement : public RewritePattern { 1105 TestNonRootReplacement(MLIRContext *ctx) 1106 : RewritePattern("test.replace_non_root", 1, ctx) {} 1107 1108 LogicalResult matchAndRewrite(Operation *op, 1109 PatternRewriter &rewriter) const final { 1110 auto resultType = *op->result_type_begin(); 1111 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 1112 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 1113 1114 rewriter.replaceOp(illegalOp, legalOp); 1115 rewriter.replaceOp(op, illegalOp); 1116 return success(); 1117 } 1118 }; 1119 1120 //===----------------------------------------------------------------------===// 1121 // Recursive Rewrite Testing 1122 /// This pattern is applied to the same operation multiple times, but has a 1123 /// bounded recursion. 1124 struct TestBoundedRecursiveRewrite 1125 : public OpRewritePattern<TestRecursiveRewriteOp> { 1126 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 1127 1128 void initialize() { 1129 // The conversion target handles bounding the recursion of this pattern. 1130 setHasBoundedRewriteRecursion(); 1131 } 1132 1133 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 1134 PatternRewriter &rewriter) const final { 1135 // Decrement the depth of the op in-place. 1136 rewriter.modifyOpInPlace(op, [&] { 1137 op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1)); 1138 }); 1139 return success(); 1140 } 1141 }; 1142 1143 struct TestNestedOpCreationUndoRewrite 1144 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 1145 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 1146 1147 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 1148 PatternRewriter &rewriter) const final { 1149 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 1150 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 1151 return success(); 1152 }; 1153 }; 1154 1155 // This pattern matches `test.blackhole` and delete this op and its producer. 1156 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 1157 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 1158 1159 LogicalResult matchAndRewrite(BlackHoleOp op, 1160 PatternRewriter &rewriter) const final { 1161 Operation *producer = op.getOperand().getDefiningOp(); 1162 // Always erase the user before the producer, the framework should handle 1163 // this correctly. 1164 rewriter.eraseOp(op); 1165 rewriter.eraseOp(producer); 1166 return success(); 1167 }; 1168 }; 1169 1170 // This pattern replaces explicitly illegal op with explicitly legal op, 1171 // but in addition creates unregistered operation. 1172 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { 1173 using OpRewritePattern<ILLegalOpG>::OpRewritePattern; 1174 1175 LogicalResult matchAndRewrite(ILLegalOpG op, 1176 PatternRewriter &rewriter) const final { 1177 IntegerAttr attr = rewriter.getI32IntegerAttr(0); 1178 Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr); 1179 rewriter.replaceOpWithNewOp<LegalOpC>(op, val); 1180 return success(); 1181 }; 1182 }; 1183 1184 class TestEraseOp : public ConversionPattern { 1185 public: 1186 TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op", 1, ctx) {} 1187 LogicalResult 1188 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1189 ConversionPatternRewriter &rewriter) const final { 1190 // Erase op without replacements. 1191 rewriter.eraseOp(op); 1192 return success(); 1193 } 1194 }; 1195 1196 /// This pattern matches a test.duplicate_block_args op and duplicates all 1197 /// block arguments. 1198 class TestDuplicateBlockArgs 1199 : public OpConversionPattern<DuplicateBlockArgsOp> { 1200 using OpConversionPattern<DuplicateBlockArgsOp>::OpConversionPattern; 1201 1202 LogicalResult 1203 matchAndRewrite(DuplicateBlockArgsOp op, OpAdaptor adaptor, 1204 ConversionPatternRewriter &rewriter) const override { 1205 if (op.getIsLegal()) 1206 return failure(); 1207 rewriter.startOpModification(op); 1208 Block *body = &op.getBody().front(); 1209 TypeConverter::SignatureConversion result(body->getNumArguments()); 1210 for (auto it : llvm::enumerate(body->getArgumentTypes())) 1211 result.addInputs(it.index(), {it.value(), it.value()}); 1212 rewriter.applySignatureConversion(body, result, getTypeConverter()); 1213 op.setIsLegal(true); 1214 rewriter.finalizeOpModification(op); 1215 return success(); 1216 } 1217 }; 1218 1219 /// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid 1220 /// op. The pattern supports 1:N replacements and forwards the replacement 1221 /// values of the single operand as test.valid operands. 1222 class TestRepetitive1ToNConsumer : public ConversionPattern { 1223 public: 1224 TestRepetitive1ToNConsumer(MLIRContext *ctx) 1225 : ConversionPattern("test.repetitive_1_to_n_consumer", 1, ctx) {} 1226 LogicalResult 1227 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 1228 ConversionPatternRewriter &rewriter) const final { 1229 // A single operand is expected. 1230 if (op->getNumOperands() != 1) 1231 return failure(); 1232 rewriter.replaceOpWithNewOp<TestValidOp>(op, operands.front()); 1233 return success(); 1234 } 1235 }; 1236 1237 } // namespace 1238 1239 namespace { 1240 struct TestTypeConverter : public TypeConverter { 1241 using TypeConverter::TypeConverter; 1242 TestTypeConverter() { 1243 addConversion(convertType); 1244 addArgumentMaterialization(materializeCast); 1245 addSourceMaterialization(materializeCast); 1246 } 1247 1248 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 1249 // Drop I16 types. 1250 if (t.isSignlessInteger(16)) 1251 return success(); 1252 1253 // Convert I64 to F64. 1254 if (t.isSignlessInteger(64)) { 1255 results.push_back(FloatType::getF64(t.getContext())); 1256 return success(); 1257 } 1258 1259 // Convert I42 to I43. 1260 if (t.isInteger(42)) { 1261 results.push_back(IntegerType::get(t.getContext(), 43)); 1262 return success(); 1263 } 1264 1265 // Split F32 into F16,F16. 1266 if (t.isF32()) { 1267 results.assign(2, FloatType::getF16(t.getContext())); 1268 return success(); 1269 } 1270 1271 // Drop I24 types. 1272 if (t.isInteger(24)) { 1273 return success(); 1274 } 1275 1276 // Otherwise, convert the type directly. 1277 results.push_back(t); 1278 return success(); 1279 } 1280 1281 /// Hook for materializing a conversion. This is necessary because we generate 1282 /// 1->N type mappings. 1283 static Value materializeCast(OpBuilder &builder, Type resultType, 1284 ValueRange inputs, Location loc) { 1285 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1286 } 1287 }; 1288 1289 struct TestLegalizePatternDriver 1290 : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> { 1291 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver) 1292 1293 StringRef getArgument() const final { return "test-legalize-patterns"; } 1294 StringRef getDescription() const final { 1295 return "Run test dialect legalization patterns"; 1296 } 1297 /// The mode of conversion to use with the driver. 1298 enum class ConversionMode { Analysis, Full, Partial }; 1299 1300 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 1301 1302 void getDependentDialects(DialectRegistry ®istry) const override { 1303 registry.insert<func::FuncDialect, test::TestDialect>(); 1304 } 1305 1306 void runOnOperation() override { 1307 TestTypeConverter converter; 1308 mlir::RewritePatternSet patterns(&getContext()); 1309 populateWithGenerated(patterns); 1310 patterns.add< 1311 TestRegionRewriteBlockMovement, TestDetachedSignatureConversion, 1312 TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock, 1313 TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp, 1314 TestSplitReturnType, TestChangeProducerTypeI32ToF32, 1315 TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid, 1316 TestUpdateConsumerType, TestNonRootReplacement, 1317 TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite, 1318 TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore, 1319 TestUndoPropertiesModification, TestEraseOp, 1320 TestRepetitive1ToNConsumer>(&getContext()); 1321 patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>( 1322 &getContext(), converter); 1323 patterns.add<TestDuplicateBlockArgs>(converter, &getContext()); 1324 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1325 converter); 1326 mlir::populateCallOpTypeConversionPattern(patterns, converter); 1327 1328 // Define the conversion target used for the test. 1329 ConversionTarget target(getContext()); 1330 target.addLegalOp<ModuleOp>(); 1331 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, 1332 TerminatorOp, OneRegionOp>(); 1333 target.addLegalOp( 1334 OperationName("test.legal_op_with_region", &getContext())); 1335 target 1336 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 1337 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 1338 // Don't allow F32 operands. 1339 return llvm::none_of(op.getOperandTypes(), 1340 [](Type type) { return type.isF32(); }); 1341 }); 1342 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1343 return converter.isSignatureLegal(op.getFunctionType()) && 1344 converter.isLegal(&op.getBody()); 1345 }); 1346 target.addDynamicallyLegalOp<func::CallOp>( 1347 [&](func::CallOp op) { return converter.isLegal(op); }); 1348 1349 // TestCreateUnregisteredOp creates `arith.constant` operation, 1350 // which was not added to target intentionally to test 1351 // correct error code from conversion driver. 1352 target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); 1353 1354 // Expect the type_producer/type_consumer operations to only operate on f64. 1355 target.addDynamicallyLegalOp<TestTypeProducerOp>( 1356 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1357 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1358 return op.getOperand().getType().isF64(); 1359 }); 1360 1361 // Check support for marking certain operations as recursively legal. 1362 target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) { 1363 return static_cast<bool>( 1364 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 1365 }); 1366 1367 // Mark the bound recursion operation as dynamically legal. 1368 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 1369 [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); 1370 1371 // Create a dynamically legal rule that can only be legalized by folding it. 1372 target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>( 1373 [](TestOpInPlaceSelfFold op) { return op.getFolded(); }); 1374 1375 target.addDynamicallyLegalOp<DuplicateBlockArgsOp>( 1376 [](DuplicateBlockArgsOp op) { return op.getIsLegal(); }); 1377 1378 // Handle a partial conversion. 1379 if (mode == ConversionMode::Partial) { 1380 DenseSet<Operation *> unlegalizedOps; 1381 ConversionConfig config; 1382 DumpNotifications dumpNotifications; 1383 config.listener = &dumpNotifications; 1384 config.unlegalizedOps = &unlegalizedOps; 1385 if (failed(applyPartialConversion(getOperation(), target, 1386 std::move(patterns), config))) { 1387 getOperation()->emitRemark() << "applyPartialConversion failed"; 1388 } 1389 // Emit remarks for each legalizable operation. 1390 for (auto *op : unlegalizedOps) 1391 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1392 return; 1393 } 1394 1395 // Handle a full conversion. 1396 if (mode == ConversionMode::Full) { 1397 // Check support for marking unknown operations as dynamically legal. 1398 target.markUnknownOpDynamicallyLegal([](Operation *op) { 1399 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 1400 }); 1401 1402 ConversionConfig config; 1403 DumpNotifications dumpNotifications; 1404 config.listener = &dumpNotifications; 1405 if (failed(applyFullConversion(getOperation(), target, 1406 std::move(patterns), config))) { 1407 getOperation()->emitRemark() << "applyFullConversion failed"; 1408 } 1409 return; 1410 } 1411 1412 // Otherwise, handle an analysis conversion. 1413 assert(mode == ConversionMode::Analysis); 1414 1415 // Analyze the convertible operations. 1416 DenseSet<Operation *> legalizedOps; 1417 ConversionConfig config; 1418 config.legalizableOps = &legalizedOps; 1419 if (failed(applyAnalysisConversion(getOperation(), target, 1420 std::move(patterns), config))) 1421 return signalPassFailure(); 1422 1423 // Emit remarks for each legalizable operation. 1424 for (auto *op : legalizedOps) 1425 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 1426 } 1427 1428 /// The mode of conversion to use. 1429 ConversionMode mode; 1430 }; 1431 } // namespace 1432 1433 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 1434 legalizerConversionMode( 1435 "test-legalize-mode", 1436 llvm::cl::desc("The legalization mode to use with the test driver"), 1437 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 1438 llvm::cl::values( 1439 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 1440 "analysis", "Perform an analysis conversion"), 1441 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 1442 "Perform a full conversion"), 1443 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 1444 "partial", "Perform a partial conversion"))); 1445 1446 //===----------------------------------------------------------------------===// 1447 // ConversionPatternRewriter::getRemappedValue testing. This method is used 1448 // to get the remapped value of an original value that was replaced using 1449 // ConversionPatternRewriter. 1450 namespace { 1451 struct TestRemapValueTypeConverter : public TypeConverter { 1452 using TypeConverter::TypeConverter; 1453 1454 TestRemapValueTypeConverter() { 1455 addConversion( 1456 [](Float32Type type) { return Float64Type::get(type.getContext()); }); 1457 addConversion([](Type type) { return type; }); 1458 } 1459 }; 1460 1461 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 1462 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 1463 /// operand twice. 1464 /// 1465 /// Example: 1466 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 1467 /// is replaced with: 1468 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 1469 struct OneVResOneVOperandOp1Converter 1470 : public OpConversionPattern<OneVResOneVOperandOp1> { 1471 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 1472 1473 LogicalResult 1474 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, 1475 ConversionPatternRewriter &rewriter) const override { 1476 auto origOps = op.getOperands(); 1477 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 1478 "One operand expected"); 1479 Value origOp = *origOps.begin(); 1480 SmallVector<Value, 2> remappedOperands; 1481 // Replicate the remapped original operand twice. Note that we don't used 1482 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 1483 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1484 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1485 1486 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 1487 remappedOperands); 1488 return success(); 1489 } 1490 }; 1491 1492 /// A rewriter pattern that tests that blocks can be merged. 1493 struct TestRemapValueInRegion 1494 : public OpConversionPattern<TestRemappedValueRegionOp> { 1495 using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern; 1496 1497 LogicalResult 1498 matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor, 1499 ConversionPatternRewriter &rewriter) const final { 1500 Block &block = op.getBody().front(); 1501 Operation *terminator = block.getTerminator(); 1502 1503 // Merge the block into the parent region. 1504 Block *parentBlock = op->getBlock(); 1505 Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator()); 1506 rewriter.mergeBlocks(&block, parentBlock, ValueRange()); 1507 rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange()); 1508 1509 // Replace the results of this operation with the remapped terminator 1510 // values. 1511 SmallVector<Value> terminatorOperands; 1512 if (failed(rewriter.getRemappedValues(terminator->getOperands(), 1513 terminatorOperands))) 1514 return failure(); 1515 1516 rewriter.eraseOp(terminator); 1517 rewriter.replaceOp(op, terminatorOperands); 1518 return success(); 1519 } 1520 }; 1521 1522 struct TestRemappedValue 1523 : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> { 1524 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue) 1525 1526 StringRef getArgument() const final { return "test-remapped-value"; } 1527 StringRef getDescription() const final { 1528 return "Test public remapped value mechanism in ConversionPatternRewriter"; 1529 } 1530 void runOnOperation() override { 1531 TestRemapValueTypeConverter typeConverter; 1532 1533 mlir::RewritePatternSet patterns(&getContext()); 1534 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 1535 patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>( 1536 &getContext()); 1537 patterns.add<TestRemapValueInRegion>(typeConverter, &getContext()); 1538 1539 mlir::ConversionTarget target(getContext()); 1540 target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>(); 1541 1542 // Expect the type_producer/type_consumer operations to only operate on f64. 1543 target.addDynamicallyLegalOp<TestTypeProducerOp>( 1544 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1545 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1546 return op.getOperand().getType().isF64(); 1547 }); 1548 1549 // We make OneVResOneVOperandOp1 legal only when it has more that one 1550 // operand. This will trigger the conversion that will replace one-operand 1551 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 1552 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 1553 [](Operation *op) { return op->getNumOperands() > 1; }); 1554 1555 if (failed(mlir::applyFullConversion(getOperation(), target, 1556 std::move(patterns)))) { 1557 signalPassFailure(); 1558 } 1559 } 1560 }; 1561 } // namespace 1562 1563 //===----------------------------------------------------------------------===// 1564 // Test patterns without a specific root operation kind 1565 //===----------------------------------------------------------------------===// 1566 1567 namespace { 1568 /// This pattern matches and removes any operation in the test dialect. 1569 struct RemoveTestDialectOps : public RewritePattern { 1570 RemoveTestDialectOps(MLIRContext *context) 1571 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 1572 1573 LogicalResult matchAndRewrite(Operation *op, 1574 PatternRewriter &rewriter) const override { 1575 if (!isa<TestDialect>(op->getDialect())) 1576 return failure(); 1577 rewriter.eraseOp(op); 1578 return success(); 1579 } 1580 }; 1581 1582 struct TestUnknownRootOpDriver 1583 : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> { 1584 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver) 1585 1586 StringRef getArgument() const final { 1587 return "test-legalize-unknown-root-patterns"; 1588 } 1589 StringRef getDescription() const final { 1590 return "Test public remapped value mechanism in ConversionPatternRewriter"; 1591 } 1592 void runOnOperation() override { 1593 mlir::RewritePatternSet patterns(&getContext()); 1594 patterns.add<RemoveTestDialectOps>(&getContext()); 1595 1596 mlir::ConversionTarget target(getContext()); 1597 target.addIllegalDialect<TestDialect>(); 1598 if (failed(applyPartialConversion(getOperation(), target, 1599 std::move(patterns)))) 1600 signalPassFailure(); 1601 } 1602 }; 1603 } // namespace 1604 1605 //===----------------------------------------------------------------------===// 1606 // Test patterns that uses operations and types defined at runtime 1607 //===----------------------------------------------------------------------===// 1608 1609 namespace { 1610 /// This pattern matches dynamic operations 'test.one_operand_two_results' and 1611 /// replace them with dynamic operations 'test.generic_dynamic_op'. 1612 struct RewriteDynamicOp : public RewritePattern { 1613 RewriteDynamicOp(MLIRContext *context) 1614 : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1, 1615 context) {} 1616 1617 LogicalResult matchAndRewrite(Operation *op, 1618 PatternRewriter &rewriter) const override { 1619 assert(op->getName().getStringRef() == 1620 "test.dynamic_one_operand_two_results" && 1621 "rewrite pattern should only match operations with the right name"); 1622 1623 OperationState state(op->getLoc(), "test.dynamic_generic", 1624 op->getOperands(), op->getResultTypes(), 1625 op->getAttrs()); 1626 auto *newOp = rewriter.create(state); 1627 rewriter.replaceOp(op, newOp->getResults()); 1628 return success(); 1629 } 1630 }; 1631 1632 struct TestRewriteDynamicOpDriver 1633 : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> { 1634 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver) 1635 1636 void getDependentDialects(DialectRegistry ®istry) const override { 1637 registry.insert<TestDialect>(); 1638 } 1639 StringRef getArgument() const final { return "test-rewrite-dynamic-op"; } 1640 StringRef getDescription() const final { 1641 return "Test rewritting on dynamic operations"; 1642 } 1643 void runOnOperation() override { 1644 RewritePatternSet patterns(&getContext()); 1645 patterns.add<RewriteDynamicOp>(&getContext()); 1646 1647 ConversionTarget target(getContext()); 1648 target.addIllegalOp( 1649 OperationName("test.dynamic_one_operand_two_results", &getContext())); 1650 target.addLegalOp(OperationName("test.dynamic_generic", &getContext())); 1651 if (failed(applyPartialConversion(getOperation(), target, 1652 std::move(patterns)))) 1653 signalPassFailure(); 1654 } 1655 }; 1656 } // end anonymous namespace 1657 1658 //===----------------------------------------------------------------------===// 1659 // Test type conversions 1660 //===----------------------------------------------------------------------===// 1661 1662 namespace { 1663 struct TestTypeConversionProducer 1664 : public OpConversionPattern<TestTypeProducerOp> { 1665 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 1666 LogicalResult 1667 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, 1668 ConversionPatternRewriter &rewriter) const final { 1669 Type resultType = op.getType(); 1670 Type convertedType = getTypeConverter() 1671 ? getTypeConverter()->convertType(resultType) 1672 : resultType; 1673 if (isa<FloatType>(resultType)) 1674 resultType = rewriter.getF64Type(); 1675 else if (resultType.isInteger(16)) 1676 resultType = rewriter.getIntegerType(64); 1677 else if (isa<test::TestRecursiveType>(resultType) && 1678 convertedType != resultType) 1679 resultType = convertedType; 1680 else 1681 return failure(); 1682 1683 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 1684 return success(); 1685 } 1686 }; 1687 1688 /// Call signature conversion and then fail the rewrite to trigger the undo 1689 /// mechanism. 1690 struct TestSignatureConversionUndo 1691 : public OpConversionPattern<TestSignatureConversionUndoOp> { 1692 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 1693 1694 LogicalResult 1695 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, 1696 ConversionPatternRewriter &rewriter) const final { 1697 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 1698 return failure(); 1699 } 1700 }; 1701 1702 /// Call signature conversion without providing a type converter to handle 1703 /// materializations. 1704 struct TestTestSignatureConversionNoConverter 1705 : public OpConversionPattern<TestSignatureConversionNoConverterOp> { 1706 TestTestSignatureConversionNoConverter(const TypeConverter &converter, 1707 MLIRContext *context) 1708 : OpConversionPattern<TestSignatureConversionNoConverterOp>(context), 1709 converter(converter) {} 1710 1711 LogicalResult 1712 matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor, 1713 ConversionPatternRewriter &rewriter) const final { 1714 Region ®ion = op->getRegion(0); 1715 Block *entry = ®ion.front(); 1716 1717 // Convert the original entry arguments. 1718 TypeConverter::SignatureConversion result(entry->getNumArguments()); 1719 if (failed( 1720 converter.convertSignatureArgs(entry->getArgumentTypes(), result))) 1721 return failure(); 1722 rewriter.modifyOpInPlace(op, [&] { 1723 rewriter.applySignatureConversion(®ion.front(), result); 1724 }); 1725 return success(); 1726 } 1727 1728 const TypeConverter &converter; 1729 }; 1730 1731 /// Just forward the operands to the root op. This is essentially a no-op 1732 /// pattern that is used to trigger target materialization. 1733 struct TestTypeConsumerForward 1734 : public OpConversionPattern<TestTypeConsumerOp> { 1735 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 1736 1737 LogicalResult 1738 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, 1739 ConversionPatternRewriter &rewriter) const final { 1740 rewriter.modifyOpInPlace(op, 1741 [&] { op->setOperands(adaptor.getOperands()); }); 1742 return success(); 1743 } 1744 }; 1745 1746 struct TestTypeConversionAnotherProducer 1747 : public OpRewritePattern<TestAnotherTypeProducerOp> { 1748 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 1749 1750 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 1751 PatternRewriter &rewriter) const final { 1752 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 1753 return success(); 1754 } 1755 }; 1756 1757 struct TestReplaceWithLegalOp : public ConversionPattern { 1758 TestReplaceWithLegalOp(MLIRContext *ctx) 1759 : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {} 1760 LogicalResult 1761 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1762 ConversionPatternRewriter &rewriter) const final { 1763 rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]); 1764 return success(); 1765 } 1766 }; 1767 1768 struct TestTypeConversionDriver 1769 : public PassWrapper<TestTypeConversionDriver, OperationPass<>> { 1770 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver) 1771 1772 void getDependentDialects(DialectRegistry ®istry) const override { 1773 registry.insert<TestDialect>(); 1774 } 1775 StringRef getArgument() const final { 1776 return "test-legalize-type-conversion"; 1777 } 1778 StringRef getDescription() const final { 1779 return "Test various type conversion functionalities in DialectConversion"; 1780 } 1781 1782 void runOnOperation() override { 1783 // Initialize the type converter. 1784 SmallVector<Type, 2> conversionCallStack; 1785 TypeConverter converter; 1786 1787 /// Add the legal set of type conversions. 1788 converter.addConversion([](Type type) -> Type { 1789 // Treat F64 as legal. 1790 if (type.isF64()) 1791 return type; 1792 // Allow converting BF16/F16/F32 to F64. 1793 if (type.isBF16() || type.isF16() || type.isF32()) 1794 return FloatType::getF64(type.getContext()); 1795 // Otherwise, the type is illegal. 1796 return nullptr; 1797 }); 1798 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 1799 // Drop all integer types. 1800 return success(); 1801 }); 1802 converter.addConversion( 1803 // Convert a recursive self-referring type into a non-self-referring 1804 // type named "outer_converted_type" that contains a SimpleAType. 1805 [&](test::TestRecursiveType type, 1806 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { 1807 // If the type is already converted, return it to indicate that it is 1808 // legal. 1809 if (type.getName() == "outer_converted_type") { 1810 results.push_back(type); 1811 return success(); 1812 } 1813 1814 conversionCallStack.push_back(type); 1815 auto popConversionCallStack = llvm::make_scope_exit( 1816 [&conversionCallStack]() { conversionCallStack.pop_back(); }); 1817 1818 // If the type is on the call stack more than once (it is there at 1819 // least once because of the _current_ call, which is always the last 1820 // element on the stack), we've hit the recursive case. Just return 1821 // SimpleAType here to create a non-recursive type as a result. 1822 if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(), 1823 type)) { 1824 results.push_back(test::SimpleAType::get(type.getContext())); 1825 return success(); 1826 } 1827 1828 // Convert the body recursively. 1829 auto result = test::TestRecursiveType::get(type.getContext(), 1830 "outer_converted_type"); 1831 if (failed(result.setBody(converter.convertType(type.getBody())))) 1832 return failure(); 1833 results.push_back(result); 1834 return success(); 1835 }); 1836 1837 /// Add the legal set of type materializations. 1838 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 1839 ValueRange inputs, 1840 Location loc) -> Value { 1841 // Allow casting from F64 back to F32. 1842 if (!resultType.isF16() && inputs.size() == 1 && 1843 inputs[0].getType().isF64()) 1844 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1845 // Allow producing an i32 or i64 from nothing. 1846 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 1847 inputs.empty()) 1848 return builder.create<TestTypeProducerOp>(loc, resultType); 1849 // Allow producing an i64 from an integer. 1850 if (isa<IntegerType>(resultType) && inputs.size() == 1 && 1851 isa<IntegerType>(inputs[0].getType())) 1852 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1853 // Otherwise, fail. 1854 return nullptr; 1855 }); 1856 1857 // Initialize the conversion target. 1858 mlir::ConversionTarget target(getContext()); 1859 target.addLegalOp<LegalOpD>(); 1860 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 1861 auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType()); 1862 return op.getType().isF64() || op.getType().isInteger(64) || 1863 (recursiveType && 1864 recursiveType.getName() == "outer_converted_type"); 1865 }); 1866 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1867 return converter.isSignatureLegal(op.getFunctionType()) && 1868 converter.isLegal(&op.getBody()); 1869 }); 1870 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 1871 // Allow casts from F64 to F32. 1872 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 1873 }); 1874 target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>( 1875 [&](TestSignatureConversionNoConverterOp op) { 1876 return converter.isLegal(op.getRegion().front().getArgumentTypes()); 1877 }); 1878 1879 // Initialize the set of rewrite patterns. 1880 RewritePatternSet patterns(&getContext()); 1881 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 1882 TestSignatureConversionUndo, 1883 TestTestSignatureConversionNoConverter>(converter, 1884 &getContext()); 1885 patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>( 1886 &getContext()); 1887 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1888 converter); 1889 1890 if (failed(applyPartialConversion(getOperation(), target, 1891 std::move(patterns)))) 1892 signalPassFailure(); 1893 } 1894 }; 1895 } // namespace 1896 1897 //===----------------------------------------------------------------------===// 1898 // Test Target Materialization With No Uses 1899 //===----------------------------------------------------------------------===// 1900 1901 namespace { 1902 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> { 1903 using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern; 1904 1905 LogicalResult 1906 matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor, 1907 ConversionPatternRewriter &rewriter) const final { 1908 rewriter.replaceOp(op, adaptor.getOperands()); 1909 return success(); 1910 } 1911 }; 1912 1913 struct TestTargetMaterializationWithNoUses 1914 : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> { 1915 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 1916 TestTargetMaterializationWithNoUses) 1917 1918 StringRef getArgument() const final { 1919 return "test-target-materialization-with-no-uses"; 1920 } 1921 StringRef getDescription() const final { 1922 return "Test a special case of target materialization in DialectConversion"; 1923 } 1924 1925 void runOnOperation() override { 1926 TypeConverter converter; 1927 converter.addConversion([](Type t) { return t; }); 1928 converter.addConversion([](IntegerType intTy) -> Type { 1929 if (intTy.getWidth() == 16) 1930 return IntegerType::get(intTy.getContext(), 64); 1931 return intTy; 1932 }); 1933 converter.addTargetMaterialization( 1934 [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { 1935 return builder.create<TestCastOp>(loc, type, inputs).getResult(); 1936 }); 1937 1938 ConversionTarget target(getContext()); 1939 target.addIllegalOp<TestTypeChangerOp>(); 1940 1941 RewritePatternSet patterns(&getContext()); 1942 patterns.add<ForwardOperandPattern>(converter, &getContext()); 1943 1944 if (failed(applyPartialConversion(getOperation(), target, 1945 std::move(patterns)))) 1946 signalPassFailure(); 1947 } 1948 }; 1949 } // namespace 1950 1951 //===----------------------------------------------------------------------===// 1952 // Test Block Merging 1953 //===----------------------------------------------------------------------===// 1954 1955 namespace { 1956 /// A rewriter pattern that tests that blocks can be merged. 1957 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 1958 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 1959 1960 LogicalResult 1961 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, 1962 ConversionPatternRewriter &rewriter) const final { 1963 Block &firstBlock = op.getBody().front(); 1964 Operation *branchOp = firstBlock.getTerminator(); 1965 Block *secondBlock = &*(std::next(op.getBody().begin())); 1966 auto succOperands = branchOp->getOperands(); 1967 SmallVector<Value, 2> replacements(succOperands); 1968 rewriter.eraseOp(branchOp); 1969 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1970 rewriter.modifyOpInPlace(op, [] {}); 1971 return success(); 1972 } 1973 }; 1974 1975 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 1976 struct TestUndoBlocksMerge : public ConversionPattern { 1977 TestUndoBlocksMerge(MLIRContext *ctx) 1978 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 1979 LogicalResult 1980 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1981 ConversionPatternRewriter &rewriter) const final { 1982 Block &firstBlock = op->getRegion(0).front(); 1983 Operation *branchOp = firstBlock.getTerminator(); 1984 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 1985 rewriter.setInsertionPointToStart(secondBlock); 1986 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 1987 auto succOperands = branchOp->getOperands(); 1988 SmallVector<Value, 2> replacements(succOperands); 1989 rewriter.eraseOp(branchOp); 1990 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1991 rewriter.modifyOpInPlace(op, [] {}); 1992 return success(); 1993 } 1994 }; 1995 1996 /// A rewrite mechanism to inline the body of the op into its parent, when both 1997 /// ops can have a single block. 1998 struct TestMergeSingleBlockOps 1999 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 2000 using OpConversionPattern< 2001 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 2002 2003 LogicalResult 2004 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, 2005 ConversionPatternRewriter &rewriter) const final { 2006 SingleBlockImplicitTerminatorOp parentOp = 2007 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 2008 if (!parentOp) 2009 return failure(); 2010 Block &innerBlock = op.getRegion().front(); 2011 TerminatorOp innerTerminator = 2012 cast<TerminatorOp>(innerBlock.getTerminator()); 2013 rewriter.inlineBlockBefore(&innerBlock, op); 2014 rewriter.eraseOp(innerTerminator); 2015 rewriter.eraseOp(op); 2016 return success(); 2017 } 2018 }; 2019 2020 struct TestMergeBlocksPatternDriver 2021 : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> { 2022 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver) 2023 2024 StringRef getArgument() const final { return "test-merge-blocks"; } 2025 StringRef getDescription() const final { 2026 return "Test Merging operation in ConversionPatternRewriter"; 2027 } 2028 void runOnOperation() override { 2029 MLIRContext *context = &getContext(); 2030 mlir::RewritePatternSet patterns(context); 2031 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 2032 context); 2033 ConversionTarget target(*context); 2034 target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 2035 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 2036 target.addIllegalOp<ILLegalOpF>(); 2037 2038 /// Expect the op to have a single block after legalization. 2039 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 2040 [&](TestMergeBlocksOp op) -> bool { 2041 return llvm::hasSingleElement(op.getBody()); 2042 }); 2043 2044 /// Only allow `test.br` within test.merge_blocks op. 2045 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 2046 return op->getParentOfType<TestMergeBlocksOp>(); 2047 }); 2048 2049 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 2050 /// inlined. 2051 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 2052 [&](SingleBlockImplicitTerminatorOp op) -> bool { 2053 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 2054 }); 2055 2056 DenseSet<Operation *> unlegalizedOps; 2057 ConversionConfig config; 2058 config.unlegalizedOps = &unlegalizedOps; 2059 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 2060 config); 2061 for (auto *op : unlegalizedOps) 2062 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 2063 } 2064 }; 2065 } // namespace 2066 2067 //===----------------------------------------------------------------------===// 2068 // Test Selective Replacement 2069 //===----------------------------------------------------------------------===// 2070 2071 namespace { 2072 /// A rewrite mechanism to inline the body of the op into its parent, when both 2073 /// ops can have a single block. 2074 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 2075 using OpRewritePattern<TestCastOp>::OpRewritePattern; 2076 2077 LogicalResult matchAndRewrite(TestCastOp op, 2078 PatternRewriter &rewriter) const final { 2079 if (op.getNumOperands() != 2) 2080 return failure(); 2081 OperandRange operands = op.getOperands(); 2082 2083 // Replace non-terminator uses with the first operand. 2084 rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) { 2085 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 2086 }); 2087 // Replace everything else with the second operand if the operation isn't 2088 // dead. 2089 rewriter.replaceOp(op, op.getOperand(1)); 2090 return success(); 2091 } 2092 }; 2093 2094 struct TestSelectiveReplacementPatternDriver 2095 : public PassWrapper<TestSelectiveReplacementPatternDriver, 2096 OperationPass<>> { 2097 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 2098 TestSelectiveReplacementPatternDriver) 2099 2100 StringRef getArgument() const final { 2101 return "test-pattern-selective-replacement"; 2102 } 2103 StringRef getDescription() const final { 2104 return "Test selective replacement in the PatternRewriter"; 2105 } 2106 void runOnOperation() override { 2107 MLIRContext *context = &getContext(); 2108 mlir::RewritePatternSet patterns(context); 2109 patterns.add<TestSelectiveOpReplacementPattern>(context); 2110 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 2111 } 2112 }; 2113 } // namespace 2114 2115 //===----------------------------------------------------------------------===// 2116 // PassRegistration 2117 //===----------------------------------------------------------------------===// 2118 2119 namespace mlir { 2120 namespace test { 2121 void registerPatternsTestPass() { 2122 PassRegistration<TestReturnTypeDriver>(); 2123 2124 PassRegistration<TestDerivedAttributeDriver>(); 2125 2126 PassRegistration<TestGreedyPatternDriver>(); 2127 PassRegistration<TestStrictPatternDriver>(); 2128 PassRegistration<TestWalkPatternDriver>(); 2129 2130 PassRegistration<TestLegalizePatternDriver>([] { 2131 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 2132 }); 2133 2134 PassRegistration<TestRemappedValue>(); 2135 2136 PassRegistration<TestUnknownRootOpDriver>(); 2137 2138 PassRegistration<TestTypeConversionDriver>(); 2139 PassRegistration<TestTargetMaterializationWithNoUses>(); 2140 2141 PassRegistration<TestRewriteDynamicOpDriver>(); 2142 2143 PassRegistration<TestMergeBlocksPatternDriver>(); 2144 PassRegistration<TestSelectiveReplacementPatternDriver>(); 2145 } 2146 } // namespace test 2147 } // namespace mlir 2148