19ed1e587SSean Silva //===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===// 29ed1e587SSean Silva // 39ed1e587SSean Silva // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 49ed1e587SSean Silva // See https://llvm.org/LICENSE.txt for license information. 59ed1e587SSean Silva // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 69ed1e587SSean Silva // 79ed1e587SSean Silva //===----------------------------------------------------------------------===// 89ed1e587SSean Silva 99ed1e587SSean Silva #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 109ed1e587SSean Silva 11ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 128b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 139ed1e587SSean Silva #include "mlir/Dialect/Shape/IR/Shape.h" 14444822d7SSean Silva #include "mlir/Dialect/Tensor/IR/Tensor.h" 159ed1e587SSean Silva #include "mlir/IR/PatternMatch.h" 169ed1e587SSean Silva #include "mlir/Pass/Pass.h" 179ed1e587SSean Silva #include "mlir/Pass/PassRegistry.h" 18b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 199ed1e587SSean Silva 2067d0d7acSMichele Scuttari namespace mlir { 2167d0d7acSMichele Scuttari #define GEN_PASS_DEF_CONVERTSHAPECONSTRAINTS 2267d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc" 2367d0d7acSMichele Scuttari } // namespace mlir 2467d0d7acSMichele Scuttari 259ed1e587SSean Silva using namespace mlir; 2667d0d7acSMichele Scuttari 279ed1e587SSean Silva namespace { 283842d4b6STres Popp #include "ShapeToStandard.cpp.inc" 299ed1e587SSean Silva } // namespace 309ed1e587SSean Silva 319ed1e587SSean Silva namespace { 329ed1e587SSean Silva class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> { 339ed1e587SSean Silva public: 349ed1e587SSean Silva using OpRewritePattern::OpRewritePattern; 359ed1e587SSean Silva LogicalResult matchAndRewrite(shape::CstrRequireOp op, 369ed1e587SSean Silva PatternRewriter &rewriter) const override { 37ace01605SRiver Riddle rewriter.create<cf::AssertOp>(op.getLoc(), op.getPred(), op.getMsgAttr()); 389ed1e587SSean Silva rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true); 399ed1e587SSean Silva return success(); 409ed1e587SSean Silva } 419ed1e587SSean Silva }; 429ed1e587SSean Silva } // namespace 439ed1e587SSean Silva 449ed1e587SSean Silva void mlir::populateConvertShapeConstraintsConversionPatterns( 45dc4e913bSChris Lattner RewritePatternSet &patterns) { 46dc4e913bSChris Lattner patterns.add<CstrBroadcastableToRequire>(patterns.getContext()); 47dc4e913bSChris Lattner patterns.add<CstrEqToRequire>(patterns.getContext()); 48dc4e913bSChris Lattner patterns.add<ConvertCstrRequireOp>(patterns.getContext()); 499ed1e587SSean Silva } 509ed1e587SSean Silva 519ed1e587SSean Silva namespace { 529ed1e587SSean Silva // This pass eliminates shape constraints from the program, converting them to 539ed1e587SSean Silva // eager (side-effecting) error handling code. After eager error handling code 549ed1e587SSean Silva // is emitted, witnesses are satisfied, so they are replace with 559ed1e587SSean Silva // `shape.const_witness true`. 56039b969bSMichele Scuttari class ConvertShapeConstraints 5767d0d7acSMichele Scuttari : public impl::ConvertShapeConstraintsBase<ConvertShapeConstraints> { 582d128b04SRahul Joshi void runOnOperation() override { 59ceefc261SMehdi Amini auto *func = getOperation(); 609ed1e587SSean Silva auto *context = &getContext(); 619ed1e587SSean Silva 62dc4e913bSChris Lattner RewritePatternSet patterns(context); 633a506b31SChris Lattner populateConvertShapeConstraintsConversionPatterns(patterns); 649ed1e587SSean Silva 65*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(func, std::move(patterns)))) 669ed1e587SSean Silva return signalPassFailure(); 679ed1e587SSean Silva } 689ed1e587SSean Silva }; 699ed1e587SSean Silva } // namespace 70039b969bSMichele Scuttari 71039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createConvertShapeConstraintsPass() { 72039b969bSMichele Scuttari return std::make_unique<ConvertShapeConstraints>(); 73039b969bSMichele Scuttari } 74