1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestOps.h" 11 #include "TestTypes.h" 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/BuiltinAttributes.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/Visitors.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "mlir/Transforms/FoldUtils.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 #include "mlir/Transforms/WalkPatternRewriteDriver.h" 24 #include "llvm/ADT/ScopeExit.h" 25 #include <cstdint> 26 27 using namespace mlir; 28 using namespace test; 29 30 // Native function for testing NativeCodeCall 31 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 32 return choice.getValue() ? input1 : input2; 33 } 34 35 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 36 rewriter.create<OpI>(loc, input); 37 } 38 39 static void handleNoResultOp(PatternRewriter &rewriter, 40 OpSymbolBindingNoResult op) { 41 // Turn the no result op to a one-result op. 42 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), 43 op.getOperand()); 44 } 45 46 static bool getFirstI32Result(Operation *op, Value &value) { 47 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 48 return false; 49 value = op->getResult(0); 50 return true; 51 } 52 53 static Value bindNativeCodeCallResult(Value value) { return value; } 54 55 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, 56 Value input2) { 57 return SmallVector<Value, 2>({input2, input1}); 58 } 59 60 // Test that natives calls are only called once during rewrites. 61 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 62 // This let us check the number of times OpM_Test was called by inspecting 63 // the returned value in the MLIR output. 64 static int64_t opMIncreasingValue = 314159265; 65 static Attribute opMTest(PatternRewriter &rewriter, Value val) { 66 int64_t i = opMIncreasingValue++; 67 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 68 } 69 70 namespace { 71 #include "TestPatterns.inc" 72 } // namespace 73 74 //===----------------------------------------------------------------------===// 75 // Test Reduce Pattern Interface 76 //===----------------------------------------------------------------------===// 77 78 void test::populateTestReductionPatterns(RewritePatternSet &patterns) { 79 populateWithGenerated(patterns); 80 } 81 82 //===----------------------------------------------------------------------===// 83 // Canonicalizer Driver. 84 //===----------------------------------------------------------------------===// 85 86 namespace { 87 struct FoldingPattern : public RewritePattern { 88 public: 89 FoldingPattern(MLIRContext *context) 90 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 91 /*benefit=*/1, context) {} 92 93 LogicalResult matchAndRewrite(Operation *op, 94 PatternRewriter &rewriter) const override { 95 // Exercise createOrFold API for a single-result operation that is folded 96 // upon construction. The operation being created has an in-place folder, 97 // and it should be still present in the output. Furthermore, the folder 98 // should not crash when attempting to recover the (unchanged) operation 99 // result. 100 Value result = rewriter.createOrFold<TestOpInPlaceFold>( 101 op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0)); 102 assert(result); 103 rewriter.replaceOp(op, result); 104 return success(); 105 } 106 }; 107 108 /// This pattern creates a foldable operation at the entry point of the block. 109 /// This tests the situation where the operation folder will need to replace an 110 /// operation with a previously created constant that does not initially 111 /// dominate the operation to replace. 112 struct FolderInsertBeforePreviouslyFoldedConstantPattern 113 : public OpRewritePattern<TestCastOp> { 114 public: 115 using OpRewritePattern<TestCastOp>::OpRewritePattern; 116 117 LogicalResult matchAndRewrite(TestCastOp op, 118 PatternRewriter &rewriter) const override { 119 if (!op->hasAttr("test_fold_before_previously_folded_op")) 120 return failure(); 121 rewriter.setInsertionPointToStart(op->getBlock()); 122 123 auto constOp = rewriter.create<arith::ConstantOp>( 124 op.getLoc(), rewriter.getBoolAttr(true)); 125 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), 126 Value(constOp)); 127 return success(); 128 } 129 }; 130 131 /// This pattern matches test.op_commutative2 with the first operand being 132 /// another test.op_commutative2 with a constant on the right side and fold it 133 /// away by propagating it as its result. This is intend to check that patterns 134 /// are applied after the commutative property moves constant to the right. 135 struct FolderCommutativeOp2WithConstant 136 : public OpRewritePattern<TestCommutative2Op> { 137 public: 138 using OpRewritePattern<TestCommutative2Op>::OpRewritePattern; 139 140 LogicalResult matchAndRewrite(TestCommutative2Op op, 141 PatternRewriter &rewriter) const override { 142 auto operand = 143 dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp()); 144 if (!operand) 145 return failure(); 146 Attribute constInput; 147 if (!matchPattern(operand->getOperand(1), m_Constant(&constInput))) 148 return failure(); 149 rewriter.replaceOp(op, operand->getOperand(1)); 150 return success(); 151 } 152 }; 153 154 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer 155 /// attribute with value smaller than MaxVal, it increments the value by 1. 156 template <int MaxVal> 157 struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> { 158 using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern; 159 160 LogicalResult matchAndRewrite(AnyAttrOfOp op, 161 PatternRewriter &rewriter) const override { 162 auto intAttr = dyn_cast<IntegerAttr>(op.getAttr()); 163 if (!intAttr) 164 return failure(); 165 int64_t val = intAttr.getInt(); 166 if (val >= MaxVal) 167 return failure(); 168 rewriter.modifyOpInPlace( 169 op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); }); 170 return success(); 171 } 172 }; 173 174 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op". 175 struct MakeOpEligible : public RewritePattern { 176 MakeOpEligible(MLIRContext *context) 177 : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {} 178 179 LogicalResult matchAndRewrite(Operation *op, 180 PatternRewriter &rewriter) const override { 181 if (op->hasAttr("eligible")) 182 return failure(); 183 rewriter.modifyOpInPlace( 184 op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); }); 185 return success(); 186 } 187 }; 188 189 /// This pattern hoists eligible ops out of a "test.one_region_op". 190 struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> { 191 using OpRewritePattern<test::OneRegionOp>::OpRewritePattern; 192 193 LogicalResult matchAndRewrite(test::OneRegionOp op, 194 PatternRewriter &rewriter) const override { 195 Operation *terminator = op.getRegion().front().getTerminator(); 196 Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp(); 197 if (toBeHoisted->getParentOp() != op) 198 return failure(); 199 if (!toBeHoisted->hasAttr("eligible")) 200 return failure(); 201 rewriter.moveOpBefore(toBeHoisted, op); 202 return success(); 203 } 204 }; 205 206 /// This pattern moves "test.move_before_parent_op" before the parent op. 207 struct MoveBeforeParentOp : public RewritePattern { 208 MoveBeforeParentOp(MLIRContext *context) 209 : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {} 210 211 LogicalResult matchAndRewrite(Operation *op, 212 PatternRewriter &rewriter) const override { 213 // Do not hoist past functions. 214 if (isa<FunctionOpInterface>(op->getParentOp())) 215 return failure(); 216 rewriter.moveOpBefore(op, op->getParentOp()); 217 return success(); 218 } 219 }; 220 221 /// This pattern moves "test.move_after_parent_op" after the parent op. 222 struct MoveAfterParentOp : public RewritePattern { 223 MoveAfterParentOp(MLIRContext *context) 224 : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {} 225 226 LogicalResult matchAndRewrite(Operation *op, 227 PatternRewriter &rewriter) const override { 228 // Do not hoist past functions. 229 if (isa<FunctionOpInterface>(op->getParentOp())) 230 return failure(); 231 232 int64_t moveForwardBy = 0; 233 if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance")) 234 moveForwardBy = advanceBy.getInt(); 235 236 Operation *moveAfter = op->getParentOp(); 237 for (int64_t i = 0; i < moveForwardBy; ++i) 238 moveAfter = moveAfter->getNextNode(); 239 240 rewriter.moveOpAfter(op, moveAfter); 241 return success(); 242 } 243 }; 244 245 /// This pattern inlines blocks that are nested in 246 /// "test.inline_blocks_into_parent" into the parent block. 247 struct InlineBlocksIntoParent : public RewritePattern { 248 InlineBlocksIntoParent(MLIRContext *context) 249 : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1, 250 context) {} 251 252 LogicalResult matchAndRewrite(Operation *op, 253 PatternRewriter &rewriter) const override { 254 bool changed = false; 255 for (Region &r : op->getRegions()) { 256 while (!r.empty()) { 257 rewriter.inlineBlockBefore(&r.front(), op); 258 changed = true; 259 } 260 } 261 return success(changed); 262 } 263 }; 264 265 /// This pattern splits blocks at "test.split_block_here" and replaces the op 266 /// with a new op (to prevent an infinite loop of block splitting). 267 struct SplitBlockHere : public RewritePattern { 268 SplitBlockHere(MLIRContext *context) 269 : RewritePattern("test.split_block_here", /*benefit=*/1, context) {} 270 271 LogicalResult matchAndRewrite(Operation *op, 272 PatternRewriter &rewriter) const override { 273 rewriter.splitBlock(op->getBlock(), op->getIterator()); 274 Operation *newOp = rewriter.create( 275 op->getLoc(), 276 OperationName("test.new_op", op->getContext()).getIdentifier(), 277 op->getOperands(), op->getResultTypes()); 278 rewriter.replaceOp(op, newOp); 279 return success(); 280 } 281 }; 282 283 /// This pattern clones "test.clone_me" ops. 284 struct CloneOp : public RewritePattern { 285 CloneOp(MLIRContext *context) 286 : RewritePattern("test.clone_me", /*benefit=*/1, context) {} 287 288 LogicalResult matchAndRewrite(Operation *op, 289 PatternRewriter &rewriter) const override { 290 // Do not clone already cloned ops to avoid going into an infinite loop. 291 if (op->hasAttr("was_cloned")) 292 return failure(); 293 Operation *cloned = rewriter.clone(*op); 294 cloned->setAttr("was_cloned", rewriter.getUnitAttr()); 295 return success(); 296 } 297 }; 298 299 /// This pattern clones regions of "test.clone_region_before" ops before the 300 /// parent block. 301 struct CloneRegionBeforeOp : public RewritePattern { 302 CloneRegionBeforeOp(MLIRContext *context) 303 : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {} 304 305 LogicalResult matchAndRewrite(Operation *op, 306 PatternRewriter &rewriter) const override { 307 // Do not clone already cloned ops to avoid going into an infinite loop. 308 if (op->hasAttr("was_cloned")) 309 return failure(); 310 for (Region &r : op->getRegions()) 311 rewriter.cloneRegionBefore(r, op->getBlock()); 312 op->setAttr("was_cloned", rewriter.getUnitAttr()); 313 return success(); 314 } 315 }; 316 317 /// Replace an operation may introduce the re-visiting of its users. 318 class ReplaceWithNewOp : public RewritePattern { 319 public: 320 ReplaceWithNewOp(MLIRContext *context) 321 : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {} 322 323 LogicalResult matchAndRewrite(Operation *op, 324 PatternRewriter &rewriter) const override { 325 Operation *newOp; 326 if (op->hasAttr("create_erase_op")) { 327 newOp = rewriter.create( 328 op->getLoc(), 329 OperationName("test.erase_op", op->getContext()).getIdentifier(), 330 ValueRange(), TypeRange()); 331 } else { 332 newOp = rewriter.create( 333 op->getLoc(), 334 OperationName("test.new_op", op->getContext()).getIdentifier(), 335 op->getOperands(), op->getResultTypes()); 336 } 337 // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp". 338 // A "notifyOperationReplaced" callback is triggered in either case. 339 rewriter.replaceAllOpUsesWith(op, newOp->getResults()); 340 rewriter.eraseOp(op); 341 return success(); 342 } 343 }; 344 345 /// Erases the first child block of the matched "test.erase_first_block" 346 /// operation. 347 class EraseFirstBlock : public RewritePattern { 348 public: 349 EraseFirstBlock(MLIRContext *context) 350 : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {} 351 352 LogicalResult matchAndRewrite(Operation *op, 353 PatternRewriter &rewriter) const override { 354 for (Region &r : op->getRegions()) { 355 for (Block &b : r.getBlocks()) { 356 rewriter.eraseBlock(&b); 357 return success(); 358 } 359 } 360 361 return failure(); 362 } 363 }; 364 365 struct TestGreedyPatternDriver 366 : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> { 367 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver) 368 369 TestGreedyPatternDriver() = default; 370 TestGreedyPatternDriver(const TestGreedyPatternDriver &other) 371 : PassWrapper(other) {} 372 373 StringRef getArgument() const final { return "test-greedy-patterns"; } 374 StringRef getDescription() const final { return "Run test dialect patterns"; } 375 void runOnOperation() override { 376 mlir::RewritePatternSet patterns(&getContext()); 377 populateWithGenerated(patterns); 378 379 // Verify named pattern is generated with expected name. 380 patterns.add<FoldingPattern, TestNamedPatternRule, 381 FolderInsertBeforePreviouslyFoldedConstantPattern, 382 FolderCommutativeOp2WithConstant, HoistEligibleOps, 383 MakeOpEligible>(&getContext()); 384 385 // Additional patterns for testing the GreedyPatternRewriteDriver. 386 patterns.insert<IncrementIntAttribute<3>>(&getContext()); 387 388 GreedyRewriteConfig config; 389 config.useTopDownTraversal = this->useTopDownTraversal; 390 config.maxIterations = this->maxIterations; 391 config.fold = this->fold; 392 config.cseConstants = this->cseConstants; 393 (void)applyPatternsGreedily(getOperation(), std::move(patterns), config); 394 } 395 396 Option<bool> useTopDownTraversal{ 397 *this, "top-down", 398 llvm::cl::desc("Seed the worklist in general top-down order"), 399 llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)}; 400 Option<int> maxIterations{ 401 *this, "max-iterations", 402 llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"), 403 llvm::cl::init(GreedyRewriteConfig().maxIterations)}; 404 Option<bool> fold{*this, "fold", llvm::cl::desc("Whether to fold"), 405 llvm::cl::init(GreedyRewriteConfig().fold)}; 406 Option<bool> cseConstants{*this, "cse-constants", 407 llvm::cl::desc("Whether to CSE constants"), 408 llvm::cl::init(GreedyRewriteConfig().cseConstants)}; 409 }; 410 411 struct DumpNotifications : public RewriterBase::Listener { 412 void notifyBlockInserted(Block *block, Region *previous, 413 Region::iterator previousIt) override { 414 llvm::outs() << "notifyBlockInserted"; 415 if (block->getParentOp()) { 416 llvm::outs() << " into " << block->getParentOp()->getName() << ": "; 417 } else { 418 llvm::outs() << " into unknown op: "; 419 } 420 if (previous == nullptr) { 421 llvm::outs() << "was unlinked\n"; 422 } else { 423 llvm::outs() << "was linked\n"; 424 } 425 } 426 void notifyOperationInserted(Operation *op, 427 OpBuilder::InsertPoint previous) override { 428 llvm::outs() << "notifyOperationInserted: " << op->getName(); 429 if (!previous.isSet()) { 430 llvm::outs() << ", was unlinked\n"; 431 } else { 432 if (!previous.getPoint().getNodePtr()) { 433 llvm::outs() << ", was linked, exact position unknown\n"; 434 } else if (previous.getPoint() == previous.getBlock()->end()) { 435 llvm::outs() << ", was last in block\n"; 436 } else { 437 llvm::outs() << ", previous = " << previous.getPoint()->getName() 438 << "\n"; 439 } 440 } 441 } 442 void notifyBlockErased(Block *block) override { 443 llvm::outs() << "notifyBlockErased\n"; 444 } 445 void notifyOperationErased(Operation *op) override { 446 llvm::outs() << "notifyOperationErased: " << op->getName() << "\n"; 447 } 448 void notifyOperationModified(Operation *op) override { 449 llvm::outs() << "notifyOperationModified: " << op->getName() << "\n"; 450 } 451 void notifyOperationReplaced(Operation *op, ValueRange values) override { 452 llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n"; 453 } 454 }; 455 456 struct TestStrictPatternDriver 457 : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> { 458 public: 459 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver) 460 461 TestStrictPatternDriver() = default; 462 TestStrictPatternDriver(const TestStrictPatternDriver &other) 463 : PassWrapper(other) { 464 strictMode = other.strictMode; 465 } 466 467 StringRef getArgument() const final { return "test-strict-pattern-driver"; } 468 StringRef getDescription() const final { 469 return "Test strict mode of pattern driver"; 470 } 471 472 void runOnOperation() override { 473 MLIRContext *ctx = &getContext(); 474 mlir::RewritePatternSet patterns(ctx); 475 patterns.add< 476 // clang-format off 477 ChangeBlockOp, 478 CloneOp, 479 CloneRegionBeforeOp, 480 EraseOp, 481 ImplicitChangeOp, 482 InlineBlocksIntoParent, 483 InsertSameOp, 484 MoveBeforeParentOp, 485 ReplaceWithNewOp, 486 SplitBlockHere 487 // clang-format on 488 >(ctx); 489 SmallVector<Operation *> ops; 490 getOperation()->walk([&](Operation *op) { 491 StringRef opName = op->getName().getStringRef(); 492 if (opName == "test.insert_same_op" || opName == "test.change_block_op" || 493 opName == "test.replace_with_new_op" || opName == "test.erase_op" || 494 opName == "test.move_before_parent_op" || 495 opName == "test.inline_blocks_into_parent" || 496 opName == "test.split_block_here" || opName == "test.clone_me" || 497 opName == "test.clone_region_before") { 498 ops.push_back(op); 499 } 500 }); 501 502 DumpNotifications dumpNotifications; 503 GreedyRewriteConfig config; 504 config.listener = &dumpNotifications; 505 if (strictMode == "AnyOp") { 506 config.strictMode = GreedyRewriteStrictness::AnyOp; 507 } else if (strictMode == "ExistingAndNewOps") { 508 config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; 509 } else if (strictMode == "ExistingOps") { 510 config.strictMode = GreedyRewriteStrictness::ExistingOps; 511 } else { 512 llvm_unreachable("invalid strictness option"); 513 } 514 515 // Check if these transformations introduce visiting of operations that 516 // are not in the `ops` set (The new created ops are valid). An invalid 517 // operation will trigger the assertion while processing. 518 bool changed = false; 519 bool allErased = false; 520 (void)applyOpPatternsGreedily(ArrayRef(ops), std::move(patterns), config, 521 &changed, &allErased); 522 Builder b(ctx); 523 getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed)); 524 getOperation()->setAttr("pattern_driver_all_erased", 525 b.getBoolAttr(allErased)); 526 } 527 528 Option<std::string> strictMode{ 529 *this, "strictness", 530 llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"), 531 llvm::cl::init("AnyOp")}; 532 533 private: 534 // New inserted operation is valid for further transformation. 535 class InsertSameOp : public RewritePattern { 536 public: 537 InsertSameOp(MLIRContext *context) 538 : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {} 539 540 LogicalResult matchAndRewrite(Operation *op, 541 PatternRewriter &rewriter) const override { 542 if (op->hasAttr("skip")) 543 return failure(); 544 545 Operation *newOp = 546 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 547 op->getOperands(), op->getResultTypes()); 548 rewriter.modifyOpInPlace( 549 op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); }); 550 newOp->setAttr("skip", rewriter.getBoolAttr(true)); 551 552 return success(); 553 } 554 }; 555 556 // Remove an operation may introduce the re-visiting of its operands. 557 class EraseOp : public RewritePattern { 558 public: 559 EraseOp(MLIRContext *context) 560 : RewritePattern("test.erase_op", /*benefit=*/1, context) {} 561 LogicalResult matchAndRewrite(Operation *op, 562 PatternRewriter &rewriter) const override { 563 rewriter.eraseOp(op); 564 return success(); 565 } 566 }; 567 568 // The following two patterns test RewriterBase::replaceAllUsesWith. 569 // 570 // That function replaces all usages of a Block (or a Value) with another one 571 // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver 572 // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its 573 // worklist: when an op is modified, it is added to the worklist. The two 574 // patterns below make the tracking observable: ChangeBlockOp replaces all 575 // usages of a block and that pattern is applied because the corresponding ops 576 // are put on the initial worklist (see above). ImplicitChangeOp does an 577 // unrelated change but ops of the corresponding type are *not* on the initial 578 // worklist, so the effect of the second pattern is only visible if the 579 // tracking and subsequent adding to the worklist actually works. 580 581 // Replace all usages of the first successor with the second successor. 582 class ChangeBlockOp : public RewritePattern { 583 public: 584 ChangeBlockOp(MLIRContext *context) 585 : RewritePattern("test.change_block_op", /*benefit=*/1, context) {} 586 LogicalResult matchAndRewrite(Operation *op, 587 PatternRewriter &rewriter) const override { 588 if (op->getNumSuccessors() < 2) 589 return failure(); 590 Block *firstSuccessor = op->getSuccessor(0); 591 Block *secondSuccessor = op->getSuccessor(1); 592 if (firstSuccessor == secondSuccessor) 593 return failure(); 594 // This is the function being tested: 595 rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor); 596 // Using the following line instead would make the test fail: 597 // firstSuccessor->replaceAllUsesWith(secondSuccessor); 598 return success(); 599 } 600 }; 601 602 // Changes the successor to the parent block. 603 class ImplicitChangeOp : public RewritePattern { 604 public: 605 ImplicitChangeOp(MLIRContext *context) 606 : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {} 607 LogicalResult matchAndRewrite(Operation *op, 608 PatternRewriter &rewriter) const override { 609 if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock()) 610 return failure(); 611 rewriter.modifyOpInPlace(op, 612 [&]() { op->setSuccessor(op->getBlock(), 0); }); 613 return success(); 614 } 615 }; 616 }; 617 618 struct TestWalkPatternDriver final 619 : PassWrapper<TestWalkPatternDriver, OperationPass<>> { 620 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver) 621 622 TestWalkPatternDriver() = default; 623 TestWalkPatternDriver(const TestWalkPatternDriver &other) 624 : PassWrapper(other) {} 625 626 StringRef getArgument() const override { 627 return "test-walk-pattern-rewrite-driver"; 628 } 629 StringRef getDescription() const override { 630 return "Run test walk pattern rewrite driver"; 631 } 632 void runOnOperation() override { 633 mlir::RewritePatternSet patterns(&getContext()); 634 635 // Patterns for testing the WalkPatternRewriteDriver. 636 patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp, 637 MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>( 638 &getContext()); 639 640 DumpNotifications dumpListener; 641 walkAndApplyPatterns(getOperation(), std::move(patterns), 642 dumpNotifications ? &dumpListener : nullptr); 643 } 644 645 Option<bool> dumpNotifications{ 646 *this, "dump-notifications", 647 llvm::cl::desc("Print rewrite listener notifications"), 648 llvm::cl::init(false)}; 649 }; 650 651 } // namespace 652 653 //===----------------------------------------------------------------------===// 654 // ReturnType Driver. 655 //===----------------------------------------------------------------------===// 656 657 namespace { 658 // Generate ops for each instance where the type can be successfully inferred. 659 template <typename OpTy> 660 static void invokeCreateWithInferredReturnType(Operation *op) { 661 auto *context = op->getContext(); 662 auto fop = op->getParentOfType<func::FuncOp>(); 663 auto location = UnknownLoc::get(context); 664 OpBuilder b(op); 665 b.setInsertionPointAfter(op); 666 667 // Use permutations of 2 args as operands. 668 assert(fop.getNumArguments() >= 2); 669 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 670 for (int j = 0; j < e; ++j) { 671 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 672 SmallVector<Type, 2> inferredReturnTypes; 673 if (succeeded(OpTy::inferReturnTypes( 674 context, std::nullopt, values, op->getDiscardableAttrDictionary(), 675 op->getPropertiesStorage(), op->getRegions(), 676 inferredReturnTypes))) { 677 OperationState state(location, OpTy::getOperationName()); 678 // TODO: Expand to regions. 679 OpTy::build(b, state, values, op->getAttrs()); 680 (void)b.create(state); 681 } 682 } 683 } 684 } 685 686 static void reifyReturnShape(Operation *op) { 687 OpBuilder b(op); 688 689 // Use permutations of 2 args as operands. 690 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 691 SmallVector<Value, 2> shapes; 692 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 693 !llvm::hasSingleElement(shapes)) 694 return; 695 for (const auto &it : llvm::enumerate(shapes)) { 696 op->emitRemark() << "value " << it.index() << ": " 697 << it.value().getDefiningOp(); 698 } 699 } 700 701 struct TestReturnTypeDriver 702 : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> { 703 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver) 704 705 void getDependentDialects(DialectRegistry ®istry) const override { 706 registry.insert<tensor::TensorDialect>(); 707 } 708 StringRef getArgument() const final { return "test-return-type"; } 709 StringRef getDescription() const final { return "Run return type functions"; } 710 711 void runOnOperation() override { 712 if (getOperation().getName() == "testCreateFunctions") { 713 std::vector<Operation *> ops; 714 // Collect ops to avoid triggering on inserted ops. 715 for (auto &op : getOperation().getBody().front()) 716 ops.push_back(&op); 717 // Generate test patterns for each, but skip terminator. 718 for (auto *op : llvm::ArrayRef(ops).drop_back()) { 719 // Test create method of each of the Op classes below. The resultant 720 // output would be in reverse order underneath `op` from which 721 // the attributes and regions are used. 722 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 723 invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>( 724 op); 725 invokeCreateWithInferredReturnType< 726 OpWithShapedTypeInferTypeInterfaceOp>(op); 727 }; 728 return; 729 } 730 if (getOperation().getName() == "testReifyFunctions") { 731 std::vector<Operation *> ops; 732 // Collect ops to avoid triggering on inserted ops. 733 for (auto &op : getOperation().getBody().front()) 734 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 735 ops.push_back(&op); 736 // Generate test patterns for each, but skip terminator. 737 for (auto *op : ops) 738 reifyReturnShape(op); 739 } 740 } 741 }; 742 } // namespace 743 744 namespace { 745 struct TestDerivedAttributeDriver 746 : public PassWrapper<TestDerivedAttributeDriver, 747 OperationPass<func::FuncOp>> { 748 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver) 749 750 StringRef getArgument() const final { return "test-derived-attr"; } 751 StringRef getDescription() const final { 752 return "Run test derived attributes"; 753 } 754 void runOnOperation() override; 755 }; 756 } // namespace 757 758 void TestDerivedAttributeDriver::runOnOperation() { 759 getOperation().walk([](DerivedAttributeOpInterface dOp) { 760 auto dAttr = dOp.materializeDerivedAttributes(); 761 if (!dAttr) 762 return; 763 for (auto d : dAttr) 764 dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue(); 765 }); 766 } 767 768 //===----------------------------------------------------------------------===// 769 // Legalization Driver. 770 //===----------------------------------------------------------------------===// 771 772 namespace { 773 //===----------------------------------------------------------------------===// 774 // Region-Block Rewrite Testing 775 776 /// This pattern applies a signature conversion to a block inside a detached 777 /// region. 778 struct TestDetachedSignatureConversion : public ConversionPattern { 779 TestDetachedSignatureConversion(MLIRContext *ctx) 780 : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1, 781 ctx) {} 782 783 LogicalResult 784 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 785 ConversionPatternRewriter &rewriter) const final { 786 if (op->getNumRegions() != 1) 787 return failure(); 788 OperationState state(op->getLoc(), "test.legal_op_with_region", operands, 789 op->getResultTypes(), {}, BlockRange()); 790 Region *newRegion = state.addRegion(); 791 rewriter.inlineRegionBefore(op->getRegion(0), *newRegion, 792 newRegion->begin()); 793 TypeConverter::SignatureConversion result(newRegion->getNumArguments()); 794 for (unsigned i = 0, e = newRegion->getNumArguments(); i < e; ++i) 795 result.addInputs(i, rewriter.getF64Type()); 796 rewriter.applySignatureConversion(&newRegion->front(), result); 797 Operation *newOp = rewriter.create(state); 798 rewriter.replaceOp(op, newOp->getResults()); 799 return success(); 800 } 801 }; 802 803 /// This pattern is a simple pattern that inlines the first region of a given 804 /// operation into the parent region. 805 struct TestRegionRewriteBlockMovement : public ConversionPattern { 806 TestRegionRewriteBlockMovement(MLIRContext *ctx) 807 : ConversionPattern("test.region", 1, ctx) {} 808 809 LogicalResult 810 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 811 ConversionPatternRewriter &rewriter) const final { 812 // Inline this region into the parent region. 813 auto &parentRegion = *op->getParentRegion(); 814 auto &opRegion = op->getRegion(0); 815 if (op->getDiscardableAttr("legalizer.should_clone")) 816 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 817 else 818 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 819 820 if (op->getDiscardableAttr("legalizer.erase_old_blocks")) { 821 while (!opRegion.empty()) 822 rewriter.eraseBlock(&opRegion.front()); 823 } 824 825 // Drop this operation. 826 rewriter.eraseOp(op); 827 return success(); 828 } 829 }; 830 /// This pattern is a simple pattern that generates a region containing an 831 /// illegal operation. 832 struct TestRegionRewriteUndo : public RewritePattern { 833 TestRegionRewriteUndo(MLIRContext *ctx) 834 : RewritePattern("test.region_builder", 1, ctx) {} 835 836 LogicalResult matchAndRewrite(Operation *op, 837 PatternRewriter &rewriter) const final { 838 // Create the region operation with an entry block containing arguments. 839 OperationState newRegion(op->getLoc(), "test.region"); 840 newRegion.addRegion(); 841 auto *regionOp = rewriter.create(newRegion); 842 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 843 entryBlock->addArgument(rewriter.getIntegerType(64), 844 rewriter.getUnknownLoc()); 845 846 // Add an explicitly illegal operation to ensure the conversion fails. 847 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 848 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 849 850 // Drop this operation. 851 rewriter.eraseOp(op); 852 return success(); 853 } 854 }; 855 /// A simple pattern that creates a block at the end of the parent region of the 856 /// matched operation. 857 struct TestCreateBlock : public RewritePattern { 858 TestCreateBlock(MLIRContext *ctx) 859 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 860 861 LogicalResult matchAndRewrite(Operation *op, 862 PatternRewriter &rewriter) const final { 863 Region ®ion = *op->getParentRegion(); 864 Type i32Type = rewriter.getIntegerType(32); 865 Location loc = op->getLoc(); 866 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 867 rewriter.create<TerminatorOp>(loc); 868 rewriter.eraseOp(op); 869 return success(); 870 } 871 }; 872 873 /// A simple pattern that creates a block containing an invalid operation in 874 /// order to trigger the block creation undo mechanism. 875 struct TestCreateIllegalBlock : public RewritePattern { 876 TestCreateIllegalBlock(MLIRContext *ctx) 877 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 878 879 LogicalResult matchAndRewrite(Operation *op, 880 PatternRewriter &rewriter) const final { 881 Region ®ion = *op->getParentRegion(); 882 Type i32Type = rewriter.getIntegerType(32); 883 Location loc = op->getLoc(); 884 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 885 // Create an illegal op to ensure the conversion fails. 886 rewriter.create<ILLegalOpF>(loc, i32Type); 887 rewriter.create<TerminatorOp>(loc); 888 rewriter.eraseOp(op); 889 return success(); 890 } 891 }; 892 893 /// A simple pattern that tests the undo mechanism when replacing the uses of a 894 /// block argument. 895 struct TestUndoBlockArgReplace : public ConversionPattern { 896 TestUndoBlockArgReplace(MLIRContext *ctx) 897 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 898 899 LogicalResult 900 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 901 ConversionPatternRewriter &rewriter) const final { 902 auto illegalOp = 903 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 904 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 905 illegalOp->getResult(0)); 906 rewriter.modifyOpInPlace(op, [] {}); 907 return success(); 908 } 909 }; 910 911 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion. 912 /// This is to test the rollback logic. 913 struct TestUndoMoveOpBefore : public ConversionPattern { 914 TestUndoMoveOpBefore(MLIRContext *ctx) 915 : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {} 916 917 LogicalResult 918 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 919 ConversionPatternRewriter &rewriter) const override { 920 rewriter.moveOpBefore(op, op->getParentOp()); 921 // Replace with an illegal op to ensure the conversion fails. 922 rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type()); 923 return success(); 924 } 925 }; 926 927 /// A rewrite pattern that tests the undo mechanism when erasing a block. 928 struct TestUndoBlockErase : public ConversionPattern { 929 TestUndoBlockErase(MLIRContext *ctx) 930 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 931 932 LogicalResult 933 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 934 ConversionPatternRewriter &rewriter) const final { 935 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 936 rewriter.setInsertionPointToStart(secondBlock); 937 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 938 rewriter.eraseBlock(secondBlock); 939 rewriter.modifyOpInPlace(op, [] {}); 940 return success(); 941 } 942 }; 943 944 /// A pattern that modifies a property in-place, but keeps the op illegal. 945 struct TestUndoPropertiesModification : public ConversionPattern { 946 TestUndoPropertiesModification(MLIRContext *ctx) 947 : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {} 948 LogicalResult 949 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 950 ConversionPatternRewriter &rewriter) const final { 951 if (!op->hasAttr("modify_inplace")) 952 return failure(); 953 rewriter.modifyOpInPlace( 954 op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); }); 955 return success(); 956 } 957 }; 958 959 //===----------------------------------------------------------------------===// 960 // Type-Conversion Rewrite Testing 961 962 /// This patterns erases a region operation that has had a type conversion. 963 struct TestDropOpSignatureConversion : public ConversionPattern { 964 TestDropOpSignatureConversion(MLIRContext *ctx, 965 const TypeConverter &converter) 966 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 967 LogicalResult 968 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 969 ConversionPatternRewriter &rewriter) const override { 970 Region ®ion = op->getRegion(0); 971 Block *entry = ®ion.front(); 972 973 // Convert the original entry arguments. 974 const TypeConverter &converter = *getTypeConverter(); 975 TypeConverter::SignatureConversion result(entry->getNumArguments()); 976 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 977 result)) || 978 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 979 return failure(); 980 981 // Convert the region signature and just drop the operation. 982 rewriter.eraseOp(op); 983 return success(); 984 } 985 }; 986 /// This pattern simply updates the operands of the given operation. 987 struct TestPassthroughInvalidOp : public ConversionPattern { 988 TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter) 989 : ConversionPattern(converter, "test.invalid", 1, ctx) {} 990 LogicalResult 991 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 992 ConversionPatternRewriter &rewriter) const final { 993 SmallVector<Value> flattened; 994 for (auto it : llvm::enumerate(operands)) { 995 ValueRange range = it.value(); 996 if (range.size() == 1) { 997 flattened.push_back(range.front()); 998 continue; 999 } 1000 1001 // This is a 1:N replacement. Insert a test.cast op. (That's what the 1002 // argument materialization used to do.) 1003 flattened.push_back( 1004 rewriter 1005 .create<TestCastOp>(op->getLoc(), 1006 op->getOperand(it.index()).getType(), range) 1007 .getResult()); 1008 } 1009 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, flattened, 1010 std::nullopt); 1011 return success(); 1012 } 1013 }; 1014 /// Replace with valid op, but simply drop the operands. This is used in a 1015 /// regression where we used to generate circular unrealized_conversion_cast 1016 /// ops. 1017 struct TestDropAndReplaceInvalidOp : public ConversionPattern { 1018 TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter) 1019 : ConversionPattern(converter, 1020 "test.drop_operands_and_replace_with_valid", 1, ctx) { 1021 } 1022 LogicalResult 1023 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1024 ConversionPatternRewriter &rewriter) const final { 1025 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, ValueRange(), 1026 std::nullopt); 1027 return success(); 1028 } 1029 }; 1030 /// This pattern handles the case of a split return value. 1031 struct TestSplitReturnType : public ConversionPattern { 1032 TestSplitReturnType(MLIRContext *ctx) 1033 : ConversionPattern("test.return", 1, ctx) {} 1034 LogicalResult 1035 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 1036 ConversionPatternRewriter &rewriter) const final { 1037 // Check for a return of F32. 1038 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 1039 return failure(); 1040 rewriter.replaceOpWithNewOp<TestReturnOp>(op, operands[0]); 1041 return success(); 1042 } 1043 }; 1044 1045 //===----------------------------------------------------------------------===// 1046 // Multi-Level Type-Conversion Rewrite Testing 1047 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 1048 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 1049 : ConversionPattern("test.type_producer", 1, ctx) {} 1050 LogicalResult 1051 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1052 ConversionPatternRewriter &rewriter) const final { 1053 // If the type is I32, change the type to F32. 1054 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 1055 return failure(); 1056 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 1057 return success(); 1058 } 1059 }; 1060 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 1061 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 1062 : ConversionPattern("test.type_producer", 1, ctx) {} 1063 LogicalResult 1064 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1065 ConversionPatternRewriter &rewriter) const final { 1066 // If the type is F32, change the type to F64. 1067 if (!Type(*op->result_type_begin()).isF32()) 1068 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 1069 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 1070 return success(); 1071 } 1072 }; 1073 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 1074 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 1075 : ConversionPattern("test.type_producer", 10, ctx) {} 1076 LogicalResult 1077 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1078 ConversionPatternRewriter &rewriter) const final { 1079 // Always convert to B16, even though it is not a legal type. This tests 1080 // that values are unmapped correctly. 1081 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 1082 return success(); 1083 } 1084 }; 1085 struct TestUpdateConsumerType : public ConversionPattern { 1086 TestUpdateConsumerType(MLIRContext *ctx) 1087 : ConversionPattern("test.type_consumer", 1, ctx) {} 1088 LogicalResult 1089 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1090 ConversionPatternRewriter &rewriter) const final { 1091 // Verify that the incoming operand has been successfully remapped to F64. 1092 if (!operands[0].getType().isF64()) 1093 return failure(); 1094 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 1095 return success(); 1096 } 1097 }; 1098 1099 //===----------------------------------------------------------------------===// 1100 // Non-Root Replacement Rewrite Testing 1101 /// This pattern generates an invalid operation, but replaces it before the 1102 /// pattern is finished. This checks that we don't need to legalize the 1103 /// temporary op. 1104 struct TestNonRootReplacement : public RewritePattern { 1105 TestNonRootReplacement(MLIRContext *ctx) 1106 : RewritePattern("test.replace_non_root", 1, ctx) {} 1107 1108 LogicalResult matchAndRewrite(Operation *op, 1109 PatternRewriter &rewriter) const final { 1110 auto resultType = *op->result_type_begin(); 1111 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 1112 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 1113 1114 rewriter.replaceOp(illegalOp, legalOp); 1115 rewriter.replaceOp(op, illegalOp); 1116 return success(); 1117 } 1118 }; 1119 1120 //===----------------------------------------------------------------------===// 1121 // Recursive Rewrite Testing 1122 /// This pattern is applied to the same operation multiple times, but has a 1123 /// bounded recursion. 1124 struct TestBoundedRecursiveRewrite 1125 : public OpRewritePattern<TestRecursiveRewriteOp> { 1126 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 1127 1128 void initialize() { 1129 // The conversion target handles bounding the recursion of this pattern. 1130 setHasBoundedRewriteRecursion(); 1131 } 1132 1133 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 1134 PatternRewriter &rewriter) const final { 1135 // Decrement the depth of the op in-place. 1136 rewriter.modifyOpInPlace(op, [&] { 1137 op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1)); 1138 }); 1139 return success(); 1140 } 1141 }; 1142 1143 struct TestNestedOpCreationUndoRewrite 1144 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 1145 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 1146 1147 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 1148 PatternRewriter &rewriter) const final { 1149 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 1150 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 1151 return success(); 1152 }; 1153 }; 1154 1155 // This pattern matches `test.blackhole` and delete this op and its producer. 1156 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 1157 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 1158 1159 LogicalResult matchAndRewrite(BlackHoleOp op, 1160 PatternRewriter &rewriter) const final { 1161 Operation *producer = op.getOperand().getDefiningOp(); 1162 // Always erase the user before the producer, the framework should handle 1163 // this correctly. 1164 rewriter.eraseOp(op); 1165 rewriter.eraseOp(producer); 1166 return success(); 1167 }; 1168 }; 1169 1170 // This pattern replaces explicitly illegal op with explicitly legal op, 1171 // but in addition creates unregistered operation. 1172 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { 1173 using OpRewritePattern<ILLegalOpG>::OpRewritePattern; 1174 1175 LogicalResult matchAndRewrite(ILLegalOpG op, 1176 PatternRewriter &rewriter) const final { 1177 IntegerAttr attr = rewriter.getI32IntegerAttr(0); 1178 Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr); 1179 rewriter.replaceOpWithNewOp<LegalOpC>(op, val); 1180 return success(); 1181 }; 1182 }; 1183 1184 class TestEraseOp : public ConversionPattern { 1185 public: 1186 TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op", 1, ctx) {} 1187 LogicalResult 1188 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1189 ConversionPatternRewriter &rewriter) const final { 1190 // Erase op without replacements. 1191 rewriter.eraseOp(op); 1192 return success(); 1193 } 1194 }; 1195 1196 /// This pattern matches a test.duplicate_block_args op and duplicates all 1197 /// block arguments. 1198 class TestDuplicateBlockArgs 1199 : public OpConversionPattern<DuplicateBlockArgsOp> { 1200 using OpConversionPattern<DuplicateBlockArgsOp>::OpConversionPattern; 1201 1202 LogicalResult 1203 matchAndRewrite(DuplicateBlockArgsOp op, OpAdaptor adaptor, 1204 ConversionPatternRewriter &rewriter) const override { 1205 if (op.getIsLegal()) 1206 return failure(); 1207 rewriter.startOpModification(op); 1208 Block *body = &op.getBody().front(); 1209 TypeConverter::SignatureConversion result(body->getNumArguments()); 1210 for (auto it : llvm::enumerate(body->getArgumentTypes())) 1211 result.addInputs(it.index(), {it.value(), it.value()}); 1212 rewriter.applySignatureConversion(body, result, getTypeConverter()); 1213 op.setIsLegal(true); 1214 rewriter.finalizeOpModification(op); 1215 return success(); 1216 } 1217 }; 1218 1219 /// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid 1220 /// op. The pattern supports 1:N replacements and forwards the replacement 1221 /// values of the single operand as test.valid operands. 1222 class TestRepetitive1ToNConsumer : public ConversionPattern { 1223 public: 1224 TestRepetitive1ToNConsumer(MLIRContext *ctx) 1225 : ConversionPattern("test.repetitive_1_to_n_consumer", 1, ctx) {} 1226 LogicalResult 1227 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 1228 ConversionPatternRewriter &rewriter) const final { 1229 // A single operand is expected. 1230 if (op->getNumOperands() != 1) 1231 return failure(); 1232 rewriter.replaceOpWithNewOp<TestValidOp>(op, operands.front()); 1233 return success(); 1234 } 1235 }; 1236 1237 } // namespace 1238 1239 namespace { 1240 struct TestTypeConverter : public TypeConverter { 1241 using TypeConverter::TypeConverter; 1242 TestTypeConverter() { 1243 addConversion(convertType); 1244 addArgumentMaterialization(materializeCast); 1245 addSourceMaterialization(materializeCast); 1246 } 1247 1248 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 1249 // Drop I16 types. 1250 if (t.isSignlessInteger(16)) 1251 return success(); 1252 1253 // Convert I64 to F64. 1254 if (t.isSignlessInteger(64)) { 1255 results.push_back(FloatType::getF64(t.getContext())); 1256 return success(); 1257 } 1258 1259 // Convert I42 to I43. 1260 if (t.isInteger(42)) { 1261 results.push_back(IntegerType::get(t.getContext(), 43)); 1262 return success(); 1263 } 1264 1265 // Split F32 into F16,F16. 1266 if (t.isF32()) { 1267 results.assign(2, FloatType::getF16(t.getContext())); 1268 return success(); 1269 } 1270 1271 // Drop I24 types. 1272 if (t.isInteger(24)) { 1273 return success(); 1274 } 1275 1276 // Otherwise, convert the type directly. 1277 results.push_back(t); 1278 return success(); 1279 } 1280 1281 /// Hook for materializing a conversion. This is necessary because we generate 1282 /// 1->N type mappings. 1283 static Value materializeCast(OpBuilder &builder, Type resultType, 1284 ValueRange inputs, Location loc) { 1285 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1286 } 1287 }; 1288 1289 struct TestLegalizePatternDriver 1290 : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> { 1291 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver) 1292 1293 StringRef getArgument() const final { return "test-legalize-patterns"; } 1294 StringRef getDescription() const final { 1295 return "Run test dialect legalization patterns"; 1296 } 1297 /// The mode of conversion to use with the driver. 1298 enum class ConversionMode { Analysis, Full, Partial }; 1299 1300 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 1301 1302 void getDependentDialects(DialectRegistry ®istry) const override { 1303 registry.insert<func::FuncDialect, test::TestDialect>(); 1304 } 1305 1306 void runOnOperation() override { 1307 TestTypeConverter converter; 1308 mlir::RewritePatternSet patterns(&getContext()); 1309 populateWithGenerated(patterns); 1310 patterns 1311 .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion, 1312 TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock, 1313 TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType, 1314 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 1315 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 1316 TestNonRootReplacement, TestBoundedRecursiveRewrite, 1317 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, 1318 TestCreateUnregisteredOp, TestUndoMoveOpBefore, 1319 TestUndoPropertiesModification, TestEraseOp, 1320 TestRepetitive1ToNConsumer>(&getContext()); 1321 patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp, 1322 TestPassthroughInvalidOp>(&getContext(), converter); 1323 patterns.add<TestDuplicateBlockArgs>(converter, &getContext()); 1324 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1325 converter); 1326 mlir::populateCallOpTypeConversionPattern(patterns, converter); 1327 1328 // Define the conversion target used for the test. 1329 ConversionTarget target(getContext()); 1330 target.addLegalOp<ModuleOp>(); 1331 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, 1332 TerminatorOp, OneRegionOp>(); 1333 target.addLegalOp( 1334 OperationName("test.legal_op_with_region", &getContext())); 1335 target 1336 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 1337 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 1338 // Don't allow F32 operands. 1339 return llvm::none_of(op.getOperandTypes(), 1340 [](Type type) { return type.isF32(); }); 1341 }); 1342 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1343 return converter.isSignatureLegal(op.getFunctionType()) && 1344 converter.isLegal(&op.getBody()); 1345 }); 1346 target.addDynamicallyLegalOp<func::CallOp>( 1347 [&](func::CallOp op) { return converter.isLegal(op); }); 1348 1349 // TestCreateUnregisteredOp creates `arith.constant` operation, 1350 // which was not added to target intentionally to test 1351 // correct error code from conversion driver. 1352 target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); 1353 1354 // Expect the type_producer/type_consumer operations to only operate on f64. 1355 target.addDynamicallyLegalOp<TestTypeProducerOp>( 1356 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1357 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1358 return op.getOperand().getType().isF64(); 1359 }); 1360 1361 // Check support for marking certain operations as recursively legal. 1362 target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) { 1363 return static_cast<bool>( 1364 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 1365 }); 1366 1367 // Mark the bound recursion operation as dynamically legal. 1368 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 1369 [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); 1370 1371 // Create a dynamically legal rule that can only be legalized by folding it. 1372 target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>( 1373 [](TestOpInPlaceSelfFold op) { return op.getFolded(); }); 1374 1375 target.addDynamicallyLegalOp<DuplicateBlockArgsOp>( 1376 [](DuplicateBlockArgsOp op) { return op.getIsLegal(); }); 1377 1378 // Handle a partial conversion. 1379 if (mode == ConversionMode::Partial) { 1380 DenseSet<Operation *> unlegalizedOps; 1381 ConversionConfig config; 1382 DumpNotifications dumpNotifications; 1383 config.listener = &dumpNotifications; 1384 config.unlegalizedOps = &unlegalizedOps; 1385 if (failed(applyPartialConversion(getOperation(), target, 1386 std::move(patterns), config))) { 1387 getOperation()->emitRemark() << "applyPartialConversion failed"; 1388 } 1389 // Emit remarks for each legalizable operation. 1390 for (auto *op : unlegalizedOps) 1391 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1392 return; 1393 } 1394 1395 // Handle a full conversion. 1396 if (mode == ConversionMode::Full) { 1397 // Check support for marking unknown operations as dynamically legal. 1398 target.markUnknownOpDynamicallyLegal([](Operation *op) { 1399 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 1400 }); 1401 1402 ConversionConfig config; 1403 DumpNotifications dumpNotifications; 1404 config.listener = &dumpNotifications; 1405 if (failed(applyFullConversion(getOperation(), target, 1406 std::move(patterns), config))) { 1407 getOperation()->emitRemark() << "applyFullConversion failed"; 1408 } 1409 return; 1410 } 1411 1412 // Otherwise, handle an analysis conversion. 1413 assert(mode == ConversionMode::Analysis); 1414 1415 // Analyze the convertible operations. 1416 DenseSet<Operation *> legalizedOps; 1417 ConversionConfig config; 1418 config.legalizableOps = &legalizedOps; 1419 if (failed(applyAnalysisConversion(getOperation(), target, 1420 std::move(patterns), config))) 1421 return signalPassFailure(); 1422 1423 // Emit remarks for each legalizable operation. 1424 for (auto *op : legalizedOps) 1425 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 1426 } 1427 1428 /// The mode of conversion to use. 1429 ConversionMode mode; 1430 }; 1431 } // namespace 1432 1433 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 1434 legalizerConversionMode( 1435 "test-legalize-mode", 1436 llvm::cl::desc("The legalization mode to use with the test driver"), 1437 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 1438 llvm::cl::values( 1439 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 1440 "analysis", "Perform an analysis conversion"), 1441 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 1442 "Perform a full conversion"), 1443 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 1444 "partial", "Perform a partial conversion"))); 1445 1446 //===----------------------------------------------------------------------===// 1447 // ConversionPatternRewriter::getRemappedValue testing. This method is used 1448 // to get the remapped value of an original value that was replaced using 1449 // ConversionPatternRewriter. 1450 namespace { 1451 struct TestRemapValueTypeConverter : public TypeConverter { 1452 using TypeConverter::TypeConverter; 1453 1454 TestRemapValueTypeConverter() { 1455 addConversion( 1456 [](Float32Type type) { return Float64Type::get(type.getContext()); }); 1457 addConversion([](Type type) { return type; }); 1458 } 1459 }; 1460 1461 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 1462 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 1463 /// operand twice. 1464 /// 1465 /// Example: 1466 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 1467 /// is replaced with: 1468 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 1469 struct OneVResOneVOperandOp1Converter 1470 : public OpConversionPattern<OneVResOneVOperandOp1> { 1471 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 1472 1473 LogicalResult 1474 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, 1475 ConversionPatternRewriter &rewriter) const override { 1476 auto origOps = op.getOperands(); 1477 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 1478 "One operand expected"); 1479 Value origOp = *origOps.begin(); 1480 SmallVector<Value, 2> remappedOperands; 1481 // Replicate the remapped original operand twice. Note that we don't used 1482 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 1483 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1484 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1485 1486 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 1487 remappedOperands); 1488 return success(); 1489 } 1490 }; 1491 1492 /// A rewriter pattern that tests that blocks can be merged. 1493 struct TestRemapValueInRegion 1494 : public OpConversionPattern<TestRemappedValueRegionOp> { 1495 using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern; 1496 1497 LogicalResult 1498 matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor, 1499 ConversionPatternRewriter &rewriter) const final { 1500 Block &block = op.getBody().front(); 1501 Operation *terminator = block.getTerminator(); 1502 1503 // Merge the block into the parent region. 1504 Block *parentBlock = op->getBlock(); 1505 Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator()); 1506 rewriter.mergeBlocks(&block, parentBlock, ValueRange()); 1507 rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange()); 1508 1509 // Replace the results of this operation with the remapped terminator 1510 // values. 1511 SmallVector<Value> terminatorOperands; 1512 if (failed(rewriter.getRemappedValues(terminator->getOperands(), 1513 terminatorOperands))) 1514 return failure(); 1515 1516 rewriter.eraseOp(terminator); 1517 rewriter.replaceOp(op, terminatorOperands); 1518 return success(); 1519 } 1520 }; 1521 1522 struct TestRemappedValue 1523 : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> { 1524 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue) 1525 1526 StringRef getArgument() const final { return "test-remapped-value"; } 1527 StringRef getDescription() const final { 1528 return "Test public remapped value mechanism in ConversionPatternRewriter"; 1529 } 1530 void runOnOperation() override { 1531 TestRemapValueTypeConverter typeConverter; 1532 1533 mlir::RewritePatternSet patterns(&getContext()); 1534 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 1535 patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>( 1536 &getContext()); 1537 patterns.add<TestRemapValueInRegion>(typeConverter, &getContext()); 1538 1539 mlir::ConversionTarget target(getContext()); 1540 target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>(); 1541 1542 // Expect the type_producer/type_consumer operations to only operate on f64. 1543 target.addDynamicallyLegalOp<TestTypeProducerOp>( 1544 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1545 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1546 return op.getOperand().getType().isF64(); 1547 }); 1548 1549 // We make OneVResOneVOperandOp1 legal only when it has more that one 1550 // operand. This will trigger the conversion that will replace one-operand 1551 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 1552 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 1553 [](Operation *op) { return op->getNumOperands() > 1; }); 1554 1555 if (failed(mlir::applyFullConversion(getOperation(), target, 1556 std::move(patterns)))) { 1557 signalPassFailure(); 1558 } 1559 } 1560 }; 1561 } // namespace 1562 1563 //===----------------------------------------------------------------------===// 1564 // Test patterns without a specific root operation kind 1565 //===----------------------------------------------------------------------===// 1566 1567 namespace { 1568 /// This pattern matches and removes any operation in the test dialect. 1569 struct RemoveTestDialectOps : public RewritePattern { 1570 RemoveTestDialectOps(MLIRContext *context) 1571 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 1572 1573 LogicalResult matchAndRewrite(Operation *op, 1574 PatternRewriter &rewriter) const override { 1575 if (!isa<TestDialect>(op->getDialect())) 1576 return failure(); 1577 rewriter.eraseOp(op); 1578 return success(); 1579 } 1580 }; 1581 1582 struct TestUnknownRootOpDriver 1583 : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> { 1584 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver) 1585 1586 StringRef getArgument() const final { 1587 return "test-legalize-unknown-root-patterns"; 1588 } 1589 StringRef getDescription() const final { 1590 return "Test public remapped value mechanism in ConversionPatternRewriter"; 1591 } 1592 void runOnOperation() override { 1593 mlir::RewritePatternSet patterns(&getContext()); 1594 patterns.add<RemoveTestDialectOps>(&getContext()); 1595 1596 mlir::ConversionTarget target(getContext()); 1597 target.addIllegalDialect<TestDialect>(); 1598 if (failed(applyPartialConversion(getOperation(), target, 1599 std::move(patterns)))) 1600 signalPassFailure(); 1601 } 1602 }; 1603 } // namespace 1604 1605 //===----------------------------------------------------------------------===// 1606 // Test patterns that uses operations and types defined at runtime 1607 //===----------------------------------------------------------------------===// 1608 1609 namespace { 1610 /// This pattern matches dynamic operations 'test.one_operand_two_results' and 1611 /// replace them with dynamic operations 'test.generic_dynamic_op'. 1612 struct RewriteDynamicOp : public RewritePattern { 1613 RewriteDynamicOp(MLIRContext *context) 1614 : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1, 1615 context) {} 1616 1617 LogicalResult matchAndRewrite(Operation *op, 1618 PatternRewriter &rewriter) const override { 1619 assert(op->getName().getStringRef() == 1620 "test.dynamic_one_operand_two_results" && 1621 "rewrite pattern should only match operations with the right name"); 1622 1623 OperationState state(op->getLoc(), "test.dynamic_generic", 1624 op->getOperands(), op->getResultTypes(), 1625 op->getAttrs()); 1626 auto *newOp = rewriter.create(state); 1627 rewriter.replaceOp(op, newOp->getResults()); 1628 return success(); 1629 } 1630 }; 1631 1632 struct TestRewriteDynamicOpDriver 1633 : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> { 1634 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver) 1635 1636 void getDependentDialects(DialectRegistry ®istry) const override { 1637 registry.insert<TestDialect>(); 1638 } 1639 StringRef getArgument() const final { return "test-rewrite-dynamic-op"; } 1640 StringRef getDescription() const final { 1641 return "Test rewritting on dynamic operations"; 1642 } 1643 void runOnOperation() override { 1644 RewritePatternSet patterns(&getContext()); 1645 patterns.add<RewriteDynamicOp>(&getContext()); 1646 1647 ConversionTarget target(getContext()); 1648 target.addIllegalOp( 1649 OperationName("test.dynamic_one_operand_two_results", &getContext())); 1650 target.addLegalOp(OperationName("test.dynamic_generic", &getContext())); 1651 if (failed(applyPartialConversion(getOperation(), target, 1652 std::move(patterns)))) 1653 signalPassFailure(); 1654 } 1655 }; 1656 } // end anonymous namespace 1657 1658 //===----------------------------------------------------------------------===// 1659 // Test type conversions 1660 //===----------------------------------------------------------------------===// 1661 1662 namespace { 1663 struct TestTypeConversionProducer 1664 : public OpConversionPattern<TestTypeProducerOp> { 1665 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 1666 LogicalResult 1667 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, 1668 ConversionPatternRewriter &rewriter) const final { 1669 Type resultType = op.getType(); 1670 Type convertedType = getTypeConverter() 1671 ? getTypeConverter()->convertType(resultType) 1672 : resultType; 1673 if (isa<FloatType>(resultType)) 1674 resultType = rewriter.getF64Type(); 1675 else if (resultType.isInteger(16)) 1676 resultType = rewriter.getIntegerType(64); 1677 else if (isa<test::TestRecursiveType>(resultType) && 1678 convertedType != resultType) 1679 resultType = convertedType; 1680 else 1681 return failure(); 1682 1683 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 1684 return success(); 1685 } 1686 }; 1687 1688 /// Call signature conversion and then fail the rewrite to trigger the undo 1689 /// mechanism. 1690 struct TestSignatureConversionUndo 1691 : public OpConversionPattern<TestSignatureConversionUndoOp> { 1692 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 1693 1694 LogicalResult 1695 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, 1696 ConversionPatternRewriter &rewriter) const final { 1697 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 1698 return failure(); 1699 } 1700 }; 1701 1702 /// Call signature conversion without providing a type converter to handle 1703 /// materializations. 1704 struct TestTestSignatureConversionNoConverter 1705 : public OpConversionPattern<TestSignatureConversionNoConverterOp> { 1706 TestTestSignatureConversionNoConverter(const TypeConverter &converter, 1707 MLIRContext *context) 1708 : OpConversionPattern<TestSignatureConversionNoConverterOp>(context), 1709 converter(converter) {} 1710 1711 LogicalResult 1712 matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor, 1713 ConversionPatternRewriter &rewriter) const final { 1714 Region ®ion = op->getRegion(0); 1715 Block *entry = ®ion.front(); 1716 1717 // Convert the original entry arguments. 1718 TypeConverter::SignatureConversion result(entry->getNumArguments()); 1719 if (failed( 1720 converter.convertSignatureArgs(entry->getArgumentTypes(), result))) 1721 return failure(); 1722 rewriter.modifyOpInPlace(op, [&] { 1723 rewriter.applySignatureConversion(®ion.front(), result); 1724 }); 1725 return success(); 1726 } 1727 1728 const TypeConverter &converter; 1729 }; 1730 1731 /// Just forward the operands to the root op. This is essentially a no-op 1732 /// pattern that is used to trigger target materialization. 1733 struct TestTypeConsumerForward 1734 : public OpConversionPattern<TestTypeConsumerOp> { 1735 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 1736 1737 LogicalResult 1738 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, 1739 ConversionPatternRewriter &rewriter) const final { 1740 rewriter.modifyOpInPlace(op, 1741 [&] { op->setOperands(adaptor.getOperands()); }); 1742 return success(); 1743 } 1744 }; 1745 1746 struct TestTypeConversionAnotherProducer 1747 : public OpRewritePattern<TestAnotherTypeProducerOp> { 1748 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 1749 1750 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 1751 PatternRewriter &rewriter) const final { 1752 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 1753 return success(); 1754 } 1755 }; 1756 1757 struct TestReplaceWithLegalOp : public ConversionPattern { 1758 TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx) 1759 : ConversionPattern(converter, "test.replace_with_legal_op", 1760 /*benefit=*/1, ctx) {} 1761 LogicalResult 1762 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1763 ConversionPatternRewriter &rewriter) const final { 1764 rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]); 1765 return success(); 1766 } 1767 }; 1768 1769 struct TestTypeConversionDriver 1770 : public PassWrapper<TestTypeConversionDriver, OperationPass<>> { 1771 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver) 1772 1773 void getDependentDialects(DialectRegistry ®istry) const override { 1774 registry.insert<TestDialect>(); 1775 } 1776 StringRef getArgument() const final { 1777 return "test-legalize-type-conversion"; 1778 } 1779 StringRef getDescription() const final { 1780 return "Test various type conversion functionalities in DialectConversion"; 1781 } 1782 1783 void runOnOperation() override { 1784 // Initialize the type converter. 1785 SmallVector<Type, 2> conversionCallStack; 1786 TypeConverter converter; 1787 1788 /// Add the legal set of type conversions. 1789 converter.addConversion([](Type type) -> Type { 1790 // Treat F64 as legal. 1791 if (type.isF64()) 1792 return type; 1793 // Allow converting BF16/F16/F32 to F64. 1794 if (type.isBF16() || type.isF16() || type.isF32()) 1795 return FloatType::getF64(type.getContext()); 1796 // Otherwise, the type is illegal. 1797 return nullptr; 1798 }); 1799 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 1800 // Drop all integer types. 1801 return success(); 1802 }); 1803 converter.addConversion( 1804 // Convert a recursive self-referring type into a non-self-referring 1805 // type named "outer_converted_type" that contains a SimpleAType. 1806 [&](test::TestRecursiveType type, 1807 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { 1808 // If the type is already converted, return it to indicate that it is 1809 // legal. 1810 if (type.getName() == "outer_converted_type") { 1811 results.push_back(type); 1812 return success(); 1813 } 1814 1815 conversionCallStack.push_back(type); 1816 auto popConversionCallStack = llvm::make_scope_exit( 1817 [&conversionCallStack]() { conversionCallStack.pop_back(); }); 1818 1819 // If the type is on the call stack more than once (it is there at 1820 // least once because of the _current_ call, which is always the last 1821 // element on the stack), we've hit the recursive case. Just return 1822 // SimpleAType here to create a non-recursive type as a result. 1823 if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(), 1824 type)) { 1825 results.push_back(test::SimpleAType::get(type.getContext())); 1826 return success(); 1827 } 1828 1829 // Convert the body recursively. 1830 auto result = test::TestRecursiveType::get(type.getContext(), 1831 "outer_converted_type"); 1832 if (failed(result.setBody(converter.convertType(type.getBody())))) 1833 return failure(); 1834 results.push_back(result); 1835 return success(); 1836 }); 1837 1838 /// Add the legal set of type materializations. 1839 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 1840 ValueRange inputs, 1841 Location loc) -> Value { 1842 // Allow casting from F64 back to F32. 1843 if (!resultType.isF16() && inputs.size() == 1 && 1844 inputs[0].getType().isF64()) 1845 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1846 // Allow producing an i32 or i64 from nothing. 1847 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 1848 inputs.empty()) 1849 return builder.create<TestTypeProducerOp>(loc, resultType); 1850 // Allow producing an i64 from an integer. 1851 if (isa<IntegerType>(resultType) && inputs.size() == 1 && 1852 isa<IntegerType>(inputs[0].getType())) 1853 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1854 // Otherwise, fail. 1855 return nullptr; 1856 }); 1857 1858 // Initialize the conversion target. 1859 mlir::ConversionTarget target(getContext()); 1860 target.addLegalOp<LegalOpD>(); 1861 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 1862 auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType()); 1863 return op.getType().isF64() || op.getType().isInteger(64) || 1864 (recursiveType && 1865 recursiveType.getName() == "outer_converted_type"); 1866 }); 1867 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1868 return converter.isSignatureLegal(op.getFunctionType()) && 1869 converter.isLegal(&op.getBody()); 1870 }); 1871 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 1872 // Allow casts from F64 to F32. 1873 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 1874 }); 1875 target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>( 1876 [&](TestSignatureConversionNoConverterOp op) { 1877 return converter.isLegal(op.getRegion().front().getArgumentTypes()); 1878 }); 1879 1880 // Initialize the set of rewrite patterns. 1881 RewritePatternSet patterns(&getContext()); 1882 patterns 1883 .add<TestTypeConsumerForward, TestTypeConversionProducer, 1884 TestSignatureConversionUndo, 1885 TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>( 1886 converter, &getContext()); 1887 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 1888 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1889 converter); 1890 1891 if (failed(applyPartialConversion(getOperation(), target, 1892 std::move(patterns)))) 1893 signalPassFailure(); 1894 } 1895 }; 1896 } // namespace 1897 1898 //===----------------------------------------------------------------------===// 1899 // Test Target Materialization With No Uses 1900 //===----------------------------------------------------------------------===// 1901 1902 namespace { 1903 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> { 1904 using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern; 1905 1906 LogicalResult 1907 matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor, 1908 ConversionPatternRewriter &rewriter) const final { 1909 rewriter.replaceOp(op, adaptor.getOperands()); 1910 return success(); 1911 } 1912 }; 1913 1914 struct TestTargetMaterializationWithNoUses 1915 : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> { 1916 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 1917 TestTargetMaterializationWithNoUses) 1918 1919 StringRef getArgument() const final { 1920 return "test-target-materialization-with-no-uses"; 1921 } 1922 StringRef getDescription() const final { 1923 return "Test a special case of target materialization in DialectConversion"; 1924 } 1925 1926 void runOnOperation() override { 1927 TypeConverter converter; 1928 converter.addConversion([](Type t) { return t; }); 1929 converter.addConversion([](IntegerType intTy) -> Type { 1930 if (intTy.getWidth() == 16) 1931 return IntegerType::get(intTy.getContext(), 64); 1932 return intTy; 1933 }); 1934 converter.addTargetMaterialization( 1935 [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { 1936 return builder.create<TestCastOp>(loc, type, inputs).getResult(); 1937 }); 1938 1939 ConversionTarget target(getContext()); 1940 target.addIllegalOp<TestTypeChangerOp>(); 1941 1942 RewritePatternSet patterns(&getContext()); 1943 patterns.add<ForwardOperandPattern>(converter, &getContext()); 1944 1945 if (failed(applyPartialConversion(getOperation(), target, 1946 std::move(patterns)))) 1947 signalPassFailure(); 1948 } 1949 }; 1950 } // namespace 1951 1952 //===----------------------------------------------------------------------===// 1953 // Test Block Merging 1954 //===----------------------------------------------------------------------===// 1955 1956 namespace { 1957 /// A rewriter pattern that tests that blocks can be merged. 1958 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 1959 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 1960 1961 LogicalResult 1962 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, 1963 ConversionPatternRewriter &rewriter) const final { 1964 Block &firstBlock = op.getBody().front(); 1965 Operation *branchOp = firstBlock.getTerminator(); 1966 Block *secondBlock = &*(std::next(op.getBody().begin())); 1967 auto succOperands = branchOp->getOperands(); 1968 SmallVector<Value, 2> replacements(succOperands); 1969 rewriter.eraseOp(branchOp); 1970 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1971 rewriter.modifyOpInPlace(op, [] {}); 1972 return success(); 1973 } 1974 }; 1975 1976 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 1977 struct TestUndoBlocksMerge : public ConversionPattern { 1978 TestUndoBlocksMerge(MLIRContext *ctx) 1979 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 1980 LogicalResult 1981 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1982 ConversionPatternRewriter &rewriter) const final { 1983 Block &firstBlock = op->getRegion(0).front(); 1984 Operation *branchOp = firstBlock.getTerminator(); 1985 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 1986 rewriter.setInsertionPointToStart(secondBlock); 1987 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 1988 auto succOperands = branchOp->getOperands(); 1989 SmallVector<Value, 2> replacements(succOperands); 1990 rewriter.eraseOp(branchOp); 1991 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1992 rewriter.modifyOpInPlace(op, [] {}); 1993 return success(); 1994 } 1995 }; 1996 1997 /// A rewrite mechanism to inline the body of the op into its parent, when both 1998 /// ops can have a single block. 1999 struct TestMergeSingleBlockOps 2000 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 2001 using OpConversionPattern< 2002 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 2003 2004 LogicalResult 2005 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, 2006 ConversionPatternRewriter &rewriter) const final { 2007 SingleBlockImplicitTerminatorOp parentOp = 2008 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 2009 if (!parentOp) 2010 return failure(); 2011 Block &innerBlock = op.getRegion().front(); 2012 TerminatorOp innerTerminator = 2013 cast<TerminatorOp>(innerBlock.getTerminator()); 2014 rewriter.inlineBlockBefore(&innerBlock, op); 2015 rewriter.eraseOp(innerTerminator); 2016 rewriter.eraseOp(op); 2017 return success(); 2018 } 2019 }; 2020 2021 struct TestMergeBlocksPatternDriver 2022 : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> { 2023 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver) 2024 2025 StringRef getArgument() const final { return "test-merge-blocks"; } 2026 StringRef getDescription() const final { 2027 return "Test Merging operation in ConversionPatternRewriter"; 2028 } 2029 void runOnOperation() override { 2030 MLIRContext *context = &getContext(); 2031 mlir::RewritePatternSet patterns(context); 2032 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 2033 context); 2034 ConversionTarget target(*context); 2035 target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 2036 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 2037 target.addIllegalOp<ILLegalOpF>(); 2038 2039 /// Expect the op to have a single block after legalization. 2040 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 2041 [&](TestMergeBlocksOp op) -> bool { 2042 return llvm::hasSingleElement(op.getBody()); 2043 }); 2044 2045 /// Only allow `test.br` within test.merge_blocks op. 2046 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 2047 return op->getParentOfType<TestMergeBlocksOp>(); 2048 }); 2049 2050 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 2051 /// inlined. 2052 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 2053 [&](SingleBlockImplicitTerminatorOp op) -> bool { 2054 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 2055 }); 2056 2057 DenseSet<Operation *> unlegalizedOps; 2058 ConversionConfig config; 2059 config.unlegalizedOps = &unlegalizedOps; 2060 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 2061 config); 2062 for (auto *op : unlegalizedOps) 2063 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 2064 } 2065 }; 2066 } // namespace 2067 2068 //===----------------------------------------------------------------------===// 2069 // Test Selective Replacement 2070 //===----------------------------------------------------------------------===// 2071 2072 namespace { 2073 /// A rewrite mechanism to inline the body of the op into its parent, when both 2074 /// ops can have a single block. 2075 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 2076 using OpRewritePattern<TestCastOp>::OpRewritePattern; 2077 2078 LogicalResult matchAndRewrite(TestCastOp op, 2079 PatternRewriter &rewriter) const final { 2080 if (op.getNumOperands() != 2) 2081 return failure(); 2082 OperandRange operands = op.getOperands(); 2083 2084 // Replace non-terminator uses with the first operand. 2085 rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) { 2086 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 2087 }); 2088 // Replace everything else with the second operand if the operation isn't 2089 // dead. 2090 rewriter.replaceOp(op, op.getOperand(1)); 2091 return success(); 2092 } 2093 }; 2094 2095 struct TestSelectiveReplacementPatternDriver 2096 : public PassWrapper<TestSelectiveReplacementPatternDriver, 2097 OperationPass<>> { 2098 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 2099 TestSelectiveReplacementPatternDriver) 2100 2101 StringRef getArgument() const final { 2102 return "test-pattern-selective-replacement"; 2103 } 2104 StringRef getDescription() const final { 2105 return "Test selective replacement in the PatternRewriter"; 2106 } 2107 void runOnOperation() override { 2108 MLIRContext *context = &getContext(); 2109 mlir::RewritePatternSet patterns(context); 2110 patterns.add<TestSelectiveOpReplacementPattern>(context); 2111 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 2112 } 2113 }; 2114 } // namespace 2115 2116 //===----------------------------------------------------------------------===// 2117 // PassRegistration 2118 //===----------------------------------------------------------------------===// 2119 2120 namespace mlir { 2121 namespace test { 2122 void registerPatternsTestPass() { 2123 PassRegistration<TestReturnTypeDriver>(); 2124 2125 PassRegistration<TestDerivedAttributeDriver>(); 2126 2127 PassRegistration<TestGreedyPatternDriver>(); 2128 PassRegistration<TestStrictPatternDriver>(); 2129 PassRegistration<TestWalkPatternDriver>(); 2130 2131 PassRegistration<TestLegalizePatternDriver>([] { 2132 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 2133 }); 2134 2135 PassRegistration<TestRemappedValue>(); 2136 2137 PassRegistration<TestUnknownRootOpDriver>(); 2138 2139 PassRegistration<TestTypeConversionDriver>(); 2140 PassRegistration<TestTargetMaterializationWithNoUses>(); 2141 2142 PassRegistration<TestRewriteDynamicOpDriver>(); 2143 2144 PassRegistration<TestMergeBlocksPatternDriver>(); 2145 PassRegistration<TestSelectiveReplacementPatternDriver>(); 2146 } 2147 } // namespace test 2148 } // namespace mlir 2149