xref: /llvm-project/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp (revision 092372da15e5165be14cdbb7cac3cf4976fd82d0)
1 //===- InferTypeOpImpl.cpp - InferType Interface external models *- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/Arith/Utils/Utils.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/Dialect/Utils/StaticValueUtils.h"
14 #include "mlir/Interfaces/InferTypeOpInterface.h"
15 
16 using namespace mlir;
17 using namespace mlir::tensor;
18 
19 /// For reshape op compute the shape at dimension `dimIndex` of the output in
20 /// terms of shape of the `src`, when the reshape op is a collapsing
21 /// operation. It is the product of the shape of the collapsed dimensions of the
22 /// `src`.
23 static OpFoldResult getCollapsedOutputDimFromInputShape(
24     OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
25     ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociationMap) {
26   if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
27     // Static dimension: return Attribute.
28     return builder.getIndexAttr(dstStaticShape[dimIndex]);
29   }
30   AffineMap map = reassociationMap[dimIndex];
31   unsigned startPos =
32       cast<AffineDimExpr>(map.getResults().front()).getPosition();
33   unsigned endPos = cast<AffineDimExpr>(map.getResults().back()).getPosition();
34   AffineExpr expr;
35   SmallVector<OpFoldResult> dynamicDims;
36   for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
37     dynamicDims.push_back(builder.createOrFold<tensor::DimOp>(loc, src, dim));
38     AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos);
39     expr = (expr ? expr * currExpr : currExpr);
40   }
41 
42   // Dynamic dimension: return Value.
43   return affine::makeComposedAffineApply(
44              builder, loc, AffineMap::get(0, endPos - startPos + 1, expr),
45              dynamicDims)
46       ->getResult(0);
47 }
48 
49 /// Given the `src` of a collapsing reshape op and its reassociation maps,
50 /// compute the shape of the result of the reshape.
51 static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape(
52     OpBuilder &builder, Location loc, Value src,
53     ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
54   return llvm::to_vector<4>(llvm::map_range(
55       llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
56         return getCollapsedOutputDimFromInputShape(
57             builder, loc, dim, src, dstStaticShape, reassociation);
58       }));
59 }
60 
61 struct ReifyCollapseShapeOp
62     : public ReifyRankedShapedTypeOpInterface::ExternalModel<
63           ReifyCollapseShapeOp, CollapseShapeOp> {
64   LogicalResult
65   reifyResultShapes(Operation *op, OpBuilder &b,
66                     ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
67     auto loc = op->getLoc();
68     auto reshapeOp = cast<tensor::CollapseShapeOp>(op);
69     reifiedReturnShapes.push_back(getCollapsedOutputShapeFromInputShape(
70         b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
71         reshapeOp.getReassociationMaps()));
72     return success();
73   }
74 };
75 
76 namespace {
77 
78 struct ReifyExpandShapeOp
79     : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
80                                                              ExpandShapeOp> {
81   LogicalResult
82   reifyResultShapes(Operation *op, OpBuilder &b,
83                     ReifiedRankedShapedTypeDims &reifyResultShapes) const {
84     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
85     SmallVector<OpFoldResult> resultShapes =
86         expandShapeOp.getMixedOutputShape();
87     reifyResultShapes.emplace_back(std::move(resultShapes));
88     return success();
89   }
90 };
91 
92 struct ReifyPadOp
93     : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyPadOp,
94                                                              PadOp> {
95   LogicalResult
96   reifyResultShapes(Operation *op, OpBuilder &b,
97                     ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
98     auto padOp = cast<PadOp>(op);
99     Location loc = padOp.getLoc();
100     auto lowPad = padOp.getMixedLowPad();
101     auto highPad = padOp.getMixedHighPad();
102     SmallVector<OpFoldResult> shapes;
103     for (auto dim : llvm::seq<int64_t>(0, padOp.getSourceType().getRank())) {
104       if (!padOp.getResultType().isDynamicDim(dim)) {
105         shapes.push_back(b.getIndexAttr(padOp.getResultType().getDimSize(dim)));
106         continue;
107       }
108 
109       // Shape along each dimension is source dim + low pad + high pad.
110       SmallVector<OpFoldResult> mapOperands;
111       mapOperands.push_back(
112           b.createOrFold<tensor::DimOp>(loc, padOp.getSource(), dim));
113       mapOperands.push_back(lowPad[dim]);
114       mapOperands.push_back(highPad[dim]);
115       AffineExpr expr = b.getAffineDimExpr(0) + b.getAffineSymbolExpr(0) +
116                         b.getAffineSymbolExpr(1);
117       shapes.push_back(getValueOrCreateConstantIndexOp(
118           b, loc,
119           affine::makeComposedFoldedAffineApply(
120               b, loc, AffineMap::get(1, 2, expr), mapOperands)));
121     }
122     reifiedReturnShapes.emplace_back(std::move(shapes));
123     return success();
124   }
125 };
126 
127 } // namespace
128 
129 void mlir::tensor::registerInferTypeOpInterfaceExternalModels(
130     DialectRegistry &registry) {
131   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
132     ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
133     CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
134     PadOp::attachInterface<ReifyPadOp>(*ctx);
135   });
136 }
137