1 //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Tensor/IR/Tensor.h" 12 #include "mlir/Interfaces/ValueBoundsOpInterface.h" 13 14 using namespace mlir; 15 16 namespace mlir { 17 namespace tensor { 18 namespace { 19 20 struct CastOpInterface 21 : public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> { 22 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, 23 ValueBoundsConstraintSet &cstr) const { 24 auto castOp = cast<CastOp>(op); 25 assert(value == castOp.getResult() && "invalid value"); 26 27 if (llvm::isa<RankedTensorType>(castOp.getResult().getType()) && 28 llvm::isa<RankedTensorType>(castOp.getSource().getType())) { 29 cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim); 30 } 31 } 32 }; 33 34 struct DimOpInterface 35 : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> { 36 void populateBoundsForIndexValue(Operation *op, Value value, 37 ValueBoundsConstraintSet &cstr) const { 38 auto dimOp = cast<DimOp>(op); 39 assert(value == dimOp.getResult() && "invalid value"); 40 41 cstr.bound(value) >= 0; 42 auto constIndex = dimOp.getConstantIndex(); 43 if (!constIndex.has_value()) 44 return; 45 cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex); 46 } 47 }; 48 49 struct EmptyOpInterface 50 : public ValueBoundsOpInterface::ExternalModel<EmptyOpInterface, EmptyOp> { 51 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, 52 ValueBoundsConstraintSet &cstr) const { 53 auto emptyOp = cast<EmptyOp>(op); 54 assert(value == emptyOp.getResult() && "invalid value"); 55 56 cstr.bound(value)[dim] == emptyOp.getMixedSizes()[dim]; 57 } 58 }; 59 60 struct ExtractSliceOpInterface 61 : public ValueBoundsOpInterface::ExternalModel<ExtractSliceOpInterface, 62 ExtractSliceOp> { 63 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, 64 ValueBoundsConstraintSet &cstr) const { 65 auto extractSliceOp = cast<ExtractSliceOp>(op); 66 assert(value == extractSliceOp.getResult() && "invalid value"); 67 68 llvm::SmallBitVector dropped = extractSliceOp.getDroppedDims(); 69 int64_t ctr = -1; 70 for (int64_t i = 0, e = extractSliceOp.getMixedSizes().size(); i < e; ++i) { 71 // Skip over rank-reduced dimensions. 72 if (!dropped.test(i)) 73 ++ctr; 74 if (ctr == dim) { 75 cstr.bound(value)[dim] == extractSliceOp.getMixedSizes()[i]; 76 return; 77 } 78 } 79 llvm_unreachable("could not find non-rank-reduced dim"); 80 } 81 }; 82 83 struct PadOpInterface 84 : public ValueBoundsOpInterface::ExternalModel<PadOpInterface, PadOp> { 85 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, 86 ValueBoundsConstraintSet &cstr) const { 87 auto padOp = cast<PadOp>(op); 88 assert(value == padOp.getResult() && "invalid value"); 89 90 AffineExpr srcSize = cstr.getExpr(padOp.getSource(), dim); 91 AffineExpr lowPad = cstr.getExpr(padOp.getMixedLowPad()[dim]); 92 AffineExpr highPad = cstr.getExpr(padOp.getMixedHighPad()[dim]); 93 cstr.bound(value)[dim] == srcSize + lowPad + highPad; 94 } 95 }; 96 97 struct RankOpInterface 98 : public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> { 99 void populateBoundsForIndexValue(Operation *op, Value value, 100 ValueBoundsConstraintSet &cstr) const { 101 auto rankOp = cast<RankOp>(op); 102 assert(value == rankOp.getResult() && "invalid value"); 103 104 auto tensorType = 105 llvm::dyn_cast<RankedTensorType>(rankOp.getTensor().getType()); 106 if (!tensorType) 107 return; 108 cstr.bound(value) == tensorType.getRank(); 109 } 110 }; 111 112 } // namespace 113 } // namespace tensor 114 } // namespace mlir 115 116 void mlir::tensor::registerValueBoundsOpInterfaceExternalModels( 117 DialectRegistry ®istry) { 118 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 119 tensor::CastOp::attachInterface<tensor::CastOpInterface>(*ctx); 120 tensor::DimOp::attachInterface<tensor::DimOpInterface>(*ctx); 121 tensor::EmptyOp::attachInterface<tensor::EmptyOpInterface>(*ctx); 122 tensor::ExtractSliceOp::attachInterface<tensor::ExtractSliceOpInterface>( 123 *ctx); 124 tensor::PadOp::attachInterface<tensor::PadOpInterface>(*ctx); 125 tensor::RankOp::attachInterface<tensor::RankOpInterface>(*ctx); 126 // Note: ValueBoundsOpInterface implementation is not required for ops that 127 // implement `DestinationStyleOpInterface` (for querying shaped OpResults). 128 }); 129 } 130