xref: /llvm-project/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===-- RemoveShapeConstraints.cpp - Remove Shape Cstr and Assuming Ops ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/Transforms/Passes.h"
10 
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/Shape/IR/Shape.h"
13 #include "mlir/Transforms/DialectConversion.h"
14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15 
16 namespace mlir {
17 #define GEN_PASS_DEF_REMOVESHAPECONSTRAINTS
18 #include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
19 } // namespace mlir
20 
21 using namespace mlir;
22 
23 namespace {
24 /// Removal patterns.
25 class RemoveCstrBroadcastableOp
26     : public OpRewritePattern<shape::CstrBroadcastableOp> {
27 public:
28   using OpRewritePattern::OpRewritePattern;
29 
30   LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
31                                 PatternRewriter &rewriter) const override {
32     rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
33     return success();
34   }
35 };
36 
37 class RemoveCstrEqOp : public OpRewritePattern<shape::CstrEqOp> {
38 public:
39   using OpRewritePattern::OpRewritePattern;
40 
41   LogicalResult matchAndRewrite(shape::CstrEqOp op,
42                                 PatternRewriter &rewriter) const override {
43     rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
44     return success();
45   }
46 };
47 
48 /// Removal pass.
49 class RemoveShapeConstraintsPass
50     : public impl::RemoveShapeConstraintsBase<RemoveShapeConstraintsPass> {
51 
52   void runOnOperation() override {
53     MLIRContext &ctx = getContext();
54 
55     RewritePatternSet patterns(&ctx);
56     populateRemoveShapeConstraintsPatterns(patterns);
57 
58     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
59   }
60 };
61 
62 } // namespace
63 
64 void mlir::populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns) {
65   patterns.add<RemoveCstrBroadcastableOp, RemoveCstrEqOp>(
66       patterns.getContext());
67 }
68 
69 std::unique_ptr<OperationPass<func::FuncOp>>
70 mlir::createRemoveShapeConstraintsPass() {
71   return std::make_unique<RemoveShapeConstraintsPass>();
72 }
73