xref: /llvm-project/mlir/test/lib/Dialect/Test/TestPatterns.cpp (revision f023da12d12635f5fba436e825cbfc999e28e623)
1fec6c5acSUday Bondhugula //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2fec6c5acSUday Bondhugula //
3fec6c5acSUday Bondhugula // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fec6c5acSUday Bondhugula // See https://llvm.org/LICENSE.txt for license information.
5fec6c5acSUday Bondhugula // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fec6c5acSUday Bondhugula //
7fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
8fec6c5acSUday Bondhugula 
9fec6c5acSUday Bondhugula #include "TestDialect.h"
10e95e94adSJeff Niu #include "TestOps.h"
119c5982efSAlex Zinenko #include "TestTypes.h"
12abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1323aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1423aa5a74SRiver Riddle #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15c0a6318dSMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
160f8a6b7dSJakub Kuderski #include "mlir/IR/BuiltinAttributes.h"
172bf423b0SRob Suderman #include "mlir/IR/Matchers.h"
180f8a6b7dSJakub Kuderski #include "mlir/IR/Visitors.h"
19fec6c5acSUday Bondhugula #include "mlir/Pass/Pass.h"
20fec6c5acSUday Bondhugula #include "mlir/Transforms/DialectConversion.h"
2126f93d9fSAlex Zinenko #include "mlir/Transforms/FoldUtils.h"
22b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
230f8a6b7dSJakub Kuderski #include "mlir/Transforms/WalkPatternRewriteDriver.h"
24dc3dc974SMehdi Amini #include "llvm/ADT/ScopeExit.h"
250f8a6b7dSJakub Kuderski #include <cstdint>
269ba37b3bSJacques Pienaar 
27fec6c5acSUday Bondhugula using namespace mlir;
287776b19eSStephen Neuendorffer using namespace test;
29fec6c5acSUday Bondhugula 
30fec6c5acSUday Bondhugula // Native function for testing NativeCodeCall
31fec6c5acSUday Bondhugula static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
32fec6c5acSUday Bondhugula   return choice.getValue() ? input1 : input2;
33fec6c5acSUday Bondhugula }
34fec6c5acSUday Bondhugula 
3529429d1aSJacques Pienaar static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
3629429d1aSJacques Pienaar   rewriter.create<OpI>(loc, input);
37fec6c5acSUday Bondhugula }
38fec6c5acSUday Bondhugula 
39fec6c5acSUday Bondhugula static void handleNoResultOp(PatternRewriter &rewriter,
40fec6c5acSUday Bondhugula                              OpSymbolBindingNoResult op) {
41fec6c5acSUday Bondhugula   // Turn the no result op to a one-result op.
426a994233SJacques Pienaar   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
436a994233SJacques Pienaar                                     op.getOperand());
44fec6c5acSUday Bondhugula }
45fec6c5acSUday Bondhugula 
4634b5482bSChia-hung Duan static bool getFirstI32Result(Operation *op, Value &value) {
4734b5482bSChia-hung Duan   if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
4834b5482bSChia-hung Duan     return false;
4934b5482bSChia-hung Duan   value = op->getResult(0);
5034b5482bSChia-hung Duan   return true;
5134b5482bSChia-hung Duan }
5234b5482bSChia-hung Duan 
5334b5482bSChia-hung Duan static Value bindNativeCodeCallResult(Value value) { return value; }
5434b5482bSChia-hung Duan 
55d7314b3cSChia-hung Duan static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
56d7314b3cSChia-hung Duan                                                               Value input2) {
57d7314b3cSChia-hung Duan   return SmallVector<Value, 2>({input2, input1});
58d7314b3cSChia-hung Duan }
59d7314b3cSChia-hung Duan 
6001641197SAlexEichenberger // Test that natives calls are only called once during rewrites.
6101641197SAlexEichenberger // OpM_Test will return Pi, increased by 1 for each subsequent calls.
6201641197SAlexEichenberger // This let us check the number of times OpM_Test was called by inspecting
6301641197SAlexEichenberger // the returned value in the MLIR output.
6401641197SAlexEichenberger static int64_t opMIncreasingValue = 314159265;
6502b6fb21SMehdi Amini static Attribute opMTest(PatternRewriter &rewriter, Value val) {
6601641197SAlexEichenberger   int64_t i = opMIncreasingValue++;
6701641197SAlexEichenberger   return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
6801641197SAlexEichenberger }
6901641197SAlexEichenberger 
70fec6c5acSUday Bondhugula namespace {
71fec6c5acSUday Bondhugula #include "TestPatterns.inc"
72be0a7e9fSMehdi Amini } // namespace
73fec6c5acSUday Bondhugula 
74fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
75c484c7ddSChia-hung Duan // Test Reduce Pattern Interface
76c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
77c484c7ddSChia-hung Duan 
787776b19eSStephen Neuendorffer void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
79c484c7ddSChia-hung Duan   populateWithGenerated(patterns);
80c484c7ddSChia-hung Duan }
81c484c7ddSChia-hung Duan 
82c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
83fec6c5acSUday Bondhugula // Canonicalizer Driver.
84fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
85fec6c5acSUday Bondhugula 
86fec6c5acSUday Bondhugula namespace {
8726f93d9fSAlex Zinenko struct FoldingPattern : public RewritePattern {
8826f93d9fSAlex Zinenko public:
8926f93d9fSAlex Zinenko   FoldingPattern(MLIRContext *context)
9026f93d9fSAlex Zinenko       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
9126f93d9fSAlex Zinenko                        /*benefit=*/1, context) {}
9226f93d9fSAlex Zinenko 
9326f93d9fSAlex Zinenko   LogicalResult matchAndRewrite(Operation *op,
9426f93d9fSAlex Zinenko                                 PatternRewriter &rewriter) const override {
95506fd672SMatthias Springer     // Exercise createOrFold API for a single-result operation that is folded
96506fd672SMatthias Springer     // upon construction. The operation being created has an in-place folder,
97506fd672SMatthias Springer     // and it should be still present in the output. Furthermore, the folder
98506fd672SMatthias Springer     // should not crash when attempting to recover the (unchanged) operation
99506fd672SMatthias Springer     // result.
100506fd672SMatthias Springer     Value result = rewriter.createOrFold<TestOpInPlaceFold>(
101506fd672SMatthias Springer         op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0));
10226f93d9fSAlex Zinenko     assert(result);
10326f93d9fSAlex Zinenko     rewriter.replaceOp(op, result);
10426f93d9fSAlex Zinenko     return success();
10526f93d9fSAlex Zinenko   }
10626f93d9fSAlex Zinenko };
10726f93d9fSAlex Zinenko 
108e4635e63SRiver Riddle /// This pattern creates a foldable operation at the entry point of the block.
109e4635e63SRiver Riddle /// This tests the situation where the operation folder will need to replace an
110e4635e63SRiver Riddle /// operation with a previously created constant that does not initially
111e4635e63SRiver Riddle /// dominate the operation to replace.
112e4635e63SRiver Riddle struct FolderInsertBeforePreviouslyFoldedConstantPattern
113e4635e63SRiver Riddle     : public OpRewritePattern<TestCastOp> {
114e4635e63SRiver Riddle public:
115e4635e63SRiver Riddle   using OpRewritePattern<TestCastOp>::OpRewritePattern;
116e4635e63SRiver Riddle 
117e4635e63SRiver Riddle   LogicalResult matchAndRewrite(TestCastOp op,
118e4635e63SRiver Riddle                                 PatternRewriter &rewriter) const override {
119e4635e63SRiver Riddle     if (!op->hasAttr("test_fold_before_previously_folded_op"))
120e4635e63SRiver Riddle       return failure();
121e4635e63SRiver Riddle     rewriter.setInsertionPointToStart(op->getBlock());
122e4635e63SRiver Riddle 
123a54f4eaeSMogball     auto constOp = rewriter.create<arith::ConstantOp>(
124a54f4eaeSMogball         op.getLoc(), rewriter.getBoolAttr(true));
125e4635e63SRiver Riddle     rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
126e4635e63SRiver Riddle                                             Value(constOp));
127e4635e63SRiver Riddle     return success();
128e4635e63SRiver Riddle   }
129e4635e63SRiver Riddle };
130e4635e63SRiver Riddle 
13143f0d5f9SMehdi Amini /// This pattern matches test.op_commutative2 with the first operand being
13243f0d5f9SMehdi Amini /// another test.op_commutative2 with a constant on the right side and fold it
13343f0d5f9SMehdi Amini /// away by propagating it as its result. This is intend to check that patterns
13443f0d5f9SMehdi Amini /// are applied after the commutative property moves constant to the right.
13543f0d5f9SMehdi Amini struct FolderCommutativeOp2WithConstant
13643f0d5f9SMehdi Amini     : public OpRewritePattern<TestCommutative2Op> {
13743f0d5f9SMehdi Amini public:
13843f0d5f9SMehdi Amini   using OpRewritePattern<TestCommutative2Op>::OpRewritePattern;
13943f0d5f9SMehdi Amini 
14043f0d5f9SMehdi Amini   LogicalResult matchAndRewrite(TestCommutative2Op op,
14143f0d5f9SMehdi Amini                                 PatternRewriter &rewriter) const override {
14243f0d5f9SMehdi Amini     auto operand =
14343f0d5f9SMehdi Amini         dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
14443f0d5f9SMehdi Amini     if (!operand)
14543f0d5f9SMehdi Amini       return failure();
14643f0d5f9SMehdi Amini     Attribute constInput;
14743f0d5f9SMehdi Amini     if (!matchPattern(operand->getOperand(1), m_Constant(&constInput)))
14843f0d5f9SMehdi Amini       return failure();
14943f0d5f9SMehdi Amini     rewriter.replaceOp(op, operand->getOperand(1));
15043f0d5f9SMehdi Amini     return success();
15143f0d5f9SMehdi Amini   }
15243f0d5f9SMehdi Amini };
15343f0d5f9SMehdi Amini 
1542fea658aSMatthias Springer /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer
1552fea658aSMatthias Springer /// attribute with value smaller than MaxVal, it increments the value by 1.
1562fea658aSMatthias Springer template <int MaxVal>
1572fea658aSMatthias Springer struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
1582fea658aSMatthias Springer   using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern;
1592fea658aSMatthias Springer 
1602fea658aSMatthias Springer   LogicalResult matchAndRewrite(AnyAttrOfOp op,
1612fea658aSMatthias Springer                                 PatternRewriter &rewriter) const override {
1625550c821STres Popp     auto intAttr = dyn_cast<IntegerAttr>(op.getAttr());
1632fea658aSMatthias Springer     if (!intAttr)
1642fea658aSMatthias Springer       return failure();
1652fea658aSMatthias Springer     int64_t val = intAttr.getInt();
1662fea658aSMatthias Springer     if (val >= MaxVal)
1672fea658aSMatthias Springer       return failure();
1685fcf907bSMatthias Springer     rewriter.modifyOpInPlace(
1692fea658aSMatthias Springer         op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); });
1702fea658aSMatthias Springer     return success();
1712fea658aSMatthias Springer   }
1722fea658aSMatthias Springer };
1732fea658aSMatthias Springer 
174ed9194beSMatthias Springer /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op".
175ed9194beSMatthias Springer struct MakeOpEligible : public RewritePattern {
176ed9194beSMatthias Springer   MakeOpEligible(MLIRContext *context)
177ed9194beSMatthias Springer       : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {}
178ed9194beSMatthias Springer 
179ed9194beSMatthias Springer   LogicalResult matchAndRewrite(Operation *op,
180ed9194beSMatthias Springer                                 PatternRewriter &rewriter) const override {
181ed9194beSMatthias Springer     if (op->hasAttr("eligible"))
182ed9194beSMatthias Springer       return failure();
1835fcf907bSMatthias Springer     rewriter.modifyOpInPlace(
184ed9194beSMatthias Springer         op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); });
185ed9194beSMatthias Springer     return success();
186ed9194beSMatthias Springer   }
187ed9194beSMatthias Springer };
188ed9194beSMatthias Springer 
189ed9194beSMatthias Springer /// This pattern hoists eligible ops out of a "test.one_region_op".
190ed9194beSMatthias Springer struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
191ed9194beSMatthias Springer   using OpRewritePattern<test::OneRegionOp>::OpRewritePattern;
192ed9194beSMatthias Springer 
193ed9194beSMatthias Springer   LogicalResult matchAndRewrite(test::OneRegionOp op,
194ed9194beSMatthias Springer                                 PatternRewriter &rewriter) const override {
195ed9194beSMatthias Springer     Operation *terminator = op.getRegion().front().getTerminator();
196ed9194beSMatthias Springer     Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp();
197ed9194beSMatthias Springer     if (toBeHoisted->getParentOp() != op)
198ed9194beSMatthias Springer       return failure();
199ed9194beSMatthias Springer     if (!toBeHoisted->hasAttr("eligible"))
200ed9194beSMatthias Springer       return failure();
2015cc0f76dSMatthias Springer     rewriter.moveOpBefore(toBeHoisted, op);
202ed9194beSMatthias Springer     return success();
203ed9194beSMatthias Springer   }
204ed9194beSMatthias Springer };
205ed9194beSMatthias Springer 
206da784a25SMatthias Springer /// This pattern moves "test.move_before_parent_op" before the parent op.
207da784a25SMatthias Springer struct MoveBeforeParentOp : public RewritePattern {
208da784a25SMatthias Springer   MoveBeforeParentOp(MLIRContext *context)
209da784a25SMatthias Springer       : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
210da784a25SMatthias Springer 
211da784a25SMatthias Springer   LogicalResult matchAndRewrite(Operation *op,
212da784a25SMatthias Springer                                 PatternRewriter &rewriter) const override {
213da784a25SMatthias Springer     // Do not hoist past functions.
214da784a25SMatthias Springer     if (isa<FunctionOpInterface>(op->getParentOp()))
215da784a25SMatthias Springer       return failure();
216da784a25SMatthias Springer     rewriter.moveOpBefore(op, op->getParentOp());
217da784a25SMatthias Springer     return success();
218da784a25SMatthias Springer   }
219da784a25SMatthias Springer };
220da784a25SMatthias Springer 
2210f8a6b7dSJakub Kuderski /// This pattern moves "test.move_after_parent_op" after the parent op.
2220f8a6b7dSJakub Kuderski struct MoveAfterParentOp : public RewritePattern {
2230f8a6b7dSJakub Kuderski   MoveAfterParentOp(MLIRContext *context)
2240f8a6b7dSJakub Kuderski       : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {}
2250f8a6b7dSJakub Kuderski 
2260f8a6b7dSJakub Kuderski   LogicalResult matchAndRewrite(Operation *op,
2270f8a6b7dSJakub Kuderski                                 PatternRewriter &rewriter) const override {
2280f8a6b7dSJakub Kuderski     // Do not hoist past functions.
2290f8a6b7dSJakub Kuderski     if (isa<FunctionOpInterface>(op->getParentOp()))
2300f8a6b7dSJakub Kuderski       return failure();
2310f8a6b7dSJakub Kuderski 
2320f8a6b7dSJakub Kuderski     int64_t moveForwardBy = 0;
2330f8a6b7dSJakub Kuderski     if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance"))
2340f8a6b7dSJakub Kuderski       moveForwardBy = advanceBy.getInt();
2350f8a6b7dSJakub Kuderski 
2360f8a6b7dSJakub Kuderski     Operation *moveAfter = op->getParentOp();
2370f8a6b7dSJakub Kuderski     for (int64_t i = 0; i < moveForwardBy; ++i)
2380f8a6b7dSJakub Kuderski       moveAfter = moveAfter->getNextNode();
2390f8a6b7dSJakub Kuderski 
2400f8a6b7dSJakub Kuderski     rewriter.moveOpAfter(op, moveAfter);
2410f8a6b7dSJakub Kuderski     return success();
2420f8a6b7dSJakub Kuderski   }
2430f8a6b7dSJakub Kuderski };
2440f8a6b7dSJakub Kuderski 
245c672b342SMatthias Springer /// This pattern inlines blocks that are nested in
246c672b342SMatthias Springer /// "test.inline_blocks_into_parent" into the parent block.
247c672b342SMatthias Springer struct InlineBlocksIntoParent : public RewritePattern {
248c672b342SMatthias Springer   InlineBlocksIntoParent(MLIRContext *context)
249c672b342SMatthias Springer       : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
250c672b342SMatthias Springer                        context) {}
251c672b342SMatthias Springer 
252c672b342SMatthias Springer   LogicalResult matchAndRewrite(Operation *op,
253c672b342SMatthias Springer                                 PatternRewriter &rewriter) const override {
254c672b342SMatthias Springer     bool changed = false;
255c672b342SMatthias Springer     for (Region &r : op->getRegions()) {
256c672b342SMatthias Springer       while (!r.empty()) {
257c672b342SMatthias Springer         rewriter.inlineBlockBefore(&r.front(), op);
258c672b342SMatthias Springer         changed = true;
259c672b342SMatthias Springer       }
260c672b342SMatthias Springer     }
261c672b342SMatthias Springer     return success(changed);
262c672b342SMatthias Springer   }
263c672b342SMatthias Springer };
264c672b342SMatthias Springer 
265c2675ba9SMatthias Springer /// This pattern splits blocks at "test.split_block_here" and replaces the op
266c2675ba9SMatthias Springer /// with a new op (to prevent an infinite loop of block splitting).
267c2675ba9SMatthias Springer struct SplitBlockHere : public RewritePattern {
268c2675ba9SMatthias Springer   SplitBlockHere(MLIRContext *context)
269c2675ba9SMatthias Springer       : RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
270c2675ba9SMatthias Springer 
271c2675ba9SMatthias Springer   LogicalResult matchAndRewrite(Operation *op,
272c2675ba9SMatthias Springer                                 PatternRewriter &rewriter) const override {
273c2675ba9SMatthias Springer     rewriter.splitBlock(op->getBlock(), op->getIterator());
274c2675ba9SMatthias Springer     Operation *newOp = rewriter.create(
275c2675ba9SMatthias Springer         op->getLoc(),
276c2675ba9SMatthias Springer         OperationName("test.new_op", op->getContext()).getIdentifier(),
277c2675ba9SMatthias Springer         op->getOperands(), op->getResultTypes());
278c2675ba9SMatthias Springer     rewriter.replaceOp(op, newOp);
279c2675ba9SMatthias Springer     return success();
280c2675ba9SMatthias Springer   }
281c2675ba9SMatthias Springer };
282c2675ba9SMatthias Springer 
283237a799eSMatthias Springer /// This pattern clones "test.clone_me" ops.
284237a799eSMatthias Springer struct CloneOp : public RewritePattern {
285237a799eSMatthias Springer   CloneOp(MLIRContext *context)
286237a799eSMatthias Springer       : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
287237a799eSMatthias Springer 
288237a799eSMatthias Springer   LogicalResult matchAndRewrite(Operation *op,
289237a799eSMatthias Springer                                 PatternRewriter &rewriter) const override {
290237a799eSMatthias Springer     // Do not clone already cloned ops to avoid going into an infinite loop.
291237a799eSMatthias Springer     if (op->hasAttr("was_cloned"))
292237a799eSMatthias Springer       return failure();
293237a799eSMatthias Springer     Operation *cloned = rewriter.clone(*op);
294237a799eSMatthias Springer     cloned->setAttr("was_cloned", rewriter.getUnitAttr());
295237a799eSMatthias Springer     return success();
296237a799eSMatthias Springer   }
297237a799eSMatthias Springer };
298237a799eSMatthias Springer 
299b840d296SMatthias Springer /// This pattern clones regions of "test.clone_region_before" ops before the
300b840d296SMatthias Springer /// parent block.
301b840d296SMatthias Springer struct CloneRegionBeforeOp : public RewritePattern {
302b840d296SMatthias Springer   CloneRegionBeforeOp(MLIRContext *context)
303b840d296SMatthias Springer       : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {}
304b840d296SMatthias Springer 
305b840d296SMatthias Springer   LogicalResult matchAndRewrite(Operation *op,
306b840d296SMatthias Springer                                 PatternRewriter &rewriter) const override {
307b840d296SMatthias Springer     // Do not clone already cloned ops to avoid going into an infinite loop.
308b840d296SMatthias Springer     if (op->hasAttr("was_cloned"))
309b840d296SMatthias Springer       return failure();
310b840d296SMatthias Springer     for (Region &r : op->getRegions())
311b840d296SMatthias Springer       rewriter.cloneRegionBefore(r, op->getBlock());
312b840d296SMatthias Springer     op->setAttr("was_cloned", rewriter.getUnitAttr());
313b840d296SMatthias Springer     return success();
314b840d296SMatthias Springer   }
315b840d296SMatthias Springer };
316b840d296SMatthias Springer 
3170f8a6b7dSJakub Kuderski /// Replace an operation may introduce the re-visiting of its users.
3180f8a6b7dSJakub Kuderski class ReplaceWithNewOp : public RewritePattern {
3190f8a6b7dSJakub Kuderski public:
3200f8a6b7dSJakub Kuderski   ReplaceWithNewOp(MLIRContext *context)
3210f8a6b7dSJakub Kuderski       : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
3225e50dd04SRiver Riddle 
3230f8a6b7dSJakub Kuderski   LogicalResult matchAndRewrite(Operation *op,
3240f8a6b7dSJakub Kuderski                                 PatternRewriter &rewriter) const override {
3250f8a6b7dSJakub Kuderski     Operation *newOp;
3260f8a6b7dSJakub Kuderski     if (op->hasAttr("create_erase_op")) {
3270f8a6b7dSJakub Kuderski       newOp = rewriter.create(
3280f8a6b7dSJakub Kuderski           op->getLoc(),
3290f8a6b7dSJakub Kuderski           OperationName("test.erase_op", op->getContext()).getIdentifier(),
3300f8a6b7dSJakub Kuderski           ValueRange(), TypeRange());
3310f8a6b7dSJakub Kuderski     } else {
3320f8a6b7dSJakub Kuderski       newOp = rewriter.create(
3330f8a6b7dSJakub Kuderski           op->getLoc(),
3340f8a6b7dSJakub Kuderski           OperationName("test.new_op", op->getContext()).getIdentifier(),
3350f8a6b7dSJakub Kuderski           op->getOperands(), op->getResultTypes());
3360f8a6b7dSJakub Kuderski     }
3370f8a6b7dSJakub Kuderski     // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
3380f8a6b7dSJakub Kuderski     // A "notifyOperationReplaced" callback is triggered in either case.
3390f8a6b7dSJakub Kuderski     rewriter.replaceAllOpUsesWith(op, newOp->getResults());
3400f8a6b7dSJakub Kuderski     rewriter.eraseOp(op);
3410f8a6b7dSJakub Kuderski     return success();
3420f8a6b7dSJakub Kuderski   }
3430f8a6b7dSJakub Kuderski };
3447814b559Srkayaith 
3450f8a6b7dSJakub Kuderski /// Erases the first child block of the matched "test.erase_first_block"
3460f8a6b7dSJakub Kuderski /// operation.
3470f8a6b7dSJakub Kuderski class EraseFirstBlock : public RewritePattern {
3480f8a6b7dSJakub Kuderski public:
3490f8a6b7dSJakub Kuderski   EraseFirstBlock(MLIRContext *context)
3500f8a6b7dSJakub Kuderski       : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {}
3510f8a6b7dSJakub Kuderski 
3520f8a6b7dSJakub Kuderski   LogicalResult matchAndRewrite(Operation *op,
3530f8a6b7dSJakub Kuderski                                 PatternRewriter &rewriter) const override {
3540f8a6b7dSJakub Kuderski     for (Region &r : op->getRegions()) {
3550f8a6b7dSJakub Kuderski       for (Block &b : r.getBlocks()) {
3560f8a6b7dSJakub Kuderski         rewriter.eraseBlock(&b);
3570f8a6b7dSJakub Kuderski         return success();
3580f8a6b7dSJakub Kuderski       }
3590f8a6b7dSJakub Kuderski     }
3600f8a6b7dSJakub Kuderski 
3610f8a6b7dSJakub Kuderski     return failure();
3620f8a6b7dSJakub Kuderski   }
3630f8a6b7dSJakub Kuderski };
3640f8a6b7dSJakub Kuderski 
3650f8a6b7dSJakub Kuderski struct TestGreedyPatternDriver
3660f8a6b7dSJakub Kuderski     : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> {
3670f8a6b7dSJakub Kuderski   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver)
3680f8a6b7dSJakub Kuderski 
3690f8a6b7dSJakub Kuderski   TestGreedyPatternDriver() = default;
3700f8a6b7dSJakub Kuderski   TestGreedyPatternDriver(const TestGreedyPatternDriver &other)
3710f8a6b7dSJakub Kuderski       : PassWrapper(other) {}
3720f8a6b7dSJakub Kuderski 
3730f8a6b7dSJakub Kuderski   StringRef getArgument() const final { return "test-greedy-patterns"; }
374b5e22e6dSMehdi Amini   StringRef getDescription() const final { return "Run test dialect patterns"; }
37541574554SRiver Riddle   void runOnOperation() override {
376dc4e913bSChris Lattner     mlir::RewritePatternSet patterns(&getContext());
37783e3c6a7SMehdi Amini     populateWithGenerated(patterns);
378fec6c5acSUday Bondhugula 
379fec6c5acSUday Bondhugula     // Verify named pattern is generated with expected name.
380e4635e63SRiver Riddle     patterns.add<FoldingPattern, TestNamedPatternRule,
38143f0d5f9SMehdi Amini                  FolderInsertBeforePreviouslyFoldedConstantPattern,
382ed9194beSMatthias Springer                  FolderCommutativeOp2WithConstant, HoistEligibleOps,
383ed9194beSMatthias Springer                  MakeOpEligible>(&getContext());
384fec6c5acSUday Bondhugula 
3852fea658aSMatthias Springer     // Additional patterns for testing the GreedyPatternRewriteDriver.
3862fea658aSMatthias Springer     patterns.insert<IncrementIntAttribute<3>>(&getContext());
3872fea658aSMatthias Springer 
3887814b559Srkayaith     GreedyRewriteConfig config;
3897814b559Srkayaith     config.useTopDownTraversal = this->useTopDownTraversal;
3902fea658aSMatthias Springer     config.maxIterations = this->maxIterations;
39109dfc571SJacques Pienaar     config.fold = this->fold;
39209dfc571SJacques Pienaar     config.cseConstants = this->cseConstants;
39309dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
394fec6c5acSUday Bondhugula   }
3957814b559Srkayaith 
3967814b559Srkayaith   Option<bool> useTopDownTraversal{
3977814b559Srkayaith       *this, "top-down",
3987814b559Srkayaith       llvm::cl::desc("Seed the worklist in general top-down order"),
3997814b559Srkayaith       llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
4002fea658aSMatthias Springer   Option<int> maxIterations{
4012fea658aSMatthias Springer       *this, "max-iterations",
4022fea658aSMatthias Springer       llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
4032fea658aSMatthias Springer       llvm::cl::init(GreedyRewriteConfig().maxIterations)};
40409dfc571SJacques Pienaar   Option<bool> fold{*this, "fold", llvm::cl::desc("Whether to fold"),
40509dfc571SJacques Pienaar                     llvm::cl::init(GreedyRewriteConfig().fold)};
40609dfc571SJacques Pienaar   Option<bool> cseConstants{*this, "cse-constants",
40709dfc571SJacques Pienaar                             llvm::cl::desc("Whether to CSE constants"),
40809dfc571SJacques Pienaar                             llvm::cl::init(GreedyRewriteConfig().cseConstants)};
409fec6c5acSUday Bondhugula };
410ba3a9f51SChia-hung Duan 
411695a5a6aSMatthias Springer struct DumpNotifications : public RewriterBase::Listener {
412237a799eSMatthias Springer   void notifyBlockInserted(Block *block, Region *previous,
413237a799eSMatthias Springer                            Region::iterator previousIt) override {
41460a20bd6SMatthias Springer     llvm::outs() << "notifyBlockInserted";
41560a20bd6SMatthias Springer     if (block->getParentOp()) {
41660a20bd6SMatthias Springer       llvm::outs() << " into " << block->getParentOp()->getName() << ": ";
41760a20bd6SMatthias Springer     } else {
41860a20bd6SMatthias Springer       llvm::outs() << " into unknown op: ";
41960a20bd6SMatthias Springer     }
420237a799eSMatthias Springer     if (previous == nullptr) {
421237a799eSMatthias Springer       llvm::outs() << "was unlinked\n";
422237a799eSMatthias Springer     } else {
423237a799eSMatthias Springer       llvm::outs() << "was linked\n";
424237a799eSMatthias Springer     }
425237a799eSMatthias Springer   }
426da784a25SMatthias Springer   void notifyOperationInserted(Operation *op,
427da784a25SMatthias Springer                                OpBuilder::InsertPoint previous) override {
428da784a25SMatthias Springer     llvm::outs() << "notifyOperationInserted: " << op->getName();
429da784a25SMatthias Springer     if (!previous.isSet()) {
430da784a25SMatthias Springer       llvm::outs() << ", was unlinked\n";
431da784a25SMatthias Springer     } else {
43260a20bd6SMatthias Springer       if (!previous.getPoint().getNodePtr()) {
43360a20bd6SMatthias Springer         llvm::outs() << ", was linked, exact position unknown\n";
43460a20bd6SMatthias Springer       } else if (previous.getPoint() == previous.getBlock()->end()) {
435da784a25SMatthias Springer         llvm::outs() << ", was last in block\n";
436da784a25SMatthias Springer       } else {
437da784a25SMatthias Springer         llvm::outs() << ", previous = " << previous.getPoint()->getName()
438da784a25SMatthias Springer                      << "\n";
439da784a25SMatthias Springer       }
440da784a25SMatthias Springer     }
441da784a25SMatthias Springer   }
44260a20bd6SMatthias Springer   void notifyBlockErased(Block *block) override {
44360a20bd6SMatthias Springer     llvm::outs() << "notifyBlockErased\n";
44460a20bd6SMatthias Springer   }
445914e6074SMatthias Springer   void notifyOperationErased(Operation *op) override {
446914e6074SMatthias Springer     llvm::outs() << "notifyOperationErased: " << op->getName() << "\n";
447695a5a6aSMatthias Springer   }
44860a20bd6SMatthias Springer   void notifyOperationModified(Operation *op) override {
44960a20bd6SMatthias Springer     llvm::outs() << "notifyOperationModified: " << op->getName() << "\n";
45060a20bd6SMatthias Springer   }
45160a20bd6SMatthias Springer   void notifyOperationReplaced(Operation *op, ValueRange values) override {
45260a20bd6SMatthias Springer     llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n";
45360a20bd6SMatthias Springer   }
454695a5a6aSMatthias Springer };
455695a5a6aSMatthias Springer 
456ba3a9f51SChia-hung Duan struct TestStrictPatternDriver
457ba3a9f51SChia-hung Duan     : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
458ba3a9f51SChia-hung Duan public:
459ba3a9f51SChia-hung Duan   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
460ba3a9f51SChia-hung Duan 
461ba3a9f51SChia-hung Duan   TestStrictPatternDriver() = default;
4626f469d60SMehdi Amini   TestStrictPatternDriver(const TestStrictPatternDriver &other)
4636f469d60SMehdi Amini       : PassWrapper(other) {
46487e345b1SMatthias Springer     strictMode = other.strictMode;
46587e345b1SMatthias Springer   }
466ba3a9f51SChia-hung Duan 
467ba3a9f51SChia-hung Duan   StringRef getArgument() const final { return "test-strict-pattern-driver"; }
468ba3a9f51SChia-hung Duan   StringRef getDescription() const final {
46987e345b1SMatthias Springer     return "Test strict mode of pattern driver";
470ba3a9f51SChia-hung Duan   }
471ba3a9f51SChia-hung Duan 
472ba3a9f51SChia-hung Duan   void runOnOperation() override {
473774416bdSMatthias Springer     MLIRContext *ctx = &getContext();
474774416bdSMatthias Springer     mlir::RewritePatternSet patterns(ctx);
4754bba8bd3SIngo Müller     patterns.add<
4764bba8bd3SIngo Müller         // clang-format off
4774bba8bd3SIngo Müller         ChangeBlockOp,
478237a799eSMatthias Springer         CloneOp,
479b840d296SMatthias Springer         CloneRegionBeforeOp,
480c672b342SMatthias Springer         EraseOp,
481da784a25SMatthias Springer         ImplicitChangeOp,
482c672b342SMatthias Springer         InlineBlocksIntoParent,
483c672b342SMatthias Springer         InsertSameOp,
484c672b342SMatthias Springer         MoveBeforeParentOp,
485c2675ba9SMatthias Springer         ReplaceWithNewOp,
486c2675ba9SMatthias Springer         SplitBlockHere
4874bba8bd3SIngo Müller         // clang-format on
4884bba8bd3SIngo Müller         >(ctx);
489ba3a9f51SChia-hung Duan     SmallVector<Operation *> ops;
490ba3a9f51SChia-hung Duan     getOperation()->walk([&](Operation *op) {
491ba3a9f51SChia-hung Duan       StringRef opName = op->getName().getStringRef();
4924bba8bd3SIngo Müller       if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
493da784a25SMatthias Springer           opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
494c672b342SMatthias Springer           opName == "test.move_before_parent_op" ||
495c2675ba9SMatthias Springer           opName == "test.inline_blocks_into_parent" ||
496b840d296SMatthias Springer           opName == "test.split_block_here" || opName == "test.clone_me" ||
497b840d296SMatthias Springer           opName == "test.clone_region_before") {
498ba3a9f51SChia-hung Duan         ops.push_back(op);
499ba3a9f51SChia-hung Duan       }
500ba3a9f51SChia-hung Duan     });
501ba3a9f51SChia-hung Duan 
502695a5a6aSMatthias Springer     DumpNotifications dumpNotifications;
5036bdecbcbSMatthias Springer     GreedyRewriteConfig config;
504695a5a6aSMatthias Springer     config.listener = &dumpNotifications;
50587e345b1SMatthias Springer     if (strictMode == "AnyOp") {
5066bdecbcbSMatthias Springer       config.strictMode = GreedyRewriteStrictness::AnyOp;
50787e345b1SMatthias Springer     } else if (strictMode == "ExistingAndNewOps") {
5086bdecbcbSMatthias Springer       config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
50987e345b1SMatthias Springer     } else if (strictMode == "ExistingOps") {
5106bdecbcbSMatthias Springer       config.strictMode = GreedyRewriteStrictness::ExistingOps;
51187e345b1SMatthias Springer     } else {
51287e345b1SMatthias Springer       llvm_unreachable("invalid strictness option");
51387e345b1SMatthias Springer     }
51487e345b1SMatthias Springer 
515ba3a9f51SChia-hung Duan     // Check if these transformations introduce visiting of operations that
516ba3a9f51SChia-hung Duan     // are not in the `ops` set (The new created ops are valid). An invalid
517ba3a9f51SChia-hung Duan     // operation will trigger the assertion while processing.
518774416bdSMatthias Springer     bool changed = false;
519774416bdSMatthias Springer     bool allErased = false;
52009dfc571SJacques Pienaar     (void)applyOpPatternsGreedily(ArrayRef(ops), std::move(patterns), config,
5216bdecbcbSMatthias Springer                                   &changed, &allErased);
522774416bdSMatthias Springer     Builder b(ctx);
523774416bdSMatthias Springer     getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed));
524774416bdSMatthias Springer     getOperation()->setAttr("pattern_driver_all_erased",
525774416bdSMatthias Springer                             b.getBoolAttr(allErased));
526ba3a9f51SChia-hung Duan   }
527ba3a9f51SChia-hung Duan 
52887e345b1SMatthias Springer   Option<std::string> strictMode{
52987e345b1SMatthias Springer       *this, "strictness",
53087e345b1SMatthias Springer       llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"),
53187e345b1SMatthias Springer       llvm::cl::init("AnyOp")};
53287e345b1SMatthias Springer 
533ba3a9f51SChia-hung Duan private:
534ba3a9f51SChia-hung Duan   // New inserted operation is valid for further transformation.
535ba3a9f51SChia-hung Duan   class InsertSameOp : public RewritePattern {
536ba3a9f51SChia-hung Duan   public:
537ba3a9f51SChia-hung Duan     InsertSameOp(MLIRContext *context)
538ba3a9f51SChia-hung Duan         : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {}
539ba3a9f51SChia-hung Duan 
540ba3a9f51SChia-hung Duan     LogicalResult matchAndRewrite(Operation *op,
541ba3a9f51SChia-hung Duan                                   PatternRewriter &rewriter) const override {
542ba3a9f51SChia-hung Duan       if (op->hasAttr("skip"))
543ba3a9f51SChia-hung Duan         return failure();
544ba3a9f51SChia-hung Duan 
545ba3a9f51SChia-hung Duan       Operation *newOp =
546ba3a9f51SChia-hung Duan           rewriter.create(op->getLoc(), op->getName().getIdentifier(),
547ba3a9f51SChia-hung Duan                           op->getOperands(), op->getResultTypes());
5485fcf907bSMatthias Springer       rewriter.modifyOpInPlace(
549e219e66eSMatthias Springer           op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); });
550ba3a9f51SChia-hung Duan       newOp->setAttr("skip", rewriter.getBoolAttr(true));
551ba3a9f51SChia-hung Duan 
552ba3a9f51SChia-hung Duan       return success();
553ba3a9f51SChia-hung Duan     }
554ba3a9f51SChia-hung Duan   };
555ba3a9f51SChia-hung Duan 
5564bba8bd3SIngo Müller   // Remove an operation may introduce the re-visiting of its operands.
557ba3a9f51SChia-hung Duan   class EraseOp : public RewritePattern {
558ba3a9f51SChia-hung Duan   public:
559ba3a9f51SChia-hung Duan     EraseOp(MLIRContext *context)
560ba3a9f51SChia-hung Duan         : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
561ba3a9f51SChia-hung Duan     LogicalResult matchAndRewrite(Operation *op,
562ba3a9f51SChia-hung Duan                                   PatternRewriter &rewriter) const override {
563ba3a9f51SChia-hung Duan       rewriter.eraseOp(op);
564ba3a9f51SChia-hung Duan       return success();
565ba3a9f51SChia-hung Duan     }
566ba3a9f51SChia-hung Duan   };
5674bba8bd3SIngo Müller 
5684bba8bd3SIngo Müller   // The following two patterns test RewriterBase::replaceAllUsesWith.
5694bba8bd3SIngo Müller   //
5704bba8bd3SIngo Müller   // That function replaces all usages of a Block (or a Value) with another one
5714bba8bd3SIngo Müller   // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver
5724bba8bd3SIngo Müller   // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its
5734bba8bd3SIngo Müller   // worklist: when an op is modified, it is added to the worklist. The two
5744bba8bd3SIngo Müller   // patterns below make the tracking observable: ChangeBlockOp replaces all
5754bba8bd3SIngo Müller   // usages of a block and that pattern is applied because the corresponding ops
5764bba8bd3SIngo Müller   // are put on the initial worklist (see above). ImplicitChangeOp does an
5774bba8bd3SIngo Müller   // unrelated change but ops of the corresponding type are *not* on the initial
5784bba8bd3SIngo Müller   // worklist, so the effect of the second pattern is only visible if the
5794bba8bd3SIngo Müller   // tracking and subsequent adding to the worklist actually works.
5804bba8bd3SIngo Müller 
5814bba8bd3SIngo Müller   // Replace all usages of the first successor with the second successor.
5824bba8bd3SIngo Müller   class ChangeBlockOp : public RewritePattern {
5834bba8bd3SIngo Müller   public:
5844bba8bd3SIngo Müller     ChangeBlockOp(MLIRContext *context)
5854bba8bd3SIngo Müller         : RewritePattern("test.change_block_op", /*benefit=*/1, context) {}
5864bba8bd3SIngo Müller     LogicalResult matchAndRewrite(Operation *op,
5874bba8bd3SIngo Müller                                   PatternRewriter &rewriter) const override {
5884bba8bd3SIngo Müller       if (op->getNumSuccessors() < 2)
5894bba8bd3SIngo Müller         return failure();
5904bba8bd3SIngo Müller       Block *firstSuccessor = op->getSuccessor(0);
5914bba8bd3SIngo Müller       Block *secondSuccessor = op->getSuccessor(1);
5924bba8bd3SIngo Müller       if (firstSuccessor == secondSuccessor)
5934bba8bd3SIngo Müller         return failure();
5944bba8bd3SIngo Müller       // This is the function being tested:
5954bba8bd3SIngo Müller       rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor);
5964bba8bd3SIngo Müller       // Using the following line instead would make the test fail:
5974bba8bd3SIngo Müller       // firstSuccessor->replaceAllUsesWith(secondSuccessor);
5984bba8bd3SIngo Müller       return success();
5994bba8bd3SIngo Müller     }
6004bba8bd3SIngo Müller   };
6014bba8bd3SIngo Müller 
6024bba8bd3SIngo Müller   // Changes the successor to the parent block.
6034bba8bd3SIngo Müller   class ImplicitChangeOp : public RewritePattern {
6044bba8bd3SIngo Müller   public:
6054bba8bd3SIngo Müller     ImplicitChangeOp(MLIRContext *context)
6064bba8bd3SIngo Müller         : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {}
6074bba8bd3SIngo Müller     LogicalResult matchAndRewrite(Operation *op,
6084bba8bd3SIngo Müller                                   PatternRewriter &rewriter) const override {
6094bba8bd3SIngo Müller       if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock())
6104bba8bd3SIngo Müller         return failure();
6115fcf907bSMatthias Springer       rewriter.modifyOpInPlace(op,
6125fcf907bSMatthias Springer                                [&]() { op->setSuccessor(op->getBlock(), 0); });
6134bba8bd3SIngo Müller       return success();
6144bba8bd3SIngo Müller     }
6154bba8bd3SIngo Müller   };
616ba3a9f51SChia-hung Duan };
617ba3a9f51SChia-hung Duan 
6180f8a6b7dSJakub Kuderski struct TestWalkPatternDriver final
6190f8a6b7dSJakub Kuderski     : PassWrapper<TestWalkPatternDriver, OperationPass<>> {
6200f8a6b7dSJakub Kuderski   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver)
6210f8a6b7dSJakub Kuderski 
6220f8a6b7dSJakub Kuderski   TestWalkPatternDriver() = default;
6230f8a6b7dSJakub Kuderski   TestWalkPatternDriver(const TestWalkPatternDriver &other)
6240f8a6b7dSJakub Kuderski       : PassWrapper(other) {}
6250f8a6b7dSJakub Kuderski 
6260f8a6b7dSJakub Kuderski   StringRef getArgument() const override {
6270f8a6b7dSJakub Kuderski     return "test-walk-pattern-rewrite-driver";
6280f8a6b7dSJakub Kuderski   }
6290f8a6b7dSJakub Kuderski   StringRef getDescription() const override {
6300f8a6b7dSJakub Kuderski     return "Run test walk pattern rewrite driver";
6310f8a6b7dSJakub Kuderski   }
6320f8a6b7dSJakub Kuderski   void runOnOperation() override {
6330f8a6b7dSJakub Kuderski     mlir::RewritePatternSet patterns(&getContext());
6340f8a6b7dSJakub Kuderski 
6350f8a6b7dSJakub Kuderski     // Patterns for testing the WalkPatternRewriteDriver.
6360f8a6b7dSJakub Kuderski     patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp,
6370f8a6b7dSJakub Kuderski                  MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>(
6380f8a6b7dSJakub Kuderski         &getContext());
6390f8a6b7dSJakub Kuderski 
6400f8a6b7dSJakub Kuderski     DumpNotifications dumpListener;
6410f8a6b7dSJakub Kuderski     walkAndApplyPatterns(getOperation(), std::move(patterns),
6420f8a6b7dSJakub Kuderski                          dumpNotifications ? &dumpListener : nullptr);
6430f8a6b7dSJakub Kuderski   }
6440f8a6b7dSJakub Kuderski 
6450f8a6b7dSJakub Kuderski   Option<bool> dumpNotifications{
6460f8a6b7dSJakub Kuderski       *this, "dump-notifications",
6470f8a6b7dSJakub Kuderski       llvm::cl::desc("Print rewrite listener notifications"),
6480f8a6b7dSJakub Kuderski       llvm::cl::init(false)};
6490f8a6b7dSJakub Kuderski };
6500f8a6b7dSJakub Kuderski 
651be0a7e9fSMehdi Amini } // namespace
652fec6c5acSUday Bondhugula 
653fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
654fec6c5acSUday Bondhugula // ReturnType Driver.
655fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
656fec6c5acSUday Bondhugula 
657fec6c5acSUday Bondhugula namespace {
658fec6c5acSUday Bondhugula // Generate ops for each instance where the type can be successfully inferred.
659fec6c5acSUday Bondhugula template <typename OpTy>
660fec6c5acSUday Bondhugula static void invokeCreateWithInferredReturnType(Operation *op) {
661fec6c5acSUday Bondhugula   auto *context = op->getContext();
66258ceae95SRiver Riddle   auto fop = op->getParentOfType<func::FuncOp>();
663fec6c5acSUday Bondhugula   auto location = UnknownLoc::get(context);
664fec6c5acSUday Bondhugula   OpBuilder b(op);
665fec6c5acSUday Bondhugula   b.setInsertionPointAfter(op);
666fec6c5acSUday Bondhugula 
667fec6c5acSUday Bondhugula   // Use permutations of 2 args as operands.
668fec6c5acSUday Bondhugula   assert(fop.getNumArguments() >= 2);
669fec6c5acSUday Bondhugula   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
670fec6c5acSUday Bondhugula     for (int j = 0; j < e; ++j) {
671fec6c5acSUday Bondhugula       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
672fec6c5acSUday Bondhugula       SmallVector<Type, 2> inferredReturnTypes;
6735eae715aSJacques Pienaar       if (succeeded(OpTy::inferReturnTypes(
674bbe5bf17SMehdi Amini               context, std::nullopt, values, op->getDiscardableAttrDictionary(),
6755e118f93SMehdi Amini               op->getPropertiesStorage(), op->getRegions(),
6765e118f93SMehdi Amini               inferredReturnTypes))) {
677fec6c5acSUday Bondhugula         OperationState state(location, OpTy::getOperationName());
6789db53a18SRiver Riddle         // TODO: Expand to regions.
679bb1d976fSAlex Zinenko         OpTy::build(b, state, values, op->getAttrs());
68014ecafd0SChia-hung Duan         (void)b.create(state);
681fec6c5acSUday Bondhugula       }
682fec6c5acSUday Bondhugula     }
683fec6c5acSUday Bondhugula   }
684fec6c5acSUday Bondhugula }
685fec6c5acSUday Bondhugula 
686fec6c5acSUday Bondhugula static void reifyReturnShape(Operation *op) {
687fec6c5acSUday Bondhugula   OpBuilder b(op);
688fec6c5acSUday Bondhugula 
689fec6c5acSUday Bondhugula   // Use permutations of 2 args as operands.
690fec6c5acSUday Bondhugula   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
691fec6c5acSUday Bondhugula   SmallVector<Value, 2> shapes;
692851d02f6SWenyi Zhao   if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
6939b051703SMaheshRavishankar       !llvm::hasSingleElement(shapes))
694fec6c5acSUday Bondhugula     return;
69589de9cc8SMehdi Amini   for (const auto &it : llvm::enumerate(shapes)) {
696fec6c5acSUday Bondhugula     op->emitRemark() << "value " << it.index() << ": "
697fec6c5acSUday Bondhugula                      << it.value().getDefiningOp();
698fec6c5acSUday Bondhugula   }
6999b051703SMaheshRavishankar }
700fec6c5acSUday Bondhugula 
70180aca1eaSRiver Riddle struct TestReturnTypeDriver
70258ceae95SRiver Riddle     : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> {
7035e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver)
7045e50dd04SRiver Riddle 
705e2310704SJulian Gross   void getDependentDialects(DialectRegistry &registry) const override {
706c0a6318dSMatthias Springer     registry.insert<tensor::TensorDialect>();
707e2310704SJulian Gross   }
708b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-return-type"; }
709b5e22e6dSMehdi Amini   StringRef getDescription() const final { return "Run return type functions"; }
710e2310704SJulian Gross 
71141574554SRiver Riddle   void runOnOperation() override {
71241574554SRiver Riddle     if (getOperation().getName() == "testCreateFunctions") {
713fec6c5acSUday Bondhugula       std::vector<Operation *> ops;
714fec6c5acSUday Bondhugula       // Collect ops to avoid triggering on inserted ops.
71541574554SRiver Riddle       for (auto &op : getOperation().getBody().front())
716fec6c5acSUday Bondhugula         ops.push_back(&op);
717fec6c5acSUday Bondhugula       // Generate test patterns for each, but skip terminator.
718984b800aSserge-sans-paille       for (auto *op : llvm::ArrayRef(ops).drop_back()) {
719fec6c5acSUday Bondhugula         // Test create method of each of the Op classes below. The resultant
720fec6c5acSUday Bondhugula         // output would be in reverse order underneath `op` from which
721fec6c5acSUday Bondhugula         // the attributes and regions are used.
722fec6c5acSUday Bondhugula         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
7235267ed05SAmanda Tang         invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>(
7245267ed05SAmanda Tang             op);
725fec6c5acSUday Bondhugula         invokeCreateWithInferredReturnType<
726fec6c5acSUday Bondhugula             OpWithShapedTypeInferTypeInterfaceOp>(op);
727fec6c5acSUday Bondhugula       };
728fec6c5acSUday Bondhugula       return;
729fec6c5acSUday Bondhugula     }
73041574554SRiver Riddle     if (getOperation().getName() == "testReifyFunctions") {
731fec6c5acSUday Bondhugula       std::vector<Operation *> ops;
732fec6c5acSUday Bondhugula       // Collect ops to avoid triggering on inserted ops.
73341574554SRiver Riddle       for (auto &op : getOperation().getBody().front())
734fec6c5acSUday Bondhugula         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
735fec6c5acSUday Bondhugula           ops.push_back(&op);
736fec6c5acSUday Bondhugula       // Generate test patterns for each, but skip terminator.
737fec6c5acSUday Bondhugula       for (auto *op : ops)
738fec6c5acSUday Bondhugula         reifyReturnShape(op);
739fec6c5acSUday Bondhugula     }
740fec6c5acSUday Bondhugula   }
741fec6c5acSUday Bondhugula };
742be0a7e9fSMehdi Amini } // namespace
743fec6c5acSUday Bondhugula 
7449ba37b3bSJacques Pienaar namespace {
7459ba37b3bSJacques Pienaar struct TestDerivedAttributeDriver
74658ceae95SRiver Riddle     : public PassWrapper<TestDerivedAttributeDriver,
74758ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
7485e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver)
7495e50dd04SRiver Riddle 
750b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-derived-attr"; }
751b5e22e6dSMehdi Amini   StringRef getDescription() const final {
752b5e22e6dSMehdi Amini     return "Run test derived attributes";
753b5e22e6dSMehdi Amini   }
75441574554SRiver Riddle   void runOnOperation() override;
7559ba37b3bSJacques Pienaar };
756be0a7e9fSMehdi Amini } // namespace
7579ba37b3bSJacques Pienaar 
75841574554SRiver Riddle void TestDerivedAttributeDriver::runOnOperation() {
75941574554SRiver Riddle   getOperation().walk([](DerivedAttributeOpInterface dOp) {
7609ba37b3bSJacques Pienaar     auto dAttr = dOp.materializeDerivedAttributes();
7619ba37b3bSJacques Pienaar     if (!dAttr)
7629ba37b3bSJacques Pienaar       return;
7639ba37b3bSJacques Pienaar     for (auto d : dAttr)
7640c7890c8SRiver Riddle       dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue();
7659ba37b3bSJacques Pienaar   });
7669ba37b3bSJacques Pienaar }
7679ba37b3bSJacques Pienaar 
768fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
769fec6c5acSUday Bondhugula // Legalization Driver.
770fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
771fec6c5acSUday Bondhugula 
772fec6c5acSUday Bondhugula namespace {
773fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
774fec6c5acSUday Bondhugula // Region-Block Rewrite Testing
775fec6c5acSUday Bondhugula 
776684a5a30SMatthias Springer /// This pattern applies a signature conversion to a block inside a detached
777684a5a30SMatthias Springer /// region.
778684a5a30SMatthias Springer struct TestDetachedSignatureConversion : public ConversionPattern {
779684a5a30SMatthias Springer   TestDetachedSignatureConversion(MLIRContext *ctx)
780684a5a30SMatthias Springer       : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1,
781684a5a30SMatthias Springer                           ctx) {}
782684a5a30SMatthias Springer 
783684a5a30SMatthias Springer   LogicalResult
784684a5a30SMatthias Springer   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
785684a5a30SMatthias Springer                   ConversionPatternRewriter &rewriter) const final {
786684a5a30SMatthias Springer     if (op->getNumRegions() != 1)
787684a5a30SMatthias Springer       return failure();
788412e30b2SMatthias Springer     OperationState state(op->getLoc(), "test.legal_op", operands,
789684a5a30SMatthias Springer                          op->getResultTypes(), {}, BlockRange());
790684a5a30SMatthias Springer     Region *newRegion = state.addRegion();
791684a5a30SMatthias Springer     rewriter.inlineRegionBefore(op->getRegion(0), *newRegion,
792684a5a30SMatthias Springer                                 newRegion->begin());
793684a5a30SMatthias Springer     TypeConverter::SignatureConversion result(newRegion->getNumArguments());
794684a5a30SMatthias Springer     for (unsigned i = 0, e = newRegion->getNumArguments(); i < e; ++i)
795684a5a30SMatthias Springer       result.addInputs(i, rewriter.getF64Type());
796684a5a30SMatthias Springer     rewriter.applySignatureConversion(&newRegion->front(), result);
797684a5a30SMatthias Springer     Operation *newOp = rewriter.create(state);
798684a5a30SMatthias Springer     rewriter.replaceOp(op, newOp->getResults());
799684a5a30SMatthias Springer     return success();
800684a5a30SMatthias Springer   }
801684a5a30SMatthias Springer };
802684a5a30SMatthias Springer 
803fec6c5acSUday Bondhugula /// This pattern is a simple pattern that inlines the first region of a given
804fec6c5acSUday Bondhugula /// operation into the parent region.
805fec6c5acSUday Bondhugula struct TestRegionRewriteBlockMovement : public ConversionPattern {
806fec6c5acSUday Bondhugula   TestRegionRewriteBlockMovement(MLIRContext *ctx)
807fec6c5acSUday Bondhugula       : ConversionPattern("test.region", 1, ctx) {}
808fec6c5acSUday Bondhugula 
809fec6c5acSUday Bondhugula   LogicalResult
810fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
811fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const final {
812fec6c5acSUday Bondhugula     // Inline this region into the parent region.
813fec6c5acSUday Bondhugula     auto &parentRegion = *op->getParentRegion();
814b0750e2dSTres Popp     auto &opRegion = op->getRegion(0);
815830b9b07SMehdi Amini     if (op->getDiscardableAttr("legalizer.should_clone"))
816b0750e2dSTres Popp       rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
817fec6c5acSUday Bondhugula     else
818b0750e2dSTres Popp       rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
819b0750e2dSTres Popp 
820830b9b07SMehdi Amini     if (op->getDiscardableAttr("legalizer.erase_old_blocks")) {
821b0750e2dSTres Popp       while (!opRegion.empty())
822b0750e2dSTres Popp         rewriter.eraseBlock(&opRegion.front());
823b0750e2dSTres Popp     }
824fec6c5acSUday Bondhugula 
825fec6c5acSUday Bondhugula     // Drop this operation.
826fec6c5acSUday Bondhugula     rewriter.eraseOp(op);
827fec6c5acSUday Bondhugula     return success();
828fec6c5acSUday Bondhugula   }
829fec6c5acSUday Bondhugula };
830fec6c5acSUday Bondhugula /// This pattern is a simple pattern that generates a region containing an
831fec6c5acSUday Bondhugula /// illegal operation.
832fec6c5acSUday Bondhugula struct TestRegionRewriteUndo : public RewritePattern {
833fec6c5acSUday Bondhugula   TestRegionRewriteUndo(MLIRContext *ctx)
834fec6c5acSUday Bondhugula       : RewritePattern("test.region_builder", 1, ctx) {}
835fec6c5acSUday Bondhugula 
836fec6c5acSUday Bondhugula   LogicalResult matchAndRewrite(Operation *op,
837fec6c5acSUday Bondhugula                                 PatternRewriter &rewriter) const final {
838fec6c5acSUday Bondhugula     // Create the region operation with an entry block containing arguments.
839fec6c5acSUday Bondhugula     OperationState newRegion(op->getLoc(), "test.region");
840fec6c5acSUday Bondhugula     newRegion.addRegion();
84114ecafd0SChia-hung Duan     auto *regionOp = rewriter.create(newRegion);
842fec6c5acSUday Bondhugula     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
843e084679fSRiver Riddle     entryBlock->addArgument(rewriter.getIntegerType(64),
844e084679fSRiver Riddle                             rewriter.getUnknownLoc());
845fec6c5acSUday Bondhugula 
846fec6c5acSUday Bondhugula     // Add an explicitly illegal operation to ensure the conversion fails.
847fec6c5acSUday Bondhugula     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
848fec6c5acSUday Bondhugula     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
849fec6c5acSUday Bondhugula 
850fec6c5acSUday Bondhugula     // Drop this operation.
851fec6c5acSUday Bondhugula     rewriter.eraseOp(op);
852fec6c5acSUday Bondhugula     return success();
853fec6c5acSUday Bondhugula   }
854fec6c5acSUday Bondhugula };
855f27f1e8cSAlex Zinenko /// A simple pattern that creates a block at the end of the parent region of the
856f27f1e8cSAlex Zinenko /// matched operation.
857f27f1e8cSAlex Zinenko struct TestCreateBlock : public RewritePattern {
858f27f1e8cSAlex Zinenko   TestCreateBlock(MLIRContext *ctx)
859f27f1e8cSAlex Zinenko       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
860f27f1e8cSAlex Zinenko 
861f27f1e8cSAlex Zinenko   LogicalResult matchAndRewrite(Operation *op,
862f27f1e8cSAlex Zinenko                                 PatternRewriter &rewriter) const final {
863f27f1e8cSAlex Zinenko     Region &region = *op->getParentRegion();
864f27f1e8cSAlex Zinenko     Type i32Type = rewriter.getIntegerType(32);
865e084679fSRiver Riddle     Location loc = op->getLoc();
866e084679fSRiver Riddle     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
867e084679fSRiver Riddle     rewriter.create<TerminatorOp>(loc);
86871d50c89SMatthias Springer     rewriter.eraseOp(op);
869f27f1e8cSAlex Zinenko     return success();
870f27f1e8cSAlex Zinenko   }
871f27f1e8cSAlex Zinenko };
872f27f1e8cSAlex Zinenko 
873a23d0559SKazuaki Ishizaki /// A simple pattern that creates a block containing an invalid operation in
874f27f1e8cSAlex Zinenko /// order to trigger the block creation undo mechanism.
875f27f1e8cSAlex Zinenko struct TestCreateIllegalBlock : public RewritePattern {
876f27f1e8cSAlex Zinenko   TestCreateIllegalBlock(MLIRContext *ctx)
877f27f1e8cSAlex Zinenko       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
878f27f1e8cSAlex Zinenko 
879f27f1e8cSAlex Zinenko   LogicalResult matchAndRewrite(Operation *op,
880f27f1e8cSAlex Zinenko                                 PatternRewriter &rewriter) const final {
881f27f1e8cSAlex Zinenko     Region &region = *op->getParentRegion();
882f27f1e8cSAlex Zinenko     Type i32Type = rewriter.getIntegerType(32);
883e084679fSRiver Riddle     Location loc = op->getLoc();
884e084679fSRiver Riddle     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
885f27f1e8cSAlex Zinenko     // Create an illegal op to ensure the conversion fails.
886e084679fSRiver Riddle     rewriter.create<ILLegalOpF>(loc, i32Type);
887e084679fSRiver Riddle     rewriter.create<TerminatorOp>(loc);
88871d50c89SMatthias Springer     rewriter.eraseOp(op);
889f27f1e8cSAlex Zinenko     return success();
890f27f1e8cSAlex Zinenko   }
891f27f1e8cSAlex Zinenko };
892fec6c5acSUday Bondhugula 
8930816de16SRiver Riddle /// A simple pattern that tests the undo mechanism when replacing the uses of a
8940816de16SRiver Riddle /// block argument.
8950816de16SRiver Riddle struct TestUndoBlockArgReplace : public ConversionPattern {
8960816de16SRiver Riddle   TestUndoBlockArgReplace(MLIRContext *ctx)
8970816de16SRiver Riddle       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
8980816de16SRiver Riddle 
8990816de16SRiver Riddle   LogicalResult
9000816de16SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
9010816de16SRiver Riddle                   ConversionPatternRewriter &rewriter) const final {
9020816de16SRiver Riddle     auto illegalOp =
9030816de16SRiver Riddle         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
904e2b71610SRahul Joshi     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
905a0b21046SRahul Kayaith                                         illegalOp->getResult(0));
9065fcf907bSMatthias Springer     rewriter.modifyOpInPlace(op, [] {});
9070816de16SRiver Riddle     return success();
9080816de16SRiver Riddle   }
9090816de16SRiver Riddle };
9100816de16SRiver Riddle 
9118f4cd2c7SMatthias Springer /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
9128f4cd2c7SMatthias Springer /// This is to test the rollback logic.
9138f4cd2c7SMatthias Springer struct TestUndoMoveOpBefore : public ConversionPattern {
9148f4cd2c7SMatthias Springer   TestUndoMoveOpBefore(MLIRContext *ctx)
9158f4cd2c7SMatthias Springer       : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
9168f4cd2c7SMatthias Springer 
9178f4cd2c7SMatthias Springer   LogicalResult
9188f4cd2c7SMatthias Springer   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
9198f4cd2c7SMatthias Springer                   ConversionPatternRewriter &rewriter) const override {
9208f4cd2c7SMatthias Springer     rewriter.moveOpBefore(op, op->getParentOp());
9218f4cd2c7SMatthias Springer     // Replace with an illegal op to ensure the conversion fails.
9228f4cd2c7SMatthias Springer     rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
9238f4cd2c7SMatthias Springer     return success();
9248f4cd2c7SMatthias Springer   }
9258f4cd2c7SMatthias Springer };
9268f4cd2c7SMatthias Springer 
927df48026bSAlex Zinenko /// A rewrite pattern that tests the undo mechanism when erasing a block.
928df48026bSAlex Zinenko struct TestUndoBlockErase : public ConversionPattern {
929df48026bSAlex Zinenko   TestUndoBlockErase(MLIRContext *ctx)
930df48026bSAlex Zinenko       : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
931df48026bSAlex Zinenko 
932df48026bSAlex Zinenko   LogicalResult
933df48026bSAlex Zinenko   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
934df48026bSAlex Zinenko                   ConversionPatternRewriter &rewriter) const final {
935df48026bSAlex Zinenko     Block *secondBlock = &*std::next(op->getRegion(0).begin());
936df48026bSAlex Zinenko     rewriter.setInsertionPointToStart(secondBlock);
937df48026bSAlex Zinenko     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
938df48026bSAlex Zinenko     rewriter.eraseBlock(secondBlock);
9395fcf907bSMatthias Springer     rewriter.modifyOpInPlace(op, [] {});
940df48026bSAlex Zinenko     return success();
941df48026bSAlex Zinenko   }
942df48026bSAlex Zinenko };
943df48026bSAlex Zinenko 
9443a70335bSMatthias Springer /// A pattern that modifies a property in-place, but keeps the op illegal.
9453a70335bSMatthias Springer struct TestUndoPropertiesModification : public ConversionPattern {
9463a70335bSMatthias Springer   TestUndoPropertiesModification(MLIRContext *ctx)
9473a70335bSMatthias Springer       : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {}
9483a70335bSMatthias Springer   LogicalResult
9493a70335bSMatthias Springer   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
9503a70335bSMatthias Springer                   ConversionPatternRewriter &rewriter) const final {
9513a70335bSMatthias Springer     if (!op->hasAttr("modify_inplace"))
9523a70335bSMatthias Springer       return failure();
9533a70335bSMatthias Springer     rewriter.modifyOpInPlace(
9543a70335bSMatthias Springer         op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); });
9553a70335bSMatthias Springer     return success();
9563a70335bSMatthias Springer   }
9573a70335bSMatthias Springer };
9583a70335bSMatthias Springer 
959fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
960fec6c5acSUday Bondhugula // Type-Conversion Rewrite Testing
961fec6c5acSUday Bondhugula 
962fec6c5acSUday Bondhugula /// This patterns erases a region operation that has had a type conversion.
963fec6c5acSUday Bondhugula struct TestDropOpSignatureConversion : public ConversionPattern {
964ce254598SMatthias Springer   TestDropOpSignatureConversion(MLIRContext *ctx,
965ce254598SMatthias Springer                                 const TypeConverter &converter)
96676f3c2f3SRiver Riddle       : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
967fec6c5acSUday Bondhugula   LogicalResult
968fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
969fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const override {
970fec6c5acSUday Bondhugula     Region &region = op->getRegion(0);
971fec6c5acSUday Bondhugula     Block *entry = &region.front();
972fec6c5acSUday Bondhugula 
973fec6c5acSUday Bondhugula     // Convert the original entry arguments.
974ce254598SMatthias Springer     const TypeConverter &converter = *getTypeConverter();
975fec6c5acSUday Bondhugula     TypeConverter::SignatureConversion result(entry->getNumArguments());
9768d67d187SRiver Riddle     if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
9778d67d187SRiver Riddle                                               result)) ||
9788d67d187SRiver Riddle         failed(rewriter.convertRegionTypes(&region, converter, &result)))
979fec6c5acSUday Bondhugula       return failure();
980fec6c5acSUday Bondhugula 
981fec6c5acSUday Bondhugula     // Convert the region signature and just drop the operation.
982fec6c5acSUday Bondhugula     rewriter.eraseOp(op);
983fec6c5acSUday Bondhugula     return success();
984fec6c5acSUday Bondhugula   }
985fec6c5acSUday Bondhugula };
986fec6c5acSUday Bondhugula /// This pattern simply updates the operands of the given operation.
987fec6c5acSUday Bondhugula struct TestPassthroughInvalidOp : public ConversionPattern {
9883cc311abSMatthias Springer   TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
9893cc311abSMatthias Springer       : ConversionPattern(converter, "test.invalid", 1, ctx) {}
990fec6c5acSUday Bondhugula   LogicalResult
9919df63b26SMatthias Springer   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
992fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const final {
9939df63b26SMatthias Springer     SmallVector<Value> flattened;
9949df63b26SMatthias Springer     for (auto it : llvm::enumerate(operands)) {
9959df63b26SMatthias Springer       ValueRange range = it.value();
9969df63b26SMatthias Springer       if (range.size() == 1) {
9979df63b26SMatthias Springer         flattened.push_back(range.front());
9989df63b26SMatthias Springer         continue;
9999df63b26SMatthias Springer       }
10009df63b26SMatthias Springer 
10019df63b26SMatthias Springer       // This is a 1:N replacement. Insert a test.cast op. (That's what the
10029df63b26SMatthias Springer       // argument materialization used to do.)
10039df63b26SMatthias Springer       flattened.push_back(
10049df63b26SMatthias Springer           rewriter
10059df63b26SMatthias Springer               .create<TestCastOp>(op->getLoc(),
10069df63b26SMatthias Springer                                   op->getOperand(it.index()).getType(), range)
10079df63b26SMatthias Springer               .getResult());
10089df63b26SMatthias Springer     }
10099df63b26SMatthias Springer     rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, flattened,
10101a36588eSKazu Hirata                                              std::nullopt);
1011fec6c5acSUday Bondhugula     return success();
1012fec6c5acSUday Bondhugula   }
1013fec6c5acSUday Bondhugula };
10143815f478SMatthias Springer /// Replace with valid op, but simply drop the operands. This is used in a
10153815f478SMatthias Springer /// regression where we used to generate circular unrealized_conversion_cast
10163815f478SMatthias Springer /// ops.
10173815f478SMatthias Springer struct TestDropAndReplaceInvalidOp : public ConversionPattern {
10183815f478SMatthias Springer   TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
10193815f478SMatthias Springer       : ConversionPattern(converter,
10203815f478SMatthias Springer                           "test.drop_operands_and_replace_with_valid", 1, ctx) {
10213815f478SMatthias Springer   }
10223815f478SMatthias Springer   LogicalResult
10233815f478SMatthias Springer   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
10243815f478SMatthias Springer                   ConversionPatternRewriter &rewriter) const final {
10253815f478SMatthias Springer     rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, ValueRange(),
10263815f478SMatthias Springer                                              std::nullopt);
10273815f478SMatthias Springer     return success();
10283815f478SMatthias Springer   }
10293815f478SMatthias Springer };
1030fec6c5acSUday Bondhugula /// This pattern handles the case of a split return value.
1031fec6c5acSUday Bondhugula struct TestSplitReturnType : public ConversionPattern {
1032fec6c5acSUday Bondhugula   TestSplitReturnType(MLIRContext *ctx)
1033fec6c5acSUday Bondhugula       : ConversionPattern("test.return", 1, ctx) {}
1034fec6c5acSUday Bondhugula   LogicalResult
10359df63b26SMatthias Springer   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1036fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const final {
1037fec6c5acSUday Bondhugula     // Check for a return of F32.
1038fec6c5acSUday Bondhugula     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
1039fec6c5acSUday Bondhugula       return failure();
10409df63b26SMatthias Springer     rewriter.replaceOpWithNewOp<TestReturnOp>(op, operands[0]);
1041fec6c5acSUday Bondhugula     return success();
1042fec6c5acSUday Bondhugula   }
1043fec6c5acSUday Bondhugula };
1044fec6c5acSUday Bondhugula 
1045fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
1046fec6c5acSUday Bondhugula // Multi-Level Type-Conversion Rewrite Testing
1047fec6c5acSUday Bondhugula struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
1048fec6c5acSUday Bondhugula   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
1049fec6c5acSUday Bondhugula       : ConversionPattern("test.type_producer", 1, ctx) {}
1050fec6c5acSUday Bondhugula   LogicalResult
1051fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1052fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const final {
1053fec6c5acSUday Bondhugula     // If the type is I32, change the type to F32.
1054fec6c5acSUday Bondhugula     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
1055fec6c5acSUday Bondhugula       return failure();
1056fec6c5acSUday Bondhugula     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
1057fec6c5acSUday Bondhugula     return success();
1058fec6c5acSUday Bondhugula   }
1059fec6c5acSUday Bondhugula };
1060fec6c5acSUday Bondhugula struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
1061fec6c5acSUday Bondhugula   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
1062fec6c5acSUday Bondhugula       : ConversionPattern("test.type_producer", 1, ctx) {}
1063fec6c5acSUday Bondhugula   LogicalResult
1064fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1065fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const final {
1066fec6c5acSUday Bondhugula     // If the type is F32, change the type to F64.
1067fec6c5acSUday Bondhugula     if (!Type(*op->result_type_begin()).isF32())
1068fec6c5acSUday Bondhugula       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
1069fec6c5acSUday Bondhugula     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
1070fec6c5acSUday Bondhugula     return success();
1071fec6c5acSUday Bondhugula   }
1072fec6c5acSUday Bondhugula };
1073fec6c5acSUday Bondhugula struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
1074fec6c5acSUday Bondhugula   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
1075fec6c5acSUday Bondhugula       : ConversionPattern("test.type_producer", 10, ctx) {}
1076fec6c5acSUday Bondhugula   LogicalResult
1077fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1078fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const final {
1079fec6c5acSUday Bondhugula     // Always convert to B16, even though it is not a legal type. This tests
1080fec6c5acSUday Bondhugula     // that values are unmapped correctly.
1081fec6c5acSUday Bondhugula     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
1082fec6c5acSUday Bondhugula     return success();
1083fec6c5acSUday Bondhugula   }
1084fec6c5acSUday Bondhugula };
1085fec6c5acSUday Bondhugula struct TestUpdateConsumerType : public ConversionPattern {
1086fec6c5acSUday Bondhugula   TestUpdateConsumerType(MLIRContext *ctx)
1087fec6c5acSUday Bondhugula       : ConversionPattern("test.type_consumer", 1, ctx) {}
1088fec6c5acSUday Bondhugula   LogicalResult
1089fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1090fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const final {
1091fec6c5acSUday Bondhugula     // Verify that the incoming operand has been successfully remapped to F64.
1092fec6c5acSUday Bondhugula     if (!operands[0].getType().isF64())
1093fec6c5acSUday Bondhugula       return failure();
1094fec6c5acSUday Bondhugula     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
1095fec6c5acSUday Bondhugula     return success();
1096fec6c5acSUday Bondhugula   }
1097fec6c5acSUday Bondhugula };
1098fec6c5acSUday Bondhugula 
1099fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
1100fec6c5acSUday Bondhugula // Non-Root Replacement Rewrite Testing
1101fec6c5acSUday Bondhugula /// This pattern generates an invalid operation, but replaces it before the
1102fec6c5acSUday Bondhugula /// pattern is finished. This checks that we don't need to legalize the
1103fec6c5acSUday Bondhugula /// temporary op.
1104fec6c5acSUday Bondhugula struct TestNonRootReplacement : public RewritePattern {
1105fec6c5acSUday Bondhugula   TestNonRootReplacement(MLIRContext *ctx)
1106fec6c5acSUday Bondhugula       : RewritePattern("test.replace_non_root", 1, ctx) {}
1107fec6c5acSUday Bondhugula 
1108fec6c5acSUday Bondhugula   LogicalResult matchAndRewrite(Operation *op,
1109fec6c5acSUday Bondhugula                                 PatternRewriter &rewriter) const final {
1110fec6c5acSUday Bondhugula     auto resultType = *op->result_type_begin();
1111fec6c5acSUday Bondhugula     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
1112fec6c5acSUday Bondhugula     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
1113fec6c5acSUday Bondhugula 
111471d50c89SMatthias Springer     rewriter.replaceOp(illegalOp, legalOp);
111571d50c89SMatthias Springer     rewriter.replaceOp(op, illegalOp);
1116fec6c5acSUday Bondhugula     return success();
1117fec6c5acSUday Bondhugula   }
1118fec6c5acSUday Bondhugula };
1119bd1ccfe6SRiver Riddle 
1120bd1ccfe6SRiver Riddle //===----------------------------------------------------------------------===//
1121bd1ccfe6SRiver Riddle // Recursive Rewrite Testing
1122bd1ccfe6SRiver Riddle /// This pattern is applied to the same operation multiple times, but has a
1123bd1ccfe6SRiver Riddle /// bounded recursion.
1124bd1ccfe6SRiver Riddle struct TestBoundedRecursiveRewrite
1125bd1ccfe6SRiver Riddle     : public OpRewritePattern<TestRecursiveRewriteOp> {
11262257e4a7SRiver Riddle   using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
11272257e4a7SRiver Riddle 
11282257e4a7SRiver Riddle   void initialize() {
1129b99bd771SRiver Riddle     // The conversion target handles bounding the recursion of this pattern.
1130b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
1131b99bd771SRiver Riddle   }
1132bd1ccfe6SRiver Riddle 
1133bd1ccfe6SRiver Riddle   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
1134bd1ccfe6SRiver Riddle                                 PatternRewriter &rewriter) const final {
1135bd1ccfe6SRiver Riddle     // Decrement the depth of the op in-place.
11365fcf907bSMatthias Springer     rewriter.modifyOpInPlace(op, [&] {
11376a994233SJacques Pienaar       op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
1138bd1ccfe6SRiver Riddle     });
1139bd1ccfe6SRiver Riddle     return success();
1140bd1ccfe6SRiver Riddle   }
1141bd1ccfe6SRiver Riddle };
11425d5df06aSAlex Zinenko 
11435d5df06aSAlex Zinenko struct TestNestedOpCreationUndoRewrite
11445d5df06aSAlex Zinenko     : public OpRewritePattern<IllegalOpWithRegionAnchor> {
11455d5df06aSAlex Zinenko   using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
11465d5df06aSAlex Zinenko 
11475d5df06aSAlex Zinenko   LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
11485d5df06aSAlex Zinenko                                 PatternRewriter &rewriter) const final {
11495d5df06aSAlex Zinenko     // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
11505d5df06aSAlex Zinenko     rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
11515d5df06aSAlex Zinenko     return success();
11525d5df06aSAlex Zinenko   };
11535d5df06aSAlex Zinenko };
1154a360a978SMehdi Amini 
1155a360a978SMehdi Amini // This pattern matches `test.blackhole` and delete this op and its producer.
1156a360a978SMehdi Amini struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
1157a360a978SMehdi Amini   using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
1158a360a978SMehdi Amini 
1159a360a978SMehdi Amini   LogicalResult matchAndRewrite(BlackHoleOp op,
1160a360a978SMehdi Amini                                 PatternRewriter &rewriter) const final {
1161a360a978SMehdi Amini     Operation *producer = op.getOperand().getDefiningOp();
1162a360a978SMehdi Amini     // Always erase the user before the producer, the framework should handle
1163a360a978SMehdi Amini     // this correctly.
1164a360a978SMehdi Amini     rewriter.eraseOp(op);
1165a360a978SMehdi Amini     rewriter.eraseOp(producer);
1166a360a978SMehdi Amini     return success();
1167a360a978SMehdi Amini   };
1168a360a978SMehdi Amini };
1169ec03bbe8SVladislav Vinogradov 
1170ec03bbe8SVladislav Vinogradov // This pattern replaces explicitly illegal op with explicitly legal op,
1171ec03bbe8SVladislav Vinogradov // but in addition creates unregistered operation.
1172ec03bbe8SVladislav Vinogradov struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
1173ec03bbe8SVladislav Vinogradov   using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
1174ec03bbe8SVladislav Vinogradov 
1175ec03bbe8SVladislav Vinogradov   LogicalResult matchAndRewrite(ILLegalOpG op,
1176ec03bbe8SVladislav Vinogradov                                 PatternRewriter &rewriter) const final {
1177ec03bbe8SVladislav Vinogradov     IntegerAttr attr = rewriter.getI32IntegerAttr(0);
11788e123ca6SRiver Riddle     Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
1179ec03bbe8SVladislav Vinogradov     rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
1180ec03bbe8SVladislav Vinogradov     return success();
1181ec03bbe8SVladislav Vinogradov   };
1182ec03bbe8SVladislav Vinogradov };
11833815f478SMatthias Springer 
11843815f478SMatthias Springer class TestEraseOp : public ConversionPattern {
11853815f478SMatthias Springer public:
11863815f478SMatthias Springer   TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op", 1, ctx) {}
11873815f478SMatthias Springer   LogicalResult
11883815f478SMatthias Springer   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
11893815f478SMatthias Springer                   ConversionPatternRewriter &rewriter) const final {
11903815f478SMatthias Springer     // Erase op without replacements.
11913815f478SMatthias Springer     rewriter.eraseOp(op);
11923815f478SMatthias Springer     return success();
11933815f478SMatthias Springer   }
11943815f478SMatthias Springer };
11953815f478SMatthias Springer 
11969df63b26SMatthias Springer /// This pattern matches a test.duplicate_block_args op and duplicates all
11979df63b26SMatthias Springer /// block arguments.
11989df63b26SMatthias Springer class TestDuplicateBlockArgs
11999df63b26SMatthias Springer     : public OpConversionPattern<DuplicateBlockArgsOp> {
12009df63b26SMatthias Springer   using OpConversionPattern<DuplicateBlockArgsOp>::OpConversionPattern;
12019df63b26SMatthias Springer 
12029df63b26SMatthias Springer   LogicalResult
12039df63b26SMatthias Springer   matchAndRewrite(DuplicateBlockArgsOp op, OpAdaptor adaptor,
12049df63b26SMatthias Springer                   ConversionPatternRewriter &rewriter) const override {
12059df63b26SMatthias Springer     if (op.getIsLegal())
12069df63b26SMatthias Springer       return failure();
12079df63b26SMatthias Springer     rewriter.startOpModification(op);
12089df63b26SMatthias Springer     Block *body = &op.getBody().front();
12099df63b26SMatthias Springer     TypeConverter::SignatureConversion result(body->getNumArguments());
12109df63b26SMatthias Springer     for (auto it : llvm::enumerate(body->getArgumentTypes()))
12119df63b26SMatthias Springer       result.addInputs(it.index(), {it.value(), it.value()});
12129df63b26SMatthias Springer     rewriter.applySignatureConversion(body, result, getTypeConverter());
12139df63b26SMatthias Springer     op.setIsLegal(true);
12149df63b26SMatthias Springer     rewriter.finalizeOpModification(op);
12159df63b26SMatthias Springer     return success();
12169df63b26SMatthias Springer   }
12179df63b26SMatthias Springer };
12189df63b26SMatthias Springer 
12199df63b26SMatthias Springer /// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid
12209df63b26SMatthias Springer /// op. The pattern supports 1:N replacements and forwards the replacement
12219df63b26SMatthias Springer /// values of the single operand as test.valid operands.
12229df63b26SMatthias Springer class TestRepetitive1ToNConsumer : public ConversionPattern {
12239df63b26SMatthias Springer public:
12249df63b26SMatthias Springer   TestRepetitive1ToNConsumer(MLIRContext *ctx)
12259df63b26SMatthias Springer       : ConversionPattern("test.repetitive_1_to_n_consumer", 1, ctx) {}
12269df63b26SMatthias Springer   LogicalResult
12279df63b26SMatthias Springer   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
12289df63b26SMatthias Springer                   ConversionPatternRewriter &rewriter) const final {
12299df63b26SMatthias Springer     // A single operand is expected.
12309df63b26SMatthias Springer     if (op->getNumOperands() != 1)
12319df63b26SMatthias Springer       return failure();
12329df63b26SMatthias Springer     rewriter.replaceOpWithNewOp<TestValidOp>(op, operands.front());
12339df63b26SMatthias Springer     return success();
12349df63b26SMatthias Springer   }
12359df63b26SMatthias Springer };
12369df63b26SMatthias Springer 
1237412e30b2SMatthias Springer /// A pattern that tests two back-to-back 1 -> 2 op replacements.
1238412e30b2SMatthias Springer class TestMultiple1ToNReplacement : public ConversionPattern {
1239412e30b2SMatthias Springer public:
1240412e30b2SMatthias Springer   TestMultiple1ToNReplacement(MLIRContext *ctx, const TypeConverter &converter)
1241412e30b2SMatthias Springer       : ConversionPattern(converter, "test.multiple_1_to_n_replacement", 1,
1242412e30b2SMatthias Springer                           ctx) {}
1243412e30b2SMatthias Springer   LogicalResult
1244412e30b2SMatthias Springer   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1245412e30b2SMatthias Springer                   ConversionPatternRewriter &rewriter) const final {
1246412e30b2SMatthias Springer     // Helper function that replaces the given op with a new op of the given
1247412e30b2SMatthias Springer     // name and doubles each result (1 -> 2 replacement of each result).
1248412e30b2SMatthias Springer     auto replaceWithDoubleResults = [&](Operation *op, StringRef name) {
1249412e30b2SMatthias Springer       SmallVector<Type> types;
1250412e30b2SMatthias Springer       for (Type t : op->getResultTypes()) {
1251412e30b2SMatthias Springer         types.push_back(t);
1252412e30b2SMatthias Springer         types.push_back(t);
1253412e30b2SMatthias Springer       }
1254412e30b2SMatthias Springer       OperationState state(op->getLoc(), name,
1255412e30b2SMatthias Springer                            /*operands=*/{}, types, op->getAttrs());
1256412e30b2SMatthias Springer       auto *newOp = rewriter.create(state);
1257412e30b2SMatthias Springer       SmallVector<ValueRange> repls;
1258412e30b2SMatthias Springer       for (size_t i = 0, e = op->getNumResults(); i < e; ++i)
1259412e30b2SMatthias Springer         repls.push_back(newOp->getResults().slice(2 * i, 2));
1260412e30b2SMatthias Springer       rewriter.replaceOpWithMultiple(op, repls);
1261412e30b2SMatthias Springer       return newOp;
1262412e30b2SMatthias Springer     };
1263412e30b2SMatthias Springer 
1264412e30b2SMatthias Springer     // Replace test.multiple_1_to_n_replacement with test.step_1.
1265412e30b2SMatthias Springer     Operation *repl1 = replaceWithDoubleResults(op, "test.step_1");
1266412e30b2SMatthias Springer     // Now replace test.step_1 with test.legal_op.
1267412e30b2SMatthias Springer     replaceWithDoubleResults(repl1, "test.legal_op");
1268412e30b2SMatthias Springer     return success();
1269412e30b2SMatthias Springer   }
1270412e30b2SMatthias Springer };
1271412e30b2SMatthias Springer 
1272fec6c5acSUday Bondhugula } // namespace
1273fec6c5acSUday Bondhugula 
1274fec6c5acSUday Bondhugula namespace {
1275fec6c5acSUday Bondhugula struct TestTypeConverter : public TypeConverter {
1276fec6c5acSUday Bondhugula   using TypeConverter::TypeConverter;
12775c5dafc5SAlex Zinenko   TestTypeConverter() {
12785c5dafc5SAlex Zinenko     addConversion(convertType);
12794589dd92SRiver Riddle     addSourceMaterialization(materializeCast);
12805c5dafc5SAlex Zinenko   }
1281fec6c5acSUday Bondhugula 
1282fec6c5acSUday Bondhugula   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
1283fec6c5acSUday Bondhugula     // Drop I16 types.
1284fec6c5acSUday Bondhugula     if (t.isSignlessInteger(16))
1285fec6c5acSUday Bondhugula       return success();
1286fec6c5acSUday Bondhugula 
1287fec6c5acSUday Bondhugula     // Convert I64 to F64.
1288fec6c5acSUday Bondhugula     if (t.isSignlessInteger(64)) {
1289*f023da12SMatthias Springer       results.push_back(Float64Type::get(t.getContext()));
1290fec6c5acSUday Bondhugula       return success();
1291fec6c5acSUday Bondhugula     }
1292fec6c5acSUday Bondhugula 
12935c5dafc5SAlex Zinenko     // Convert I42 to I43.
12945c5dafc5SAlex Zinenko     if (t.isInteger(42)) {
12951b97cdf8SRiver Riddle       results.push_back(IntegerType::get(t.getContext(), 43));
12965c5dafc5SAlex Zinenko       return success();
12975c5dafc5SAlex Zinenko     }
12985c5dafc5SAlex Zinenko 
1299fec6c5acSUday Bondhugula     // Split F32 into F16,F16.
1300fec6c5acSUday Bondhugula     if (t.isF32()) {
1301*f023da12SMatthias Springer       results.assign(2, Float16Type::get(t.getContext()));
1302fec6c5acSUday Bondhugula       return success();
1303fec6c5acSUday Bondhugula     }
1304fec6c5acSUday Bondhugula 
130508e6566dSMatthias Springer     // Drop I24 types.
130608e6566dSMatthias Springer     if (t.isInteger(24)) {
130708e6566dSMatthias Springer       return success();
130808e6566dSMatthias Springer     }
130908e6566dSMatthias Springer 
1310fec6c5acSUday Bondhugula     // Otherwise, convert the type directly.
1311fec6c5acSUday Bondhugula     results.push_back(t);
1312fec6c5acSUday Bondhugula     return success();
1313fec6c5acSUday Bondhugula   }
1314fec6c5acSUday Bondhugula 
13155c5dafc5SAlex Zinenko   /// Hook for materializing a conversion. This is necessary because we generate
13165c5dafc5SAlex Zinenko   /// 1->N type mappings.
1317f18c3e4eSMatthias Springer   static Value materializeCast(OpBuilder &builder, Type resultType,
13184589dd92SRiver Riddle                                ValueRange inputs, Location loc) {
13194589dd92SRiver Riddle     return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
13205c5dafc5SAlex Zinenko   }
1321fec6c5acSUday Bondhugula };
1322fec6c5acSUday Bondhugula 
1323fec6c5acSUday Bondhugula struct TestLegalizePatternDriver
1324c3a8ac79SMatthias Springer     : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> {
13255e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver)
13265e50dd04SRiver Riddle 
1327b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-legalize-patterns"; }
1328b5e22e6dSMehdi Amini   StringRef getDescription() const final {
1329b5e22e6dSMehdi Amini     return "Run test dialect legalization patterns";
1330b5e22e6dSMehdi Amini   }
1331fec6c5acSUday Bondhugula   /// The mode of conversion to use with the driver.
1332fec6c5acSUday Bondhugula   enum class ConversionMode { Analysis, Full, Partial };
1333fec6c5acSUday Bondhugula 
1334fec6c5acSUday Bondhugula   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
1335fec6c5acSUday Bondhugula 
1336ec03bbe8SVladislav Vinogradov   void getDependentDialects(DialectRegistry &registry) const override {
133799fd12f2SMehdi Amini     registry.insert<func::FuncDialect, test::TestDialect>();
1338ec03bbe8SVladislav Vinogradov   }
1339ec03bbe8SVladislav Vinogradov 
1340722f909fSRiver Riddle   void runOnOperation() override {
1341fec6c5acSUday Bondhugula     TestTypeConverter converter;
1342dc4e913bSChris Lattner     mlir::RewritePatternSet patterns(&getContext());
13431d909c9aSChris Lattner     populateWithGenerated(patterns);
13443cc311abSMatthias Springer     patterns
13453cc311abSMatthias Springer         .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1346684a5a30SMatthias Springer              TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
13473cc311abSMatthias Springer              TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
13483cc311abSMatthias Springer              TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
13493cc311abSMatthias Springer              TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
13503cc311abSMatthias Springer              TestNonRootReplacement, TestBoundedRecursiveRewrite,
13513cc311abSMatthias Springer              TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
13523cc311abSMatthias Springer              TestCreateUnregisteredOp, TestUndoMoveOpBefore,
13539df63b26SMatthias Springer              TestUndoPropertiesModification, TestEraseOp,
13549df63b26SMatthias Springer              TestRepetitive1ToNConsumer>(&getContext());
13553cc311abSMatthias Springer     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
1356412e30b2SMatthias Springer                  TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
1357412e30b2SMatthias Springer         &getContext(), converter);
13589df63b26SMatthias Springer     patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
1359ed4749f9SIvan Butygin     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1360ed4749f9SIvan Butygin                                                               converter);
13613a506b31SChris Lattner     mlir::populateCallOpTypeConversionPattern(patterns, converter);
1362fec6c5acSUday Bondhugula 
1363fec6c5acSUday Bondhugula     // Define the conversion target used for the test.
1364fec6c5acSUday Bondhugula     ConversionTarget target(getContext());
1365973ddb7dSMehdi Amini     target.addLegalOp<ModuleOp>();
1366ec03bbe8SVladislav Vinogradov     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
13678f4cd2c7SMatthias Springer                       TerminatorOp, OneRegionOp>();
1368412e30b2SMatthias Springer     target.addLegalOp(OperationName("test.legal_op", &getContext()));
1369fec6c5acSUday Bondhugula     target
1370fec6c5acSUday Bondhugula         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
1371fec6c5acSUday Bondhugula     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
1372fec6c5acSUday Bondhugula       // Don't allow F32 operands.
1373fec6c5acSUday Bondhugula       return llvm::none_of(op.getOperandTypes(),
1374fec6c5acSUday Bondhugula                            [](Type type) { return type.isF32(); });
1375fec6c5acSUday Bondhugula     });
137658ceae95SRiver Riddle     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
13774a3460a7SRiver Riddle       return converter.isSignatureLegal(op.getFunctionType()) &&
13788d67d187SRiver Riddle              converter.isLegal(&op.getBody());
13798d67d187SRiver Riddle     });
138023aa5a74SRiver Riddle     target.addDynamicallyLegalOp<func::CallOp>(
138123aa5a74SRiver Riddle         [&](func::CallOp op) { return converter.isLegal(op); });
1382fec6c5acSUday Bondhugula 
1383a54f4eaeSMogball     // TestCreateUnregisteredOp creates `arith.constant` operation,
1384ec03bbe8SVladislav Vinogradov     // which was not added to target intentionally to test
1385ec03bbe8SVladislav Vinogradov     // correct error code from conversion driver.
1386ec03bbe8SVladislav Vinogradov     target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
1387ec03bbe8SVladislav Vinogradov 
1388fec6c5acSUday Bondhugula     // Expect the type_producer/type_consumer operations to only operate on f64.
1389fec6c5acSUday Bondhugula     target.addDynamicallyLegalOp<TestTypeProducerOp>(
1390fec6c5acSUday Bondhugula         [](TestTypeProducerOp op) { return op.getType().isF64(); });
1391fec6c5acSUday Bondhugula     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1392fec6c5acSUday Bondhugula       return op.getOperand().getType().isF64();
1393fec6c5acSUday Bondhugula     });
1394fec6c5acSUday Bondhugula 
1395fec6c5acSUday Bondhugula     // Check support for marking certain operations as recursively legal.
139658ceae95SRiver Riddle     target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) {
1397fec6c5acSUday Bondhugula       return static_cast<bool>(
1398fec6c5acSUday Bondhugula           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
1399fec6c5acSUday Bondhugula     });
1400fec6c5acSUday Bondhugula 
1401bd1ccfe6SRiver Riddle     // Mark the bound recursion operation as dynamically legal.
1402bd1ccfe6SRiver Riddle     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
14036a994233SJacques Pienaar         [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
1404bd1ccfe6SRiver Riddle 
14054513050fSChristian Ulmann     // Create a dynamically legal rule that can only be legalized by folding it.
14064513050fSChristian Ulmann     target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
14074513050fSChristian Ulmann         [](TestOpInPlaceSelfFold op) { return op.getFolded(); });
14084513050fSChristian Ulmann 
14099df63b26SMatthias Springer     target.addDynamicallyLegalOp<DuplicateBlockArgsOp>(
14109df63b26SMatthias Springer         [](DuplicateBlockArgsOp op) { return op.getIsLegal(); });
14119df63b26SMatthias Springer 
1412fec6c5acSUday Bondhugula     // Handle a partial conversion.
1413fec6c5acSUday Bondhugula     if (mode == ConversionMode::Partial) {
14148de482eaSLucy Fox       DenseSet<Operation *> unlegalizedOps;
1415a2821094SMatthias Springer       ConversionConfig config;
141660a20bd6SMatthias Springer       DumpNotifications dumpNotifications;
141760a20bd6SMatthias Springer       config.listener = &dumpNotifications;
1418a2821094SMatthias Springer       config.unlegalizedOps = &unlegalizedOps;
1419a2821094SMatthias Springer       if (failed(applyPartialConversion(getOperation(), target,
1420a2821094SMatthias Springer                                         std::move(patterns), config))) {
1421ec03bbe8SVladislav Vinogradov         getOperation()->emitRemark() << "applyPartialConversion failed";
1422ec03bbe8SVladislav Vinogradov       }
14238de482eaSLucy Fox       // Emit remarks for each legalizable operation.
14248de482eaSLucy Fox       for (auto *op : unlegalizedOps)
14258de482eaSLucy Fox         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1426fec6c5acSUday Bondhugula       return;
1427fec6c5acSUday Bondhugula     }
1428fec6c5acSUday Bondhugula 
1429fec6c5acSUday Bondhugula     // Handle a full conversion.
1430fec6c5acSUday Bondhugula     if (mode == ConversionMode::Full) {
1431fec6c5acSUday Bondhugula       // Check support for marking unknown operations as dynamically legal.
1432fec6c5acSUday Bondhugula       target.markUnknownOpDynamicallyLegal([](Operation *op) {
1433fec6c5acSUday Bondhugula         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
1434fec6c5acSUday Bondhugula       });
1435fec6c5acSUday Bondhugula 
143660a20bd6SMatthias Springer       ConversionConfig config;
143760a20bd6SMatthias Springer       DumpNotifications dumpNotifications;
143860a20bd6SMatthias Springer       config.listener = &dumpNotifications;
1439ec03bbe8SVladislav Vinogradov       if (failed(applyFullConversion(getOperation(), target,
144060a20bd6SMatthias Springer                                      std::move(patterns), config))) {
1441ec03bbe8SVladislav Vinogradov         getOperation()->emitRemark() << "applyFullConversion failed";
1442ec03bbe8SVladislav Vinogradov       }
1443fec6c5acSUday Bondhugula       return;
1444fec6c5acSUday Bondhugula     }
1445fec6c5acSUday Bondhugula 
1446fec6c5acSUday Bondhugula     // Otherwise, handle an analysis conversion.
1447fec6c5acSUday Bondhugula     assert(mode == ConversionMode::Analysis);
1448fec6c5acSUday Bondhugula 
1449fec6c5acSUday Bondhugula     // Analyze the convertible operations.
1450fec6c5acSUday Bondhugula     DenseSet<Operation *> legalizedOps;
1451a2821094SMatthias Springer     ConversionConfig config;
1452a2821094SMatthias Springer     config.legalizableOps = &legalizedOps;
14533fffffa8SRiver Riddle     if (failed(applyAnalysisConversion(getOperation(), target,
1454a2821094SMatthias Springer                                        std::move(patterns), config)))
1455fec6c5acSUday Bondhugula       return signalPassFailure();
1456fec6c5acSUday Bondhugula 
1457fec6c5acSUday Bondhugula     // Emit remarks for each legalizable operation.
1458fec6c5acSUday Bondhugula     for (auto *op : legalizedOps)
1459fec6c5acSUday Bondhugula       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
1460fec6c5acSUday Bondhugula   }
1461fec6c5acSUday Bondhugula 
1462fec6c5acSUday Bondhugula   /// The mode of conversion to use.
1463fec6c5acSUday Bondhugula   ConversionMode mode;
1464fec6c5acSUday Bondhugula };
1465be0a7e9fSMehdi Amini } // namespace
1466fec6c5acSUday Bondhugula 
1467fec6c5acSUday Bondhugula static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
1468fec6c5acSUday Bondhugula     legalizerConversionMode(
1469fec6c5acSUday Bondhugula         "test-legalize-mode",
1470fec6c5acSUday Bondhugula         llvm::cl::desc("The legalization mode to use with the test driver"),
1471fec6c5acSUday Bondhugula         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
1472fec6c5acSUday Bondhugula         llvm::cl::values(
1473fec6c5acSUday Bondhugula             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
1474fec6c5acSUday Bondhugula                        "analysis", "Perform an analysis conversion"),
1475fec6c5acSUday Bondhugula             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
1476fec6c5acSUday Bondhugula                        "Perform a full conversion"),
1477fec6c5acSUday Bondhugula             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
1478fec6c5acSUday Bondhugula                        "partial", "Perform a partial conversion")));
1479fec6c5acSUday Bondhugula 
1480fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
1481fec6c5acSUday Bondhugula // ConversionPatternRewriter::getRemappedValue testing. This method is used
14825aacce3dSKazuaki Ishizaki // to get the remapped value of an original value that was replaced using
1483fec6c5acSUday Bondhugula // ConversionPatternRewriter.
1484fec6c5acSUday Bondhugula namespace {
1485015192c6SRiver Riddle struct TestRemapValueTypeConverter : public TypeConverter {
1486015192c6SRiver Riddle   using TypeConverter::TypeConverter;
1487015192c6SRiver Riddle 
1488015192c6SRiver Riddle   TestRemapValueTypeConverter() {
1489015192c6SRiver Riddle     addConversion(
1490015192c6SRiver Riddle         [](Float32Type type) { return Float64Type::get(type.getContext()); });
1491015192c6SRiver Riddle     addConversion([](Type type) { return type; });
1492015192c6SRiver Riddle   }
1493015192c6SRiver Riddle };
1494015192c6SRiver Riddle 
1495fec6c5acSUday Bondhugula /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
1496fec6c5acSUday Bondhugula /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
1497fec6c5acSUday Bondhugula /// operand twice.
1498fec6c5acSUday Bondhugula ///
1499fec6c5acSUday Bondhugula /// Example:
1500fec6c5acSUday Bondhugula ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
1501fec6c5acSUday Bondhugula /// is replaced with:
1502fec6c5acSUday Bondhugula ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
1503fec6c5acSUday Bondhugula struct OneVResOneVOperandOp1Converter
1504fec6c5acSUday Bondhugula     : public OpConversionPattern<OneVResOneVOperandOp1> {
1505fec6c5acSUday Bondhugula   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
1506fec6c5acSUday Bondhugula 
1507fec6c5acSUday Bondhugula   LogicalResult
1508ef976337SRiver Riddle   matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
1509fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const override {
1510fec6c5acSUday Bondhugula     auto origOps = op.getOperands();
1511fec6c5acSUday Bondhugula     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
1512fec6c5acSUday Bondhugula            "One operand expected");
1513fec6c5acSUday Bondhugula     Value origOp = *origOps.begin();
1514fec6c5acSUday Bondhugula     SmallVector<Value, 2> remappedOperands;
1515fec6c5acSUday Bondhugula     // Replicate the remapped original operand twice. Note that we don't used
1516fec6c5acSUday Bondhugula     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
1517fec6c5acSUday Bondhugula     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1518fec6c5acSUday Bondhugula     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1519fec6c5acSUday Bondhugula 
1520fec6c5acSUday Bondhugula     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
1521fec6c5acSUday Bondhugula                                                        remappedOperands);
1522fec6c5acSUday Bondhugula     return success();
1523fec6c5acSUday Bondhugula   }
1524fec6c5acSUday Bondhugula };
1525fec6c5acSUday Bondhugula 
1526015192c6SRiver Riddle /// A rewriter pattern that tests that blocks can be merged.
1527015192c6SRiver Riddle struct TestRemapValueInRegion
1528015192c6SRiver Riddle     : public OpConversionPattern<TestRemappedValueRegionOp> {
1529015192c6SRiver Riddle   using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
1530015192c6SRiver Riddle 
1531015192c6SRiver Riddle   LogicalResult
1532015192c6SRiver Riddle   matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
1533015192c6SRiver Riddle                   ConversionPatternRewriter &rewriter) const final {
1534015192c6SRiver Riddle     Block &block = op.getBody().front();
1535015192c6SRiver Riddle     Operation *terminator = block.getTerminator();
1536015192c6SRiver Riddle 
1537015192c6SRiver Riddle     // Merge the block into the parent region.
1538015192c6SRiver Riddle     Block *parentBlock = op->getBlock();
1539015192c6SRiver Riddle     Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator());
1540015192c6SRiver Riddle     rewriter.mergeBlocks(&block, parentBlock, ValueRange());
1541015192c6SRiver Riddle     rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange());
1542015192c6SRiver Riddle 
1543015192c6SRiver Riddle     // Replace the results of this operation with the remapped terminator
1544015192c6SRiver Riddle     // values.
1545015192c6SRiver Riddle     SmallVector<Value> terminatorOperands;
1546015192c6SRiver Riddle     if (failed(rewriter.getRemappedValues(terminator->getOperands(),
1547015192c6SRiver Riddle                                           terminatorOperands)))
1548015192c6SRiver Riddle       return failure();
1549015192c6SRiver Riddle 
1550015192c6SRiver Riddle     rewriter.eraseOp(terminator);
1551015192c6SRiver Riddle     rewriter.replaceOp(op, terminatorOperands);
1552015192c6SRiver Riddle     return success();
1553015192c6SRiver Riddle   }
1554015192c6SRiver Riddle };
1555015192c6SRiver Riddle 
155680aca1eaSRiver Riddle struct TestRemappedValue
1557c3a8ac79SMatthias Springer     : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> {
15585e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue)
15595e50dd04SRiver Riddle 
1560b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-remapped-value"; }
1561b5e22e6dSMehdi Amini   StringRef getDescription() const final {
1562b5e22e6dSMehdi Amini     return "Test public remapped value mechanism in ConversionPatternRewriter";
1563b5e22e6dSMehdi Amini   }
156441574554SRiver Riddle   void runOnOperation() override {
1565015192c6SRiver Riddle     TestRemapValueTypeConverter typeConverter;
1566015192c6SRiver Riddle 
1567dc4e913bSChris Lattner     mlir::RewritePatternSet patterns(&getContext());
1568dc4e913bSChris Lattner     patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
1569015192c6SRiver Riddle     patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
1570015192c6SRiver Riddle         &getContext());
1571015192c6SRiver Riddle     patterns.add<TestRemapValueInRegion>(typeConverter, &getContext());
1572fec6c5acSUday Bondhugula 
1573fec6c5acSUday Bondhugula     mlir::ConversionTarget target(getContext());
157458ceae95SRiver Riddle     target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>();
1575015192c6SRiver Riddle 
1576015192c6SRiver Riddle     // Expect the type_producer/type_consumer operations to only operate on f64.
1577015192c6SRiver Riddle     target.addDynamicallyLegalOp<TestTypeProducerOp>(
1578015192c6SRiver Riddle         [](TestTypeProducerOp op) { return op.getType().isF64(); });
1579015192c6SRiver Riddle     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1580015192c6SRiver Riddle       return op.getOperand().getType().isF64();
1581015192c6SRiver Riddle     });
1582015192c6SRiver Riddle 
1583fec6c5acSUday Bondhugula     // We make OneVResOneVOperandOp1 legal only when it has more that one
1584fec6c5acSUday Bondhugula     // operand. This will trigger the conversion that will replace one-operand
1585fec6c5acSUday Bondhugula     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
1586fec6c5acSUday Bondhugula     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
1587015192c6SRiver Riddle         [](Operation *op) { return op->getNumOperands() > 1; });
1588fec6c5acSUday Bondhugula 
158941574554SRiver Riddle     if (failed(mlir::applyFullConversion(getOperation(), target,
15903fffffa8SRiver Riddle                                          std::move(patterns)))) {
1591fec6c5acSUday Bondhugula       signalPassFailure();
1592fec6c5acSUday Bondhugula     }
1593fec6c5acSUday Bondhugula   }
1594fec6c5acSUday Bondhugula };
1595be0a7e9fSMehdi Amini } // namespace
1596fec6c5acSUday Bondhugula 
159780d7ac3bSRiver Riddle //===----------------------------------------------------------------------===//
159880d7ac3bSRiver Riddle // Test patterns without a specific root operation kind
159980d7ac3bSRiver Riddle //===----------------------------------------------------------------------===//
160080d7ac3bSRiver Riddle 
160180d7ac3bSRiver Riddle namespace {
160280d7ac3bSRiver Riddle /// This pattern matches and removes any operation in the test dialect.
160380d7ac3bSRiver Riddle struct RemoveTestDialectOps : public RewritePattern {
160476f3c2f3SRiver Riddle   RemoveTestDialectOps(MLIRContext *context)
160576f3c2f3SRiver Riddle       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
160680d7ac3bSRiver Riddle 
160780d7ac3bSRiver Riddle   LogicalResult matchAndRewrite(Operation *op,
160880d7ac3bSRiver Riddle                                 PatternRewriter &rewriter) const override {
160980d7ac3bSRiver Riddle     if (!isa<TestDialect>(op->getDialect()))
161080d7ac3bSRiver Riddle       return failure();
161180d7ac3bSRiver Riddle     rewriter.eraseOp(op);
161280d7ac3bSRiver Riddle     return success();
161380d7ac3bSRiver Riddle   }
161480d7ac3bSRiver Riddle };
161580d7ac3bSRiver Riddle 
161680d7ac3bSRiver Riddle struct TestUnknownRootOpDriver
1617c3a8ac79SMatthias Springer     : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> {
16185e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver)
16195e50dd04SRiver Riddle 
1620b5e22e6dSMehdi Amini   StringRef getArgument() const final {
1621b5e22e6dSMehdi Amini     return "test-legalize-unknown-root-patterns";
1622b5e22e6dSMehdi Amini   }
1623b5e22e6dSMehdi Amini   StringRef getDescription() const final {
1624b5e22e6dSMehdi Amini     return "Test public remapped value mechanism in ConversionPatternRewriter";
1625b5e22e6dSMehdi Amini   }
162641574554SRiver Riddle   void runOnOperation() override {
1627dc4e913bSChris Lattner     mlir::RewritePatternSet patterns(&getContext());
162876f3c2f3SRiver Riddle     patterns.add<RemoveTestDialectOps>(&getContext());
162980d7ac3bSRiver Riddle 
163080d7ac3bSRiver Riddle     mlir::ConversionTarget target(getContext());
163180d7ac3bSRiver Riddle     target.addIllegalDialect<TestDialect>();
163241574554SRiver Riddle     if (failed(applyPartialConversion(getOperation(), target,
163341574554SRiver Riddle                                       std::move(patterns))))
163480d7ac3bSRiver Riddle       signalPassFailure();
163580d7ac3bSRiver Riddle   }
163680d7ac3bSRiver Riddle };
1637be0a7e9fSMehdi Amini } // namespace
163880d7ac3bSRiver Riddle 
16394589dd92SRiver Riddle //===----------------------------------------------------------------------===//
164088bc24a7SMathieu Fehr // Test patterns that uses operations and types defined at runtime
164188bc24a7SMathieu Fehr //===----------------------------------------------------------------------===//
164288bc24a7SMathieu Fehr 
164388bc24a7SMathieu Fehr namespace {
164488bc24a7SMathieu Fehr /// This pattern matches dynamic operations 'test.one_operand_two_results' and
164588bc24a7SMathieu Fehr /// replace them with dynamic operations 'test.generic_dynamic_op'.
164688bc24a7SMathieu Fehr struct RewriteDynamicOp : public RewritePattern {
164788bc24a7SMathieu Fehr   RewriteDynamicOp(MLIRContext *context)
164888bc24a7SMathieu Fehr       : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1,
164988bc24a7SMathieu Fehr                        context) {}
165088bc24a7SMathieu Fehr 
165188bc24a7SMathieu Fehr   LogicalResult matchAndRewrite(Operation *op,
165288bc24a7SMathieu Fehr                                 PatternRewriter &rewriter) const override {
165388bc24a7SMathieu Fehr     assert(op->getName().getStringRef() ==
165488bc24a7SMathieu Fehr                "test.dynamic_one_operand_two_results" &&
165588bc24a7SMathieu Fehr            "rewrite pattern should only match operations with the right name");
165688bc24a7SMathieu Fehr 
165788bc24a7SMathieu Fehr     OperationState state(op->getLoc(), "test.dynamic_generic",
165888bc24a7SMathieu Fehr                          op->getOperands(), op->getResultTypes(),
165988bc24a7SMathieu Fehr                          op->getAttrs());
166088bc24a7SMathieu Fehr     auto *newOp = rewriter.create(state);
166188bc24a7SMathieu Fehr     rewriter.replaceOp(op, newOp->getResults());
166288bc24a7SMathieu Fehr     return success();
166388bc24a7SMathieu Fehr   }
166488bc24a7SMathieu Fehr };
166588bc24a7SMathieu Fehr 
166688bc24a7SMathieu Fehr struct TestRewriteDynamicOpDriver
1667253afd03SMatthias Springer     : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> {
166888bc24a7SMathieu Fehr   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver)
166988bc24a7SMathieu Fehr 
167088bc24a7SMathieu Fehr   void getDependentDialects(DialectRegistry &registry) const override {
167188bc24a7SMathieu Fehr     registry.insert<TestDialect>();
167288bc24a7SMathieu Fehr   }
167388bc24a7SMathieu Fehr   StringRef getArgument() const final { return "test-rewrite-dynamic-op"; }
167488bc24a7SMathieu Fehr   StringRef getDescription() const final {
167588bc24a7SMathieu Fehr     return "Test rewritting on dynamic operations";
167688bc24a7SMathieu Fehr   }
167788bc24a7SMathieu Fehr   void runOnOperation() override {
167888bc24a7SMathieu Fehr     RewritePatternSet patterns(&getContext());
167988bc24a7SMathieu Fehr     patterns.add<RewriteDynamicOp>(&getContext());
168088bc24a7SMathieu Fehr 
168188bc24a7SMathieu Fehr     ConversionTarget target(getContext());
168288bc24a7SMathieu Fehr     target.addIllegalOp(
168388bc24a7SMathieu Fehr         OperationName("test.dynamic_one_operand_two_results", &getContext()));
168488bc24a7SMathieu Fehr     target.addLegalOp(OperationName("test.dynamic_generic", &getContext()));
168588bc24a7SMathieu Fehr     if (failed(applyPartialConversion(getOperation(), target,
168688bc24a7SMathieu Fehr                                       std::move(patterns))))
168788bc24a7SMathieu Fehr       signalPassFailure();
168888bc24a7SMathieu Fehr   }
168988bc24a7SMathieu Fehr };
169088bc24a7SMathieu Fehr } // end anonymous namespace
169188bc24a7SMathieu Fehr 
169288bc24a7SMathieu Fehr //===----------------------------------------------------------------------===//
16934589dd92SRiver Riddle // Test type conversions
16944589dd92SRiver Riddle //===----------------------------------------------------------------------===//
16954589dd92SRiver Riddle 
16964589dd92SRiver Riddle namespace {
16974589dd92SRiver Riddle struct TestTypeConversionProducer
16984589dd92SRiver Riddle     : public OpConversionPattern<TestTypeProducerOp> {
16994589dd92SRiver Riddle   using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
17004589dd92SRiver Riddle   LogicalResult
1701ef976337SRiver Riddle   matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
17024589dd92SRiver Riddle                   ConversionPatternRewriter &rewriter) const final {
17034589dd92SRiver Riddle     Type resultType = op.getType();
17049c5982efSAlex Zinenko     Type convertedType = getTypeConverter()
17059c5982efSAlex Zinenko                              ? getTypeConverter()->convertType(resultType)
17069c5982efSAlex Zinenko                              : resultType;
17075550c821STres Popp     if (isa<FloatType>(resultType))
17084589dd92SRiver Riddle       resultType = rewriter.getF64Type();
17094589dd92SRiver Riddle     else if (resultType.isInteger(16))
17104589dd92SRiver Riddle       resultType = rewriter.getIntegerType(64);
17115550c821STres Popp     else if (isa<test::TestRecursiveType>(resultType) &&
17129c5982efSAlex Zinenko              convertedType != resultType)
17139c5982efSAlex Zinenko       resultType = convertedType;
17144589dd92SRiver Riddle     else
17154589dd92SRiver Riddle       return failure();
17164589dd92SRiver Riddle 
17174589dd92SRiver Riddle     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
17184589dd92SRiver Riddle     return success();
17194589dd92SRiver Riddle   }
17204589dd92SRiver Riddle };
17214589dd92SRiver Riddle 
17220409eb28SAlex Zinenko /// Call signature conversion and then fail the rewrite to trigger the undo
17230409eb28SAlex Zinenko /// mechanism.
17240409eb28SAlex Zinenko struct TestSignatureConversionUndo
17250409eb28SAlex Zinenko     : public OpConversionPattern<TestSignatureConversionUndoOp> {
17260409eb28SAlex Zinenko   using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
17270409eb28SAlex Zinenko 
17280409eb28SAlex Zinenko   LogicalResult
1729ef976337SRiver Riddle   matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
17300409eb28SAlex Zinenko                   ConversionPatternRewriter &rewriter) const final {
17310409eb28SAlex Zinenko     (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
17320409eb28SAlex Zinenko     return failure();
17330409eb28SAlex Zinenko   }
17340409eb28SAlex Zinenko };
17350409eb28SAlex Zinenko 
17364070f305SRiver Riddle /// Call signature conversion without providing a type converter to handle
17374070f305SRiver Riddle /// materializations.
17384070f305SRiver Riddle struct TestTestSignatureConversionNoConverter
17394070f305SRiver Riddle     : public OpConversionPattern<TestSignatureConversionNoConverterOp> {
1740ce254598SMatthias Springer   TestTestSignatureConversionNoConverter(const TypeConverter &converter,
17414070f305SRiver Riddle                                          MLIRContext *context)
17424070f305SRiver Riddle       : OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
17434070f305SRiver Riddle         converter(converter) {}
17444070f305SRiver Riddle 
17454070f305SRiver Riddle   LogicalResult
17464070f305SRiver Riddle   matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor,
17474070f305SRiver Riddle                   ConversionPatternRewriter &rewriter) const final {
17484070f305SRiver Riddle     Region &region = op->getRegion(0);
17494070f305SRiver Riddle     Block *entry = &region.front();
17504070f305SRiver Riddle 
17514070f305SRiver Riddle     // Convert the original entry arguments.
17524070f305SRiver Riddle     TypeConverter::SignatureConversion result(entry->getNumArguments());
17534070f305SRiver Riddle     if (failed(
17544070f305SRiver Riddle             converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
17554070f305SRiver Riddle       return failure();
175652050f3fSMatthias Springer     rewriter.modifyOpInPlace(op, [&] {
175752050f3fSMatthias Springer       rewriter.applySignatureConversion(&region.front(), result);
175852050f3fSMatthias Springer     });
17594070f305SRiver Riddle     return success();
17604070f305SRiver Riddle   }
17614070f305SRiver Riddle 
1762ce254598SMatthias Springer   const TypeConverter &converter;
17634070f305SRiver Riddle };
17644070f305SRiver Riddle 
17650409eb28SAlex Zinenko /// Just forward the operands to the root op. This is essentially a no-op
17660409eb28SAlex Zinenko /// pattern that is used to trigger target materialization.
17670409eb28SAlex Zinenko struct TestTypeConsumerForward
17680409eb28SAlex Zinenko     : public OpConversionPattern<TestTypeConsumerOp> {
17690409eb28SAlex Zinenko   using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
17700409eb28SAlex Zinenko 
17710409eb28SAlex Zinenko   LogicalResult
1772ef976337SRiver Riddle   matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
17730409eb28SAlex Zinenko                   ConversionPatternRewriter &rewriter) const final {
17745fcf907bSMatthias Springer     rewriter.modifyOpInPlace(op,
1775ef976337SRiver Riddle                              [&] { op->setOperands(adaptor.getOperands()); });
17760409eb28SAlex Zinenko     return success();
17770409eb28SAlex Zinenko   }
17780409eb28SAlex Zinenko };
17790409eb28SAlex Zinenko 
17805b91060dSAlex Zinenko struct TestTypeConversionAnotherProducer
17815b91060dSAlex Zinenko     : public OpRewritePattern<TestAnotherTypeProducerOp> {
17825b91060dSAlex Zinenko   using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
17835b91060dSAlex Zinenko 
17845b91060dSAlex Zinenko   LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
17855b91060dSAlex Zinenko                                 PatternRewriter &rewriter) const final {
17865b91060dSAlex Zinenko     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
17875b91060dSAlex Zinenko     return success();
17885b91060dSAlex Zinenko   }
17895b91060dSAlex Zinenko };
17905b91060dSAlex Zinenko 
17918fc32942SMatthias Springer struct TestReplaceWithLegalOp : public ConversionPattern {
17923cc311abSMatthias Springer   TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
17933cc311abSMatthias Springer       : ConversionPattern(converter, "test.replace_with_legal_op",
17943cc311abSMatthias Springer                           /*benefit=*/1, ctx) {}
17958fc32942SMatthias Springer   LogicalResult
17968fc32942SMatthias Springer   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
17978fc32942SMatthias Springer                   ConversionPatternRewriter &rewriter) const final {
17988fc32942SMatthias Springer     rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
17998fc32942SMatthias Springer     return success();
18008fc32942SMatthias Springer   }
18018fc32942SMatthias Springer };
18028fc32942SMatthias Springer 
18034589dd92SRiver Riddle struct TestTypeConversionDriver
1804c3a8ac79SMatthias Springer     : public PassWrapper<TestTypeConversionDriver, OperationPass<>> {
18055e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
18065e50dd04SRiver Riddle 
1807f9dc2b70SMehdi Amini   void getDependentDialects(DialectRegistry &registry) const override {
1808f9dc2b70SMehdi Amini     registry.insert<TestDialect>();
1809f9dc2b70SMehdi Amini   }
1810b5e22e6dSMehdi Amini   StringRef getArgument() const final {
1811b5e22e6dSMehdi Amini     return "test-legalize-type-conversion";
1812b5e22e6dSMehdi Amini   }
1813b5e22e6dSMehdi Amini   StringRef getDescription() const final {
1814b5e22e6dSMehdi Amini     return "Test various type conversion functionalities in DialectConversion";
1815b5e22e6dSMehdi Amini   }
1816f9dc2b70SMehdi Amini 
18174589dd92SRiver Riddle   void runOnOperation() override {
18184589dd92SRiver Riddle     // Initialize the type converter.
1819dc3dc974SMehdi Amini     SmallVector<Type, 2> conversionCallStack;
18204589dd92SRiver Riddle     TypeConverter converter;
18214589dd92SRiver Riddle 
18224589dd92SRiver Riddle     /// Add the legal set of type conversions.
18234589dd92SRiver Riddle     converter.addConversion([](Type type) -> Type {
18244589dd92SRiver Riddle       // Treat F64 as legal.
18254589dd92SRiver Riddle       if (type.isF64())
18264589dd92SRiver Riddle         return type;
18274589dd92SRiver Riddle       // Allow converting BF16/F16/F32 to F64.
18284589dd92SRiver Riddle       if (type.isBF16() || type.isF16() || type.isF32())
1829*f023da12SMatthias Springer         return Float64Type::get(type.getContext());
18304589dd92SRiver Riddle       // Otherwise, the type is illegal.
18314589dd92SRiver Riddle       return nullptr;
18324589dd92SRiver Riddle     });
18334589dd92SRiver Riddle     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
18344589dd92SRiver Riddle       // Drop all integer types.
18354589dd92SRiver Riddle       return success();
18364589dd92SRiver Riddle     });
18379c5982efSAlex Zinenko     converter.addConversion(
18389c5982efSAlex Zinenko         // Convert a recursive self-referring type into a non-self-referring
18399c5982efSAlex Zinenko         // type named "outer_converted_type" that contains a SimpleAType.
1840dc3dc974SMehdi Amini         [&](test::TestRecursiveType type,
1841dc3dc974SMehdi Amini             SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
18429c5982efSAlex Zinenko           // If the type is already converted, return it to indicate that it is
18439c5982efSAlex Zinenko           // legal.
18449c5982efSAlex Zinenko           if (type.getName() == "outer_converted_type") {
18459c5982efSAlex Zinenko             results.push_back(type);
18469c5982efSAlex Zinenko             return success();
18479c5982efSAlex Zinenko           }
18489c5982efSAlex Zinenko 
1849dc3dc974SMehdi Amini           conversionCallStack.push_back(type);
1850dc3dc974SMehdi Amini           auto popConversionCallStack = llvm::make_scope_exit(
1851dc3dc974SMehdi Amini               [&conversionCallStack]() { conversionCallStack.pop_back(); });
1852dc3dc974SMehdi Amini 
18539c5982efSAlex Zinenko           // If the type is on the call stack more than once (it is there at
18549c5982efSAlex Zinenko           // least once because of the _current_ call, which is always the last
18559c5982efSAlex Zinenko           // element on the stack), we've hit the recursive case. Just return
18569c5982efSAlex Zinenko           // SimpleAType here to create a non-recursive type as a result.
1857dc3dc974SMehdi Amini           if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(),
1858dc3dc974SMehdi Amini                                  type)) {
18599c5982efSAlex Zinenko             results.push_back(test::SimpleAType::get(type.getContext()));
18609c5982efSAlex Zinenko             return success();
18619c5982efSAlex Zinenko           }
18629c5982efSAlex Zinenko 
18639c5982efSAlex Zinenko           // Convert the body recursively.
18649c5982efSAlex Zinenko           auto result = test::TestRecursiveType::get(type.getContext(),
18659c5982efSAlex Zinenko                                                      "outer_converted_type");
18669c5982efSAlex Zinenko           if (failed(result.setBody(converter.convertType(type.getBody()))))
18679c5982efSAlex Zinenko             return failure();
18689c5982efSAlex Zinenko           results.push_back(result);
18699c5982efSAlex Zinenko           return success();
18709c5982efSAlex Zinenko         });
18714589dd92SRiver Riddle 
18724589dd92SRiver Riddle     /// Add the legal set of type materializations.
18734589dd92SRiver Riddle     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
18744589dd92SRiver Riddle                                           ValueRange inputs,
18754589dd92SRiver Riddle                                           Location loc) -> Value {
18764589dd92SRiver Riddle       // Allow casting from F64 back to F32.
18774589dd92SRiver Riddle       if (!resultType.isF16() && inputs.size() == 1 &&
18784589dd92SRiver Riddle           inputs[0].getType().isF64())
18794589dd92SRiver Riddle         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
18804589dd92SRiver Riddle       // Allow producing an i32 or i64 from nothing.
18814589dd92SRiver Riddle       if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
18824589dd92SRiver Riddle           inputs.empty())
18834589dd92SRiver Riddle         return builder.create<TestTypeProducerOp>(loc, resultType);
18844589dd92SRiver Riddle       // Allow producing an i64 from an integer.
18855550c821STres Popp       if (isa<IntegerType>(resultType) && inputs.size() == 1 &&
18865550c821STres Popp           isa<IntegerType>(inputs[0].getType()))
18874589dd92SRiver Riddle         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
18884589dd92SRiver Riddle       // Otherwise, fail.
18894589dd92SRiver Riddle       return nullptr;
18904589dd92SRiver Riddle     });
18914589dd92SRiver Riddle 
18924589dd92SRiver Riddle     // Initialize the conversion target.
18934589dd92SRiver Riddle     mlir::ConversionTarget target(getContext());
18948fc32942SMatthias Springer     target.addLegalOp<LegalOpD>();
18954589dd92SRiver Riddle     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
18965550c821STres Popp       auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType());
18979c5982efSAlex Zinenko       return op.getType().isF64() || op.getType().isInteger(64) ||
18989c5982efSAlex Zinenko              (recursiveType &&
18999c5982efSAlex Zinenko               recursiveType.getName() == "outer_converted_type");
19004589dd92SRiver Riddle     });
190158ceae95SRiver Riddle     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
19024a3460a7SRiver Riddle       return converter.isSignatureLegal(op.getFunctionType()) &&
19034589dd92SRiver Riddle              converter.isLegal(&op.getBody());
19044589dd92SRiver Riddle     });
19054589dd92SRiver Riddle     target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
19064589dd92SRiver Riddle       // Allow casts from F64 to F32.
19074589dd92SRiver Riddle       return (*op.operand_type_begin()).isF64() && op.getType().isF32();
19084589dd92SRiver Riddle     });
19094070f305SRiver Riddle     target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>(
19104070f305SRiver Riddle         [&](TestSignatureConversionNoConverterOp op) {
19114070f305SRiver Riddle           return converter.isLegal(op.getRegion().front().getArgumentTypes());
19124070f305SRiver Riddle         });
19134589dd92SRiver Riddle 
19144589dd92SRiver Riddle     // Initialize the set of rewrite patterns.
1915dc4e913bSChris Lattner     RewritePatternSet patterns(&getContext());
19163cc311abSMatthias Springer     patterns
19173cc311abSMatthias Springer         .add<TestTypeConsumerForward, TestTypeConversionProducer,
19184070f305SRiver Riddle              TestSignatureConversionUndo,
19193cc311abSMatthias Springer              TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
19203cc311abSMatthias Springer             converter, &getContext());
19213cc311abSMatthias Springer     patterns.add<TestTypeConversionAnotherProducer>(&getContext());
1922ed4749f9SIvan Butygin     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1923ed4749f9SIvan Butygin                                                               converter);
19244589dd92SRiver Riddle 
19253fffffa8SRiver Riddle     if (failed(applyPartialConversion(getOperation(), target,
19263fffffa8SRiver Riddle                                       std::move(patterns))))
19274589dd92SRiver Riddle       signalPassFailure();
19284589dd92SRiver Riddle   }
19294589dd92SRiver Riddle };
1930be0a7e9fSMehdi Amini } // namespace
19314589dd92SRiver Riddle 
1932c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===//
1933d4a53f3bSAlex Zinenko // Test Target Materialization With No Uses
1934d4a53f3bSAlex Zinenko //===----------------------------------------------------------------------===//
1935d4a53f3bSAlex Zinenko 
1936d4a53f3bSAlex Zinenko namespace {
1937d4a53f3bSAlex Zinenko struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> {
1938d4a53f3bSAlex Zinenko   using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern;
1939d4a53f3bSAlex Zinenko 
1940d4a53f3bSAlex Zinenko   LogicalResult
1941d4a53f3bSAlex Zinenko   matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor,
1942d4a53f3bSAlex Zinenko                   ConversionPatternRewriter &rewriter) const final {
1943d4a53f3bSAlex Zinenko     rewriter.replaceOp(op, adaptor.getOperands());
1944d4a53f3bSAlex Zinenko     return success();
1945d4a53f3bSAlex Zinenko   }
1946d4a53f3bSAlex Zinenko };
1947d4a53f3bSAlex Zinenko 
1948d4a53f3bSAlex Zinenko struct TestTargetMaterializationWithNoUses
1949c3a8ac79SMatthias Springer     : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> {
19505e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
19515e50dd04SRiver Riddle       TestTargetMaterializationWithNoUses)
19525e50dd04SRiver Riddle 
1953d4a53f3bSAlex Zinenko   StringRef getArgument() const final {
1954d4a53f3bSAlex Zinenko     return "test-target-materialization-with-no-uses";
1955d4a53f3bSAlex Zinenko   }
1956d4a53f3bSAlex Zinenko   StringRef getDescription() const final {
1957d4a53f3bSAlex Zinenko     return "Test a special case of target materialization in DialectConversion";
1958d4a53f3bSAlex Zinenko   }
1959d4a53f3bSAlex Zinenko 
1960d4a53f3bSAlex Zinenko   void runOnOperation() override {
1961d4a53f3bSAlex Zinenko     TypeConverter converter;
1962d4a53f3bSAlex Zinenko     converter.addConversion([](Type t) { return t; });
1963d4a53f3bSAlex Zinenko     converter.addConversion([](IntegerType intTy) -> Type {
1964d4a53f3bSAlex Zinenko       if (intTy.getWidth() == 16)
1965d4a53f3bSAlex Zinenko         return IntegerType::get(intTy.getContext(), 64);
1966d4a53f3bSAlex Zinenko       return intTy;
1967d4a53f3bSAlex Zinenko     });
1968d4a53f3bSAlex Zinenko     converter.addTargetMaterialization(
1969d4a53f3bSAlex Zinenko         [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1970d4a53f3bSAlex Zinenko           return builder.create<TestCastOp>(loc, type, inputs).getResult();
1971d4a53f3bSAlex Zinenko         });
1972d4a53f3bSAlex Zinenko 
1973d4a53f3bSAlex Zinenko     ConversionTarget target(getContext());
1974d4a53f3bSAlex Zinenko     target.addIllegalOp<TestTypeChangerOp>();
1975d4a53f3bSAlex Zinenko 
1976d4a53f3bSAlex Zinenko     RewritePatternSet patterns(&getContext());
1977d4a53f3bSAlex Zinenko     patterns.add<ForwardOperandPattern>(converter, &getContext());
1978d4a53f3bSAlex Zinenko 
1979d4a53f3bSAlex Zinenko     if (failed(applyPartialConversion(getOperation(), target,
1980d4a53f3bSAlex Zinenko                                       std::move(patterns))))
1981d4a53f3bSAlex Zinenko       signalPassFailure();
1982d4a53f3bSAlex Zinenko   }
1983d4a53f3bSAlex Zinenko };
1984d4a53f3bSAlex Zinenko } // namespace
1985d4a53f3bSAlex Zinenko 
1986d4a53f3bSAlex Zinenko //===----------------------------------------------------------------------===//
1987c8fb6ee3SRiver Riddle // Test Block Merging
1988c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===//
1989c8fb6ee3SRiver Riddle 
1990e888886cSMaheshRavishankar namespace {
1991e888886cSMaheshRavishankar /// A rewriter pattern that tests that blocks can be merged.
1992e888886cSMaheshRavishankar struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
1993e888886cSMaheshRavishankar   using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
1994e888886cSMaheshRavishankar 
1995e888886cSMaheshRavishankar   LogicalResult
1996ef976337SRiver Riddle   matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
1997e888886cSMaheshRavishankar                   ConversionPatternRewriter &rewriter) const final {
19986a994233SJacques Pienaar     Block &firstBlock = op.getBody().front();
1999e888886cSMaheshRavishankar     Operation *branchOp = firstBlock.getTerminator();
20006a994233SJacques Pienaar     Block *secondBlock = &*(std::next(op.getBody().begin()));
2001e888886cSMaheshRavishankar     auto succOperands = branchOp->getOperands();
2002e888886cSMaheshRavishankar     SmallVector<Value, 2> replacements(succOperands);
2003e888886cSMaheshRavishankar     rewriter.eraseOp(branchOp);
2004e888886cSMaheshRavishankar     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
20055fcf907bSMatthias Springer     rewriter.modifyOpInPlace(op, [] {});
2006e888886cSMaheshRavishankar     return success();
2007e888886cSMaheshRavishankar   }
2008e888886cSMaheshRavishankar };
2009e888886cSMaheshRavishankar 
2010e888886cSMaheshRavishankar /// A rewrite pattern to tests the undo mechanism of blocks being merged.
2011e888886cSMaheshRavishankar struct TestUndoBlocksMerge : public ConversionPattern {
2012e888886cSMaheshRavishankar   TestUndoBlocksMerge(MLIRContext *ctx)
2013e888886cSMaheshRavishankar       : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
2014e888886cSMaheshRavishankar   LogicalResult
2015e888886cSMaheshRavishankar   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
2016e888886cSMaheshRavishankar                   ConversionPatternRewriter &rewriter) const final {
2017e888886cSMaheshRavishankar     Block &firstBlock = op->getRegion(0).front();
2018e888886cSMaheshRavishankar     Operation *branchOp = firstBlock.getTerminator();
2019e888886cSMaheshRavishankar     Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
2020e888886cSMaheshRavishankar     rewriter.setInsertionPointToStart(secondBlock);
2021e888886cSMaheshRavishankar     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
2022e888886cSMaheshRavishankar     auto succOperands = branchOp->getOperands();
2023e888886cSMaheshRavishankar     SmallVector<Value, 2> replacements(succOperands);
2024e888886cSMaheshRavishankar     rewriter.eraseOp(branchOp);
2025e888886cSMaheshRavishankar     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
20265fcf907bSMatthias Springer     rewriter.modifyOpInPlace(op, [] {});
2027e888886cSMaheshRavishankar     return success();
2028e888886cSMaheshRavishankar   }
2029e888886cSMaheshRavishankar };
2030e888886cSMaheshRavishankar 
2031e888886cSMaheshRavishankar /// A rewrite mechanism to inline the body of the op into its parent, when both
2032e888886cSMaheshRavishankar /// ops can have a single block.
2033e888886cSMaheshRavishankar struct TestMergeSingleBlockOps
2034e888886cSMaheshRavishankar     : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
2035e888886cSMaheshRavishankar   using OpConversionPattern<
2036e888886cSMaheshRavishankar       SingleBlockImplicitTerminatorOp>::OpConversionPattern;
2037e888886cSMaheshRavishankar 
2038e888886cSMaheshRavishankar   LogicalResult
2039ef976337SRiver Riddle   matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
2040e888886cSMaheshRavishankar                   ConversionPatternRewriter &rewriter) const final {
2041e888886cSMaheshRavishankar     SingleBlockImplicitTerminatorOp parentOp =
20420bf4a82aSChristian Sigg         op->getParentOfType<SingleBlockImplicitTerminatorOp>();
2043e888886cSMaheshRavishankar     if (!parentOp)
2044e888886cSMaheshRavishankar       return failure();
20456a994233SJacques Pienaar     Block &innerBlock = op.getRegion().front();
2046e888886cSMaheshRavishankar     TerminatorOp innerTerminator =
2047e888886cSMaheshRavishankar         cast<TerminatorOp>(innerBlock.getTerminator());
204842c31d83SMatthias Springer     rewriter.inlineBlockBefore(&innerBlock, op);
2049e888886cSMaheshRavishankar     rewriter.eraseOp(innerTerminator);
2050e888886cSMaheshRavishankar     rewriter.eraseOp(op);
2051e888886cSMaheshRavishankar     return success();
2052e888886cSMaheshRavishankar   }
2053e888886cSMaheshRavishankar };
2054e888886cSMaheshRavishankar 
2055e888886cSMaheshRavishankar struct TestMergeBlocksPatternDriver
2056c3a8ac79SMatthias Springer     : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> {
20575e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver)
20585e50dd04SRiver Riddle 
2059b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-merge-blocks"; }
2060b5e22e6dSMehdi Amini   StringRef getDescription() const final {
2061b5e22e6dSMehdi Amini     return "Test Merging operation in ConversionPatternRewriter";
2062b5e22e6dSMehdi Amini   }
2063e888886cSMaheshRavishankar   void runOnOperation() override {
2064e888886cSMaheshRavishankar     MLIRContext *context = &getContext();
2065dc4e913bSChris Lattner     mlir::RewritePatternSet patterns(context);
2066dc4e913bSChris Lattner     patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
2067e888886cSMaheshRavishankar         context);
2068e888886cSMaheshRavishankar     ConversionTarget target(*context);
206958ceae95SRiver Riddle     target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
2070973ddb7dSMehdi Amini                       TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
2071e888886cSMaheshRavishankar     target.addIllegalOp<ILLegalOpF>();
2072e888886cSMaheshRavishankar 
2073e888886cSMaheshRavishankar     /// Expect the op to have a single block after legalization.
2074e888886cSMaheshRavishankar     target.addDynamicallyLegalOp<TestMergeBlocksOp>(
2075e888886cSMaheshRavishankar         [&](TestMergeBlocksOp op) -> bool {
20766a994233SJacques Pienaar           return llvm::hasSingleElement(op.getBody());
2077e888886cSMaheshRavishankar         });
2078e888886cSMaheshRavishankar 
2079e888886cSMaheshRavishankar     /// Only allow `test.br` within test.merge_blocks op.
2080e888886cSMaheshRavishankar     target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
20810bf4a82aSChristian Sigg       return op->getParentOfType<TestMergeBlocksOp>();
2082e888886cSMaheshRavishankar     });
2083e888886cSMaheshRavishankar 
2084e888886cSMaheshRavishankar     /// Expect that all nested test.SingleBlockImplicitTerminator ops are
2085e888886cSMaheshRavishankar     /// inlined.
2086e888886cSMaheshRavishankar     target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
2087e888886cSMaheshRavishankar         [&](SingleBlockImplicitTerminatorOp op) -> bool {
20880bf4a82aSChristian Sigg           return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
2089e888886cSMaheshRavishankar         });
2090e888886cSMaheshRavishankar 
2091e888886cSMaheshRavishankar     DenseSet<Operation *> unlegalizedOps;
2092a2821094SMatthias Springer     ConversionConfig config;
2093a2821094SMatthias Springer     config.unlegalizedOps = &unlegalizedOps;
20943fffffa8SRiver Riddle     (void)applyPartialConversion(getOperation(), target, std::move(patterns),
2095a2821094SMatthias Springer                                  config);
2096e888886cSMaheshRavishankar     for (auto *op : unlegalizedOps)
2097e888886cSMaheshRavishankar       op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
2098e888886cSMaheshRavishankar   }
2099e888886cSMaheshRavishankar };
2100e888886cSMaheshRavishankar } // namespace
2101e888886cSMaheshRavishankar 
21024589dd92SRiver Riddle //===----------------------------------------------------------------------===//
2103c8fb6ee3SRiver Riddle // Test Selective Replacement
2104c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===//
2105c8fb6ee3SRiver Riddle 
2106c8fb6ee3SRiver Riddle namespace {
2107c8fb6ee3SRiver Riddle /// A rewrite mechanism to inline the body of the op into its parent, when both
2108c8fb6ee3SRiver Riddle /// ops can have a single block.
2109c8fb6ee3SRiver Riddle struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
2110c8fb6ee3SRiver Riddle   using OpRewritePattern<TestCastOp>::OpRewritePattern;
2111c8fb6ee3SRiver Riddle 
2112c8fb6ee3SRiver Riddle   LogicalResult matchAndRewrite(TestCastOp op,
2113c8fb6ee3SRiver Riddle                                 PatternRewriter &rewriter) const final {
2114c8fb6ee3SRiver Riddle     if (op.getNumOperands() != 2)
2115c8fb6ee3SRiver Riddle       return failure();
2116c8fb6ee3SRiver Riddle     OperandRange operands = op.getOperands();
2117c8fb6ee3SRiver Riddle 
2118c8fb6ee3SRiver Riddle     // Replace non-terminator uses with the first operand.
211959a92019SMatthias Springer     rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) {
2120fe7c0d90SRiver Riddle       return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
2121c8fb6ee3SRiver Riddle     });
2122c8fb6ee3SRiver Riddle     // Replace everything else with the second operand if the operation isn't
2123c8fb6ee3SRiver Riddle     // dead.
2124c8fb6ee3SRiver Riddle     rewriter.replaceOp(op, op.getOperand(1));
2125c8fb6ee3SRiver Riddle     return success();
2126c8fb6ee3SRiver Riddle   }
2127c8fb6ee3SRiver Riddle };
2128c8fb6ee3SRiver Riddle 
2129c8fb6ee3SRiver Riddle struct TestSelectiveReplacementPatternDriver
2130c8fb6ee3SRiver Riddle     : public PassWrapper<TestSelectiveReplacementPatternDriver,
2131c8fb6ee3SRiver Riddle                          OperationPass<>> {
21325e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
21335e50dd04SRiver Riddle       TestSelectiveReplacementPatternDriver)
21345e50dd04SRiver Riddle 
2135b5e22e6dSMehdi Amini   StringRef getArgument() const final {
2136b5e22e6dSMehdi Amini     return "test-pattern-selective-replacement";
2137b5e22e6dSMehdi Amini   }
2138b5e22e6dSMehdi Amini   StringRef getDescription() const final {
2139b5e22e6dSMehdi Amini     return "Test selective replacement in the PatternRewriter";
2140b5e22e6dSMehdi Amini   }
2141c8fb6ee3SRiver Riddle   void runOnOperation() override {
2142c8fb6ee3SRiver Riddle     MLIRContext *context = &getContext();
2143dc4e913bSChris Lattner     mlir::RewritePatternSet patterns(context);
2144dc4e913bSChris Lattner     patterns.add<TestSelectiveOpReplacementPattern>(context);
214509dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
2146c8fb6ee3SRiver Riddle   }
2147c8fb6ee3SRiver Riddle };
2148c8fb6ee3SRiver Riddle } // namespace
2149c8fb6ee3SRiver Riddle 
2150c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===//
21514589dd92SRiver Riddle // PassRegistration
21524589dd92SRiver Riddle //===----------------------------------------------------------------------===//
21534589dd92SRiver Riddle 
2154fec6c5acSUday Bondhugula namespace mlir {
215572c65b69SAlexander Belyaev namespace test {
2156fec6c5acSUday Bondhugula void registerPatternsTestPass() {
2157b5e22e6dSMehdi Amini   PassRegistration<TestReturnTypeDriver>();
2158fec6c5acSUday Bondhugula 
2159b5e22e6dSMehdi Amini   PassRegistration<TestDerivedAttributeDriver>();
21609ba37b3bSJacques Pienaar 
21610f8a6b7dSJakub Kuderski   PassRegistration<TestGreedyPatternDriver>();
2162ba3a9f51SChia-hung Duan   PassRegistration<TestStrictPatternDriver>();
21630f8a6b7dSJakub Kuderski   PassRegistration<TestWalkPatternDriver>();
2164fec6c5acSUday Bondhugula 
2165b5e22e6dSMehdi Amini   PassRegistration<TestLegalizePatternDriver>([] {
2166b5e22e6dSMehdi Amini     return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
2167fec6c5acSUday Bondhugula   });
2168fec6c5acSUday Bondhugula 
2169b5e22e6dSMehdi Amini   PassRegistration<TestRemappedValue>();
217080d7ac3bSRiver Riddle 
2171b5e22e6dSMehdi Amini   PassRegistration<TestUnknownRootOpDriver>();
21724589dd92SRiver Riddle 
2173b5e22e6dSMehdi Amini   PassRegistration<TestTypeConversionDriver>();
2174d4a53f3bSAlex Zinenko   PassRegistration<TestTargetMaterializationWithNoUses>();
2175e888886cSMaheshRavishankar 
217688bc24a7SMathieu Fehr   PassRegistration<TestRewriteDynamicOpDriver>();
217788bc24a7SMathieu Fehr 
2178b5e22e6dSMehdi Amini   PassRegistration<TestMergeBlocksPatternDriver>();
2179b5e22e6dSMehdi Amini   PassRegistration<TestSelectiveReplacementPatternDriver>();
2180fec6c5acSUday Bondhugula }
218172c65b69SAlexander Belyaev } // namespace test
2182fec6c5acSUday Bondhugula } // namespace mlir
2183