133245988STres Popp //===-- RemoveShapeConstraints.cpp - Remove Shape Cstr and Assuming Ops ---===// 233245988STres Popp // 333245988STres Popp // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 433245988STres Popp // See https://llvm.org/LICENSE.txt for license information. 533245988STres Popp // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 633245988STres Popp // 733245988STres Popp //===----------------------------------------------------------------------===// 833245988STres Popp 9039b969bSMichele Scuttari #include "mlir/Dialect/Shape/Transforms/Passes.h" 1067d0d7acSMichele Scuttari 1167d0d7acSMichele Scuttari #include "mlir/Dialect/Func/IR/FuncOps.h" 1267d0d7acSMichele Scuttari #include "mlir/Dialect/Shape/IR/Shape.h" 1333245988STres Popp #include "mlir/Transforms/DialectConversion.h" 14b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 1533245988STres Popp 1667d0d7acSMichele Scuttari namespace mlir { 1767d0d7acSMichele Scuttari #define GEN_PASS_DEF_REMOVESHAPECONSTRAINTS 1867d0d7acSMichele Scuttari #include "mlir/Dialect/Shape/Transforms/Passes.h.inc" 1967d0d7acSMichele Scuttari } // namespace mlir 2067d0d7acSMichele Scuttari 2133245988STres Popp using namespace mlir; 2233245988STres Popp 2333245988STres Popp namespace { 2433245988STres Popp /// Removal patterns. 2533245988STres Popp class RemoveCstrBroadcastableOp 2633245988STres Popp : public OpRewritePattern<shape::CstrBroadcastableOp> { 2733245988STres Popp public: 2833245988STres Popp using OpRewritePattern::OpRewritePattern; 2933245988STres Popp 3033245988STres Popp LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, 3133245988STres Popp PatternRewriter &rewriter) const override { 3233245988STres Popp rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true); 3333245988STres Popp return success(); 3433245988STres Popp } 3533245988STres Popp }; 3633245988STres Popp 3733245988STres Popp class RemoveCstrEqOp : public OpRewritePattern<shape::CstrEqOp> { 3833245988STres Popp public: 3933245988STres Popp using OpRewritePattern::OpRewritePattern; 4033245988STres Popp 4133245988STres Popp LogicalResult matchAndRewrite(shape::CstrEqOp op, 4233245988STres Popp PatternRewriter &rewriter) const override { 4333245988STres Popp rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true); 4433245988STres Popp return success(); 4533245988STres Popp } 4633245988STres Popp }; 4733245988STres Popp 4833245988STres Popp /// Removal pass. 49039b969bSMichele Scuttari class RemoveShapeConstraintsPass 5067d0d7acSMichele Scuttari : public impl::RemoveShapeConstraintsBase<RemoveShapeConstraintsPass> { 5133245988STres Popp 5241574554SRiver Riddle void runOnOperation() override { 5333245988STres Popp MLIRContext &ctx = getContext(); 5433245988STres Popp 55dc4e913bSChris Lattner RewritePatternSet patterns(&ctx); 563a506b31SChris Lattner populateRemoveShapeConstraintsPatterns(patterns); 5733245988STres Popp 58*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 5933245988STres Popp } 6033245988STres Popp }; 6133245988STres Popp 6233245988STres Popp } // namespace 6333245988STres Popp 64dc4e913bSChris Lattner void mlir::populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns) { 65dc4e913bSChris Lattner patterns.add<RemoveCstrBroadcastableOp, RemoveCstrEqOp>( 663a506b31SChris Lattner patterns.getContext()); 6733245988STres Popp } 6833245988STres Popp 6958ceae95SRiver Riddle std::unique_ptr<OperationPass<func::FuncOp>> 7041574554SRiver Riddle mlir::createRemoveShapeConstraintsPass() { 7133245988STres Popp return std::make_unique<RemoveShapeConstraintsPass>(); 7233245988STres Popp } 73