1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestTypes.h" 11 #include "mlir/Dialect/Arith/IR/Arith.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/Matchers.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 #include "mlir/Transforms/FoldUtils.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "llvm/ADT/ScopeExit.h" 21 22 using namespace mlir; 23 using namespace test; 24 25 // Native function for testing NativeCodeCall 26 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 27 return choice.getValue() ? input1 : input2; 28 } 29 30 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 31 rewriter.create<OpI>(loc, input); 32 } 33 34 static void handleNoResultOp(PatternRewriter &rewriter, 35 OpSymbolBindingNoResult op) { 36 // Turn the no result op to a one-result op. 37 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), 38 op.getOperand()); 39 } 40 41 static bool getFirstI32Result(Operation *op, Value &value) { 42 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 43 return false; 44 value = op->getResult(0); 45 return true; 46 } 47 48 static Value bindNativeCodeCallResult(Value value) { return value; } 49 50 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, 51 Value input2) { 52 return SmallVector<Value, 2>({input2, input1}); 53 } 54 55 // Test that natives calls are only called once during rewrites. 56 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 57 // This let us check the number of times OpM_Test was called by inspecting 58 // the returned value in the MLIR output. 59 static int64_t opMIncreasingValue = 314159265; 60 static Attribute opMTest(PatternRewriter &rewriter, Value val) { 61 int64_t i = opMIncreasingValue++; 62 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 63 } 64 65 namespace { 66 #include "TestPatterns.inc" 67 } // namespace 68 69 //===----------------------------------------------------------------------===// 70 // Test Reduce Pattern Interface 71 //===----------------------------------------------------------------------===// 72 73 void test::populateTestReductionPatterns(RewritePatternSet &patterns) { 74 populateWithGenerated(patterns); 75 } 76 77 //===----------------------------------------------------------------------===// 78 // Canonicalizer Driver. 79 //===----------------------------------------------------------------------===// 80 81 namespace { 82 struct FoldingPattern : public RewritePattern { 83 public: 84 FoldingPattern(MLIRContext *context) 85 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 86 /*benefit=*/1, context) {} 87 88 LogicalResult matchAndRewrite(Operation *op, 89 PatternRewriter &rewriter) const override { 90 // Exercise createOrFold API for a single-result operation that is folded 91 // upon construction. The operation being created has an in-place folder, 92 // and it should be still present in the output. Furthermore, the folder 93 // should not crash when attempting to recover the (unchanged) operation 94 // result. 95 Value result = rewriter.createOrFold<TestOpInPlaceFold>( 96 op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0)); 97 assert(result); 98 rewriter.replaceOp(op, result); 99 return success(); 100 } 101 }; 102 103 /// This pattern creates a foldable operation at the entry point of the block. 104 /// This tests the situation where the operation folder will need to replace an 105 /// operation with a previously created constant that does not initially 106 /// dominate the operation to replace. 107 struct FolderInsertBeforePreviouslyFoldedConstantPattern 108 : public OpRewritePattern<TestCastOp> { 109 public: 110 using OpRewritePattern<TestCastOp>::OpRewritePattern; 111 112 LogicalResult matchAndRewrite(TestCastOp op, 113 PatternRewriter &rewriter) const override { 114 if (!op->hasAttr("test_fold_before_previously_folded_op")) 115 return failure(); 116 rewriter.setInsertionPointToStart(op->getBlock()); 117 118 auto constOp = rewriter.create<arith::ConstantOp>( 119 op.getLoc(), rewriter.getBoolAttr(true)); 120 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), 121 Value(constOp)); 122 return success(); 123 } 124 }; 125 126 /// This pattern matches test.op_commutative2 with the first operand being 127 /// another test.op_commutative2 with a constant on the right side and fold it 128 /// away by propagating it as its result. This is intend to check that patterns 129 /// are applied after the commutative property moves constant to the right. 130 struct FolderCommutativeOp2WithConstant 131 : public OpRewritePattern<TestCommutative2Op> { 132 public: 133 using OpRewritePattern<TestCommutative2Op>::OpRewritePattern; 134 135 LogicalResult matchAndRewrite(TestCommutative2Op op, 136 PatternRewriter &rewriter) const override { 137 auto operand = 138 dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp()); 139 if (!operand) 140 return failure(); 141 Attribute constInput; 142 if (!matchPattern(operand->getOperand(1), m_Constant(&constInput))) 143 return failure(); 144 rewriter.replaceOp(op, operand->getOperand(1)); 145 return success(); 146 } 147 }; 148 149 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer 150 /// attribute with value smaller than MaxVal, it increments the value by 1. 151 template <int MaxVal> 152 struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> { 153 using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern; 154 155 LogicalResult matchAndRewrite(AnyAttrOfOp op, 156 PatternRewriter &rewriter) const override { 157 auto intAttr = dyn_cast<IntegerAttr>(op.getAttr()); 158 if (!intAttr) 159 return failure(); 160 int64_t val = intAttr.getInt(); 161 if (val >= MaxVal) 162 return failure(); 163 rewriter.modifyOpInPlace( 164 op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); }); 165 return success(); 166 } 167 }; 168 169 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op". 170 struct MakeOpEligible : public RewritePattern { 171 MakeOpEligible(MLIRContext *context) 172 : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {} 173 174 LogicalResult matchAndRewrite(Operation *op, 175 PatternRewriter &rewriter) const override { 176 if (op->hasAttr("eligible")) 177 return failure(); 178 rewriter.modifyOpInPlace( 179 op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); }); 180 return success(); 181 } 182 }; 183 184 /// This pattern hoists eligible ops out of a "test.one_region_op". 185 struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> { 186 using OpRewritePattern<test::OneRegionOp>::OpRewritePattern; 187 188 LogicalResult matchAndRewrite(test::OneRegionOp op, 189 PatternRewriter &rewriter) const override { 190 Operation *terminator = op.getRegion().front().getTerminator(); 191 Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp(); 192 if (toBeHoisted->getParentOp() != op) 193 return failure(); 194 if (!toBeHoisted->hasAttr("eligible")) 195 return failure(); 196 rewriter.moveOpBefore(toBeHoisted, op); 197 return success(); 198 } 199 }; 200 201 /// This pattern moves "test.move_before_parent_op" before the parent op. 202 struct MoveBeforeParentOp : public RewritePattern { 203 MoveBeforeParentOp(MLIRContext *context) 204 : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {} 205 206 LogicalResult matchAndRewrite(Operation *op, 207 PatternRewriter &rewriter) const override { 208 // Do not hoist past functions. 209 if (isa<FunctionOpInterface>(op->getParentOp())) 210 return failure(); 211 rewriter.moveOpBefore(op, op->getParentOp()); 212 return success(); 213 } 214 }; 215 216 /// This pattern inlines blocks that are nested in 217 /// "test.inline_blocks_into_parent" into the parent block. 218 struct InlineBlocksIntoParent : public RewritePattern { 219 InlineBlocksIntoParent(MLIRContext *context) 220 : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1, 221 context) {} 222 223 LogicalResult matchAndRewrite(Operation *op, 224 PatternRewriter &rewriter) const override { 225 bool changed = false; 226 for (Region &r : op->getRegions()) { 227 while (!r.empty()) { 228 rewriter.inlineBlockBefore(&r.front(), op); 229 changed = true; 230 } 231 } 232 return success(changed); 233 } 234 }; 235 236 /// This pattern splits blocks at "test.split_block_here" and replaces the op 237 /// with a new op (to prevent an infinite loop of block splitting). 238 struct SplitBlockHere : public RewritePattern { 239 SplitBlockHere(MLIRContext *context) 240 : RewritePattern("test.split_block_here", /*benefit=*/1, context) {} 241 242 LogicalResult matchAndRewrite(Operation *op, 243 PatternRewriter &rewriter) const override { 244 rewriter.splitBlock(op->getBlock(), op->getIterator()); 245 Operation *newOp = rewriter.create( 246 op->getLoc(), 247 OperationName("test.new_op", op->getContext()).getIdentifier(), 248 op->getOperands(), op->getResultTypes()); 249 rewriter.replaceOp(op, newOp); 250 return success(); 251 } 252 }; 253 254 /// This pattern clones "test.clone_me" ops. 255 struct CloneOp : public RewritePattern { 256 CloneOp(MLIRContext *context) 257 : RewritePattern("test.clone_me", /*benefit=*/1, context) {} 258 259 LogicalResult matchAndRewrite(Operation *op, 260 PatternRewriter &rewriter) const override { 261 // Do not clone already cloned ops to avoid going into an infinite loop. 262 if (op->hasAttr("was_cloned")) 263 return failure(); 264 Operation *cloned = rewriter.clone(*op); 265 cloned->setAttr("was_cloned", rewriter.getUnitAttr()); 266 return success(); 267 } 268 }; 269 270 /// This pattern clones regions of "test.clone_region_before" ops before the 271 /// parent block. 272 struct CloneRegionBeforeOp : public RewritePattern { 273 CloneRegionBeforeOp(MLIRContext *context) 274 : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {} 275 276 LogicalResult matchAndRewrite(Operation *op, 277 PatternRewriter &rewriter) const override { 278 // Do not clone already cloned ops to avoid going into an infinite loop. 279 if (op->hasAttr("was_cloned")) 280 return failure(); 281 for (Region &r : op->getRegions()) 282 rewriter.cloneRegionBefore(r, op->getBlock()); 283 op->setAttr("was_cloned", rewriter.getUnitAttr()); 284 return success(); 285 } 286 }; 287 288 struct TestPatternDriver 289 : public PassWrapper<TestPatternDriver, OperationPass<>> { 290 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver) 291 292 TestPatternDriver() = default; 293 TestPatternDriver(const TestPatternDriver &other) : PassWrapper(other) {} 294 295 StringRef getArgument() const final { return "test-patterns"; } 296 StringRef getDescription() const final { return "Run test dialect patterns"; } 297 void runOnOperation() override { 298 mlir::RewritePatternSet patterns(&getContext()); 299 populateWithGenerated(patterns); 300 301 // Verify named pattern is generated with expected name. 302 patterns.add<FoldingPattern, TestNamedPatternRule, 303 FolderInsertBeforePreviouslyFoldedConstantPattern, 304 FolderCommutativeOp2WithConstant, HoistEligibleOps, 305 MakeOpEligible>(&getContext()); 306 307 // Additional patterns for testing the GreedyPatternRewriteDriver. 308 patterns.insert<IncrementIntAttribute<3>>(&getContext()); 309 310 GreedyRewriteConfig config; 311 config.useTopDownTraversal = this->useTopDownTraversal; 312 config.maxIterations = this->maxIterations; 313 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), 314 config); 315 } 316 317 Option<bool> useTopDownTraversal{ 318 *this, "top-down", 319 llvm::cl::desc("Seed the worklist in general top-down order"), 320 llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)}; 321 Option<int> maxIterations{ 322 *this, "max-iterations", 323 llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"), 324 llvm::cl::init(GreedyRewriteConfig().maxIterations)}; 325 }; 326 327 struct DumpNotifications : public RewriterBase::Listener { 328 void notifyBlockInserted(Block *block, Region *previous, 329 Region::iterator previousIt) override { 330 llvm::outs() << "notifyBlockInserted into " 331 << block->getParentOp()->getName() << ": "; 332 if (previous == nullptr) { 333 llvm::outs() << "was unlinked\n"; 334 } else { 335 llvm::outs() << "was linked\n"; 336 } 337 } 338 void notifyOperationInserted(Operation *op, 339 OpBuilder::InsertPoint previous) override { 340 llvm::outs() << "notifyOperationInserted: " << op->getName(); 341 if (!previous.isSet()) { 342 llvm::outs() << ", was unlinked\n"; 343 } else { 344 if (previous.getPoint() == previous.getBlock()->end()) { 345 llvm::outs() << ", was last in block\n"; 346 } else { 347 llvm::outs() << ", previous = " << previous.getPoint()->getName() 348 << "\n"; 349 } 350 } 351 } 352 void notifyOperationErased(Operation *op) override { 353 llvm::outs() << "notifyOperationErased: " << op->getName() << "\n"; 354 } 355 }; 356 357 struct TestStrictPatternDriver 358 : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> { 359 public: 360 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver) 361 362 TestStrictPatternDriver() = default; 363 TestStrictPatternDriver(const TestStrictPatternDriver &other) 364 : PassWrapper(other) { 365 strictMode = other.strictMode; 366 } 367 368 StringRef getArgument() const final { return "test-strict-pattern-driver"; } 369 StringRef getDescription() const final { 370 return "Test strict mode of pattern driver"; 371 } 372 373 void runOnOperation() override { 374 MLIRContext *ctx = &getContext(); 375 mlir::RewritePatternSet patterns(ctx); 376 patterns.add< 377 // clang-format off 378 ChangeBlockOp, 379 CloneOp, 380 CloneRegionBeforeOp, 381 EraseOp, 382 ImplicitChangeOp, 383 InlineBlocksIntoParent, 384 InsertSameOp, 385 MoveBeforeParentOp, 386 ReplaceWithNewOp, 387 SplitBlockHere 388 // clang-format on 389 >(ctx); 390 SmallVector<Operation *> ops; 391 getOperation()->walk([&](Operation *op) { 392 StringRef opName = op->getName().getStringRef(); 393 if (opName == "test.insert_same_op" || opName == "test.change_block_op" || 394 opName == "test.replace_with_new_op" || opName == "test.erase_op" || 395 opName == "test.move_before_parent_op" || 396 opName == "test.inline_blocks_into_parent" || 397 opName == "test.split_block_here" || opName == "test.clone_me" || 398 opName == "test.clone_region_before") { 399 ops.push_back(op); 400 } 401 }); 402 403 DumpNotifications dumpNotifications; 404 GreedyRewriteConfig config; 405 config.listener = &dumpNotifications; 406 if (strictMode == "AnyOp") { 407 config.strictMode = GreedyRewriteStrictness::AnyOp; 408 } else if (strictMode == "ExistingAndNewOps") { 409 config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; 410 } else if (strictMode == "ExistingOps") { 411 config.strictMode = GreedyRewriteStrictness::ExistingOps; 412 } else { 413 llvm_unreachable("invalid strictness option"); 414 } 415 416 // Check if these transformations introduce visiting of operations that 417 // are not in the `ops` set (The new created ops are valid). An invalid 418 // operation will trigger the assertion while processing. 419 bool changed = false; 420 bool allErased = false; 421 (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config, 422 &changed, &allErased); 423 Builder b(ctx); 424 getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed)); 425 getOperation()->setAttr("pattern_driver_all_erased", 426 b.getBoolAttr(allErased)); 427 } 428 429 Option<std::string> strictMode{ 430 *this, "strictness", 431 llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"), 432 llvm::cl::init("AnyOp")}; 433 434 private: 435 // New inserted operation is valid for further transformation. 436 class InsertSameOp : public RewritePattern { 437 public: 438 InsertSameOp(MLIRContext *context) 439 : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {} 440 441 LogicalResult matchAndRewrite(Operation *op, 442 PatternRewriter &rewriter) const override { 443 if (op->hasAttr("skip")) 444 return failure(); 445 446 Operation *newOp = 447 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 448 op->getOperands(), op->getResultTypes()); 449 rewriter.modifyOpInPlace( 450 op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); }); 451 newOp->setAttr("skip", rewriter.getBoolAttr(true)); 452 453 return success(); 454 } 455 }; 456 457 // Replace an operation may introduce the re-visiting of its users. 458 class ReplaceWithNewOp : public RewritePattern { 459 public: 460 ReplaceWithNewOp(MLIRContext *context) 461 : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {} 462 463 LogicalResult matchAndRewrite(Operation *op, 464 PatternRewriter &rewriter) const override { 465 Operation *newOp; 466 if (op->hasAttr("create_erase_op")) { 467 newOp = rewriter.create( 468 op->getLoc(), 469 OperationName("test.erase_op", op->getContext()).getIdentifier(), 470 ValueRange(), TypeRange()); 471 } else { 472 newOp = rewriter.create( 473 op->getLoc(), 474 OperationName("test.new_op", op->getContext()).getIdentifier(), 475 op->getOperands(), op->getResultTypes()); 476 } 477 rewriter.replaceOp(op, newOp->getResults()); 478 return success(); 479 } 480 }; 481 482 // Remove an operation may introduce the re-visiting of its operands. 483 class EraseOp : public RewritePattern { 484 public: 485 EraseOp(MLIRContext *context) 486 : RewritePattern("test.erase_op", /*benefit=*/1, context) {} 487 LogicalResult matchAndRewrite(Operation *op, 488 PatternRewriter &rewriter) const override { 489 rewriter.eraseOp(op); 490 return success(); 491 } 492 }; 493 494 // The following two patterns test RewriterBase::replaceAllUsesWith. 495 // 496 // That function replaces all usages of a Block (or a Value) with another one 497 // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver 498 // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its 499 // worklist: when an op is modified, it is added to the worklist. The two 500 // patterns below make the tracking observable: ChangeBlockOp replaces all 501 // usages of a block and that pattern is applied because the corresponding ops 502 // are put on the initial worklist (see above). ImplicitChangeOp does an 503 // unrelated change but ops of the corresponding type are *not* on the initial 504 // worklist, so the effect of the second pattern is only visible if the 505 // tracking and subsequent adding to the worklist actually works. 506 507 // Replace all usages of the first successor with the second successor. 508 class ChangeBlockOp : public RewritePattern { 509 public: 510 ChangeBlockOp(MLIRContext *context) 511 : RewritePattern("test.change_block_op", /*benefit=*/1, context) {} 512 LogicalResult matchAndRewrite(Operation *op, 513 PatternRewriter &rewriter) const override { 514 if (op->getNumSuccessors() < 2) 515 return failure(); 516 Block *firstSuccessor = op->getSuccessor(0); 517 Block *secondSuccessor = op->getSuccessor(1); 518 if (firstSuccessor == secondSuccessor) 519 return failure(); 520 // This is the function being tested: 521 rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor); 522 // Using the following line instead would make the test fail: 523 // firstSuccessor->replaceAllUsesWith(secondSuccessor); 524 return success(); 525 } 526 }; 527 528 // Changes the successor to the parent block. 529 class ImplicitChangeOp : public RewritePattern { 530 public: 531 ImplicitChangeOp(MLIRContext *context) 532 : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {} 533 LogicalResult matchAndRewrite(Operation *op, 534 PatternRewriter &rewriter) const override { 535 if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock()) 536 return failure(); 537 rewriter.modifyOpInPlace(op, 538 [&]() { op->setSuccessor(op->getBlock(), 0); }); 539 return success(); 540 } 541 }; 542 }; 543 544 } // namespace 545 546 //===----------------------------------------------------------------------===// 547 // ReturnType Driver. 548 //===----------------------------------------------------------------------===// 549 550 namespace { 551 // Generate ops for each instance where the type can be successfully inferred. 552 template <typename OpTy> 553 static void invokeCreateWithInferredReturnType(Operation *op) { 554 auto *context = op->getContext(); 555 auto fop = op->getParentOfType<func::FuncOp>(); 556 auto location = UnknownLoc::get(context); 557 OpBuilder b(op); 558 b.setInsertionPointAfter(op); 559 560 // Use permutations of 2 args as operands. 561 assert(fop.getNumArguments() >= 2); 562 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 563 for (int j = 0; j < e; ++j) { 564 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 565 SmallVector<Type, 2> inferredReturnTypes; 566 if (succeeded(OpTy::inferReturnTypes( 567 context, std::nullopt, values, op->getDiscardableAttrDictionary(), 568 op->getPropertiesStorage(), op->getRegions(), 569 inferredReturnTypes))) { 570 OperationState state(location, OpTy::getOperationName()); 571 // TODO: Expand to regions. 572 OpTy::build(b, state, values, op->getAttrs()); 573 (void)b.create(state); 574 } 575 } 576 } 577 } 578 579 static void reifyReturnShape(Operation *op) { 580 OpBuilder b(op); 581 582 // Use permutations of 2 args as operands. 583 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 584 SmallVector<Value, 2> shapes; 585 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 586 !llvm::hasSingleElement(shapes)) 587 return; 588 for (const auto &it : llvm::enumerate(shapes)) { 589 op->emitRemark() << "value " << it.index() << ": " 590 << it.value().getDefiningOp(); 591 } 592 } 593 594 struct TestReturnTypeDriver 595 : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> { 596 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver) 597 598 void getDependentDialects(DialectRegistry ®istry) const override { 599 registry.insert<tensor::TensorDialect>(); 600 } 601 StringRef getArgument() const final { return "test-return-type"; } 602 StringRef getDescription() const final { return "Run return type functions"; } 603 604 void runOnOperation() override { 605 if (getOperation().getName() == "testCreateFunctions") { 606 std::vector<Operation *> ops; 607 // Collect ops to avoid triggering on inserted ops. 608 for (auto &op : getOperation().getBody().front()) 609 ops.push_back(&op); 610 // Generate test patterns for each, but skip terminator. 611 for (auto *op : llvm::ArrayRef(ops).drop_back()) { 612 // Test create method of each of the Op classes below. The resultant 613 // output would be in reverse order underneath `op` from which 614 // the attributes and regions are used. 615 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 616 invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>( 617 op); 618 invokeCreateWithInferredReturnType< 619 OpWithShapedTypeInferTypeInterfaceOp>(op); 620 }; 621 return; 622 } 623 if (getOperation().getName() == "testReifyFunctions") { 624 std::vector<Operation *> ops; 625 // Collect ops to avoid triggering on inserted ops. 626 for (auto &op : getOperation().getBody().front()) 627 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 628 ops.push_back(&op); 629 // Generate test patterns for each, but skip terminator. 630 for (auto *op : ops) 631 reifyReturnShape(op); 632 } 633 } 634 }; 635 } // namespace 636 637 namespace { 638 struct TestDerivedAttributeDriver 639 : public PassWrapper<TestDerivedAttributeDriver, 640 OperationPass<func::FuncOp>> { 641 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver) 642 643 StringRef getArgument() const final { return "test-derived-attr"; } 644 StringRef getDescription() const final { 645 return "Run test derived attributes"; 646 } 647 void runOnOperation() override; 648 }; 649 } // namespace 650 651 void TestDerivedAttributeDriver::runOnOperation() { 652 getOperation().walk([](DerivedAttributeOpInterface dOp) { 653 auto dAttr = dOp.materializeDerivedAttributes(); 654 if (!dAttr) 655 return; 656 for (auto d : dAttr) 657 dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue(); 658 }); 659 } 660 661 //===----------------------------------------------------------------------===// 662 // Legalization Driver. 663 //===----------------------------------------------------------------------===// 664 665 namespace { 666 //===----------------------------------------------------------------------===// 667 // Region-Block Rewrite Testing 668 669 /// This pattern is a simple pattern that inlines the first region of a given 670 /// operation into the parent region. 671 struct TestRegionRewriteBlockMovement : public ConversionPattern { 672 TestRegionRewriteBlockMovement(MLIRContext *ctx) 673 : ConversionPattern("test.region", 1, ctx) {} 674 675 LogicalResult 676 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 677 ConversionPatternRewriter &rewriter) const final { 678 // Inline this region into the parent region. 679 auto &parentRegion = *op->getParentRegion(); 680 auto &opRegion = op->getRegion(0); 681 if (op->getDiscardableAttr("legalizer.should_clone")) 682 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 683 else 684 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 685 686 if (op->getDiscardableAttr("legalizer.erase_old_blocks")) { 687 while (!opRegion.empty()) 688 rewriter.eraseBlock(&opRegion.front()); 689 } 690 691 // Drop this operation. 692 rewriter.eraseOp(op); 693 return success(); 694 } 695 }; 696 /// This pattern is a simple pattern that generates a region containing an 697 /// illegal operation. 698 struct TestRegionRewriteUndo : public RewritePattern { 699 TestRegionRewriteUndo(MLIRContext *ctx) 700 : RewritePattern("test.region_builder", 1, ctx) {} 701 702 LogicalResult matchAndRewrite(Operation *op, 703 PatternRewriter &rewriter) const final { 704 // Create the region operation with an entry block containing arguments. 705 OperationState newRegion(op->getLoc(), "test.region"); 706 newRegion.addRegion(); 707 auto *regionOp = rewriter.create(newRegion); 708 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 709 entryBlock->addArgument(rewriter.getIntegerType(64), 710 rewriter.getUnknownLoc()); 711 712 // Add an explicitly illegal operation to ensure the conversion fails. 713 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 714 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 715 716 // Drop this operation. 717 rewriter.eraseOp(op); 718 return success(); 719 } 720 }; 721 /// A simple pattern that creates a block at the end of the parent region of the 722 /// matched operation. 723 struct TestCreateBlock : public RewritePattern { 724 TestCreateBlock(MLIRContext *ctx) 725 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 726 727 LogicalResult matchAndRewrite(Operation *op, 728 PatternRewriter &rewriter) const final { 729 Region ®ion = *op->getParentRegion(); 730 Type i32Type = rewriter.getIntegerType(32); 731 Location loc = op->getLoc(); 732 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 733 rewriter.create<TerminatorOp>(loc); 734 rewriter.eraseOp(op); 735 return success(); 736 } 737 }; 738 739 /// A simple pattern that creates a block containing an invalid operation in 740 /// order to trigger the block creation undo mechanism. 741 struct TestCreateIllegalBlock : public RewritePattern { 742 TestCreateIllegalBlock(MLIRContext *ctx) 743 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 744 745 LogicalResult matchAndRewrite(Operation *op, 746 PatternRewriter &rewriter) const final { 747 Region ®ion = *op->getParentRegion(); 748 Type i32Type = rewriter.getIntegerType(32); 749 Location loc = op->getLoc(); 750 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 751 // Create an illegal op to ensure the conversion fails. 752 rewriter.create<ILLegalOpF>(loc, i32Type); 753 rewriter.create<TerminatorOp>(loc); 754 rewriter.eraseOp(op); 755 return success(); 756 } 757 }; 758 759 /// A simple pattern that tests the undo mechanism when replacing the uses of a 760 /// block argument. 761 struct TestUndoBlockArgReplace : public ConversionPattern { 762 TestUndoBlockArgReplace(MLIRContext *ctx) 763 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 764 765 LogicalResult 766 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 767 ConversionPatternRewriter &rewriter) const final { 768 auto illegalOp = 769 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 770 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 771 illegalOp->getResult(0)); 772 rewriter.modifyOpInPlace(op, [] {}); 773 return success(); 774 } 775 }; 776 777 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion. 778 /// This is to test the rollback logic. 779 struct TestUndoMoveOpBefore : public ConversionPattern { 780 TestUndoMoveOpBefore(MLIRContext *ctx) 781 : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {} 782 783 LogicalResult 784 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 785 ConversionPatternRewriter &rewriter) const override { 786 rewriter.moveOpBefore(op, op->getParentOp()); 787 // Replace with an illegal op to ensure the conversion fails. 788 rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type()); 789 return success(); 790 } 791 }; 792 793 /// A rewrite pattern that tests the undo mechanism when erasing a block. 794 struct TestUndoBlockErase : public ConversionPattern { 795 TestUndoBlockErase(MLIRContext *ctx) 796 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 797 798 LogicalResult 799 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 800 ConversionPatternRewriter &rewriter) const final { 801 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 802 rewriter.setInsertionPointToStart(secondBlock); 803 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 804 rewriter.eraseBlock(secondBlock); 805 rewriter.modifyOpInPlace(op, [] {}); 806 return success(); 807 } 808 }; 809 810 //===----------------------------------------------------------------------===// 811 // Type-Conversion Rewrite Testing 812 813 /// This patterns erases a region operation that has had a type conversion. 814 struct TestDropOpSignatureConversion : public ConversionPattern { 815 TestDropOpSignatureConversion(MLIRContext *ctx, 816 const TypeConverter &converter) 817 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 818 LogicalResult 819 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 820 ConversionPatternRewriter &rewriter) const override { 821 Region ®ion = op->getRegion(0); 822 Block *entry = ®ion.front(); 823 824 // Convert the original entry arguments. 825 const TypeConverter &converter = *getTypeConverter(); 826 TypeConverter::SignatureConversion result(entry->getNumArguments()); 827 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 828 result)) || 829 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 830 return failure(); 831 832 // Convert the region signature and just drop the operation. 833 rewriter.eraseOp(op); 834 return success(); 835 } 836 }; 837 /// This pattern simply updates the operands of the given operation. 838 struct TestPassthroughInvalidOp : public ConversionPattern { 839 TestPassthroughInvalidOp(MLIRContext *ctx) 840 : ConversionPattern("test.invalid", 1, ctx) {} 841 LogicalResult 842 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 843 ConversionPatternRewriter &rewriter) const final { 844 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, operands, 845 std::nullopt); 846 return success(); 847 } 848 }; 849 /// This pattern handles the case of a split return value. 850 struct TestSplitReturnType : public ConversionPattern { 851 TestSplitReturnType(MLIRContext *ctx) 852 : ConversionPattern("test.return", 1, ctx) {} 853 LogicalResult 854 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 855 ConversionPatternRewriter &rewriter) const final { 856 // Check for a return of F32. 857 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 858 return failure(); 859 860 // Check if the first operation is a cast operation, if it is we use the 861 // results directly. 862 auto *defOp = operands[0].getDefiningOp(); 863 if (auto packerOp = 864 llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) { 865 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 866 return success(); 867 } 868 869 // Otherwise, fail to match. 870 return failure(); 871 } 872 }; 873 874 //===----------------------------------------------------------------------===// 875 // Multi-Level Type-Conversion Rewrite Testing 876 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 877 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 878 : ConversionPattern("test.type_producer", 1, ctx) {} 879 LogicalResult 880 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 881 ConversionPatternRewriter &rewriter) const final { 882 // If the type is I32, change the type to F32. 883 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 884 return failure(); 885 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 886 return success(); 887 } 888 }; 889 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 890 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 891 : ConversionPattern("test.type_producer", 1, ctx) {} 892 LogicalResult 893 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 894 ConversionPatternRewriter &rewriter) const final { 895 // If the type is F32, change the type to F64. 896 if (!Type(*op->result_type_begin()).isF32()) 897 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 898 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 899 return success(); 900 } 901 }; 902 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 903 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 904 : ConversionPattern("test.type_producer", 10, ctx) {} 905 LogicalResult 906 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 907 ConversionPatternRewriter &rewriter) const final { 908 // Always convert to B16, even though it is not a legal type. This tests 909 // that values are unmapped correctly. 910 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 911 return success(); 912 } 913 }; 914 struct TestUpdateConsumerType : public ConversionPattern { 915 TestUpdateConsumerType(MLIRContext *ctx) 916 : ConversionPattern("test.type_consumer", 1, ctx) {} 917 LogicalResult 918 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 919 ConversionPatternRewriter &rewriter) const final { 920 // Verify that the incoming operand has been successfully remapped to F64. 921 if (!operands[0].getType().isF64()) 922 return failure(); 923 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 924 return success(); 925 } 926 }; 927 928 //===----------------------------------------------------------------------===// 929 // Non-Root Replacement Rewrite Testing 930 /// This pattern generates an invalid operation, but replaces it before the 931 /// pattern is finished. This checks that we don't need to legalize the 932 /// temporary op. 933 struct TestNonRootReplacement : public RewritePattern { 934 TestNonRootReplacement(MLIRContext *ctx) 935 : RewritePattern("test.replace_non_root", 1, ctx) {} 936 937 LogicalResult matchAndRewrite(Operation *op, 938 PatternRewriter &rewriter) const final { 939 auto resultType = *op->result_type_begin(); 940 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 941 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 942 943 rewriter.replaceOp(illegalOp, legalOp); 944 rewriter.replaceOp(op, illegalOp); 945 return success(); 946 } 947 }; 948 949 //===----------------------------------------------------------------------===// 950 // Recursive Rewrite Testing 951 /// This pattern is applied to the same operation multiple times, but has a 952 /// bounded recursion. 953 struct TestBoundedRecursiveRewrite 954 : public OpRewritePattern<TestRecursiveRewriteOp> { 955 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 956 957 void initialize() { 958 // The conversion target handles bounding the recursion of this pattern. 959 setHasBoundedRewriteRecursion(); 960 } 961 962 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 963 PatternRewriter &rewriter) const final { 964 // Decrement the depth of the op in-place. 965 rewriter.modifyOpInPlace(op, [&] { 966 op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1)); 967 }); 968 return success(); 969 } 970 }; 971 972 struct TestNestedOpCreationUndoRewrite 973 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 974 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 975 976 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 977 PatternRewriter &rewriter) const final { 978 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 979 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 980 return success(); 981 }; 982 }; 983 984 // This pattern matches `test.blackhole` and delete this op and its producer. 985 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 986 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 987 988 LogicalResult matchAndRewrite(BlackHoleOp op, 989 PatternRewriter &rewriter) const final { 990 Operation *producer = op.getOperand().getDefiningOp(); 991 // Always erase the user before the producer, the framework should handle 992 // this correctly. 993 rewriter.eraseOp(op); 994 rewriter.eraseOp(producer); 995 return success(); 996 }; 997 }; 998 999 // This pattern replaces explicitly illegal op with explicitly legal op, 1000 // but in addition creates unregistered operation. 1001 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { 1002 using OpRewritePattern<ILLegalOpG>::OpRewritePattern; 1003 1004 LogicalResult matchAndRewrite(ILLegalOpG op, 1005 PatternRewriter &rewriter) const final { 1006 IntegerAttr attr = rewriter.getI32IntegerAttr(0); 1007 Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr); 1008 rewriter.replaceOpWithNewOp<LegalOpC>(op, val); 1009 return success(); 1010 }; 1011 }; 1012 } // namespace 1013 1014 namespace { 1015 struct TestTypeConverter : public TypeConverter { 1016 using TypeConverter::TypeConverter; 1017 TestTypeConverter() { 1018 addConversion(convertType); 1019 addArgumentMaterialization(materializeCast); 1020 addSourceMaterialization(materializeCast); 1021 } 1022 1023 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 1024 // Drop I16 types. 1025 if (t.isSignlessInteger(16)) 1026 return success(); 1027 1028 // Convert I64 to F64. 1029 if (t.isSignlessInteger(64)) { 1030 results.push_back(FloatType::getF64(t.getContext())); 1031 return success(); 1032 } 1033 1034 // Convert I42 to I43. 1035 if (t.isInteger(42)) { 1036 results.push_back(IntegerType::get(t.getContext(), 43)); 1037 return success(); 1038 } 1039 1040 // Split F32 into F16,F16. 1041 if (t.isF32()) { 1042 results.assign(2, FloatType::getF16(t.getContext())); 1043 return success(); 1044 } 1045 1046 // Otherwise, convert the type directly. 1047 results.push_back(t); 1048 return success(); 1049 } 1050 1051 /// Hook for materializing a conversion. This is necessary because we generate 1052 /// 1->N type mappings. 1053 static std::optional<Value> materializeCast(OpBuilder &builder, 1054 Type resultType, 1055 ValueRange inputs, Location loc) { 1056 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1057 } 1058 }; 1059 1060 struct TestLegalizePatternDriver 1061 : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> { 1062 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver) 1063 1064 StringRef getArgument() const final { return "test-legalize-patterns"; } 1065 StringRef getDescription() const final { 1066 return "Run test dialect legalization patterns"; 1067 } 1068 /// The mode of conversion to use with the driver. 1069 enum class ConversionMode { Analysis, Full, Partial }; 1070 1071 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 1072 1073 void getDependentDialects(DialectRegistry ®istry) const override { 1074 registry.insert<func::FuncDialect, test::TestDialect>(); 1075 } 1076 1077 void runOnOperation() override { 1078 TestTypeConverter converter; 1079 mlir::RewritePatternSet patterns(&getContext()); 1080 populateWithGenerated(patterns); 1081 patterns 1082 .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 1083 TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, 1084 TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, 1085 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 1086 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 1087 TestNonRootReplacement, TestBoundedRecursiveRewrite, 1088 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, 1089 TestCreateUnregisteredOp, TestUndoMoveOpBefore>(&getContext()); 1090 patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); 1091 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1092 converter); 1093 mlir::populateCallOpTypeConversionPattern(patterns, converter); 1094 1095 // Define the conversion target used for the test. 1096 ConversionTarget target(getContext()); 1097 target.addLegalOp<ModuleOp>(); 1098 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, 1099 TerminatorOp, OneRegionOp>(); 1100 target 1101 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 1102 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 1103 // Don't allow F32 operands. 1104 return llvm::none_of(op.getOperandTypes(), 1105 [](Type type) { return type.isF32(); }); 1106 }); 1107 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1108 return converter.isSignatureLegal(op.getFunctionType()) && 1109 converter.isLegal(&op.getBody()); 1110 }); 1111 target.addDynamicallyLegalOp<func::CallOp>( 1112 [&](func::CallOp op) { return converter.isLegal(op); }); 1113 1114 // TestCreateUnregisteredOp creates `arith.constant` operation, 1115 // which was not added to target intentionally to test 1116 // correct error code from conversion driver. 1117 target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); 1118 1119 // Expect the type_producer/type_consumer operations to only operate on f64. 1120 target.addDynamicallyLegalOp<TestTypeProducerOp>( 1121 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1122 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1123 return op.getOperand().getType().isF64(); 1124 }); 1125 1126 // Check support for marking certain operations as recursively legal. 1127 target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) { 1128 return static_cast<bool>( 1129 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 1130 }); 1131 1132 // Mark the bound recursion operation as dynamically legal. 1133 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 1134 [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); 1135 1136 // Handle a partial conversion. 1137 if (mode == ConversionMode::Partial) { 1138 DenseSet<Operation *> unlegalizedOps; 1139 if (failed(applyPartialConversion( 1140 getOperation(), target, std::move(patterns), &unlegalizedOps))) { 1141 getOperation()->emitRemark() << "applyPartialConversion failed"; 1142 } 1143 // Emit remarks for each legalizable operation. 1144 for (auto *op : unlegalizedOps) 1145 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1146 return; 1147 } 1148 1149 // Handle a full conversion. 1150 if (mode == ConversionMode::Full) { 1151 // Check support for marking unknown operations as dynamically legal. 1152 target.markUnknownOpDynamicallyLegal([](Operation *op) { 1153 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 1154 }); 1155 1156 if (failed(applyFullConversion(getOperation(), target, 1157 std::move(patterns)))) { 1158 getOperation()->emitRemark() << "applyFullConversion failed"; 1159 } 1160 return; 1161 } 1162 1163 // Otherwise, handle an analysis conversion. 1164 assert(mode == ConversionMode::Analysis); 1165 1166 // Analyze the convertible operations. 1167 DenseSet<Operation *> legalizedOps; 1168 if (failed(applyAnalysisConversion(getOperation(), target, 1169 std::move(patterns), legalizedOps))) 1170 return signalPassFailure(); 1171 1172 // Emit remarks for each legalizable operation. 1173 for (auto *op : legalizedOps) 1174 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 1175 } 1176 1177 /// The mode of conversion to use. 1178 ConversionMode mode; 1179 }; 1180 } // namespace 1181 1182 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 1183 legalizerConversionMode( 1184 "test-legalize-mode", 1185 llvm::cl::desc("The legalization mode to use with the test driver"), 1186 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 1187 llvm::cl::values( 1188 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 1189 "analysis", "Perform an analysis conversion"), 1190 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 1191 "Perform a full conversion"), 1192 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 1193 "partial", "Perform a partial conversion"))); 1194 1195 //===----------------------------------------------------------------------===// 1196 // ConversionPatternRewriter::getRemappedValue testing. This method is used 1197 // to get the remapped value of an original value that was replaced using 1198 // ConversionPatternRewriter. 1199 namespace { 1200 struct TestRemapValueTypeConverter : public TypeConverter { 1201 using TypeConverter::TypeConverter; 1202 1203 TestRemapValueTypeConverter() { 1204 addConversion( 1205 [](Float32Type type) { return Float64Type::get(type.getContext()); }); 1206 addConversion([](Type type) { return type; }); 1207 } 1208 }; 1209 1210 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 1211 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 1212 /// operand twice. 1213 /// 1214 /// Example: 1215 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 1216 /// is replaced with: 1217 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 1218 struct OneVResOneVOperandOp1Converter 1219 : public OpConversionPattern<OneVResOneVOperandOp1> { 1220 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 1221 1222 LogicalResult 1223 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, 1224 ConversionPatternRewriter &rewriter) const override { 1225 auto origOps = op.getOperands(); 1226 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 1227 "One operand expected"); 1228 Value origOp = *origOps.begin(); 1229 SmallVector<Value, 2> remappedOperands; 1230 // Replicate the remapped original operand twice. Note that we don't used 1231 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 1232 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1233 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1234 1235 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 1236 remappedOperands); 1237 return success(); 1238 } 1239 }; 1240 1241 /// A rewriter pattern that tests that blocks can be merged. 1242 struct TestRemapValueInRegion 1243 : public OpConversionPattern<TestRemappedValueRegionOp> { 1244 using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern; 1245 1246 LogicalResult 1247 matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor, 1248 ConversionPatternRewriter &rewriter) const final { 1249 Block &block = op.getBody().front(); 1250 Operation *terminator = block.getTerminator(); 1251 1252 // Merge the block into the parent region. 1253 Block *parentBlock = op->getBlock(); 1254 Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator()); 1255 rewriter.mergeBlocks(&block, parentBlock, ValueRange()); 1256 rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange()); 1257 1258 // Replace the results of this operation with the remapped terminator 1259 // values. 1260 SmallVector<Value> terminatorOperands; 1261 if (failed(rewriter.getRemappedValues(terminator->getOperands(), 1262 terminatorOperands))) 1263 return failure(); 1264 1265 rewriter.eraseOp(terminator); 1266 rewriter.replaceOp(op, terminatorOperands); 1267 return success(); 1268 } 1269 }; 1270 1271 struct TestRemappedValue 1272 : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> { 1273 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue) 1274 1275 StringRef getArgument() const final { return "test-remapped-value"; } 1276 StringRef getDescription() const final { 1277 return "Test public remapped value mechanism in ConversionPatternRewriter"; 1278 } 1279 void runOnOperation() override { 1280 TestRemapValueTypeConverter typeConverter; 1281 1282 mlir::RewritePatternSet patterns(&getContext()); 1283 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 1284 patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>( 1285 &getContext()); 1286 patterns.add<TestRemapValueInRegion>(typeConverter, &getContext()); 1287 1288 mlir::ConversionTarget target(getContext()); 1289 target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>(); 1290 1291 // Expect the type_producer/type_consumer operations to only operate on f64. 1292 target.addDynamicallyLegalOp<TestTypeProducerOp>( 1293 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1294 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1295 return op.getOperand().getType().isF64(); 1296 }); 1297 1298 // We make OneVResOneVOperandOp1 legal only when it has more that one 1299 // operand. This will trigger the conversion that will replace one-operand 1300 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 1301 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 1302 [](Operation *op) { return op->getNumOperands() > 1; }); 1303 1304 if (failed(mlir::applyFullConversion(getOperation(), target, 1305 std::move(patterns)))) { 1306 signalPassFailure(); 1307 } 1308 } 1309 }; 1310 } // namespace 1311 1312 //===----------------------------------------------------------------------===// 1313 // Test patterns without a specific root operation kind 1314 //===----------------------------------------------------------------------===// 1315 1316 namespace { 1317 /// This pattern matches and removes any operation in the test dialect. 1318 struct RemoveTestDialectOps : public RewritePattern { 1319 RemoveTestDialectOps(MLIRContext *context) 1320 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 1321 1322 LogicalResult matchAndRewrite(Operation *op, 1323 PatternRewriter &rewriter) const override { 1324 if (!isa<TestDialect>(op->getDialect())) 1325 return failure(); 1326 rewriter.eraseOp(op); 1327 return success(); 1328 } 1329 }; 1330 1331 struct TestUnknownRootOpDriver 1332 : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> { 1333 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver) 1334 1335 StringRef getArgument() const final { 1336 return "test-legalize-unknown-root-patterns"; 1337 } 1338 StringRef getDescription() const final { 1339 return "Test public remapped value mechanism in ConversionPatternRewriter"; 1340 } 1341 void runOnOperation() override { 1342 mlir::RewritePatternSet patterns(&getContext()); 1343 patterns.add<RemoveTestDialectOps>(&getContext()); 1344 1345 mlir::ConversionTarget target(getContext()); 1346 target.addIllegalDialect<TestDialect>(); 1347 if (failed(applyPartialConversion(getOperation(), target, 1348 std::move(patterns)))) 1349 signalPassFailure(); 1350 } 1351 }; 1352 } // namespace 1353 1354 //===----------------------------------------------------------------------===// 1355 // Test patterns that uses operations and types defined at runtime 1356 //===----------------------------------------------------------------------===// 1357 1358 namespace { 1359 /// This pattern matches dynamic operations 'test.one_operand_two_results' and 1360 /// replace them with dynamic operations 'test.generic_dynamic_op'. 1361 struct RewriteDynamicOp : public RewritePattern { 1362 RewriteDynamicOp(MLIRContext *context) 1363 : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1, 1364 context) {} 1365 1366 LogicalResult matchAndRewrite(Operation *op, 1367 PatternRewriter &rewriter) const override { 1368 assert(op->getName().getStringRef() == 1369 "test.dynamic_one_operand_two_results" && 1370 "rewrite pattern should only match operations with the right name"); 1371 1372 OperationState state(op->getLoc(), "test.dynamic_generic", 1373 op->getOperands(), op->getResultTypes(), 1374 op->getAttrs()); 1375 auto *newOp = rewriter.create(state); 1376 rewriter.replaceOp(op, newOp->getResults()); 1377 return success(); 1378 } 1379 }; 1380 1381 struct TestRewriteDynamicOpDriver 1382 : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> { 1383 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver) 1384 1385 void getDependentDialects(DialectRegistry ®istry) const override { 1386 registry.insert<TestDialect>(); 1387 } 1388 StringRef getArgument() const final { return "test-rewrite-dynamic-op"; } 1389 StringRef getDescription() const final { 1390 return "Test rewritting on dynamic operations"; 1391 } 1392 void runOnOperation() override { 1393 RewritePatternSet patterns(&getContext()); 1394 patterns.add<RewriteDynamicOp>(&getContext()); 1395 1396 ConversionTarget target(getContext()); 1397 target.addIllegalOp( 1398 OperationName("test.dynamic_one_operand_two_results", &getContext())); 1399 target.addLegalOp(OperationName("test.dynamic_generic", &getContext())); 1400 if (failed(applyPartialConversion(getOperation(), target, 1401 std::move(patterns)))) 1402 signalPassFailure(); 1403 } 1404 }; 1405 } // end anonymous namespace 1406 1407 //===----------------------------------------------------------------------===// 1408 // Test type conversions 1409 //===----------------------------------------------------------------------===// 1410 1411 namespace { 1412 struct TestTypeConversionProducer 1413 : public OpConversionPattern<TestTypeProducerOp> { 1414 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 1415 LogicalResult 1416 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, 1417 ConversionPatternRewriter &rewriter) const final { 1418 Type resultType = op.getType(); 1419 Type convertedType = getTypeConverter() 1420 ? getTypeConverter()->convertType(resultType) 1421 : resultType; 1422 if (isa<FloatType>(resultType)) 1423 resultType = rewriter.getF64Type(); 1424 else if (resultType.isInteger(16)) 1425 resultType = rewriter.getIntegerType(64); 1426 else if (isa<test::TestRecursiveType>(resultType) && 1427 convertedType != resultType) 1428 resultType = convertedType; 1429 else 1430 return failure(); 1431 1432 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 1433 return success(); 1434 } 1435 }; 1436 1437 /// Call signature conversion and then fail the rewrite to trigger the undo 1438 /// mechanism. 1439 struct TestSignatureConversionUndo 1440 : public OpConversionPattern<TestSignatureConversionUndoOp> { 1441 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 1442 1443 LogicalResult 1444 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, 1445 ConversionPatternRewriter &rewriter) const final { 1446 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 1447 return failure(); 1448 } 1449 }; 1450 1451 /// Call signature conversion without providing a type converter to handle 1452 /// materializations. 1453 struct TestTestSignatureConversionNoConverter 1454 : public OpConversionPattern<TestSignatureConversionNoConverterOp> { 1455 TestTestSignatureConversionNoConverter(const TypeConverter &converter, 1456 MLIRContext *context) 1457 : OpConversionPattern<TestSignatureConversionNoConverterOp>(context), 1458 converter(converter) {} 1459 1460 LogicalResult 1461 matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor, 1462 ConversionPatternRewriter &rewriter) const final { 1463 Region ®ion = op->getRegion(0); 1464 Block *entry = ®ion.front(); 1465 1466 // Convert the original entry arguments. 1467 TypeConverter::SignatureConversion result(entry->getNumArguments()); 1468 if (failed( 1469 converter.convertSignatureArgs(entry->getArgumentTypes(), result))) 1470 return failure(); 1471 rewriter.modifyOpInPlace( 1472 op, [&] { rewriter.applySignatureConversion(®ion, result); }); 1473 return success(); 1474 } 1475 1476 const TypeConverter &converter; 1477 }; 1478 1479 /// Just forward the operands to the root op. This is essentially a no-op 1480 /// pattern that is used to trigger target materialization. 1481 struct TestTypeConsumerForward 1482 : public OpConversionPattern<TestTypeConsumerOp> { 1483 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 1484 1485 LogicalResult 1486 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, 1487 ConversionPatternRewriter &rewriter) const final { 1488 rewriter.modifyOpInPlace(op, 1489 [&] { op->setOperands(adaptor.getOperands()); }); 1490 return success(); 1491 } 1492 }; 1493 1494 struct TestTypeConversionAnotherProducer 1495 : public OpRewritePattern<TestAnotherTypeProducerOp> { 1496 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 1497 1498 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 1499 PatternRewriter &rewriter) const final { 1500 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 1501 return success(); 1502 } 1503 }; 1504 1505 struct TestTypeConversionDriver 1506 : public PassWrapper<TestTypeConversionDriver, OperationPass<>> { 1507 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver) 1508 1509 void getDependentDialects(DialectRegistry ®istry) const override { 1510 registry.insert<TestDialect>(); 1511 } 1512 StringRef getArgument() const final { 1513 return "test-legalize-type-conversion"; 1514 } 1515 StringRef getDescription() const final { 1516 return "Test various type conversion functionalities in DialectConversion"; 1517 } 1518 1519 void runOnOperation() override { 1520 // Initialize the type converter. 1521 SmallVector<Type, 2> conversionCallStack; 1522 TypeConverter converter; 1523 1524 /// Add the legal set of type conversions. 1525 converter.addConversion([](Type type) -> Type { 1526 // Treat F64 as legal. 1527 if (type.isF64()) 1528 return type; 1529 // Allow converting BF16/F16/F32 to F64. 1530 if (type.isBF16() || type.isF16() || type.isF32()) 1531 return FloatType::getF64(type.getContext()); 1532 // Otherwise, the type is illegal. 1533 return nullptr; 1534 }); 1535 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 1536 // Drop all integer types. 1537 return success(); 1538 }); 1539 converter.addConversion( 1540 // Convert a recursive self-referring type into a non-self-referring 1541 // type named "outer_converted_type" that contains a SimpleAType. 1542 [&](test::TestRecursiveType type, 1543 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { 1544 // If the type is already converted, return it to indicate that it is 1545 // legal. 1546 if (type.getName() == "outer_converted_type") { 1547 results.push_back(type); 1548 return success(); 1549 } 1550 1551 conversionCallStack.push_back(type); 1552 auto popConversionCallStack = llvm::make_scope_exit( 1553 [&conversionCallStack]() { conversionCallStack.pop_back(); }); 1554 1555 // If the type is on the call stack more than once (it is there at 1556 // least once because of the _current_ call, which is always the last 1557 // element on the stack), we've hit the recursive case. Just return 1558 // SimpleAType here to create a non-recursive type as a result. 1559 if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(), 1560 type)) { 1561 results.push_back(test::SimpleAType::get(type.getContext())); 1562 return success(); 1563 } 1564 1565 // Convert the body recursively. 1566 auto result = test::TestRecursiveType::get(type.getContext(), 1567 "outer_converted_type"); 1568 if (failed(result.setBody(converter.convertType(type.getBody())))) 1569 return failure(); 1570 results.push_back(result); 1571 return success(); 1572 }); 1573 1574 /// Add the legal set of type materializations. 1575 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 1576 ValueRange inputs, 1577 Location loc) -> Value { 1578 // Allow casting from F64 back to F32. 1579 if (!resultType.isF16() && inputs.size() == 1 && 1580 inputs[0].getType().isF64()) 1581 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1582 // Allow producing an i32 or i64 from nothing. 1583 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 1584 inputs.empty()) 1585 return builder.create<TestTypeProducerOp>(loc, resultType); 1586 // Allow producing an i64 from an integer. 1587 if (isa<IntegerType>(resultType) && inputs.size() == 1 && 1588 isa<IntegerType>(inputs[0].getType())) 1589 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1590 // Otherwise, fail. 1591 return nullptr; 1592 }); 1593 1594 // Initialize the conversion target. 1595 mlir::ConversionTarget target(getContext()); 1596 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 1597 auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType()); 1598 return op.getType().isF64() || op.getType().isInteger(64) || 1599 (recursiveType && 1600 recursiveType.getName() == "outer_converted_type"); 1601 }); 1602 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1603 return converter.isSignatureLegal(op.getFunctionType()) && 1604 converter.isLegal(&op.getBody()); 1605 }); 1606 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 1607 // Allow casts from F64 to F32. 1608 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 1609 }); 1610 target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>( 1611 [&](TestSignatureConversionNoConverterOp op) { 1612 return converter.isLegal(op.getRegion().front().getArgumentTypes()); 1613 }); 1614 1615 // Initialize the set of rewrite patterns. 1616 RewritePatternSet patterns(&getContext()); 1617 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 1618 TestSignatureConversionUndo, 1619 TestTestSignatureConversionNoConverter>(converter, 1620 &getContext()); 1621 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 1622 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1623 converter); 1624 1625 if (failed(applyPartialConversion(getOperation(), target, 1626 std::move(patterns)))) 1627 signalPassFailure(); 1628 } 1629 }; 1630 } // namespace 1631 1632 //===----------------------------------------------------------------------===// 1633 // Test Target Materialization With No Uses 1634 //===----------------------------------------------------------------------===// 1635 1636 namespace { 1637 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> { 1638 using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern; 1639 1640 LogicalResult 1641 matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor, 1642 ConversionPatternRewriter &rewriter) const final { 1643 rewriter.replaceOp(op, adaptor.getOperands()); 1644 return success(); 1645 } 1646 }; 1647 1648 struct TestTargetMaterializationWithNoUses 1649 : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> { 1650 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 1651 TestTargetMaterializationWithNoUses) 1652 1653 StringRef getArgument() const final { 1654 return "test-target-materialization-with-no-uses"; 1655 } 1656 StringRef getDescription() const final { 1657 return "Test a special case of target materialization in DialectConversion"; 1658 } 1659 1660 void runOnOperation() override { 1661 TypeConverter converter; 1662 converter.addConversion([](Type t) { return t; }); 1663 converter.addConversion([](IntegerType intTy) -> Type { 1664 if (intTy.getWidth() == 16) 1665 return IntegerType::get(intTy.getContext(), 64); 1666 return intTy; 1667 }); 1668 converter.addTargetMaterialization( 1669 [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { 1670 return builder.create<TestCastOp>(loc, type, inputs).getResult(); 1671 }); 1672 1673 ConversionTarget target(getContext()); 1674 target.addIllegalOp<TestTypeChangerOp>(); 1675 1676 RewritePatternSet patterns(&getContext()); 1677 patterns.add<ForwardOperandPattern>(converter, &getContext()); 1678 1679 if (failed(applyPartialConversion(getOperation(), target, 1680 std::move(patterns)))) 1681 signalPassFailure(); 1682 } 1683 }; 1684 } // namespace 1685 1686 //===----------------------------------------------------------------------===// 1687 // Test Block Merging 1688 //===----------------------------------------------------------------------===// 1689 1690 namespace { 1691 /// A rewriter pattern that tests that blocks can be merged. 1692 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 1693 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 1694 1695 LogicalResult 1696 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, 1697 ConversionPatternRewriter &rewriter) const final { 1698 Block &firstBlock = op.getBody().front(); 1699 Operation *branchOp = firstBlock.getTerminator(); 1700 Block *secondBlock = &*(std::next(op.getBody().begin())); 1701 auto succOperands = branchOp->getOperands(); 1702 SmallVector<Value, 2> replacements(succOperands); 1703 rewriter.eraseOp(branchOp); 1704 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1705 rewriter.modifyOpInPlace(op, [] {}); 1706 return success(); 1707 } 1708 }; 1709 1710 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 1711 struct TestUndoBlocksMerge : public ConversionPattern { 1712 TestUndoBlocksMerge(MLIRContext *ctx) 1713 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 1714 LogicalResult 1715 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1716 ConversionPatternRewriter &rewriter) const final { 1717 Block &firstBlock = op->getRegion(0).front(); 1718 Operation *branchOp = firstBlock.getTerminator(); 1719 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 1720 rewriter.setInsertionPointToStart(secondBlock); 1721 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 1722 auto succOperands = branchOp->getOperands(); 1723 SmallVector<Value, 2> replacements(succOperands); 1724 rewriter.eraseOp(branchOp); 1725 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1726 rewriter.modifyOpInPlace(op, [] {}); 1727 return success(); 1728 } 1729 }; 1730 1731 /// A rewrite mechanism to inline the body of the op into its parent, when both 1732 /// ops can have a single block. 1733 struct TestMergeSingleBlockOps 1734 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 1735 using OpConversionPattern< 1736 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 1737 1738 LogicalResult 1739 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, 1740 ConversionPatternRewriter &rewriter) const final { 1741 SingleBlockImplicitTerminatorOp parentOp = 1742 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1743 if (!parentOp) 1744 return failure(); 1745 Block &innerBlock = op.getRegion().front(); 1746 TerminatorOp innerTerminator = 1747 cast<TerminatorOp>(innerBlock.getTerminator()); 1748 rewriter.inlineBlockBefore(&innerBlock, op); 1749 rewriter.eraseOp(innerTerminator); 1750 rewriter.eraseOp(op); 1751 rewriter.modifyOpInPlace(op, [] {}); 1752 return success(); 1753 } 1754 }; 1755 1756 struct TestMergeBlocksPatternDriver 1757 : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> { 1758 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver) 1759 1760 StringRef getArgument() const final { return "test-merge-blocks"; } 1761 StringRef getDescription() const final { 1762 return "Test Merging operation in ConversionPatternRewriter"; 1763 } 1764 void runOnOperation() override { 1765 MLIRContext *context = &getContext(); 1766 mlir::RewritePatternSet patterns(context); 1767 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 1768 context); 1769 ConversionTarget target(*context); 1770 target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 1771 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 1772 target.addIllegalOp<ILLegalOpF>(); 1773 1774 /// Expect the op to have a single block after legalization. 1775 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 1776 [&](TestMergeBlocksOp op) -> bool { 1777 return llvm::hasSingleElement(op.getBody()); 1778 }); 1779 1780 /// Only allow `test.br` within test.merge_blocks op. 1781 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 1782 return op->getParentOfType<TestMergeBlocksOp>(); 1783 }); 1784 1785 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 1786 /// inlined. 1787 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 1788 [&](SingleBlockImplicitTerminatorOp op) -> bool { 1789 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1790 }); 1791 1792 DenseSet<Operation *> unlegalizedOps; 1793 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 1794 &unlegalizedOps); 1795 for (auto *op : unlegalizedOps) 1796 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1797 } 1798 }; 1799 } // namespace 1800 1801 //===----------------------------------------------------------------------===// 1802 // Test Selective Replacement 1803 //===----------------------------------------------------------------------===// 1804 1805 namespace { 1806 /// A rewrite mechanism to inline the body of the op into its parent, when both 1807 /// ops can have a single block. 1808 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 1809 using OpRewritePattern<TestCastOp>::OpRewritePattern; 1810 1811 LogicalResult matchAndRewrite(TestCastOp op, 1812 PatternRewriter &rewriter) const final { 1813 if (op.getNumOperands() != 2) 1814 return failure(); 1815 OperandRange operands = op.getOperands(); 1816 1817 // Replace non-terminator uses with the first operand. 1818 rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) { 1819 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 1820 }); 1821 // Replace everything else with the second operand if the operation isn't 1822 // dead. 1823 rewriter.replaceOp(op, op.getOperand(1)); 1824 return success(); 1825 } 1826 }; 1827 1828 struct TestSelectiveReplacementPatternDriver 1829 : public PassWrapper<TestSelectiveReplacementPatternDriver, 1830 OperationPass<>> { 1831 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 1832 TestSelectiveReplacementPatternDriver) 1833 1834 StringRef getArgument() const final { 1835 return "test-pattern-selective-replacement"; 1836 } 1837 StringRef getDescription() const final { 1838 return "Test selective replacement in the PatternRewriter"; 1839 } 1840 void runOnOperation() override { 1841 MLIRContext *context = &getContext(); 1842 mlir::RewritePatternSet patterns(context); 1843 patterns.add<TestSelectiveOpReplacementPattern>(context); 1844 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 1845 } 1846 }; 1847 } // namespace 1848 1849 //===----------------------------------------------------------------------===// 1850 // PassRegistration 1851 //===----------------------------------------------------------------------===// 1852 1853 namespace mlir { 1854 namespace test { 1855 void registerPatternsTestPass() { 1856 PassRegistration<TestReturnTypeDriver>(); 1857 1858 PassRegistration<TestDerivedAttributeDriver>(); 1859 1860 PassRegistration<TestPatternDriver>(); 1861 PassRegistration<TestStrictPatternDriver>(); 1862 1863 PassRegistration<TestLegalizePatternDriver>([] { 1864 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 1865 }); 1866 1867 PassRegistration<TestRemappedValue>(); 1868 1869 PassRegistration<TestUnknownRootOpDriver>(); 1870 1871 PassRegistration<TestTypeConversionDriver>(); 1872 PassRegistration<TestTargetMaterializationWithNoUses>(); 1873 1874 PassRegistration<TestRewriteDynamicOpDriver>(); 1875 1876 PassRegistration<TestMergeBlocksPatternDriver>(); 1877 PassRegistration<TestSelectiveReplacementPatternDriver>(); 1878 } 1879 } // namespace test 1880 } // namespace mlir 1881