xref: /llvm-project/mlir/test/lib/Dialect/Test/TestPatterns.cpp (revision 60a20bd6973c8fc7aa9a19465ed042604e07fb17)
1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestTypes.h"
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include "mlir/Transforms/FoldUtils.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 #include "llvm/ADT/ScopeExit.h"
21 
22 using namespace mlir;
23 using namespace test;
24 
25 // Native function for testing NativeCodeCall
26 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
27   return choice.getValue() ? input1 : input2;
28 }
29 
30 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
31   rewriter.create<OpI>(loc, input);
32 }
33 
34 static void handleNoResultOp(PatternRewriter &rewriter,
35                              OpSymbolBindingNoResult op) {
36   // Turn the no result op to a one-result op.
37   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
38                                     op.getOperand());
39 }
40 
41 static bool getFirstI32Result(Operation *op, Value &value) {
42   if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
43     return false;
44   value = op->getResult(0);
45   return true;
46 }
47 
48 static Value bindNativeCodeCallResult(Value value) { return value; }
49 
50 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
51                                                               Value input2) {
52   return SmallVector<Value, 2>({input2, input1});
53 }
54 
55 // Test that natives calls are only called once during rewrites.
56 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
57 // This let us check the number of times OpM_Test was called by inspecting
58 // the returned value in the MLIR output.
59 static int64_t opMIncreasingValue = 314159265;
60 static Attribute opMTest(PatternRewriter &rewriter, Value val) {
61   int64_t i = opMIncreasingValue++;
62   return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
63 }
64 
65 namespace {
66 #include "TestPatterns.inc"
67 } // namespace
68 
69 //===----------------------------------------------------------------------===//
70 // Test Reduce Pattern Interface
71 //===----------------------------------------------------------------------===//
72 
73 void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
74   populateWithGenerated(patterns);
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // Canonicalizer Driver.
79 //===----------------------------------------------------------------------===//
80 
81 namespace {
82 struct FoldingPattern : public RewritePattern {
83 public:
84   FoldingPattern(MLIRContext *context)
85       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
86                        /*benefit=*/1, context) {}
87 
88   LogicalResult matchAndRewrite(Operation *op,
89                                 PatternRewriter &rewriter) const override {
90     // Exercise createOrFold API for a single-result operation that is folded
91     // upon construction. The operation being created has an in-place folder,
92     // and it should be still present in the output. Furthermore, the folder
93     // should not crash when attempting to recover the (unchanged) operation
94     // result.
95     Value result = rewriter.createOrFold<TestOpInPlaceFold>(
96         op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0));
97     assert(result);
98     rewriter.replaceOp(op, result);
99     return success();
100   }
101 };
102 
103 /// This pattern creates a foldable operation at the entry point of the block.
104 /// This tests the situation where the operation folder will need to replace an
105 /// operation with a previously created constant that does not initially
106 /// dominate the operation to replace.
107 struct FolderInsertBeforePreviouslyFoldedConstantPattern
108     : public OpRewritePattern<TestCastOp> {
109 public:
110   using OpRewritePattern<TestCastOp>::OpRewritePattern;
111 
112   LogicalResult matchAndRewrite(TestCastOp op,
113                                 PatternRewriter &rewriter) const override {
114     if (!op->hasAttr("test_fold_before_previously_folded_op"))
115       return failure();
116     rewriter.setInsertionPointToStart(op->getBlock());
117 
118     auto constOp = rewriter.create<arith::ConstantOp>(
119         op.getLoc(), rewriter.getBoolAttr(true));
120     rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
121                                             Value(constOp));
122     return success();
123   }
124 };
125 
126 /// This pattern matches test.op_commutative2 with the first operand being
127 /// another test.op_commutative2 with a constant on the right side and fold it
128 /// away by propagating it as its result. This is intend to check that patterns
129 /// are applied after the commutative property moves constant to the right.
130 struct FolderCommutativeOp2WithConstant
131     : public OpRewritePattern<TestCommutative2Op> {
132 public:
133   using OpRewritePattern<TestCommutative2Op>::OpRewritePattern;
134 
135   LogicalResult matchAndRewrite(TestCommutative2Op op,
136                                 PatternRewriter &rewriter) const override {
137     auto operand =
138         dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
139     if (!operand)
140       return failure();
141     Attribute constInput;
142     if (!matchPattern(operand->getOperand(1), m_Constant(&constInput)))
143       return failure();
144     rewriter.replaceOp(op, operand->getOperand(1));
145     return success();
146   }
147 };
148 
149 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer
150 /// attribute with value smaller than MaxVal, it increments the value by 1.
151 template <int MaxVal>
152 struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
153   using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern;
154 
155   LogicalResult matchAndRewrite(AnyAttrOfOp op,
156                                 PatternRewriter &rewriter) const override {
157     auto intAttr = dyn_cast<IntegerAttr>(op.getAttr());
158     if (!intAttr)
159       return failure();
160     int64_t val = intAttr.getInt();
161     if (val >= MaxVal)
162       return failure();
163     rewriter.modifyOpInPlace(
164         op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); });
165     return success();
166   }
167 };
168 
169 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op".
170 struct MakeOpEligible : public RewritePattern {
171   MakeOpEligible(MLIRContext *context)
172       : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {}
173 
174   LogicalResult matchAndRewrite(Operation *op,
175                                 PatternRewriter &rewriter) const override {
176     if (op->hasAttr("eligible"))
177       return failure();
178     rewriter.modifyOpInPlace(
179         op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); });
180     return success();
181   }
182 };
183 
184 /// This pattern hoists eligible ops out of a "test.one_region_op".
185 struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
186   using OpRewritePattern<test::OneRegionOp>::OpRewritePattern;
187 
188   LogicalResult matchAndRewrite(test::OneRegionOp op,
189                                 PatternRewriter &rewriter) const override {
190     Operation *terminator = op.getRegion().front().getTerminator();
191     Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp();
192     if (toBeHoisted->getParentOp() != op)
193       return failure();
194     if (!toBeHoisted->hasAttr("eligible"))
195       return failure();
196     rewriter.moveOpBefore(toBeHoisted, op);
197     return success();
198   }
199 };
200 
201 /// This pattern moves "test.move_before_parent_op" before the parent op.
202 struct MoveBeforeParentOp : public RewritePattern {
203   MoveBeforeParentOp(MLIRContext *context)
204       : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
205 
206   LogicalResult matchAndRewrite(Operation *op,
207                                 PatternRewriter &rewriter) const override {
208     // Do not hoist past functions.
209     if (isa<FunctionOpInterface>(op->getParentOp()))
210       return failure();
211     rewriter.moveOpBefore(op, op->getParentOp());
212     return success();
213   }
214 };
215 
216 /// This pattern inlines blocks that are nested in
217 /// "test.inline_blocks_into_parent" into the parent block.
218 struct InlineBlocksIntoParent : public RewritePattern {
219   InlineBlocksIntoParent(MLIRContext *context)
220       : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
221                        context) {}
222 
223   LogicalResult matchAndRewrite(Operation *op,
224                                 PatternRewriter &rewriter) const override {
225     bool changed = false;
226     for (Region &r : op->getRegions()) {
227       while (!r.empty()) {
228         rewriter.inlineBlockBefore(&r.front(), op);
229         changed = true;
230       }
231     }
232     return success(changed);
233   }
234 };
235 
236 /// This pattern splits blocks at "test.split_block_here" and replaces the op
237 /// with a new op (to prevent an infinite loop of block splitting).
238 struct SplitBlockHere : public RewritePattern {
239   SplitBlockHere(MLIRContext *context)
240       : RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
241 
242   LogicalResult matchAndRewrite(Operation *op,
243                                 PatternRewriter &rewriter) const override {
244     rewriter.splitBlock(op->getBlock(), op->getIterator());
245     Operation *newOp = rewriter.create(
246         op->getLoc(),
247         OperationName("test.new_op", op->getContext()).getIdentifier(),
248         op->getOperands(), op->getResultTypes());
249     rewriter.replaceOp(op, newOp);
250     return success();
251   }
252 };
253 
254 /// This pattern clones "test.clone_me" ops.
255 struct CloneOp : public RewritePattern {
256   CloneOp(MLIRContext *context)
257       : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
258 
259   LogicalResult matchAndRewrite(Operation *op,
260                                 PatternRewriter &rewriter) const override {
261     // Do not clone already cloned ops to avoid going into an infinite loop.
262     if (op->hasAttr("was_cloned"))
263       return failure();
264     Operation *cloned = rewriter.clone(*op);
265     cloned->setAttr("was_cloned", rewriter.getUnitAttr());
266     return success();
267   }
268 };
269 
270 /// This pattern clones regions of "test.clone_region_before" ops before the
271 /// parent block.
272 struct CloneRegionBeforeOp : public RewritePattern {
273   CloneRegionBeforeOp(MLIRContext *context)
274       : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {}
275 
276   LogicalResult matchAndRewrite(Operation *op,
277                                 PatternRewriter &rewriter) const override {
278     // Do not clone already cloned ops to avoid going into an infinite loop.
279     if (op->hasAttr("was_cloned"))
280       return failure();
281     for (Region &r : op->getRegions())
282       rewriter.cloneRegionBefore(r, op->getBlock());
283     op->setAttr("was_cloned", rewriter.getUnitAttr());
284     return success();
285   }
286 };
287 
288 struct TestPatternDriver
289     : public PassWrapper<TestPatternDriver, OperationPass<>> {
290   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
291 
292   TestPatternDriver() = default;
293   TestPatternDriver(const TestPatternDriver &other) : PassWrapper(other) {}
294 
295   StringRef getArgument() const final { return "test-patterns"; }
296   StringRef getDescription() const final { return "Run test dialect patterns"; }
297   void runOnOperation() override {
298     mlir::RewritePatternSet patterns(&getContext());
299     populateWithGenerated(patterns);
300 
301     // Verify named pattern is generated with expected name.
302     patterns.add<FoldingPattern, TestNamedPatternRule,
303                  FolderInsertBeforePreviouslyFoldedConstantPattern,
304                  FolderCommutativeOp2WithConstant, HoistEligibleOps,
305                  MakeOpEligible>(&getContext());
306 
307     // Additional patterns for testing the GreedyPatternRewriteDriver.
308     patterns.insert<IncrementIntAttribute<3>>(&getContext());
309 
310     GreedyRewriteConfig config;
311     config.useTopDownTraversal = this->useTopDownTraversal;
312     config.maxIterations = this->maxIterations;
313     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
314                                        config);
315   }
316 
317   Option<bool> useTopDownTraversal{
318       *this, "top-down",
319       llvm::cl::desc("Seed the worklist in general top-down order"),
320       llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
321   Option<int> maxIterations{
322       *this, "max-iterations",
323       llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
324       llvm::cl::init(GreedyRewriteConfig().maxIterations)};
325 };
326 
327 struct DumpNotifications : public RewriterBase::Listener {
328   void notifyBlockInserted(Block *block, Region *previous,
329                            Region::iterator previousIt) override {
330     llvm::outs() << "notifyBlockInserted";
331     if (block->getParentOp()) {
332       llvm::outs() << " into " << block->getParentOp()->getName() << ": ";
333     } else {
334       llvm::outs() << " into unknown op: ";
335     }
336     if (previous == nullptr) {
337       llvm::outs() << "was unlinked\n";
338     } else {
339       llvm::outs() << "was linked\n";
340     }
341   }
342   void notifyOperationInserted(Operation *op,
343                                OpBuilder::InsertPoint previous) override {
344     llvm::outs() << "notifyOperationInserted: " << op->getName();
345     if (!previous.isSet()) {
346       llvm::outs() << ", was unlinked\n";
347     } else {
348       if (!previous.getPoint().getNodePtr()) {
349         llvm::outs() << ", was linked, exact position unknown\n";
350       } else if (previous.getPoint() == previous.getBlock()->end()) {
351         llvm::outs() << ", was last in block\n";
352       } else {
353         llvm::outs() << ", previous = " << previous.getPoint()->getName()
354                      << "\n";
355       }
356     }
357   }
358   void notifyBlockErased(Block *block) override {
359     llvm::outs() << "notifyBlockErased\n";
360   }
361   void notifyOperationErased(Operation *op) override {
362     llvm::outs() << "notifyOperationErased: " << op->getName() << "\n";
363   }
364   void notifyOperationModified(Operation *op) override {
365     llvm::outs() << "notifyOperationModified: " << op->getName() << "\n";
366   }
367   void notifyOperationReplaced(Operation *op, ValueRange values) override {
368     llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n";
369   }
370 };
371 
372 struct TestStrictPatternDriver
373     : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
374 public:
375   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
376 
377   TestStrictPatternDriver() = default;
378   TestStrictPatternDriver(const TestStrictPatternDriver &other)
379       : PassWrapper(other) {
380     strictMode = other.strictMode;
381   }
382 
383   StringRef getArgument() const final { return "test-strict-pattern-driver"; }
384   StringRef getDescription() const final {
385     return "Test strict mode of pattern driver";
386   }
387 
388   void runOnOperation() override {
389     MLIRContext *ctx = &getContext();
390     mlir::RewritePatternSet patterns(ctx);
391     patterns.add<
392         // clang-format off
393         ChangeBlockOp,
394         CloneOp,
395         CloneRegionBeforeOp,
396         EraseOp,
397         ImplicitChangeOp,
398         InlineBlocksIntoParent,
399         InsertSameOp,
400         MoveBeforeParentOp,
401         ReplaceWithNewOp,
402         SplitBlockHere
403         // clang-format on
404         >(ctx);
405     SmallVector<Operation *> ops;
406     getOperation()->walk([&](Operation *op) {
407       StringRef opName = op->getName().getStringRef();
408       if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
409           opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
410           opName == "test.move_before_parent_op" ||
411           opName == "test.inline_blocks_into_parent" ||
412           opName == "test.split_block_here" || opName == "test.clone_me" ||
413           opName == "test.clone_region_before") {
414         ops.push_back(op);
415       }
416     });
417 
418     DumpNotifications dumpNotifications;
419     GreedyRewriteConfig config;
420     config.listener = &dumpNotifications;
421     if (strictMode == "AnyOp") {
422       config.strictMode = GreedyRewriteStrictness::AnyOp;
423     } else if (strictMode == "ExistingAndNewOps") {
424       config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
425     } else if (strictMode == "ExistingOps") {
426       config.strictMode = GreedyRewriteStrictness::ExistingOps;
427     } else {
428       llvm_unreachable("invalid strictness option");
429     }
430 
431     // Check if these transformations introduce visiting of operations that
432     // are not in the `ops` set (The new created ops are valid). An invalid
433     // operation will trigger the assertion while processing.
434     bool changed = false;
435     bool allErased = false;
436     (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config,
437                                  &changed, &allErased);
438     Builder b(ctx);
439     getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed));
440     getOperation()->setAttr("pattern_driver_all_erased",
441                             b.getBoolAttr(allErased));
442   }
443 
444   Option<std::string> strictMode{
445       *this, "strictness",
446       llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"),
447       llvm::cl::init("AnyOp")};
448 
449 private:
450   // New inserted operation is valid for further transformation.
451   class InsertSameOp : public RewritePattern {
452   public:
453     InsertSameOp(MLIRContext *context)
454         : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {}
455 
456     LogicalResult matchAndRewrite(Operation *op,
457                                   PatternRewriter &rewriter) const override {
458       if (op->hasAttr("skip"))
459         return failure();
460 
461       Operation *newOp =
462           rewriter.create(op->getLoc(), op->getName().getIdentifier(),
463                           op->getOperands(), op->getResultTypes());
464       rewriter.modifyOpInPlace(
465           op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); });
466       newOp->setAttr("skip", rewriter.getBoolAttr(true));
467 
468       return success();
469     }
470   };
471 
472   // Replace an operation may introduce the re-visiting of its users.
473   class ReplaceWithNewOp : public RewritePattern {
474   public:
475     ReplaceWithNewOp(MLIRContext *context)
476         : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
477 
478     LogicalResult matchAndRewrite(Operation *op,
479                                   PatternRewriter &rewriter) const override {
480       Operation *newOp;
481       if (op->hasAttr("create_erase_op")) {
482         newOp = rewriter.create(
483             op->getLoc(),
484             OperationName("test.erase_op", op->getContext()).getIdentifier(),
485             ValueRange(), TypeRange());
486       } else {
487         newOp = rewriter.create(
488             op->getLoc(),
489             OperationName("test.new_op", op->getContext()).getIdentifier(),
490             op->getOperands(), op->getResultTypes());
491       }
492       rewriter.replaceOp(op, newOp->getResults());
493       return success();
494     }
495   };
496 
497   // Remove an operation may introduce the re-visiting of its operands.
498   class EraseOp : public RewritePattern {
499   public:
500     EraseOp(MLIRContext *context)
501         : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
502     LogicalResult matchAndRewrite(Operation *op,
503                                   PatternRewriter &rewriter) const override {
504       rewriter.eraseOp(op);
505       return success();
506     }
507   };
508 
509   // The following two patterns test RewriterBase::replaceAllUsesWith.
510   //
511   // That function replaces all usages of a Block (or a Value) with another one
512   // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver
513   // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its
514   // worklist: when an op is modified, it is added to the worklist. The two
515   // patterns below make the tracking observable: ChangeBlockOp replaces all
516   // usages of a block and that pattern is applied because the corresponding ops
517   // are put on the initial worklist (see above). ImplicitChangeOp does an
518   // unrelated change but ops of the corresponding type are *not* on the initial
519   // worklist, so the effect of the second pattern is only visible if the
520   // tracking and subsequent adding to the worklist actually works.
521 
522   // Replace all usages of the first successor with the second successor.
523   class ChangeBlockOp : public RewritePattern {
524   public:
525     ChangeBlockOp(MLIRContext *context)
526         : RewritePattern("test.change_block_op", /*benefit=*/1, context) {}
527     LogicalResult matchAndRewrite(Operation *op,
528                                   PatternRewriter &rewriter) const override {
529       if (op->getNumSuccessors() < 2)
530         return failure();
531       Block *firstSuccessor = op->getSuccessor(0);
532       Block *secondSuccessor = op->getSuccessor(1);
533       if (firstSuccessor == secondSuccessor)
534         return failure();
535       // This is the function being tested:
536       rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor);
537       // Using the following line instead would make the test fail:
538       // firstSuccessor->replaceAllUsesWith(secondSuccessor);
539       return success();
540     }
541   };
542 
543   // Changes the successor to the parent block.
544   class ImplicitChangeOp : public RewritePattern {
545   public:
546     ImplicitChangeOp(MLIRContext *context)
547         : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {}
548     LogicalResult matchAndRewrite(Operation *op,
549                                   PatternRewriter &rewriter) const override {
550       if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock())
551         return failure();
552       rewriter.modifyOpInPlace(op,
553                                [&]() { op->setSuccessor(op->getBlock(), 0); });
554       return success();
555     }
556   };
557 };
558 
559 } // namespace
560 
561 //===----------------------------------------------------------------------===//
562 // ReturnType Driver.
563 //===----------------------------------------------------------------------===//
564 
565 namespace {
566 // Generate ops for each instance where the type can be successfully inferred.
567 template <typename OpTy>
568 static void invokeCreateWithInferredReturnType(Operation *op) {
569   auto *context = op->getContext();
570   auto fop = op->getParentOfType<func::FuncOp>();
571   auto location = UnknownLoc::get(context);
572   OpBuilder b(op);
573   b.setInsertionPointAfter(op);
574 
575   // Use permutations of 2 args as operands.
576   assert(fop.getNumArguments() >= 2);
577   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
578     for (int j = 0; j < e; ++j) {
579       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
580       SmallVector<Type, 2> inferredReturnTypes;
581       if (succeeded(OpTy::inferReturnTypes(
582               context, std::nullopt, values, op->getDiscardableAttrDictionary(),
583               op->getPropertiesStorage(), op->getRegions(),
584               inferredReturnTypes))) {
585         OperationState state(location, OpTy::getOperationName());
586         // TODO: Expand to regions.
587         OpTy::build(b, state, values, op->getAttrs());
588         (void)b.create(state);
589       }
590     }
591   }
592 }
593 
594 static void reifyReturnShape(Operation *op) {
595   OpBuilder b(op);
596 
597   // Use permutations of 2 args as operands.
598   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
599   SmallVector<Value, 2> shapes;
600   if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
601       !llvm::hasSingleElement(shapes))
602     return;
603   for (const auto &it : llvm::enumerate(shapes)) {
604     op->emitRemark() << "value " << it.index() << ": "
605                      << it.value().getDefiningOp();
606   }
607 }
608 
609 struct TestReturnTypeDriver
610     : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> {
611   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver)
612 
613   void getDependentDialects(DialectRegistry &registry) const override {
614     registry.insert<tensor::TensorDialect>();
615   }
616   StringRef getArgument() const final { return "test-return-type"; }
617   StringRef getDescription() const final { return "Run return type functions"; }
618 
619   void runOnOperation() override {
620     if (getOperation().getName() == "testCreateFunctions") {
621       std::vector<Operation *> ops;
622       // Collect ops to avoid triggering on inserted ops.
623       for (auto &op : getOperation().getBody().front())
624         ops.push_back(&op);
625       // Generate test patterns for each, but skip terminator.
626       for (auto *op : llvm::ArrayRef(ops).drop_back()) {
627         // Test create method of each of the Op classes below. The resultant
628         // output would be in reverse order underneath `op` from which
629         // the attributes and regions are used.
630         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
631         invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>(
632             op);
633         invokeCreateWithInferredReturnType<
634             OpWithShapedTypeInferTypeInterfaceOp>(op);
635       };
636       return;
637     }
638     if (getOperation().getName() == "testReifyFunctions") {
639       std::vector<Operation *> ops;
640       // Collect ops to avoid triggering on inserted ops.
641       for (auto &op : getOperation().getBody().front())
642         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
643           ops.push_back(&op);
644       // Generate test patterns for each, but skip terminator.
645       for (auto *op : ops)
646         reifyReturnShape(op);
647     }
648   }
649 };
650 } // namespace
651 
652 namespace {
653 struct TestDerivedAttributeDriver
654     : public PassWrapper<TestDerivedAttributeDriver,
655                          OperationPass<func::FuncOp>> {
656   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver)
657 
658   StringRef getArgument() const final { return "test-derived-attr"; }
659   StringRef getDescription() const final {
660     return "Run test derived attributes";
661   }
662   void runOnOperation() override;
663 };
664 } // namespace
665 
666 void TestDerivedAttributeDriver::runOnOperation() {
667   getOperation().walk([](DerivedAttributeOpInterface dOp) {
668     auto dAttr = dOp.materializeDerivedAttributes();
669     if (!dAttr)
670       return;
671     for (auto d : dAttr)
672       dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue();
673   });
674 }
675 
676 //===----------------------------------------------------------------------===//
677 // Legalization Driver.
678 //===----------------------------------------------------------------------===//
679 
680 namespace {
681 //===----------------------------------------------------------------------===//
682 // Region-Block Rewrite Testing
683 
684 /// This pattern is a simple pattern that inlines the first region of a given
685 /// operation into the parent region.
686 struct TestRegionRewriteBlockMovement : public ConversionPattern {
687   TestRegionRewriteBlockMovement(MLIRContext *ctx)
688       : ConversionPattern("test.region", 1, ctx) {}
689 
690   LogicalResult
691   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
692                   ConversionPatternRewriter &rewriter) const final {
693     // Inline this region into the parent region.
694     auto &parentRegion = *op->getParentRegion();
695     auto &opRegion = op->getRegion(0);
696     if (op->getDiscardableAttr("legalizer.should_clone"))
697       rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
698     else
699       rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
700 
701     if (op->getDiscardableAttr("legalizer.erase_old_blocks")) {
702       while (!opRegion.empty())
703         rewriter.eraseBlock(&opRegion.front());
704     }
705 
706     // Drop this operation.
707     rewriter.eraseOp(op);
708     return success();
709   }
710 };
711 /// This pattern is a simple pattern that generates a region containing an
712 /// illegal operation.
713 struct TestRegionRewriteUndo : public RewritePattern {
714   TestRegionRewriteUndo(MLIRContext *ctx)
715       : RewritePattern("test.region_builder", 1, ctx) {}
716 
717   LogicalResult matchAndRewrite(Operation *op,
718                                 PatternRewriter &rewriter) const final {
719     // Create the region operation with an entry block containing arguments.
720     OperationState newRegion(op->getLoc(), "test.region");
721     newRegion.addRegion();
722     auto *regionOp = rewriter.create(newRegion);
723     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
724     entryBlock->addArgument(rewriter.getIntegerType(64),
725                             rewriter.getUnknownLoc());
726 
727     // Add an explicitly illegal operation to ensure the conversion fails.
728     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
729     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
730 
731     // Drop this operation.
732     rewriter.eraseOp(op);
733     return success();
734   }
735 };
736 /// A simple pattern that creates a block at the end of the parent region of the
737 /// matched operation.
738 struct TestCreateBlock : public RewritePattern {
739   TestCreateBlock(MLIRContext *ctx)
740       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
741 
742   LogicalResult matchAndRewrite(Operation *op,
743                                 PatternRewriter &rewriter) const final {
744     Region &region = *op->getParentRegion();
745     Type i32Type = rewriter.getIntegerType(32);
746     Location loc = op->getLoc();
747     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
748     rewriter.create<TerminatorOp>(loc);
749     rewriter.eraseOp(op);
750     return success();
751   }
752 };
753 
754 /// A simple pattern that creates a block containing an invalid operation in
755 /// order to trigger the block creation undo mechanism.
756 struct TestCreateIllegalBlock : public RewritePattern {
757   TestCreateIllegalBlock(MLIRContext *ctx)
758       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
759 
760   LogicalResult matchAndRewrite(Operation *op,
761                                 PatternRewriter &rewriter) const final {
762     Region &region = *op->getParentRegion();
763     Type i32Type = rewriter.getIntegerType(32);
764     Location loc = op->getLoc();
765     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
766     // Create an illegal op to ensure the conversion fails.
767     rewriter.create<ILLegalOpF>(loc, i32Type);
768     rewriter.create<TerminatorOp>(loc);
769     rewriter.eraseOp(op);
770     return success();
771   }
772 };
773 
774 /// A simple pattern that tests the undo mechanism when replacing the uses of a
775 /// block argument.
776 struct TestUndoBlockArgReplace : public ConversionPattern {
777   TestUndoBlockArgReplace(MLIRContext *ctx)
778       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
779 
780   LogicalResult
781   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
782                   ConversionPatternRewriter &rewriter) const final {
783     auto illegalOp =
784         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
785     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
786                                         illegalOp->getResult(0));
787     rewriter.modifyOpInPlace(op, [] {});
788     return success();
789   }
790 };
791 
792 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
793 /// This is to test the rollback logic.
794 struct TestUndoMoveOpBefore : public ConversionPattern {
795   TestUndoMoveOpBefore(MLIRContext *ctx)
796       : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
797 
798   LogicalResult
799   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
800                   ConversionPatternRewriter &rewriter) const override {
801     rewriter.moveOpBefore(op, op->getParentOp());
802     // Replace with an illegal op to ensure the conversion fails.
803     rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
804     return success();
805   }
806 };
807 
808 /// A rewrite pattern that tests the undo mechanism when erasing a block.
809 struct TestUndoBlockErase : public ConversionPattern {
810   TestUndoBlockErase(MLIRContext *ctx)
811       : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
812 
813   LogicalResult
814   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
815                   ConversionPatternRewriter &rewriter) const final {
816     Block *secondBlock = &*std::next(op->getRegion(0).begin());
817     rewriter.setInsertionPointToStart(secondBlock);
818     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
819     rewriter.eraseBlock(secondBlock);
820     rewriter.modifyOpInPlace(op, [] {});
821     return success();
822   }
823 };
824 
825 /// A pattern that modifies a property in-place, but keeps the op illegal.
826 struct TestUndoPropertiesModification : public ConversionPattern {
827   TestUndoPropertiesModification(MLIRContext *ctx)
828       : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {}
829   LogicalResult
830   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
831                   ConversionPatternRewriter &rewriter) const final {
832     if (!op->hasAttr("modify_inplace"))
833       return failure();
834     rewriter.modifyOpInPlace(
835         op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); });
836     return success();
837   }
838 };
839 
840 //===----------------------------------------------------------------------===//
841 // Type-Conversion Rewrite Testing
842 
843 /// This patterns erases a region operation that has had a type conversion.
844 struct TestDropOpSignatureConversion : public ConversionPattern {
845   TestDropOpSignatureConversion(MLIRContext *ctx,
846                                 const TypeConverter &converter)
847       : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
848   LogicalResult
849   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
850                   ConversionPatternRewriter &rewriter) const override {
851     Region &region = op->getRegion(0);
852     Block *entry = &region.front();
853 
854     // Convert the original entry arguments.
855     const TypeConverter &converter = *getTypeConverter();
856     TypeConverter::SignatureConversion result(entry->getNumArguments());
857     if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
858                                               result)) ||
859         failed(rewriter.convertRegionTypes(&region, converter, &result)))
860       return failure();
861 
862     // Convert the region signature and just drop the operation.
863     rewriter.eraseOp(op);
864     return success();
865   }
866 };
867 /// This pattern simply updates the operands of the given operation.
868 struct TestPassthroughInvalidOp : public ConversionPattern {
869   TestPassthroughInvalidOp(MLIRContext *ctx)
870       : ConversionPattern("test.invalid", 1, ctx) {}
871   LogicalResult
872   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
873                   ConversionPatternRewriter &rewriter) const final {
874     rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, operands,
875                                              std::nullopt);
876     return success();
877   }
878 };
879 /// This pattern handles the case of a split return value.
880 struct TestSplitReturnType : public ConversionPattern {
881   TestSplitReturnType(MLIRContext *ctx)
882       : ConversionPattern("test.return", 1, ctx) {}
883   LogicalResult
884   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
885                   ConversionPatternRewriter &rewriter) const final {
886     // Check for a return of F32.
887     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
888       return failure();
889 
890     // Check if the first operation is a cast operation, if it is we use the
891     // results directly.
892     auto *defOp = operands[0].getDefiningOp();
893     if (auto packerOp =
894             llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) {
895       rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
896       return success();
897     }
898 
899     // Otherwise, fail to match.
900     return failure();
901   }
902 };
903 
904 //===----------------------------------------------------------------------===//
905 // Multi-Level Type-Conversion Rewrite Testing
906 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
907   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
908       : ConversionPattern("test.type_producer", 1, ctx) {}
909   LogicalResult
910   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
911                   ConversionPatternRewriter &rewriter) const final {
912     // If the type is I32, change the type to F32.
913     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
914       return failure();
915     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
916     return success();
917   }
918 };
919 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
920   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
921       : ConversionPattern("test.type_producer", 1, ctx) {}
922   LogicalResult
923   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
924                   ConversionPatternRewriter &rewriter) const final {
925     // If the type is F32, change the type to F64.
926     if (!Type(*op->result_type_begin()).isF32())
927       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
928     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
929     return success();
930   }
931 };
932 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
933   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
934       : ConversionPattern("test.type_producer", 10, ctx) {}
935   LogicalResult
936   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
937                   ConversionPatternRewriter &rewriter) const final {
938     // Always convert to B16, even though it is not a legal type. This tests
939     // that values are unmapped correctly.
940     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
941     return success();
942   }
943 };
944 struct TestUpdateConsumerType : public ConversionPattern {
945   TestUpdateConsumerType(MLIRContext *ctx)
946       : ConversionPattern("test.type_consumer", 1, ctx) {}
947   LogicalResult
948   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
949                   ConversionPatternRewriter &rewriter) const final {
950     // Verify that the incoming operand has been successfully remapped to F64.
951     if (!operands[0].getType().isF64())
952       return failure();
953     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
954     return success();
955   }
956 };
957 
958 //===----------------------------------------------------------------------===//
959 // Non-Root Replacement Rewrite Testing
960 /// This pattern generates an invalid operation, but replaces it before the
961 /// pattern is finished. This checks that we don't need to legalize the
962 /// temporary op.
963 struct TestNonRootReplacement : public RewritePattern {
964   TestNonRootReplacement(MLIRContext *ctx)
965       : RewritePattern("test.replace_non_root", 1, ctx) {}
966 
967   LogicalResult matchAndRewrite(Operation *op,
968                                 PatternRewriter &rewriter) const final {
969     auto resultType = *op->result_type_begin();
970     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
971     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
972 
973     rewriter.replaceOp(illegalOp, legalOp);
974     rewriter.replaceOp(op, illegalOp);
975     return success();
976   }
977 };
978 
979 //===----------------------------------------------------------------------===//
980 // Recursive Rewrite Testing
981 /// This pattern is applied to the same operation multiple times, but has a
982 /// bounded recursion.
983 struct TestBoundedRecursiveRewrite
984     : public OpRewritePattern<TestRecursiveRewriteOp> {
985   using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
986 
987   void initialize() {
988     // The conversion target handles bounding the recursion of this pattern.
989     setHasBoundedRewriteRecursion();
990   }
991 
992   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
993                                 PatternRewriter &rewriter) const final {
994     // Decrement the depth of the op in-place.
995     rewriter.modifyOpInPlace(op, [&] {
996       op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
997     });
998     return success();
999   }
1000 };
1001 
1002 struct TestNestedOpCreationUndoRewrite
1003     : public OpRewritePattern<IllegalOpWithRegionAnchor> {
1004   using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
1005 
1006   LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
1007                                 PatternRewriter &rewriter) const final {
1008     // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1009     rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1010     return success();
1011   };
1012 };
1013 
1014 // This pattern matches `test.blackhole` and delete this op and its producer.
1015 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
1016   using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
1017 
1018   LogicalResult matchAndRewrite(BlackHoleOp op,
1019                                 PatternRewriter &rewriter) const final {
1020     Operation *producer = op.getOperand().getDefiningOp();
1021     // Always erase the user before the producer, the framework should handle
1022     // this correctly.
1023     rewriter.eraseOp(op);
1024     rewriter.eraseOp(producer);
1025     return success();
1026   };
1027 };
1028 
1029 // This pattern replaces explicitly illegal op with explicitly legal op,
1030 // but in addition creates unregistered operation.
1031 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
1032   using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
1033 
1034   LogicalResult matchAndRewrite(ILLegalOpG op,
1035                                 PatternRewriter &rewriter) const final {
1036     IntegerAttr attr = rewriter.getI32IntegerAttr(0);
1037     Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
1038     rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
1039     return success();
1040   };
1041 };
1042 } // namespace
1043 
1044 namespace {
1045 struct TestTypeConverter : public TypeConverter {
1046   using TypeConverter::TypeConverter;
1047   TestTypeConverter() {
1048     addConversion(convertType);
1049     addArgumentMaterialization(materializeCast);
1050     addSourceMaterialization(materializeCast);
1051   }
1052 
1053   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
1054     // Drop I16 types.
1055     if (t.isSignlessInteger(16))
1056       return success();
1057 
1058     // Convert I64 to F64.
1059     if (t.isSignlessInteger(64)) {
1060       results.push_back(FloatType::getF64(t.getContext()));
1061       return success();
1062     }
1063 
1064     // Convert I42 to I43.
1065     if (t.isInteger(42)) {
1066       results.push_back(IntegerType::get(t.getContext(), 43));
1067       return success();
1068     }
1069 
1070     // Split F32 into F16,F16.
1071     if (t.isF32()) {
1072       results.assign(2, FloatType::getF16(t.getContext()));
1073       return success();
1074     }
1075 
1076     // Otherwise, convert the type directly.
1077     results.push_back(t);
1078     return success();
1079   }
1080 
1081   /// Hook for materializing a conversion. This is necessary because we generate
1082   /// 1->N type mappings.
1083   static std::optional<Value> materializeCast(OpBuilder &builder,
1084                                               Type resultType,
1085                                               ValueRange inputs, Location loc) {
1086     return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1087   }
1088 };
1089 
1090 struct TestLegalizePatternDriver
1091     : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> {
1092   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver)
1093 
1094   StringRef getArgument() const final { return "test-legalize-patterns"; }
1095   StringRef getDescription() const final {
1096     return "Run test dialect legalization patterns";
1097   }
1098   /// The mode of conversion to use with the driver.
1099   enum class ConversionMode { Analysis, Full, Partial };
1100 
1101   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
1102 
1103   void getDependentDialects(DialectRegistry &registry) const override {
1104     registry.insert<func::FuncDialect, test::TestDialect>();
1105   }
1106 
1107   void runOnOperation() override {
1108     TestTypeConverter converter;
1109     mlir::RewritePatternSet patterns(&getContext());
1110     populateWithGenerated(patterns);
1111     patterns
1112         .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
1113              TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace,
1114              TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType,
1115              TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
1116              TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
1117              TestNonRootReplacement, TestBoundedRecursiveRewrite,
1118              TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
1119              TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1120              TestUndoPropertiesModification>(&getContext());
1121     patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
1122     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1123                                                               converter);
1124     mlir::populateCallOpTypeConversionPattern(patterns, converter);
1125 
1126     // Define the conversion target used for the test.
1127     ConversionTarget target(getContext());
1128     target.addLegalOp<ModuleOp>();
1129     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
1130                       TerminatorOp, OneRegionOp>();
1131     target
1132         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
1133     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
1134       // Don't allow F32 operands.
1135       return llvm::none_of(op.getOperandTypes(),
1136                            [](Type type) { return type.isF32(); });
1137     });
1138     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1139       return converter.isSignatureLegal(op.getFunctionType()) &&
1140              converter.isLegal(&op.getBody());
1141     });
1142     target.addDynamicallyLegalOp<func::CallOp>(
1143         [&](func::CallOp op) { return converter.isLegal(op); });
1144 
1145     // TestCreateUnregisteredOp creates `arith.constant` operation,
1146     // which was not added to target intentionally to test
1147     // correct error code from conversion driver.
1148     target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
1149 
1150     // Expect the type_producer/type_consumer operations to only operate on f64.
1151     target.addDynamicallyLegalOp<TestTypeProducerOp>(
1152         [](TestTypeProducerOp op) { return op.getType().isF64(); });
1153     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1154       return op.getOperand().getType().isF64();
1155     });
1156 
1157     // Check support for marking certain operations as recursively legal.
1158     target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) {
1159       return static_cast<bool>(
1160           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
1161     });
1162 
1163     // Mark the bound recursion operation as dynamically legal.
1164     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
1165         [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
1166 
1167     // Handle a partial conversion.
1168     if (mode == ConversionMode::Partial) {
1169       DenseSet<Operation *> unlegalizedOps;
1170       ConversionConfig config;
1171       DumpNotifications dumpNotifications;
1172       config.listener = &dumpNotifications;
1173       config.unlegalizedOps = &unlegalizedOps;
1174       if (failed(applyPartialConversion(getOperation(), target,
1175                                         std::move(patterns), config))) {
1176         getOperation()->emitRemark() << "applyPartialConversion failed";
1177       }
1178       // Emit remarks for each legalizable operation.
1179       for (auto *op : unlegalizedOps)
1180         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1181       return;
1182     }
1183 
1184     // Handle a full conversion.
1185     if (mode == ConversionMode::Full) {
1186       // Check support for marking unknown operations as dynamically legal.
1187       target.markUnknownOpDynamicallyLegal([](Operation *op) {
1188         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
1189       });
1190 
1191       ConversionConfig config;
1192       DumpNotifications dumpNotifications;
1193       config.listener = &dumpNotifications;
1194       if (failed(applyFullConversion(getOperation(), target,
1195                                      std::move(patterns), config))) {
1196         getOperation()->emitRemark() << "applyFullConversion failed";
1197       }
1198       return;
1199     }
1200 
1201     // Otherwise, handle an analysis conversion.
1202     assert(mode == ConversionMode::Analysis);
1203 
1204     // Analyze the convertible operations.
1205     DenseSet<Operation *> legalizedOps;
1206     ConversionConfig config;
1207     config.legalizableOps = &legalizedOps;
1208     if (failed(applyAnalysisConversion(getOperation(), target,
1209                                        std::move(patterns), config)))
1210       return signalPassFailure();
1211 
1212     // Emit remarks for each legalizable operation.
1213     for (auto *op : legalizedOps)
1214       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
1215   }
1216 
1217   /// The mode of conversion to use.
1218   ConversionMode mode;
1219 };
1220 } // namespace
1221 
1222 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
1223     legalizerConversionMode(
1224         "test-legalize-mode",
1225         llvm::cl::desc("The legalization mode to use with the test driver"),
1226         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
1227         llvm::cl::values(
1228             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
1229                        "analysis", "Perform an analysis conversion"),
1230             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
1231                        "Perform a full conversion"),
1232             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
1233                        "partial", "Perform a partial conversion")));
1234 
1235 //===----------------------------------------------------------------------===//
1236 // ConversionPatternRewriter::getRemappedValue testing. This method is used
1237 // to get the remapped value of an original value that was replaced using
1238 // ConversionPatternRewriter.
1239 namespace {
1240 struct TestRemapValueTypeConverter : public TypeConverter {
1241   using TypeConverter::TypeConverter;
1242 
1243   TestRemapValueTypeConverter() {
1244     addConversion(
1245         [](Float32Type type) { return Float64Type::get(type.getContext()); });
1246     addConversion([](Type type) { return type; });
1247   }
1248 };
1249 
1250 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
1251 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
1252 /// operand twice.
1253 ///
1254 /// Example:
1255 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
1256 /// is replaced with:
1257 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
1258 struct OneVResOneVOperandOp1Converter
1259     : public OpConversionPattern<OneVResOneVOperandOp1> {
1260   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
1261 
1262   LogicalResult
1263   matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
1264                   ConversionPatternRewriter &rewriter) const override {
1265     auto origOps = op.getOperands();
1266     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
1267            "One operand expected");
1268     Value origOp = *origOps.begin();
1269     SmallVector<Value, 2> remappedOperands;
1270     // Replicate the remapped original operand twice. Note that we don't used
1271     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
1272     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1273     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1274 
1275     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
1276                                                        remappedOperands);
1277     return success();
1278   }
1279 };
1280 
1281 /// A rewriter pattern that tests that blocks can be merged.
1282 struct TestRemapValueInRegion
1283     : public OpConversionPattern<TestRemappedValueRegionOp> {
1284   using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
1285 
1286   LogicalResult
1287   matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
1288                   ConversionPatternRewriter &rewriter) const final {
1289     Block &block = op.getBody().front();
1290     Operation *terminator = block.getTerminator();
1291 
1292     // Merge the block into the parent region.
1293     Block *parentBlock = op->getBlock();
1294     Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator());
1295     rewriter.mergeBlocks(&block, parentBlock, ValueRange());
1296     rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange());
1297 
1298     // Replace the results of this operation with the remapped terminator
1299     // values.
1300     SmallVector<Value> terminatorOperands;
1301     if (failed(rewriter.getRemappedValues(terminator->getOperands(),
1302                                           terminatorOperands)))
1303       return failure();
1304 
1305     rewriter.eraseOp(terminator);
1306     rewriter.replaceOp(op, terminatorOperands);
1307     return success();
1308   }
1309 };
1310 
1311 struct TestRemappedValue
1312     : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> {
1313   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue)
1314 
1315   StringRef getArgument() const final { return "test-remapped-value"; }
1316   StringRef getDescription() const final {
1317     return "Test public remapped value mechanism in ConversionPatternRewriter";
1318   }
1319   void runOnOperation() override {
1320     TestRemapValueTypeConverter typeConverter;
1321 
1322     mlir::RewritePatternSet patterns(&getContext());
1323     patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
1324     patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
1325         &getContext());
1326     patterns.add<TestRemapValueInRegion>(typeConverter, &getContext());
1327 
1328     mlir::ConversionTarget target(getContext());
1329     target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>();
1330 
1331     // Expect the type_producer/type_consumer operations to only operate on f64.
1332     target.addDynamicallyLegalOp<TestTypeProducerOp>(
1333         [](TestTypeProducerOp op) { return op.getType().isF64(); });
1334     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1335       return op.getOperand().getType().isF64();
1336     });
1337 
1338     // We make OneVResOneVOperandOp1 legal only when it has more that one
1339     // operand. This will trigger the conversion that will replace one-operand
1340     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
1341     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
1342         [](Operation *op) { return op->getNumOperands() > 1; });
1343 
1344     if (failed(mlir::applyFullConversion(getOperation(), target,
1345                                          std::move(patterns)))) {
1346       signalPassFailure();
1347     }
1348   }
1349 };
1350 } // namespace
1351 
1352 //===----------------------------------------------------------------------===//
1353 // Test patterns without a specific root operation kind
1354 //===----------------------------------------------------------------------===//
1355 
1356 namespace {
1357 /// This pattern matches and removes any operation in the test dialect.
1358 struct RemoveTestDialectOps : public RewritePattern {
1359   RemoveTestDialectOps(MLIRContext *context)
1360       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
1361 
1362   LogicalResult matchAndRewrite(Operation *op,
1363                                 PatternRewriter &rewriter) const override {
1364     if (!isa<TestDialect>(op->getDialect()))
1365       return failure();
1366     rewriter.eraseOp(op);
1367     return success();
1368   }
1369 };
1370 
1371 struct TestUnknownRootOpDriver
1372     : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> {
1373   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver)
1374 
1375   StringRef getArgument() const final {
1376     return "test-legalize-unknown-root-patterns";
1377   }
1378   StringRef getDescription() const final {
1379     return "Test public remapped value mechanism in ConversionPatternRewriter";
1380   }
1381   void runOnOperation() override {
1382     mlir::RewritePatternSet patterns(&getContext());
1383     patterns.add<RemoveTestDialectOps>(&getContext());
1384 
1385     mlir::ConversionTarget target(getContext());
1386     target.addIllegalDialect<TestDialect>();
1387     if (failed(applyPartialConversion(getOperation(), target,
1388                                       std::move(patterns))))
1389       signalPassFailure();
1390   }
1391 };
1392 } // namespace
1393 
1394 //===----------------------------------------------------------------------===//
1395 // Test patterns that uses operations and types defined at runtime
1396 //===----------------------------------------------------------------------===//
1397 
1398 namespace {
1399 /// This pattern matches dynamic operations 'test.one_operand_two_results' and
1400 /// replace them with dynamic operations 'test.generic_dynamic_op'.
1401 struct RewriteDynamicOp : public RewritePattern {
1402   RewriteDynamicOp(MLIRContext *context)
1403       : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1,
1404                        context) {}
1405 
1406   LogicalResult matchAndRewrite(Operation *op,
1407                                 PatternRewriter &rewriter) const override {
1408     assert(op->getName().getStringRef() ==
1409                "test.dynamic_one_operand_two_results" &&
1410            "rewrite pattern should only match operations with the right name");
1411 
1412     OperationState state(op->getLoc(), "test.dynamic_generic",
1413                          op->getOperands(), op->getResultTypes(),
1414                          op->getAttrs());
1415     auto *newOp = rewriter.create(state);
1416     rewriter.replaceOp(op, newOp->getResults());
1417     return success();
1418   }
1419 };
1420 
1421 struct TestRewriteDynamicOpDriver
1422     : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> {
1423   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver)
1424 
1425   void getDependentDialects(DialectRegistry &registry) const override {
1426     registry.insert<TestDialect>();
1427   }
1428   StringRef getArgument() const final { return "test-rewrite-dynamic-op"; }
1429   StringRef getDescription() const final {
1430     return "Test rewritting on dynamic operations";
1431   }
1432   void runOnOperation() override {
1433     RewritePatternSet patterns(&getContext());
1434     patterns.add<RewriteDynamicOp>(&getContext());
1435 
1436     ConversionTarget target(getContext());
1437     target.addIllegalOp(
1438         OperationName("test.dynamic_one_operand_two_results", &getContext()));
1439     target.addLegalOp(OperationName("test.dynamic_generic", &getContext()));
1440     if (failed(applyPartialConversion(getOperation(), target,
1441                                       std::move(patterns))))
1442       signalPassFailure();
1443   }
1444 };
1445 } // end anonymous namespace
1446 
1447 //===----------------------------------------------------------------------===//
1448 // Test type conversions
1449 //===----------------------------------------------------------------------===//
1450 
1451 namespace {
1452 struct TestTypeConversionProducer
1453     : public OpConversionPattern<TestTypeProducerOp> {
1454   using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
1455   LogicalResult
1456   matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
1457                   ConversionPatternRewriter &rewriter) const final {
1458     Type resultType = op.getType();
1459     Type convertedType = getTypeConverter()
1460                              ? getTypeConverter()->convertType(resultType)
1461                              : resultType;
1462     if (isa<FloatType>(resultType))
1463       resultType = rewriter.getF64Type();
1464     else if (resultType.isInteger(16))
1465       resultType = rewriter.getIntegerType(64);
1466     else if (isa<test::TestRecursiveType>(resultType) &&
1467              convertedType != resultType)
1468       resultType = convertedType;
1469     else
1470       return failure();
1471 
1472     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
1473     return success();
1474   }
1475 };
1476 
1477 /// Call signature conversion and then fail the rewrite to trigger the undo
1478 /// mechanism.
1479 struct TestSignatureConversionUndo
1480     : public OpConversionPattern<TestSignatureConversionUndoOp> {
1481   using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
1482 
1483   LogicalResult
1484   matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
1485                   ConversionPatternRewriter &rewriter) const final {
1486     (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
1487     return failure();
1488   }
1489 };
1490 
1491 /// Call signature conversion without providing a type converter to handle
1492 /// materializations.
1493 struct TestTestSignatureConversionNoConverter
1494     : public OpConversionPattern<TestSignatureConversionNoConverterOp> {
1495   TestTestSignatureConversionNoConverter(const TypeConverter &converter,
1496                                          MLIRContext *context)
1497       : OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
1498         converter(converter) {}
1499 
1500   LogicalResult
1501   matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor,
1502                   ConversionPatternRewriter &rewriter) const final {
1503     Region &region = op->getRegion(0);
1504     Block *entry = &region.front();
1505 
1506     // Convert the original entry arguments.
1507     TypeConverter::SignatureConversion result(entry->getNumArguments());
1508     if (failed(
1509             converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
1510       return failure();
1511     rewriter.modifyOpInPlace(
1512         op, [&] { rewriter.applySignatureConversion(&region, result); });
1513     return success();
1514   }
1515 
1516   const TypeConverter &converter;
1517 };
1518 
1519 /// Just forward the operands to the root op. This is essentially a no-op
1520 /// pattern that is used to trigger target materialization.
1521 struct TestTypeConsumerForward
1522     : public OpConversionPattern<TestTypeConsumerOp> {
1523   using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
1524 
1525   LogicalResult
1526   matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
1527                   ConversionPatternRewriter &rewriter) const final {
1528     rewriter.modifyOpInPlace(op,
1529                              [&] { op->setOperands(adaptor.getOperands()); });
1530     return success();
1531   }
1532 };
1533 
1534 struct TestTypeConversionAnotherProducer
1535     : public OpRewritePattern<TestAnotherTypeProducerOp> {
1536   using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
1537 
1538   LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
1539                                 PatternRewriter &rewriter) const final {
1540     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
1541     return success();
1542   }
1543 };
1544 
1545 struct TestTypeConversionDriver
1546     : public PassWrapper<TestTypeConversionDriver, OperationPass<>> {
1547   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
1548 
1549   void getDependentDialects(DialectRegistry &registry) const override {
1550     registry.insert<TestDialect>();
1551   }
1552   StringRef getArgument() const final {
1553     return "test-legalize-type-conversion";
1554   }
1555   StringRef getDescription() const final {
1556     return "Test various type conversion functionalities in DialectConversion";
1557   }
1558 
1559   void runOnOperation() override {
1560     // Initialize the type converter.
1561     SmallVector<Type, 2> conversionCallStack;
1562     TypeConverter converter;
1563 
1564     /// Add the legal set of type conversions.
1565     converter.addConversion([](Type type) -> Type {
1566       // Treat F64 as legal.
1567       if (type.isF64())
1568         return type;
1569       // Allow converting BF16/F16/F32 to F64.
1570       if (type.isBF16() || type.isF16() || type.isF32())
1571         return FloatType::getF64(type.getContext());
1572       // Otherwise, the type is illegal.
1573       return nullptr;
1574     });
1575     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
1576       // Drop all integer types.
1577       return success();
1578     });
1579     converter.addConversion(
1580         // Convert a recursive self-referring type into a non-self-referring
1581         // type named "outer_converted_type" that contains a SimpleAType.
1582         [&](test::TestRecursiveType type,
1583             SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
1584           // If the type is already converted, return it to indicate that it is
1585           // legal.
1586           if (type.getName() == "outer_converted_type") {
1587             results.push_back(type);
1588             return success();
1589           }
1590 
1591           conversionCallStack.push_back(type);
1592           auto popConversionCallStack = llvm::make_scope_exit(
1593               [&conversionCallStack]() { conversionCallStack.pop_back(); });
1594 
1595           // If the type is on the call stack more than once (it is there at
1596           // least once because of the _current_ call, which is always the last
1597           // element on the stack), we've hit the recursive case. Just return
1598           // SimpleAType here to create a non-recursive type as a result.
1599           if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(),
1600                                  type)) {
1601             results.push_back(test::SimpleAType::get(type.getContext()));
1602             return success();
1603           }
1604 
1605           // Convert the body recursively.
1606           auto result = test::TestRecursiveType::get(type.getContext(),
1607                                                      "outer_converted_type");
1608           if (failed(result.setBody(converter.convertType(type.getBody()))))
1609             return failure();
1610           results.push_back(result);
1611           return success();
1612         });
1613 
1614     /// Add the legal set of type materializations.
1615     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
1616                                           ValueRange inputs,
1617                                           Location loc) -> Value {
1618       // Allow casting from F64 back to F32.
1619       if (!resultType.isF16() && inputs.size() == 1 &&
1620           inputs[0].getType().isF64())
1621         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1622       // Allow producing an i32 or i64 from nothing.
1623       if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
1624           inputs.empty())
1625         return builder.create<TestTypeProducerOp>(loc, resultType);
1626       // Allow producing an i64 from an integer.
1627       if (isa<IntegerType>(resultType) && inputs.size() == 1 &&
1628           isa<IntegerType>(inputs[0].getType()))
1629         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1630       // Otherwise, fail.
1631       return nullptr;
1632     });
1633 
1634     // Initialize the conversion target.
1635     mlir::ConversionTarget target(getContext());
1636     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
1637       auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType());
1638       return op.getType().isF64() || op.getType().isInteger(64) ||
1639              (recursiveType &&
1640               recursiveType.getName() == "outer_converted_type");
1641     });
1642     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1643       return converter.isSignatureLegal(op.getFunctionType()) &&
1644              converter.isLegal(&op.getBody());
1645     });
1646     target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
1647       // Allow casts from F64 to F32.
1648       return (*op.operand_type_begin()).isF64() && op.getType().isF32();
1649     });
1650     target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>(
1651         [&](TestSignatureConversionNoConverterOp op) {
1652           return converter.isLegal(op.getRegion().front().getArgumentTypes());
1653         });
1654 
1655     // Initialize the set of rewrite patterns.
1656     RewritePatternSet patterns(&getContext());
1657     patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
1658                  TestSignatureConversionUndo,
1659                  TestTestSignatureConversionNoConverter>(converter,
1660                                                          &getContext());
1661     patterns.add<TestTypeConversionAnotherProducer>(&getContext());
1662     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1663                                                               converter);
1664 
1665     if (failed(applyPartialConversion(getOperation(), target,
1666                                       std::move(patterns))))
1667       signalPassFailure();
1668   }
1669 };
1670 } // namespace
1671 
1672 //===----------------------------------------------------------------------===//
1673 // Test Target Materialization With No Uses
1674 //===----------------------------------------------------------------------===//
1675 
1676 namespace {
1677 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> {
1678   using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern;
1679 
1680   LogicalResult
1681   matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor,
1682                   ConversionPatternRewriter &rewriter) const final {
1683     rewriter.replaceOp(op, adaptor.getOperands());
1684     return success();
1685   }
1686 };
1687 
1688 struct TestTargetMaterializationWithNoUses
1689     : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> {
1690   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
1691       TestTargetMaterializationWithNoUses)
1692 
1693   StringRef getArgument() const final {
1694     return "test-target-materialization-with-no-uses";
1695   }
1696   StringRef getDescription() const final {
1697     return "Test a special case of target materialization in DialectConversion";
1698   }
1699 
1700   void runOnOperation() override {
1701     TypeConverter converter;
1702     converter.addConversion([](Type t) { return t; });
1703     converter.addConversion([](IntegerType intTy) -> Type {
1704       if (intTy.getWidth() == 16)
1705         return IntegerType::get(intTy.getContext(), 64);
1706       return intTy;
1707     });
1708     converter.addTargetMaterialization(
1709         [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1710           return builder.create<TestCastOp>(loc, type, inputs).getResult();
1711         });
1712 
1713     ConversionTarget target(getContext());
1714     target.addIllegalOp<TestTypeChangerOp>();
1715 
1716     RewritePatternSet patterns(&getContext());
1717     patterns.add<ForwardOperandPattern>(converter, &getContext());
1718 
1719     if (failed(applyPartialConversion(getOperation(), target,
1720                                       std::move(patterns))))
1721       signalPassFailure();
1722   }
1723 };
1724 } // namespace
1725 
1726 //===----------------------------------------------------------------------===//
1727 // Test Block Merging
1728 //===----------------------------------------------------------------------===//
1729 
1730 namespace {
1731 /// A rewriter pattern that tests that blocks can be merged.
1732 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
1733   using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
1734 
1735   LogicalResult
1736   matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
1737                   ConversionPatternRewriter &rewriter) const final {
1738     Block &firstBlock = op.getBody().front();
1739     Operation *branchOp = firstBlock.getTerminator();
1740     Block *secondBlock = &*(std::next(op.getBody().begin()));
1741     auto succOperands = branchOp->getOperands();
1742     SmallVector<Value, 2> replacements(succOperands);
1743     rewriter.eraseOp(branchOp);
1744     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1745     rewriter.modifyOpInPlace(op, [] {});
1746     return success();
1747   }
1748 };
1749 
1750 /// A rewrite pattern to tests the undo mechanism of blocks being merged.
1751 struct TestUndoBlocksMerge : public ConversionPattern {
1752   TestUndoBlocksMerge(MLIRContext *ctx)
1753       : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
1754   LogicalResult
1755   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1756                   ConversionPatternRewriter &rewriter) const final {
1757     Block &firstBlock = op->getRegion(0).front();
1758     Operation *branchOp = firstBlock.getTerminator();
1759     Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
1760     rewriter.setInsertionPointToStart(secondBlock);
1761     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
1762     auto succOperands = branchOp->getOperands();
1763     SmallVector<Value, 2> replacements(succOperands);
1764     rewriter.eraseOp(branchOp);
1765     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1766     rewriter.modifyOpInPlace(op, [] {});
1767     return success();
1768   }
1769 };
1770 
1771 /// A rewrite mechanism to inline the body of the op into its parent, when both
1772 /// ops can have a single block.
1773 struct TestMergeSingleBlockOps
1774     : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
1775   using OpConversionPattern<
1776       SingleBlockImplicitTerminatorOp>::OpConversionPattern;
1777 
1778   LogicalResult
1779   matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
1780                   ConversionPatternRewriter &rewriter) const final {
1781     SingleBlockImplicitTerminatorOp parentOp =
1782         op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1783     if (!parentOp)
1784       return failure();
1785     Block &innerBlock = op.getRegion().front();
1786     TerminatorOp innerTerminator =
1787         cast<TerminatorOp>(innerBlock.getTerminator());
1788     rewriter.inlineBlockBefore(&innerBlock, op);
1789     rewriter.eraseOp(innerTerminator);
1790     rewriter.eraseOp(op);
1791     return success();
1792   }
1793 };
1794 
1795 struct TestMergeBlocksPatternDriver
1796     : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> {
1797   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver)
1798 
1799   StringRef getArgument() const final { return "test-merge-blocks"; }
1800   StringRef getDescription() const final {
1801     return "Test Merging operation in ConversionPatternRewriter";
1802   }
1803   void runOnOperation() override {
1804     MLIRContext *context = &getContext();
1805     mlir::RewritePatternSet patterns(context);
1806     patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
1807         context);
1808     ConversionTarget target(*context);
1809     target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
1810                       TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
1811     target.addIllegalOp<ILLegalOpF>();
1812 
1813     /// Expect the op to have a single block after legalization.
1814     target.addDynamicallyLegalOp<TestMergeBlocksOp>(
1815         [&](TestMergeBlocksOp op) -> bool {
1816           return llvm::hasSingleElement(op.getBody());
1817         });
1818 
1819     /// Only allow `test.br` within test.merge_blocks op.
1820     target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
1821       return op->getParentOfType<TestMergeBlocksOp>();
1822     });
1823 
1824     /// Expect that all nested test.SingleBlockImplicitTerminator ops are
1825     /// inlined.
1826     target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
1827         [&](SingleBlockImplicitTerminatorOp op) -> bool {
1828           return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1829         });
1830 
1831     DenseSet<Operation *> unlegalizedOps;
1832     ConversionConfig config;
1833     config.unlegalizedOps = &unlegalizedOps;
1834     (void)applyPartialConversion(getOperation(), target, std::move(patterns),
1835                                  config);
1836     for (auto *op : unlegalizedOps)
1837       op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1838   }
1839 };
1840 } // namespace
1841 
1842 //===----------------------------------------------------------------------===//
1843 // Test Selective Replacement
1844 //===----------------------------------------------------------------------===//
1845 
1846 namespace {
1847 /// A rewrite mechanism to inline the body of the op into its parent, when both
1848 /// ops can have a single block.
1849 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
1850   using OpRewritePattern<TestCastOp>::OpRewritePattern;
1851 
1852   LogicalResult matchAndRewrite(TestCastOp op,
1853                                 PatternRewriter &rewriter) const final {
1854     if (op.getNumOperands() != 2)
1855       return failure();
1856     OperandRange operands = op.getOperands();
1857 
1858     // Replace non-terminator uses with the first operand.
1859     rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) {
1860       return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
1861     });
1862     // Replace everything else with the second operand if the operation isn't
1863     // dead.
1864     rewriter.replaceOp(op, op.getOperand(1));
1865     return success();
1866   }
1867 };
1868 
1869 struct TestSelectiveReplacementPatternDriver
1870     : public PassWrapper<TestSelectiveReplacementPatternDriver,
1871                          OperationPass<>> {
1872   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
1873       TestSelectiveReplacementPatternDriver)
1874 
1875   StringRef getArgument() const final {
1876     return "test-pattern-selective-replacement";
1877   }
1878   StringRef getDescription() const final {
1879     return "Test selective replacement in the PatternRewriter";
1880   }
1881   void runOnOperation() override {
1882     MLIRContext *context = &getContext();
1883     mlir::RewritePatternSet patterns(context);
1884     patterns.add<TestSelectiveOpReplacementPattern>(context);
1885     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
1886   }
1887 };
1888 } // namespace
1889 
1890 //===----------------------------------------------------------------------===//
1891 // PassRegistration
1892 //===----------------------------------------------------------------------===//
1893 
1894 namespace mlir {
1895 namespace test {
1896 void registerPatternsTestPass() {
1897   PassRegistration<TestReturnTypeDriver>();
1898 
1899   PassRegistration<TestDerivedAttributeDriver>();
1900 
1901   PassRegistration<TestPatternDriver>();
1902   PassRegistration<TestStrictPatternDriver>();
1903 
1904   PassRegistration<TestLegalizePatternDriver>([] {
1905     return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
1906   });
1907 
1908   PassRegistration<TestRemappedValue>();
1909 
1910   PassRegistration<TestUnknownRootOpDriver>();
1911 
1912   PassRegistration<TestTypeConversionDriver>();
1913   PassRegistration<TestTargetMaterializationWithNoUses>();
1914 
1915   PassRegistration<TestRewriteDynamicOpDriver>();
1916 
1917   PassRegistration<TestMergeBlocksPatternDriver>();
1918   PassRegistration<TestSelectiveReplacementPatternDriver>();
1919 }
1920 } // namespace test
1921 } // namespace mlir
1922