xref: /llvm-project/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (revision fcfdabfea18b4d2dd98a1c5b52d5b33aff77ae1a)
1 //===- ReshapeOpsUtils.h - Utilities used by reshape ops --*- C++ -*------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines utilities and common canonicalization patterns for
10 // reshape operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
15 #define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
16 
17 #include "mlir/Dialect/Utils/StaticValueUtils.h"
18 #include "mlir/IR/OpImplementation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Support/LLVM.h"
21 #include "llvm/ADT/StringRef.h"
22 #include <optional>
23 
24 namespace mlir {
25 
26 using ReassociationIndices = SmallVector<int64_t, 2>;
27 using ReassociationIndicesRef = ArrayRef<int64_t>;
28 using ReassociationExprs = SmallVector<AffineExpr, 2>;
29 
30 /// Attribute name for the ArrayAttr which encodes reassociation indices.
31 constexpr StringRef getReassociationAttrName() { return "reassociation"; }
32 
33 /// Compose reassociation maps that are used in pair of reshape ops where one
34 /// is a producer and other is the consumer. Only valid to use this method when
35 /// both the producer and consumer are collapsing dimensions or both are
36 /// expanding dimensions.
37 ///
38 /// For example,
39 ///   producerReassociation = [[0, 1], [2], [3, 4]]
40 ///   consumerReassociation = [[0, 1], [2]]
41 ///
42 /// is folded into
43 ///
44 ///   result = [[0, 1, 2], [3, 4]].
45 std::optional<SmallVector<ReassociationIndices>> composeReassociationIndices(
46     ArrayRef<ReassociationIndices> producerReassociations,
47     ArrayRef<ReassociationIndices> consumerReassociations,
48     MLIRContext *context);
49 
50 /// Convert reassociation indices to affine expressions.
51 SmallVector<SmallVector<AffineExpr, 2>, 2> convertReassociationIndicesToExprs(
52     MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices);
53 
54 /// Constructs affine maps out of Array<Array<AffineExpr>>.
55 SmallVector<AffineMap, 4>
56 getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation);
57 
58 /// Wraps a list of reassociations in an ArrayAttr.
59 ArrayAttr
60 getReassociationIndicesAttribute(OpBuilder &b,
61                                  ArrayRef<ReassociationIndices> reassociation);
62 
63 /// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
64 SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
65     ArrayRef<ReassociationExprs> reassociationExprs);
66 
67 /// Return the reassociations maps to use to reshape given the source type and
68 /// the target type when possible. Return std::nullopt when this computation
69 /// failed.
70 std::optional<SmallVector<ReassociationIndices>>
71 getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType);
72 
73 /// Returns the reassociation maps to collapse `sourceShape` to `targetShape` if
74 /// possible.
75 std::optional<SmallVector<ReassociationIndices>>
76 getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
77                                    ArrayRef<int64_t> targetShape);
78 
79 /// Return true if the reassociation specification is valid, false otherwise.
80 /// When false, the `invalidIndex` integer pointer is optionally filled with the
81 /// index of the offending reassociation map.
82 bool isReassociationValid(ArrayRef<AffineMap> reassociation,
83                           int *invalidIndex = nullptr);
84 
85 template <typename ReshapeOpTy, typename InverseReshapeOpTy>
86 static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
87                                   ArrayRef<Attribute> operands) {
88   // Fold identity reshape.
89   if (reshapeOp.getSrcType() == reshapeOp.getType())
90     return reshapeOp.getSrc();
91 
92   // Reshape of a constant can be replaced with a new constant.
93   if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
94     return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
95 
96   // Fold if the producer reshape source has the same shape with at most 1
97   // dynamic dimension.
98   auto reshapeSrcOp =
99       reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
100   if (!reshapeSrcOp)
101     return nullptr;
102   auto srcType = reshapeSrcOp.getSrcType();
103   auto resultType = reshapeOp.getResultType();
104   if (srcType != resultType)
105     return nullptr;
106 
107   if (llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
108     return reshapeSrcOp.getSrc();
109   }
110 
111   // Fold producer-consumer reshape ops when they are perfect inverses of each
112   // other:
113   //   1) Reassociation indices are equivalent.
114   //   2) Boundary types are equivalent.
115   //   3) No reassociations have more than 1 dynamic dimension, and reassociated
116   //      shapes are equal for each reassociation.
117   auto reassociations = reshapeOp.getReassociationIndices();
118   if (reassociations != reshapeSrcOp.getReassociationIndices())
119     return nullptr;
120   // If the reshapes are expanding and then collapsing, the ops can be folded
121   // despite multiple dynamic dimensions.
122   if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
123     return reshapeSrcOp.getSrc();
124   if (llvm::all_of(reassociations, [&](auto reInd) {
125         ArrayRef<int64_t> srcSlice =
126             srcType.getShape().slice(reInd.front(), reInd.size());
127         return llvm::count_if(srcSlice, ShapedType::isDynamic) < 2;
128       })) {
129     return reshapeSrcOp.getSrc();
130   }
131   return nullptr;
132 }
133 
134 /// Common verifier for reshape-like types. Fills `expandedType` and
135 ///`collapsedType` with the proper `src` or `result` type.
136 template <typename Op, typename T>
137 static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
138                                             T collapsedType, bool isExpansion) {
139 
140   unsigned expandedRank = expandedType.getRank();
141   unsigned collapsedRank = collapsedType.getRank();
142   if (expandedRank < collapsedRank)
143     return op.emitOpError("expected the expanded type, ")
144            << expandedType << " to have a higher (or same) rank "
145            << "than the collapsed type, " << collapsedType << '.';
146 
147   if (collapsedRank != op.getReassociation().size())
148     return op.emitOpError("expected collapsed rank (")
149            << collapsedRank << ") to equal the number of reassociation maps ("
150            << op.getReassociation().size() << ").";
151 
152   auto maps = op.getReassociationMaps();
153   for (auto it : llvm::enumerate(maps))
154     if (it.value().getNumDims() != expandedRank)
155       return op.emitOpError("expected reassociation map #")
156              << it.index() << " to have size equal to the expanded rank ("
157              << expandedRank << "), but it is  " << it.value().getNumDims()
158              << '.';
159 
160   int invalidIdx = 0;
161   if (!isReassociationValid(maps, &invalidIdx))
162     return op.emitOpError("expected reassociation map #")
163            << invalidIdx << " to be valid and contiguous.";
164 
165   return reshapeLikeShapesAreCompatible(
166       [&](const Twine &msg) { return op->emitOpError(msg); },
167       collapsedType.getShape(), expandedType.getShape(),
168       op.getReassociationIndices(), isExpansion);
169 }
170 
171 /// Verify that shapes of the reshaped types using following rule:
172 /// if a dimension in the collapsed type is static, then the corresponding
173 /// dimensions in the expanded shape should be
174 ///    a) static
175 ///    b) the product should be same as the collaped shape.
176 LogicalResult reshapeLikeShapesAreCompatible(
177     function_ref<LogicalResult(const Twine &)> emitError,
178     ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
179     ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape);
180 
181 /// Returns true iff the type is a MemRefType and has a non-identity layout.
182 bool hasNonIdentityLayout(Type type);
183 
184 enum class ReshapeOpKind { kExpand, kCollapse };
185 
186 /// Pattern to collapse producer/consumer reshape ops that are both collapsing
187 /// dimensions or are both expanding dimensions.
188 template <typename ReshapeOpTy, ReshapeOpKind opKind>
189 struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
190   using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
191   LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
192                                 PatternRewriter &rewriter) const override {
193     auto srcReshapeOp =
194         reshapeOp.getSrc().template getDefiningOp<ReshapeOpTy>();
195     if (!srcReshapeOp)
196       return failure();
197 
198     ShapedType resultType = reshapeOp.getResultType();
199 
200     if (hasNonIdentityLayout(srcReshapeOp.getSrc().getType()) ||
201         hasNonIdentityLayout(reshapeOp.getSrc().getType()) ||
202         hasNonIdentityLayout(reshapeOp.getResult().getType()))
203       return failure();
204 
205     std::optional<SmallVector<ReassociationIndices>> reassociationIndices =
206         composeReassociationIndices(srcReshapeOp.getReassociationIndices(),
207                                     reshapeOp.getReassociationIndices(),
208                                     rewriter.getContext());
209     if (!reassociationIndices)
210       return failure();
211 
212     if constexpr (opKind == ReshapeOpKind::kExpand) {
213       SmallVector<OpFoldResult> outputShape(
214           getMixedValues(reshapeOp.getStaticOutputShape(),
215                          reshapeOp.getOutputShape(), rewriter));
216       rewriter.replaceOpWithNewOp<ReshapeOpTy>(
217           reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices,
218           outputShape);
219     } else {
220       rewriter.replaceOpWithNewOp<ReshapeOpTy>(
221           reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
222     }
223     return success();
224   }
225 };
226 
227 /// Pattern to compose
228 /// `collapse_shape(expand_shape(%src, reassociation_1), reassociation_2)`.
229 /// In that case both `srcType` and `resultType` can be expressed as a function
230 /// of `intermediateType`.
231 /// In order to demonstrate the approach, let's assume that `rank(srcType) >
232 /// `rank(resultType)`, i.e. the resulting operation should be `collapse_shape`.
233 /// In that case, we can iterate over every set of indices in `reassociation_2`
234 /// and try to find ids of sets of indices in `reassociation_1` that cover it
235 /// completely.
236 ///
237 /// Example:
238 ///
239 ///   %0 = tensor.expand_shape %arg [[0], [1], [2, 3]]
240 ///     : tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
241 ///   %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
242 ///     : tensor<?x?x?x1xi64> into tensor<?x?xi64>
243 ///
244 /// can be canonicalized into
245 ///
246 ///   %0 = tensor.collapse_shape %arg [[0, 1], [2]]
247 ///     : tensor<?x?x?xi64> into tensor<?x?xi64>
248 ///
249 /// because [0] and [1] from `expand_shape` reassociation cover completely
250 /// `[0, 1]` from `collapse_shape`. If it is impossible to find such union of
251 /// indices, then we fail.
252 //
253 /// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
254 /// `reassociation_2` and produce `expand_shape`.
255 template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy,
256           typename DimOpTy, typename TensorTy>
257 struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
258   using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
259   LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
260                                 PatternRewriter &rewriter) const override {
261     auto expandOp = collapseOp.getSrc().template getDefiningOp<ExpandOpTy>();
262     if (!expandOp)
263       return failure();
264 
265     ShapedType srcType = expandOp.getSrcType();
266     ShapedType resultType = collapseOp.getResultType();
267 
268     if (hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
269         hasNonIdentityLayout(expandOp.getSrc().getType()) ||
270         hasNonIdentityLayout(expandOp.getResult().getType()))
271       return failure();
272 
273     int64_t srcRank = srcType.getRank();
274     int64_t resultRank = resultType.getRank();
275     if (srcType == resultType)
276       return failure();
277 
278     SmallVector<ReassociationIndices, 4> higherRankReassociation,
279         lowerRankReassociation;
280 
281     if (srcRank > resultRank) {
282       higherRankReassociation = expandOp.getReassociationIndices();
283       lowerRankReassociation = collapseOp.getReassociationIndices();
284     } else {
285       higherRankReassociation = collapseOp.getReassociationIndices();
286       lowerRankReassociation = expandOp.getReassociationIndices();
287     }
288 
289     size_t higherRankIndicesID = 0;
290     SmallVector<ReassociationIndices, 4> composedReassociation;
291     for (const auto &lowerRankIndices : lowerRankReassociation) {
292       ReassociationIndices composedIndices;
293       while (higherRankIndicesID < higherRankReassociation.size()) {
294         auto rightmostIndex =
295             higherRankReassociation[higherRankIndicesID].back();
296         if (rightmostIndex > lowerRankIndices.back())
297           return failure();
298         composedIndices.push_back(higherRankIndicesID++);
299         if (rightmostIndex == lowerRankIndices.back())
300           break;
301       }
302       composedReassociation.push_back(composedIndices);
303     }
304     if (srcRank > resultRank) {
305       rewriter.replaceOpWithNewOp<CollapseOpTy>(
306           collapseOp, resultType, expandOp.getSrc(), composedReassociation);
307     } else if (srcRank < resultRank) {
308       rewriter.replaceOpWithNewOp<ExpandOpTy>(
309           collapseOp, resultType, expandOp.getSrc(), composedReassociation);
310     } else {
311       // Collapses/expansions that do not change the rank are not allowed. Use
312       // a cast instead.
313       assert(llvm::equal(srcType.getShape(), resultType.getShape()) &&
314              "expected same shape");
315       rewriter.replaceOpWithNewOp<CastOpTy>(collapseOp, resultType,
316                                             expandOp.getSrc());
317     }
318     return success();
319   }
320 };
321 
322 template <typename ExpandOpTy, typename CollapseOpTy>
323 struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
324   using OpRewritePattern<ExpandOpTy>::OpRewritePattern;
325   LogicalResult matchAndRewrite(ExpandOpTy expandOp,
326                                 PatternRewriter &rewriter) const override {
327     auto collapseOp = expandOp.getSrc().template getDefiningOp<CollapseOpTy>();
328     if (!collapseOp)
329       return failure();
330 
331     ShapedType srcType = collapseOp.getSrcType();
332     ShapedType resultType = expandOp.getResultType();
333 
334     if (hasNonIdentityLayout(expandOp.getSrc().getType()) ||
335         hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
336         hasNonIdentityLayout(collapseOp.getResult().getType()))
337       return failure();
338 
339     int64_t srcRank = srcType.getRank();
340     int64_t resultRank = resultType.getRank();
341     if (srcRank == resultRank)
342       return failure();
343 
344     auto srcReassociation = collapseOp.getReassociationIndices();
345     auto resultReassociation = expandOp.getReassociationIndices();
346     if (srcRank > resultRank) {
347       auto composedReassociation = findCollapsingReassociation(
348           srcReassociation, resultReassociation, srcType.getShape(),
349           resultType.getShape());
350       if (!composedReassociation)
351         return failure();
352 
353       rewriter.replaceOpWithNewOp<CollapseOpTy>(
354           expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
355       return success();
356     }
357     auto composedReassociation =
358         findCollapsingReassociation(resultReassociation, srcReassociation,
359                                     resultType.getShape(), srcType.getShape());
360     if (!composedReassociation)
361       return failure();
362 
363     SmallVector<OpFoldResult> outputShape(getMixedValues(
364         expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
365     rewriter.replaceOpWithNewOp<ExpandOpTy>(
366         expandOp, resultType, collapseOp.getSrc(), *composedReassociation,
367         outputShape);
368     return success();
369   }
370 
371 private:
372   // Attempts to find a way to collapse `srcShape` to `resultShape` by
373   // collapsing subshapes defined by the reassociation indices.
374   std::optional<SmallVector<ReassociationIndices>> findCollapsingReassociation(
375       ArrayRef<ReassociationIndices> srcReassociation,
376       ArrayRef<ReassociationIndices> resultReassociation,
377       ArrayRef<int64_t> srcShape, ArrayRef<int64_t> resultShape) const {
378     SmallVector<ReassociationIndices, 4> composedReassociation;
379 
380     if (srcReassociation.empty())
381       return {getReassociationIndicesForCollapse(srcShape, resultShape)};
382 
383     for (auto item : llvm::zip(srcReassociation, resultReassociation)) {
384       auto &srcIndices = std::get<0>(item);
385       auto &resultIndices = std::get<1>(item);
386       auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size());
387       auto resultSubShape =
388           resultShape.slice(resultIndices.front(), resultIndices.size());
389 
390       if (srcSubShape.size() == resultSubShape.size()) {
391         if (srcSubShape != resultSubShape ||
392             llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2) {
393           return std::nullopt;
394         }
395         for (auto index : llvm::seq<int64_t>(0, srcSubShape.size())) {
396           composedReassociation.emplace_back(1, srcIndices.front() + index);
397         }
398         continue;
399       }
400 
401       // Find reassociation to collapse `srcSubShape` into `resultSubShape`.
402       auto subShapeReassociation =
403           getReassociationIndicesForCollapse(srcSubShape, resultSubShape);
404       if (!subShapeReassociation)
405         return std::nullopt;
406 
407       // Remap the subshape indices back to the original srcShape.
408       for (auto &subshapeIndices : *subShapeReassociation) {
409         ReassociationIndices shapeIndices;
410         for (int64_t index : subshapeIndices)
411           shapeIndices.push_back(srcIndices.front() + index);
412         composedReassociation.push_back(shapeIndices);
413       }
414     }
415     return {std::move(composedReassociation)};
416   }
417 };
418 
419 /// The input parameters `offsets`, `sizes`, `strides` specify a rectangular
420 /// non rank-reducing slice of the collapse_shape output. Try to find which
421 /// dimensions have been sliced and which dimensions are not sliced (offset = 0,
422 /// size = dim, size = 1). Note that this conservative as it cannot detect if a
423 /// dynamic size corresponds to the full tensor dimension or not.
424 llvm::SmallBitVector getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape,
425                                          ArrayRef<Range> sliceParams);
426 
427 /// Determine which dimensions are linearized by a `tensor.collapse_shape` op by
428 /// inspecting its reassociation indices.
429 llvm::SmallBitVector
430 getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
431 
432 /// Given the parameters for both operations in a `CollapseShape->ExtractSlice`
433 /// chain and reified source and result shapes of the CollapseShapeOp, this
434 /// class provides two functions that assist with directly forming the result
435 /// of the extract slice by "tiling the CollapseShapeOp by 1".
436 //// Example:
437 // clang-format off
438 /// ```
439 /// %0 = linalg.generic ... -> tensor<3x7x11x10xf32>
440 /// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : ... to tensor<341x10xf32>
441 /// %2 = tensor.extract_slice %1 [13, 0] [10, 10] [2, 1] : .... tensor<10x10xf32>
442 /// ```
443 /// This class helps build the below IR to replace %2:
444 /// ```
445 /// %dest = tensor.empty() : tensor<10x10xf32>
446 /// %2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0) -> tensor<10x10xf32> {
447 ///    %linear_index = affine.apply affine_map<(d0)[]->(d0*2 + 11)>(%iv)
448 ///    %3:3 = arith.delinearize_index %iv into (3, 7, 11)
449 ///
450 ///    // This function takes %3 (multiIndices) and the parameters for the slice below.
451 ///    %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
452 ///          tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
453 ///
454 ///    %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
455 ///          tensor<1x1x1x10xf32> into tensor<1x10xf32>
456 ///    %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
457 ///          tensor<1x10xf32> into tensor<10x10xf32>
458 ///    scf.yield %6 : tensor<10x10xf32>
459 /// }
460 /// ```
461 // clang-format on
462 class SliceFromCollapseHelper {
463 public:
464   SliceFromCollapseHelper(ArrayRef<ReassociationIndices> reassociationIndices,
465                           ArrayRef<OpFoldResult> collapseShapeInputShape,
466                           ArrayRef<OpFoldResult> collapseShapeOutputShape,
467                           ArrayRef<Range> extractSliceParams)
468       : reassociationIndices(reassociationIndices),
469         collapseShapeInputShape(collapseShapeInputShape),
470         collapseShapeOutputShape(collapseShapeOutputShape),
471         sliceParams(extractSliceParams),
472         linearizedDimensions(getLinearizedDimensions(reassociationIndices)),
473         slicedDimensions(getSlicedDimensions(collapseShapeOutputShape,
474                                              extractSliceParams)) {}
475 
476   /// This function takes multi-indices and maps them to ExtractSlice parameters
477   /// in the index space of the CollapseShape's source tensor. This function's
478   /// signature can be described by `(D_0, D_1,.. D_{n-1}) -> (offsets, sizes,
479   /// strides)` where `n` the number of "tiled dimensions", which are the
480   /// dimensions of the output that are linearized by the collapse shape op and
481   /// are also sliced. Each `D_i` is a tuple that must represent a valid
482   /// multi-index for the `i-th` tiled dimension. In the example above, there is
483   /// only one tiled dimension (D_0) and `arith.delinearize_index` produces the
484   /// multi-index (%3) that would be passed to this function to generate the
485   /// parameters for the `tensor.extract_slice` op (%4).
486   SmallVector<Range> getExtractSliceParams(MLIRContext *ctx,
487                                            ArrayRef<ValueRange> multiIndices);
488 
489   /// This function takes indices in the index space of the "tiled dimensions"
490   /// described above and returns a set of Range variables that describe how the
491   /// slice should be inserted into the destination. In the example above, `%iv`
492   /// would be passed to this function to generate the parameters for the
493   /// `tensor.insert_slice` op producing %6.
494   SmallVector<Range> getInsertSliceParams(MLIRContext *ctx,
495                                           ValueRange tileIndices);
496 
497 private:
498   SmallVector<ReassociationIndices> reassociationIndices;
499   SmallVector<OpFoldResult> collapseShapeInputShape;
500   SmallVector<OpFoldResult> collapseShapeOutputShape;
501   SmallVector<Range> sliceParams;
502   llvm::SmallBitVector linearizedDimensions;
503   llvm::SmallBitVector slicedDimensions;
504 };
505 
506 /// Parameters required to simplify a collapsing reshape op with a rank-reducing
507 /// slice operation. See `getSimplifyCollapseShapeWithRankReducingSliceInfo`.
508 struct CollapseShapeRankReducingSliceSimplificationInfo {
509   /// The shape of the output of the rank-reducing slice.
510   RankedTensorType sliceResultType;
511   /// The reassociation indices for the new collapse shape op, if required. If
512   /// `std::nullopt`, the slice should replace the collapse shape op.
513   std::optional<SmallVector<ReassociationIndices>> newReassociationIndices;
514 };
515 
516 /// A collapsing reshape operation can sometimes be simplified or eliminated by
517 /// inserting a single rank-reducing slice operation between it and the source
518 /// tensor. The slice op will either take the place of the source, allowing for
519 /// a new, simpler reshape op to replace the original, or the reshape op will be
520 /// completely replaced by the slice result.
521 ///
522 /// This function returns the parameters required to implement this pattern. If
523 /// the pattern is not applicable, then failure is returned.
524 ///
525 /// ### Example:
526 /// ```
527 /// %result = tensor.collapse_shape %0 [[0, 1], [2, 3]]
528 ///    : tensor<?x1x30x10xf32> to tensor<?x300xf32>
529 /// ```
530 /// can be transformed to
531 /// ```
532 /// %tmp = tensor.extract_slice %0 [0, 0, 0, 0]
533 ///                         [0, %dim1, 30, 30]
534 ///                         [1, 1, 1 1]
535 ///   : tensor<?x1x30x10xf32> to tensor<?x30x10xf32>
536 /// %result = tensor.collapse_shape %tmp [[0], [1, 2]]
537 ///   : tensor<?x30x10xf32> to tensor<?x300xf32>
538 /// ```
539 ///
540 /// ### Example:
541 /// ```
542 /// %result = tensor.collapse_shape %1 [[0, 1], [2]]
543 ///    : tensor<?x1x30xf32> to tensor<?x30xf32>
544 /// ```
545 /// can be transformed to
546 /// ```
547 /// %result = tensor.extract_slice %1 [0, 0, 0]
548 ///                                   [%dim2, 1, 30]
549 ///                                   [1, 1, 1]
550 ///    : tensor<?x1x30xf32> to tensor<?x30xf32>
551 /// ```
552 FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
553 getSimplifyCollapseShapeWithRankReducingSliceInfo(
554     RankedTensorType sourceType,
555     ArrayRef<ReassociationIndices> reassociationIndices);
556 
557 struct PackingMetadata {
558   SmallVector<int64_t> insertPositions;
559   SmallVector<int64_t> outerPositions;
560   SmallVector<ReassociationIndices> reassociations;
561 };
562 
563 /// Given a vector of `positions` indices representing desired packing insertion
564 /// points into a target vector (i.e. pack/unpack.inner_dim_pos), compute the
565 /// final positions in the target shape as well as the reshape reassociations.
566 // Note: This should not be called with a large positions array (or the
567 // implementation needs to be updated to use an N.log N sort instead of
568 // repeated N^2 counts).
569 PackingMetadata computePackingMetadata(int64_t packedRank,
570                                        ArrayRef<int64_t> innerDimPos);
571 } // namespace mlir
572 
573 #endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
574