xref: /llvm-project/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp (revision 2f15d7e43e17f72839861bfe3a5c466c325bc04d)
16412a135SAlexander Belyaev //===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===//
26412a135SAlexander Belyaev //
36412a135SAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
46412a135SAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information.
56412a135SAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66412a135SAlexander Belyaev //
76412a135SAlexander Belyaev //===----------------------------------------------------------------------===//
86412a135SAlexander Belyaev 
96412a135SAlexander Belyaev #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
106412a135SAlexander Belyaev 
116412a135SAlexander Belyaev #include "mlir/IR/AffineMap.h"
126412a135SAlexander Belyaev #include "mlir/IR/Builders.h"
136412a135SAlexander Belyaev 
14d6595278SAlexander Belyaev #include <numeric>
15a1fe1f5fSKazu Hirata #include <optional>
16d6595278SAlexander Belyaev 
176412a135SAlexander Belyaev using namespace mlir;
186412a135SAlexander Belyaev 
190a81ace0SKazu Hirata std::optional<SmallVector<ReassociationIndices>>
206412a135SAlexander Belyaev mlir::getReassociationIndicesForReshape(ShapedType sourceType,
216412a135SAlexander Belyaev                                         ShapedType targetType) {
22747b10beSAlexander Belyaev   if (sourceType.getRank() > targetType.getRank())
23747b10beSAlexander Belyaev     return getReassociationIndicesForCollapse(sourceType.getShape(),
24747b10beSAlexander Belyaev                                               targetType.getShape());
256412a135SAlexander Belyaev   if (sourceType.getRank() < targetType.getRank())
26747b10beSAlexander Belyaev     return getReassociationIndicesForCollapse(targetType.getShape(),
27747b10beSAlexander Belyaev                                               sourceType.getShape());
281a36588eSKazu Hirata   return std::nullopt;
29747b10beSAlexander Belyaev }
306412a135SAlexander Belyaev 
310a81ace0SKazu Hirata std::optional<SmallVector<ReassociationIndices>>
32747b10beSAlexander Belyaev mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
33747b10beSAlexander Belyaev                                          ArrayRef<int64_t> targetShape) {
34747b10beSAlexander Belyaev   if (sourceShape.size() <= targetShape.size())
351a36588eSKazu Hirata     return std::nullopt;
366412a135SAlexander Belyaev   unsigned sourceDim = 0;
376412a135SAlexander Belyaev   SmallVector<ReassociationIndices> reassociationMap;
38747b10beSAlexander Belyaev   reassociationMap.reserve(targetShape.size());
396412a135SAlexander Belyaev 
406412a135SAlexander Belyaev   ReassociationIndices currIndices;
416412a135SAlexander Belyaev   int64_t prodOfCollapsedDims = 1;
426412a135SAlexander Belyaev   while (sourceDim < sourceShape.size()) {
436412a135SAlexander Belyaev     unsigned targetDim = reassociationMap.size();
44a43f7d6dSStephan Herhut     // If we have mapped all the target dimensions stop and handle the remaining
453f222f3bSc8ef     // tail of size-1 dimensions explicitly.
46747b10beSAlexander Belyaev     if (targetDim == targetShape.size())
47a43f7d6dSStephan Herhut       break;
486412a135SAlexander Belyaev 
49a43f7d6dSStephan Herhut     int64_t currTargetShape = targetShape[targetDim];
50*2f15d7e4SVinayak Dev     while (sourceDim < (sourceShape.size() - 1) &&
51ded3644aSShivam Gupta            sourceShape[sourceDim] != ShapedType::kDynamic &&
52ded3644aSShivam Gupta            prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
536412a135SAlexander Belyaev       prodOfCollapsedDims *= sourceShape[sourceDim];
546412a135SAlexander Belyaev       currIndices.push_back(sourceDim++);
556412a135SAlexander Belyaev     }
566412a135SAlexander Belyaev 
576412a135SAlexander Belyaev     // If the current expanded dimension is dynamic, then the collapsed
586412a135SAlexander Belyaev     // dimensions should also be dynamic and product of all previous unprocessed
596412a135SAlexander Belyaev     // dimensions of the expanded shape should be 1.
60399638f9SAliia Khasanova     if (sourceShape[sourceDim] == ShapedType::kDynamic &&
6122426110SRamkumar Ramachandra         (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
621a36588eSKazu Hirata       return std::nullopt;
636412a135SAlexander Belyaev 
646412a135SAlexander Belyaev     // If the collapsed dim is dynamic, the current expanded dim should also
656412a135SAlexander Belyaev     // be dynamic.
66399638f9SAliia Khasanova     if (currTargetShape == ShapedType::kDynamic &&
67399638f9SAliia Khasanova         sourceShape[sourceDim] != ShapedType::kDynamic)
681a36588eSKazu Hirata       return std::nullopt;
696412a135SAlexander Belyaev 
706412a135SAlexander Belyaev     // For static shapes, if the product of dimensions of the expanded shape
716412a135SAlexander Belyaev     // should match the collapsed dimension shape.
726412a135SAlexander Belyaev     if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
731a36588eSKazu Hirata       return std::nullopt;
746412a135SAlexander Belyaev 
756412a135SAlexander Belyaev     currIndices.push_back(sourceDim++);
766412a135SAlexander Belyaev     reassociationMap.emplace_back(ReassociationIndices{});
776412a135SAlexander Belyaev     std::swap(reassociationMap.back(), currIndices);
786412a135SAlexander Belyaev     prodOfCollapsedDims = 1;
796412a135SAlexander Belyaev   }
80a43f7d6dSStephan Herhut   // All the dimensions in the target must have been processed.
81a43f7d6dSStephan Herhut   if (reassociationMap.size() != targetShape.size())
821a36588eSKazu Hirata     return std::nullopt;
83a43f7d6dSStephan Herhut   // Process any remaining entries in the source shape. They all need to be
84a43f7d6dSStephan Herhut   // 1 or dynamic.
85a43f7d6dSStephan Herhut   for (; sourceDim < sourceShape.size(); sourceDim++) {
86399638f9SAliia Khasanova     if (sourceShape[sourceDim] != ShapedType::kDynamic &&
87a43f7d6dSStephan Herhut         sourceShape[sourceDim] != 1)
881a36588eSKazu Hirata       return std::nullopt;
89a43f7d6dSStephan Herhut     // The map is empty when the target type is a scalar.
90a43f7d6dSStephan Herhut     if (!reassociationMap.empty())
91a43f7d6dSStephan Herhut       reassociationMap.back().push_back(sourceDim);
92a43f7d6dSStephan Herhut   }
936412a135SAlexander Belyaev   return reassociationMap;
946412a135SAlexander Belyaev }
956412a135SAlexander Belyaev 
960a81ace0SKazu Hirata std::optional<SmallVector<ReassociationIndices>>
970a81ace0SKazu Hirata mlir::composeReassociationIndices(
98d6595278SAlexander Belyaev     ArrayRef<ReassociationIndices> producerReassociations,
99d6595278SAlexander Belyaev     ArrayRef<ReassociationIndices> consumerReassociations,
1006412a135SAlexander Belyaev     MLIRContext *context) {
101d6595278SAlexander Belyaev   SmallVector<ReassociationIndices> composedIndices;
1026412a135SAlexander Belyaev   // Make the producer the larger sized vector. If they are of same size, the
1036412a135SAlexander Belyaev   // resulting reshape is not a supported reshape op.
104d6595278SAlexander Belyaev   if (producerReassociations.size() == consumerReassociations.size())
1051a36588eSKazu Hirata     return std::nullopt;
106d6595278SAlexander Belyaev   if (producerReassociations.size() < consumerReassociations.size())
107d6595278SAlexander Belyaev     std::swap(producerReassociations, consumerReassociations);
1086412a135SAlexander Belyaev 
1096412a135SAlexander Belyaev   // Handle the corner case of the result being a rank 0 shaped type. Return an
1106412a135SAlexander Belyaev   // empty reassociation.
111d6595278SAlexander Belyaev   if (consumerReassociations.empty())
112d6595278SAlexander Belyaev     return composedIndices;
113d6595278SAlexander Belyaev 
114d6595278SAlexander Belyaev   size_t consumerDims = std::accumulate(
115d6595278SAlexander Belyaev       consumerReassociations.begin(), consumerReassociations.end(), 0,
116d6595278SAlexander Belyaev       [](size_t all, ReassociationIndicesRef indices) {
117d6595278SAlexander Belyaev         return all + indices.size();
118d6595278SAlexander Belyaev       });
119d6595278SAlexander Belyaev   if (producerReassociations.size() != consumerDims)
1201a36588eSKazu Hirata     return std::nullopt;
1216412a135SAlexander Belyaev 
122d6595278SAlexander Belyaev   for (ReassociationIndicesRef consumerIndices : consumerReassociations) {
1236412a135SAlexander Belyaev     ReassociationIndices reassociations;
124d6595278SAlexander Belyaev     for (int64_t consumerIndex : consumerIndices) {
12589d8035eSBenjamin Kramer       llvm::append_range(reassociations, producerReassociations[consumerIndex]);
1266412a135SAlexander Belyaev     }
127d6595278SAlexander Belyaev     composedIndices.push_back(std::move(reassociations));
1286412a135SAlexander Belyaev   }
129d6595278SAlexander Belyaev   return composedIndices;
1306412a135SAlexander Belyaev }
1316412a135SAlexander Belyaev 
13246ef86b5SAlexander Belyaev SmallVector<SmallVector<AffineExpr, 2>, 2>
13346ef86b5SAlexander Belyaev mlir::convertReassociationIndicesToExprs(
1348ed66cb8SYi Zhang     MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) {
13546ef86b5SAlexander Belyaev   SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
13646ef86b5SAlexander Belyaev   for (const auto &indices : reassociationIndices) {
13746ef86b5SAlexander Belyaev     SmallVector<AffineExpr, 2> reassociationMap;
13846ef86b5SAlexander Belyaev     reassociationMap.reserve(indices.size());
13946ef86b5SAlexander Belyaev     for (int64_t index : indices)
1408ed66cb8SYi Zhang       reassociationMap.push_back(mlir::getAffineDimExpr(index, context));
14146ef86b5SAlexander Belyaev     reassociationMaps.push_back(std::move(reassociationMap));
14246ef86b5SAlexander Belyaev   }
14346ef86b5SAlexander Belyaev   return reassociationMaps;
14446ef86b5SAlexander Belyaev }
14546ef86b5SAlexander Belyaev 
14646ef86b5SAlexander Belyaev template <typename AffineExprTy>
14746ef86b5SAlexander Belyaev unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
14846ef86b5SAlexander Belyaev   unsigned pos = 0;
14946ef86b5SAlexander Belyaev   for (const auto &exprs : exprArrays) {
15046ef86b5SAlexander Belyaev     for (auto expr : exprs) {
15146ef86b5SAlexander Belyaev       expr.walk([&pos](AffineExpr e) {
1521609f1c2Slong.chen         if (auto d = dyn_cast<AffineExprTy>(e))
15346ef86b5SAlexander Belyaev           pos = std::max(pos, d.getPosition());
15446ef86b5SAlexander Belyaev       });
15546ef86b5SAlexander Belyaev     }
15646ef86b5SAlexander Belyaev   }
15746ef86b5SAlexander Belyaev   return pos;
15846ef86b5SAlexander Belyaev }
15946ef86b5SAlexander Belyaev 
16046ef86b5SAlexander Belyaev ArrayAttr mlir::getReassociationIndicesAttribute(
16146ef86b5SAlexander Belyaev     OpBuilder &b, ArrayRef<ReassociationIndices> reassociation) {
16246ef86b5SAlexander Belyaev   SmallVector<Attribute, 4> reassociationAttr =
16346ef86b5SAlexander Belyaev       llvm::to_vector<4>(llvm::map_range(
1641fc096afSMehdi Amini           reassociation, [&](const ReassociationIndices &indices) -> Attribute {
1655550c821STres Popp             return cast<Attribute>(b.getI64ArrayAttr(indices));
16646ef86b5SAlexander Belyaev           }));
16746ef86b5SAlexander Belyaev   return b.getArrayAttr(reassociationAttr);
16846ef86b5SAlexander Belyaev }
16946ef86b5SAlexander Belyaev 
17046ef86b5SAlexander Belyaev SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices(
17197069a86SGaurav Shukla     ArrayRef<ReassociationExprs> reassociationExprs) {
17246ef86b5SAlexander Belyaev   SmallVector<ReassociationIndices, 2> reassociationIndices;
17346ef86b5SAlexander Belyaev   for (const auto &exprs : reassociationExprs) {
17446ef86b5SAlexander Belyaev     ReassociationIndices indices;
17546ef86b5SAlexander Belyaev     indices.reserve(exprs.size());
17646ef86b5SAlexander Belyaev     for (const auto &expr : exprs)
1771609f1c2Slong.chen       indices.push_back(cast<AffineDimExpr>(expr).getPosition());
17846ef86b5SAlexander Belyaev     reassociationIndices.push_back(indices);
17946ef86b5SAlexander Belyaev   }
18046ef86b5SAlexander Belyaev   return reassociationIndices;
18146ef86b5SAlexander Belyaev }
18246ef86b5SAlexander Belyaev 
18346ef86b5SAlexander Belyaev SmallVector<AffineMap, 4>
18446ef86b5SAlexander Belyaev mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
18546ef86b5SAlexander Belyaev   unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
18646ef86b5SAlexander Belyaev   assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
18746ef86b5SAlexander Belyaev          "Expected symbol-less expressions");
18846ef86b5SAlexander Belyaev   SmallVector<AffineMap, 4> maps;
18946ef86b5SAlexander Belyaev   maps.reserve(reassociation.size());
19046ef86b5SAlexander Belyaev   for (const auto &exprs : reassociation) {
19146ef86b5SAlexander Belyaev     assert(!exprs.empty());
19246ef86b5SAlexander Belyaev     maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
19346ef86b5SAlexander Belyaev   }
19446ef86b5SAlexander Belyaev   return maps;
19546ef86b5SAlexander Belyaev }
196747b10beSAlexander Belyaev 
1976412a135SAlexander Belyaev bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
1986412a135SAlexander Belyaev                                 int *invalidIndex) {
1996412a135SAlexander Belyaev   if (reassociation.empty())
2006412a135SAlexander Belyaev     return true;
2016412a135SAlexander Belyaev   unsigned nDims = reassociation[0].getNumDims();
2026412a135SAlexander Belyaev   unsigned nextExpectedDim = 0;
20389de9cc8SMehdi Amini   for (const auto &it : llvm::enumerate(reassociation)) {
2046412a135SAlexander Belyaev     auto m = it.value();
2056412a135SAlexander Belyaev     if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
2066412a135SAlexander Belyaev       if (invalidIndex)
2076412a135SAlexander Belyaev         *invalidIndex = it.index();
2086412a135SAlexander Belyaev       return false;
2096412a135SAlexander Belyaev     }
2106412a135SAlexander Belyaev     for (auto e : m.getResults()) {
2111609f1c2Slong.chen       auto d = dyn_cast<AffineDimExpr>(e);
2126412a135SAlexander Belyaev       if (!d || d.getPosition() != nextExpectedDim++) {
2136412a135SAlexander Belyaev         if (invalidIndex)
2146412a135SAlexander Belyaev           *invalidIndex = it.index();
2156412a135SAlexander Belyaev         return false;
2166412a135SAlexander Belyaev       }
2176412a135SAlexander Belyaev     }
2186412a135SAlexander Belyaev   }
2196412a135SAlexander Belyaev   if (nextExpectedDim != nDims) {
2206412a135SAlexander Belyaev     if (invalidIndex)
2216412a135SAlexander Belyaev       *invalidIndex = reassociation.size() - 1;
2226412a135SAlexander Belyaev     return false;
2236412a135SAlexander Belyaev   }
2246412a135SAlexander Belyaev   return true;
2256412a135SAlexander Belyaev }
226ff5de8a9SBenjamin Kramer 
227ff5de8a9SBenjamin Kramer LogicalResult mlir::reshapeLikeShapesAreCompatible(
228ff5de8a9SBenjamin Kramer     function_ref<LogicalResult(const Twine &)> emitError,
229ff5de8a9SBenjamin Kramer     ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
230ff5de8a9SBenjamin Kramer     ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) {
231ff5de8a9SBenjamin Kramer   unsigned expandedDimStart = 0;
232ff5de8a9SBenjamin Kramer   for (const auto &map : llvm::enumerate(reassociationMaps)) {
23397069a86SGaurav Shukla     bool foundDynamicShape = false;
234ff5de8a9SBenjamin Kramer     int64_t linearizedStaticShape = 1;
23597069a86SGaurav Shukla 
236ff5de8a9SBenjamin Kramer     for (const auto &dim : llvm::enumerate(
237ff5de8a9SBenjamin Kramer              expandedShape.slice(expandedDimStart, map.value().size()))) {
23897069a86SGaurav Shukla       if (ShapedType::isDynamic(dim.value()))
23997069a86SGaurav Shukla         foundDynamicShape = true;
24097069a86SGaurav Shukla       else
241ff5de8a9SBenjamin Kramer         linearizedStaticShape *= dim.value();
242ff5de8a9SBenjamin Kramer     }
24397069a86SGaurav Shukla     if (foundDynamicShape) {
244ff5de8a9SBenjamin Kramer       if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
245ff5de8a9SBenjamin Kramer         return emitError(
246ff5de8a9SBenjamin Kramer             "expected dimension " + Twine(map.index()) +
247ff5de8a9SBenjamin Kramer             " of collapsed type to be dynamic since one or more of the "
248ff5de8a9SBenjamin Kramer             "corresponding dimensions in the expanded type is dynamic");
249ff5de8a9SBenjamin Kramer       }
250ff5de8a9SBenjamin Kramer     } else {
251ff5de8a9SBenjamin Kramer       if (collapsedShape[map.index()] != linearizedStaticShape) {
252ff5de8a9SBenjamin Kramer         return emitError("expected dimension " + Twine(map.index()) +
253ff5de8a9SBenjamin Kramer                          " of collapsed type to be static value of " +
254ff5de8a9SBenjamin Kramer                          Twine(linearizedStaticShape));
255ff5de8a9SBenjamin Kramer       }
256ff5de8a9SBenjamin Kramer     }
257ff5de8a9SBenjamin Kramer     expandedDimStart += map.value().size();
258ff5de8a9SBenjamin Kramer   }
259ff5de8a9SBenjamin Kramer   return success();
260ff5de8a9SBenjamin Kramer }
261747b10beSAlexander Belyaev 
262747b10beSAlexander Belyaev bool mlir::hasNonIdentityLayout(Type type) {
2635550c821STres Popp   if (auto memrefType = dyn_cast<MemRefType>(type))
264747b10beSAlexander Belyaev     return !memrefType.getLayout().isIdentity();
265747b10beSAlexander Belyaev   return false;
266747b10beSAlexander Belyaev }
267f4a478cdSChristopher Bate 
268f4a478cdSChristopher Bate llvm::SmallBitVector
269f4a478cdSChristopher Bate mlir::getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape,
270f4a478cdSChristopher Bate                           ArrayRef<Range> sliceParams) {
271f4a478cdSChristopher Bate   assert(sliceParams.size() == sliceInputShape.size() &&
272f4a478cdSChristopher Bate          "only supports non rank-reducing case");
273f4a478cdSChristopher Bate   llvm::SmallBitVector mask(sliceInputShape.size());
274f4a478cdSChristopher Bate   unsigned idx = 0;
275f4a478cdSChristopher Bate   for (const auto &[offset, size, stride] : sliceParams) {
27622426110SRamkumar Ramachandra     std::optional<int64_t> offsetConst = getConstantIntValue(offset);
27722426110SRamkumar Ramachandra     std::optional<int64_t> strideConst = getConstantIntValue(stride);
278f4a478cdSChristopher Bate     mask[idx] = !isEqualConstantIntOrValue(size, sliceInputShape[idx]) ||
279f4a478cdSChristopher Bate                 (!strideConst || *strideConst != 1) ||
280f4a478cdSChristopher Bate                 (!offsetConst || *offsetConst != 0);
281f4a478cdSChristopher Bate     idx++;
282f4a478cdSChristopher Bate   }
283f4a478cdSChristopher Bate   return mask;
284f4a478cdSChristopher Bate }
285f4a478cdSChristopher Bate 
286f4a478cdSChristopher Bate llvm::SmallBitVector mlir::getLinearizedDimensions(
287f4a478cdSChristopher Bate     ArrayRef<ReassociationIndices> reassociationIndices) {
288f4a478cdSChristopher Bate   llvm::SmallBitVector result(reassociationIndices.size());
289f4a478cdSChristopher Bate   for (const auto &it : llvm::enumerate(reassociationIndices))
290f4a478cdSChristopher Bate     result[it.index()] = it.value().size() > 1;
291f4a478cdSChristopher Bate   return result;
292f4a478cdSChristopher Bate }
293f4a478cdSChristopher Bate 
294f4a478cdSChristopher Bate SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
2954d27f06fSChristopher Bate     MLIRContext *ctx, ArrayRef<ValueRange> multiIndices) {
296f4a478cdSChristopher Bate   unsigned loopIdx = 0;
297f4a478cdSChristopher Bate   auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1);
298f4a478cdSChristopher Bate   auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);
299f4a478cdSChristopher Bate   SmallVector<Range> offsetsSizesAndStrides;
300f4a478cdSChristopher Bate   offsetsSizesAndStrides.reserve(collapseShapeInputShape.size());
301f4a478cdSChristopher Bate   for (const auto &it : llvm::enumerate(reassociationIndices)) {
302f4a478cdSChristopher Bate     // Case 1: Linearized dimensions that have also been sliced. These
303f4a478cdSChristopher Bate     // are size of 1 because we are iterating over these dimensions. The
304f4a478cdSChristopher Bate     // offsets are exactly the de-linearized multi-indices.
305f4a478cdSChristopher Bate     if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) {
306f4a478cdSChristopher Bate       llvm::append_range(
307f4a478cdSChristopher Bate           offsetsSizesAndStrides,
308f4a478cdSChristopher Bate           llvm::map_range(multiIndices[loopIdx++], [&](Value v) -> Range {
309f4a478cdSChristopher Bate             return Range{getAsOpFoldResult(v), oneAttr, oneAttr};
310f4a478cdSChristopher Bate           }));
311f4a478cdSChristopher Bate       continue;
312f4a478cdSChristopher Bate     }
313f4a478cdSChristopher Bate 
314f4a478cdSChristopher Bate     // Case 2: One or possibly multiple combined input dimensions, but we
315f4a478cdSChristopher Bate     // have proven that these are not sliced. In this case we just take
316f4a478cdSChristopher Bate     // the full extent of each dimension in the reassociation list.
317f4a478cdSChristopher Bate     if (linearizedDimensions[it.index()]) {
318f4a478cdSChristopher Bate       llvm::append_range(
319f4a478cdSChristopher Bate           offsetsSizesAndStrides,
320f4a478cdSChristopher Bate           llvm::map_range(it.value(), [&](int64_t idx) -> Range {
321f4a478cdSChristopher Bate             return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
322f4a478cdSChristopher Bate           }));
323f4a478cdSChristopher Bate       continue;
324f4a478cdSChristopher Bate     }
325f4a478cdSChristopher Bate 
326f4a478cdSChristopher Bate     // Case 3: A single index, but it may be sliced.
327f4a478cdSChristopher Bate     offsetsSizesAndStrides.push_back(sliceParams[it.index()]);
328f4a478cdSChristopher Bate   }
329f4a478cdSChristopher Bate   return offsetsSizesAndStrides;
330f4a478cdSChristopher Bate }
331f4a478cdSChristopher Bate 
332f4a478cdSChristopher Bate SmallVector<Range>
3334d27f06fSChristopher Bate SliceFromCollapseHelper::getInsertSliceParams(MLIRContext *ctx,
3344d27f06fSChristopher Bate                                               ValueRange tileIndices) {
335f4a478cdSChristopher Bate   auto one = IntegerAttr::get(IndexType::get(ctx), 1);
336f4a478cdSChristopher Bate   auto zero = IntegerAttr::get(IndexType::get(ctx), 0);
337f4a478cdSChristopher Bate   SmallVector<Range> insertParams;
338f4a478cdSChristopher Bate   insertParams.reserve(linearizedDimensions.size());
339f4a478cdSChristopher Bate   unsigned loopIdx = 0;
340f4a478cdSChristopher Bate   for (unsigned i = 0; i < linearizedDimensions.size(); i++) {
341f4a478cdSChristopher Bate     if (linearizedDimensions[i] && slicedDimensions[i]) {
342f4a478cdSChristopher Bate       insertParams.push_back(Range{tileIndices[loopIdx++], one, one});
343f4a478cdSChristopher Bate       continue;
344f4a478cdSChristopher Bate     }
345f4a478cdSChristopher Bate     insertParams.push_back(Range{zero, sliceParams[i].size, one});
346f4a478cdSChristopher Bate   }
347f4a478cdSChristopher Bate   return insertParams;
348f4a478cdSChristopher Bate }
349446981bdSChristopher Bate 
350446981bdSChristopher Bate /// Returns the index of the only non-unit dimension among `indices` of `shape`,
351446981bdSChristopher Bate /// if such a dimension exists and `indices` has more than one element.
352f09b0e35SKazu Hirata /// Otherwise, return std::nullopt.
3530a81ace0SKazu Hirata static std::optional<int64_t> getUniqueNonUnitDim(ArrayRef<int64_t> indices,
354446981bdSChristopher Bate                                                   ArrayRef<int64_t> shape) {
355446981bdSChristopher Bate   // Return false if more than one of the dimensions in this group are not 1.
35691682b26SKazu Hirata   std::optional<int64_t> dimIndex;
357446981bdSChristopher Bate   if (indices.size() < 2)
3581a36588eSKazu Hirata     return std::nullopt;
359446981bdSChristopher Bate   for (int64_t idx : indices) {
360446981bdSChristopher Bate     if (shape[idx] != 1) {
3611a36588eSKazu Hirata       if (dimIndex != std::nullopt)
3621a36588eSKazu Hirata         return std::nullopt;
363446981bdSChristopher Bate       dimIndex = idx;
364446981bdSChristopher Bate     }
365446981bdSChristopher Bate   }
366446981bdSChristopher Bate   return dimIndex;
367446981bdSChristopher Bate }
368446981bdSChristopher Bate 
369446981bdSChristopher Bate // For each segment in the reassociation indices, check whether we can
370446981bdSChristopher Bate // simplify that segment with a rank-reducing extract slice. We can do this if
371446981bdSChristopher Bate // all but (exactly) one of the corresponding source dims is 1.
3720a81ace0SKazu Hirata static SmallVector<std::optional<int64_t>> getCollapseShapeTrivialSegments(
373446981bdSChristopher Bate     RankedTensorType sourceType,
374446981bdSChristopher Bate     ArrayRef<ReassociationIndices> reassociationIndices) {
3750a81ace0SKazu Hirata   SmallVector<std::optional<int64_t>> trivialSegments;
376446981bdSChristopher Bate   for (const auto &indices : reassociationIndices)
377446981bdSChristopher Bate     trivialSegments.push_back(
378446981bdSChristopher Bate         getUniqueNonUnitDim(indices, sourceType.getShape()));
379446981bdSChristopher Bate   return trivialSegments;
380446981bdSChristopher Bate }
381446981bdSChristopher Bate 
382446981bdSChristopher Bate /// Returns true if any of the segments of the reassociation indices for a
383446981bdSChristopher Bate /// collapsing reshape can be simplified using a rank-reducing slice.
3840a81ace0SKazu Hirata static FailureOr<SmallVector<std::optional<int64_t>>>
385446981bdSChristopher Bate canCollapseShapeBeSimplifiedByRankReducingSlice(
386446981bdSChristopher Bate     RankedTensorType sourceType,
387446981bdSChristopher Bate     ArrayRef<ReassociationIndices> reassociationIndices) {
3880a81ace0SKazu Hirata   SmallVector<std::optional<int64_t>> trivialSegments =
389446981bdSChristopher Bate       getCollapseShapeTrivialSegments(sourceType, reassociationIndices);
3900a81ace0SKazu Hirata   if (!llvm::any_of(trivialSegments, [](const std::optional<int64_t> &idx) {
391446981bdSChristopher Bate         return idx.has_value();
392446981bdSChristopher Bate       }))
393446981bdSChristopher Bate     return failure();
394446981bdSChristopher Bate   return trivialSegments;
395446981bdSChristopher Bate }
396446981bdSChristopher Bate 
397446981bdSChristopher Bate FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
398446981bdSChristopher Bate mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
399446981bdSChristopher Bate     RankedTensorType sourceType,
400446981bdSChristopher Bate     ArrayRef<ReassociationIndices> reassociationIndices) {
4010a81ace0SKazu Hirata   FailureOr<SmallVector<std::optional<int64_t>>> trivialSegments =
402446981bdSChristopher Bate       canCollapseShapeBeSimplifiedByRankReducingSlice(sourceType,
403446981bdSChristopher Bate                                                       reassociationIndices);
404446981bdSChristopher Bate   if (failed(trivialSegments))
405446981bdSChristopher Bate     return failure();
406446981bdSChristopher Bate 
407446981bdSChristopher Bate   // Create the expected result shape of the rank-reducing slice.
408446981bdSChristopher Bate   SmallVector<int64_t> sliceShape;
409446981bdSChristopher Bate   for (const auto &[nonUnitDim, indices] :
410446981bdSChristopher Bate        llvm::zip(*trivialSegments, reassociationIndices)) {
411446981bdSChristopher Bate     if (nonUnitDim) {
412cbb09813SFangrui Song       sliceShape.push_back(sourceType.getDimSize(*nonUnitDim));
413446981bdSChristopher Bate       continue;
414446981bdSChristopher Bate     }
415446981bdSChristopher Bate     llvm::append_range(sliceShape, llvm::map_range(indices, [&](int64_t idx) {
416446981bdSChristopher Bate                          return sourceType.getDimSize(idx);
417446981bdSChristopher Bate                        }));
418446981bdSChristopher Bate   }
419446981bdSChristopher Bate   auto sliceType =
420446981bdSChristopher Bate       RankedTensorType::get(sliceShape, sourceType.getElementType());
421446981bdSChristopher Bate 
422446981bdSChristopher Bate   // If the rank-reducing slice simplified every segment, then we are done.
423446981bdSChristopher Bate   if (sliceShape.size() == reassociationIndices.size())
4241a36588eSKazu Hirata     return CollapseShapeRankReducingSliceSimplificationInfo{sliceType,
4251a36588eSKazu Hirata                                                             std::nullopt};
426446981bdSChristopher Bate 
427446981bdSChristopher Bate   // Otherwise, we need to create a new collapse_shape op for the segments that
428446981bdSChristopher Bate   // weren't covered by the slice. By design, the new reassociation indices has
429446981bdSChristopher Bate   // the same number of groups as the old reassociation indices.
430446981bdSChristopher Bate   SmallVector<ReassociationIndices> newReassociationIndices;
431446981bdSChristopher Bate   SmallVector<int64_t, 2> reassociation;
432446981bdSChristopher Bate   int64_t groupIdx = 0;
433446981bdSChristopher Bate   for (int64_t dimIdx = 0; dimIdx < sliceType.getRank(); dimIdx++) {
434446981bdSChristopher Bate     reassociation.push_back(dimIdx);
435446981bdSChristopher Bate     if ((*trivialSegments)[groupIdx] ||
436446981bdSChristopher Bate         reassociation.size() == reassociationIndices[groupIdx].size()) {
437446981bdSChristopher Bate       newReassociationIndices.push_back(reassociation);
438446981bdSChristopher Bate       reassociation.clear();
439446981bdSChristopher Bate       groupIdx++;
440446981bdSChristopher Bate     }
441446981bdSChristopher Bate   }
442446981bdSChristopher Bate 
443446981bdSChristopher Bate   return CollapseShapeRankReducingSliceSimplificationInfo{
444446981bdSChristopher Bate       sliceType, newReassociationIndices};
445446981bdSChristopher Bate }
4460bfbecf5SQuentin Colombet 
4470bfbecf5SQuentin Colombet PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
4480bfbecf5SQuentin Colombet                                              ArrayRef<int64_t> innerDimPos) {
4490bfbecf5SQuentin Colombet   PackingMetadata res;
4500bfbecf5SQuentin Colombet   res.insertPositions.reserve(innerDimPos.size());
4510bfbecf5SQuentin Colombet   // The pack insert position is the position + the number of previously
4520bfbecf5SQuentin Colombet   // inserted positions + offset.
4530bfbecf5SQuentin Colombet   // The offset controls whether the packing dimension is the first or last.
4540bfbecf5SQuentin Colombet   //
4550bfbecf5SQuentin Colombet   // Example
4560bfbecf5SQuentin Colombet   // =======
4570bfbecf5SQuentin Colombet   // Consider packing from a hypothetical ABCD layout to ABCDba whose
4580bfbecf5SQuentin Colombet   // pack.inner_dims is [1, 0]. The first step consists in undoing the
4590bfbecf5SQuentin Colombet   // permutation and producing AaBbCD. This is achieved purely by computing the
4600bfbecf5SQuentin Colombet   // insert positions of `b` and `a` into `ABCD`, starting from [1, 0]. One
4610bfbecf5SQuentin Colombet   // possibility, is to produce insert positions [2, 0], this would result in an
4620bfbecf5SQuentin Colombet   // aAbBCD layout (i.e. offset 0). The other possibility, is to produce insert
4630bfbecf5SQuentin Colombet   // positions [3, 1], this would result in an AaBbCD layout (i.e. offset 1).
4640bfbecf5SQuentin Colombet   // The latter is what we expect from packing.
4650bfbecf5SQuentin Colombet   int64_t offset = 1;
4660bfbecf5SQuentin Colombet   for (int64_t pos : innerDimPos) {
4670bfbecf5SQuentin Colombet     int64_t numInsertedBefore = llvm::count_if(
4680bfbecf5SQuentin Colombet         innerDimPos, [&pos](int64_t pos2) { return pos > pos2; });
4690bfbecf5SQuentin Colombet     res.insertPositions.push_back(pos + numInsertedBefore + offset);
4700bfbecf5SQuentin Colombet   }
4710bfbecf5SQuentin Colombet 
4720bfbecf5SQuentin Colombet   DenseSet<int64_t> posSet(res.insertPositions.begin(),
4730bfbecf5SQuentin Colombet                            res.insertPositions.end());
4740bfbecf5SQuentin Colombet   res.reassociations.reserve(packedRank);
4750bfbecf5SQuentin Colombet   for (int64_t i = 1; i <= packedRank; ++i) {
4766f87b50bSHanhan Wang     res.outerPositions.push_back(i - 1);
4770bfbecf5SQuentin Colombet     if (!posSet.contains(i)) {
4780bfbecf5SQuentin Colombet       res.reassociations.push_back(ReassociationIndices{i - 1});
4790bfbecf5SQuentin Colombet       continue;
4800bfbecf5SQuentin Colombet     }
4810bfbecf5SQuentin Colombet     res.reassociations.push_back(ReassociationIndices{i - 1, i});
4820bfbecf5SQuentin Colombet     ++i;
4830bfbecf5SQuentin Colombet   }
4840bfbecf5SQuentin Colombet   return res;
4850bfbecf5SQuentin Colombet }
486