xref: /llvm-project/mlir/test/lib/Dialect/Test/TestPatterns.cpp (revision f023da12d12635f5fba436e825cbfc999e28e623)
1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestOps.h"
11 #include "TestTypes.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BuiltinAttributes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Visitors.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
24 #include "llvm/ADT/ScopeExit.h"
25 #include <cstdint>
26 
27 using namespace mlir;
28 using namespace test;
29 
30 // Native function for testing NativeCodeCall
31 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
32   return choice.getValue() ? input1 : input2;
33 }
34 
35 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
36   rewriter.create<OpI>(loc, input);
37 }
38 
39 static void handleNoResultOp(PatternRewriter &rewriter,
40                              OpSymbolBindingNoResult op) {
41   // Turn the no result op to a one-result op.
42   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
43                                     op.getOperand());
44 }
45 
46 static bool getFirstI32Result(Operation *op, Value &value) {
47   if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
48     return false;
49   value = op->getResult(0);
50   return true;
51 }
52 
53 static Value bindNativeCodeCallResult(Value value) { return value; }
54 
55 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
56                                                               Value input2) {
57   return SmallVector<Value, 2>({input2, input1});
58 }
59 
60 // Test that natives calls are only called once during rewrites.
61 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
62 // This let us check the number of times OpM_Test was called by inspecting
63 // the returned value in the MLIR output.
64 static int64_t opMIncreasingValue = 314159265;
65 static Attribute opMTest(PatternRewriter &rewriter, Value val) {
66   int64_t i = opMIncreasingValue++;
67   return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
68 }
69 
70 namespace {
71 #include "TestPatterns.inc"
72 } // namespace
73 
74 //===----------------------------------------------------------------------===//
75 // Test Reduce Pattern Interface
76 //===----------------------------------------------------------------------===//
77 
78 void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
79   populateWithGenerated(patterns);
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // Canonicalizer Driver.
84 //===----------------------------------------------------------------------===//
85 
86 namespace {
87 struct FoldingPattern : public RewritePattern {
88 public:
89   FoldingPattern(MLIRContext *context)
90       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
91                        /*benefit=*/1, context) {}
92 
93   LogicalResult matchAndRewrite(Operation *op,
94                                 PatternRewriter &rewriter) const override {
95     // Exercise createOrFold API for a single-result operation that is folded
96     // upon construction. The operation being created has an in-place folder,
97     // and it should be still present in the output. Furthermore, the folder
98     // should not crash when attempting to recover the (unchanged) operation
99     // result.
100     Value result = rewriter.createOrFold<TestOpInPlaceFold>(
101         op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0));
102     assert(result);
103     rewriter.replaceOp(op, result);
104     return success();
105   }
106 };
107 
108 /// This pattern creates a foldable operation at the entry point of the block.
109 /// This tests the situation where the operation folder will need to replace an
110 /// operation with a previously created constant that does not initially
111 /// dominate the operation to replace.
112 struct FolderInsertBeforePreviouslyFoldedConstantPattern
113     : public OpRewritePattern<TestCastOp> {
114 public:
115   using OpRewritePattern<TestCastOp>::OpRewritePattern;
116 
117   LogicalResult matchAndRewrite(TestCastOp op,
118                                 PatternRewriter &rewriter) const override {
119     if (!op->hasAttr("test_fold_before_previously_folded_op"))
120       return failure();
121     rewriter.setInsertionPointToStart(op->getBlock());
122 
123     auto constOp = rewriter.create<arith::ConstantOp>(
124         op.getLoc(), rewriter.getBoolAttr(true));
125     rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
126                                             Value(constOp));
127     return success();
128   }
129 };
130 
131 /// This pattern matches test.op_commutative2 with the first operand being
132 /// another test.op_commutative2 with a constant on the right side and fold it
133 /// away by propagating it as its result. This is intend to check that patterns
134 /// are applied after the commutative property moves constant to the right.
135 struct FolderCommutativeOp2WithConstant
136     : public OpRewritePattern<TestCommutative2Op> {
137 public:
138   using OpRewritePattern<TestCommutative2Op>::OpRewritePattern;
139 
140   LogicalResult matchAndRewrite(TestCommutative2Op op,
141                                 PatternRewriter &rewriter) const override {
142     auto operand =
143         dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
144     if (!operand)
145       return failure();
146     Attribute constInput;
147     if (!matchPattern(operand->getOperand(1), m_Constant(&constInput)))
148       return failure();
149     rewriter.replaceOp(op, operand->getOperand(1));
150     return success();
151   }
152 };
153 
154 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer
155 /// attribute with value smaller than MaxVal, it increments the value by 1.
156 template <int MaxVal>
157 struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
158   using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern;
159 
160   LogicalResult matchAndRewrite(AnyAttrOfOp op,
161                                 PatternRewriter &rewriter) const override {
162     auto intAttr = dyn_cast<IntegerAttr>(op.getAttr());
163     if (!intAttr)
164       return failure();
165     int64_t val = intAttr.getInt();
166     if (val >= MaxVal)
167       return failure();
168     rewriter.modifyOpInPlace(
169         op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); });
170     return success();
171   }
172 };
173 
174 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op".
175 struct MakeOpEligible : public RewritePattern {
176   MakeOpEligible(MLIRContext *context)
177       : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {}
178 
179   LogicalResult matchAndRewrite(Operation *op,
180                                 PatternRewriter &rewriter) const override {
181     if (op->hasAttr("eligible"))
182       return failure();
183     rewriter.modifyOpInPlace(
184         op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); });
185     return success();
186   }
187 };
188 
189 /// This pattern hoists eligible ops out of a "test.one_region_op".
190 struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
191   using OpRewritePattern<test::OneRegionOp>::OpRewritePattern;
192 
193   LogicalResult matchAndRewrite(test::OneRegionOp op,
194                                 PatternRewriter &rewriter) const override {
195     Operation *terminator = op.getRegion().front().getTerminator();
196     Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp();
197     if (toBeHoisted->getParentOp() != op)
198       return failure();
199     if (!toBeHoisted->hasAttr("eligible"))
200       return failure();
201     rewriter.moveOpBefore(toBeHoisted, op);
202     return success();
203   }
204 };
205 
206 /// This pattern moves "test.move_before_parent_op" before the parent op.
207 struct MoveBeforeParentOp : public RewritePattern {
208   MoveBeforeParentOp(MLIRContext *context)
209       : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
210 
211   LogicalResult matchAndRewrite(Operation *op,
212                                 PatternRewriter &rewriter) const override {
213     // Do not hoist past functions.
214     if (isa<FunctionOpInterface>(op->getParentOp()))
215       return failure();
216     rewriter.moveOpBefore(op, op->getParentOp());
217     return success();
218   }
219 };
220 
221 /// This pattern moves "test.move_after_parent_op" after the parent op.
222 struct MoveAfterParentOp : public RewritePattern {
223   MoveAfterParentOp(MLIRContext *context)
224       : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {}
225 
226   LogicalResult matchAndRewrite(Operation *op,
227                                 PatternRewriter &rewriter) const override {
228     // Do not hoist past functions.
229     if (isa<FunctionOpInterface>(op->getParentOp()))
230       return failure();
231 
232     int64_t moveForwardBy = 0;
233     if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance"))
234       moveForwardBy = advanceBy.getInt();
235 
236     Operation *moveAfter = op->getParentOp();
237     for (int64_t i = 0; i < moveForwardBy; ++i)
238       moveAfter = moveAfter->getNextNode();
239 
240     rewriter.moveOpAfter(op, moveAfter);
241     return success();
242   }
243 };
244 
245 /// This pattern inlines blocks that are nested in
246 /// "test.inline_blocks_into_parent" into the parent block.
247 struct InlineBlocksIntoParent : public RewritePattern {
248   InlineBlocksIntoParent(MLIRContext *context)
249       : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
250                        context) {}
251 
252   LogicalResult matchAndRewrite(Operation *op,
253                                 PatternRewriter &rewriter) const override {
254     bool changed = false;
255     for (Region &r : op->getRegions()) {
256       while (!r.empty()) {
257         rewriter.inlineBlockBefore(&r.front(), op);
258         changed = true;
259       }
260     }
261     return success(changed);
262   }
263 };
264 
265 /// This pattern splits blocks at "test.split_block_here" and replaces the op
266 /// with a new op (to prevent an infinite loop of block splitting).
267 struct SplitBlockHere : public RewritePattern {
268   SplitBlockHere(MLIRContext *context)
269       : RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
270 
271   LogicalResult matchAndRewrite(Operation *op,
272                                 PatternRewriter &rewriter) const override {
273     rewriter.splitBlock(op->getBlock(), op->getIterator());
274     Operation *newOp = rewriter.create(
275         op->getLoc(),
276         OperationName("test.new_op", op->getContext()).getIdentifier(),
277         op->getOperands(), op->getResultTypes());
278     rewriter.replaceOp(op, newOp);
279     return success();
280   }
281 };
282 
283 /// This pattern clones "test.clone_me" ops.
284 struct CloneOp : public RewritePattern {
285   CloneOp(MLIRContext *context)
286       : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
287 
288   LogicalResult matchAndRewrite(Operation *op,
289                                 PatternRewriter &rewriter) const override {
290     // Do not clone already cloned ops to avoid going into an infinite loop.
291     if (op->hasAttr("was_cloned"))
292       return failure();
293     Operation *cloned = rewriter.clone(*op);
294     cloned->setAttr("was_cloned", rewriter.getUnitAttr());
295     return success();
296   }
297 };
298 
299 /// This pattern clones regions of "test.clone_region_before" ops before the
300 /// parent block.
301 struct CloneRegionBeforeOp : public RewritePattern {
302   CloneRegionBeforeOp(MLIRContext *context)
303       : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {}
304 
305   LogicalResult matchAndRewrite(Operation *op,
306                                 PatternRewriter &rewriter) const override {
307     // Do not clone already cloned ops to avoid going into an infinite loop.
308     if (op->hasAttr("was_cloned"))
309       return failure();
310     for (Region &r : op->getRegions())
311       rewriter.cloneRegionBefore(r, op->getBlock());
312     op->setAttr("was_cloned", rewriter.getUnitAttr());
313     return success();
314   }
315 };
316 
317 /// Replace an operation may introduce the re-visiting of its users.
318 class ReplaceWithNewOp : public RewritePattern {
319 public:
320   ReplaceWithNewOp(MLIRContext *context)
321       : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
322 
323   LogicalResult matchAndRewrite(Operation *op,
324                                 PatternRewriter &rewriter) const override {
325     Operation *newOp;
326     if (op->hasAttr("create_erase_op")) {
327       newOp = rewriter.create(
328           op->getLoc(),
329           OperationName("test.erase_op", op->getContext()).getIdentifier(),
330           ValueRange(), TypeRange());
331     } else {
332       newOp = rewriter.create(
333           op->getLoc(),
334           OperationName("test.new_op", op->getContext()).getIdentifier(),
335           op->getOperands(), op->getResultTypes());
336     }
337     // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
338     // A "notifyOperationReplaced" callback is triggered in either case.
339     rewriter.replaceAllOpUsesWith(op, newOp->getResults());
340     rewriter.eraseOp(op);
341     return success();
342   }
343 };
344 
345 /// Erases the first child block of the matched "test.erase_first_block"
346 /// operation.
347 class EraseFirstBlock : public RewritePattern {
348 public:
349   EraseFirstBlock(MLIRContext *context)
350       : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {}
351 
352   LogicalResult matchAndRewrite(Operation *op,
353                                 PatternRewriter &rewriter) const override {
354     for (Region &r : op->getRegions()) {
355       for (Block &b : r.getBlocks()) {
356         rewriter.eraseBlock(&b);
357         return success();
358       }
359     }
360 
361     return failure();
362   }
363 };
364 
365 struct TestGreedyPatternDriver
366     : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> {
367   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver)
368 
369   TestGreedyPatternDriver() = default;
370   TestGreedyPatternDriver(const TestGreedyPatternDriver &other)
371       : PassWrapper(other) {}
372 
373   StringRef getArgument() const final { return "test-greedy-patterns"; }
374   StringRef getDescription() const final { return "Run test dialect patterns"; }
375   void runOnOperation() override {
376     mlir::RewritePatternSet patterns(&getContext());
377     populateWithGenerated(patterns);
378 
379     // Verify named pattern is generated with expected name.
380     patterns.add<FoldingPattern, TestNamedPatternRule,
381                  FolderInsertBeforePreviouslyFoldedConstantPattern,
382                  FolderCommutativeOp2WithConstant, HoistEligibleOps,
383                  MakeOpEligible>(&getContext());
384 
385     // Additional patterns for testing the GreedyPatternRewriteDriver.
386     patterns.insert<IncrementIntAttribute<3>>(&getContext());
387 
388     GreedyRewriteConfig config;
389     config.useTopDownTraversal = this->useTopDownTraversal;
390     config.maxIterations = this->maxIterations;
391     config.fold = this->fold;
392     config.cseConstants = this->cseConstants;
393     (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
394   }
395 
396   Option<bool> useTopDownTraversal{
397       *this, "top-down",
398       llvm::cl::desc("Seed the worklist in general top-down order"),
399       llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
400   Option<int> maxIterations{
401       *this, "max-iterations",
402       llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
403       llvm::cl::init(GreedyRewriteConfig().maxIterations)};
404   Option<bool> fold{*this, "fold", llvm::cl::desc("Whether to fold"),
405                     llvm::cl::init(GreedyRewriteConfig().fold)};
406   Option<bool> cseConstants{*this, "cse-constants",
407                             llvm::cl::desc("Whether to CSE constants"),
408                             llvm::cl::init(GreedyRewriteConfig().cseConstants)};
409 };
410 
411 struct DumpNotifications : public RewriterBase::Listener {
412   void notifyBlockInserted(Block *block, Region *previous,
413                            Region::iterator previousIt) override {
414     llvm::outs() << "notifyBlockInserted";
415     if (block->getParentOp()) {
416       llvm::outs() << " into " << block->getParentOp()->getName() << ": ";
417     } else {
418       llvm::outs() << " into unknown op: ";
419     }
420     if (previous == nullptr) {
421       llvm::outs() << "was unlinked\n";
422     } else {
423       llvm::outs() << "was linked\n";
424     }
425   }
426   void notifyOperationInserted(Operation *op,
427                                OpBuilder::InsertPoint previous) override {
428     llvm::outs() << "notifyOperationInserted: " << op->getName();
429     if (!previous.isSet()) {
430       llvm::outs() << ", was unlinked\n";
431     } else {
432       if (!previous.getPoint().getNodePtr()) {
433         llvm::outs() << ", was linked, exact position unknown\n";
434       } else if (previous.getPoint() == previous.getBlock()->end()) {
435         llvm::outs() << ", was last in block\n";
436       } else {
437         llvm::outs() << ", previous = " << previous.getPoint()->getName()
438                      << "\n";
439       }
440     }
441   }
442   void notifyBlockErased(Block *block) override {
443     llvm::outs() << "notifyBlockErased\n";
444   }
445   void notifyOperationErased(Operation *op) override {
446     llvm::outs() << "notifyOperationErased: " << op->getName() << "\n";
447   }
448   void notifyOperationModified(Operation *op) override {
449     llvm::outs() << "notifyOperationModified: " << op->getName() << "\n";
450   }
451   void notifyOperationReplaced(Operation *op, ValueRange values) override {
452     llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n";
453   }
454 };
455 
456 struct TestStrictPatternDriver
457     : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
458 public:
459   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
460 
461   TestStrictPatternDriver() = default;
462   TestStrictPatternDriver(const TestStrictPatternDriver &other)
463       : PassWrapper(other) {
464     strictMode = other.strictMode;
465   }
466 
467   StringRef getArgument() const final { return "test-strict-pattern-driver"; }
468   StringRef getDescription() const final {
469     return "Test strict mode of pattern driver";
470   }
471 
472   void runOnOperation() override {
473     MLIRContext *ctx = &getContext();
474     mlir::RewritePatternSet patterns(ctx);
475     patterns.add<
476         // clang-format off
477         ChangeBlockOp,
478         CloneOp,
479         CloneRegionBeforeOp,
480         EraseOp,
481         ImplicitChangeOp,
482         InlineBlocksIntoParent,
483         InsertSameOp,
484         MoveBeforeParentOp,
485         ReplaceWithNewOp,
486         SplitBlockHere
487         // clang-format on
488         >(ctx);
489     SmallVector<Operation *> ops;
490     getOperation()->walk([&](Operation *op) {
491       StringRef opName = op->getName().getStringRef();
492       if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
493           opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
494           opName == "test.move_before_parent_op" ||
495           opName == "test.inline_blocks_into_parent" ||
496           opName == "test.split_block_here" || opName == "test.clone_me" ||
497           opName == "test.clone_region_before") {
498         ops.push_back(op);
499       }
500     });
501 
502     DumpNotifications dumpNotifications;
503     GreedyRewriteConfig config;
504     config.listener = &dumpNotifications;
505     if (strictMode == "AnyOp") {
506       config.strictMode = GreedyRewriteStrictness::AnyOp;
507     } else if (strictMode == "ExistingAndNewOps") {
508       config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
509     } else if (strictMode == "ExistingOps") {
510       config.strictMode = GreedyRewriteStrictness::ExistingOps;
511     } else {
512       llvm_unreachable("invalid strictness option");
513     }
514 
515     // Check if these transformations introduce visiting of operations that
516     // are not in the `ops` set (The new created ops are valid). An invalid
517     // operation will trigger the assertion while processing.
518     bool changed = false;
519     bool allErased = false;
520     (void)applyOpPatternsGreedily(ArrayRef(ops), std::move(patterns), config,
521                                   &changed, &allErased);
522     Builder b(ctx);
523     getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed));
524     getOperation()->setAttr("pattern_driver_all_erased",
525                             b.getBoolAttr(allErased));
526   }
527 
528   Option<std::string> strictMode{
529       *this, "strictness",
530       llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"),
531       llvm::cl::init("AnyOp")};
532 
533 private:
534   // New inserted operation is valid for further transformation.
535   class InsertSameOp : public RewritePattern {
536   public:
537     InsertSameOp(MLIRContext *context)
538         : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {}
539 
540     LogicalResult matchAndRewrite(Operation *op,
541                                   PatternRewriter &rewriter) const override {
542       if (op->hasAttr("skip"))
543         return failure();
544 
545       Operation *newOp =
546           rewriter.create(op->getLoc(), op->getName().getIdentifier(),
547                           op->getOperands(), op->getResultTypes());
548       rewriter.modifyOpInPlace(
549           op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); });
550       newOp->setAttr("skip", rewriter.getBoolAttr(true));
551 
552       return success();
553     }
554   };
555 
556   // Remove an operation may introduce the re-visiting of its operands.
557   class EraseOp : public RewritePattern {
558   public:
559     EraseOp(MLIRContext *context)
560         : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
561     LogicalResult matchAndRewrite(Operation *op,
562                                   PatternRewriter &rewriter) const override {
563       rewriter.eraseOp(op);
564       return success();
565     }
566   };
567 
568   // The following two patterns test RewriterBase::replaceAllUsesWith.
569   //
570   // That function replaces all usages of a Block (or a Value) with another one
571   // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver
572   // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its
573   // worklist: when an op is modified, it is added to the worklist. The two
574   // patterns below make the tracking observable: ChangeBlockOp replaces all
575   // usages of a block and that pattern is applied because the corresponding ops
576   // are put on the initial worklist (see above). ImplicitChangeOp does an
577   // unrelated change but ops of the corresponding type are *not* on the initial
578   // worklist, so the effect of the second pattern is only visible if the
579   // tracking and subsequent adding to the worklist actually works.
580 
581   // Replace all usages of the first successor with the second successor.
582   class ChangeBlockOp : public RewritePattern {
583   public:
584     ChangeBlockOp(MLIRContext *context)
585         : RewritePattern("test.change_block_op", /*benefit=*/1, context) {}
586     LogicalResult matchAndRewrite(Operation *op,
587                                   PatternRewriter &rewriter) const override {
588       if (op->getNumSuccessors() < 2)
589         return failure();
590       Block *firstSuccessor = op->getSuccessor(0);
591       Block *secondSuccessor = op->getSuccessor(1);
592       if (firstSuccessor == secondSuccessor)
593         return failure();
594       // This is the function being tested:
595       rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor);
596       // Using the following line instead would make the test fail:
597       // firstSuccessor->replaceAllUsesWith(secondSuccessor);
598       return success();
599     }
600   };
601 
602   // Changes the successor to the parent block.
603   class ImplicitChangeOp : public RewritePattern {
604   public:
605     ImplicitChangeOp(MLIRContext *context)
606         : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {}
607     LogicalResult matchAndRewrite(Operation *op,
608                                   PatternRewriter &rewriter) const override {
609       if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock())
610         return failure();
611       rewriter.modifyOpInPlace(op,
612                                [&]() { op->setSuccessor(op->getBlock(), 0); });
613       return success();
614     }
615   };
616 };
617 
618 struct TestWalkPatternDriver final
619     : PassWrapper<TestWalkPatternDriver, OperationPass<>> {
620   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver)
621 
622   TestWalkPatternDriver() = default;
623   TestWalkPatternDriver(const TestWalkPatternDriver &other)
624       : PassWrapper(other) {}
625 
626   StringRef getArgument() const override {
627     return "test-walk-pattern-rewrite-driver";
628   }
629   StringRef getDescription() const override {
630     return "Run test walk pattern rewrite driver";
631   }
632   void runOnOperation() override {
633     mlir::RewritePatternSet patterns(&getContext());
634 
635     // Patterns for testing the WalkPatternRewriteDriver.
636     patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp,
637                  MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>(
638         &getContext());
639 
640     DumpNotifications dumpListener;
641     walkAndApplyPatterns(getOperation(), std::move(patterns),
642                          dumpNotifications ? &dumpListener : nullptr);
643   }
644 
645   Option<bool> dumpNotifications{
646       *this, "dump-notifications",
647       llvm::cl::desc("Print rewrite listener notifications"),
648       llvm::cl::init(false)};
649 };
650 
651 } // namespace
652 
653 //===----------------------------------------------------------------------===//
654 // ReturnType Driver.
655 //===----------------------------------------------------------------------===//
656 
657 namespace {
658 // Generate ops for each instance where the type can be successfully inferred.
659 template <typename OpTy>
660 static void invokeCreateWithInferredReturnType(Operation *op) {
661   auto *context = op->getContext();
662   auto fop = op->getParentOfType<func::FuncOp>();
663   auto location = UnknownLoc::get(context);
664   OpBuilder b(op);
665   b.setInsertionPointAfter(op);
666 
667   // Use permutations of 2 args as operands.
668   assert(fop.getNumArguments() >= 2);
669   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
670     for (int j = 0; j < e; ++j) {
671       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
672       SmallVector<Type, 2> inferredReturnTypes;
673       if (succeeded(OpTy::inferReturnTypes(
674               context, std::nullopt, values, op->getDiscardableAttrDictionary(),
675               op->getPropertiesStorage(), op->getRegions(),
676               inferredReturnTypes))) {
677         OperationState state(location, OpTy::getOperationName());
678         // TODO: Expand to regions.
679         OpTy::build(b, state, values, op->getAttrs());
680         (void)b.create(state);
681       }
682     }
683   }
684 }
685 
686 static void reifyReturnShape(Operation *op) {
687   OpBuilder b(op);
688 
689   // Use permutations of 2 args as operands.
690   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
691   SmallVector<Value, 2> shapes;
692   if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
693       !llvm::hasSingleElement(shapes))
694     return;
695   for (const auto &it : llvm::enumerate(shapes)) {
696     op->emitRemark() << "value " << it.index() << ": "
697                      << it.value().getDefiningOp();
698   }
699 }
700 
701 struct TestReturnTypeDriver
702     : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> {
703   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver)
704 
705   void getDependentDialects(DialectRegistry &registry) const override {
706     registry.insert<tensor::TensorDialect>();
707   }
708   StringRef getArgument() const final { return "test-return-type"; }
709   StringRef getDescription() const final { return "Run return type functions"; }
710 
711   void runOnOperation() override {
712     if (getOperation().getName() == "testCreateFunctions") {
713       std::vector<Operation *> ops;
714       // Collect ops to avoid triggering on inserted ops.
715       for (auto &op : getOperation().getBody().front())
716         ops.push_back(&op);
717       // Generate test patterns for each, but skip terminator.
718       for (auto *op : llvm::ArrayRef(ops).drop_back()) {
719         // Test create method of each of the Op classes below. The resultant
720         // output would be in reverse order underneath `op` from which
721         // the attributes and regions are used.
722         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
723         invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>(
724             op);
725         invokeCreateWithInferredReturnType<
726             OpWithShapedTypeInferTypeInterfaceOp>(op);
727       };
728       return;
729     }
730     if (getOperation().getName() == "testReifyFunctions") {
731       std::vector<Operation *> ops;
732       // Collect ops to avoid triggering on inserted ops.
733       for (auto &op : getOperation().getBody().front())
734         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
735           ops.push_back(&op);
736       // Generate test patterns for each, but skip terminator.
737       for (auto *op : ops)
738         reifyReturnShape(op);
739     }
740   }
741 };
742 } // namespace
743 
744 namespace {
745 struct TestDerivedAttributeDriver
746     : public PassWrapper<TestDerivedAttributeDriver,
747                          OperationPass<func::FuncOp>> {
748   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver)
749 
750   StringRef getArgument() const final { return "test-derived-attr"; }
751   StringRef getDescription() const final {
752     return "Run test derived attributes";
753   }
754   void runOnOperation() override;
755 };
756 } // namespace
757 
758 void TestDerivedAttributeDriver::runOnOperation() {
759   getOperation().walk([](DerivedAttributeOpInterface dOp) {
760     auto dAttr = dOp.materializeDerivedAttributes();
761     if (!dAttr)
762       return;
763     for (auto d : dAttr)
764       dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue();
765   });
766 }
767 
768 //===----------------------------------------------------------------------===//
769 // Legalization Driver.
770 //===----------------------------------------------------------------------===//
771 
772 namespace {
773 //===----------------------------------------------------------------------===//
774 // Region-Block Rewrite Testing
775 
776 /// This pattern applies a signature conversion to a block inside a detached
777 /// region.
778 struct TestDetachedSignatureConversion : public ConversionPattern {
779   TestDetachedSignatureConversion(MLIRContext *ctx)
780       : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1,
781                           ctx) {}
782 
783   LogicalResult
784   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
785                   ConversionPatternRewriter &rewriter) const final {
786     if (op->getNumRegions() != 1)
787       return failure();
788     OperationState state(op->getLoc(), "test.legal_op", operands,
789                          op->getResultTypes(), {}, BlockRange());
790     Region *newRegion = state.addRegion();
791     rewriter.inlineRegionBefore(op->getRegion(0), *newRegion,
792                                 newRegion->begin());
793     TypeConverter::SignatureConversion result(newRegion->getNumArguments());
794     for (unsigned i = 0, e = newRegion->getNumArguments(); i < e; ++i)
795       result.addInputs(i, rewriter.getF64Type());
796     rewriter.applySignatureConversion(&newRegion->front(), result);
797     Operation *newOp = rewriter.create(state);
798     rewriter.replaceOp(op, newOp->getResults());
799     return success();
800   }
801 };
802 
803 /// This pattern is a simple pattern that inlines the first region of a given
804 /// operation into the parent region.
805 struct TestRegionRewriteBlockMovement : public ConversionPattern {
806   TestRegionRewriteBlockMovement(MLIRContext *ctx)
807       : ConversionPattern("test.region", 1, ctx) {}
808 
809   LogicalResult
810   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
811                   ConversionPatternRewriter &rewriter) const final {
812     // Inline this region into the parent region.
813     auto &parentRegion = *op->getParentRegion();
814     auto &opRegion = op->getRegion(0);
815     if (op->getDiscardableAttr("legalizer.should_clone"))
816       rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
817     else
818       rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
819 
820     if (op->getDiscardableAttr("legalizer.erase_old_blocks")) {
821       while (!opRegion.empty())
822         rewriter.eraseBlock(&opRegion.front());
823     }
824 
825     // Drop this operation.
826     rewriter.eraseOp(op);
827     return success();
828   }
829 };
830 /// This pattern is a simple pattern that generates a region containing an
831 /// illegal operation.
832 struct TestRegionRewriteUndo : public RewritePattern {
833   TestRegionRewriteUndo(MLIRContext *ctx)
834       : RewritePattern("test.region_builder", 1, ctx) {}
835 
836   LogicalResult matchAndRewrite(Operation *op,
837                                 PatternRewriter &rewriter) const final {
838     // Create the region operation with an entry block containing arguments.
839     OperationState newRegion(op->getLoc(), "test.region");
840     newRegion.addRegion();
841     auto *regionOp = rewriter.create(newRegion);
842     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
843     entryBlock->addArgument(rewriter.getIntegerType(64),
844                             rewriter.getUnknownLoc());
845 
846     // Add an explicitly illegal operation to ensure the conversion fails.
847     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
848     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
849 
850     // Drop this operation.
851     rewriter.eraseOp(op);
852     return success();
853   }
854 };
855 /// A simple pattern that creates a block at the end of the parent region of the
856 /// matched operation.
857 struct TestCreateBlock : public RewritePattern {
858   TestCreateBlock(MLIRContext *ctx)
859       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
860 
861   LogicalResult matchAndRewrite(Operation *op,
862                                 PatternRewriter &rewriter) const final {
863     Region &region = *op->getParentRegion();
864     Type i32Type = rewriter.getIntegerType(32);
865     Location loc = op->getLoc();
866     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
867     rewriter.create<TerminatorOp>(loc);
868     rewriter.eraseOp(op);
869     return success();
870   }
871 };
872 
873 /// A simple pattern that creates a block containing an invalid operation in
874 /// order to trigger the block creation undo mechanism.
875 struct TestCreateIllegalBlock : public RewritePattern {
876   TestCreateIllegalBlock(MLIRContext *ctx)
877       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
878 
879   LogicalResult matchAndRewrite(Operation *op,
880                                 PatternRewriter &rewriter) const final {
881     Region &region = *op->getParentRegion();
882     Type i32Type = rewriter.getIntegerType(32);
883     Location loc = op->getLoc();
884     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
885     // Create an illegal op to ensure the conversion fails.
886     rewriter.create<ILLegalOpF>(loc, i32Type);
887     rewriter.create<TerminatorOp>(loc);
888     rewriter.eraseOp(op);
889     return success();
890   }
891 };
892 
893 /// A simple pattern that tests the undo mechanism when replacing the uses of a
894 /// block argument.
895 struct TestUndoBlockArgReplace : public ConversionPattern {
896   TestUndoBlockArgReplace(MLIRContext *ctx)
897       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
898 
899   LogicalResult
900   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
901                   ConversionPatternRewriter &rewriter) const final {
902     auto illegalOp =
903         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
904     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
905                                         illegalOp->getResult(0));
906     rewriter.modifyOpInPlace(op, [] {});
907     return success();
908   }
909 };
910 
911 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
912 /// This is to test the rollback logic.
913 struct TestUndoMoveOpBefore : public ConversionPattern {
914   TestUndoMoveOpBefore(MLIRContext *ctx)
915       : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
916 
917   LogicalResult
918   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
919                   ConversionPatternRewriter &rewriter) const override {
920     rewriter.moveOpBefore(op, op->getParentOp());
921     // Replace with an illegal op to ensure the conversion fails.
922     rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
923     return success();
924   }
925 };
926 
927 /// A rewrite pattern that tests the undo mechanism when erasing a block.
928 struct TestUndoBlockErase : public ConversionPattern {
929   TestUndoBlockErase(MLIRContext *ctx)
930       : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
931 
932   LogicalResult
933   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
934                   ConversionPatternRewriter &rewriter) const final {
935     Block *secondBlock = &*std::next(op->getRegion(0).begin());
936     rewriter.setInsertionPointToStart(secondBlock);
937     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
938     rewriter.eraseBlock(secondBlock);
939     rewriter.modifyOpInPlace(op, [] {});
940     return success();
941   }
942 };
943 
944 /// A pattern that modifies a property in-place, but keeps the op illegal.
945 struct TestUndoPropertiesModification : public ConversionPattern {
946   TestUndoPropertiesModification(MLIRContext *ctx)
947       : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {}
948   LogicalResult
949   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
950                   ConversionPatternRewriter &rewriter) const final {
951     if (!op->hasAttr("modify_inplace"))
952       return failure();
953     rewriter.modifyOpInPlace(
954         op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); });
955     return success();
956   }
957 };
958 
959 //===----------------------------------------------------------------------===//
960 // Type-Conversion Rewrite Testing
961 
962 /// This patterns erases a region operation that has had a type conversion.
963 struct TestDropOpSignatureConversion : public ConversionPattern {
964   TestDropOpSignatureConversion(MLIRContext *ctx,
965                                 const TypeConverter &converter)
966       : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
967   LogicalResult
968   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
969                   ConversionPatternRewriter &rewriter) const override {
970     Region &region = op->getRegion(0);
971     Block *entry = &region.front();
972 
973     // Convert the original entry arguments.
974     const TypeConverter &converter = *getTypeConverter();
975     TypeConverter::SignatureConversion result(entry->getNumArguments());
976     if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
977                                               result)) ||
978         failed(rewriter.convertRegionTypes(&region, converter, &result)))
979       return failure();
980 
981     // Convert the region signature and just drop the operation.
982     rewriter.eraseOp(op);
983     return success();
984   }
985 };
986 /// This pattern simply updates the operands of the given operation.
987 struct TestPassthroughInvalidOp : public ConversionPattern {
988   TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
989       : ConversionPattern(converter, "test.invalid", 1, ctx) {}
990   LogicalResult
991   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
992                   ConversionPatternRewriter &rewriter) const final {
993     SmallVector<Value> flattened;
994     for (auto it : llvm::enumerate(operands)) {
995       ValueRange range = it.value();
996       if (range.size() == 1) {
997         flattened.push_back(range.front());
998         continue;
999       }
1000 
1001       // This is a 1:N replacement. Insert a test.cast op. (That's what the
1002       // argument materialization used to do.)
1003       flattened.push_back(
1004           rewriter
1005               .create<TestCastOp>(op->getLoc(),
1006                                   op->getOperand(it.index()).getType(), range)
1007               .getResult());
1008     }
1009     rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, flattened,
1010                                              std::nullopt);
1011     return success();
1012   }
1013 };
1014 /// Replace with valid op, but simply drop the operands. This is used in a
1015 /// regression where we used to generate circular unrealized_conversion_cast
1016 /// ops.
1017 struct TestDropAndReplaceInvalidOp : public ConversionPattern {
1018   TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
1019       : ConversionPattern(converter,
1020                           "test.drop_operands_and_replace_with_valid", 1, ctx) {
1021   }
1022   LogicalResult
1023   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1024                   ConversionPatternRewriter &rewriter) const final {
1025     rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, ValueRange(),
1026                                              std::nullopt);
1027     return success();
1028   }
1029 };
1030 /// This pattern handles the case of a split return value.
1031 struct TestSplitReturnType : public ConversionPattern {
1032   TestSplitReturnType(MLIRContext *ctx)
1033       : ConversionPattern("test.return", 1, ctx) {}
1034   LogicalResult
1035   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1036                   ConversionPatternRewriter &rewriter) const final {
1037     // Check for a return of F32.
1038     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
1039       return failure();
1040     rewriter.replaceOpWithNewOp<TestReturnOp>(op, operands[0]);
1041     return success();
1042   }
1043 };
1044 
1045 //===----------------------------------------------------------------------===//
1046 // Multi-Level Type-Conversion Rewrite Testing
1047 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
1048   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
1049       : ConversionPattern("test.type_producer", 1, ctx) {}
1050   LogicalResult
1051   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1052                   ConversionPatternRewriter &rewriter) const final {
1053     // If the type is I32, change the type to F32.
1054     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
1055       return failure();
1056     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
1057     return success();
1058   }
1059 };
1060 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
1061   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
1062       : ConversionPattern("test.type_producer", 1, ctx) {}
1063   LogicalResult
1064   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1065                   ConversionPatternRewriter &rewriter) const final {
1066     // If the type is F32, change the type to F64.
1067     if (!Type(*op->result_type_begin()).isF32())
1068       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
1069     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
1070     return success();
1071   }
1072 };
1073 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
1074   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
1075       : ConversionPattern("test.type_producer", 10, ctx) {}
1076   LogicalResult
1077   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1078                   ConversionPatternRewriter &rewriter) const final {
1079     // Always convert to B16, even though it is not a legal type. This tests
1080     // that values are unmapped correctly.
1081     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
1082     return success();
1083   }
1084 };
1085 struct TestUpdateConsumerType : public ConversionPattern {
1086   TestUpdateConsumerType(MLIRContext *ctx)
1087       : ConversionPattern("test.type_consumer", 1, ctx) {}
1088   LogicalResult
1089   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1090                   ConversionPatternRewriter &rewriter) const final {
1091     // Verify that the incoming operand has been successfully remapped to F64.
1092     if (!operands[0].getType().isF64())
1093       return failure();
1094     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
1095     return success();
1096   }
1097 };
1098 
1099 //===----------------------------------------------------------------------===//
1100 // Non-Root Replacement Rewrite Testing
1101 /// This pattern generates an invalid operation, but replaces it before the
1102 /// pattern is finished. This checks that we don't need to legalize the
1103 /// temporary op.
1104 struct TestNonRootReplacement : public RewritePattern {
1105   TestNonRootReplacement(MLIRContext *ctx)
1106       : RewritePattern("test.replace_non_root", 1, ctx) {}
1107 
1108   LogicalResult matchAndRewrite(Operation *op,
1109                                 PatternRewriter &rewriter) const final {
1110     auto resultType = *op->result_type_begin();
1111     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
1112     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
1113 
1114     rewriter.replaceOp(illegalOp, legalOp);
1115     rewriter.replaceOp(op, illegalOp);
1116     return success();
1117   }
1118 };
1119 
1120 //===----------------------------------------------------------------------===//
1121 // Recursive Rewrite Testing
1122 /// This pattern is applied to the same operation multiple times, but has a
1123 /// bounded recursion.
1124 struct TestBoundedRecursiveRewrite
1125     : public OpRewritePattern<TestRecursiveRewriteOp> {
1126   using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
1127 
1128   void initialize() {
1129     // The conversion target handles bounding the recursion of this pattern.
1130     setHasBoundedRewriteRecursion();
1131   }
1132 
1133   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
1134                                 PatternRewriter &rewriter) const final {
1135     // Decrement the depth of the op in-place.
1136     rewriter.modifyOpInPlace(op, [&] {
1137       op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
1138     });
1139     return success();
1140   }
1141 };
1142 
1143 struct TestNestedOpCreationUndoRewrite
1144     : public OpRewritePattern<IllegalOpWithRegionAnchor> {
1145   using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
1146 
1147   LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
1148                                 PatternRewriter &rewriter) const final {
1149     // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1150     rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1151     return success();
1152   };
1153 };
1154 
1155 // This pattern matches `test.blackhole` and delete this op and its producer.
1156 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
1157   using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
1158 
1159   LogicalResult matchAndRewrite(BlackHoleOp op,
1160                                 PatternRewriter &rewriter) const final {
1161     Operation *producer = op.getOperand().getDefiningOp();
1162     // Always erase the user before the producer, the framework should handle
1163     // this correctly.
1164     rewriter.eraseOp(op);
1165     rewriter.eraseOp(producer);
1166     return success();
1167   };
1168 };
1169 
1170 // This pattern replaces explicitly illegal op with explicitly legal op,
1171 // but in addition creates unregistered operation.
1172 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
1173   using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
1174 
1175   LogicalResult matchAndRewrite(ILLegalOpG op,
1176                                 PatternRewriter &rewriter) const final {
1177     IntegerAttr attr = rewriter.getI32IntegerAttr(0);
1178     Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
1179     rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
1180     return success();
1181   };
1182 };
1183 
1184 class TestEraseOp : public ConversionPattern {
1185 public:
1186   TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op", 1, ctx) {}
1187   LogicalResult
1188   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1189                   ConversionPatternRewriter &rewriter) const final {
1190     // Erase op without replacements.
1191     rewriter.eraseOp(op);
1192     return success();
1193   }
1194 };
1195 
1196 /// This pattern matches a test.duplicate_block_args op and duplicates all
1197 /// block arguments.
1198 class TestDuplicateBlockArgs
1199     : public OpConversionPattern<DuplicateBlockArgsOp> {
1200   using OpConversionPattern<DuplicateBlockArgsOp>::OpConversionPattern;
1201 
1202   LogicalResult
1203   matchAndRewrite(DuplicateBlockArgsOp op, OpAdaptor adaptor,
1204                   ConversionPatternRewriter &rewriter) const override {
1205     if (op.getIsLegal())
1206       return failure();
1207     rewriter.startOpModification(op);
1208     Block *body = &op.getBody().front();
1209     TypeConverter::SignatureConversion result(body->getNumArguments());
1210     for (auto it : llvm::enumerate(body->getArgumentTypes()))
1211       result.addInputs(it.index(), {it.value(), it.value()});
1212     rewriter.applySignatureConversion(body, result, getTypeConverter());
1213     op.setIsLegal(true);
1214     rewriter.finalizeOpModification(op);
1215     return success();
1216   }
1217 };
1218 
1219 /// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid
1220 /// op. The pattern supports 1:N replacements and forwards the replacement
1221 /// values of the single operand as test.valid operands.
1222 class TestRepetitive1ToNConsumer : public ConversionPattern {
1223 public:
1224   TestRepetitive1ToNConsumer(MLIRContext *ctx)
1225       : ConversionPattern("test.repetitive_1_to_n_consumer", 1, ctx) {}
1226   LogicalResult
1227   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1228                   ConversionPatternRewriter &rewriter) const final {
1229     // A single operand is expected.
1230     if (op->getNumOperands() != 1)
1231       return failure();
1232     rewriter.replaceOpWithNewOp<TestValidOp>(op, operands.front());
1233     return success();
1234   }
1235 };
1236 
1237 /// A pattern that tests two back-to-back 1 -> 2 op replacements.
1238 class TestMultiple1ToNReplacement : public ConversionPattern {
1239 public:
1240   TestMultiple1ToNReplacement(MLIRContext *ctx, const TypeConverter &converter)
1241       : ConversionPattern(converter, "test.multiple_1_to_n_replacement", 1,
1242                           ctx) {}
1243   LogicalResult
1244   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1245                   ConversionPatternRewriter &rewriter) const final {
1246     // Helper function that replaces the given op with a new op of the given
1247     // name and doubles each result (1 -> 2 replacement of each result).
1248     auto replaceWithDoubleResults = [&](Operation *op, StringRef name) {
1249       SmallVector<Type> types;
1250       for (Type t : op->getResultTypes()) {
1251         types.push_back(t);
1252         types.push_back(t);
1253       }
1254       OperationState state(op->getLoc(), name,
1255                            /*operands=*/{}, types, op->getAttrs());
1256       auto *newOp = rewriter.create(state);
1257       SmallVector<ValueRange> repls;
1258       for (size_t i = 0, e = op->getNumResults(); i < e; ++i)
1259         repls.push_back(newOp->getResults().slice(2 * i, 2));
1260       rewriter.replaceOpWithMultiple(op, repls);
1261       return newOp;
1262     };
1263 
1264     // Replace test.multiple_1_to_n_replacement with test.step_1.
1265     Operation *repl1 = replaceWithDoubleResults(op, "test.step_1");
1266     // Now replace test.step_1 with test.legal_op.
1267     replaceWithDoubleResults(repl1, "test.legal_op");
1268     return success();
1269   }
1270 };
1271 
1272 } // namespace
1273 
1274 namespace {
1275 struct TestTypeConverter : public TypeConverter {
1276   using TypeConverter::TypeConverter;
1277   TestTypeConverter() {
1278     addConversion(convertType);
1279     addSourceMaterialization(materializeCast);
1280   }
1281 
1282   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
1283     // Drop I16 types.
1284     if (t.isSignlessInteger(16))
1285       return success();
1286 
1287     // Convert I64 to F64.
1288     if (t.isSignlessInteger(64)) {
1289       results.push_back(Float64Type::get(t.getContext()));
1290       return success();
1291     }
1292 
1293     // Convert I42 to I43.
1294     if (t.isInteger(42)) {
1295       results.push_back(IntegerType::get(t.getContext(), 43));
1296       return success();
1297     }
1298 
1299     // Split F32 into F16,F16.
1300     if (t.isF32()) {
1301       results.assign(2, Float16Type::get(t.getContext()));
1302       return success();
1303     }
1304 
1305     // Drop I24 types.
1306     if (t.isInteger(24)) {
1307       return success();
1308     }
1309 
1310     // Otherwise, convert the type directly.
1311     results.push_back(t);
1312     return success();
1313   }
1314 
1315   /// Hook for materializing a conversion. This is necessary because we generate
1316   /// 1->N type mappings.
1317   static Value materializeCast(OpBuilder &builder, Type resultType,
1318                                ValueRange inputs, Location loc) {
1319     return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1320   }
1321 };
1322 
1323 struct TestLegalizePatternDriver
1324     : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> {
1325   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver)
1326 
1327   StringRef getArgument() const final { return "test-legalize-patterns"; }
1328   StringRef getDescription() const final {
1329     return "Run test dialect legalization patterns";
1330   }
1331   /// The mode of conversion to use with the driver.
1332   enum class ConversionMode { Analysis, Full, Partial };
1333 
1334   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
1335 
1336   void getDependentDialects(DialectRegistry &registry) const override {
1337     registry.insert<func::FuncDialect, test::TestDialect>();
1338   }
1339 
1340   void runOnOperation() override {
1341     TestTypeConverter converter;
1342     mlir::RewritePatternSet patterns(&getContext());
1343     populateWithGenerated(patterns);
1344     patterns
1345         .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1346              TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1347              TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
1348              TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
1349              TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
1350              TestNonRootReplacement, TestBoundedRecursiveRewrite,
1351              TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
1352              TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1353              TestUndoPropertiesModification, TestEraseOp,
1354              TestRepetitive1ToNConsumer>(&getContext());
1355     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
1356                  TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
1357         &getContext(), converter);
1358     patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
1359     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1360                                                               converter);
1361     mlir::populateCallOpTypeConversionPattern(patterns, converter);
1362 
1363     // Define the conversion target used for the test.
1364     ConversionTarget target(getContext());
1365     target.addLegalOp<ModuleOp>();
1366     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
1367                       TerminatorOp, OneRegionOp>();
1368     target.addLegalOp(OperationName("test.legal_op", &getContext()));
1369     target
1370         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
1371     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
1372       // Don't allow F32 operands.
1373       return llvm::none_of(op.getOperandTypes(),
1374                            [](Type type) { return type.isF32(); });
1375     });
1376     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1377       return converter.isSignatureLegal(op.getFunctionType()) &&
1378              converter.isLegal(&op.getBody());
1379     });
1380     target.addDynamicallyLegalOp<func::CallOp>(
1381         [&](func::CallOp op) { return converter.isLegal(op); });
1382 
1383     // TestCreateUnregisteredOp creates `arith.constant` operation,
1384     // which was not added to target intentionally to test
1385     // correct error code from conversion driver.
1386     target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
1387 
1388     // Expect the type_producer/type_consumer operations to only operate on f64.
1389     target.addDynamicallyLegalOp<TestTypeProducerOp>(
1390         [](TestTypeProducerOp op) { return op.getType().isF64(); });
1391     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1392       return op.getOperand().getType().isF64();
1393     });
1394 
1395     // Check support for marking certain operations as recursively legal.
1396     target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) {
1397       return static_cast<bool>(
1398           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
1399     });
1400 
1401     // Mark the bound recursion operation as dynamically legal.
1402     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
1403         [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
1404 
1405     // Create a dynamically legal rule that can only be legalized by folding it.
1406     target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
1407         [](TestOpInPlaceSelfFold op) { return op.getFolded(); });
1408 
1409     target.addDynamicallyLegalOp<DuplicateBlockArgsOp>(
1410         [](DuplicateBlockArgsOp op) { return op.getIsLegal(); });
1411 
1412     // Handle a partial conversion.
1413     if (mode == ConversionMode::Partial) {
1414       DenseSet<Operation *> unlegalizedOps;
1415       ConversionConfig config;
1416       DumpNotifications dumpNotifications;
1417       config.listener = &dumpNotifications;
1418       config.unlegalizedOps = &unlegalizedOps;
1419       if (failed(applyPartialConversion(getOperation(), target,
1420                                         std::move(patterns), config))) {
1421         getOperation()->emitRemark() << "applyPartialConversion failed";
1422       }
1423       // Emit remarks for each legalizable operation.
1424       for (auto *op : unlegalizedOps)
1425         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1426       return;
1427     }
1428 
1429     // Handle a full conversion.
1430     if (mode == ConversionMode::Full) {
1431       // Check support for marking unknown operations as dynamically legal.
1432       target.markUnknownOpDynamicallyLegal([](Operation *op) {
1433         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
1434       });
1435 
1436       ConversionConfig config;
1437       DumpNotifications dumpNotifications;
1438       config.listener = &dumpNotifications;
1439       if (failed(applyFullConversion(getOperation(), target,
1440                                      std::move(patterns), config))) {
1441         getOperation()->emitRemark() << "applyFullConversion failed";
1442       }
1443       return;
1444     }
1445 
1446     // Otherwise, handle an analysis conversion.
1447     assert(mode == ConversionMode::Analysis);
1448 
1449     // Analyze the convertible operations.
1450     DenseSet<Operation *> legalizedOps;
1451     ConversionConfig config;
1452     config.legalizableOps = &legalizedOps;
1453     if (failed(applyAnalysisConversion(getOperation(), target,
1454                                        std::move(patterns), config)))
1455       return signalPassFailure();
1456 
1457     // Emit remarks for each legalizable operation.
1458     for (auto *op : legalizedOps)
1459       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
1460   }
1461 
1462   /// The mode of conversion to use.
1463   ConversionMode mode;
1464 };
1465 } // namespace
1466 
1467 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
1468     legalizerConversionMode(
1469         "test-legalize-mode",
1470         llvm::cl::desc("The legalization mode to use with the test driver"),
1471         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
1472         llvm::cl::values(
1473             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
1474                        "analysis", "Perform an analysis conversion"),
1475             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
1476                        "Perform a full conversion"),
1477             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
1478                        "partial", "Perform a partial conversion")));
1479 
1480 //===----------------------------------------------------------------------===//
1481 // ConversionPatternRewriter::getRemappedValue testing. This method is used
1482 // to get the remapped value of an original value that was replaced using
1483 // ConversionPatternRewriter.
1484 namespace {
1485 struct TestRemapValueTypeConverter : public TypeConverter {
1486   using TypeConverter::TypeConverter;
1487 
1488   TestRemapValueTypeConverter() {
1489     addConversion(
1490         [](Float32Type type) { return Float64Type::get(type.getContext()); });
1491     addConversion([](Type type) { return type; });
1492   }
1493 };
1494 
1495 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
1496 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
1497 /// operand twice.
1498 ///
1499 /// Example:
1500 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
1501 /// is replaced with:
1502 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
1503 struct OneVResOneVOperandOp1Converter
1504     : public OpConversionPattern<OneVResOneVOperandOp1> {
1505   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
1506 
1507   LogicalResult
1508   matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
1509                   ConversionPatternRewriter &rewriter) const override {
1510     auto origOps = op.getOperands();
1511     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
1512            "One operand expected");
1513     Value origOp = *origOps.begin();
1514     SmallVector<Value, 2> remappedOperands;
1515     // Replicate the remapped original operand twice. Note that we don't used
1516     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
1517     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1518     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1519 
1520     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
1521                                                        remappedOperands);
1522     return success();
1523   }
1524 };
1525 
1526 /// A rewriter pattern that tests that blocks can be merged.
1527 struct TestRemapValueInRegion
1528     : public OpConversionPattern<TestRemappedValueRegionOp> {
1529   using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
1530 
1531   LogicalResult
1532   matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
1533                   ConversionPatternRewriter &rewriter) const final {
1534     Block &block = op.getBody().front();
1535     Operation *terminator = block.getTerminator();
1536 
1537     // Merge the block into the parent region.
1538     Block *parentBlock = op->getBlock();
1539     Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator());
1540     rewriter.mergeBlocks(&block, parentBlock, ValueRange());
1541     rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange());
1542 
1543     // Replace the results of this operation with the remapped terminator
1544     // values.
1545     SmallVector<Value> terminatorOperands;
1546     if (failed(rewriter.getRemappedValues(terminator->getOperands(),
1547                                           terminatorOperands)))
1548       return failure();
1549 
1550     rewriter.eraseOp(terminator);
1551     rewriter.replaceOp(op, terminatorOperands);
1552     return success();
1553   }
1554 };
1555 
1556 struct TestRemappedValue
1557     : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> {
1558   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue)
1559 
1560   StringRef getArgument() const final { return "test-remapped-value"; }
1561   StringRef getDescription() const final {
1562     return "Test public remapped value mechanism in ConversionPatternRewriter";
1563   }
1564   void runOnOperation() override {
1565     TestRemapValueTypeConverter typeConverter;
1566 
1567     mlir::RewritePatternSet patterns(&getContext());
1568     patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
1569     patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
1570         &getContext());
1571     patterns.add<TestRemapValueInRegion>(typeConverter, &getContext());
1572 
1573     mlir::ConversionTarget target(getContext());
1574     target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>();
1575 
1576     // Expect the type_producer/type_consumer operations to only operate on f64.
1577     target.addDynamicallyLegalOp<TestTypeProducerOp>(
1578         [](TestTypeProducerOp op) { return op.getType().isF64(); });
1579     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1580       return op.getOperand().getType().isF64();
1581     });
1582 
1583     // We make OneVResOneVOperandOp1 legal only when it has more that one
1584     // operand. This will trigger the conversion that will replace one-operand
1585     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
1586     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
1587         [](Operation *op) { return op->getNumOperands() > 1; });
1588 
1589     if (failed(mlir::applyFullConversion(getOperation(), target,
1590                                          std::move(patterns)))) {
1591       signalPassFailure();
1592     }
1593   }
1594 };
1595 } // namespace
1596 
1597 //===----------------------------------------------------------------------===//
1598 // Test patterns without a specific root operation kind
1599 //===----------------------------------------------------------------------===//
1600 
1601 namespace {
1602 /// This pattern matches and removes any operation in the test dialect.
1603 struct RemoveTestDialectOps : public RewritePattern {
1604   RemoveTestDialectOps(MLIRContext *context)
1605       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
1606 
1607   LogicalResult matchAndRewrite(Operation *op,
1608                                 PatternRewriter &rewriter) const override {
1609     if (!isa<TestDialect>(op->getDialect()))
1610       return failure();
1611     rewriter.eraseOp(op);
1612     return success();
1613   }
1614 };
1615 
1616 struct TestUnknownRootOpDriver
1617     : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> {
1618   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver)
1619 
1620   StringRef getArgument() const final {
1621     return "test-legalize-unknown-root-patterns";
1622   }
1623   StringRef getDescription() const final {
1624     return "Test public remapped value mechanism in ConversionPatternRewriter";
1625   }
1626   void runOnOperation() override {
1627     mlir::RewritePatternSet patterns(&getContext());
1628     patterns.add<RemoveTestDialectOps>(&getContext());
1629 
1630     mlir::ConversionTarget target(getContext());
1631     target.addIllegalDialect<TestDialect>();
1632     if (failed(applyPartialConversion(getOperation(), target,
1633                                       std::move(patterns))))
1634       signalPassFailure();
1635   }
1636 };
1637 } // namespace
1638 
1639 //===----------------------------------------------------------------------===//
1640 // Test patterns that uses operations and types defined at runtime
1641 //===----------------------------------------------------------------------===//
1642 
1643 namespace {
1644 /// This pattern matches dynamic operations 'test.one_operand_two_results' and
1645 /// replace them with dynamic operations 'test.generic_dynamic_op'.
1646 struct RewriteDynamicOp : public RewritePattern {
1647   RewriteDynamicOp(MLIRContext *context)
1648       : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1,
1649                        context) {}
1650 
1651   LogicalResult matchAndRewrite(Operation *op,
1652                                 PatternRewriter &rewriter) const override {
1653     assert(op->getName().getStringRef() ==
1654                "test.dynamic_one_operand_two_results" &&
1655            "rewrite pattern should only match operations with the right name");
1656 
1657     OperationState state(op->getLoc(), "test.dynamic_generic",
1658                          op->getOperands(), op->getResultTypes(),
1659                          op->getAttrs());
1660     auto *newOp = rewriter.create(state);
1661     rewriter.replaceOp(op, newOp->getResults());
1662     return success();
1663   }
1664 };
1665 
1666 struct TestRewriteDynamicOpDriver
1667     : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> {
1668   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver)
1669 
1670   void getDependentDialects(DialectRegistry &registry) const override {
1671     registry.insert<TestDialect>();
1672   }
1673   StringRef getArgument() const final { return "test-rewrite-dynamic-op"; }
1674   StringRef getDescription() const final {
1675     return "Test rewritting on dynamic operations";
1676   }
1677   void runOnOperation() override {
1678     RewritePatternSet patterns(&getContext());
1679     patterns.add<RewriteDynamicOp>(&getContext());
1680 
1681     ConversionTarget target(getContext());
1682     target.addIllegalOp(
1683         OperationName("test.dynamic_one_operand_two_results", &getContext()));
1684     target.addLegalOp(OperationName("test.dynamic_generic", &getContext()));
1685     if (failed(applyPartialConversion(getOperation(), target,
1686                                       std::move(patterns))))
1687       signalPassFailure();
1688   }
1689 };
1690 } // end anonymous namespace
1691 
1692 //===----------------------------------------------------------------------===//
1693 // Test type conversions
1694 //===----------------------------------------------------------------------===//
1695 
1696 namespace {
1697 struct TestTypeConversionProducer
1698     : public OpConversionPattern<TestTypeProducerOp> {
1699   using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
1700   LogicalResult
1701   matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
1702                   ConversionPatternRewriter &rewriter) const final {
1703     Type resultType = op.getType();
1704     Type convertedType = getTypeConverter()
1705                              ? getTypeConverter()->convertType(resultType)
1706                              : resultType;
1707     if (isa<FloatType>(resultType))
1708       resultType = rewriter.getF64Type();
1709     else if (resultType.isInteger(16))
1710       resultType = rewriter.getIntegerType(64);
1711     else if (isa<test::TestRecursiveType>(resultType) &&
1712              convertedType != resultType)
1713       resultType = convertedType;
1714     else
1715       return failure();
1716 
1717     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
1718     return success();
1719   }
1720 };
1721 
1722 /// Call signature conversion and then fail the rewrite to trigger the undo
1723 /// mechanism.
1724 struct TestSignatureConversionUndo
1725     : public OpConversionPattern<TestSignatureConversionUndoOp> {
1726   using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
1727 
1728   LogicalResult
1729   matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
1730                   ConversionPatternRewriter &rewriter) const final {
1731     (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
1732     return failure();
1733   }
1734 };
1735 
1736 /// Call signature conversion without providing a type converter to handle
1737 /// materializations.
1738 struct TestTestSignatureConversionNoConverter
1739     : public OpConversionPattern<TestSignatureConversionNoConverterOp> {
1740   TestTestSignatureConversionNoConverter(const TypeConverter &converter,
1741                                          MLIRContext *context)
1742       : OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
1743         converter(converter) {}
1744 
1745   LogicalResult
1746   matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor,
1747                   ConversionPatternRewriter &rewriter) const final {
1748     Region &region = op->getRegion(0);
1749     Block *entry = &region.front();
1750 
1751     // Convert the original entry arguments.
1752     TypeConverter::SignatureConversion result(entry->getNumArguments());
1753     if (failed(
1754             converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
1755       return failure();
1756     rewriter.modifyOpInPlace(op, [&] {
1757       rewriter.applySignatureConversion(&region.front(), result);
1758     });
1759     return success();
1760   }
1761 
1762   const TypeConverter &converter;
1763 };
1764 
1765 /// Just forward the operands to the root op. This is essentially a no-op
1766 /// pattern that is used to trigger target materialization.
1767 struct TestTypeConsumerForward
1768     : public OpConversionPattern<TestTypeConsumerOp> {
1769   using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
1770 
1771   LogicalResult
1772   matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
1773                   ConversionPatternRewriter &rewriter) const final {
1774     rewriter.modifyOpInPlace(op,
1775                              [&] { op->setOperands(adaptor.getOperands()); });
1776     return success();
1777   }
1778 };
1779 
1780 struct TestTypeConversionAnotherProducer
1781     : public OpRewritePattern<TestAnotherTypeProducerOp> {
1782   using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
1783 
1784   LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
1785                                 PatternRewriter &rewriter) const final {
1786     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
1787     return success();
1788   }
1789 };
1790 
1791 struct TestReplaceWithLegalOp : public ConversionPattern {
1792   TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
1793       : ConversionPattern(converter, "test.replace_with_legal_op",
1794                           /*benefit=*/1, ctx) {}
1795   LogicalResult
1796   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1797                   ConversionPatternRewriter &rewriter) const final {
1798     rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
1799     return success();
1800   }
1801 };
1802 
1803 struct TestTypeConversionDriver
1804     : public PassWrapper<TestTypeConversionDriver, OperationPass<>> {
1805   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
1806 
1807   void getDependentDialects(DialectRegistry &registry) const override {
1808     registry.insert<TestDialect>();
1809   }
1810   StringRef getArgument() const final {
1811     return "test-legalize-type-conversion";
1812   }
1813   StringRef getDescription() const final {
1814     return "Test various type conversion functionalities in DialectConversion";
1815   }
1816 
1817   void runOnOperation() override {
1818     // Initialize the type converter.
1819     SmallVector<Type, 2> conversionCallStack;
1820     TypeConverter converter;
1821 
1822     /// Add the legal set of type conversions.
1823     converter.addConversion([](Type type) -> Type {
1824       // Treat F64 as legal.
1825       if (type.isF64())
1826         return type;
1827       // Allow converting BF16/F16/F32 to F64.
1828       if (type.isBF16() || type.isF16() || type.isF32())
1829         return Float64Type::get(type.getContext());
1830       // Otherwise, the type is illegal.
1831       return nullptr;
1832     });
1833     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
1834       // Drop all integer types.
1835       return success();
1836     });
1837     converter.addConversion(
1838         // Convert a recursive self-referring type into a non-self-referring
1839         // type named "outer_converted_type" that contains a SimpleAType.
1840         [&](test::TestRecursiveType type,
1841             SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
1842           // If the type is already converted, return it to indicate that it is
1843           // legal.
1844           if (type.getName() == "outer_converted_type") {
1845             results.push_back(type);
1846             return success();
1847           }
1848 
1849           conversionCallStack.push_back(type);
1850           auto popConversionCallStack = llvm::make_scope_exit(
1851               [&conversionCallStack]() { conversionCallStack.pop_back(); });
1852 
1853           // If the type is on the call stack more than once (it is there at
1854           // least once because of the _current_ call, which is always the last
1855           // element on the stack), we've hit the recursive case. Just return
1856           // SimpleAType here to create a non-recursive type as a result.
1857           if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(),
1858                                  type)) {
1859             results.push_back(test::SimpleAType::get(type.getContext()));
1860             return success();
1861           }
1862 
1863           // Convert the body recursively.
1864           auto result = test::TestRecursiveType::get(type.getContext(),
1865                                                      "outer_converted_type");
1866           if (failed(result.setBody(converter.convertType(type.getBody()))))
1867             return failure();
1868           results.push_back(result);
1869           return success();
1870         });
1871 
1872     /// Add the legal set of type materializations.
1873     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
1874                                           ValueRange inputs,
1875                                           Location loc) -> Value {
1876       // Allow casting from F64 back to F32.
1877       if (!resultType.isF16() && inputs.size() == 1 &&
1878           inputs[0].getType().isF64())
1879         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1880       // Allow producing an i32 or i64 from nothing.
1881       if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
1882           inputs.empty())
1883         return builder.create<TestTypeProducerOp>(loc, resultType);
1884       // Allow producing an i64 from an integer.
1885       if (isa<IntegerType>(resultType) && inputs.size() == 1 &&
1886           isa<IntegerType>(inputs[0].getType()))
1887         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1888       // Otherwise, fail.
1889       return nullptr;
1890     });
1891 
1892     // Initialize the conversion target.
1893     mlir::ConversionTarget target(getContext());
1894     target.addLegalOp<LegalOpD>();
1895     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
1896       auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType());
1897       return op.getType().isF64() || op.getType().isInteger(64) ||
1898              (recursiveType &&
1899               recursiveType.getName() == "outer_converted_type");
1900     });
1901     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1902       return converter.isSignatureLegal(op.getFunctionType()) &&
1903              converter.isLegal(&op.getBody());
1904     });
1905     target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
1906       // Allow casts from F64 to F32.
1907       return (*op.operand_type_begin()).isF64() && op.getType().isF32();
1908     });
1909     target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>(
1910         [&](TestSignatureConversionNoConverterOp op) {
1911           return converter.isLegal(op.getRegion().front().getArgumentTypes());
1912         });
1913 
1914     // Initialize the set of rewrite patterns.
1915     RewritePatternSet patterns(&getContext());
1916     patterns
1917         .add<TestTypeConsumerForward, TestTypeConversionProducer,
1918              TestSignatureConversionUndo,
1919              TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
1920             converter, &getContext());
1921     patterns.add<TestTypeConversionAnotherProducer>(&getContext());
1922     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1923                                                               converter);
1924 
1925     if (failed(applyPartialConversion(getOperation(), target,
1926                                       std::move(patterns))))
1927       signalPassFailure();
1928   }
1929 };
1930 } // namespace
1931 
1932 //===----------------------------------------------------------------------===//
1933 // Test Target Materialization With No Uses
1934 //===----------------------------------------------------------------------===//
1935 
1936 namespace {
1937 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> {
1938   using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern;
1939 
1940   LogicalResult
1941   matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor,
1942                   ConversionPatternRewriter &rewriter) const final {
1943     rewriter.replaceOp(op, adaptor.getOperands());
1944     return success();
1945   }
1946 };
1947 
1948 struct TestTargetMaterializationWithNoUses
1949     : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> {
1950   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
1951       TestTargetMaterializationWithNoUses)
1952 
1953   StringRef getArgument() const final {
1954     return "test-target-materialization-with-no-uses";
1955   }
1956   StringRef getDescription() const final {
1957     return "Test a special case of target materialization in DialectConversion";
1958   }
1959 
1960   void runOnOperation() override {
1961     TypeConverter converter;
1962     converter.addConversion([](Type t) { return t; });
1963     converter.addConversion([](IntegerType intTy) -> Type {
1964       if (intTy.getWidth() == 16)
1965         return IntegerType::get(intTy.getContext(), 64);
1966       return intTy;
1967     });
1968     converter.addTargetMaterialization(
1969         [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1970           return builder.create<TestCastOp>(loc, type, inputs).getResult();
1971         });
1972 
1973     ConversionTarget target(getContext());
1974     target.addIllegalOp<TestTypeChangerOp>();
1975 
1976     RewritePatternSet patterns(&getContext());
1977     patterns.add<ForwardOperandPattern>(converter, &getContext());
1978 
1979     if (failed(applyPartialConversion(getOperation(), target,
1980                                       std::move(patterns))))
1981       signalPassFailure();
1982   }
1983 };
1984 } // namespace
1985 
1986 //===----------------------------------------------------------------------===//
1987 // Test Block Merging
1988 //===----------------------------------------------------------------------===//
1989 
1990 namespace {
1991 /// A rewriter pattern that tests that blocks can be merged.
1992 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
1993   using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
1994 
1995   LogicalResult
1996   matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
1997                   ConversionPatternRewriter &rewriter) const final {
1998     Block &firstBlock = op.getBody().front();
1999     Operation *branchOp = firstBlock.getTerminator();
2000     Block *secondBlock = &*(std::next(op.getBody().begin()));
2001     auto succOperands = branchOp->getOperands();
2002     SmallVector<Value, 2> replacements(succOperands);
2003     rewriter.eraseOp(branchOp);
2004     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
2005     rewriter.modifyOpInPlace(op, [] {});
2006     return success();
2007   }
2008 };
2009 
2010 /// A rewrite pattern to tests the undo mechanism of blocks being merged.
2011 struct TestUndoBlocksMerge : public ConversionPattern {
2012   TestUndoBlocksMerge(MLIRContext *ctx)
2013       : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
2014   LogicalResult
2015   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
2016                   ConversionPatternRewriter &rewriter) const final {
2017     Block &firstBlock = op->getRegion(0).front();
2018     Operation *branchOp = firstBlock.getTerminator();
2019     Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
2020     rewriter.setInsertionPointToStart(secondBlock);
2021     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
2022     auto succOperands = branchOp->getOperands();
2023     SmallVector<Value, 2> replacements(succOperands);
2024     rewriter.eraseOp(branchOp);
2025     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
2026     rewriter.modifyOpInPlace(op, [] {});
2027     return success();
2028   }
2029 };
2030 
2031 /// A rewrite mechanism to inline the body of the op into its parent, when both
2032 /// ops can have a single block.
2033 struct TestMergeSingleBlockOps
2034     : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
2035   using OpConversionPattern<
2036       SingleBlockImplicitTerminatorOp>::OpConversionPattern;
2037 
2038   LogicalResult
2039   matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
2040                   ConversionPatternRewriter &rewriter) const final {
2041     SingleBlockImplicitTerminatorOp parentOp =
2042         op->getParentOfType<SingleBlockImplicitTerminatorOp>();
2043     if (!parentOp)
2044       return failure();
2045     Block &innerBlock = op.getRegion().front();
2046     TerminatorOp innerTerminator =
2047         cast<TerminatorOp>(innerBlock.getTerminator());
2048     rewriter.inlineBlockBefore(&innerBlock, op);
2049     rewriter.eraseOp(innerTerminator);
2050     rewriter.eraseOp(op);
2051     return success();
2052   }
2053 };
2054 
2055 struct TestMergeBlocksPatternDriver
2056     : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> {
2057   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver)
2058 
2059   StringRef getArgument() const final { return "test-merge-blocks"; }
2060   StringRef getDescription() const final {
2061     return "Test Merging operation in ConversionPatternRewriter";
2062   }
2063   void runOnOperation() override {
2064     MLIRContext *context = &getContext();
2065     mlir::RewritePatternSet patterns(context);
2066     patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
2067         context);
2068     ConversionTarget target(*context);
2069     target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
2070                       TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
2071     target.addIllegalOp<ILLegalOpF>();
2072 
2073     /// Expect the op to have a single block after legalization.
2074     target.addDynamicallyLegalOp<TestMergeBlocksOp>(
2075         [&](TestMergeBlocksOp op) -> bool {
2076           return llvm::hasSingleElement(op.getBody());
2077         });
2078 
2079     /// Only allow `test.br` within test.merge_blocks op.
2080     target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
2081       return op->getParentOfType<TestMergeBlocksOp>();
2082     });
2083 
2084     /// Expect that all nested test.SingleBlockImplicitTerminator ops are
2085     /// inlined.
2086     target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
2087         [&](SingleBlockImplicitTerminatorOp op) -> bool {
2088           return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
2089         });
2090 
2091     DenseSet<Operation *> unlegalizedOps;
2092     ConversionConfig config;
2093     config.unlegalizedOps = &unlegalizedOps;
2094     (void)applyPartialConversion(getOperation(), target, std::move(patterns),
2095                                  config);
2096     for (auto *op : unlegalizedOps)
2097       op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
2098   }
2099 };
2100 } // namespace
2101 
2102 //===----------------------------------------------------------------------===//
2103 // Test Selective Replacement
2104 //===----------------------------------------------------------------------===//
2105 
2106 namespace {
2107 /// A rewrite mechanism to inline the body of the op into its parent, when both
2108 /// ops can have a single block.
2109 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
2110   using OpRewritePattern<TestCastOp>::OpRewritePattern;
2111 
2112   LogicalResult matchAndRewrite(TestCastOp op,
2113                                 PatternRewriter &rewriter) const final {
2114     if (op.getNumOperands() != 2)
2115       return failure();
2116     OperandRange operands = op.getOperands();
2117 
2118     // Replace non-terminator uses with the first operand.
2119     rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) {
2120       return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
2121     });
2122     // Replace everything else with the second operand if the operation isn't
2123     // dead.
2124     rewriter.replaceOp(op, op.getOperand(1));
2125     return success();
2126   }
2127 };
2128 
2129 struct TestSelectiveReplacementPatternDriver
2130     : public PassWrapper<TestSelectiveReplacementPatternDriver,
2131                          OperationPass<>> {
2132   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
2133       TestSelectiveReplacementPatternDriver)
2134 
2135   StringRef getArgument() const final {
2136     return "test-pattern-selective-replacement";
2137   }
2138   StringRef getDescription() const final {
2139     return "Test selective replacement in the PatternRewriter";
2140   }
2141   void runOnOperation() override {
2142     MLIRContext *context = &getContext();
2143     mlir::RewritePatternSet patterns(context);
2144     patterns.add<TestSelectiveOpReplacementPattern>(context);
2145     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
2146   }
2147 };
2148 } // namespace
2149 
2150 //===----------------------------------------------------------------------===//
2151 // PassRegistration
2152 //===----------------------------------------------------------------------===//
2153 
2154 namespace mlir {
2155 namespace test {
2156 void registerPatternsTestPass() {
2157   PassRegistration<TestReturnTypeDriver>();
2158 
2159   PassRegistration<TestDerivedAttributeDriver>();
2160 
2161   PassRegistration<TestGreedyPatternDriver>();
2162   PassRegistration<TestStrictPatternDriver>();
2163   PassRegistration<TestWalkPatternDriver>();
2164 
2165   PassRegistration<TestLegalizePatternDriver>([] {
2166     return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
2167   });
2168 
2169   PassRegistration<TestRemappedValue>();
2170 
2171   PassRegistration<TestUnknownRootOpDriver>();
2172 
2173   PassRegistration<TestTypeConversionDriver>();
2174   PassRegistration<TestTargetMaterializationWithNoUses>();
2175 
2176   PassRegistration<TestRewriteDynamicOpDriver>();
2177 
2178   PassRegistration<TestMergeBlocksPatternDriver>();
2179   PassRegistration<TestSelectiveReplacementPatternDriver>();
2180 }
2181 } // namespace test
2182 } // namespace mlir
2183