1 //===- TestMakeIsolatedFromAbove.cpp - Test makeIsolatedFromAbove method -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestOps.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/IR/PatternMatch.h" 13 #include "mlir/Pass/Pass.h" 14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 15 #include "mlir/Transforms/RegionUtils.h" 16 17 using namespace mlir; 18 19 /// Helper function to call the `makeRegionIsolatedFromAbove` to convert 20 /// `test.one_region_op` to `test.isolated_one_region_op`. 21 static LogicalResult 22 makeIsolatedFromAboveImpl(RewriterBase &rewriter, 23 test::OneRegionWithOperandsOp regionOp, 24 llvm::function_ref<bool(Operation *)> callBack) { 25 Region ®ion = regionOp.getRegion(); 26 SmallVector<Value> capturedValues = 27 makeRegionIsolatedFromAbove(rewriter, region, callBack); 28 SmallVector<Value> operands = regionOp.getOperands(); 29 operands.append(capturedValues); 30 auto isolatedRegionOp = 31 rewriter.create<test::IsolatedOneRegionOp>(regionOp.getLoc(), operands); 32 rewriter.inlineRegionBefore(region, isolatedRegionOp.getRegion(), 33 isolatedRegionOp.getRegion().begin()); 34 rewriter.eraseOp(regionOp); 35 return success(); 36 } 37 38 namespace { 39 40 /// Simple test for making region isolated from above without cloning any 41 /// operations. 42 struct SimpleMakeIsolatedFromAbove 43 : OpRewritePattern<test::OneRegionWithOperandsOp> { 44 using OpRewritePattern::OpRewritePattern; 45 46 LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp, 47 PatternRewriter &rewriter) const override { 48 return makeIsolatedFromAboveImpl(rewriter, regionOp, 49 [](Operation *) { return false; }); 50 } 51 }; 52 53 /// Test for making region isolated from above while clong operations 54 /// with no operands. 55 struct MakeIsolatedFromAboveAndCloneOpsWithNoOperands 56 : OpRewritePattern<test::OneRegionWithOperandsOp> { 57 using OpRewritePattern::OpRewritePattern; 58 59 LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp, 60 PatternRewriter &rewriter) const override { 61 return makeIsolatedFromAboveImpl(rewriter, regionOp, [](Operation *op) { 62 return op->getNumOperands() == 0; 63 }); 64 } 65 }; 66 67 /// Test for making region isolated from above while clong operations 68 /// with no operands. 69 struct MakeIsolatedFromAboveAndCloneOpsWithOperands 70 : OpRewritePattern<test::OneRegionWithOperandsOp> { 71 using OpRewritePattern::OpRewritePattern; 72 73 LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp, 74 PatternRewriter &rewriter) const override { 75 return makeIsolatedFromAboveImpl(rewriter, regionOp, 76 [](Operation *op) { return true; }); 77 } 78 }; 79 80 /// Test pass for testing the `makeIsolatedFromAbove` function. 81 struct TestMakeIsolatedFromAbovePass 82 : public PassWrapper<TestMakeIsolatedFromAbovePass, 83 OperationPass<func::FuncOp>> { 84 85 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMakeIsolatedFromAbovePass) 86 87 TestMakeIsolatedFromAbovePass() = default; 88 TestMakeIsolatedFromAbovePass(const TestMakeIsolatedFromAbovePass &pass) 89 : PassWrapper(pass) {} 90 91 StringRef getArgument() const final { 92 return "test-make-isolated-from-above"; 93 } 94 95 StringRef getDescription() const final { 96 return "Test making a region isolated from above"; 97 } 98 99 Option<bool> simple{ 100 *this, "simple", 101 llvm::cl::desc("Test simple case with no cloning of operations"), 102 llvm::cl::init(false)}; 103 104 Option<bool> cloneOpsWithNoOperands{ 105 *this, "clone-ops-with-no-operands", 106 llvm::cl::desc("Test case with cloning of operations with no operands"), 107 llvm::cl::init(false)}; 108 109 Option<bool> cloneOpsWithOperands{ 110 *this, "clone-ops-with-operands", 111 llvm::cl::desc("Test case with cloning of operations with no operands"), 112 llvm::cl::init(false)}; 113 114 void runOnOperation() override; 115 }; 116 117 } // namespace 118 119 void TestMakeIsolatedFromAbovePass::runOnOperation() { 120 MLIRContext *context = &getContext(); 121 func::FuncOp funcOp = getOperation(); 122 123 if (simple) { 124 RewritePatternSet patterns(context); 125 patterns.insert<SimpleMakeIsolatedFromAbove>(context); 126 if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { 127 return signalPassFailure(); 128 } 129 return; 130 } 131 132 if (cloneOpsWithNoOperands) { 133 RewritePatternSet patterns(context); 134 patterns.insert<MakeIsolatedFromAboveAndCloneOpsWithNoOperands>(context); 135 if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { 136 return signalPassFailure(); 137 } 138 return; 139 } 140 141 if (cloneOpsWithOperands) { 142 RewritePatternSet patterns(context); 143 patterns.insert<MakeIsolatedFromAboveAndCloneOpsWithOperands>(context); 144 if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { 145 return signalPassFailure(); 146 } 147 return; 148 } 149 } 150 151 namespace mlir { 152 namespace test { 153 void registerTestMakeIsolatedFromAbovePass() { 154 PassRegistration<TestMakeIsolatedFromAbovePass>(); 155 } 156 } // namespace test 157 } // namespace mlir 158