1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestTypes.h" 11 #include "mlir/Dialect/Arith/IR/Arith.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/Matchers.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 #include "mlir/Transforms/FoldUtils.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "llvm/ADT/ScopeExit.h" 21 22 using namespace mlir; 23 using namespace test; 24 25 // Native function for testing NativeCodeCall 26 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 27 return choice.getValue() ? input1 : input2; 28 } 29 30 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 31 rewriter.create<OpI>(loc, input); 32 } 33 34 static void handleNoResultOp(PatternRewriter &rewriter, 35 OpSymbolBindingNoResult op) { 36 // Turn the no result op to a one-result op. 37 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), 38 op.getOperand()); 39 } 40 41 static bool getFirstI32Result(Operation *op, Value &value) { 42 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 43 return false; 44 value = op->getResult(0); 45 return true; 46 } 47 48 static Value bindNativeCodeCallResult(Value value) { return value; } 49 50 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, 51 Value input2) { 52 return SmallVector<Value, 2>({input2, input1}); 53 } 54 55 // Test that natives calls are only called once during rewrites. 56 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 57 // This let us check the number of times OpM_Test was called by inspecting 58 // the returned value in the MLIR output. 59 static int64_t opMIncreasingValue = 314159265; 60 static Attribute opMTest(PatternRewriter &rewriter, Value val) { 61 int64_t i = opMIncreasingValue++; 62 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 63 } 64 65 namespace { 66 #include "TestPatterns.inc" 67 } // namespace 68 69 //===----------------------------------------------------------------------===// 70 // Test Reduce Pattern Interface 71 //===----------------------------------------------------------------------===// 72 73 void test::populateTestReductionPatterns(RewritePatternSet &patterns) { 74 populateWithGenerated(patterns); 75 } 76 77 //===----------------------------------------------------------------------===// 78 // Canonicalizer Driver. 79 //===----------------------------------------------------------------------===// 80 81 namespace { 82 struct FoldingPattern : public RewritePattern { 83 public: 84 FoldingPattern(MLIRContext *context) 85 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 86 /*benefit=*/1, context) {} 87 88 LogicalResult matchAndRewrite(Operation *op, 89 PatternRewriter &rewriter) const override { 90 // Exercise createOrFold API for a single-result operation that is folded 91 // upon construction. The operation being created has an in-place folder, 92 // and it should be still present in the output. Furthermore, the folder 93 // should not crash when attempting to recover the (unchanged) operation 94 // result. 95 Value result = rewriter.createOrFold<TestOpInPlaceFold>( 96 op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0)); 97 assert(result); 98 rewriter.replaceOp(op, result); 99 return success(); 100 } 101 }; 102 103 /// This pattern creates a foldable operation at the entry point of the block. 104 /// This tests the situation where the operation folder will need to replace an 105 /// operation with a previously created constant that does not initially 106 /// dominate the operation to replace. 107 struct FolderInsertBeforePreviouslyFoldedConstantPattern 108 : public OpRewritePattern<TestCastOp> { 109 public: 110 using OpRewritePattern<TestCastOp>::OpRewritePattern; 111 112 LogicalResult matchAndRewrite(TestCastOp op, 113 PatternRewriter &rewriter) const override { 114 if (!op->hasAttr("test_fold_before_previously_folded_op")) 115 return failure(); 116 rewriter.setInsertionPointToStart(op->getBlock()); 117 118 auto constOp = rewriter.create<arith::ConstantOp>( 119 op.getLoc(), rewriter.getBoolAttr(true)); 120 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), 121 Value(constOp)); 122 return success(); 123 } 124 }; 125 126 /// This pattern matches test.op_commutative2 with the first operand being 127 /// another test.op_commutative2 with a constant on the right side and fold it 128 /// away by propagating it as its result. This is intend to check that patterns 129 /// are applied after the commutative property moves constant to the right. 130 struct FolderCommutativeOp2WithConstant 131 : public OpRewritePattern<TestCommutative2Op> { 132 public: 133 using OpRewritePattern<TestCommutative2Op>::OpRewritePattern; 134 135 LogicalResult matchAndRewrite(TestCommutative2Op op, 136 PatternRewriter &rewriter) const override { 137 auto operand = 138 dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp()); 139 if (!operand) 140 return failure(); 141 Attribute constInput; 142 if (!matchPattern(operand->getOperand(1), m_Constant(&constInput))) 143 return failure(); 144 rewriter.replaceOp(op, operand->getOperand(1)); 145 return success(); 146 } 147 }; 148 149 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer 150 /// attribute with value smaller than MaxVal, it increments the value by 1. 151 template <int MaxVal> 152 struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> { 153 using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern; 154 155 LogicalResult matchAndRewrite(AnyAttrOfOp op, 156 PatternRewriter &rewriter) const override { 157 auto intAttr = dyn_cast<IntegerAttr>(op.getAttr()); 158 if (!intAttr) 159 return failure(); 160 int64_t val = intAttr.getInt(); 161 if (val >= MaxVal) 162 return failure(); 163 rewriter.modifyOpInPlace( 164 op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); }); 165 return success(); 166 } 167 }; 168 169 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op". 170 struct MakeOpEligible : public RewritePattern { 171 MakeOpEligible(MLIRContext *context) 172 : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {} 173 174 LogicalResult matchAndRewrite(Operation *op, 175 PatternRewriter &rewriter) const override { 176 if (op->hasAttr("eligible")) 177 return failure(); 178 rewriter.modifyOpInPlace( 179 op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); }); 180 return success(); 181 } 182 }; 183 184 /// This pattern hoists eligible ops out of a "test.one_region_op". 185 struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> { 186 using OpRewritePattern<test::OneRegionOp>::OpRewritePattern; 187 188 LogicalResult matchAndRewrite(test::OneRegionOp op, 189 PatternRewriter &rewriter) const override { 190 Operation *terminator = op.getRegion().front().getTerminator(); 191 Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp(); 192 if (toBeHoisted->getParentOp() != op) 193 return failure(); 194 if (!toBeHoisted->hasAttr("eligible")) 195 return failure(); 196 rewriter.moveOpBefore(toBeHoisted, op); 197 return success(); 198 } 199 }; 200 201 /// This pattern moves "test.move_before_parent_op" before the parent op. 202 struct MoveBeforeParentOp : public RewritePattern { 203 MoveBeforeParentOp(MLIRContext *context) 204 : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {} 205 206 LogicalResult matchAndRewrite(Operation *op, 207 PatternRewriter &rewriter) const override { 208 // Do not hoist past functions. 209 if (isa<FunctionOpInterface>(op->getParentOp())) 210 return failure(); 211 rewriter.moveOpBefore(op, op->getParentOp()); 212 return success(); 213 } 214 }; 215 216 /// This pattern inlines blocks that are nested in 217 /// "test.inline_blocks_into_parent" into the parent block. 218 struct InlineBlocksIntoParent : public RewritePattern { 219 InlineBlocksIntoParent(MLIRContext *context) 220 : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1, 221 context) {} 222 223 LogicalResult matchAndRewrite(Operation *op, 224 PatternRewriter &rewriter) const override { 225 bool changed = false; 226 for (Region &r : op->getRegions()) { 227 while (!r.empty()) { 228 rewriter.inlineBlockBefore(&r.front(), op); 229 changed = true; 230 } 231 } 232 return success(changed); 233 } 234 }; 235 236 /// This pattern splits blocks at "test.split_block_here" and replaces the op 237 /// with a new op (to prevent an infinite loop of block splitting). 238 struct SplitBlockHere : public RewritePattern { 239 SplitBlockHere(MLIRContext *context) 240 : RewritePattern("test.split_block_here", /*benefit=*/1, context) {} 241 242 LogicalResult matchAndRewrite(Operation *op, 243 PatternRewriter &rewriter) const override { 244 rewriter.splitBlock(op->getBlock(), op->getIterator()); 245 Operation *newOp = rewriter.create( 246 op->getLoc(), 247 OperationName("test.new_op", op->getContext()).getIdentifier(), 248 op->getOperands(), op->getResultTypes()); 249 rewriter.replaceOp(op, newOp); 250 return success(); 251 } 252 }; 253 254 /// This pattern clones "test.clone_me" ops. 255 struct CloneOp : public RewritePattern { 256 CloneOp(MLIRContext *context) 257 : RewritePattern("test.clone_me", /*benefit=*/1, context) {} 258 259 LogicalResult matchAndRewrite(Operation *op, 260 PatternRewriter &rewriter) const override { 261 // Do not clone already cloned ops to avoid going into an infinite loop. 262 if (op->hasAttr("was_cloned")) 263 return failure(); 264 Operation *cloned = rewriter.clone(*op); 265 cloned->setAttr("was_cloned", rewriter.getUnitAttr()); 266 return success(); 267 } 268 }; 269 270 /// This pattern clones regions of "test.clone_region_before" ops before the 271 /// parent block. 272 struct CloneRegionBeforeOp : public RewritePattern { 273 CloneRegionBeforeOp(MLIRContext *context) 274 : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {} 275 276 LogicalResult matchAndRewrite(Operation *op, 277 PatternRewriter &rewriter) const override { 278 // Do not clone already cloned ops to avoid going into an infinite loop. 279 if (op->hasAttr("was_cloned")) 280 return failure(); 281 for (Region &r : op->getRegions()) 282 rewriter.cloneRegionBefore(r, op->getBlock()); 283 op->setAttr("was_cloned", rewriter.getUnitAttr()); 284 return success(); 285 } 286 }; 287 288 struct TestPatternDriver 289 : public PassWrapper<TestPatternDriver, OperationPass<>> { 290 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver) 291 292 TestPatternDriver() = default; 293 TestPatternDriver(const TestPatternDriver &other) : PassWrapper(other) {} 294 295 StringRef getArgument() const final { return "test-patterns"; } 296 StringRef getDescription() const final { return "Run test dialect patterns"; } 297 void runOnOperation() override { 298 mlir::RewritePatternSet patterns(&getContext()); 299 populateWithGenerated(patterns); 300 301 // Verify named pattern is generated with expected name. 302 patterns.add<FoldingPattern, TestNamedPatternRule, 303 FolderInsertBeforePreviouslyFoldedConstantPattern, 304 FolderCommutativeOp2WithConstant, HoistEligibleOps, 305 MakeOpEligible>(&getContext()); 306 307 // Additional patterns for testing the GreedyPatternRewriteDriver. 308 patterns.insert<IncrementIntAttribute<3>>(&getContext()); 309 310 GreedyRewriteConfig config; 311 config.useTopDownTraversal = this->useTopDownTraversal; 312 config.maxIterations = this->maxIterations; 313 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), 314 config); 315 } 316 317 Option<bool> useTopDownTraversal{ 318 *this, "top-down", 319 llvm::cl::desc("Seed the worklist in general top-down order"), 320 llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)}; 321 Option<int> maxIterations{ 322 *this, "max-iterations", 323 llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"), 324 llvm::cl::init(GreedyRewriteConfig().maxIterations)}; 325 }; 326 327 struct DumpNotifications : public RewriterBase::Listener { 328 void notifyBlockInserted(Block *block, Region *previous, 329 Region::iterator previousIt) override { 330 llvm::outs() << "notifyBlockInserted into " 331 << block->getParentOp()->getName() << ": "; 332 if (previous == nullptr) { 333 llvm::outs() << "was unlinked\n"; 334 } else { 335 llvm::outs() << "was linked\n"; 336 } 337 } 338 void notifyOperationInserted(Operation *op, 339 OpBuilder::InsertPoint previous) override { 340 llvm::outs() << "notifyOperationInserted: " << op->getName(); 341 if (!previous.isSet()) { 342 llvm::outs() << ", was unlinked\n"; 343 } else { 344 if (previous.getPoint() == previous.getBlock()->end()) { 345 llvm::outs() << ", was last in block\n"; 346 } else { 347 llvm::outs() << ", previous = " << previous.getPoint()->getName() 348 << "\n"; 349 } 350 } 351 } 352 void notifyOperationErased(Operation *op) override { 353 llvm::outs() << "notifyOperationErased: " << op->getName() << "\n"; 354 } 355 }; 356 357 struct TestStrictPatternDriver 358 : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> { 359 public: 360 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver) 361 362 TestStrictPatternDriver() = default; 363 TestStrictPatternDriver(const TestStrictPatternDriver &other) 364 : PassWrapper(other) { 365 strictMode = other.strictMode; 366 } 367 368 StringRef getArgument() const final { return "test-strict-pattern-driver"; } 369 StringRef getDescription() const final { 370 return "Test strict mode of pattern driver"; 371 } 372 373 void runOnOperation() override { 374 MLIRContext *ctx = &getContext(); 375 mlir::RewritePatternSet patterns(ctx); 376 patterns.add< 377 // clang-format off 378 ChangeBlockOp, 379 CloneOp, 380 CloneRegionBeforeOp, 381 EraseOp, 382 ImplicitChangeOp, 383 InlineBlocksIntoParent, 384 InsertSameOp, 385 MoveBeforeParentOp, 386 ReplaceWithNewOp, 387 SplitBlockHere 388 // clang-format on 389 >(ctx); 390 SmallVector<Operation *> ops; 391 getOperation()->walk([&](Operation *op) { 392 StringRef opName = op->getName().getStringRef(); 393 if (opName == "test.insert_same_op" || opName == "test.change_block_op" || 394 opName == "test.replace_with_new_op" || opName == "test.erase_op" || 395 opName == "test.move_before_parent_op" || 396 opName == "test.inline_blocks_into_parent" || 397 opName == "test.split_block_here" || opName == "test.clone_me" || 398 opName == "test.clone_region_before") { 399 ops.push_back(op); 400 } 401 }); 402 403 DumpNotifications dumpNotifications; 404 GreedyRewriteConfig config; 405 config.listener = &dumpNotifications; 406 if (strictMode == "AnyOp") { 407 config.strictMode = GreedyRewriteStrictness::AnyOp; 408 } else if (strictMode == "ExistingAndNewOps") { 409 config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; 410 } else if (strictMode == "ExistingOps") { 411 config.strictMode = GreedyRewriteStrictness::ExistingOps; 412 } else { 413 llvm_unreachable("invalid strictness option"); 414 } 415 416 // Check if these transformations introduce visiting of operations that 417 // are not in the `ops` set (The new created ops are valid). An invalid 418 // operation will trigger the assertion while processing. 419 bool changed = false; 420 bool allErased = false; 421 (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config, 422 &changed, &allErased); 423 Builder b(ctx); 424 getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed)); 425 getOperation()->setAttr("pattern_driver_all_erased", 426 b.getBoolAttr(allErased)); 427 } 428 429 Option<std::string> strictMode{ 430 *this, "strictness", 431 llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"), 432 llvm::cl::init("AnyOp")}; 433 434 private: 435 // New inserted operation is valid for further transformation. 436 class InsertSameOp : public RewritePattern { 437 public: 438 InsertSameOp(MLIRContext *context) 439 : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {} 440 441 LogicalResult matchAndRewrite(Operation *op, 442 PatternRewriter &rewriter) const override { 443 if (op->hasAttr("skip")) 444 return failure(); 445 446 Operation *newOp = 447 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 448 op->getOperands(), op->getResultTypes()); 449 rewriter.modifyOpInPlace( 450 op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); }); 451 newOp->setAttr("skip", rewriter.getBoolAttr(true)); 452 453 return success(); 454 } 455 }; 456 457 // Replace an operation may introduce the re-visiting of its users. 458 class ReplaceWithNewOp : public RewritePattern { 459 public: 460 ReplaceWithNewOp(MLIRContext *context) 461 : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {} 462 463 LogicalResult matchAndRewrite(Operation *op, 464 PatternRewriter &rewriter) const override { 465 Operation *newOp; 466 if (op->hasAttr("create_erase_op")) { 467 newOp = rewriter.create( 468 op->getLoc(), 469 OperationName("test.erase_op", op->getContext()).getIdentifier(), 470 ValueRange(), TypeRange()); 471 } else { 472 newOp = rewriter.create( 473 op->getLoc(), 474 OperationName("test.new_op", op->getContext()).getIdentifier(), 475 op->getOperands(), op->getResultTypes()); 476 } 477 rewriter.replaceOp(op, newOp->getResults()); 478 return success(); 479 } 480 }; 481 482 // Remove an operation may introduce the re-visiting of its operands. 483 class EraseOp : public RewritePattern { 484 public: 485 EraseOp(MLIRContext *context) 486 : RewritePattern("test.erase_op", /*benefit=*/1, context) {} 487 LogicalResult matchAndRewrite(Operation *op, 488 PatternRewriter &rewriter) const override { 489 rewriter.eraseOp(op); 490 return success(); 491 } 492 }; 493 494 // The following two patterns test RewriterBase::replaceAllUsesWith. 495 // 496 // That function replaces all usages of a Block (or a Value) with another one 497 // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver 498 // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its 499 // worklist: when an op is modified, it is added to the worklist. The two 500 // patterns below make the tracking observable: ChangeBlockOp replaces all 501 // usages of a block and that pattern is applied because the corresponding ops 502 // are put on the initial worklist (see above). ImplicitChangeOp does an 503 // unrelated change but ops of the corresponding type are *not* on the initial 504 // worklist, so the effect of the second pattern is only visible if the 505 // tracking and subsequent adding to the worklist actually works. 506 507 // Replace all usages of the first successor with the second successor. 508 class ChangeBlockOp : public RewritePattern { 509 public: 510 ChangeBlockOp(MLIRContext *context) 511 : RewritePattern("test.change_block_op", /*benefit=*/1, context) {} 512 LogicalResult matchAndRewrite(Operation *op, 513 PatternRewriter &rewriter) const override { 514 if (op->getNumSuccessors() < 2) 515 return failure(); 516 Block *firstSuccessor = op->getSuccessor(0); 517 Block *secondSuccessor = op->getSuccessor(1); 518 if (firstSuccessor == secondSuccessor) 519 return failure(); 520 // This is the function being tested: 521 rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor); 522 // Using the following line instead would make the test fail: 523 // firstSuccessor->replaceAllUsesWith(secondSuccessor); 524 return success(); 525 } 526 }; 527 528 // Changes the successor to the parent block. 529 class ImplicitChangeOp : public RewritePattern { 530 public: 531 ImplicitChangeOp(MLIRContext *context) 532 : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {} 533 LogicalResult matchAndRewrite(Operation *op, 534 PatternRewriter &rewriter) const override { 535 if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock()) 536 return failure(); 537 rewriter.modifyOpInPlace(op, 538 [&]() { op->setSuccessor(op->getBlock(), 0); }); 539 return success(); 540 } 541 }; 542 }; 543 544 } // namespace 545 546 //===----------------------------------------------------------------------===// 547 // ReturnType Driver. 548 //===----------------------------------------------------------------------===// 549 550 namespace { 551 // Generate ops for each instance where the type can be successfully inferred. 552 template <typename OpTy> 553 static void invokeCreateWithInferredReturnType(Operation *op) { 554 auto *context = op->getContext(); 555 auto fop = op->getParentOfType<func::FuncOp>(); 556 auto location = UnknownLoc::get(context); 557 OpBuilder b(op); 558 b.setInsertionPointAfter(op); 559 560 // Use permutations of 2 args as operands. 561 assert(fop.getNumArguments() >= 2); 562 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 563 for (int j = 0; j < e; ++j) { 564 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 565 SmallVector<Type, 2> inferredReturnTypes; 566 if (succeeded(OpTy::inferReturnTypes( 567 context, std::nullopt, values, op->getDiscardableAttrDictionary(), 568 op->getPropertiesStorage(), op->getRegions(), 569 inferredReturnTypes))) { 570 OperationState state(location, OpTy::getOperationName()); 571 // TODO: Expand to regions. 572 OpTy::build(b, state, values, op->getAttrs()); 573 (void)b.create(state); 574 } 575 } 576 } 577 } 578 579 static void reifyReturnShape(Operation *op) { 580 OpBuilder b(op); 581 582 // Use permutations of 2 args as operands. 583 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 584 SmallVector<Value, 2> shapes; 585 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 586 !llvm::hasSingleElement(shapes)) 587 return; 588 for (const auto &it : llvm::enumerate(shapes)) { 589 op->emitRemark() << "value " << it.index() << ": " 590 << it.value().getDefiningOp(); 591 } 592 } 593 594 struct TestReturnTypeDriver 595 : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> { 596 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver) 597 598 void getDependentDialects(DialectRegistry ®istry) const override { 599 registry.insert<tensor::TensorDialect>(); 600 } 601 StringRef getArgument() const final { return "test-return-type"; } 602 StringRef getDescription() const final { return "Run return type functions"; } 603 604 void runOnOperation() override { 605 if (getOperation().getName() == "testCreateFunctions") { 606 std::vector<Operation *> ops; 607 // Collect ops to avoid triggering on inserted ops. 608 for (auto &op : getOperation().getBody().front()) 609 ops.push_back(&op); 610 // Generate test patterns for each, but skip terminator. 611 for (auto *op : llvm::ArrayRef(ops).drop_back()) { 612 // Test create method of each of the Op classes below. The resultant 613 // output would be in reverse order underneath `op` from which 614 // the attributes and regions are used. 615 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 616 invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>( 617 op); 618 invokeCreateWithInferredReturnType< 619 OpWithShapedTypeInferTypeInterfaceOp>(op); 620 }; 621 return; 622 } 623 if (getOperation().getName() == "testReifyFunctions") { 624 std::vector<Operation *> ops; 625 // Collect ops to avoid triggering on inserted ops. 626 for (auto &op : getOperation().getBody().front()) 627 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 628 ops.push_back(&op); 629 // Generate test patterns for each, but skip terminator. 630 for (auto *op : ops) 631 reifyReturnShape(op); 632 } 633 } 634 }; 635 } // namespace 636 637 namespace { 638 struct TestDerivedAttributeDriver 639 : public PassWrapper<TestDerivedAttributeDriver, 640 OperationPass<func::FuncOp>> { 641 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver) 642 643 StringRef getArgument() const final { return "test-derived-attr"; } 644 StringRef getDescription() const final { 645 return "Run test derived attributes"; 646 } 647 void runOnOperation() override; 648 }; 649 } // namespace 650 651 void TestDerivedAttributeDriver::runOnOperation() { 652 getOperation().walk([](DerivedAttributeOpInterface dOp) { 653 auto dAttr = dOp.materializeDerivedAttributes(); 654 if (!dAttr) 655 return; 656 for (auto d : dAttr) 657 dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue(); 658 }); 659 } 660 661 //===----------------------------------------------------------------------===// 662 // Legalization Driver. 663 //===----------------------------------------------------------------------===// 664 665 namespace { 666 //===----------------------------------------------------------------------===// 667 // Region-Block Rewrite Testing 668 669 /// This pattern is a simple pattern that inlines the first region of a given 670 /// operation into the parent region. 671 struct TestRegionRewriteBlockMovement : public ConversionPattern { 672 TestRegionRewriteBlockMovement(MLIRContext *ctx) 673 : ConversionPattern("test.region", 1, ctx) {} 674 675 LogicalResult 676 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 677 ConversionPatternRewriter &rewriter) const final { 678 // Inline this region into the parent region. 679 auto &parentRegion = *op->getParentRegion(); 680 auto &opRegion = op->getRegion(0); 681 if (op->getDiscardableAttr("legalizer.should_clone")) 682 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 683 else 684 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 685 686 if (op->getDiscardableAttr("legalizer.erase_old_blocks")) { 687 while (!opRegion.empty()) 688 rewriter.eraseBlock(&opRegion.front()); 689 } 690 691 // Drop this operation. 692 rewriter.eraseOp(op); 693 return success(); 694 } 695 }; 696 /// This pattern is a simple pattern that generates a region containing an 697 /// illegal operation. 698 struct TestRegionRewriteUndo : public RewritePattern { 699 TestRegionRewriteUndo(MLIRContext *ctx) 700 : RewritePattern("test.region_builder", 1, ctx) {} 701 702 LogicalResult matchAndRewrite(Operation *op, 703 PatternRewriter &rewriter) const final { 704 // Create the region operation with an entry block containing arguments. 705 OperationState newRegion(op->getLoc(), "test.region"); 706 newRegion.addRegion(); 707 auto *regionOp = rewriter.create(newRegion); 708 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 709 entryBlock->addArgument(rewriter.getIntegerType(64), 710 rewriter.getUnknownLoc()); 711 712 // Add an explicitly illegal operation to ensure the conversion fails. 713 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 714 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 715 716 // Drop this operation. 717 rewriter.eraseOp(op); 718 return success(); 719 } 720 }; 721 /// A simple pattern that creates a block at the end of the parent region of the 722 /// matched operation. 723 struct TestCreateBlock : public RewritePattern { 724 TestCreateBlock(MLIRContext *ctx) 725 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 726 727 LogicalResult matchAndRewrite(Operation *op, 728 PatternRewriter &rewriter) const final { 729 Region ®ion = *op->getParentRegion(); 730 Type i32Type = rewriter.getIntegerType(32); 731 Location loc = op->getLoc(); 732 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 733 rewriter.create<TerminatorOp>(loc); 734 rewriter.eraseOp(op); 735 return success(); 736 } 737 }; 738 739 /// A simple pattern that creates a block containing an invalid operation in 740 /// order to trigger the block creation undo mechanism. 741 struct TestCreateIllegalBlock : public RewritePattern { 742 TestCreateIllegalBlock(MLIRContext *ctx) 743 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 744 745 LogicalResult matchAndRewrite(Operation *op, 746 PatternRewriter &rewriter) const final { 747 Region ®ion = *op->getParentRegion(); 748 Type i32Type = rewriter.getIntegerType(32); 749 Location loc = op->getLoc(); 750 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 751 // Create an illegal op to ensure the conversion fails. 752 rewriter.create<ILLegalOpF>(loc, i32Type); 753 rewriter.create<TerminatorOp>(loc); 754 rewriter.eraseOp(op); 755 return success(); 756 } 757 }; 758 759 /// A simple pattern that tests the undo mechanism when replacing the uses of a 760 /// block argument. 761 struct TestUndoBlockArgReplace : public ConversionPattern { 762 TestUndoBlockArgReplace(MLIRContext *ctx) 763 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 764 765 LogicalResult 766 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 767 ConversionPatternRewriter &rewriter) const final { 768 auto illegalOp = 769 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 770 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 771 illegalOp->getResult(0)); 772 rewriter.modifyOpInPlace(op, [] {}); 773 return success(); 774 } 775 }; 776 777 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion. 778 /// This is to test the rollback logic. 779 struct TestUndoMoveOpBefore : public ConversionPattern { 780 TestUndoMoveOpBefore(MLIRContext *ctx) 781 : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {} 782 783 LogicalResult 784 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 785 ConversionPatternRewriter &rewriter) const override { 786 rewriter.moveOpBefore(op, op->getParentOp()); 787 // Replace with an illegal op to ensure the conversion fails. 788 rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type()); 789 return success(); 790 } 791 }; 792 793 /// A rewrite pattern that tests the undo mechanism when erasing a block. 794 struct TestUndoBlockErase : public ConversionPattern { 795 TestUndoBlockErase(MLIRContext *ctx) 796 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 797 798 LogicalResult 799 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 800 ConversionPatternRewriter &rewriter) const final { 801 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 802 rewriter.setInsertionPointToStart(secondBlock); 803 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 804 rewriter.eraseBlock(secondBlock); 805 rewriter.modifyOpInPlace(op, [] {}); 806 return success(); 807 } 808 }; 809 810 /// A pattern that modifies a property in-place, but keeps the op illegal. 811 struct TestUndoPropertiesModification : public ConversionPattern { 812 TestUndoPropertiesModification(MLIRContext *ctx) 813 : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {} 814 LogicalResult 815 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 816 ConversionPatternRewriter &rewriter) const final { 817 if (!op->hasAttr("modify_inplace")) 818 return failure(); 819 rewriter.modifyOpInPlace( 820 op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); }); 821 return success(); 822 } 823 }; 824 825 //===----------------------------------------------------------------------===// 826 // Type-Conversion Rewrite Testing 827 828 /// This patterns erases a region operation that has had a type conversion. 829 struct TestDropOpSignatureConversion : public ConversionPattern { 830 TestDropOpSignatureConversion(MLIRContext *ctx, 831 const TypeConverter &converter) 832 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 833 LogicalResult 834 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 835 ConversionPatternRewriter &rewriter) const override { 836 Region ®ion = op->getRegion(0); 837 Block *entry = ®ion.front(); 838 839 // Convert the original entry arguments. 840 const TypeConverter &converter = *getTypeConverter(); 841 TypeConverter::SignatureConversion result(entry->getNumArguments()); 842 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 843 result)) || 844 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 845 return failure(); 846 847 // Convert the region signature and just drop the operation. 848 rewriter.eraseOp(op); 849 return success(); 850 } 851 }; 852 /// This pattern simply updates the operands of the given operation. 853 struct TestPassthroughInvalidOp : public ConversionPattern { 854 TestPassthroughInvalidOp(MLIRContext *ctx) 855 : ConversionPattern("test.invalid", 1, ctx) {} 856 LogicalResult 857 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 858 ConversionPatternRewriter &rewriter) const final { 859 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, operands, 860 std::nullopt); 861 return success(); 862 } 863 }; 864 /// This pattern handles the case of a split return value. 865 struct TestSplitReturnType : public ConversionPattern { 866 TestSplitReturnType(MLIRContext *ctx) 867 : ConversionPattern("test.return", 1, ctx) {} 868 LogicalResult 869 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 870 ConversionPatternRewriter &rewriter) const final { 871 // Check for a return of F32. 872 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 873 return failure(); 874 875 // Check if the first operation is a cast operation, if it is we use the 876 // results directly. 877 auto *defOp = operands[0].getDefiningOp(); 878 if (auto packerOp = 879 llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) { 880 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 881 return success(); 882 } 883 884 // Otherwise, fail to match. 885 return failure(); 886 } 887 }; 888 889 //===----------------------------------------------------------------------===// 890 // Multi-Level Type-Conversion Rewrite Testing 891 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 892 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 893 : ConversionPattern("test.type_producer", 1, ctx) {} 894 LogicalResult 895 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 896 ConversionPatternRewriter &rewriter) const final { 897 // If the type is I32, change the type to F32. 898 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 899 return failure(); 900 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 901 return success(); 902 } 903 }; 904 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 905 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 906 : ConversionPattern("test.type_producer", 1, ctx) {} 907 LogicalResult 908 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 909 ConversionPatternRewriter &rewriter) const final { 910 // If the type is F32, change the type to F64. 911 if (!Type(*op->result_type_begin()).isF32()) 912 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 913 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 914 return success(); 915 } 916 }; 917 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 918 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 919 : ConversionPattern("test.type_producer", 10, ctx) {} 920 LogicalResult 921 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 922 ConversionPatternRewriter &rewriter) const final { 923 // Always convert to B16, even though it is not a legal type. This tests 924 // that values are unmapped correctly. 925 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 926 return success(); 927 } 928 }; 929 struct TestUpdateConsumerType : public ConversionPattern { 930 TestUpdateConsumerType(MLIRContext *ctx) 931 : ConversionPattern("test.type_consumer", 1, ctx) {} 932 LogicalResult 933 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 934 ConversionPatternRewriter &rewriter) const final { 935 // Verify that the incoming operand has been successfully remapped to F64. 936 if (!operands[0].getType().isF64()) 937 return failure(); 938 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 939 return success(); 940 } 941 }; 942 943 //===----------------------------------------------------------------------===// 944 // Non-Root Replacement Rewrite Testing 945 /// This pattern generates an invalid operation, but replaces it before the 946 /// pattern is finished. This checks that we don't need to legalize the 947 /// temporary op. 948 struct TestNonRootReplacement : public RewritePattern { 949 TestNonRootReplacement(MLIRContext *ctx) 950 : RewritePattern("test.replace_non_root", 1, ctx) {} 951 952 LogicalResult matchAndRewrite(Operation *op, 953 PatternRewriter &rewriter) const final { 954 auto resultType = *op->result_type_begin(); 955 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 956 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 957 958 rewriter.replaceOp(illegalOp, legalOp); 959 rewriter.replaceOp(op, illegalOp); 960 return success(); 961 } 962 }; 963 964 //===----------------------------------------------------------------------===// 965 // Recursive Rewrite Testing 966 /// This pattern is applied to the same operation multiple times, but has a 967 /// bounded recursion. 968 struct TestBoundedRecursiveRewrite 969 : public OpRewritePattern<TestRecursiveRewriteOp> { 970 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 971 972 void initialize() { 973 // The conversion target handles bounding the recursion of this pattern. 974 setHasBoundedRewriteRecursion(); 975 } 976 977 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 978 PatternRewriter &rewriter) const final { 979 // Decrement the depth of the op in-place. 980 rewriter.modifyOpInPlace(op, [&] { 981 op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1)); 982 }); 983 return success(); 984 } 985 }; 986 987 struct TestNestedOpCreationUndoRewrite 988 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 989 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 990 991 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 992 PatternRewriter &rewriter) const final { 993 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 994 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 995 return success(); 996 }; 997 }; 998 999 // This pattern matches `test.blackhole` and delete this op and its producer. 1000 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 1001 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 1002 1003 LogicalResult matchAndRewrite(BlackHoleOp op, 1004 PatternRewriter &rewriter) const final { 1005 Operation *producer = op.getOperand().getDefiningOp(); 1006 // Always erase the user before the producer, the framework should handle 1007 // this correctly. 1008 rewriter.eraseOp(op); 1009 rewriter.eraseOp(producer); 1010 return success(); 1011 }; 1012 }; 1013 1014 // This pattern replaces explicitly illegal op with explicitly legal op, 1015 // but in addition creates unregistered operation. 1016 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { 1017 using OpRewritePattern<ILLegalOpG>::OpRewritePattern; 1018 1019 LogicalResult matchAndRewrite(ILLegalOpG op, 1020 PatternRewriter &rewriter) const final { 1021 IntegerAttr attr = rewriter.getI32IntegerAttr(0); 1022 Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr); 1023 rewriter.replaceOpWithNewOp<LegalOpC>(op, val); 1024 return success(); 1025 }; 1026 }; 1027 } // namespace 1028 1029 namespace { 1030 struct TestTypeConverter : public TypeConverter { 1031 using TypeConverter::TypeConverter; 1032 TestTypeConverter() { 1033 addConversion(convertType); 1034 addArgumentMaterialization(materializeCast); 1035 addSourceMaterialization(materializeCast); 1036 } 1037 1038 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 1039 // Drop I16 types. 1040 if (t.isSignlessInteger(16)) 1041 return success(); 1042 1043 // Convert I64 to F64. 1044 if (t.isSignlessInteger(64)) { 1045 results.push_back(FloatType::getF64(t.getContext())); 1046 return success(); 1047 } 1048 1049 // Convert I42 to I43. 1050 if (t.isInteger(42)) { 1051 results.push_back(IntegerType::get(t.getContext(), 43)); 1052 return success(); 1053 } 1054 1055 // Split F32 into F16,F16. 1056 if (t.isF32()) { 1057 results.assign(2, FloatType::getF16(t.getContext())); 1058 return success(); 1059 } 1060 1061 // Otherwise, convert the type directly. 1062 results.push_back(t); 1063 return success(); 1064 } 1065 1066 /// Hook for materializing a conversion. This is necessary because we generate 1067 /// 1->N type mappings. 1068 static std::optional<Value> materializeCast(OpBuilder &builder, 1069 Type resultType, 1070 ValueRange inputs, Location loc) { 1071 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1072 } 1073 }; 1074 1075 struct TestLegalizePatternDriver 1076 : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> { 1077 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver) 1078 1079 StringRef getArgument() const final { return "test-legalize-patterns"; } 1080 StringRef getDescription() const final { 1081 return "Run test dialect legalization patterns"; 1082 } 1083 /// The mode of conversion to use with the driver. 1084 enum class ConversionMode { Analysis, Full, Partial }; 1085 1086 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 1087 1088 void getDependentDialects(DialectRegistry ®istry) const override { 1089 registry.insert<func::FuncDialect, test::TestDialect>(); 1090 } 1091 1092 void runOnOperation() override { 1093 TestTypeConverter converter; 1094 mlir::RewritePatternSet patterns(&getContext()); 1095 populateWithGenerated(patterns); 1096 patterns 1097 .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 1098 TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, 1099 TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, 1100 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 1101 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 1102 TestNonRootReplacement, TestBoundedRecursiveRewrite, 1103 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, 1104 TestCreateUnregisteredOp, TestUndoMoveOpBefore, 1105 TestUndoPropertiesModification>(&getContext()); 1106 patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); 1107 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1108 converter); 1109 mlir::populateCallOpTypeConversionPattern(patterns, converter); 1110 1111 // Define the conversion target used for the test. 1112 ConversionTarget target(getContext()); 1113 target.addLegalOp<ModuleOp>(); 1114 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, 1115 TerminatorOp, OneRegionOp>(); 1116 target 1117 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 1118 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 1119 // Don't allow F32 operands. 1120 return llvm::none_of(op.getOperandTypes(), 1121 [](Type type) { return type.isF32(); }); 1122 }); 1123 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1124 return converter.isSignatureLegal(op.getFunctionType()) && 1125 converter.isLegal(&op.getBody()); 1126 }); 1127 target.addDynamicallyLegalOp<func::CallOp>( 1128 [&](func::CallOp op) { return converter.isLegal(op); }); 1129 1130 // TestCreateUnregisteredOp creates `arith.constant` operation, 1131 // which was not added to target intentionally to test 1132 // correct error code from conversion driver. 1133 target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); 1134 1135 // Expect the type_producer/type_consumer operations to only operate on f64. 1136 target.addDynamicallyLegalOp<TestTypeProducerOp>( 1137 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1138 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1139 return op.getOperand().getType().isF64(); 1140 }); 1141 1142 // Check support for marking certain operations as recursively legal. 1143 target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) { 1144 return static_cast<bool>( 1145 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 1146 }); 1147 1148 // Mark the bound recursion operation as dynamically legal. 1149 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 1150 [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); 1151 1152 // Handle a partial conversion. 1153 if (mode == ConversionMode::Partial) { 1154 DenseSet<Operation *> unlegalizedOps; 1155 ConversionConfig config; 1156 config.unlegalizedOps = &unlegalizedOps; 1157 if (failed(applyPartialConversion(getOperation(), target, 1158 std::move(patterns), config))) { 1159 getOperation()->emitRemark() << "applyPartialConversion failed"; 1160 } 1161 // Emit remarks for each legalizable operation. 1162 for (auto *op : unlegalizedOps) 1163 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1164 return; 1165 } 1166 1167 // Handle a full conversion. 1168 if (mode == ConversionMode::Full) { 1169 // Check support for marking unknown operations as dynamically legal. 1170 target.markUnknownOpDynamicallyLegal([](Operation *op) { 1171 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 1172 }); 1173 1174 if (failed(applyFullConversion(getOperation(), target, 1175 std::move(patterns)))) { 1176 getOperation()->emitRemark() << "applyFullConversion failed"; 1177 } 1178 return; 1179 } 1180 1181 // Otherwise, handle an analysis conversion. 1182 assert(mode == ConversionMode::Analysis); 1183 1184 // Analyze the convertible operations. 1185 DenseSet<Operation *> legalizedOps; 1186 ConversionConfig config; 1187 config.legalizableOps = &legalizedOps; 1188 if (failed(applyAnalysisConversion(getOperation(), target, 1189 std::move(patterns), config))) 1190 return signalPassFailure(); 1191 1192 // Emit remarks for each legalizable operation. 1193 for (auto *op : legalizedOps) 1194 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 1195 } 1196 1197 /// The mode of conversion to use. 1198 ConversionMode mode; 1199 }; 1200 } // namespace 1201 1202 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 1203 legalizerConversionMode( 1204 "test-legalize-mode", 1205 llvm::cl::desc("The legalization mode to use with the test driver"), 1206 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 1207 llvm::cl::values( 1208 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 1209 "analysis", "Perform an analysis conversion"), 1210 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 1211 "Perform a full conversion"), 1212 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 1213 "partial", "Perform a partial conversion"))); 1214 1215 //===----------------------------------------------------------------------===// 1216 // ConversionPatternRewriter::getRemappedValue testing. This method is used 1217 // to get the remapped value of an original value that was replaced using 1218 // ConversionPatternRewriter. 1219 namespace { 1220 struct TestRemapValueTypeConverter : public TypeConverter { 1221 using TypeConverter::TypeConverter; 1222 1223 TestRemapValueTypeConverter() { 1224 addConversion( 1225 [](Float32Type type) { return Float64Type::get(type.getContext()); }); 1226 addConversion([](Type type) { return type; }); 1227 } 1228 }; 1229 1230 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 1231 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 1232 /// operand twice. 1233 /// 1234 /// Example: 1235 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 1236 /// is replaced with: 1237 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 1238 struct OneVResOneVOperandOp1Converter 1239 : public OpConversionPattern<OneVResOneVOperandOp1> { 1240 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 1241 1242 LogicalResult 1243 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, 1244 ConversionPatternRewriter &rewriter) const override { 1245 auto origOps = op.getOperands(); 1246 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 1247 "One operand expected"); 1248 Value origOp = *origOps.begin(); 1249 SmallVector<Value, 2> remappedOperands; 1250 // Replicate the remapped original operand twice. Note that we don't used 1251 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 1252 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1253 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1254 1255 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 1256 remappedOperands); 1257 return success(); 1258 } 1259 }; 1260 1261 /// A rewriter pattern that tests that blocks can be merged. 1262 struct TestRemapValueInRegion 1263 : public OpConversionPattern<TestRemappedValueRegionOp> { 1264 using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern; 1265 1266 LogicalResult 1267 matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor, 1268 ConversionPatternRewriter &rewriter) const final { 1269 Block &block = op.getBody().front(); 1270 Operation *terminator = block.getTerminator(); 1271 1272 // Merge the block into the parent region. 1273 Block *parentBlock = op->getBlock(); 1274 Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator()); 1275 rewriter.mergeBlocks(&block, parentBlock, ValueRange()); 1276 rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange()); 1277 1278 // Replace the results of this operation with the remapped terminator 1279 // values. 1280 SmallVector<Value> terminatorOperands; 1281 if (failed(rewriter.getRemappedValues(terminator->getOperands(), 1282 terminatorOperands))) 1283 return failure(); 1284 1285 rewriter.eraseOp(terminator); 1286 rewriter.replaceOp(op, terminatorOperands); 1287 return success(); 1288 } 1289 }; 1290 1291 struct TestRemappedValue 1292 : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> { 1293 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue) 1294 1295 StringRef getArgument() const final { return "test-remapped-value"; } 1296 StringRef getDescription() const final { 1297 return "Test public remapped value mechanism in ConversionPatternRewriter"; 1298 } 1299 void runOnOperation() override { 1300 TestRemapValueTypeConverter typeConverter; 1301 1302 mlir::RewritePatternSet patterns(&getContext()); 1303 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 1304 patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>( 1305 &getContext()); 1306 patterns.add<TestRemapValueInRegion>(typeConverter, &getContext()); 1307 1308 mlir::ConversionTarget target(getContext()); 1309 target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>(); 1310 1311 // Expect the type_producer/type_consumer operations to only operate on f64. 1312 target.addDynamicallyLegalOp<TestTypeProducerOp>( 1313 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1314 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1315 return op.getOperand().getType().isF64(); 1316 }); 1317 1318 // We make OneVResOneVOperandOp1 legal only when it has more that one 1319 // operand. This will trigger the conversion that will replace one-operand 1320 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 1321 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 1322 [](Operation *op) { return op->getNumOperands() > 1; }); 1323 1324 if (failed(mlir::applyFullConversion(getOperation(), target, 1325 std::move(patterns)))) { 1326 signalPassFailure(); 1327 } 1328 } 1329 }; 1330 } // namespace 1331 1332 //===----------------------------------------------------------------------===// 1333 // Test patterns without a specific root operation kind 1334 //===----------------------------------------------------------------------===// 1335 1336 namespace { 1337 /// This pattern matches and removes any operation in the test dialect. 1338 struct RemoveTestDialectOps : public RewritePattern { 1339 RemoveTestDialectOps(MLIRContext *context) 1340 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 1341 1342 LogicalResult matchAndRewrite(Operation *op, 1343 PatternRewriter &rewriter) const override { 1344 if (!isa<TestDialect>(op->getDialect())) 1345 return failure(); 1346 rewriter.eraseOp(op); 1347 return success(); 1348 } 1349 }; 1350 1351 struct TestUnknownRootOpDriver 1352 : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> { 1353 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver) 1354 1355 StringRef getArgument() const final { 1356 return "test-legalize-unknown-root-patterns"; 1357 } 1358 StringRef getDescription() const final { 1359 return "Test public remapped value mechanism in ConversionPatternRewriter"; 1360 } 1361 void runOnOperation() override { 1362 mlir::RewritePatternSet patterns(&getContext()); 1363 patterns.add<RemoveTestDialectOps>(&getContext()); 1364 1365 mlir::ConversionTarget target(getContext()); 1366 target.addIllegalDialect<TestDialect>(); 1367 if (failed(applyPartialConversion(getOperation(), target, 1368 std::move(patterns)))) 1369 signalPassFailure(); 1370 } 1371 }; 1372 } // namespace 1373 1374 //===----------------------------------------------------------------------===// 1375 // Test patterns that uses operations and types defined at runtime 1376 //===----------------------------------------------------------------------===// 1377 1378 namespace { 1379 /// This pattern matches dynamic operations 'test.one_operand_two_results' and 1380 /// replace them with dynamic operations 'test.generic_dynamic_op'. 1381 struct RewriteDynamicOp : public RewritePattern { 1382 RewriteDynamicOp(MLIRContext *context) 1383 : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1, 1384 context) {} 1385 1386 LogicalResult matchAndRewrite(Operation *op, 1387 PatternRewriter &rewriter) const override { 1388 assert(op->getName().getStringRef() == 1389 "test.dynamic_one_operand_two_results" && 1390 "rewrite pattern should only match operations with the right name"); 1391 1392 OperationState state(op->getLoc(), "test.dynamic_generic", 1393 op->getOperands(), op->getResultTypes(), 1394 op->getAttrs()); 1395 auto *newOp = rewriter.create(state); 1396 rewriter.replaceOp(op, newOp->getResults()); 1397 return success(); 1398 } 1399 }; 1400 1401 struct TestRewriteDynamicOpDriver 1402 : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> { 1403 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver) 1404 1405 void getDependentDialects(DialectRegistry ®istry) const override { 1406 registry.insert<TestDialect>(); 1407 } 1408 StringRef getArgument() const final { return "test-rewrite-dynamic-op"; } 1409 StringRef getDescription() const final { 1410 return "Test rewritting on dynamic operations"; 1411 } 1412 void runOnOperation() override { 1413 RewritePatternSet patterns(&getContext()); 1414 patterns.add<RewriteDynamicOp>(&getContext()); 1415 1416 ConversionTarget target(getContext()); 1417 target.addIllegalOp( 1418 OperationName("test.dynamic_one_operand_two_results", &getContext())); 1419 target.addLegalOp(OperationName("test.dynamic_generic", &getContext())); 1420 if (failed(applyPartialConversion(getOperation(), target, 1421 std::move(patterns)))) 1422 signalPassFailure(); 1423 } 1424 }; 1425 } // end anonymous namespace 1426 1427 //===----------------------------------------------------------------------===// 1428 // Test type conversions 1429 //===----------------------------------------------------------------------===// 1430 1431 namespace { 1432 struct TestTypeConversionProducer 1433 : public OpConversionPattern<TestTypeProducerOp> { 1434 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 1435 LogicalResult 1436 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, 1437 ConversionPatternRewriter &rewriter) const final { 1438 Type resultType = op.getType(); 1439 Type convertedType = getTypeConverter() 1440 ? getTypeConverter()->convertType(resultType) 1441 : resultType; 1442 if (isa<FloatType>(resultType)) 1443 resultType = rewriter.getF64Type(); 1444 else if (resultType.isInteger(16)) 1445 resultType = rewriter.getIntegerType(64); 1446 else if (isa<test::TestRecursiveType>(resultType) && 1447 convertedType != resultType) 1448 resultType = convertedType; 1449 else 1450 return failure(); 1451 1452 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 1453 return success(); 1454 } 1455 }; 1456 1457 /// Call signature conversion and then fail the rewrite to trigger the undo 1458 /// mechanism. 1459 struct TestSignatureConversionUndo 1460 : public OpConversionPattern<TestSignatureConversionUndoOp> { 1461 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 1462 1463 LogicalResult 1464 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, 1465 ConversionPatternRewriter &rewriter) const final { 1466 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 1467 return failure(); 1468 } 1469 }; 1470 1471 /// Call signature conversion without providing a type converter to handle 1472 /// materializations. 1473 struct TestTestSignatureConversionNoConverter 1474 : public OpConversionPattern<TestSignatureConversionNoConverterOp> { 1475 TestTestSignatureConversionNoConverter(const TypeConverter &converter, 1476 MLIRContext *context) 1477 : OpConversionPattern<TestSignatureConversionNoConverterOp>(context), 1478 converter(converter) {} 1479 1480 LogicalResult 1481 matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor, 1482 ConversionPatternRewriter &rewriter) const final { 1483 Region ®ion = op->getRegion(0); 1484 Block *entry = ®ion.front(); 1485 1486 // Convert the original entry arguments. 1487 TypeConverter::SignatureConversion result(entry->getNumArguments()); 1488 if (failed( 1489 converter.convertSignatureArgs(entry->getArgumentTypes(), result))) 1490 return failure(); 1491 rewriter.modifyOpInPlace( 1492 op, [&] { rewriter.applySignatureConversion(®ion, result); }); 1493 return success(); 1494 } 1495 1496 const TypeConverter &converter; 1497 }; 1498 1499 /// Just forward the operands to the root op. This is essentially a no-op 1500 /// pattern that is used to trigger target materialization. 1501 struct TestTypeConsumerForward 1502 : public OpConversionPattern<TestTypeConsumerOp> { 1503 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 1504 1505 LogicalResult 1506 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, 1507 ConversionPatternRewriter &rewriter) const final { 1508 rewriter.modifyOpInPlace(op, 1509 [&] { op->setOperands(adaptor.getOperands()); }); 1510 return success(); 1511 } 1512 }; 1513 1514 struct TestTypeConversionAnotherProducer 1515 : public OpRewritePattern<TestAnotherTypeProducerOp> { 1516 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 1517 1518 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 1519 PatternRewriter &rewriter) const final { 1520 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 1521 return success(); 1522 } 1523 }; 1524 1525 struct TestTypeConversionDriver 1526 : public PassWrapper<TestTypeConversionDriver, OperationPass<>> { 1527 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver) 1528 1529 void getDependentDialects(DialectRegistry ®istry) const override { 1530 registry.insert<TestDialect>(); 1531 } 1532 StringRef getArgument() const final { 1533 return "test-legalize-type-conversion"; 1534 } 1535 StringRef getDescription() const final { 1536 return "Test various type conversion functionalities in DialectConversion"; 1537 } 1538 1539 void runOnOperation() override { 1540 // Initialize the type converter. 1541 SmallVector<Type, 2> conversionCallStack; 1542 TypeConverter converter; 1543 1544 /// Add the legal set of type conversions. 1545 converter.addConversion([](Type type) -> Type { 1546 // Treat F64 as legal. 1547 if (type.isF64()) 1548 return type; 1549 // Allow converting BF16/F16/F32 to F64. 1550 if (type.isBF16() || type.isF16() || type.isF32()) 1551 return FloatType::getF64(type.getContext()); 1552 // Otherwise, the type is illegal. 1553 return nullptr; 1554 }); 1555 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 1556 // Drop all integer types. 1557 return success(); 1558 }); 1559 converter.addConversion( 1560 // Convert a recursive self-referring type into a non-self-referring 1561 // type named "outer_converted_type" that contains a SimpleAType. 1562 [&](test::TestRecursiveType type, 1563 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { 1564 // If the type is already converted, return it to indicate that it is 1565 // legal. 1566 if (type.getName() == "outer_converted_type") { 1567 results.push_back(type); 1568 return success(); 1569 } 1570 1571 conversionCallStack.push_back(type); 1572 auto popConversionCallStack = llvm::make_scope_exit( 1573 [&conversionCallStack]() { conversionCallStack.pop_back(); }); 1574 1575 // If the type is on the call stack more than once (it is there at 1576 // least once because of the _current_ call, which is always the last 1577 // element on the stack), we've hit the recursive case. Just return 1578 // SimpleAType here to create a non-recursive type as a result. 1579 if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(), 1580 type)) { 1581 results.push_back(test::SimpleAType::get(type.getContext())); 1582 return success(); 1583 } 1584 1585 // Convert the body recursively. 1586 auto result = test::TestRecursiveType::get(type.getContext(), 1587 "outer_converted_type"); 1588 if (failed(result.setBody(converter.convertType(type.getBody())))) 1589 return failure(); 1590 results.push_back(result); 1591 return success(); 1592 }); 1593 1594 /// Add the legal set of type materializations. 1595 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 1596 ValueRange inputs, 1597 Location loc) -> Value { 1598 // Allow casting from F64 back to F32. 1599 if (!resultType.isF16() && inputs.size() == 1 && 1600 inputs[0].getType().isF64()) 1601 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1602 // Allow producing an i32 or i64 from nothing. 1603 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 1604 inputs.empty()) 1605 return builder.create<TestTypeProducerOp>(loc, resultType); 1606 // Allow producing an i64 from an integer. 1607 if (isa<IntegerType>(resultType) && inputs.size() == 1 && 1608 isa<IntegerType>(inputs[0].getType())) 1609 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1610 // Otherwise, fail. 1611 return nullptr; 1612 }); 1613 1614 // Initialize the conversion target. 1615 mlir::ConversionTarget target(getContext()); 1616 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 1617 auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType()); 1618 return op.getType().isF64() || op.getType().isInteger(64) || 1619 (recursiveType && 1620 recursiveType.getName() == "outer_converted_type"); 1621 }); 1622 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1623 return converter.isSignatureLegal(op.getFunctionType()) && 1624 converter.isLegal(&op.getBody()); 1625 }); 1626 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 1627 // Allow casts from F64 to F32. 1628 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 1629 }); 1630 target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>( 1631 [&](TestSignatureConversionNoConverterOp op) { 1632 return converter.isLegal(op.getRegion().front().getArgumentTypes()); 1633 }); 1634 1635 // Initialize the set of rewrite patterns. 1636 RewritePatternSet patterns(&getContext()); 1637 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 1638 TestSignatureConversionUndo, 1639 TestTestSignatureConversionNoConverter>(converter, 1640 &getContext()); 1641 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 1642 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1643 converter); 1644 1645 if (failed(applyPartialConversion(getOperation(), target, 1646 std::move(patterns)))) 1647 signalPassFailure(); 1648 } 1649 }; 1650 } // namespace 1651 1652 //===----------------------------------------------------------------------===// 1653 // Test Target Materialization With No Uses 1654 //===----------------------------------------------------------------------===// 1655 1656 namespace { 1657 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> { 1658 using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern; 1659 1660 LogicalResult 1661 matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor, 1662 ConversionPatternRewriter &rewriter) const final { 1663 rewriter.replaceOp(op, adaptor.getOperands()); 1664 return success(); 1665 } 1666 }; 1667 1668 struct TestTargetMaterializationWithNoUses 1669 : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> { 1670 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 1671 TestTargetMaterializationWithNoUses) 1672 1673 StringRef getArgument() const final { 1674 return "test-target-materialization-with-no-uses"; 1675 } 1676 StringRef getDescription() const final { 1677 return "Test a special case of target materialization in DialectConversion"; 1678 } 1679 1680 void runOnOperation() override { 1681 TypeConverter converter; 1682 converter.addConversion([](Type t) { return t; }); 1683 converter.addConversion([](IntegerType intTy) -> Type { 1684 if (intTy.getWidth() == 16) 1685 return IntegerType::get(intTy.getContext(), 64); 1686 return intTy; 1687 }); 1688 converter.addTargetMaterialization( 1689 [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { 1690 return builder.create<TestCastOp>(loc, type, inputs).getResult(); 1691 }); 1692 1693 ConversionTarget target(getContext()); 1694 target.addIllegalOp<TestTypeChangerOp>(); 1695 1696 RewritePatternSet patterns(&getContext()); 1697 patterns.add<ForwardOperandPattern>(converter, &getContext()); 1698 1699 if (failed(applyPartialConversion(getOperation(), target, 1700 std::move(patterns)))) 1701 signalPassFailure(); 1702 } 1703 }; 1704 } // namespace 1705 1706 //===----------------------------------------------------------------------===// 1707 // Test Block Merging 1708 //===----------------------------------------------------------------------===// 1709 1710 namespace { 1711 /// A rewriter pattern that tests that blocks can be merged. 1712 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 1713 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 1714 1715 LogicalResult 1716 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, 1717 ConversionPatternRewriter &rewriter) const final { 1718 Block &firstBlock = op.getBody().front(); 1719 Operation *branchOp = firstBlock.getTerminator(); 1720 Block *secondBlock = &*(std::next(op.getBody().begin())); 1721 auto succOperands = branchOp->getOperands(); 1722 SmallVector<Value, 2> replacements(succOperands); 1723 rewriter.eraseOp(branchOp); 1724 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1725 rewriter.modifyOpInPlace(op, [] {}); 1726 return success(); 1727 } 1728 }; 1729 1730 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 1731 struct TestUndoBlocksMerge : public ConversionPattern { 1732 TestUndoBlocksMerge(MLIRContext *ctx) 1733 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 1734 LogicalResult 1735 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1736 ConversionPatternRewriter &rewriter) const final { 1737 Block &firstBlock = op->getRegion(0).front(); 1738 Operation *branchOp = firstBlock.getTerminator(); 1739 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 1740 rewriter.setInsertionPointToStart(secondBlock); 1741 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 1742 auto succOperands = branchOp->getOperands(); 1743 SmallVector<Value, 2> replacements(succOperands); 1744 rewriter.eraseOp(branchOp); 1745 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1746 rewriter.modifyOpInPlace(op, [] {}); 1747 return success(); 1748 } 1749 }; 1750 1751 /// A rewrite mechanism to inline the body of the op into its parent, when both 1752 /// ops can have a single block. 1753 struct TestMergeSingleBlockOps 1754 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 1755 using OpConversionPattern< 1756 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 1757 1758 LogicalResult 1759 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, 1760 ConversionPatternRewriter &rewriter) const final { 1761 SingleBlockImplicitTerminatorOp parentOp = 1762 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1763 if (!parentOp) 1764 return failure(); 1765 Block &innerBlock = op.getRegion().front(); 1766 TerminatorOp innerTerminator = 1767 cast<TerminatorOp>(innerBlock.getTerminator()); 1768 rewriter.inlineBlockBefore(&innerBlock, op); 1769 rewriter.eraseOp(innerTerminator); 1770 rewriter.eraseOp(op); 1771 return success(); 1772 } 1773 }; 1774 1775 struct TestMergeBlocksPatternDriver 1776 : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> { 1777 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver) 1778 1779 StringRef getArgument() const final { return "test-merge-blocks"; } 1780 StringRef getDescription() const final { 1781 return "Test Merging operation in ConversionPatternRewriter"; 1782 } 1783 void runOnOperation() override { 1784 MLIRContext *context = &getContext(); 1785 mlir::RewritePatternSet patterns(context); 1786 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 1787 context); 1788 ConversionTarget target(*context); 1789 target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 1790 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 1791 target.addIllegalOp<ILLegalOpF>(); 1792 1793 /// Expect the op to have a single block after legalization. 1794 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 1795 [&](TestMergeBlocksOp op) -> bool { 1796 return llvm::hasSingleElement(op.getBody()); 1797 }); 1798 1799 /// Only allow `test.br` within test.merge_blocks op. 1800 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 1801 return op->getParentOfType<TestMergeBlocksOp>(); 1802 }); 1803 1804 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 1805 /// inlined. 1806 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 1807 [&](SingleBlockImplicitTerminatorOp op) -> bool { 1808 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1809 }); 1810 1811 DenseSet<Operation *> unlegalizedOps; 1812 ConversionConfig config; 1813 config.unlegalizedOps = &unlegalizedOps; 1814 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 1815 config); 1816 for (auto *op : unlegalizedOps) 1817 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1818 } 1819 }; 1820 } // namespace 1821 1822 //===----------------------------------------------------------------------===// 1823 // Test Selective Replacement 1824 //===----------------------------------------------------------------------===// 1825 1826 namespace { 1827 /// A rewrite mechanism to inline the body of the op into its parent, when both 1828 /// ops can have a single block. 1829 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 1830 using OpRewritePattern<TestCastOp>::OpRewritePattern; 1831 1832 LogicalResult matchAndRewrite(TestCastOp op, 1833 PatternRewriter &rewriter) const final { 1834 if (op.getNumOperands() != 2) 1835 return failure(); 1836 OperandRange operands = op.getOperands(); 1837 1838 // Replace non-terminator uses with the first operand. 1839 rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) { 1840 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 1841 }); 1842 // Replace everything else with the second operand if the operation isn't 1843 // dead. 1844 rewriter.replaceOp(op, op.getOperand(1)); 1845 return success(); 1846 } 1847 }; 1848 1849 struct TestSelectiveReplacementPatternDriver 1850 : public PassWrapper<TestSelectiveReplacementPatternDriver, 1851 OperationPass<>> { 1852 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 1853 TestSelectiveReplacementPatternDriver) 1854 1855 StringRef getArgument() const final { 1856 return "test-pattern-selective-replacement"; 1857 } 1858 StringRef getDescription() const final { 1859 return "Test selective replacement in the PatternRewriter"; 1860 } 1861 void runOnOperation() override { 1862 MLIRContext *context = &getContext(); 1863 mlir::RewritePatternSet patterns(context); 1864 patterns.add<TestSelectiveOpReplacementPattern>(context); 1865 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 1866 } 1867 }; 1868 } // namespace 1869 1870 //===----------------------------------------------------------------------===// 1871 // PassRegistration 1872 //===----------------------------------------------------------------------===// 1873 1874 namespace mlir { 1875 namespace test { 1876 void registerPatternsTestPass() { 1877 PassRegistration<TestReturnTypeDriver>(); 1878 1879 PassRegistration<TestDerivedAttributeDriver>(); 1880 1881 PassRegistration<TestPatternDriver>(); 1882 PassRegistration<TestStrictPatternDriver>(); 1883 1884 PassRegistration<TestLegalizePatternDriver>([] { 1885 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 1886 }); 1887 1888 PassRegistration<TestRemappedValue>(); 1889 1890 PassRegistration<TestUnknownRootOpDriver>(); 1891 1892 PassRegistration<TestTypeConversionDriver>(); 1893 PassRegistration<TestTargetMaterializationWithNoUses>(); 1894 1895 PassRegistration<TestRewriteDynamicOpDriver>(); 1896 1897 PassRegistration<TestMergeBlocksPatternDriver>(); 1898 PassRegistration<TestSelectiveReplacementPatternDriver>(); 1899 } 1900 } // namespace test 1901 } // namespace mlir 1902