1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestOps.h" 11 #include "TestTypes.h" 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/BuiltinAttributes.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/Visitors.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "mlir/Transforms/FoldUtils.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 #include "mlir/Transforms/WalkPatternRewriteDriver.h" 24 #include "llvm/ADT/ScopeExit.h" 25 #include <cstdint> 26 27 using namespace mlir; 28 using namespace test; 29 30 // Native function for testing NativeCodeCall 31 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 32 return choice.getValue() ? input1 : input2; 33 } 34 35 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 36 rewriter.create<OpI>(loc, input); 37 } 38 39 static void handleNoResultOp(PatternRewriter &rewriter, 40 OpSymbolBindingNoResult op) { 41 // Turn the no result op to a one-result op. 42 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), 43 op.getOperand()); 44 } 45 46 static bool getFirstI32Result(Operation *op, Value &value) { 47 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 48 return false; 49 value = op->getResult(0); 50 return true; 51 } 52 53 static Value bindNativeCodeCallResult(Value value) { return value; } 54 55 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, 56 Value input2) { 57 return SmallVector<Value, 2>({input2, input1}); 58 } 59 60 // Test that natives calls are only called once during rewrites. 61 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 62 // This let us check the number of times OpM_Test was called by inspecting 63 // the returned value in the MLIR output. 64 static int64_t opMIncreasingValue = 314159265; 65 static Attribute opMTest(PatternRewriter &rewriter, Value val) { 66 int64_t i = opMIncreasingValue++; 67 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 68 } 69 70 namespace { 71 #include "TestPatterns.inc" 72 } // namespace 73 74 //===----------------------------------------------------------------------===// 75 // Test Reduce Pattern Interface 76 //===----------------------------------------------------------------------===// 77 78 void test::populateTestReductionPatterns(RewritePatternSet &patterns) { 79 populateWithGenerated(patterns); 80 } 81 82 //===----------------------------------------------------------------------===// 83 // Canonicalizer Driver. 84 //===----------------------------------------------------------------------===// 85 86 namespace { 87 struct FoldingPattern : public RewritePattern { 88 public: 89 FoldingPattern(MLIRContext *context) 90 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 91 /*benefit=*/1, context) {} 92 93 LogicalResult matchAndRewrite(Operation *op, 94 PatternRewriter &rewriter) const override { 95 // Exercise createOrFold API for a single-result operation that is folded 96 // upon construction. The operation being created has an in-place folder, 97 // and it should be still present in the output. Furthermore, the folder 98 // should not crash when attempting to recover the (unchanged) operation 99 // result. 100 Value result = rewriter.createOrFold<TestOpInPlaceFold>( 101 op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0)); 102 assert(result); 103 rewriter.replaceOp(op, result); 104 return success(); 105 } 106 }; 107 108 /// This pattern creates a foldable operation at the entry point of the block. 109 /// This tests the situation where the operation folder will need to replace an 110 /// operation with a previously created constant that does not initially 111 /// dominate the operation to replace. 112 struct FolderInsertBeforePreviouslyFoldedConstantPattern 113 : public OpRewritePattern<TestCastOp> { 114 public: 115 using OpRewritePattern<TestCastOp>::OpRewritePattern; 116 117 LogicalResult matchAndRewrite(TestCastOp op, 118 PatternRewriter &rewriter) const override { 119 if (!op->hasAttr("test_fold_before_previously_folded_op")) 120 return failure(); 121 rewriter.setInsertionPointToStart(op->getBlock()); 122 123 auto constOp = rewriter.create<arith::ConstantOp>( 124 op.getLoc(), rewriter.getBoolAttr(true)); 125 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), 126 Value(constOp)); 127 return success(); 128 } 129 }; 130 131 /// This pattern matches test.op_commutative2 with the first operand being 132 /// another test.op_commutative2 with a constant on the right side and fold it 133 /// away by propagating it as its result. This is intend to check that patterns 134 /// are applied after the commutative property moves constant to the right. 135 struct FolderCommutativeOp2WithConstant 136 : public OpRewritePattern<TestCommutative2Op> { 137 public: 138 using OpRewritePattern<TestCommutative2Op>::OpRewritePattern; 139 140 LogicalResult matchAndRewrite(TestCommutative2Op op, 141 PatternRewriter &rewriter) const override { 142 auto operand = 143 dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp()); 144 if (!operand) 145 return failure(); 146 Attribute constInput; 147 if (!matchPattern(operand->getOperand(1), m_Constant(&constInput))) 148 return failure(); 149 rewriter.replaceOp(op, operand->getOperand(1)); 150 return success(); 151 } 152 }; 153 154 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer 155 /// attribute with value smaller than MaxVal, it increments the value by 1. 156 template <int MaxVal> 157 struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> { 158 using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern; 159 160 LogicalResult matchAndRewrite(AnyAttrOfOp op, 161 PatternRewriter &rewriter) const override { 162 auto intAttr = dyn_cast<IntegerAttr>(op.getAttr()); 163 if (!intAttr) 164 return failure(); 165 int64_t val = intAttr.getInt(); 166 if (val >= MaxVal) 167 return failure(); 168 rewriter.modifyOpInPlace( 169 op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); }); 170 return success(); 171 } 172 }; 173 174 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op". 175 struct MakeOpEligible : public RewritePattern { 176 MakeOpEligible(MLIRContext *context) 177 : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {} 178 179 LogicalResult matchAndRewrite(Operation *op, 180 PatternRewriter &rewriter) const override { 181 if (op->hasAttr("eligible")) 182 return failure(); 183 rewriter.modifyOpInPlace( 184 op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); }); 185 return success(); 186 } 187 }; 188 189 /// This pattern hoists eligible ops out of a "test.one_region_op". 190 struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> { 191 using OpRewritePattern<test::OneRegionOp>::OpRewritePattern; 192 193 LogicalResult matchAndRewrite(test::OneRegionOp op, 194 PatternRewriter &rewriter) const override { 195 Operation *terminator = op.getRegion().front().getTerminator(); 196 Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp(); 197 if (toBeHoisted->getParentOp() != op) 198 return failure(); 199 if (!toBeHoisted->hasAttr("eligible")) 200 return failure(); 201 rewriter.moveOpBefore(toBeHoisted, op); 202 return success(); 203 } 204 }; 205 206 /// This pattern moves "test.move_before_parent_op" before the parent op. 207 struct MoveBeforeParentOp : public RewritePattern { 208 MoveBeforeParentOp(MLIRContext *context) 209 : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {} 210 211 LogicalResult matchAndRewrite(Operation *op, 212 PatternRewriter &rewriter) const override { 213 // Do not hoist past functions. 214 if (isa<FunctionOpInterface>(op->getParentOp())) 215 return failure(); 216 rewriter.moveOpBefore(op, op->getParentOp()); 217 return success(); 218 } 219 }; 220 221 /// This pattern moves "test.move_after_parent_op" after the parent op. 222 struct MoveAfterParentOp : public RewritePattern { 223 MoveAfterParentOp(MLIRContext *context) 224 : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {} 225 226 LogicalResult matchAndRewrite(Operation *op, 227 PatternRewriter &rewriter) const override { 228 // Do not hoist past functions. 229 if (isa<FunctionOpInterface>(op->getParentOp())) 230 return failure(); 231 232 int64_t moveForwardBy = 0; 233 if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance")) 234 moveForwardBy = advanceBy.getInt(); 235 236 Operation *moveAfter = op->getParentOp(); 237 for (int64_t i = 0; i < moveForwardBy; ++i) 238 moveAfter = moveAfter->getNextNode(); 239 240 rewriter.moveOpAfter(op, moveAfter); 241 return success(); 242 } 243 }; 244 245 /// This pattern inlines blocks that are nested in 246 /// "test.inline_blocks_into_parent" into the parent block. 247 struct InlineBlocksIntoParent : public RewritePattern { 248 InlineBlocksIntoParent(MLIRContext *context) 249 : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1, 250 context) {} 251 252 LogicalResult matchAndRewrite(Operation *op, 253 PatternRewriter &rewriter) const override { 254 bool changed = false; 255 for (Region &r : op->getRegions()) { 256 while (!r.empty()) { 257 rewriter.inlineBlockBefore(&r.front(), op); 258 changed = true; 259 } 260 } 261 return success(changed); 262 } 263 }; 264 265 /// This pattern splits blocks at "test.split_block_here" and replaces the op 266 /// with a new op (to prevent an infinite loop of block splitting). 267 struct SplitBlockHere : public RewritePattern { 268 SplitBlockHere(MLIRContext *context) 269 : RewritePattern("test.split_block_here", /*benefit=*/1, context) {} 270 271 LogicalResult matchAndRewrite(Operation *op, 272 PatternRewriter &rewriter) const override { 273 rewriter.splitBlock(op->getBlock(), op->getIterator()); 274 Operation *newOp = rewriter.create( 275 op->getLoc(), 276 OperationName("test.new_op", op->getContext()).getIdentifier(), 277 op->getOperands(), op->getResultTypes()); 278 rewriter.replaceOp(op, newOp); 279 return success(); 280 } 281 }; 282 283 /// This pattern clones "test.clone_me" ops. 284 struct CloneOp : public RewritePattern { 285 CloneOp(MLIRContext *context) 286 : RewritePattern("test.clone_me", /*benefit=*/1, context) {} 287 288 LogicalResult matchAndRewrite(Operation *op, 289 PatternRewriter &rewriter) const override { 290 // Do not clone already cloned ops to avoid going into an infinite loop. 291 if (op->hasAttr("was_cloned")) 292 return failure(); 293 Operation *cloned = rewriter.clone(*op); 294 cloned->setAttr("was_cloned", rewriter.getUnitAttr()); 295 return success(); 296 } 297 }; 298 299 /// This pattern clones regions of "test.clone_region_before" ops before the 300 /// parent block. 301 struct CloneRegionBeforeOp : public RewritePattern { 302 CloneRegionBeforeOp(MLIRContext *context) 303 : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {} 304 305 LogicalResult matchAndRewrite(Operation *op, 306 PatternRewriter &rewriter) const override { 307 // Do not clone already cloned ops to avoid going into an infinite loop. 308 if (op->hasAttr("was_cloned")) 309 return failure(); 310 for (Region &r : op->getRegions()) 311 rewriter.cloneRegionBefore(r, op->getBlock()); 312 op->setAttr("was_cloned", rewriter.getUnitAttr()); 313 return success(); 314 } 315 }; 316 317 /// Replace an operation may introduce the re-visiting of its users. 318 class ReplaceWithNewOp : public RewritePattern { 319 public: 320 ReplaceWithNewOp(MLIRContext *context) 321 : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {} 322 323 LogicalResult matchAndRewrite(Operation *op, 324 PatternRewriter &rewriter) const override { 325 Operation *newOp; 326 if (op->hasAttr("create_erase_op")) { 327 newOp = rewriter.create( 328 op->getLoc(), 329 OperationName("test.erase_op", op->getContext()).getIdentifier(), 330 ValueRange(), TypeRange()); 331 } else { 332 newOp = rewriter.create( 333 op->getLoc(), 334 OperationName("test.new_op", op->getContext()).getIdentifier(), 335 op->getOperands(), op->getResultTypes()); 336 } 337 // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp". 338 // A "notifyOperationReplaced" callback is triggered in either case. 339 rewriter.replaceAllOpUsesWith(op, newOp->getResults()); 340 rewriter.eraseOp(op); 341 return success(); 342 } 343 }; 344 345 /// Erases the first child block of the matched "test.erase_first_block" 346 /// operation. 347 class EraseFirstBlock : public RewritePattern { 348 public: 349 EraseFirstBlock(MLIRContext *context) 350 : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {} 351 352 LogicalResult matchAndRewrite(Operation *op, 353 PatternRewriter &rewriter) const override { 354 for (Region &r : op->getRegions()) { 355 for (Block &b : r.getBlocks()) { 356 rewriter.eraseBlock(&b); 357 return success(); 358 } 359 } 360 361 return failure(); 362 } 363 }; 364 365 struct TestGreedyPatternDriver 366 : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> { 367 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver) 368 369 TestGreedyPatternDriver() = default; 370 TestGreedyPatternDriver(const TestGreedyPatternDriver &other) 371 : PassWrapper(other) {} 372 373 StringRef getArgument() const final { return "test-greedy-patterns"; } 374 StringRef getDescription() const final { return "Run test dialect patterns"; } 375 void runOnOperation() override { 376 mlir::RewritePatternSet patterns(&getContext()); 377 populateWithGenerated(patterns); 378 379 // Verify named pattern is generated with expected name. 380 patterns.add<FoldingPattern, TestNamedPatternRule, 381 FolderInsertBeforePreviouslyFoldedConstantPattern, 382 FolderCommutativeOp2WithConstant, HoistEligibleOps, 383 MakeOpEligible>(&getContext()); 384 385 // Additional patterns for testing the GreedyPatternRewriteDriver. 386 patterns.insert<IncrementIntAttribute<3>>(&getContext()); 387 388 GreedyRewriteConfig config; 389 config.useTopDownTraversal = this->useTopDownTraversal; 390 config.maxIterations = this->maxIterations; 391 config.fold = this->fold; 392 config.cseConstants = this->cseConstants; 393 (void)applyPatternsGreedily(getOperation(), std::move(patterns), config); 394 } 395 396 Option<bool> useTopDownTraversal{ 397 *this, "top-down", 398 llvm::cl::desc("Seed the worklist in general top-down order"), 399 llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)}; 400 Option<int> maxIterations{ 401 *this, "max-iterations", 402 llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"), 403 llvm::cl::init(GreedyRewriteConfig().maxIterations)}; 404 Option<bool> fold{*this, "fold", llvm::cl::desc("Whether to fold"), 405 llvm::cl::init(GreedyRewriteConfig().fold)}; 406 Option<bool> cseConstants{*this, "cse-constants", 407 llvm::cl::desc("Whether to CSE constants"), 408 llvm::cl::init(GreedyRewriteConfig().cseConstants)}; 409 }; 410 411 struct DumpNotifications : public RewriterBase::Listener { 412 void notifyBlockInserted(Block *block, Region *previous, 413 Region::iterator previousIt) override { 414 llvm::outs() << "notifyBlockInserted"; 415 if (block->getParentOp()) { 416 llvm::outs() << " into " << block->getParentOp()->getName() << ": "; 417 } else { 418 llvm::outs() << " into unknown op: "; 419 } 420 if (previous == nullptr) { 421 llvm::outs() << "was unlinked\n"; 422 } else { 423 llvm::outs() << "was linked\n"; 424 } 425 } 426 void notifyOperationInserted(Operation *op, 427 OpBuilder::InsertPoint previous) override { 428 llvm::outs() << "notifyOperationInserted: " << op->getName(); 429 if (!previous.isSet()) { 430 llvm::outs() << ", was unlinked\n"; 431 } else { 432 if (!previous.getPoint().getNodePtr()) { 433 llvm::outs() << ", was linked, exact position unknown\n"; 434 } else if (previous.getPoint() == previous.getBlock()->end()) { 435 llvm::outs() << ", was last in block\n"; 436 } else { 437 llvm::outs() << ", previous = " << previous.getPoint()->getName() 438 << "\n"; 439 } 440 } 441 } 442 void notifyBlockErased(Block *block) override { 443 llvm::outs() << "notifyBlockErased\n"; 444 } 445 void notifyOperationErased(Operation *op) override { 446 llvm::outs() << "notifyOperationErased: " << op->getName() << "\n"; 447 } 448 void notifyOperationModified(Operation *op) override { 449 llvm::outs() << "notifyOperationModified: " << op->getName() << "\n"; 450 } 451 void notifyOperationReplaced(Operation *op, ValueRange values) override { 452 llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n"; 453 } 454 }; 455 456 struct TestStrictPatternDriver 457 : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> { 458 public: 459 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver) 460 461 TestStrictPatternDriver() = default; 462 TestStrictPatternDriver(const TestStrictPatternDriver &other) 463 : PassWrapper(other) { 464 strictMode = other.strictMode; 465 } 466 467 StringRef getArgument() const final { return "test-strict-pattern-driver"; } 468 StringRef getDescription() const final { 469 return "Test strict mode of pattern driver"; 470 } 471 472 void runOnOperation() override { 473 MLIRContext *ctx = &getContext(); 474 mlir::RewritePatternSet patterns(ctx); 475 patterns.add< 476 // clang-format off 477 ChangeBlockOp, 478 CloneOp, 479 CloneRegionBeforeOp, 480 EraseOp, 481 ImplicitChangeOp, 482 InlineBlocksIntoParent, 483 InsertSameOp, 484 MoveBeforeParentOp, 485 ReplaceWithNewOp, 486 SplitBlockHere 487 // clang-format on 488 >(ctx); 489 SmallVector<Operation *> ops; 490 getOperation()->walk([&](Operation *op) { 491 StringRef opName = op->getName().getStringRef(); 492 if (opName == "test.insert_same_op" || opName == "test.change_block_op" || 493 opName == "test.replace_with_new_op" || opName == "test.erase_op" || 494 opName == "test.move_before_parent_op" || 495 opName == "test.inline_blocks_into_parent" || 496 opName == "test.split_block_here" || opName == "test.clone_me" || 497 opName == "test.clone_region_before") { 498 ops.push_back(op); 499 } 500 }); 501 502 DumpNotifications dumpNotifications; 503 GreedyRewriteConfig config; 504 config.listener = &dumpNotifications; 505 if (strictMode == "AnyOp") { 506 config.strictMode = GreedyRewriteStrictness::AnyOp; 507 } else if (strictMode == "ExistingAndNewOps") { 508 config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; 509 } else if (strictMode == "ExistingOps") { 510 config.strictMode = GreedyRewriteStrictness::ExistingOps; 511 } else { 512 llvm_unreachable("invalid strictness option"); 513 } 514 515 // Check if these transformations introduce visiting of operations that 516 // are not in the `ops` set (The new created ops are valid). An invalid 517 // operation will trigger the assertion while processing. 518 bool changed = false; 519 bool allErased = false; 520 (void)applyOpPatternsGreedily(ArrayRef(ops), std::move(patterns), config, 521 &changed, &allErased); 522 Builder b(ctx); 523 getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed)); 524 getOperation()->setAttr("pattern_driver_all_erased", 525 b.getBoolAttr(allErased)); 526 } 527 528 Option<std::string> strictMode{ 529 *this, "strictness", 530 llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"), 531 llvm::cl::init("AnyOp")}; 532 533 private: 534 // New inserted operation is valid for further transformation. 535 class InsertSameOp : public RewritePattern { 536 public: 537 InsertSameOp(MLIRContext *context) 538 : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {} 539 540 LogicalResult matchAndRewrite(Operation *op, 541 PatternRewriter &rewriter) const override { 542 if (op->hasAttr("skip")) 543 return failure(); 544 545 Operation *newOp = 546 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 547 op->getOperands(), op->getResultTypes()); 548 rewriter.modifyOpInPlace( 549 op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); }); 550 newOp->setAttr("skip", rewriter.getBoolAttr(true)); 551 552 return success(); 553 } 554 }; 555 556 // Remove an operation may introduce the re-visiting of its operands. 557 class EraseOp : public RewritePattern { 558 public: 559 EraseOp(MLIRContext *context) 560 : RewritePattern("test.erase_op", /*benefit=*/1, context) {} 561 LogicalResult matchAndRewrite(Operation *op, 562 PatternRewriter &rewriter) const override { 563 rewriter.eraseOp(op); 564 return success(); 565 } 566 }; 567 568 // The following two patterns test RewriterBase::replaceAllUsesWith. 569 // 570 // That function replaces all usages of a Block (or a Value) with another one 571 // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver 572 // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its 573 // worklist: when an op is modified, it is added to the worklist. The two 574 // patterns below make the tracking observable: ChangeBlockOp replaces all 575 // usages of a block and that pattern is applied because the corresponding ops 576 // are put on the initial worklist (see above). ImplicitChangeOp does an 577 // unrelated change but ops of the corresponding type are *not* on the initial 578 // worklist, so the effect of the second pattern is only visible if the 579 // tracking and subsequent adding to the worklist actually works. 580 581 // Replace all usages of the first successor with the second successor. 582 class ChangeBlockOp : public RewritePattern { 583 public: 584 ChangeBlockOp(MLIRContext *context) 585 : RewritePattern("test.change_block_op", /*benefit=*/1, context) {} 586 LogicalResult matchAndRewrite(Operation *op, 587 PatternRewriter &rewriter) const override { 588 if (op->getNumSuccessors() < 2) 589 return failure(); 590 Block *firstSuccessor = op->getSuccessor(0); 591 Block *secondSuccessor = op->getSuccessor(1); 592 if (firstSuccessor == secondSuccessor) 593 return failure(); 594 // This is the function being tested: 595 rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor); 596 // Using the following line instead would make the test fail: 597 // firstSuccessor->replaceAllUsesWith(secondSuccessor); 598 return success(); 599 } 600 }; 601 602 // Changes the successor to the parent block. 603 class ImplicitChangeOp : public RewritePattern { 604 public: 605 ImplicitChangeOp(MLIRContext *context) 606 : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {} 607 LogicalResult matchAndRewrite(Operation *op, 608 PatternRewriter &rewriter) const override { 609 if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock()) 610 return failure(); 611 rewriter.modifyOpInPlace(op, 612 [&]() { op->setSuccessor(op->getBlock(), 0); }); 613 return success(); 614 } 615 }; 616 }; 617 618 struct TestWalkPatternDriver final 619 : PassWrapper<TestWalkPatternDriver, OperationPass<>> { 620 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver) 621 622 TestWalkPatternDriver() = default; 623 TestWalkPatternDriver(const TestWalkPatternDriver &other) 624 : PassWrapper(other) {} 625 626 StringRef getArgument() const override { 627 return "test-walk-pattern-rewrite-driver"; 628 } 629 StringRef getDescription() const override { 630 return "Run test walk pattern rewrite driver"; 631 } 632 void runOnOperation() override { 633 mlir::RewritePatternSet patterns(&getContext()); 634 635 // Patterns for testing the WalkPatternRewriteDriver. 636 patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp, 637 MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>( 638 &getContext()); 639 640 DumpNotifications dumpListener; 641 walkAndApplyPatterns(getOperation(), std::move(patterns), 642 dumpNotifications ? &dumpListener : nullptr); 643 } 644 645 Option<bool> dumpNotifications{ 646 *this, "dump-notifications", 647 llvm::cl::desc("Print rewrite listener notifications"), 648 llvm::cl::init(false)}; 649 }; 650 651 } // namespace 652 653 //===----------------------------------------------------------------------===// 654 // ReturnType Driver. 655 //===----------------------------------------------------------------------===// 656 657 namespace { 658 // Generate ops for each instance where the type can be successfully inferred. 659 template <typename OpTy> 660 static void invokeCreateWithInferredReturnType(Operation *op) { 661 auto *context = op->getContext(); 662 auto fop = op->getParentOfType<func::FuncOp>(); 663 auto location = UnknownLoc::get(context); 664 OpBuilder b(op); 665 b.setInsertionPointAfter(op); 666 667 // Use permutations of 2 args as operands. 668 assert(fop.getNumArguments() >= 2); 669 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 670 for (int j = 0; j < e; ++j) { 671 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 672 SmallVector<Type, 2> inferredReturnTypes; 673 if (succeeded(OpTy::inferReturnTypes( 674 context, std::nullopt, values, op->getDiscardableAttrDictionary(), 675 op->getPropertiesStorage(), op->getRegions(), 676 inferredReturnTypes))) { 677 OperationState state(location, OpTy::getOperationName()); 678 // TODO: Expand to regions. 679 OpTy::build(b, state, values, op->getAttrs()); 680 (void)b.create(state); 681 } 682 } 683 } 684 } 685 686 static void reifyReturnShape(Operation *op) { 687 OpBuilder b(op); 688 689 // Use permutations of 2 args as operands. 690 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 691 SmallVector<Value, 2> shapes; 692 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 693 !llvm::hasSingleElement(shapes)) 694 return; 695 for (const auto &it : llvm::enumerate(shapes)) { 696 op->emitRemark() << "value " << it.index() << ": " 697 << it.value().getDefiningOp(); 698 } 699 } 700 701 struct TestReturnTypeDriver 702 : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> { 703 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver) 704 705 void getDependentDialects(DialectRegistry ®istry) const override { 706 registry.insert<tensor::TensorDialect>(); 707 } 708 StringRef getArgument() const final { return "test-return-type"; } 709 StringRef getDescription() const final { return "Run return type functions"; } 710 711 void runOnOperation() override { 712 if (getOperation().getName() == "testCreateFunctions") { 713 std::vector<Operation *> ops; 714 // Collect ops to avoid triggering on inserted ops. 715 for (auto &op : getOperation().getBody().front()) 716 ops.push_back(&op); 717 // Generate test patterns for each, but skip terminator. 718 for (auto *op : llvm::ArrayRef(ops).drop_back()) { 719 // Test create method of each of the Op classes below. The resultant 720 // output would be in reverse order underneath `op` from which 721 // the attributes and regions are used. 722 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 723 invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>( 724 op); 725 invokeCreateWithInferredReturnType< 726 OpWithShapedTypeInferTypeInterfaceOp>(op); 727 }; 728 return; 729 } 730 if (getOperation().getName() == "testReifyFunctions") { 731 std::vector<Operation *> ops; 732 // Collect ops to avoid triggering on inserted ops. 733 for (auto &op : getOperation().getBody().front()) 734 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 735 ops.push_back(&op); 736 // Generate test patterns for each, but skip terminator. 737 for (auto *op : ops) 738 reifyReturnShape(op); 739 } 740 } 741 }; 742 } // namespace 743 744 namespace { 745 struct TestDerivedAttributeDriver 746 : public PassWrapper<TestDerivedAttributeDriver, 747 OperationPass<func::FuncOp>> { 748 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver) 749 750 StringRef getArgument() const final { return "test-derived-attr"; } 751 StringRef getDescription() const final { 752 return "Run test derived attributes"; 753 } 754 void runOnOperation() override; 755 }; 756 } // namespace 757 758 void TestDerivedAttributeDriver::runOnOperation() { 759 getOperation().walk([](DerivedAttributeOpInterface dOp) { 760 auto dAttr = dOp.materializeDerivedAttributes(); 761 if (!dAttr) 762 return; 763 for (auto d : dAttr) 764 dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue(); 765 }); 766 } 767 768 //===----------------------------------------------------------------------===// 769 // Legalization Driver. 770 //===----------------------------------------------------------------------===// 771 772 namespace { 773 //===----------------------------------------------------------------------===// 774 // Region-Block Rewrite Testing 775 776 /// This pattern applies a signature conversion to a block inside a detached 777 /// region. 778 struct TestDetachedSignatureConversion : public ConversionPattern { 779 TestDetachedSignatureConversion(MLIRContext *ctx) 780 : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1, 781 ctx) {} 782 783 LogicalResult 784 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 785 ConversionPatternRewriter &rewriter) const final { 786 if (op->getNumRegions() != 1) 787 return failure(); 788 OperationState state(op->getLoc(), "test.legal_op", operands, 789 op->getResultTypes(), {}, BlockRange()); 790 Region *newRegion = state.addRegion(); 791 rewriter.inlineRegionBefore(op->getRegion(0), *newRegion, 792 newRegion->begin()); 793 TypeConverter::SignatureConversion result(newRegion->getNumArguments()); 794 for (unsigned i = 0, e = newRegion->getNumArguments(); i < e; ++i) 795 result.addInputs(i, rewriter.getF64Type()); 796 rewriter.applySignatureConversion(&newRegion->front(), result); 797 Operation *newOp = rewriter.create(state); 798 rewriter.replaceOp(op, newOp->getResults()); 799 return success(); 800 } 801 }; 802 803 /// This pattern is a simple pattern that inlines the first region of a given 804 /// operation into the parent region. 805 struct TestRegionRewriteBlockMovement : public ConversionPattern { 806 TestRegionRewriteBlockMovement(MLIRContext *ctx) 807 : ConversionPattern("test.region", 1, ctx) {} 808 809 LogicalResult 810 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 811 ConversionPatternRewriter &rewriter) const final { 812 // Inline this region into the parent region. 813 auto &parentRegion = *op->getParentRegion(); 814 auto &opRegion = op->getRegion(0); 815 if (op->getDiscardableAttr("legalizer.should_clone")) 816 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 817 else 818 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 819 820 if (op->getDiscardableAttr("legalizer.erase_old_blocks")) { 821 while (!opRegion.empty()) 822 rewriter.eraseBlock(&opRegion.front()); 823 } 824 825 // Drop this operation. 826 rewriter.eraseOp(op); 827 return success(); 828 } 829 }; 830 /// This pattern is a simple pattern that generates a region containing an 831 /// illegal operation. 832 struct TestRegionRewriteUndo : public RewritePattern { 833 TestRegionRewriteUndo(MLIRContext *ctx) 834 : RewritePattern("test.region_builder", 1, ctx) {} 835 836 LogicalResult matchAndRewrite(Operation *op, 837 PatternRewriter &rewriter) const final { 838 // Create the region operation with an entry block containing arguments. 839 OperationState newRegion(op->getLoc(), "test.region"); 840 newRegion.addRegion(); 841 auto *regionOp = rewriter.create(newRegion); 842 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 843 entryBlock->addArgument(rewriter.getIntegerType(64), 844 rewriter.getUnknownLoc()); 845 846 // Add an explicitly illegal operation to ensure the conversion fails. 847 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 848 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 849 850 // Drop this operation. 851 rewriter.eraseOp(op); 852 return success(); 853 } 854 }; 855 /// A simple pattern that creates a block at the end of the parent region of the 856 /// matched operation. 857 struct TestCreateBlock : public RewritePattern { 858 TestCreateBlock(MLIRContext *ctx) 859 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 860 861 LogicalResult matchAndRewrite(Operation *op, 862 PatternRewriter &rewriter) const final { 863 Region ®ion = *op->getParentRegion(); 864 Type i32Type = rewriter.getIntegerType(32); 865 Location loc = op->getLoc(); 866 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 867 rewriter.create<TerminatorOp>(loc); 868 rewriter.eraseOp(op); 869 return success(); 870 } 871 }; 872 873 /// A simple pattern that creates a block containing an invalid operation in 874 /// order to trigger the block creation undo mechanism. 875 struct TestCreateIllegalBlock : public RewritePattern { 876 TestCreateIllegalBlock(MLIRContext *ctx) 877 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 878 879 LogicalResult matchAndRewrite(Operation *op, 880 PatternRewriter &rewriter) const final { 881 Region ®ion = *op->getParentRegion(); 882 Type i32Type = rewriter.getIntegerType(32); 883 Location loc = op->getLoc(); 884 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 885 // Create an illegal op to ensure the conversion fails. 886 rewriter.create<ILLegalOpF>(loc, i32Type); 887 rewriter.create<TerminatorOp>(loc); 888 rewriter.eraseOp(op); 889 return success(); 890 } 891 }; 892 893 /// A simple pattern that tests the undo mechanism when replacing the uses of a 894 /// block argument. 895 struct TestUndoBlockArgReplace : public ConversionPattern { 896 TestUndoBlockArgReplace(MLIRContext *ctx) 897 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 898 899 LogicalResult 900 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 901 ConversionPatternRewriter &rewriter) const final { 902 auto illegalOp = 903 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 904 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 905 illegalOp->getResult(0)); 906 rewriter.modifyOpInPlace(op, [] {}); 907 return success(); 908 } 909 }; 910 911 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion. 912 /// This is to test the rollback logic. 913 struct TestUndoMoveOpBefore : public ConversionPattern { 914 TestUndoMoveOpBefore(MLIRContext *ctx) 915 : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {} 916 917 LogicalResult 918 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 919 ConversionPatternRewriter &rewriter) const override { 920 rewriter.moveOpBefore(op, op->getParentOp()); 921 // Replace with an illegal op to ensure the conversion fails. 922 rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type()); 923 return success(); 924 } 925 }; 926 927 /// A rewrite pattern that tests the undo mechanism when erasing a block. 928 struct TestUndoBlockErase : public ConversionPattern { 929 TestUndoBlockErase(MLIRContext *ctx) 930 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 931 932 LogicalResult 933 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 934 ConversionPatternRewriter &rewriter) const final { 935 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 936 rewriter.setInsertionPointToStart(secondBlock); 937 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 938 rewriter.eraseBlock(secondBlock); 939 rewriter.modifyOpInPlace(op, [] {}); 940 return success(); 941 } 942 }; 943 944 /// A pattern that modifies a property in-place, but keeps the op illegal. 945 struct TestUndoPropertiesModification : public ConversionPattern { 946 TestUndoPropertiesModification(MLIRContext *ctx) 947 : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {} 948 LogicalResult 949 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 950 ConversionPatternRewriter &rewriter) const final { 951 if (!op->hasAttr("modify_inplace")) 952 return failure(); 953 rewriter.modifyOpInPlace( 954 op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); }); 955 return success(); 956 } 957 }; 958 959 //===----------------------------------------------------------------------===// 960 // Type-Conversion Rewrite Testing 961 962 /// This patterns erases a region operation that has had a type conversion. 963 struct TestDropOpSignatureConversion : public ConversionPattern { 964 TestDropOpSignatureConversion(MLIRContext *ctx, 965 const TypeConverter &converter) 966 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 967 LogicalResult 968 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 969 ConversionPatternRewriter &rewriter) const override { 970 Region ®ion = op->getRegion(0); 971 Block *entry = ®ion.front(); 972 973 // Convert the original entry arguments. 974 const TypeConverter &converter = *getTypeConverter(); 975 TypeConverter::SignatureConversion result(entry->getNumArguments()); 976 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 977 result)) || 978 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 979 return failure(); 980 981 // Convert the region signature and just drop the operation. 982 rewriter.eraseOp(op); 983 return success(); 984 } 985 }; 986 /// This pattern simply updates the operands of the given operation. 987 struct TestPassthroughInvalidOp : public ConversionPattern { 988 TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter) 989 : ConversionPattern(converter, "test.invalid", 1, ctx) {} 990 LogicalResult 991 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 992 ConversionPatternRewriter &rewriter) const final { 993 SmallVector<Value> flattened; 994 for (auto it : llvm::enumerate(operands)) { 995 ValueRange range = it.value(); 996 if (range.size() == 1) { 997 flattened.push_back(range.front()); 998 continue; 999 } 1000 1001 // This is a 1:N replacement. Insert a test.cast op. (That's what the 1002 // argument materialization used to do.) 1003 flattened.push_back( 1004 rewriter 1005 .create<TestCastOp>(op->getLoc(), 1006 op->getOperand(it.index()).getType(), range) 1007 .getResult()); 1008 } 1009 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, flattened, 1010 std::nullopt); 1011 return success(); 1012 } 1013 }; 1014 /// Replace with valid op, but simply drop the operands. This is used in a 1015 /// regression where we used to generate circular unrealized_conversion_cast 1016 /// ops. 1017 struct TestDropAndReplaceInvalidOp : public ConversionPattern { 1018 TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter) 1019 : ConversionPattern(converter, 1020 "test.drop_operands_and_replace_with_valid", 1, ctx) { 1021 } 1022 LogicalResult 1023 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1024 ConversionPatternRewriter &rewriter) const final { 1025 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, ValueRange(), 1026 std::nullopt); 1027 return success(); 1028 } 1029 }; 1030 /// This pattern handles the case of a split return value. 1031 struct TestSplitReturnType : public ConversionPattern { 1032 TestSplitReturnType(MLIRContext *ctx) 1033 : ConversionPattern("test.return", 1, ctx) {} 1034 LogicalResult 1035 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 1036 ConversionPatternRewriter &rewriter) const final { 1037 // Check for a return of F32. 1038 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 1039 return failure(); 1040 rewriter.replaceOpWithNewOp<TestReturnOp>(op, operands[0]); 1041 return success(); 1042 } 1043 }; 1044 1045 //===----------------------------------------------------------------------===// 1046 // Multi-Level Type-Conversion Rewrite Testing 1047 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 1048 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 1049 : ConversionPattern("test.type_producer", 1, ctx) {} 1050 LogicalResult 1051 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1052 ConversionPatternRewriter &rewriter) const final { 1053 // If the type is I32, change the type to F32. 1054 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 1055 return failure(); 1056 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 1057 return success(); 1058 } 1059 }; 1060 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 1061 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 1062 : ConversionPattern("test.type_producer", 1, ctx) {} 1063 LogicalResult 1064 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1065 ConversionPatternRewriter &rewriter) const final { 1066 // If the type is F32, change the type to F64. 1067 if (!Type(*op->result_type_begin()).isF32()) 1068 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 1069 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 1070 return success(); 1071 } 1072 }; 1073 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 1074 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 1075 : ConversionPattern("test.type_producer", 10, ctx) {} 1076 LogicalResult 1077 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1078 ConversionPatternRewriter &rewriter) const final { 1079 // Always convert to B16, even though it is not a legal type. This tests 1080 // that values are unmapped correctly. 1081 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 1082 return success(); 1083 } 1084 }; 1085 struct TestUpdateConsumerType : public ConversionPattern { 1086 TestUpdateConsumerType(MLIRContext *ctx) 1087 : ConversionPattern("test.type_consumer", 1, ctx) {} 1088 LogicalResult 1089 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1090 ConversionPatternRewriter &rewriter) const final { 1091 // Verify that the incoming operand has been successfully remapped to F64. 1092 if (!operands[0].getType().isF64()) 1093 return failure(); 1094 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 1095 return success(); 1096 } 1097 }; 1098 1099 //===----------------------------------------------------------------------===// 1100 // Non-Root Replacement Rewrite Testing 1101 /// This pattern generates an invalid operation, but replaces it before the 1102 /// pattern is finished. This checks that we don't need to legalize the 1103 /// temporary op. 1104 struct TestNonRootReplacement : public RewritePattern { 1105 TestNonRootReplacement(MLIRContext *ctx) 1106 : RewritePattern("test.replace_non_root", 1, ctx) {} 1107 1108 LogicalResult matchAndRewrite(Operation *op, 1109 PatternRewriter &rewriter) const final { 1110 auto resultType = *op->result_type_begin(); 1111 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 1112 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 1113 1114 rewriter.replaceOp(illegalOp, legalOp); 1115 rewriter.replaceOp(op, illegalOp); 1116 return success(); 1117 } 1118 }; 1119 1120 //===----------------------------------------------------------------------===// 1121 // Recursive Rewrite Testing 1122 /// This pattern is applied to the same operation multiple times, but has a 1123 /// bounded recursion. 1124 struct TestBoundedRecursiveRewrite 1125 : public OpRewritePattern<TestRecursiveRewriteOp> { 1126 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 1127 1128 void initialize() { 1129 // The conversion target handles bounding the recursion of this pattern. 1130 setHasBoundedRewriteRecursion(); 1131 } 1132 1133 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 1134 PatternRewriter &rewriter) const final { 1135 // Decrement the depth of the op in-place. 1136 rewriter.modifyOpInPlace(op, [&] { 1137 op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1)); 1138 }); 1139 return success(); 1140 } 1141 }; 1142 1143 struct TestNestedOpCreationUndoRewrite 1144 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 1145 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 1146 1147 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 1148 PatternRewriter &rewriter) const final { 1149 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 1150 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 1151 return success(); 1152 }; 1153 }; 1154 1155 // This pattern matches `test.blackhole` and delete this op and its producer. 1156 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 1157 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 1158 1159 LogicalResult matchAndRewrite(BlackHoleOp op, 1160 PatternRewriter &rewriter) const final { 1161 Operation *producer = op.getOperand().getDefiningOp(); 1162 // Always erase the user before the producer, the framework should handle 1163 // this correctly. 1164 rewriter.eraseOp(op); 1165 rewriter.eraseOp(producer); 1166 return success(); 1167 }; 1168 }; 1169 1170 // This pattern replaces explicitly illegal op with explicitly legal op, 1171 // but in addition creates unregistered operation. 1172 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { 1173 using OpRewritePattern<ILLegalOpG>::OpRewritePattern; 1174 1175 LogicalResult matchAndRewrite(ILLegalOpG op, 1176 PatternRewriter &rewriter) const final { 1177 IntegerAttr attr = rewriter.getI32IntegerAttr(0); 1178 Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr); 1179 rewriter.replaceOpWithNewOp<LegalOpC>(op, val); 1180 return success(); 1181 }; 1182 }; 1183 1184 class TestEraseOp : public ConversionPattern { 1185 public: 1186 TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op", 1, ctx) {} 1187 LogicalResult 1188 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1189 ConversionPatternRewriter &rewriter) const final { 1190 // Erase op without replacements. 1191 rewriter.eraseOp(op); 1192 return success(); 1193 } 1194 }; 1195 1196 /// This pattern matches a test.duplicate_block_args op and duplicates all 1197 /// block arguments. 1198 class TestDuplicateBlockArgs 1199 : public OpConversionPattern<DuplicateBlockArgsOp> { 1200 using OpConversionPattern<DuplicateBlockArgsOp>::OpConversionPattern; 1201 1202 LogicalResult 1203 matchAndRewrite(DuplicateBlockArgsOp op, OpAdaptor adaptor, 1204 ConversionPatternRewriter &rewriter) const override { 1205 if (op.getIsLegal()) 1206 return failure(); 1207 rewriter.startOpModification(op); 1208 Block *body = &op.getBody().front(); 1209 TypeConverter::SignatureConversion result(body->getNumArguments()); 1210 for (auto it : llvm::enumerate(body->getArgumentTypes())) 1211 result.addInputs(it.index(), {it.value(), it.value()}); 1212 rewriter.applySignatureConversion(body, result, getTypeConverter()); 1213 op.setIsLegal(true); 1214 rewriter.finalizeOpModification(op); 1215 return success(); 1216 } 1217 }; 1218 1219 /// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid 1220 /// op. The pattern supports 1:N replacements and forwards the replacement 1221 /// values of the single operand as test.valid operands. 1222 class TestRepetitive1ToNConsumer : public ConversionPattern { 1223 public: 1224 TestRepetitive1ToNConsumer(MLIRContext *ctx) 1225 : ConversionPattern("test.repetitive_1_to_n_consumer", 1, ctx) {} 1226 LogicalResult 1227 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 1228 ConversionPatternRewriter &rewriter) const final { 1229 // A single operand is expected. 1230 if (op->getNumOperands() != 1) 1231 return failure(); 1232 rewriter.replaceOpWithNewOp<TestValidOp>(op, operands.front()); 1233 return success(); 1234 } 1235 }; 1236 1237 /// A pattern that tests two back-to-back 1 -> 2 op replacements. 1238 class TestMultiple1ToNReplacement : public ConversionPattern { 1239 public: 1240 TestMultiple1ToNReplacement(MLIRContext *ctx, const TypeConverter &converter) 1241 : ConversionPattern(converter, "test.multiple_1_to_n_replacement", 1, 1242 ctx) {} 1243 LogicalResult 1244 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 1245 ConversionPatternRewriter &rewriter) const final { 1246 // Helper function that replaces the given op with a new op of the given 1247 // name and doubles each result (1 -> 2 replacement of each result). 1248 auto replaceWithDoubleResults = [&](Operation *op, StringRef name) { 1249 SmallVector<Type> types; 1250 for (Type t : op->getResultTypes()) { 1251 types.push_back(t); 1252 types.push_back(t); 1253 } 1254 OperationState state(op->getLoc(), name, 1255 /*operands=*/{}, types, op->getAttrs()); 1256 auto *newOp = rewriter.create(state); 1257 SmallVector<ValueRange> repls; 1258 for (size_t i = 0, e = op->getNumResults(); i < e; ++i) 1259 repls.push_back(newOp->getResults().slice(2 * i, 2)); 1260 rewriter.replaceOpWithMultiple(op, repls); 1261 return newOp; 1262 }; 1263 1264 // Replace test.multiple_1_to_n_replacement with test.step_1. 1265 Operation *repl1 = replaceWithDoubleResults(op, "test.step_1"); 1266 // Now replace test.step_1 with test.legal_op. 1267 // TODO: Ideally, it should not be necessary to reset the insertion point 1268 // here. Based on the API calls, it looks like test.step_1 is entirely 1269 // erased. But that's not the case: an argument materialization will 1270 // survive. And that argument materialization will be used by the users of 1271 // `op`. If we don't reset the insertion point here, we get dominance 1272 // errors. This will be fixed when we have 1:N support in the conversion 1273 // value mapping. 1274 rewriter.setInsertionPoint(repl1); 1275 replaceWithDoubleResults(repl1, "test.legal_op"); 1276 return success(); 1277 } 1278 }; 1279 1280 } // namespace 1281 1282 namespace { 1283 struct TestTypeConverter : public TypeConverter { 1284 using TypeConverter::TypeConverter; 1285 TestTypeConverter() { 1286 addConversion(convertType); 1287 addArgumentMaterialization(materializeCast); 1288 addSourceMaterialization(materializeCast); 1289 } 1290 1291 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 1292 // Drop I16 types. 1293 if (t.isSignlessInteger(16)) 1294 return success(); 1295 1296 // Convert I64 to F64. 1297 if (t.isSignlessInteger(64)) { 1298 results.push_back(FloatType::getF64(t.getContext())); 1299 return success(); 1300 } 1301 1302 // Convert I42 to I43. 1303 if (t.isInteger(42)) { 1304 results.push_back(IntegerType::get(t.getContext(), 43)); 1305 return success(); 1306 } 1307 1308 // Split F32 into F16,F16. 1309 if (t.isF32()) { 1310 results.assign(2, FloatType::getF16(t.getContext())); 1311 return success(); 1312 } 1313 1314 // Drop I24 types. 1315 if (t.isInteger(24)) { 1316 return success(); 1317 } 1318 1319 // Otherwise, convert the type directly. 1320 results.push_back(t); 1321 return success(); 1322 } 1323 1324 /// Hook for materializing a conversion. This is necessary because we generate 1325 /// 1->N type mappings. 1326 static Value materializeCast(OpBuilder &builder, Type resultType, 1327 ValueRange inputs, Location loc) { 1328 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1329 } 1330 }; 1331 1332 struct TestLegalizePatternDriver 1333 : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> { 1334 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver) 1335 1336 StringRef getArgument() const final { return "test-legalize-patterns"; } 1337 StringRef getDescription() const final { 1338 return "Run test dialect legalization patterns"; 1339 } 1340 /// The mode of conversion to use with the driver. 1341 enum class ConversionMode { Analysis, Full, Partial }; 1342 1343 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 1344 1345 void getDependentDialects(DialectRegistry ®istry) const override { 1346 registry.insert<func::FuncDialect, test::TestDialect>(); 1347 } 1348 1349 void runOnOperation() override { 1350 TestTypeConverter converter; 1351 mlir::RewritePatternSet patterns(&getContext()); 1352 populateWithGenerated(patterns); 1353 patterns 1354 .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion, 1355 TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock, 1356 TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType, 1357 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 1358 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 1359 TestNonRootReplacement, TestBoundedRecursiveRewrite, 1360 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, 1361 TestCreateUnregisteredOp, TestUndoMoveOpBefore, 1362 TestUndoPropertiesModification, TestEraseOp, 1363 TestRepetitive1ToNConsumer>(&getContext()); 1364 patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp, 1365 TestPassthroughInvalidOp, TestMultiple1ToNReplacement>( 1366 &getContext(), converter); 1367 patterns.add<TestDuplicateBlockArgs>(converter, &getContext()); 1368 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1369 converter); 1370 mlir::populateCallOpTypeConversionPattern(patterns, converter); 1371 1372 // Define the conversion target used for the test. 1373 ConversionTarget target(getContext()); 1374 target.addLegalOp<ModuleOp>(); 1375 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, 1376 TerminatorOp, OneRegionOp>(); 1377 target.addLegalOp(OperationName("test.legal_op", &getContext())); 1378 target 1379 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 1380 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 1381 // Don't allow F32 operands. 1382 return llvm::none_of(op.getOperandTypes(), 1383 [](Type type) { return type.isF32(); }); 1384 }); 1385 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1386 return converter.isSignatureLegal(op.getFunctionType()) && 1387 converter.isLegal(&op.getBody()); 1388 }); 1389 target.addDynamicallyLegalOp<func::CallOp>( 1390 [&](func::CallOp op) { return converter.isLegal(op); }); 1391 1392 // TestCreateUnregisteredOp creates `arith.constant` operation, 1393 // which was not added to target intentionally to test 1394 // correct error code from conversion driver. 1395 target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); 1396 1397 // Expect the type_producer/type_consumer operations to only operate on f64. 1398 target.addDynamicallyLegalOp<TestTypeProducerOp>( 1399 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1400 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1401 return op.getOperand().getType().isF64(); 1402 }); 1403 1404 // Check support for marking certain operations as recursively legal. 1405 target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) { 1406 return static_cast<bool>( 1407 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 1408 }); 1409 1410 // Mark the bound recursion operation as dynamically legal. 1411 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 1412 [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); 1413 1414 // Create a dynamically legal rule that can only be legalized by folding it. 1415 target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>( 1416 [](TestOpInPlaceSelfFold op) { return op.getFolded(); }); 1417 1418 target.addDynamicallyLegalOp<DuplicateBlockArgsOp>( 1419 [](DuplicateBlockArgsOp op) { return op.getIsLegal(); }); 1420 1421 // Handle a partial conversion. 1422 if (mode == ConversionMode::Partial) { 1423 DenseSet<Operation *> unlegalizedOps; 1424 ConversionConfig config; 1425 DumpNotifications dumpNotifications; 1426 config.listener = &dumpNotifications; 1427 config.unlegalizedOps = &unlegalizedOps; 1428 if (failed(applyPartialConversion(getOperation(), target, 1429 std::move(patterns), config))) { 1430 getOperation()->emitRemark() << "applyPartialConversion failed"; 1431 } 1432 // Emit remarks for each legalizable operation. 1433 for (auto *op : unlegalizedOps) 1434 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1435 return; 1436 } 1437 1438 // Handle a full conversion. 1439 if (mode == ConversionMode::Full) { 1440 // Check support for marking unknown operations as dynamically legal. 1441 target.markUnknownOpDynamicallyLegal([](Operation *op) { 1442 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 1443 }); 1444 1445 ConversionConfig config; 1446 DumpNotifications dumpNotifications; 1447 config.listener = &dumpNotifications; 1448 if (failed(applyFullConversion(getOperation(), target, 1449 std::move(patterns), config))) { 1450 getOperation()->emitRemark() << "applyFullConversion failed"; 1451 } 1452 return; 1453 } 1454 1455 // Otherwise, handle an analysis conversion. 1456 assert(mode == ConversionMode::Analysis); 1457 1458 // Analyze the convertible operations. 1459 DenseSet<Operation *> legalizedOps; 1460 ConversionConfig config; 1461 config.legalizableOps = &legalizedOps; 1462 if (failed(applyAnalysisConversion(getOperation(), target, 1463 std::move(patterns), config))) 1464 return signalPassFailure(); 1465 1466 // Emit remarks for each legalizable operation. 1467 for (auto *op : legalizedOps) 1468 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 1469 } 1470 1471 /// The mode of conversion to use. 1472 ConversionMode mode; 1473 }; 1474 } // namespace 1475 1476 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 1477 legalizerConversionMode( 1478 "test-legalize-mode", 1479 llvm::cl::desc("The legalization mode to use with the test driver"), 1480 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 1481 llvm::cl::values( 1482 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 1483 "analysis", "Perform an analysis conversion"), 1484 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 1485 "Perform a full conversion"), 1486 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 1487 "partial", "Perform a partial conversion"))); 1488 1489 //===----------------------------------------------------------------------===// 1490 // ConversionPatternRewriter::getRemappedValue testing. This method is used 1491 // to get the remapped value of an original value that was replaced using 1492 // ConversionPatternRewriter. 1493 namespace { 1494 struct TestRemapValueTypeConverter : public TypeConverter { 1495 using TypeConverter::TypeConverter; 1496 1497 TestRemapValueTypeConverter() { 1498 addConversion( 1499 [](Float32Type type) { return Float64Type::get(type.getContext()); }); 1500 addConversion([](Type type) { return type; }); 1501 } 1502 }; 1503 1504 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 1505 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 1506 /// operand twice. 1507 /// 1508 /// Example: 1509 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 1510 /// is replaced with: 1511 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 1512 struct OneVResOneVOperandOp1Converter 1513 : public OpConversionPattern<OneVResOneVOperandOp1> { 1514 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 1515 1516 LogicalResult 1517 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, 1518 ConversionPatternRewriter &rewriter) const override { 1519 auto origOps = op.getOperands(); 1520 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 1521 "One operand expected"); 1522 Value origOp = *origOps.begin(); 1523 SmallVector<Value, 2> remappedOperands; 1524 // Replicate the remapped original operand twice. Note that we don't used 1525 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 1526 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1527 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1528 1529 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 1530 remappedOperands); 1531 return success(); 1532 } 1533 }; 1534 1535 /// A rewriter pattern that tests that blocks can be merged. 1536 struct TestRemapValueInRegion 1537 : public OpConversionPattern<TestRemappedValueRegionOp> { 1538 using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern; 1539 1540 LogicalResult 1541 matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor, 1542 ConversionPatternRewriter &rewriter) const final { 1543 Block &block = op.getBody().front(); 1544 Operation *terminator = block.getTerminator(); 1545 1546 // Merge the block into the parent region. 1547 Block *parentBlock = op->getBlock(); 1548 Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator()); 1549 rewriter.mergeBlocks(&block, parentBlock, ValueRange()); 1550 rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange()); 1551 1552 // Replace the results of this operation with the remapped terminator 1553 // values. 1554 SmallVector<Value> terminatorOperands; 1555 if (failed(rewriter.getRemappedValues(terminator->getOperands(), 1556 terminatorOperands))) 1557 return failure(); 1558 1559 rewriter.eraseOp(terminator); 1560 rewriter.replaceOp(op, terminatorOperands); 1561 return success(); 1562 } 1563 }; 1564 1565 struct TestRemappedValue 1566 : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> { 1567 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue) 1568 1569 StringRef getArgument() const final { return "test-remapped-value"; } 1570 StringRef getDescription() const final { 1571 return "Test public remapped value mechanism in ConversionPatternRewriter"; 1572 } 1573 void runOnOperation() override { 1574 TestRemapValueTypeConverter typeConverter; 1575 1576 mlir::RewritePatternSet patterns(&getContext()); 1577 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 1578 patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>( 1579 &getContext()); 1580 patterns.add<TestRemapValueInRegion>(typeConverter, &getContext()); 1581 1582 mlir::ConversionTarget target(getContext()); 1583 target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>(); 1584 1585 // Expect the type_producer/type_consumer operations to only operate on f64. 1586 target.addDynamicallyLegalOp<TestTypeProducerOp>( 1587 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1588 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1589 return op.getOperand().getType().isF64(); 1590 }); 1591 1592 // We make OneVResOneVOperandOp1 legal only when it has more that one 1593 // operand. This will trigger the conversion that will replace one-operand 1594 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 1595 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 1596 [](Operation *op) { return op->getNumOperands() > 1; }); 1597 1598 if (failed(mlir::applyFullConversion(getOperation(), target, 1599 std::move(patterns)))) { 1600 signalPassFailure(); 1601 } 1602 } 1603 }; 1604 } // namespace 1605 1606 //===----------------------------------------------------------------------===// 1607 // Test patterns without a specific root operation kind 1608 //===----------------------------------------------------------------------===// 1609 1610 namespace { 1611 /// This pattern matches and removes any operation in the test dialect. 1612 struct RemoveTestDialectOps : public RewritePattern { 1613 RemoveTestDialectOps(MLIRContext *context) 1614 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 1615 1616 LogicalResult matchAndRewrite(Operation *op, 1617 PatternRewriter &rewriter) const override { 1618 if (!isa<TestDialect>(op->getDialect())) 1619 return failure(); 1620 rewriter.eraseOp(op); 1621 return success(); 1622 } 1623 }; 1624 1625 struct TestUnknownRootOpDriver 1626 : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> { 1627 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver) 1628 1629 StringRef getArgument() const final { 1630 return "test-legalize-unknown-root-patterns"; 1631 } 1632 StringRef getDescription() const final { 1633 return "Test public remapped value mechanism in ConversionPatternRewriter"; 1634 } 1635 void runOnOperation() override { 1636 mlir::RewritePatternSet patterns(&getContext()); 1637 patterns.add<RemoveTestDialectOps>(&getContext()); 1638 1639 mlir::ConversionTarget target(getContext()); 1640 target.addIllegalDialect<TestDialect>(); 1641 if (failed(applyPartialConversion(getOperation(), target, 1642 std::move(patterns)))) 1643 signalPassFailure(); 1644 } 1645 }; 1646 } // namespace 1647 1648 //===----------------------------------------------------------------------===// 1649 // Test patterns that uses operations and types defined at runtime 1650 //===----------------------------------------------------------------------===// 1651 1652 namespace { 1653 /// This pattern matches dynamic operations 'test.one_operand_two_results' and 1654 /// replace them with dynamic operations 'test.generic_dynamic_op'. 1655 struct RewriteDynamicOp : public RewritePattern { 1656 RewriteDynamicOp(MLIRContext *context) 1657 : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1, 1658 context) {} 1659 1660 LogicalResult matchAndRewrite(Operation *op, 1661 PatternRewriter &rewriter) const override { 1662 assert(op->getName().getStringRef() == 1663 "test.dynamic_one_operand_two_results" && 1664 "rewrite pattern should only match operations with the right name"); 1665 1666 OperationState state(op->getLoc(), "test.dynamic_generic", 1667 op->getOperands(), op->getResultTypes(), 1668 op->getAttrs()); 1669 auto *newOp = rewriter.create(state); 1670 rewriter.replaceOp(op, newOp->getResults()); 1671 return success(); 1672 } 1673 }; 1674 1675 struct TestRewriteDynamicOpDriver 1676 : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> { 1677 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver) 1678 1679 void getDependentDialects(DialectRegistry ®istry) const override { 1680 registry.insert<TestDialect>(); 1681 } 1682 StringRef getArgument() const final { return "test-rewrite-dynamic-op"; } 1683 StringRef getDescription() const final { 1684 return "Test rewritting on dynamic operations"; 1685 } 1686 void runOnOperation() override { 1687 RewritePatternSet patterns(&getContext()); 1688 patterns.add<RewriteDynamicOp>(&getContext()); 1689 1690 ConversionTarget target(getContext()); 1691 target.addIllegalOp( 1692 OperationName("test.dynamic_one_operand_two_results", &getContext())); 1693 target.addLegalOp(OperationName("test.dynamic_generic", &getContext())); 1694 if (failed(applyPartialConversion(getOperation(), target, 1695 std::move(patterns)))) 1696 signalPassFailure(); 1697 } 1698 }; 1699 } // end anonymous namespace 1700 1701 //===----------------------------------------------------------------------===// 1702 // Test type conversions 1703 //===----------------------------------------------------------------------===// 1704 1705 namespace { 1706 struct TestTypeConversionProducer 1707 : public OpConversionPattern<TestTypeProducerOp> { 1708 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 1709 LogicalResult 1710 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, 1711 ConversionPatternRewriter &rewriter) const final { 1712 Type resultType = op.getType(); 1713 Type convertedType = getTypeConverter() 1714 ? getTypeConverter()->convertType(resultType) 1715 : resultType; 1716 if (isa<FloatType>(resultType)) 1717 resultType = rewriter.getF64Type(); 1718 else if (resultType.isInteger(16)) 1719 resultType = rewriter.getIntegerType(64); 1720 else if (isa<test::TestRecursiveType>(resultType) && 1721 convertedType != resultType) 1722 resultType = convertedType; 1723 else 1724 return failure(); 1725 1726 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 1727 return success(); 1728 } 1729 }; 1730 1731 /// Call signature conversion and then fail the rewrite to trigger the undo 1732 /// mechanism. 1733 struct TestSignatureConversionUndo 1734 : public OpConversionPattern<TestSignatureConversionUndoOp> { 1735 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 1736 1737 LogicalResult 1738 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, 1739 ConversionPatternRewriter &rewriter) const final { 1740 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 1741 return failure(); 1742 } 1743 }; 1744 1745 /// Call signature conversion without providing a type converter to handle 1746 /// materializations. 1747 struct TestTestSignatureConversionNoConverter 1748 : public OpConversionPattern<TestSignatureConversionNoConverterOp> { 1749 TestTestSignatureConversionNoConverter(const TypeConverter &converter, 1750 MLIRContext *context) 1751 : OpConversionPattern<TestSignatureConversionNoConverterOp>(context), 1752 converter(converter) {} 1753 1754 LogicalResult 1755 matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor, 1756 ConversionPatternRewriter &rewriter) const final { 1757 Region ®ion = op->getRegion(0); 1758 Block *entry = ®ion.front(); 1759 1760 // Convert the original entry arguments. 1761 TypeConverter::SignatureConversion result(entry->getNumArguments()); 1762 if (failed( 1763 converter.convertSignatureArgs(entry->getArgumentTypes(), result))) 1764 return failure(); 1765 rewriter.modifyOpInPlace(op, [&] { 1766 rewriter.applySignatureConversion(®ion.front(), result); 1767 }); 1768 return success(); 1769 } 1770 1771 const TypeConverter &converter; 1772 }; 1773 1774 /// Just forward the operands to the root op. This is essentially a no-op 1775 /// pattern that is used to trigger target materialization. 1776 struct TestTypeConsumerForward 1777 : public OpConversionPattern<TestTypeConsumerOp> { 1778 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 1779 1780 LogicalResult 1781 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, 1782 ConversionPatternRewriter &rewriter) const final { 1783 rewriter.modifyOpInPlace(op, 1784 [&] { op->setOperands(adaptor.getOperands()); }); 1785 return success(); 1786 } 1787 }; 1788 1789 struct TestTypeConversionAnotherProducer 1790 : public OpRewritePattern<TestAnotherTypeProducerOp> { 1791 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 1792 1793 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 1794 PatternRewriter &rewriter) const final { 1795 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 1796 return success(); 1797 } 1798 }; 1799 1800 struct TestReplaceWithLegalOp : public ConversionPattern { 1801 TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx) 1802 : ConversionPattern(converter, "test.replace_with_legal_op", 1803 /*benefit=*/1, ctx) {} 1804 LogicalResult 1805 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1806 ConversionPatternRewriter &rewriter) const final { 1807 rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]); 1808 return success(); 1809 } 1810 }; 1811 1812 struct TestTypeConversionDriver 1813 : public PassWrapper<TestTypeConversionDriver, OperationPass<>> { 1814 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver) 1815 1816 void getDependentDialects(DialectRegistry ®istry) const override { 1817 registry.insert<TestDialect>(); 1818 } 1819 StringRef getArgument() const final { 1820 return "test-legalize-type-conversion"; 1821 } 1822 StringRef getDescription() const final { 1823 return "Test various type conversion functionalities in DialectConversion"; 1824 } 1825 1826 void runOnOperation() override { 1827 // Initialize the type converter. 1828 SmallVector<Type, 2> conversionCallStack; 1829 TypeConverter converter; 1830 1831 /// Add the legal set of type conversions. 1832 converter.addConversion([](Type type) -> Type { 1833 // Treat F64 as legal. 1834 if (type.isF64()) 1835 return type; 1836 // Allow converting BF16/F16/F32 to F64. 1837 if (type.isBF16() || type.isF16() || type.isF32()) 1838 return FloatType::getF64(type.getContext()); 1839 // Otherwise, the type is illegal. 1840 return nullptr; 1841 }); 1842 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 1843 // Drop all integer types. 1844 return success(); 1845 }); 1846 converter.addConversion( 1847 // Convert a recursive self-referring type into a non-self-referring 1848 // type named "outer_converted_type" that contains a SimpleAType. 1849 [&](test::TestRecursiveType type, 1850 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { 1851 // If the type is already converted, return it to indicate that it is 1852 // legal. 1853 if (type.getName() == "outer_converted_type") { 1854 results.push_back(type); 1855 return success(); 1856 } 1857 1858 conversionCallStack.push_back(type); 1859 auto popConversionCallStack = llvm::make_scope_exit( 1860 [&conversionCallStack]() { conversionCallStack.pop_back(); }); 1861 1862 // If the type is on the call stack more than once (it is there at 1863 // least once because of the _current_ call, which is always the last 1864 // element on the stack), we've hit the recursive case. Just return 1865 // SimpleAType here to create a non-recursive type as a result. 1866 if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(), 1867 type)) { 1868 results.push_back(test::SimpleAType::get(type.getContext())); 1869 return success(); 1870 } 1871 1872 // Convert the body recursively. 1873 auto result = test::TestRecursiveType::get(type.getContext(), 1874 "outer_converted_type"); 1875 if (failed(result.setBody(converter.convertType(type.getBody())))) 1876 return failure(); 1877 results.push_back(result); 1878 return success(); 1879 }); 1880 1881 /// Add the legal set of type materializations. 1882 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 1883 ValueRange inputs, 1884 Location loc) -> Value { 1885 // Allow casting from F64 back to F32. 1886 if (!resultType.isF16() && inputs.size() == 1 && 1887 inputs[0].getType().isF64()) 1888 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1889 // Allow producing an i32 or i64 from nothing. 1890 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 1891 inputs.empty()) 1892 return builder.create<TestTypeProducerOp>(loc, resultType); 1893 // Allow producing an i64 from an integer. 1894 if (isa<IntegerType>(resultType) && inputs.size() == 1 && 1895 isa<IntegerType>(inputs[0].getType())) 1896 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1897 // Otherwise, fail. 1898 return nullptr; 1899 }); 1900 1901 // Initialize the conversion target. 1902 mlir::ConversionTarget target(getContext()); 1903 target.addLegalOp<LegalOpD>(); 1904 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 1905 auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType()); 1906 return op.getType().isF64() || op.getType().isInteger(64) || 1907 (recursiveType && 1908 recursiveType.getName() == "outer_converted_type"); 1909 }); 1910 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1911 return converter.isSignatureLegal(op.getFunctionType()) && 1912 converter.isLegal(&op.getBody()); 1913 }); 1914 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 1915 // Allow casts from F64 to F32. 1916 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 1917 }); 1918 target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>( 1919 [&](TestSignatureConversionNoConverterOp op) { 1920 return converter.isLegal(op.getRegion().front().getArgumentTypes()); 1921 }); 1922 1923 // Initialize the set of rewrite patterns. 1924 RewritePatternSet patterns(&getContext()); 1925 patterns 1926 .add<TestTypeConsumerForward, TestTypeConversionProducer, 1927 TestSignatureConversionUndo, 1928 TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>( 1929 converter, &getContext()); 1930 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 1931 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1932 converter); 1933 1934 if (failed(applyPartialConversion(getOperation(), target, 1935 std::move(patterns)))) 1936 signalPassFailure(); 1937 } 1938 }; 1939 } // namespace 1940 1941 //===----------------------------------------------------------------------===// 1942 // Test Target Materialization With No Uses 1943 //===----------------------------------------------------------------------===// 1944 1945 namespace { 1946 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> { 1947 using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern; 1948 1949 LogicalResult 1950 matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor, 1951 ConversionPatternRewriter &rewriter) const final { 1952 rewriter.replaceOp(op, adaptor.getOperands()); 1953 return success(); 1954 } 1955 }; 1956 1957 struct TestTargetMaterializationWithNoUses 1958 : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> { 1959 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 1960 TestTargetMaterializationWithNoUses) 1961 1962 StringRef getArgument() const final { 1963 return "test-target-materialization-with-no-uses"; 1964 } 1965 StringRef getDescription() const final { 1966 return "Test a special case of target materialization in DialectConversion"; 1967 } 1968 1969 void runOnOperation() override { 1970 TypeConverter converter; 1971 converter.addConversion([](Type t) { return t; }); 1972 converter.addConversion([](IntegerType intTy) -> Type { 1973 if (intTy.getWidth() == 16) 1974 return IntegerType::get(intTy.getContext(), 64); 1975 return intTy; 1976 }); 1977 converter.addTargetMaterialization( 1978 [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { 1979 return builder.create<TestCastOp>(loc, type, inputs).getResult(); 1980 }); 1981 1982 ConversionTarget target(getContext()); 1983 target.addIllegalOp<TestTypeChangerOp>(); 1984 1985 RewritePatternSet patterns(&getContext()); 1986 patterns.add<ForwardOperandPattern>(converter, &getContext()); 1987 1988 if (failed(applyPartialConversion(getOperation(), target, 1989 std::move(patterns)))) 1990 signalPassFailure(); 1991 } 1992 }; 1993 } // namespace 1994 1995 //===----------------------------------------------------------------------===// 1996 // Test Block Merging 1997 //===----------------------------------------------------------------------===// 1998 1999 namespace { 2000 /// A rewriter pattern that tests that blocks can be merged. 2001 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 2002 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 2003 2004 LogicalResult 2005 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, 2006 ConversionPatternRewriter &rewriter) const final { 2007 Block &firstBlock = op.getBody().front(); 2008 Operation *branchOp = firstBlock.getTerminator(); 2009 Block *secondBlock = &*(std::next(op.getBody().begin())); 2010 auto succOperands = branchOp->getOperands(); 2011 SmallVector<Value, 2> replacements(succOperands); 2012 rewriter.eraseOp(branchOp); 2013 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 2014 rewriter.modifyOpInPlace(op, [] {}); 2015 return success(); 2016 } 2017 }; 2018 2019 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 2020 struct TestUndoBlocksMerge : public ConversionPattern { 2021 TestUndoBlocksMerge(MLIRContext *ctx) 2022 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 2023 LogicalResult 2024 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 2025 ConversionPatternRewriter &rewriter) const final { 2026 Block &firstBlock = op->getRegion(0).front(); 2027 Operation *branchOp = firstBlock.getTerminator(); 2028 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 2029 rewriter.setInsertionPointToStart(secondBlock); 2030 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 2031 auto succOperands = branchOp->getOperands(); 2032 SmallVector<Value, 2> replacements(succOperands); 2033 rewriter.eraseOp(branchOp); 2034 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 2035 rewriter.modifyOpInPlace(op, [] {}); 2036 return success(); 2037 } 2038 }; 2039 2040 /// A rewrite mechanism to inline the body of the op into its parent, when both 2041 /// ops can have a single block. 2042 struct TestMergeSingleBlockOps 2043 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 2044 using OpConversionPattern< 2045 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 2046 2047 LogicalResult 2048 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, 2049 ConversionPatternRewriter &rewriter) const final { 2050 SingleBlockImplicitTerminatorOp parentOp = 2051 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 2052 if (!parentOp) 2053 return failure(); 2054 Block &innerBlock = op.getRegion().front(); 2055 TerminatorOp innerTerminator = 2056 cast<TerminatorOp>(innerBlock.getTerminator()); 2057 rewriter.inlineBlockBefore(&innerBlock, op); 2058 rewriter.eraseOp(innerTerminator); 2059 rewriter.eraseOp(op); 2060 return success(); 2061 } 2062 }; 2063 2064 struct TestMergeBlocksPatternDriver 2065 : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> { 2066 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver) 2067 2068 StringRef getArgument() const final { return "test-merge-blocks"; } 2069 StringRef getDescription() const final { 2070 return "Test Merging operation in ConversionPatternRewriter"; 2071 } 2072 void runOnOperation() override { 2073 MLIRContext *context = &getContext(); 2074 mlir::RewritePatternSet patterns(context); 2075 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 2076 context); 2077 ConversionTarget target(*context); 2078 target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 2079 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 2080 target.addIllegalOp<ILLegalOpF>(); 2081 2082 /// Expect the op to have a single block after legalization. 2083 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 2084 [&](TestMergeBlocksOp op) -> bool { 2085 return llvm::hasSingleElement(op.getBody()); 2086 }); 2087 2088 /// Only allow `test.br` within test.merge_blocks op. 2089 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 2090 return op->getParentOfType<TestMergeBlocksOp>(); 2091 }); 2092 2093 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 2094 /// inlined. 2095 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 2096 [&](SingleBlockImplicitTerminatorOp op) -> bool { 2097 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 2098 }); 2099 2100 DenseSet<Operation *> unlegalizedOps; 2101 ConversionConfig config; 2102 config.unlegalizedOps = &unlegalizedOps; 2103 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 2104 config); 2105 for (auto *op : unlegalizedOps) 2106 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 2107 } 2108 }; 2109 } // namespace 2110 2111 //===----------------------------------------------------------------------===// 2112 // Test Selective Replacement 2113 //===----------------------------------------------------------------------===// 2114 2115 namespace { 2116 /// A rewrite mechanism to inline the body of the op into its parent, when both 2117 /// ops can have a single block. 2118 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 2119 using OpRewritePattern<TestCastOp>::OpRewritePattern; 2120 2121 LogicalResult matchAndRewrite(TestCastOp op, 2122 PatternRewriter &rewriter) const final { 2123 if (op.getNumOperands() != 2) 2124 return failure(); 2125 OperandRange operands = op.getOperands(); 2126 2127 // Replace non-terminator uses with the first operand. 2128 rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) { 2129 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 2130 }); 2131 // Replace everything else with the second operand if the operation isn't 2132 // dead. 2133 rewriter.replaceOp(op, op.getOperand(1)); 2134 return success(); 2135 } 2136 }; 2137 2138 struct TestSelectiveReplacementPatternDriver 2139 : public PassWrapper<TestSelectiveReplacementPatternDriver, 2140 OperationPass<>> { 2141 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 2142 TestSelectiveReplacementPatternDriver) 2143 2144 StringRef getArgument() const final { 2145 return "test-pattern-selective-replacement"; 2146 } 2147 StringRef getDescription() const final { 2148 return "Test selective replacement in the PatternRewriter"; 2149 } 2150 void runOnOperation() override { 2151 MLIRContext *context = &getContext(); 2152 mlir::RewritePatternSet patterns(context); 2153 patterns.add<TestSelectiveOpReplacementPattern>(context); 2154 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 2155 } 2156 }; 2157 } // namespace 2158 2159 //===----------------------------------------------------------------------===// 2160 // PassRegistration 2161 //===----------------------------------------------------------------------===// 2162 2163 namespace mlir { 2164 namespace test { 2165 void registerPatternsTestPass() { 2166 PassRegistration<TestReturnTypeDriver>(); 2167 2168 PassRegistration<TestDerivedAttributeDriver>(); 2169 2170 PassRegistration<TestGreedyPatternDriver>(); 2171 PassRegistration<TestStrictPatternDriver>(); 2172 PassRegistration<TestWalkPatternDriver>(); 2173 2174 PassRegistration<TestLegalizePatternDriver>([] { 2175 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 2176 }); 2177 2178 PassRegistration<TestRemappedValue>(); 2179 2180 PassRegistration<TestUnknownRootOpDriver>(); 2181 2182 PassRegistration<TestTypeConversionDriver>(); 2183 PassRegistration<TestTargetMaterializationWithNoUses>(); 2184 2185 PassRegistration<TestRewriteDynamicOpDriver>(); 2186 2187 PassRegistration<TestMergeBlocksPatternDriver>(); 2188 PassRegistration<TestSelectiveReplacementPatternDriver>(); 2189 } 2190 } // namespace test 2191 } // namespace mlir 2192