xref: /llvm-project/mlir/test/lib/Dialect/Test/TestPatterns.cpp (revision 0f8a6b7d03550cb58cf49535af2de2230abfe997)
1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestOps.h"
11 #include "TestTypes.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BuiltinAttributes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Visitors.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
24 #include "llvm/ADT/ScopeExit.h"
25 #include <cstdint>
26 
27 using namespace mlir;
28 using namespace test;
29 
30 // Native function for testing NativeCodeCall
31 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
32   return choice.getValue() ? input1 : input2;
33 }
34 
35 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
36   rewriter.create<OpI>(loc, input);
37 }
38 
39 static void handleNoResultOp(PatternRewriter &rewriter,
40                              OpSymbolBindingNoResult op) {
41   // Turn the no result op to a one-result op.
42   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
43                                     op.getOperand());
44 }
45 
46 static bool getFirstI32Result(Operation *op, Value &value) {
47   if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
48     return false;
49   value = op->getResult(0);
50   return true;
51 }
52 
53 static Value bindNativeCodeCallResult(Value value) { return value; }
54 
55 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
56                                                               Value input2) {
57   return SmallVector<Value, 2>({input2, input1});
58 }
59 
60 // Test that natives calls are only called once during rewrites.
61 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
62 // This let us check the number of times OpM_Test was called by inspecting
63 // the returned value in the MLIR output.
64 static int64_t opMIncreasingValue = 314159265;
65 static Attribute opMTest(PatternRewriter &rewriter, Value val) {
66   int64_t i = opMIncreasingValue++;
67   return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
68 }
69 
70 namespace {
71 #include "TestPatterns.inc"
72 } // namespace
73 
74 //===----------------------------------------------------------------------===//
75 // Test Reduce Pattern Interface
76 //===----------------------------------------------------------------------===//
77 
78 void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
79   populateWithGenerated(patterns);
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // Canonicalizer Driver.
84 //===----------------------------------------------------------------------===//
85 
86 namespace {
87 struct FoldingPattern : public RewritePattern {
88 public:
89   FoldingPattern(MLIRContext *context)
90       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
91                        /*benefit=*/1, context) {}
92 
93   LogicalResult matchAndRewrite(Operation *op,
94                                 PatternRewriter &rewriter) const override {
95     // Exercise createOrFold API for a single-result operation that is folded
96     // upon construction. The operation being created has an in-place folder,
97     // and it should be still present in the output. Furthermore, the folder
98     // should not crash when attempting to recover the (unchanged) operation
99     // result.
100     Value result = rewriter.createOrFold<TestOpInPlaceFold>(
101         op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0));
102     assert(result);
103     rewriter.replaceOp(op, result);
104     return success();
105   }
106 };
107 
108 /// This pattern creates a foldable operation at the entry point of the block.
109 /// This tests the situation where the operation folder will need to replace an
110 /// operation with a previously created constant that does not initially
111 /// dominate the operation to replace.
112 struct FolderInsertBeforePreviouslyFoldedConstantPattern
113     : public OpRewritePattern<TestCastOp> {
114 public:
115   using OpRewritePattern<TestCastOp>::OpRewritePattern;
116 
117   LogicalResult matchAndRewrite(TestCastOp op,
118                                 PatternRewriter &rewriter) const override {
119     if (!op->hasAttr("test_fold_before_previously_folded_op"))
120       return failure();
121     rewriter.setInsertionPointToStart(op->getBlock());
122 
123     auto constOp = rewriter.create<arith::ConstantOp>(
124         op.getLoc(), rewriter.getBoolAttr(true));
125     rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
126                                             Value(constOp));
127     return success();
128   }
129 };
130 
131 /// This pattern matches test.op_commutative2 with the first operand being
132 /// another test.op_commutative2 with a constant on the right side and fold it
133 /// away by propagating it as its result. This is intend to check that patterns
134 /// are applied after the commutative property moves constant to the right.
135 struct FolderCommutativeOp2WithConstant
136     : public OpRewritePattern<TestCommutative2Op> {
137 public:
138   using OpRewritePattern<TestCommutative2Op>::OpRewritePattern;
139 
140   LogicalResult matchAndRewrite(TestCommutative2Op op,
141                                 PatternRewriter &rewriter) const override {
142     auto operand =
143         dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
144     if (!operand)
145       return failure();
146     Attribute constInput;
147     if (!matchPattern(operand->getOperand(1), m_Constant(&constInput)))
148       return failure();
149     rewriter.replaceOp(op, operand->getOperand(1));
150     return success();
151   }
152 };
153 
154 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer
155 /// attribute with value smaller than MaxVal, it increments the value by 1.
156 template <int MaxVal>
157 struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
158   using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern;
159 
160   LogicalResult matchAndRewrite(AnyAttrOfOp op,
161                                 PatternRewriter &rewriter) const override {
162     auto intAttr = dyn_cast<IntegerAttr>(op.getAttr());
163     if (!intAttr)
164       return failure();
165     int64_t val = intAttr.getInt();
166     if (val >= MaxVal)
167       return failure();
168     rewriter.modifyOpInPlace(
169         op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); });
170     return success();
171   }
172 };
173 
174 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op".
175 struct MakeOpEligible : public RewritePattern {
176   MakeOpEligible(MLIRContext *context)
177       : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {}
178 
179   LogicalResult matchAndRewrite(Operation *op,
180                                 PatternRewriter &rewriter) const override {
181     if (op->hasAttr("eligible"))
182       return failure();
183     rewriter.modifyOpInPlace(
184         op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); });
185     return success();
186   }
187 };
188 
189 /// This pattern hoists eligible ops out of a "test.one_region_op".
190 struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
191   using OpRewritePattern<test::OneRegionOp>::OpRewritePattern;
192 
193   LogicalResult matchAndRewrite(test::OneRegionOp op,
194                                 PatternRewriter &rewriter) const override {
195     Operation *terminator = op.getRegion().front().getTerminator();
196     Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp();
197     if (toBeHoisted->getParentOp() != op)
198       return failure();
199     if (!toBeHoisted->hasAttr("eligible"))
200       return failure();
201     rewriter.moveOpBefore(toBeHoisted, op);
202     return success();
203   }
204 };
205 
206 /// This pattern moves "test.move_before_parent_op" before the parent op.
207 struct MoveBeforeParentOp : public RewritePattern {
208   MoveBeforeParentOp(MLIRContext *context)
209       : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
210 
211   LogicalResult matchAndRewrite(Operation *op,
212                                 PatternRewriter &rewriter) const override {
213     // Do not hoist past functions.
214     if (isa<FunctionOpInterface>(op->getParentOp()))
215       return failure();
216     rewriter.moveOpBefore(op, op->getParentOp());
217     return success();
218   }
219 };
220 
221 /// This pattern moves "test.move_after_parent_op" after the parent op.
222 struct MoveAfterParentOp : public RewritePattern {
223   MoveAfterParentOp(MLIRContext *context)
224       : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {}
225 
226   LogicalResult matchAndRewrite(Operation *op,
227                                 PatternRewriter &rewriter) const override {
228     // Do not hoist past functions.
229     if (isa<FunctionOpInterface>(op->getParentOp()))
230       return failure();
231 
232     int64_t moveForwardBy = 0;
233     if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance"))
234       moveForwardBy = advanceBy.getInt();
235 
236     Operation *moveAfter = op->getParentOp();
237     for (int64_t i = 0; i < moveForwardBy; ++i)
238       moveAfter = moveAfter->getNextNode();
239 
240     rewriter.moveOpAfter(op, moveAfter);
241     return success();
242   }
243 };
244 
245 /// This pattern inlines blocks that are nested in
246 /// "test.inline_blocks_into_parent" into the parent block.
247 struct InlineBlocksIntoParent : public RewritePattern {
248   InlineBlocksIntoParent(MLIRContext *context)
249       : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
250                        context) {}
251 
252   LogicalResult matchAndRewrite(Operation *op,
253                                 PatternRewriter &rewriter) const override {
254     bool changed = false;
255     for (Region &r : op->getRegions()) {
256       while (!r.empty()) {
257         rewriter.inlineBlockBefore(&r.front(), op);
258         changed = true;
259       }
260     }
261     return success(changed);
262   }
263 };
264 
265 /// This pattern splits blocks at "test.split_block_here" and replaces the op
266 /// with a new op (to prevent an infinite loop of block splitting).
267 struct SplitBlockHere : public RewritePattern {
268   SplitBlockHere(MLIRContext *context)
269       : RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
270 
271   LogicalResult matchAndRewrite(Operation *op,
272                                 PatternRewriter &rewriter) const override {
273     rewriter.splitBlock(op->getBlock(), op->getIterator());
274     Operation *newOp = rewriter.create(
275         op->getLoc(),
276         OperationName("test.new_op", op->getContext()).getIdentifier(),
277         op->getOperands(), op->getResultTypes());
278     rewriter.replaceOp(op, newOp);
279     return success();
280   }
281 };
282 
283 /// This pattern clones "test.clone_me" ops.
284 struct CloneOp : public RewritePattern {
285   CloneOp(MLIRContext *context)
286       : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
287 
288   LogicalResult matchAndRewrite(Operation *op,
289                                 PatternRewriter &rewriter) const override {
290     // Do not clone already cloned ops to avoid going into an infinite loop.
291     if (op->hasAttr("was_cloned"))
292       return failure();
293     Operation *cloned = rewriter.clone(*op);
294     cloned->setAttr("was_cloned", rewriter.getUnitAttr());
295     return success();
296   }
297 };
298 
299 /// This pattern clones regions of "test.clone_region_before" ops before the
300 /// parent block.
301 struct CloneRegionBeforeOp : public RewritePattern {
302   CloneRegionBeforeOp(MLIRContext *context)
303       : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {}
304 
305   LogicalResult matchAndRewrite(Operation *op,
306                                 PatternRewriter &rewriter) const override {
307     // Do not clone already cloned ops to avoid going into an infinite loop.
308     if (op->hasAttr("was_cloned"))
309       return failure();
310     for (Region &r : op->getRegions())
311       rewriter.cloneRegionBefore(r, op->getBlock());
312     op->setAttr("was_cloned", rewriter.getUnitAttr());
313     return success();
314   }
315 };
316 
317 /// Replace an operation may introduce the re-visiting of its users.
318 class ReplaceWithNewOp : public RewritePattern {
319 public:
320   ReplaceWithNewOp(MLIRContext *context)
321       : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
322 
323   LogicalResult matchAndRewrite(Operation *op,
324                                 PatternRewriter &rewriter) const override {
325     Operation *newOp;
326     if (op->hasAttr("create_erase_op")) {
327       newOp = rewriter.create(
328           op->getLoc(),
329           OperationName("test.erase_op", op->getContext()).getIdentifier(),
330           ValueRange(), TypeRange());
331     } else {
332       newOp = rewriter.create(
333           op->getLoc(),
334           OperationName("test.new_op", op->getContext()).getIdentifier(),
335           op->getOperands(), op->getResultTypes());
336     }
337     // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
338     // A "notifyOperationReplaced" callback is triggered in either case.
339     rewriter.replaceAllOpUsesWith(op, newOp->getResults());
340     rewriter.eraseOp(op);
341     return success();
342   }
343 };
344 
345 /// Erases the first child block of the matched "test.erase_first_block"
346 /// operation.
347 class EraseFirstBlock : public RewritePattern {
348 public:
349   EraseFirstBlock(MLIRContext *context)
350       : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {}
351 
352   LogicalResult matchAndRewrite(Operation *op,
353                                 PatternRewriter &rewriter) const override {
354     llvm::errs() << "Num regions: " << op->getNumRegions() << "\n";
355     for (Region &r : op->getRegions()) {
356       for (Block &b : r.getBlocks()) {
357         rewriter.eraseBlock(&b);
358         llvm::errs() << "Erasing block: " << b << "\n";
359         return success();
360       }
361     }
362 
363     return failure();
364   }
365 };
366 
367 struct TestGreedyPatternDriver
368     : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> {
369   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver)
370 
371   TestGreedyPatternDriver() = default;
372   TestGreedyPatternDriver(const TestGreedyPatternDriver &other)
373       : PassWrapper(other) {}
374 
375   StringRef getArgument() const final { return "test-greedy-patterns"; }
376   StringRef getDescription() const final { return "Run test dialect patterns"; }
377   void runOnOperation() override {
378     mlir::RewritePatternSet patterns(&getContext());
379     populateWithGenerated(patterns);
380 
381     // Verify named pattern is generated with expected name.
382     patterns.add<FoldingPattern, TestNamedPatternRule,
383                  FolderInsertBeforePreviouslyFoldedConstantPattern,
384                  FolderCommutativeOp2WithConstant, HoistEligibleOps,
385                  MakeOpEligible>(&getContext());
386 
387     // Additional patterns for testing the GreedyPatternRewriteDriver.
388     patterns.insert<IncrementIntAttribute<3>>(&getContext());
389 
390     GreedyRewriteConfig config;
391     config.useTopDownTraversal = this->useTopDownTraversal;
392     config.maxIterations = this->maxIterations;
393     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
394                                        config);
395   }
396 
397   Option<bool> useTopDownTraversal{
398       *this, "top-down",
399       llvm::cl::desc("Seed the worklist in general top-down order"),
400       llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
401   Option<int> maxIterations{
402       *this, "max-iterations",
403       llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
404       llvm::cl::init(GreedyRewriteConfig().maxIterations)};
405 };
406 
407 struct DumpNotifications : public RewriterBase::Listener {
408   void notifyBlockInserted(Block *block, Region *previous,
409                            Region::iterator previousIt) override {
410     llvm::outs() << "notifyBlockInserted";
411     if (block->getParentOp()) {
412       llvm::outs() << " into " << block->getParentOp()->getName() << ": ";
413     } else {
414       llvm::outs() << " into unknown op: ";
415     }
416     if (previous == nullptr) {
417       llvm::outs() << "was unlinked\n";
418     } else {
419       llvm::outs() << "was linked\n";
420     }
421   }
422   void notifyOperationInserted(Operation *op,
423                                OpBuilder::InsertPoint previous) override {
424     llvm::outs() << "notifyOperationInserted: " << op->getName();
425     if (!previous.isSet()) {
426       llvm::outs() << ", was unlinked\n";
427     } else {
428       if (!previous.getPoint().getNodePtr()) {
429         llvm::outs() << ", was linked, exact position unknown\n";
430       } else if (previous.getPoint() == previous.getBlock()->end()) {
431         llvm::outs() << ", was last in block\n";
432       } else {
433         llvm::outs() << ", previous = " << previous.getPoint()->getName()
434                      << "\n";
435       }
436     }
437   }
438   void notifyBlockErased(Block *block) override {
439     llvm::outs() << "notifyBlockErased\n";
440   }
441   void notifyOperationErased(Operation *op) override {
442     llvm::outs() << "notifyOperationErased: " << op->getName() << "\n";
443   }
444   void notifyOperationModified(Operation *op) override {
445     llvm::outs() << "notifyOperationModified: " << op->getName() << "\n";
446   }
447   void notifyOperationReplaced(Operation *op, ValueRange values) override {
448     llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n";
449   }
450 };
451 
452 struct TestStrictPatternDriver
453     : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
454 public:
455   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
456 
457   TestStrictPatternDriver() = default;
458   TestStrictPatternDriver(const TestStrictPatternDriver &other)
459       : PassWrapper(other) {
460     strictMode = other.strictMode;
461   }
462 
463   StringRef getArgument() const final { return "test-strict-pattern-driver"; }
464   StringRef getDescription() const final {
465     return "Test strict mode of pattern driver";
466   }
467 
468   void runOnOperation() override {
469     MLIRContext *ctx = &getContext();
470     mlir::RewritePatternSet patterns(ctx);
471     patterns.add<
472         // clang-format off
473         ChangeBlockOp,
474         CloneOp,
475         CloneRegionBeforeOp,
476         EraseOp,
477         ImplicitChangeOp,
478         InlineBlocksIntoParent,
479         InsertSameOp,
480         MoveBeforeParentOp,
481         ReplaceWithNewOp,
482         SplitBlockHere
483         // clang-format on
484         >(ctx);
485     SmallVector<Operation *> ops;
486     getOperation()->walk([&](Operation *op) {
487       StringRef opName = op->getName().getStringRef();
488       if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
489           opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
490           opName == "test.move_before_parent_op" ||
491           opName == "test.inline_blocks_into_parent" ||
492           opName == "test.split_block_here" || opName == "test.clone_me" ||
493           opName == "test.clone_region_before") {
494         ops.push_back(op);
495       }
496     });
497 
498     DumpNotifications dumpNotifications;
499     GreedyRewriteConfig config;
500     config.listener = &dumpNotifications;
501     if (strictMode == "AnyOp") {
502       config.strictMode = GreedyRewriteStrictness::AnyOp;
503     } else if (strictMode == "ExistingAndNewOps") {
504       config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
505     } else if (strictMode == "ExistingOps") {
506       config.strictMode = GreedyRewriteStrictness::ExistingOps;
507     } else {
508       llvm_unreachable("invalid strictness option");
509     }
510 
511     // Check if these transformations introduce visiting of operations that
512     // are not in the `ops` set (The new created ops are valid). An invalid
513     // operation will trigger the assertion while processing.
514     bool changed = false;
515     bool allErased = false;
516     (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config,
517                                  &changed, &allErased);
518     Builder b(ctx);
519     getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed));
520     getOperation()->setAttr("pattern_driver_all_erased",
521                             b.getBoolAttr(allErased));
522   }
523 
524   Option<std::string> strictMode{
525       *this, "strictness",
526       llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"),
527       llvm::cl::init("AnyOp")};
528 
529 private:
530   // New inserted operation is valid for further transformation.
531   class InsertSameOp : public RewritePattern {
532   public:
533     InsertSameOp(MLIRContext *context)
534         : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {}
535 
536     LogicalResult matchAndRewrite(Operation *op,
537                                   PatternRewriter &rewriter) const override {
538       if (op->hasAttr("skip"))
539         return failure();
540 
541       Operation *newOp =
542           rewriter.create(op->getLoc(), op->getName().getIdentifier(),
543                           op->getOperands(), op->getResultTypes());
544       rewriter.modifyOpInPlace(
545           op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); });
546       newOp->setAttr("skip", rewriter.getBoolAttr(true));
547 
548       return success();
549     }
550   };
551 
552   // Remove an operation may introduce the re-visiting of its operands.
553   class EraseOp : public RewritePattern {
554   public:
555     EraseOp(MLIRContext *context)
556         : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
557     LogicalResult matchAndRewrite(Operation *op,
558                                   PatternRewriter &rewriter) const override {
559       rewriter.eraseOp(op);
560       return success();
561     }
562   };
563 
564   // The following two patterns test RewriterBase::replaceAllUsesWith.
565   //
566   // That function replaces all usages of a Block (or a Value) with another one
567   // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver
568   // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its
569   // worklist: when an op is modified, it is added to the worklist. The two
570   // patterns below make the tracking observable: ChangeBlockOp replaces all
571   // usages of a block and that pattern is applied because the corresponding ops
572   // are put on the initial worklist (see above). ImplicitChangeOp does an
573   // unrelated change but ops of the corresponding type are *not* on the initial
574   // worklist, so the effect of the second pattern is only visible if the
575   // tracking and subsequent adding to the worklist actually works.
576 
577   // Replace all usages of the first successor with the second successor.
578   class ChangeBlockOp : public RewritePattern {
579   public:
580     ChangeBlockOp(MLIRContext *context)
581         : RewritePattern("test.change_block_op", /*benefit=*/1, context) {}
582     LogicalResult matchAndRewrite(Operation *op,
583                                   PatternRewriter &rewriter) const override {
584       if (op->getNumSuccessors() < 2)
585         return failure();
586       Block *firstSuccessor = op->getSuccessor(0);
587       Block *secondSuccessor = op->getSuccessor(1);
588       if (firstSuccessor == secondSuccessor)
589         return failure();
590       // This is the function being tested:
591       rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor);
592       // Using the following line instead would make the test fail:
593       // firstSuccessor->replaceAllUsesWith(secondSuccessor);
594       return success();
595     }
596   };
597 
598   // Changes the successor to the parent block.
599   class ImplicitChangeOp : public RewritePattern {
600   public:
601     ImplicitChangeOp(MLIRContext *context)
602         : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {}
603     LogicalResult matchAndRewrite(Operation *op,
604                                   PatternRewriter &rewriter) const override {
605       if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock())
606         return failure();
607       rewriter.modifyOpInPlace(op,
608                                [&]() { op->setSuccessor(op->getBlock(), 0); });
609       return success();
610     }
611   };
612 };
613 
614 struct TestWalkPatternDriver final
615     : PassWrapper<TestWalkPatternDriver, OperationPass<>> {
616   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver)
617 
618   TestWalkPatternDriver() = default;
619   TestWalkPatternDriver(const TestWalkPatternDriver &other)
620       : PassWrapper(other) {}
621 
622   StringRef getArgument() const override {
623     return "test-walk-pattern-rewrite-driver";
624   }
625   StringRef getDescription() const override {
626     return "Run test walk pattern rewrite driver";
627   }
628   void runOnOperation() override {
629     mlir::RewritePatternSet patterns(&getContext());
630 
631     // Patterns for testing the WalkPatternRewriteDriver.
632     patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp,
633                  MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>(
634         &getContext());
635 
636     DumpNotifications dumpListener;
637     walkAndApplyPatterns(getOperation(), std::move(patterns),
638                          dumpNotifications ? &dumpListener : nullptr);
639   }
640 
641   Option<bool> dumpNotifications{
642       *this, "dump-notifications",
643       llvm::cl::desc("Print rewrite listener notifications"),
644       llvm::cl::init(false)};
645 };
646 
647 } // namespace
648 
649 //===----------------------------------------------------------------------===//
650 // ReturnType Driver.
651 //===----------------------------------------------------------------------===//
652 
653 namespace {
654 // Generate ops for each instance where the type can be successfully inferred.
655 template <typename OpTy>
656 static void invokeCreateWithInferredReturnType(Operation *op) {
657   auto *context = op->getContext();
658   auto fop = op->getParentOfType<func::FuncOp>();
659   auto location = UnknownLoc::get(context);
660   OpBuilder b(op);
661   b.setInsertionPointAfter(op);
662 
663   // Use permutations of 2 args as operands.
664   assert(fop.getNumArguments() >= 2);
665   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
666     for (int j = 0; j < e; ++j) {
667       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
668       SmallVector<Type, 2> inferredReturnTypes;
669       if (succeeded(OpTy::inferReturnTypes(
670               context, std::nullopt, values, op->getDiscardableAttrDictionary(),
671               op->getPropertiesStorage(), op->getRegions(),
672               inferredReturnTypes))) {
673         OperationState state(location, OpTy::getOperationName());
674         // TODO: Expand to regions.
675         OpTy::build(b, state, values, op->getAttrs());
676         (void)b.create(state);
677       }
678     }
679   }
680 }
681 
682 static void reifyReturnShape(Operation *op) {
683   OpBuilder b(op);
684 
685   // Use permutations of 2 args as operands.
686   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
687   SmallVector<Value, 2> shapes;
688   if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
689       !llvm::hasSingleElement(shapes))
690     return;
691   for (const auto &it : llvm::enumerate(shapes)) {
692     op->emitRemark() << "value " << it.index() << ": "
693                      << it.value().getDefiningOp();
694   }
695 }
696 
697 struct TestReturnTypeDriver
698     : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> {
699   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver)
700 
701   void getDependentDialects(DialectRegistry &registry) const override {
702     registry.insert<tensor::TensorDialect>();
703   }
704   StringRef getArgument() const final { return "test-return-type"; }
705   StringRef getDescription() const final { return "Run return type functions"; }
706 
707   void runOnOperation() override {
708     if (getOperation().getName() == "testCreateFunctions") {
709       std::vector<Operation *> ops;
710       // Collect ops to avoid triggering on inserted ops.
711       for (auto &op : getOperation().getBody().front())
712         ops.push_back(&op);
713       // Generate test patterns for each, but skip terminator.
714       for (auto *op : llvm::ArrayRef(ops).drop_back()) {
715         // Test create method of each of the Op classes below. The resultant
716         // output would be in reverse order underneath `op` from which
717         // the attributes and regions are used.
718         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
719         invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>(
720             op);
721         invokeCreateWithInferredReturnType<
722             OpWithShapedTypeInferTypeInterfaceOp>(op);
723       };
724       return;
725     }
726     if (getOperation().getName() == "testReifyFunctions") {
727       std::vector<Operation *> ops;
728       // Collect ops to avoid triggering on inserted ops.
729       for (auto &op : getOperation().getBody().front())
730         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
731           ops.push_back(&op);
732       // Generate test patterns for each, but skip terminator.
733       for (auto *op : ops)
734         reifyReturnShape(op);
735     }
736   }
737 };
738 } // namespace
739 
740 namespace {
741 struct TestDerivedAttributeDriver
742     : public PassWrapper<TestDerivedAttributeDriver,
743                          OperationPass<func::FuncOp>> {
744   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver)
745 
746   StringRef getArgument() const final { return "test-derived-attr"; }
747   StringRef getDescription() const final {
748     return "Run test derived attributes";
749   }
750   void runOnOperation() override;
751 };
752 } // namespace
753 
754 void TestDerivedAttributeDriver::runOnOperation() {
755   getOperation().walk([](DerivedAttributeOpInterface dOp) {
756     auto dAttr = dOp.materializeDerivedAttributes();
757     if (!dAttr)
758       return;
759     for (auto d : dAttr)
760       dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue();
761   });
762 }
763 
764 //===----------------------------------------------------------------------===//
765 // Legalization Driver.
766 //===----------------------------------------------------------------------===//
767 
768 namespace {
769 //===----------------------------------------------------------------------===//
770 // Region-Block Rewrite Testing
771 
772 /// This pattern applies a signature conversion to a block inside a detached
773 /// region.
774 struct TestDetachedSignatureConversion : public ConversionPattern {
775   TestDetachedSignatureConversion(MLIRContext *ctx)
776       : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1,
777                           ctx) {}
778 
779   LogicalResult
780   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
781                   ConversionPatternRewriter &rewriter) const final {
782     if (op->getNumRegions() != 1)
783       return failure();
784     OperationState state(op->getLoc(), "test.legal_op_with_region", operands,
785                          op->getResultTypes(), {}, BlockRange());
786     Region *newRegion = state.addRegion();
787     rewriter.inlineRegionBefore(op->getRegion(0), *newRegion,
788                                 newRegion->begin());
789     TypeConverter::SignatureConversion result(newRegion->getNumArguments());
790     for (unsigned i = 0, e = newRegion->getNumArguments(); i < e; ++i)
791       result.addInputs(i, rewriter.getF64Type());
792     rewriter.applySignatureConversion(&newRegion->front(), result);
793     Operation *newOp = rewriter.create(state);
794     rewriter.replaceOp(op, newOp->getResults());
795     return success();
796   }
797 };
798 
799 /// This pattern is a simple pattern that inlines the first region of a given
800 /// operation into the parent region.
801 struct TestRegionRewriteBlockMovement : public ConversionPattern {
802   TestRegionRewriteBlockMovement(MLIRContext *ctx)
803       : ConversionPattern("test.region", 1, ctx) {}
804 
805   LogicalResult
806   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
807                   ConversionPatternRewriter &rewriter) const final {
808     // Inline this region into the parent region.
809     auto &parentRegion = *op->getParentRegion();
810     auto &opRegion = op->getRegion(0);
811     if (op->getDiscardableAttr("legalizer.should_clone"))
812       rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
813     else
814       rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
815 
816     if (op->getDiscardableAttr("legalizer.erase_old_blocks")) {
817       while (!opRegion.empty())
818         rewriter.eraseBlock(&opRegion.front());
819     }
820 
821     // Drop this operation.
822     rewriter.eraseOp(op);
823     return success();
824   }
825 };
826 /// This pattern is a simple pattern that generates a region containing an
827 /// illegal operation.
828 struct TestRegionRewriteUndo : public RewritePattern {
829   TestRegionRewriteUndo(MLIRContext *ctx)
830       : RewritePattern("test.region_builder", 1, ctx) {}
831 
832   LogicalResult matchAndRewrite(Operation *op,
833                                 PatternRewriter &rewriter) const final {
834     // Create the region operation with an entry block containing arguments.
835     OperationState newRegion(op->getLoc(), "test.region");
836     newRegion.addRegion();
837     auto *regionOp = rewriter.create(newRegion);
838     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
839     entryBlock->addArgument(rewriter.getIntegerType(64),
840                             rewriter.getUnknownLoc());
841 
842     // Add an explicitly illegal operation to ensure the conversion fails.
843     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
844     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
845 
846     // Drop this operation.
847     rewriter.eraseOp(op);
848     return success();
849   }
850 };
851 /// A simple pattern that creates a block at the end of the parent region of the
852 /// matched operation.
853 struct TestCreateBlock : public RewritePattern {
854   TestCreateBlock(MLIRContext *ctx)
855       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
856 
857   LogicalResult matchAndRewrite(Operation *op,
858                                 PatternRewriter &rewriter) const final {
859     Region &region = *op->getParentRegion();
860     Type i32Type = rewriter.getIntegerType(32);
861     Location loc = op->getLoc();
862     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
863     rewriter.create<TerminatorOp>(loc);
864     rewriter.eraseOp(op);
865     return success();
866   }
867 };
868 
869 /// A simple pattern that creates a block containing an invalid operation in
870 /// order to trigger the block creation undo mechanism.
871 struct TestCreateIllegalBlock : public RewritePattern {
872   TestCreateIllegalBlock(MLIRContext *ctx)
873       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
874 
875   LogicalResult matchAndRewrite(Operation *op,
876                                 PatternRewriter &rewriter) const final {
877     Region &region = *op->getParentRegion();
878     Type i32Type = rewriter.getIntegerType(32);
879     Location loc = op->getLoc();
880     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
881     // Create an illegal op to ensure the conversion fails.
882     rewriter.create<ILLegalOpF>(loc, i32Type);
883     rewriter.create<TerminatorOp>(loc);
884     rewriter.eraseOp(op);
885     return success();
886   }
887 };
888 
889 /// A simple pattern that tests the undo mechanism when replacing the uses of a
890 /// block argument.
891 struct TestUndoBlockArgReplace : public ConversionPattern {
892   TestUndoBlockArgReplace(MLIRContext *ctx)
893       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
894 
895   LogicalResult
896   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
897                   ConversionPatternRewriter &rewriter) const final {
898     auto illegalOp =
899         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
900     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
901                                         illegalOp->getResult(0));
902     rewriter.modifyOpInPlace(op, [] {});
903     return success();
904   }
905 };
906 
907 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
908 /// This is to test the rollback logic.
909 struct TestUndoMoveOpBefore : public ConversionPattern {
910   TestUndoMoveOpBefore(MLIRContext *ctx)
911       : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
912 
913   LogicalResult
914   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
915                   ConversionPatternRewriter &rewriter) const override {
916     rewriter.moveOpBefore(op, op->getParentOp());
917     // Replace with an illegal op to ensure the conversion fails.
918     rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
919     return success();
920   }
921 };
922 
923 /// A rewrite pattern that tests the undo mechanism when erasing a block.
924 struct TestUndoBlockErase : public ConversionPattern {
925   TestUndoBlockErase(MLIRContext *ctx)
926       : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
927 
928   LogicalResult
929   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
930                   ConversionPatternRewriter &rewriter) const final {
931     Block *secondBlock = &*std::next(op->getRegion(0).begin());
932     rewriter.setInsertionPointToStart(secondBlock);
933     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
934     rewriter.eraseBlock(secondBlock);
935     rewriter.modifyOpInPlace(op, [] {});
936     return success();
937   }
938 };
939 
940 /// A pattern that modifies a property in-place, but keeps the op illegal.
941 struct TestUndoPropertiesModification : public ConversionPattern {
942   TestUndoPropertiesModification(MLIRContext *ctx)
943       : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {}
944   LogicalResult
945   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
946                   ConversionPatternRewriter &rewriter) const final {
947     if (!op->hasAttr("modify_inplace"))
948       return failure();
949     rewriter.modifyOpInPlace(
950         op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); });
951     return success();
952   }
953 };
954 
955 //===----------------------------------------------------------------------===//
956 // Type-Conversion Rewrite Testing
957 
958 /// This patterns erases a region operation that has had a type conversion.
959 struct TestDropOpSignatureConversion : public ConversionPattern {
960   TestDropOpSignatureConversion(MLIRContext *ctx,
961                                 const TypeConverter &converter)
962       : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
963   LogicalResult
964   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
965                   ConversionPatternRewriter &rewriter) const override {
966     Region &region = op->getRegion(0);
967     Block *entry = &region.front();
968 
969     // Convert the original entry arguments.
970     const TypeConverter &converter = *getTypeConverter();
971     TypeConverter::SignatureConversion result(entry->getNumArguments());
972     if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
973                                               result)) ||
974         failed(rewriter.convertRegionTypes(&region, converter, &result)))
975       return failure();
976 
977     // Convert the region signature and just drop the operation.
978     rewriter.eraseOp(op);
979     return success();
980   }
981 };
982 /// This pattern simply updates the operands of the given operation.
983 struct TestPassthroughInvalidOp : public ConversionPattern {
984   TestPassthroughInvalidOp(MLIRContext *ctx)
985       : ConversionPattern("test.invalid", 1, ctx) {}
986   LogicalResult
987   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
988                   ConversionPatternRewriter &rewriter) const final {
989     rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, operands,
990                                              std::nullopt);
991     return success();
992   }
993 };
994 /// Replace with valid op, but simply drop the operands. This is used in a
995 /// regression where we used to generate circular unrealized_conversion_cast
996 /// ops.
997 struct TestDropAndReplaceInvalidOp : public ConversionPattern {
998   TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
999       : ConversionPattern(converter,
1000                           "test.drop_operands_and_replace_with_valid", 1, ctx) {
1001   }
1002   LogicalResult
1003   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1004                   ConversionPatternRewriter &rewriter) const final {
1005     rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, ValueRange(),
1006                                              std::nullopt);
1007     return success();
1008   }
1009 };
1010 /// This pattern handles the case of a split return value.
1011 struct TestSplitReturnType : public ConversionPattern {
1012   TestSplitReturnType(MLIRContext *ctx)
1013       : ConversionPattern("test.return", 1, ctx) {}
1014   LogicalResult
1015   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1016                   ConversionPatternRewriter &rewriter) const final {
1017     // Check for a return of F32.
1018     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
1019       return failure();
1020 
1021     // Check if the first operation is a cast operation, if it is we use the
1022     // results directly.
1023     auto *defOp = operands[0].getDefiningOp();
1024     if (auto packerOp =
1025             llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) {
1026       rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
1027       return success();
1028     }
1029 
1030     // Otherwise, fail to match.
1031     return failure();
1032   }
1033 };
1034 
1035 //===----------------------------------------------------------------------===//
1036 // Multi-Level Type-Conversion Rewrite Testing
1037 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
1038   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
1039       : ConversionPattern("test.type_producer", 1, ctx) {}
1040   LogicalResult
1041   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1042                   ConversionPatternRewriter &rewriter) const final {
1043     // If the type is I32, change the type to F32.
1044     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
1045       return failure();
1046     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
1047     return success();
1048   }
1049 };
1050 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
1051   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
1052       : ConversionPattern("test.type_producer", 1, ctx) {}
1053   LogicalResult
1054   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1055                   ConversionPatternRewriter &rewriter) const final {
1056     // If the type is F32, change the type to F64.
1057     if (!Type(*op->result_type_begin()).isF32())
1058       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
1059     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
1060     return success();
1061   }
1062 };
1063 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
1064   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
1065       : ConversionPattern("test.type_producer", 10, ctx) {}
1066   LogicalResult
1067   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1068                   ConversionPatternRewriter &rewriter) const final {
1069     // Always convert to B16, even though it is not a legal type. This tests
1070     // that values are unmapped correctly.
1071     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
1072     return success();
1073   }
1074 };
1075 struct TestUpdateConsumerType : public ConversionPattern {
1076   TestUpdateConsumerType(MLIRContext *ctx)
1077       : ConversionPattern("test.type_consumer", 1, ctx) {}
1078   LogicalResult
1079   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1080                   ConversionPatternRewriter &rewriter) const final {
1081     // Verify that the incoming operand has been successfully remapped to F64.
1082     if (!operands[0].getType().isF64())
1083       return failure();
1084     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
1085     return success();
1086   }
1087 };
1088 
1089 //===----------------------------------------------------------------------===//
1090 // Non-Root Replacement Rewrite Testing
1091 /// This pattern generates an invalid operation, but replaces it before the
1092 /// pattern is finished. This checks that we don't need to legalize the
1093 /// temporary op.
1094 struct TestNonRootReplacement : public RewritePattern {
1095   TestNonRootReplacement(MLIRContext *ctx)
1096       : RewritePattern("test.replace_non_root", 1, ctx) {}
1097 
1098   LogicalResult matchAndRewrite(Operation *op,
1099                                 PatternRewriter &rewriter) const final {
1100     auto resultType = *op->result_type_begin();
1101     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
1102     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
1103 
1104     rewriter.replaceOp(illegalOp, legalOp);
1105     rewriter.replaceOp(op, illegalOp);
1106     return success();
1107   }
1108 };
1109 
1110 //===----------------------------------------------------------------------===//
1111 // Recursive Rewrite Testing
1112 /// This pattern is applied to the same operation multiple times, but has a
1113 /// bounded recursion.
1114 struct TestBoundedRecursiveRewrite
1115     : public OpRewritePattern<TestRecursiveRewriteOp> {
1116   using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
1117 
1118   void initialize() {
1119     // The conversion target handles bounding the recursion of this pattern.
1120     setHasBoundedRewriteRecursion();
1121   }
1122 
1123   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
1124                                 PatternRewriter &rewriter) const final {
1125     // Decrement the depth of the op in-place.
1126     rewriter.modifyOpInPlace(op, [&] {
1127       op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
1128     });
1129     return success();
1130   }
1131 };
1132 
1133 struct TestNestedOpCreationUndoRewrite
1134     : public OpRewritePattern<IllegalOpWithRegionAnchor> {
1135   using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
1136 
1137   LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
1138                                 PatternRewriter &rewriter) const final {
1139     // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1140     rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1141     return success();
1142   };
1143 };
1144 
1145 // This pattern matches `test.blackhole` and delete this op and its producer.
1146 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
1147   using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
1148 
1149   LogicalResult matchAndRewrite(BlackHoleOp op,
1150                                 PatternRewriter &rewriter) const final {
1151     Operation *producer = op.getOperand().getDefiningOp();
1152     // Always erase the user before the producer, the framework should handle
1153     // this correctly.
1154     rewriter.eraseOp(op);
1155     rewriter.eraseOp(producer);
1156     return success();
1157   };
1158 };
1159 
1160 // This pattern replaces explicitly illegal op with explicitly legal op,
1161 // but in addition creates unregistered operation.
1162 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
1163   using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
1164 
1165   LogicalResult matchAndRewrite(ILLegalOpG op,
1166                                 PatternRewriter &rewriter) const final {
1167     IntegerAttr attr = rewriter.getI32IntegerAttr(0);
1168     Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
1169     rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
1170     return success();
1171   };
1172 };
1173 
1174 class TestEraseOp : public ConversionPattern {
1175 public:
1176   TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op", 1, ctx) {}
1177   LogicalResult
1178   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1179                   ConversionPatternRewriter &rewriter) const final {
1180     // Erase op without replacements.
1181     rewriter.eraseOp(op);
1182     return success();
1183   }
1184 };
1185 
1186 } // namespace
1187 
1188 namespace {
1189 struct TestTypeConverter : public TypeConverter {
1190   using TypeConverter::TypeConverter;
1191   TestTypeConverter() {
1192     addConversion(convertType);
1193     addArgumentMaterialization(materializeCast);
1194     addSourceMaterialization(materializeCast);
1195   }
1196 
1197   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
1198     // Drop I16 types.
1199     if (t.isSignlessInteger(16))
1200       return success();
1201 
1202     // Convert I64 to F64.
1203     if (t.isSignlessInteger(64)) {
1204       results.push_back(FloatType::getF64(t.getContext()));
1205       return success();
1206     }
1207 
1208     // Convert I42 to I43.
1209     if (t.isInteger(42)) {
1210       results.push_back(IntegerType::get(t.getContext(), 43));
1211       return success();
1212     }
1213 
1214     // Split F32 into F16,F16.
1215     if (t.isF32()) {
1216       results.assign(2, FloatType::getF16(t.getContext()));
1217       return success();
1218     }
1219 
1220     // Otherwise, convert the type directly.
1221     results.push_back(t);
1222     return success();
1223   }
1224 
1225   /// Hook for materializing a conversion. This is necessary because we generate
1226   /// 1->N type mappings.
1227   static Value materializeCast(OpBuilder &builder, Type resultType,
1228                                ValueRange inputs, Location loc) {
1229     return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1230   }
1231 };
1232 
1233 struct TestLegalizePatternDriver
1234     : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> {
1235   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver)
1236 
1237   StringRef getArgument() const final { return "test-legalize-patterns"; }
1238   StringRef getDescription() const final {
1239     return "Run test dialect legalization patterns";
1240   }
1241   /// The mode of conversion to use with the driver.
1242   enum class ConversionMode { Analysis, Full, Partial };
1243 
1244   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
1245 
1246   void getDependentDialects(DialectRegistry &registry) const override {
1247     registry.insert<func::FuncDialect, test::TestDialect>();
1248   }
1249 
1250   void runOnOperation() override {
1251     TestTypeConverter converter;
1252     mlir::RewritePatternSet patterns(&getContext());
1253     populateWithGenerated(patterns);
1254     patterns.add<
1255         TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1256         TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1257         TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
1258         TestSplitReturnType, TestChangeProducerTypeI32ToF32,
1259         TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
1260         TestUpdateConsumerType, TestNonRootReplacement,
1261         TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
1262         TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1263         TestUndoPropertiesModification, TestEraseOp>(&getContext());
1264     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
1265         &getContext(), converter);
1266     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1267                                                               converter);
1268     mlir::populateCallOpTypeConversionPattern(patterns, converter);
1269 
1270     // Define the conversion target used for the test.
1271     ConversionTarget target(getContext());
1272     target.addLegalOp<ModuleOp>();
1273     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
1274                       TerminatorOp, OneRegionOp>();
1275     target.addLegalOp(
1276         OperationName("test.legal_op_with_region", &getContext()));
1277     target
1278         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
1279     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
1280       // Don't allow F32 operands.
1281       return llvm::none_of(op.getOperandTypes(),
1282                            [](Type type) { return type.isF32(); });
1283     });
1284     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1285       return converter.isSignatureLegal(op.getFunctionType()) &&
1286              converter.isLegal(&op.getBody());
1287     });
1288     target.addDynamicallyLegalOp<func::CallOp>(
1289         [&](func::CallOp op) { return converter.isLegal(op); });
1290 
1291     // TestCreateUnregisteredOp creates `arith.constant` operation,
1292     // which was not added to target intentionally to test
1293     // correct error code from conversion driver.
1294     target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
1295 
1296     // Expect the type_producer/type_consumer operations to only operate on f64.
1297     target.addDynamicallyLegalOp<TestTypeProducerOp>(
1298         [](TestTypeProducerOp op) { return op.getType().isF64(); });
1299     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1300       return op.getOperand().getType().isF64();
1301     });
1302 
1303     // Check support for marking certain operations as recursively legal.
1304     target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) {
1305       return static_cast<bool>(
1306           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
1307     });
1308 
1309     // Mark the bound recursion operation as dynamically legal.
1310     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
1311         [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
1312 
1313     // Create a dynamically legal rule that can only be legalized by folding it.
1314     target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
1315         [](TestOpInPlaceSelfFold op) { return op.getFolded(); });
1316 
1317     // Handle a partial conversion.
1318     if (mode == ConversionMode::Partial) {
1319       DenseSet<Operation *> unlegalizedOps;
1320       ConversionConfig config;
1321       DumpNotifications dumpNotifications;
1322       config.listener = &dumpNotifications;
1323       config.unlegalizedOps = &unlegalizedOps;
1324       if (failed(applyPartialConversion(getOperation(), target,
1325                                         std::move(patterns), config))) {
1326         getOperation()->emitRemark() << "applyPartialConversion failed";
1327       }
1328       // Emit remarks for each legalizable operation.
1329       for (auto *op : unlegalizedOps)
1330         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1331       return;
1332     }
1333 
1334     // Handle a full conversion.
1335     if (mode == ConversionMode::Full) {
1336       // Check support for marking unknown operations as dynamically legal.
1337       target.markUnknownOpDynamicallyLegal([](Operation *op) {
1338         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
1339       });
1340 
1341       ConversionConfig config;
1342       DumpNotifications dumpNotifications;
1343       config.listener = &dumpNotifications;
1344       if (failed(applyFullConversion(getOperation(), target,
1345                                      std::move(patterns), config))) {
1346         getOperation()->emitRemark() << "applyFullConversion failed";
1347       }
1348       return;
1349     }
1350 
1351     // Otherwise, handle an analysis conversion.
1352     assert(mode == ConversionMode::Analysis);
1353 
1354     // Analyze the convertible operations.
1355     DenseSet<Operation *> legalizedOps;
1356     ConversionConfig config;
1357     config.legalizableOps = &legalizedOps;
1358     if (failed(applyAnalysisConversion(getOperation(), target,
1359                                        std::move(patterns), config)))
1360       return signalPassFailure();
1361 
1362     // Emit remarks for each legalizable operation.
1363     for (auto *op : legalizedOps)
1364       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
1365   }
1366 
1367   /// The mode of conversion to use.
1368   ConversionMode mode;
1369 };
1370 } // namespace
1371 
1372 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
1373     legalizerConversionMode(
1374         "test-legalize-mode",
1375         llvm::cl::desc("The legalization mode to use with the test driver"),
1376         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
1377         llvm::cl::values(
1378             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
1379                        "analysis", "Perform an analysis conversion"),
1380             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
1381                        "Perform a full conversion"),
1382             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
1383                        "partial", "Perform a partial conversion")));
1384 
1385 //===----------------------------------------------------------------------===//
1386 // ConversionPatternRewriter::getRemappedValue testing. This method is used
1387 // to get the remapped value of an original value that was replaced using
1388 // ConversionPatternRewriter.
1389 namespace {
1390 struct TestRemapValueTypeConverter : public TypeConverter {
1391   using TypeConverter::TypeConverter;
1392 
1393   TestRemapValueTypeConverter() {
1394     addConversion(
1395         [](Float32Type type) { return Float64Type::get(type.getContext()); });
1396     addConversion([](Type type) { return type; });
1397   }
1398 };
1399 
1400 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
1401 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
1402 /// operand twice.
1403 ///
1404 /// Example:
1405 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
1406 /// is replaced with:
1407 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
1408 struct OneVResOneVOperandOp1Converter
1409     : public OpConversionPattern<OneVResOneVOperandOp1> {
1410   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
1411 
1412   LogicalResult
1413   matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
1414                   ConversionPatternRewriter &rewriter) const override {
1415     auto origOps = op.getOperands();
1416     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
1417            "One operand expected");
1418     Value origOp = *origOps.begin();
1419     SmallVector<Value, 2> remappedOperands;
1420     // Replicate the remapped original operand twice. Note that we don't used
1421     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
1422     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1423     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1424 
1425     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
1426                                                        remappedOperands);
1427     return success();
1428   }
1429 };
1430 
1431 /// A rewriter pattern that tests that blocks can be merged.
1432 struct TestRemapValueInRegion
1433     : public OpConversionPattern<TestRemappedValueRegionOp> {
1434   using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
1435 
1436   LogicalResult
1437   matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
1438                   ConversionPatternRewriter &rewriter) const final {
1439     Block &block = op.getBody().front();
1440     Operation *terminator = block.getTerminator();
1441 
1442     // Merge the block into the parent region.
1443     Block *parentBlock = op->getBlock();
1444     Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator());
1445     rewriter.mergeBlocks(&block, parentBlock, ValueRange());
1446     rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange());
1447 
1448     // Replace the results of this operation with the remapped terminator
1449     // values.
1450     SmallVector<Value> terminatorOperands;
1451     if (failed(rewriter.getRemappedValues(terminator->getOperands(),
1452                                           terminatorOperands)))
1453       return failure();
1454 
1455     rewriter.eraseOp(terminator);
1456     rewriter.replaceOp(op, terminatorOperands);
1457     return success();
1458   }
1459 };
1460 
1461 struct TestRemappedValue
1462     : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> {
1463   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue)
1464 
1465   StringRef getArgument() const final { return "test-remapped-value"; }
1466   StringRef getDescription() const final {
1467     return "Test public remapped value mechanism in ConversionPatternRewriter";
1468   }
1469   void runOnOperation() override {
1470     TestRemapValueTypeConverter typeConverter;
1471 
1472     mlir::RewritePatternSet patterns(&getContext());
1473     patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
1474     patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
1475         &getContext());
1476     patterns.add<TestRemapValueInRegion>(typeConverter, &getContext());
1477 
1478     mlir::ConversionTarget target(getContext());
1479     target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>();
1480 
1481     // Expect the type_producer/type_consumer operations to only operate on f64.
1482     target.addDynamicallyLegalOp<TestTypeProducerOp>(
1483         [](TestTypeProducerOp op) { return op.getType().isF64(); });
1484     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1485       return op.getOperand().getType().isF64();
1486     });
1487 
1488     // We make OneVResOneVOperandOp1 legal only when it has more that one
1489     // operand. This will trigger the conversion that will replace one-operand
1490     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
1491     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
1492         [](Operation *op) { return op->getNumOperands() > 1; });
1493 
1494     if (failed(mlir::applyFullConversion(getOperation(), target,
1495                                          std::move(patterns)))) {
1496       signalPassFailure();
1497     }
1498   }
1499 };
1500 } // namespace
1501 
1502 //===----------------------------------------------------------------------===//
1503 // Test patterns without a specific root operation kind
1504 //===----------------------------------------------------------------------===//
1505 
1506 namespace {
1507 /// This pattern matches and removes any operation in the test dialect.
1508 struct RemoveTestDialectOps : public RewritePattern {
1509   RemoveTestDialectOps(MLIRContext *context)
1510       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
1511 
1512   LogicalResult matchAndRewrite(Operation *op,
1513                                 PatternRewriter &rewriter) const override {
1514     if (!isa<TestDialect>(op->getDialect()))
1515       return failure();
1516     rewriter.eraseOp(op);
1517     return success();
1518   }
1519 };
1520 
1521 struct TestUnknownRootOpDriver
1522     : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> {
1523   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver)
1524 
1525   StringRef getArgument() const final {
1526     return "test-legalize-unknown-root-patterns";
1527   }
1528   StringRef getDescription() const final {
1529     return "Test public remapped value mechanism in ConversionPatternRewriter";
1530   }
1531   void runOnOperation() override {
1532     mlir::RewritePatternSet patterns(&getContext());
1533     patterns.add<RemoveTestDialectOps>(&getContext());
1534 
1535     mlir::ConversionTarget target(getContext());
1536     target.addIllegalDialect<TestDialect>();
1537     if (failed(applyPartialConversion(getOperation(), target,
1538                                       std::move(patterns))))
1539       signalPassFailure();
1540   }
1541 };
1542 } // namespace
1543 
1544 //===----------------------------------------------------------------------===//
1545 // Test patterns that uses operations and types defined at runtime
1546 //===----------------------------------------------------------------------===//
1547 
1548 namespace {
1549 /// This pattern matches dynamic operations 'test.one_operand_two_results' and
1550 /// replace them with dynamic operations 'test.generic_dynamic_op'.
1551 struct RewriteDynamicOp : public RewritePattern {
1552   RewriteDynamicOp(MLIRContext *context)
1553       : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1,
1554                        context) {}
1555 
1556   LogicalResult matchAndRewrite(Operation *op,
1557                                 PatternRewriter &rewriter) const override {
1558     assert(op->getName().getStringRef() ==
1559                "test.dynamic_one_operand_two_results" &&
1560            "rewrite pattern should only match operations with the right name");
1561 
1562     OperationState state(op->getLoc(), "test.dynamic_generic",
1563                          op->getOperands(), op->getResultTypes(),
1564                          op->getAttrs());
1565     auto *newOp = rewriter.create(state);
1566     rewriter.replaceOp(op, newOp->getResults());
1567     return success();
1568   }
1569 };
1570 
1571 struct TestRewriteDynamicOpDriver
1572     : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> {
1573   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver)
1574 
1575   void getDependentDialects(DialectRegistry &registry) const override {
1576     registry.insert<TestDialect>();
1577   }
1578   StringRef getArgument() const final { return "test-rewrite-dynamic-op"; }
1579   StringRef getDescription() const final {
1580     return "Test rewritting on dynamic operations";
1581   }
1582   void runOnOperation() override {
1583     RewritePatternSet patterns(&getContext());
1584     patterns.add<RewriteDynamicOp>(&getContext());
1585 
1586     ConversionTarget target(getContext());
1587     target.addIllegalOp(
1588         OperationName("test.dynamic_one_operand_two_results", &getContext()));
1589     target.addLegalOp(OperationName("test.dynamic_generic", &getContext()));
1590     if (failed(applyPartialConversion(getOperation(), target,
1591                                       std::move(patterns))))
1592       signalPassFailure();
1593   }
1594 };
1595 } // end anonymous namespace
1596 
1597 //===----------------------------------------------------------------------===//
1598 // Test type conversions
1599 //===----------------------------------------------------------------------===//
1600 
1601 namespace {
1602 struct TestTypeConversionProducer
1603     : public OpConversionPattern<TestTypeProducerOp> {
1604   using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
1605   LogicalResult
1606   matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
1607                   ConversionPatternRewriter &rewriter) const final {
1608     Type resultType = op.getType();
1609     Type convertedType = getTypeConverter()
1610                              ? getTypeConverter()->convertType(resultType)
1611                              : resultType;
1612     if (isa<FloatType>(resultType))
1613       resultType = rewriter.getF64Type();
1614     else if (resultType.isInteger(16))
1615       resultType = rewriter.getIntegerType(64);
1616     else if (isa<test::TestRecursiveType>(resultType) &&
1617              convertedType != resultType)
1618       resultType = convertedType;
1619     else
1620       return failure();
1621 
1622     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
1623     return success();
1624   }
1625 };
1626 
1627 /// Call signature conversion and then fail the rewrite to trigger the undo
1628 /// mechanism.
1629 struct TestSignatureConversionUndo
1630     : public OpConversionPattern<TestSignatureConversionUndoOp> {
1631   using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
1632 
1633   LogicalResult
1634   matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
1635                   ConversionPatternRewriter &rewriter) const final {
1636     (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
1637     return failure();
1638   }
1639 };
1640 
1641 /// Call signature conversion without providing a type converter to handle
1642 /// materializations.
1643 struct TestTestSignatureConversionNoConverter
1644     : public OpConversionPattern<TestSignatureConversionNoConverterOp> {
1645   TestTestSignatureConversionNoConverter(const TypeConverter &converter,
1646                                          MLIRContext *context)
1647       : OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
1648         converter(converter) {}
1649 
1650   LogicalResult
1651   matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor,
1652                   ConversionPatternRewriter &rewriter) const final {
1653     Region &region = op->getRegion(0);
1654     Block *entry = &region.front();
1655 
1656     // Convert the original entry arguments.
1657     TypeConverter::SignatureConversion result(entry->getNumArguments());
1658     if (failed(
1659             converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
1660       return failure();
1661     rewriter.modifyOpInPlace(op, [&] {
1662       rewriter.applySignatureConversion(&region.front(), result);
1663     });
1664     return success();
1665   }
1666 
1667   const TypeConverter &converter;
1668 };
1669 
1670 /// Just forward the operands to the root op. This is essentially a no-op
1671 /// pattern that is used to trigger target materialization.
1672 struct TestTypeConsumerForward
1673     : public OpConversionPattern<TestTypeConsumerOp> {
1674   using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
1675 
1676   LogicalResult
1677   matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
1678                   ConversionPatternRewriter &rewriter) const final {
1679     rewriter.modifyOpInPlace(op,
1680                              [&] { op->setOperands(adaptor.getOperands()); });
1681     return success();
1682   }
1683 };
1684 
1685 struct TestTypeConversionAnotherProducer
1686     : public OpRewritePattern<TestAnotherTypeProducerOp> {
1687   using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
1688 
1689   LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
1690                                 PatternRewriter &rewriter) const final {
1691     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
1692     return success();
1693   }
1694 };
1695 
1696 struct TestReplaceWithLegalOp : public ConversionPattern {
1697   TestReplaceWithLegalOp(MLIRContext *ctx)
1698       : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
1699   LogicalResult
1700   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1701                   ConversionPatternRewriter &rewriter) const final {
1702     rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
1703     return success();
1704   }
1705 };
1706 
1707 struct TestTypeConversionDriver
1708     : public PassWrapper<TestTypeConversionDriver, OperationPass<>> {
1709   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
1710 
1711   void getDependentDialects(DialectRegistry &registry) const override {
1712     registry.insert<TestDialect>();
1713   }
1714   StringRef getArgument() const final {
1715     return "test-legalize-type-conversion";
1716   }
1717   StringRef getDescription() const final {
1718     return "Test various type conversion functionalities in DialectConversion";
1719   }
1720 
1721   void runOnOperation() override {
1722     // Initialize the type converter.
1723     SmallVector<Type, 2> conversionCallStack;
1724     TypeConverter converter;
1725 
1726     /// Add the legal set of type conversions.
1727     converter.addConversion([](Type type) -> Type {
1728       // Treat F64 as legal.
1729       if (type.isF64())
1730         return type;
1731       // Allow converting BF16/F16/F32 to F64.
1732       if (type.isBF16() || type.isF16() || type.isF32())
1733         return FloatType::getF64(type.getContext());
1734       // Otherwise, the type is illegal.
1735       return nullptr;
1736     });
1737     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
1738       // Drop all integer types.
1739       return success();
1740     });
1741     converter.addConversion(
1742         // Convert a recursive self-referring type into a non-self-referring
1743         // type named "outer_converted_type" that contains a SimpleAType.
1744         [&](test::TestRecursiveType type,
1745             SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
1746           // If the type is already converted, return it to indicate that it is
1747           // legal.
1748           if (type.getName() == "outer_converted_type") {
1749             results.push_back(type);
1750             return success();
1751           }
1752 
1753           conversionCallStack.push_back(type);
1754           auto popConversionCallStack = llvm::make_scope_exit(
1755               [&conversionCallStack]() { conversionCallStack.pop_back(); });
1756 
1757           // If the type is on the call stack more than once (it is there at
1758           // least once because of the _current_ call, which is always the last
1759           // element on the stack), we've hit the recursive case. Just return
1760           // SimpleAType here to create a non-recursive type as a result.
1761           if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(),
1762                                  type)) {
1763             results.push_back(test::SimpleAType::get(type.getContext()));
1764             return success();
1765           }
1766 
1767           // Convert the body recursively.
1768           auto result = test::TestRecursiveType::get(type.getContext(),
1769                                                      "outer_converted_type");
1770           if (failed(result.setBody(converter.convertType(type.getBody()))))
1771             return failure();
1772           results.push_back(result);
1773           return success();
1774         });
1775 
1776     /// Add the legal set of type materializations.
1777     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
1778                                           ValueRange inputs,
1779                                           Location loc) -> Value {
1780       // Allow casting from F64 back to F32.
1781       if (!resultType.isF16() && inputs.size() == 1 &&
1782           inputs[0].getType().isF64())
1783         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1784       // Allow producing an i32 or i64 from nothing.
1785       if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
1786           inputs.empty())
1787         return builder.create<TestTypeProducerOp>(loc, resultType);
1788       // Allow producing an i64 from an integer.
1789       if (isa<IntegerType>(resultType) && inputs.size() == 1 &&
1790           isa<IntegerType>(inputs[0].getType()))
1791         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1792       // Otherwise, fail.
1793       return nullptr;
1794     });
1795 
1796     // Initialize the conversion target.
1797     mlir::ConversionTarget target(getContext());
1798     target.addLegalOp<LegalOpD>();
1799     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
1800       auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType());
1801       return op.getType().isF64() || op.getType().isInteger(64) ||
1802              (recursiveType &&
1803               recursiveType.getName() == "outer_converted_type");
1804     });
1805     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1806       return converter.isSignatureLegal(op.getFunctionType()) &&
1807              converter.isLegal(&op.getBody());
1808     });
1809     target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
1810       // Allow casts from F64 to F32.
1811       return (*op.operand_type_begin()).isF64() && op.getType().isF32();
1812     });
1813     target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>(
1814         [&](TestSignatureConversionNoConverterOp op) {
1815           return converter.isLegal(op.getRegion().front().getArgumentTypes());
1816         });
1817 
1818     // Initialize the set of rewrite patterns.
1819     RewritePatternSet patterns(&getContext());
1820     patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
1821                  TestSignatureConversionUndo,
1822                  TestTestSignatureConversionNoConverter>(converter,
1823                                                          &getContext());
1824     patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
1825         &getContext());
1826     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1827                                                               converter);
1828 
1829     if (failed(applyPartialConversion(getOperation(), target,
1830                                       std::move(patterns))))
1831       signalPassFailure();
1832   }
1833 };
1834 } // namespace
1835 
1836 //===----------------------------------------------------------------------===//
1837 // Test Target Materialization With No Uses
1838 //===----------------------------------------------------------------------===//
1839 
1840 namespace {
1841 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> {
1842   using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern;
1843 
1844   LogicalResult
1845   matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor,
1846                   ConversionPatternRewriter &rewriter) const final {
1847     rewriter.replaceOp(op, adaptor.getOperands());
1848     return success();
1849   }
1850 };
1851 
1852 struct TestTargetMaterializationWithNoUses
1853     : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> {
1854   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
1855       TestTargetMaterializationWithNoUses)
1856 
1857   StringRef getArgument() const final {
1858     return "test-target-materialization-with-no-uses";
1859   }
1860   StringRef getDescription() const final {
1861     return "Test a special case of target materialization in DialectConversion";
1862   }
1863 
1864   void runOnOperation() override {
1865     TypeConverter converter;
1866     converter.addConversion([](Type t) { return t; });
1867     converter.addConversion([](IntegerType intTy) -> Type {
1868       if (intTy.getWidth() == 16)
1869         return IntegerType::get(intTy.getContext(), 64);
1870       return intTy;
1871     });
1872     converter.addTargetMaterialization(
1873         [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1874           return builder.create<TestCastOp>(loc, type, inputs).getResult();
1875         });
1876 
1877     ConversionTarget target(getContext());
1878     target.addIllegalOp<TestTypeChangerOp>();
1879 
1880     RewritePatternSet patterns(&getContext());
1881     patterns.add<ForwardOperandPattern>(converter, &getContext());
1882 
1883     if (failed(applyPartialConversion(getOperation(), target,
1884                                       std::move(patterns))))
1885       signalPassFailure();
1886   }
1887 };
1888 } // namespace
1889 
1890 //===----------------------------------------------------------------------===//
1891 // Test Block Merging
1892 //===----------------------------------------------------------------------===//
1893 
1894 namespace {
1895 /// A rewriter pattern that tests that blocks can be merged.
1896 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
1897   using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
1898 
1899   LogicalResult
1900   matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
1901                   ConversionPatternRewriter &rewriter) const final {
1902     Block &firstBlock = op.getBody().front();
1903     Operation *branchOp = firstBlock.getTerminator();
1904     Block *secondBlock = &*(std::next(op.getBody().begin()));
1905     auto succOperands = branchOp->getOperands();
1906     SmallVector<Value, 2> replacements(succOperands);
1907     rewriter.eraseOp(branchOp);
1908     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1909     rewriter.modifyOpInPlace(op, [] {});
1910     return success();
1911   }
1912 };
1913 
1914 /// A rewrite pattern to tests the undo mechanism of blocks being merged.
1915 struct TestUndoBlocksMerge : public ConversionPattern {
1916   TestUndoBlocksMerge(MLIRContext *ctx)
1917       : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
1918   LogicalResult
1919   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1920                   ConversionPatternRewriter &rewriter) const final {
1921     Block &firstBlock = op->getRegion(0).front();
1922     Operation *branchOp = firstBlock.getTerminator();
1923     Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
1924     rewriter.setInsertionPointToStart(secondBlock);
1925     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
1926     auto succOperands = branchOp->getOperands();
1927     SmallVector<Value, 2> replacements(succOperands);
1928     rewriter.eraseOp(branchOp);
1929     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1930     rewriter.modifyOpInPlace(op, [] {});
1931     return success();
1932   }
1933 };
1934 
1935 /// A rewrite mechanism to inline the body of the op into its parent, when both
1936 /// ops can have a single block.
1937 struct TestMergeSingleBlockOps
1938     : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
1939   using OpConversionPattern<
1940       SingleBlockImplicitTerminatorOp>::OpConversionPattern;
1941 
1942   LogicalResult
1943   matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
1944                   ConversionPatternRewriter &rewriter) const final {
1945     SingleBlockImplicitTerminatorOp parentOp =
1946         op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1947     if (!parentOp)
1948       return failure();
1949     Block &innerBlock = op.getRegion().front();
1950     TerminatorOp innerTerminator =
1951         cast<TerminatorOp>(innerBlock.getTerminator());
1952     rewriter.inlineBlockBefore(&innerBlock, op);
1953     rewriter.eraseOp(innerTerminator);
1954     rewriter.eraseOp(op);
1955     return success();
1956   }
1957 };
1958 
1959 struct TestMergeBlocksPatternDriver
1960     : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> {
1961   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver)
1962 
1963   StringRef getArgument() const final { return "test-merge-blocks"; }
1964   StringRef getDescription() const final {
1965     return "Test Merging operation in ConversionPatternRewriter";
1966   }
1967   void runOnOperation() override {
1968     MLIRContext *context = &getContext();
1969     mlir::RewritePatternSet patterns(context);
1970     patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
1971         context);
1972     ConversionTarget target(*context);
1973     target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
1974                       TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
1975     target.addIllegalOp<ILLegalOpF>();
1976 
1977     /// Expect the op to have a single block after legalization.
1978     target.addDynamicallyLegalOp<TestMergeBlocksOp>(
1979         [&](TestMergeBlocksOp op) -> bool {
1980           return llvm::hasSingleElement(op.getBody());
1981         });
1982 
1983     /// Only allow `test.br` within test.merge_blocks op.
1984     target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
1985       return op->getParentOfType<TestMergeBlocksOp>();
1986     });
1987 
1988     /// Expect that all nested test.SingleBlockImplicitTerminator ops are
1989     /// inlined.
1990     target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
1991         [&](SingleBlockImplicitTerminatorOp op) -> bool {
1992           return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1993         });
1994 
1995     DenseSet<Operation *> unlegalizedOps;
1996     ConversionConfig config;
1997     config.unlegalizedOps = &unlegalizedOps;
1998     (void)applyPartialConversion(getOperation(), target, std::move(patterns),
1999                                  config);
2000     for (auto *op : unlegalizedOps)
2001       op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
2002   }
2003 };
2004 } // namespace
2005 
2006 //===----------------------------------------------------------------------===//
2007 // Test Selective Replacement
2008 //===----------------------------------------------------------------------===//
2009 
2010 namespace {
2011 /// A rewrite mechanism to inline the body of the op into its parent, when both
2012 /// ops can have a single block.
2013 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
2014   using OpRewritePattern<TestCastOp>::OpRewritePattern;
2015 
2016   LogicalResult matchAndRewrite(TestCastOp op,
2017                                 PatternRewriter &rewriter) const final {
2018     if (op.getNumOperands() != 2)
2019       return failure();
2020     OperandRange operands = op.getOperands();
2021 
2022     // Replace non-terminator uses with the first operand.
2023     rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) {
2024       return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
2025     });
2026     // Replace everything else with the second operand if the operation isn't
2027     // dead.
2028     rewriter.replaceOp(op, op.getOperand(1));
2029     return success();
2030   }
2031 };
2032 
2033 struct TestSelectiveReplacementPatternDriver
2034     : public PassWrapper<TestSelectiveReplacementPatternDriver,
2035                          OperationPass<>> {
2036   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
2037       TestSelectiveReplacementPatternDriver)
2038 
2039   StringRef getArgument() const final {
2040     return "test-pattern-selective-replacement";
2041   }
2042   StringRef getDescription() const final {
2043     return "Test selective replacement in the PatternRewriter";
2044   }
2045   void runOnOperation() override {
2046     MLIRContext *context = &getContext();
2047     mlir::RewritePatternSet patterns(context);
2048     patterns.add<TestSelectiveOpReplacementPattern>(context);
2049     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
2050   }
2051 };
2052 } // namespace
2053 
2054 //===----------------------------------------------------------------------===//
2055 // PassRegistration
2056 //===----------------------------------------------------------------------===//
2057 
2058 namespace mlir {
2059 namespace test {
2060 void registerPatternsTestPass() {
2061   PassRegistration<TestReturnTypeDriver>();
2062 
2063   PassRegistration<TestDerivedAttributeDriver>();
2064 
2065   PassRegistration<TestGreedyPatternDriver>();
2066   PassRegistration<TestStrictPatternDriver>();
2067   PassRegistration<TestWalkPatternDriver>();
2068 
2069   PassRegistration<TestLegalizePatternDriver>([] {
2070     return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
2071   });
2072 
2073   PassRegistration<TestRemappedValue>();
2074 
2075   PassRegistration<TestUnknownRootOpDriver>();
2076 
2077   PassRegistration<TestTypeConversionDriver>();
2078   PassRegistration<TestTargetMaterializationWithNoUses>();
2079 
2080   PassRegistration<TestRewriteDynamicOpDriver>();
2081 
2082   PassRegistration<TestMergeBlocksPatternDriver>();
2083   PassRegistration<TestSelectiveReplacementPatternDriver>();
2084 }
2085 } // namespace test
2086 } // namespace mlir
2087