1 //===- InferTypeOpImpl.cpp - InferType Interface external models *- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" 10 #include "mlir/Dialect/Affine/IR/AffineOps.h" 11 #include "mlir/Dialect/Arith/Utils/Utils.h" 12 #include "mlir/Dialect/Tensor/IR/Tensor.h" 13 #include "mlir/Dialect/Utils/StaticValueUtils.h" 14 #include "mlir/Interfaces/InferTypeOpInterface.h" 15 16 using namespace mlir; 17 using namespace mlir::tensor; 18 19 /// For reshape op compute the shape at dimension `dimIndex` of the output in 20 /// terms of shape of the `src`, when the reshape op is a collapsing 21 /// operation. It is the product of the shape of the collapsed dimensions of the 22 /// `src`. 23 static OpFoldResult getCollapsedOutputDimFromInputShape( 24 OpBuilder &builder, Location loc, int64_t dimIndex, Value src, 25 ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociationMap) { 26 if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) { 27 // Static dimension: return Attribute. 28 return builder.getIndexAttr(dstStaticShape[dimIndex]); 29 } 30 AffineMap map = reassociationMap[dimIndex]; 31 unsigned startPos = 32 cast<AffineDimExpr>(map.getResults().front()).getPosition(); 33 unsigned endPos = cast<AffineDimExpr>(map.getResults().back()).getPosition(); 34 AffineExpr expr; 35 SmallVector<OpFoldResult> dynamicDims; 36 for (auto dim : llvm::seq_inclusive(startPos, endPos)) { 37 dynamicDims.push_back(builder.createOrFold<tensor::DimOp>(loc, src, dim)); 38 AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos); 39 expr = (expr ? expr * currExpr : currExpr); 40 } 41 42 // Dynamic dimension: return Value. 43 return affine::makeComposedAffineApply( 44 builder, loc, AffineMap::get(0, endPos - startPos + 1, expr), 45 dynamicDims) 46 ->getResult(0); 47 } 48 49 /// Given the `src` of a collapsing reshape op and its reassociation maps, 50 /// compute the shape of the result of the reshape. 51 static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape( 52 OpBuilder &builder, Location loc, Value src, 53 ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) { 54 return llvm::to_vector<4>(llvm::map_range( 55 llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) { 56 return getCollapsedOutputDimFromInputShape( 57 builder, loc, dim, src, dstStaticShape, reassociation); 58 })); 59 } 60 61 struct ReifyCollapseShapeOp 62 : public ReifyRankedShapedTypeOpInterface::ExternalModel< 63 ReifyCollapseShapeOp, CollapseShapeOp> { 64 LogicalResult 65 reifyResultShapes(Operation *op, OpBuilder &b, 66 ReifiedRankedShapedTypeDims &reifiedReturnShapes) const { 67 auto loc = op->getLoc(); 68 auto reshapeOp = cast<tensor::CollapseShapeOp>(op); 69 reifiedReturnShapes.push_back(getCollapsedOutputShapeFromInputShape( 70 b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(), 71 reshapeOp.getReassociationMaps())); 72 return success(); 73 } 74 }; 75 76 namespace { 77 78 struct ReifyExpandShapeOp 79 : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp, 80 ExpandShapeOp> { 81 LogicalResult 82 reifyResultShapes(Operation *op, OpBuilder &b, 83 ReifiedRankedShapedTypeDims &reifyResultShapes) const { 84 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 85 SmallVector<OpFoldResult> resultShapes = 86 expandShapeOp.getMixedOutputShape(); 87 reifyResultShapes.emplace_back(std::move(resultShapes)); 88 return success(); 89 } 90 }; 91 92 struct ReifyPadOp 93 : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyPadOp, 94 PadOp> { 95 LogicalResult 96 reifyResultShapes(Operation *op, OpBuilder &b, 97 ReifiedRankedShapedTypeDims &reifiedReturnShapes) const { 98 auto padOp = cast<PadOp>(op); 99 Location loc = padOp.getLoc(); 100 auto lowPad = padOp.getMixedLowPad(); 101 auto highPad = padOp.getMixedHighPad(); 102 SmallVector<OpFoldResult> shapes; 103 for (auto dim : llvm::seq<int64_t>(0, padOp.getSourceType().getRank())) { 104 if (!padOp.getResultType().isDynamicDim(dim)) { 105 shapes.push_back(b.getIndexAttr(padOp.getResultType().getDimSize(dim))); 106 continue; 107 } 108 109 // Shape along each dimension is source dim + low pad + high pad. 110 SmallVector<OpFoldResult> mapOperands; 111 mapOperands.push_back( 112 b.createOrFold<tensor::DimOp>(loc, padOp.getSource(), dim)); 113 mapOperands.push_back(lowPad[dim]); 114 mapOperands.push_back(highPad[dim]); 115 AffineExpr expr = b.getAffineDimExpr(0) + b.getAffineSymbolExpr(0) + 116 b.getAffineSymbolExpr(1); 117 shapes.push_back(getValueOrCreateConstantIndexOp( 118 b, loc, 119 affine::makeComposedFoldedAffineApply( 120 b, loc, AffineMap::get(1, 2, expr), mapOperands))); 121 } 122 reifiedReturnShapes.emplace_back(std::move(shapes)); 123 return success(); 124 } 125 }; 126 127 } // namespace 128 129 void mlir::tensor::registerInferTypeOpInterfaceExternalModels( 130 DialectRegistry ®istry) { 131 registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { 132 ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx); 133 CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx); 134 PadOp::attachInterface<ReifyPadOp>(*ctx); 135 }); 136 } 137