1fd0c6f53SAlexander Belyaev //===- Utils.cpp - Utilities to support the Tensor dialect ----------------===// 2fd0c6f53SAlexander Belyaev // 3fd0c6f53SAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4fd0c6f53SAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information. 5fd0c6f53SAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6fd0c6f53SAlexander Belyaev // 7fd0c6f53SAlexander Belyaev //===----------------------------------------------------------------------===// 8fd0c6f53SAlexander Belyaev // 9fd0c6f53SAlexander Belyaev // This file implements utilities for the Tensor dialect. 10fd0c6f53SAlexander Belyaev // 11fd0c6f53SAlexander Belyaev //===----------------------------------------------------------------------===// 12fd0c6f53SAlexander Belyaev 13fd0c6f53SAlexander Belyaev #include "mlir/Dialect/Tensor/Utils/Utils.h" 14fd0c6f53SAlexander Belyaev 15fd0c6f53SAlexander Belyaev #include "mlir/Dialect/Affine/IR/AffineOps.h" 16abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 174dc72d47SNicolas Vasilache #include "mlir/Dialect/Arith/Utils/Utils.h" 182f07d627SNicolas Vasilache #include "mlir/Dialect/Utils/IndexingUtils.h" 19eb6222b9SDanial Klimkin #include "mlir/Dialect/Vector/IR/VectorOps.h" 2026864d8fSMatthias Springer #include "mlir/Interfaces/ValueBoundsOpInterface.h" 21fd0c6f53SAlexander Belyaev 22fd0c6f53SAlexander Belyaev using namespace mlir; 23fd0c6f53SAlexander Belyaev using namespace mlir::tensor; 24fd0c6f53SAlexander Belyaev 2566f84c8bSAndrzej Warzyński PadOp mlir::tensor::createPadHighOp(RankedTensorType resType, Value source, 26d26c42afSgysit Value pad, bool nofold, Location loc, 2766f84c8bSAndrzej Warzyński OpBuilder &b, 2866f84c8bSAndrzej Warzyński SmallVector<Value> dynOutDims) { 2966f84c8bSAndrzej Warzyński 30*a24c4687SAlexander Pivovarov assert(((resType.getNumDynamicDims() == dynOutDims.size()) || 31*a24c4687SAlexander Pivovarov dynOutDims.empty()) && 3266f84c8bSAndrzej Warzyński "Either none or all output dynamic dims must be specified!"); 3366f84c8bSAndrzej Warzyński 3466f84c8bSAndrzej Warzyński // Init "low" and "high" padding values ("low" is kept as is, "high" is 3566f84c8bSAndrzej Warzyński // computed below). 3666f84c8bSAndrzej Warzyński SmallVector<OpFoldResult> low(resType.getRank(), b.getIndexAttr(0)); 3766f84c8bSAndrzej Warzyński SmallVector<OpFoldResult> high(resType.getRank(), b.getIndexAttr(0)); 3866f84c8bSAndrzej Warzyński 3966f84c8bSAndrzej Warzyński size_t outDimIdx = 0; 4066f84c8bSAndrzej Warzyński 4166f84c8bSAndrzej Warzyński for (const auto [idx, val] : enumerate(resType.getShape())) { 4266f84c8bSAndrzej Warzyński bool isDimDynamic = ShapedType::isDynamic(val); 4366f84c8bSAndrzej Warzyński bool updatePadHigh = !isDimDynamic || !dynOutDims.empty(); 4466f84c8bSAndrzej Warzyński 4566f84c8bSAndrzej Warzyński // Keep the default padding width (i.e. "0") when the output dim is dynamic 4666f84c8bSAndrzej Warzyński // and no actual output sizes have been provided. 4766f84c8bSAndrzej Warzyński if (!updatePadHigh) 48a285ba75SHan-Chung Wang continue; 4966f84c8bSAndrzej Warzyński 5066f84c8bSAndrzej Warzyński // Compute the padding width: resDim - sourceDim. 5166f84c8bSAndrzej Warzyński AffineExpr d0, d1; 5266f84c8bSAndrzej Warzyński bindDims(b.getContext(), d0, d1); 5366f84c8bSAndrzej Warzyński OpFoldResult sourceDim = tensor::getMixedSize(b, loc, source, idx); 5466f84c8bSAndrzej Warzyński OpFoldResult outDim = isDimDynamic ? OpFoldResult(dynOutDims[outDimIdx++]) 5566f84c8bSAndrzej Warzyński : OpFoldResult(b.getIndexAttr(val)); 5666f84c8bSAndrzej Warzyński 5766f84c8bSAndrzej Warzyński high[idx] = affine::makeComposedFoldedAffineApply(b, loc, d0 - d1, 5866f84c8bSAndrzej Warzyński {outDim, sourceDim}); 59fd0c6f53SAlexander Belyaev } 6066f84c8bSAndrzej Warzyński return b.create<PadOp>(loc, resType, source, low, high, pad, nofold); 61fd0c6f53SAlexander Belyaev } 62ff6ce9e8SFrederik Gossen 63ff6ce9e8SFrederik Gossen SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b, 64ff6ce9e8SFrederik Gossen Location loc, 65ff6ce9e8SFrederik Gossen Value rankedTensor) { 665550c821STres Popp auto tensorTy = cast<RankedTensorType>(rankedTensor.getType()); 67ff6ce9e8SFrederik Gossen SmallVector<Value> dynamicDims; 68ff6ce9e8SFrederik Gossen for (const auto &en : llvm::enumerate(tensorTy.getShape())) { 69399638f9SAliia Khasanova if (en.value() == ShapedType::kDynamic) 70ff6ce9e8SFrederik Gossen dynamicDims.push_back( 71ff6ce9e8SFrederik Gossen b.create<tensor::DimOp>(loc, rankedTensor, en.index())); 72ff6ce9e8SFrederik Gossen } 73ff6ce9e8SFrederik Gossen return dynamicDims; 74ff6ce9e8SFrederik Gossen } 752c3ca3b6SFrederik Gossen 762f07d627SNicolas Vasilache FailureOr<RankedTensorType> 772f07d627SNicolas Vasilache mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType, 782f07d627SNicolas Vasilache ArrayRef<int64_t> transposeVector) { 792f07d627SNicolas Vasilache if (transposeVector.empty()) 802f07d627SNicolas Vasilache return rankedTensorType; 811445c11fSNicolas Vasilache 822f07d627SNicolas Vasilache if (!isPermutationVector(transposeVector) || 832f07d627SNicolas Vasilache transposeVector.size() != static_cast<size_t>(rankedTensorType.getRank())) 842f07d627SNicolas Vasilache return failure(); 852f07d627SNicolas Vasilache 865262865aSKazu Hirata SmallVector<int64_t> transposedShape(rankedTensorType.getShape()); 872f07d627SNicolas Vasilache applyPermutationToVector(transposedShape, transposeVector); 882f07d627SNicolas Vasilache 892f07d627SNicolas Vasilache using RTTBuilder = RankedTensorType::Builder; 902f07d627SNicolas Vasilache RankedTensorType transposedTensorType = 912f07d627SNicolas Vasilache RTTBuilder(rankedTensorType).setShape(transposedShape); 922f07d627SNicolas Vasilache return transposedTensorType; 932f07d627SNicolas Vasilache } 942db190fdSHan-Chung Wang 95adf838daSBalaji V. Iyer /// The permutation can be obtained from two permutations: 96adf838daSBalaji V. Iyer /// a) Compute the permutation vector to move the last `numPackedDims` into 97adf838daSBalaji V. Iyer /// the `innerPosDims` of a shape of rank `rank`. 98adf838daSBalaji V. Iyer /// b) Compute the permutation vector to move outer dims if the 99adf838daSBalaji V. Iyer /// `outerPerm` parameter is not empty. 100adf838daSBalaji V. Iyer /// Apply (b) permutation on (a) permutation to get the final permutation. 101adf838daSBalaji V. Iyer static SmallVector<int64_t> 102adf838daSBalaji V. Iyer computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos, 103adf838daSBalaji V. Iyer ArrayRef<int64_t> &outerPerm, 104adf838daSBalaji V. Iyer PackingMetadata &packingMetadata) { 105adf838daSBalaji V. Iyer int64_t numPackedDims = innerDimsPos.size(); 106adf838daSBalaji V. Iyer auto lastDims = 107adf838daSBalaji V. Iyer llvm::to_vector(llvm::seq<int64_t>(rank - numPackedDims, rank)); 108adf838daSBalaji V. Iyer packingMetadata = computePackingMetadata(rank, innerDimsPos); 109adf838daSBalaji V. Iyer SmallVector<int64_t> innerPositionsPerm = 110adf838daSBalaji V. Iyer computePermutationVector(rank, lastDims, packingMetadata.insertPositions); 1117880b2c8SMax191 1127880b2c8SMax191 SmallVector<int64_t> outerPos = packingMetadata.outerPositions; 1137880b2c8SMax191 if (!outerPerm.empty()) 1147880b2c8SMax191 applyPermutationToVector(outerPos, outerPerm); 115adf838daSBalaji V. Iyer SmallVector<int64_t> outerPositionPerm = 116adf838daSBalaji V. Iyer computePermutationVector(rank, packingMetadata.outerPositions, outerPos); 1177880b2c8SMax191 1187880b2c8SMax191 SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm; 1197880b2c8SMax191 applyPermutationToVector(packInverseDestPermutation, outerPositionPerm); 1207880b2c8SMax191 return packInverseDestPermutation; 1217880b2c8SMax191 } 1227880b2c8SMax191 123adf838daSBalaji V. Iyer SmallVector<int64_t> mlir::tensor::getPackInverseDestPerm(PackOp packOp) { 124adf838daSBalaji V. Iyer 125adf838daSBalaji V. Iyer PackingMetadata pMetadata; 126adf838daSBalaji V. Iyer int64_t packedRank = packOp.getDestType().getRank(); 127adf838daSBalaji V. Iyer ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos(); 128adf838daSBalaji V. Iyer ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm(); 129adf838daSBalaji V. Iyer SmallVector<int64_t> packInvDestPerm = 130adf838daSBalaji V. Iyer computePackUnPackPerm(packedRank, innerDimPos, outerPerm, pMetadata); 131adf838daSBalaji V. Iyer return packInvDestPerm; 132adf838daSBalaji V. Iyer } 133adf838daSBalaji V. Iyer 134adf838daSBalaji V. Iyer SmallVector<int64_t> mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp) { 135adf838daSBalaji V. Iyer PackingMetadata metadata; 136adf838daSBalaji V. Iyer return mlir::tensor::getUnPackInverseSrcPerm(unpackOp, metadata); 137adf838daSBalaji V. Iyer } 138adf838daSBalaji V. Iyer 139adf838daSBalaji V. Iyer SmallVector<int64_t> 140adf838daSBalaji V. Iyer mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp, 141adf838daSBalaji V. Iyer PackingMetadata &metadata) { 142adf838daSBalaji V. Iyer int64_t unpackRank = unpackOp.getSourceType().getRank(); 143adf838daSBalaji V. Iyer ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos(); 144adf838daSBalaji V. Iyer ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm(); 145adf838daSBalaji V. Iyer SmallVector<int64_t> unpackInvSrcPerm = 146adf838daSBalaji V. Iyer computePackUnPackPerm(unpackRank, innerDimPos, outerPerm, metadata); 147adf838daSBalaji V. Iyer return unpackInvSrcPerm; 148adf838daSBalaji V. Iyer } 149adf838daSBalaji V. Iyer 15026864d8fSMatthias Springer bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) { 15126864d8fSMatthias Springer llvm::SmallBitVector droppedDims = op.getDroppedDims(); 15226864d8fSMatthias Springer int64_t srcDim = 0; 153f566b079SJerry Wu RankedTensorType resultType = op.getDestType(); 15426864d8fSMatthias Springer // Source dims and destination dims (apart from dropped dims) must have the 15526864d8fSMatthias Springer // same size. 156f566b079SJerry Wu for (int64_t resultDim = 0; resultDim < resultType.getRank(); ++resultDim) { 15726864d8fSMatthias Springer if (droppedDims.test(resultDim)) { 158f566b079SJerry Wu // InsertSlice may expand unit dimensions that result from inserting a 159f566b079SJerry Wu // size-1 slice into a non-size-1 result dimension. 160f566b079SJerry Wu if (resultType.getDimSize(resultDim) != 1) 161f566b079SJerry Wu return false; 16226864d8fSMatthias Springer continue; 16326864d8fSMatthias Springer } 16426864d8fSMatthias Springer FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual( 16540dd3aa9SMatthias Springer {op.getSource(), srcDim}, {op.getResult(), resultDim}); 16626864d8fSMatthias Springer if (failed(equalDimSize) || !*equalDimSize) 16726864d8fSMatthias Springer return false; 16826864d8fSMatthias Springer ++srcDim; 16926864d8fSMatthias Springer } 17026864d8fSMatthias Springer 17126864d8fSMatthias Springer return true; 17226864d8fSMatthias Springer } 17334cf67aeSMatthias Springer 17434cf67aeSMatthias Springer bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) { 17534cf67aeSMatthias Springer llvm::SmallBitVector droppedDims = op.getDroppedDims(); 17634cf67aeSMatthias Springer int64_t resultDim = 0; 17734cf67aeSMatthias Springer // Source dims and result dims (apart from dropped dims) must have the same 17834cf67aeSMatthias Springer // size. 1799bd19bb7SChristopher Bate RankedTensorType sourceType = op.getSourceType(); 1809bd19bb7SChristopher Bate for (int64_t dim = 0, e = sourceType.getRank(); dim < e; ++dim) { 18134cf67aeSMatthias Springer if (droppedDims.test(dim)) { 1829bd19bb7SChristopher Bate // ExtractSlice may drop unit dimensions that result from taking a size-1 1839bd19bb7SChristopher Bate // slice from a non-size-1 source dimension. 1849bd19bb7SChristopher Bate if (sourceType.getDimSize(dim) != 1) 1859bd19bb7SChristopher Bate return false; 18634cf67aeSMatthias Springer continue; 18734cf67aeSMatthias Springer } 18834cf67aeSMatthias Springer FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual( 18940dd3aa9SMatthias Springer {op.getSource(), dim}, {op.getResult(), resultDim}); 19034cf67aeSMatthias Springer if (failed(equalDimSize) || !*equalDimSize) 19134cf67aeSMatthias Springer return false; 19234cf67aeSMatthias Springer ++resultDim; 19334cf67aeSMatthias Springer } 19434cf67aeSMatthias Springer 19534cf67aeSMatthias Springer return true; 19634cf67aeSMatthias Springer } 197