1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestTypes.h" 11 #include "mlir/Dialect/Arith/IR/Arith.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/Matchers.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 #include "mlir/Transforms/FoldUtils.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "llvm/ADT/ScopeExit.h" 21 22 using namespace mlir; 23 using namespace test; 24 25 // Native function for testing NativeCodeCall 26 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 27 return choice.getValue() ? input1 : input2; 28 } 29 30 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 31 rewriter.create<OpI>(loc, input); 32 } 33 34 static void handleNoResultOp(PatternRewriter &rewriter, 35 OpSymbolBindingNoResult op) { 36 // Turn the no result op to a one-result op. 37 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), 38 op.getOperand()); 39 } 40 41 static bool getFirstI32Result(Operation *op, Value &value) { 42 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 43 return false; 44 value = op->getResult(0); 45 return true; 46 } 47 48 static Value bindNativeCodeCallResult(Value value) { return value; } 49 50 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, 51 Value input2) { 52 return SmallVector<Value, 2>({input2, input1}); 53 } 54 55 // Test that natives calls are only called once during rewrites. 56 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 57 // This let us check the number of times OpM_Test was called by inspecting 58 // the returned value in the MLIR output. 59 static int64_t opMIncreasingValue = 314159265; 60 static Attribute opMTest(PatternRewriter &rewriter, Value val) { 61 int64_t i = opMIncreasingValue++; 62 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 63 } 64 65 namespace { 66 #include "TestPatterns.inc" 67 } // namespace 68 69 //===----------------------------------------------------------------------===// 70 // Test Reduce Pattern Interface 71 //===----------------------------------------------------------------------===// 72 73 void test::populateTestReductionPatterns(RewritePatternSet &patterns) { 74 populateWithGenerated(patterns); 75 } 76 77 //===----------------------------------------------------------------------===// 78 // Canonicalizer Driver. 79 //===----------------------------------------------------------------------===// 80 81 namespace { 82 struct FoldingPattern : public RewritePattern { 83 public: 84 FoldingPattern(MLIRContext *context) 85 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 86 /*benefit=*/1, context) {} 87 88 LogicalResult matchAndRewrite(Operation *op, 89 PatternRewriter &rewriter) const override { 90 // Exercise createOrFold API for a single-result operation that is folded 91 // upon construction. The operation being created has an in-place folder, 92 // and it should be still present in the output. Furthermore, the folder 93 // should not crash when attempting to recover the (unchanged) operation 94 // result. 95 Value result = rewriter.createOrFold<TestOpInPlaceFold>( 96 op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0)); 97 assert(result); 98 rewriter.replaceOp(op, result); 99 return success(); 100 } 101 }; 102 103 /// This pattern creates a foldable operation at the entry point of the block. 104 /// This tests the situation where the operation folder will need to replace an 105 /// operation with a previously created constant that does not initially 106 /// dominate the operation to replace. 107 struct FolderInsertBeforePreviouslyFoldedConstantPattern 108 : public OpRewritePattern<TestCastOp> { 109 public: 110 using OpRewritePattern<TestCastOp>::OpRewritePattern; 111 112 LogicalResult matchAndRewrite(TestCastOp op, 113 PatternRewriter &rewriter) const override { 114 if (!op->hasAttr("test_fold_before_previously_folded_op")) 115 return failure(); 116 rewriter.setInsertionPointToStart(op->getBlock()); 117 118 auto constOp = rewriter.create<arith::ConstantOp>( 119 op.getLoc(), rewriter.getBoolAttr(true)); 120 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), 121 Value(constOp)); 122 return success(); 123 } 124 }; 125 126 /// This pattern matches test.op_commutative2 with the first operand being 127 /// another test.op_commutative2 with a constant on the right side and fold it 128 /// away by propagating it as its result. This is intend to check that patterns 129 /// are applied after the commutative property moves constant to the right. 130 struct FolderCommutativeOp2WithConstant 131 : public OpRewritePattern<TestCommutative2Op> { 132 public: 133 using OpRewritePattern<TestCommutative2Op>::OpRewritePattern; 134 135 LogicalResult matchAndRewrite(TestCommutative2Op op, 136 PatternRewriter &rewriter) const override { 137 auto operand = 138 dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp()); 139 if (!operand) 140 return failure(); 141 Attribute constInput; 142 if (!matchPattern(operand->getOperand(1), m_Constant(&constInput))) 143 return failure(); 144 rewriter.replaceOp(op, operand->getOperand(1)); 145 return success(); 146 } 147 }; 148 149 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer 150 /// attribute with value smaller than MaxVal, it increments the value by 1. 151 template <int MaxVal> 152 struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> { 153 using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern; 154 155 LogicalResult matchAndRewrite(AnyAttrOfOp op, 156 PatternRewriter &rewriter) const override { 157 auto intAttr = dyn_cast<IntegerAttr>(op.getAttr()); 158 if (!intAttr) 159 return failure(); 160 int64_t val = intAttr.getInt(); 161 if (val >= MaxVal) 162 return failure(); 163 rewriter.modifyOpInPlace( 164 op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); }); 165 return success(); 166 } 167 }; 168 169 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op". 170 struct MakeOpEligible : public RewritePattern { 171 MakeOpEligible(MLIRContext *context) 172 : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {} 173 174 LogicalResult matchAndRewrite(Operation *op, 175 PatternRewriter &rewriter) const override { 176 if (op->hasAttr("eligible")) 177 return failure(); 178 rewriter.modifyOpInPlace( 179 op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); }); 180 return success(); 181 } 182 }; 183 184 /// This pattern hoists eligible ops out of a "test.one_region_op". 185 struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> { 186 using OpRewritePattern<test::OneRegionOp>::OpRewritePattern; 187 188 LogicalResult matchAndRewrite(test::OneRegionOp op, 189 PatternRewriter &rewriter) const override { 190 Operation *terminator = op.getRegion().front().getTerminator(); 191 Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp(); 192 if (toBeHoisted->getParentOp() != op) 193 return failure(); 194 if (!toBeHoisted->hasAttr("eligible")) 195 return failure(); 196 rewriter.moveOpBefore(toBeHoisted, op); 197 return success(); 198 } 199 }; 200 201 /// This pattern moves "test.move_before_parent_op" before the parent op. 202 struct MoveBeforeParentOp : public RewritePattern { 203 MoveBeforeParentOp(MLIRContext *context) 204 : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {} 205 206 LogicalResult matchAndRewrite(Operation *op, 207 PatternRewriter &rewriter) const override { 208 // Do not hoist past functions. 209 if (isa<FunctionOpInterface>(op->getParentOp())) 210 return failure(); 211 rewriter.moveOpBefore(op, op->getParentOp()); 212 return success(); 213 } 214 }; 215 216 /// This pattern inlines blocks that are nested in 217 /// "test.inline_blocks_into_parent" into the parent block. 218 struct InlineBlocksIntoParent : public RewritePattern { 219 InlineBlocksIntoParent(MLIRContext *context) 220 : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1, 221 context) {} 222 223 LogicalResult matchAndRewrite(Operation *op, 224 PatternRewriter &rewriter) const override { 225 bool changed = false; 226 for (Region &r : op->getRegions()) { 227 while (!r.empty()) { 228 rewriter.inlineBlockBefore(&r.front(), op); 229 changed = true; 230 } 231 } 232 return success(changed); 233 } 234 }; 235 236 /// This pattern splits blocks at "test.split_block_here" and replaces the op 237 /// with a new op (to prevent an infinite loop of block splitting). 238 struct SplitBlockHere : public RewritePattern { 239 SplitBlockHere(MLIRContext *context) 240 : RewritePattern("test.split_block_here", /*benefit=*/1, context) {} 241 242 LogicalResult matchAndRewrite(Operation *op, 243 PatternRewriter &rewriter) const override { 244 rewriter.splitBlock(op->getBlock(), op->getIterator()); 245 Operation *newOp = rewriter.create( 246 op->getLoc(), 247 OperationName("test.new_op", op->getContext()).getIdentifier(), 248 op->getOperands(), op->getResultTypes()); 249 rewriter.replaceOp(op, newOp); 250 return success(); 251 } 252 }; 253 254 /// This pattern clones "test.clone_me" ops. 255 struct CloneOp : public RewritePattern { 256 CloneOp(MLIRContext *context) 257 : RewritePattern("test.clone_me", /*benefit=*/1, context) {} 258 259 LogicalResult matchAndRewrite(Operation *op, 260 PatternRewriter &rewriter) const override { 261 // Do not clone already cloned ops to avoid going into an infinite loop. 262 if (op->hasAttr("was_cloned")) 263 return failure(); 264 Operation *cloned = rewriter.clone(*op); 265 cloned->setAttr("was_cloned", rewriter.getUnitAttr()); 266 return success(); 267 } 268 }; 269 270 /// This pattern clones regions of "test.clone_region_before" ops before the 271 /// parent block. 272 struct CloneRegionBeforeOp : public RewritePattern { 273 CloneRegionBeforeOp(MLIRContext *context) 274 : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {} 275 276 LogicalResult matchAndRewrite(Operation *op, 277 PatternRewriter &rewriter) const override { 278 // Do not clone already cloned ops to avoid going into an infinite loop. 279 if (op->hasAttr("was_cloned")) 280 return failure(); 281 for (Region &r : op->getRegions()) 282 rewriter.cloneRegionBefore(r, op->getBlock()); 283 op->setAttr("was_cloned", rewriter.getUnitAttr()); 284 return success(); 285 } 286 }; 287 288 struct TestPatternDriver 289 : public PassWrapper<TestPatternDriver, OperationPass<>> { 290 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver) 291 292 TestPatternDriver() = default; 293 TestPatternDriver(const TestPatternDriver &other) : PassWrapper(other) {} 294 295 StringRef getArgument() const final { return "test-patterns"; } 296 StringRef getDescription() const final { return "Run test dialect patterns"; } 297 void runOnOperation() override { 298 mlir::RewritePatternSet patterns(&getContext()); 299 populateWithGenerated(patterns); 300 301 // Verify named pattern is generated with expected name. 302 patterns.add<FoldingPattern, TestNamedPatternRule, 303 FolderInsertBeforePreviouslyFoldedConstantPattern, 304 FolderCommutativeOp2WithConstant, HoistEligibleOps, 305 MakeOpEligible>(&getContext()); 306 307 // Additional patterns for testing the GreedyPatternRewriteDriver. 308 patterns.insert<IncrementIntAttribute<3>>(&getContext()); 309 310 GreedyRewriteConfig config; 311 config.useTopDownTraversal = this->useTopDownTraversal; 312 config.maxIterations = this->maxIterations; 313 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), 314 config); 315 } 316 317 Option<bool> useTopDownTraversal{ 318 *this, "top-down", 319 llvm::cl::desc("Seed the worklist in general top-down order"), 320 llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)}; 321 Option<int> maxIterations{ 322 *this, "max-iterations", 323 llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"), 324 llvm::cl::init(GreedyRewriteConfig().maxIterations)}; 325 }; 326 327 struct DumpNotifications : public RewriterBase::Listener { 328 void notifyBlockInserted(Block *block, Region *previous, 329 Region::iterator previousIt) override { 330 llvm::outs() << "notifyBlockInserted"; 331 if (block->getParentOp()) { 332 llvm::outs() << " into " << block->getParentOp()->getName() << ": "; 333 } else { 334 llvm::outs() << " into unknown op: "; 335 } 336 if (previous == nullptr) { 337 llvm::outs() << "was unlinked\n"; 338 } else { 339 llvm::outs() << "was linked\n"; 340 } 341 } 342 void notifyOperationInserted(Operation *op, 343 OpBuilder::InsertPoint previous) override { 344 llvm::outs() << "notifyOperationInserted: " << op->getName(); 345 if (!previous.isSet()) { 346 llvm::outs() << ", was unlinked\n"; 347 } else { 348 if (!previous.getPoint().getNodePtr()) { 349 llvm::outs() << ", was linked, exact position unknown\n"; 350 } else if (previous.getPoint() == previous.getBlock()->end()) { 351 llvm::outs() << ", was last in block\n"; 352 } else { 353 llvm::outs() << ", previous = " << previous.getPoint()->getName() 354 << "\n"; 355 } 356 } 357 } 358 void notifyBlockErased(Block *block) override { 359 llvm::outs() << "notifyBlockErased\n"; 360 } 361 void notifyOperationErased(Operation *op) override { 362 llvm::outs() << "notifyOperationErased: " << op->getName() << "\n"; 363 } 364 void notifyOperationModified(Operation *op) override { 365 llvm::outs() << "notifyOperationModified: " << op->getName() << "\n"; 366 } 367 void notifyOperationReplaced(Operation *op, ValueRange values) override { 368 llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n"; 369 } 370 }; 371 372 struct TestStrictPatternDriver 373 : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> { 374 public: 375 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver) 376 377 TestStrictPatternDriver() = default; 378 TestStrictPatternDriver(const TestStrictPatternDriver &other) 379 : PassWrapper(other) { 380 strictMode = other.strictMode; 381 } 382 383 StringRef getArgument() const final { return "test-strict-pattern-driver"; } 384 StringRef getDescription() const final { 385 return "Test strict mode of pattern driver"; 386 } 387 388 void runOnOperation() override { 389 MLIRContext *ctx = &getContext(); 390 mlir::RewritePatternSet patterns(ctx); 391 patterns.add< 392 // clang-format off 393 ChangeBlockOp, 394 CloneOp, 395 CloneRegionBeforeOp, 396 EraseOp, 397 ImplicitChangeOp, 398 InlineBlocksIntoParent, 399 InsertSameOp, 400 MoveBeforeParentOp, 401 ReplaceWithNewOp, 402 SplitBlockHere 403 // clang-format on 404 >(ctx); 405 SmallVector<Operation *> ops; 406 getOperation()->walk([&](Operation *op) { 407 StringRef opName = op->getName().getStringRef(); 408 if (opName == "test.insert_same_op" || opName == "test.change_block_op" || 409 opName == "test.replace_with_new_op" || opName == "test.erase_op" || 410 opName == "test.move_before_parent_op" || 411 opName == "test.inline_blocks_into_parent" || 412 opName == "test.split_block_here" || opName == "test.clone_me" || 413 opName == "test.clone_region_before") { 414 ops.push_back(op); 415 } 416 }); 417 418 DumpNotifications dumpNotifications; 419 GreedyRewriteConfig config; 420 config.listener = &dumpNotifications; 421 if (strictMode == "AnyOp") { 422 config.strictMode = GreedyRewriteStrictness::AnyOp; 423 } else if (strictMode == "ExistingAndNewOps") { 424 config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; 425 } else if (strictMode == "ExistingOps") { 426 config.strictMode = GreedyRewriteStrictness::ExistingOps; 427 } else { 428 llvm_unreachable("invalid strictness option"); 429 } 430 431 // Check if these transformations introduce visiting of operations that 432 // are not in the `ops` set (The new created ops are valid). An invalid 433 // operation will trigger the assertion while processing. 434 bool changed = false; 435 bool allErased = false; 436 (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config, 437 &changed, &allErased); 438 Builder b(ctx); 439 getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed)); 440 getOperation()->setAttr("pattern_driver_all_erased", 441 b.getBoolAttr(allErased)); 442 } 443 444 Option<std::string> strictMode{ 445 *this, "strictness", 446 llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"), 447 llvm::cl::init("AnyOp")}; 448 449 private: 450 // New inserted operation is valid for further transformation. 451 class InsertSameOp : public RewritePattern { 452 public: 453 InsertSameOp(MLIRContext *context) 454 : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {} 455 456 LogicalResult matchAndRewrite(Operation *op, 457 PatternRewriter &rewriter) const override { 458 if (op->hasAttr("skip")) 459 return failure(); 460 461 Operation *newOp = 462 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 463 op->getOperands(), op->getResultTypes()); 464 rewriter.modifyOpInPlace( 465 op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); }); 466 newOp->setAttr("skip", rewriter.getBoolAttr(true)); 467 468 return success(); 469 } 470 }; 471 472 // Replace an operation may introduce the re-visiting of its users. 473 class ReplaceWithNewOp : public RewritePattern { 474 public: 475 ReplaceWithNewOp(MLIRContext *context) 476 : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {} 477 478 LogicalResult matchAndRewrite(Operation *op, 479 PatternRewriter &rewriter) const override { 480 Operation *newOp; 481 if (op->hasAttr("create_erase_op")) { 482 newOp = rewriter.create( 483 op->getLoc(), 484 OperationName("test.erase_op", op->getContext()).getIdentifier(), 485 ValueRange(), TypeRange()); 486 } else { 487 newOp = rewriter.create( 488 op->getLoc(), 489 OperationName("test.new_op", op->getContext()).getIdentifier(), 490 op->getOperands(), op->getResultTypes()); 491 } 492 // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp". 493 // A "notifyOperationReplaced" callback is triggered in either case. 494 rewriter.replaceAllOpUsesWith(op, newOp->getResults()); 495 rewriter.eraseOp(op); 496 return success(); 497 } 498 }; 499 500 // Remove an operation may introduce the re-visiting of its operands. 501 class EraseOp : public RewritePattern { 502 public: 503 EraseOp(MLIRContext *context) 504 : RewritePattern("test.erase_op", /*benefit=*/1, context) {} 505 LogicalResult matchAndRewrite(Operation *op, 506 PatternRewriter &rewriter) const override { 507 rewriter.eraseOp(op); 508 return success(); 509 } 510 }; 511 512 // The following two patterns test RewriterBase::replaceAllUsesWith. 513 // 514 // That function replaces all usages of a Block (or a Value) with another one 515 // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver 516 // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its 517 // worklist: when an op is modified, it is added to the worklist. The two 518 // patterns below make the tracking observable: ChangeBlockOp replaces all 519 // usages of a block and that pattern is applied because the corresponding ops 520 // are put on the initial worklist (see above). ImplicitChangeOp does an 521 // unrelated change but ops of the corresponding type are *not* on the initial 522 // worklist, so the effect of the second pattern is only visible if the 523 // tracking and subsequent adding to the worklist actually works. 524 525 // Replace all usages of the first successor with the second successor. 526 class ChangeBlockOp : public RewritePattern { 527 public: 528 ChangeBlockOp(MLIRContext *context) 529 : RewritePattern("test.change_block_op", /*benefit=*/1, context) {} 530 LogicalResult matchAndRewrite(Operation *op, 531 PatternRewriter &rewriter) const override { 532 if (op->getNumSuccessors() < 2) 533 return failure(); 534 Block *firstSuccessor = op->getSuccessor(0); 535 Block *secondSuccessor = op->getSuccessor(1); 536 if (firstSuccessor == secondSuccessor) 537 return failure(); 538 // This is the function being tested: 539 rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor); 540 // Using the following line instead would make the test fail: 541 // firstSuccessor->replaceAllUsesWith(secondSuccessor); 542 return success(); 543 } 544 }; 545 546 // Changes the successor to the parent block. 547 class ImplicitChangeOp : public RewritePattern { 548 public: 549 ImplicitChangeOp(MLIRContext *context) 550 : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {} 551 LogicalResult matchAndRewrite(Operation *op, 552 PatternRewriter &rewriter) const override { 553 if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock()) 554 return failure(); 555 rewriter.modifyOpInPlace(op, 556 [&]() { op->setSuccessor(op->getBlock(), 0); }); 557 return success(); 558 } 559 }; 560 }; 561 562 } // namespace 563 564 //===----------------------------------------------------------------------===// 565 // ReturnType Driver. 566 //===----------------------------------------------------------------------===// 567 568 namespace { 569 // Generate ops for each instance where the type can be successfully inferred. 570 template <typename OpTy> 571 static void invokeCreateWithInferredReturnType(Operation *op) { 572 auto *context = op->getContext(); 573 auto fop = op->getParentOfType<func::FuncOp>(); 574 auto location = UnknownLoc::get(context); 575 OpBuilder b(op); 576 b.setInsertionPointAfter(op); 577 578 // Use permutations of 2 args as operands. 579 assert(fop.getNumArguments() >= 2); 580 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 581 for (int j = 0; j < e; ++j) { 582 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 583 SmallVector<Type, 2> inferredReturnTypes; 584 if (succeeded(OpTy::inferReturnTypes( 585 context, std::nullopt, values, op->getDiscardableAttrDictionary(), 586 op->getPropertiesStorage(), op->getRegions(), 587 inferredReturnTypes))) { 588 OperationState state(location, OpTy::getOperationName()); 589 // TODO: Expand to regions. 590 OpTy::build(b, state, values, op->getAttrs()); 591 (void)b.create(state); 592 } 593 } 594 } 595 } 596 597 static void reifyReturnShape(Operation *op) { 598 OpBuilder b(op); 599 600 // Use permutations of 2 args as operands. 601 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 602 SmallVector<Value, 2> shapes; 603 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 604 !llvm::hasSingleElement(shapes)) 605 return; 606 for (const auto &it : llvm::enumerate(shapes)) { 607 op->emitRemark() << "value " << it.index() << ": " 608 << it.value().getDefiningOp(); 609 } 610 } 611 612 struct TestReturnTypeDriver 613 : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> { 614 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver) 615 616 void getDependentDialects(DialectRegistry ®istry) const override { 617 registry.insert<tensor::TensorDialect>(); 618 } 619 StringRef getArgument() const final { return "test-return-type"; } 620 StringRef getDescription() const final { return "Run return type functions"; } 621 622 void runOnOperation() override { 623 if (getOperation().getName() == "testCreateFunctions") { 624 std::vector<Operation *> ops; 625 // Collect ops to avoid triggering on inserted ops. 626 for (auto &op : getOperation().getBody().front()) 627 ops.push_back(&op); 628 // Generate test patterns for each, but skip terminator. 629 for (auto *op : llvm::ArrayRef(ops).drop_back()) { 630 // Test create method of each of the Op classes below. The resultant 631 // output would be in reverse order underneath `op` from which 632 // the attributes and regions are used. 633 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 634 invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>( 635 op); 636 invokeCreateWithInferredReturnType< 637 OpWithShapedTypeInferTypeInterfaceOp>(op); 638 }; 639 return; 640 } 641 if (getOperation().getName() == "testReifyFunctions") { 642 std::vector<Operation *> ops; 643 // Collect ops to avoid triggering on inserted ops. 644 for (auto &op : getOperation().getBody().front()) 645 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 646 ops.push_back(&op); 647 // Generate test patterns for each, but skip terminator. 648 for (auto *op : ops) 649 reifyReturnShape(op); 650 } 651 } 652 }; 653 } // namespace 654 655 namespace { 656 struct TestDerivedAttributeDriver 657 : public PassWrapper<TestDerivedAttributeDriver, 658 OperationPass<func::FuncOp>> { 659 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver) 660 661 StringRef getArgument() const final { return "test-derived-attr"; } 662 StringRef getDescription() const final { 663 return "Run test derived attributes"; 664 } 665 void runOnOperation() override; 666 }; 667 } // namespace 668 669 void TestDerivedAttributeDriver::runOnOperation() { 670 getOperation().walk([](DerivedAttributeOpInterface dOp) { 671 auto dAttr = dOp.materializeDerivedAttributes(); 672 if (!dAttr) 673 return; 674 for (auto d : dAttr) 675 dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue(); 676 }); 677 } 678 679 //===----------------------------------------------------------------------===// 680 // Legalization Driver. 681 //===----------------------------------------------------------------------===// 682 683 namespace { 684 //===----------------------------------------------------------------------===// 685 // Region-Block Rewrite Testing 686 687 /// This pattern is a simple pattern that inlines the first region of a given 688 /// operation into the parent region. 689 struct TestRegionRewriteBlockMovement : public ConversionPattern { 690 TestRegionRewriteBlockMovement(MLIRContext *ctx) 691 : ConversionPattern("test.region", 1, ctx) {} 692 693 LogicalResult 694 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 695 ConversionPatternRewriter &rewriter) const final { 696 // Inline this region into the parent region. 697 auto &parentRegion = *op->getParentRegion(); 698 auto &opRegion = op->getRegion(0); 699 if (op->getDiscardableAttr("legalizer.should_clone")) 700 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 701 else 702 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 703 704 if (op->getDiscardableAttr("legalizer.erase_old_blocks")) { 705 while (!opRegion.empty()) 706 rewriter.eraseBlock(&opRegion.front()); 707 } 708 709 // Drop this operation. 710 rewriter.eraseOp(op); 711 return success(); 712 } 713 }; 714 /// This pattern is a simple pattern that generates a region containing an 715 /// illegal operation. 716 struct TestRegionRewriteUndo : public RewritePattern { 717 TestRegionRewriteUndo(MLIRContext *ctx) 718 : RewritePattern("test.region_builder", 1, ctx) {} 719 720 LogicalResult matchAndRewrite(Operation *op, 721 PatternRewriter &rewriter) const final { 722 // Create the region operation with an entry block containing arguments. 723 OperationState newRegion(op->getLoc(), "test.region"); 724 newRegion.addRegion(); 725 auto *regionOp = rewriter.create(newRegion); 726 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 727 entryBlock->addArgument(rewriter.getIntegerType(64), 728 rewriter.getUnknownLoc()); 729 730 // Add an explicitly illegal operation to ensure the conversion fails. 731 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 732 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 733 734 // Drop this operation. 735 rewriter.eraseOp(op); 736 return success(); 737 } 738 }; 739 /// A simple pattern that creates a block at the end of the parent region of the 740 /// matched operation. 741 struct TestCreateBlock : public RewritePattern { 742 TestCreateBlock(MLIRContext *ctx) 743 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 744 745 LogicalResult matchAndRewrite(Operation *op, 746 PatternRewriter &rewriter) const final { 747 Region ®ion = *op->getParentRegion(); 748 Type i32Type = rewriter.getIntegerType(32); 749 Location loc = op->getLoc(); 750 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 751 rewriter.create<TerminatorOp>(loc); 752 rewriter.eraseOp(op); 753 return success(); 754 } 755 }; 756 757 /// A simple pattern that creates a block containing an invalid operation in 758 /// order to trigger the block creation undo mechanism. 759 struct TestCreateIllegalBlock : public RewritePattern { 760 TestCreateIllegalBlock(MLIRContext *ctx) 761 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 762 763 LogicalResult matchAndRewrite(Operation *op, 764 PatternRewriter &rewriter) const final { 765 Region ®ion = *op->getParentRegion(); 766 Type i32Type = rewriter.getIntegerType(32); 767 Location loc = op->getLoc(); 768 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 769 // Create an illegal op to ensure the conversion fails. 770 rewriter.create<ILLegalOpF>(loc, i32Type); 771 rewriter.create<TerminatorOp>(loc); 772 rewriter.eraseOp(op); 773 return success(); 774 } 775 }; 776 777 /// A simple pattern that tests the undo mechanism when replacing the uses of a 778 /// block argument. 779 struct TestUndoBlockArgReplace : public ConversionPattern { 780 TestUndoBlockArgReplace(MLIRContext *ctx) 781 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 782 783 LogicalResult 784 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 785 ConversionPatternRewriter &rewriter) const final { 786 auto illegalOp = 787 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 788 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 789 illegalOp->getResult(0)); 790 rewriter.modifyOpInPlace(op, [] {}); 791 return success(); 792 } 793 }; 794 795 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion. 796 /// This is to test the rollback logic. 797 struct TestUndoMoveOpBefore : public ConversionPattern { 798 TestUndoMoveOpBefore(MLIRContext *ctx) 799 : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {} 800 801 LogicalResult 802 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 803 ConversionPatternRewriter &rewriter) const override { 804 rewriter.moveOpBefore(op, op->getParentOp()); 805 // Replace with an illegal op to ensure the conversion fails. 806 rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type()); 807 return success(); 808 } 809 }; 810 811 /// A rewrite pattern that tests the undo mechanism when erasing a block. 812 struct TestUndoBlockErase : public ConversionPattern { 813 TestUndoBlockErase(MLIRContext *ctx) 814 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 815 816 LogicalResult 817 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 818 ConversionPatternRewriter &rewriter) const final { 819 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 820 rewriter.setInsertionPointToStart(secondBlock); 821 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 822 rewriter.eraseBlock(secondBlock); 823 rewriter.modifyOpInPlace(op, [] {}); 824 return success(); 825 } 826 }; 827 828 /// A pattern that modifies a property in-place, but keeps the op illegal. 829 struct TestUndoPropertiesModification : public ConversionPattern { 830 TestUndoPropertiesModification(MLIRContext *ctx) 831 : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {} 832 LogicalResult 833 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 834 ConversionPatternRewriter &rewriter) const final { 835 if (!op->hasAttr("modify_inplace")) 836 return failure(); 837 rewriter.modifyOpInPlace( 838 op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); }); 839 return success(); 840 } 841 }; 842 843 //===----------------------------------------------------------------------===// 844 // Type-Conversion Rewrite Testing 845 846 /// This patterns erases a region operation that has had a type conversion. 847 struct TestDropOpSignatureConversion : public ConversionPattern { 848 TestDropOpSignatureConversion(MLIRContext *ctx, 849 const TypeConverter &converter) 850 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 851 LogicalResult 852 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 853 ConversionPatternRewriter &rewriter) const override { 854 Region ®ion = op->getRegion(0); 855 Block *entry = ®ion.front(); 856 857 // Convert the original entry arguments. 858 const TypeConverter &converter = *getTypeConverter(); 859 TypeConverter::SignatureConversion result(entry->getNumArguments()); 860 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 861 result)) || 862 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 863 return failure(); 864 865 // Convert the region signature and just drop the operation. 866 rewriter.eraseOp(op); 867 return success(); 868 } 869 }; 870 /// This pattern simply updates the operands of the given operation. 871 struct TestPassthroughInvalidOp : public ConversionPattern { 872 TestPassthroughInvalidOp(MLIRContext *ctx) 873 : ConversionPattern("test.invalid", 1, ctx) {} 874 LogicalResult 875 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 876 ConversionPatternRewriter &rewriter) const final { 877 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, operands, 878 std::nullopt); 879 return success(); 880 } 881 }; 882 /// This pattern handles the case of a split return value. 883 struct TestSplitReturnType : public ConversionPattern { 884 TestSplitReturnType(MLIRContext *ctx) 885 : ConversionPattern("test.return", 1, ctx) {} 886 LogicalResult 887 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 888 ConversionPatternRewriter &rewriter) const final { 889 // Check for a return of F32. 890 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 891 return failure(); 892 893 // Check if the first operation is a cast operation, if it is we use the 894 // results directly. 895 auto *defOp = operands[0].getDefiningOp(); 896 if (auto packerOp = 897 llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) { 898 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 899 return success(); 900 } 901 902 // Otherwise, fail to match. 903 return failure(); 904 } 905 }; 906 907 //===----------------------------------------------------------------------===// 908 // Multi-Level Type-Conversion Rewrite Testing 909 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 910 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 911 : ConversionPattern("test.type_producer", 1, ctx) {} 912 LogicalResult 913 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 914 ConversionPatternRewriter &rewriter) const final { 915 // If the type is I32, change the type to F32. 916 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 917 return failure(); 918 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 919 return success(); 920 } 921 }; 922 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 923 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 924 : ConversionPattern("test.type_producer", 1, ctx) {} 925 LogicalResult 926 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 927 ConversionPatternRewriter &rewriter) const final { 928 // If the type is F32, change the type to F64. 929 if (!Type(*op->result_type_begin()).isF32()) 930 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 931 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 932 return success(); 933 } 934 }; 935 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 936 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 937 : ConversionPattern("test.type_producer", 10, ctx) {} 938 LogicalResult 939 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 940 ConversionPatternRewriter &rewriter) const final { 941 // Always convert to B16, even though it is not a legal type. This tests 942 // that values are unmapped correctly. 943 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 944 return success(); 945 } 946 }; 947 struct TestUpdateConsumerType : public ConversionPattern { 948 TestUpdateConsumerType(MLIRContext *ctx) 949 : ConversionPattern("test.type_consumer", 1, ctx) {} 950 LogicalResult 951 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 952 ConversionPatternRewriter &rewriter) const final { 953 // Verify that the incoming operand has been successfully remapped to F64. 954 if (!operands[0].getType().isF64()) 955 return failure(); 956 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 957 return success(); 958 } 959 }; 960 961 //===----------------------------------------------------------------------===// 962 // Non-Root Replacement Rewrite Testing 963 /// This pattern generates an invalid operation, but replaces it before the 964 /// pattern is finished. This checks that we don't need to legalize the 965 /// temporary op. 966 struct TestNonRootReplacement : public RewritePattern { 967 TestNonRootReplacement(MLIRContext *ctx) 968 : RewritePattern("test.replace_non_root", 1, ctx) {} 969 970 LogicalResult matchAndRewrite(Operation *op, 971 PatternRewriter &rewriter) const final { 972 auto resultType = *op->result_type_begin(); 973 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 974 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 975 976 rewriter.replaceOp(illegalOp, legalOp); 977 rewriter.replaceOp(op, illegalOp); 978 return success(); 979 } 980 }; 981 982 //===----------------------------------------------------------------------===// 983 // Recursive Rewrite Testing 984 /// This pattern is applied to the same operation multiple times, but has a 985 /// bounded recursion. 986 struct TestBoundedRecursiveRewrite 987 : public OpRewritePattern<TestRecursiveRewriteOp> { 988 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 989 990 void initialize() { 991 // The conversion target handles bounding the recursion of this pattern. 992 setHasBoundedRewriteRecursion(); 993 } 994 995 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 996 PatternRewriter &rewriter) const final { 997 // Decrement the depth of the op in-place. 998 rewriter.modifyOpInPlace(op, [&] { 999 op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1)); 1000 }); 1001 return success(); 1002 } 1003 }; 1004 1005 struct TestNestedOpCreationUndoRewrite 1006 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 1007 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 1008 1009 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 1010 PatternRewriter &rewriter) const final { 1011 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 1012 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 1013 return success(); 1014 }; 1015 }; 1016 1017 // This pattern matches `test.blackhole` and delete this op and its producer. 1018 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 1019 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 1020 1021 LogicalResult matchAndRewrite(BlackHoleOp op, 1022 PatternRewriter &rewriter) const final { 1023 Operation *producer = op.getOperand().getDefiningOp(); 1024 // Always erase the user before the producer, the framework should handle 1025 // this correctly. 1026 rewriter.eraseOp(op); 1027 rewriter.eraseOp(producer); 1028 return success(); 1029 }; 1030 }; 1031 1032 // This pattern replaces explicitly illegal op with explicitly legal op, 1033 // but in addition creates unregistered operation. 1034 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { 1035 using OpRewritePattern<ILLegalOpG>::OpRewritePattern; 1036 1037 LogicalResult matchAndRewrite(ILLegalOpG op, 1038 PatternRewriter &rewriter) const final { 1039 IntegerAttr attr = rewriter.getI32IntegerAttr(0); 1040 Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr); 1041 rewriter.replaceOpWithNewOp<LegalOpC>(op, val); 1042 return success(); 1043 }; 1044 }; 1045 } // namespace 1046 1047 namespace { 1048 struct TestTypeConverter : public TypeConverter { 1049 using TypeConverter::TypeConverter; 1050 TestTypeConverter() { 1051 addConversion(convertType); 1052 addArgumentMaterialization(materializeCast); 1053 addSourceMaterialization(materializeCast); 1054 } 1055 1056 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 1057 // Drop I16 types. 1058 if (t.isSignlessInteger(16)) 1059 return success(); 1060 1061 // Convert I64 to F64. 1062 if (t.isSignlessInteger(64)) { 1063 results.push_back(FloatType::getF64(t.getContext())); 1064 return success(); 1065 } 1066 1067 // Convert I42 to I43. 1068 if (t.isInteger(42)) { 1069 results.push_back(IntegerType::get(t.getContext(), 43)); 1070 return success(); 1071 } 1072 1073 // Split F32 into F16,F16. 1074 if (t.isF32()) { 1075 results.assign(2, FloatType::getF16(t.getContext())); 1076 return success(); 1077 } 1078 1079 // Otherwise, convert the type directly. 1080 results.push_back(t); 1081 return success(); 1082 } 1083 1084 /// Hook for materializing a conversion. This is necessary because we generate 1085 /// 1->N type mappings. 1086 static std::optional<Value> materializeCast(OpBuilder &builder, 1087 Type resultType, 1088 ValueRange inputs, Location loc) { 1089 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1090 } 1091 }; 1092 1093 struct TestLegalizePatternDriver 1094 : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> { 1095 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver) 1096 1097 StringRef getArgument() const final { return "test-legalize-patterns"; } 1098 StringRef getDescription() const final { 1099 return "Run test dialect legalization patterns"; 1100 } 1101 /// The mode of conversion to use with the driver. 1102 enum class ConversionMode { Analysis, Full, Partial }; 1103 1104 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 1105 1106 void getDependentDialects(DialectRegistry ®istry) const override { 1107 registry.insert<func::FuncDialect, test::TestDialect>(); 1108 } 1109 1110 void runOnOperation() override { 1111 TestTypeConverter converter; 1112 mlir::RewritePatternSet patterns(&getContext()); 1113 populateWithGenerated(patterns); 1114 patterns 1115 .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 1116 TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, 1117 TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, 1118 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 1119 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 1120 TestNonRootReplacement, TestBoundedRecursiveRewrite, 1121 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, 1122 TestCreateUnregisteredOp, TestUndoMoveOpBefore, 1123 TestUndoPropertiesModification>(&getContext()); 1124 patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); 1125 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1126 converter); 1127 mlir::populateCallOpTypeConversionPattern(patterns, converter); 1128 1129 // Define the conversion target used for the test. 1130 ConversionTarget target(getContext()); 1131 target.addLegalOp<ModuleOp>(); 1132 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, 1133 TerminatorOp, OneRegionOp>(); 1134 target 1135 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 1136 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 1137 // Don't allow F32 operands. 1138 return llvm::none_of(op.getOperandTypes(), 1139 [](Type type) { return type.isF32(); }); 1140 }); 1141 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1142 return converter.isSignatureLegal(op.getFunctionType()) && 1143 converter.isLegal(&op.getBody()); 1144 }); 1145 target.addDynamicallyLegalOp<func::CallOp>( 1146 [&](func::CallOp op) { return converter.isLegal(op); }); 1147 1148 // TestCreateUnregisteredOp creates `arith.constant` operation, 1149 // which was not added to target intentionally to test 1150 // correct error code from conversion driver. 1151 target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); 1152 1153 // Expect the type_producer/type_consumer operations to only operate on f64. 1154 target.addDynamicallyLegalOp<TestTypeProducerOp>( 1155 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1156 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1157 return op.getOperand().getType().isF64(); 1158 }); 1159 1160 // Check support for marking certain operations as recursively legal. 1161 target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) { 1162 return static_cast<bool>( 1163 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 1164 }); 1165 1166 // Mark the bound recursion operation as dynamically legal. 1167 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 1168 [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); 1169 1170 // Handle a partial conversion. 1171 if (mode == ConversionMode::Partial) { 1172 DenseSet<Operation *> unlegalizedOps; 1173 ConversionConfig config; 1174 DumpNotifications dumpNotifications; 1175 config.listener = &dumpNotifications; 1176 config.unlegalizedOps = &unlegalizedOps; 1177 if (failed(applyPartialConversion(getOperation(), target, 1178 std::move(patterns), config))) { 1179 getOperation()->emitRemark() << "applyPartialConversion failed"; 1180 } 1181 // Emit remarks for each legalizable operation. 1182 for (auto *op : unlegalizedOps) 1183 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1184 return; 1185 } 1186 1187 // Handle a full conversion. 1188 if (mode == ConversionMode::Full) { 1189 // Check support for marking unknown operations as dynamically legal. 1190 target.markUnknownOpDynamicallyLegal([](Operation *op) { 1191 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 1192 }); 1193 1194 ConversionConfig config; 1195 DumpNotifications dumpNotifications; 1196 config.listener = &dumpNotifications; 1197 if (failed(applyFullConversion(getOperation(), target, 1198 std::move(patterns), config))) { 1199 getOperation()->emitRemark() << "applyFullConversion failed"; 1200 } 1201 return; 1202 } 1203 1204 // Otherwise, handle an analysis conversion. 1205 assert(mode == ConversionMode::Analysis); 1206 1207 // Analyze the convertible operations. 1208 DenseSet<Operation *> legalizedOps; 1209 ConversionConfig config; 1210 config.legalizableOps = &legalizedOps; 1211 if (failed(applyAnalysisConversion(getOperation(), target, 1212 std::move(patterns), config))) 1213 return signalPassFailure(); 1214 1215 // Emit remarks for each legalizable operation. 1216 for (auto *op : legalizedOps) 1217 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 1218 } 1219 1220 /// The mode of conversion to use. 1221 ConversionMode mode; 1222 }; 1223 } // namespace 1224 1225 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 1226 legalizerConversionMode( 1227 "test-legalize-mode", 1228 llvm::cl::desc("The legalization mode to use with the test driver"), 1229 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 1230 llvm::cl::values( 1231 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 1232 "analysis", "Perform an analysis conversion"), 1233 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 1234 "Perform a full conversion"), 1235 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 1236 "partial", "Perform a partial conversion"))); 1237 1238 //===----------------------------------------------------------------------===// 1239 // ConversionPatternRewriter::getRemappedValue testing. This method is used 1240 // to get the remapped value of an original value that was replaced using 1241 // ConversionPatternRewriter. 1242 namespace { 1243 struct TestRemapValueTypeConverter : public TypeConverter { 1244 using TypeConverter::TypeConverter; 1245 1246 TestRemapValueTypeConverter() { 1247 addConversion( 1248 [](Float32Type type) { return Float64Type::get(type.getContext()); }); 1249 addConversion([](Type type) { return type; }); 1250 } 1251 }; 1252 1253 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 1254 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 1255 /// operand twice. 1256 /// 1257 /// Example: 1258 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 1259 /// is replaced with: 1260 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 1261 struct OneVResOneVOperandOp1Converter 1262 : public OpConversionPattern<OneVResOneVOperandOp1> { 1263 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 1264 1265 LogicalResult 1266 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, 1267 ConversionPatternRewriter &rewriter) const override { 1268 auto origOps = op.getOperands(); 1269 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 1270 "One operand expected"); 1271 Value origOp = *origOps.begin(); 1272 SmallVector<Value, 2> remappedOperands; 1273 // Replicate the remapped original operand twice. Note that we don't used 1274 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 1275 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1276 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1277 1278 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 1279 remappedOperands); 1280 return success(); 1281 } 1282 }; 1283 1284 /// A rewriter pattern that tests that blocks can be merged. 1285 struct TestRemapValueInRegion 1286 : public OpConversionPattern<TestRemappedValueRegionOp> { 1287 using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern; 1288 1289 LogicalResult 1290 matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor, 1291 ConversionPatternRewriter &rewriter) const final { 1292 Block &block = op.getBody().front(); 1293 Operation *terminator = block.getTerminator(); 1294 1295 // Merge the block into the parent region. 1296 Block *parentBlock = op->getBlock(); 1297 Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator()); 1298 rewriter.mergeBlocks(&block, parentBlock, ValueRange()); 1299 rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange()); 1300 1301 // Replace the results of this operation with the remapped terminator 1302 // values. 1303 SmallVector<Value> terminatorOperands; 1304 if (failed(rewriter.getRemappedValues(terminator->getOperands(), 1305 terminatorOperands))) 1306 return failure(); 1307 1308 rewriter.eraseOp(terminator); 1309 rewriter.replaceOp(op, terminatorOperands); 1310 return success(); 1311 } 1312 }; 1313 1314 struct TestRemappedValue 1315 : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> { 1316 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue) 1317 1318 StringRef getArgument() const final { return "test-remapped-value"; } 1319 StringRef getDescription() const final { 1320 return "Test public remapped value mechanism in ConversionPatternRewriter"; 1321 } 1322 void runOnOperation() override { 1323 TestRemapValueTypeConverter typeConverter; 1324 1325 mlir::RewritePatternSet patterns(&getContext()); 1326 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 1327 patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>( 1328 &getContext()); 1329 patterns.add<TestRemapValueInRegion>(typeConverter, &getContext()); 1330 1331 mlir::ConversionTarget target(getContext()); 1332 target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>(); 1333 1334 // Expect the type_producer/type_consumer operations to only operate on f64. 1335 target.addDynamicallyLegalOp<TestTypeProducerOp>( 1336 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1337 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1338 return op.getOperand().getType().isF64(); 1339 }); 1340 1341 // We make OneVResOneVOperandOp1 legal only when it has more that one 1342 // operand. This will trigger the conversion that will replace one-operand 1343 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 1344 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 1345 [](Operation *op) { return op->getNumOperands() > 1; }); 1346 1347 if (failed(mlir::applyFullConversion(getOperation(), target, 1348 std::move(patterns)))) { 1349 signalPassFailure(); 1350 } 1351 } 1352 }; 1353 } // namespace 1354 1355 //===----------------------------------------------------------------------===// 1356 // Test patterns without a specific root operation kind 1357 //===----------------------------------------------------------------------===// 1358 1359 namespace { 1360 /// This pattern matches and removes any operation in the test dialect. 1361 struct RemoveTestDialectOps : public RewritePattern { 1362 RemoveTestDialectOps(MLIRContext *context) 1363 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 1364 1365 LogicalResult matchAndRewrite(Operation *op, 1366 PatternRewriter &rewriter) const override { 1367 if (!isa<TestDialect>(op->getDialect())) 1368 return failure(); 1369 rewriter.eraseOp(op); 1370 return success(); 1371 } 1372 }; 1373 1374 struct TestUnknownRootOpDriver 1375 : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> { 1376 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver) 1377 1378 StringRef getArgument() const final { 1379 return "test-legalize-unknown-root-patterns"; 1380 } 1381 StringRef getDescription() const final { 1382 return "Test public remapped value mechanism in ConversionPatternRewriter"; 1383 } 1384 void runOnOperation() override { 1385 mlir::RewritePatternSet patterns(&getContext()); 1386 patterns.add<RemoveTestDialectOps>(&getContext()); 1387 1388 mlir::ConversionTarget target(getContext()); 1389 target.addIllegalDialect<TestDialect>(); 1390 if (failed(applyPartialConversion(getOperation(), target, 1391 std::move(patterns)))) 1392 signalPassFailure(); 1393 } 1394 }; 1395 } // namespace 1396 1397 //===----------------------------------------------------------------------===// 1398 // Test patterns that uses operations and types defined at runtime 1399 //===----------------------------------------------------------------------===// 1400 1401 namespace { 1402 /// This pattern matches dynamic operations 'test.one_operand_two_results' and 1403 /// replace them with dynamic operations 'test.generic_dynamic_op'. 1404 struct RewriteDynamicOp : public RewritePattern { 1405 RewriteDynamicOp(MLIRContext *context) 1406 : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1, 1407 context) {} 1408 1409 LogicalResult matchAndRewrite(Operation *op, 1410 PatternRewriter &rewriter) const override { 1411 assert(op->getName().getStringRef() == 1412 "test.dynamic_one_operand_two_results" && 1413 "rewrite pattern should only match operations with the right name"); 1414 1415 OperationState state(op->getLoc(), "test.dynamic_generic", 1416 op->getOperands(), op->getResultTypes(), 1417 op->getAttrs()); 1418 auto *newOp = rewriter.create(state); 1419 rewriter.replaceOp(op, newOp->getResults()); 1420 return success(); 1421 } 1422 }; 1423 1424 struct TestRewriteDynamicOpDriver 1425 : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> { 1426 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver) 1427 1428 void getDependentDialects(DialectRegistry ®istry) const override { 1429 registry.insert<TestDialect>(); 1430 } 1431 StringRef getArgument() const final { return "test-rewrite-dynamic-op"; } 1432 StringRef getDescription() const final { 1433 return "Test rewritting on dynamic operations"; 1434 } 1435 void runOnOperation() override { 1436 RewritePatternSet patterns(&getContext()); 1437 patterns.add<RewriteDynamicOp>(&getContext()); 1438 1439 ConversionTarget target(getContext()); 1440 target.addIllegalOp( 1441 OperationName("test.dynamic_one_operand_two_results", &getContext())); 1442 target.addLegalOp(OperationName("test.dynamic_generic", &getContext())); 1443 if (failed(applyPartialConversion(getOperation(), target, 1444 std::move(patterns)))) 1445 signalPassFailure(); 1446 } 1447 }; 1448 } // end anonymous namespace 1449 1450 //===----------------------------------------------------------------------===// 1451 // Test type conversions 1452 //===----------------------------------------------------------------------===// 1453 1454 namespace { 1455 struct TestTypeConversionProducer 1456 : public OpConversionPattern<TestTypeProducerOp> { 1457 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 1458 LogicalResult 1459 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, 1460 ConversionPatternRewriter &rewriter) const final { 1461 Type resultType = op.getType(); 1462 Type convertedType = getTypeConverter() 1463 ? getTypeConverter()->convertType(resultType) 1464 : resultType; 1465 if (isa<FloatType>(resultType)) 1466 resultType = rewriter.getF64Type(); 1467 else if (resultType.isInteger(16)) 1468 resultType = rewriter.getIntegerType(64); 1469 else if (isa<test::TestRecursiveType>(resultType) && 1470 convertedType != resultType) 1471 resultType = convertedType; 1472 else 1473 return failure(); 1474 1475 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 1476 return success(); 1477 } 1478 }; 1479 1480 /// Call signature conversion and then fail the rewrite to trigger the undo 1481 /// mechanism. 1482 struct TestSignatureConversionUndo 1483 : public OpConversionPattern<TestSignatureConversionUndoOp> { 1484 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 1485 1486 LogicalResult 1487 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, 1488 ConversionPatternRewriter &rewriter) const final { 1489 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 1490 return failure(); 1491 } 1492 }; 1493 1494 /// Call signature conversion without providing a type converter to handle 1495 /// materializations. 1496 struct TestTestSignatureConversionNoConverter 1497 : public OpConversionPattern<TestSignatureConversionNoConverterOp> { 1498 TestTestSignatureConversionNoConverter(const TypeConverter &converter, 1499 MLIRContext *context) 1500 : OpConversionPattern<TestSignatureConversionNoConverterOp>(context), 1501 converter(converter) {} 1502 1503 LogicalResult 1504 matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor, 1505 ConversionPatternRewriter &rewriter) const final { 1506 Region ®ion = op->getRegion(0); 1507 Block *entry = ®ion.front(); 1508 1509 // Convert the original entry arguments. 1510 TypeConverter::SignatureConversion result(entry->getNumArguments()); 1511 if (failed( 1512 converter.convertSignatureArgs(entry->getArgumentTypes(), result))) 1513 return failure(); 1514 rewriter.modifyOpInPlace( 1515 op, [&] { rewriter.applySignatureConversion(®ion, result); }); 1516 return success(); 1517 } 1518 1519 const TypeConverter &converter; 1520 }; 1521 1522 /// Just forward the operands to the root op. This is essentially a no-op 1523 /// pattern that is used to trigger target materialization. 1524 struct TestTypeConsumerForward 1525 : public OpConversionPattern<TestTypeConsumerOp> { 1526 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 1527 1528 LogicalResult 1529 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, 1530 ConversionPatternRewriter &rewriter) const final { 1531 rewriter.modifyOpInPlace(op, 1532 [&] { op->setOperands(adaptor.getOperands()); }); 1533 return success(); 1534 } 1535 }; 1536 1537 struct TestTypeConversionAnotherProducer 1538 : public OpRewritePattern<TestAnotherTypeProducerOp> { 1539 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 1540 1541 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 1542 PatternRewriter &rewriter) const final { 1543 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 1544 return success(); 1545 } 1546 }; 1547 1548 struct TestTypeConversionDriver 1549 : public PassWrapper<TestTypeConversionDriver, OperationPass<>> { 1550 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver) 1551 1552 void getDependentDialects(DialectRegistry ®istry) const override { 1553 registry.insert<TestDialect>(); 1554 } 1555 StringRef getArgument() const final { 1556 return "test-legalize-type-conversion"; 1557 } 1558 StringRef getDescription() const final { 1559 return "Test various type conversion functionalities in DialectConversion"; 1560 } 1561 1562 void runOnOperation() override { 1563 // Initialize the type converter. 1564 SmallVector<Type, 2> conversionCallStack; 1565 TypeConverter converter; 1566 1567 /// Add the legal set of type conversions. 1568 converter.addConversion([](Type type) -> Type { 1569 // Treat F64 as legal. 1570 if (type.isF64()) 1571 return type; 1572 // Allow converting BF16/F16/F32 to F64. 1573 if (type.isBF16() || type.isF16() || type.isF32()) 1574 return FloatType::getF64(type.getContext()); 1575 // Otherwise, the type is illegal. 1576 return nullptr; 1577 }); 1578 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 1579 // Drop all integer types. 1580 return success(); 1581 }); 1582 converter.addConversion( 1583 // Convert a recursive self-referring type into a non-self-referring 1584 // type named "outer_converted_type" that contains a SimpleAType. 1585 [&](test::TestRecursiveType type, 1586 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { 1587 // If the type is already converted, return it to indicate that it is 1588 // legal. 1589 if (type.getName() == "outer_converted_type") { 1590 results.push_back(type); 1591 return success(); 1592 } 1593 1594 conversionCallStack.push_back(type); 1595 auto popConversionCallStack = llvm::make_scope_exit( 1596 [&conversionCallStack]() { conversionCallStack.pop_back(); }); 1597 1598 // If the type is on the call stack more than once (it is there at 1599 // least once because of the _current_ call, which is always the last 1600 // element on the stack), we've hit the recursive case. Just return 1601 // SimpleAType here to create a non-recursive type as a result. 1602 if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(), 1603 type)) { 1604 results.push_back(test::SimpleAType::get(type.getContext())); 1605 return success(); 1606 } 1607 1608 // Convert the body recursively. 1609 auto result = test::TestRecursiveType::get(type.getContext(), 1610 "outer_converted_type"); 1611 if (failed(result.setBody(converter.convertType(type.getBody())))) 1612 return failure(); 1613 results.push_back(result); 1614 return success(); 1615 }); 1616 1617 /// Add the legal set of type materializations. 1618 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 1619 ValueRange inputs, 1620 Location loc) -> Value { 1621 // Allow casting from F64 back to F32. 1622 if (!resultType.isF16() && inputs.size() == 1 && 1623 inputs[0].getType().isF64()) 1624 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1625 // Allow producing an i32 or i64 from nothing. 1626 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 1627 inputs.empty()) 1628 return builder.create<TestTypeProducerOp>(loc, resultType); 1629 // Allow producing an i64 from an integer. 1630 if (isa<IntegerType>(resultType) && inputs.size() == 1 && 1631 isa<IntegerType>(inputs[0].getType())) 1632 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1633 // Otherwise, fail. 1634 return nullptr; 1635 }); 1636 1637 // Initialize the conversion target. 1638 mlir::ConversionTarget target(getContext()); 1639 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 1640 auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType()); 1641 return op.getType().isF64() || op.getType().isInteger(64) || 1642 (recursiveType && 1643 recursiveType.getName() == "outer_converted_type"); 1644 }); 1645 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1646 return converter.isSignatureLegal(op.getFunctionType()) && 1647 converter.isLegal(&op.getBody()); 1648 }); 1649 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 1650 // Allow casts from F64 to F32. 1651 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 1652 }); 1653 target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>( 1654 [&](TestSignatureConversionNoConverterOp op) { 1655 return converter.isLegal(op.getRegion().front().getArgumentTypes()); 1656 }); 1657 1658 // Initialize the set of rewrite patterns. 1659 RewritePatternSet patterns(&getContext()); 1660 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 1661 TestSignatureConversionUndo, 1662 TestTestSignatureConversionNoConverter>(converter, 1663 &getContext()); 1664 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 1665 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1666 converter); 1667 1668 if (failed(applyPartialConversion(getOperation(), target, 1669 std::move(patterns)))) 1670 signalPassFailure(); 1671 } 1672 }; 1673 } // namespace 1674 1675 //===----------------------------------------------------------------------===// 1676 // Test Target Materialization With No Uses 1677 //===----------------------------------------------------------------------===// 1678 1679 namespace { 1680 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> { 1681 using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern; 1682 1683 LogicalResult 1684 matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor, 1685 ConversionPatternRewriter &rewriter) const final { 1686 rewriter.replaceOp(op, adaptor.getOperands()); 1687 return success(); 1688 } 1689 }; 1690 1691 struct TestTargetMaterializationWithNoUses 1692 : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> { 1693 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 1694 TestTargetMaterializationWithNoUses) 1695 1696 StringRef getArgument() const final { 1697 return "test-target-materialization-with-no-uses"; 1698 } 1699 StringRef getDescription() const final { 1700 return "Test a special case of target materialization in DialectConversion"; 1701 } 1702 1703 void runOnOperation() override { 1704 TypeConverter converter; 1705 converter.addConversion([](Type t) { return t; }); 1706 converter.addConversion([](IntegerType intTy) -> Type { 1707 if (intTy.getWidth() == 16) 1708 return IntegerType::get(intTy.getContext(), 64); 1709 return intTy; 1710 }); 1711 converter.addTargetMaterialization( 1712 [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { 1713 return builder.create<TestCastOp>(loc, type, inputs).getResult(); 1714 }); 1715 1716 ConversionTarget target(getContext()); 1717 target.addIllegalOp<TestTypeChangerOp>(); 1718 1719 RewritePatternSet patterns(&getContext()); 1720 patterns.add<ForwardOperandPattern>(converter, &getContext()); 1721 1722 if (failed(applyPartialConversion(getOperation(), target, 1723 std::move(patterns)))) 1724 signalPassFailure(); 1725 } 1726 }; 1727 } // namespace 1728 1729 //===----------------------------------------------------------------------===// 1730 // Test Block Merging 1731 //===----------------------------------------------------------------------===// 1732 1733 namespace { 1734 /// A rewriter pattern that tests that blocks can be merged. 1735 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 1736 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 1737 1738 LogicalResult 1739 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, 1740 ConversionPatternRewriter &rewriter) const final { 1741 Block &firstBlock = op.getBody().front(); 1742 Operation *branchOp = firstBlock.getTerminator(); 1743 Block *secondBlock = &*(std::next(op.getBody().begin())); 1744 auto succOperands = branchOp->getOperands(); 1745 SmallVector<Value, 2> replacements(succOperands); 1746 rewriter.eraseOp(branchOp); 1747 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1748 rewriter.modifyOpInPlace(op, [] {}); 1749 return success(); 1750 } 1751 }; 1752 1753 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 1754 struct TestUndoBlocksMerge : public ConversionPattern { 1755 TestUndoBlocksMerge(MLIRContext *ctx) 1756 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 1757 LogicalResult 1758 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1759 ConversionPatternRewriter &rewriter) const final { 1760 Block &firstBlock = op->getRegion(0).front(); 1761 Operation *branchOp = firstBlock.getTerminator(); 1762 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 1763 rewriter.setInsertionPointToStart(secondBlock); 1764 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 1765 auto succOperands = branchOp->getOperands(); 1766 SmallVector<Value, 2> replacements(succOperands); 1767 rewriter.eraseOp(branchOp); 1768 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1769 rewriter.modifyOpInPlace(op, [] {}); 1770 return success(); 1771 } 1772 }; 1773 1774 /// A rewrite mechanism to inline the body of the op into its parent, when both 1775 /// ops can have a single block. 1776 struct TestMergeSingleBlockOps 1777 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 1778 using OpConversionPattern< 1779 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 1780 1781 LogicalResult 1782 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, 1783 ConversionPatternRewriter &rewriter) const final { 1784 SingleBlockImplicitTerminatorOp parentOp = 1785 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1786 if (!parentOp) 1787 return failure(); 1788 Block &innerBlock = op.getRegion().front(); 1789 TerminatorOp innerTerminator = 1790 cast<TerminatorOp>(innerBlock.getTerminator()); 1791 rewriter.inlineBlockBefore(&innerBlock, op); 1792 rewriter.eraseOp(innerTerminator); 1793 rewriter.eraseOp(op); 1794 return success(); 1795 } 1796 }; 1797 1798 struct TestMergeBlocksPatternDriver 1799 : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> { 1800 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver) 1801 1802 StringRef getArgument() const final { return "test-merge-blocks"; } 1803 StringRef getDescription() const final { 1804 return "Test Merging operation in ConversionPatternRewriter"; 1805 } 1806 void runOnOperation() override { 1807 MLIRContext *context = &getContext(); 1808 mlir::RewritePatternSet patterns(context); 1809 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 1810 context); 1811 ConversionTarget target(*context); 1812 target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 1813 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 1814 target.addIllegalOp<ILLegalOpF>(); 1815 1816 /// Expect the op to have a single block after legalization. 1817 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 1818 [&](TestMergeBlocksOp op) -> bool { 1819 return llvm::hasSingleElement(op.getBody()); 1820 }); 1821 1822 /// Only allow `test.br` within test.merge_blocks op. 1823 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 1824 return op->getParentOfType<TestMergeBlocksOp>(); 1825 }); 1826 1827 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 1828 /// inlined. 1829 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 1830 [&](SingleBlockImplicitTerminatorOp op) -> bool { 1831 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1832 }); 1833 1834 DenseSet<Operation *> unlegalizedOps; 1835 ConversionConfig config; 1836 config.unlegalizedOps = &unlegalizedOps; 1837 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 1838 config); 1839 for (auto *op : unlegalizedOps) 1840 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1841 } 1842 }; 1843 } // namespace 1844 1845 //===----------------------------------------------------------------------===// 1846 // Test Selective Replacement 1847 //===----------------------------------------------------------------------===// 1848 1849 namespace { 1850 /// A rewrite mechanism to inline the body of the op into its parent, when both 1851 /// ops can have a single block. 1852 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 1853 using OpRewritePattern<TestCastOp>::OpRewritePattern; 1854 1855 LogicalResult matchAndRewrite(TestCastOp op, 1856 PatternRewriter &rewriter) const final { 1857 if (op.getNumOperands() != 2) 1858 return failure(); 1859 OperandRange operands = op.getOperands(); 1860 1861 // Replace non-terminator uses with the first operand. 1862 rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) { 1863 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 1864 }); 1865 // Replace everything else with the second operand if the operation isn't 1866 // dead. 1867 rewriter.replaceOp(op, op.getOperand(1)); 1868 return success(); 1869 } 1870 }; 1871 1872 struct TestSelectiveReplacementPatternDriver 1873 : public PassWrapper<TestSelectiveReplacementPatternDriver, 1874 OperationPass<>> { 1875 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 1876 TestSelectiveReplacementPatternDriver) 1877 1878 StringRef getArgument() const final { 1879 return "test-pattern-selective-replacement"; 1880 } 1881 StringRef getDescription() const final { 1882 return "Test selective replacement in the PatternRewriter"; 1883 } 1884 void runOnOperation() override { 1885 MLIRContext *context = &getContext(); 1886 mlir::RewritePatternSet patterns(context); 1887 patterns.add<TestSelectiveOpReplacementPattern>(context); 1888 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 1889 } 1890 }; 1891 } // namespace 1892 1893 //===----------------------------------------------------------------------===// 1894 // PassRegistration 1895 //===----------------------------------------------------------------------===// 1896 1897 namespace mlir { 1898 namespace test { 1899 void registerPatternsTestPass() { 1900 PassRegistration<TestReturnTypeDriver>(); 1901 1902 PassRegistration<TestDerivedAttributeDriver>(); 1903 1904 PassRegistration<TestPatternDriver>(); 1905 PassRegistration<TestStrictPatternDriver>(); 1906 1907 PassRegistration<TestLegalizePatternDriver>([] { 1908 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 1909 }); 1910 1911 PassRegistration<TestRemappedValue>(); 1912 1913 PassRegistration<TestUnknownRootOpDriver>(); 1914 1915 PassRegistration<TestTypeConversionDriver>(); 1916 PassRegistration<TestTargetMaterializationWithNoUses>(); 1917 1918 PassRegistration<TestRewriteDynamicOpDriver>(); 1919 1920 PassRegistration<TestMergeBlocksPatternDriver>(); 1921 PassRegistration<TestSelectiveReplacementPatternDriver>(); 1922 } 1923 } // namespace test 1924 } // namespace mlir 1925