1 //===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 12 #include "mlir/Dialect/SCF/IR/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Pass/PassRegistry.h" 18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 19 20 namespace mlir { 21 #define GEN_PASS_DEF_CONVERTSHAPECONSTRAINTS 22 #include "mlir/Conversion/Passes.h.inc" 23 } // namespace mlir 24 25 using namespace mlir; 26 27 namespace { 28 #include "ShapeToStandard.cpp.inc" 29 } // namespace 30 31 namespace { 32 class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> { 33 public: 34 using OpRewritePattern::OpRewritePattern; 35 LogicalResult matchAndRewrite(shape::CstrRequireOp op, 36 PatternRewriter &rewriter) const override { 37 rewriter.create<cf::AssertOp>(op.getLoc(), op.getPred(), op.getMsgAttr()); 38 rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true); 39 return success(); 40 } 41 }; 42 } // namespace 43 44 void mlir::populateConvertShapeConstraintsConversionPatterns( 45 RewritePatternSet &patterns) { 46 patterns.add<CstrBroadcastableToRequire>(patterns.getContext()); 47 patterns.add<CstrEqToRequire>(patterns.getContext()); 48 patterns.add<ConvertCstrRequireOp>(patterns.getContext()); 49 } 50 51 namespace { 52 // This pass eliminates shape constraints from the program, converting them to 53 // eager (side-effecting) error handling code. After eager error handling code 54 // is emitted, witnesses are satisfied, so they are replace with 55 // `shape.const_witness true`. 56 class ConvertShapeConstraints 57 : public impl::ConvertShapeConstraintsBase<ConvertShapeConstraints> { 58 void runOnOperation() override { 59 auto *func = getOperation(); 60 auto *context = &getContext(); 61 62 RewritePatternSet patterns(context); 63 populateConvertShapeConstraintsConversionPatterns(patterns); 64 65 if (failed(applyPatternsGreedily(func, std::move(patterns)))) 66 return signalPassFailure(); 67 } 68 }; 69 } // namespace 70 71 std::unique_ptr<Pass> mlir::createConvertShapeConstraintsPass() { 72 return std::make_unique<ConvertShapeConstraints>(); 73 } 74