xref: /llvm-project/mlir/lib/Dialect/Tensor/Utils/Utils.cpp (revision a24c468782010e17563f6aa93c5bb173c7f873b2)
1 //===- Utils.cpp - Utilities to support the Tensor dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Tensor dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Tensor/Utils/Utils.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Arith/Utils/Utils.h"
18 #include "mlir/Dialect/Utils/IndexingUtils.h"
19 #include "mlir/Dialect/Vector/IR/VectorOps.h"
20 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
21 
22 using namespace mlir;
23 using namespace mlir::tensor;
24 
25 PadOp mlir::tensor::createPadHighOp(RankedTensorType resType, Value source,
26                                     Value pad, bool nofold, Location loc,
27                                     OpBuilder &b,
28                                     SmallVector<Value> dynOutDims) {
29 
30   assert(((resType.getNumDynamicDims() == dynOutDims.size()) ||
31           dynOutDims.empty()) &&
32          "Either none or all output dynamic dims must be specified!");
33 
34   // Init "low" and "high" padding values ("low" is kept as is, "high" is
35   // computed below).
36   SmallVector<OpFoldResult> low(resType.getRank(), b.getIndexAttr(0));
37   SmallVector<OpFoldResult> high(resType.getRank(), b.getIndexAttr(0));
38 
39   size_t outDimIdx = 0;
40 
41   for (const auto [idx, val] : enumerate(resType.getShape())) {
42     bool isDimDynamic = ShapedType::isDynamic(val);
43     bool updatePadHigh = !isDimDynamic || !dynOutDims.empty();
44 
45     // Keep the default padding width (i.e. "0") when the output dim is dynamic
46     // and no actual output sizes have been provided.
47     if (!updatePadHigh)
48       continue;
49 
50     // Compute the padding width: resDim - sourceDim.
51     AffineExpr d0, d1;
52     bindDims(b.getContext(), d0, d1);
53     OpFoldResult sourceDim = tensor::getMixedSize(b, loc, source, idx);
54     OpFoldResult outDim = isDimDynamic ? OpFoldResult(dynOutDims[outDimIdx++])
55                                        : OpFoldResult(b.getIndexAttr(val));
56 
57     high[idx] = affine::makeComposedFoldedAffineApply(b, loc, d0 - d1,
58                                                       {outDim, sourceDim});
59   }
60   return b.create<PadOp>(loc, resType, source, low, high, pad, nofold);
61 }
62 
63 SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b,
64                                                         Location loc,
65                                                         Value rankedTensor) {
66   auto tensorTy = cast<RankedTensorType>(rankedTensor.getType());
67   SmallVector<Value> dynamicDims;
68   for (const auto &en : llvm::enumerate(tensorTy.getShape())) {
69     if (en.value() == ShapedType::kDynamic)
70       dynamicDims.push_back(
71           b.create<tensor::DimOp>(loc, rankedTensor, en.index()));
72   }
73   return dynamicDims;
74 }
75 
76 FailureOr<RankedTensorType>
77 mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
78                                     ArrayRef<int64_t> transposeVector) {
79   if (transposeVector.empty())
80     return rankedTensorType;
81 
82   if (!isPermutationVector(transposeVector) ||
83       transposeVector.size() != static_cast<size_t>(rankedTensorType.getRank()))
84     return failure();
85 
86   SmallVector<int64_t> transposedShape(rankedTensorType.getShape());
87   applyPermutationToVector(transposedShape, transposeVector);
88 
89   using RTTBuilder = RankedTensorType::Builder;
90   RankedTensorType transposedTensorType =
91       RTTBuilder(rankedTensorType).setShape(transposedShape);
92   return transposedTensorType;
93 }
94 
95 /// The permutation can be obtained from two permutations:
96 ///   a) Compute the permutation vector to move the last `numPackedDims` into
97 ///      the `innerPosDims` of a shape of rank `rank`.
98 ///   b) Compute the permutation vector to move outer dims if the
99 ///      `outerPerm` parameter is not empty.
100 /// Apply (b) permutation on (a) permutation to get the final permutation.
101 static SmallVector<int64_t>
102 computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos,
103                       ArrayRef<int64_t> &outerPerm,
104                       PackingMetadata &packingMetadata) {
105   int64_t numPackedDims = innerDimsPos.size();
106   auto lastDims =
107       llvm::to_vector(llvm::seq<int64_t>(rank - numPackedDims, rank));
108   packingMetadata = computePackingMetadata(rank, innerDimsPos);
109   SmallVector<int64_t> innerPositionsPerm =
110       computePermutationVector(rank, lastDims, packingMetadata.insertPositions);
111 
112   SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
113   if (!outerPerm.empty())
114     applyPermutationToVector(outerPos, outerPerm);
115   SmallVector<int64_t> outerPositionPerm =
116       computePermutationVector(rank, packingMetadata.outerPositions, outerPos);
117 
118   SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm;
119   applyPermutationToVector(packInverseDestPermutation, outerPositionPerm);
120   return packInverseDestPermutation;
121 }
122 
123 SmallVector<int64_t> mlir::tensor::getPackInverseDestPerm(PackOp packOp) {
124 
125   PackingMetadata pMetadata;
126   int64_t packedRank = packOp.getDestType().getRank();
127   ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
128   ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
129   SmallVector<int64_t> packInvDestPerm =
130       computePackUnPackPerm(packedRank, innerDimPos, outerPerm, pMetadata);
131   return packInvDestPerm;
132 }
133 
134 SmallVector<int64_t> mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp) {
135   PackingMetadata metadata;
136   return mlir::tensor::getUnPackInverseSrcPerm(unpackOp, metadata);
137 }
138 
139 SmallVector<int64_t>
140 mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp,
141                                       PackingMetadata &metadata) {
142   int64_t unpackRank = unpackOp.getSourceType().getRank();
143   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
144   ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
145   SmallVector<int64_t> unpackInvSrcPerm =
146       computePackUnPackPerm(unpackRank, innerDimPos, outerPerm, metadata);
147   return unpackInvSrcPerm;
148 }
149 
150 bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
151   llvm::SmallBitVector droppedDims = op.getDroppedDims();
152   int64_t srcDim = 0;
153   RankedTensorType resultType = op.getDestType();
154   // Source dims and destination dims (apart from dropped dims) must have the
155   // same size.
156   for (int64_t resultDim = 0; resultDim < resultType.getRank(); ++resultDim) {
157     if (droppedDims.test(resultDim)) {
158       // InsertSlice may expand unit dimensions that result from inserting a
159       // size-1 slice into a non-size-1 result dimension.
160       if (resultType.getDimSize(resultDim) != 1)
161         return false;
162       continue;
163     }
164     FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
165         {op.getSource(), srcDim}, {op.getResult(), resultDim});
166     if (failed(equalDimSize) || !*equalDimSize)
167       return false;
168     ++srcDim;
169   }
170 
171   return true;
172 }
173 
174 bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) {
175   llvm::SmallBitVector droppedDims = op.getDroppedDims();
176   int64_t resultDim = 0;
177   // Source dims and result dims (apart from dropped dims) must have the same
178   // size.
179   RankedTensorType sourceType = op.getSourceType();
180   for (int64_t dim = 0, e = sourceType.getRank(); dim < e; ++dim) {
181     if (droppedDims.test(dim)) {
182       // ExtractSlice may drop unit dimensions that result from taking a size-1
183       // slice from a non-size-1 source dimension.
184       if (sourceType.getDimSize(dim) != 1)
185         return false;
186       continue;
187     }
188     FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
189         {op.getSource(), dim}, {op.getResult(), resultDim});
190     if (failed(equalDimSize) || !*equalDimSize)
191       return false;
192     ++resultDim;
193   }
194 
195   return true;
196 }
197