xref: /llvm-project/mlir/test/lib/Dialect/Test/TestPatterns.cpp (revision 9df63b2651b2435c02a7d825953ca2ddc65c778e)
1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestOps.h"
11 #include "TestTypes.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BuiltinAttributes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Visitors.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
24 #include "llvm/ADT/ScopeExit.h"
25 #include <cstdint>
26 
27 using namespace mlir;
28 using namespace test;
29 
30 // Native function for testing NativeCodeCall
31 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
32   return choice.getValue() ? input1 : input2;
33 }
34 
35 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
36   rewriter.create<OpI>(loc, input);
37 }
38 
39 static void handleNoResultOp(PatternRewriter &rewriter,
40                              OpSymbolBindingNoResult op) {
41   // Turn the no result op to a one-result op.
42   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
43                                     op.getOperand());
44 }
45 
46 static bool getFirstI32Result(Operation *op, Value &value) {
47   if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
48     return false;
49   value = op->getResult(0);
50   return true;
51 }
52 
53 static Value bindNativeCodeCallResult(Value value) { return value; }
54 
55 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
56                                                               Value input2) {
57   return SmallVector<Value, 2>({input2, input1});
58 }
59 
60 // Test that natives calls are only called once during rewrites.
61 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
62 // This let us check the number of times OpM_Test was called by inspecting
63 // the returned value in the MLIR output.
64 static int64_t opMIncreasingValue = 314159265;
65 static Attribute opMTest(PatternRewriter &rewriter, Value val) {
66   int64_t i = opMIncreasingValue++;
67   return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
68 }
69 
70 namespace {
71 #include "TestPatterns.inc"
72 } // namespace
73 
74 //===----------------------------------------------------------------------===//
75 // Test Reduce Pattern Interface
76 //===----------------------------------------------------------------------===//
77 
78 void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
79   populateWithGenerated(patterns);
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // Canonicalizer Driver.
84 //===----------------------------------------------------------------------===//
85 
86 namespace {
87 struct FoldingPattern : public RewritePattern {
88 public:
89   FoldingPattern(MLIRContext *context)
90       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
91                        /*benefit=*/1, context) {}
92 
93   LogicalResult matchAndRewrite(Operation *op,
94                                 PatternRewriter &rewriter) const override {
95     // Exercise createOrFold API for a single-result operation that is folded
96     // upon construction. The operation being created has an in-place folder,
97     // and it should be still present in the output. Furthermore, the folder
98     // should not crash when attempting to recover the (unchanged) operation
99     // result.
100     Value result = rewriter.createOrFold<TestOpInPlaceFold>(
101         op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0));
102     assert(result);
103     rewriter.replaceOp(op, result);
104     return success();
105   }
106 };
107 
108 /// This pattern creates a foldable operation at the entry point of the block.
109 /// This tests the situation where the operation folder will need to replace an
110 /// operation with a previously created constant that does not initially
111 /// dominate the operation to replace.
112 struct FolderInsertBeforePreviouslyFoldedConstantPattern
113     : public OpRewritePattern<TestCastOp> {
114 public:
115   using OpRewritePattern<TestCastOp>::OpRewritePattern;
116 
117   LogicalResult matchAndRewrite(TestCastOp op,
118                                 PatternRewriter &rewriter) const override {
119     if (!op->hasAttr("test_fold_before_previously_folded_op"))
120       return failure();
121     rewriter.setInsertionPointToStart(op->getBlock());
122 
123     auto constOp = rewriter.create<arith::ConstantOp>(
124         op.getLoc(), rewriter.getBoolAttr(true));
125     rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
126                                             Value(constOp));
127     return success();
128   }
129 };
130 
131 /// This pattern matches test.op_commutative2 with the first operand being
132 /// another test.op_commutative2 with a constant on the right side and fold it
133 /// away by propagating it as its result. This is intend to check that patterns
134 /// are applied after the commutative property moves constant to the right.
135 struct FolderCommutativeOp2WithConstant
136     : public OpRewritePattern<TestCommutative2Op> {
137 public:
138   using OpRewritePattern<TestCommutative2Op>::OpRewritePattern;
139 
140   LogicalResult matchAndRewrite(TestCommutative2Op op,
141                                 PatternRewriter &rewriter) const override {
142     auto operand =
143         dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
144     if (!operand)
145       return failure();
146     Attribute constInput;
147     if (!matchPattern(operand->getOperand(1), m_Constant(&constInput)))
148       return failure();
149     rewriter.replaceOp(op, operand->getOperand(1));
150     return success();
151   }
152 };
153 
154 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer
155 /// attribute with value smaller than MaxVal, it increments the value by 1.
156 template <int MaxVal>
157 struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
158   using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern;
159 
160   LogicalResult matchAndRewrite(AnyAttrOfOp op,
161                                 PatternRewriter &rewriter) const override {
162     auto intAttr = dyn_cast<IntegerAttr>(op.getAttr());
163     if (!intAttr)
164       return failure();
165     int64_t val = intAttr.getInt();
166     if (val >= MaxVal)
167       return failure();
168     rewriter.modifyOpInPlace(
169         op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); });
170     return success();
171   }
172 };
173 
174 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op".
175 struct MakeOpEligible : public RewritePattern {
176   MakeOpEligible(MLIRContext *context)
177       : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {}
178 
179   LogicalResult matchAndRewrite(Operation *op,
180                                 PatternRewriter &rewriter) const override {
181     if (op->hasAttr("eligible"))
182       return failure();
183     rewriter.modifyOpInPlace(
184         op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); });
185     return success();
186   }
187 };
188 
189 /// This pattern hoists eligible ops out of a "test.one_region_op".
190 struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
191   using OpRewritePattern<test::OneRegionOp>::OpRewritePattern;
192 
193   LogicalResult matchAndRewrite(test::OneRegionOp op,
194                                 PatternRewriter &rewriter) const override {
195     Operation *terminator = op.getRegion().front().getTerminator();
196     Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp();
197     if (toBeHoisted->getParentOp() != op)
198       return failure();
199     if (!toBeHoisted->hasAttr("eligible"))
200       return failure();
201     rewriter.moveOpBefore(toBeHoisted, op);
202     return success();
203   }
204 };
205 
206 /// This pattern moves "test.move_before_parent_op" before the parent op.
207 struct MoveBeforeParentOp : public RewritePattern {
208   MoveBeforeParentOp(MLIRContext *context)
209       : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
210 
211   LogicalResult matchAndRewrite(Operation *op,
212                                 PatternRewriter &rewriter) const override {
213     // Do not hoist past functions.
214     if (isa<FunctionOpInterface>(op->getParentOp()))
215       return failure();
216     rewriter.moveOpBefore(op, op->getParentOp());
217     return success();
218   }
219 };
220 
221 /// This pattern moves "test.move_after_parent_op" after the parent op.
222 struct MoveAfterParentOp : public RewritePattern {
223   MoveAfterParentOp(MLIRContext *context)
224       : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {}
225 
226   LogicalResult matchAndRewrite(Operation *op,
227                                 PatternRewriter &rewriter) const override {
228     // Do not hoist past functions.
229     if (isa<FunctionOpInterface>(op->getParentOp()))
230       return failure();
231 
232     int64_t moveForwardBy = 0;
233     if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance"))
234       moveForwardBy = advanceBy.getInt();
235 
236     Operation *moveAfter = op->getParentOp();
237     for (int64_t i = 0; i < moveForwardBy; ++i)
238       moveAfter = moveAfter->getNextNode();
239 
240     rewriter.moveOpAfter(op, moveAfter);
241     return success();
242   }
243 };
244 
245 /// This pattern inlines blocks that are nested in
246 /// "test.inline_blocks_into_parent" into the parent block.
247 struct InlineBlocksIntoParent : public RewritePattern {
248   InlineBlocksIntoParent(MLIRContext *context)
249       : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
250                        context) {}
251 
252   LogicalResult matchAndRewrite(Operation *op,
253                                 PatternRewriter &rewriter) const override {
254     bool changed = false;
255     for (Region &r : op->getRegions()) {
256       while (!r.empty()) {
257         rewriter.inlineBlockBefore(&r.front(), op);
258         changed = true;
259       }
260     }
261     return success(changed);
262   }
263 };
264 
265 /// This pattern splits blocks at "test.split_block_here" and replaces the op
266 /// with a new op (to prevent an infinite loop of block splitting).
267 struct SplitBlockHere : public RewritePattern {
268   SplitBlockHere(MLIRContext *context)
269       : RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
270 
271   LogicalResult matchAndRewrite(Operation *op,
272                                 PatternRewriter &rewriter) const override {
273     rewriter.splitBlock(op->getBlock(), op->getIterator());
274     Operation *newOp = rewriter.create(
275         op->getLoc(),
276         OperationName("test.new_op", op->getContext()).getIdentifier(),
277         op->getOperands(), op->getResultTypes());
278     rewriter.replaceOp(op, newOp);
279     return success();
280   }
281 };
282 
283 /// This pattern clones "test.clone_me" ops.
284 struct CloneOp : public RewritePattern {
285   CloneOp(MLIRContext *context)
286       : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
287 
288   LogicalResult matchAndRewrite(Operation *op,
289                                 PatternRewriter &rewriter) const override {
290     // Do not clone already cloned ops to avoid going into an infinite loop.
291     if (op->hasAttr("was_cloned"))
292       return failure();
293     Operation *cloned = rewriter.clone(*op);
294     cloned->setAttr("was_cloned", rewriter.getUnitAttr());
295     return success();
296   }
297 };
298 
299 /// This pattern clones regions of "test.clone_region_before" ops before the
300 /// parent block.
301 struct CloneRegionBeforeOp : public RewritePattern {
302   CloneRegionBeforeOp(MLIRContext *context)
303       : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {}
304 
305   LogicalResult matchAndRewrite(Operation *op,
306                                 PatternRewriter &rewriter) const override {
307     // Do not clone already cloned ops to avoid going into an infinite loop.
308     if (op->hasAttr("was_cloned"))
309       return failure();
310     for (Region &r : op->getRegions())
311       rewriter.cloneRegionBefore(r, op->getBlock());
312     op->setAttr("was_cloned", rewriter.getUnitAttr());
313     return success();
314   }
315 };
316 
317 /// Replace an operation may introduce the re-visiting of its users.
318 class ReplaceWithNewOp : public RewritePattern {
319 public:
320   ReplaceWithNewOp(MLIRContext *context)
321       : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
322 
323   LogicalResult matchAndRewrite(Operation *op,
324                                 PatternRewriter &rewriter) const override {
325     Operation *newOp;
326     if (op->hasAttr("create_erase_op")) {
327       newOp = rewriter.create(
328           op->getLoc(),
329           OperationName("test.erase_op", op->getContext()).getIdentifier(),
330           ValueRange(), TypeRange());
331     } else {
332       newOp = rewriter.create(
333           op->getLoc(),
334           OperationName("test.new_op", op->getContext()).getIdentifier(),
335           op->getOperands(), op->getResultTypes());
336     }
337     // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
338     // A "notifyOperationReplaced" callback is triggered in either case.
339     rewriter.replaceAllOpUsesWith(op, newOp->getResults());
340     rewriter.eraseOp(op);
341     return success();
342   }
343 };
344 
345 /// Erases the first child block of the matched "test.erase_first_block"
346 /// operation.
347 class EraseFirstBlock : public RewritePattern {
348 public:
349   EraseFirstBlock(MLIRContext *context)
350       : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {}
351 
352   LogicalResult matchAndRewrite(Operation *op,
353                                 PatternRewriter &rewriter) const override {
354     for (Region &r : op->getRegions()) {
355       for (Block &b : r.getBlocks()) {
356         rewriter.eraseBlock(&b);
357         return success();
358       }
359     }
360 
361     return failure();
362   }
363 };
364 
365 struct TestGreedyPatternDriver
366     : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> {
367   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver)
368 
369   TestGreedyPatternDriver() = default;
370   TestGreedyPatternDriver(const TestGreedyPatternDriver &other)
371       : PassWrapper(other) {}
372 
373   StringRef getArgument() const final { return "test-greedy-patterns"; }
374   StringRef getDescription() const final { return "Run test dialect patterns"; }
375   void runOnOperation() override {
376     mlir::RewritePatternSet patterns(&getContext());
377     populateWithGenerated(patterns);
378 
379     // Verify named pattern is generated with expected name.
380     patterns.add<FoldingPattern, TestNamedPatternRule,
381                  FolderInsertBeforePreviouslyFoldedConstantPattern,
382                  FolderCommutativeOp2WithConstant, HoistEligibleOps,
383                  MakeOpEligible>(&getContext());
384 
385     // Additional patterns for testing the GreedyPatternRewriteDriver.
386     patterns.insert<IncrementIntAttribute<3>>(&getContext());
387 
388     GreedyRewriteConfig config;
389     config.useTopDownTraversal = this->useTopDownTraversal;
390     config.maxIterations = this->maxIterations;
391     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
392                                        config);
393   }
394 
395   Option<bool> useTopDownTraversal{
396       *this, "top-down",
397       llvm::cl::desc("Seed the worklist in general top-down order"),
398       llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
399   Option<int> maxIterations{
400       *this, "max-iterations",
401       llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
402       llvm::cl::init(GreedyRewriteConfig().maxIterations)};
403 };
404 
405 struct DumpNotifications : public RewriterBase::Listener {
406   void notifyBlockInserted(Block *block, Region *previous,
407                            Region::iterator previousIt) override {
408     llvm::outs() << "notifyBlockInserted";
409     if (block->getParentOp()) {
410       llvm::outs() << " into " << block->getParentOp()->getName() << ": ";
411     } else {
412       llvm::outs() << " into unknown op: ";
413     }
414     if (previous == nullptr) {
415       llvm::outs() << "was unlinked\n";
416     } else {
417       llvm::outs() << "was linked\n";
418     }
419   }
420   void notifyOperationInserted(Operation *op,
421                                OpBuilder::InsertPoint previous) override {
422     llvm::outs() << "notifyOperationInserted: " << op->getName();
423     if (!previous.isSet()) {
424       llvm::outs() << ", was unlinked\n";
425     } else {
426       if (!previous.getPoint().getNodePtr()) {
427         llvm::outs() << ", was linked, exact position unknown\n";
428       } else if (previous.getPoint() == previous.getBlock()->end()) {
429         llvm::outs() << ", was last in block\n";
430       } else {
431         llvm::outs() << ", previous = " << previous.getPoint()->getName()
432                      << "\n";
433       }
434     }
435   }
436   void notifyBlockErased(Block *block) override {
437     llvm::outs() << "notifyBlockErased\n";
438   }
439   void notifyOperationErased(Operation *op) override {
440     llvm::outs() << "notifyOperationErased: " << op->getName() << "\n";
441   }
442   void notifyOperationModified(Operation *op) override {
443     llvm::outs() << "notifyOperationModified: " << op->getName() << "\n";
444   }
445   void notifyOperationReplaced(Operation *op, ValueRange values) override {
446     llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n";
447   }
448 };
449 
450 struct TestStrictPatternDriver
451     : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
452 public:
453   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
454 
455   TestStrictPatternDriver() = default;
456   TestStrictPatternDriver(const TestStrictPatternDriver &other)
457       : PassWrapper(other) {
458     strictMode = other.strictMode;
459   }
460 
461   StringRef getArgument() const final { return "test-strict-pattern-driver"; }
462   StringRef getDescription() const final {
463     return "Test strict mode of pattern driver";
464   }
465 
466   void runOnOperation() override {
467     MLIRContext *ctx = &getContext();
468     mlir::RewritePatternSet patterns(ctx);
469     patterns.add<
470         // clang-format off
471         ChangeBlockOp,
472         CloneOp,
473         CloneRegionBeforeOp,
474         EraseOp,
475         ImplicitChangeOp,
476         InlineBlocksIntoParent,
477         InsertSameOp,
478         MoveBeforeParentOp,
479         ReplaceWithNewOp,
480         SplitBlockHere
481         // clang-format on
482         >(ctx);
483     SmallVector<Operation *> ops;
484     getOperation()->walk([&](Operation *op) {
485       StringRef opName = op->getName().getStringRef();
486       if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
487           opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
488           opName == "test.move_before_parent_op" ||
489           opName == "test.inline_blocks_into_parent" ||
490           opName == "test.split_block_here" || opName == "test.clone_me" ||
491           opName == "test.clone_region_before") {
492         ops.push_back(op);
493       }
494     });
495 
496     DumpNotifications dumpNotifications;
497     GreedyRewriteConfig config;
498     config.listener = &dumpNotifications;
499     if (strictMode == "AnyOp") {
500       config.strictMode = GreedyRewriteStrictness::AnyOp;
501     } else if (strictMode == "ExistingAndNewOps") {
502       config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
503     } else if (strictMode == "ExistingOps") {
504       config.strictMode = GreedyRewriteStrictness::ExistingOps;
505     } else {
506       llvm_unreachable("invalid strictness option");
507     }
508 
509     // Check if these transformations introduce visiting of operations that
510     // are not in the `ops` set (The new created ops are valid). An invalid
511     // operation will trigger the assertion while processing.
512     bool changed = false;
513     bool allErased = false;
514     (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config,
515                                  &changed, &allErased);
516     Builder b(ctx);
517     getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed));
518     getOperation()->setAttr("pattern_driver_all_erased",
519                             b.getBoolAttr(allErased));
520   }
521 
522   Option<std::string> strictMode{
523       *this, "strictness",
524       llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"),
525       llvm::cl::init("AnyOp")};
526 
527 private:
528   // New inserted operation is valid for further transformation.
529   class InsertSameOp : public RewritePattern {
530   public:
531     InsertSameOp(MLIRContext *context)
532         : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {}
533 
534     LogicalResult matchAndRewrite(Operation *op,
535                                   PatternRewriter &rewriter) const override {
536       if (op->hasAttr("skip"))
537         return failure();
538 
539       Operation *newOp =
540           rewriter.create(op->getLoc(), op->getName().getIdentifier(),
541                           op->getOperands(), op->getResultTypes());
542       rewriter.modifyOpInPlace(
543           op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); });
544       newOp->setAttr("skip", rewriter.getBoolAttr(true));
545 
546       return success();
547     }
548   };
549 
550   // Remove an operation may introduce the re-visiting of its operands.
551   class EraseOp : public RewritePattern {
552   public:
553     EraseOp(MLIRContext *context)
554         : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
555     LogicalResult matchAndRewrite(Operation *op,
556                                   PatternRewriter &rewriter) const override {
557       rewriter.eraseOp(op);
558       return success();
559     }
560   };
561 
562   // The following two patterns test RewriterBase::replaceAllUsesWith.
563   //
564   // That function replaces all usages of a Block (or a Value) with another one
565   // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver
566   // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its
567   // worklist: when an op is modified, it is added to the worklist. The two
568   // patterns below make the tracking observable: ChangeBlockOp replaces all
569   // usages of a block and that pattern is applied because the corresponding ops
570   // are put on the initial worklist (see above). ImplicitChangeOp does an
571   // unrelated change but ops of the corresponding type are *not* on the initial
572   // worklist, so the effect of the second pattern is only visible if the
573   // tracking and subsequent adding to the worklist actually works.
574 
575   // Replace all usages of the first successor with the second successor.
576   class ChangeBlockOp : public RewritePattern {
577   public:
578     ChangeBlockOp(MLIRContext *context)
579         : RewritePattern("test.change_block_op", /*benefit=*/1, context) {}
580     LogicalResult matchAndRewrite(Operation *op,
581                                   PatternRewriter &rewriter) const override {
582       if (op->getNumSuccessors() < 2)
583         return failure();
584       Block *firstSuccessor = op->getSuccessor(0);
585       Block *secondSuccessor = op->getSuccessor(1);
586       if (firstSuccessor == secondSuccessor)
587         return failure();
588       // This is the function being tested:
589       rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor);
590       // Using the following line instead would make the test fail:
591       // firstSuccessor->replaceAllUsesWith(secondSuccessor);
592       return success();
593     }
594   };
595 
596   // Changes the successor to the parent block.
597   class ImplicitChangeOp : public RewritePattern {
598   public:
599     ImplicitChangeOp(MLIRContext *context)
600         : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {}
601     LogicalResult matchAndRewrite(Operation *op,
602                                   PatternRewriter &rewriter) const override {
603       if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock())
604         return failure();
605       rewriter.modifyOpInPlace(op,
606                                [&]() { op->setSuccessor(op->getBlock(), 0); });
607       return success();
608     }
609   };
610 };
611 
612 struct TestWalkPatternDriver final
613     : PassWrapper<TestWalkPatternDriver, OperationPass<>> {
614   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver)
615 
616   TestWalkPatternDriver() = default;
617   TestWalkPatternDriver(const TestWalkPatternDriver &other)
618       : PassWrapper(other) {}
619 
620   StringRef getArgument() const override {
621     return "test-walk-pattern-rewrite-driver";
622   }
623   StringRef getDescription() const override {
624     return "Run test walk pattern rewrite driver";
625   }
626   void runOnOperation() override {
627     mlir::RewritePatternSet patterns(&getContext());
628 
629     // Patterns for testing the WalkPatternRewriteDriver.
630     patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp,
631                  MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>(
632         &getContext());
633 
634     DumpNotifications dumpListener;
635     walkAndApplyPatterns(getOperation(), std::move(patterns),
636                          dumpNotifications ? &dumpListener : nullptr);
637   }
638 
639   Option<bool> dumpNotifications{
640       *this, "dump-notifications",
641       llvm::cl::desc("Print rewrite listener notifications"),
642       llvm::cl::init(false)};
643 };
644 
645 } // namespace
646 
647 //===----------------------------------------------------------------------===//
648 // ReturnType Driver.
649 //===----------------------------------------------------------------------===//
650 
651 namespace {
652 // Generate ops for each instance where the type can be successfully inferred.
653 template <typename OpTy>
654 static void invokeCreateWithInferredReturnType(Operation *op) {
655   auto *context = op->getContext();
656   auto fop = op->getParentOfType<func::FuncOp>();
657   auto location = UnknownLoc::get(context);
658   OpBuilder b(op);
659   b.setInsertionPointAfter(op);
660 
661   // Use permutations of 2 args as operands.
662   assert(fop.getNumArguments() >= 2);
663   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
664     for (int j = 0; j < e; ++j) {
665       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
666       SmallVector<Type, 2> inferredReturnTypes;
667       if (succeeded(OpTy::inferReturnTypes(
668               context, std::nullopt, values, op->getDiscardableAttrDictionary(),
669               op->getPropertiesStorage(), op->getRegions(),
670               inferredReturnTypes))) {
671         OperationState state(location, OpTy::getOperationName());
672         // TODO: Expand to regions.
673         OpTy::build(b, state, values, op->getAttrs());
674         (void)b.create(state);
675       }
676     }
677   }
678 }
679 
680 static void reifyReturnShape(Operation *op) {
681   OpBuilder b(op);
682 
683   // Use permutations of 2 args as operands.
684   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
685   SmallVector<Value, 2> shapes;
686   if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
687       !llvm::hasSingleElement(shapes))
688     return;
689   for (const auto &it : llvm::enumerate(shapes)) {
690     op->emitRemark() << "value " << it.index() << ": "
691                      << it.value().getDefiningOp();
692   }
693 }
694 
695 struct TestReturnTypeDriver
696     : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> {
697   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver)
698 
699   void getDependentDialects(DialectRegistry &registry) const override {
700     registry.insert<tensor::TensorDialect>();
701   }
702   StringRef getArgument() const final { return "test-return-type"; }
703   StringRef getDescription() const final { return "Run return type functions"; }
704 
705   void runOnOperation() override {
706     if (getOperation().getName() == "testCreateFunctions") {
707       std::vector<Operation *> ops;
708       // Collect ops to avoid triggering on inserted ops.
709       for (auto &op : getOperation().getBody().front())
710         ops.push_back(&op);
711       // Generate test patterns for each, but skip terminator.
712       for (auto *op : llvm::ArrayRef(ops).drop_back()) {
713         // Test create method of each of the Op classes below. The resultant
714         // output would be in reverse order underneath `op` from which
715         // the attributes and regions are used.
716         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
717         invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>(
718             op);
719         invokeCreateWithInferredReturnType<
720             OpWithShapedTypeInferTypeInterfaceOp>(op);
721       };
722       return;
723     }
724     if (getOperation().getName() == "testReifyFunctions") {
725       std::vector<Operation *> ops;
726       // Collect ops to avoid triggering on inserted ops.
727       for (auto &op : getOperation().getBody().front())
728         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
729           ops.push_back(&op);
730       // Generate test patterns for each, but skip terminator.
731       for (auto *op : ops)
732         reifyReturnShape(op);
733     }
734   }
735 };
736 } // namespace
737 
738 namespace {
739 struct TestDerivedAttributeDriver
740     : public PassWrapper<TestDerivedAttributeDriver,
741                          OperationPass<func::FuncOp>> {
742   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver)
743 
744   StringRef getArgument() const final { return "test-derived-attr"; }
745   StringRef getDescription() const final {
746     return "Run test derived attributes";
747   }
748   void runOnOperation() override;
749 };
750 } // namespace
751 
752 void TestDerivedAttributeDriver::runOnOperation() {
753   getOperation().walk([](DerivedAttributeOpInterface dOp) {
754     auto dAttr = dOp.materializeDerivedAttributes();
755     if (!dAttr)
756       return;
757     for (auto d : dAttr)
758       dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue();
759   });
760 }
761 
762 //===----------------------------------------------------------------------===//
763 // Legalization Driver.
764 //===----------------------------------------------------------------------===//
765 
766 namespace {
767 //===----------------------------------------------------------------------===//
768 // Region-Block Rewrite Testing
769 
770 /// This pattern applies a signature conversion to a block inside a detached
771 /// region.
772 struct TestDetachedSignatureConversion : public ConversionPattern {
773   TestDetachedSignatureConversion(MLIRContext *ctx)
774       : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1,
775                           ctx) {}
776 
777   LogicalResult
778   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
779                   ConversionPatternRewriter &rewriter) const final {
780     if (op->getNumRegions() != 1)
781       return failure();
782     OperationState state(op->getLoc(), "test.legal_op_with_region", operands,
783                          op->getResultTypes(), {}, BlockRange());
784     Region *newRegion = state.addRegion();
785     rewriter.inlineRegionBefore(op->getRegion(0), *newRegion,
786                                 newRegion->begin());
787     TypeConverter::SignatureConversion result(newRegion->getNumArguments());
788     for (unsigned i = 0, e = newRegion->getNumArguments(); i < e; ++i)
789       result.addInputs(i, rewriter.getF64Type());
790     rewriter.applySignatureConversion(&newRegion->front(), result);
791     Operation *newOp = rewriter.create(state);
792     rewriter.replaceOp(op, newOp->getResults());
793     return success();
794   }
795 };
796 
797 /// This pattern is a simple pattern that inlines the first region of a given
798 /// operation into the parent region.
799 struct TestRegionRewriteBlockMovement : public ConversionPattern {
800   TestRegionRewriteBlockMovement(MLIRContext *ctx)
801       : ConversionPattern("test.region", 1, ctx) {}
802 
803   LogicalResult
804   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
805                   ConversionPatternRewriter &rewriter) const final {
806     // Inline this region into the parent region.
807     auto &parentRegion = *op->getParentRegion();
808     auto &opRegion = op->getRegion(0);
809     if (op->getDiscardableAttr("legalizer.should_clone"))
810       rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
811     else
812       rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
813 
814     if (op->getDiscardableAttr("legalizer.erase_old_blocks")) {
815       while (!opRegion.empty())
816         rewriter.eraseBlock(&opRegion.front());
817     }
818 
819     // Drop this operation.
820     rewriter.eraseOp(op);
821     return success();
822   }
823 };
824 /// This pattern is a simple pattern that generates a region containing an
825 /// illegal operation.
826 struct TestRegionRewriteUndo : public RewritePattern {
827   TestRegionRewriteUndo(MLIRContext *ctx)
828       : RewritePattern("test.region_builder", 1, ctx) {}
829 
830   LogicalResult matchAndRewrite(Operation *op,
831                                 PatternRewriter &rewriter) const final {
832     // Create the region operation with an entry block containing arguments.
833     OperationState newRegion(op->getLoc(), "test.region");
834     newRegion.addRegion();
835     auto *regionOp = rewriter.create(newRegion);
836     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
837     entryBlock->addArgument(rewriter.getIntegerType(64),
838                             rewriter.getUnknownLoc());
839 
840     // Add an explicitly illegal operation to ensure the conversion fails.
841     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
842     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
843 
844     // Drop this operation.
845     rewriter.eraseOp(op);
846     return success();
847   }
848 };
849 /// A simple pattern that creates a block at the end of the parent region of the
850 /// matched operation.
851 struct TestCreateBlock : public RewritePattern {
852   TestCreateBlock(MLIRContext *ctx)
853       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
854 
855   LogicalResult matchAndRewrite(Operation *op,
856                                 PatternRewriter &rewriter) const final {
857     Region &region = *op->getParentRegion();
858     Type i32Type = rewriter.getIntegerType(32);
859     Location loc = op->getLoc();
860     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
861     rewriter.create<TerminatorOp>(loc);
862     rewriter.eraseOp(op);
863     return success();
864   }
865 };
866 
867 /// A simple pattern that creates a block containing an invalid operation in
868 /// order to trigger the block creation undo mechanism.
869 struct TestCreateIllegalBlock : public RewritePattern {
870   TestCreateIllegalBlock(MLIRContext *ctx)
871       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
872 
873   LogicalResult matchAndRewrite(Operation *op,
874                                 PatternRewriter &rewriter) const final {
875     Region &region = *op->getParentRegion();
876     Type i32Type = rewriter.getIntegerType(32);
877     Location loc = op->getLoc();
878     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
879     // Create an illegal op to ensure the conversion fails.
880     rewriter.create<ILLegalOpF>(loc, i32Type);
881     rewriter.create<TerminatorOp>(loc);
882     rewriter.eraseOp(op);
883     return success();
884   }
885 };
886 
887 /// A simple pattern that tests the undo mechanism when replacing the uses of a
888 /// block argument.
889 struct TestUndoBlockArgReplace : public ConversionPattern {
890   TestUndoBlockArgReplace(MLIRContext *ctx)
891       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
892 
893   LogicalResult
894   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
895                   ConversionPatternRewriter &rewriter) const final {
896     auto illegalOp =
897         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
898     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
899                                         illegalOp->getResult(0));
900     rewriter.modifyOpInPlace(op, [] {});
901     return success();
902   }
903 };
904 
905 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
906 /// This is to test the rollback logic.
907 struct TestUndoMoveOpBefore : public ConversionPattern {
908   TestUndoMoveOpBefore(MLIRContext *ctx)
909       : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
910 
911   LogicalResult
912   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
913                   ConversionPatternRewriter &rewriter) const override {
914     rewriter.moveOpBefore(op, op->getParentOp());
915     // Replace with an illegal op to ensure the conversion fails.
916     rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
917     return success();
918   }
919 };
920 
921 /// A rewrite pattern that tests the undo mechanism when erasing a block.
922 struct TestUndoBlockErase : public ConversionPattern {
923   TestUndoBlockErase(MLIRContext *ctx)
924       : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
925 
926   LogicalResult
927   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
928                   ConversionPatternRewriter &rewriter) const final {
929     Block *secondBlock = &*std::next(op->getRegion(0).begin());
930     rewriter.setInsertionPointToStart(secondBlock);
931     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
932     rewriter.eraseBlock(secondBlock);
933     rewriter.modifyOpInPlace(op, [] {});
934     return success();
935   }
936 };
937 
938 /// A pattern that modifies a property in-place, but keeps the op illegal.
939 struct TestUndoPropertiesModification : public ConversionPattern {
940   TestUndoPropertiesModification(MLIRContext *ctx)
941       : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {}
942   LogicalResult
943   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
944                   ConversionPatternRewriter &rewriter) const final {
945     if (!op->hasAttr("modify_inplace"))
946       return failure();
947     rewriter.modifyOpInPlace(
948         op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); });
949     return success();
950   }
951 };
952 
953 //===----------------------------------------------------------------------===//
954 // Type-Conversion Rewrite Testing
955 
956 /// This patterns erases a region operation that has had a type conversion.
957 struct TestDropOpSignatureConversion : public ConversionPattern {
958   TestDropOpSignatureConversion(MLIRContext *ctx,
959                                 const TypeConverter &converter)
960       : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
961   LogicalResult
962   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
963                   ConversionPatternRewriter &rewriter) const override {
964     Region &region = op->getRegion(0);
965     Block *entry = &region.front();
966 
967     // Convert the original entry arguments.
968     const TypeConverter &converter = *getTypeConverter();
969     TypeConverter::SignatureConversion result(entry->getNumArguments());
970     if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
971                                               result)) ||
972         failed(rewriter.convertRegionTypes(&region, converter, &result)))
973       return failure();
974 
975     // Convert the region signature and just drop the operation.
976     rewriter.eraseOp(op);
977     return success();
978   }
979 };
980 /// This pattern simply updates the operands of the given operation.
981 struct TestPassthroughInvalidOp : public ConversionPattern {
982   TestPassthroughInvalidOp(MLIRContext *ctx)
983       : ConversionPattern("test.invalid", 1, ctx) {}
984   LogicalResult
985   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
986                   ConversionPatternRewriter &rewriter) const final {
987     SmallVector<Value> flattened;
988     for (auto it : llvm::enumerate(operands)) {
989       ValueRange range = it.value();
990       if (range.size() == 1) {
991         flattened.push_back(range.front());
992         continue;
993       }
994 
995       // This is a 1:N replacement. Insert a test.cast op. (That's what the
996       // argument materialization used to do.)
997       flattened.push_back(
998           rewriter
999               .create<TestCastOp>(op->getLoc(),
1000                                   op->getOperand(it.index()).getType(), range)
1001               .getResult());
1002     }
1003     rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, flattened,
1004                                              std::nullopt);
1005     return success();
1006   }
1007 };
1008 /// Replace with valid op, but simply drop the operands. This is used in a
1009 /// regression where we used to generate circular unrealized_conversion_cast
1010 /// ops.
1011 struct TestDropAndReplaceInvalidOp : public ConversionPattern {
1012   TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
1013       : ConversionPattern(converter,
1014                           "test.drop_operands_and_replace_with_valid", 1, ctx) {
1015   }
1016   LogicalResult
1017   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1018                   ConversionPatternRewriter &rewriter) const final {
1019     rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, ValueRange(),
1020                                              std::nullopt);
1021     return success();
1022   }
1023 };
1024 /// This pattern handles the case of a split return value.
1025 struct TestSplitReturnType : public ConversionPattern {
1026   TestSplitReturnType(MLIRContext *ctx)
1027       : ConversionPattern("test.return", 1, ctx) {}
1028   LogicalResult
1029   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1030                   ConversionPatternRewriter &rewriter) const final {
1031     // Check for a return of F32.
1032     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
1033       return failure();
1034     rewriter.replaceOpWithNewOp<TestReturnOp>(op, operands[0]);
1035     return success();
1036   }
1037 };
1038 
1039 //===----------------------------------------------------------------------===//
1040 // Multi-Level Type-Conversion Rewrite Testing
1041 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
1042   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
1043       : ConversionPattern("test.type_producer", 1, ctx) {}
1044   LogicalResult
1045   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1046                   ConversionPatternRewriter &rewriter) const final {
1047     // If the type is I32, change the type to F32.
1048     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
1049       return failure();
1050     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
1051     return success();
1052   }
1053 };
1054 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
1055   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
1056       : ConversionPattern("test.type_producer", 1, ctx) {}
1057   LogicalResult
1058   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1059                   ConversionPatternRewriter &rewriter) const final {
1060     // If the type is F32, change the type to F64.
1061     if (!Type(*op->result_type_begin()).isF32())
1062       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
1063     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
1064     return success();
1065   }
1066 };
1067 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
1068   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
1069       : ConversionPattern("test.type_producer", 10, ctx) {}
1070   LogicalResult
1071   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1072                   ConversionPatternRewriter &rewriter) const final {
1073     // Always convert to B16, even though it is not a legal type. This tests
1074     // that values are unmapped correctly.
1075     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
1076     return success();
1077   }
1078 };
1079 struct TestUpdateConsumerType : public ConversionPattern {
1080   TestUpdateConsumerType(MLIRContext *ctx)
1081       : ConversionPattern("test.type_consumer", 1, ctx) {}
1082   LogicalResult
1083   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1084                   ConversionPatternRewriter &rewriter) const final {
1085     // Verify that the incoming operand has been successfully remapped to F64.
1086     if (!operands[0].getType().isF64())
1087       return failure();
1088     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
1089     return success();
1090   }
1091 };
1092 
1093 //===----------------------------------------------------------------------===//
1094 // Non-Root Replacement Rewrite Testing
1095 /// This pattern generates an invalid operation, but replaces it before the
1096 /// pattern is finished. This checks that we don't need to legalize the
1097 /// temporary op.
1098 struct TestNonRootReplacement : public RewritePattern {
1099   TestNonRootReplacement(MLIRContext *ctx)
1100       : RewritePattern("test.replace_non_root", 1, ctx) {}
1101 
1102   LogicalResult matchAndRewrite(Operation *op,
1103                                 PatternRewriter &rewriter) const final {
1104     auto resultType = *op->result_type_begin();
1105     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
1106     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
1107 
1108     rewriter.replaceOp(illegalOp, legalOp);
1109     rewriter.replaceOp(op, illegalOp);
1110     return success();
1111   }
1112 };
1113 
1114 //===----------------------------------------------------------------------===//
1115 // Recursive Rewrite Testing
1116 /// This pattern is applied to the same operation multiple times, but has a
1117 /// bounded recursion.
1118 struct TestBoundedRecursiveRewrite
1119     : public OpRewritePattern<TestRecursiveRewriteOp> {
1120   using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
1121 
1122   void initialize() {
1123     // The conversion target handles bounding the recursion of this pattern.
1124     setHasBoundedRewriteRecursion();
1125   }
1126 
1127   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
1128                                 PatternRewriter &rewriter) const final {
1129     // Decrement the depth of the op in-place.
1130     rewriter.modifyOpInPlace(op, [&] {
1131       op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
1132     });
1133     return success();
1134   }
1135 };
1136 
1137 struct TestNestedOpCreationUndoRewrite
1138     : public OpRewritePattern<IllegalOpWithRegionAnchor> {
1139   using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
1140 
1141   LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
1142                                 PatternRewriter &rewriter) const final {
1143     // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1144     rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1145     return success();
1146   };
1147 };
1148 
1149 // This pattern matches `test.blackhole` and delete this op and its producer.
1150 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
1151   using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
1152 
1153   LogicalResult matchAndRewrite(BlackHoleOp op,
1154                                 PatternRewriter &rewriter) const final {
1155     Operation *producer = op.getOperand().getDefiningOp();
1156     // Always erase the user before the producer, the framework should handle
1157     // this correctly.
1158     rewriter.eraseOp(op);
1159     rewriter.eraseOp(producer);
1160     return success();
1161   };
1162 };
1163 
1164 // This pattern replaces explicitly illegal op with explicitly legal op,
1165 // but in addition creates unregistered operation.
1166 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
1167   using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
1168 
1169   LogicalResult matchAndRewrite(ILLegalOpG op,
1170                                 PatternRewriter &rewriter) const final {
1171     IntegerAttr attr = rewriter.getI32IntegerAttr(0);
1172     Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
1173     rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
1174     return success();
1175   };
1176 };
1177 
1178 class TestEraseOp : public ConversionPattern {
1179 public:
1180   TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op", 1, ctx) {}
1181   LogicalResult
1182   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1183                   ConversionPatternRewriter &rewriter) const final {
1184     // Erase op without replacements.
1185     rewriter.eraseOp(op);
1186     return success();
1187   }
1188 };
1189 
1190 /// This pattern matches a test.duplicate_block_args op and duplicates all
1191 /// block arguments.
1192 class TestDuplicateBlockArgs
1193     : public OpConversionPattern<DuplicateBlockArgsOp> {
1194   using OpConversionPattern<DuplicateBlockArgsOp>::OpConversionPattern;
1195 
1196   LogicalResult
1197   matchAndRewrite(DuplicateBlockArgsOp op, OpAdaptor adaptor,
1198                   ConversionPatternRewriter &rewriter) const override {
1199     if (op.getIsLegal())
1200       return failure();
1201     rewriter.startOpModification(op);
1202     Block *body = &op.getBody().front();
1203     TypeConverter::SignatureConversion result(body->getNumArguments());
1204     for (auto it : llvm::enumerate(body->getArgumentTypes()))
1205       result.addInputs(it.index(), {it.value(), it.value()});
1206     rewriter.applySignatureConversion(body, result, getTypeConverter());
1207     op.setIsLegal(true);
1208     rewriter.finalizeOpModification(op);
1209     return success();
1210   }
1211 };
1212 
1213 /// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid
1214 /// op. The pattern supports 1:N replacements and forwards the replacement
1215 /// values of the single operand as test.valid operands.
1216 class TestRepetitive1ToNConsumer : public ConversionPattern {
1217 public:
1218   TestRepetitive1ToNConsumer(MLIRContext *ctx)
1219       : ConversionPattern("test.repetitive_1_to_n_consumer", 1, ctx) {}
1220   LogicalResult
1221   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1222                   ConversionPatternRewriter &rewriter) const final {
1223     // A single operand is expected.
1224     if (op->getNumOperands() != 1)
1225       return failure();
1226     rewriter.replaceOpWithNewOp<TestValidOp>(op, operands.front());
1227     return success();
1228   }
1229 };
1230 
1231 } // namespace
1232 
1233 namespace {
1234 struct TestTypeConverter : public TypeConverter {
1235   using TypeConverter::TypeConverter;
1236   TestTypeConverter() {
1237     addConversion(convertType);
1238     addArgumentMaterialization(materializeCast);
1239     addSourceMaterialization(materializeCast);
1240   }
1241 
1242   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
1243     // Drop I16 types.
1244     if (t.isSignlessInteger(16))
1245       return success();
1246 
1247     // Convert I64 to F64.
1248     if (t.isSignlessInteger(64)) {
1249       results.push_back(FloatType::getF64(t.getContext()));
1250       return success();
1251     }
1252 
1253     // Convert I42 to I43.
1254     if (t.isInteger(42)) {
1255       results.push_back(IntegerType::get(t.getContext(), 43));
1256       return success();
1257     }
1258 
1259     // Split F32 into F16,F16.
1260     if (t.isF32()) {
1261       results.assign(2, FloatType::getF16(t.getContext()));
1262       return success();
1263     }
1264 
1265     // Drop I24 types.
1266     if (t.isInteger(24)) {
1267       return success();
1268     }
1269 
1270     // Otherwise, convert the type directly.
1271     results.push_back(t);
1272     return success();
1273   }
1274 
1275   /// Hook for materializing a conversion. This is necessary because we generate
1276   /// 1->N type mappings.
1277   static Value materializeCast(OpBuilder &builder, Type resultType,
1278                                ValueRange inputs, Location loc) {
1279     return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1280   }
1281 };
1282 
1283 struct TestLegalizePatternDriver
1284     : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> {
1285   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver)
1286 
1287   StringRef getArgument() const final { return "test-legalize-patterns"; }
1288   StringRef getDescription() const final {
1289     return "Run test dialect legalization patterns";
1290   }
1291   /// The mode of conversion to use with the driver.
1292   enum class ConversionMode { Analysis, Full, Partial };
1293 
1294   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
1295 
1296   void getDependentDialects(DialectRegistry &registry) const override {
1297     registry.insert<func::FuncDialect, test::TestDialect>();
1298   }
1299 
1300   void runOnOperation() override {
1301     TestTypeConverter converter;
1302     mlir::RewritePatternSet patterns(&getContext());
1303     populateWithGenerated(patterns);
1304     patterns.add<
1305         TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1306         TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1307         TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
1308         TestSplitReturnType, TestChangeProducerTypeI32ToF32,
1309         TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
1310         TestUpdateConsumerType, TestNonRootReplacement,
1311         TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
1312         TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1313         TestUndoPropertiesModification, TestEraseOp,
1314         TestRepetitive1ToNConsumer>(&getContext());
1315     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
1316         &getContext(), converter);
1317     patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
1318     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1319                                                               converter);
1320     mlir::populateCallOpTypeConversionPattern(patterns, converter);
1321 
1322     // Define the conversion target used for the test.
1323     ConversionTarget target(getContext());
1324     target.addLegalOp<ModuleOp>();
1325     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
1326                       TerminatorOp, OneRegionOp>();
1327     target.addLegalOp(
1328         OperationName("test.legal_op_with_region", &getContext()));
1329     target
1330         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
1331     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
1332       // Don't allow F32 operands.
1333       return llvm::none_of(op.getOperandTypes(),
1334                            [](Type type) { return type.isF32(); });
1335     });
1336     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1337       return converter.isSignatureLegal(op.getFunctionType()) &&
1338              converter.isLegal(&op.getBody());
1339     });
1340     target.addDynamicallyLegalOp<func::CallOp>(
1341         [&](func::CallOp op) { return converter.isLegal(op); });
1342 
1343     // TestCreateUnregisteredOp creates `arith.constant` operation,
1344     // which was not added to target intentionally to test
1345     // correct error code from conversion driver.
1346     target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
1347 
1348     // Expect the type_producer/type_consumer operations to only operate on f64.
1349     target.addDynamicallyLegalOp<TestTypeProducerOp>(
1350         [](TestTypeProducerOp op) { return op.getType().isF64(); });
1351     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1352       return op.getOperand().getType().isF64();
1353     });
1354 
1355     // Check support for marking certain operations as recursively legal.
1356     target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) {
1357       return static_cast<bool>(
1358           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
1359     });
1360 
1361     // Mark the bound recursion operation as dynamically legal.
1362     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
1363         [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
1364 
1365     // Create a dynamically legal rule that can only be legalized by folding it.
1366     target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
1367         [](TestOpInPlaceSelfFold op) { return op.getFolded(); });
1368 
1369     target.addDynamicallyLegalOp<DuplicateBlockArgsOp>(
1370         [](DuplicateBlockArgsOp op) { return op.getIsLegal(); });
1371 
1372     // Handle a partial conversion.
1373     if (mode == ConversionMode::Partial) {
1374       DenseSet<Operation *> unlegalizedOps;
1375       ConversionConfig config;
1376       DumpNotifications dumpNotifications;
1377       config.listener = &dumpNotifications;
1378       config.unlegalizedOps = &unlegalizedOps;
1379       if (failed(applyPartialConversion(getOperation(), target,
1380                                         std::move(patterns), config))) {
1381         getOperation()->emitRemark() << "applyPartialConversion failed";
1382       }
1383       // Emit remarks for each legalizable operation.
1384       for (auto *op : unlegalizedOps)
1385         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1386       return;
1387     }
1388 
1389     // Handle a full conversion.
1390     if (mode == ConversionMode::Full) {
1391       // Check support for marking unknown operations as dynamically legal.
1392       target.markUnknownOpDynamicallyLegal([](Operation *op) {
1393         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
1394       });
1395 
1396       ConversionConfig config;
1397       DumpNotifications dumpNotifications;
1398       config.listener = &dumpNotifications;
1399       if (failed(applyFullConversion(getOperation(), target,
1400                                      std::move(patterns), config))) {
1401         getOperation()->emitRemark() << "applyFullConversion failed";
1402       }
1403       return;
1404     }
1405 
1406     // Otherwise, handle an analysis conversion.
1407     assert(mode == ConversionMode::Analysis);
1408 
1409     // Analyze the convertible operations.
1410     DenseSet<Operation *> legalizedOps;
1411     ConversionConfig config;
1412     config.legalizableOps = &legalizedOps;
1413     if (failed(applyAnalysisConversion(getOperation(), target,
1414                                        std::move(patterns), config)))
1415       return signalPassFailure();
1416 
1417     // Emit remarks for each legalizable operation.
1418     for (auto *op : legalizedOps)
1419       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
1420   }
1421 
1422   /// The mode of conversion to use.
1423   ConversionMode mode;
1424 };
1425 } // namespace
1426 
1427 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
1428     legalizerConversionMode(
1429         "test-legalize-mode",
1430         llvm::cl::desc("The legalization mode to use with the test driver"),
1431         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
1432         llvm::cl::values(
1433             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
1434                        "analysis", "Perform an analysis conversion"),
1435             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
1436                        "Perform a full conversion"),
1437             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
1438                        "partial", "Perform a partial conversion")));
1439 
1440 //===----------------------------------------------------------------------===//
1441 // ConversionPatternRewriter::getRemappedValue testing. This method is used
1442 // to get the remapped value of an original value that was replaced using
1443 // ConversionPatternRewriter.
1444 namespace {
1445 struct TestRemapValueTypeConverter : public TypeConverter {
1446   using TypeConverter::TypeConverter;
1447 
1448   TestRemapValueTypeConverter() {
1449     addConversion(
1450         [](Float32Type type) { return Float64Type::get(type.getContext()); });
1451     addConversion([](Type type) { return type; });
1452   }
1453 };
1454 
1455 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
1456 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
1457 /// operand twice.
1458 ///
1459 /// Example:
1460 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
1461 /// is replaced with:
1462 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
1463 struct OneVResOneVOperandOp1Converter
1464     : public OpConversionPattern<OneVResOneVOperandOp1> {
1465   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
1466 
1467   LogicalResult
1468   matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
1469                   ConversionPatternRewriter &rewriter) const override {
1470     auto origOps = op.getOperands();
1471     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
1472            "One operand expected");
1473     Value origOp = *origOps.begin();
1474     SmallVector<Value, 2> remappedOperands;
1475     // Replicate the remapped original operand twice. Note that we don't used
1476     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
1477     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1478     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1479 
1480     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
1481                                                        remappedOperands);
1482     return success();
1483   }
1484 };
1485 
1486 /// A rewriter pattern that tests that blocks can be merged.
1487 struct TestRemapValueInRegion
1488     : public OpConversionPattern<TestRemappedValueRegionOp> {
1489   using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
1490 
1491   LogicalResult
1492   matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
1493                   ConversionPatternRewriter &rewriter) const final {
1494     Block &block = op.getBody().front();
1495     Operation *terminator = block.getTerminator();
1496 
1497     // Merge the block into the parent region.
1498     Block *parentBlock = op->getBlock();
1499     Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator());
1500     rewriter.mergeBlocks(&block, parentBlock, ValueRange());
1501     rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange());
1502 
1503     // Replace the results of this operation with the remapped terminator
1504     // values.
1505     SmallVector<Value> terminatorOperands;
1506     if (failed(rewriter.getRemappedValues(terminator->getOperands(),
1507                                           terminatorOperands)))
1508       return failure();
1509 
1510     rewriter.eraseOp(terminator);
1511     rewriter.replaceOp(op, terminatorOperands);
1512     return success();
1513   }
1514 };
1515 
1516 struct TestRemappedValue
1517     : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> {
1518   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue)
1519 
1520   StringRef getArgument() const final { return "test-remapped-value"; }
1521   StringRef getDescription() const final {
1522     return "Test public remapped value mechanism in ConversionPatternRewriter";
1523   }
1524   void runOnOperation() override {
1525     TestRemapValueTypeConverter typeConverter;
1526 
1527     mlir::RewritePatternSet patterns(&getContext());
1528     patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
1529     patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
1530         &getContext());
1531     patterns.add<TestRemapValueInRegion>(typeConverter, &getContext());
1532 
1533     mlir::ConversionTarget target(getContext());
1534     target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>();
1535 
1536     // Expect the type_producer/type_consumer operations to only operate on f64.
1537     target.addDynamicallyLegalOp<TestTypeProducerOp>(
1538         [](TestTypeProducerOp op) { return op.getType().isF64(); });
1539     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1540       return op.getOperand().getType().isF64();
1541     });
1542 
1543     // We make OneVResOneVOperandOp1 legal only when it has more that one
1544     // operand. This will trigger the conversion that will replace one-operand
1545     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
1546     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
1547         [](Operation *op) { return op->getNumOperands() > 1; });
1548 
1549     if (failed(mlir::applyFullConversion(getOperation(), target,
1550                                          std::move(patterns)))) {
1551       signalPassFailure();
1552     }
1553   }
1554 };
1555 } // namespace
1556 
1557 //===----------------------------------------------------------------------===//
1558 // Test patterns without a specific root operation kind
1559 //===----------------------------------------------------------------------===//
1560 
1561 namespace {
1562 /// This pattern matches and removes any operation in the test dialect.
1563 struct RemoveTestDialectOps : public RewritePattern {
1564   RemoveTestDialectOps(MLIRContext *context)
1565       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
1566 
1567   LogicalResult matchAndRewrite(Operation *op,
1568                                 PatternRewriter &rewriter) const override {
1569     if (!isa<TestDialect>(op->getDialect()))
1570       return failure();
1571     rewriter.eraseOp(op);
1572     return success();
1573   }
1574 };
1575 
1576 struct TestUnknownRootOpDriver
1577     : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> {
1578   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver)
1579 
1580   StringRef getArgument() const final {
1581     return "test-legalize-unknown-root-patterns";
1582   }
1583   StringRef getDescription() const final {
1584     return "Test public remapped value mechanism in ConversionPatternRewriter";
1585   }
1586   void runOnOperation() override {
1587     mlir::RewritePatternSet patterns(&getContext());
1588     patterns.add<RemoveTestDialectOps>(&getContext());
1589 
1590     mlir::ConversionTarget target(getContext());
1591     target.addIllegalDialect<TestDialect>();
1592     if (failed(applyPartialConversion(getOperation(), target,
1593                                       std::move(patterns))))
1594       signalPassFailure();
1595   }
1596 };
1597 } // namespace
1598 
1599 //===----------------------------------------------------------------------===//
1600 // Test patterns that uses operations and types defined at runtime
1601 //===----------------------------------------------------------------------===//
1602 
1603 namespace {
1604 /// This pattern matches dynamic operations 'test.one_operand_two_results' and
1605 /// replace them with dynamic operations 'test.generic_dynamic_op'.
1606 struct RewriteDynamicOp : public RewritePattern {
1607   RewriteDynamicOp(MLIRContext *context)
1608       : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1,
1609                        context) {}
1610 
1611   LogicalResult matchAndRewrite(Operation *op,
1612                                 PatternRewriter &rewriter) const override {
1613     assert(op->getName().getStringRef() ==
1614                "test.dynamic_one_operand_two_results" &&
1615            "rewrite pattern should only match operations with the right name");
1616 
1617     OperationState state(op->getLoc(), "test.dynamic_generic",
1618                          op->getOperands(), op->getResultTypes(),
1619                          op->getAttrs());
1620     auto *newOp = rewriter.create(state);
1621     rewriter.replaceOp(op, newOp->getResults());
1622     return success();
1623   }
1624 };
1625 
1626 struct TestRewriteDynamicOpDriver
1627     : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> {
1628   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver)
1629 
1630   void getDependentDialects(DialectRegistry &registry) const override {
1631     registry.insert<TestDialect>();
1632   }
1633   StringRef getArgument() const final { return "test-rewrite-dynamic-op"; }
1634   StringRef getDescription() const final {
1635     return "Test rewritting on dynamic operations";
1636   }
1637   void runOnOperation() override {
1638     RewritePatternSet patterns(&getContext());
1639     patterns.add<RewriteDynamicOp>(&getContext());
1640 
1641     ConversionTarget target(getContext());
1642     target.addIllegalOp(
1643         OperationName("test.dynamic_one_operand_two_results", &getContext()));
1644     target.addLegalOp(OperationName("test.dynamic_generic", &getContext()));
1645     if (failed(applyPartialConversion(getOperation(), target,
1646                                       std::move(patterns))))
1647       signalPassFailure();
1648   }
1649 };
1650 } // end anonymous namespace
1651 
1652 //===----------------------------------------------------------------------===//
1653 // Test type conversions
1654 //===----------------------------------------------------------------------===//
1655 
1656 namespace {
1657 struct TestTypeConversionProducer
1658     : public OpConversionPattern<TestTypeProducerOp> {
1659   using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
1660   LogicalResult
1661   matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
1662                   ConversionPatternRewriter &rewriter) const final {
1663     Type resultType = op.getType();
1664     Type convertedType = getTypeConverter()
1665                              ? getTypeConverter()->convertType(resultType)
1666                              : resultType;
1667     if (isa<FloatType>(resultType))
1668       resultType = rewriter.getF64Type();
1669     else if (resultType.isInteger(16))
1670       resultType = rewriter.getIntegerType(64);
1671     else if (isa<test::TestRecursiveType>(resultType) &&
1672              convertedType != resultType)
1673       resultType = convertedType;
1674     else
1675       return failure();
1676 
1677     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
1678     return success();
1679   }
1680 };
1681 
1682 /// Call signature conversion and then fail the rewrite to trigger the undo
1683 /// mechanism.
1684 struct TestSignatureConversionUndo
1685     : public OpConversionPattern<TestSignatureConversionUndoOp> {
1686   using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
1687 
1688   LogicalResult
1689   matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
1690                   ConversionPatternRewriter &rewriter) const final {
1691     (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
1692     return failure();
1693   }
1694 };
1695 
1696 /// Call signature conversion without providing a type converter to handle
1697 /// materializations.
1698 struct TestTestSignatureConversionNoConverter
1699     : public OpConversionPattern<TestSignatureConversionNoConverterOp> {
1700   TestTestSignatureConversionNoConverter(const TypeConverter &converter,
1701                                          MLIRContext *context)
1702       : OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
1703         converter(converter) {}
1704 
1705   LogicalResult
1706   matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor,
1707                   ConversionPatternRewriter &rewriter) const final {
1708     Region &region = op->getRegion(0);
1709     Block *entry = &region.front();
1710 
1711     // Convert the original entry arguments.
1712     TypeConverter::SignatureConversion result(entry->getNumArguments());
1713     if (failed(
1714             converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
1715       return failure();
1716     rewriter.modifyOpInPlace(op, [&] {
1717       rewriter.applySignatureConversion(&region.front(), result);
1718     });
1719     return success();
1720   }
1721 
1722   const TypeConverter &converter;
1723 };
1724 
1725 /// Just forward the operands to the root op. This is essentially a no-op
1726 /// pattern that is used to trigger target materialization.
1727 struct TestTypeConsumerForward
1728     : public OpConversionPattern<TestTypeConsumerOp> {
1729   using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
1730 
1731   LogicalResult
1732   matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
1733                   ConversionPatternRewriter &rewriter) const final {
1734     rewriter.modifyOpInPlace(op,
1735                              [&] { op->setOperands(adaptor.getOperands()); });
1736     return success();
1737   }
1738 };
1739 
1740 struct TestTypeConversionAnotherProducer
1741     : public OpRewritePattern<TestAnotherTypeProducerOp> {
1742   using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
1743 
1744   LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
1745                                 PatternRewriter &rewriter) const final {
1746     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
1747     return success();
1748   }
1749 };
1750 
1751 struct TestReplaceWithLegalOp : public ConversionPattern {
1752   TestReplaceWithLegalOp(MLIRContext *ctx)
1753       : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
1754   LogicalResult
1755   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1756                   ConversionPatternRewriter &rewriter) const final {
1757     rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
1758     return success();
1759   }
1760 };
1761 
1762 struct TestTypeConversionDriver
1763     : public PassWrapper<TestTypeConversionDriver, OperationPass<>> {
1764   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
1765 
1766   void getDependentDialects(DialectRegistry &registry) const override {
1767     registry.insert<TestDialect>();
1768   }
1769   StringRef getArgument() const final {
1770     return "test-legalize-type-conversion";
1771   }
1772   StringRef getDescription() const final {
1773     return "Test various type conversion functionalities in DialectConversion";
1774   }
1775 
1776   void runOnOperation() override {
1777     // Initialize the type converter.
1778     SmallVector<Type, 2> conversionCallStack;
1779     TypeConverter converter;
1780 
1781     /// Add the legal set of type conversions.
1782     converter.addConversion([](Type type) -> Type {
1783       // Treat F64 as legal.
1784       if (type.isF64())
1785         return type;
1786       // Allow converting BF16/F16/F32 to F64.
1787       if (type.isBF16() || type.isF16() || type.isF32())
1788         return FloatType::getF64(type.getContext());
1789       // Otherwise, the type is illegal.
1790       return nullptr;
1791     });
1792     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
1793       // Drop all integer types.
1794       return success();
1795     });
1796     converter.addConversion(
1797         // Convert a recursive self-referring type into a non-self-referring
1798         // type named "outer_converted_type" that contains a SimpleAType.
1799         [&](test::TestRecursiveType type,
1800             SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
1801           // If the type is already converted, return it to indicate that it is
1802           // legal.
1803           if (type.getName() == "outer_converted_type") {
1804             results.push_back(type);
1805             return success();
1806           }
1807 
1808           conversionCallStack.push_back(type);
1809           auto popConversionCallStack = llvm::make_scope_exit(
1810               [&conversionCallStack]() { conversionCallStack.pop_back(); });
1811 
1812           // If the type is on the call stack more than once (it is there at
1813           // least once because of the _current_ call, which is always the last
1814           // element on the stack), we've hit the recursive case. Just return
1815           // SimpleAType here to create a non-recursive type as a result.
1816           if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(),
1817                                  type)) {
1818             results.push_back(test::SimpleAType::get(type.getContext()));
1819             return success();
1820           }
1821 
1822           // Convert the body recursively.
1823           auto result = test::TestRecursiveType::get(type.getContext(),
1824                                                      "outer_converted_type");
1825           if (failed(result.setBody(converter.convertType(type.getBody()))))
1826             return failure();
1827           results.push_back(result);
1828           return success();
1829         });
1830 
1831     /// Add the legal set of type materializations.
1832     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
1833                                           ValueRange inputs,
1834                                           Location loc) -> Value {
1835       // Allow casting from F64 back to F32.
1836       if (!resultType.isF16() && inputs.size() == 1 &&
1837           inputs[0].getType().isF64())
1838         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1839       // Allow producing an i32 or i64 from nothing.
1840       if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
1841           inputs.empty())
1842         return builder.create<TestTypeProducerOp>(loc, resultType);
1843       // Allow producing an i64 from an integer.
1844       if (isa<IntegerType>(resultType) && inputs.size() == 1 &&
1845           isa<IntegerType>(inputs[0].getType()))
1846         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1847       // Otherwise, fail.
1848       return nullptr;
1849     });
1850 
1851     // Initialize the conversion target.
1852     mlir::ConversionTarget target(getContext());
1853     target.addLegalOp<LegalOpD>();
1854     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
1855       auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType());
1856       return op.getType().isF64() || op.getType().isInteger(64) ||
1857              (recursiveType &&
1858               recursiveType.getName() == "outer_converted_type");
1859     });
1860     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1861       return converter.isSignatureLegal(op.getFunctionType()) &&
1862              converter.isLegal(&op.getBody());
1863     });
1864     target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
1865       // Allow casts from F64 to F32.
1866       return (*op.operand_type_begin()).isF64() && op.getType().isF32();
1867     });
1868     target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>(
1869         [&](TestSignatureConversionNoConverterOp op) {
1870           return converter.isLegal(op.getRegion().front().getArgumentTypes());
1871         });
1872 
1873     // Initialize the set of rewrite patterns.
1874     RewritePatternSet patterns(&getContext());
1875     patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
1876                  TestSignatureConversionUndo,
1877                  TestTestSignatureConversionNoConverter>(converter,
1878                                                          &getContext());
1879     patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
1880         &getContext());
1881     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1882                                                               converter);
1883 
1884     if (failed(applyPartialConversion(getOperation(), target,
1885                                       std::move(patterns))))
1886       signalPassFailure();
1887   }
1888 };
1889 } // namespace
1890 
1891 //===----------------------------------------------------------------------===//
1892 // Test Target Materialization With No Uses
1893 //===----------------------------------------------------------------------===//
1894 
1895 namespace {
1896 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> {
1897   using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern;
1898 
1899   LogicalResult
1900   matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor,
1901                   ConversionPatternRewriter &rewriter) const final {
1902     rewriter.replaceOp(op, adaptor.getOperands());
1903     return success();
1904   }
1905 };
1906 
1907 struct TestTargetMaterializationWithNoUses
1908     : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> {
1909   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
1910       TestTargetMaterializationWithNoUses)
1911 
1912   StringRef getArgument() const final {
1913     return "test-target-materialization-with-no-uses";
1914   }
1915   StringRef getDescription() const final {
1916     return "Test a special case of target materialization in DialectConversion";
1917   }
1918 
1919   void runOnOperation() override {
1920     TypeConverter converter;
1921     converter.addConversion([](Type t) { return t; });
1922     converter.addConversion([](IntegerType intTy) -> Type {
1923       if (intTy.getWidth() == 16)
1924         return IntegerType::get(intTy.getContext(), 64);
1925       return intTy;
1926     });
1927     converter.addTargetMaterialization(
1928         [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1929           return builder.create<TestCastOp>(loc, type, inputs).getResult();
1930         });
1931 
1932     ConversionTarget target(getContext());
1933     target.addIllegalOp<TestTypeChangerOp>();
1934 
1935     RewritePatternSet patterns(&getContext());
1936     patterns.add<ForwardOperandPattern>(converter, &getContext());
1937 
1938     if (failed(applyPartialConversion(getOperation(), target,
1939                                       std::move(patterns))))
1940       signalPassFailure();
1941   }
1942 };
1943 } // namespace
1944 
1945 //===----------------------------------------------------------------------===//
1946 // Test Block Merging
1947 //===----------------------------------------------------------------------===//
1948 
1949 namespace {
1950 /// A rewriter pattern that tests that blocks can be merged.
1951 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
1952   using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
1953 
1954   LogicalResult
1955   matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
1956                   ConversionPatternRewriter &rewriter) const final {
1957     Block &firstBlock = op.getBody().front();
1958     Operation *branchOp = firstBlock.getTerminator();
1959     Block *secondBlock = &*(std::next(op.getBody().begin()));
1960     auto succOperands = branchOp->getOperands();
1961     SmallVector<Value, 2> replacements(succOperands);
1962     rewriter.eraseOp(branchOp);
1963     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1964     rewriter.modifyOpInPlace(op, [] {});
1965     return success();
1966   }
1967 };
1968 
1969 /// A rewrite pattern to tests the undo mechanism of blocks being merged.
1970 struct TestUndoBlocksMerge : public ConversionPattern {
1971   TestUndoBlocksMerge(MLIRContext *ctx)
1972       : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
1973   LogicalResult
1974   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1975                   ConversionPatternRewriter &rewriter) const final {
1976     Block &firstBlock = op->getRegion(0).front();
1977     Operation *branchOp = firstBlock.getTerminator();
1978     Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
1979     rewriter.setInsertionPointToStart(secondBlock);
1980     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
1981     auto succOperands = branchOp->getOperands();
1982     SmallVector<Value, 2> replacements(succOperands);
1983     rewriter.eraseOp(branchOp);
1984     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1985     rewriter.modifyOpInPlace(op, [] {});
1986     return success();
1987   }
1988 };
1989 
1990 /// A rewrite mechanism to inline the body of the op into its parent, when both
1991 /// ops can have a single block.
1992 struct TestMergeSingleBlockOps
1993     : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
1994   using OpConversionPattern<
1995       SingleBlockImplicitTerminatorOp>::OpConversionPattern;
1996 
1997   LogicalResult
1998   matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
1999                   ConversionPatternRewriter &rewriter) const final {
2000     SingleBlockImplicitTerminatorOp parentOp =
2001         op->getParentOfType<SingleBlockImplicitTerminatorOp>();
2002     if (!parentOp)
2003       return failure();
2004     Block &innerBlock = op.getRegion().front();
2005     TerminatorOp innerTerminator =
2006         cast<TerminatorOp>(innerBlock.getTerminator());
2007     rewriter.inlineBlockBefore(&innerBlock, op);
2008     rewriter.eraseOp(innerTerminator);
2009     rewriter.eraseOp(op);
2010     return success();
2011   }
2012 };
2013 
2014 struct TestMergeBlocksPatternDriver
2015     : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> {
2016   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver)
2017 
2018   StringRef getArgument() const final { return "test-merge-blocks"; }
2019   StringRef getDescription() const final {
2020     return "Test Merging operation in ConversionPatternRewriter";
2021   }
2022   void runOnOperation() override {
2023     MLIRContext *context = &getContext();
2024     mlir::RewritePatternSet patterns(context);
2025     patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
2026         context);
2027     ConversionTarget target(*context);
2028     target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
2029                       TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
2030     target.addIllegalOp<ILLegalOpF>();
2031 
2032     /// Expect the op to have a single block after legalization.
2033     target.addDynamicallyLegalOp<TestMergeBlocksOp>(
2034         [&](TestMergeBlocksOp op) -> bool {
2035           return llvm::hasSingleElement(op.getBody());
2036         });
2037 
2038     /// Only allow `test.br` within test.merge_blocks op.
2039     target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
2040       return op->getParentOfType<TestMergeBlocksOp>();
2041     });
2042 
2043     /// Expect that all nested test.SingleBlockImplicitTerminator ops are
2044     /// inlined.
2045     target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
2046         [&](SingleBlockImplicitTerminatorOp op) -> bool {
2047           return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
2048         });
2049 
2050     DenseSet<Operation *> unlegalizedOps;
2051     ConversionConfig config;
2052     config.unlegalizedOps = &unlegalizedOps;
2053     (void)applyPartialConversion(getOperation(), target, std::move(patterns),
2054                                  config);
2055     for (auto *op : unlegalizedOps)
2056       op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
2057   }
2058 };
2059 } // namespace
2060 
2061 //===----------------------------------------------------------------------===//
2062 // Test Selective Replacement
2063 //===----------------------------------------------------------------------===//
2064 
2065 namespace {
2066 /// A rewrite mechanism to inline the body of the op into its parent, when both
2067 /// ops can have a single block.
2068 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
2069   using OpRewritePattern<TestCastOp>::OpRewritePattern;
2070 
2071   LogicalResult matchAndRewrite(TestCastOp op,
2072                                 PatternRewriter &rewriter) const final {
2073     if (op.getNumOperands() != 2)
2074       return failure();
2075     OperandRange operands = op.getOperands();
2076 
2077     // Replace non-terminator uses with the first operand.
2078     rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) {
2079       return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
2080     });
2081     // Replace everything else with the second operand if the operation isn't
2082     // dead.
2083     rewriter.replaceOp(op, op.getOperand(1));
2084     return success();
2085   }
2086 };
2087 
2088 struct TestSelectiveReplacementPatternDriver
2089     : public PassWrapper<TestSelectiveReplacementPatternDriver,
2090                          OperationPass<>> {
2091   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
2092       TestSelectiveReplacementPatternDriver)
2093 
2094   StringRef getArgument() const final {
2095     return "test-pattern-selective-replacement";
2096   }
2097   StringRef getDescription() const final {
2098     return "Test selective replacement in the PatternRewriter";
2099   }
2100   void runOnOperation() override {
2101     MLIRContext *context = &getContext();
2102     mlir::RewritePatternSet patterns(context);
2103     patterns.add<TestSelectiveOpReplacementPattern>(context);
2104     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
2105   }
2106 };
2107 } // namespace
2108 
2109 //===----------------------------------------------------------------------===//
2110 // PassRegistration
2111 //===----------------------------------------------------------------------===//
2112 
2113 namespace mlir {
2114 namespace test {
2115 void registerPatternsTestPass() {
2116   PassRegistration<TestReturnTypeDriver>();
2117 
2118   PassRegistration<TestDerivedAttributeDriver>();
2119 
2120   PassRegistration<TestGreedyPatternDriver>();
2121   PassRegistration<TestStrictPatternDriver>();
2122   PassRegistration<TestWalkPatternDriver>();
2123 
2124   PassRegistration<TestLegalizePatternDriver>([] {
2125     return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
2126   });
2127 
2128   PassRegistration<TestRemappedValue>();
2129 
2130   PassRegistration<TestUnknownRootOpDriver>();
2131 
2132   PassRegistration<TestTypeConversionDriver>();
2133   PassRegistration<TestTargetMaterializationWithNoUses>();
2134 
2135   PassRegistration<TestRewriteDynamicOpDriver>();
2136 
2137   PassRegistration<TestMergeBlocksPatternDriver>();
2138   PassRegistration<TestSelectiveReplacementPatternDriver>();
2139 }
2140 } // namespace test
2141 } // namespace mlir
2142