xref: /llvm-project/mlir/test/lib/Dialect/Test/TestPatterns.cpp (revision a8b4cb185c135de006d2a012b986b7655960c414)
1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestOps.h"
11 #include "TestTypes.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BuiltinAttributes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Visitors.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
24 #include "llvm/ADT/ScopeExit.h"
25 #include <cstdint>
26 
27 using namespace mlir;
28 using namespace test;
29 
30 // Native function for testing NativeCodeCall
31 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
32   return choice.getValue() ? input1 : input2;
33 }
34 
35 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
36   rewriter.create<OpI>(loc, input);
37 }
38 
39 static void handleNoResultOp(PatternRewriter &rewriter,
40                              OpSymbolBindingNoResult op) {
41   // Turn the no result op to a one-result op.
42   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
43                                     op.getOperand());
44 }
45 
46 static bool getFirstI32Result(Operation *op, Value &value) {
47   if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
48     return false;
49   value = op->getResult(0);
50   return true;
51 }
52 
53 static Value bindNativeCodeCallResult(Value value) { return value; }
54 
55 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
56                                                               Value input2) {
57   return SmallVector<Value, 2>({input2, input1});
58 }
59 
60 // Test that natives calls are only called once during rewrites.
61 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
62 // This let us check the number of times OpM_Test was called by inspecting
63 // the returned value in the MLIR output.
64 static int64_t opMIncreasingValue = 314159265;
65 static Attribute opMTest(PatternRewriter &rewriter, Value val) {
66   int64_t i = opMIncreasingValue++;
67   return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
68 }
69 
70 namespace {
71 #include "TestPatterns.inc"
72 } // namespace
73 
74 //===----------------------------------------------------------------------===//
75 // Test Reduce Pattern Interface
76 //===----------------------------------------------------------------------===//
77 
78 void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
79   populateWithGenerated(patterns);
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // Canonicalizer Driver.
84 //===----------------------------------------------------------------------===//
85 
86 namespace {
87 struct FoldingPattern : public RewritePattern {
88 public:
89   FoldingPattern(MLIRContext *context)
90       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
91                        /*benefit=*/1, context) {}
92 
93   LogicalResult matchAndRewrite(Operation *op,
94                                 PatternRewriter &rewriter) const override {
95     // Exercise createOrFold API for a single-result operation that is folded
96     // upon construction. The operation being created has an in-place folder,
97     // and it should be still present in the output. Furthermore, the folder
98     // should not crash when attempting to recover the (unchanged) operation
99     // result.
100     Value result = rewriter.createOrFold<TestOpInPlaceFold>(
101         op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0));
102     assert(result);
103     rewriter.replaceOp(op, result);
104     return success();
105   }
106 };
107 
108 /// This pattern creates a foldable operation at the entry point of the block.
109 /// This tests the situation where the operation folder will need to replace an
110 /// operation with a previously created constant that does not initially
111 /// dominate the operation to replace.
112 struct FolderInsertBeforePreviouslyFoldedConstantPattern
113     : public OpRewritePattern<TestCastOp> {
114 public:
115   using OpRewritePattern<TestCastOp>::OpRewritePattern;
116 
117   LogicalResult matchAndRewrite(TestCastOp op,
118                                 PatternRewriter &rewriter) const override {
119     if (!op->hasAttr("test_fold_before_previously_folded_op"))
120       return failure();
121     rewriter.setInsertionPointToStart(op->getBlock());
122 
123     auto constOp = rewriter.create<arith::ConstantOp>(
124         op.getLoc(), rewriter.getBoolAttr(true));
125     rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
126                                             Value(constOp));
127     return success();
128   }
129 };
130 
131 /// This pattern matches test.op_commutative2 with the first operand being
132 /// another test.op_commutative2 with a constant on the right side and fold it
133 /// away by propagating it as its result. This is intend to check that patterns
134 /// are applied after the commutative property moves constant to the right.
135 struct FolderCommutativeOp2WithConstant
136     : public OpRewritePattern<TestCommutative2Op> {
137 public:
138   using OpRewritePattern<TestCommutative2Op>::OpRewritePattern;
139 
140   LogicalResult matchAndRewrite(TestCommutative2Op op,
141                                 PatternRewriter &rewriter) const override {
142     auto operand =
143         dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
144     if (!operand)
145       return failure();
146     Attribute constInput;
147     if (!matchPattern(operand->getOperand(1), m_Constant(&constInput)))
148       return failure();
149     rewriter.replaceOp(op, operand->getOperand(1));
150     return success();
151   }
152 };
153 
154 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer
155 /// attribute with value smaller than MaxVal, it increments the value by 1.
156 template <int MaxVal>
157 struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
158   using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern;
159 
160   LogicalResult matchAndRewrite(AnyAttrOfOp op,
161                                 PatternRewriter &rewriter) const override {
162     auto intAttr = dyn_cast<IntegerAttr>(op.getAttr());
163     if (!intAttr)
164       return failure();
165     int64_t val = intAttr.getInt();
166     if (val >= MaxVal)
167       return failure();
168     rewriter.modifyOpInPlace(
169         op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); });
170     return success();
171   }
172 };
173 
174 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op".
175 struct MakeOpEligible : public RewritePattern {
176   MakeOpEligible(MLIRContext *context)
177       : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {}
178 
179   LogicalResult matchAndRewrite(Operation *op,
180                                 PatternRewriter &rewriter) const override {
181     if (op->hasAttr("eligible"))
182       return failure();
183     rewriter.modifyOpInPlace(
184         op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); });
185     return success();
186   }
187 };
188 
189 /// This pattern hoists eligible ops out of a "test.one_region_op".
190 struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
191   using OpRewritePattern<test::OneRegionOp>::OpRewritePattern;
192 
193   LogicalResult matchAndRewrite(test::OneRegionOp op,
194                                 PatternRewriter &rewriter) const override {
195     Operation *terminator = op.getRegion().front().getTerminator();
196     Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp();
197     if (toBeHoisted->getParentOp() != op)
198       return failure();
199     if (!toBeHoisted->hasAttr("eligible"))
200       return failure();
201     rewriter.moveOpBefore(toBeHoisted, op);
202     return success();
203   }
204 };
205 
206 /// This pattern moves "test.move_before_parent_op" before the parent op.
207 struct MoveBeforeParentOp : public RewritePattern {
208   MoveBeforeParentOp(MLIRContext *context)
209       : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
210 
211   LogicalResult matchAndRewrite(Operation *op,
212                                 PatternRewriter &rewriter) const override {
213     // Do not hoist past functions.
214     if (isa<FunctionOpInterface>(op->getParentOp()))
215       return failure();
216     rewriter.moveOpBefore(op, op->getParentOp());
217     return success();
218   }
219 };
220 
221 /// This pattern moves "test.move_after_parent_op" after the parent op.
222 struct MoveAfterParentOp : public RewritePattern {
223   MoveAfterParentOp(MLIRContext *context)
224       : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {}
225 
226   LogicalResult matchAndRewrite(Operation *op,
227                                 PatternRewriter &rewriter) const override {
228     // Do not hoist past functions.
229     if (isa<FunctionOpInterface>(op->getParentOp()))
230       return failure();
231 
232     int64_t moveForwardBy = 0;
233     if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance"))
234       moveForwardBy = advanceBy.getInt();
235 
236     Operation *moveAfter = op->getParentOp();
237     for (int64_t i = 0; i < moveForwardBy; ++i)
238       moveAfter = moveAfter->getNextNode();
239 
240     rewriter.moveOpAfter(op, moveAfter);
241     return success();
242   }
243 };
244 
245 /// This pattern inlines blocks that are nested in
246 /// "test.inline_blocks_into_parent" into the parent block.
247 struct InlineBlocksIntoParent : public RewritePattern {
248   InlineBlocksIntoParent(MLIRContext *context)
249       : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
250                        context) {}
251 
252   LogicalResult matchAndRewrite(Operation *op,
253                                 PatternRewriter &rewriter) const override {
254     bool changed = false;
255     for (Region &r : op->getRegions()) {
256       while (!r.empty()) {
257         rewriter.inlineBlockBefore(&r.front(), op);
258         changed = true;
259       }
260     }
261     return success(changed);
262   }
263 };
264 
265 /// This pattern splits blocks at "test.split_block_here" and replaces the op
266 /// with a new op (to prevent an infinite loop of block splitting).
267 struct SplitBlockHere : public RewritePattern {
268   SplitBlockHere(MLIRContext *context)
269       : RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
270 
271   LogicalResult matchAndRewrite(Operation *op,
272                                 PatternRewriter &rewriter) const override {
273     rewriter.splitBlock(op->getBlock(), op->getIterator());
274     Operation *newOp = rewriter.create(
275         op->getLoc(),
276         OperationName("test.new_op", op->getContext()).getIdentifier(),
277         op->getOperands(), op->getResultTypes());
278     rewriter.replaceOp(op, newOp);
279     return success();
280   }
281 };
282 
283 /// This pattern clones "test.clone_me" ops.
284 struct CloneOp : public RewritePattern {
285   CloneOp(MLIRContext *context)
286       : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
287 
288   LogicalResult matchAndRewrite(Operation *op,
289                                 PatternRewriter &rewriter) const override {
290     // Do not clone already cloned ops to avoid going into an infinite loop.
291     if (op->hasAttr("was_cloned"))
292       return failure();
293     Operation *cloned = rewriter.clone(*op);
294     cloned->setAttr("was_cloned", rewriter.getUnitAttr());
295     return success();
296   }
297 };
298 
299 /// This pattern clones regions of "test.clone_region_before" ops before the
300 /// parent block.
301 struct CloneRegionBeforeOp : public RewritePattern {
302   CloneRegionBeforeOp(MLIRContext *context)
303       : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {}
304 
305   LogicalResult matchAndRewrite(Operation *op,
306                                 PatternRewriter &rewriter) const override {
307     // Do not clone already cloned ops to avoid going into an infinite loop.
308     if (op->hasAttr("was_cloned"))
309       return failure();
310     for (Region &r : op->getRegions())
311       rewriter.cloneRegionBefore(r, op->getBlock());
312     op->setAttr("was_cloned", rewriter.getUnitAttr());
313     return success();
314   }
315 };
316 
317 /// Replace an operation may introduce the re-visiting of its users.
318 class ReplaceWithNewOp : public RewritePattern {
319 public:
320   ReplaceWithNewOp(MLIRContext *context)
321       : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
322 
323   LogicalResult matchAndRewrite(Operation *op,
324                                 PatternRewriter &rewriter) const override {
325     Operation *newOp;
326     if (op->hasAttr("create_erase_op")) {
327       newOp = rewriter.create(
328           op->getLoc(),
329           OperationName("test.erase_op", op->getContext()).getIdentifier(),
330           ValueRange(), TypeRange());
331     } else {
332       newOp = rewriter.create(
333           op->getLoc(),
334           OperationName("test.new_op", op->getContext()).getIdentifier(),
335           op->getOperands(), op->getResultTypes());
336     }
337     // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
338     // A "notifyOperationReplaced" callback is triggered in either case.
339     rewriter.replaceAllOpUsesWith(op, newOp->getResults());
340     rewriter.eraseOp(op);
341     return success();
342   }
343 };
344 
345 /// Erases the first child block of the matched "test.erase_first_block"
346 /// operation.
347 class EraseFirstBlock : public RewritePattern {
348 public:
349   EraseFirstBlock(MLIRContext *context)
350       : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {}
351 
352   LogicalResult matchAndRewrite(Operation *op,
353                                 PatternRewriter &rewriter) const override {
354     for (Region &r : op->getRegions()) {
355       for (Block &b : r.getBlocks()) {
356         rewriter.eraseBlock(&b);
357         return success();
358       }
359     }
360 
361     return failure();
362   }
363 };
364 
365 struct TestGreedyPatternDriver
366     : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> {
367   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver)
368 
369   TestGreedyPatternDriver() = default;
370   TestGreedyPatternDriver(const TestGreedyPatternDriver &other)
371       : PassWrapper(other) {}
372 
373   StringRef getArgument() const final { return "test-greedy-patterns"; }
374   StringRef getDescription() const final { return "Run test dialect patterns"; }
375   void runOnOperation() override {
376     mlir::RewritePatternSet patterns(&getContext());
377     populateWithGenerated(patterns);
378 
379     // Verify named pattern is generated with expected name.
380     patterns.add<FoldingPattern, TestNamedPatternRule,
381                  FolderInsertBeforePreviouslyFoldedConstantPattern,
382                  FolderCommutativeOp2WithConstant, HoistEligibleOps,
383                  MakeOpEligible>(&getContext());
384 
385     // Additional patterns for testing the GreedyPatternRewriteDriver.
386     patterns.insert<IncrementIntAttribute<3>>(&getContext());
387 
388     GreedyRewriteConfig config;
389     config.useTopDownTraversal = this->useTopDownTraversal;
390     config.maxIterations = this->maxIterations;
391     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
392                                        config);
393   }
394 
395   Option<bool> useTopDownTraversal{
396       *this, "top-down",
397       llvm::cl::desc("Seed the worklist in general top-down order"),
398       llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
399   Option<int> maxIterations{
400       *this, "max-iterations",
401       llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
402       llvm::cl::init(GreedyRewriteConfig().maxIterations)};
403 };
404 
405 struct DumpNotifications : public RewriterBase::Listener {
406   void notifyBlockInserted(Block *block, Region *previous,
407                            Region::iterator previousIt) override {
408     llvm::outs() << "notifyBlockInserted";
409     if (block->getParentOp()) {
410       llvm::outs() << " into " << block->getParentOp()->getName() << ": ";
411     } else {
412       llvm::outs() << " into unknown op: ";
413     }
414     if (previous == nullptr) {
415       llvm::outs() << "was unlinked\n";
416     } else {
417       llvm::outs() << "was linked\n";
418     }
419   }
420   void notifyOperationInserted(Operation *op,
421                                OpBuilder::InsertPoint previous) override {
422     llvm::outs() << "notifyOperationInserted: " << op->getName();
423     if (!previous.isSet()) {
424       llvm::outs() << ", was unlinked\n";
425     } else {
426       if (!previous.getPoint().getNodePtr()) {
427         llvm::outs() << ", was linked, exact position unknown\n";
428       } else if (previous.getPoint() == previous.getBlock()->end()) {
429         llvm::outs() << ", was last in block\n";
430       } else {
431         llvm::outs() << ", previous = " << previous.getPoint()->getName()
432                      << "\n";
433       }
434     }
435   }
436   void notifyBlockErased(Block *block) override {
437     llvm::outs() << "notifyBlockErased\n";
438   }
439   void notifyOperationErased(Operation *op) override {
440     llvm::outs() << "notifyOperationErased: " << op->getName() << "\n";
441   }
442   void notifyOperationModified(Operation *op) override {
443     llvm::outs() << "notifyOperationModified: " << op->getName() << "\n";
444   }
445   void notifyOperationReplaced(Operation *op, ValueRange values) override {
446     llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n";
447   }
448 };
449 
450 struct TestStrictPatternDriver
451     : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
452 public:
453   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
454 
455   TestStrictPatternDriver() = default;
456   TestStrictPatternDriver(const TestStrictPatternDriver &other)
457       : PassWrapper(other) {
458     strictMode = other.strictMode;
459   }
460 
461   StringRef getArgument() const final { return "test-strict-pattern-driver"; }
462   StringRef getDescription() const final {
463     return "Test strict mode of pattern driver";
464   }
465 
466   void runOnOperation() override {
467     MLIRContext *ctx = &getContext();
468     mlir::RewritePatternSet patterns(ctx);
469     patterns.add<
470         // clang-format off
471         ChangeBlockOp,
472         CloneOp,
473         CloneRegionBeforeOp,
474         EraseOp,
475         ImplicitChangeOp,
476         InlineBlocksIntoParent,
477         InsertSameOp,
478         MoveBeforeParentOp,
479         ReplaceWithNewOp,
480         SplitBlockHere
481         // clang-format on
482         >(ctx);
483     SmallVector<Operation *> ops;
484     getOperation()->walk([&](Operation *op) {
485       StringRef opName = op->getName().getStringRef();
486       if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
487           opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
488           opName == "test.move_before_parent_op" ||
489           opName == "test.inline_blocks_into_parent" ||
490           opName == "test.split_block_here" || opName == "test.clone_me" ||
491           opName == "test.clone_region_before") {
492         ops.push_back(op);
493       }
494     });
495 
496     DumpNotifications dumpNotifications;
497     GreedyRewriteConfig config;
498     config.listener = &dumpNotifications;
499     if (strictMode == "AnyOp") {
500       config.strictMode = GreedyRewriteStrictness::AnyOp;
501     } else if (strictMode == "ExistingAndNewOps") {
502       config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
503     } else if (strictMode == "ExistingOps") {
504       config.strictMode = GreedyRewriteStrictness::ExistingOps;
505     } else {
506       llvm_unreachable("invalid strictness option");
507     }
508 
509     // Check if these transformations introduce visiting of operations that
510     // are not in the `ops` set (The new created ops are valid). An invalid
511     // operation will trigger the assertion while processing.
512     bool changed = false;
513     bool allErased = false;
514     (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config,
515                                  &changed, &allErased);
516     Builder b(ctx);
517     getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed));
518     getOperation()->setAttr("pattern_driver_all_erased",
519                             b.getBoolAttr(allErased));
520   }
521 
522   Option<std::string> strictMode{
523       *this, "strictness",
524       llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"),
525       llvm::cl::init("AnyOp")};
526 
527 private:
528   // New inserted operation is valid for further transformation.
529   class InsertSameOp : public RewritePattern {
530   public:
531     InsertSameOp(MLIRContext *context)
532         : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {}
533 
534     LogicalResult matchAndRewrite(Operation *op,
535                                   PatternRewriter &rewriter) const override {
536       if (op->hasAttr("skip"))
537         return failure();
538 
539       Operation *newOp =
540           rewriter.create(op->getLoc(), op->getName().getIdentifier(),
541                           op->getOperands(), op->getResultTypes());
542       rewriter.modifyOpInPlace(
543           op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); });
544       newOp->setAttr("skip", rewriter.getBoolAttr(true));
545 
546       return success();
547     }
548   };
549 
550   // Remove an operation may introduce the re-visiting of its operands.
551   class EraseOp : public RewritePattern {
552   public:
553     EraseOp(MLIRContext *context)
554         : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
555     LogicalResult matchAndRewrite(Operation *op,
556                                   PatternRewriter &rewriter) const override {
557       rewriter.eraseOp(op);
558       return success();
559     }
560   };
561 
562   // The following two patterns test RewriterBase::replaceAllUsesWith.
563   //
564   // That function replaces all usages of a Block (or a Value) with another one
565   // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver
566   // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its
567   // worklist: when an op is modified, it is added to the worklist. The two
568   // patterns below make the tracking observable: ChangeBlockOp replaces all
569   // usages of a block and that pattern is applied because the corresponding ops
570   // are put on the initial worklist (see above). ImplicitChangeOp does an
571   // unrelated change but ops of the corresponding type are *not* on the initial
572   // worklist, so the effect of the second pattern is only visible if the
573   // tracking and subsequent adding to the worklist actually works.
574 
575   // Replace all usages of the first successor with the second successor.
576   class ChangeBlockOp : public RewritePattern {
577   public:
578     ChangeBlockOp(MLIRContext *context)
579         : RewritePattern("test.change_block_op", /*benefit=*/1, context) {}
580     LogicalResult matchAndRewrite(Operation *op,
581                                   PatternRewriter &rewriter) const override {
582       if (op->getNumSuccessors() < 2)
583         return failure();
584       Block *firstSuccessor = op->getSuccessor(0);
585       Block *secondSuccessor = op->getSuccessor(1);
586       if (firstSuccessor == secondSuccessor)
587         return failure();
588       // This is the function being tested:
589       rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor);
590       // Using the following line instead would make the test fail:
591       // firstSuccessor->replaceAllUsesWith(secondSuccessor);
592       return success();
593     }
594   };
595 
596   // Changes the successor to the parent block.
597   class ImplicitChangeOp : public RewritePattern {
598   public:
599     ImplicitChangeOp(MLIRContext *context)
600         : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {}
601     LogicalResult matchAndRewrite(Operation *op,
602                                   PatternRewriter &rewriter) const override {
603       if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock())
604         return failure();
605       rewriter.modifyOpInPlace(op,
606                                [&]() { op->setSuccessor(op->getBlock(), 0); });
607       return success();
608     }
609   };
610 };
611 
612 struct TestWalkPatternDriver final
613     : PassWrapper<TestWalkPatternDriver, OperationPass<>> {
614   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver)
615 
616   TestWalkPatternDriver() = default;
617   TestWalkPatternDriver(const TestWalkPatternDriver &other)
618       : PassWrapper(other) {}
619 
620   StringRef getArgument() const override {
621     return "test-walk-pattern-rewrite-driver";
622   }
623   StringRef getDescription() const override {
624     return "Run test walk pattern rewrite driver";
625   }
626   void runOnOperation() override {
627     mlir::RewritePatternSet patterns(&getContext());
628 
629     // Patterns for testing the WalkPatternRewriteDriver.
630     patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp,
631                  MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>(
632         &getContext());
633 
634     DumpNotifications dumpListener;
635     walkAndApplyPatterns(getOperation(), std::move(patterns),
636                          dumpNotifications ? &dumpListener : nullptr);
637   }
638 
639   Option<bool> dumpNotifications{
640       *this, "dump-notifications",
641       llvm::cl::desc("Print rewrite listener notifications"),
642       llvm::cl::init(false)};
643 };
644 
645 } // namespace
646 
647 //===----------------------------------------------------------------------===//
648 // ReturnType Driver.
649 //===----------------------------------------------------------------------===//
650 
651 namespace {
652 // Generate ops for each instance where the type can be successfully inferred.
653 template <typename OpTy>
654 static void invokeCreateWithInferredReturnType(Operation *op) {
655   auto *context = op->getContext();
656   auto fop = op->getParentOfType<func::FuncOp>();
657   auto location = UnknownLoc::get(context);
658   OpBuilder b(op);
659   b.setInsertionPointAfter(op);
660 
661   // Use permutations of 2 args as operands.
662   assert(fop.getNumArguments() >= 2);
663   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
664     for (int j = 0; j < e; ++j) {
665       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
666       SmallVector<Type, 2> inferredReturnTypes;
667       if (succeeded(OpTy::inferReturnTypes(
668               context, std::nullopt, values, op->getDiscardableAttrDictionary(),
669               op->getPropertiesStorage(), op->getRegions(),
670               inferredReturnTypes))) {
671         OperationState state(location, OpTy::getOperationName());
672         // TODO: Expand to regions.
673         OpTy::build(b, state, values, op->getAttrs());
674         (void)b.create(state);
675       }
676     }
677   }
678 }
679 
680 static void reifyReturnShape(Operation *op) {
681   OpBuilder b(op);
682 
683   // Use permutations of 2 args as operands.
684   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
685   SmallVector<Value, 2> shapes;
686   if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
687       !llvm::hasSingleElement(shapes))
688     return;
689   for (const auto &it : llvm::enumerate(shapes)) {
690     op->emitRemark() << "value " << it.index() << ": "
691                      << it.value().getDefiningOp();
692   }
693 }
694 
695 struct TestReturnTypeDriver
696     : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> {
697   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver)
698 
699   void getDependentDialects(DialectRegistry &registry) const override {
700     registry.insert<tensor::TensorDialect>();
701   }
702   StringRef getArgument() const final { return "test-return-type"; }
703   StringRef getDescription() const final { return "Run return type functions"; }
704 
705   void runOnOperation() override {
706     if (getOperation().getName() == "testCreateFunctions") {
707       std::vector<Operation *> ops;
708       // Collect ops to avoid triggering on inserted ops.
709       for (auto &op : getOperation().getBody().front())
710         ops.push_back(&op);
711       // Generate test patterns for each, but skip terminator.
712       for (auto *op : llvm::ArrayRef(ops).drop_back()) {
713         // Test create method of each of the Op classes below. The resultant
714         // output would be in reverse order underneath `op` from which
715         // the attributes and regions are used.
716         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
717         invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>(
718             op);
719         invokeCreateWithInferredReturnType<
720             OpWithShapedTypeInferTypeInterfaceOp>(op);
721       };
722       return;
723     }
724     if (getOperation().getName() == "testReifyFunctions") {
725       std::vector<Operation *> ops;
726       // Collect ops to avoid triggering on inserted ops.
727       for (auto &op : getOperation().getBody().front())
728         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
729           ops.push_back(&op);
730       // Generate test patterns for each, but skip terminator.
731       for (auto *op : ops)
732         reifyReturnShape(op);
733     }
734   }
735 };
736 } // namespace
737 
738 namespace {
739 struct TestDerivedAttributeDriver
740     : public PassWrapper<TestDerivedAttributeDriver,
741                          OperationPass<func::FuncOp>> {
742   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver)
743 
744   StringRef getArgument() const final { return "test-derived-attr"; }
745   StringRef getDescription() const final {
746     return "Run test derived attributes";
747   }
748   void runOnOperation() override;
749 };
750 } // namespace
751 
752 void TestDerivedAttributeDriver::runOnOperation() {
753   getOperation().walk([](DerivedAttributeOpInterface dOp) {
754     auto dAttr = dOp.materializeDerivedAttributes();
755     if (!dAttr)
756       return;
757     for (auto d : dAttr)
758       dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue();
759   });
760 }
761 
762 //===----------------------------------------------------------------------===//
763 // Legalization Driver.
764 //===----------------------------------------------------------------------===//
765 
766 namespace {
767 //===----------------------------------------------------------------------===//
768 // Region-Block Rewrite Testing
769 
770 /// This pattern applies a signature conversion to a block inside a detached
771 /// region.
772 struct TestDetachedSignatureConversion : public ConversionPattern {
773   TestDetachedSignatureConversion(MLIRContext *ctx)
774       : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1,
775                           ctx) {}
776 
777   LogicalResult
778   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
779                   ConversionPatternRewriter &rewriter) const final {
780     if (op->getNumRegions() != 1)
781       return failure();
782     OperationState state(op->getLoc(), "test.legal_op_with_region", operands,
783                          op->getResultTypes(), {}, BlockRange());
784     Region *newRegion = state.addRegion();
785     rewriter.inlineRegionBefore(op->getRegion(0), *newRegion,
786                                 newRegion->begin());
787     TypeConverter::SignatureConversion result(newRegion->getNumArguments());
788     for (unsigned i = 0, e = newRegion->getNumArguments(); i < e; ++i)
789       result.addInputs(i, rewriter.getF64Type());
790     rewriter.applySignatureConversion(&newRegion->front(), result);
791     Operation *newOp = rewriter.create(state);
792     rewriter.replaceOp(op, newOp->getResults());
793     return success();
794   }
795 };
796 
797 /// This pattern is a simple pattern that inlines the first region of a given
798 /// operation into the parent region.
799 struct TestRegionRewriteBlockMovement : public ConversionPattern {
800   TestRegionRewriteBlockMovement(MLIRContext *ctx)
801       : ConversionPattern("test.region", 1, ctx) {}
802 
803   LogicalResult
804   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
805                   ConversionPatternRewriter &rewriter) const final {
806     // Inline this region into the parent region.
807     auto &parentRegion = *op->getParentRegion();
808     auto &opRegion = op->getRegion(0);
809     if (op->getDiscardableAttr("legalizer.should_clone"))
810       rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
811     else
812       rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
813 
814     if (op->getDiscardableAttr("legalizer.erase_old_blocks")) {
815       while (!opRegion.empty())
816         rewriter.eraseBlock(&opRegion.front());
817     }
818 
819     // Drop this operation.
820     rewriter.eraseOp(op);
821     return success();
822   }
823 };
824 /// This pattern is a simple pattern that generates a region containing an
825 /// illegal operation.
826 struct TestRegionRewriteUndo : public RewritePattern {
827   TestRegionRewriteUndo(MLIRContext *ctx)
828       : RewritePattern("test.region_builder", 1, ctx) {}
829 
830   LogicalResult matchAndRewrite(Operation *op,
831                                 PatternRewriter &rewriter) const final {
832     // Create the region operation with an entry block containing arguments.
833     OperationState newRegion(op->getLoc(), "test.region");
834     newRegion.addRegion();
835     auto *regionOp = rewriter.create(newRegion);
836     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
837     entryBlock->addArgument(rewriter.getIntegerType(64),
838                             rewriter.getUnknownLoc());
839 
840     // Add an explicitly illegal operation to ensure the conversion fails.
841     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
842     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
843 
844     // Drop this operation.
845     rewriter.eraseOp(op);
846     return success();
847   }
848 };
849 /// A simple pattern that creates a block at the end of the parent region of the
850 /// matched operation.
851 struct TestCreateBlock : public RewritePattern {
852   TestCreateBlock(MLIRContext *ctx)
853       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
854 
855   LogicalResult matchAndRewrite(Operation *op,
856                                 PatternRewriter &rewriter) const final {
857     Region &region = *op->getParentRegion();
858     Type i32Type = rewriter.getIntegerType(32);
859     Location loc = op->getLoc();
860     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
861     rewriter.create<TerminatorOp>(loc);
862     rewriter.eraseOp(op);
863     return success();
864   }
865 };
866 
867 /// A simple pattern that creates a block containing an invalid operation in
868 /// order to trigger the block creation undo mechanism.
869 struct TestCreateIllegalBlock : public RewritePattern {
870   TestCreateIllegalBlock(MLIRContext *ctx)
871       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
872 
873   LogicalResult matchAndRewrite(Operation *op,
874                                 PatternRewriter &rewriter) const final {
875     Region &region = *op->getParentRegion();
876     Type i32Type = rewriter.getIntegerType(32);
877     Location loc = op->getLoc();
878     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
879     // Create an illegal op to ensure the conversion fails.
880     rewriter.create<ILLegalOpF>(loc, i32Type);
881     rewriter.create<TerminatorOp>(loc);
882     rewriter.eraseOp(op);
883     return success();
884   }
885 };
886 
887 /// A simple pattern that tests the undo mechanism when replacing the uses of a
888 /// block argument.
889 struct TestUndoBlockArgReplace : public ConversionPattern {
890   TestUndoBlockArgReplace(MLIRContext *ctx)
891       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
892 
893   LogicalResult
894   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
895                   ConversionPatternRewriter &rewriter) const final {
896     auto illegalOp =
897         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
898     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
899                                         illegalOp->getResult(0));
900     rewriter.modifyOpInPlace(op, [] {});
901     return success();
902   }
903 };
904 
905 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
906 /// This is to test the rollback logic.
907 struct TestUndoMoveOpBefore : public ConversionPattern {
908   TestUndoMoveOpBefore(MLIRContext *ctx)
909       : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
910 
911   LogicalResult
912   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
913                   ConversionPatternRewriter &rewriter) const override {
914     rewriter.moveOpBefore(op, op->getParentOp());
915     // Replace with an illegal op to ensure the conversion fails.
916     rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
917     return success();
918   }
919 };
920 
921 /// A rewrite pattern that tests the undo mechanism when erasing a block.
922 struct TestUndoBlockErase : public ConversionPattern {
923   TestUndoBlockErase(MLIRContext *ctx)
924       : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
925 
926   LogicalResult
927   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
928                   ConversionPatternRewriter &rewriter) const final {
929     Block *secondBlock = &*std::next(op->getRegion(0).begin());
930     rewriter.setInsertionPointToStart(secondBlock);
931     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
932     rewriter.eraseBlock(secondBlock);
933     rewriter.modifyOpInPlace(op, [] {});
934     return success();
935   }
936 };
937 
938 /// A pattern that modifies a property in-place, but keeps the op illegal.
939 struct TestUndoPropertiesModification : public ConversionPattern {
940   TestUndoPropertiesModification(MLIRContext *ctx)
941       : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {}
942   LogicalResult
943   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
944                   ConversionPatternRewriter &rewriter) const final {
945     if (!op->hasAttr("modify_inplace"))
946       return failure();
947     rewriter.modifyOpInPlace(
948         op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); });
949     return success();
950   }
951 };
952 
953 //===----------------------------------------------------------------------===//
954 // Type-Conversion Rewrite Testing
955 
956 /// This patterns erases a region operation that has had a type conversion.
957 struct TestDropOpSignatureConversion : public ConversionPattern {
958   TestDropOpSignatureConversion(MLIRContext *ctx,
959                                 const TypeConverter &converter)
960       : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
961   LogicalResult
962   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
963                   ConversionPatternRewriter &rewriter) const override {
964     Region &region = op->getRegion(0);
965     Block *entry = &region.front();
966 
967     // Convert the original entry arguments.
968     const TypeConverter &converter = *getTypeConverter();
969     TypeConverter::SignatureConversion result(entry->getNumArguments());
970     if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
971                                               result)) ||
972         failed(rewriter.convertRegionTypes(&region, converter, &result)))
973       return failure();
974 
975     // Convert the region signature and just drop the operation.
976     rewriter.eraseOp(op);
977     return success();
978   }
979 };
980 /// This pattern simply updates the operands of the given operation.
981 struct TestPassthroughInvalidOp : public ConversionPattern {
982   TestPassthroughInvalidOp(MLIRContext *ctx)
983       : ConversionPattern("test.invalid", 1, ctx) {}
984   LogicalResult
985   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
986                   ConversionPatternRewriter &rewriter) const final {
987     rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, operands,
988                                              std::nullopt);
989     return success();
990   }
991 };
992 /// Replace with valid op, but simply drop the operands. This is used in a
993 /// regression where we used to generate circular unrealized_conversion_cast
994 /// ops.
995 struct TestDropAndReplaceInvalidOp : public ConversionPattern {
996   TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
997       : ConversionPattern(converter,
998                           "test.drop_operands_and_replace_with_valid", 1, ctx) {
999   }
1000   LogicalResult
1001   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1002                   ConversionPatternRewriter &rewriter) const final {
1003     rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, ValueRange(),
1004                                              std::nullopt);
1005     return success();
1006   }
1007 };
1008 /// This pattern handles the case of a split return value.
1009 struct TestSplitReturnType : public ConversionPattern {
1010   TestSplitReturnType(MLIRContext *ctx)
1011       : ConversionPattern("test.return", 1, ctx) {}
1012   LogicalResult
1013   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1014                   ConversionPatternRewriter &rewriter) const final {
1015     // Check for a return of F32.
1016     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
1017       return failure();
1018 
1019     // Check if the first operation is a cast operation, if it is we use the
1020     // results directly.
1021     auto *defOp = operands[0].getDefiningOp();
1022     if (auto packerOp =
1023             llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) {
1024       rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
1025       return success();
1026     }
1027 
1028     // Otherwise, fail to match.
1029     return failure();
1030   }
1031 };
1032 
1033 //===----------------------------------------------------------------------===//
1034 // Multi-Level Type-Conversion Rewrite Testing
1035 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
1036   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
1037       : ConversionPattern("test.type_producer", 1, ctx) {}
1038   LogicalResult
1039   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1040                   ConversionPatternRewriter &rewriter) const final {
1041     // If the type is I32, change the type to F32.
1042     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
1043       return failure();
1044     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
1045     return success();
1046   }
1047 };
1048 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
1049   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
1050       : ConversionPattern("test.type_producer", 1, ctx) {}
1051   LogicalResult
1052   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1053                   ConversionPatternRewriter &rewriter) const final {
1054     // If the type is F32, change the type to F64.
1055     if (!Type(*op->result_type_begin()).isF32())
1056       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
1057     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
1058     return success();
1059   }
1060 };
1061 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
1062   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
1063       : ConversionPattern("test.type_producer", 10, ctx) {}
1064   LogicalResult
1065   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1066                   ConversionPatternRewriter &rewriter) const final {
1067     // Always convert to B16, even though it is not a legal type. This tests
1068     // that values are unmapped correctly.
1069     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
1070     return success();
1071   }
1072 };
1073 struct TestUpdateConsumerType : public ConversionPattern {
1074   TestUpdateConsumerType(MLIRContext *ctx)
1075       : ConversionPattern("test.type_consumer", 1, ctx) {}
1076   LogicalResult
1077   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1078                   ConversionPatternRewriter &rewriter) const final {
1079     // Verify that the incoming operand has been successfully remapped to F64.
1080     if (!operands[0].getType().isF64())
1081       return failure();
1082     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
1083     return success();
1084   }
1085 };
1086 
1087 //===----------------------------------------------------------------------===//
1088 // Non-Root Replacement Rewrite Testing
1089 /// This pattern generates an invalid operation, but replaces it before the
1090 /// pattern is finished. This checks that we don't need to legalize the
1091 /// temporary op.
1092 struct TestNonRootReplacement : public RewritePattern {
1093   TestNonRootReplacement(MLIRContext *ctx)
1094       : RewritePattern("test.replace_non_root", 1, ctx) {}
1095 
1096   LogicalResult matchAndRewrite(Operation *op,
1097                                 PatternRewriter &rewriter) const final {
1098     auto resultType = *op->result_type_begin();
1099     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
1100     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
1101 
1102     rewriter.replaceOp(illegalOp, legalOp);
1103     rewriter.replaceOp(op, illegalOp);
1104     return success();
1105   }
1106 };
1107 
1108 //===----------------------------------------------------------------------===//
1109 // Recursive Rewrite Testing
1110 /// This pattern is applied to the same operation multiple times, but has a
1111 /// bounded recursion.
1112 struct TestBoundedRecursiveRewrite
1113     : public OpRewritePattern<TestRecursiveRewriteOp> {
1114   using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
1115 
1116   void initialize() {
1117     // The conversion target handles bounding the recursion of this pattern.
1118     setHasBoundedRewriteRecursion();
1119   }
1120 
1121   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
1122                                 PatternRewriter &rewriter) const final {
1123     // Decrement the depth of the op in-place.
1124     rewriter.modifyOpInPlace(op, [&] {
1125       op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
1126     });
1127     return success();
1128   }
1129 };
1130 
1131 struct TestNestedOpCreationUndoRewrite
1132     : public OpRewritePattern<IllegalOpWithRegionAnchor> {
1133   using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
1134 
1135   LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
1136                                 PatternRewriter &rewriter) const final {
1137     // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1138     rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1139     return success();
1140   };
1141 };
1142 
1143 // This pattern matches `test.blackhole` and delete this op and its producer.
1144 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
1145   using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
1146 
1147   LogicalResult matchAndRewrite(BlackHoleOp op,
1148                                 PatternRewriter &rewriter) const final {
1149     Operation *producer = op.getOperand().getDefiningOp();
1150     // Always erase the user before the producer, the framework should handle
1151     // this correctly.
1152     rewriter.eraseOp(op);
1153     rewriter.eraseOp(producer);
1154     return success();
1155   };
1156 };
1157 
1158 // This pattern replaces explicitly illegal op with explicitly legal op,
1159 // but in addition creates unregistered operation.
1160 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
1161   using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
1162 
1163   LogicalResult matchAndRewrite(ILLegalOpG op,
1164                                 PatternRewriter &rewriter) const final {
1165     IntegerAttr attr = rewriter.getI32IntegerAttr(0);
1166     Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
1167     rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
1168     return success();
1169   };
1170 };
1171 
1172 class TestEraseOp : public ConversionPattern {
1173 public:
1174   TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op", 1, ctx) {}
1175   LogicalResult
1176   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1177                   ConversionPatternRewriter &rewriter) const final {
1178     // Erase op without replacements.
1179     rewriter.eraseOp(op);
1180     return success();
1181   }
1182 };
1183 
1184 } // namespace
1185 
1186 namespace {
1187 struct TestTypeConverter : public TypeConverter {
1188   using TypeConverter::TypeConverter;
1189   TestTypeConverter() {
1190     addConversion(convertType);
1191     addArgumentMaterialization(materializeCast);
1192     addSourceMaterialization(materializeCast);
1193   }
1194 
1195   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
1196     // Drop I16 types.
1197     if (t.isSignlessInteger(16))
1198       return success();
1199 
1200     // Convert I64 to F64.
1201     if (t.isSignlessInteger(64)) {
1202       results.push_back(FloatType::getF64(t.getContext()));
1203       return success();
1204     }
1205 
1206     // Convert I42 to I43.
1207     if (t.isInteger(42)) {
1208       results.push_back(IntegerType::get(t.getContext(), 43));
1209       return success();
1210     }
1211 
1212     // Split F32 into F16,F16.
1213     if (t.isF32()) {
1214       results.assign(2, FloatType::getF16(t.getContext()));
1215       return success();
1216     }
1217 
1218     // Otherwise, convert the type directly.
1219     results.push_back(t);
1220     return success();
1221   }
1222 
1223   /// Hook for materializing a conversion. This is necessary because we generate
1224   /// 1->N type mappings.
1225   static Value materializeCast(OpBuilder &builder, Type resultType,
1226                                ValueRange inputs, Location loc) {
1227     return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1228   }
1229 };
1230 
1231 struct TestLegalizePatternDriver
1232     : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> {
1233   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver)
1234 
1235   StringRef getArgument() const final { return "test-legalize-patterns"; }
1236   StringRef getDescription() const final {
1237     return "Run test dialect legalization patterns";
1238   }
1239   /// The mode of conversion to use with the driver.
1240   enum class ConversionMode { Analysis, Full, Partial };
1241 
1242   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
1243 
1244   void getDependentDialects(DialectRegistry &registry) const override {
1245     registry.insert<func::FuncDialect, test::TestDialect>();
1246   }
1247 
1248   void runOnOperation() override {
1249     TestTypeConverter converter;
1250     mlir::RewritePatternSet patterns(&getContext());
1251     populateWithGenerated(patterns);
1252     patterns.add<
1253         TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1254         TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1255         TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
1256         TestSplitReturnType, TestChangeProducerTypeI32ToF32,
1257         TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
1258         TestUpdateConsumerType, TestNonRootReplacement,
1259         TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
1260         TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1261         TestUndoPropertiesModification, TestEraseOp>(&getContext());
1262     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
1263         &getContext(), converter);
1264     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1265                                                               converter);
1266     mlir::populateCallOpTypeConversionPattern(patterns, converter);
1267 
1268     // Define the conversion target used for the test.
1269     ConversionTarget target(getContext());
1270     target.addLegalOp<ModuleOp>();
1271     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
1272                       TerminatorOp, OneRegionOp>();
1273     target.addLegalOp(
1274         OperationName("test.legal_op_with_region", &getContext()));
1275     target
1276         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
1277     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
1278       // Don't allow F32 operands.
1279       return llvm::none_of(op.getOperandTypes(),
1280                            [](Type type) { return type.isF32(); });
1281     });
1282     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1283       return converter.isSignatureLegal(op.getFunctionType()) &&
1284              converter.isLegal(&op.getBody());
1285     });
1286     target.addDynamicallyLegalOp<func::CallOp>(
1287         [&](func::CallOp op) { return converter.isLegal(op); });
1288 
1289     // TestCreateUnregisteredOp creates `arith.constant` operation,
1290     // which was not added to target intentionally to test
1291     // correct error code from conversion driver.
1292     target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
1293 
1294     // Expect the type_producer/type_consumer operations to only operate on f64.
1295     target.addDynamicallyLegalOp<TestTypeProducerOp>(
1296         [](TestTypeProducerOp op) { return op.getType().isF64(); });
1297     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1298       return op.getOperand().getType().isF64();
1299     });
1300 
1301     // Check support for marking certain operations as recursively legal.
1302     target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) {
1303       return static_cast<bool>(
1304           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
1305     });
1306 
1307     // Mark the bound recursion operation as dynamically legal.
1308     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
1309         [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
1310 
1311     // Create a dynamically legal rule that can only be legalized by folding it.
1312     target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
1313         [](TestOpInPlaceSelfFold op) { return op.getFolded(); });
1314 
1315     // Handle a partial conversion.
1316     if (mode == ConversionMode::Partial) {
1317       DenseSet<Operation *> unlegalizedOps;
1318       ConversionConfig config;
1319       DumpNotifications dumpNotifications;
1320       config.listener = &dumpNotifications;
1321       config.unlegalizedOps = &unlegalizedOps;
1322       if (failed(applyPartialConversion(getOperation(), target,
1323                                         std::move(patterns), config))) {
1324         getOperation()->emitRemark() << "applyPartialConversion failed";
1325       }
1326       // Emit remarks for each legalizable operation.
1327       for (auto *op : unlegalizedOps)
1328         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1329       return;
1330     }
1331 
1332     // Handle a full conversion.
1333     if (mode == ConversionMode::Full) {
1334       // Check support for marking unknown operations as dynamically legal.
1335       target.markUnknownOpDynamicallyLegal([](Operation *op) {
1336         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
1337       });
1338 
1339       ConversionConfig config;
1340       DumpNotifications dumpNotifications;
1341       config.listener = &dumpNotifications;
1342       if (failed(applyFullConversion(getOperation(), target,
1343                                      std::move(patterns), config))) {
1344         getOperation()->emitRemark() << "applyFullConversion failed";
1345       }
1346       return;
1347     }
1348 
1349     // Otherwise, handle an analysis conversion.
1350     assert(mode == ConversionMode::Analysis);
1351 
1352     // Analyze the convertible operations.
1353     DenseSet<Operation *> legalizedOps;
1354     ConversionConfig config;
1355     config.legalizableOps = &legalizedOps;
1356     if (failed(applyAnalysisConversion(getOperation(), target,
1357                                        std::move(patterns), config)))
1358       return signalPassFailure();
1359 
1360     // Emit remarks for each legalizable operation.
1361     for (auto *op : legalizedOps)
1362       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
1363   }
1364 
1365   /// The mode of conversion to use.
1366   ConversionMode mode;
1367 };
1368 } // namespace
1369 
1370 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
1371     legalizerConversionMode(
1372         "test-legalize-mode",
1373         llvm::cl::desc("The legalization mode to use with the test driver"),
1374         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
1375         llvm::cl::values(
1376             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
1377                        "analysis", "Perform an analysis conversion"),
1378             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
1379                        "Perform a full conversion"),
1380             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
1381                        "partial", "Perform a partial conversion")));
1382 
1383 //===----------------------------------------------------------------------===//
1384 // ConversionPatternRewriter::getRemappedValue testing. This method is used
1385 // to get the remapped value of an original value that was replaced using
1386 // ConversionPatternRewriter.
1387 namespace {
1388 struct TestRemapValueTypeConverter : public TypeConverter {
1389   using TypeConverter::TypeConverter;
1390 
1391   TestRemapValueTypeConverter() {
1392     addConversion(
1393         [](Float32Type type) { return Float64Type::get(type.getContext()); });
1394     addConversion([](Type type) { return type; });
1395   }
1396 };
1397 
1398 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
1399 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
1400 /// operand twice.
1401 ///
1402 /// Example:
1403 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
1404 /// is replaced with:
1405 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
1406 struct OneVResOneVOperandOp1Converter
1407     : public OpConversionPattern<OneVResOneVOperandOp1> {
1408   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
1409 
1410   LogicalResult
1411   matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
1412                   ConversionPatternRewriter &rewriter) const override {
1413     auto origOps = op.getOperands();
1414     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
1415            "One operand expected");
1416     Value origOp = *origOps.begin();
1417     SmallVector<Value, 2> remappedOperands;
1418     // Replicate the remapped original operand twice. Note that we don't used
1419     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
1420     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1421     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1422 
1423     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
1424                                                        remappedOperands);
1425     return success();
1426   }
1427 };
1428 
1429 /// A rewriter pattern that tests that blocks can be merged.
1430 struct TestRemapValueInRegion
1431     : public OpConversionPattern<TestRemappedValueRegionOp> {
1432   using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
1433 
1434   LogicalResult
1435   matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
1436                   ConversionPatternRewriter &rewriter) const final {
1437     Block &block = op.getBody().front();
1438     Operation *terminator = block.getTerminator();
1439 
1440     // Merge the block into the parent region.
1441     Block *parentBlock = op->getBlock();
1442     Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator());
1443     rewriter.mergeBlocks(&block, parentBlock, ValueRange());
1444     rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange());
1445 
1446     // Replace the results of this operation with the remapped terminator
1447     // values.
1448     SmallVector<Value> terminatorOperands;
1449     if (failed(rewriter.getRemappedValues(terminator->getOperands(),
1450                                           terminatorOperands)))
1451       return failure();
1452 
1453     rewriter.eraseOp(terminator);
1454     rewriter.replaceOp(op, terminatorOperands);
1455     return success();
1456   }
1457 };
1458 
1459 struct TestRemappedValue
1460     : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> {
1461   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue)
1462 
1463   StringRef getArgument() const final { return "test-remapped-value"; }
1464   StringRef getDescription() const final {
1465     return "Test public remapped value mechanism in ConversionPatternRewriter";
1466   }
1467   void runOnOperation() override {
1468     TestRemapValueTypeConverter typeConverter;
1469 
1470     mlir::RewritePatternSet patterns(&getContext());
1471     patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
1472     patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
1473         &getContext());
1474     patterns.add<TestRemapValueInRegion>(typeConverter, &getContext());
1475 
1476     mlir::ConversionTarget target(getContext());
1477     target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>();
1478 
1479     // Expect the type_producer/type_consumer operations to only operate on f64.
1480     target.addDynamicallyLegalOp<TestTypeProducerOp>(
1481         [](TestTypeProducerOp op) { return op.getType().isF64(); });
1482     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1483       return op.getOperand().getType().isF64();
1484     });
1485 
1486     // We make OneVResOneVOperandOp1 legal only when it has more that one
1487     // operand. This will trigger the conversion that will replace one-operand
1488     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
1489     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
1490         [](Operation *op) { return op->getNumOperands() > 1; });
1491 
1492     if (failed(mlir::applyFullConversion(getOperation(), target,
1493                                          std::move(patterns)))) {
1494       signalPassFailure();
1495     }
1496   }
1497 };
1498 } // namespace
1499 
1500 //===----------------------------------------------------------------------===//
1501 // Test patterns without a specific root operation kind
1502 //===----------------------------------------------------------------------===//
1503 
1504 namespace {
1505 /// This pattern matches and removes any operation in the test dialect.
1506 struct RemoveTestDialectOps : public RewritePattern {
1507   RemoveTestDialectOps(MLIRContext *context)
1508       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
1509 
1510   LogicalResult matchAndRewrite(Operation *op,
1511                                 PatternRewriter &rewriter) const override {
1512     if (!isa<TestDialect>(op->getDialect()))
1513       return failure();
1514     rewriter.eraseOp(op);
1515     return success();
1516   }
1517 };
1518 
1519 struct TestUnknownRootOpDriver
1520     : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> {
1521   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver)
1522 
1523   StringRef getArgument() const final {
1524     return "test-legalize-unknown-root-patterns";
1525   }
1526   StringRef getDescription() const final {
1527     return "Test public remapped value mechanism in ConversionPatternRewriter";
1528   }
1529   void runOnOperation() override {
1530     mlir::RewritePatternSet patterns(&getContext());
1531     patterns.add<RemoveTestDialectOps>(&getContext());
1532 
1533     mlir::ConversionTarget target(getContext());
1534     target.addIllegalDialect<TestDialect>();
1535     if (failed(applyPartialConversion(getOperation(), target,
1536                                       std::move(patterns))))
1537       signalPassFailure();
1538   }
1539 };
1540 } // namespace
1541 
1542 //===----------------------------------------------------------------------===//
1543 // Test patterns that uses operations and types defined at runtime
1544 //===----------------------------------------------------------------------===//
1545 
1546 namespace {
1547 /// This pattern matches dynamic operations 'test.one_operand_two_results' and
1548 /// replace them with dynamic operations 'test.generic_dynamic_op'.
1549 struct RewriteDynamicOp : public RewritePattern {
1550   RewriteDynamicOp(MLIRContext *context)
1551       : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1,
1552                        context) {}
1553 
1554   LogicalResult matchAndRewrite(Operation *op,
1555                                 PatternRewriter &rewriter) const override {
1556     assert(op->getName().getStringRef() ==
1557                "test.dynamic_one_operand_two_results" &&
1558            "rewrite pattern should only match operations with the right name");
1559 
1560     OperationState state(op->getLoc(), "test.dynamic_generic",
1561                          op->getOperands(), op->getResultTypes(),
1562                          op->getAttrs());
1563     auto *newOp = rewriter.create(state);
1564     rewriter.replaceOp(op, newOp->getResults());
1565     return success();
1566   }
1567 };
1568 
1569 struct TestRewriteDynamicOpDriver
1570     : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> {
1571   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver)
1572 
1573   void getDependentDialects(DialectRegistry &registry) const override {
1574     registry.insert<TestDialect>();
1575   }
1576   StringRef getArgument() const final { return "test-rewrite-dynamic-op"; }
1577   StringRef getDescription() const final {
1578     return "Test rewritting on dynamic operations";
1579   }
1580   void runOnOperation() override {
1581     RewritePatternSet patterns(&getContext());
1582     patterns.add<RewriteDynamicOp>(&getContext());
1583 
1584     ConversionTarget target(getContext());
1585     target.addIllegalOp(
1586         OperationName("test.dynamic_one_operand_two_results", &getContext()));
1587     target.addLegalOp(OperationName("test.dynamic_generic", &getContext()));
1588     if (failed(applyPartialConversion(getOperation(), target,
1589                                       std::move(patterns))))
1590       signalPassFailure();
1591   }
1592 };
1593 } // end anonymous namespace
1594 
1595 //===----------------------------------------------------------------------===//
1596 // Test type conversions
1597 //===----------------------------------------------------------------------===//
1598 
1599 namespace {
1600 struct TestTypeConversionProducer
1601     : public OpConversionPattern<TestTypeProducerOp> {
1602   using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
1603   LogicalResult
1604   matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
1605                   ConversionPatternRewriter &rewriter) const final {
1606     Type resultType = op.getType();
1607     Type convertedType = getTypeConverter()
1608                              ? getTypeConverter()->convertType(resultType)
1609                              : resultType;
1610     if (isa<FloatType>(resultType))
1611       resultType = rewriter.getF64Type();
1612     else if (resultType.isInteger(16))
1613       resultType = rewriter.getIntegerType(64);
1614     else if (isa<test::TestRecursiveType>(resultType) &&
1615              convertedType != resultType)
1616       resultType = convertedType;
1617     else
1618       return failure();
1619 
1620     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
1621     return success();
1622   }
1623 };
1624 
1625 /// Call signature conversion and then fail the rewrite to trigger the undo
1626 /// mechanism.
1627 struct TestSignatureConversionUndo
1628     : public OpConversionPattern<TestSignatureConversionUndoOp> {
1629   using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
1630 
1631   LogicalResult
1632   matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
1633                   ConversionPatternRewriter &rewriter) const final {
1634     (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
1635     return failure();
1636   }
1637 };
1638 
1639 /// Call signature conversion without providing a type converter to handle
1640 /// materializations.
1641 struct TestTestSignatureConversionNoConverter
1642     : public OpConversionPattern<TestSignatureConversionNoConverterOp> {
1643   TestTestSignatureConversionNoConverter(const TypeConverter &converter,
1644                                          MLIRContext *context)
1645       : OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
1646         converter(converter) {}
1647 
1648   LogicalResult
1649   matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor,
1650                   ConversionPatternRewriter &rewriter) const final {
1651     Region &region = op->getRegion(0);
1652     Block *entry = &region.front();
1653 
1654     // Convert the original entry arguments.
1655     TypeConverter::SignatureConversion result(entry->getNumArguments());
1656     if (failed(
1657             converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
1658       return failure();
1659     rewriter.modifyOpInPlace(op, [&] {
1660       rewriter.applySignatureConversion(&region.front(), result);
1661     });
1662     return success();
1663   }
1664 
1665   const TypeConverter &converter;
1666 };
1667 
1668 /// Just forward the operands to the root op. This is essentially a no-op
1669 /// pattern that is used to trigger target materialization.
1670 struct TestTypeConsumerForward
1671     : public OpConversionPattern<TestTypeConsumerOp> {
1672   using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
1673 
1674   LogicalResult
1675   matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
1676                   ConversionPatternRewriter &rewriter) const final {
1677     rewriter.modifyOpInPlace(op,
1678                              [&] { op->setOperands(adaptor.getOperands()); });
1679     return success();
1680   }
1681 };
1682 
1683 struct TestTypeConversionAnotherProducer
1684     : public OpRewritePattern<TestAnotherTypeProducerOp> {
1685   using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
1686 
1687   LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
1688                                 PatternRewriter &rewriter) const final {
1689     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
1690     return success();
1691   }
1692 };
1693 
1694 struct TestReplaceWithLegalOp : public ConversionPattern {
1695   TestReplaceWithLegalOp(MLIRContext *ctx)
1696       : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
1697   LogicalResult
1698   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1699                   ConversionPatternRewriter &rewriter) const final {
1700     rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
1701     return success();
1702   }
1703 };
1704 
1705 struct TestTypeConversionDriver
1706     : public PassWrapper<TestTypeConversionDriver, OperationPass<>> {
1707   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
1708 
1709   void getDependentDialects(DialectRegistry &registry) const override {
1710     registry.insert<TestDialect>();
1711   }
1712   StringRef getArgument() const final {
1713     return "test-legalize-type-conversion";
1714   }
1715   StringRef getDescription() const final {
1716     return "Test various type conversion functionalities in DialectConversion";
1717   }
1718 
1719   void runOnOperation() override {
1720     // Initialize the type converter.
1721     SmallVector<Type, 2> conversionCallStack;
1722     TypeConverter converter;
1723 
1724     /// Add the legal set of type conversions.
1725     converter.addConversion([](Type type) -> Type {
1726       // Treat F64 as legal.
1727       if (type.isF64())
1728         return type;
1729       // Allow converting BF16/F16/F32 to F64.
1730       if (type.isBF16() || type.isF16() || type.isF32())
1731         return FloatType::getF64(type.getContext());
1732       // Otherwise, the type is illegal.
1733       return nullptr;
1734     });
1735     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
1736       // Drop all integer types.
1737       return success();
1738     });
1739     converter.addConversion(
1740         // Convert a recursive self-referring type into a non-self-referring
1741         // type named "outer_converted_type" that contains a SimpleAType.
1742         [&](test::TestRecursiveType type,
1743             SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
1744           // If the type is already converted, return it to indicate that it is
1745           // legal.
1746           if (type.getName() == "outer_converted_type") {
1747             results.push_back(type);
1748             return success();
1749           }
1750 
1751           conversionCallStack.push_back(type);
1752           auto popConversionCallStack = llvm::make_scope_exit(
1753               [&conversionCallStack]() { conversionCallStack.pop_back(); });
1754 
1755           // If the type is on the call stack more than once (it is there at
1756           // least once because of the _current_ call, which is always the last
1757           // element on the stack), we've hit the recursive case. Just return
1758           // SimpleAType here to create a non-recursive type as a result.
1759           if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(),
1760                                  type)) {
1761             results.push_back(test::SimpleAType::get(type.getContext()));
1762             return success();
1763           }
1764 
1765           // Convert the body recursively.
1766           auto result = test::TestRecursiveType::get(type.getContext(),
1767                                                      "outer_converted_type");
1768           if (failed(result.setBody(converter.convertType(type.getBody()))))
1769             return failure();
1770           results.push_back(result);
1771           return success();
1772         });
1773 
1774     /// Add the legal set of type materializations.
1775     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
1776                                           ValueRange inputs,
1777                                           Location loc) -> Value {
1778       // Allow casting from F64 back to F32.
1779       if (!resultType.isF16() && inputs.size() == 1 &&
1780           inputs[0].getType().isF64())
1781         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1782       // Allow producing an i32 or i64 from nothing.
1783       if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
1784           inputs.empty())
1785         return builder.create<TestTypeProducerOp>(loc, resultType);
1786       // Allow producing an i64 from an integer.
1787       if (isa<IntegerType>(resultType) && inputs.size() == 1 &&
1788           isa<IntegerType>(inputs[0].getType()))
1789         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1790       // Otherwise, fail.
1791       return nullptr;
1792     });
1793 
1794     // Initialize the conversion target.
1795     mlir::ConversionTarget target(getContext());
1796     target.addLegalOp<LegalOpD>();
1797     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
1798       auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType());
1799       return op.getType().isF64() || op.getType().isInteger(64) ||
1800              (recursiveType &&
1801               recursiveType.getName() == "outer_converted_type");
1802     });
1803     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1804       return converter.isSignatureLegal(op.getFunctionType()) &&
1805              converter.isLegal(&op.getBody());
1806     });
1807     target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
1808       // Allow casts from F64 to F32.
1809       return (*op.operand_type_begin()).isF64() && op.getType().isF32();
1810     });
1811     target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>(
1812         [&](TestSignatureConversionNoConverterOp op) {
1813           return converter.isLegal(op.getRegion().front().getArgumentTypes());
1814         });
1815 
1816     // Initialize the set of rewrite patterns.
1817     RewritePatternSet patterns(&getContext());
1818     patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
1819                  TestSignatureConversionUndo,
1820                  TestTestSignatureConversionNoConverter>(converter,
1821                                                          &getContext());
1822     patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
1823         &getContext());
1824     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1825                                                               converter);
1826 
1827     if (failed(applyPartialConversion(getOperation(), target,
1828                                       std::move(patterns))))
1829       signalPassFailure();
1830   }
1831 };
1832 } // namespace
1833 
1834 //===----------------------------------------------------------------------===//
1835 // Test Target Materialization With No Uses
1836 //===----------------------------------------------------------------------===//
1837 
1838 namespace {
1839 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> {
1840   using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern;
1841 
1842   LogicalResult
1843   matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor,
1844                   ConversionPatternRewriter &rewriter) const final {
1845     rewriter.replaceOp(op, adaptor.getOperands());
1846     return success();
1847   }
1848 };
1849 
1850 struct TestTargetMaterializationWithNoUses
1851     : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> {
1852   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
1853       TestTargetMaterializationWithNoUses)
1854 
1855   StringRef getArgument() const final {
1856     return "test-target-materialization-with-no-uses";
1857   }
1858   StringRef getDescription() const final {
1859     return "Test a special case of target materialization in DialectConversion";
1860   }
1861 
1862   void runOnOperation() override {
1863     TypeConverter converter;
1864     converter.addConversion([](Type t) { return t; });
1865     converter.addConversion([](IntegerType intTy) -> Type {
1866       if (intTy.getWidth() == 16)
1867         return IntegerType::get(intTy.getContext(), 64);
1868       return intTy;
1869     });
1870     converter.addTargetMaterialization(
1871         [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1872           return builder.create<TestCastOp>(loc, type, inputs).getResult();
1873         });
1874 
1875     ConversionTarget target(getContext());
1876     target.addIllegalOp<TestTypeChangerOp>();
1877 
1878     RewritePatternSet patterns(&getContext());
1879     patterns.add<ForwardOperandPattern>(converter, &getContext());
1880 
1881     if (failed(applyPartialConversion(getOperation(), target,
1882                                       std::move(patterns))))
1883       signalPassFailure();
1884   }
1885 };
1886 } // namespace
1887 
1888 //===----------------------------------------------------------------------===//
1889 // Test Block Merging
1890 //===----------------------------------------------------------------------===//
1891 
1892 namespace {
1893 /// A rewriter pattern that tests that blocks can be merged.
1894 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
1895   using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
1896 
1897   LogicalResult
1898   matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
1899                   ConversionPatternRewriter &rewriter) const final {
1900     Block &firstBlock = op.getBody().front();
1901     Operation *branchOp = firstBlock.getTerminator();
1902     Block *secondBlock = &*(std::next(op.getBody().begin()));
1903     auto succOperands = branchOp->getOperands();
1904     SmallVector<Value, 2> replacements(succOperands);
1905     rewriter.eraseOp(branchOp);
1906     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1907     rewriter.modifyOpInPlace(op, [] {});
1908     return success();
1909   }
1910 };
1911 
1912 /// A rewrite pattern to tests the undo mechanism of blocks being merged.
1913 struct TestUndoBlocksMerge : public ConversionPattern {
1914   TestUndoBlocksMerge(MLIRContext *ctx)
1915       : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
1916   LogicalResult
1917   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1918                   ConversionPatternRewriter &rewriter) const final {
1919     Block &firstBlock = op->getRegion(0).front();
1920     Operation *branchOp = firstBlock.getTerminator();
1921     Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
1922     rewriter.setInsertionPointToStart(secondBlock);
1923     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
1924     auto succOperands = branchOp->getOperands();
1925     SmallVector<Value, 2> replacements(succOperands);
1926     rewriter.eraseOp(branchOp);
1927     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1928     rewriter.modifyOpInPlace(op, [] {});
1929     return success();
1930   }
1931 };
1932 
1933 /// A rewrite mechanism to inline the body of the op into its parent, when both
1934 /// ops can have a single block.
1935 struct TestMergeSingleBlockOps
1936     : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
1937   using OpConversionPattern<
1938       SingleBlockImplicitTerminatorOp>::OpConversionPattern;
1939 
1940   LogicalResult
1941   matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
1942                   ConversionPatternRewriter &rewriter) const final {
1943     SingleBlockImplicitTerminatorOp parentOp =
1944         op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1945     if (!parentOp)
1946       return failure();
1947     Block &innerBlock = op.getRegion().front();
1948     TerminatorOp innerTerminator =
1949         cast<TerminatorOp>(innerBlock.getTerminator());
1950     rewriter.inlineBlockBefore(&innerBlock, op);
1951     rewriter.eraseOp(innerTerminator);
1952     rewriter.eraseOp(op);
1953     return success();
1954   }
1955 };
1956 
1957 struct TestMergeBlocksPatternDriver
1958     : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> {
1959   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver)
1960 
1961   StringRef getArgument() const final { return "test-merge-blocks"; }
1962   StringRef getDescription() const final {
1963     return "Test Merging operation in ConversionPatternRewriter";
1964   }
1965   void runOnOperation() override {
1966     MLIRContext *context = &getContext();
1967     mlir::RewritePatternSet patterns(context);
1968     patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
1969         context);
1970     ConversionTarget target(*context);
1971     target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
1972                       TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
1973     target.addIllegalOp<ILLegalOpF>();
1974 
1975     /// Expect the op to have a single block after legalization.
1976     target.addDynamicallyLegalOp<TestMergeBlocksOp>(
1977         [&](TestMergeBlocksOp op) -> bool {
1978           return llvm::hasSingleElement(op.getBody());
1979         });
1980 
1981     /// Only allow `test.br` within test.merge_blocks op.
1982     target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
1983       return op->getParentOfType<TestMergeBlocksOp>();
1984     });
1985 
1986     /// Expect that all nested test.SingleBlockImplicitTerminator ops are
1987     /// inlined.
1988     target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
1989         [&](SingleBlockImplicitTerminatorOp op) -> bool {
1990           return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1991         });
1992 
1993     DenseSet<Operation *> unlegalizedOps;
1994     ConversionConfig config;
1995     config.unlegalizedOps = &unlegalizedOps;
1996     (void)applyPartialConversion(getOperation(), target, std::move(patterns),
1997                                  config);
1998     for (auto *op : unlegalizedOps)
1999       op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
2000   }
2001 };
2002 } // namespace
2003 
2004 //===----------------------------------------------------------------------===//
2005 // Test Selective Replacement
2006 //===----------------------------------------------------------------------===//
2007 
2008 namespace {
2009 /// A rewrite mechanism to inline the body of the op into its parent, when both
2010 /// ops can have a single block.
2011 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
2012   using OpRewritePattern<TestCastOp>::OpRewritePattern;
2013 
2014   LogicalResult matchAndRewrite(TestCastOp op,
2015                                 PatternRewriter &rewriter) const final {
2016     if (op.getNumOperands() != 2)
2017       return failure();
2018     OperandRange operands = op.getOperands();
2019 
2020     // Replace non-terminator uses with the first operand.
2021     rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) {
2022       return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
2023     });
2024     // Replace everything else with the second operand if the operation isn't
2025     // dead.
2026     rewriter.replaceOp(op, op.getOperand(1));
2027     return success();
2028   }
2029 };
2030 
2031 struct TestSelectiveReplacementPatternDriver
2032     : public PassWrapper<TestSelectiveReplacementPatternDriver,
2033                          OperationPass<>> {
2034   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
2035       TestSelectiveReplacementPatternDriver)
2036 
2037   StringRef getArgument() const final {
2038     return "test-pattern-selective-replacement";
2039   }
2040   StringRef getDescription() const final {
2041     return "Test selective replacement in the PatternRewriter";
2042   }
2043   void runOnOperation() override {
2044     MLIRContext *context = &getContext();
2045     mlir::RewritePatternSet patterns(context);
2046     patterns.add<TestSelectiveOpReplacementPattern>(context);
2047     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
2048   }
2049 };
2050 } // namespace
2051 
2052 //===----------------------------------------------------------------------===//
2053 // PassRegistration
2054 //===----------------------------------------------------------------------===//
2055 
2056 namespace mlir {
2057 namespace test {
2058 void registerPatternsTestPass() {
2059   PassRegistration<TestReturnTypeDriver>();
2060 
2061   PassRegistration<TestDerivedAttributeDriver>();
2062 
2063   PassRegistration<TestGreedyPatternDriver>();
2064   PassRegistration<TestStrictPatternDriver>();
2065   PassRegistration<TestWalkPatternDriver>();
2066 
2067   PassRegistration<TestLegalizePatternDriver>([] {
2068     return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
2069   });
2070 
2071   PassRegistration<TestRemappedValue>();
2072 
2073   PassRegistration<TestUnknownRootOpDriver>();
2074 
2075   PassRegistration<TestTypeConversionDriver>();
2076   PassRegistration<TestTargetMaterializationWithNoUses>();
2077 
2078   PassRegistration<TestRewriteDynamicOpDriver>();
2079 
2080   PassRegistration<TestMergeBlocksPatternDriver>();
2081   PassRegistration<TestSelectiveReplacementPatternDriver>();
2082 }
2083 } // namespace test
2084 } // namespace mlir
2085