xref: /llvm-project/mlir/test/lib/Dialect/Test/TestPatterns.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestOps.h"
11 #include "TestTypes.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BuiltinAttributes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Visitors.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
24 #include "llvm/ADT/ScopeExit.h"
25 #include <cstdint>
26 
27 using namespace mlir;
28 using namespace test;
29 
30 // Native function for testing NativeCodeCall
31 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
32   return choice.getValue() ? input1 : input2;
33 }
34 
35 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
36   rewriter.create<OpI>(loc, input);
37 }
38 
39 static void handleNoResultOp(PatternRewriter &rewriter,
40                              OpSymbolBindingNoResult op) {
41   // Turn the no result op to a one-result op.
42   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
43                                     op.getOperand());
44 }
45 
46 static bool getFirstI32Result(Operation *op, Value &value) {
47   if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
48     return false;
49   value = op->getResult(0);
50   return true;
51 }
52 
53 static Value bindNativeCodeCallResult(Value value) { return value; }
54 
55 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
56                                                               Value input2) {
57   return SmallVector<Value, 2>({input2, input1});
58 }
59 
60 // Test that natives calls are only called once during rewrites.
61 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
62 // This let us check the number of times OpM_Test was called by inspecting
63 // the returned value in the MLIR output.
64 static int64_t opMIncreasingValue = 314159265;
65 static Attribute opMTest(PatternRewriter &rewriter, Value val) {
66   int64_t i = opMIncreasingValue++;
67   return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
68 }
69 
70 namespace {
71 #include "TestPatterns.inc"
72 } // namespace
73 
74 //===----------------------------------------------------------------------===//
75 // Test Reduce Pattern Interface
76 //===----------------------------------------------------------------------===//
77 
78 void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
79   populateWithGenerated(patterns);
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // Canonicalizer Driver.
84 //===----------------------------------------------------------------------===//
85 
86 namespace {
87 struct FoldingPattern : public RewritePattern {
88 public:
89   FoldingPattern(MLIRContext *context)
90       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
91                        /*benefit=*/1, context) {}
92 
93   LogicalResult matchAndRewrite(Operation *op,
94                                 PatternRewriter &rewriter) const override {
95     // Exercise createOrFold API for a single-result operation that is folded
96     // upon construction. The operation being created has an in-place folder,
97     // and it should be still present in the output. Furthermore, the folder
98     // should not crash when attempting to recover the (unchanged) operation
99     // result.
100     Value result = rewriter.createOrFold<TestOpInPlaceFold>(
101         op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0));
102     assert(result);
103     rewriter.replaceOp(op, result);
104     return success();
105   }
106 };
107 
108 /// This pattern creates a foldable operation at the entry point of the block.
109 /// This tests the situation where the operation folder will need to replace an
110 /// operation with a previously created constant that does not initially
111 /// dominate the operation to replace.
112 struct FolderInsertBeforePreviouslyFoldedConstantPattern
113     : public OpRewritePattern<TestCastOp> {
114 public:
115   using OpRewritePattern<TestCastOp>::OpRewritePattern;
116 
117   LogicalResult matchAndRewrite(TestCastOp op,
118                                 PatternRewriter &rewriter) const override {
119     if (!op->hasAttr("test_fold_before_previously_folded_op"))
120       return failure();
121     rewriter.setInsertionPointToStart(op->getBlock());
122 
123     auto constOp = rewriter.create<arith::ConstantOp>(
124         op.getLoc(), rewriter.getBoolAttr(true));
125     rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
126                                             Value(constOp));
127     return success();
128   }
129 };
130 
131 /// This pattern matches test.op_commutative2 with the first operand being
132 /// another test.op_commutative2 with a constant on the right side and fold it
133 /// away by propagating it as its result. This is intend to check that patterns
134 /// are applied after the commutative property moves constant to the right.
135 struct FolderCommutativeOp2WithConstant
136     : public OpRewritePattern<TestCommutative2Op> {
137 public:
138   using OpRewritePattern<TestCommutative2Op>::OpRewritePattern;
139 
140   LogicalResult matchAndRewrite(TestCommutative2Op op,
141                                 PatternRewriter &rewriter) const override {
142     auto operand =
143         dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
144     if (!operand)
145       return failure();
146     Attribute constInput;
147     if (!matchPattern(operand->getOperand(1), m_Constant(&constInput)))
148       return failure();
149     rewriter.replaceOp(op, operand->getOperand(1));
150     return success();
151   }
152 };
153 
154 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer
155 /// attribute with value smaller than MaxVal, it increments the value by 1.
156 template <int MaxVal>
157 struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
158   using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern;
159 
160   LogicalResult matchAndRewrite(AnyAttrOfOp op,
161                                 PatternRewriter &rewriter) const override {
162     auto intAttr = dyn_cast<IntegerAttr>(op.getAttr());
163     if (!intAttr)
164       return failure();
165     int64_t val = intAttr.getInt();
166     if (val >= MaxVal)
167       return failure();
168     rewriter.modifyOpInPlace(
169         op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); });
170     return success();
171   }
172 };
173 
174 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op".
175 struct MakeOpEligible : public RewritePattern {
176   MakeOpEligible(MLIRContext *context)
177       : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {}
178 
179   LogicalResult matchAndRewrite(Operation *op,
180                                 PatternRewriter &rewriter) const override {
181     if (op->hasAttr("eligible"))
182       return failure();
183     rewriter.modifyOpInPlace(
184         op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); });
185     return success();
186   }
187 };
188 
189 /// This pattern hoists eligible ops out of a "test.one_region_op".
190 struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
191   using OpRewritePattern<test::OneRegionOp>::OpRewritePattern;
192 
193   LogicalResult matchAndRewrite(test::OneRegionOp op,
194                                 PatternRewriter &rewriter) const override {
195     Operation *terminator = op.getRegion().front().getTerminator();
196     Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp();
197     if (toBeHoisted->getParentOp() != op)
198       return failure();
199     if (!toBeHoisted->hasAttr("eligible"))
200       return failure();
201     rewriter.moveOpBefore(toBeHoisted, op);
202     return success();
203   }
204 };
205 
206 /// This pattern moves "test.move_before_parent_op" before the parent op.
207 struct MoveBeforeParentOp : public RewritePattern {
208   MoveBeforeParentOp(MLIRContext *context)
209       : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
210 
211   LogicalResult matchAndRewrite(Operation *op,
212                                 PatternRewriter &rewriter) const override {
213     // Do not hoist past functions.
214     if (isa<FunctionOpInterface>(op->getParentOp()))
215       return failure();
216     rewriter.moveOpBefore(op, op->getParentOp());
217     return success();
218   }
219 };
220 
221 /// This pattern moves "test.move_after_parent_op" after the parent op.
222 struct MoveAfterParentOp : public RewritePattern {
223   MoveAfterParentOp(MLIRContext *context)
224       : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {}
225 
226   LogicalResult matchAndRewrite(Operation *op,
227                                 PatternRewriter &rewriter) const override {
228     // Do not hoist past functions.
229     if (isa<FunctionOpInterface>(op->getParentOp()))
230       return failure();
231 
232     int64_t moveForwardBy = 0;
233     if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance"))
234       moveForwardBy = advanceBy.getInt();
235 
236     Operation *moveAfter = op->getParentOp();
237     for (int64_t i = 0; i < moveForwardBy; ++i)
238       moveAfter = moveAfter->getNextNode();
239 
240     rewriter.moveOpAfter(op, moveAfter);
241     return success();
242   }
243 };
244 
245 /// This pattern inlines blocks that are nested in
246 /// "test.inline_blocks_into_parent" into the parent block.
247 struct InlineBlocksIntoParent : public RewritePattern {
248   InlineBlocksIntoParent(MLIRContext *context)
249       : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
250                        context) {}
251 
252   LogicalResult matchAndRewrite(Operation *op,
253                                 PatternRewriter &rewriter) const override {
254     bool changed = false;
255     for (Region &r : op->getRegions()) {
256       while (!r.empty()) {
257         rewriter.inlineBlockBefore(&r.front(), op);
258         changed = true;
259       }
260     }
261     return success(changed);
262   }
263 };
264 
265 /// This pattern splits blocks at "test.split_block_here" and replaces the op
266 /// with a new op (to prevent an infinite loop of block splitting).
267 struct SplitBlockHere : public RewritePattern {
268   SplitBlockHere(MLIRContext *context)
269       : RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
270 
271   LogicalResult matchAndRewrite(Operation *op,
272                                 PatternRewriter &rewriter) const override {
273     rewriter.splitBlock(op->getBlock(), op->getIterator());
274     Operation *newOp = rewriter.create(
275         op->getLoc(),
276         OperationName("test.new_op", op->getContext()).getIdentifier(),
277         op->getOperands(), op->getResultTypes());
278     rewriter.replaceOp(op, newOp);
279     return success();
280   }
281 };
282 
283 /// This pattern clones "test.clone_me" ops.
284 struct CloneOp : public RewritePattern {
285   CloneOp(MLIRContext *context)
286       : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
287 
288   LogicalResult matchAndRewrite(Operation *op,
289                                 PatternRewriter &rewriter) const override {
290     // Do not clone already cloned ops to avoid going into an infinite loop.
291     if (op->hasAttr("was_cloned"))
292       return failure();
293     Operation *cloned = rewriter.clone(*op);
294     cloned->setAttr("was_cloned", rewriter.getUnitAttr());
295     return success();
296   }
297 };
298 
299 /// This pattern clones regions of "test.clone_region_before" ops before the
300 /// parent block.
301 struct CloneRegionBeforeOp : public RewritePattern {
302   CloneRegionBeforeOp(MLIRContext *context)
303       : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {}
304 
305   LogicalResult matchAndRewrite(Operation *op,
306                                 PatternRewriter &rewriter) const override {
307     // Do not clone already cloned ops to avoid going into an infinite loop.
308     if (op->hasAttr("was_cloned"))
309       return failure();
310     for (Region &r : op->getRegions())
311       rewriter.cloneRegionBefore(r, op->getBlock());
312     op->setAttr("was_cloned", rewriter.getUnitAttr());
313     return success();
314   }
315 };
316 
317 /// Replace an operation may introduce the re-visiting of its users.
318 class ReplaceWithNewOp : public RewritePattern {
319 public:
320   ReplaceWithNewOp(MLIRContext *context)
321       : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
322 
323   LogicalResult matchAndRewrite(Operation *op,
324                                 PatternRewriter &rewriter) const override {
325     Operation *newOp;
326     if (op->hasAttr("create_erase_op")) {
327       newOp = rewriter.create(
328           op->getLoc(),
329           OperationName("test.erase_op", op->getContext()).getIdentifier(),
330           ValueRange(), TypeRange());
331     } else {
332       newOp = rewriter.create(
333           op->getLoc(),
334           OperationName("test.new_op", op->getContext()).getIdentifier(),
335           op->getOperands(), op->getResultTypes());
336     }
337     // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
338     // A "notifyOperationReplaced" callback is triggered in either case.
339     rewriter.replaceAllOpUsesWith(op, newOp->getResults());
340     rewriter.eraseOp(op);
341     return success();
342   }
343 };
344 
345 /// Erases the first child block of the matched "test.erase_first_block"
346 /// operation.
347 class EraseFirstBlock : public RewritePattern {
348 public:
349   EraseFirstBlock(MLIRContext *context)
350       : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {}
351 
352   LogicalResult matchAndRewrite(Operation *op,
353                                 PatternRewriter &rewriter) const override {
354     for (Region &r : op->getRegions()) {
355       for (Block &b : r.getBlocks()) {
356         rewriter.eraseBlock(&b);
357         return success();
358       }
359     }
360 
361     return failure();
362   }
363 };
364 
365 struct TestGreedyPatternDriver
366     : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> {
367   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver)
368 
369   TestGreedyPatternDriver() = default;
370   TestGreedyPatternDriver(const TestGreedyPatternDriver &other)
371       : PassWrapper(other) {}
372 
373   StringRef getArgument() const final { return "test-greedy-patterns"; }
374   StringRef getDescription() const final { return "Run test dialect patterns"; }
375   void runOnOperation() override {
376     mlir::RewritePatternSet patterns(&getContext());
377     populateWithGenerated(patterns);
378 
379     // Verify named pattern is generated with expected name.
380     patterns.add<FoldingPattern, TestNamedPatternRule,
381                  FolderInsertBeforePreviouslyFoldedConstantPattern,
382                  FolderCommutativeOp2WithConstant, HoistEligibleOps,
383                  MakeOpEligible>(&getContext());
384 
385     // Additional patterns for testing the GreedyPatternRewriteDriver.
386     patterns.insert<IncrementIntAttribute<3>>(&getContext());
387 
388     GreedyRewriteConfig config;
389     config.useTopDownTraversal = this->useTopDownTraversal;
390     config.maxIterations = this->maxIterations;
391     config.fold = this->fold;
392     config.cseConstants = this->cseConstants;
393     (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
394   }
395 
396   Option<bool> useTopDownTraversal{
397       *this, "top-down",
398       llvm::cl::desc("Seed the worklist in general top-down order"),
399       llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
400   Option<int> maxIterations{
401       *this, "max-iterations",
402       llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
403       llvm::cl::init(GreedyRewriteConfig().maxIterations)};
404   Option<bool> fold{*this, "fold", llvm::cl::desc("Whether to fold"),
405                     llvm::cl::init(GreedyRewriteConfig().fold)};
406   Option<bool> cseConstants{*this, "cse-constants",
407                             llvm::cl::desc("Whether to CSE constants"),
408                             llvm::cl::init(GreedyRewriteConfig().cseConstants)};
409 };
410 
411 struct DumpNotifications : public RewriterBase::Listener {
412   void notifyBlockInserted(Block *block, Region *previous,
413                            Region::iterator previousIt) override {
414     llvm::outs() << "notifyBlockInserted";
415     if (block->getParentOp()) {
416       llvm::outs() << " into " << block->getParentOp()->getName() << ": ";
417     } else {
418       llvm::outs() << " into unknown op: ";
419     }
420     if (previous == nullptr) {
421       llvm::outs() << "was unlinked\n";
422     } else {
423       llvm::outs() << "was linked\n";
424     }
425   }
426   void notifyOperationInserted(Operation *op,
427                                OpBuilder::InsertPoint previous) override {
428     llvm::outs() << "notifyOperationInserted: " << op->getName();
429     if (!previous.isSet()) {
430       llvm::outs() << ", was unlinked\n";
431     } else {
432       if (!previous.getPoint().getNodePtr()) {
433         llvm::outs() << ", was linked, exact position unknown\n";
434       } else if (previous.getPoint() == previous.getBlock()->end()) {
435         llvm::outs() << ", was last in block\n";
436       } else {
437         llvm::outs() << ", previous = " << previous.getPoint()->getName()
438                      << "\n";
439       }
440     }
441   }
442   void notifyBlockErased(Block *block) override {
443     llvm::outs() << "notifyBlockErased\n";
444   }
445   void notifyOperationErased(Operation *op) override {
446     llvm::outs() << "notifyOperationErased: " << op->getName() << "\n";
447   }
448   void notifyOperationModified(Operation *op) override {
449     llvm::outs() << "notifyOperationModified: " << op->getName() << "\n";
450   }
451   void notifyOperationReplaced(Operation *op, ValueRange values) override {
452     llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n";
453   }
454 };
455 
456 struct TestStrictPatternDriver
457     : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
458 public:
459   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
460 
461   TestStrictPatternDriver() = default;
462   TestStrictPatternDriver(const TestStrictPatternDriver &other)
463       : PassWrapper(other) {
464     strictMode = other.strictMode;
465   }
466 
467   StringRef getArgument() const final { return "test-strict-pattern-driver"; }
468   StringRef getDescription() const final {
469     return "Test strict mode of pattern driver";
470   }
471 
472   void runOnOperation() override {
473     MLIRContext *ctx = &getContext();
474     mlir::RewritePatternSet patterns(ctx);
475     patterns.add<
476         // clang-format off
477         ChangeBlockOp,
478         CloneOp,
479         CloneRegionBeforeOp,
480         EraseOp,
481         ImplicitChangeOp,
482         InlineBlocksIntoParent,
483         InsertSameOp,
484         MoveBeforeParentOp,
485         ReplaceWithNewOp,
486         SplitBlockHere
487         // clang-format on
488         >(ctx);
489     SmallVector<Operation *> ops;
490     getOperation()->walk([&](Operation *op) {
491       StringRef opName = op->getName().getStringRef();
492       if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
493           opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
494           opName == "test.move_before_parent_op" ||
495           opName == "test.inline_blocks_into_parent" ||
496           opName == "test.split_block_here" || opName == "test.clone_me" ||
497           opName == "test.clone_region_before") {
498         ops.push_back(op);
499       }
500     });
501 
502     DumpNotifications dumpNotifications;
503     GreedyRewriteConfig config;
504     config.listener = &dumpNotifications;
505     if (strictMode == "AnyOp") {
506       config.strictMode = GreedyRewriteStrictness::AnyOp;
507     } else if (strictMode == "ExistingAndNewOps") {
508       config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
509     } else if (strictMode == "ExistingOps") {
510       config.strictMode = GreedyRewriteStrictness::ExistingOps;
511     } else {
512       llvm_unreachable("invalid strictness option");
513     }
514 
515     // Check if these transformations introduce visiting of operations that
516     // are not in the `ops` set (The new created ops are valid). An invalid
517     // operation will trigger the assertion while processing.
518     bool changed = false;
519     bool allErased = false;
520     (void)applyOpPatternsGreedily(ArrayRef(ops), std::move(patterns), config,
521                                   &changed, &allErased);
522     Builder b(ctx);
523     getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed));
524     getOperation()->setAttr("pattern_driver_all_erased",
525                             b.getBoolAttr(allErased));
526   }
527 
528   Option<std::string> strictMode{
529       *this, "strictness",
530       llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"),
531       llvm::cl::init("AnyOp")};
532 
533 private:
534   // New inserted operation is valid for further transformation.
535   class InsertSameOp : public RewritePattern {
536   public:
537     InsertSameOp(MLIRContext *context)
538         : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {}
539 
540     LogicalResult matchAndRewrite(Operation *op,
541                                   PatternRewriter &rewriter) const override {
542       if (op->hasAttr("skip"))
543         return failure();
544 
545       Operation *newOp =
546           rewriter.create(op->getLoc(), op->getName().getIdentifier(),
547                           op->getOperands(), op->getResultTypes());
548       rewriter.modifyOpInPlace(
549           op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); });
550       newOp->setAttr("skip", rewriter.getBoolAttr(true));
551 
552       return success();
553     }
554   };
555 
556   // Remove an operation may introduce the re-visiting of its operands.
557   class EraseOp : public RewritePattern {
558   public:
559     EraseOp(MLIRContext *context)
560         : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
561     LogicalResult matchAndRewrite(Operation *op,
562                                   PatternRewriter &rewriter) const override {
563       rewriter.eraseOp(op);
564       return success();
565     }
566   };
567 
568   // The following two patterns test RewriterBase::replaceAllUsesWith.
569   //
570   // That function replaces all usages of a Block (or a Value) with another one
571   // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver
572   // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its
573   // worklist: when an op is modified, it is added to the worklist. The two
574   // patterns below make the tracking observable: ChangeBlockOp replaces all
575   // usages of a block and that pattern is applied because the corresponding ops
576   // are put on the initial worklist (see above). ImplicitChangeOp does an
577   // unrelated change but ops of the corresponding type are *not* on the initial
578   // worklist, so the effect of the second pattern is only visible if the
579   // tracking and subsequent adding to the worklist actually works.
580 
581   // Replace all usages of the first successor with the second successor.
582   class ChangeBlockOp : public RewritePattern {
583   public:
584     ChangeBlockOp(MLIRContext *context)
585         : RewritePattern("test.change_block_op", /*benefit=*/1, context) {}
586     LogicalResult matchAndRewrite(Operation *op,
587                                   PatternRewriter &rewriter) const override {
588       if (op->getNumSuccessors() < 2)
589         return failure();
590       Block *firstSuccessor = op->getSuccessor(0);
591       Block *secondSuccessor = op->getSuccessor(1);
592       if (firstSuccessor == secondSuccessor)
593         return failure();
594       // This is the function being tested:
595       rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor);
596       // Using the following line instead would make the test fail:
597       // firstSuccessor->replaceAllUsesWith(secondSuccessor);
598       return success();
599     }
600   };
601 
602   // Changes the successor to the parent block.
603   class ImplicitChangeOp : public RewritePattern {
604   public:
605     ImplicitChangeOp(MLIRContext *context)
606         : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {}
607     LogicalResult matchAndRewrite(Operation *op,
608                                   PatternRewriter &rewriter) const override {
609       if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock())
610         return failure();
611       rewriter.modifyOpInPlace(op,
612                                [&]() { op->setSuccessor(op->getBlock(), 0); });
613       return success();
614     }
615   };
616 };
617 
618 struct TestWalkPatternDriver final
619     : PassWrapper<TestWalkPatternDriver, OperationPass<>> {
620   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver)
621 
622   TestWalkPatternDriver() = default;
623   TestWalkPatternDriver(const TestWalkPatternDriver &other)
624       : PassWrapper(other) {}
625 
626   StringRef getArgument() const override {
627     return "test-walk-pattern-rewrite-driver";
628   }
629   StringRef getDescription() const override {
630     return "Run test walk pattern rewrite driver";
631   }
632   void runOnOperation() override {
633     mlir::RewritePatternSet patterns(&getContext());
634 
635     // Patterns for testing the WalkPatternRewriteDriver.
636     patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp,
637                  MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>(
638         &getContext());
639 
640     DumpNotifications dumpListener;
641     walkAndApplyPatterns(getOperation(), std::move(patterns),
642                          dumpNotifications ? &dumpListener : nullptr);
643   }
644 
645   Option<bool> dumpNotifications{
646       *this, "dump-notifications",
647       llvm::cl::desc("Print rewrite listener notifications"),
648       llvm::cl::init(false)};
649 };
650 
651 } // namespace
652 
653 //===----------------------------------------------------------------------===//
654 // ReturnType Driver.
655 //===----------------------------------------------------------------------===//
656 
657 namespace {
658 // Generate ops for each instance where the type can be successfully inferred.
659 template <typename OpTy>
660 static void invokeCreateWithInferredReturnType(Operation *op) {
661   auto *context = op->getContext();
662   auto fop = op->getParentOfType<func::FuncOp>();
663   auto location = UnknownLoc::get(context);
664   OpBuilder b(op);
665   b.setInsertionPointAfter(op);
666 
667   // Use permutations of 2 args as operands.
668   assert(fop.getNumArguments() >= 2);
669   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
670     for (int j = 0; j < e; ++j) {
671       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
672       SmallVector<Type, 2> inferredReturnTypes;
673       if (succeeded(OpTy::inferReturnTypes(
674               context, std::nullopt, values, op->getDiscardableAttrDictionary(),
675               op->getPropertiesStorage(), op->getRegions(),
676               inferredReturnTypes))) {
677         OperationState state(location, OpTy::getOperationName());
678         // TODO: Expand to regions.
679         OpTy::build(b, state, values, op->getAttrs());
680         (void)b.create(state);
681       }
682     }
683   }
684 }
685 
686 static void reifyReturnShape(Operation *op) {
687   OpBuilder b(op);
688 
689   // Use permutations of 2 args as operands.
690   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
691   SmallVector<Value, 2> shapes;
692   if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
693       !llvm::hasSingleElement(shapes))
694     return;
695   for (const auto &it : llvm::enumerate(shapes)) {
696     op->emitRemark() << "value " << it.index() << ": "
697                      << it.value().getDefiningOp();
698   }
699 }
700 
701 struct TestReturnTypeDriver
702     : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> {
703   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver)
704 
705   void getDependentDialects(DialectRegistry &registry) const override {
706     registry.insert<tensor::TensorDialect>();
707   }
708   StringRef getArgument() const final { return "test-return-type"; }
709   StringRef getDescription() const final { return "Run return type functions"; }
710 
711   void runOnOperation() override {
712     if (getOperation().getName() == "testCreateFunctions") {
713       std::vector<Operation *> ops;
714       // Collect ops to avoid triggering on inserted ops.
715       for (auto &op : getOperation().getBody().front())
716         ops.push_back(&op);
717       // Generate test patterns for each, but skip terminator.
718       for (auto *op : llvm::ArrayRef(ops).drop_back()) {
719         // Test create method of each of the Op classes below. The resultant
720         // output would be in reverse order underneath `op` from which
721         // the attributes and regions are used.
722         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
723         invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>(
724             op);
725         invokeCreateWithInferredReturnType<
726             OpWithShapedTypeInferTypeInterfaceOp>(op);
727       };
728       return;
729     }
730     if (getOperation().getName() == "testReifyFunctions") {
731       std::vector<Operation *> ops;
732       // Collect ops to avoid triggering on inserted ops.
733       for (auto &op : getOperation().getBody().front())
734         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
735           ops.push_back(&op);
736       // Generate test patterns for each, but skip terminator.
737       for (auto *op : ops)
738         reifyReturnShape(op);
739     }
740   }
741 };
742 } // namespace
743 
744 namespace {
745 struct TestDerivedAttributeDriver
746     : public PassWrapper<TestDerivedAttributeDriver,
747                          OperationPass<func::FuncOp>> {
748   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver)
749 
750   StringRef getArgument() const final { return "test-derived-attr"; }
751   StringRef getDescription() const final {
752     return "Run test derived attributes";
753   }
754   void runOnOperation() override;
755 };
756 } // namespace
757 
758 void TestDerivedAttributeDriver::runOnOperation() {
759   getOperation().walk([](DerivedAttributeOpInterface dOp) {
760     auto dAttr = dOp.materializeDerivedAttributes();
761     if (!dAttr)
762       return;
763     for (auto d : dAttr)
764       dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue();
765   });
766 }
767 
768 //===----------------------------------------------------------------------===//
769 // Legalization Driver.
770 //===----------------------------------------------------------------------===//
771 
772 namespace {
773 //===----------------------------------------------------------------------===//
774 // Region-Block Rewrite Testing
775 
776 /// This pattern applies a signature conversion to a block inside a detached
777 /// region.
778 struct TestDetachedSignatureConversion : public ConversionPattern {
779   TestDetachedSignatureConversion(MLIRContext *ctx)
780       : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1,
781                           ctx) {}
782 
783   LogicalResult
784   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
785                   ConversionPatternRewriter &rewriter) const final {
786     if (op->getNumRegions() != 1)
787       return failure();
788     OperationState state(op->getLoc(), "test.legal_op_with_region", operands,
789                          op->getResultTypes(), {}, BlockRange());
790     Region *newRegion = state.addRegion();
791     rewriter.inlineRegionBefore(op->getRegion(0), *newRegion,
792                                 newRegion->begin());
793     TypeConverter::SignatureConversion result(newRegion->getNumArguments());
794     for (unsigned i = 0, e = newRegion->getNumArguments(); i < e; ++i)
795       result.addInputs(i, rewriter.getF64Type());
796     rewriter.applySignatureConversion(&newRegion->front(), result);
797     Operation *newOp = rewriter.create(state);
798     rewriter.replaceOp(op, newOp->getResults());
799     return success();
800   }
801 };
802 
803 /// This pattern is a simple pattern that inlines the first region of a given
804 /// operation into the parent region.
805 struct TestRegionRewriteBlockMovement : public ConversionPattern {
806   TestRegionRewriteBlockMovement(MLIRContext *ctx)
807       : ConversionPattern("test.region", 1, ctx) {}
808 
809   LogicalResult
810   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
811                   ConversionPatternRewriter &rewriter) const final {
812     // Inline this region into the parent region.
813     auto &parentRegion = *op->getParentRegion();
814     auto &opRegion = op->getRegion(0);
815     if (op->getDiscardableAttr("legalizer.should_clone"))
816       rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
817     else
818       rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
819 
820     if (op->getDiscardableAttr("legalizer.erase_old_blocks")) {
821       while (!opRegion.empty())
822         rewriter.eraseBlock(&opRegion.front());
823     }
824 
825     // Drop this operation.
826     rewriter.eraseOp(op);
827     return success();
828   }
829 };
830 /// This pattern is a simple pattern that generates a region containing an
831 /// illegal operation.
832 struct TestRegionRewriteUndo : public RewritePattern {
833   TestRegionRewriteUndo(MLIRContext *ctx)
834       : RewritePattern("test.region_builder", 1, ctx) {}
835 
836   LogicalResult matchAndRewrite(Operation *op,
837                                 PatternRewriter &rewriter) const final {
838     // Create the region operation with an entry block containing arguments.
839     OperationState newRegion(op->getLoc(), "test.region");
840     newRegion.addRegion();
841     auto *regionOp = rewriter.create(newRegion);
842     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
843     entryBlock->addArgument(rewriter.getIntegerType(64),
844                             rewriter.getUnknownLoc());
845 
846     // Add an explicitly illegal operation to ensure the conversion fails.
847     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
848     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
849 
850     // Drop this operation.
851     rewriter.eraseOp(op);
852     return success();
853   }
854 };
855 /// A simple pattern that creates a block at the end of the parent region of the
856 /// matched operation.
857 struct TestCreateBlock : public RewritePattern {
858   TestCreateBlock(MLIRContext *ctx)
859       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
860 
861   LogicalResult matchAndRewrite(Operation *op,
862                                 PatternRewriter &rewriter) const final {
863     Region &region = *op->getParentRegion();
864     Type i32Type = rewriter.getIntegerType(32);
865     Location loc = op->getLoc();
866     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
867     rewriter.create<TerminatorOp>(loc);
868     rewriter.eraseOp(op);
869     return success();
870   }
871 };
872 
873 /// A simple pattern that creates a block containing an invalid operation in
874 /// order to trigger the block creation undo mechanism.
875 struct TestCreateIllegalBlock : public RewritePattern {
876   TestCreateIllegalBlock(MLIRContext *ctx)
877       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
878 
879   LogicalResult matchAndRewrite(Operation *op,
880                                 PatternRewriter &rewriter) const final {
881     Region &region = *op->getParentRegion();
882     Type i32Type = rewriter.getIntegerType(32);
883     Location loc = op->getLoc();
884     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
885     // Create an illegal op to ensure the conversion fails.
886     rewriter.create<ILLegalOpF>(loc, i32Type);
887     rewriter.create<TerminatorOp>(loc);
888     rewriter.eraseOp(op);
889     return success();
890   }
891 };
892 
893 /// A simple pattern that tests the undo mechanism when replacing the uses of a
894 /// block argument.
895 struct TestUndoBlockArgReplace : public ConversionPattern {
896   TestUndoBlockArgReplace(MLIRContext *ctx)
897       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
898 
899   LogicalResult
900   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
901                   ConversionPatternRewriter &rewriter) const final {
902     auto illegalOp =
903         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
904     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
905                                         illegalOp->getResult(0));
906     rewriter.modifyOpInPlace(op, [] {});
907     return success();
908   }
909 };
910 
911 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
912 /// This is to test the rollback logic.
913 struct TestUndoMoveOpBefore : public ConversionPattern {
914   TestUndoMoveOpBefore(MLIRContext *ctx)
915       : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
916 
917   LogicalResult
918   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
919                   ConversionPatternRewriter &rewriter) const override {
920     rewriter.moveOpBefore(op, op->getParentOp());
921     // Replace with an illegal op to ensure the conversion fails.
922     rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
923     return success();
924   }
925 };
926 
927 /// A rewrite pattern that tests the undo mechanism when erasing a block.
928 struct TestUndoBlockErase : public ConversionPattern {
929   TestUndoBlockErase(MLIRContext *ctx)
930       : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
931 
932   LogicalResult
933   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
934                   ConversionPatternRewriter &rewriter) const final {
935     Block *secondBlock = &*std::next(op->getRegion(0).begin());
936     rewriter.setInsertionPointToStart(secondBlock);
937     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
938     rewriter.eraseBlock(secondBlock);
939     rewriter.modifyOpInPlace(op, [] {});
940     return success();
941   }
942 };
943 
944 /// A pattern that modifies a property in-place, but keeps the op illegal.
945 struct TestUndoPropertiesModification : public ConversionPattern {
946   TestUndoPropertiesModification(MLIRContext *ctx)
947       : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {}
948   LogicalResult
949   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
950                   ConversionPatternRewriter &rewriter) const final {
951     if (!op->hasAttr("modify_inplace"))
952       return failure();
953     rewriter.modifyOpInPlace(
954         op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); });
955     return success();
956   }
957 };
958 
959 //===----------------------------------------------------------------------===//
960 // Type-Conversion Rewrite Testing
961 
962 /// This patterns erases a region operation that has had a type conversion.
963 struct TestDropOpSignatureConversion : public ConversionPattern {
964   TestDropOpSignatureConversion(MLIRContext *ctx,
965                                 const TypeConverter &converter)
966       : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
967   LogicalResult
968   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
969                   ConversionPatternRewriter &rewriter) const override {
970     Region &region = op->getRegion(0);
971     Block *entry = &region.front();
972 
973     // Convert the original entry arguments.
974     const TypeConverter &converter = *getTypeConverter();
975     TypeConverter::SignatureConversion result(entry->getNumArguments());
976     if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
977                                               result)) ||
978         failed(rewriter.convertRegionTypes(&region, converter, &result)))
979       return failure();
980 
981     // Convert the region signature and just drop the operation.
982     rewriter.eraseOp(op);
983     return success();
984   }
985 };
986 /// This pattern simply updates the operands of the given operation.
987 struct TestPassthroughInvalidOp : public ConversionPattern {
988   TestPassthroughInvalidOp(MLIRContext *ctx)
989       : ConversionPattern("test.invalid", 1, ctx) {}
990   LogicalResult
991   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
992                   ConversionPatternRewriter &rewriter) const final {
993     SmallVector<Value> flattened;
994     for (auto it : llvm::enumerate(operands)) {
995       ValueRange range = it.value();
996       if (range.size() == 1) {
997         flattened.push_back(range.front());
998         continue;
999       }
1000 
1001       // This is a 1:N replacement. Insert a test.cast op. (That's what the
1002       // argument materialization used to do.)
1003       flattened.push_back(
1004           rewriter
1005               .create<TestCastOp>(op->getLoc(),
1006                                   op->getOperand(it.index()).getType(), range)
1007               .getResult());
1008     }
1009     rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, flattened,
1010                                              std::nullopt);
1011     return success();
1012   }
1013 };
1014 /// Replace with valid op, but simply drop the operands. This is used in a
1015 /// regression where we used to generate circular unrealized_conversion_cast
1016 /// ops.
1017 struct TestDropAndReplaceInvalidOp : public ConversionPattern {
1018   TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
1019       : ConversionPattern(converter,
1020                           "test.drop_operands_and_replace_with_valid", 1, ctx) {
1021   }
1022   LogicalResult
1023   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1024                   ConversionPatternRewriter &rewriter) const final {
1025     rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, ValueRange(),
1026                                              std::nullopt);
1027     return success();
1028   }
1029 };
1030 /// This pattern handles the case of a split return value.
1031 struct TestSplitReturnType : public ConversionPattern {
1032   TestSplitReturnType(MLIRContext *ctx)
1033       : ConversionPattern("test.return", 1, ctx) {}
1034   LogicalResult
1035   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1036                   ConversionPatternRewriter &rewriter) const final {
1037     // Check for a return of F32.
1038     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
1039       return failure();
1040     rewriter.replaceOpWithNewOp<TestReturnOp>(op, operands[0]);
1041     return success();
1042   }
1043 };
1044 
1045 //===----------------------------------------------------------------------===//
1046 // Multi-Level Type-Conversion Rewrite Testing
1047 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
1048   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
1049       : ConversionPattern("test.type_producer", 1, ctx) {}
1050   LogicalResult
1051   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1052                   ConversionPatternRewriter &rewriter) const final {
1053     // If the type is I32, change the type to F32.
1054     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
1055       return failure();
1056     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
1057     return success();
1058   }
1059 };
1060 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
1061   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
1062       : ConversionPattern("test.type_producer", 1, ctx) {}
1063   LogicalResult
1064   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1065                   ConversionPatternRewriter &rewriter) const final {
1066     // If the type is F32, change the type to F64.
1067     if (!Type(*op->result_type_begin()).isF32())
1068       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
1069     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
1070     return success();
1071   }
1072 };
1073 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
1074   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
1075       : ConversionPattern("test.type_producer", 10, ctx) {}
1076   LogicalResult
1077   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1078                   ConversionPatternRewriter &rewriter) const final {
1079     // Always convert to B16, even though it is not a legal type. This tests
1080     // that values are unmapped correctly.
1081     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
1082     return success();
1083   }
1084 };
1085 struct TestUpdateConsumerType : public ConversionPattern {
1086   TestUpdateConsumerType(MLIRContext *ctx)
1087       : ConversionPattern("test.type_consumer", 1, ctx) {}
1088   LogicalResult
1089   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1090                   ConversionPatternRewriter &rewriter) const final {
1091     // Verify that the incoming operand has been successfully remapped to F64.
1092     if (!operands[0].getType().isF64())
1093       return failure();
1094     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
1095     return success();
1096   }
1097 };
1098 
1099 //===----------------------------------------------------------------------===//
1100 // Non-Root Replacement Rewrite Testing
1101 /// This pattern generates an invalid operation, but replaces it before the
1102 /// pattern is finished. This checks that we don't need to legalize the
1103 /// temporary op.
1104 struct TestNonRootReplacement : public RewritePattern {
1105   TestNonRootReplacement(MLIRContext *ctx)
1106       : RewritePattern("test.replace_non_root", 1, ctx) {}
1107 
1108   LogicalResult matchAndRewrite(Operation *op,
1109                                 PatternRewriter &rewriter) const final {
1110     auto resultType = *op->result_type_begin();
1111     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
1112     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
1113 
1114     rewriter.replaceOp(illegalOp, legalOp);
1115     rewriter.replaceOp(op, illegalOp);
1116     return success();
1117   }
1118 };
1119 
1120 //===----------------------------------------------------------------------===//
1121 // Recursive Rewrite Testing
1122 /// This pattern is applied to the same operation multiple times, but has a
1123 /// bounded recursion.
1124 struct TestBoundedRecursiveRewrite
1125     : public OpRewritePattern<TestRecursiveRewriteOp> {
1126   using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
1127 
1128   void initialize() {
1129     // The conversion target handles bounding the recursion of this pattern.
1130     setHasBoundedRewriteRecursion();
1131   }
1132 
1133   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
1134                                 PatternRewriter &rewriter) const final {
1135     // Decrement the depth of the op in-place.
1136     rewriter.modifyOpInPlace(op, [&] {
1137       op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
1138     });
1139     return success();
1140   }
1141 };
1142 
1143 struct TestNestedOpCreationUndoRewrite
1144     : public OpRewritePattern<IllegalOpWithRegionAnchor> {
1145   using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
1146 
1147   LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
1148                                 PatternRewriter &rewriter) const final {
1149     // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1150     rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1151     return success();
1152   };
1153 };
1154 
1155 // This pattern matches `test.blackhole` and delete this op and its producer.
1156 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
1157   using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
1158 
1159   LogicalResult matchAndRewrite(BlackHoleOp op,
1160                                 PatternRewriter &rewriter) const final {
1161     Operation *producer = op.getOperand().getDefiningOp();
1162     // Always erase the user before the producer, the framework should handle
1163     // this correctly.
1164     rewriter.eraseOp(op);
1165     rewriter.eraseOp(producer);
1166     return success();
1167   };
1168 };
1169 
1170 // This pattern replaces explicitly illegal op with explicitly legal op,
1171 // but in addition creates unregistered operation.
1172 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
1173   using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
1174 
1175   LogicalResult matchAndRewrite(ILLegalOpG op,
1176                                 PatternRewriter &rewriter) const final {
1177     IntegerAttr attr = rewriter.getI32IntegerAttr(0);
1178     Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
1179     rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
1180     return success();
1181   };
1182 };
1183 
1184 class TestEraseOp : public ConversionPattern {
1185 public:
1186   TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op", 1, ctx) {}
1187   LogicalResult
1188   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1189                   ConversionPatternRewriter &rewriter) const final {
1190     // Erase op without replacements.
1191     rewriter.eraseOp(op);
1192     return success();
1193   }
1194 };
1195 
1196 /// This pattern matches a test.duplicate_block_args op and duplicates all
1197 /// block arguments.
1198 class TestDuplicateBlockArgs
1199     : public OpConversionPattern<DuplicateBlockArgsOp> {
1200   using OpConversionPattern<DuplicateBlockArgsOp>::OpConversionPattern;
1201 
1202   LogicalResult
1203   matchAndRewrite(DuplicateBlockArgsOp op, OpAdaptor adaptor,
1204                   ConversionPatternRewriter &rewriter) const override {
1205     if (op.getIsLegal())
1206       return failure();
1207     rewriter.startOpModification(op);
1208     Block *body = &op.getBody().front();
1209     TypeConverter::SignatureConversion result(body->getNumArguments());
1210     for (auto it : llvm::enumerate(body->getArgumentTypes()))
1211       result.addInputs(it.index(), {it.value(), it.value()});
1212     rewriter.applySignatureConversion(body, result, getTypeConverter());
1213     op.setIsLegal(true);
1214     rewriter.finalizeOpModification(op);
1215     return success();
1216   }
1217 };
1218 
1219 /// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid
1220 /// op. The pattern supports 1:N replacements and forwards the replacement
1221 /// values of the single operand as test.valid operands.
1222 class TestRepetitive1ToNConsumer : public ConversionPattern {
1223 public:
1224   TestRepetitive1ToNConsumer(MLIRContext *ctx)
1225       : ConversionPattern("test.repetitive_1_to_n_consumer", 1, ctx) {}
1226   LogicalResult
1227   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1228                   ConversionPatternRewriter &rewriter) const final {
1229     // A single operand is expected.
1230     if (op->getNumOperands() != 1)
1231       return failure();
1232     rewriter.replaceOpWithNewOp<TestValidOp>(op, operands.front());
1233     return success();
1234   }
1235 };
1236 
1237 } // namespace
1238 
1239 namespace {
1240 struct TestTypeConverter : public TypeConverter {
1241   using TypeConverter::TypeConverter;
1242   TestTypeConverter() {
1243     addConversion(convertType);
1244     addArgumentMaterialization(materializeCast);
1245     addSourceMaterialization(materializeCast);
1246   }
1247 
1248   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
1249     // Drop I16 types.
1250     if (t.isSignlessInteger(16))
1251       return success();
1252 
1253     // Convert I64 to F64.
1254     if (t.isSignlessInteger(64)) {
1255       results.push_back(FloatType::getF64(t.getContext()));
1256       return success();
1257     }
1258 
1259     // Convert I42 to I43.
1260     if (t.isInteger(42)) {
1261       results.push_back(IntegerType::get(t.getContext(), 43));
1262       return success();
1263     }
1264 
1265     // Split F32 into F16,F16.
1266     if (t.isF32()) {
1267       results.assign(2, FloatType::getF16(t.getContext()));
1268       return success();
1269     }
1270 
1271     // Drop I24 types.
1272     if (t.isInteger(24)) {
1273       return success();
1274     }
1275 
1276     // Otherwise, convert the type directly.
1277     results.push_back(t);
1278     return success();
1279   }
1280 
1281   /// Hook for materializing a conversion. This is necessary because we generate
1282   /// 1->N type mappings.
1283   static Value materializeCast(OpBuilder &builder, Type resultType,
1284                                ValueRange inputs, Location loc) {
1285     return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1286   }
1287 };
1288 
1289 struct TestLegalizePatternDriver
1290     : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> {
1291   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver)
1292 
1293   StringRef getArgument() const final { return "test-legalize-patterns"; }
1294   StringRef getDescription() const final {
1295     return "Run test dialect legalization patterns";
1296   }
1297   /// The mode of conversion to use with the driver.
1298   enum class ConversionMode { Analysis, Full, Partial };
1299 
1300   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
1301 
1302   void getDependentDialects(DialectRegistry &registry) const override {
1303     registry.insert<func::FuncDialect, test::TestDialect>();
1304   }
1305 
1306   void runOnOperation() override {
1307     TestTypeConverter converter;
1308     mlir::RewritePatternSet patterns(&getContext());
1309     populateWithGenerated(patterns);
1310     patterns.add<
1311         TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1312         TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1313         TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
1314         TestSplitReturnType, TestChangeProducerTypeI32ToF32,
1315         TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
1316         TestUpdateConsumerType, TestNonRootReplacement,
1317         TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
1318         TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1319         TestUndoPropertiesModification, TestEraseOp,
1320         TestRepetitive1ToNConsumer>(&getContext());
1321     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
1322         &getContext(), converter);
1323     patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
1324     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1325                                                               converter);
1326     mlir::populateCallOpTypeConversionPattern(patterns, converter);
1327 
1328     // Define the conversion target used for the test.
1329     ConversionTarget target(getContext());
1330     target.addLegalOp<ModuleOp>();
1331     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
1332                       TerminatorOp, OneRegionOp>();
1333     target.addLegalOp(
1334         OperationName("test.legal_op_with_region", &getContext()));
1335     target
1336         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
1337     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
1338       // Don't allow F32 operands.
1339       return llvm::none_of(op.getOperandTypes(),
1340                            [](Type type) { return type.isF32(); });
1341     });
1342     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1343       return converter.isSignatureLegal(op.getFunctionType()) &&
1344              converter.isLegal(&op.getBody());
1345     });
1346     target.addDynamicallyLegalOp<func::CallOp>(
1347         [&](func::CallOp op) { return converter.isLegal(op); });
1348 
1349     // TestCreateUnregisteredOp creates `arith.constant` operation,
1350     // which was not added to target intentionally to test
1351     // correct error code from conversion driver.
1352     target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
1353 
1354     // Expect the type_producer/type_consumer operations to only operate on f64.
1355     target.addDynamicallyLegalOp<TestTypeProducerOp>(
1356         [](TestTypeProducerOp op) { return op.getType().isF64(); });
1357     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1358       return op.getOperand().getType().isF64();
1359     });
1360 
1361     // Check support for marking certain operations as recursively legal.
1362     target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) {
1363       return static_cast<bool>(
1364           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
1365     });
1366 
1367     // Mark the bound recursion operation as dynamically legal.
1368     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
1369         [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
1370 
1371     // Create a dynamically legal rule that can only be legalized by folding it.
1372     target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
1373         [](TestOpInPlaceSelfFold op) { return op.getFolded(); });
1374 
1375     target.addDynamicallyLegalOp<DuplicateBlockArgsOp>(
1376         [](DuplicateBlockArgsOp op) { return op.getIsLegal(); });
1377 
1378     // Handle a partial conversion.
1379     if (mode == ConversionMode::Partial) {
1380       DenseSet<Operation *> unlegalizedOps;
1381       ConversionConfig config;
1382       DumpNotifications dumpNotifications;
1383       config.listener = &dumpNotifications;
1384       config.unlegalizedOps = &unlegalizedOps;
1385       if (failed(applyPartialConversion(getOperation(), target,
1386                                         std::move(patterns), config))) {
1387         getOperation()->emitRemark() << "applyPartialConversion failed";
1388       }
1389       // Emit remarks for each legalizable operation.
1390       for (auto *op : unlegalizedOps)
1391         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1392       return;
1393     }
1394 
1395     // Handle a full conversion.
1396     if (mode == ConversionMode::Full) {
1397       // Check support for marking unknown operations as dynamically legal.
1398       target.markUnknownOpDynamicallyLegal([](Operation *op) {
1399         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
1400       });
1401 
1402       ConversionConfig config;
1403       DumpNotifications dumpNotifications;
1404       config.listener = &dumpNotifications;
1405       if (failed(applyFullConversion(getOperation(), target,
1406                                      std::move(patterns), config))) {
1407         getOperation()->emitRemark() << "applyFullConversion failed";
1408       }
1409       return;
1410     }
1411 
1412     // Otherwise, handle an analysis conversion.
1413     assert(mode == ConversionMode::Analysis);
1414 
1415     // Analyze the convertible operations.
1416     DenseSet<Operation *> legalizedOps;
1417     ConversionConfig config;
1418     config.legalizableOps = &legalizedOps;
1419     if (failed(applyAnalysisConversion(getOperation(), target,
1420                                        std::move(patterns), config)))
1421       return signalPassFailure();
1422 
1423     // Emit remarks for each legalizable operation.
1424     for (auto *op : legalizedOps)
1425       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
1426   }
1427 
1428   /// The mode of conversion to use.
1429   ConversionMode mode;
1430 };
1431 } // namespace
1432 
1433 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
1434     legalizerConversionMode(
1435         "test-legalize-mode",
1436         llvm::cl::desc("The legalization mode to use with the test driver"),
1437         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
1438         llvm::cl::values(
1439             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
1440                        "analysis", "Perform an analysis conversion"),
1441             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
1442                        "Perform a full conversion"),
1443             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
1444                        "partial", "Perform a partial conversion")));
1445 
1446 //===----------------------------------------------------------------------===//
1447 // ConversionPatternRewriter::getRemappedValue testing. This method is used
1448 // to get the remapped value of an original value that was replaced using
1449 // ConversionPatternRewriter.
1450 namespace {
1451 struct TestRemapValueTypeConverter : public TypeConverter {
1452   using TypeConverter::TypeConverter;
1453 
1454   TestRemapValueTypeConverter() {
1455     addConversion(
1456         [](Float32Type type) { return Float64Type::get(type.getContext()); });
1457     addConversion([](Type type) { return type; });
1458   }
1459 };
1460 
1461 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
1462 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
1463 /// operand twice.
1464 ///
1465 /// Example:
1466 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
1467 /// is replaced with:
1468 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
1469 struct OneVResOneVOperandOp1Converter
1470     : public OpConversionPattern<OneVResOneVOperandOp1> {
1471   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
1472 
1473   LogicalResult
1474   matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
1475                   ConversionPatternRewriter &rewriter) const override {
1476     auto origOps = op.getOperands();
1477     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
1478            "One operand expected");
1479     Value origOp = *origOps.begin();
1480     SmallVector<Value, 2> remappedOperands;
1481     // Replicate the remapped original operand twice. Note that we don't used
1482     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
1483     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1484     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1485 
1486     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
1487                                                        remappedOperands);
1488     return success();
1489   }
1490 };
1491 
1492 /// A rewriter pattern that tests that blocks can be merged.
1493 struct TestRemapValueInRegion
1494     : public OpConversionPattern<TestRemappedValueRegionOp> {
1495   using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
1496 
1497   LogicalResult
1498   matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
1499                   ConversionPatternRewriter &rewriter) const final {
1500     Block &block = op.getBody().front();
1501     Operation *terminator = block.getTerminator();
1502 
1503     // Merge the block into the parent region.
1504     Block *parentBlock = op->getBlock();
1505     Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator());
1506     rewriter.mergeBlocks(&block, parentBlock, ValueRange());
1507     rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange());
1508 
1509     // Replace the results of this operation with the remapped terminator
1510     // values.
1511     SmallVector<Value> terminatorOperands;
1512     if (failed(rewriter.getRemappedValues(terminator->getOperands(),
1513                                           terminatorOperands)))
1514       return failure();
1515 
1516     rewriter.eraseOp(terminator);
1517     rewriter.replaceOp(op, terminatorOperands);
1518     return success();
1519   }
1520 };
1521 
1522 struct TestRemappedValue
1523     : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> {
1524   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue)
1525 
1526   StringRef getArgument() const final { return "test-remapped-value"; }
1527   StringRef getDescription() const final {
1528     return "Test public remapped value mechanism in ConversionPatternRewriter";
1529   }
1530   void runOnOperation() override {
1531     TestRemapValueTypeConverter typeConverter;
1532 
1533     mlir::RewritePatternSet patterns(&getContext());
1534     patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
1535     patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
1536         &getContext());
1537     patterns.add<TestRemapValueInRegion>(typeConverter, &getContext());
1538 
1539     mlir::ConversionTarget target(getContext());
1540     target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>();
1541 
1542     // Expect the type_producer/type_consumer operations to only operate on f64.
1543     target.addDynamicallyLegalOp<TestTypeProducerOp>(
1544         [](TestTypeProducerOp op) { return op.getType().isF64(); });
1545     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1546       return op.getOperand().getType().isF64();
1547     });
1548 
1549     // We make OneVResOneVOperandOp1 legal only when it has more that one
1550     // operand. This will trigger the conversion that will replace one-operand
1551     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
1552     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
1553         [](Operation *op) { return op->getNumOperands() > 1; });
1554 
1555     if (failed(mlir::applyFullConversion(getOperation(), target,
1556                                          std::move(patterns)))) {
1557       signalPassFailure();
1558     }
1559   }
1560 };
1561 } // namespace
1562 
1563 //===----------------------------------------------------------------------===//
1564 // Test patterns without a specific root operation kind
1565 //===----------------------------------------------------------------------===//
1566 
1567 namespace {
1568 /// This pattern matches and removes any operation in the test dialect.
1569 struct RemoveTestDialectOps : public RewritePattern {
1570   RemoveTestDialectOps(MLIRContext *context)
1571       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
1572 
1573   LogicalResult matchAndRewrite(Operation *op,
1574                                 PatternRewriter &rewriter) const override {
1575     if (!isa<TestDialect>(op->getDialect()))
1576       return failure();
1577     rewriter.eraseOp(op);
1578     return success();
1579   }
1580 };
1581 
1582 struct TestUnknownRootOpDriver
1583     : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> {
1584   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver)
1585 
1586   StringRef getArgument() const final {
1587     return "test-legalize-unknown-root-patterns";
1588   }
1589   StringRef getDescription() const final {
1590     return "Test public remapped value mechanism in ConversionPatternRewriter";
1591   }
1592   void runOnOperation() override {
1593     mlir::RewritePatternSet patterns(&getContext());
1594     patterns.add<RemoveTestDialectOps>(&getContext());
1595 
1596     mlir::ConversionTarget target(getContext());
1597     target.addIllegalDialect<TestDialect>();
1598     if (failed(applyPartialConversion(getOperation(), target,
1599                                       std::move(patterns))))
1600       signalPassFailure();
1601   }
1602 };
1603 } // namespace
1604 
1605 //===----------------------------------------------------------------------===//
1606 // Test patterns that uses operations and types defined at runtime
1607 //===----------------------------------------------------------------------===//
1608 
1609 namespace {
1610 /// This pattern matches dynamic operations 'test.one_operand_two_results' and
1611 /// replace them with dynamic operations 'test.generic_dynamic_op'.
1612 struct RewriteDynamicOp : public RewritePattern {
1613   RewriteDynamicOp(MLIRContext *context)
1614       : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1,
1615                        context) {}
1616 
1617   LogicalResult matchAndRewrite(Operation *op,
1618                                 PatternRewriter &rewriter) const override {
1619     assert(op->getName().getStringRef() ==
1620                "test.dynamic_one_operand_two_results" &&
1621            "rewrite pattern should only match operations with the right name");
1622 
1623     OperationState state(op->getLoc(), "test.dynamic_generic",
1624                          op->getOperands(), op->getResultTypes(),
1625                          op->getAttrs());
1626     auto *newOp = rewriter.create(state);
1627     rewriter.replaceOp(op, newOp->getResults());
1628     return success();
1629   }
1630 };
1631 
1632 struct TestRewriteDynamicOpDriver
1633     : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> {
1634   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver)
1635 
1636   void getDependentDialects(DialectRegistry &registry) const override {
1637     registry.insert<TestDialect>();
1638   }
1639   StringRef getArgument() const final { return "test-rewrite-dynamic-op"; }
1640   StringRef getDescription() const final {
1641     return "Test rewritting on dynamic operations";
1642   }
1643   void runOnOperation() override {
1644     RewritePatternSet patterns(&getContext());
1645     patterns.add<RewriteDynamicOp>(&getContext());
1646 
1647     ConversionTarget target(getContext());
1648     target.addIllegalOp(
1649         OperationName("test.dynamic_one_operand_two_results", &getContext()));
1650     target.addLegalOp(OperationName("test.dynamic_generic", &getContext()));
1651     if (failed(applyPartialConversion(getOperation(), target,
1652                                       std::move(patterns))))
1653       signalPassFailure();
1654   }
1655 };
1656 } // end anonymous namespace
1657 
1658 //===----------------------------------------------------------------------===//
1659 // Test type conversions
1660 //===----------------------------------------------------------------------===//
1661 
1662 namespace {
1663 struct TestTypeConversionProducer
1664     : public OpConversionPattern<TestTypeProducerOp> {
1665   using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
1666   LogicalResult
1667   matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
1668                   ConversionPatternRewriter &rewriter) const final {
1669     Type resultType = op.getType();
1670     Type convertedType = getTypeConverter()
1671                              ? getTypeConverter()->convertType(resultType)
1672                              : resultType;
1673     if (isa<FloatType>(resultType))
1674       resultType = rewriter.getF64Type();
1675     else if (resultType.isInteger(16))
1676       resultType = rewriter.getIntegerType(64);
1677     else if (isa<test::TestRecursiveType>(resultType) &&
1678              convertedType != resultType)
1679       resultType = convertedType;
1680     else
1681       return failure();
1682 
1683     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
1684     return success();
1685   }
1686 };
1687 
1688 /// Call signature conversion and then fail the rewrite to trigger the undo
1689 /// mechanism.
1690 struct TestSignatureConversionUndo
1691     : public OpConversionPattern<TestSignatureConversionUndoOp> {
1692   using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
1693 
1694   LogicalResult
1695   matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
1696                   ConversionPatternRewriter &rewriter) const final {
1697     (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
1698     return failure();
1699   }
1700 };
1701 
1702 /// Call signature conversion without providing a type converter to handle
1703 /// materializations.
1704 struct TestTestSignatureConversionNoConverter
1705     : public OpConversionPattern<TestSignatureConversionNoConverterOp> {
1706   TestTestSignatureConversionNoConverter(const TypeConverter &converter,
1707                                          MLIRContext *context)
1708       : OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
1709         converter(converter) {}
1710 
1711   LogicalResult
1712   matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor,
1713                   ConversionPatternRewriter &rewriter) const final {
1714     Region &region = op->getRegion(0);
1715     Block *entry = &region.front();
1716 
1717     // Convert the original entry arguments.
1718     TypeConverter::SignatureConversion result(entry->getNumArguments());
1719     if (failed(
1720             converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
1721       return failure();
1722     rewriter.modifyOpInPlace(op, [&] {
1723       rewriter.applySignatureConversion(&region.front(), result);
1724     });
1725     return success();
1726   }
1727 
1728   const TypeConverter &converter;
1729 };
1730 
1731 /// Just forward the operands to the root op. This is essentially a no-op
1732 /// pattern that is used to trigger target materialization.
1733 struct TestTypeConsumerForward
1734     : public OpConversionPattern<TestTypeConsumerOp> {
1735   using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
1736 
1737   LogicalResult
1738   matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
1739                   ConversionPatternRewriter &rewriter) const final {
1740     rewriter.modifyOpInPlace(op,
1741                              [&] { op->setOperands(adaptor.getOperands()); });
1742     return success();
1743   }
1744 };
1745 
1746 struct TestTypeConversionAnotherProducer
1747     : public OpRewritePattern<TestAnotherTypeProducerOp> {
1748   using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
1749 
1750   LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
1751                                 PatternRewriter &rewriter) const final {
1752     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
1753     return success();
1754   }
1755 };
1756 
1757 struct TestReplaceWithLegalOp : public ConversionPattern {
1758   TestReplaceWithLegalOp(MLIRContext *ctx)
1759       : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
1760   LogicalResult
1761   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1762                   ConversionPatternRewriter &rewriter) const final {
1763     rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
1764     return success();
1765   }
1766 };
1767 
1768 struct TestTypeConversionDriver
1769     : public PassWrapper<TestTypeConversionDriver, OperationPass<>> {
1770   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
1771 
1772   void getDependentDialects(DialectRegistry &registry) const override {
1773     registry.insert<TestDialect>();
1774   }
1775   StringRef getArgument() const final {
1776     return "test-legalize-type-conversion";
1777   }
1778   StringRef getDescription() const final {
1779     return "Test various type conversion functionalities in DialectConversion";
1780   }
1781 
1782   void runOnOperation() override {
1783     // Initialize the type converter.
1784     SmallVector<Type, 2> conversionCallStack;
1785     TypeConverter converter;
1786 
1787     /// Add the legal set of type conversions.
1788     converter.addConversion([](Type type) -> Type {
1789       // Treat F64 as legal.
1790       if (type.isF64())
1791         return type;
1792       // Allow converting BF16/F16/F32 to F64.
1793       if (type.isBF16() || type.isF16() || type.isF32())
1794         return FloatType::getF64(type.getContext());
1795       // Otherwise, the type is illegal.
1796       return nullptr;
1797     });
1798     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
1799       // Drop all integer types.
1800       return success();
1801     });
1802     converter.addConversion(
1803         // Convert a recursive self-referring type into a non-self-referring
1804         // type named "outer_converted_type" that contains a SimpleAType.
1805         [&](test::TestRecursiveType type,
1806             SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
1807           // If the type is already converted, return it to indicate that it is
1808           // legal.
1809           if (type.getName() == "outer_converted_type") {
1810             results.push_back(type);
1811             return success();
1812           }
1813 
1814           conversionCallStack.push_back(type);
1815           auto popConversionCallStack = llvm::make_scope_exit(
1816               [&conversionCallStack]() { conversionCallStack.pop_back(); });
1817 
1818           // If the type is on the call stack more than once (it is there at
1819           // least once because of the _current_ call, which is always the last
1820           // element on the stack), we've hit the recursive case. Just return
1821           // SimpleAType here to create a non-recursive type as a result.
1822           if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(),
1823                                  type)) {
1824             results.push_back(test::SimpleAType::get(type.getContext()));
1825             return success();
1826           }
1827 
1828           // Convert the body recursively.
1829           auto result = test::TestRecursiveType::get(type.getContext(),
1830                                                      "outer_converted_type");
1831           if (failed(result.setBody(converter.convertType(type.getBody()))))
1832             return failure();
1833           results.push_back(result);
1834           return success();
1835         });
1836 
1837     /// Add the legal set of type materializations.
1838     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
1839                                           ValueRange inputs,
1840                                           Location loc) -> Value {
1841       // Allow casting from F64 back to F32.
1842       if (!resultType.isF16() && inputs.size() == 1 &&
1843           inputs[0].getType().isF64())
1844         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1845       // Allow producing an i32 or i64 from nothing.
1846       if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
1847           inputs.empty())
1848         return builder.create<TestTypeProducerOp>(loc, resultType);
1849       // Allow producing an i64 from an integer.
1850       if (isa<IntegerType>(resultType) && inputs.size() == 1 &&
1851           isa<IntegerType>(inputs[0].getType()))
1852         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1853       // Otherwise, fail.
1854       return nullptr;
1855     });
1856 
1857     // Initialize the conversion target.
1858     mlir::ConversionTarget target(getContext());
1859     target.addLegalOp<LegalOpD>();
1860     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
1861       auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType());
1862       return op.getType().isF64() || op.getType().isInteger(64) ||
1863              (recursiveType &&
1864               recursiveType.getName() == "outer_converted_type");
1865     });
1866     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1867       return converter.isSignatureLegal(op.getFunctionType()) &&
1868              converter.isLegal(&op.getBody());
1869     });
1870     target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
1871       // Allow casts from F64 to F32.
1872       return (*op.operand_type_begin()).isF64() && op.getType().isF32();
1873     });
1874     target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>(
1875         [&](TestSignatureConversionNoConverterOp op) {
1876           return converter.isLegal(op.getRegion().front().getArgumentTypes());
1877         });
1878 
1879     // Initialize the set of rewrite patterns.
1880     RewritePatternSet patterns(&getContext());
1881     patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
1882                  TestSignatureConversionUndo,
1883                  TestTestSignatureConversionNoConverter>(converter,
1884                                                          &getContext());
1885     patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
1886         &getContext());
1887     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1888                                                               converter);
1889 
1890     if (failed(applyPartialConversion(getOperation(), target,
1891                                       std::move(patterns))))
1892       signalPassFailure();
1893   }
1894 };
1895 } // namespace
1896 
1897 //===----------------------------------------------------------------------===//
1898 // Test Target Materialization With No Uses
1899 //===----------------------------------------------------------------------===//
1900 
1901 namespace {
1902 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> {
1903   using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern;
1904 
1905   LogicalResult
1906   matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor,
1907                   ConversionPatternRewriter &rewriter) const final {
1908     rewriter.replaceOp(op, adaptor.getOperands());
1909     return success();
1910   }
1911 };
1912 
1913 struct TestTargetMaterializationWithNoUses
1914     : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> {
1915   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
1916       TestTargetMaterializationWithNoUses)
1917 
1918   StringRef getArgument() const final {
1919     return "test-target-materialization-with-no-uses";
1920   }
1921   StringRef getDescription() const final {
1922     return "Test a special case of target materialization in DialectConversion";
1923   }
1924 
1925   void runOnOperation() override {
1926     TypeConverter converter;
1927     converter.addConversion([](Type t) { return t; });
1928     converter.addConversion([](IntegerType intTy) -> Type {
1929       if (intTy.getWidth() == 16)
1930         return IntegerType::get(intTy.getContext(), 64);
1931       return intTy;
1932     });
1933     converter.addTargetMaterialization(
1934         [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1935           return builder.create<TestCastOp>(loc, type, inputs).getResult();
1936         });
1937 
1938     ConversionTarget target(getContext());
1939     target.addIllegalOp<TestTypeChangerOp>();
1940 
1941     RewritePatternSet patterns(&getContext());
1942     patterns.add<ForwardOperandPattern>(converter, &getContext());
1943 
1944     if (failed(applyPartialConversion(getOperation(), target,
1945                                       std::move(patterns))))
1946       signalPassFailure();
1947   }
1948 };
1949 } // namespace
1950 
1951 //===----------------------------------------------------------------------===//
1952 // Test Block Merging
1953 //===----------------------------------------------------------------------===//
1954 
1955 namespace {
1956 /// A rewriter pattern that tests that blocks can be merged.
1957 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
1958   using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
1959 
1960   LogicalResult
1961   matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
1962                   ConversionPatternRewriter &rewriter) const final {
1963     Block &firstBlock = op.getBody().front();
1964     Operation *branchOp = firstBlock.getTerminator();
1965     Block *secondBlock = &*(std::next(op.getBody().begin()));
1966     auto succOperands = branchOp->getOperands();
1967     SmallVector<Value, 2> replacements(succOperands);
1968     rewriter.eraseOp(branchOp);
1969     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1970     rewriter.modifyOpInPlace(op, [] {});
1971     return success();
1972   }
1973 };
1974 
1975 /// A rewrite pattern to tests the undo mechanism of blocks being merged.
1976 struct TestUndoBlocksMerge : public ConversionPattern {
1977   TestUndoBlocksMerge(MLIRContext *ctx)
1978       : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
1979   LogicalResult
1980   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1981                   ConversionPatternRewriter &rewriter) const final {
1982     Block &firstBlock = op->getRegion(0).front();
1983     Operation *branchOp = firstBlock.getTerminator();
1984     Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
1985     rewriter.setInsertionPointToStart(secondBlock);
1986     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
1987     auto succOperands = branchOp->getOperands();
1988     SmallVector<Value, 2> replacements(succOperands);
1989     rewriter.eraseOp(branchOp);
1990     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1991     rewriter.modifyOpInPlace(op, [] {});
1992     return success();
1993   }
1994 };
1995 
1996 /// A rewrite mechanism to inline the body of the op into its parent, when both
1997 /// ops can have a single block.
1998 struct TestMergeSingleBlockOps
1999     : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
2000   using OpConversionPattern<
2001       SingleBlockImplicitTerminatorOp>::OpConversionPattern;
2002 
2003   LogicalResult
2004   matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
2005                   ConversionPatternRewriter &rewriter) const final {
2006     SingleBlockImplicitTerminatorOp parentOp =
2007         op->getParentOfType<SingleBlockImplicitTerminatorOp>();
2008     if (!parentOp)
2009       return failure();
2010     Block &innerBlock = op.getRegion().front();
2011     TerminatorOp innerTerminator =
2012         cast<TerminatorOp>(innerBlock.getTerminator());
2013     rewriter.inlineBlockBefore(&innerBlock, op);
2014     rewriter.eraseOp(innerTerminator);
2015     rewriter.eraseOp(op);
2016     return success();
2017   }
2018 };
2019 
2020 struct TestMergeBlocksPatternDriver
2021     : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> {
2022   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver)
2023 
2024   StringRef getArgument() const final { return "test-merge-blocks"; }
2025   StringRef getDescription() const final {
2026     return "Test Merging operation in ConversionPatternRewriter";
2027   }
2028   void runOnOperation() override {
2029     MLIRContext *context = &getContext();
2030     mlir::RewritePatternSet patterns(context);
2031     patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
2032         context);
2033     ConversionTarget target(*context);
2034     target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
2035                       TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
2036     target.addIllegalOp<ILLegalOpF>();
2037 
2038     /// Expect the op to have a single block after legalization.
2039     target.addDynamicallyLegalOp<TestMergeBlocksOp>(
2040         [&](TestMergeBlocksOp op) -> bool {
2041           return llvm::hasSingleElement(op.getBody());
2042         });
2043 
2044     /// Only allow `test.br` within test.merge_blocks op.
2045     target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
2046       return op->getParentOfType<TestMergeBlocksOp>();
2047     });
2048 
2049     /// Expect that all nested test.SingleBlockImplicitTerminator ops are
2050     /// inlined.
2051     target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
2052         [&](SingleBlockImplicitTerminatorOp op) -> bool {
2053           return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
2054         });
2055 
2056     DenseSet<Operation *> unlegalizedOps;
2057     ConversionConfig config;
2058     config.unlegalizedOps = &unlegalizedOps;
2059     (void)applyPartialConversion(getOperation(), target, std::move(patterns),
2060                                  config);
2061     for (auto *op : unlegalizedOps)
2062       op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
2063   }
2064 };
2065 } // namespace
2066 
2067 //===----------------------------------------------------------------------===//
2068 // Test Selective Replacement
2069 //===----------------------------------------------------------------------===//
2070 
2071 namespace {
2072 /// A rewrite mechanism to inline the body of the op into its parent, when both
2073 /// ops can have a single block.
2074 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
2075   using OpRewritePattern<TestCastOp>::OpRewritePattern;
2076 
2077   LogicalResult matchAndRewrite(TestCastOp op,
2078                                 PatternRewriter &rewriter) const final {
2079     if (op.getNumOperands() != 2)
2080       return failure();
2081     OperandRange operands = op.getOperands();
2082 
2083     // Replace non-terminator uses with the first operand.
2084     rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) {
2085       return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
2086     });
2087     // Replace everything else with the second operand if the operation isn't
2088     // dead.
2089     rewriter.replaceOp(op, op.getOperand(1));
2090     return success();
2091   }
2092 };
2093 
2094 struct TestSelectiveReplacementPatternDriver
2095     : public PassWrapper<TestSelectiveReplacementPatternDriver,
2096                          OperationPass<>> {
2097   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
2098       TestSelectiveReplacementPatternDriver)
2099 
2100   StringRef getArgument() const final {
2101     return "test-pattern-selective-replacement";
2102   }
2103   StringRef getDescription() const final {
2104     return "Test selective replacement in the PatternRewriter";
2105   }
2106   void runOnOperation() override {
2107     MLIRContext *context = &getContext();
2108     mlir::RewritePatternSet patterns(context);
2109     patterns.add<TestSelectiveOpReplacementPattern>(context);
2110     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
2111   }
2112 };
2113 } // namespace
2114 
2115 //===----------------------------------------------------------------------===//
2116 // PassRegistration
2117 //===----------------------------------------------------------------------===//
2118 
2119 namespace mlir {
2120 namespace test {
2121 void registerPatternsTestPass() {
2122   PassRegistration<TestReturnTypeDriver>();
2123 
2124   PassRegistration<TestDerivedAttributeDriver>();
2125 
2126   PassRegistration<TestGreedyPatternDriver>();
2127   PassRegistration<TestStrictPatternDriver>();
2128   PassRegistration<TestWalkPatternDriver>();
2129 
2130   PassRegistration<TestLegalizePatternDriver>([] {
2131     return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
2132   });
2133 
2134   PassRegistration<TestRemappedValue>();
2135 
2136   PassRegistration<TestUnknownRootOpDriver>();
2137 
2138   PassRegistration<TestTypeConversionDriver>();
2139   PassRegistration<TestTargetMaterializationWithNoUses>();
2140 
2141   PassRegistration<TestRewriteDynamicOpDriver>();
2142 
2143   PassRegistration<TestMergeBlocksPatternDriver>();
2144   PassRegistration<TestSelectiveReplacementPatternDriver>();
2145 }
2146 } // namespace test
2147 } // namespace mlir
2148