xref: /llvm-project/mlir/test/lib/Dialect/Test/TestPatterns.cpp (revision 412e30b2274a134c01c2140ac7c7579be70f0896)
1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestOps.h"
11 #include "TestTypes.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BuiltinAttributes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Visitors.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
24 #include "llvm/ADT/ScopeExit.h"
25 #include <cstdint>
26 
27 using namespace mlir;
28 using namespace test;
29 
30 // Native function for testing NativeCodeCall
31 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
32   return choice.getValue() ? input1 : input2;
33 }
34 
35 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
36   rewriter.create<OpI>(loc, input);
37 }
38 
39 static void handleNoResultOp(PatternRewriter &rewriter,
40                              OpSymbolBindingNoResult op) {
41   // Turn the no result op to a one-result op.
42   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
43                                     op.getOperand());
44 }
45 
46 static bool getFirstI32Result(Operation *op, Value &value) {
47   if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
48     return false;
49   value = op->getResult(0);
50   return true;
51 }
52 
53 static Value bindNativeCodeCallResult(Value value) { return value; }
54 
55 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
56                                                               Value input2) {
57   return SmallVector<Value, 2>({input2, input1});
58 }
59 
60 // Test that natives calls are only called once during rewrites.
61 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
62 // This let us check the number of times OpM_Test was called by inspecting
63 // the returned value in the MLIR output.
64 static int64_t opMIncreasingValue = 314159265;
65 static Attribute opMTest(PatternRewriter &rewriter, Value val) {
66   int64_t i = opMIncreasingValue++;
67   return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
68 }
69 
70 namespace {
71 #include "TestPatterns.inc"
72 } // namespace
73 
74 //===----------------------------------------------------------------------===//
75 // Test Reduce Pattern Interface
76 //===----------------------------------------------------------------------===//
77 
78 void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
79   populateWithGenerated(patterns);
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // Canonicalizer Driver.
84 //===----------------------------------------------------------------------===//
85 
86 namespace {
87 struct FoldingPattern : public RewritePattern {
88 public:
89   FoldingPattern(MLIRContext *context)
90       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
91                        /*benefit=*/1, context) {}
92 
93   LogicalResult matchAndRewrite(Operation *op,
94                                 PatternRewriter &rewriter) const override {
95     // Exercise createOrFold API for a single-result operation that is folded
96     // upon construction. The operation being created has an in-place folder,
97     // and it should be still present in the output. Furthermore, the folder
98     // should not crash when attempting to recover the (unchanged) operation
99     // result.
100     Value result = rewriter.createOrFold<TestOpInPlaceFold>(
101         op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0));
102     assert(result);
103     rewriter.replaceOp(op, result);
104     return success();
105   }
106 };
107 
108 /// This pattern creates a foldable operation at the entry point of the block.
109 /// This tests the situation where the operation folder will need to replace an
110 /// operation with a previously created constant that does not initially
111 /// dominate the operation to replace.
112 struct FolderInsertBeforePreviouslyFoldedConstantPattern
113     : public OpRewritePattern<TestCastOp> {
114 public:
115   using OpRewritePattern<TestCastOp>::OpRewritePattern;
116 
117   LogicalResult matchAndRewrite(TestCastOp op,
118                                 PatternRewriter &rewriter) const override {
119     if (!op->hasAttr("test_fold_before_previously_folded_op"))
120       return failure();
121     rewriter.setInsertionPointToStart(op->getBlock());
122 
123     auto constOp = rewriter.create<arith::ConstantOp>(
124         op.getLoc(), rewriter.getBoolAttr(true));
125     rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
126                                             Value(constOp));
127     return success();
128   }
129 };
130 
131 /// This pattern matches test.op_commutative2 with the first operand being
132 /// another test.op_commutative2 with a constant on the right side and fold it
133 /// away by propagating it as its result. This is intend to check that patterns
134 /// are applied after the commutative property moves constant to the right.
135 struct FolderCommutativeOp2WithConstant
136     : public OpRewritePattern<TestCommutative2Op> {
137 public:
138   using OpRewritePattern<TestCommutative2Op>::OpRewritePattern;
139 
140   LogicalResult matchAndRewrite(TestCommutative2Op op,
141                                 PatternRewriter &rewriter) const override {
142     auto operand =
143         dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
144     if (!operand)
145       return failure();
146     Attribute constInput;
147     if (!matchPattern(operand->getOperand(1), m_Constant(&constInput)))
148       return failure();
149     rewriter.replaceOp(op, operand->getOperand(1));
150     return success();
151   }
152 };
153 
154 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer
155 /// attribute with value smaller than MaxVal, it increments the value by 1.
156 template <int MaxVal>
157 struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
158   using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern;
159 
160   LogicalResult matchAndRewrite(AnyAttrOfOp op,
161                                 PatternRewriter &rewriter) const override {
162     auto intAttr = dyn_cast<IntegerAttr>(op.getAttr());
163     if (!intAttr)
164       return failure();
165     int64_t val = intAttr.getInt();
166     if (val >= MaxVal)
167       return failure();
168     rewriter.modifyOpInPlace(
169         op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); });
170     return success();
171   }
172 };
173 
174 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op".
175 struct MakeOpEligible : public RewritePattern {
176   MakeOpEligible(MLIRContext *context)
177       : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {}
178 
179   LogicalResult matchAndRewrite(Operation *op,
180                                 PatternRewriter &rewriter) const override {
181     if (op->hasAttr("eligible"))
182       return failure();
183     rewriter.modifyOpInPlace(
184         op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); });
185     return success();
186   }
187 };
188 
189 /// This pattern hoists eligible ops out of a "test.one_region_op".
190 struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
191   using OpRewritePattern<test::OneRegionOp>::OpRewritePattern;
192 
193   LogicalResult matchAndRewrite(test::OneRegionOp op,
194                                 PatternRewriter &rewriter) const override {
195     Operation *terminator = op.getRegion().front().getTerminator();
196     Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp();
197     if (toBeHoisted->getParentOp() != op)
198       return failure();
199     if (!toBeHoisted->hasAttr("eligible"))
200       return failure();
201     rewriter.moveOpBefore(toBeHoisted, op);
202     return success();
203   }
204 };
205 
206 /// This pattern moves "test.move_before_parent_op" before the parent op.
207 struct MoveBeforeParentOp : public RewritePattern {
208   MoveBeforeParentOp(MLIRContext *context)
209       : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
210 
211   LogicalResult matchAndRewrite(Operation *op,
212                                 PatternRewriter &rewriter) const override {
213     // Do not hoist past functions.
214     if (isa<FunctionOpInterface>(op->getParentOp()))
215       return failure();
216     rewriter.moveOpBefore(op, op->getParentOp());
217     return success();
218   }
219 };
220 
221 /// This pattern moves "test.move_after_parent_op" after the parent op.
222 struct MoveAfterParentOp : public RewritePattern {
223   MoveAfterParentOp(MLIRContext *context)
224       : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {}
225 
226   LogicalResult matchAndRewrite(Operation *op,
227                                 PatternRewriter &rewriter) const override {
228     // Do not hoist past functions.
229     if (isa<FunctionOpInterface>(op->getParentOp()))
230       return failure();
231 
232     int64_t moveForwardBy = 0;
233     if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance"))
234       moveForwardBy = advanceBy.getInt();
235 
236     Operation *moveAfter = op->getParentOp();
237     for (int64_t i = 0; i < moveForwardBy; ++i)
238       moveAfter = moveAfter->getNextNode();
239 
240     rewriter.moveOpAfter(op, moveAfter);
241     return success();
242   }
243 };
244 
245 /// This pattern inlines blocks that are nested in
246 /// "test.inline_blocks_into_parent" into the parent block.
247 struct InlineBlocksIntoParent : public RewritePattern {
248   InlineBlocksIntoParent(MLIRContext *context)
249       : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
250                        context) {}
251 
252   LogicalResult matchAndRewrite(Operation *op,
253                                 PatternRewriter &rewriter) const override {
254     bool changed = false;
255     for (Region &r : op->getRegions()) {
256       while (!r.empty()) {
257         rewriter.inlineBlockBefore(&r.front(), op);
258         changed = true;
259       }
260     }
261     return success(changed);
262   }
263 };
264 
265 /// This pattern splits blocks at "test.split_block_here" and replaces the op
266 /// with a new op (to prevent an infinite loop of block splitting).
267 struct SplitBlockHere : public RewritePattern {
268   SplitBlockHere(MLIRContext *context)
269       : RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
270 
271   LogicalResult matchAndRewrite(Operation *op,
272                                 PatternRewriter &rewriter) const override {
273     rewriter.splitBlock(op->getBlock(), op->getIterator());
274     Operation *newOp = rewriter.create(
275         op->getLoc(),
276         OperationName("test.new_op", op->getContext()).getIdentifier(),
277         op->getOperands(), op->getResultTypes());
278     rewriter.replaceOp(op, newOp);
279     return success();
280   }
281 };
282 
283 /// This pattern clones "test.clone_me" ops.
284 struct CloneOp : public RewritePattern {
285   CloneOp(MLIRContext *context)
286       : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
287 
288   LogicalResult matchAndRewrite(Operation *op,
289                                 PatternRewriter &rewriter) const override {
290     // Do not clone already cloned ops to avoid going into an infinite loop.
291     if (op->hasAttr("was_cloned"))
292       return failure();
293     Operation *cloned = rewriter.clone(*op);
294     cloned->setAttr("was_cloned", rewriter.getUnitAttr());
295     return success();
296   }
297 };
298 
299 /// This pattern clones regions of "test.clone_region_before" ops before the
300 /// parent block.
301 struct CloneRegionBeforeOp : public RewritePattern {
302   CloneRegionBeforeOp(MLIRContext *context)
303       : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {}
304 
305   LogicalResult matchAndRewrite(Operation *op,
306                                 PatternRewriter &rewriter) const override {
307     // Do not clone already cloned ops to avoid going into an infinite loop.
308     if (op->hasAttr("was_cloned"))
309       return failure();
310     for (Region &r : op->getRegions())
311       rewriter.cloneRegionBefore(r, op->getBlock());
312     op->setAttr("was_cloned", rewriter.getUnitAttr());
313     return success();
314   }
315 };
316 
317 /// Replace an operation may introduce the re-visiting of its users.
318 class ReplaceWithNewOp : public RewritePattern {
319 public:
320   ReplaceWithNewOp(MLIRContext *context)
321       : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
322 
323   LogicalResult matchAndRewrite(Operation *op,
324                                 PatternRewriter &rewriter) const override {
325     Operation *newOp;
326     if (op->hasAttr("create_erase_op")) {
327       newOp = rewriter.create(
328           op->getLoc(),
329           OperationName("test.erase_op", op->getContext()).getIdentifier(),
330           ValueRange(), TypeRange());
331     } else {
332       newOp = rewriter.create(
333           op->getLoc(),
334           OperationName("test.new_op", op->getContext()).getIdentifier(),
335           op->getOperands(), op->getResultTypes());
336     }
337     // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
338     // A "notifyOperationReplaced" callback is triggered in either case.
339     rewriter.replaceAllOpUsesWith(op, newOp->getResults());
340     rewriter.eraseOp(op);
341     return success();
342   }
343 };
344 
345 /// Erases the first child block of the matched "test.erase_first_block"
346 /// operation.
347 class EraseFirstBlock : public RewritePattern {
348 public:
349   EraseFirstBlock(MLIRContext *context)
350       : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {}
351 
352   LogicalResult matchAndRewrite(Operation *op,
353                                 PatternRewriter &rewriter) const override {
354     for (Region &r : op->getRegions()) {
355       for (Block &b : r.getBlocks()) {
356         rewriter.eraseBlock(&b);
357         return success();
358       }
359     }
360 
361     return failure();
362   }
363 };
364 
365 struct TestGreedyPatternDriver
366     : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> {
367   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver)
368 
369   TestGreedyPatternDriver() = default;
370   TestGreedyPatternDriver(const TestGreedyPatternDriver &other)
371       : PassWrapper(other) {}
372 
373   StringRef getArgument() const final { return "test-greedy-patterns"; }
374   StringRef getDescription() const final { return "Run test dialect patterns"; }
375   void runOnOperation() override {
376     mlir::RewritePatternSet patterns(&getContext());
377     populateWithGenerated(patterns);
378 
379     // Verify named pattern is generated with expected name.
380     patterns.add<FoldingPattern, TestNamedPatternRule,
381                  FolderInsertBeforePreviouslyFoldedConstantPattern,
382                  FolderCommutativeOp2WithConstant, HoistEligibleOps,
383                  MakeOpEligible>(&getContext());
384 
385     // Additional patterns for testing the GreedyPatternRewriteDriver.
386     patterns.insert<IncrementIntAttribute<3>>(&getContext());
387 
388     GreedyRewriteConfig config;
389     config.useTopDownTraversal = this->useTopDownTraversal;
390     config.maxIterations = this->maxIterations;
391     config.fold = this->fold;
392     config.cseConstants = this->cseConstants;
393     (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
394   }
395 
396   Option<bool> useTopDownTraversal{
397       *this, "top-down",
398       llvm::cl::desc("Seed the worklist in general top-down order"),
399       llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
400   Option<int> maxIterations{
401       *this, "max-iterations",
402       llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
403       llvm::cl::init(GreedyRewriteConfig().maxIterations)};
404   Option<bool> fold{*this, "fold", llvm::cl::desc("Whether to fold"),
405                     llvm::cl::init(GreedyRewriteConfig().fold)};
406   Option<bool> cseConstants{*this, "cse-constants",
407                             llvm::cl::desc("Whether to CSE constants"),
408                             llvm::cl::init(GreedyRewriteConfig().cseConstants)};
409 };
410 
411 struct DumpNotifications : public RewriterBase::Listener {
412   void notifyBlockInserted(Block *block, Region *previous,
413                            Region::iterator previousIt) override {
414     llvm::outs() << "notifyBlockInserted";
415     if (block->getParentOp()) {
416       llvm::outs() << " into " << block->getParentOp()->getName() << ": ";
417     } else {
418       llvm::outs() << " into unknown op: ";
419     }
420     if (previous == nullptr) {
421       llvm::outs() << "was unlinked\n";
422     } else {
423       llvm::outs() << "was linked\n";
424     }
425   }
426   void notifyOperationInserted(Operation *op,
427                                OpBuilder::InsertPoint previous) override {
428     llvm::outs() << "notifyOperationInserted: " << op->getName();
429     if (!previous.isSet()) {
430       llvm::outs() << ", was unlinked\n";
431     } else {
432       if (!previous.getPoint().getNodePtr()) {
433         llvm::outs() << ", was linked, exact position unknown\n";
434       } else if (previous.getPoint() == previous.getBlock()->end()) {
435         llvm::outs() << ", was last in block\n";
436       } else {
437         llvm::outs() << ", previous = " << previous.getPoint()->getName()
438                      << "\n";
439       }
440     }
441   }
442   void notifyBlockErased(Block *block) override {
443     llvm::outs() << "notifyBlockErased\n";
444   }
445   void notifyOperationErased(Operation *op) override {
446     llvm::outs() << "notifyOperationErased: " << op->getName() << "\n";
447   }
448   void notifyOperationModified(Operation *op) override {
449     llvm::outs() << "notifyOperationModified: " << op->getName() << "\n";
450   }
451   void notifyOperationReplaced(Operation *op, ValueRange values) override {
452     llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n";
453   }
454 };
455 
456 struct TestStrictPatternDriver
457     : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
458 public:
459   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
460 
461   TestStrictPatternDriver() = default;
462   TestStrictPatternDriver(const TestStrictPatternDriver &other)
463       : PassWrapper(other) {
464     strictMode = other.strictMode;
465   }
466 
467   StringRef getArgument() const final { return "test-strict-pattern-driver"; }
468   StringRef getDescription() const final {
469     return "Test strict mode of pattern driver";
470   }
471 
472   void runOnOperation() override {
473     MLIRContext *ctx = &getContext();
474     mlir::RewritePatternSet patterns(ctx);
475     patterns.add<
476         // clang-format off
477         ChangeBlockOp,
478         CloneOp,
479         CloneRegionBeforeOp,
480         EraseOp,
481         ImplicitChangeOp,
482         InlineBlocksIntoParent,
483         InsertSameOp,
484         MoveBeforeParentOp,
485         ReplaceWithNewOp,
486         SplitBlockHere
487         // clang-format on
488         >(ctx);
489     SmallVector<Operation *> ops;
490     getOperation()->walk([&](Operation *op) {
491       StringRef opName = op->getName().getStringRef();
492       if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
493           opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
494           opName == "test.move_before_parent_op" ||
495           opName == "test.inline_blocks_into_parent" ||
496           opName == "test.split_block_here" || opName == "test.clone_me" ||
497           opName == "test.clone_region_before") {
498         ops.push_back(op);
499       }
500     });
501 
502     DumpNotifications dumpNotifications;
503     GreedyRewriteConfig config;
504     config.listener = &dumpNotifications;
505     if (strictMode == "AnyOp") {
506       config.strictMode = GreedyRewriteStrictness::AnyOp;
507     } else if (strictMode == "ExistingAndNewOps") {
508       config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
509     } else if (strictMode == "ExistingOps") {
510       config.strictMode = GreedyRewriteStrictness::ExistingOps;
511     } else {
512       llvm_unreachable("invalid strictness option");
513     }
514 
515     // Check if these transformations introduce visiting of operations that
516     // are not in the `ops` set (The new created ops are valid). An invalid
517     // operation will trigger the assertion while processing.
518     bool changed = false;
519     bool allErased = false;
520     (void)applyOpPatternsGreedily(ArrayRef(ops), std::move(patterns), config,
521                                   &changed, &allErased);
522     Builder b(ctx);
523     getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed));
524     getOperation()->setAttr("pattern_driver_all_erased",
525                             b.getBoolAttr(allErased));
526   }
527 
528   Option<std::string> strictMode{
529       *this, "strictness",
530       llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"),
531       llvm::cl::init("AnyOp")};
532 
533 private:
534   // New inserted operation is valid for further transformation.
535   class InsertSameOp : public RewritePattern {
536   public:
537     InsertSameOp(MLIRContext *context)
538         : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {}
539 
540     LogicalResult matchAndRewrite(Operation *op,
541                                   PatternRewriter &rewriter) const override {
542       if (op->hasAttr("skip"))
543         return failure();
544 
545       Operation *newOp =
546           rewriter.create(op->getLoc(), op->getName().getIdentifier(),
547                           op->getOperands(), op->getResultTypes());
548       rewriter.modifyOpInPlace(
549           op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); });
550       newOp->setAttr("skip", rewriter.getBoolAttr(true));
551 
552       return success();
553     }
554   };
555 
556   // Remove an operation may introduce the re-visiting of its operands.
557   class EraseOp : public RewritePattern {
558   public:
559     EraseOp(MLIRContext *context)
560         : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
561     LogicalResult matchAndRewrite(Operation *op,
562                                   PatternRewriter &rewriter) const override {
563       rewriter.eraseOp(op);
564       return success();
565     }
566   };
567 
568   // The following two patterns test RewriterBase::replaceAllUsesWith.
569   //
570   // That function replaces all usages of a Block (or a Value) with another one
571   // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver
572   // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its
573   // worklist: when an op is modified, it is added to the worklist. The two
574   // patterns below make the tracking observable: ChangeBlockOp replaces all
575   // usages of a block and that pattern is applied because the corresponding ops
576   // are put on the initial worklist (see above). ImplicitChangeOp does an
577   // unrelated change but ops of the corresponding type are *not* on the initial
578   // worklist, so the effect of the second pattern is only visible if the
579   // tracking and subsequent adding to the worklist actually works.
580 
581   // Replace all usages of the first successor with the second successor.
582   class ChangeBlockOp : public RewritePattern {
583   public:
584     ChangeBlockOp(MLIRContext *context)
585         : RewritePattern("test.change_block_op", /*benefit=*/1, context) {}
586     LogicalResult matchAndRewrite(Operation *op,
587                                   PatternRewriter &rewriter) const override {
588       if (op->getNumSuccessors() < 2)
589         return failure();
590       Block *firstSuccessor = op->getSuccessor(0);
591       Block *secondSuccessor = op->getSuccessor(1);
592       if (firstSuccessor == secondSuccessor)
593         return failure();
594       // This is the function being tested:
595       rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor);
596       // Using the following line instead would make the test fail:
597       // firstSuccessor->replaceAllUsesWith(secondSuccessor);
598       return success();
599     }
600   };
601 
602   // Changes the successor to the parent block.
603   class ImplicitChangeOp : public RewritePattern {
604   public:
605     ImplicitChangeOp(MLIRContext *context)
606         : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {}
607     LogicalResult matchAndRewrite(Operation *op,
608                                   PatternRewriter &rewriter) const override {
609       if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock())
610         return failure();
611       rewriter.modifyOpInPlace(op,
612                                [&]() { op->setSuccessor(op->getBlock(), 0); });
613       return success();
614     }
615   };
616 };
617 
618 struct TestWalkPatternDriver final
619     : PassWrapper<TestWalkPatternDriver, OperationPass<>> {
620   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver)
621 
622   TestWalkPatternDriver() = default;
623   TestWalkPatternDriver(const TestWalkPatternDriver &other)
624       : PassWrapper(other) {}
625 
626   StringRef getArgument() const override {
627     return "test-walk-pattern-rewrite-driver";
628   }
629   StringRef getDescription() const override {
630     return "Run test walk pattern rewrite driver";
631   }
632   void runOnOperation() override {
633     mlir::RewritePatternSet patterns(&getContext());
634 
635     // Patterns for testing the WalkPatternRewriteDriver.
636     patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp,
637                  MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>(
638         &getContext());
639 
640     DumpNotifications dumpListener;
641     walkAndApplyPatterns(getOperation(), std::move(patterns),
642                          dumpNotifications ? &dumpListener : nullptr);
643   }
644 
645   Option<bool> dumpNotifications{
646       *this, "dump-notifications",
647       llvm::cl::desc("Print rewrite listener notifications"),
648       llvm::cl::init(false)};
649 };
650 
651 } // namespace
652 
653 //===----------------------------------------------------------------------===//
654 // ReturnType Driver.
655 //===----------------------------------------------------------------------===//
656 
657 namespace {
658 // Generate ops for each instance where the type can be successfully inferred.
659 template <typename OpTy>
660 static void invokeCreateWithInferredReturnType(Operation *op) {
661   auto *context = op->getContext();
662   auto fop = op->getParentOfType<func::FuncOp>();
663   auto location = UnknownLoc::get(context);
664   OpBuilder b(op);
665   b.setInsertionPointAfter(op);
666 
667   // Use permutations of 2 args as operands.
668   assert(fop.getNumArguments() >= 2);
669   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
670     for (int j = 0; j < e; ++j) {
671       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
672       SmallVector<Type, 2> inferredReturnTypes;
673       if (succeeded(OpTy::inferReturnTypes(
674               context, std::nullopt, values, op->getDiscardableAttrDictionary(),
675               op->getPropertiesStorage(), op->getRegions(),
676               inferredReturnTypes))) {
677         OperationState state(location, OpTy::getOperationName());
678         // TODO: Expand to regions.
679         OpTy::build(b, state, values, op->getAttrs());
680         (void)b.create(state);
681       }
682     }
683   }
684 }
685 
686 static void reifyReturnShape(Operation *op) {
687   OpBuilder b(op);
688 
689   // Use permutations of 2 args as operands.
690   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
691   SmallVector<Value, 2> shapes;
692   if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
693       !llvm::hasSingleElement(shapes))
694     return;
695   for (const auto &it : llvm::enumerate(shapes)) {
696     op->emitRemark() << "value " << it.index() << ": "
697                      << it.value().getDefiningOp();
698   }
699 }
700 
701 struct TestReturnTypeDriver
702     : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> {
703   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver)
704 
705   void getDependentDialects(DialectRegistry &registry) const override {
706     registry.insert<tensor::TensorDialect>();
707   }
708   StringRef getArgument() const final { return "test-return-type"; }
709   StringRef getDescription() const final { return "Run return type functions"; }
710 
711   void runOnOperation() override {
712     if (getOperation().getName() == "testCreateFunctions") {
713       std::vector<Operation *> ops;
714       // Collect ops to avoid triggering on inserted ops.
715       for (auto &op : getOperation().getBody().front())
716         ops.push_back(&op);
717       // Generate test patterns for each, but skip terminator.
718       for (auto *op : llvm::ArrayRef(ops).drop_back()) {
719         // Test create method of each of the Op classes below. The resultant
720         // output would be in reverse order underneath `op` from which
721         // the attributes and regions are used.
722         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
723         invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>(
724             op);
725         invokeCreateWithInferredReturnType<
726             OpWithShapedTypeInferTypeInterfaceOp>(op);
727       };
728       return;
729     }
730     if (getOperation().getName() == "testReifyFunctions") {
731       std::vector<Operation *> ops;
732       // Collect ops to avoid triggering on inserted ops.
733       for (auto &op : getOperation().getBody().front())
734         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
735           ops.push_back(&op);
736       // Generate test patterns for each, but skip terminator.
737       for (auto *op : ops)
738         reifyReturnShape(op);
739     }
740   }
741 };
742 } // namespace
743 
744 namespace {
745 struct TestDerivedAttributeDriver
746     : public PassWrapper<TestDerivedAttributeDriver,
747                          OperationPass<func::FuncOp>> {
748   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver)
749 
750   StringRef getArgument() const final { return "test-derived-attr"; }
751   StringRef getDescription() const final {
752     return "Run test derived attributes";
753   }
754   void runOnOperation() override;
755 };
756 } // namespace
757 
758 void TestDerivedAttributeDriver::runOnOperation() {
759   getOperation().walk([](DerivedAttributeOpInterface dOp) {
760     auto dAttr = dOp.materializeDerivedAttributes();
761     if (!dAttr)
762       return;
763     for (auto d : dAttr)
764       dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue();
765   });
766 }
767 
768 //===----------------------------------------------------------------------===//
769 // Legalization Driver.
770 //===----------------------------------------------------------------------===//
771 
772 namespace {
773 //===----------------------------------------------------------------------===//
774 // Region-Block Rewrite Testing
775 
776 /// This pattern applies a signature conversion to a block inside a detached
777 /// region.
778 struct TestDetachedSignatureConversion : public ConversionPattern {
779   TestDetachedSignatureConversion(MLIRContext *ctx)
780       : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1,
781                           ctx) {}
782 
783   LogicalResult
784   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
785                   ConversionPatternRewriter &rewriter) const final {
786     if (op->getNumRegions() != 1)
787       return failure();
788     OperationState state(op->getLoc(), "test.legal_op", operands,
789                          op->getResultTypes(), {}, BlockRange());
790     Region *newRegion = state.addRegion();
791     rewriter.inlineRegionBefore(op->getRegion(0), *newRegion,
792                                 newRegion->begin());
793     TypeConverter::SignatureConversion result(newRegion->getNumArguments());
794     for (unsigned i = 0, e = newRegion->getNumArguments(); i < e; ++i)
795       result.addInputs(i, rewriter.getF64Type());
796     rewriter.applySignatureConversion(&newRegion->front(), result);
797     Operation *newOp = rewriter.create(state);
798     rewriter.replaceOp(op, newOp->getResults());
799     return success();
800   }
801 };
802 
803 /// This pattern is a simple pattern that inlines the first region of a given
804 /// operation into the parent region.
805 struct TestRegionRewriteBlockMovement : public ConversionPattern {
806   TestRegionRewriteBlockMovement(MLIRContext *ctx)
807       : ConversionPattern("test.region", 1, ctx) {}
808 
809   LogicalResult
810   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
811                   ConversionPatternRewriter &rewriter) const final {
812     // Inline this region into the parent region.
813     auto &parentRegion = *op->getParentRegion();
814     auto &opRegion = op->getRegion(0);
815     if (op->getDiscardableAttr("legalizer.should_clone"))
816       rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
817     else
818       rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
819 
820     if (op->getDiscardableAttr("legalizer.erase_old_blocks")) {
821       while (!opRegion.empty())
822         rewriter.eraseBlock(&opRegion.front());
823     }
824 
825     // Drop this operation.
826     rewriter.eraseOp(op);
827     return success();
828   }
829 };
830 /// This pattern is a simple pattern that generates a region containing an
831 /// illegal operation.
832 struct TestRegionRewriteUndo : public RewritePattern {
833   TestRegionRewriteUndo(MLIRContext *ctx)
834       : RewritePattern("test.region_builder", 1, ctx) {}
835 
836   LogicalResult matchAndRewrite(Operation *op,
837                                 PatternRewriter &rewriter) const final {
838     // Create the region operation with an entry block containing arguments.
839     OperationState newRegion(op->getLoc(), "test.region");
840     newRegion.addRegion();
841     auto *regionOp = rewriter.create(newRegion);
842     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
843     entryBlock->addArgument(rewriter.getIntegerType(64),
844                             rewriter.getUnknownLoc());
845 
846     // Add an explicitly illegal operation to ensure the conversion fails.
847     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
848     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
849 
850     // Drop this operation.
851     rewriter.eraseOp(op);
852     return success();
853   }
854 };
855 /// A simple pattern that creates a block at the end of the parent region of the
856 /// matched operation.
857 struct TestCreateBlock : public RewritePattern {
858   TestCreateBlock(MLIRContext *ctx)
859       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
860 
861   LogicalResult matchAndRewrite(Operation *op,
862                                 PatternRewriter &rewriter) const final {
863     Region &region = *op->getParentRegion();
864     Type i32Type = rewriter.getIntegerType(32);
865     Location loc = op->getLoc();
866     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
867     rewriter.create<TerminatorOp>(loc);
868     rewriter.eraseOp(op);
869     return success();
870   }
871 };
872 
873 /// A simple pattern that creates a block containing an invalid operation in
874 /// order to trigger the block creation undo mechanism.
875 struct TestCreateIllegalBlock : public RewritePattern {
876   TestCreateIllegalBlock(MLIRContext *ctx)
877       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
878 
879   LogicalResult matchAndRewrite(Operation *op,
880                                 PatternRewriter &rewriter) const final {
881     Region &region = *op->getParentRegion();
882     Type i32Type = rewriter.getIntegerType(32);
883     Location loc = op->getLoc();
884     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
885     // Create an illegal op to ensure the conversion fails.
886     rewriter.create<ILLegalOpF>(loc, i32Type);
887     rewriter.create<TerminatorOp>(loc);
888     rewriter.eraseOp(op);
889     return success();
890   }
891 };
892 
893 /// A simple pattern that tests the undo mechanism when replacing the uses of a
894 /// block argument.
895 struct TestUndoBlockArgReplace : public ConversionPattern {
896   TestUndoBlockArgReplace(MLIRContext *ctx)
897       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
898 
899   LogicalResult
900   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
901                   ConversionPatternRewriter &rewriter) const final {
902     auto illegalOp =
903         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
904     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
905                                         illegalOp->getResult(0));
906     rewriter.modifyOpInPlace(op, [] {});
907     return success();
908   }
909 };
910 
911 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
912 /// This is to test the rollback logic.
913 struct TestUndoMoveOpBefore : public ConversionPattern {
914   TestUndoMoveOpBefore(MLIRContext *ctx)
915       : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
916 
917   LogicalResult
918   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
919                   ConversionPatternRewriter &rewriter) const override {
920     rewriter.moveOpBefore(op, op->getParentOp());
921     // Replace with an illegal op to ensure the conversion fails.
922     rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
923     return success();
924   }
925 };
926 
927 /// A rewrite pattern that tests the undo mechanism when erasing a block.
928 struct TestUndoBlockErase : public ConversionPattern {
929   TestUndoBlockErase(MLIRContext *ctx)
930       : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
931 
932   LogicalResult
933   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
934                   ConversionPatternRewriter &rewriter) const final {
935     Block *secondBlock = &*std::next(op->getRegion(0).begin());
936     rewriter.setInsertionPointToStart(secondBlock);
937     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
938     rewriter.eraseBlock(secondBlock);
939     rewriter.modifyOpInPlace(op, [] {});
940     return success();
941   }
942 };
943 
944 /// A pattern that modifies a property in-place, but keeps the op illegal.
945 struct TestUndoPropertiesModification : public ConversionPattern {
946   TestUndoPropertiesModification(MLIRContext *ctx)
947       : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {}
948   LogicalResult
949   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
950                   ConversionPatternRewriter &rewriter) const final {
951     if (!op->hasAttr("modify_inplace"))
952       return failure();
953     rewriter.modifyOpInPlace(
954         op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); });
955     return success();
956   }
957 };
958 
959 //===----------------------------------------------------------------------===//
960 // Type-Conversion Rewrite Testing
961 
962 /// This patterns erases a region operation that has had a type conversion.
963 struct TestDropOpSignatureConversion : public ConversionPattern {
964   TestDropOpSignatureConversion(MLIRContext *ctx,
965                                 const TypeConverter &converter)
966       : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
967   LogicalResult
968   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
969                   ConversionPatternRewriter &rewriter) const override {
970     Region &region = op->getRegion(0);
971     Block *entry = &region.front();
972 
973     // Convert the original entry arguments.
974     const TypeConverter &converter = *getTypeConverter();
975     TypeConverter::SignatureConversion result(entry->getNumArguments());
976     if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
977                                               result)) ||
978         failed(rewriter.convertRegionTypes(&region, converter, &result)))
979       return failure();
980 
981     // Convert the region signature and just drop the operation.
982     rewriter.eraseOp(op);
983     return success();
984   }
985 };
986 /// This pattern simply updates the operands of the given operation.
987 struct TestPassthroughInvalidOp : public ConversionPattern {
988   TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
989       : ConversionPattern(converter, "test.invalid", 1, ctx) {}
990   LogicalResult
991   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
992                   ConversionPatternRewriter &rewriter) const final {
993     SmallVector<Value> flattened;
994     for (auto it : llvm::enumerate(operands)) {
995       ValueRange range = it.value();
996       if (range.size() == 1) {
997         flattened.push_back(range.front());
998         continue;
999       }
1000 
1001       // This is a 1:N replacement. Insert a test.cast op. (That's what the
1002       // argument materialization used to do.)
1003       flattened.push_back(
1004           rewriter
1005               .create<TestCastOp>(op->getLoc(),
1006                                   op->getOperand(it.index()).getType(), range)
1007               .getResult());
1008     }
1009     rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, flattened,
1010                                              std::nullopt);
1011     return success();
1012   }
1013 };
1014 /// Replace with valid op, but simply drop the operands. This is used in a
1015 /// regression where we used to generate circular unrealized_conversion_cast
1016 /// ops.
1017 struct TestDropAndReplaceInvalidOp : public ConversionPattern {
1018   TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
1019       : ConversionPattern(converter,
1020                           "test.drop_operands_and_replace_with_valid", 1, ctx) {
1021   }
1022   LogicalResult
1023   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1024                   ConversionPatternRewriter &rewriter) const final {
1025     rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, ValueRange(),
1026                                              std::nullopt);
1027     return success();
1028   }
1029 };
1030 /// This pattern handles the case of a split return value.
1031 struct TestSplitReturnType : public ConversionPattern {
1032   TestSplitReturnType(MLIRContext *ctx)
1033       : ConversionPattern("test.return", 1, ctx) {}
1034   LogicalResult
1035   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1036                   ConversionPatternRewriter &rewriter) const final {
1037     // Check for a return of F32.
1038     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
1039       return failure();
1040     rewriter.replaceOpWithNewOp<TestReturnOp>(op, operands[0]);
1041     return success();
1042   }
1043 };
1044 
1045 //===----------------------------------------------------------------------===//
1046 // Multi-Level Type-Conversion Rewrite Testing
1047 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
1048   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
1049       : ConversionPattern("test.type_producer", 1, ctx) {}
1050   LogicalResult
1051   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1052                   ConversionPatternRewriter &rewriter) const final {
1053     // If the type is I32, change the type to F32.
1054     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
1055       return failure();
1056     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
1057     return success();
1058   }
1059 };
1060 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
1061   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
1062       : ConversionPattern("test.type_producer", 1, ctx) {}
1063   LogicalResult
1064   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1065                   ConversionPatternRewriter &rewriter) const final {
1066     // If the type is F32, change the type to F64.
1067     if (!Type(*op->result_type_begin()).isF32())
1068       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
1069     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
1070     return success();
1071   }
1072 };
1073 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
1074   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
1075       : ConversionPattern("test.type_producer", 10, ctx) {}
1076   LogicalResult
1077   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1078                   ConversionPatternRewriter &rewriter) const final {
1079     // Always convert to B16, even though it is not a legal type. This tests
1080     // that values are unmapped correctly.
1081     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
1082     return success();
1083   }
1084 };
1085 struct TestUpdateConsumerType : public ConversionPattern {
1086   TestUpdateConsumerType(MLIRContext *ctx)
1087       : ConversionPattern("test.type_consumer", 1, ctx) {}
1088   LogicalResult
1089   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1090                   ConversionPatternRewriter &rewriter) const final {
1091     // Verify that the incoming operand has been successfully remapped to F64.
1092     if (!operands[0].getType().isF64())
1093       return failure();
1094     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
1095     return success();
1096   }
1097 };
1098 
1099 //===----------------------------------------------------------------------===//
1100 // Non-Root Replacement Rewrite Testing
1101 /// This pattern generates an invalid operation, but replaces it before the
1102 /// pattern is finished. This checks that we don't need to legalize the
1103 /// temporary op.
1104 struct TestNonRootReplacement : public RewritePattern {
1105   TestNonRootReplacement(MLIRContext *ctx)
1106       : RewritePattern("test.replace_non_root", 1, ctx) {}
1107 
1108   LogicalResult matchAndRewrite(Operation *op,
1109                                 PatternRewriter &rewriter) const final {
1110     auto resultType = *op->result_type_begin();
1111     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
1112     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
1113 
1114     rewriter.replaceOp(illegalOp, legalOp);
1115     rewriter.replaceOp(op, illegalOp);
1116     return success();
1117   }
1118 };
1119 
1120 //===----------------------------------------------------------------------===//
1121 // Recursive Rewrite Testing
1122 /// This pattern is applied to the same operation multiple times, but has a
1123 /// bounded recursion.
1124 struct TestBoundedRecursiveRewrite
1125     : public OpRewritePattern<TestRecursiveRewriteOp> {
1126   using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
1127 
1128   void initialize() {
1129     // The conversion target handles bounding the recursion of this pattern.
1130     setHasBoundedRewriteRecursion();
1131   }
1132 
1133   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
1134                                 PatternRewriter &rewriter) const final {
1135     // Decrement the depth of the op in-place.
1136     rewriter.modifyOpInPlace(op, [&] {
1137       op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
1138     });
1139     return success();
1140   }
1141 };
1142 
1143 struct TestNestedOpCreationUndoRewrite
1144     : public OpRewritePattern<IllegalOpWithRegionAnchor> {
1145   using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
1146 
1147   LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
1148                                 PatternRewriter &rewriter) const final {
1149     // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1150     rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1151     return success();
1152   };
1153 };
1154 
1155 // This pattern matches `test.blackhole` and delete this op and its producer.
1156 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
1157   using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
1158 
1159   LogicalResult matchAndRewrite(BlackHoleOp op,
1160                                 PatternRewriter &rewriter) const final {
1161     Operation *producer = op.getOperand().getDefiningOp();
1162     // Always erase the user before the producer, the framework should handle
1163     // this correctly.
1164     rewriter.eraseOp(op);
1165     rewriter.eraseOp(producer);
1166     return success();
1167   };
1168 };
1169 
1170 // This pattern replaces explicitly illegal op with explicitly legal op,
1171 // but in addition creates unregistered operation.
1172 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
1173   using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
1174 
1175   LogicalResult matchAndRewrite(ILLegalOpG op,
1176                                 PatternRewriter &rewriter) const final {
1177     IntegerAttr attr = rewriter.getI32IntegerAttr(0);
1178     Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
1179     rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
1180     return success();
1181   };
1182 };
1183 
1184 class TestEraseOp : public ConversionPattern {
1185 public:
1186   TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op", 1, ctx) {}
1187   LogicalResult
1188   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1189                   ConversionPatternRewriter &rewriter) const final {
1190     // Erase op without replacements.
1191     rewriter.eraseOp(op);
1192     return success();
1193   }
1194 };
1195 
1196 /// This pattern matches a test.duplicate_block_args op and duplicates all
1197 /// block arguments.
1198 class TestDuplicateBlockArgs
1199     : public OpConversionPattern<DuplicateBlockArgsOp> {
1200   using OpConversionPattern<DuplicateBlockArgsOp>::OpConversionPattern;
1201 
1202   LogicalResult
1203   matchAndRewrite(DuplicateBlockArgsOp op, OpAdaptor adaptor,
1204                   ConversionPatternRewriter &rewriter) const override {
1205     if (op.getIsLegal())
1206       return failure();
1207     rewriter.startOpModification(op);
1208     Block *body = &op.getBody().front();
1209     TypeConverter::SignatureConversion result(body->getNumArguments());
1210     for (auto it : llvm::enumerate(body->getArgumentTypes()))
1211       result.addInputs(it.index(), {it.value(), it.value()});
1212     rewriter.applySignatureConversion(body, result, getTypeConverter());
1213     op.setIsLegal(true);
1214     rewriter.finalizeOpModification(op);
1215     return success();
1216   }
1217 };
1218 
1219 /// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid
1220 /// op. The pattern supports 1:N replacements and forwards the replacement
1221 /// values of the single operand as test.valid operands.
1222 class TestRepetitive1ToNConsumer : public ConversionPattern {
1223 public:
1224   TestRepetitive1ToNConsumer(MLIRContext *ctx)
1225       : ConversionPattern("test.repetitive_1_to_n_consumer", 1, ctx) {}
1226   LogicalResult
1227   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1228                   ConversionPatternRewriter &rewriter) const final {
1229     // A single operand is expected.
1230     if (op->getNumOperands() != 1)
1231       return failure();
1232     rewriter.replaceOpWithNewOp<TestValidOp>(op, operands.front());
1233     return success();
1234   }
1235 };
1236 
1237 /// A pattern that tests two back-to-back 1 -> 2 op replacements.
1238 class TestMultiple1ToNReplacement : public ConversionPattern {
1239 public:
1240   TestMultiple1ToNReplacement(MLIRContext *ctx, const TypeConverter &converter)
1241       : ConversionPattern(converter, "test.multiple_1_to_n_replacement", 1,
1242                           ctx) {}
1243   LogicalResult
1244   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1245                   ConversionPatternRewriter &rewriter) const final {
1246     // Helper function that replaces the given op with a new op of the given
1247     // name and doubles each result (1 -> 2 replacement of each result).
1248     auto replaceWithDoubleResults = [&](Operation *op, StringRef name) {
1249       SmallVector<Type> types;
1250       for (Type t : op->getResultTypes()) {
1251         types.push_back(t);
1252         types.push_back(t);
1253       }
1254       OperationState state(op->getLoc(), name,
1255                            /*operands=*/{}, types, op->getAttrs());
1256       auto *newOp = rewriter.create(state);
1257       SmallVector<ValueRange> repls;
1258       for (size_t i = 0, e = op->getNumResults(); i < e; ++i)
1259         repls.push_back(newOp->getResults().slice(2 * i, 2));
1260       rewriter.replaceOpWithMultiple(op, repls);
1261       return newOp;
1262     };
1263 
1264     // Replace test.multiple_1_to_n_replacement with test.step_1.
1265     Operation *repl1 = replaceWithDoubleResults(op, "test.step_1");
1266     // Now replace test.step_1 with test.legal_op.
1267     // TODO: Ideally, it should not be necessary to reset the insertion point
1268     // here. Based on the API calls, it looks like test.step_1 is entirely
1269     // erased. But that's not the case: an argument materialization will
1270     // survive. And that argument materialization will be used by the users of
1271     // `op`. If we don't reset the insertion point here, we get dominance
1272     // errors. This will be fixed when we have 1:N support in the conversion
1273     // value mapping.
1274     rewriter.setInsertionPoint(repl1);
1275     replaceWithDoubleResults(repl1, "test.legal_op");
1276     return success();
1277   }
1278 };
1279 
1280 } // namespace
1281 
1282 namespace {
1283 struct TestTypeConverter : public TypeConverter {
1284   using TypeConverter::TypeConverter;
1285   TestTypeConverter() {
1286     addConversion(convertType);
1287     addArgumentMaterialization(materializeCast);
1288     addSourceMaterialization(materializeCast);
1289   }
1290 
1291   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
1292     // Drop I16 types.
1293     if (t.isSignlessInteger(16))
1294       return success();
1295 
1296     // Convert I64 to F64.
1297     if (t.isSignlessInteger(64)) {
1298       results.push_back(FloatType::getF64(t.getContext()));
1299       return success();
1300     }
1301 
1302     // Convert I42 to I43.
1303     if (t.isInteger(42)) {
1304       results.push_back(IntegerType::get(t.getContext(), 43));
1305       return success();
1306     }
1307 
1308     // Split F32 into F16,F16.
1309     if (t.isF32()) {
1310       results.assign(2, FloatType::getF16(t.getContext()));
1311       return success();
1312     }
1313 
1314     // Drop I24 types.
1315     if (t.isInteger(24)) {
1316       return success();
1317     }
1318 
1319     // Otherwise, convert the type directly.
1320     results.push_back(t);
1321     return success();
1322   }
1323 
1324   /// Hook for materializing a conversion. This is necessary because we generate
1325   /// 1->N type mappings.
1326   static Value materializeCast(OpBuilder &builder, Type resultType,
1327                                ValueRange inputs, Location loc) {
1328     return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1329   }
1330 };
1331 
1332 struct TestLegalizePatternDriver
1333     : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> {
1334   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver)
1335 
1336   StringRef getArgument() const final { return "test-legalize-patterns"; }
1337   StringRef getDescription() const final {
1338     return "Run test dialect legalization patterns";
1339   }
1340   /// The mode of conversion to use with the driver.
1341   enum class ConversionMode { Analysis, Full, Partial };
1342 
1343   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
1344 
1345   void getDependentDialects(DialectRegistry &registry) const override {
1346     registry.insert<func::FuncDialect, test::TestDialect>();
1347   }
1348 
1349   void runOnOperation() override {
1350     TestTypeConverter converter;
1351     mlir::RewritePatternSet patterns(&getContext());
1352     populateWithGenerated(patterns);
1353     patterns
1354         .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1355              TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1356              TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
1357              TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
1358              TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
1359              TestNonRootReplacement, TestBoundedRecursiveRewrite,
1360              TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
1361              TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1362              TestUndoPropertiesModification, TestEraseOp,
1363              TestRepetitive1ToNConsumer>(&getContext());
1364     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
1365                  TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
1366         &getContext(), converter);
1367     patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
1368     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1369                                                               converter);
1370     mlir::populateCallOpTypeConversionPattern(patterns, converter);
1371 
1372     // Define the conversion target used for the test.
1373     ConversionTarget target(getContext());
1374     target.addLegalOp<ModuleOp>();
1375     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
1376                       TerminatorOp, OneRegionOp>();
1377     target.addLegalOp(OperationName("test.legal_op", &getContext()));
1378     target
1379         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
1380     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
1381       // Don't allow F32 operands.
1382       return llvm::none_of(op.getOperandTypes(),
1383                            [](Type type) { return type.isF32(); });
1384     });
1385     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1386       return converter.isSignatureLegal(op.getFunctionType()) &&
1387              converter.isLegal(&op.getBody());
1388     });
1389     target.addDynamicallyLegalOp<func::CallOp>(
1390         [&](func::CallOp op) { return converter.isLegal(op); });
1391 
1392     // TestCreateUnregisteredOp creates `arith.constant` operation,
1393     // which was not added to target intentionally to test
1394     // correct error code from conversion driver.
1395     target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
1396 
1397     // Expect the type_producer/type_consumer operations to only operate on f64.
1398     target.addDynamicallyLegalOp<TestTypeProducerOp>(
1399         [](TestTypeProducerOp op) { return op.getType().isF64(); });
1400     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1401       return op.getOperand().getType().isF64();
1402     });
1403 
1404     // Check support for marking certain operations as recursively legal.
1405     target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) {
1406       return static_cast<bool>(
1407           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
1408     });
1409 
1410     // Mark the bound recursion operation as dynamically legal.
1411     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
1412         [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
1413 
1414     // Create a dynamically legal rule that can only be legalized by folding it.
1415     target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
1416         [](TestOpInPlaceSelfFold op) { return op.getFolded(); });
1417 
1418     target.addDynamicallyLegalOp<DuplicateBlockArgsOp>(
1419         [](DuplicateBlockArgsOp op) { return op.getIsLegal(); });
1420 
1421     // Handle a partial conversion.
1422     if (mode == ConversionMode::Partial) {
1423       DenseSet<Operation *> unlegalizedOps;
1424       ConversionConfig config;
1425       DumpNotifications dumpNotifications;
1426       config.listener = &dumpNotifications;
1427       config.unlegalizedOps = &unlegalizedOps;
1428       if (failed(applyPartialConversion(getOperation(), target,
1429                                         std::move(patterns), config))) {
1430         getOperation()->emitRemark() << "applyPartialConversion failed";
1431       }
1432       // Emit remarks for each legalizable operation.
1433       for (auto *op : unlegalizedOps)
1434         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1435       return;
1436     }
1437 
1438     // Handle a full conversion.
1439     if (mode == ConversionMode::Full) {
1440       // Check support for marking unknown operations as dynamically legal.
1441       target.markUnknownOpDynamicallyLegal([](Operation *op) {
1442         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
1443       });
1444 
1445       ConversionConfig config;
1446       DumpNotifications dumpNotifications;
1447       config.listener = &dumpNotifications;
1448       if (failed(applyFullConversion(getOperation(), target,
1449                                      std::move(patterns), config))) {
1450         getOperation()->emitRemark() << "applyFullConversion failed";
1451       }
1452       return;
1453     }
1454 
1455     // Otherwise, handle an analysis conversion.
1456     assert(mode == ConversionMode::Analysis);
1457 
1458     // Analyze the convertible operations.
1459     DenseSet<Operation *> legalizedOps;
1460     ConversionConfig config;
1461     config.legalizableOps = &legalizedOps;
1462     if (failed(applyAnalysisConversion(getOperation(), target,
1463                                        std::move(patterns), config)))
1464       return signalPassFailure();
1465 
1466     // Emit remarks for each legalizable operation.
1467     for (auto *op : legalizedOps)
1468       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
1469   }
1470 
1471   /// The mode of conversion to use.
1472   ConversionMode mode;
1473 };
1474 } // namespace
1475 
1476 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
1477     legalizerConversionMode(
1478         "test-legalize-mode",
1479         llvm::cl::desc("The legalization mode to use with the test driver"),
1480         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
1481         llvm::cl::values(
1482             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
1483                        "analysis", "Perform an analysis conversion"),
1484             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
1485                        "Perform a full conversion"),
1486             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
1487                        "partial", "Perform a partial conversion")));
1488 
1489 //===----------------------------------------------------------------------===//
1490 // ConversionPatternRewriter::getRemappedValue testing. This method is used
1491 // to get the remapped value of an original value that was replaced using
1492 // ConversionPatternRewriter.
1493 namespace {
1494 struct TestRemapValueTypeConverter : public TypeConverter {
1495   using TypeConverter::TypeConverter;
1496 
1497   TestRemapValueTypeConverter() {
1498     addConversion(
1499         [](Float32Type type) { return Float64Type::get(type.getContext()); });
1500     addConversion([](Type type) { return type; });
1501   }
1502 };
1503 
1504 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
1505 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
1506 /// operand twice.
1507 ///
1508 /// Example:
1509 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
1510 /// is replaced with:
1511 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
1512 struct OneVResOneVOperandOp1Converter
1513     : public OpConversionPattern<OneVResOneVOperandOp1> {
1514   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
1515 
1516   LogicalResult
1517   matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
1518                   ConversionPatternRewriter &rewriter) const override {
1519     auto origOps = op.getOperands();
1520     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
1521            "One operand expected");
1522     Value origOp = *origOps.begin();
1523     SmallVector<Value, 2> remappedOperands;
1524     // Replicate the remapped original operand twice. Note that we don't used
1525     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
1526     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1527     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1528 
1529     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
1530                                                        remappedOperands);
1531     return success();
1532   }
1533 };
1534 
1535 /// A rewriter pattern that tests that blocks can be merged.
1536 struct TestRemapValueInRegion
1537     : public OpConversionPattern<TestRemappedValueRegionOp> {
1538   using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
1539 
1540   LogicalResult
1541   matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
1542                   ConversionPatternRewriter &rewriter) const final {
1543     Block &block = op.getBody().front();
1544     Operation *terminator = block.getTerminator();
1545 
1546     // Merge the block into the parent region.
1547     Block *parentBlock = op->getBlock();
1548     Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator());
1549     rewriter.mergeBlocks(&block, parentBlock, ValueRange());
1550     rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange());
1551 
1552     // Replace the results of this operation with the remapped terminator
1553     // values.
1554     SmallVector<Value> terminatorOperands;
1555     if (failed(rewriter.getRemappedValues(terminator->getOperands(),
1556                                           terminatorOperands)))
1557       return failure();
1558 
1559     rewriter.eraseOp(terminator);
1560     rewriter.replaceOp(op, terminatorOperands);
1561     return success();
1562   }
1563 };
1564 
1565 struct TestRemappedValue
1566     : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> {
1567   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue)
1568 
1569   StringRef getArgument() const final { return "test-remapped-value"; }
1570   StringRef getDescription() const final {
1571     return "Test public remapped value mechanism in ConversionPatternRewriter";
1572   }
1573   void runOnOperation() override {
1574     TestRemapValueTypeConverter typeConverter;
1575 
1576     mlir::RewritePatternSet patterns(&getContext());
1577     patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
1578     patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
1579         &getContext());
1580     patterns.add<TestRemapValueInRegion>(typeConverter, &getContext());
1581 
1582     mlir::ConversionTarget target(getContext());
1583     target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>();
1584 
1585     // Expect the type_producer/type_consumer operations to only operate on f64.
1586     target.addDynamicallyLegalOp<TestTypeProducerOp>(
1587         [](TestTypeProducerOp op) { return op.getType().isF64(); });
1588     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1589       return op.getOperand().getType().isF64();
1590     });
1591 
1592     // We make OneVResOneVOperandOp1 legal only when it has more that one
1593     // operand. This will trigger the conversion that will replace one-operand
1594     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
1595     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
1596         [](Operation *op) { return op->getNumOperands() > 1; });
1597 
1598     if (failed(mlir::applyFullConversion(getOperation(), target,
1599                                          std::move(patterns)))) {
1600       signalPassFailure();
1601     }
1602   }
1603 };
1604 } // namespace
1605 
1606 //===----------------------------------------------------------------------===//
1607 // Test patterns without a specific root operation kind
1608 //===----------------------------------------------------------------------===//
1609 
1610 namespace {
1611 /// This pattern matches and removes any operation in the test dialect.
1612 struct RemoveTestDialectOps : public RewritePattern {
1613   RemoveTestDialectOps(MLIRContext *context)
1614       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
1615 
1616   LogicalResult matchAndRewrite(Operation *op,
1617                                 PatternRewriter &rewriter) const override {
1618     if (!isa<TestDialect>(op->getDialect()))
1619       return failure();
1620     rewriter.eraseOp(op);
1621     return success();
1622   }
1623 };
1624 
1625 struct TestUnknownRootOpDriver
1626     : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> {
1627   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver)
1628 
1629   StringRef getArgument() const final {
1630     return "test-legalize-unknown-root-patterns";
1631   }
1632   StringRef getDescription() const final {
1633     return "Test public remapped value mechanism in ConversionPatternRewriter";
1634   }
1635   void runOnOperation() override {
1636     mlir::RewritePatternSet patterns(&getContext());
1637     patterns.add<RemoveTestDialectOps>(&getContext());
1638 
1639     mlir::ConversionTarget target(getContext());
1640     target.addIllegalDialect<TestDialect>();
1641     if (failed(applyPartialConversion(getOperation(), target,
1642                                       std::move(patterns))))
1643       signalPassFailure();
1644   }
1645 };
1646 } // namespace
1647 
1648 //===----------------------------------------------------------------------===//
1649 // Test patterns that uses operations and types defined at runtime
1650 //===----------------------------------------------------------------------===//
1651 
1652 namespace {
1653 /// This pattern matches dynamic operations 'test.one_operand_two_results' and
1654 /// replace them with dynamic operations 'test.generic_dynamic_op'.
1655 struct RewriteDynamicOp : public RewritePattern {
1656   RewriteDynamicOp(MLIRContext *context)
1657       : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1,
1658                        context) {}
1659 
1660   LogicalResult matchAndRewrite(Operation *op,
1661                                 PatternRewriter &rewriter) const override {
1662     assert(op->getName().getStringRef() ==
1663                "test.dynamic_one_operand_two_results" &&
1664            "rewrite pattern should only match operations with the right name");
1665 
1666     OperationState state(op->getLoc(), "test.dynamic_generic",
1667                          op->getOperands(), op->getResultTypes(),
1668                          op->getAttrs());
1669     auto *newOp = rewriter.create(state);
1670     rewriter.replaceOp(op, newOp->getResults());
1671     return success();
1672   }
1673 };
1674 
1675 struct TestRewriteDynamicOpDriver
1676     : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> {
1677   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver)
1678 
1679   void getDependentDialects(DialectRegistry &registry) const override {
1680     registry.insert<TestDialect>();
1681   }
1682   StringRef getArgument() const final { return "test-rewrite-dynamic-op"; }
1683   StringRef getDescription() const final {
1684     return "Test rewritting on dynamic operations";
1685   }
1686   void runOnOperation() override {
1687     RewritePatternSet patterns(&getContext());
1688     patterns.add<RewriteDynamicOp>(&getContext());
1689 
1690     ConversionTarget target(getContext());
1691     target.addIllegalOp(
1692         OperationName("test.dynamic_one_operand_two_results", &getContext()));
1693     target.addLegalOp(OperationName("test.dynamic_generic", &getContext()));
1694     if (failed(applyPartialConversion(getOperation(), target,
1695                                       std::move(patterns))))
1696       signalPassFailure();
1697   }
1698 };
1699 } // end anonymous namespace
1700 
1701 //===----------------------------------------------------------------------===//
1702 // Test type conversions
1703 //===----------------------------------------------------------------------===//
1704 
1705 namespace {
1706 struct TestTypeConversionProducer
1707     : public OpConversionPattern<TestTypeProducerOp> {
1708   using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
1709   LogicalResult
1710   matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
1711                   ConversionPatternRewriter &rewriter) const final {
1712     Type resultType = op.getType();
1713     Type convertedType = getTypeConverter()
1714                              ? getTypeConverter()->convertType(resultType)
1715                              : resultType;
1716     if (isa<FloatType>(resultType))
1717       resultType = rewriter.getF64Type();
1718     else if (resultType.isInteger(16))
1719       resultType = rewriter.getIntegerType(64);
1720     else if (isa<test::TestRecursiveType>(resultType) &&
1721              convertedType != resultType)
1722       resultType = convertedType;
1723     else
1724       return failure();
1725 
1726     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
1727     return success();
1728   }
1729 };
1730 
1731 /// Call signature conversion and then fail the rewrite to trigger the undo
1732 /// mechanism.
1733 struct TestSignatureConversionUndo
1734     : public OpConversionPattern<TestSignatureConversionUndoOp> {
1735   using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
1736 
1737   LogicalResult
1738   matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
1739                   ConversionPatternRewriter &rewriter) const final {
1740     (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
1741     return failure();
1742   }
1743 };
1744 
1745 /// Call signature conversion without providing a type converter to handle
1746 /// materializations.
1747 struct TestTestSignatureConversionNoConverter
1748     : public OpConversionPattern<TestSignatureConversionNoConverterOp> {
1749   TestTestSignatureConversionNoConverter(const TypeConverter &converter,
1750                                          MLIRContext *context)
1751       : OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
1752         converter(converter) {}
1753 
1754   LogicalResult
1755   matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor,
1756                   ConversionPatternRewriter &rewriter) const final {
1757     Region &region = op->getRegion(0);
1758     Block *entry = &region.front();
1759 
1760     // Convert the original entry arguments.
1761     TypeConverter::SignatureConversion result(entry->getNumArguments());
1762     if (failed(
1763             converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
1764       return failure();
1765     rewriter.modifyOpInPlace(op, [&] {
1766       rewriter.applySignatureConversion(&region.front(), result);
1767     });
1768     return success();
1769   }
1770 
1771   const TypeConverter &converter;
1772 };
1773 
1774 /// Just forward the operands to the root op. This is essentially a no-op
1775 /// pattern that is used to trigger target materialization.
1776 struct TestTypeConsumerForward
1777     : public OpConversionPattern<TestTypeConsumerOp> {
1778   using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
1779 
1780   LogicalResult
1781   matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
1782                   ConversionPatternRewriter &rewriter) const final {
1783     rewriter.modifyOpInPlace(op,
1784                              [&] { op->setOperands(adaptor.getOperands()); });
1785     return success();
1786   }
1787 };
1788 
1789 struct TestTypeConversionAnotherProducer
1790     : public OpRewritePattern<TestAnotherTypeProducerOp> {
1791   using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
1792 
1793   LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
1794                                 PatternRewriter &rewriter) const final {
1795     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
1796     return success();
1797   }
1798 };
1799 
1800 struct TestReplaceWithLegalOp : public ConversionPattern {
1801   TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
1802       : ConversionPattern(converter, "test.replace_with_legal_op",
1803                           /*benefit=*/1, ctx) {}
1804   LogicalResult
1805   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1806                   ConversionPatternRewriter &rewriter) const final {
1807     rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
1808     return success();
1809   }
1810 };
1811 
1812 struct TestTypeConversionDriver
1813     : public PassWrapper<TestTypeConversionDriver, OperationPass<>> {
1814   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
1815 
1816   void getDependentDialects(DialectRegistry &registry) const override {
1817     registry.insert<TestDialect>();
1818   }
1819   StringRef getArgument() const final {
1820     return "test-legalize-type-conversion";
1821   }
1822   StringRef getDescription() const final {
1823     return "Test various type conversion functionalities in DialectConversion";
1824   }
1825 
1826   void runOnOperation() override {
1827     // Initialize the type converter.
1828     SmallVector<Type, 2> conversionCallStack;
1829     TypeConverter converter;
1830 
1831     /// Add the legal set of type conversions.
1832     converter.addConversion([](Type type) -> Type {
1833       // Treat F64 as legal.
1834       if (type.isF64())
1835         return type;
1836       // Allow converting BF16/F16/F32 to F64.
1837       if (type.isBF16() || type.isF16() || type.isF32())
1838         return FloatType::getF64(type.getContext());
1839       // Otherwise, the type is illegal.
1840       return nullptr;
1841     });
1842     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
1843       // Drop all integer types.
1844       return success();
1845     });
1846     converter.addConversion(
1847         // Convert a recursive self-referring type into a non-self-referring
1848         // type named "outer_converted_type" that contains a SimpleAType.
1849         [&](test::TestRecursiveType type,
1850             SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
1851           // If the type is already converted, return it to indicate that it is
1852           // legal.
1853           if (type.getName() == "outer_converted_type") {
1854             results.push_back(type);
1855             return success();
1856           }
1857 
1858           conversionCallStack.push_back(type);
1859           auto popConversionCallStack = llvm::make_scope_exit(
1860               [&conversionCallStack]() { conversionCallStack.pop_back(); });
1861 
1862           // If the type is on the call stack more than once (it is there at
1863           // least once because of the _current_ call, which is always the last
1864           // element on the stack), we've hit the recursive case. Just return
1865           // SimpleAType here to create a non-recursive type as a result.
1866           if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(),
1867                                  type)) {
1868             results.push_back(test::SimpleAType::get(type.getContext()));
1869             return success();
1870           }
1871 
1872           // Convert the body recursively.
1873           auto result = test::TestRecursiveType::get(type.getContext(),
1874                                                      "outer_converted_type");
1875           if (failed(result.setBody(converter.convertType(type.getBody()))))
1876             return failure();
1877           results.push_back(result);
1878           return success();
1879         });
1880 
1881     /// Add the legal set of type materializations.
1882     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
1883                                           ValueRange inputs,
1884                                           Location loc) -> Value {
1885       // Allow casting from F64 back to F32.
1886       if (!resultType.isF16() && inputs.size() == 1 &&
1887           inputs[0].getType().isF64())
1888         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1889       // Allow producing an i32 or i64 from nothing.
1890       if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
1891           inputs.empty())
1892         return builder.create<TestTypeProducerOp>(loc, resultType);
1893       // Allow producing an i64 from an integer.
1894       if (isa<IntegerType>(resultType) && inputs.size() == 1 &&
1895           isa<IntegerType>(inputs[0].getType()))
1896         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1897       // Otherwise, fail.
1898       return nullptr;
1899     });
1900 
1901     // Initialize the conversion target.
1902     mlir::ConversionTarget target(getContext());
1903     target.addLegalOp<LegalOpD>();
1904     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
1905       auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType());
1906       return op.getType().isF64() || op.getType().isInteger(64) ||
1907              (recursiveType &&
1908               recursiveType.getName() == "outer_converted_type");
1909     });
1910     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1911       return converter.isSignatureLegal(op.getFunctionType()) &&
1912              converter.isLegal(&op.getBody());
1913     });
1914     target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
1915       // Allow casts from F64 to F32.
1916       return (*op.operand_type_begin()).isF64() && op.getType().isF32();
1917     });
1918     target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>(
1919         [&](TestSignatureConversionNoConverterOp op) {
1920           return converter.isLegal(op.getRegion().front().getArgumentTypes());
1921         });
1922 
1923     // Initialize the set of rewrite patterns.
1924     RewritePatternSet patterns(&getContext());
1925     patterns
1926         .add<TestTypeConsumerForward, TestTypeConversionProducer,
1927              TestSignatureConversionUndo,
1928              TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
1929             converter, &getContext());
1930     patterns.add<TestTypeConversionAnotherProducer>(&getContext());
1931     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1932                                                               converter);
1933 
1934     if (failed(applyPartialConversion(getOperation(), target,
1935                                       std::move(patterns))))
1936       signalPassFailure();
1937   }
1938 };
1939 } // namespace
1940 
1941 //===----------------------------------------------------------------------===//
1942 // Test Target Materialization With No Uses
1943 //===----------------------------------------------------------------------===//
1944 
1945 namespace {
1946 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> {
1947   using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern;
1948 
1949   LogicalResult
1950   matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor,
1951                   ConversionPatternRewriter &rewriter) const final {
1952     rewriter.replaceOp(op, adaptor.getOperands());
1953     return success();
1954   }
1955 };
1956 
1957 struct TestTargetMaterializationWithNoUses
1958     : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> {
1959   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
1960       TestTargetMaterializationWithNoUses)
1961 
1962   StringRef getArgument() const final {
1963     return "test-target-materialization-with-no-uses";
1964   }
1965   StringRef getDescription() const final {
1966     return "Test a special case of target materialization in DialectConversion";
1967   }
1968 
1969   void runOnOperation() override {
1970     TypeConverter converter;
1971     converter.addConversion([](Type t) { return t; });
1972     converter.addConversion([](IntegerType intTy) -> Type {
1973       if (intTy.getWidth() == 16)
1974         return IntegerType::get(intTy.getContext(), 64);
1975       return intTy;
1976     });
1977     converter.addTargetMaterialization(
1978         [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1979           return builder.create<TestCastOp>(loc, type, inputs).getResult();
1980         });
1981 
1982     ConversionTarget target(getContext());
1983     target.addIllegalOp<TestTypeChangerOp>();
1984 
1985     RewritePatternSet patterns(&getContext());
1986     patterns.add<ForwardOperandPattern>(converter, &getContext());
1987 
1988     if (failed(applyPartialConversion(getOperation(), target,
1989                                       std::move(patterns))))
1990       signalPassFailure();
1991   }
1992 };
1993 } // namespace
1994 
1995 //===----------------------------------------------------------------------===//
1996 // Test Block Merging
1997 //===----------------------------------------------------------------------===//
1998 
1999 namespace {
2000 /// A rewriter pattern that tests that blocks can be merged.
2001 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
2002   using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
2003 
2004   LogicalResult
2005   matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
2006                   ConversionPatternRewriter &rewriter) const final {
2007     Block &firstBlock = op.getBody().front();
2008     Operation *branchOp = firstBlock.getTerminator();
2009     Block *secondBlock = &*(std::next(op.getBody().begin()));
2010     auto succOperands = branchOp->getOperands();
2011     SmallVector<Value, 2> replacements(succOperands);
2012     rewriter.eraseOp(branchOp);
2013     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
2014     rewriter.modifyOpInPlace(op, [] {});
2015     return success();
2016   }
2017 };
2018 
2019 /// A rewrite pattern to tests the undo mechanism of blocks being merged.
2020 struct TestUndoBlocksMerge : public ConversionPattern {
2021   TestUndoBlocksMerge(MLIRContext *ctx)
2022       : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
2023   LogicalResult
2024   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
2025                   ConversionPatternRewriter &rewriter) const final {
2026     Block &firstBlock = op->getRegion(0).front();
2027     Operation *branchOp = firstBlock.getTerminator();
2028     Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
2029     rewriter.setInsertionPointToStart(secondBlock);
2030     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
2031     auto succOperands = branchOp->getOperands();
2032     SmallVector<Value, 2> replacements(succOperands);
2033     rewriter.eraseOp(branchOp);
2034     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
2035     rewriter.modifyOpInPlace(op, [] {});
2036     return success();
2037   }
2038 };
2039 
2040 /// A rewrite mechanism to inline the body of the op into its parent, when both
2041 /// ops can have a single block.
2042 struct TestMergeSingleBlockOps
2043     : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
2044   using OpConversionPattern<
2045       SingleBlockImplicitTerminatorOp>::OpConversionPattern;
2046 
2047   LogicalResult
2048   matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
2049                   ConversionPatternRewriter &rewriter) const final {
2050     SingleBlockImplicitTerminatorOp parentOp =
2051         op->getParentOfType<SingleBlockImplicitTerminatorOp>();
2052     if (!parentOp)
2053       return failure();
2054     Block &innerBlock = op.getRegion().front();
2055     TerminatorOp innerTerminator =
2056         cast<TerminatorOp>(innerBlock.getTerminator());
2057     rewriter.inlineBlockBefore(&innerBlock, op);
2058     rewriter.eraseOp(innerTerminator);
2059     rewriter.eraseOp(op);
2060     return success();
2061   }
2062 };
2063 
2064 struct TestMergeBlocksPatternDriver
2065     : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> {
2066   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver)
2067 
2068   StringRef getArgument() const final { return "test-merge-blocks"; }
2069   StringRef getDescription() const final {
2070     return "Test Merging operation in ConversionPatternRewriter";
2071   }
2072   void runOnOperation() override {
2073     MLIRContext *context = &getContext();
2074     mlir::RewritePatternSet patterns(context);
2075     patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
2076         context);
2077     ConversionTarget target(*context);
2078     target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
2079                       TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
2080     target.addIllegalOp<ILLegalOpF>();
2081 
2082     /// Expect the op to have a single block after legalization.
2083     target.addDynamicallyLegalOp<TestMergeBlocksOp>(
2084         [&](TestMergeBlocksOp op) -> bool {
2085           return llvm::hasSingleElement(op.getBody());
2086         });
2087 
2088     /// Only allow `test.br` within test.merge_blocks op.
2089     target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
2090       return op->getParentOfType<TestMergeBlocksOp>();
2091     });
2092 
2093     /// Expect that all nested test.SingleBlockImplicitTerminator ops are
2094     /// inlined.
2095     target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
2096         [&](SingleBlockImplicitTerminatorOp op) -> bool {
2097           return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
2098         });
2099 
2100     DenseSet<Operation *> unlegalizedOps;
2101     ConversionConfig config;
2102     config.unlegalizedOps = &unlegalizedOps;
2103     (void)applyPartialConversion(getOperation(), target, std::move(patterns),
2104                                  config);
2105     for (auto *op : unlegalizedOps)
2106       op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
2107   }
2108 };
2109 } // namespace
2110 
2111 //===----------------------------------------------------------------------===//
2112 // Test Selective Replacement
2113 //===----------------------------------------------------------------------===//
2114 
2115 namespace {
2116 /// A rewrite mechanism to inline the body of the op into its parent, when both
2117 /// ops can have a single block.
2118 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
2119   using OpRewritePattern<TestCastOp>::OpRewritePattern;
2120 
2121   LogicalResult matchAndRewrite(TestCastOp op,
2122                                 PatternRewriter &rewriter) const final {
2123     if (op.getNumOperands() != 2)
2124       return failure();
2125     OperandRange operands = op.getOperands();
2126 
2127     // Replace non-terminator uses with the first operand.
2128     rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) {
2129       return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
2130     });
2131     // Replace everything else with the second operand if the operation isn't
2132     // dead.
2133     rewriter.replaceOp(op, op.getOperand(1));
2134     return success();
2135   }
2136 };
2137 
2138 struct TestSelectiveReplacementPatternDriver
2139     : public PassWrapper<TestSelectiveReplacementPatternDriver,
2140                          OperationPass<>> {
2141   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
2142       TestSelectiveReplacementPatternDriver)
2143 
2144   StringRef getArgument() const final {
2145     return "test-pattern-selective-replacement";
2146   }
2147   StringRef getDescription() const final {
2148     return "Test selective replacement in the PatternRewriter";
2149   }
2150   void runOnOperation() override {
2151     MLIRContext *context = &getContext();
2152     mlir::RewritePatternSet patterns(context);
2153     patterns.add<TestSelectiveOpReplacementPattern>(context);
2154     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
2155   }
2156 };
2157 } // namespace
2158 
2159 //===----------------------------------------------------------------------===//
2160 // PassRegistration
2161 //===----------------------------------------------------------------------===//
2162 
2163 namespace mlir {
2164 namespace test {
2165 void registerPatternsTestPass() {
2166   PassRegistration<TestReturnTypeDriver>();
2167 
2168   PassRegistration<TestDerivedAttributeDriver>();
2169 
2170   PassRegistration<TestGreedyPatternDriver>();
2171   PassRegistration<TestStrictPatternDriver>();
2172   PassRegistration<TestWalkPatternDriver>();
2173 
2174   PassRegistration<TestLegalizePatternDriver>([] {
2175     return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
2176   });
2177 
2178   PassRegistration<TestRemappedValue>();
2179 
2180   PassRegistration<TestUnknownRootOpDriver>();
2181 
2182   PassRegistration<TestTypeConversionDriver>();
2183   PassRegistration<TestTargetMaterializationWithNoUses>();
2184 
2185   PassRegistration<TestRewriteDynamicOpDriver>();
2186 
2187   PassRegistration<TestMergeBlocksPatternDriver>();
2188   PassRegistration<TestSelectiveReplacementPatternDriver>();
2189 }
2190 } // namespace test
2191 } // namespace mlir
2192