xref: /llvm-project/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- TestMakeIsolatedFromAbove.cpp - Test makeIsolatedFromAbove method -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestOps.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/IR/PatternMatch.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15 #include "mlir/Transforms/RegionUtils.h"
16 
17 using namespace mlir;
18 
19 /// Helper function to call the `makeRegionIsolatedFromAbove` to convert
20 /// `test.one_region_op` to `test.isolated_one_region_op`.
21 static LogicalResult
22 makeIsolatedFromAboveImpl(RewriterBase &rewriter,
23                           test::OneRegionWithOperandsOp regionOp,
24                           llvm::function_ref<bool(Operation *)> callBack) {
25   Region &region = regionOp.getRegion();
26   SmallVector<Value> capturedValues =
27       makeRegionIsolatedFromAbove(rewriter, region, callBack);
28   SmallVector<Value> operands = regionOp.getOperands();
29   operands.append(capturedValues);
30   auto isolatedRegionOp =
31       rewriter.create<test::IsolatedOneRegionOp>(regionOp.getLoc(), operands);
32   rewriter.inlineRegionBefore(region, isolatedRegionOp.getRegion(),
33                               isolatedRegionOp.getRegion().begin());
34   rewriter.eraseOp(regionOp);
35   return success();
36 }
37 
38 namespace {
39 
40 /// Simple test for making region isolated from above without cloning any
41 /// operations.
42 struct SimpleMakeIsolatedFromAbove
43     : OpRewritePattern<test::OneRegionWithOperandsOp> {
44   using OpRewritePattern::OpRewritePattern;
45 
46   LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp,
47                                 PatternRewriter &rewriter) const override {
48     return makeIsolatedFromAboveImpl(rewriter, regionOp,
49                                      [](Operation *) { return false; });
50   }
51 };
52 
53 /// Test for making region isolated from above while clong operations
54 /// with no operands.
55 struct MakeIsolatedFromAboveAndCloneOpsWithNoOperands
56     : OpRewritePattern<test::OneRegionWithOperandsOp> {
57   using OpRewritePattern::OpRewritePattern;
58 
59   LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp,
60                                 PatternRewriter &rewriter) const override {
61     return makeIsolatedFromAboveImpl(rewriter, regionOp, [](Operation *op) {
62       return op->getNumOperands() == 0;
63     });
64   }
65 };
66 
67 /// Test for making region isolated from above while clong operations
68 /// with no operands.
69 struct MakeIsolatedFromAboveAndCloneOpsWithOperands
70     : OpRewritePattern<test::OneRegionWithOperandsOp> {
71   using OpRewritePattern::OpRewritePattern;
72 
73   LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp,
74                                 PatternRewriter &rewriter) const override {
75     return makeIsolatedFromAboveImpl(rewriter, regionOp,
76                                      [](Operation *op) { return true; });
77   }
78 };
79 
80 /// Test pass for testing the `makeIsolatedFromAbove` function.
81 struct TestMakeIsolatedFromAbovePass
82     : public PassWrapper<TestMakeIsolatedFromAbovePass,
83                          OperationPass<func::FuncOp>> {
84 
85   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMakeIsolatedFromAbovePass)
86 
87   TestMakeIsolatedFromAbovePass() = default;
88   TestMakeIsolatedFromAbovePass(const TestMakeIsolatedFromAbovePass &pass)
89       : PassWrapper(pass) {}
90 
91   StringRef getArgument() const final {
92     return "test-make-isolated-from-above";
93   }
94 
95   StringRef getDescription() const final {
96     return "Test making a region isolated from above";
97   }
98 
99   Option<bool> simple{
100       *this, "simple",
101       llvm::cl::desc("Test simple case with no cloning of operations"),
102       llvm::cl::init(false)};
103 
104   Option<bool> cloneOpsWithNoOperands{
105       *this, "clone-ops-with-no-operands",
106       llvm::cl::desc("Test case with cloning of operations with no operands"),
107       llvm::cl::init(false)};
108 
109   Option<bool> cloneOpsWithOperands{
110       *this, "clone-ops-with-operands",
111       llvm::cl::desc("Test case with cloning of operations with no operands"),
112       llvm::cl::init(false)};
113 
114   void runOnOperation() override;
115 };
116 
117 } // namespace
118 
119 void TestMakeIsolatedFromAbovePass::runOnOperation() {
120   MLIRContext *context = &getContext();
121   func::FuncOp funcOp = getOperation();
122 
123   if (simple) {
124     RewritePatternSet patterns(context);
125     patterns.insert<SimpleMakeIsolatedFromAbove>(context);
126     if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
127       return signalPassFailure();
128     }
129     return;
130   }
131 
132   if (cloneOpsWithNoOperands) {
133     RewritePatternSet patterns(context);
134     patterns.insert<MakeIsolatedFromAboveAndCloneOpsWithNoOperands>(context);
135     if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
136       return signalPassFailure();
137     }
138     return;
139   }
140 
141   if (cloneOpsWithOperands) {
142     RewritePatternSet patterns(context);
143     patterns.insert<MakeIsolatedFromAboveAndCloneOpsWithOperands>(context);
144     if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
145       return signalPassFailure();
146     }
147     return;
148   }
149 }
150 
151 namespace mlir {
152 namespace test {
153 void registerTestMakeIsolatedFromAbovePass() {
154   PassRegistration<TestMakeIsolatedFromAbovePass>();
155 }
156 } // namespace test
157 } // namespace mlir
158