1 //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Interfaces/ValueBoundsOpInterface.h" 13 14 using namespace mlir; 15 16 namespace mlir { 17 namespace memref { 18 namespace { 19 20 template <typename OpTy> 21 struct AllocOpInterface 22 : public ValueBoundsOpInterface::ExternalModel<AllocOpInterface<OpTy>, 23 OpTy> { 24 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, 25 ValueBoundsConstraintSet &cstr) const { 26 auto allocOp = cast<OpTy>(op); 27 assert(value == allocOp.getResult() && "invalid value"); 28 29 cstr.bound(value)[dim] == allocOp.getMixedSizes()[dim]; 30 } 31 }; 32 33 struct CastOpInterface 34 : public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> { 35 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, 36 ValueBoundsConstraintSet &cstr) const { 37 auto castOp = cast<CastOp>(op); 38 assert(value == castOp.getResult() && "invalid value"); 39 40 if (llvm::isa<MemRefType>(castOp.getResult().getType()) && 41 llvm::isa<MemRefType>(castOp.getSource().getType())) { 42 cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim); 43 } 44 } 45 }; 46 47 struct DimOpInterface 48 : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> { 49 void populateBoundsForIndexValue(Operation *op, Value value, 50 ValueBoundsConstraintSet &cstr) const { 51 auto dimOp = cast<DimOp>(op); 52 assert(value == dimOp.getResult() && "invalid value"); 53 54 cstr.bound(value) >= 0; 55 auto constIndex = dimOp.getConstantIndex(); 56 if (!constIndex.has_value()) 57 return; 58 cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex); 59 } 60 }; 61 62 struct GetGlobalOpInterface 63 : public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface, 64 GetGlobalOp> { 65 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, 66 ValueBoundsConstraintSet &cstr) const { 67 auto getGlobalOp = cast<GetGlobalOp>(op); 68 assert(value == getGlobalOp.getResult() && "invalid value"); 69 70 auto type = getGlobalOp.getType(); 71 assert(!type.isDynamicDim(dim) && "expected static dim"); 72 cstr.bound(value)[dim] == type.getDimSize(dim); 73 } 74 }; 75 76 struct RankOpInterface 77 : public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> { 78 void populateBoundsForIndexValue(Operation *op, Value value, 79 ValueBoundsConstraintSet &cstr) const { 80 auto rankOp = cast<RankOp>(op); 81 assert(value == rankOp.getResult() && "invalid value"); 82 83 auto memrefType = llvm::dyn_cast<MemRefType>(rankOp.getMemref().getType()); 84 if (!memrefType) 85 return; 86 cstr.bound(value) == memrefType.getRank(); 87 } 88 }; 89 90 struct SubViewOpInterface 91 : public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface, 92 SubViewOp> { 93 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, 94 ValueBoundsConstraintSet &cstr) const { 95 auto subViewOp = cast<SubViewOp>(op); 96 assert(value == subViewOp.getResult() && "invalid value"); 97 98 llvm::SmallBitVector dropped = subViewOp.getDroppedDims(); 99 int64_t ctr = -1; 100 for (int64_t i = 0, e = subViewOp.getMixedSizes().size(); i < e; ++i) { 101 // Skip over rank-reduced dimensions. 102 if (!dropped.test(i)) 103 ++ctr; 104 if (ctr == dim) { 105 cstr.bound(value)[dim] == subViewOp.getMixedSizes()[i]; 106 return; 107 } 108 } 109 llvm_unreachable("could not find non-rank-reduced dim"); 110 } 111 }; 112 113 } // namespace 114 } // namespace memref 115 } // namespace mlir 116 117 void mlir::memref::registerValueBoundsOpInterfaceExternalModels( 118 DialectRegistry ®istry) { 119 registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { 120 memref::AllocOp::attachInterface<memref::AllocOpInterface<memref::AllocOp>>( 121 *ctx); 122 memref::AllocaOp::attachInterface< 123 memref::AllocOpInterface<memref::AllocaOp>>(*ctx); 124 memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx); 125 memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx); 126 memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx); 127 memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx); 128 memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx); 129 }); 130 } 131