1 //===-- RemoveShapeConstraints.cpp - Remove Shape Cstr and Assuming Ops ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/Transforms/Passes.h" 10 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/Shape/IR/Shape.h" 13 #include "mlir/Transforms/DialectConversion.h" 14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 15 16 namespace mlir { 17 #define GEN_PASS_DEF_REMOVESHAPECONSTRAINTS 18 #include "mlir/Dialect/Shape/Transforms/Passes.h.inc" 19 } // namespace mlir 20 21 using namespace mlir; 22 23 namespace { 24 /// Removal patterns. 25 class RemoveCstrBroadcastableOp 26 : public OpRewritePattern<shape::CstrBroadcastableOp> { 27 public: 28 using OpRewritePattern::OpRewritePattern; 29 30 LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, 31 PatternRewriter &rewriter) const override { 32 rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true); 33 return success(); 34 } 35 }; 36 37 class RemoveCstrEqOp : public OpRewritePattern<shape::CstrEqOp> { 38 public: 39 using OpRewritePattern::OpRewritePattern; 40 41 LogicalResult matchAndRewrite(shape::CstrEqOp op, 42 PatternRewriter &rewriter) const override { 43 rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true); 44 return success(); 45 } 46 }; 47 48 /// Removal pass. 49 class RemoveShapeConstraintsPass 50 : public impl::RemoveShapeConstraintsBase<RemoveShapeConstraintsPass> { 51 52 void runOnOperation() override { 53 MLIRContext &ctx = getContext(); 54 55 RewritePatternSet patterns(&ctx); 56 populateRemoveShapeConstraintsPatterns(patterns); 57 58 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 59 } 60 }; 61 62 } // namespace 63 64 void mlir::populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns) { 65 patterns.add<RemoveCstrBroadcastableOp, RemoveCstrEqOp>( 66 patterns.getContext()); 67 } 68 69 std::unique_ptr<OperationPass<func::FuncOp>> 70 mlir::createRemoveShapeConstraintsPass() { 71 return std::make_unique<RemoveShapeConstraintsPass>(); 72 } 73