xref: /llvm-project/mlir/lib/Dialect/Tensor/Utils/Utils.cpp (revision a24c468782010e17563f6aa93c5bb173c7f873b2)
1fd0c6f53SAlexander Belyaev //===- Utils.cpp - Utilities to support the Tensor dialect ----------------===//
2fd0c6f53SAlexander Belyaev //
3fd0c6f53SAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fd0c6f53SAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information.
5fd0c6f53SAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fd0c6f53SAlexander Belyaev //
7fd0c6f53SAlexander Belyaev //===----------------------------------------------------------------------===//
8fd0c6f53SAlexander Belyaev //
9fd0c6f53SAlexander Belyaev // This file implements utilities for the Tensor dialect.
10fd0c6f53SAlexander Belyaev //
11fd0c6f53SAlexander Belyaev //===----------------------------------------------------------------------===//
12fd0c6f53SAlexander Belyaev 
13fd0c6f53SAlexander Belyaev #include "mlir/Dialect/Tensor/Utils/Utils.h"
14fd0c6f53SAlexander Belyaev 
15fd0c6f53SAlexander Belyaev #include "mlir/Dialect/Affine/IR/AffineOps.h"
16abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
174dc72d47SNicolas Vasilache #include "mlir/Dialect/Arith/Utils/Utils.h"
182f07d627SNicolas Vasilache #include "mlir/Dialect/Utils/IndexingUtils.h"
19eb6222b9SDanial Klimkin #include "mlir/Dialect/Vector/IR/VectorOps.h"
2026864d8fSMatthias Springer #include "mlir/Interfaces/ValueBoundsOpInterface.h"
21fd0c6f53SAlexander Belyaev 
22fd0c6f53SAlexander Belyaev using namespace mlir;
23fd0c6f53SAlexander Belyaev using namespace mlir::tensor;
24fd0c6f53SAlexander Belyaev 
2566f84c8bSAndrzej Warzyński PadOp mlir::tensor::createPadHighOp(RankedTensorType resType, Value source,
26d26c42afSgysit                                     Value pad, bool nofold, Location loc,
2766f84c8bSAndrzej Warzyński                                     OpBuilder &b,
2866f84c8bSAndrzej Warzyński                                     SmallVector<Value> dynOutDims) {
2966f84c8bSAndrzej Warzyński 
30*a24c4687SAlexander Pivovarov   assert(((resType.getNumDynamicDims() == dynOutDims.size()) ||
31*a24c4687SAlexander Pivovarov           dynOutDims.empty()) &&
3266f84c8bSAndrzej Warzyński          "Either none or all output dynamic dims must be specified!");
3366f84c8bSAndrzej Warzyński 
3466f84c8bSAndrzej Warzyński   // Init "low" and "high" padding values ("low" is kept as is, "high" is
3566f84c8bSAndrzej Warzyński   // computed below).
3666f84c8bSAndrzej Warzyński   SmallVector<OpFoldResult> low(resType.getRank(), b.getIndexAttr(0));
3766f84c8bSAndrzej Warzyński   SmallVector<OpFoldResult> high(resType.getRank(), b.getIndexAttr(0));
3866f84c8bSAndrzej Warzyński 
3966f84c8bSAndrzej Warzyński   size_t outDimIdx = 0;
4066f84c8bSAndrzej Warzyński 
4166f84c8bSAndrzej Warzyński   for (const auto [idx, val] : enumerate(resType.getShape())) {
4266f84c8bSAndrzej Warzyński     bool isDimDynamic = ShapedType::isDynamic(val);
4366f84c8bSAndrzej Warzyński     bool updatePadHigh = !isDimDynamic || !dynOutDims.empty();
4466f84c8bSAndrzej Warzyński 
4566f84c8bSAndrzej Warzyński     // Keep the default padding width (i.e. "0") when the output dim is dynamic
4666f84c8bSAndrzej Warzyński     // and no actual output sizes have been provided.
4766f84c8bSAndrzej Warzyński     if (!updatePadHigh)
48a285ba75SHan-Chung Wang       continue;
4966f84c8bSAndrzej Warzyński 
5066f84c8bSAndrzej Warzyński     // Compute the padding width: resDim - sourceDim.
5166f84c8bSAndrzej Warzyński     AffineExpr d0, d1;
5266f84c8bSAndrzej Warzyński     bindDims(b.getContext(), d0, d1);
5366f84c8bSAndrzej Warzyński     OpFoldResult sourceDim = tensor::getMixedSize(b, loc, source, idx);
5466f84c8bSAndrzej Warzyński     OpFoldResult outDim = isDimDynamic ? OpFoldResult(dynOutDims[outDimIdx++])
5566f84c8bSAndrzej Warzyński                                        : OpFoldResult(b.getIndexAttr(val));
5666f84c8bSAndrzej Warzyński 
5766f84c8bSAndrzej Warzyński     high[idx] = affine::makeComposedFoldedAffineApply(b, loc, d0 - d1,
5866f84c8bSAndrzej Warzyński                                                       {outDim, sourceDim});
59fd0c6f53SAlexander Belyaev   }
6066f84c8bSAndrzej Warzyński   return b.create<PadOp>(loc, resType, source, low, high, pad, nofold);
61fd0c6f53SAlexander Belyaev }
62ff6ce9e8SFrederik Gossen 
63ff6ce9e8SFrederik Gossen SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b,
64ff6ce9e8SFrederik Gossen                                                         Location loc,
65ff6ce9e8SFrederik Gossen                                                         Value rankedTensor) {
665550c821STres Popp   auto tensorTy = cast<RankedTensorType>(rankedTensor.getType());
67ff6ce9e8SFrederik Gossen   SmallVector<Value> dynamicDims;
68ff6ce9e8SFrederik Gossen   for (const auto &en : llvm::enumerate(tensorTy.getShape())) {
69399638f9SAliia Khasanova     if (en.value() == ShapedType::kDynamic)
70ff6ce9e8SFrederik Gossen       dynamicDims.push_back(
71ff6ce9e8SFrederik Gossen           b.create<tensor::DimOp>(loc, rankedTensor, en.index()));
72ff6ce9e8SFrederik Gossen   }
73ff6ce9e8SFrederik Gossen   return dynamicDims;
74ff6ce9e8SFrederik Gossen }
752c3ca3b6SFrederik Gossen 
762f07d627SNicolas Vasilache FailureOr<RankedTensorType>
772f07d627SNicolas Vasilache mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
782f07d627SNicolas Vasilache                                     ArrayRef<int64_t> transposeVector) {
792f07d627SNicolas Vasilache   if (transposeVector.empty())
802f07d627SNicolas Vasilache     return rankedTensorType;
811445c11fSNicolas Vasilache 
822f07d627SNicolas Vasilache   if (!isPermutationVector(transposeVector) ||
832f07d627SNicolas Vasilache       transposeVector.size() != static_cast<size_t>(rankedTensorType.getRank()))
842f07d627SNicolas Vasilache     return failure();
852f07d627SNicolas Vasilache 
865262865aSKazu Hirata   SmallVector<int64_t> transposedShape(rankedTensorType.getShape());
872f07d627SNicolas Vasilache   applyPermutationToVector(transposedShape, transposeVector);
882f07d627SNicolas Vasilache 
892f07d627SNicolas Vasilache   using RTTBuilder = RankedTensorType::Builder;
902f07d627SNicolas Vasilache   RankedTensorType transposedTensorType =
912f07d627SNicolas Vasilache       RTTBuilder(rankedTensorType).setShape(transposedShape);
922f07d627SNicolas Vasilache   return transposedTensorType;
932f07d627SNicolas Vasilache }
942db190fdSHan-Chung Wang 
95adf838daSBalaji V. Iyer /// The permutation can be obtained from two permutations:
96adf838daSBalaji V. Iyer ///   a) Compute the permutation vector to move the last `numPackedDims` into
97adf838daSBalaji V. Iyer ///      the `innerPosDims` of a shape of rank `rank`.
98adf838daSBalaji V. Iyer ///   b) Compute the permutation vector to move outer dims if the
99adf838daSBalaji V. Iyer ///      `outerPerm` parameter is not empty.
100adf838daSBalaji V. Iyer /// Apply (b) permutation on (a) permutation to get the final permutation.
101adf838daSBalaji V. Iyer static SmallVector<int64_t>
102adf838daSBalaji V. Iyer computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos,
103adf838daSBalaji V. Iyer                       ArrayRef<int64_t> &outerPerm,
104adf838daSBalaji V. Iyer                       PackingMetadata &packingMetadata) {
105adf838daSBalaji V. Iyer   int64_t numPackedDims = innerDimsPos.size();
106adf838daSBalaji V. Iyer   auto lastDims =
107adf838daSBalaji V. Iyer       llvm::to_vector(llvm::seq<int64_t>(rank - numPackedDims, rank));
108adf838daSBalaji V. Iyer   packingMetadata = computePackingMetadata(rank, innerDimsPos);
109adf838daSBalaji V. Iyer   SmallVector<int64_t> innerPositionsPerm =
110adf838daSBalaji V. Iyer       computePermutationVector(rank, lastDims, packingMetadata.insertPositions);
1117880b2c8SMax191 
1127880b2c8SMax191   SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
1137880b2c8SMax191   if (!outerPerm.empty())
1147880b2c8SMax191     applyPermutationToVector(outerPos, outerPerm);
115adf838daSBalaji V. Iyer   SmallVector<int64_t> outerPositionPerm =
116adf838daSBalaji V. Iyer       computePermutationVector(rank, packingMetadata.outerPositions, outerPos);
1177880b2c8SMax191 
1187880b2c8SMax191   SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm;
1197880b2c8SMax191   applyPermutationToVector(packInverseDestPermutation, outerPositionPerm);
1207880b2c8SMax191   return packInverseDestPermutation;
1217880b2c8SMax191 }
1227880b2c8SMax191 
123adf838daSBalaji V. Iyer SmallVector<int64_t> mlir::tensor::getPackInverseDestPerm(PackOp packOp) {
124adf838daSBalaji V. Iyer 
125adf838daSBalaji V. Iyer   PackingMetadata pMetadata;
126adf838daSBalaji V. Iyer   int64_t packedRank = packOp.getDestType().getRank();
127adf838daSBalaji V. Iyer   ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
128adf838daSBalaji V. Iyer   ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
129adf838daSBalaji V. Iyer   SmallVector<int64_t> packInvDestPerm =
130adf838daSBalaji V. Iyer       computePackUnPackPerm(packedRank, innerDimPos, outerPerm, pMetadata);
131adf838daSBalaji V. Iyer   return packInvDestPerm;
132adf838daSBalaji V. Iyer }
133adf838daSBalaji V. Iyer 
134adf838daSBalaji V. Iyer SmallVector<int64_t> mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp) {
135adf838daSBalaji V. Iyer   PackingMetadata metadata;
136adf838daSBalaji V. Iyer   return mlir::tensor::getUnPackInverseSrcPerm(unpackOp, metadata);
137adf838daSBalaji V. Iyer }
138adf838daSBalaji V. Iyer 
139adf838daSBalaji V. Iyer SmallVector<int64_t>
140adf838daSBalaji V. Iyer mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp,
141adf838daSBalaji V. Iyer                                       PackingMetadata &metadata) {
142adf838daSBalaji V. Iyer   int64_t unpackRank = unpackOp.getSourceType().getRank();
143adf838daSBalaji V. Iyer   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
144adf838daSBalaji V. Iyer   ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
145adf838daSBalaji V. Iyer   SmallVector<int64_t> unpackInvSrcPerm =
146adf838daSBalaji V. Iyer       computePackUnPackPerm(unpackRank, innerDimPos, outerPerm, metadata);
147adf838daSBalaji V. Iyer   return unpackInvSrcPerm;
148adf838daSBalaji V. Iyer }
149adf838daSBalaji V. Iyer 
15026864d8fSMatthias Springer bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
15126864d8fSMatthias Springer   llvm::SmallBitVector droppedDims = op.getDroppedDims();
15226864d8fSMatthias Springer   int64_t srcDim = 0;
153f566b079SJerry Wu   RankedTensorType resultType = op.getDestType();
15426864d8fSMatthias Springer   // Source dims and destination dims (apart from dropped dims) must have the
15526864d8fSMatthias Springer   // same size.
156f566b079SJerry Wu   for (int64_t resultDim = 0; resultDim < resultType.getRank(); ++resultDim) {
15726864d8fSMatthias Springer     if (droppedDims.test(resultDim)) {
158f566b079SJerry Wu       // InsertSlice may expand unit dimensions that result from inserting a
159f566b079SJerry Wu       // size-1 slice into a non-size-1 result dimension.
160f566b079SJerry Wu       if (resultType.getDimSize(resultDim) != 1)
161f566b079SJerry Wu         return false;
16226864d8fSMatthias Springer       continue;
16326864d8fSMatthias Springer     }
16426864d8fSMatthias Springer     FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
16540dd3aa9SMatthias Springer         {op.getSource(), srcDim}, {op.getResult(), resultDim});
16626864d8fSMatthias Springer     if (failed(equalDimSize) || !*equalDimSize)
16726864d8fSMatthias Springer       return false;
16826864d8fSMatthias Springer     ++srcDim;
16926864d8fSMatthias Springer   }
17026864d8fSMatthias Springer 
17126864d8fSMatthias Springer   return true;
17226864d8fSMatthias Springer }
17334cf67aeSMatthias Springer 
17434cf67aeSMatthias Springer bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) {
17534cf67aeSMatthias Springer   llvm::SmallBitVector droppedDims = op.getDroppedDims();
17634cf67aeSMatthias Springer   int64_t resultDim = 0;
17734cf67aeSMatthias Springer   // Source dims and result dims (apart from dropped dims) must have the same
17834cf67aeSMatthias Springer   // size.
1799bd19bb7SChristopher Bate   RankedTensorType sourceType = op.getSourceType();
1809bd19bb7SChristopher Bate   for (int64_t dim = 0, e = sourceType.getRank(); dim < e; ++dim) {
18134cf67aeSMatthias Springer     if (droppedDims.test(dim)) {
1829bd19bb7SChristopher Bate       // ExtractSlice may drop unit dimensions that result from taking a size-1
1839bd19bb7SChristopher Bate       // slice from a non-size-1 source dimension.
1849bd19bb7SChristopher Bate       if (sourceType.getDimSize(dim) != 1)
1859bd19bb7SChristopher Bate         return false;
18634cf67aeSMatthias Springer       continue;
18734cf67aeSMatthias Springer     }
18834cf67aeSMatthias Springer     FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
18940dd3aa9SMatthias Springer         {op.getSource(), dim}, {op.getResult(), resultDim});
19034cf67aeSMatthias Springer     if (failed(equalDimSize) || !*equalDimSize)
19134cf67aeSMatthias Springer       return false;
19234cf67aeSMatthias Springer     ++resultDim;
19334cf67aeSMatthias Springer   }
19434cf67aeSMatthias Springer 
19534cf67aeSMatthias Springer   return true;
19634cf67aeSMatthias Springer }
197