1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestOps.h" 11 #include "TestTypes.h" 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/BuiltinAttributes.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/Visitors.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "mlir/Transforms/FoldUtils.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 #include "mlir/Transforms/WalkPatternRewriteDriver.h" 24 #include "llvm/ADT/ScopeExit.h" 25 #include <cstdint> 26 27 using namespace mlir; 28 using namespace test; 29 30 // Native function for testing NativeCodeCall 31 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 32 return choice.getValue() ? input1 : input2; 33 } 34 35 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 36 rewriter.create<OpI>(loc, input); 37 } 38 39 static void handleNoResultOp(PatternRewriter &rewriter, 40 OpSymbolBindingNoResult op) { 41 // Turn the no result op to a one-result op. 42 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), 43 op.getOperand()); 44 } 45 46 static bool getFirstI32Result(Operation *op, Value &value) { 47 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 48 return false; 49 value = op->getResult(0); 50 return true; 51 } 52 53 static Value bindNativeCodeCallResult(Value value) { return value; } 54 55 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, 56 Value input2) { 57 return SmallVector<Value, 2>({input2, input1}); 58 } 59 60 // Test that natives calls are only called once during rewrites. 61 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 62 // This let us check the number of times OpM_Test was called by inspecting 63 // the returned value in the MLIR output. 64 static int64_t opMIncreasingValue = 314159265; 65 static Attribute opMTest(PatternRewriter &rewriter, Value val) { 66 int64_t i = opMIncreasingValue++; 67 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 68 } 69 70 namespace { 71 #include "TestPatterns.inc" 72 } // namespace 73 74 //===----------------------------------------------------------------------===// 75 // Test Reduce Pattern Interface 76 //===----------------------------------------------------------------------===// 77 78 void test::populateTestReductionPatterns(RewritePatternSet &patterns) { 79 populateWithGenerated(patterns); 80 } 81 82 //===----------------------------------------------------------------------===// 83 // Canonicalizer Driver. 84 //===----------------------------------------------------------------------===// 85 86 namespace { 87 struct FoldingPattern : public RewritePattern { 88 public: 89 FoldingPattern(MLIRContext *context) 90 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 91 /*benefit=*/1, context) {} 92 93 LogicalResult matchAndRewrite(Operation *op, 94 PatternRewriter &rewriter) const override { 95 // Exercise createOrFold API for a single-result operation that is folded 96 // upon construction. The operation being created has an in-place folder, 97 // and it should be still present in the output. Furthermore, the folder 98 // should not crash when attempting to recover the (unchanged) operation 99 // result. 100 Value result = rewriter.createOrFold<TestOpInPlaceFold>( 101 op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0)); 102 assert(result); 103 rewriter.replaceOp(op, result); 104 return success(); 105 } 106 }; 107 108 /// This pattern creates a foldable operation at the entry point of the block. 109 /// This tests the situation where the operation folder will need to replace an 110 /// operation with a previously created constant that does not initially 111 /// dominate the operation to replace. 112 struct FolderInsertBeforePreviouslyFoldedConstantPattern 113 : public OpRewritePattern<TestCastOp> { 114 public: 115 using OpRewritePattern<TestCastOp>::OpRewritePattern; 116 117 LogicalResult matchAndRewrite(TestCastOp op, 118 PatternRewriter &rewriter) const override { 119 if (!op->hasAttr("test_fold_before_previously_folded_op")) 120 return failure(); 121 rewriter.setInsertionPointToStart(op->getBlock()); 122 123 auto constOp = rewriter.create<arith::ConstantOp>( 124 op.getLoc(), rewriter.getBoolAttr(true)); 125 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), 126 Value(constOp)); 127 return success(); 128 } 129 }; 130 131 /// This pattern matches test.op_commutative2 with the first operand being 132 /// another test.op_commutative2 with a constant on the right side and fold it 133 /// away by propagating it as its result. This is intend to check that patterns 134 /// are applied after the commutative property moves constant to the right. 135 struct FolderCommutativeOp2WithConstant 136 : public OpRewritePattern<TestCommutative2Op> { 137 public: 138 using OpRewritePattern<TestCommutative2Op>::OpRewritePattern; 139 140 LogicalResult matchAndRewrite(TestCommutative2Op op, 141 PatternRewriter &rewriter) const override { 142 auto operand = 143 dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp()); 144 if (!operand) 145 return failure(); 146 Attribute constInput; 147 if (!matchPattern(operand->getOperand(1), m_Constant(&constInput))) 148 return failure(); 149 rewriter.replaceOp(op, operand->getOperand(1)); 150 return success(); 151 } 152 }; 153 154 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer 155 /// attribute with value smaller than MaxVal, it increments the value by 1. 156 template <int MaxVal> 157 struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> { 158 using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern; 159 160 LogicalResult matchAndRewrite(AnyAttrOfOp op, 161 PatternRewriter &rewriter) const override { 162 auto intAttr = dyn_cast<IntegerAttr>(op.getAttr()); 163 if (!intAttr) 164 return failure(); 165 int64_t val = intAttr.getInt(); 166 if (val >= MaxVal) 167 return failure(); 168 rewriter.modifyOpInPlace( 169 op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); }); 170 return success(); 171 } 172 }; 173 174 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op". 175 struct MakeOpEligible : public RewritePattern { 176 MakeOpEligible(MLIRContext *context) 177 : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {} 178 179 LogicalResult matchAndRewrite(Operation *op, 180 PatternRewriter &rewriter) const override { 181 if (op->hasAttr("eligible")) 182 return failure(); 183 rewriter.modifyOpInPlace( 184 op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); }); 185 return success(); 186 } 187 }; 188 189 /// This pattern hoists eligible ops out of a "test.one_region_op". 190 struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> { 191 using OpRewritePattern<test::OneRegionOp>::OpRewritePattern; 192 193 LogicalResult matchAndRewrite(test::OneRegionOp op, 194 PatternRewriter &rewriter) const override { 195 Operation *terminator = op.getRegion().front().getTerminator(); 196 Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp(); 197 if (toBeHoisted->getParentOp() != op) 198 return failure(); 199 if (!toBeHoisted->hasAttr("eligible")) 200 return failure(); 201 rewriter.moveOpBefore(toBeHoisted, op); 202 return success(); 203 } 204 }; 205 206 /// This pattern moves "test.move_before_parent_op" before the parent op. 207 struct MoveBeforeParentOp : public RewritePattern { 208 MoveBeforeParentOp(MLIRContext *context) 209 : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {} 210 211 LogicalResult matchAndRewrite(Operation *op, 212 PatternRewriter &rewriter) const override { 213 // Do not hoist past functions. 214 if (isa<FunctionOpInterface>(op->getParentOp())) 215 return failure(); 216 rewriter.moveOpBefore(op, op->getParentOp()); 217 return success(); 218 } 219 }; 220 221 /// This pattern moves "test.move_after_parent_op" after the parent op. 222 struct MoveAfterParentOp : public RewritePattern { 223 MoveAfterParentOp(MLIRContext *context) 224 : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {} 225 226 LogicalResult matchAndRewrite(Operation *op, 227 PatternRewriter &rewriter) const override { 228 // Do not hoist past functions. 229 if (isa<FunctionOpInterface>(op->getParentOp())) 230 return failure(); 231 232 int64_t moveForwardBy = 0; 233 if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance")) 234 moveForwardBy = advanceBy.getInt(); 235 236 Operation *moveAfter = op->getParentOp(); 237 for (int64_t i = 0; i < moveForwardBy; ++i) 238 moveAfter = moveAfter->getNextNode(); 239 240 rewriter.moveOpAfter(op, moveAfter); 241 return success(); 242 } 243 }; 244 245 /// This pattern inlines blocks that are nested in 246 /// "test.inline_blocks_into_parent" into the parent block. 247 struct InlineBlocksIntoParent : public RewritePattern { 248 InlineBlocksIntoParent(MLIRContext *context) 249 : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1, 250 context) {} 251 252 LogicalResult matchAndRewrite(Operation *op, 253 PatternRewriter &rewriter) const override { 254 bool changed = false; 255 for (Region &r : op->getRegions()) { 256 while (!r.empty()) { 257 rewriter.inlineBlockBefore(&r.front(), op); 258 changed = true; 259 } 260 } 261 return success(changed); 262 } 263 }; 264 265 /// This pattern splits blocks at "test.split_block_here" and replaces the op 266 /// with a new op (to prevent an infinite loop of block splitting). 267 struct SplitBlockHere : public RewritePattern { 268 SplitBlockHere(MLIRContext *context) 269 : RewritePattern("test.split_block_here", /*benefit=*/1, context) {} 270 271 LogicalResult matchAndRewrite(Operation *op, 272 PatternRewriter &rewriter) const override { 273 rewriter.splitBlock(op->getBlock(), op->getIterator()); 274 Operation *newOp = rewriter.create( 275 op->getLoc(), 276 OperationName("test.new_op", op->getContext()).getIdentifier(), 277 op->getOperands(), op->getResultTypes()); 278 rewriter.replaceOp(op, newOp); 279 return success(); 280 } 281 }; 282 283 /// This pattern clones "test.clone_me" ops. 284 struct CloneOp : public RewritePattern { 285 CloneOp(MLIRContext *context) 286 : RewritePattern("test.clone_me", /*benefit=*/1, context) {} 287 288 LogicalResult matchAndRewrite(Operation *op, 289 PatternRewriter &rewriter) const override { 290 // Do not clone already cloned ops to avoid going into an infinite loop. 291 if (op->hasAttr("was_cloned")) 292 return failure(); 293 Operation *cloned = rewriter.clone(*op); 294 cloned->setAttr("was_cloned", rewriter.getUnitAttr()); 295 return success(); 296 } 297 }; 298 299 /// This pattern clones regions of "test.clone_region_before" ops before the 300 /// parent block. 301 struct CloneRegionBeforeOp : public RewritePattern { 302 CloneRegionBeforeOp(MLIRContext *context) 303 : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {} 304 305 LogicalResult matchAndRewrite(Operation *op, 306 PatternRewriter &rewriter) const override { 307 // Do not clone already cloned ops to avoid going into an infinite loop. 308 if (op->hasAttr("was_cloned")) 309 return failure(); 310 for (Region &r : op->getRegions()) 311 rewriter.cloneRegionBefore(r, op->getBlock()); 312 op->setAttr("was_cloned", rewriter.getUnitAttr()); 313 return success(); 314 } 315 }; 316 317 /// Replace an operation may introduce the re-visiting of its users. 318 class ReplaceWithNewOp : public RewritePattern { 319 public: 320 ReplaceWithNewOp(MLIRContext *context) 321 : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {} 322 323 LogicalResult matchAndRewrite(Operation *op, 324 PatternRewriter &rewriter) const override { 325 Operation *newOp; 326 if (op->hasAttr("create_erase_op")) { 327 newOp = rewriter.create( 328 op->getLoc(), 329 OperationName("test.erase_op", op->getContext()).getIdentifier(), 330 ValueRange(), TypeRange()); 331 } else { 332 newOp = rewriter.create( 333 op->getLoc(), 334 OperationName("test.new_op", op->getContext()).getIdentifier(), 335 op->getOperands(), op->getResultTypes()); 336 } 337 // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp". 338 // A "notifyOperationReplaced" callback is triggered in either case. 339 rewriter.replaceAllOpUsesWith(op, newOp->getResults()); 340 rewriter.eraseOp(op); 341 return success(); 342 } 343 }; 344 345 /// Erases the first child block of the matched "test.erase_first_block" 346 /// operation. 347 class EraseFirstBlock : public RewritePattern { 348 public: 349 EraseFirstBlock(MLIRContext *context) 350 : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {} 351 352 LogicalResult matchAndRewrite(Operation *op, 353 PatternRewriter &rewriter) const override { 354 for (Region &r : op->getRegions()) { 355 for (Block &b : r.getBlocks()) { 356 rewriter.eraseBlock(&b); 357 return success(); 358 } 359 } 360 361 return failure(); 362 } 363 }; 364 365 struct TestGreedyPatternDriver 366 : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> { 367 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver) 368 369 TestGreedyPatternDriver() = default; 370 TestGreedyPatternDriver(const TestGreedyPatternDriver &other) 371 : PassWrapper(other) {} 372 373 StringRef getArgument() const final { return "test-greedy-patterns"; } 374 StringRef getDescription() const final { return "Run test dialect patterns"; } 375 void runOnOperation() override { 376 mlir::RewritePatternSet patterns(&getContext()); 377 populateWithGenerated(patterns); 378 379 // Verify named pattern is generated with expected name. 380 patterns.add<FoldingPattern, TestNamedPatternRule, 381 FolderInsertBeforePreviouslyFoldedConstantPattern, 382 FolderCommutativeOp2WithConstant, HoistEligibleOps, 383 MakeOpEligible>(&getContext()); 384 385 // Additional patterns for testing the GreedyPatternRewriteDriver. 386 patterns.insert<IncrementIntAttribute<3>>(&getContext()); 387 388 GreedyRewriteConfig config; 389 config.useTopDownTraversal = this->useTopDownTraversal; 390 config.maxIterations = this->maxIterations; 391 config.fold = this->fold; 392 config.cseConstants = this->cseConstants; 393 (void)applyPatternsGreedily(getOperation(), std::move(patterns), config); 394 } 395 396 Option<bool> useTopDownTraversal{ 397 *this, "top-down", 398 llvm::cl::desc("Seed the worklist in general top-down order"), 399 llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)}; 400 Option<int> maxIterations{ 401 *this, "max-iterations", 402 llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"), 403 llvm::cl::init(GreedyRewriteConfig().maxIterations)}; 404 Option<bool> fold{*this, "fold", llvm::cl::desc("Whether to fold"), 405 llvm::cl::init(GreedyRewriteConfig().fold)}; 406 Option<bool> cseConstants{*this, "cse-constants", 407 llvm::cl::desc("Whether to CSE constants"), 408 llvm::cl::init(GreedyRewriteConfig().cseConstants)}; 409 }; 410 411 struct DumpNotifications : public RewriterBase::Listener { 412 void notifyBlockInserted(Block *block, Region *previous, 413 Region::iterator previousIt) override { 414 llvm::outs() << "notifyBlockInserted"; 415 if (block->getParentOp()) { 416 llvm::outs() << " into " << block->getParentOp()->getName() << ": "; 417 } else { 418 llvm::outs() << " into unknown op: "; 419 } 420 if (previous == nullptr) { 421 llvm::outs() << "was unlinked\n"; 422 } else { 423 llvm::outs() << "was linked\n"; 424 } 425 } 426 void notifyOperationInserted(Operation *op, 427 OpBuilder::InsertPoint previous) override { 428 llvm::outs() << "notifyOperationInserted: " << op->getName(); 429 if (!previous.isSet()) { 430 llvm::outs() << ", was unlinked\n"; 431 } else { 432 if (!previous.getPoint().getNodePtr()) { 433 llvm::outs() << ", was linked, exact position unknown\n"; 434 } else if (previous.getPoint() == previous.getBlock()->end()) { 435 llvm::outs() << ", was last in block\n"; 436 } else { 437 llvm::outs() << ", previous = " << previous.getPoint()->getName() 438 << "\n"; 439 } 440 } 441 } 442 void notifyBlockErased(Block *block) override { 443 llvm::outs() << "notifyBlockErased\n"; 444 } 445 void notifyOperationErased(Operation *op) override { 446 llvm::outs() << "notifyOperationErased: " << op->getName() << "\n"; 447 } 448 void notifyOperationModified(Operation *op) override { 449 llvm::outs() << "notifyOperationModified: " << op->getName() << "\n"; 450 } 451 void notifyOperationReplaced(Operation *op, ValueRange values) override { 452 llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n"; 453 } 454 }; 455 456 struct TestStrictPatternDriver 457 : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> { 458 public: 459 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver) 460 461 TestStrictPatternDriver() = default; 462 TestStrictPatternDriver(const TestStrictPatternDriver &other) 463 : PassWrapper(other) { 464 strictMode = other.strictMode; 465 } 466 467 StringRef getArgument() const final { return "test-strict-pattern-driver"; } 468 StringRef getDescription() const final { 469 return "Test strict mode of pattern driver"; 470 } 471 472 void runOnOperation() override { 473 MLIRContext *ctx = &getContext(); 474 mlir::RewritePatternSet patterns(ctx); 475 patterns.add< 476 // clang-format off 477 ChangeBlockOp, 478 CloneOp, 479 CloneRegionBeforeOp, 480 EraseOp, 481 ImplicitChangeOp, 482 InlineBlocksIntoParent, 483 InsertSameOp, 484 MoveBeforeParentOp, 485 ReplaceWithNewOp, 486 SplitBlockHere 487 // clang-format on 488 >(ctx); 489 SmallVector<Operation *> ops; 490 getOperation()->walk([&](Operation *op) { 491 StringRef opName = op->getName().getStringRef(); 492 if (opName == "test.insert_same_op" || opName == "test.change_block_op" || 493 opName == "test.replace_with_new_op" || opName == "test.erase_op" || 494 opName == "test.move_before_parent_op" || 495 opName == "test.inline_blocks_into_parent" || 496 opName == "test.split_block_here" || opName == "test.clone_me" || 497 opName == "test.clone_region_before") { 498 ops.push_back(op); 499 } 500 }); 501 502 DumpNotifications dumpNotifications; 503 GreedyRewriteConfig config; 504 config.listener = &dumpNotifications; 505 if (strictMode == "AnyOp") { 506 config.strictMode = GreedyRewriteStrictness::AnyOp; 507 } else if (strictMode == "ExistingAndNewOps") { 508 config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; 509 } else if (strictMode == "ExistingOps") { 510 config.strictMode = GreedyRewriteStrictness::ExistingOps; 511 } else { 512 llvm_unreachable("invalid strictness option"); 513 } 514 515 // Check if these transformations introduce visiting of operations that 516 // are not in the `ops` set (The new created ops are valid). An invalid 517 // operation will trigger the assertion while processing. 518 bool changed = false; 519 bool allErased = false; 520 (void)applyOpPatternsGreedily(ArrayRef(ops), std::move(patterns), config, 521 &changed, &allErased); 522 Builder b(ctx); 523 getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed)); 524 getOperation()->setAttr("pattern_driver_all_erased", 525 b.getBoolAttr(allErased)); 526 } 527 528 Option<std::string> strictMode{ 529 *this, "strictness", 530 llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"), 531 llvm::cl::init("AnyOp")}; 532 533 private: 534 // New inserted operation is valid for further transformation. 535 class InsertSameOp : public RewritePattern { 536 public: 537 InsertSameOp(MLIRContext *context) 538 : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {} 539 540 LogicalResult matchAndRewrite(Operation *op, 541 PatternRewriter &rewriter) const override { 542 if (op->hasAttr("skip")) 543 return failure(); 544 545 Operation *newOp = 546 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 547 op->getOperands(), op->getResultTypes()); 548 rewriter.modifyOpInPlace( 549 op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); }); 550 newOp->setAttr("skip", rewriter.getBoolAttr(true)); 551 552 return success(); 553 } 554 }; 555 556 // Remove an operation may introduce the re-visiting of its operands. 557 class EraseOp : public RewritePattern { 558 public: 559 EraseOp(MLIRContext *context) 560 : RewritePattern("test.erase_op", /*benefit=*/1, context) {} 561 LogicalResult matchAndRewrite(Operation *op, 562 PatternRewriter &rewriter) const override { 563 rewriter.eraseOp(op); 564 return success(); 565 } 566 }; 567 568 // The following two patterns test RewriterBase::replaceAllUsesWith. 569 // 570 // That function replaces all usages of a Block (or a Value) with another one 571 // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver 572 // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its 573 // worklist: when an op is modified, it is added to the worklist. The two 574 // patterns below make the tracking observable: ChangeBlockOp replaces all 575 // usages of a block and that pattern is applied because the corresponding ops 576 // are put on the initial worklist (see above). ImplicitChangeOp does an 577 // unrelated change but ops of the corresponding type are *not* on the initial 578 // worklist, so the effect of the second pattern is only visible if the 579 // tracking and subsequent adding to the worklist actually works. 580 581 // Replace all usages of the first successor with the second successor. 582 class ChangeBlockOp : public RewritePattern { 583 public: 584 ChangeBlockOp(MLIRContext *context) 585 : RewritePattern("test.change_block_op", /*benefit=*/1, context) {} 586 LogicalResult matchAndRewrite(Operation *op, 587 PatternRewriter &rewriter) const override { 588 if (op->getNumSuccessors() < 2) 589 return failure(); 590 Block *firstSuccessor = op->getSuccessor(0); 591 Block *secondSuccessor = op->getSuccessor(1); 592 if (firstSuccessor == secondSuccessor) 593 return failure(); 594 // This is the function being tested: 595 rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor); 596 // Using the following line instead would make the test fail: 597 // firstSuccessor->replaceAllUsesWith(secondSuccessor); 598 return success(); 599 } 600 }; 601 602 // Changes the successor to the parent block. 603 class ImplicitChangeOp : public RewritePattern { 604 public: 605 ImplicitChangeOp(MLIRContext *context) 606 : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {} 607 LogicalResult matchAndRewrite(Operation *op, 608 PatternRewriter &rewriter) const override { 609 if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock()) 610 return failure(); 611 rewriter.modifyOpInPlace(op, 612 [&]() { op->setSuccessor(op->getBlock(), 0); }); 613 return success(); 614 } 615 }; 616 }; 617 618 struct TestWalkPatternDriver final 619 : PassWrapper<TestWalkPatternDriver, OperationPass<>> { 620 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver) 621 622 TestWalkPatternDriver() = default; 623 TestWalkPatternDriver(const TestWalkPatternDriver &other) 624 : PassWrapper(other) {} 625 626 StringRef getArgument() const override { 627 return "test-walk-pattern-rewrite-driver"; 628 } 629 StringRef getDescription() const override { 630 return "Run test walk pattern rewrite driver"; 631 } 632 void runOnOperation() override { 633 mlir::RewritePatternSet patterns(&getContext()); 634 635 // Patterns for testing the WalkPatternRewriteDriver. 636 patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp, 637 MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>( 638 &getContext()); 639 640 DumpNotifications dumpListener; 641 walkAndApplyPatterns(getOperation(), std::move(patterns), 642 dumpNotifications ? &dumpListener : nullptr); 643 } 644 645 Option<bool> dumpNotifications{ 646 *this, "dump-notifications", 647 llvm::cl::desc("Print rewrite listener notifications"), 648 llvm::cl::init(false)}; 649 }; 650 651 } // namespace 652 653 //===----------------------------------------------------------------------===// 654 // ReturnType Driver. 655 //===----------------------------------------------------------------------===// 656 657 namespace { 658 // Generate ops for each instance where the type can be successfully inferred. 659 template <typename OpTy> 660 static void invokeCreateWithInferredReturnType(Operation *op) { 661 auto *context = op->getContext(); 662 auto fop = op->getParentOfType<func::FuncOp>(); 663 auto location = UnknownLoc::get(context); 664 OpBuilder b(op); 665 b.setInsertionPointAfter(op); 666 667 // Use permutations of 2 args as operands. 668 assert(fop.getNumArguments() >= 2); 669 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 670 for (int j = 0; j < e; ++j) { 671 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 672 SmallVector<Type, 2> inferredReturnTypes; 673 if (succeeded(OpTy::inferReturnTypes( 674 context, std::nullopt, values, op->getDiscardableAttrDictionary(), 675 op->getPropertiesStorage(), op->getRegions(), 676 inferredReturnTypes))) { 677 OperationState state(location, OpTy::getOperationName()); 678 // TODO: Expand to regions. 679 OpTy::build(b, state, values, op->getAttrs()); 680 (void)b.create(state); 681 } 682 } 683 } 684 } 685 686 static void reifyReturnShape(Operation *op) { 687 OpBuilder b(op); 688 689 // Use permutations of 2 args as operands. 690 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 691 SmallVector<Value, 2> shapes; 692 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 693 !llvm::hasSingleElement(shapes)) 694 return; 695 for (const auto &it : llvm::enumerate(shapes)) { 696 op->emitRemark() << "value " << it.index() << ": " 697 << it.value().getDefiningOp(); 698 } 699 } 700 701 struct TestReturnTypeDriver 702 : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> { 703 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver) 704 705 void getDependentDialects(DialectRegistry ®istry) const override { 706 registry.insert<tensor::TensorDialect>(); 707 } 708 StringRef getArgument() const final { return "test-return-type"; } 709 StringRef getDescription() const final { return "Run return type functions"; } 710 711 void runOnOperation() override { 712 if (getOperation().getName() == "testCreateFunctions") { 713 std::vector<Operation *> ops; 714 // Collect ops to avoid triggering on inserted ops. 715 for (auto &op : getOperation().getBody().front()) 716 ops.push_back(&op); 717 // Generate test patterns for each, but skip terminator. 718 for (auto *op : llvm::ArrayRef(ops).drop_back()) { 719 // Test create method of each of the Op classes below. The resultant 720 // output would be in reverse order underneath `op` from which 721 // the attributes and regions are used. 722 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 723 invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>( 724 op); 725 invokeCreateWithInferredReturnType< 726 OpWithShapedTypeInferTypeInterfaceOp>(op); 727 }; 728 return; 729 } 730 if (getOperation().getName() == "testReifyFunctions") { 731 std::vector<Operation *> ops; 732 // Collect ops to avoid triggering on inserted ops. 733 for (auto &op : getOperation().getBody().front()) 734 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 735 ops.push_back(&op); 736 // Generate test patterns for each, but skip terminator. 737 for (auto *op : ops) 738 reifyReturnShape(op); 739 } 740 } 741 }; 742 } // namespace 743 744 namespace { 745 struct TestDerivedAttributeDriver 746 : public PassWrapper<TestDerivedAttributeDriver, 747 OperationPass<func::FuncOp>> { 748 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver) 749 750 StringRef getArgument() const final { return "test-derived-attr"; } 751 StringRef getDescription() const final { 752 return "Run test derived attributes"; 753 } 754 void runOnOperation() override; 755 }; 756 } // namespace 757 758 void TestDerivedAttributeDriver::runOnOperation() { 759 getOperation().walk([](DerivedAttributeOpInterface dOp) { 760 auto dAttr = dOp.materializeDerivedAttributes(); 761 if (!dAttr) 762 return; 763 for (auto d : dAttr) 764 dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue(); 765 }); 766 } 767 768 //===----------------------------------------------------------------------===// 769 // Legalization Driver. 770 //===----------------------------------------------------------------------===// 771 772 namespace { 773 //===----------------------------------------------------------------------===// 774 // Region-Block Rewrite Testing 775 776 /// This pattern applies a signature conversion to a block inside a detached 777 /// region. 778 struct TestDetachedSignatureConversion : public ConversionPattern { 779 TestDetachedSignatureConversion(MLIRContext *ctx) 780 : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1, 781 ctx) {} 782 783 LogicalResult 784 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 785 ConversionPatternRewriter &rewriter) const final { 786 if (op->getNumRegions() != 1) 787 return failure(); 788 OperationState state(op->getLoc(), "test.legal_op", operands, 789 op->getResultTypes(), {}, BlockRange()); 790 Region *newRegion = state.addRegion(); 791 rewriter.inlineRegionBefore(op->getRegion(0), *newRegion, 792 newRegion->begin()); 793 TypeConverter::SignatureConversion result(newRegion->getNumArguments()); 794 for (unsigned i = 0, e = newRegion->getNumArguments(); i < e; ++i) 795 result.addInputs(i, rewriter.getF64Type()); 796 rewriter.applySignatureConversion(&newRegion->front(), result); 797 Operation *newOp = rewriter.create(state); 798 rewriter.replaceOp(op, newOp->getResults()); 799 return success(); 800 } 801 }; 802 803 /// This pattern is a simple pattern that inlines the first region of a given 804 /// operation into the parent region. 805 struct TestRegionRewriteBlockMovement : public ConversionPattern { 806 TestRegionRewriteBlockMovement(MLIRContext *ctx) 807 : ConversionPattern("test.region", 1, ctx) {} 808 809 LogicalResult 810 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 811 ConversionPatternRewriter &rewriter) const final { 812 // Inline this region into the parent region. 813 auto &parentRegion = *op->getParentRegion(); 814 auto &opRegion = op->getRegion(0); 815 if (op->getDiscardableAttr("legalizer.should_clone")) 816 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 817 else 818 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 819 820 if (op->getDiscardableAttr("legalizer.erase_old_blocks")) { 821 while (!opRegion.empty()) 822 rewriter.eraseBlock(&opRegion.front()); 823 } 824 825 // Drop this operation. 826 rewriter.eraseOp(op); 827 return success(); 828 } 829 }; 830 /// This pattern is a simple pattern that generates a region containing an 831 /// illegal operation. 832 struct TestRegionRewriteUndo : public RewritePattern { 833 TestRegionRewriteUndo(MLIRContext *ctx) 834 : RewritePattern("test.region_builder", 1, ctx) {} 835 836 LogicalResult matchAndRewrite(Operation *op, 837 PatternRewriter &rewriter) const final { 838 // Create the region operation with an entry block containing arguments. 839 OperationState newRegion(op->getLoc(), "test.region"); 840 newRegion.addRegion(); 841 auto *regionOp = rewriter.create(newRegion); 842 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 843 entryBlock->addArgument(rewriter.getIntegerType(64), 844 rewriter.getUnknownLoc()); 845 846 // Add an explicitly illegal operation to ensure the conversion fails. 847 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 848 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 849 850 // Drop this operation. 851 rewriter.eraseOp(op); 852 return success(); 853 } 854 }; 855 /// A simple pattern that creates a block at the end of the parent region of the 856 /// matched operation. 857 struct TestCreateBlock : public RewritePattern { 858 TestCreateBlock(MLIRContext *ctx) 859 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 860 861 LogicalResult matchAndRewrite(Operation *op, 862 PatternRewriter &rewriter) const final { 863 Region ®ion = *op->getParentRegion(); 864 Type i32Type = rewriter.getIntegerType(32); 865 Location loc = op->getLoc(); 866 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 867 rewriter.create<TerminatorOp>(loc); 868 rewriter.eraseOp(op); 869 return success(); 870 } 871 }; 872 873 /// A simple pattern that creates a block containing an invalid operation in 874 /// order to trigger the block creation undo mechanism. 875 struct TestCreateIllegalBlock : public RewritePattern { 876 TestCreateIllegalBlock(MLIRContext *ctx) 877 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 878 879 LogicalResult matchAndRewrite(Operation *op, 880 PatternRewriter &rewriter) const final { 881 Region ®ion = *op->getParentRegion(); 882 Type i32Type = rewriter.getIntegerType(32); 883 Location loc = op->getLoc(); 884 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 885 // Create an illegal op to ensure the conversion fails. 886 rewriter.create<ILLegalOpF>(loc, i32Type); 887 rewriter.create<TerminatorOp>(loc); 888 rewriter.eraseOp(op); 889 return success(); 890 } 891 }; 892 893 /// A simple pattern that tests the undo mechanism when replacing the uses of a 894 /// block argument. 895 struct TestUndoBlockArgReplace : public ConversionPattern { 896 TestUndoBlockArgReplace(MLIRContext *ctx) 897 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 898 899 LogicalResult 900 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 901 ConversionPatternRewriter &rewriter) const final { 902 auto illegalOp = 903 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 904 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 905 illegalOp->getResult(0)); 906 rewriter.modifyOpInPlace(op, [] {}); 907 return success(); 908 } 909 }; 910 911 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion. 912 /// This is to test the rollback logic. 913 struct TestUndoMoveOpBefore : public ConversionPattern { 914 TestUndoMoveOpBefore(MLIRContext *ctx) 915 : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {} 916 917 LogicalResult 918 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 919 ConversionPatternRewriter &rewriter) const override { 920 rewriter.moveOpBefore(op, op->getParentOp()); 921 // Replace with an illegal op to ensure the conversion fails. 922 rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type()); 923 return success(); 924 } 925 }; 926 927 /// A rewrite pattern that tests the undo mechanism when erasing a block. 928 struct TestUndoBlockErase : public ConversionPattern { 929 TestUndoBlockErase(MLIRContext *ctx) 930 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 931 932 LogicalResult 933 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 934 ConversionPatternRewriter &rewriter) const final { 935 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 936 rewriter.setInsertionPointToStart(secondBlock); 937 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 938 rewriter.eraseBlock(secondBlock); 939 rewriter.modifyOpInPlace(op, [] {}); 940 return success(); 941 } 942 }; 943 944 /// A pattern that modifies a property in-place, but keeps the op illegal. 945 struct TestUndoPropertiesModification : public ConversionPattern { 946 TestUndoPropertiesModification(MLIRContext *ctx) 947 : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {} 948 LogicalResult 949 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 950 ConversionPatternRewriter &rewriter) const final { 951 if (!op->hasAttr("modify_inplace")) 952 return failure(); 953 rewriter.modifyOpInPlace( 954 op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); }); 955 return success(); 956 } 957 }; 958 959 //===----------------------------------------------------------------------===// 960 // Type-Conversion Rewrite Testing 961 962 /// This patterns erases a region operation that has had a type conversion. 963 struct TestDropOpSignatureConversion : public ConversionPattern { 964 TestDropOpSignatureConversion(MLIRContext *ctx, 965 const TypeConverter &converter) 966 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 967 LogicalResult 968 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 969 ConversionPatternRewriter &rewriter) const override { 970 Region ®ion = op->getRegion(0); 971 Block *entry = ®ion.front(); 972 973 // Convert the original entry arguments. 974 const TypeConverter &converter = *getTypeConverter(); 975 TypeConverter::SignatureConversion result(entry->getNumArguments()); 976 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 977 result)) || 978 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 979 return failure(); 980 981 // Convert the region signature and just drop the operation. 982 rewriter.eraseOp(op); 983 return success(); 984 } 985 }; 986 /// This pattern simply updates the operands of the given operation. 987 struct TestPassthroughInvalidOp : public ConversionPattern { 988 TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter) 989 : ConversionPattern(converter, "test.invalid", 1, ctx) {} 990 LogicalResult 991 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 992 ConversionPatternRewriter &rewriter) const final { 993 SmallVector<Value> flattened; 994 for (auto it : llvm::enumerate(operands)) { 995 ValueRange range = it.value(); 996 if (range.size() == 1) { 997 flattened.push_back(range.front()); 998 continue; 999 } 1000 1001 // This is a 1:N replacement. Insert a test.cast op. (That's what the 1002 // argument materialization used to do.) 1003 flattened.push_back( 1004 rewriter 1005 .create<TestCastOp>(op->getLoc(), 1006 op->getOperand(it.index()).getType(), range) 1007 .getResult()); 1008 } 1009 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, flattened, 1010 std::nullopt); 1011 return success(); 1012 } 1013 }; 1014 /// Replace with valid op, but simply drop the operands. This is used in a 1015 /// regression where we used to generate circular unrealized_conversion_cast 1016 /// ops. 1017 struct TestDropAndReplaceInvalidOp : public ConversionPattern { 1018 TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter) 1019 : ConversionPattern(converter, 1020 "test.drop_operands_and_replace_with_valid", 1, ctx) { 1021 } 1022 LogicalResult 1023 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1024 ConversionPatternRewriter &rewriter) const final { 1025 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, ValueRange(), 1026 std::nullopt); 1027 return success(); 1028 } 1029 }; 1030 /// This pattern handles the case of a split return value. 1031 struct TestSplitReturnType : public ConversionPattern { 1032 TestSplitReturnType(MLIRContext *ctx) 1033 : ConversionPattern("test.return", 1, ctx) {} 1034 LogicalResult 1035 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 1036 ConversionPatternRewriter &rewriter) const final { 1037 // Check for a return of F32. 1038 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 1039 return failure(); 1040 rewriter.replaceOpWithNewOp<TestReturnOp>(op, operands[0]); 1041 return success(); 1042 } 1043 }; 1044 1045 //===----------------------------------------------------------------------===// 1046 // Multi-Level Type-Conversion Rewrite Testing 1047 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 1048 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 1049 : ConversionPattern("test.type_producer", 1, ctx) {} 1050 LogicalResult 1051 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1052 ConversionPatternRewriter &rewriter) const final { 1053 // If the type is I32, change the type to F32. 1054 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 1055 return failure(); 1056 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 1057 return success(); 1058 } 1059 }; 1060 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 1061 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 1062 : ConversionPattern("test.type_producer", 1, ctx) {} 1063 LogicalResult 1064 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1065 ConversionPatternRewriter &rewriter) const final { 1066 // If the type is F32, change the type to F64. 1067 if (!Type(*op->result_type_begin()).isF32()) 1068 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 1069 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 1070 return success(); 1071 } 1072 }; 1073 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 1074 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 1075 : ConversionPattern("test.type_producer", 10, ctx) {} 1076 LogicalResult 1077 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1078 ConversionPatternRewriter &rewriter) const final { 1079 // Always convert to B16, even though it is not a legal type. This tests 1080 // that values are unmapped correctly. 1081 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 1082 return success(); 1083 } 1084 }; 1085 struct TestUpdateConsumerType : public ConversionPattern { 1086 TestUpdateConsumerType(MLIRContext *ctx) 1087 : ConversionPattern("test.type_consumer", 1, ctx) {} 1088 LogicalResult 1089 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1090 ConversionPatternRewriter &rewriter) const final { 1091 // Verify that the incoming operand has been successfully remapped to F64. 1092 if (!operands[0].getType().isF64()) 1093 return failure(); 1094 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 1095 return success(); 1096 } 1097 }; 1098 1099 //===----------------------------------------------------------------------===// 1100 // Non-Root Replacement Rewrite Testing 1101 /// This pattern generates an invalid operation, but replaces it before the 1102 /// pattern is finished. This checks that we don't need to legalize the 1103 /// temporary op. 1104 struct TestNonRootReplacement : public RewritePattern { 1105 TestNonRootReplacement(MLIRContext *ctx) 1106 : RewritePattern("test.replace_non_root", 1, ctx) {} 1107 1108 LogicalResult matchAndRewrite(Operation *op, 1109 PatternRewriter &rewriter) const final { 1110 auto resultType = *op->result_type_begin(); 1111 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 1112 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 1113 1114 rewriter.replaceOp(illegalOp, legalOp); 1115 rewriter.replaceOp(op, illegalOp); 1116 return success(); 1117 } 1118 }; 1119 1120 //===----------------------------------------------------------------------===// 1121 // Recursive Rewrite Testing 1122 /// This pattern is applied to the same operation multiple times, but has a 1123 /// bounded recursion. 1124 struct TestBoundedRecursiveRewrite 1125 : public OpRewritePattern<TestRecursiveRewriteOp> { 1126 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 1127 1128 void initialize() { 1129 // The conversion target handles bounding the recursion of this pattern. 1130 setHasBoundedRewriteRecursion(); 1131 } 1132 1133 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 1134 PatternRewriter &rewriter) const final { 1135 // Decrement the depth of the op in-place. 1136 rewriter.modifyOpInPlace(op, [&] { 1137 op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1)); 1138 }); 1139 return success(); 1140 } 1141 }; 1142 1143 struct TestNestedOpCreationUndoRewrite 1144 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 1145 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 1146 1147 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 1148 PatternRewriter &rewriter) const final { 1149 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 1150 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 1151 return success(); 1152 }; 1153 }; 1154 1155 // This pattern matches `test.blackhole` and delete this op and its producer. 1156 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 1157 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 1158 1159 LogicalResult matchAndRewrite(BlackHoleOp op, 1160 PatternRewriter &rewriter) const final { 1161 Operation *producer = op.getOperand().getDefiningOp(); 1162 // Always erase the user before the producer, the framework should handle 1163 // this correctly. 1164 rewriter.eraseOp(op); 1165 rewriter.eraseOp(producer); 1166 return success(); 1167 }; 1168 }; 1169 1170 // This pattern replaces explicitly illegal op with explicitly legal op, 1171 // but in addition creates unregistered operation. 1172 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { 1173 using OpRewritePattern<ILLegalOpG>::OpRewritePattern; 1174 1175 LogicalResult matchAndRewrite(ILLegalOpG op, 1176 PatternRewriter &rewriter) const final { 1177 IntegerAttr attr = rewriter.getI32IntegerAttr(0); 1178 Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr); 1179 rewriter.replaceOpWithNewOp<LegalOpC>(op, val); 1180 return success(); 1181 }; 1182 }; 1183 1184 class TestEraseOp : public ConversionPattern { 1185 public: 1186 TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op", 1, ctx) {} 1187 LogicalResult 1188 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1189 ConversionPatternRewriter &rewriter) const final { 1190 // Erase op without replacements. 1191 rewriter.eraseOp(op); 1192 return success(); 1193 } 1194 }; 1195 1196 /// This pattern matches a test.duplicate_block_args op and duplicates all 1197 /// block arguments. 1198 class TestDuplicateBlockArgs 1199 : public OpConversionPattern<DuplicateBlockArgsOp> { 1200 using OpConversionPattern<DuplicateBlockArgsOp>::OpConversionPattern; 1201 1202 LogicalResult 1203 matchAndRewrite(DuplicateBlockArgsOp op, OpAdaptor adaptor, 1204 ConversionPatternRewriter &rewriter) const override { 1205 if (op.getIsLegal()) 1206 return failure(); 1207 rewriter.startOpModification(op); 1208 Block *body = &op.getBody().front(); 1209 TypeConverter::SignatureConversion result(body->getNumArguments()); 1210 for (auto it : llvm::enumerate(body->getArgumentTypes())) 1211 result.addInputs(it.index(), {it.value(), it.value()}); 1212 rewriter.applySignatureConversion(body, result, getTypeConverter()); 1213 op.setIsLegal(true); 1214 rewriter.finalizeOpModification(op); 1215 return success(); 1216 } 1217 }; 1218 1219 /// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid 1220 /// op. The pattern supports 1:N replacements and forwards the replacement 1221 /// values of the single operand as test.valid operands. 1222 class TestRepetitive1ToNConsumer : public ConversionPattern { 1223 public: 1224 TestRepetitive1ToNConsumer(MLIRContext *ctx) 1225 : ConversionPattern("test.repetitive_1_to_n_consumer", 1, ctx) {} 1226 LogicalResult 1227 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 1228 ConversionPatternRewriter &rewriter) const final { 1229 // A single operand is expected. 1230 if (op->getNumOperands() != 1) 1231 return failure(); 1232 rewriter.replaceOpWithNewOp<TestValidOp>(op, operands.front()); 1233 return success(); 1234 } 1235 }; 1236 1237 /// A pattern that tests two back-to-back 1 -> 2 op replacements. 1238 class TestMultiple1ToNReplacement : public ConversionPattern { 1239 public: 1240 TestMultiple1ToNReplacement(MLIRContext *ctx, const TypeConverter &converter) 1241 : ConversionPattern(converter, "test.multiple_1_to_n_replacement", 1, 1242 ctx) {} 1243 LogicalResult 1244 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 1245 ConversionPatternRewriter &rewriter) const final { 1246 // Helper function that replaces the given op with a new op of the given 1247 // name and doubles each result (1 -> 2 replacement of each result). 1248 auto replaceWithDoubleResults = [&](Operation *op, StringRef name) { 1249 SmallVector<Type> types; 1250 for (Type t : op->getResultTypes()) { 1251 types.push_back(t); 1252 types.push_back(t); 1253 } 1254 OperationState state(op->getLoc(), name, 1255 /*operands=*/{}, types, op->getAttrs()); 1256 auto *newOp = rewriter.create(state); 1257 SmallVector<ValueRange> repls; 1258 for (size_t i = 0, e = op->getNumResults(); i < e; ++i) 1259 repls.push_back(newOp->getResults().slice(2 * i, 2)); 1260 rewriter.replaceOpWithMultiple(op, repls); 1261 return newOp; 1262 }; 1263 1264 // Replace test.multiple_1_to_n_replacement with test.step_1. 1265 Operation *repl1 = replaceWithDoubleResults(op, "test.step_1"); 1266 // Now replace test.step_1 with test.legal_op. 1267 replaceWithDoubleResults(repl1, "test.legal_op"); 1268 return success(); 1269 } 1270 }; 1271 1272 } // namespace 1273 1274 namespace { 1275 struct TestTypeConverter : public TypeConverter { 1276 using TypeConverter::TypeConverter; 1277 TestTypeConverter() { 1278 addConversion(convertType); 1279 addSourceMaterialization(materializeCast); 1280 } 1281 1282 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 1283 // Drop I16 types. 1284 if (t.isSignlessInteger(16)) 1285 return success(); 1286 1287 // Convert I64 to F64. 1288 if (t.isSignlessInteger(64)) { 1289 results.push_back(Float64Type::get(t.getContext())); 1290 return success(); 1291 } 1292 1293 // Convert I42 to I43. 1294 if (t.isInteger(42)) { 1295 results.push_back(IntegerType::get(t.getContext(), 43)); 1296 return success(); 1297 } 1298 1299 // Split F32 into F16,F16. 1300 if (t.isF32()) { 1301 results.assign(2, Float16Type::get(t.getContext())); 1302 return success(); 1303 } 1304 1305 // Drop I24 types. 1306 if (t.isInteger(24)) { 1307 return success(); 1308 } 1309 1310 // Otherwise, convert the type directly. 1311 results.push_back(t); 1312 return success(); 1313 } 1314 1315 /// Hook for materializing a conversion. This is necessary because we generate 1316 /// 1->N type mappings. 1317 static Value materializeCast(OpBuilder &builder, Type resultType, 1318 ValueRange inputs, Location loc) { 1319 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1320 } 1321 }; 1322 1323 struct TestLegalizePatternDriver 1324 : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> { 1325 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver) 1326 1327 StringRef getArgument() const final { return "test-legalize-patterns"; } 1328 StringRef getDescription() const final { 1329 return "Run test dialect legalization patterns"; 1330 } 1331 /// The mode of conversion to use with the driver. 1332 enum class ConversionMode { Analysis, Full, Partial }; 1333 1334 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 1335 1336 void getDependentDialects(DialectRegistry ®istry) const override { 1337 registry.insert<func::FuncDialect, test::TestDialect>(); 1338 } 1339 1340 void runOnOperation() override { 1341 TestTypeConverter converter; 1342 mlir::RewritePatternSet patterns(&getContext()); 1343 populateWithGenerated(patterns); 1344 patterns 1345 .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion, 1346 TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock, 1347 TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType, 1348 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 1349 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 1350 TestNonRootReplacement, TestBoundedRecursiveRewrite, 1351 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, 1352 TestCreateUnregisteredOp, TestUndoMoveOpBefore, 1353 TestUndoPropertiesModification, TestEraseOp, 1354 TestRepetitive1ToNConsumer>(&getContext()); 1355 patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp, 1356 TestPassthroughInvalidOp, TestMultiple1ToNReplacement>( 1357 &getContext(), converter); 1358 patterns.add<TestDuplicateBlockArgs>(converter, &getContext()); 1359 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1360 converter); 1361 mlir::populateCallOpTypeConversionPattern(patterns, converter); 1362 1363 // Define the conversion target used for the test. 1364 ConversionTarget target(getContext()); 1365 target.addLegalOp<ModuleOp>(); 1366 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, 1367 TerminatorOp, OneRegionOp>(); 1368 target.addLegalOp(OperationName("test.legal_op", &getContext())); 1369 target 1370 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 1371 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 1372 // Don't allow F32 operands. 1373 return llvm::none_of(op.getOperandTypes(), 1374 [](Type type) { return type.isF32(); }); 1375 }); 1376 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1377 return converter.isSignatureLegal(op.getFunctionType()) && 1378 converter.isLegal(&op.getBody()); 1379 }); 1380 target.addDynamicallyLegalOp<func::CallOp>( 1381 [&](func::CallOp op) { return converter.isLegal(op); }); 1382 1383 // TestCreateUnregisteredOp creates `arith.constant` operation, 1384 // which was not added to target intentionally to test 1385 // correct error code from conversion driver. 1386 target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); 1387 1388 // Expect the type_producer/type_consumer operations to only operate on f64. 1389 target.addDynamicallyLegalOp<TestTypeProducerOp>( 1390 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1391 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1392 return op.getOperand().getType().isF64(); 1393 }); 1394 1395 // Check support for marking certain operations as recursively legal. 1396 target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) { 1397 return static_cast<bool>( 1398 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 1399 }); 1400 1401 // Mark the bound recursion operation as dynamically legal. 1402 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 1403 [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); 1404 1405 // Create a dynamically legal rule that can only be legalized by folding it. 1406 target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>( 1407 [](TestOpInPlaceSelfFold op) { return op.getFolded(); }); 1408 1409 target.addDynamicallyLegalOp<DuplicateBlockArgsOp>( 1410 [](DuplicateBlockArgsOp op) { return op.getIsLegal(); }); 1411 1412 // Handle a partial conversion. 1413 if (mode == ConversionMode::Partial) { 1414 DenseSet<Operation *> unlegalizedOps; 1415 ConversionConfig config; 1416 DumpNotifications dumpNotifications; 1417 config.listener = &dumpNotifications; 1418 config.unlegalizedOps = &unlegalizedOps; 1419 if (failed(applyPartialConversion(getOperation(), target, 1420 std::move(patterns), config))) { 1421 getOperation()->emitRemark() << "applyPartialConversion failed"; 1422 } 1423 // Emit remarks for each legalizable operation. 1424 for (auto *op : unlegalizedOps) 1425 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1426 return; 1427 } 1428 1429 // Handle a full conversion. 1430 if (mode == ConversionMode::Full) { 1431 // Check support for marking unknown operations as dynamically legal. 1432 target.markUnknownOpDynamicallyLegal([](Operation *op) { 1433 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 1434 }); 1435 1436 ConversionConfig config; 1437 DumpNotifications dumpNotifications; 1438 config.listener = &dumpNotifications; 1439 if (failed(applyFullConversion(getOperation(), target, 1440 std::move(patterns), config))) { 1441 getOperation()->emitRemark() << "applyFullConversion failed"; 1442 } 1443 return; 1444 } 1445 1446 // Otherwise, handle an analysis conversion. 1447 assert(mode == ConversionMode::Analysis); 1448 1449 // Analyze the convertible operations. 1450 DenseSet<Operation *> legalizedOps; 1451 ConversionConfig config; 1452 config.legalizableOps = &legalizedOps; 1453 if (failed(applyAnalysisConversion(getOperation(), target, 1454 std::move(patterns), config))) 1455 return signalPassFailure(); 1456 1457 // Emit remarks for each legalizable operation. 1458 for (auto *op : legalizedOps) 1459 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 1460 } 1461 1462 /// The mode of conversion to use. 1463 ConversionMode mode; 1464 }; 1465 } // namespace 1466 1467 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 1468 legalizerConversionMode( 1469 "test-legalize-mode", 1470 llvm::cl::desc("The legalization mode to use with the test driver"), 1471 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 1472 llvm::cl::values( 1473 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 1474 "analysis", "Perform an analysis conversion"), 1475 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 1476 "Perform a full conversion"), 1477 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 1478 "partial", "Perform a partial conversion"))); 1479 1480 //===----------------------------------------------------------------------===// 1481 // ConversionPatternRewriter::getRemappedValue testing. This method is used 1482 // to get the remapped value of an original value that was replaced using 1483 // ConversionPatternRewriter. 1484 namespace { 1485 struct TestRemapValueTypeConverter : public TypeConverter { 1486 using TypeConverter::TypeConverter; 1487 1488 TestRemapValueTypeConverter() { 1489 addConversion( 1490 [](Float32Type type) { return Float64Type::get(type.getContext()); }); 1491 addConversion([](Type type) { return type; }); 1492 } 1493 }; 1494 1495 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 1496 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 1497 /// operand twice. 1498 /// 1499 /// Example: 1500 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 1501 /// is replaced with: 1502 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 1503 struct OneVResOneVOperandOp1Converter 1504 : public OpConversionPattern<OneVResOneVOperandOp1> { 1505 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 1506 1507 LogicalResult 1508 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, 1509 ConversionPatternRewriter &rewriter) const override { 1510 auto origOps = op.getOperands(); 1511 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 1512 "One operand expected"); 1513 Value origOp = *origOps.begin(); 1514 SmallVector<Value, 2> remappedOperands; 1515 // Replicate the remapped original operand twice. Note that we don't used 1516 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 1517 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1518 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1519 1520 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 1521 remappedOperands); 1522 return success(); 1523 } 1524 }; 1525 1526 /// A rewriter pattern that tests that blocks can be merged. 1527 struct TestRemapValueInRegion 1528 : public OpConversionPattern<TestRemappedValueRegionOp> { 1529 using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern; 1530 1531 LogicalResult 1532 matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor, 1533 ConversionPatternRewriter &rewriter) const final { 1534 Block &block = op.getBody().front(); 1535 Operation *terminator = block.getTerminator(); 1536 1537 // Merge the block into the parent region. 1538 Block *parentBlock = op->getBlock(); 1539 Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator()); 1540 rewriter.mergeBlocks(&block, parentBlock, ValueRange()); 1541 rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange()); 1542 1543 // Replace the results of this operation with the remapped terminator 1544 // values. 1545 SmallVector<Value> terminatorOperands; 1546 if (failed(rewriter.getRemappedValues(terminator->getOperands(), 1547 terminatorOperands))) 1548 return failure(); 1549 1550 rewriter.eraseOp(terminator); 1551 rewriter.replaceOp(op, terminatorOperands); 1552 return success(); 1553 } 1554 }; 1555 1556 struct TestRemappedValue 1557 : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> { 1558 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue) 1559 1560 StringRef getArgument() const final { return "test-remapped-value"; } 1561 StringRef getDescription() const final { 1562 return "Test public remapped value mechanism in ConversionPatternRewriter"; 1563 } 1564 void runOnOperation() override { 1565 TestRemapValueTypeConverter typeConverter; 1566 1567 mlir::RewritePatternSet patterns(&getContext()); 1568 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 1569 patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>( 1570 &getContext()); 1571 patterns.add<TestRemapValueInRegion>(typeConverter, &getContext()); 1572 1573 mlir::ConversionTarget target(getContext()); 1574 target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>(); 1575 1576 // Expect the type_producer/type_consumer operations to only operate on f64. 1577 target.addDynamicallyLegalOp<TestTypeProducerOp>( 1578 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1579 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1580 return op.getOperand().getType().isF64(); 1581 }); 1582 1583 // We make OneVResOneVOperandOp1 legal only when it has more that one 1584 // operand. This will trigger the conversion that will replace one-operand 1585 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 1586 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 1587 [](Operation *op) { return op->getNumOperands() > 1; }); 1588 1589 if (failed(mlir::applyFullConversion(getOperation(), target, 1590 std::move(patterns)))) { 1591 signalPassFailure(); 1592 } 1593 } 1594 }; 1595 } // namespace 1596 1597 //===----------------------------------------------------------------------===// 1598 // Test patterns without a specific root operation kind 1599 //===----------------------------------------------------------------------===// 1600 1601 namespace { 1602 /// This pattern matches and removes any operation in the test dialect. 1603 struct RemoveTestDialectOps : public RewritePattern { 1604 RemoveTestDialectOps(MLIRContext *context) 1605 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 1606 1607 LogicalResult matchAndRewrite(Operation *op, 1608 PatternRewriter &rewriter) const override { 1609 if (!isa<TestDialect>(op->getDialect())) 1610 return failure(); 1611 rewriter.eraseOp(op); 1612 return success(); 1613 } 1614 }; 1615 1616 struct TestUnknownRootOpDriver 1617 : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> { 1618 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver) 1619 1620 StringRef getArgument() const final { 1621 return "test-legalize-unknown-root-patterns"; 1622 } 1623 StringRef getDescription() const final { 1624 return "Test public remapped value mechanism in ConversionPatternRewriter"; 1625 } 1626 void runOnOperation() override { 1627 mlir::RewritePatternSet patterns(&getContext()); 1628 patterns.add<RemoveTestDialectOps>(&getContext()); 1629 1630 mlir::ConversionTarget target(getContext()); 1631 target.addIllegalDialect<TestDialect>(); 1632 if (failed(applyPartialConversion(getOperation(), target, 1633 std::move(patterns)))) 1634 signalPassFailure(); 1635 } 1636 }; 1637 } // namespace 1638 1639 //===----------------------------------------------------------------------===// 1640 // Test patterns that uses operations and types defined at runtime 1641 //===----------------------------------------------------------------------===// 1642 1643 namespace { 1644 /// This pattern matches dynamic operations 'test.one_operand_two_results' and 1645 /// replace them with dynamic operations 'test.generic_dynamic_op'. 1646 struct RewriteDynamicOp : public RewritePattern { 1647 RewriteDynamicOp(MLIRContext *context) 1648 : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1, 1649 context) {} 1650 1651 LogicalResult matchAndRewrite(Operation *op, 1652 PatternRewriter &rewriter) const override { 1653 assert(op->getName().getStringRef() == 1654 "test.dynamic_one_operand_two_results" && 1655 "rewrite pattern should only match operations with the right name"); 1656 1657 OperationState state(op->getLoc(), "test.dynamic_generic", 1658 op->getOperands(), op->getResultTypes(), 1659 op->getAttrs()); 1660 auto *newOp = rewriter.create(state); 1661 rewriter.replaceOp(op, newOp->getResults()); 1662 return success(); 1663 } 1664 }; 1665 1666 struct TestRewriteDynamicOpDriver 1667 : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> { 1668 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver) 1669 1670 void getDependentDialects(DialectRegistry ®istry) const override { 1671 registry.insert<TestDialect>(); 1672 } 1673 StringRef getArgument() const final { return "test-rewrite-dynamic-op"; } 1674 StringRef getDescription() const final { 1675 return "Test rewritting on dynamic operations"; 1676 } 1677 void runOnOperation() override { 1678 RewritePatternSet patterns(&getContext()); 1679 patterns.add<RewriteDynamicOp>(&getContext()); 1680 1681 ConversionTarget target(getContext()); 1682 target.addIllegalOp( 1683 OperationName("test.dynamic_one_operand_two_results", &getContext())); 1684 target.addLegalOp(OperationName("test.dynamic_generic", &getContext())); 1685 if (failed(applyPartialConversion(getOperation(), target, 1686 std::move(patterns)))) 1687 signalPassFailure(); 1688 } 1689 }; 1690 } // end anonymous namespace 1691 1692 //===----------------------------------------------------------------------===// 1693 // Test type conversions 1694 //===----------------------------------------------------------------------===// 1695 1696 namespace { 1697 struct TestTypeConversionProducer 1698 : public OpConversionPattern<TestTypeProducerOp> { 1699 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 1700 LogicalResult 1701 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, 1702 ConversionPatternRewriter &rewriter) const final { 1703 Type resultType = op.getType(); 1704 Type convertedType = getTypeConverter() 1705 ? getTypeConverter()->convertType(resultType) 1706 : resultType; 1707 if (isa<FloatType>(resultType)) 1708 resultType = rewriter.getF64Type(); 1709 else if (resultType.isInteger(16)) 1710 resultType = rewriter.getIntegerType(64); 1711 else if (isa<test::TestRecursiveType>(resultType) && 1712 convertedType != resultType) 1713 resultType = convertedType; 1714 else 1715 return failure(); 1716 1717 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 1718 return success(); 1719 } 1720 }; 1721 1722 /// Call signature conversion and then fail the rewrite to trigger the undo 1723 /// mechanism. 1724 struct TestSignatureConversionUndo 1725 : public OpConversionPattern<TestSignatureConversionUndoOp> { 1726 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 1727 1728 LogicalResult 1729 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, 1730 ConversionPatternRewriter &rewriter) const final { 1731 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 1732 return failure(); 1733 } 1734 }; 1735 1736 /// Call signature conversion without providing a type converter to handle 1737 /// materializations. 1738 struct TestTestSignatureConversionNoConverter 1739 : public OpConversionPattern<TestSignatureConversionNoConverterOp> { 1740 TestTestSignatureConversionNoConverter(const TypeConverter &converter, 1741 MLIRContext *context) 1742 : OpConversionPattern<TestSignatureConversionNoConverterOp>(context), 1743 converter(converter) {} 1744 1745 LogicalResult 1746 matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor, 1747 ConversionPatternRewriter &rewriter) const final { 1748 Region ®ion = op->getRegion(0); 1749 Block *entry = ®ion.front(); 1750 1751 // Convert the original entry arguments. 1752 TypeConverter::SignatureConversion result(entry->getNumArguments()); 1753 if (failed( 1754 converter.convertSignatureArgs(entry->getArgumentTypes(), result))) 1755 return failure(); 1756 rewriter.modifyOpInPlace(op, [&] { 1757 rewriter.applySignatureConversion(®ion.front(), result); 1758 }); 1759 return success(); 1760 } 1761 1762 const TypeConverter &converter; 1763 }; 1764 1765 /// Just forward the operands to the root op. This is essentially a no-op 1766 /// pattern that is used to trigger target materialization. 1767 struct TestTypeConsumerForward 1768 : public OpConversionPattern<TestTypeConsumerOp> { 1769 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 1770 1771 LogicalResult 1772 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, 1773 ConversionPatternRewriter &rewriter) const final { 1774 rewriter.modifyOpInPlace(op, 1775 [&] { op->setOperands(adaptor.getOperands()); }); 1776 return success(); 1777 } 1778 }; 1779 1780 struct TestTypeConversionAnotherProducer 1781 : public OpRewritePattern<TestAnotherTypeProducerOp> { 1782 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 1783 1784 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 1785 PatternRewriter &rewriter) const final { 1786 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 1787 return success(); 1788 } 1789 }; 1790 1791 struct TestReplaceWithLegalOp : public ConversionPattern { 1792 TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx) 1793 : ConversionPattern(converter, "test.replace_with_legal_op", 1794 /*benefit=*/1, ctx) {} 1795 LogicalResult 1796 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1797 ConversionPatternRewriter &rewriter) const final { 1798 rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]); 1799 return success(); 1800 } 1801 }; 1802 1803 struct TestTypeConversionDriver 1804 : public PassWrapper<TestTypeConversionDriver, OperationPass<>> { 1805 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver) 1806 1807 void getDependentDialects(DialectRegistry ®istry) const override { 1808 registry.insert<TestDialect>(); 1809 } 1810 StringRef getArgument() const final { 1811 return "test-legalize-type-conversion"; 1812 } 1813 StringRef getDescription() const final { 1814 return "Test various type conversion functionalities in DialectConversion"; 1815 } 1816 1817 void runOnOperation() override { 1818 // Initialize the type converter. 1819 SmallVector<Type, 2> conversionCallStack; 1820 TypeConverter converter; 1821 1822 /// Add the legal set of type conversions. 1823 converter.addConversion([](Type type) -> Type { 1824 // Treat F64 as legal. 1825 if (type.isF64()) 1826 return type; 1827 // Allow converting BF16/F16/F32 to F64. 1828 if (type.isBF16() || type.isF16() || type.isF32()) 1829 return Float64Type::get(type.getContext()); 1830 // Otherwise, the type is illegal. 1831 return nullptr; 1832 }); 1833 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 1834 // Drop all integer types. 1835 return success(); 1836 }); 1837 converter.addConversion( 1838 // Convert a recursive self-referring type into a non-self-referring 1839 // type named "outer_converted_type" that contains a SimpleAType. 1840 [&](test::TestRecursiveType type, 1841 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { 1842 // If the type is already converted, return it to indicate that it is 1843 // legal. 1844 if (type.getName() == "outer_converted_type") { 1845 results.push_back(type); 1846 return success(); 1847 } 1848 1849 conversionCallStack.push_back(type); 1850 auto popConversionCallStack = llvm::make_scope_exit( 1851 [&conversionCallStack]() { conversionCallStack.pop_back(); }); 1852 1853 // If the type is on the call stack more than once (it is there at 1854 // least once because of the _current_ call, which is always the last 1855 // element on the stack), we've hit the recursive case. Just return 1856 // SimpleAType here to create a non-recursive type as a result. 1857 if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(), 1858 type)) { 1859 results.push_back(test::SimpleAType::get(type.getContext())); 1860 return success(); 1861 } 1862 1863 // Convert the body recursively. 1864 auto result = test::TestRecursiveType::get(type.getContext(), 1865 "outer_converted_type"); 1866 if (failed(result.setBody(converter.convertType(type.getBody())))) 1867 return failure(); 1868 results.push_back(result); 1869 return success(); 1870 }); 1871 1872 /// Add the legal set of type materializations. 1873 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 1874 ValueRange inputs, 1875 Location loc) -> Value { 1876 // Allow casting from F64 back to F32. 1877 if (!resultType.isF16() && inputs.size() == 1 && 1878 inputs[0].getType().isF64()) 1879 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1880 // Allow producing an i32 or i64 from nothing. 1881 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 1882 inputs.empty()) 1883 return builder.create<TestTypeProducerOp>(loc, resultType); 1884 // Allow producing an i64 from an integer. 1885 if (isa<IntegerType>(resultType) && inputs.size() == 1 && 1886 isa<IntegerType>(inputs[0].getType())) 1887 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1888 // Otherwise, fail. 1889 return nullptr; 1890 }); 1891 1892 // Initialize the conversion target. 1893 mlir::ConversionTarget target(getContext()); 1894 target.addLegalOp<LegalOpD>(); 1895 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 1896 auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType()); 1897 return op.getType().isF64() || op.getType().isInteger(64) || 1898 (recursiveType && 1899 recursiveType.getName() == "outer_converted_type"); 1900 }); 1901 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1902 return converter.isSignatureLegal(op.getFunctionType()) && 1903 converter.isLegal(&op.getBody()); 1904 }); 1905 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 1906 // Allow casts from F64 to F32. 1907 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 1908 }); 1909 target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>( 1910 [&](TestSignatureConversionNoConverterOp op) { 1911 return converter.isLegal(op.getRegion().front().getArgumentTypes()); 1912 }); 1913 1914 // Initialize the set of rewrite patterns. 1915 RewritePatternSet patterns(&getContext()); 1916 patterns 1917 .add<TestTypeConsumerForward, TestTypeConversionProducer, 1918 TestSignatureConversionUndo, 1919 TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>( 1920 converter, &getContext()); 1921 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 1922 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1923 converter); 1924 1925 if (failed(applyPartialConversion(getOperation(), target, 1926 std::move(patterns)))) 1927 signalPassFailure(); 1928 } 1929 }; 1930 } // namespace 1931 1932 //===----------------------------------------------------------------------===// 1933 // Test Target Materialization With No Uses 1934 //===----------------------------------------------------------------------===// 1935 1936 namespace { 1937 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> { 1938 using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern; 1939 1940 LogicalResult 1941 matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor, 1942 ConversionPatternRewriter &rewriter) const final { 1943 rewriter.replaceOp(op, adaptor.getOperands()); 1944 return success(); 1945 } 1946 }; 1947 1948 struct TestTargetMaterializationWithNoUses 1949 : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> { 1950 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 1951 TestTargetMaterializationWithNoUses) 1952 1953 StringRef getArgument() const final { 1954 return "test-target-materialization-with-no-uses"; 1955 } 1956 StringRef getDescription() const final { 1957 return "Test a special case of target materialization in DialectConversion"; 1958 } 1959 1960 void runOnOperation() override { 1961 TypeConverter converter; 1962 converter.addConversion([](Type t) { return t; }); 1963 converter.addConversion([](IntegerType intTy) -> Type { 1964 if (intTy.getWidth() == 16) 1965 return IntegerType::get(intTy.getContext(), 64); 1966 return intTy; 1967 }); 1968 converter.addTargetMaterialization( 1969 [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { 1970 return builder.create<TestCastOp>(loc, type, inputs).getResult(); 1971 }); 1972 1973 ConversionTarget target(getContext()); 1974 target.addIllegalOp<TestTypeChangerOp>(); 1975 1976 RewritePatternSet patterns(&getContext()); 1977 patterns.add<ForwardOperandPattern>(converter, &getContext()); 1978 1979 if (failed(applyPartialConversion(getOperation(), target, 1980 std::move(patterns)))) 1981 signalPassFailure(); 1982 } 1983 }; 1984 } // namespace 1985 1986 //===----------------------------------------------------------------------===// 1987 // Test Block Merging 1988 //===----------------------------------------------------------------------===// 1989 1990 namespace { 1991 /// A rewriter pattern that tests that blocks can be merged. 1992 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 1993 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 1994 1995 LogicalResult 1996 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, 1997 ConversionPatternRewriter &rewriter) const final { 1998 Block &firstBlock = op.getBody().front(); 1999 Operation *branchOp = firstBlock.getTerminator(); 2000 Block *secondBlock = &*(std::next(op.getBody().begin())); 2001 auto succOperands = branchOp->getOperands(); 2002 SmallVector<Value, 2> replacements(succOperands); 2003 rewriter.eraseOp(branchOp); 2004 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 2005 rewriter.modifyOpInPlace(op, [] {}); 2006 return success(); 2007 } 2008 }; 2009 2010 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 2011 struct TestUndoBlocksMerge : public ConversionPattern { 2012 TestUndoBlocksMerge(MLIRContext *ctx) 2013 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 2014 LogicalResult 2015 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 2016 ConversionPatternRewriter &rewriter) const final { 2017 Block &firstBlock = op->getRegion(0).front(); 2018 Operation *branchOp = firstBlock.getTerminator(); 2019 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 2020 rewriter.setInsertionPointToStart(secondBlock); 2021 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 2022 auto succOperands = branchOp->getOperands(); 2023 SmallVector<Value, 2> replacements(succOperands); 2024 rewriter.eraseOp(branchOp); 2025 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 2026 rewriter.modifyOpInPlace(op, [] {}); 2027 return success(); 2028 } 2029 }; 2030 2031 /// A rewrite mechanism to inline the body of the op into its parent, when both 2032 /// ops can have a single block. 2033 struct TestMergeSingleBlockOps 2034 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 2035 using OpConversionPattern< 2036 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 2037 2038 LogicalResult 2039 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, 2040 ConversionPatternRewriter &rewriter) const final { 2041 SingleBlockImplicitTerminatorOp parentOp = 2042 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 2043 if (!parentOp) 2044 return failure(); 2045 Block &innerBlock = op.getRegion().front(); 2046 TerminatorOp innerTerminator = 2047 cast<TerminatorOp>(innerBlock.getTerminator()); 2048 rewriter.inlineBlockBefore(&innerBlock, op); 2049 rewriter.eraseOp(innerTerminator); 2050 rewriter.eraseOp(op); 2051 return success(); 2052 } 2053 }; 2054 2055 struct TestMergeBlocksPatternDriver 2056 : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> { 2057 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver) 2058 2059 StringRef getArgument() const final { return "test-merge-blocks"; } 2060 StringRef getDescription() const final { 2061 return "Test Merging operation in ConversionPatternRewriter"; 2062 } 2063 void runOnOperation() override { 2064 MLIRContext *context = &getContext(); 2065 mlir::RewritePatternSet patterns(context); 2066 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 2067 context); 2068 ConversionTarget target(*context); 2069 target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 2070 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 2071 target.addIllegalOp<ILLegalOpF>(); 2072 2073 /// Expect the op to have a single block after legalization. 2074 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 2075 [&](TestMergeBlocksOp op) -> bool { 2076 return llvm::hasSingleElement(op.getBody()); 2077 }); 2078 2079 /// Only allow `test.br` within test.merge_blocks op. 2080 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 2081 return op->getParentOfType<TestMergeBlocksOp>(); 2082 }); 2083 2084 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 2085 /// inlined. 2086 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 2087 [&](SingleBlockImplicitTerminatorOp op) -> bool { 2088 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 2089 }); 2090 2091 DenseSet<Operation *> unlegalizedOps; 2092 ConversionConfig config; 2093 config.unlegalizedOps = &unlegalizedOps; 2094 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 2095 config); 2096 for (auto *op : unlegalizedOps) 2097 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 2098 } 2099 }; 2100 } // namespace 2101 2102 //===----------------------------------------------------------------------===// 2103 // Test Selective Replacement 2104 //===----------------------------------------------------------------------===// 2105 2106 namespace { 2107 /// A rewrite mechanism to inline the body of the op into its parent, when both 2108 /// ops can have a single block. 2109 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 2110 using OpRewritePattern<TestCastOp>::OpRewritePattern; 2111 2112 LogicalResult matchAndRewrite(TestCastOp op, 2113 PatternRewriter &rewriter) const final { 2114 if (op.getNumOperands() != 2) 2115 return failure(); 2116 OperandRange operands = op.getOperands(); 2117 2118 // Replace non-terminator uses with the first operand. 2119 rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) { 2120 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 2121 }); 2122 // Replace everything else with the second operand if the operation isn't 2123 // dead. 2124 rewriter.replaceOp(op, op.getOperand(1)); 2125 return success(); 2126 } 2127 }; 2128 2129 struct TestSelectiveReplacementPatternDriver 2130 : public PassWrapper<TestSelectiveReplacementPatternDriver, 2131 OperationPass<>> { 2132 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 2133 TestSelectiveReplacementPatternDriver) 2134 2135 StringRef getArgument() const final { 2136 return "test-pattern-selective-replacement"; 2137 } 2138 StringRef getDescription() const final { 2139 return "Test selective replacement in the PatternRewriter"; 2140 } 2141 void runOnOperation() override { 2142 MLIRContext *context = &getContext(); 2143 mlir::RewritePatternSet patterns(context); 2144 patterns.add<TestSelectiveOpReplacementPattern>(context); 2145 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 2146 } 2147 }; 2148 } // namespace 2149 2150 //===----------------------------------------------------------------------===// 2151 // PassRegistration 2152 //===----------------------------------------------------------------------===// 2153 2154 namespace mlir { 2155 namespace test { 2156 void registerPatternsTestPass() { 2157 PassRegistration<TestReturnTypeDriver>(); 2158 2159 PassRegistration<TestDerivedAttributeDriver>(); 2160 2161 PassRegistration<TestGreedyPatternDriver>(); 2162 PassRegistration<TestStrictPatternDriver>(); 2163 PassRegistration<TestWalkPatternDriver>(); 2164 2165 PassRegistration<TestLegalizePatternDriver>([] { 2166 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 2167 }); 2168 2169 PassRegistration<TestRemappedValue>(); 2170 2171 PassRegistration<TestUnknownRootOpDriver>(); 2172 2173 PassRegistration<TestTypeConversionDriver>(); 2174 PassRegistration<TestTargetMaterializationWithNoUses>(); 2175 2176 PassRegistration<TestRewriteDynamicOpDriver>(); 2177 2178 PassRegistration<TestMergeBlocksPatternDriver>(); 2179 PassRegistration<TestSelectiveReplacementPatternDriver>(); 2180 } 2181 } // namespace test 2182 } // namespace mlir 2183