xref: /llvm-project/mlir/test/lib/Dialect/Test/TestPatterns.cpp (revision 3cc311ab8674eab6b9101cdf3823b55ea23d6535)
1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestOps.h"
11 #include "TestTypes.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BuiltinAttributes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Visitors.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
24 #include "llvm/ADT/ScopeExit.h"
25 #include <cstdint>
26 
27 using namespace mlir;
28 using namespace test;
29 
30 // Native function for testing NativeCodeCall
31 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
32   return choice.getValue() ? input1 : input2;
33 }
34 
35 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
36   rewriter.create<OpI>(loc, input);
37 }
38 
39 static void handleNoResultOp(PatternRewriter &rewriter,
40                              OpSymbolBindingNoResult op) {
41   // Turn the no result op to a one-result op.
42   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
43                                     op.getOperand());
44 }
45 
46 static bool getFirstI32Result(Operation *op, Value &value) {
47   if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
48     return false;
49   value = op->getResult(0);
50   return true;
51 }
52 
53 static Value bindNativeCodeCallResult(Value value) { return value; }
54 
55 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
56                                                               Value input2) {
57   return SmallVector<Value, 2>({input2, input1});
58 }
59 
60 // Test that natives calls are only called once during rewrites.
61 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
62 // This let us check the number of times OpM_Test was called by inspecting
63 // the returned value in the MLIR output.
64 static int64_t opMIncreasingValue = 314159265;
65 static Attribute opMTest(PatternRewriter &rewriter, Value val) {
66   int64_t i = opMIncreasingValue++;
67   return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
68 }
69 
70 namespace {
71 #include "TestPatterns.inc"
72 } // namespace
73 
74 //===----------------------------------------------------------------------===//
75 // Test Reduce Pattern Interface
76 //===----------------------------------------------------------------------===//
77 
78 void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
79   populateWithGenerated(patterns);
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // Canonicalizer Driver.
84 //===----------------------------------------------------------------------===//
85 
86 namespace {
87 struct FoldingPattern : public RewritePattern {
88 public:
89   FoldingPattern(MLIRContext *context)
90       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
91                        /*benefit=*/1, context) {}
92 
93   LogicalResult matchAndRewrite(Operation *op,
94                                 PatternRewriter &rewriter) const override {
95     // Exercise createOrFold API for a single-result operation that is folded
96     // upon construction. The operation being created has an in-place folder,
97     // and it should be still present in the output. Furthermore, the folder
98     // should not crash when attempting to recover the (unchanged) operation
99     // result.
100     Value result = rewriter.createOrFold<TestOpInPlaceFold>(
101         op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0));
102     assert(result);
103     rewriter.replaceOp(op, result);
104     return success();
105   }
106 };
107 
108 /// This pattern creates a foldable operation at the entry point of the block.
109 /// This tests the situation where the operation folder will need to replace an
110 /// operation with a previously created constant that does not initially
111 /// dominate the operation to replace.
112 struct FolderInsertBeforePreviouslyFoldedConstantPattern
113     : public OpRewritePattern<TestCastOp> {
114 public:
115   using OpRewritePattern<TestCastOp>::OpRewritePattern;
116 
117   LogicalResult matchAndRewrite(TestCastOp op,
118                                 PatternRewriter &rewriter) const override {
119     if (!op->hasAttr("test_fold_before_previously_folded_op"))
120       return failure();
121     rewriter.setInsertionPointToStart(op->getBlock());
122 
123     auto constOp = rewriter.create<arith::ConstantOp>(
124         op.getLoc(), rewriter.getBoolAttr(true));
125     rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
126                                             Value(constOp));
127     return success();
128   }
129 };
130 
131 /// This pattern matches test.op_commutative2 with the first operand being
132 /// another test.op_commutative2 with a constant on the right side and fold it
133 /// away by propagating it as its result. This is intend to check that patterns
134 /// are applied after the commutative property moves constant to the right.
135 struct FolderCommutativeOp2WithConstant
136     : public OpRewritePattern<TestCommutative2Op> {
137 public:
138   using OpRewritePattern<TestCommutative2Op>::OpRewritePattern;
139 
140   LogicalResult matchAndRewrite(TestCommutative2Op op,
141                                 PatternRewriter &rewriter) const override {
142     auto operand =
143         dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
144     if (!operand)
145       return failure();
146     Attribute constInput;
147     if (!matchPattern(operand->getOperand(1), m_Constant(&constInput)))
148       return failure();
149     rewriter.replaceOp(op, operand->getOperand(1));
150     return success();
151   }
152 };
153 
154 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer
155 /// attribute with value smaller than MaxVal, it increments the value by 1.
156 template <int MaxVal>
157 struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
158   using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern;
159 
160   LogicalResult matchAndRewrite(AnyAttrOfOp op,
161                                 PatternRewriter &rewriter) const override {
162     auto intAttr = dyn_cast<IntegerAttr>(op.getAttr());
163     if (!intAttr)
164       return failure();
165     int64_t val = intAttr.getInt();
166     if (val >= MaxVal)
167       return failure();
168     rewriter.modifyOpInPlace(
169         op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); });
170     return success();
171   }
172 };
173 
174 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op".
175 struct MakeOpEligible : public RewritePattern {
176   MakeOpEligible(MLIRContext *context)
177       : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {}
178 
179   LogicalResult matchAndRewrite(Operation *op,
180                                 PatternRewriter &rewriter) const override {
181     if (op->hasAttr("eligible"))
182       return failure();
183     rewriter.modifyOpInPlace(
184         op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); });
185     return success();
186   }
187 };
188 
189 /// This pattern hoists eligible ops out of a "test.one_region_op".
190 struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
191   using OpRewritePattern<test::OneRegionOp>::OpRewritePattern;
192 
193   LogicalResult matchAndRewrite(test::OneRegionOp op,
194                                 PatternRewriter &rewriter) const override {
195     Operation *terminator = op.getRegion().front().getTerminator();
196     Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp();
197     if (toBeHoisted->getParentOp() != op)
198       return failure();
199     if (!toBeHoisted->hasAttr("eligible"))
200       return failure();
201     rewriter.moveOpBefore(toBeHoisted, op);
202     return success();
203   }
204 };
205 
206 /// This pattern moves "test.move_before_parent_op" before the parent op.
207 struct MoveBeforeParentOp : public RewritePattern {
208   MoveBeforeParentOp(MLIRContext *context)
209       : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
210 
211   LogicalResult matchAndRewrite(Operation *op,
212                                 PatternRewriter &rewriter) const override {
213     // Do not hoist past functions.
214     if (isa<FunctionOpInterface>(op->getParentOp()))
215       return failure();
216     rewriter.moveOpBefore(op, op->getParentOp());
217     return success();
218   }
219 };
220 
221 /// This pattern moves "test.move_after_parent_op" after the parent op.
222 struct MoveAfterParentOp : public RewritePattern {
223   MoveAfterParentOp(MLIRContext *context)
224       : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {}
225 
226   LogicalResult matchAndRewrite(Operation *op,
227                                 PatternRewriter &rewriter) const override {
228     // Do not hoist past functions.
229     if (isa<FunctionOpInterface>(op->getParentOp()))
230       return failure();
231 
232     int64_t moveForwardBy = 0;
233     if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance"))
234       moveForwardBy = advanceBy.getInt();
235 
236     Operation *moveAfter = op->getParentOp();
237     for (int64_t i = 0; i < moveForwardBy; ++i)
238       moveAfter = moveAfter->getNextNode();
239 
240     rewriter.moveOpAfter(op, moveAfter);
241     return success();
242   }
243 };
244 
245 /// This pattern inlines blocks that are nested in
246 /// "test.inline_blocks_into_parent" into the parent block.
247 struct InlineBlocksIntoParent : public RewritePattern {
248   InlineBlocksIntoParent(MLIRContext *context)
249       : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
250                        context) {}
251 
252   LogicalResult matchAndRewrite(Operation *op,
253                                 PatternRewriter &rewriter) const override {
254     bool changed = false;
255     for (Region &r : op->getRegions()) {
256       while (!r.empty()) {
257         rewriter.inlineBlockBefore(&r.front(), op);
258         changed = true;
259       }
260     }
261     return success(changed);
262   }
263 };
264 
265 /// This pattern splits blocks at "test.split_block_here" and replaces the op
266 /// with a new op (to prevent an infinite loop of block splitting).
267 struct SplitBlockHere : public RewritePattern {
268   SplitBlockHere(MLIRContext *context)
269       : RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
270 
271   LogicalResult matchAndRewrite(Operation *op,
272                                 PatternRewriter &rewriter) const override {
273     rewriter.splitBlock(op->getBlock(), op->getIterator());
274     Operation *newOp = rewriter.create(
275         op->getLoc(),
276         OperationName("test.new_op", op->getContext()).getIdentifier(),
277         op->getOperands(), op->getResultTypes());
278     rewriter.replaceOp(op, newOp);
279     return success();
280   }
281 };
282 
283 /// This pattern clones "test.clone_me" ops.
284 struct CloneOp : public RewritePattern {
285   CloneOp(MLIRContext *context)
286       : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
287 
288   LogicalResult matchAndRewrite(Operation *op,
289                                 PatternRewriter &rewriter) const override {
290     // Do not clone already cloned ops to avoid going into an infinite loop.
291     if (op->hasAttr("was_cloned"))
292       return failure();
293     Operation *cloned = rewriter.clone(*op);
294     cloned->setAttr("was_cloned", rewriter.getUnitAttr());
295     return success();
296   }
297 };
298 
299 /// This pattern clones regions of "test.clone_region_before" ops before the
300 /// parent block.
301 struct CloneRegionBeforeOp : public RewritePattern {
302   CloneRegionBeforeOp(MLIRContext *context)
303       : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {}
304 
305   LogicalResult matchAndRewrite(Operation *op,
306                                 PatternRewriter &rewriter) const override {
307     // Do not clone already cloned ops to avoid going into an infinite loop.
308     if (op->hasAttr("was_cloned"))
309       return failure();
310     for (Region &r : op->getRegions())
311       rewriter.cloneRegionBefore(r, op->getBlock());
312     op->setAttr("was_cloned", rewriter.getUnitAttr());
313     return success();
314   }
315 };
316 
317 /// Replace an operation may introduce the re-visiting of its users.
318 class ReplaceWithNewOp : public RewritePattern {
319 public:
320   ReplaceWithNewOp(MLIRContext *context)
321       : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
322 
323   LogicalResult matchAndRewrite(Operation *op,
324                                 PatternRewriter &rewriter) const override {
325     Operation *newOp;
326     if (op->hasAttr("create_erase_op")) {
327       newOp = rewriter.create(
328           op->getLoc(),
329           OperationName("test.erase_op", op->getContext()).getIdentifier(),
330           ValueRange(), TypeRange());
331     } else {
332       newOp = rewriter.create(
333           op->getLoc(),
334           OperationName("test.new_op", op->getContext()).getIdentifier(),
335           op->getOperands(), op->getResultTypes());
336     }
337     // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
338     // A "notifyOperationReplaced" callback is triggered in either case.
339     rewriter.replaceAllOpUsesWith(op, newOp->getResults());
340     rewriter.eraseOp(op);
341     return success();
342   }
343 };
344 
345 /// Erases the first child block of the matched "test.erase_first_block"
346 /// operation.
347 class EraseFirstBlock : public RewritePattern {
348 public:
349   EraseFirstBlock(MLIRContext *context)
350       : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {}
351 
352   LogicalResult matchAndRewrite(Operation *op,
353                                 PatternRewriter &rewriter) const override {
354     for (Region &r : op->getRegions()) {
355       for (Block &b : r.getBlocks()) {
356         rewriter.eraseBlock(&b);
357         return success();
358       }
359     }
360 
361     return failure();
362   }
363 };
364 
365 struct TestGreedyPatternDriver
366     : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> {
367   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver)
368 
369   TestGreedyPatternDriver() = default;
370   TestGreedyPatternDriver(const TestGreedyPatternDriver &other)
371       : PassWrapper(other) {}
372 
373   StringRef getArgument() const final { return "test-greedy-patterns"; }
374   StringRef getDescription() const final { return "Run test dialect patterns"; }
375   void runOnOperation() override {
376     mlir::RewritePatternSet patterns(&getContext());
377     populateWithGenerated(patterns);
378 
379     // Verify named pattern is generated with expected name.
380     patterns.add<FoldingPattern, TestNamedPatternRule,
381                  FolderInsertBeforePreviouslyFoldedConstantPattern,
382                  FolderCommutativeOp2WithConstant, HoistEligibleOps,
383                  MakeOpEligible>(&getContext());
384 
385     // Additional patterns for testing the GreedyPatternRewriteDriver.
386     patterns.insert<IncrementIntAttribute<3>>(&getContext());
387 
388     GreedyRewriteConfig config;
389     config.useTopDownTraversal = this->useTopDownTraversal;
390     config.maxIterations = this->maxIterations;
391     config.fold = this->fold;
392     config.cseConstants = this->cseConstants;
393     (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
394   }
395 
396   Option<bool> useTopDownTraversal{
397       *this, "top-down",
398       llvm::cl::desc("Seed the worklist in general top-down order"),
399       llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
400   Option<int> maxIterations{
401       *this, "max-iterations",
402       llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
403       llvm::cl::init(GreedyRewriteConfig().maxIterations)};
404   Option<bool> fold{*this, "fold", llvm::cl::desc("Whether to fold"),
405                     llvm::cl::init(GreedyRewriteConfig().fold)};
406   Option<bool> cseConstants{*this, "cse-constants",
407                             llvm::cl::desc("Whether to CSE constants"),
408                             llvm::cl::init(GreedyRewriteConfig().cseConstants)};
409 };
410 
411 struct DumpNotifications : public RewriterBase::Listener {
412   void notifyBlockInserted(Block *block, Region *previous,
413                            Region::iterator previousIt) override {
414     llvm::outs() << "notifyBlockInserted";
415     if (block->getParentOp()) {
416       llvm::outs() << " into " << block->getParentOp()->getName() << ": ";
417     } else {
418       llvm::outs() << " into unknown op: ";
419     }
420     if (previous == nullptr) {
421       llvm::outs() << "was unlinked\n";
422     } else {
423       llvm::outs() << "was linked\n";
424     }
425   }
426   void notifyOperationInserted(Operation *op,
427                                OpBuilder::InsertPoint previous) override {
428     llvm::outs() << "notifyOperationInserted: " << op->getName();
429     if (!previous.isSet()) {
430       llvm::outs() << ", was unlinked\n";
431     } else {
432       if (!previous.getPoint().getNodePtr()) {
433         llvm::outs() << ", was linked, exact position unknown\n";
434       } else if (previous.getPoint() == previous.getBlock()->end()) {
435         llvm::outs() << ", was last in block\n";
436       } else {
437         llvm::outs() << ", previous = " << previous.getPoint()->getName()
438                      << "\n";
439       }
440     }
441   }
442   void notifyBlockErased(Block *block) override {
443     llvm::outs() << "notifyBlockErased\n";
444   }
445   void notifyOperationErased(Operation *op) override {
446     llvm::outs() << "notifyOperationErased: " << op->getName() << "\n";
447   }
448   void notifyOperationModified(Operation *op) override {
449     llvm::outs() << "notifyOperationModified: " << op->getName() << "\n";
450   }
451   void notifyOperationReplaced(Operation *op, ValueRange values) override {
452     llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n";
453   }
454 };
455 
456 struct TestStrictPatternDriver
457     : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
458 public:
459   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
460 
461   TestStrictPatternDriver() = default;
462   TestStrictPatternDriver(const TestStrictPatternDriver &other)
463       : PassWrapper(other) {
464     strictMode = other.strictMode;
465   }
466 
467   StringRef getArgument() const final { return "test-strict-pattern-driver"; }
468   StringRef getDescription() const final {
469     return "Test strict mode of pattern driver";
470   }
471 
472   void runOnOperation() override {
473     MLIRContext *ctx = &getContext();
474     mlir::RewritePatternSet patterns(ctx);
475     patterns.add<
476         // clang-format off
477         ChangeBlockOp,
478         CloneOp,
479         CloneRegionBeforeOp,
480         EraseOp,
481         ImplicitChangeOp,
482         InlineBlocksIntoParent,
483         InsertSameOp,
484         MoveBeforeParentOp,
485         ReplaceWithNewOp,
486         SplitBlockHere
487         // clang-format on
488         >(ctx);
489     SmallVector<Operation *> ops;
490     getOperation()->walk([&](Operation *op) {
491       StringRef opName = op->getName().getStringRef();
492       if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
493           opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
494           opName == "test.move_before_parent_op" ||
495           opName == "test.inline_blocks_into_parent" ||
496           opName == "test.split_block_here" || opName == "test.clone_me" ||
497           opName == "test.clone_region_before") {
498         ops.push_back(op);
499       }
500     });
501 
502     DumpNotifications dumpNotifications;
503     GreedyRewriteConfig config;
504     config.listener = &dumpNotifications;
505     if (strictMode == "AnyOp") {
506       config.strictMode = GreedyRewriteStrictness::AnyOp;
507     } else if (strictMode == "ExistingAndNewOps") {
508       config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
509     } else if (strictMode == "ExistingOps") {
510       config.strictMode = GreedyRewriteStrictness::ExistingOps;
511     } else {
512       llvm_unreachable("invalid strictness option");
513     }
514 
515     // Check if these transformations introduce visiting of operations that
516     // are not in the `ops` set (The new created ops are valid). An invalid
517     // operation will trigger the assertion while processing.
518     bool changed = false;
519     bool allErased = false;
520     (void)applyOpPatternsGreedily(ArrayRef(ops), std::move(patterns), config,
521                                   &changed, &allErased);
522     Builder b(ctx);
523     getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed));
524     getOperation()->setAttr("pattern_driver_all_erased",
525                             b.getBoolAttr(allErased));
526   }
527 
528   Option<std::string> strictMode{
529       *this, "strictness",
530       llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"),
531       llvm::cl::init("AnyOp")};
532 
533 private:
534   // New inserted operation is valid for further transformation.
535   class InsertSameOp : public RewritePattern {
536   public:
537     InsertSameOp(MLIRContext *context)
538         : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {}
539 
540     LogicalResult matchAndRewrite(Operation *op,
541                                   PatternRewriter &rewriter) const override {
542       if (op->hasAttr("skip"))
543         return failure();
544 
545       Operation *newOp =
546           rewriter.create(op->getLoc(), op->getName().getIdentifier(),
547                           op->getOperands(), op->getResultTypes());
548       rewriter.modifyOpInPlace(
549           op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); });
550       newOp->setAttr("skip", rewriter.getBoolAttr(true));
551 
552       return success();
553     }
554   };
555 
556   // Remove an operation may introduce the re-visiting of its operands.
557   class EraseOp : public RewritePattern {
558   public:
559     EraseOp(MLIRContext *context)
560         : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
561     LogicalResult matchAndRewrite(Operation *op,
562                                   PatternRewriter &rewriter) const override {
563       rewriter.eraseOp(op);
564       return success();
565     }
566   };
567 
568   // The following two patterns test RewriterBase::replaceAllUsesWith.
569   //
570   // That function replaces all usages of a Block (or a Value) with another one
571   // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver
572   // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its
573   // worklist: when an op is modified, it is added to the worklist. The two
574   // patterns below make the tracking observable: ChangeBlockOp replaces all
575   // usages of a block and that pattern is applied because the corresponding ops
576   // are put on the initial worklist (see above). ImplicitChangeOp does an
577   // unrelated change but ops of the corresponding type are *not* on the initial
578   // worklist, so the effect of the second pattern is only visible if the
579   // tracking and subsequent adding to the worklist actually works.
580 
581   // Replace all usages of the first successor with the second successor.
582   class ChangeBlockOp : public RewritePattern {
583   public:
584     ChangeBlockOp(MLIRContext *context)
585         : RewritePattern("test.change_block_op", /*benefit=*/1, context) {}
586     LogicalResult matchAndRewrite(Operation *op,
587                                   PatternRewriter &rewriter) const override {
588       if (op->getNumSuccessors() < 2)
589         return failure();
590       Block *firstSuccessor = op->getSuccessor(0);
591       Block *secondSuccessor = op->getSuccessor(1);
592       if (firstSuccessor == secondSuccessor)
593         return failure();
594       // This is the function being tested:
595       rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor);
596       // Using the following line instead would make the test fail:
597       // firstSuccessor->replaceAllUsesWith(secondSuccessor);
598       return success();
599     }
600   };
601 
602   // Changes the successor to the parent block.
603   class ImplicitChangeOp : public RewritePattern {
604   public:
605     ImplicitChangeOp(MLIRContext *context)
606         : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {}
607     LogicalResult matchAndRewrite(Operation *op,
608                                   PatternRewriter &rewriter) const override {
609       if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock())
610         return failure();
611       rewriter.modifyOpInPlace(op,
612                                [&]() { op->setSuccessor(op->getBlock(), 0); });
613       return success();
614     }
615   };
616 };
617 
618 struct TestWalkPatternDriver final
619     : PassWrapper<TestWalkPatternDriver, OperationPass<>> {
620   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver)
621 
622   TestWalkPatternDriver() = default;
623   TestWalkPatternDriver(const TestWalkPatternDriver &other)
624       : PassWrapper(other) {}
625 
626   StringRef getArgument() const override {
627     return "test-walk-pattern-rewrite-driver";
628   }
629   StringRef getDescription() const override {
630     return "Run test walk pattern rewrite driver";
631   }
632   void runOnOperation() override {
633     mlir::RewritePatternSet patterns(&getContext());
634 
635     // Patterns for testing the WalkPatternRewriteDriver.
636     patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp,
637                  MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>(
638         &getContext());
639 
640     DumpNotifications dumpListener;
641     walkAndApplyPatterns(getOperation(), std::move(patterns),
642                          dumpNotifications ? &dumpListener : nullptr);
643   }
644 
645   Option<bool> dumpNotifications{
646       *this, "dump-notifications",
647       llvm::cl::desc("Print rewrite listener notifications"),
648       llvm::cl::init(false)};
649 };
650 
651 } // namespace
652 
653 //===----------------------------------------------------------------------===//
654 // ReturnType Driver.
655 //===----------------------------------------------------------------------===//
656 
657 namespace {
658 // Generate ops for each instance where the type can be successfully inferred.
659 template <typename OpTy>
660 static void invokeCreateWithInferredReturnType(Operation *op) {
661   auto *context = op->getContext();
662   auto fop = op->getParentOfType<func::FuncOp>();
663   auto location = UnknownLoc::get(context);
664   OpBuilder b(op);
665   b.setInsertionPointAfter(op);
666 
667   // Use permutations of 2 args as operands.
668   assert(fop.getNumArguments() >= 2);
669   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
670     for (int j = 0; j < e; ++j) {
671       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
672       SmallVector<Type, 2> inferredReturnTypes;
673       if (succeeded(OpTy::inferReturnTypes(
674               context, std::nullopt, values, op->getDiscardableAttrDictionary(),
675               op->getPropertiesStorage(), op->getRegions(),
676               inferredReturnTypes))) {
677         OperationState state(location, OpTy::getOperationName());
678         // TODO: Expand to regions.
679         OpTy::build(b, state, values, op->getAttrs());
680         (void)b.create(state);
681       }
682     }
683   }
684 }
685 
686 static void reifyReturnShape(Operation *op) {
687   OpBuilder b(op);
688 
689   // Use permutations of 2 args as operands.
690   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
691   SmallVector<Value, 2> shapes;
692   if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
693       !llvm::hasSingleElement(shapes))
694     return;
695   for (const auto &it : llvm::enumerate(shapes)) {
696     op->emitRemark() << "value " << it.index() << ": "
697                      << it.value().getDefiningOp();
698   }
699 }
700 
701 struct TestReturnTypeDriver
702     : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> {
703   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver)
704 
705   void getDependentDialects(DialectRegistry &registry) const override {
706     registry.insert<tensor::TensorDialect>();
707   }
708   StringRef getArgument() const final { return "test-return-type"; }
709   StringRef getDescription() const final { return "Run return type functions"; }
710 
711   void runOnOperation() override {
712     if (getOperation().getName() == "testCreateFunctions") {
713       std::vector<Operation *> ops;
714       // Collect ops to avoid triggering on inserted ops.
715       for (auto &op : getOperation().getBody().front())
716         ops.push_back(&op);
717       // Generate test patterns for each, but skip terminator.
718       for (auto *op : llvm::ArrayRef(ops).drop_back()) {
719         // Test create method of each of the Op classes below. The resultant
720         // output would be in reverse order underneath `op` from which
721         // the attributes and regions are used.
722         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
723         invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>(
724             op);
725         invokeCreateWithInferredReturnType<
726             OpWithShapedTypeInferTypeInterfaceOp>(op);
727       };
728       return;
729     }
730     if (getOperation().getName() == "testReifyFunctions") {
731       std::vector<Operation *> ops;
732       // Collect ops to avoid triggering on inserted ops.
733       for (auto &op : getOperation().getBody().front())
734         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
735           ops.push_back(&op);
736       // Generate test patterns for each, but skip terminator.
737       for (auto *op : ops)
738         reifyReturnShape(op);
739     }
740   }
741 };
742 } // namespace
743 
744 namespace {
745 struct TestDerivedAttributeDriver
746     : public PassWrapper<TestDerivedAttributeDriver,
747                          OperationPass<func::FuncOp>> {
748   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver)
749 
750   StringRef getArgument() const final { return "test-derived-attr"; }
751   StringRef getDescription() const final {
752     return "Run test derived attributes";
753   }
754   void runOnOperation() override;
755 };
756 } // namespace
757 
758 void TestDerivedAttributeDriver::runOnOperation() {
759   getOperation().walk([](DerivedAttributeOpInterface dOp) {
760     auto dAttr = dOp.materializeDerivedAttributes();
761     if (!dAttr)
762       return;
763     for (auto d : dAttr)
764       dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue();
765   });
766 }
767 
768 //===----------------------------------------------------------------------===//
769 // Legalization Driver.
770 //===----------------------------------------------------------------------===//
771 
772 namespace {
773 //===----------------------------------------------------------------------===//
774 // Region-Block Rewrite Testing
775 
776 /// This pattern applies a signature conversion to a block inside a detached
777 /// region.
778 struct TestDetachedSignatureConversion : public ConversionPattern {
779   TestDetachedSignatureConversion(MLIRContext *ctx)
780       : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1,
781                           ctx) {}
782 
783   LogicalResult
784   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
785                   ConversionPatternRewriter &rewriter) const final {
786     if (op->getNumRegions() != 1)
787       return failure();
788     OperationState state(op->getLoc(), "test.legal_op_with_region", operands,
789                          op->getResultTypes(), {}, BlockRange());
790     Region *newRegion = state.addRegion();
791     rewriter.inlineRegionBefore(op->getRegion(0), *newRegion,
792                                 newRegion->begin());
793     TypeConverter::SignatureConversion result(newRegion->getNumArguments());
794     for (unsigned i = 0, e = newRegion->getNumArguments(); i < e; ++i)
795       result.addInputs(i, rewriter.getF64Type());
796     rewriter.applySignatureConversion(&newRegion->front(), result);
797     Operation *newOp = rewriter.create(state);
798     rewriter.replaceOp(op, newOp->getResults());
799     return success();
800   }
801 };
802 
803 /// This pattern is a simple pattern that inlines the first region of a given
804 /// operation into the parent region.
805 struct TestRegionRewriteBlockMovement : public ConversionPattern {
806   TestRegionRewriteBlockMovement(MLIRContext *ctx)
807       : ConversionPattern("test.region", 1, ctx) {}
808 
809   LogicalResult
810   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
811                   ConversionPatternRewriter &rewriter) const final {
812     // Inline this region into the parent region.
813     auto &parentRegion = *op->getParentRegion();
814     auto &opRegion = op->getRegion(0);
815     if (op->getDiscardableAttr("legalizer.should_clone"))
816       rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
817     else
818       rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
819 
820     if (op->getDiscardableAttr("legalizer.erase_old_blocks")) {
821       while (!opRegion.empty())
822         rewriter.eraseBlock(&opRegion.front());
823     }
824 
825     // Drop this operation.
826     rewriter.eraseOp(op);
827     return success();
828   }
829 };
830 /// This pattern is a simple pattern that generates a region containing an
831 /// illegal operation.
832 struct TestRegionRewriteUndo : public RewritePattern {
833   TestRegionRewriteUndo(MLIRContext *ctx)
834       : RewritePattern("test.region_builder", 1, ctx) {}
835 
836   LogicalResult matchAndRewrite(Operation *op,
837                                 PatternRewriter &rewriter) const final {
838     // Create the region operation with an entry block containing arguments.
839     OperationState newRegion(op->getLoc(), "test.region");
840     newRegion.addRegion();
841     auto *regionOp = rewriter.create(newRegion);
842     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
843     entryBlock->addArgument(rewriter.getIntegerType(64),
844                             rewriter.getUnknownLoc());
845 
846     // Add an explicitly illegal operation to ensure the conversion fails.
847     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
848     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
849 
850     // Drop this operation.
851     rewriter.eraseOp(op);
852     return success();
853   }
854 };
855 /// A simple pattern that creates a block at the end of the parent region of the
856 /// matched operation.
857 struct TestCreateBlock : public RewritePattern {
858   TestCreateBlock(MLIRContext *ctx)
859       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
860 
861   LogicalResult matchAndRewrite(Operation *op,
862                                 PatternRewriter &rewriter) const final {
863     Region &region = *op->getParentRegion();
864     Type i32Type = rewriter.getIntegerType(32);
865     Location loc = op->getLoc();
866     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
867     rewriter.create<TerminatorOp>(loc);
868     rewriter.eraseOp(op);
869     return success();
870   }
871 };
872 
873 /// A simple pattern that creates a block containing an invalid operation in
874 /// order to trigger the block creation undo mechanism.
875 struct TestCreateIllegalBlock : public RewritePattern {
876   TestCreateIllegalBlock(MLIRContext *ctx)
877       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
878 
879   LogicalResult matchAndRewrite(Operation *op,
880                                 PatternRewriter &rewriter) const final {
881     Region &region = *op->getParentRegion();
882     Type i32Type = rewriter.getIntegerType(32);
883     Location loc = op->getLoc();
884     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
885     // Create an illegal op to ensure the conversion fails.
886     rewriter.create<ILLegalOpF>(loc, i32Type);
887     rewriter.create<TerminatorOp>(loc);
888     rewriter.eraseOp(op);
889     return success();
890   }
891 };
892 
893 /// A simple pattern that tests the undo mechanism when replacing the uses of a
894 /// block argument.
895 struct TestUndoBlockArgReplace : public ConversionPattern {
896   TestUndoBlockArgReplace(MLIRContext *ctx)
897       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
898 
899   LogicalResult
900   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
901                   ConversionPatternRewriter &rewriter) const final {
902     auto illegalOp =
903         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
904     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
905                                         illegalOp->getResult(0));
906     rewriter.modifyOpInPlace(op, [] {});
907     return success();
908   }
909 };
910 
911 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
912 /// This is to test the rollback logic.
913 struct TestUndoMoveOpBefore : public ConversionPattern {
914   TestUndoMoveOpBefore(MLIRContext *ctx)
915       : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
916 
917   LogicalResult
918   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
919                   ConversionPatternRewriter &rewriter) const override {
920     rewriter.moveOpBefore(op, op->getParentOp());
921     // Replace with an illegal op to ensure the conversion fails.
922     rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
923     return success();
924   }
925 };
926 
927 /// A rewrite pattern that tests the undo mechanism when erasing a block.
928 struct TestUndoBlockErase : public ConversionPattern {
929   TestUndoBlockErase(MLIRContext *ctx)
930       : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
931 
932   LogicalResult
933   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
934                   ConversionPatternRewriter &rewriter) const final {
935     Block *secondBlock = &*std::next(op->getRegion(0).begin());
936     rewriter.setInsertionPointToStart(secondBlock);
937     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
938     rewriter.eraseBlock(secondBlock);
939     rewriter.modifyOpInPlace(op, [] {});
940     return success();
941   }
942 };
943 
944 /// A pattern that modifies a property in-place, but keeps the op illegal.
945 struct TestUndoPropertiesModification : public ConversionPattern {
946   TestUndoPropertiesModification(MLIRContext *ctx)
947       : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {}
948   LogicalResult
949   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
950                   ConversionPatternRewriter &rewriter) const final {
951     if (!op->hasAttr("modify_inplace"))
952       return failure();
953     rewriter.modifyOpInPlace(
954         op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); });
955     return success();
956   }
957 };
958 
959 //===----------------------------------------------------------------------===//
960 // Type-Conversion Rewrite Testing
961 
962 /// This patterns erases a region operation that has had a type conversion.
963 struct TestDropOpSignatureConversion : public ConversionPattern {
964   TestDropOpSignatureConversion(MLIRContext *ctx,
965                                 const TypeConverter &converter)
966       : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
967   LogicalResult
968   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
969                   ConversionPatternRewriter &rewriter) const override {
970     Region &region = op->getRegion(0);
971     Block *entry = &region.front();
972 
973     // Convert the original entry arguments.
974     const TypeConverter &converter = *getTypeConverter();
975     TypeConverter::SignatureConversion result(entry->getNumArguments());
976     if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
977                                               result)) ||
978         failed(rewriter.convertRegionTypes(&region, converter, &result)))
979       return failure();
980 
981     // Convert the region signature and just drop the operation.
982     rewriter.eraseOp(op);
983     return success();
984   }
985 };
986 /// This pattern simply updates the operands of the given operation.
987 struct TestPassthroughInvalidOp : public ConversionPattern {
988   TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
989       : ConversionPattern(converter, "test.invalid", 1, ctx) {}
990   LogicalResult
991   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
992                   ConversionPatternRewriter &rewriter) const final {
993     SmallVector<Value> flattened;
994     for (auto it : llvm::enumerate(operands)) {
995       ValueRange range = it.value();
996       if (range.size() == 1) {
997         flattened.push_back(range.front());
998         continue;
999       }
1000 
1001       // This is a 1:N replacement. Insert a test.cast op. (That's what the
1002       // argument materialization used to do.)
1003       flattened.push_back(
1004           rewriter
1005               .create<TestCastOp>(op->getLoc(),
1006                                   op->getOperand(it.index()).getType(), range)
1007               .getResult());
1008     }
1009     rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, flattened,
1010                                              std::nullopt);
1011     return success();
1012   }
1013 };
1014 /// Replace with valid op, but simply drop the operands. This is used in a
1015 /// regression where we used to generate circular unrealized_conversion_cast
1016 /// ops.
1017 struct TestDropAndReplaceInvalidOp : public ConversionPattern {
1018   TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
1019       : ConversionPattern(converter,
1020                           "test.drop_operands_and_replace_with_valid", 1, ctx) {
1021   }
1022   LogicalResult
1023   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1024                   ConversionPatternRewriter &rewriter) const final {
1025     rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, ValueRange(),
1026                                              std::nullopt);
1027     return success();
1028   }
1029 };
1030 /// This pattern handles the case of a split return value.
1031 struct TestSplitReturnType : public ConversionPattern {
1032   TestSplitReturnType(MLIRContext *ctx)
1033       : ConversionPattern("test.return", 1, ctx) {}
1034   LogicalResult
1035   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1036                   ConversionPatternRewriter &rewriter) const final {
1037     // Check for a return of F32.
1038     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
1039       return failure();
1040     rewriter.replaceOpWithNewOp<TestReturnOp>(op, operands[0]);
1041     return success();
1042   }
1043 };
1044 
1045 //===----------------------------------------------------------------------===//
1046 // Multi-Level Type-Conversion Rewrite Testing
1047 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
1048   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
1049       : ConversionPattern("test.type_producer", 1, ctx) {}
1050   LogicalResult
1051   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1052                   ConversionPatternRewriter &rewriter) const final {
1053     // If the type is I32, change the type to F32.
1054     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
1055       return failure();
1056     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
1057     return success();
1058   }
1059 };
1060 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
1061   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
1062       : ConversionPattern("test.type_producer", 1, ctx) {}
1063   LogicalResult
1064   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1065                   ConversionPatternRewriter &rewriter) const final {
1066     // If the type is F32, change the type to F64.
1067     if (!Type(*op->result_type_begin()).isF32())
1068       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
1069     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
1070     return success();
1071   }
1072 };
1073 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
1074   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
1075       : ConversionPattern("test.type_producer", 10, ctx) {}
1076   LogicalResult
1077   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1078                   ConversionPatternRewriter &rewriter) const final {
1079     // Always convert to B16, even though it is not a legal type. This tests
1080     // that values are unmapped correctly.
1081     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
1082     return success();
1083   }
1084 };
1085 struct TestUpdateConsumerType : public ConversionPattern {
1086   TestUpdateConsumerType(MLIRContext *ctx)
1087       : ConversionPattern("test.type_consumer", 1, ctx) {}
1088   LogicalResult
1089   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1090                   ConversionPatternRewriter &rewriter) const final {
1091     // Verify that the incoming operand has been successfully remapped to F64.
1092     if (!operands[0].getType().isF64())
1093       return failure();
1094     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
1095     return success();
1096   }
1097 };
1098 
1099 //===----------------------------------------------------------------------===//
1100 // Non-Root Replacement Rewrite Testing
1101 /// This pattern generates an invalid operation, but replaces it before the
1102 /// pattern is finished. This checks that we don't need to legalize the
1103 /// temporary op.
1104 struct TestNonRootReplacement : public RewritePattern {
1105   TestNonRootReplacement(MLIRContext *ctx)
1106       : RewritePattern("test.replace_non_root", 1, ctx) {}
1107 
1108   LogicalResult matchAndRewrite(Operation *op,
1109                                 PatternRewriter &rewriter) const final {
1110     auto resultType = *op->result_type_begin();
1111     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
1112     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
1113 
1114     rewriter.replaceOp(illegalOp, legalOp);
1115     rewriter.replaceOp(op, illegalOp);
1116     return success();
1117   }
1118 };
1119 
1120 //===----------------------------------------------------------------------===//
1121 // Recursive Rewrite Testing
1122 /// This pattern is applied to the same operation multiple times, but has a
1123 /// bounded recursion.
1124 struct TestBoundedRecursiveRewrite
1125     : public OpRewritePattern<TestRecursiveRewriteOp> {
1126   using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
1127 
1128   void initialize() {
1129     // The conversion target handles bounding the recursion of this pattern.
1130     setHasBoundedRewriteRecursion();
1131   }
1132 
1133   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
1134                                 PatternRewriter &rewriter) const final {
1135     // Decrement the depth of the op in-place.
1136     rewriter.modifyOpInPlace(op, [&] {
1137       op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
1138     });
1139     return success();
1140   }
1141 };
1142 
1143 struct TestNestedOpCreationUndoRewrite
1144     : public OpRewritePattern<IllegalOpWithRegionAnchor> {
1145   using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
1146 
1147   LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
1148                                 PatternRewriter &rewriter) const final {
1149     // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1150     rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1151     return success();
1152   };
1153 };
1154 
1155 // This pattern matches `test.blackhole` and delete this op and its producer.
1156 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
1157   using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
1158 
1159   LogicalResult matchAndRewrite(BlackHoleOp op,
1160                                 PatternRewriter &rewriter) const final {
1161     Operation *producer = op.getOperand().getDefiningOp();
1162     // Always erase the user before the producer, the framework should handle
1163     // this correctly.
1164     rewriter.eraseOp(op);
1165     rewriter.eraseOp(producer);
1166     return success();
1167   };
1168 };
1169 
1170 // This pattern replaces explicitly illegal op with explicitly legal op,
1171 // but in addition creates unregistered operation.
1172 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
1173   using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
1174 
1175   LogicalResult matchAndRewrite(ILLegalOpG op,
1176                                 PatternRewriter &rewriter) const final {
1177     IntegerAttr attr = rewriter.getI32IntegerAttr(0);
1178     Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
1179     rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
1180     return success();
1181   };
1182 };
1183 
1184 class TestEraseOp : public ConversionPattern {
1185 public:
1186   TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op", 1, ctx) {}
1187   LogicalResult
1188   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1189                   ConversionPatternRewriter &rewriter) const final {
1190     // Erase op without replacements.
1191     rewriter.eraseOp(op);
1192     return success();
1193   }
1194 };
1195 
1196 /// This pattern matches a test.duplicate_block_args op and duplicates all
1197 /// block arguments.
1198 class TestDuplicateBlockArgs
1199     : public OpConversionPattern<DuplicateBlockArgsOp> {
1200   using OpConversionPattern<DuplicateBlockArgsOp>::OpConversionPattern;
1201 
1202   LogicalResult
1203   matchAndRewrite(DuplicateBlockArgsOp op, OpAdaptor adaptor,
1204                   ConversionPatternRewriter &rewriter) const override {
1205     if (op.getIsLegal())
1206       return failure();
1207     rewriter.startOpModification(op);
1208     Block *body = &op.getBody().front();
1209     TypeConverter::SignatureConversion result(body->getNumArguments());
1210     for (auto it : llvm::enumerate(body->getArgumentTypes()))
1211       result.addInputs(it.index(), {it.value(), it.value()});
1212     rewriter.applySignatureConversion(body, result, getTypeConverter());
1213     op.setIsLegal(true);
1214     rewriter.finalizeOpModification(op);
1215     return success();
1216   }
1217 };
1218 
1219 /// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid
1220 /// op. The pattern supports 1:N replacements and forwards the replacement
1221 /// values of the single operand as test.valid operands.
1222 class TestRepetitive1ToNConsumer : public ConversionPattern {
1223 public:
1224   TestRepetitive1ToNConsumer(MLIRContext *ctx)
1225       : ConversionPattern("test.repetitive_1_to_n_consumer", 1, ctx) {}
1226   LogicalResult
1227   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1228                   ConversionPatternRewriter &rewriter) const final {
1229     // A single operand is expected.
1230     if (op->getNumOperands() != 1)
1231       return failure();
1232     rewriter.replaceOpWithNewOp<TestValidOp>(op, operands.front());
1233     return success();
1234   }
1235 };
1236 
1237 } // namespace
1238 
1239 namespace {
1240 struct TestTypeConverter : public TypeConverter {
1241   using TypeConverter::TypeConverter;
1242   TestTypeConverter() {
1243     addConversion(convertType);
1244     addArgumentMaterialization(materializeCast);
1245     addSourceMaterialization(materializeCast);
1246   }
1247 
1248   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
1249     // Drop I16 types.
1250     if (t.isSignlessInteger(16))
1251       return success();
1252 
1253     // Convert I64 to F64.
1254     if (t.isSignlessInteger(64)) {
1255       results.push_back(FloatType::getF64(t.getContext()));
1256       return success();
1257     }
1258 
1259     // Convert I42 to I43.
1260     if (t.isInteger(42)) {
1261       results.push_back(IntegerType::get(t.getContext(), 43));
1262       return success();
1263     }
1264 
1265     // Split F32 into F16,F16.
1266     if (t.isF32()) {
1267       results.assign(2, FloatType::getF16(t.getContext()));
1268       return success();
1269     }
1270 
1271     // Drop I24 types.
1272     if (t.isInteger(24)) {
1273       return success();
1274     }
1275 
1276     // Otherwise, convert the type directly.
1277     results.push_back(t);
1278     return success();
1279   }
1280 
1281   /// Hook for materializing a conversion. This is necessary because we generate
1282   /// 1->N type mappings.
1283   static Value materializeCast(OpBuilder &builder, Type resultType,
1284                                ValueRange inputs, Location loc) {
1285     return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1286   }
1287 };
1288 
1289 struct TestLegalizePatternDriver
1290     : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> {
1291   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver)
1292 
1293   StringRef getArgument() const final { return "test-legalize-patterns"; }
1294   StringRef getDescription() const final {
1295     return "Run test dialect legalization patterns";
1296   }
1297   /// The mode of conversion to use with the driver.
1298   enum class ConversionMode { Analysis, Full, Partial };
1299 
1300   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
1301 
1302   void getDependentDialects(DialectRegistry &registry) const override {
1303     registry.insert<func::FuncDialect, test::TestDialect>();
1304   }
1305 
1306   void runOnOperation() override {
1307     TestTypeConverter converter;
1308     mlir::RewritePatternSet patterns(&getContext());
1309     populateWithGenerated(patterns);
1310     patterns
1311         .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1312              TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1313              TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
1314              TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
1315              TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
1316              TestNonRootReplacement, TestBoundedRecursiveRewrite,
1317              TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
1318              TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1319              TestUndoPropertiesModification, TestEraseOp,
1320              TestRepetitive1ToNConsumer>(&getContext());
1321     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
1322                  TestPassthroughInvalidOp>(&getContext(), converter);
1323     patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
1324     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1325                                                               converter);
1326     mlir::populateCallOpTypeConversionPattern(patterns, converter);
1327 
1328     // Define the conversion target used for the test.
1329     ConversionTarget target(getContext());
1330     target.addLegalOp<ModuleOp>();
1331     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
1332                       TerminatorOp, OneRegionOp>();
1333     target.addLegalOp(
1334         OperationName("test.legal_op_with_region", &getContext()));
1335     target
1336         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
1337     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
1338       // Don't allow F32 operands.
1339       return llvm::none_of(op.getOperandTypes(),
1340                            [](Type type) { return type.isF32(); });
1341     });
1342     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1343       return converter.isSignatureLegal(op.getFunctionType()) &&
1344              converter.isLegal(&op.getBody());
1345     });
1346     target.addDynamicallyLegalOp<func::CallOp>(
1347         [&](func::CallOp op) { return converter.isLegal(op); });
1348 
1349     // TestCreateUnregisteredOp creates `arith.constant` operation,
1350     // which was not added to target intentionally to test
1351     // correct error code from conversion driver.
1352     target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
1353 
1354     // Expect the type_producer/type_consumer operations to only operate on f64.
1355     target.addDynamicallyLegalOp<TestTypeProducerOp>(
1356         [](TestTypeProducerOp op) { return op.getType().isF64(); });
1357     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1358       return op.getOperand().getType().isF64();
1359     });
1360 
1361     // Check support for marking certain operations as recursively legal.
1362     target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) {
1363       return static_cast<bool>(
1364           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
1365     });
1366 
1367     // Mark the bound recursion operation as dynamically legal.
1368     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
1369         [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
1370 
1371     // Create a dynamically legal rule that can only be legalized by folding it.
1372     target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
1373         [](TestOpInPlaceSelfFold op) { return op.getFolded(); });
1374 
1375     target.addDynamicallyLegalOp<DuplicateBlockArgsOp>(
1376         [](DuplicateBlockArgsOp op) { return op.getIsLegal(); });
1377 
1378     // Handle a partial conversion.
1379     if (mode == ConversionMode::Partial) {
1380       DenseSet<Operation *> unlegalizedOps;
1381       ConversionConfig config;
1382       DumpNotifications dumpNotifications;
1383       config.listener = &dumpNotifications;
1384       config.unlegalizedOps = &unlegalizedOps;
1385       if (failed(applyPartialConversion(getOperation(), target,
1386                                         std::move(patterns), config))) {
1387         getOperation()->emitRemark() << "applyPartialConversion failed";
1388       }
1389       // Emit remarks for each legalizable operation.
1390       for (auto *op : unlegalizedOps)
1391         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1392       return;
1393     }
1394 
1395     // Handle a full conversion.
1396     if (mode == ConversionMode::Full) {
1397       // Check support for marking unknown operations as dynamically legal.
1398       target.markUnknownOpDynamicallyLegal([](Operation *op) {
1399         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
1400       });
1401 
1402       ConversionConfig config;
1403       DumpNotifications dumpNotifications;
1404       config.listener = &dumpNotifications;
1405       if (failed(applyFullConversion(getOperation(), target,
1406                                      std::move(patterns), config))) {
1407         getOperation()->emitRemark() << "applyFullConversion failed";
1408       }
1409       return;
1410     }
1411 
1412     // Otherwise, handle an analysis conversion.
1413     assert(mode == ConversionMode::Analysis);
1414 
1415     // Analyze the convertible operations.
1416     DenseSet<Operation *> legalizedOps;
1417     ConversionConfig config;
1418     config.legalizableOps = &legalizedOps;
1419     if (failed(applyAnalysisConversion(getOperation(), target,
1420                                        std::move(patterns), config)))
1421       return signalPassFailure();
1422 
1423     // Emit remarks for each legalizable operation.
1424     for (auto *op : legalizedOps)
1425       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
1426   }
1427 
1428   /// The mode of conversion to use.
1429   ConversionMode mode;
1430 };
1431 } // namespace
1432 
1433 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
1434     legalizerConversionMode(
1435         "test-legalize-mode",
1436         llvm::cl::desc("The legalization mode to use with the test driver"),
1437         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
1438         llvm::cl::values(
1439             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
1440                        "analysis", "Perform an analysis conversion"),
1441             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
1442                        "Perform a full conversion"),
1443             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
1444                        "partial", "Perform a partial conversion")));
1445 
1446 //===----------------------------------------------------------------------===//
1447 // ConversionPatternRewriter::getRemappedValue testing. This method is used
1448 // to get the remapped value of an original value that was replaced using
1449 // ConversionPatternRewriter.
1450 namespace {
1451 struct TestRemapValueTypeConverter : public TypeConverter {
1452   using TypeConverter::TypeConverter;
1453 
1454   TestRemapValueTypeConverter() {
1455     addConversion(
1456         [](Float32Type type) { return Float64Type::get(type.getContext()); });
1457     addConversion([](Type type) { return type; });
1458   }
1459 };
1460 
1461 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
1462 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
1463 /// operand twice.
1464 ///
1465 /// Example:
1466 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
1467 /// is replaced with:
1468 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
1469 struct OneVResOneVOperandOp1Converter
1470     : public OpConversionPattern<OneVResOneVOperandOp1> {
1471   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
1472 
1473   LogicalResult
1474   matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
1475                   ConversionPatternRewriter &rewriter) const override {
1476     auto origOps = op.getOperands();
1477     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
1478            "One operand expected");
1479     Value origOp = *origOps.begin();
1480     SmallVector<Value, 2> remappedOperands;
1481     // Replicate the remapped original operand twice. Note that we don't used
1482     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
1483     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1484     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1485 
1486     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
1487                                                        remappedOperands);
1488     return success();
1489   }
1490 };
1491 
1492 /// A rewriter pattern that tests that blocks can be merged.
1493 struct TestRemapValueInRegion
1494     : public OpConversionPattern<TestRemappedValueRegionOp> {
1495   using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
1496 
1497   LogicalResult
1498   matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
1499                   ConversionPatternRewriter &rewriter) const final {
1500     Block &block = op.getBody().front();
1501     Operation *terminator = block.getTerminator();
1502 
1503     // Merge the block into the parent region.
1504     Block *parentBlock = op->getBlock();
1505     Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator());
1506     rewriter.mergeBlocks(&block, parentBlock, ValueRange());
1507     rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange());
1508 
1509     // Replace the results of this operation with the remapped terminator
1510     // values.
1511     SmallVector<Value> terminatorOperands;
1512     if (failed(rewriter.getRemappedValues(terminator->getOperands(),
1513                                           terminatorOperands)))
1514       return failure();
1515 
1516     rewriter.eraseOp(terminator);
1517     rewriter.replaceOp(op, terminatorOperands);
1518     return success();
1519   }
1520 };
1521 
1522 struct TestRemappedValue
1523     : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> {
1524   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue)
1525 
1526   StringRef getArgument() const final { return "test-remapped-value"; }
1527   StringRef getDescription() const final {
1528     return "Test public remapped value mechanism in ConversionPatternRewriter";
1529   }
1530   void runOnOperation() override {
1531     TestRemapValueTypeConverter typeConverter;
1532 
1533     mlir::RewritePatternSet patterns(&getContext());
1534     patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
1535     patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
1536         &getContext());
1537     patterns.add<TestRemapValueInRegion>(typeConverter, &getContext());
1538 
1539     mlir::ConversionTarget target(getContext());
1540     target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>();
1541 
1542     // Expect the type_producer/type_consumer operations to only operate on f64.
1543     target.addDynamicallyLegalOp<TestTypeProducerOp>(
1544         [](TestTypeProducerOp op) { return op.getType().isF64(); });
1545     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1546       return op.getOperand().getType().isF64();
1547     });
1548 
1549     // We make OneVResOneVOperandOp1 legal only when it has more that one
1550     // operand. This will trigger the conversion that will replace one-operand
1551     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
1552     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
1553         [](Operation *op) { return op->getNumOperands() > 1; });
1554 
1555     if (failed(mlir::applyFullConversion(getOperation(), target,
1556                                          std::move(patterns)))) {
1557       signalPassFailure();
1558     }
1559   }
1560 };
1561 } // namespace
1562 
1563 //===----------------------------------------------------------------------===//
1564 // Test patterns without a specific root operation kind
1565 //===----------------------------------------------------------------------===//
1566 
1567 namespace {
1568 /// This pattern matches and removes any operation in the test dialect.
1569 struct RemoveTestDialectOps : public RewritePattern {
1570   RemoveTestDialectOps(MLIRContext *context)
1571       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
1572 
1573   LogicalResult matchAndRewrite(Operation *op,
1574                                 PatternRewriter &rewriter) const override {
1575     if (!isa<TestDialect>(op->getDialect()))
1576       return failure();
1577     rewriter.eraseOp(op);
1578     return success();
1579   }
1580 };
1581 
1582 struct TestUnknownRootOpDriver
1583     : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> {
1584   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver)
1585 
1586   StringRef getArgument() const final {
1587     return "test-legalize-unknown-root-patterns";
1588   }
1589   StringRef getDescription() const final {
1590     return "Test public remapped value mechanism in ConversionPatternRewriter";
1591   }
1592   void runOnOperation() override {
1593     mlir::RewritePatternSet patterns(&getContext());
1594     patterns.add<RemoveTestDialectOps>(&getContext());
1595 
1596     mlir::ConversionTarget target(getContext());
1597     target.addIllegalDialect<TestDialect>();
1598     if (failed(applyPartialConversion(getOperation(), target,
1599                                       std::move(patterns))))
1600       signalPassFailure();
1601   }
1602 };
1603 } // namespace
1604 
1605 //===----------------------------------------------------------------------===//
1606 // Test patterns that uses operations and types defined at runtime
1607 //===----------------------------------------------------------------------===//
1608 
1609 namespace {
1610 /// This pattern matches dynamic operations 'test.one_operand_two_results' and
1611 /// replace them with dynamic operations 'test.generic_dynamic_op'.
1612 struct RewriteDynamicOp : public RewritePattern {
1613   RewriteDynamicOp(MLIRContext *context)
1614       : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1,
1615                        context) {}
1616 
1617   LogicalResult matchAndRewrite(Operation *op,
1618                                 PatternRewriter &rewriter) const override {
1619     assert(op->getName().getStringRef() ==
1620                "test.dynamic_one_operand_two_results" &&
1621            "rewrite pattern should only match operations with the right name");
1622 
1623     OperationState state(op->getLoc(), "test.dynamic_generic",
1624                          op->getOperands(), op->getResultTypes(),
1625                          op->getAttrs());
1626     auto *newOp = rewriter.create(state);
1627     rewriter.replaceOp(op, newOp->getResults());
1628     return success();
1629   }
1630 };
1631 
1632 struct TestRewriteDynamicOpDriver
1633     : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> {
1634   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver)
1635 
1636   void getDependentDialects(DialectRegistry &registry) const override {
1637     registry.insert<TestDialect>();
1638   }
1639   StringRef getArgument() const final { return "test-rewrite-dynamic-op"; }
1640   StringRef getDescription() const final {
1641     return "Test rewritting on dynamic operations";
1642   }
1643   void runOnOperation() override {
1644     RewritePatternSet patterns(&getContext());
1645     patterns.add<RewriteDynamicOp>(&getContext());
1646 
1647     ConversionTarget target(getContext());
1648     target.addIllegalOp(
1649         OperationName("test.dynamic_one_operand_two_results", &getContext()));
1650     target.addLegalOp(OperationName("test.dynamic_generic", &getContext()));
1651     if (failed(applyPartialConversion(getOperation(), target,
1652                                       std::move(patterns))))
1653       signalPassFailure();
1654   }
1655 };
1656 } // end anonymous namespace
1657 
1658 //===----------------------------------------------------------------------===//
1659 // Test type conversions
1660 //===----------------------------------------------------------------------===//
1661 
1662 namespace {
1663 struct TestTypeConversionProducer
1664     : public OpConversionPattern<TestTypeProducerOp> {
1665   using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
1666   LogicalResult
1667   matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
1668                   ConversionPatternRewriter &rewriter) const final {
1669     Type resultType = op.getType();
1670     Type convertedType = getTypeConverter()
1671                              ? getTypeConverter()->convertType(resultType)
1672                              : resultType;
1673     if (isa<FloatType>(resultType))
1674       resultType = rewriter.getF64Type();
1675     else if (resultType.isInteger(16))
1676       resultType = rewriter.getIntegerType(64);
1677     else if (isa<test::TestRecursiveType>(resultType) &&
1678              convertedType != resultType)
1679       resultType = convertedType;
1680     else
1681       return failure();
1682 
1683     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
1684     return success();
1685   }
1686 };
1687 
1688 /// Call signature conversion and then fail the rewrite to trigger the undo
1689 /// mechanism.
1690 struct TestSignatureConversionUndo
1691     : public OpConversionPattern<TestSignatureConversionUndoOp> {
1692   using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
1693 
1694   LogicalResult
1695   matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
1696                   ConversionPatternRewriter &rewriter) const final {
1697     (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
1698     return failure();
1699   }
1700 };
1701 
1702 /// Call signature conversion without providing a type converter to handle
1703 /// materializations.
1704 struct TestTestSignatureConversionNoConverter
1705     : public OpConversionPattern<TestSignatureConversionNoConverterOp> {
1706   TestTestSignatureConversionNoConverter(const TypeConverter &converter,
1707                                          MLIRContext *context)
1708       : OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
1709         converter(converter) {}
1710 
1711   LogicalResult
1712   matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor,
1713                   ConversionPatternRewriter &rewriter) const final {
1714     Region &region = op->getRegion(0);
1715     Block *entry = &region.front();
1716 
1717     // Convert the original entry arguments.
1718     TypeConverter::SignatureConversion result(entry->getNumArguments());
1719     if (failed(
1720             converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
1721       return failure();
1722     rewriter.modifyOpInPlace(op, [&] {
1723       rewriter.applySignatureConversion(&region.front(), result);
1724     });
1725     return success();
1726   }
1727 
1728   const TypeConverter &converter;
1729 };
1730 
1731 /// Just forward the operands to the root op. This is essentially a no-op
1732 /// pattern that is used to trigger target materialization.
1733 struct TestTypeConsumerForward
1734     : public OpConversionPattern<TestTypeConsumerOp> {
1735   using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
1736 
1737   LogicalResult
1738   matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
1739                   ConversionPatternRewriter &rewriter) const final {
1740     rewriter.modifyOpInPlace(op,
1741                              [&] { op->setOperands(adaptor.getOperands()); });
1742     return success();
1743   }
1744 };
1745 
1746 struct TestTypeConversionAnotherProducer
1747     : public OpRewritePattern<TestAnotherTypeProducerOp> {
1748   using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
1749 
1750   LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
1751                                 PatternRewriter &rewriter) const final {
1752     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
1753     return success();
1754   }
1755 };
1756 
1757 struct TestReplaceWithLegalOp : public ConversionPattern {
1758   TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
1759       : ConversionPattern(converter, "test.replace_with_legal_op",
1760                           /*benefit=*/1, ctx) {}
1761   LogicalResult
1762   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1763                   ConversionPatternRewriter &rewriter) const final {
1764     rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
1765     return success();
1766   }
1767 };
1768 
1769 struct TestTypeConversionDriver
1770     : public PassWrapper<TestTypeConversionDriver, OperationPass<>> {
1771   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
1772 
1773   void getDependentDialects(DialectRegistry &registry) const override {
1774     registry.insert<TestDialect>();
1775   }
1776   StringRef getArgument() const final {
1777     return "test-legalize-type-conversion";
1778   }
1779   StringRef getDescription() const final {
1780     return "Test various type conversion functionalities in DialectConversion";
1781   }
1782 
1783   void runOnOperation() override {
1784     // Initialize the type converter.
1785     SmallVector<Type, 2> conversionCallStack;
1786     TypeConverter converter;
1787 
1788     /// Add the legal set of type conversions.
1789     converter.addConversion([](Type type) -> Type {
1790       // Treat F64 as legal.
1791       if (type.isF64())
1792         return type;
1793       // Allow converting BF16/F16/F32 to F64.
1794       if (type.isBF16() || type.isF16() || type.isF32())
1795         return FloatType::getF64(type.getContext());
1796       // Otherwise, the type is illegal.
1797       return nullptr;
1798     });
1799     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
1800       // Drop all integer types.
1801       return success();
1802     });
1803     converter.addConversion(
1804         // Convert a recursive self-referring type into a non-self-referring
1805         // type named "outer_converted_type" that contains a SimpleAType.
1806         [&](test::TestRecursiveType type,
1807             SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
1808           // If the type is already converted, return it to indicate that it is
1809           // legal.
1810           if (type.getName() == "outer_converted_type") {
1811             results.push_back(type);
1812             return success();
1813           }
1814 
1815           conversionCallStack.push_back(type);
1816           auto popConversionCallStack = llvm::make_scope_exit(
1817               [&conversionCallStack]() { conversionCallStack.pop_back(); });
1818 
1819           // If the type is on the call stack more than once (it is there at
1820           // least once because of the _current_ call, which is always the last
1821           // element on the stack), we've hit the recursive case. Just return
1822           // SimpleAType here to create a non-recursive type as a result.
1823           if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(),
1824                                  type)) {
1825             results.push_back(test::SimpleAType::get(type.getContext()));
1826             return success();
1827           }
1828 
1829           // Convert the body recursively.
1830           auto result = test::TestRecursiveType::get(type.getContext(),
1831                                                      "outer_converted_type");
1832           if (failed(result.setBody(converter.convertType(type.getBody()))))
1833             return failure();
1834           results.push_back(result);
1835           return success();
1836         });
1837 
1838     /// Add the legal set of type materializations.
1839     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
1840                                           ValueRange inputs,
1841                                           Location loc) -> Value {
1842       // Allow casting from F64 back to F32.
1843       if (!resultType.isF16() && inputs.size() == 1 &&
1844           inputs[0].getType().isF64())
1845         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1846       // Allow producing an i32 or i64 from nothing.
1847       if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
1848           inputs.empty())
1849         return builder.create<TestTypeProducerOp>(loc, resultType);
1850       // Allow producing an i64 from an integer.
1851       if (isa<IntegerType>(resultType) && inputs.size() == 1 &&
1852           isa<IntegerType>(inputs[0].getType()))
1853         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1854       // Otherwise, fail.
1855       return nullptr;
1856     });
1857 
1858     // Initialize the conversion target.
1859     mlir::ConversionTarget target(getContext());
1860     target.addLegalOp<LegalOpD>();
1861     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
1862       auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType());
1863       return op.getType().isF64() || op.getType().isInteger(64) ||
1864              (recursiveType &&
1865               recursiveType.getName() == "outer_converted_type");
1866     });
1867     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1868       return converter.isSignatureLegal(op.getFunctionType()) &&
1869              converter.isLegal(&op.getBody());
1870     });
1871     target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
1872       // Allow casts from F64 to F32.
1873       return (*op.operand_type_begin()).isF64() && op.getType().isF32();
1874     });
1875     target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>(
1876         [&](TestSignatureConversionNoConverterOp op) {
1877           return converter.isLegal(op.getRegion().front().getArgumentTypes());
1878         });
1879 
1880     // Initialize the set of rewrite patterns.
1881     RewritePatternSet patterns(&getContext());
1882     patterns
1883         .add<TestTypeConsumerForward, TestTypeConversionProducer,
1884              TestSignatureConversionUndo,
1885              TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
1886             converter, &getContext());
1887     patterns.add<TestTypeConversionAnotherProducer>(&getContext());
1888     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1889                                                               converter);
1890 
1891     if (failed(applyPartialConversion(getOperation(), target,
1892                                       std::move(patterns))))
1893       signalPassFailure();
1894   }
1895 };
1896 } // namespace
1897 
1898 //===----------------------------------------------------------------------===//
1899 // Test Target Materialization With No Uses
1900 //===----------------------------------------------------------------------===//
1901 
1902 namespace {
1903 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> {
1904   using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern;
1905 
1906   LogicalResult
1907   matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor,
1908                   ConversionPatternRewriter &rewriter) const final {
1909     rewriter.replaceOp(op, adaptor.getOperands());
1910     return success();
1911   }
1912 };
1913 
1914 struct TestTargetMaterializationWithNoUses
1915     : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> {
1916   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
1917       TestTargetMaterializationWithNoUses)
1918 
1919   StringRef getArgument() const final {
1920     return "test-target-materialization-with-no-uses";
1921   }
1922   StringRef getDescription() const final {
1923     return "Test a special case of target materialization in DialectConversion";
1924   }
1925 
1926   void runOnOperation() override {
1927     TypeConverter converter;
1928     converter.addConversion([](Type t) { return t; });
1929     converter.addConversion([](IntegerType intTy) -> Type {
1930       if (intTy.getWidth() == 16)
1931         return IntegerType::get(intTy.getContext(), 64);
1932       return intTy;
1933     });
1934     converter.addTargetMaterialization(
1935         [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1936           return builder.create<TestCastOp>(loc, type, inputs).getResult();
1937         });
1938 
1939     ConversionTarget target(getContext());
1940     target.addIllegalOp<TestTypeChangerOp>();
1941 
1942     RewritePatternSet patterns(&getContext());
1943     patterns.add<ForwardOperandPattern>(converter, &getContext());
1944 
1945     if (failed(applyPartialConversion(getOperation(), target,
1946                                       std::move(patterns))))
1947       signalPassFailure();
1948   }
1949 };
1950 } // namespace
1951 
1952 //===----------------------------------------------------------------------===//
1953 // Test Block Merging
1954 //===----------------------------------------------------------------------===//
1955 
1956 namespace {
1957 /// A rewriter pattern that tests that blocks can be merged.
1958 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
1959   using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
1960 
1961   LogicalResult
1962   matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
1963                   ConversionPatternRewriter &rewriter) const final {
1964     Block &firstBlock = op.getBody().front();
1965     Operation *branchOp = firstBlock.getTerminator();
1966     Block *secondBlock = &*(std::next(op.getBody().begin()));
1967     auto succOperands = branchOp->getOperands();
1968     SmallVector<Value, 2> replacements(succOperands);
1969     rewriter.eraseOp(branchOp);
1970     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1971     rewriter.modifyOpInPlace(op, [] {});
1972     return success();
1973   }
1974 };
1975 
1976 /// A rewrite pattern to tests the undo mechanism of blocks being merged.
1977 struct TestUndoBlocksMerge : public ConversionPattern {
1978   TestUndoBlocksMerge(MLIRContext *ctx)
1979       : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
1980   LogicalResult
1981   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1982                   ConversionPatternRewriter &rewriter) const final {
1983     Block &firstBlock = op->getRegion(0).front();
1984     Operation *branchOp = firstBlock.getTerminator();
1985     Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
1986     rewriter.setInsertionPointToStart(secondBlock);
1987     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
1988     auto succOperands = branchOp->getOperands();
1989     SmallVector<Value, 2> replacements(succOperands);
1990     rewriter.eraseOp(branchOp);
1991     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1992     rewriter.modifyOpInPlace(op, [] {});
1993     return success();
1994   }
1995 };
1996 
1997 /// A rewrite mechanism to inline the body of the op into its parent, when both
1998 /// ops can have a single block.
1999 struct TestMergeSingleBlockOps
2000     : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
2001   using OpConversionPattern<
2002       SingleBlockImplicitTerminatorOp>::OpConversionPattern;
2003 
2004   LogicalResult
2005   matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
2006                   ConversionPatternRewriter &rewriter) const final {
2007     SingleBlockImplicitTerminatorOp parentOp =
2008         op->getParentOfType<SingleBlockImplicitTerminatorOp>();
2009     if (!parentOp)
2010       return failure();
2011     Block &innerBlock = op.getRegion().front();
2012     TerminatorOp innerTerminator =
2013         cast<TerminatorOp>(innerBlock.getTerminator());
2014     rewriter.inlineBlockBefore(&innerBlock, op);
2015     rewriter.eraseOp(innerTerminator);
2016     rewriter.eraseOp(op);
2017     return success();
2018   }
2019 };
2020 
2021 struct TestMergeBlocksPatternDriver
2022     : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> {
2023   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver)
2024 
2025   StringRef getArgument() const final { return "test-merge-blocks"; }
2026   StringRef getDescription() const final {
2027     return "Test Merging operation in ConversionPatternRewriter";
2028   }
2029   void runOnOperation() override {
2030     MLIRContext *context = &getContext();
2031     mlir::RewritePatternSet patterns(context);
2032     patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
2033         context);
2034     ConversionTarget target(*context);
2035     target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
2036                       TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
2037     target.addIllegalOp<ILLegalOpF>();
2038 
2039     /// Expect the op to have a single block after legalization.
2040     target.addDynamicallyLegalOp<TestMergeBlocksOp>(
2041         [&](TestMergeBlocksOp op) -> bool {
2042           return llvm::hasSingleElement(op.getBody());
2043         });
2044 
2045     /// Only allow `test.br` within test.merge_blocks op.
2046     target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
2047       return op->getParentOfType<TestMergeBlocksOp>();
2048     });
2049 
2050     /// Expect that all nested test.SingleBlockImplicitTerminator ops are
2051     /// inlined.
2052     target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
2053         [&](SingleBlockImplicitTerminatorOp op) -> bool {
2054           return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
2055         });
2056 
2057     DenseSet<Operation *> unlegalizedOps;
2058     ConversionConfig config;
2059     config.unlegalizedOps = &unlegalizedOps;
2060     (void)applyPartialConversion(getOperation(), target, std::move(patterns),
2061                                  config);
2062     for (auto *op : unlegalizedOps)
2063       op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
2064   }
2065 };
2066 } // namespace
2067 
2068 //===----------------------------------------------------------------------===//
2069 // Test Selective Replacement
2070 //===----------------------------------------------------------------------===//
2071 
2072 namespace {
2073 /// A rewrite mechanism to inline the body of the op into its parent, when both
2074 /// ops can have a single block.
2075 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
2076   using OpRewritePattern<TestCastOp>::OpRewritePattern;
2077 
2078   LogicalResult matchAndRewrite(TestCastOp op,
2079                                 PatternRewriter &rewriter) const final {
2080     if (op.getNumOperands() != 2)
2081       return failure();
2082     OperandRange operands = op.getOperands();
2083 
2084     // Replace non-terminator uses with the first operand.
2085     rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) {
2086       return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
2087     });
2088     // Replace everything else with the second operand if the operation isn't
2089     // dead.
2090     rewriter.replaceOp(op, op.getOperand(1));
2091     return success();
2092   }
2093 };
2094 
2095 struct TestSelectiveReplacementPatternDriver
2096     : public PassWrapper<TestSelectiveReplacementPatternDriver,
2097                          OperationPass<>> {
2098   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
2099       TestSelectiveReplacementPatternDriver)
2100 
2101   StringRef getArgument() const final {
2102     return "test-pattern-selective-replacement";
2103   }
2104   StringRef getDescription() const final {
2105     return "Test selective replacement in the PatternRewriter";
2106   }
2107   void runOnOperation() override {
2108     MLIRContext *context = &getContext();
2109     mlir::RewritePatternSet patterns(context);
2110     patterns.add<TestSelectiveOpReplacementPattern>(context);
2111     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
2112   }
2113 };
2114 } // namespace
2115 
2116 //===----------------------------------------------------------------------===//
2117 // PassRegistration
2118 //===----------------------------------------------------------------------===//
2119 
2120 namespace mlir {
2121 namespace test {
2122 void registerPatternsTestPass() {
2123   PassRegistration<TestReturnTypeDriver>();
2124 
2125   PassRegistration<TestDerivedAttributeDriver>();
2126 
2127   PassRegistration<TestGreedyPatternDriver>();
2128   PassRegistration<TestStrictPatternDriver>();
2129   PassRegistration<TestWalkPatternDriver>();
2130 
2131   PassRegistration<TestLegalizePatternDriver>([] {
2132     return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
2133   });
2134 
2135   PassRegistration<TestRemappedValue>();
2136 
2137   PassRegistration<TestUnknownRootOpDriver>();
2138 
2139   PassRegistration<TestTypeConversionDriver>();
2140   PassRegistration<TestTargetMaterializationWithNoUses>();
2141 
2142   PassRegistration<TestRewriteDynamicOpDriver>();
2143 
2144   PassRegistration<TestMergeBlocksPatternDriver>();
2145   PassRegistration<TestSelectiveReplacementPatternDriver>();
2146 }
2147 } // namespace test
2148 } // namespace mlir
2149