1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestTypes.h" 11 #include "mlir/Dialect/Arith/IR/Arith.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/Matchers.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 #include "mlir/Transforms/FoldUtils.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "llvm/ADT/ScopeExit.h" 21 22 using namespace mlir; 23 using namespace test; 24 25 // Native function for testing NativeCodeCall 26 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 27 return choice.getValue() ? input1 : input2; 28 } 29 30 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 31 rewriter.create<OpI>(loc, input); 32 } 33 34 static void handleNoResultOp(PatternRewriter &rewriter, 35 OpSymbolBindingNoResult op) { 36 // Turn the no result op to a one-result op. 37 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), 38 op.getOperand()); 39 } 40 41 static bool getFirstI32Result(Operation *op, Value &value) { 42 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 43 return false; 44 value = op->getResult(0); 45 return true; 46 } 47 48 static Value bindNativeCodeCallResult(Value value) { return value; } 49 50 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, 51 Value input2) { 52 return SmallVector<Value, 2>({input2, input1}); 53 } 54 55 // Test that natives calls are only called once during rewrites. 56 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 57 // This let us check the number of times OpM_Test was called by inspecting 58 // the returned value in the MLIR output. 59 static int64_t opMIncreasingValue = 314159265; 60 static Attribute opMTest(PatternRewriter &rewriter, Value val) { 61 int64_t i = opMIncreasingValue++; 62 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 63 } 64 65 namespace { 66 #include "TestPatterns.inc" 67 } // namespace 68 69 //===----------------------------------------------------------------------===// 70 // Test Reduce Pattern Interface 71 //===----------------------------------------------------------------------===// 72 73 void test::populateTestReductionPatterns(RewritePatternSet &patterns) { 74 populateWithGenerated(patterns); 75 } 76 77 //===----------------------------------------------------------------------===// 78 // Canonicalizer Driver. 79 //===----------------------------------------------------------------------===// 80 81 namespace { 82 struct FoldingPattern : public RewritePattern { 83 public: 84 FoldingPattern(MLIRContext *context) 85 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 86 /*benefit=*/1, context) {} 87 88 LogicalResult matchAndRewrite(Operation *op, 89 PatternRewriter &rewriter) const override { 90 // Exercise createOrFold API for a single-result operation that is folded 91 // upon construction. The operation being created has an in-place folder, 92 // and it should be still present in the output. Furthermore, the folder 93 // should not crash when attempting to recover the (unchanged) operation 94 // result. 95 Value result = rewriter.createOrFold<TestOpInPlaceFold>( 96 op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0)); 97 assert(result); 98 rewriter.replaceOp(op, result); 99 return success(); 100 } 101 }; 102 103 /// This pattern creates a foldable operation at the entry point of the block. 104 /// This tests the situation where the operation folder will need to replace an 105 /// operation with a previously created constant that does not initially 106 /// dominate the operation to replace. 107 struct FolderInsertBeforePreviouslyFoldedConstantPattern 108 : public OpRewritePattern<TestCastOp> { 109 public: 110 using OpRewritePattern<TestCastOp>::OpRewritePattern; 111 112 LogicalResult matchAndRewrite(TestCastOp op, 113 PatternRewriter &rewriter) const override { 114 if (!op->hasAttr("test_fold_before_previously_folded_op")) 115 return failure(); 116 rewriter.setInsertionPointToStart(op->getBlock()); 117 118 auto constOp = rewriter.create<arith::ConstantOp>( 119 op.getLoc(), rewriter.getBoolAttr(true)); 120 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), 121 Value(constOp)); 122 return success(); 123 } 124 }; 125 126 /// This pattern matches test.op_commutative2 with the first operand being 127 /// another test.op_commutative2 with a constant on the right side and fold it 128 /// away by propagating it as its result. This is intend to check that patterns 129 /// are applied after the commutative property moves constant to the right. 130 struct FolderCommutativeOp2WithConstant 131 : public OpRewritePattern<TestCommutative2Op> { 132 public: 133 using OpRewritePattern<TestCommutative2Op>::OpRewritePattern; 134 135 LogicalResult matchAndRewrite(TestCommutative2Op op, 136 PatternRewriter &rewriter) const override { 137 auto operand = 138 dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp()); 139 if (!operand) 140 return failure(); 141 Attribute constInput; 142 if (!matchPattern(operand->getOperand(1), m_Constant(&constInput))) 143 return failure(); 144 rewriter.replaceOp(op, operand->getOperand(1)); 145 return success(); 146 } 147 }; 148 149 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer 150 /// attribute with value smaller than MaxVal, it increments the value by 1. 151 template <int MaxVal> 152 struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> { 153 using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern; 154 155 LogicalResult matchAndRewrite(AnyAttrOfOp op, 156 PatternRewriter &rewriter) const override { 157 auto intAttr = dyn_cast<IntegerAttr>(op.getAttr()); 158 if (!intAttr) 159 return failure(); 160 int64_t val = intAttr.getInt(); 161 if (val >= MaxVal) 162 return failure(); 163 rewriter.modifyOpInPlace( 164 op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); }); 165 return success(); 166 } 167 }; 168 169 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op". 170 struct MakeOpEligible : public RewritePattern { 171 MakeOpEligible(MLIRContext *context) 172 : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {} 173 174 LogicalResult matchAndRewrite(Operation *op, 175 PatternRewriter &rewriter) const override { 176 if (op->hasAttr("eligible")) 177 return failure(); 178 rewriter.modifyOpInPlace( 179 op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); }); 180 return success(); 181 } 182 }; 183 184 /// This pattern hoists eligible ops out of a "test.one_region_op". 185 struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> { 186 using OpRewritePattern<test::OneRegionOp>::OpRewritePattern; 187 188 LogicalResult matchAndRewrite(test::OneRegionOp op, 189 PatternRewriter &rewriter) const override { 190 Operation *terminator = op.getRegion().front().getTerminator(); 191 Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp(); 192 if (toBeHoisted->getParentOp() != op) 193 return failure(); 194 if (!toBeHoisted->hasAttr("eligible")) 195 return failure(); 196 rewriter.moveOpBefore(toBeHoisted, op); 197 return success(); 198 } 199 }; 200 201 /// This pattern moves "test.move_before_parent_op" before the parent op. 202 struct MoveBeforeParentOp : public RewritePattern { 203 MoveBeforeParentOp(MLIRContext *context) 204 : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {} 205 206 LogicalResult matchAndRewrite(Operation *op, 207 PatternRewriter &rewriter) const override { 208 // Do not hoist past functions. 209 if (isa<FunctionOpInterface>(op->getParentOp())) 210 return failure(); 211 rewriter.moveOpBefore(op, op->getParentOp()); 212 return success(); 213 } 214 }; 215 216 /// This pattern inlines blocks that are nested in 217 /// "test.inline_blocks_into_parent" into the parent block. 218 struct InlineBlocksIntoParent : public RewritePattern { 219 InlineBlocksIntoParent(MLIRContext *context) 220 : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1, 221 context) {} 222 223 LogicalResult matchAndRewrite(Operation *op, 224 PatternRewriter &rewriter) const override { 225 bool changed = false; 226 for (Region &r : op->getRegions()) { 227 while (!r.empty()) { 228 rewriter.inlineBlockBefore(&r.front(), op); 229 changed = true; 230 } 231 } 232 return success(changed); 233 } 234 }; 235 236 /// This pattern splits blocks at "test.split_block_here" and replaces the op 237 /// with a new op (to prevent an infinite loop of block splitting). 238 struct SplitBlockHere : public RewritePattern { 239 SplitBlockHere(MLIRContext *context) 240 : RewritePattern("test.split_block_here", /*benefit=*/1, context) {} 241 242 LogicalResult matchAndRewrite(Operation *op, 243 PatternRewriter &rewriter) const override { 244 rewriter.splitBlock(op->getBlock(), op->getIterator()); 245 Operation *newOp = rewriter.create( 246 op->getLoc(), 247 OperationName("test.new_op", op->getContext()).getIdentifier(), 248 op->getOperands(), op->getResultTypes()); 249 rewriter.replaceOp(op, newOp); 250 return success(); 251 } 252 }; 253 254 /// This pattern clones "test.clone_me" ops. 255 struct CloneOp : public RewritePattern { 256 CloneOp(MLIRContext *context) 257 : RewritePattern("test.clone_me", /*benefit=*/1, context) {} 258 259 LogicalResult matchAndRewrite(Operation *op, 260 PatternRewriter &rewriter) const override { 261 // Do not clone already cloned ops to avoid going into an infinite loop. 262 if (op->hasAttr("was_cloned")) 263 return failure(); 264 Operation *cloned = rewriter.clone(*op); 265 cloned->setAttr("was_cloned", rewriter.getUnitAttr()); 266 return success(); 267 } 268 }; 269 270 /// This pattern clones regions of "test.clone_region_before" ops before the 271 /// parent block. 272 struct CloneRegionBeforeOp : public RewritePattern { 273 CloneRegionBeforeOp(MLIRContext *context) 274 : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {} 275 276 LogicalResult matchAndRewrite(Operation *op, 277 PatternRewriter &rewriter) const override { 278 // Do not clone already cloned ops to avoid going into an infinite loop. 279 if (op->hasAttr("was_cloned")) 280 return failure(); 281 for (Region &r : op->getRegions()) 282 rewriter.cloneRegionBefore(r, op->getBlock()); 283 op->setAttr("was_cloned", rewriter.getUnitAttr()); 284 return success(); 285 } 286 }; 287 288 struct TestPatternDriver 289 : public PassWrapper<TestPatternDriver, OperationPass<>> { 290 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver) 291 292 TestPatternDriver() = default; 293 TestPatternDriver(const TestPatternDriver &other) : PassWrapper(other) {} 294 295 StringRef getArgument() const final { return "test-patterns"; } 296 StringRef getDescription() const final { return "Run test dialect patterns"; } 297 void runOnOperation() override { 298 mlir::RewritePatternSet patterns(&getContext()); 299 populateWithGenerated(patterns); 300 301 // Verify named pattern is generated with expected name. 302 patterns.add<FoldingPattern, TestNamedPatternRule, 303 FolderInsertBeforePreviouslyFoldedConstantPattern, 304 FolderCommutativeOp2WithConstant, HoistEligibleOps, 305 MakeOpEligible>(&getContext()); 306 307 // Additional patterns for testing the GreedyPatternRewriteDriver. 308 patterns.insert<IncrementIntAttribute<3>>(&getContext()); 309 310 GreedyRewriteConfig config; 311 config.useTopDownTraversal = this->useTopDownTraversal; 312 config.maxIterations = this->maxIterations; 313 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), 314 config); 315 } 316 317 Option<bool> useTopDownTraversal{ 318 *this, "top-down", 319 llvm::cl::desc("Seed the worklist in general top-down order"), 320 llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)}; 321 Option<int> maxIterations{ 322 *this, "max-iterations", 323 llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"), 324 llvm::cl::init(GreedyRewriteConfig().maxIterations)}; 325 }; 326 327 struct DumpNotifications : public RewriterBase::Listener { 328 void notifyBlockInserted(Block *block, Region *previous, 329 Region::iterator previousIt) override { 330 llvm::outs() << "notifyBlockInserted"; 331 if (block->getParentOp()) { 332 llvm::outs() << " into " << block->getParentOp()->getName() << ": "; 333 } else { 334 llvm::outs() << " into unknown op: "; 335 } 336 if (previous == nullptr) { 337 llvm::outs() << "was unlinked\n"; 338 } else { 339 llvm::outs() << "was linked\n"; 340 } 341 } 342 void notifyOperationInserted(Operation *op, 343 OpBuilder::InsertPoint previous) override { 344 llvm::outs() << "notifyOperationInserted: " << op->getName(); 345 if (!previous.isSet()) { 346 llvm::outs() << ", was unlinked\n"; 347 } else { 348 if (!previous.getPoint().getNodePtr()) { 349 llvm::outs() << ", was linked, exact position unknown\n"; 350 } else if (previous.getPoint() == previous.getBlock()->end()) { 351 llvm::outs() << ", was last in block\n"; 352 } else { 353 llvm::outs() << ", previous = " << previous.getPoint()->getName() 354 << "\n"; 355 } 356 } 357 } 358 void notifyBlockErased(Block *block) override { 359 llvm::outs() << "notifyBlockErased\n"; 360 } 361 void notifyOperationErased(Operation *op) override { 362 llvm::outs() << "notifyOperationErased: " << op->getName() << "\n"; 363 } 364 void notifyOperationModified(Operation *op) override { 365 llvm::outs() << "notifyOperationModified: " << op->getName() << "\n"; 366 } 367 void notifyOperationReplaced(Operation *op, ValueRange values) override { 368 llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n"; 369 } 370 }; 371 372 struct TestStrictPatternDriver 373 : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> { 374 public: 375 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver) 376 377 TestStrictPatternDriver() = default; 378 TestStrictPatternDriver(const TestStrictPatternDriver &other) 379 : PassWrapper(other) { 380 strictMode = other.strictMode; 381 } 382 383 StringRef getArgument() const final { return "test-strict-pattern-driver"; } 384 StringRef getDescription() const final { 385 return "Test strict mode of pattern driver"; 386 } 387 388 void runOnOperation() override { 389 MLIRContext *ctx = &getContext(); 390 mlir::RewritePatternSet patterns(ctx); 391 patterns.add< 392 // clang-format off 393 ChangeBlockOp, 394 CloneOp, 395 CloneRegionBeforeOp, 396 EraseOp, 397 ImplicitChangeOp, 398 InlineBlocksIntoParent, 399 InsertSameOp, 400 MoveBeforeParentOp, 401 ReplaceWithNewOp, 402 SplitBlockHere 403 // clang-format on 404 >(ctx); 405 SmallVector<Operation *> ops; 406 getOperation()->walk([&](Operation *op) { 407 StringRef opName = op->getName().getStringRef(); 408 if (opName == "test.insert_same_op" || opName == "test.change_block_op" || 409 opName == "test.replace_with_new_op" || opName == "test.erase_op" || 410 opName == "test.move_before_parent_op" || 411 opName == "test.inline_blocks_into_parent" || 412 opName == "test.split_block_here" || opName == "test.clone_me" || 413 opName == "test.clone_region_before") { 414 ops.push_back(op); 415 } 416 }); 417 418 DumpNotifications dumpNotifications; 419 GreedyRewriteConfig config; 420 config.listener = &dumpNotifications; 421 if (strictMode == "AnyOp") { 422 config.strictMode = GreedyRewriteStrictness::AnyOp; 423 } else if (strictMode == "ExistingAndNewOps") { 424 config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; 425 } else if (strictMode == "ExistingOps") { 426 config.strictMode = GreedyRewriteStrictness::ExistingOps; 427 } else { 428 llvm_unreachable("invalid strictness option"); 429 } 430 431 // Check if these transformations introduce visiting of operations that 432 // are not in the `ops` set (The new created ops are valid). An invalid 433 // operation will trigger the assertion while processing. 434 bool changed = false; 435 bool allErased = false; 436 (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config, 437 &changed, &allErased); 438 Builder b(ctx); 439 getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed)); 440 getOperation()->setAttr("pattern_driver_all_erased", 441 b.getBoolAttr(allErased)); 442 } 443 444 Option<std::string> strictMode{ 445 *this, "strictness", 446 llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"), 447 llvm::cl::init("AnyOp")}; 448 449 private: 450 // New inserted operation is valid for further transformation. 451 class InsertSameOp : public RewritePattern { 452 public: 453 InsertSameOp(MLIRContext *context) 454 : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {} 455 456 LogicalResult matchAndRewrite(Operation *op, 457 PatternRewriter &rewriter) const override { 458 if (op->hasAttr("skip")) 459 return failure(); 460 461 Operation *newOp = 462 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 463 op->getOperands(), op->getResultTypes()); 464 rewriter.modifyOpInPlace( 465 op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); }); 466 newOp->setAttr("skip", rewriter.getBoolAttr(true)); 467 468 return success(); 469 } 470 }; 471 472 // Replace an operation may introduce the re-visiting of its users. 473 class ReplaceWithNewOp : public RewritePattern { 474 public: 475 ReplaceWithNewOp(MLIRContext *context) 476 : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {} 477 478 LogicalResult matchAndRewrite(Operation *op, 479 PatternRewriter &rewriter) const override { 480 Operation *newOp; 481 if (op->hasAttr("create_erase_op")) { 482 newOp = rewriter.create( 483 op->getLoc(), 484 OperationName("test.erase_op", op->getContext()).getIdentifier(), 485 ValueRange(), TypeRange()); 486 } else { 487 newOp = rewriter.create( 488 op->getLoc(), 489 OperationName("test.new_op", op->getContext()).getIdentifier(), 490 op->getOperands(), op->getResultTypes()); 491 } 492 rewriter.replaceOp(op, newOp->getResults()); 493 return success(); 494 } 495 }; 496 497 // Remove an operation may introduce the re-visiting of its operands. 498 class EraseOp : public RewritePattern { 499 public: 500 EraseOp(MLIRContext *context) 501 : RewritePattern("test.erase_op", /*benefit=*/1, context) {} 502 LogicalResult matchAndRewrite(Operation *op, 503 PatternRewriter &rewriter) const override { 504 rewriter.eraseOp(op); 505 return success(); 506 } 507 }; 508 509 // The following two patterns test RewriterBase::replaceAllUsesWith. 510 // 511 // That function replaces all usages of a Block (or a Value) with another one 512 // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver 513 // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its 514 // worklist: when an op is modified, it is added to the worklist. The two 515 // patterns below make the tracking observable: ChangeBlockOp replaces all 516 // usages of a block and that pattern is applied because the corresponding ops 517 // are put on the initial worklist (see above). ImplicitChangeOp does an 518 // unrelated change but ops of the corresponding type are *not* on the initial 519 // worklist, so the effect of the second pattern is only visible if the 520 // tracking and subsequent adding to the worklist actually works. 521 522 // Replace all usages of the first successor with the second successor. 523 class ChangeBlockOp : public RewritePattern { 524 public: 525 ChangeBlockOp(MLIRContext *context) 526 : RewritePattern("test.change_block_op", /*benefit=*/1, context) {} 527 LogicalResult matchAndRewrite(Operation *op, 528 PatternRewriter &rewriter) const override { 529 if (op->getNumSuccessors() < 2) 530 return failure(); 531 Block *firstSuccessor = op->getSuccessor(0); 532 Block *secondSuccessor = op->getSuccessor(1); 533 if (firstSuccessor == secondSuccessor) 534 return failure(); 535 // This is the function being tested: 536 rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor); 537 // Using the following line instead would make the test fail: 538 // firstSuccessor->replaceAllUsesWith(secondSuccessor); 539 return success(); 540 } 541 }; 542 543 // Changes the successor to the parent block. 544 class ImplicitChangeOp : public RewritePattern { 545 public: 546 ImplicitChangeOp(MLIRContext *context) 547 : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {} 548 LogicalResult matchAndRewrite(Operation *op, 549 PatternRewriter &rewriter) const override { 550 if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock()) 551 return failure(); 552 rewriter.modifyOpInPlace(op, 553 [&]() { op->setSuccessor(op->getBlock(), 0); }); 554 return success(); 555 } 556 }; 557 }; 558 559 } // namespace 560 561 //===----------------------------------------------------------------------===// 562 // ReturnType Driver. 563 //===----------------------------------------------------------------------===// 564 565 namespace { 566 // Generate ops for each instance where the type can be successfully inferred. 567 template <typename OpTy> 568 static void invokeCreateWithInferredReturnType(Operation *op) { 569 auto *context = op->getContext(); 570 auto fop = op->getParentOfType<func::FuncOp>(); 571 auto location = UnknownLoc::get(context); 572 OpBuilder b(op); 573 b.setInsertionPointAfter(op); 574 575 // Use permutations of 2 args as operands. 576 assert(fop.getNumArguments() >= 2); 577 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 578 for (int j = 0; j < e; ++j) { 579 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 580 SmallVector<Type, 2> inferredReturnTypes; 581 if (succeeded(OpTy::inferReturnTypes( 582 context, std::nullopt, values, op->getDiscardableAttrDictionary(), 583 op->getPropertiesStorage(), op->getRegions(), 584 inferredReturnTypes))) { 585 OperationState state(location, OpTy::getOperationName()); 586 // TODO: Expand to regions. 587 OpTy::build(b, state, values, op->getAttrs()); 588 (void)b.create(state); 589 } 590 } 591 } 592 } 593 594 static void reifyReturnShape(Operation *op) { 595 OpBuilder b(op); 596 597 // Use permutations of 2 args as operands. 598 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 599 SmallVector<Value, 2> shapes; 600 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 601 !llvm::hasSingleElement(shapes)) 602 return; 603 for (const auto &it : llvm::enumerate(shapes)) { 604 op->emitRemark() << "value " << it.index() << ": " 605 << it.value().getDefiningOp(); 606 } 607 } 608 609 struct TestReturnTypeDriver 610 : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> { 611 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver) 612 613 void getDependentDialects(DialectRegistry ®istry) const override { 614 registry.insert<tensor::TensorDialect>(); 615 } 616 StringRef getArgument() const final { return "test-return-type"; } 617 StringRef getDescription() const final { return "Run return type functions"; } 618 619 void runOnOperation() override { 620 if (getOperation().getName() == "testCreateFunctions") { 621 std::vector<Operation *> ops; 622 // Collect ops to avoid triggering on inserted ops. 623 for (auto &op : getOperation().getBody().front()) 624 ops.push_back(&op); 625 // Generate test patterns for each, but skip terminator. 626 for (auto *op : llvm::ArrayRef(ops).drop_back()) { 627 // Test create method of each of the Op classes below. The resultant 628 // output would be in reverse order underneath `op` from which 629 // the attributes and regions are used. 630 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 631 invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>( 632 op); 633 invokeCreateWithInferredReturnType< 634 OpWithShapedTypeInferTypeInterfaceOp>(op); 635 }; 636 return; 637 } 638 if (getOperation().getName() == "testReifyFunctions") { 639 std::vector<Operation *> ops; 640 // Collect ops to avoid triggering on inserted ops. 641 for (auto &op : getOperation().getBody().front()) 642 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 643 ops.push_back(&op); 644 // Generate test patterns for each, but skip terminator. 645 for (auto *op : ops) 646 reifyReturnShape(op); 647 } 648 } 649 }; 650 } // namespace 651 652 namespace { 653 struct TestDerivedAttributeDriver 654 : public PassWrapper<TestDerivedAttributeDriver, 655 OperationPass<func::FuncOp>> { 656 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver) 657 658 StringRef getArgument() const final { return "test-derived-attr"; } 659 StringRef getDescription() const final { 660 return "Run test derived attributes"; 661 } 662 void runOnOperation() override; 663 }; 664 } // namespace 665 666 void TestDerivedAttributeDriver::runOnOperation() { 667 getOperation().walk([](DerivedAttributeOpInterface dOp) { 668 auto dAttr = dOp.materializeDerivedAttributes(); 669 if (!dAttr) 670 return; 671 for (auto d : dAttr) 672 dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue(); 673 }); 674 } 675 676 //===----------------------------------------------------------------------===// 677 // Legalization Driver. 678 //===----------------------------------------------------------------------===// 679 680 namespace { 681 //===----------------------------------------------------------------------===// 682 // Region-Block Rewrite Testing 683 684 /// This pattern is a simple pattern that inlines the first region of a given 685 /// operation into the parent region. 686 struct TestRegionRewriteBlockMovement : public ConversionPattern { 687 TestRegionRewriteBlockMovement(MLIRContext *ctx) 688 : ConversionPattern("test.region", 1, ctx) {} 689 690 LogicalResult 691 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 692 ConversionPatternRewriter &rewriter) const final { 693 // Inline this region into the parent region. 694 auto &parentRegion = *op->getParentRegion(); 695 auto &opRegion = op->getRegion(0); 696 if (op->getDiscardableAttr("legalizer.should_clone")) 697 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 698 else 699 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 700 701 if (op->getDiscardableAttr("legalizer.erase_old_blocks")) { 702 while (!opRegion.empty()) 703 rewriter.eraseBlock(&opRegion.front()); 704 } 705 706 // Drop this operation. 707 rewriter.eraseOp(op); 708 return success(); 709 } 710 }; 711 /// This pattern is a simple pattern that generates a region containing an 712 /// illegal operation. 713 struct TestRegionRewriteUndo : public RewritePattern { 714 TestRegionRewriteUndo(MLIRContext *ctx) 715 : RewritePattern("test.region_builder", 1, ctx) {} 716 717 LogicalResult matchAndRewrite(Operation *op, 718 PatternRewriter &rewriter) const final { 719 // Create the region operation with an entry block containing arguments. 720 OperationState newRegion(op->getLoc(), "test.region"); 721 newRegion.addRegion(); 722 auto *regionOp = rewriter.create(newRegion); 723 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 724 entryBlock->addArgument(rewriter.getIntegerType(64), 725 rewriter.getUnknownLoc()); 726 727 // Add an explicitly illegal operation to ensure the conversion fails. 728 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 729 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 730 731 // Drop this operation. 732 rewriter.eraseOp(op); 733 return success(); 734 } 735 }; 736 /// A simple pattern that creates a block at the end of the parent region of the 737 /// matched operation. 738 struct TestCreateBlock : public RewritePattern { 739 TestCreateBlock(MLIRContext *ctx) 740 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 741 742 LogicalResult matchAndRewrite(Operation *op, 743 PatternRewriter &rewriter) const final { 744 Region ®ion = *op->getParentRegion(); 745 Type i32Type = rewriter.getIntegerType(32); 746 Location loc = op->getLoc(); 747 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 748 rewriter.create<TerminatorOp>(loc); 749 rewriter.eraseOp(op); 750 return success(); 751 } 752 }; 753 754 /// A simple pattern that creates a block containing an invalid operation in 755 /// order to trigger the block creation undo mechanism. 756 struct TestCreateIllegalBlock : public RewritePattern { 757 TestCreateIllegalBlock(MLIRContext *ctx) 758 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 759 760 LogicalResult matchAndRewrite(Operation *op, 761 PatternRewriter &rewriter) const final { 762 Region ®ion = *op->getParentRegion(); 763 Type i32Type = rewriter.getIntegerType(32); 764 Location loc = op->getLoc(); 765 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 766 // Create an illegal op to ensure the conversion fails. 767 rewriter.create<ILLegalOpF>(loc, i32Type); 768 rewriter.create<TerminatorOp>(loc); 769 rewriter.eraseOp(op); 770 return success(); 771 } 772 }; 773 774 /// A simple pattern that tests the undo mechanism when replacing the uses of a 775 /// block argument. 776 struct TestUndoBlockArgReplace : public ConversionPattern { 777 TestUndoBlockArgReplace(MLIRContext *ctx) 778 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 779 780 LogicalResult 781 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 782 ConversionPatternRewriter &rewriter) const final { 783 auto illegalOp = 784 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 785 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 786 illegalOp->getResult(0)); 787 rewriter.modifyOpInPlace(op, [] {}); 788 return success(); 789 } 790 }; 791 792 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion. 793 /// This is to test the rollback logic. 794 struct TestUndoMoveOpBefore : public ConversionPattern { 795 TestUndoMoveOpBefore(MLIRContext *ctx) 796 : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {} 797 798 LogicalResult 799 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 800 ConversionPatternRewriter &rewriter) const override { 801 rewriter.moveOpBefore(op, op->getParentOp()); 802 // Replace with an illegal op to ensure the conversion fails. 803 rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type()); 804 return success(); 805 } 806 }; 807 808 /// A rewrite pattern that tests the undo mechanism when erasing a block. 809 struct TestUndoBlockErase : public ConversionPattern { 810 TestUndoBlockErase(MLIRContext *ctx) 811 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 812 813 LogicalResult 814 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 815 ConversionPatternRewriter &rewriter) const final { 816 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 817 rewriter.setInsertionPointToStart(secondBlock); 818 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 819 rewriter.eraseBlock(secondBlock); 820 rewriter.modifyOpInPlace(op, [] {}); 821 return success(); 822 } 823 }; 824 825 /// A pattern that modifies a property in-place, but keeps the op illegal. 826 struct TestUndoPropertiesModification : public ConversionPattern { 827 TestUndoPropertiesModification(MLIRContext *ctx) 828 : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {} 829 LogicalResult 830 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 831 ConversionPatternRewriter &rewriter) const final { 832 if (!op->hasAttr("modify_inplace")) 833 return failure(); 834 rewriter.modifyOpInPlace( 835 op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); }); 836 return success(); 837 } 838 }; 839 840 //===----------------------------------------------------------------------===// 841 // Type-Conversion Rewrite Testing 842 843 /// This patterns erases a region operation that has had a type conversion. 844 struct TestDropOpSignatureConversion : public ConversionPattern { 845 TestDropOpSignatureConversion(MLIRContext *ctx, 846 const TypeConverter &converter) 847 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 848 LogicalResult 849 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 850 ConversionPatternRewriter &rewriter) const override { 851 Region ®ion = op->getRegion(0); 852 Block *entry = ®ion.front(); 853 854 // Convert the original entry arguments. 855 const TypeConverter &converter = *getTypeConverter(); 856 TypeConverter::SignatureConversion result(entry->getNumArguments()); 857 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 858 result)) || 859 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 860 return failure(); 861 862 // Convert the region signature and just drop the operation. 863 rewriter.eraseOp(op); 864 return success(); 865 } 866 }; 867 /// This pattern simply updates the operands of the given operation. 868 struct TestPassthroughInvalidOp : public ConversionPattern { 869 TestPassthroughInvalidOp(MLIRContext *ctx) 870 : ConversionPattern("test.invalid", 1, ctx) {} 871 LogicalResult 872 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 873 ConversionPatternRewriter &rewriter) const final { 874 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, operands, 875 std::nullopt); 876 return success(); 877 } 878 }; 879 /// This pattern handles the case of a split return value. 880 struct TestSplitReturnType : public ConversionPattern { 881 TestSplitReturnType(MLIRContext *ctx) 882 : ConversionPattern("test.return", 1, ctx) {} 883 LogicalResult 884 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 885 ConversionPatternRewriter &rewriter) const final { 886 // Check for a return of F32. 887 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 888 return failure(); 889 890 // Check if the first operation is a cast operation, if it is we use the 891 // results directly. 892 auto *defOp = operands[0].getDefiningOp(); 893 if (auto packerOp = 894 llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) { 895 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 896 return success(); 897 } 898 899 // Otherwise, fail to match. 900 return failure(); 901 } 902 }; 903 904 //===----------------------------------------------------------------------===// 905 // Multi-Level Type-Conversion Rewrite Testing 906 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 907 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 908 : ConversionPattern("test.type_producer", 1, ctx) {} 909 LogicalResult 910 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 911 ConversionPatternRewriter &rewriter) const final { 912 // If the type is I32, change the type to F32. 913 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 914 return failure(); 915 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 916 return success(); 917 } 918 }; 919 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 920 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 921 : ConversionPattern("test.type_producer", 1, ctx) {} 922 LogicalResult 923 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 924 ConversionPatternRewriter &rewriter) const final { 925 // If the type is F32, change the type to F64. 926 if (!Type(*op->result_type_begin()).isF32()) 927 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 928 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 929 return success(); 930 } 931 }; 932 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 933 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 934 : ConversionPattern("test.type_producer", 10, ctx) {} 935 LogicalResult 936 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 937 ConversionPatternRewriter &rewriter) const final { 938 // Always convert to B16, even though it is not a legal type. This tests 939 // that values are unmapped correctly. 940 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 941 return success(); 942 } 943 }; 944 struct TestUpdateConsumerType : public ConversionPattern { 945 TestUpdateConsumerType(MLIRContext *ctx) 946 : ConversionPattern("test.type_consumer", 1, ctx) {} 947 LogicalResult 948 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 949 ConversionPatternRewriter &rewriter) const final { 950 // Verify that the incoming operand has been successfully remapped to F64. 951 if (!operands[0].getType().isF64()) 952 return failure(); 953 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 954 return success(); 955 } 956 }; 957 958 //===----------------------------------------------------------------------===// 959 // Non-Root Replacement Rewrite Testing 960 /// This pattern generates an invalid operation, but replaces it before the 961 /// pattern is finished. This checks that we don't need to legalize the 962 /// temporary op. 963 struct TestNonRootReplacement : public RewritePattern { 964 TestNonRootReplacement(MLIRContext *ctx) 965 : RewritePattern("test.replace_non_root", 1, ctx) {} 966 967 LogicalResult matchAndRewrite(Operation *op, 968 PatternRewriter &rewriter) const final { 969 auto resultType = *op->result_type_begin(); 970 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 971 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 972 973 rewriter.replaceOp(illegalOp, legalOp); 974 rewriter.replaceOp(op, illegalOp); 975 return success(); 976 } 977 }; 978 979 //===----------------------------------------------------------------------===// 980 // Recursive Rewrite Testing 981 /// This pattern is applied to the same operation multiple times, but has a 982 /// bounded recursion. 983 struct TestBoundedRecursiveRewrite 984 : public OpRewritePattern<TestRecursiveRewriteOp> { 985 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 986 987 void initialize() { 988 // The conversion target handles bounding the recursion of this pattern. 989 setHasBoundedRewriteRecursion(); 990 } 991 992 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 993 PatternRewriter &rewriter) const final { 994 // Decrement the depth of the op in-place. 995 rewriter.modifyOpInPlace(op, [&] { 996 op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1)); 997 }); 998 return success(); 999 } 1000 }; 1001 1002 struct TestNestedOpCreationUndoRewrite 1003 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 1004 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 1005 1006 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 1007 PatternRewriter &rewriter) const final { 1008 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 1009 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 1010 return success(); 1011 }; 1012 }; 1013 1014 // This pattern matches `test.blackhole` and delete this op and its producer. 1015 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 1016 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 1017 1018 LogicalResult matchAndRewrite(BlackHoleOp op, 1019 PatternRewriter &rewriter) const final { 1020 Operation *producer = op.getOperand().getDefiningOp(); 1021 // Always erase the user before the producer, the framework should handle 1022 // this correctly. 1023 rewriter.eraseOp(op); 1024 rewriter.eraseOp(producer); 1025 return success(); 1026 }; 1027 }; 1028 1029 // This pattern replaces explicitly illegal op with explicitly legal op, 1030 // but in addition creates unregistered operation. 1031 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { 1032 using OpRewritePattern<ILLegalOpG>::OpRewritePattern; 1033 1034 LogicalResult matchAndRewrite(ILLegalOpG op, 1035 PatternRewriter &rewriter) const final { 1036 IntegerAttr attr = rewriter.getI32IntegerAttr(0); 1037 Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr); 1038 rewriter.replaceOpWithNewOp<LegalOpC>(op, val); 1039 return success(); 1040 }; 1041 }; 1042 } // namespace 1043 1044 namespace { 1045 struct TestTypeConverter : public TypeConverter { 1046 using TypeConverter::TypeConverter; 1047 TestTypeConverter() { 1048 addConversion(convertType); 1049 addArgumentMaterialization(materializeCast); 1050 addSourceMaterialization(materializeCast); 1051 } 1052 1053 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 1054 // Drop I16 types. 1055 if (t.isSignlessInteger(16)) 1056 return success(); 1057 1058 // Convert I64 to F64. 1059 if (t.isSignlessInteger(64)) { 1060 results.push_back(FloatType::getF64(t.getContext())); 1061 return success(); 1062 } 1063 1064 // Convert I42 to I43. 1065 if (t.isInteger(42)) { 1066 results.push_back(IntegerType::get(t.getContext(), 43)); 1067 return success(); 1068 } 1069 1070 // Split F32 into F16,F16. 1071 if (t.isF32()) { 1072 results.assign(2, FloatType::getF16(t.getContext())); 1073 return success(); 1074 } 1075 1076 // Otherwise, convert the type directly. 1077 results.push_back(t); 1078 return success(); 1079 } 1080 1081 /// Hook for materializing a conversion. This is necessary because we generate 1082 /// 1->N type mappings. 1083 static std::optional<Value> materializeCast(OpBuilder &builder, 1084 Type resultType, 1085 ValueRange inputs, Location loc) { 1086 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1087 } 1088 }; 1089 1090 struct TestLegalizePatternDriver 1091 : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> { 1092 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver) 1093 1094 StringRef getArgument() const final { return "test-legalize-patterns"; } 1095 StringRef getDescription() const final { 1096 return "Run test dialect legalization patterns"; 1097 } 1098 /// The mode of conversion to use with the driver. 1099 enum class ConversionMode { Analysis, Full, Partial }; 1100 1101 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 1102 1103 void getDependentDialects(DialectRegistry ®istry) const override { 1104 registry.insert<func::FuncDialect, test::TestDialect>(); 1105 } 1106 1107 void runOnOperation() override { 1108 TestTypeConverter converter; 1109 mlir::RewritePatternSet patterns(&getContext()); 1110 populateWithGenerated(patterns); 1111 patterns 1112 .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 1113 TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, 1114 TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, 1115 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 1116 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 1117 TestNonRootReplacement, TestBoundedRecursiveRewrite, 1118 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, 1119 TestCreateUnregisteredOp, TestUndoMoveOpBefore, 1120 TestUndoPropertiesModification>(&getContext()); 1121 patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); 1122 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1123 converter); 1124 mlir::populateCallOpTypeConversionPattern(patterns, converter); 1125 1126 // Define the conversion target used for the test. 1127 ConversionTarget target(getContext()); 1128 target.addLegalOp<ModuleOp>(); 1129 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, 1130 TerminatorOp, OneRegionOp>(); 1131 target 1132 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 1133 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 1134 // Don't allow F32 operands. 1135 return llvm::none_of(op.getOperandTypes(), 1136 [](Type type) { return type.isF32(); }); 1137 }); 1138 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1139 return converter.isSignatureLegal(op.getFunctionType()) && 1140 converter.isLegal(&op.getBody()); 1141 }); 1142 target.addDynamicallyLegalOp<func::CallOp>( 1143 [&](func::CallOp op) { return converter.isLegal(op); }); 1144 1145 // TestCreateUnregisteredOp creates `arith.constant` operation, 1146 // which was not added to target intentionally to test 1147 // correct error code from conversion driver. 1148 target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); 1149 1150 // Expect the type_producer/type_consumer operations to only operate on f64. 1151 target.addDynamicallyLegalOp<TestTypeProducerOp>( 1152 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1153 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1154 return op.getOperand().getType().isF64(); 1155 }); 1156 1157 // Check support for marking certain operations as recursively legal. 1158 target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) { 1159 return static_cast<bool>( 1160 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 1161 }); 1162 1163 // Mark the bound recursion operation as dynamically legal. 1164 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 1165 [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); 1166 1167 // Handle a partial conversion. 1168 if (mode == ConversionMode::Partial) { 1169 DenseSet<Operation *> unlegalizedOps; 1170 ConversionConfig config; 1171 DumpNotifications dumpNotifications; 1172 config.listener = &dumpNotifications; 1173 config.unlegalizedOps = &unlegalizedOps; 1174 if (failed(applyPartialConversion(getOperation(), target, 1175 std::move(patterns), config))) { 1176 getOperation()->emitRemark() << "applyPartialConversion failed"; 1177 } 1178 // Emit remarks for each legalizable operation. 1179 for (auto *op : unlegalizedOps) 1180 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1181 return; 1182 } 1183 1184 // Handle a full conversion. 1185 if (mode == ConversionMode::Full) { 1186 // Check support for marking unknown operations as dynamically legal. 1187 target.markUnknownOpDynamicallyLegal([](Operation *op) { 1188 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 1189 }); 1190 1191 ConversionConfig config; 1192 DumpNotifications dumpNotifications; 1193 config.listener = &dumpNotifications; 1194 if (failed(applyFullConversion(getOperation(), target, 1195 std::move(patterns), config))) { 1196 getOperation()->emitRemark() << "applyFullConversion failed"; 1197 } 1198 return; 1199 } 1200 1201 // Otherwise, handle an analysis conversion. 1202 assert(mode == ConversionMode::Analysis); 1203 1204 // Analyze the convertible operations. 1205 DenseSet<Operation *> legalizedOps; 1206 ConversionConfig config; 1207 config.legalizableOps = &legalizedOps; 1208 if (failed(applyAnalysisConversion(getOperation(), target, 1209 std::move(patterns), config))) 1210 return signalPassFailure(); 1211 1212 // Emit remarks for each legalizable operation. 1213 for (auto *op : legalizedOps) 1214 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 1215 } 1216 1217 /// The mode of conversion to use. 1218 ConversionMode mode; 1219 }; 1220 } // namespace 1221 1222 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 1223 legalizerConversionMode( 1224 "test-legalize-mode", 1225 llvm::cl::desc("The legalization mode to use with the test driver"), 1226 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 1227 llvm::cl::values( 1228 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 1229 "analysis", "Perform an analysis conversion"), 1230 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 1231 "Perform a full conversion"), 1232 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 1233 "partial", "Perform a partial conversion"))); 1234 1235 //===----------------------------------------------------------------------===// 1236 // ConversionPatternRewriter::getRemappedValue testing. This method is used 1237 // to get the remapped value of an original value that was replaced using 1238 // ConversionPatternRewriter. 1239 namespace { 1240 struct TestRemapValueTypeConverter : public TypeConverter { 1241 using TypeConverter::TypeConverter; 1242 1243 TestRemapValueTypeConverter() { 1244 addConversion( 1245 [](Float32Type type) { return Float64Type::get(type.getContext()); }); 1246 addConversion([](Type type) { return type; }); 1247 } 1248 }; 1249 1250 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 1251 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 1252 /// operand twice. 1253 /// 1254 /// Example: 1255 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 1256 /// is replaced with: 1257 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 1258 struct OneVResOneVOperandOp1Converter 1259 : public OpConversionPattern<OneVResOneVOperandOp1> { 1260 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 1261 1262 LogicalResult 1263 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, 1264 ConversionPatternRewriter &rewriter) const override { 1265 auto origOps = op.getOperands(); 1266 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 1267 "One operand expected"); 1268 Value origOp = *origOps.begin(); 1269 SmallVector<Value, 2> remappedOperands; 1270 // Replicate the remapped original operand twice. Note that we don't used 1271 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 1272 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1273 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1274 1275 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 1276 remappedOperands); 1277 return success(); 1278 } 1279 }; 1280 1281 /// A rewriter pattern that tests that blocks can be merged. 1282 struct TestRemapValueInRegion 1283 : public OpConversionPattern<TestRemappedValueRegionOp> { 1284 using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern; 1285 1286 LogicalResult 1287 matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor, 1288 ConversionPatternRewriter &rewriter) const final { 1289 Block &block = op.getBody().front(); 1290 Operation *terminator = block.getTerminator(); 1291 1292 // Merge the block into the parent region. 1293 Block *parentBlock = op->getBlock(); 1294 Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator()); 1295 rewriter.mergeBlocks(&block, parentBlock, ValueRange()); 1296 rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange()); 1297 1298 // Replace the results of this operation with the remapped terminator 1299 // values. 1300 SmallVector<Value> terminatorOperands; 1301 if (failed(rewriter.getRemappedValues(terminator->getOperands(), 1302 terminatorOperands))) 1303 return failure(); 1304 1305 rewriter.eraseOp(terminator); 1306 rewriter.replaceOp(op, terminatorOperands); 1307 return success(); 1308 } 1309 }; 1310 1311 struct TestRemappedValue 1312 : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> { 1313 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue) 1314 1315 StringRef getArgument() const final { return "test-remapped-value"; } 1316 StringRef getDescription() const final { 1317 return "Test public remapped value mechanism in ConversionPatternRewriter"; 1318 } 1319 void runOnOperation() override { 1320 TestRemapValueTypeConverter typeConverter; 1321 1322 mlir::RewritePatternSet patterns(&getContext()); 1323 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 1324 patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>( 1325 &getContext()); 1326 patterns.add<TestRemapValueInRegion>(typeConverter, &getContext()); 1327 1328 mlir::ConversionTarget target(getContext()); 1329 target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>(); 1330 1331 // Expect the type_producer/type_consumer operations to only operate on f64. 1332 target.addDynamicallyLegalOp<TestTypeProducerOp>( 1333 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1334 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1335 return op.getOperand().getType().isF64(); 1336 }); 1337 1338 // We make OneVResOneVOperandOp1 legal only when it has more that one 1339 // operand. This will trigger the conversion that will replace one-operand 1340 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 1341 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 1342 [](Operation *op) { return op->getNumOperands() > 1; }); 1343 1344 if (failed(mlir::applyFullConversion(getOperation(), target, 1345 std::move(patterns)))) { 1346 signalPassFailure(); 1347 } 1348 } 1349 }; 1350 } // namespace 1351 1352 //===----------------------------------------------------------------------===// 1353 // Test patterns without a specific root operation kind 1354 //===----------------------------------------------------------------------===// 1355 1356 namespace { 1357 /// This pattern matches and removes any operation in the test dialect. 1358 struct RemoveTestDialectOps : public RewritePattern { 1359 RemoveTestDialectOps(MLIRContext *context) 1360 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 1361 1362 LogicalResult matchAndRewrite(Operation *op, 1363 PatternRewriter &rewriter) const override { 1364 if (!isa<TestDialect>(op->getDialect())) 1365 return failure(); 1366 rewriter.eraseOp(op); 1367 return success(); 1368 } 1369 }; 1370 1371 struct TestUnknownRootOpDriver 1372 : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> { 1373 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver) 1374 1375 StringRef getArgument() const final { 1376 return "test-legalize-unknown-root-patterns"; 1377 } 1378 StringRef getDescription() const final { 1379 return "Test public remapped value mechanism in ConversionPatternRewriter"; 1380 } 1381 void runOnOperation() override { 1382 mlir::RewritePatternSet patterns(&getContext()); 1383 patterns.add<RemoveTestDialectOps>(&getContext()); 1384 1385 mlir::ConversionTarget target(getContext()); 1386 target.addIllegalDialect<TestDialect>(); 1387 if (failed(applyPartialConversion(getOperation(), target, 1388 std::move(patterns)))) 1389 signalPassFailure(); 1390 } 1391 }; 1392 } // namespace 1393 1394 //===----------------------------------------------------------------------===// 1395 // Test patterns that uses operations and types defined at runtime 1396 //===----------------------------------------------------------------------===// 1397 1398 namespace { 1399 /// This pattern matches dynamic operations 'test.one_operand_two_results' and 1400 /// replace them with dynamic operations 'test.generic_dynamic_op'. 1401 struct RewriteDynamicOp : public RewritePattern { 1402 RewriteDynamicOp(MLIRContext *context) 1403 : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1, 1404 context) {} 1405 1406 LogicalResult matchAndRewrite(Operation *op, 1407 PatternRewriter &rewriter) const override { 1408 assert(op->getName().getStringRef() == 1409 "test.dynamic_one_operand_two_results" && 1410 "rewrite pattern should only match operations with the right name"); 1411 1412 OperationState state(op->getLoc(), "test.dynamic_generic", 1413 op->getOperands(), op->getResultTypes(), 1414 op->getAttrs()); 1415 auto *newOp = rewriter.create(state); 1416 rewriter.replaceOp(op, newOp->getResults()); 1417 return success(); 1418 } 1419 }; 1420 1421 struct TestRewriteDynamicOpDriver 1422 : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> { 1423 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver) 1424 1425 void getDependentDialects(DialectRegistry ®istry) const override { 1426 registry.insert<TestDialect>(); 1427 } 1428 StringRef getArgument() const final { return "test-rewrite-dynamic-op"; } 1429 StringRef getDescription() const final { 1430 return "Test rewritting on dynamic operations"; 1431 } 1432 void runOnOperation() override { 1433 RewritePatternSet patterns(&getContext()); 1434 patterns.add<RewriteDynamicOp>(&getContext()); 1435 1436 ConversionTarget target(getContext()); 1437 target.addIllegalOp( 1438 OperationName("test.dynamic_one_operand_two_results", &getContext())); 1439 target.addLegalOp(OperationName("test.dynamic_generic", &getContext())); 1440 if (failed(applyPartialConversion(getOperation(), target, 1441 std::move(patterns)))) 1442 signalPassFailure(); 1443 } 1444 }; 1445 } // end anonymous namespace 1446 1447 //===----------------------------------------------------------------------===// 1448 // Test type conversions 1449 //===----------------------------------------------------------------------===// 1450 1451 namespace { 1452 struct TestTypeConversionProducer 1453 : public OpConversionPattern<TestTypeProducerOp> { 1454 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 1455 LogicalResult 1456 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, 1457 ConversionPatternRewriter &rewriter) const final { 1458 Type resultType = op.getType(); 1459 Type convertedType = getTypeConverter() 1460 ? getTypeConverter()->convertType(resultType) 1461 : resultType; 1462 if (isa<FloatType>(resultType)) 1463 resultType = rewriter.getF64Type(); 1464 else if (resultType.isInteger(16)) 1465 resultType = rewriter.getIntegerType(64); 1466 else if (isa<test::TestRecursiveType>(resultType) && 1467 convertedType != resultType) 1468 resultType = convertedType; 1469 else 1470 return failure(); 1471 1472 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 1473 return success(); 1474 } 1475 }; 1476 1477 /// Call signature conversion and then fail the rewrite to trigger the undo 1478 /// mechanism. 1479 struct TestSignatureConversionUndo 1480 : public OpConversionPattern<TestSignatureConversionUndoOp> { 1481 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 1482 1483 LogicalResult 1484 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, 1485 ConversionPatternRewriter &rewriter) const final { 1486 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 1487 return failure(); 1488 } 1489 }; 1490 1491 /// Call signature conversion without providing a type converter to handle 1492 /// materializations. 1493 struct TestTestSignatureConversionNoConverter 1494 : public OpConversionPattern<TestSignatureConversionNoConverterOp> { 1495 TestTestSignatureConversionNoConverter(const TypeConverter &converter, 1496 MLIRContext *context) 1497 : OpConversionPattern<TestSignatureConversionNoConverterOp>(context), 1498 converter(converter) {} 1499 1500 LogicalResult 1501 matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor, 1502 ConversionPatternRewriter &rewriter) const final { 1503 Region ®ion = op->getRegion(0); 1504 Block *entry = ®ion.front(); 1505 1506 // Convert the original entry arguments. 1507 TypeConverter::SignatureConversion result(entry->getNumArguments()); 1508 if (failed( 1509 converter.convertSignatureArgs(entry->getArgumentTypes(), result))) 1510 return failure(); 1511 rewriter.modifyOpInPlace( 1512 op, [&] { rewriter.applySignatureConversion(®ion, result); }); 1513 return success(); 1514 } 1515 1516 const TypeConverter &converter; 1517 }; 1518 1519 /// Just forward the operands to the root op. This is essentially a no-op 1520 /// pattern that is used to trigger target materialization. 1521 struct TestTypeConsumerForward 1522 : public OpConversionPattern<TestTypeConsumerOp> { 1523 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 1524 1525 LogicalResult 1526 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, 1527 ConversionPatternRewriter &rewriter) const final { 1528 rewriter.modifyOpInPlace(op, 1529 [&] { op->setOperands(adaptor.getOperands()); }); 1530 return success(); 1531 } 1532 }; 1533 1534 struct TestTypeConversionAnotherProducer 1535 : public OpRewritePattern<TestAnotherTypeProducerOp> { 1536 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 1537 1538 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 1539 PatternRewriter &rewriter) const final { 1540 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 1541 return success(); 1542 } 1543 }; 1544 1545 struct TestTypeConversionDriver 1546 : public PassWrapper<TestTypeConversionDriver, OperationPass<>> { 1547 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver) 1548 1549 void getDependentDialects(DialectRegistry ®istry) const override { 1550 registry.insert<TestDialect>(); 1551 } 1552 StringRef getArgument() const final { 1553 return "test-legalize-type-conversion"; 1554 } 1555 StringRef getDescription() const final { 1556 return "Test various type conversion functionalities in DialectConversion"; 1557 } 1558 1559 void runOnOperation() override { 1560 // Initialize the type converter. 1561 SmallVector<Type, 2> conversionCallStack; 1562 TypeConverter converter; 1563 1564 /// Add the legal set of type conversions. 1565 converter.addConversion([](Type type) -> Type { 1566 // Treat F64 as legal. 1567 if (type.isF64()) 1568 return type; 1569 // Allow converting BF16/F16/F32 to F64. 1570 if (type.isBF16() || type.isF16() || type.isF32()) 1571 return FloatType::getF64(type.getContext()); 1572 // Otherwise, the type is illegal. 1573 return nullptr; 1574 }); 1575 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 1576 // Drop all integer types. 1577 return success(); 1578 }); 1579 converter.addConversion( 1580 // Convert a recursive self-referring type into a non-self-referring 1581 // type named "outer_converted_type" that contains a SimpleAType. 1582 [&](test::TestRecursiveType type, 1583 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { 1584 // If the type is already converted, return it to indicate that it is 1585 // legal. 1586 if (type.getName() == "outer_converted_type") { 1587 results.push_back(type); 1588 return success(); 1589 } 1590 1591 conversionCallStack.push_back(type); 1592 auto popConversionCallStack = llvm::make_scope_exit( 1593 [&conversionCallStack]() { conversionCallStack.pop_back(); }); 1594 1595 // If the type is on the call stack more than once (it is there at 1596 // least once because of the _current_ call, which is always the last 1597 // element on the stack), we've hit the recursive case. Just return 1598 // SimpleAType here to create a non-recursive type as a result. 1599 if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(), 1600 type)) { 1601 results.push_back(test::SimpleAType::get(type.getContext())); 1602 return success(); 1603 } 1604 1605 // Convert the body recursively. 1606 auto result = test::TestRecursiveType::get(type.getContext(), 1607 "outer_converted_type"); 1608 if (failed(result.setBody(converter.convertType(type.getBody())))) 1609 return failure(); 1610 results.push_back(result); 1611 return success(); 1612 }); 1613 1614 /// Add the legal set of type materializations. 1615 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 1616 ValueRange inputs, 1617 Location loc) -> Value { 1618 // Allow casting from F64 back to F32. 1619 if (!resultType.isF16() && inputs.size() == 1 && 1620 inputs[0].getType().isF64()) 1621 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1622 // Allow producing an i32 or i64 from nothing. 1623 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 1624 inputs.empty()) 1625 return builder.create<TestTypeProducerOp>(loc, resultType); 1626 // Allow producing an i64 from an integer. 1627 if (isa<IntegerType>(resultType) && inputs.size() == 1 && 1628 isa<IntegerType>(inputs[0].getType())) 1629 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1630 // Otherwise, fail. 1631 return nullptr; 1632 }); 1633 1634 // Initialize the conversion target. 1635 mlir::ConversionTarget target(getContext()); 1636 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 1637 auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType()); 1638 return op.getType().isF64() || op.getType().isInteger(64) || 1639 (recursiveType && 1640 recursiveType.getName() == "outer_converted_type"); 1641 }); 1642 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1643 return converter.isSignatureLegal(op.getFunctionType()) && 1644 converter.isLegal(&op.getBody()); 1645 }); 1646 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 1647 // Allow casts from F64 to F32. 1648 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 1649 }); 1650 target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>( 1651 [&](TestSignatureConversionNoConverterOp op) { 1652 return converter.isLegal(op.getRegion().front().getArgumentTypes()); 1653 }); 1654 1655 // Initialize the set of rewrite patterns. 1656 RewritePatternSet patterns(&getContext()); 1657 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 1658 TestSignatureConversionUndo, 1659 TestTestSignatureConversionNoConverter>(converter, 1660 &getContext()); 1661 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 1662 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1663 converter); 1664 1665 if (failed(applyPartialConversion(getOperation(), target, 1666 std::move(patterns)))) 1667 signalPassFailure(); 1668 } 1669 }; 1670 } // namespace 1671 1672 //===----------------------------------------------------------------------===// 1673 // Test Target Materialization With No Uses 1674 //===----------------------------------------------------------------------===// 1675 1676 namespace { 1677 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> { 1678 using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern; 1679 1680 LogicalResult 1681 matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor, 1682 ConversionPatternRewriter &rewriter) const final { 1683 rewriter.replaceOp(op, adaptor.getOperands()); 1684 return success(); 1685 } 1686 }; 1687 1688 struct TestTargetMaterializationWithNoUses 1689 : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> { 1690 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 1691 TestTargetMaterializationWithNoUses) 1692 1693 StringRef getArgument() const final { 1694 return "test-target-materialization-with-no-uses"; 1695 } 1696 StringRef getDescription() const final { 1697 return "Test a special case of target materialization in DialectConversion"; 1698 } 1699 1700 void runOnOperation() override { 1701 TypeConverter converter; 1702 converter.addConversion([](Type t) { return t; }); 1703 converter.addConversion([](IntegerType intTy) -> Type { 1704 if (intTy.getWidth() == 16) 1705 return IntegerType::get(intTy.getContext(), 64); 1706 return intTy; 1707 }); 1708 converter.addTargetMaterialization( 1709 [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { 1710 return builder.create<TestCastOp>(loc, type, inputs).getResult(); 1711 }); 1712 1713 ConversionTarget target(getContext()); 1714 target.addIllegalOp<TestTypeChangerOp>(); 1715 1716 RewritePatternSet patterns(&getContext()); 1717 patterns.add<ForwardOperandPattern>(converter, &getContext()); 1718 1719 if (failed(applyPartialConversion(getOperation(), target, 1720 std::move(patterns)))) 1721 signalPassFailure(); 1722 } 1723 }; 1724 } // namespace 1725 1726 //===----------------------------------------------------------------------===// 1727 // Test Block Merging 1728 //===----------------------------------------------------------------------===// 1729 1730 namespace { 1731 /// A rewriter pattern that tests that blocks can be merged. 1732 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 1733 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 1734 1735 LogicalResult 1736 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, 1737 ConversionPatternRewriter &rewriter) const final { 1738 Block &firstBlock = op.getBody().front(); 1739 Operation *branchOp = firstBlock.getTerminator(); 1740 Block *secondBlock = &*(std::next(op.getBody().begin())); 1741 auto succOperands = branchOp->getOperands(); 1742 SmallVector<Value, 2> replacements(succOperands); 1743 rewriter.eraseOp(branchOp); 1744 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1745 rewriter.modifyOpInPlace(op, [] {}); 1746 return success(); 1747 } 1748 }; 1749 1750 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 1751 struct TestUndoBlocksMerge : public ConversionPattern { 1752 TestUndoBlocksMerge(MLIRContext *ctx) 1753 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 1754 LogicalResult 1755 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1756 ConversionPatternRewriter &rewriter) const final { 1757 Block &firstBlock = op->getRegion(0).front(); 1758 Operation *branchOp = firstBlock.getTerminator(); 1759 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 1760 rewriter.setInsertionPointToStart(secondBlock); 1761 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 1762 auto succOperands = branchOp->getOperands(); 1763 SmallVector<Value, 2> replacements(succOperands); 1764 rewriter.eraseOp(branchOp); 1765 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1766 rewriter.modifyOpInPlace(op, [] {}); 1767 return success(); 1768 } 1769 }; 1770 1771 /// A rewrite mechanism to inline the body of the op into its parent, when both 1772 /// ops can have a single block. 1773 struct TestMergeSingleBlockOps 1774 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 1775 using OpConversionPattern< 1776 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 1777 1778 LogicalResult 1779 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, 1780 ConversionPatternRewriter &rewriter) const final { 1781 SingleBlockImplicitTerminatorOp parentOp = 1782 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1783 if (!parentOp) 1784 return failure(); 1785 Block &innerBlock = op.getRegion().front(); 1786 TerminatorOp innerTerminator = 1787 cast<TerminatorOp>(innerBlock.getTerminator()); 1788 rewriter.inlineBlockBefore(&innerBlock, op); 1789 rewriter.eraseOp(innerTerminator); 1790 rewriter.eraseOp(op); 1791 return success(); 1792 } 1793 }; 1794 1795 struct TestMergeBlocksPatternDriver 1796 : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> { 1797 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver) 1798 1799 StringRef getArgument() const final { return "test-merge-blocks"; } 1800 StringRef getDescription() const final { 1801 return "Test Merging operation in ConversionPatternRewriter"; 1802 } 1803 void runOnOperation() override { 1804 MLIRContext *context = &getContext(); 1805 mlir::RewritePatternSet patterns(context); 1806 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 1807 context); 1808 ConversionTarget target(*context); 1809 target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 1810 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 1811 target.addIllegalOp<ILLegalOpF>(); 1812 1813 /// Expect the op to have a single block after legalization. 1814 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 1815 [&](TestMergeBlocksOp op) -> bool { 1816 return llvm::hasSingleElement(op.getBody()); 1817 }); 1818 1819 /// Only allow `test.br` within test.merge_blocks op. 1820 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 1821 return op->getParentOfType<TestMergeBlocksOp>(); 1822 }); 1823 1824 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 1825 /// inlined. 1826 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 1827 [&](SingleBlockImplicitTerminatorOp op) -> bool { 1828 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1829 }); 1830 1831 DenseSet<Operation *> unlegalizedOps; 1832 ConversionConfig config; 1833 config.unlegalizedOps = &unlegalizedOps; 1834 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 1835 config); 1836 for (auto *op : unlegalizedOps) 1837 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1838 } 1839 }; 1840 } // namespace 1841 1842 //===----------------------------------------------------------------------===// 1843 // Test Selective Replacement 1844 //===----------------------------------------------------------------------===// 1845 1846 namespace { 1847 /// A rewrite mechanism to inline the body of the op into its parent, when both 1848 /// ops can have a single block. 1849 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 1850 using OpRewritePattern<TestCastOp>::OpRewritePattern; 1851 1852 LogicalResult matchAndRewrite(TestCastOp op, 1853 PatternRewriter &rewriter) const final { 1854 if (op.getNumOperands() != 2) 1855 return failure(); 1856 OperandRange operands = op.getOperands(); 1857 1858 // Replace non-terminator uses with the first operand. 1859 rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) { 1860 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 1861 }); 1862 // Replace everything else with the second operand if the operation isn't 1863 // dead. 1864 rewriter.replaceOp(op, op.getOperand(1)); 1865 return success(); 1866 } 1867 }; 1868 1869 struct TestSelectiveReplacementPatternDriver 1870 : public PassWrapper<TestSelectiveReplacementPatternDriver, 1871 OperationPass<>> { 1872 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 1873 TestSelectiveReplacementPatternDriver) 1874 1875 StringRef getArgument() const final { 1876 return "test-pattern-selective-replacement"; 1877 } 1878 StringRef getDescription() const final { 1879 return "Test selective replacement in the PatternRewriter"; 1880 } 1881 void runOnOperation() override { 1882 MLIRContext *context = &getContext(); 1883 mlir::RewritePatternSet patterns(context); 1884 patterns.add<TestSelectiveOpReplacementPattern>(context); 1885 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 1886 } 1887 }; 1888 } // namespace 1889 1890 //===----------------------------------------------------------------------===// 1891 // PassRegistration 1892 //===----------------------------------------------------------------------===// 1893 1894 namespace mlir { 1895 namespace test { 1896 void registerPatternsTestPass() { 1897 PassRegistration<TestReturnTypeDriver>(); 1898 1899 PassRegistration<TestDerivedAttributeDriver>(); 1900 1901 PassRegistration<TestPatternDriver>(); 1902 PassRegistration<TestStrictPatternDriver>(); 1903 1904 PassRegistration<TestLegalizePatternDriver>([] { 1905 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 1906 }); 1907 1908 PassRegistration<TestRemappedValue>(); 1909 1910 PassRegistration<TestUnknownRootOpDriver>(); 1911 1912 PassRegistration<TestTypeConversionDriver>(); 1913 PassRegistration<TestTargetMaterializationWithNoUses>(); 1914 1915 PassRegistration<TestRewriteDynamicOpDriver>(); 1916 1917 PassRegistration<TestMergeBlocksPatternDriver>(); 1918 PassRegistration<TestSelectiveReplacementPatternDriver>(); 1919 } 1920 } // namespace test 1921 } // namespace mlir 1922