16412a135SAlexander Belyaev //===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===// 26412a135SAlexander Belyaev // 36412a135SAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 46412a135SAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information. 56412a135SAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 66412a135SAlexander Belyaev // 76412a135SAlexander Belyaev //===----------------------------------------------------------------------===// 86412a135SAlexander Belyaev 96412a135SAlexander Belyaev #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" 106412a135SAlexander Belyaev 116412a135SAlexander Belyaev #include "mlir/IR/AffineMap.h" 126412a135SAlexander Belyaev #include "mlir/IR/Builders.h" 136412a135SAlexander Belyaev 14d6595278SAlexander Belyaev #include <numeric> 15a1fe1f5fSKazu Hirata #include <optional> 16d6595278SAlexander Belyaev 176412a135SAlexander Belyaev using namespace mlir; 186412a135SAlexander Belyaev 190a81ace0SKazu Hirata std::optional<SmallVector<ReassociationIndices>> 206412a135SAlexander Belyaev mlir::getReassociationIndicesForReshape(ShapedType sourceType, 216412a135SAlexander Belyaev ShapedType targetType) { 22747b10beSAlexander Belyaev if (sourceType.getRank() > targetType.getRank()) 23747b10beSAlexander Belyaev return getReassociationIndicesForCollapse(sourceType.getShape(), 24747b10beSAlexander Belyaev targetType.getShape()); 256412a135SAlexander Belyaev if (sourceType.getRank() < targetType.getRank()) 26747b10beSAlexander Belyaev return getReassociationIndicesForCollapse(targetType.getShape(), 27747b10beSAlexander Belyaev sourceType.getShape()); 281a36588eSKazu Hirata return std::nullopt; 29747b10beSAlexander Belyaev } 306412a135SAlexander Belyaev 310a81ace0SKazu Hirata std::optional<SmallVector<ReassociationIndices>> 32747b10beSAlexander Belyaev mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape, 33747b10beSAlexander Belyaev ArrayRef<int64_t> targetShape) { 34747b10beSAlexander Belyaev if (sourceShape.size() <= targetShape.size()) 351a36588eSKazu Hirata return std::nullopt; 366412a135SAlexander Belyaev unsigned sourceDim = 0; 376412a135SAlexander Belyaev SmallVector<ReassociationIndices> reassociationMap; 38747b10beSAlexander Belyaev reassociationMap.reserve(targetShape.size()); 396412a135SAlexander Belyaev 406412a135SAlexander Belyaev ReassociationIndices currIndices; 416412a135SAlexander Belyaev int64_t prodOfCollapsedDims = 1; 426412a135SAlexander Belyaev while (sourceDim < sourceShape.size()) { 436412a135SAlexander Belyaev unsigned targetDim = reassociationMap.size(); 44a43f7d6dSStephan Herhut // If we have mapped all the target dimensions stop and handle the remaining 453f222f3bSc8ef // tail of size-1 dimensions explicitly. 46747b10beSAlexander Belyaev if (targetDim == targetShape.size()) 47a43f7d6dSStephan Herhut break; 486412a135SAlexander Belyaev 49a43f7d6dSStephan Herhut int64_t currTargetShape = targetShape[targetDim]; 50*2f15d7e4SVinayak Dev while (sourceDim < (sourceShape.size() - 1) && 51ded3644aSShivam Gupta sourceShape[sourceDim] != ShapedType::kDynamic && 52ded3644aSShivam Gupta prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) { 536412a135SAlexander Belyaev prodOfCollapsedDims *= sourceShape[sourceDim]; 546412a135SAlexander Belyaev currIndices.push_back(sourceDim++); 556412a135SAlexander Belyaev } 566412a135SAlexander Belyaev 576412a135SAlexander Belyaev // If the current expanded dimension is dynamic, then the collapsed 586412a135SAlexander Belyaev // dimensions should also be dynamic and product of all previous unprocessed 596412a135SAlexander Belyaev // dimensions of the expanded shape should be 1. 60399638f9SAliia Khasanova if (sourceShape[sourceDim] == ShapedType::kDynamic && 6122426110SRamkumar Ramachandra (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1)) 621a36588eSKazu Hirata return std::nullopt; 636412a135SAlexander Belyaev 646412a135SAlexander Belyaev // If the collapsed dim is dynamic, the current expanded dim should also 656412a135SAlexander Belyaev // be dynamic. 66399638f9SAliia Khasanova if (currTargetShape == ShapedType::kDynamic && 67399638f9SAliia Khasanova sourceShape[sourceDim] != ShapedType::kDynamic) 681a36588eSKazu Hirata return std::nullopt; 696412a135SAlexander Belyaev 706412a135SAlexander Belyaev // For static shapes, if the product of dimensions of the expanded shape 716412a135SAlexander Belyaev // should match the collapsed dimension shape. 726412a135SAlexander Belyaev if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape) 731a36588eSKazu Hirata return std::nullopt; 746412a135SAlexander Belyaev 756412a135SAlexander Belyaev currIndices.push_back(sourceDim++); 766412a135SAlexander Belyaev reassociationMap.emplace_back(ReassociationIndices{}); 776412a135SAlexander Belyaev std::swap(reassociationMap.back(), currIndices); 786412a135SAlexander Belyaev prodOfCollapsedDims = 1; 796412a135SAlexander Belyaev } 80a43f7d6dSStephan Herhut // All the dimensions in the target must have been processed. 81a43f7d6dSStephan Herhut if (reassociationMap.size() != targetShape.size()) 821a36588eSKazu Hirata return std::nullopt; 83a43f7d6dSStephan Herhut // Process any remaining entries in the source shape. They all need to be 84a43f7d6dSStephan Herhut // 1 or dynamic. 85a43f7d6dSStephan Herhut for (; sourceDim < sourceShape.size(); sourceDim++) { 86399638f9SAliia Khasanova if (sourceShape[sourceDim] != ShapedType::kDynamic && 87a43f7d6dSStephan Herhut sourceShape[sourceDim] != 1) 881a36588eSKazu Hirata return std::nullopt; 89a43f7d6dSStephan Herhut // The map is empty when the target type is a scalar. 90a43f7d6dSStephan Herhut if (!reassociationMap.empty()) 91a43f7d6dSStephan Herhut reassociationMap.back().push_back(sourceDim); 92a43f7d6dSStephan Herhut } 936412a135SAlexander Belyaev return reassociationMap; 946412a135SAlexander Belyaev } 956412a135SAlexander Belyaev 960a81ace0SKazu Hirata std::optional<SmallVector<ReassociationIndices>> 970a81ace0SKazu Hirata mlir::composeReassociationIndices( 98d6595278SAlexander Belyaev ArrayRef<ReassociationIndices> producerReassociations, 99d6595278SAlexander Belyaev ArrayRef<ReassociationIndices> consumerReassociations, 1006412a135SAlexander Belyaev MLIRContext *context) { 101d6595278SAlexander Belyaev SmallVector<ReassociationIndices> composedIndices; 1026412a135SAlexander Belyaev // Make the producer the larger sized vector. If they are of same size, the 1036412a135SAlexander Belyaev // resulting reshape is not a supported reshape op. 104d6595278SAlexander Belyaev if (producerReassociations.size() == consumerReassociations.size()) 1051a36588eSKazu Hirata return std::nullopt; 106d6595278SAlexander Belyaev if (producerReassociations.size() < consumerReassociations.size()) 107d6595278SAlexander Belyaev std::swap(producerReassociations, consumerReassociations); 1086412a135SAlexander Belyaev 1096412a135SAlexander Belyaev // Handle the corner case of the result being a rank 0 shaped type. Return an 1106412a135SAlexander Belyaev // empty reassociation. 111d6595278SAlexander Belyaev if (consumerReassociations.empty()) 112d6595278SAlexander Belyaev return composedIndices; 113d6595278SAlexander Belyaev 114d6595278SAlexander Belyaev size_t consumerDims = std::accumulate( 115d6595278SAlexander Belyaev consumerReassociations.begin(), consumerReassociations.end(), 0, 116d6595278SAlexander Belyaev [](size_t all, ReassociationIndicesRef indices) { 117d6595278SAlexander Belyaev return all + indices.size(); 118d6595278SAlexander Belyaev }); 119d6595278SAlexander Belyaev if (producerReassociations.size() != consumerDims) 1201a36588eSKazu Hirata return std::nullopt; 1216412a135SAlexander Belyaev 122d6595278SAlexander Belyaev for (ReassociationIndicesRef consumerIndices : consumerReassociations) { 1236412a135SAlexander Belyaev ReassociationIndices reassociations; 124d6595278SAlexander Belyaev for (int64_t consumerIndex : consumerIndices) { 12589d8035eSBenjamin Kramer llvm::append_range(reassociations, producerReassociations[consumerIndex]); 1266412a135SAlexander Belyaev } 127d6595278SAlexander Belyaev composedIndices.push_back(std::move(reassociations)); 1286412a135SAlexander Belyaev } 129d6595278SAlexander Belyaev return composedIndices; 1306412a135SAlexander Belyaev } 1316412a135SAlexander Belyaev 13246ef86b5SAlexander Belyaev SmallVector<SmallVector<AffineExpr, 2>, 2> 13346ef86b5SAlexander Belyaev mlir::convertReassociationIndicesToExprs( 1348ed66cb8SYi Zhang MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) { 13546ef86b5SAlexander Belyaev SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps; 13646ef86b5SAlexander Belyaev for (const auto &indices : reassociationIndices) { 13746ef86b5SAlexander Belyaev SmallVector<AffineExpr, 2> reassociationMap; 13846ef86b5SAlexander Belyaev reassociationMap.reserve(indices.size()); 13946ef86b5SAlexander Belyaev for (int64_t index : indices) 1408ed66cb8SYi Zhang reassociationMap.push_back(mlir::getAffineDimExpr(index, context)); 14146ef86b5SAlexander Belyaev reassociationMaps.push_back(std::move(reassociationMap)); 14246ef86b5SAlexander Belyaev } 14346ef86b5SAlexander Belyaev return reassociationMaps; 14446ef86b5SAlexander Belyaev } 14546ef86b5SAlexander Belyaev 14646ef86b5SAlexander Belyaev template <typename AffineExprTy> 14746ef86b5SAlexander Belyaev unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) { 14846ef86b5SAlexander Belyaev unsigned pos = 0; 14946ef86b5SAlexander Belyaev for (const auto &exprs : exprArrays) { 15046ef86b5SAlexander Belyaev for (auto expr : exprs) { 15146ef86b5SAlexander Belyaev expr.walk([&pos](AffineExpr e) { 1521609f1c2Slong.chen if (auto d = dyn_cast<AffineExprTy>(e)) 15346ef86b5SAlexander Belyaev pos = std::max(pos, d.getPosition()); 15446ef86b5SAlexander Belyaev }); 15546ef86b5SAlexander Belyaev } 15646ef86b5SAlexander Belyaev } 15746ef86b5SAlexander Belyaev return pos; 15846ef86b5SAlexander Belyaev } 15946ef86b5SAlexander Belyaev 16046ef86b5SAlexander Belyaev ArrayAttr mlir::getReassociationIndicesAttribute( 16146ef86b5SAlexander Belyaev OpBuilder &b, ArrayRef<ReassociationIndices> reassociation) { 16246ef86b5SAlexander Belyaev SmallVector<Attribute, 4> reassociationAttr = 16346ef86b5SAlexander Belyaev llvm::to_vector<4>(llvm::map_range( 1641fc096afSMehdi Amini reassociation, [&](const ReassociationIndices &indices) -> Attribute { 1655550c821STres Popp return cast<Attribute>(b.getI64ArrayAttr(indices)); 16646ef86b5SAlexander Belyaev })); 16746ef86b5SAlexander Belyaev return b.getArrayAttr(reassociationAttr); 16846ef86b5SAlexander Belyaev } 16946ef86b5SAlexander Belyaev 17046ef86b5SAlexander Belyaev SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices( 17197069a86SGaurav Shukla ArrayRef<ReassociationExprs> reassociationExprs) { 17246ef86b5SAlexander Belyaev SmallVector<ReassociationIndices, 2> reassociationIndices; 17346ef86b5SAlexander Belyaev for (const auto &exprs : reassociationExprs) { 17446ef86b5SAlexander Belyaev ReassociationIndices indices; 17546ef86b5SAlexander Belyaev indices.reserve(exprs.size()); 17646ef86b5SAlexander Belyaev for (const auto &expr : exprs) 1771609f1c2Slong.chen indices.push_back(cast<AffineDimExpr>(expr).getPosition()); 17846ef86b5SAlexander Belyaev reassociationIndices.push_back(indices); 17946ef86b5SAlexander Belyaev } 18046ef86b5SAlexander Belyaev return reassociationIndices; 18146ef86b5SAlexander Belyaev } 18246ef86b5SAlexander Belyaev 18346ef86b5SAlexander Belyaev SmallVector<AffineMap, 4> 18446ef86b5SAlexander Belyaev mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) { 18546ef86b5SAlexander Belyaev unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation); 18646ef86b5SAlexander Belyaev assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 && 18746ef86b5SAlexander Belyaev "Expected symbol-less expressions"); 18846ef86b5SAlexander Belyaev SmallVector<AffineMap, 4> maps; 18946ef86b5SAlexander Belyaev maps.reserve(reassociation.size()); 19046ef86b5SAlexander Belyaev for (const auto &exprs : reassociation) { 19146ef86b5SAlexander Belyaev assert(!exprs.empty()); 19246ef86b5SAlexander Belyaev maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext())); 19346ef86b5SAlexander Belyaev } 19446ef86b5SAlexander Belyaev return maps; 19546ef86b5SAlexander Belyaev } 196747b10beSAlexander Belyaev 1976412a135SAlexander Belyaev bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation, 1986412a135SAlexander Belyaev int *invalidIndex) { 1996412a135SAlexander Belyaev if (reassociation.empty()) 2006412a135SAlexander Belyaev return true; 2016412a135SAlexander Belyaev unsigned nDims = reassociation[0].getNumDims(); 2026412a135SAlexander Belyaev unsigned nextExpectedDim = 0; 20389de9cc8SMehdi Amini for (const auto &it : llvm::enumerate(reassociation)) { 2046412a135SAlexander Belyaev auto m = it.value(); 2056412a135SAlexander Belyaev if (m.getNumDims() != nDims || m.getNumSymbols() != 0) { 2066412a135SAlexander Belyaev if (invalidIndex) 2076412a135SAlexander Belyaev *invalidIndex = it.index(); 2086412a135SAlexander Belyaev return false; 2096412a135SAlexander Belyaev } 2106412a135SAlexander Belyaev for (auto e : m.getResults()) { 2111609f1c2Slong.chen auto d = dyn_cast<AffineDimExpr>(e); 2126412a135SAlexander Belyaev if (!d || d.getPosition() != nextExpectedDim++) { 2136412a135SAlexander Belyaev if (invalidIndex) 2146412a135SAlexander Belyaev *invalidIndex = it.index(); 2156412a135SAlexander Belyaev return false; 2166412a135SAlexander Belyaev } 2176412a135SAlexander Belyaev } 2186412a135SAlexander Belyaev } 2196412a135SAlexander Belyaev if (nextExpectedDim != nDims) { 2206412a135SAlexander Belyaev if (invalidIndex) 2216412a135SAlexander Belyaev *invalidIndex = reassociation.size() - 1; 2226412a135SAlexander Belyaev return false; 2236412a135SAlexander Belyaev } 2246412a135SAlexander Belyaev return true; 2256412a135SAlexander Belyaev } 226ff5de8a9SBenjamin Kramer 227ff5de8a9SBenjamin Kramer LogicalResult mlir::reshapeLikeShapesAreCompatible( 228ff5de8a9SBenjamin Kramer function_ref<LogicalResult(const Twine &)> emitError, 229ff5de8a9SBenjamin Kramer ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape, 230ff5de8a9SBenjamin Kramer ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) { 231ff5de8a9SBenjamin Kramer unsigned expandedDimStart = 0; 232ff5de8a9SBenjamin Kramer for (const auto &map : llvm::enumerate(reassociationMaps)) { 23397069a86SGaurav Shukla bool foundDynamicShape = false; 234ff5de8a9SBenjamin Kramer int64_t linearizedStaticShape = 1; 23597069a86SGaurav Shukla 236ff5de8a9SBenjamin Kramer for (const auto &dim : llvm::enumerate( 237ff5de8a9SBenjamin Kramer expandedShape.slice(expandedDimStart, map.value().size()))) { 23897069a86SGaurav Shukla if (ShapedType::isDynamic(dim.value())) 23997069a86SGaurav Shukla foundDynamicShape = true; 24097069a86SGaurav Shukla else 241ff5de8a9SBenjamin Kramer linearizedStaticShape *= dim.value(); 242ff5de8a9SBenjamin Kramer } 24397069a86SGaurav Shukla if (foundDynamicShape) { 244ff5de8a9SBenjamin Kramer if (!ShapedType::isDynamic(collapsedShape[map.index()])) { 245ff5de8a9SBenjamin Kramer return emitError( 246ff5de8a9SBenjamin Kramer "expected dimension " + Twine(map.index()) + 247ff5de8a9SBenjamin Kramer " of collapsed type to be dynamic since one or more of the " 248ff5de8a9SBenjamin Kramer "corresponding dimensions in the expanded type is dynamic"); 249ff5de8a9SBenjamin Kramer } 250ff5de8a9SBenjamin Kramer } else { 251ff5de8a9SBenjamin Kramer if (collapsedShape[map.index()] != linearizedStaticShape) { 252ff5de8a9SBenjamin Kramer return emitError("expected dimension " + Twine(map.index()) + 253ff5de8a9SBenjamin Kramer " of collapsed type to be static value of " + 254ff5de8a9SBenjamin Kramer Twine(linearizedStaticShape)); 255ff5de8a9SBenjamin Kramer } 256ff5de8a9SBenjamin Kramer } 257ff5de8a9SBenjamin Kramer expandedDimStart += map.value().size(); 258ff5de8a9SBenjamin Kramer } 259ff5de8a9SBenjamin Kramer return success(); 260ff5de8a9SBenjamin Kramer } 261747b10beSAlexander Belyaev 262747b10beSAlexander Belyaev bool mlir::hasNonIdentityLayout(Type type) { 2635550c821STres Popp if (auto memrefType = dyn_cast<MemRefType>(type)) 264747b10beSAlexander Belyaev return !memrefType.getLayout().isIdentity(); 265747b10beSAlexander Belyaev return false; 266747b10beSAlexander Belyaev } 267f4a478cdSChristopher Bate 268f4a478cdSChristopher Bate llvm::SmallBitVector 269f4a478cdSChristopher Bate mlir::getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape, 270f4a478cdSChristopher Bate ArrayRef<Range> sliceParams) { 271f4a478cdSChristopher Bate assert(sliceParams.size() == sliceInputShape.size() && 272f4a478cdSChristopher Bate "only supports non rank-reducing case"); 273f4a478cdSChristopher Bate llvm::SmallBitVector mask(sliceInputShape.size()); 274f4a478cdSChristopher Bate unsigned idx = 0; 275f4a478cdSChristopher Bate for (const auto &[offset, size, stride] : sliceParams) { 27622426110SRamkumar Ramachandra std::optional<int64_t> offsetConst = getConstantIntValue(offset); 27722426110SRamkumar Ramachandra std::optional<int64_t> strideConst = getConstantIntValue(stride); 278f4a478cdSChristopher Bate mask[idx] = !isEqualConstantIntOrValue(size, sliceInputShape[idx]) || 279f4a478cdSChristopher Bate (!strideConst || *strideConst != 1) || 280f4a478cdSChristopher Bate (!offsetConst || *offsetConst != 0); 281f4a478cdSChristopher Bate idx++; 282f4a478cdSChristopher Bate } 283f4a478cdSChristopher Bate return mask; 284f4a478cdSChristopher Bate } 285f4a478cdSChristopher Bate 286f4a478cdSChristopher Bate llvm::SmallBitVector mlir::getLinearizedDimensions( 287f4a478cdSChristopher Bate ArrayRef<ReassociationIndices> reassociationIndices) { 288f4a478cdSChristopher Bate llvm::SmallBitVector result(reassociationIndices.size()); 289f4a478cdSChristopher Bate for (const auto &it : llvm::enumerate(reassociationIndices)) 290f4a478cdSChristopher Bate result[it.index()] = it.value().size() > 1; 291f4a478cdSChristopher Bate return result; 292f4a478cdSChristopher Bate } 293f4a478cdSChristopher Bate 294f4a478cdSChristopher Bate SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams( 2954d27f06fSChristopher Bate MLIRContext *ctx, ArrayRef<ValueRange> multiIndices) { 296f4a478cdSChristopher Bate unsigned loopIdx = 0; 297f4a478cdSChristopher Bate auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1); 298f4a478cdSChristopher Bate auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0); 299f4a478cdSChristopher Bate SmallVector<Range> offsetsSizesAndStrides; 300f4a478cdSChristopher Bate offsetsSizesAndStrides.reserve(collapseShapeInputShape.size()); 301f4a478cdSChristopher Bate for (const auto &it : llvm::enumerate(reassociationIndices)) { 302f4a478cdSChristopher Bate // Case 1: Linearized dimensions that have also been sliced. These 303f4a478cdSChristopher Bate // are size of 1 because we are iterating over these dimensions. The 304f4a478cdSChristopher Bate // offsets are exactly the de-linearized multi-indices. 305f4a478cdSChristopher Bate if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) { 306f4a478cdSChristopher Bate llvm::append_range( 307f4a478cdSChristopher Bate offsetsSizesAndStrides, 308f4a478cdSChristopher Bate llvm::map_range(multiIndices[loopIdx++], [&](Value v) -> Range { 309f4a478cdSChristopher Bate return Range{getAsOpFoldResult(v), oneAttr, oneAttr}; 310f4a478cdSChristopher Bate })); 311f4a478cdSChristopher Bate continue; 312f4a478cdSChristopher Bate } 313f4a478cdSChristopher Bate 314f4a478cdSChristopher Bate // Case 2: One or possibly multiple combined input dimensions, but we 315f4a478cdSChristopher Bate // have proven that these are not sliced. In this case we just take 316f4a478cdSChristopher Bate // the full extent of each dimension in the reassociation list. 317f4a478cdSChristopher Bate if (linearizedDimensions[it.index()]) { 318f4a478cdSChristopher Bate llvm::append_range( 319f4a478cdSChristopher Bate offsetsSizesAndStrides, 320f4a478cdSChristopher Bate llvm::map_range(it.value(), [&](int64_t idx) -> Range { 321f4a478cdSChristopher Bate return {zeroAttr, collapseShapeInputShape[idx], oneAttr}; 322f4a478cdSChristopher Bate })); 323f4a478cdSChristopher Bate continue; 324f4a478cdSChristopher Bate } 325f4a478cdSChristopher Bate 326f4a478cdSChristopher Bate // Case 3: A single index, but it may be sliced. 327f4a478cdSChristopher Bate offsetsSizesAndStrides.push_back(sliceParams[it.index()]); 328f4a478cdSChristopher Bate } 329f4a478cdSChristopher Bate return offsetsSizesAndStrides; 330f4a478cdSChristopher Bate } 331f4a478cdSChristopher Bate 332f4a478cdSChristopher Bate SmallVector<Range> 3334d27f06fSChristopher Bate SliceFromCollapseHelper::getInsertSliceParams(MLIRContext *ctx, 3344d27f06fSChristopher Bate ValueRange tileIndices) { 335f4a478cdSChristopher Bate auto one = IntegerAttr::get(IndexType::get(ctx), 1); 336f4a478cdSChristopher Bate auto zero = IntegerAttr::get(IndexType::get(ctx), 0); 337f4a478cdSChristopher Bate SmallVector<Range> insertParams; 338f4a478cdSChristopher Bate insertParams.reserve(linearizedDimensions.size()); 339f4a478cdSChristopher Bate unsigned loopIdx = 0; 340f4a478cdSChristopher Bate for (unsigned i = 0; i < linearizedDimensions.size(); i++) { 341f4a478cdSChristopher Bate if (linearizedDimensions[i] && slicedDimensions[i]) { 342f4a478cdSChristopher Bate insertParams.push_back(Range{tileIndices[loopIdx++], one, one}); 343f4a478cdSChristopher Bate continue; 344f4a478cdSChristopher Bate } 345f4a478cdSChristopher Bate insertParams.push_back(Range{zero, sliceParams[i].size, one}); 346f4a478cdSChristopher Bate } 347f4a478cdSChristopher Bate return insertParams; 348f4a478cdSChristopher Bate } 349446981bdSChristopher Bate 350446981bdSChristopher Bate /// Returns the index of the only non-unit dimension among `indices` of `shape`, 351446981bdSChristopher Bate /// if such a dimension exists and `indices` has more than one element. 352f09b0e35SKazu Hirata /// Otherwise, return std::nullopt. 3530a81ace0SKazu Hirata static std::optional<int64_t> getUniqueNonUnitDim(ArrayRef<int64_t> indices, 354446981bdSChristopher Bate ArrayRef<int64_t> shape) { 355446981bdSChristopher Bate // Return false if more than one of the dimensions in this group are not 1. 35691682b26SKazu Hirata std::optional<int64_t> dimIndex; 357446981bdSChristopher Bate if (indices.size() < 2) 3581a36588eSKazu Hirata return std::nullopt; 359446981bdSChristopher Bate for (int64_t idx : indices) { 360446981bdSChristopher Bate if (shape[idx] != 1) { 3611a36588eSKazu Hirata if (dimIndex != std::nullopt) 3621a36588eSKazu Hirata return std::nullopt; 363446981bdSChristopher Bate dimIndex = idx; 364446981bdSChristopher Bate } 365446981bdSChristopher Bate } 366446981bdSChristopher Bate return dimIndex; 367446981bdSChristopher Bate } 368446981bdSChristopher Bate 369446981bdSChristopher Bate // For each segment in the reassociation indices, check whether we can 370446981bdSChristopher Bate // simplify that segment with a rank-reducing extract slice. We can do this if 371446981bdSChristopher Bate // all but (exactly) one of the corresponding source dims is 1. 3720a81ace0SKazu Hirata static SmallVector<std::optional<int64_t>> getCollapseShapeTrivialSegments( 373446981bdSChristopher Bate RankedTensorType sourceType, 374446981bdSChristopher Bate ArrayRef<ReassociationIndices> reassociationIndices) { 3750a81ace0SKazu Hirata SmallVector<std::optional<int64_t>> trivialSegments; 376446981bdSChristopher Bate for (const auto &indices : reassociationIndices) 377446981bdSChristopher Bate trivialSegments.push_back( 378446981bdSChristopher Bate getUniqueNonUnitDim(indices, sourceType.getShape())); 379446981bdSChristopher Bate return trivialSegments; 380446981bdSChristopher Bate } 381446981bdSChristopher Bate 382446981bdSChristopher Bate /// Returns true if any of the segments of the reassociation indices for a 383446981bdSChristopher Bate /// collapsing reshape can be simplified using a rank-reducing slice. 3840a81ace0SKazu Hirata static FailureOr<SmallVector<std::optional<int64_t>>> 385446981bdSChristopher Bate canCollapseShapeBeSimplifiedByRankReducingSlice( 386446981bdSChristopher Bate RankedTensorType sourceType, 387446981bdSChristopher Bate ArrayRef<ReassociationIndices> reassociationIndices) { 3880a81ace0SKazu Hirata SmallVector<std::optional<int64_t>> trivialSegments = 389446981bdSChristopher Bate getCollapseShapeTrivialSegments(sourceType, reassociationIndices); 3900a81ace0SKazu Hirata if (!llvm::any_of(trivialSegments, [](const std::optional<int64_t> &idx) { 391446981bdSChristopher Bate return idx.has_value(); 392446981bdSChristopher Bate })) 393446981bdSChristopher Bate return failure(); 394446981bdSChristopher Bate return trivialSegments; 395446981bdSChristopher Bate } 396446981bdSChristopher Bate 397446981bdSChristopher Bate FailureOr<CollapseShapeRankReducingSliceSimplificationInfo> 398446981bdSChristopher Bate mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo( 399446981bdSChristopher Bate RankedTensorType sourceType, 400446981bdSChristopher Bate ArrayRef<ReassociationIndices> reassociationIndices) { 4010a81ace0SKazu Hirata FailureOr<SmallVector<std::optional<int64_t>>> trivialSegments = 402446981bdSChristopher Bate canCollapseShapeBeSimplifiedByRankReducingSlice(sourceType, 403446981bdSChristopher Bate reassociationIndices); 404446981bdSChristopher Bate if (failed(trivialSegments)) 405446981bdSChristopher Bate return failure(); 406446981bdSChristopher Bate 407446981bdSChristopher Bate // Create the expected result shape of the rank-reducing slice. 408446981bdSChristopher Bate SmallVector<int64_t> sliceShape; 409446981bdSChristopher Bate for (const auto &[nonUnitDim, indices] : 410446981bdSChristopher Bate llvm::zip(*trivialSegments, reassociationIndices)) { 411446981bdSChristopher Bate if (nonUnitDim) { 412cbb09813SFangrui Song sliceShape.push_back(sourceType.getDimSize(*nonUnitDim)); 413446981bdSChristopher Bate continue; 414446981bdSChristopher Bate } 415446981bdSChristopher Bate llvm::append_range(sliceShape, llvm::map_range(indices, [&](int64_t idx) { 416446981bdSChristopher Bate return sourceType.getDimSize(idx); 417446981bdSChristopher Bate })); 418446981bdSChristopher Bate } 419446981bdSChristopher Bate auto sliceType = 420446981bdSChristopher Bate RankedTensorType::get(sliceShape, sourceType.getElementType()); 421446981bdSChristopher Bate 422446981bdSChristopher Bate // If the rank-reducing slice simplified every segment, then we are done. 423446981bdSChristopher Bate if (sliceShape.size() == reassociationIndices.size()) 4241a36588eSKazu Hirata return CollapseShapeRankReducingSliceSimplificationInfo{sliceType, 4251a36588eSKazu Hirata std::nullopt}; 426446981bdSChristopher Bate 427446981bdSChristopher Bate // Otherwise, we need to create a new collapse_shape op for the segments that 428446981bdSChristopher Bate // weren't covered by the slice. By design, the new reassociation indices has 429446981bdSChristopher Bate // the same number of groups as the old reassociation indices. 430446981bdSChristopher Bate SmallVector<ReassociationIndices> newReassociationIndices; 431446981bdSChristopher Bate SmallVector<int64_t, 2> reassociation; 432446981bdSChristopher Bate int64_t groupIdx = 0; 433446981bdSChristopher Bate for (int64_t dimIdx = 0; dimIdx < sliceType.getRank(); dimIdx++) { 434446981bdSChristopher Bate reassociation.push_back(dimIdx); 435446981bdSChristopher Bate if ((*trivialSegments)[groupIdx] || 436446981bdSChristopher Bate reassociation.size() == reassociationIndices[groupIdx].size()) { 437446981bdSChristopher Bate newReassociationIndices.push_back(reassociation); 438446981bdSChristopher Bate reassociation.clear(); 439446981bdSChristopher Bate groupIdx++; 440446981bdSChristopher Bate } 441446981bdSChristopher Bate } 442446981bdSChristopher Bate 443446981bdSChristopher Bate return CollapseShapeRankReducingSliceSimplificationInfo{ 444446981bdSChristopher Bate sliceType, newReassociationIndices}; 445446981bdSChristopher Bate } 4460bfbecf5SQuentin Colombet 4470bfbecf5SQuentin Colombet PackingMetadata mlir::computePackingMetadata(int64_t packedRank, 4480bfbecf5SQuentin Colombet ArrayRef<int64_t> innerDimPos) { 4490bfbecf5SQuentin Colombet PackingMetadata res; 4500bfbecf5SQuentin Colombet res.insertPositions.reserve(innerDimPos.size()); 4510bfbecf5SQuentin Colombet // The pack insert position is the position + the number of previously 4520bfbecf5SQuentin Colombet // inserted positions + offset. 4530bfbecf5SQuentin Colombet // The offset controls whether the packing dimension is the first or last. 4540bfbecf5SQuentin Colombet // 4550bfbecf5SQuentin Colombet // Example 4560bfbecf5SQuentin Colombet // ======= 4570bfbecf5SQuentin Colombet // Consider packing from a hypothetical ABCD layout to ABCDba whose 4580bfbecf5SQuentin Colombet // pack.inner_dims is [1, 0]. The first step consists in undoing the 4590bfbecf5SQuentin Colombet // permutation and producing AaBbCD. This is achieved purely by computing the 4600bfbecf5SQuentin Colombet // insert positions of `b` and `a` into `ABCD`, starting from [1, 0]. One 4610bfbecf5SQuentin Colombet // possibility, is to produce insert positions [2, 0], this would result in an 4620bfbecf5SQuentin Colombet // aAbBCD layout (i.e. offset 0). The other possibility, is to produce insert 4630bfbecf5SQuentin Colombet // positions [3, 1], this would result in an AaBbCD layout (i.e. offset 1). 4640bfbecf5SQuentin Colombet // The latter is what we expect from packing. 4650bfbecf5SQuentin Colombet int64_t offset = 1; 4660bfbecf5SQuentin Colombet for (int64_t pos : innerDimPos) { 4670bfbecf5SQuentin Colombet int64_t numInsertedBefore = llvm::count_if( 4680bfbecf5SQuentin Colombet innerDimPos, [&pos](int64_t pos2) { return pos > pos2; }); 4690bfbecf5SQuentin Colombet res.insertPositions.push_back(pos + numInsertedBefore + offset); 4700bfbecf5SQuentin Colombet } 4710bfbecf5SQuentin Colombet 4720bfbecf5SQuentin Colombet DenseSet<int64_t> posSet(res.insertPositions.begin(), 4730bfbecf5SQuentin Colombet res.insertPositions.end()); 4740bfbecf5SQuentin Colombet res.reassociations.reserve(packedRank); 4750bfbecf5SQuentin Colombet for (int64_t i = 1; i <= packedRank; ++i) { 4766f87b50bSHanhan Wang res.outerPositions.push_back(i - 1); 4770bfbecf5SQuentin Colombet if (!posSet.contains(i)) { 4780bfbecf5SQuentin Colombet res.reassociations.push_back(ReassociationIndices{i - 1}); 4790bfbecf5SQuentin Colombet continue; 4800bfbecf5SQuentin Colombet } 4810bfbecf5SQuentin Colombet res.reassociations.push_back(ReassociationIndices{i - 1, i}); 4820bfbecf5SQuentin Colombet ++i; 4830bfbecf5SQuentin Colombet } 4840bfbecf5SQuentin Colombet return res; 4850bfbecf5SQuentin Colombet } 486