xref: /llvm-project/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
133245988STres Popp //===-- RemoveShapeConstraints.cpp - Remove Shape Cstr and Assuming Ops ---===//
233245988STres Popp //
333245988STres Popp // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
433245988STres Popp // See https://llvm.org/LICENSE.txt for license information.
533245988STres Popp // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
633245988STres Popp //
733245988STres Popp //===----------------------------------------------------------------------===//
833245988STres Popp 
9039b969bSMichele Scuttari #include "mlir/Dialect/Shape/Transforms/Passes.h"
1067d0d7acSMichele Scuttari 
1167d0d7acSMichele Scuttari #include "mlir/Dialect/Func/IR/FuncOps.h"
1267d0d7acSMichele Scuttari #include "mlir/Dialect/Shape/IR/Shape.h"
1333245988STres Popp #include "mlir/Transforms/DialectConversion.h"
14b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1533245988STres Popp 
1667d0d7acSMichele Scuttari namespace mlir {
1767d0d7acSMichele Scuttari #define GEN_PASS_DEF_REMOVESHAPECONSTRAINTS
1867d0d7acSMichele Scuttari #include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
1967d0d7acSMichele Scuttari } // namespace mlir
2067d0d7acSMichele Scuttari 
2133245988STres Popp using namespace mlir;
2233245988STres Popp 
2333245988STres Popp namespace {
2433245988STres Popp /// Removal patterns.
2533245988STres Popp class RemoveCstrBroadcastableOp
2633245988STres Popp     : public OpRewritePattern<shape::CstrBroadcastableOp> {
2733245988STres Popp public:
2833245988STres Popp   using OpRewritePattern::OpRewritePattern;
2933245988STres Popp 
3033245988STres Popp   LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
3133245988STres Popp                                 PatternRewriter &rewriter) const override {
3233245988STres Popp     rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
3333245988STres Popp     return success();
3433245988STres Popp   }
3533245988STres Popp };
3633245988STres Popp 
3733245988STres Popp class RemoveCstrEqOp : public OpRewritePattern<shape::CstrEqOp> {
3833245988STres Popp public:
3933245988STres Popp   using OpRewritePattern::OpRewritePattern;
4033245988STres Popp 
4133245988STres Popp   LogicalResult matchAndRewrite(shape::CstrEqOp op,
4233245988STres Popp                                 PatternRewriter &rewriter) const override {
4333245988STres Popp     rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
4433245988STres Popp     return success();
4533245988STres Popp   }
4633245988STres Popp };
4733245988STres Popp 
4833245988STres Popp /// Removal pass.
49039b969bSMichele Scuttari class RemoveShapeConstraintsPass
5067d0d7acSMichele Scuttari     : public impl::RemoveShapeConstraintsBase<RemoveShapeConstraintsPass> {
5133245988STres Popp 
5241574554SRiver Riddle   void runOnOperation() override {
5333245988STres Popp     MLIRContext &ctx = getContext();
5433245988STres Popp 
55dc4e913bSChris Lattner     RewritePatternSet patterns(&ctx);
563a506b31SChris Lattner     populateRemoveShapeConstraintsPatterns(patterns);
5733245988STres Popp 
58*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
5933245988STres Popp   }
6033245988STres Popp };
6133245988STres Popp 
6233245988STres Popp } // namespace
6333245988STres Popp 
64dc4e913bSChris Lattner void mlir::populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns) {
65dc4e913bSChris Lattner   patterns.add<RemoveCstrBroadcastableOp, RemoveCstrEqOp>(
663a506b31SChris Lattner       patterns.getContext());
6733245988STres Popp }
6833245988STres Popp 
6958ceae95SRiver Riddle std::unique_ptr<OperationPass<func::FuncOp>>
7041574554SRiver Riddle mlir::createRemoveShapeConstraintsPass() {
7133245988STres Popp   return std::make_unique<RemoveShapeConstraintsPass>();
7233245988STres Popp }
73