1 //===- ReshapeOpsUtils.h - Utilities used by reshape ops --*- C++ -*------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This header file defines utilities and common canonicalization patterns for 10 // reshape operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H 15 #define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H 16 17 #include "mlir/Dialect/Utils/StaticValueUtils.h" 18 #include "mlir/IR/OpImplementation.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Support/LLVM.h" 21 #include "llvm/ADT/StringRef.h" 22 #include <optional> 23 24 namespace mlir { 25 26 using ReassociationIndices = SmallVector<int64_t, 2>; 27 using ReassociationIndicesRef = ArrayRef<int64_t>; 28 using ReassociationExprs = SmallVector<AffineExpr, 2>; 29 30 /// Attribute name for the ArrayAttr which encodes reassociation indices. 31 constexpr StringRef getReassociationAttrName() { return "reassociation"; } 32 33 /// Compose reassociation maps that are used in pair of reshape ops where one 34 /// is a producer and other is the consumer. Only valid to use this method when 35 /// both the producer and consumer are collapsing dimensions or both are 36 /// expanding dimensions. 37 /// 38 /// For example, 39 /// producerReassociation = [[0, 1], [2], [3, 4]] 40 /// consumerReassociation = [[0, 1], [2]] 41 /// 42 /// is folded into 43 /// 44 /// result = [[0, 1, 2], [3, 4]]. 45 std::optional<SmallVector<ReassociationIndices>> composeReassociationIndices( 46 ArrayRef<ReassociationIndices> producerReassociations, 47 ArrayRef<ReassociationIndices> consumerReassociations, 48 MLIRContext *context); 49 50 /// Convert reassociation indices to affine expressions. 51 SmallVector<SmallVector<AffineExpr, 2>, 2> convertReassociationIndicesToExprs( 52 MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices); 53 54 /// Constructs affine maps out of Array<Array<AffineExpr>>. 55 SmallVector<AffineMap, 4> 56 getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation); 57 58 /// Wraps a list of reassociations in an ArrayAttr. 59 ArrayAttr 60 getReassociationIndicesAttribute(OpBuilder &b, 61 ArrayRef<ReassociationIndices> reassociation); 62 63 /// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>. 64 SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices( 65 ArrayRef<ReassociationExprs> reassociationExprs); 66 67 /// Return the reassociations maps to use to reshape given the source type and 68 /// the target type when possible. Return std::nullopt when this computation 69 /// failed. 70 std::optional<SmallVector<ReassociationIndices>> 71 getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType); 72 73 /// Returns the reassociation maps to collapse `sourceShape` to `targetShape` if 74 /// possible. 75 std::optional<SmallVector<ReassociationIndices>> 76 getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape, 77 ArrayRef<int64_t> targetShape); 78 79 /// Return true if the reassociation specification is valid, false otherwise. 80 /// When false, the `invalidIndex` integer pointer is optionally filled with the 81 /// index of the offending reassociation map. 82 bool isReassociationValid(ArrayRef<AffineMap> reassociation, 83 int *invalidIndex = nullptr); 84 85 template <typename ReshapeOpTy, typename InverseReshapeOpTy> 86 static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, 87 ArrayRef<Attribute> operands) { 88 // Fold identity reshape. 89 if (reshapeOp.getSrcType() == reshapeOp.getType()) 90 return reshapeOp.getSrc(); 91 92 // Reshape of a constant can be replaced with a new constant. 93 if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front())) 94 return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType())); 95 96 // Fold if the producer reshape source has the same shape with at most 1 97 // dynamic dimension. 98 auto reshapeSrcOp = 99 reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>(); 100 if (!reshapeSrcOp) 101 return nullptr; 102 auto srcType = reshapeSrcOp.getSrcType(); 103 auto resultType = reshapeOp.getResultType(); 104 if (srcType != resultType) 105 return nullptr; 106 107 if (llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) { 108 return reshapeSrcOp.getSrc(); 109 } 110 111 // Fold producer-consumer reshape ops when they are perfect inverses of each 112 // other: 113 // 1) Reassociation indices are equivalent. 114 // 2) Boundary types are equivalent. 115 // 3) No reassociations have more than 1 dynamic dimension, and reassociated 116 // shapes are equal for each reassociation. 117 auto reassociations = reshapeOp.getReassociationIndices(); 118 if (reassociations != reshapeSrcOp.getReassociationIndices()) 119 return nullptr; 120 // If the reshapes are expanding and then collapsing, the ops can be folded 121 // despite multiple dynamic dimensions. 122 if (srcType.getRank() < reshapeSrcOp.getResultType().getRank()) 123 return reshapeSrcOp.getSrc(); 124 if (llvm::all_of(reassociations, [&](auto reInd) { 125 ArrayRef<int64_t> srcSlice = 126 srcType.getShape().slice(reInd.front(), reInd.size()); 127 return llvm::count_if(srcSlice, ShapedType::isDynamic) < 2; 128 })) { 129 return reshapeSrcOp.getSrc(); 130 } 131 return nullptr; 132 } 133 134 /// Common verifier for reshape-like types. Fills `expandedType` and 135 ///`collapsedType` with the proper `src` or `result` type. 136 template <typename Op, typename T> 137 static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, 138 T collapsedType, bool isExpansion) { 139 140 unsigned expandedRank = expandedType.getRank(); 141 unsigned collapsedRank = collapsedType.getRank(); 142 if (expandedRank < collapsedRank) 143 return op.emitOpError("expected the expanded type, ") 144 << expandedType << " to have a higher (or same) rank " 145 << "than the collapsed type, " << collapsedType << '.'; 146 147 if (collapsedRank != op.getReassociation().size()) 148 return op.emitOpError("expected collapsed rank (") 149 << collapsedRank << ") to equal the number of reassociation maps (" 150 << op.getReassociation().size() << ")."; 151 152 auto maps = op.getReassociationMaps(); 153 for (auto it : llvm::enumerate(maps)) 154 if (it.value().getNumDims() != expandedRank) 155 return op.emitOpError("expected reassociation map #") 156 << it.index() << " to have size equal to the expanded rank (" 157 << expandedRank << "), but it is " << it.value().getNumDims() 158 << '.'; 159 160 int invalidIdx = 0; 161 if (!isReassociationValid(maps, &invalidIdx)) 162 return op.emitOpError("expected reassociation map #") 163 << invalidIdx << " to be valid and contiguous."; 164 165 return reshapeLikeShapesAreCompatible( 166 [&](const Twine &msg) { return op->emitOpError(msg); }, 167 collapsedType.getShape(), expandedType.getShape(), 168 op.getReassociationIndices(), isExpansion); 169 } 170 171 /// Verify that shapes of the reshaped types using following rule: 172 /// if a dimension in the collapsed type is static, then the corresponding 173 /// dimensions in the expanded shape should be 174 /// a) static 175 /// b) the product should be same as the collaped shape. 176 LogicalResult reshapeLikeShapesAreCompatible( 177 function_ref<LogicalResult(const Twine &)> emitError, 178 ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape, 179 ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape); 180 181 /// Returns true iff the type is a MemRefType and has a non-identity layout. 182 bool hasNonIdentityLayout(Type type); 183 184 enum class ReshapeOpKind { kExpand, kCollapse }; 185 186 /// Pattern to collapse producer/consumer reshape ops that are both collapsing 187 /// dimensions or are both expanding dimensions. 188 template <typename ReshapeOpTy, ReshapeOpKind opKind> 189 struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> { 190 using OpRewritePattern<ReshapeOpTy>::OpRewritePattern; 191 LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, 192 PatternRewriter &rewriter) const override { 193 auto srcReshapeOp = 194 reshapeOp.getSrc().template getDefiningOp<ReshapeOpTy>(); 195 if (!srcReshapeOp) 196 return failure(); 197 198 ShapedType resultType = reshapeOp.getResultType(); 199 200 if (hasNonIdentityLayout(srcReshapeOp.getSrc().getType()) || 201 hasNonIdentityLayout(reshapeOp.getSrc().getType()) || 202 hasNonIdentityLayout(reshapeOp.getResult().getType())) 203 return failure(); 204 205 std::optional<SmallVector<ReassociationIndices>> reassociationIndices = 206 composeReassociationIndices(srcReshapeOp.getReassociationIndices(), 207 reshapeOp.getReassociationIndices(), 208 rewriter.getContext()); 209 if (!reassociationIndices) 210 return failure(); 211 212 if constexpr (opKind == ReshapeOpKind::kExpand) { 213 SmallVector<OpFoldResult> outputShape( 214 getMixedValues(reshapeOp.getStaticOutputShape(), 215 reshapeOp.getOutputShape(), rewriter)); 216 rewriter.replaceOpWithNewOp<ReshapeOpTy>( 217 reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices, 218 outputShape); 219 } else { 220 rewriter.replaceOpWithNewOp<ReshapeOpTy>( 221 reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices); 222 } 223 return success(); 224 } 225 }; 226 227 /// Pattern to compose 228 /// `collapse_shape(expand_shape(%src, reassociation_1), reassociation_2)`. 229 /// In that case both `srcType` and `resultType` can be expressed as a function 230 /// of `intermediateType`. 231 /// In order to demonstrate the approach, let's assume that `rank(srcType) > 232 /// `rank(resultType)`, i.e. the resulting operation should be `collapse_shape`. 233 /// In that case, we can iterate over every set of indices in `reassociation_2` 234 /// and try to find ids of sets of indices in `reassociation_1` that cover it 235 /// completely. 236 /// 237 /// Example: 238 /// 239 /// %0 = tensor.expand_shape %arg [[0], [1], [2, 3]] 240 /// : tensor<?x?x?xi64> into tensor<?x?x?x1xi64> 241 /// %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]] 242 /// : tensor<?x?x?x1xi64> into tensor<?x?xi64> 243 /// 244 /// can be canonicalized into 245 /// 246 /// %0 = tensor.collapse_shape %arg [[0, 1], [2]] 247 /// : tensor<?x?x?xi64> into tensor<?x?xi64> 248 /// 249 /// because [0] and [1] from `expand_shape` reassociation cover completely 250 /// `[0, 1]` from `collapse_shape`. If it is impossible to find such union of 251 /// indices, then we fail. 252 // 253 /// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1` 254 /// `reassociation_2` and produce `expand_shape`. 255 template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy, 256 typename DimOpTy, typename TensorTy> 257 struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> { 258 using OpRewritePattern<CollapseOpTy>::OpRewritePattern; 259 LogicalResult matchAndRewrite(CollapseOpTy collapseOp, 260 PatternRewriter &rewriter) const override { 261 auto expandOp = collapseOp.getSrc().template getDefiningOp<ExpandOpTy>(); 262 if (!expandOp) 263 return failure(); 264 265 ShapedType srcType = expandOp.getSrcType(); 266 ShapedType resultType = collapseOp.getResultType(); 267 268 if (hasNonIdentityLayout(collapseOp.getSrc().getType()) || 269 hasNonIdentityLayout(expandOp.getSrc().getType()) || 270 hasNonIdentityLayout(expandOp.getResult().getType())) 271 return failure(); 272 273 int64_t srcRank = srcType.getRank(); 274 int64_t resultRank = resultType.getRank(); 275 if (srcType == resultType) 276 return failure(); 277 278 SmallVector<ReassociationIndices, 4> higherRankReassociation, 279 lowerRankReassociation; 280 281 if (srcRank > resultRank) { 282 higherRankReassociation = expandOp.getReassociationIndices(); 283 lowerRankReassociation = collapseOp.getReassociationIndices(); 284 } else { 285 higherRankReassociation = collapseOp.getReassociationIndices(); 286 lowerRankReassociation = expandOp.getReassociationIndices(); 287 } 288 289 size_t higherRankIndicesID = 0; 290 SmallVector<ReassociationIndices, 4> composedReassociation; 291 for (const auto &lowerRankIndices : lowerRankReassociation) { 292 ReassociationIndices composedIndices; 293 while (higherRankIndicesID < higherRankReassociation.size()) { 294 auto rightmostIndex = 295 higherRankReassociation[higherRankIndicesID].back(); 296 if (rightmostIndex > lowerRankIndices.back()) 297 return failure(); 298 composedIndices.push_back(higherRankIndicesID++); 299 if (rightmostIndex == lowerRankIndices.back()) 300 break; 301 } 302 composedReassociation.push_back(composedIndices); 303 } 304 if (srcRank > resultRank) { 305 rewriter.replaceOpWithNewOp<CollapseOpTy>( 306 collapseOp, resultType, expandOp.getSrc(), composedReassociation); 307 } else if (srcRank < resultRank) { 308 rewriter.replaceOpWithNewOp<ExpandOpTy>( 309 collapseOp, resultType, expandOp.getSrc(), composedReassociation); 310 } else { 311 // Collapses/expansions that do not change the rank are not allowed. Use 312 // a cast instead. 313 assert(llvm::equal(srcType.getShape(), resultType.getShape()) && 314 "expected same shape"); 315 rewriter.replaceOpWithNewOp<CastOpTy>(collapseOp, resultType, 316 expandOp.getSrc()); 317 } 318 return success(); 319 } 320 }; 321 322 template <typename ExpandOpTy, typename CollapseOpTy> 323 struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> { 324 using OpRewritePattern<ExpandOpTy>::OpRewritePattern; 325 LogicalResult matchAndRewrite(ExpandOpTy expandOp, 326 PatternRewriter &rewriter) const override { 327 auto collapseOp = expandOp.getSrc().template getDefiningOp<CollapseOpTy>(); 328 if (!collapseOp) 329 return failure(); 330 331 ShapedType srcType = collapseOp.getSrcType(); 332 ShapedType resultType = expandOp.getResultType(); 333 334 if (hasNonIdentityLayout(expandOp.getSrc().getType()) || 335 hasNonIdentityLayout(collapseOp.getSrc().getType()) || 336 hasNonIdentityLayout(collapseOp.getResult().getType())) 337 return failure(); 338 339 int64_t srcRank = srcType.getRank(); 340 int64_t resultRank = resultType.getRank(); 341 if (srcRank == resultRank) 342 return failure(); 343 344 auto srcReassociation = collapseOp.getReassociationIndices(); 345 auto resultReassociation = expandOp.getReassociationIndices(); 346 if (srcRank > resultRank) { 347 auto composedReassociation = findCollapsingReassociation( 348 srcReassociation, resultReassociation, srcType.getShape(), 349 resultType.getShape()); 350 if (!composedReassociation) 351 return failure(); 352 353 rewriter.replaceOpWithNewOp<CollapseOpTy>( 354 expandOp, resultType, collapseOp.getSrc(), *composedReassociation); 355 return success(); 356 } 357 auto composedReassociation = 358 findCollapsingReassociation(resultReassociation, srcReassociation, 359 resultType.getShape(), srcType.getShape()); 360 if (!composedReassociation) 361 return failure(); 362 363 SmallVector<OpFoldResult> outputShape(getMixedValues( 364 expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter)); 365 rewriter.replaceOpWithNewOp<ExpandOpTy>( 366 expandOp, resultType, collapseOp.getSrc(), *composedReassociation, 367 outputShape); 368 return success(); 369 } 370 371 private: 372 // Attempts to find a way to collapse `srcShape` to `resultShape` by 373 // collapsing subshapes defined by the reassociation indices. 374 std::optional<SmallVector<ReassociationIndices>> findCollapsingReassociation( 375 ArrayRef<ReassociationIndices> srcReassociation, 376 ArrayRef<ReassociationIndices> resultReassociation, 377 ArrayRef<int64_t> srcShape, ArrayRef<int64_t> resultShape) const { 378 SmallVector<ReassociationIndices, 4> composedReassociation; 379 380 if (srcReassociation.empty()) 381 return {getReassociationIndicesForCollapse(srcShape, resultShape)}; 382 383 for (auto item : llvm::zip(srcReassociation, resultReassociation)) { 384 auto &srcIndices = std::get<0>(item); 385 auto &resultIndices = std::get<1>(item); 386 auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size()); 387 auto resultSubShape = 388 resultShape.slice(resultIndices.front(), resultIndices.size()); 389 390 if (srcSubShape.size() == resultSubShape.size()) { 391 if (srcSubShape != resultSubShape || 392 llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2) { 393 return std::nullopt; 394 } 395 for (auto index : llvm::seq<int64_t>(0, srcSubShape.size())) { 396 composedReassociation.emplace_back(1, srcIndices.front() + index); 397 } 398 continue; 399 } 400 401 // Find reassociation to collapse `srcSubShape` into `resultSubShape`. 402 auto subShapeReassociation = 403 getReassociationIndicesForCollapse(srcSubShape, resultSubShape); 404 if (!subShapeReassociation) 405 return std::nullopt; 406 407 // Remap the subshape indices back to the original srcShape. 408 for (auto &subshapeIndices : *subShapeReassociation) { 409 ReassociationIndices shapeIndices; 410 for (int64_t index : subshapeIndices) 411 shapeIndices.push_back(srcIndices.front() + index); 412 composedReassociation.push_back(shapeIndices); 413 } 414 } 415 return {std::move(composedReassociation)}; 416 } 417 }; 418 419 /// The input parameters `offsets`, `sizes`, `strides` specify a rectangular 420 /// non rank-reducing slice of the collapse_shape output. Try to find which 421 /// dimensions have been sliced and which dimensions are not sliced (offset = 0, 422 /// size = dim, size = 1). Note that this conservative as it cannot detect if a 423 /// dynamic size corresponds to the full tensor dimension or not. 424 llvm::SmallBitVector getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape, 425 ArrayRef<Range> sliceParams); 426 427 /// Determine which dimensions are linearized by a `tensor.collapse_shape` op by 428 /// inspecting its reassociation indices. 429 llvm::SmallBitVector 430 getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices); 431 432 /// Given the parameters for both operations in a `CollapseShape->ExtractSlice` 433 /// chain and reified source and result shapes of the CollapseShapeOp, this 434 /// class provides two functions that assist with directly forming the result 435 /// of the extract slice by "tiling the CollapseShapeOp by 1". 436 //// Example: 437 // clang-format off 438 /// ``` 439 /// %0 = linalg.generic ... -> tensor<3x7x11x10xf32> 440 /// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : ... to tensor<341x10xf32> 441 /// %2 = tensor.extract_slice %1 [13, 0] [10, 10] [2, 1] : .... tensor<10x10xf32> 442 /// ``` 443 /// This class helps build the below IR to replace %2: 444 /// ``` 445 /// %dest = tensor.empty() : tensor<10x10xf32> 446 /// %2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0) -> tensor<10x10xf32> { 447 /// %linear_index = affine.apply affine_map<(d0)[]->(d0*2 + 11)>(%iv) 448 /// %3:3 = arith.delinearize_index %iv into (3, 7, 11) 449 /// 450 /// // This function takes %3 (multiIndices) and the parameters for the slice below. 451 /// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] : 452 /// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32> 453 /// 454 /// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] : 455 /// tensor<1x1x1x10xf32> into tensor<1x10xf32> 456 /// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] : 457 /// tensor<1x10xf32> into tensor<10x10xf32> 458 /// scf.yield %6 : tensor<10x10xf32> 459 /// } 460 /// ``` 461 // clang-format on 462 class SliceFromCollapseHelper { 463 public: 464 SliceFromCollapseHelper(ArrayRef<ReassociationIndices> reassociationIndices, 465 ArrayRef<OpFoldResult> collapseShapeInputShape, 466 ArrayRef<OpFoldResult> collapseShapeOutputShape, 467 ArrayRef<Range> extractSliceParams) 468 : reassociationIndices(reassociationIndices), 469 collapseShapeInputShape(collapseShapeInputShape), 470 collapseShapeOutputShape(collapseShapeOutputShape), 471 sliceParams(extractSliceParams), 472 linearizedDimensions(getLinearizedDimensions(reassociationIndices)), 473 slicedDimensions(getSlicedDimensions(collapseShapeOutputShape, 474 extractSliceParams)) {} 475 476 /// This function takes multi-indices and maps them to ExtractSlice parameters 477 /// in the index space of the CollapseShape's source tensor. This function's 478 /// signature can be described by `(D_0, D_1,.. D_{n-1}) -> (offsets, sizes, 479 /// strides)` where `n` the number of "tiled dimensions", which are the 480 /// dimensions of the output that are linearized by the collapse shape op and 481 /// are also sliced. Each `D_i` is a tuple that must represent a valid 482 /// multi-index for the `i-th` tiled dimension. In the example above, there is 483 /// only one tiled dimension (D_0) and `arith.delinearize_index` produces the 484 /// multi-index (%3) that would be passed to this function to generate the 485 /// parameters for the `tensor.extract_slice` op (%4). 486 SmallVector<Range> getExtractSliceParams(MLIRContext *ctx, 487 ArrayRef<ValueRange> multiIndices); 488 489 /// This function takes indices in the index space of the "tiled dimensions" 490 /// described above and returns a set of Range variables that describe how the 491 /// slice should be inserted into the destination. In the example above, `%iv` 492 /// would be passed to this function to generate the parameters for the 493 /// `tensor.insert_slice` op producing %6. 494 SmallVector<Range> getInsertSliceParams(MLIRContext *ctx, 495 ValueRange tileIndices); 496 497 private: 498 SmallVector<ReassociationIndices> reassociationIndices; 499 SmallVector<OpFoldResult> collapseShapeInputShape; 500 SmallVector<OpFoldResult> collapseShapeOutputShape; 501 SmallVector<Range> sliceParams; 502 llvm::SmallBitVector linearizedDimensions; 503 llvm::SmallBitVector slicedDimensions; 504 }; 505 506 /// Parameters required to simplify a collapsing reshape op with a rank-reducing 507 /// slice operation. See `getSimplifyCollapseShapeWithRankReducingSliceInfo`. 508 struct CollapseShapeRankReducingSliceSimplificationInfo { 509 /// The shape of the output of the rank-reducing slice. 510 RankedTensorType sliceResultType; 511 /// The reassociation indices for the new collapse shape op, if required. If 512 /// `std::nullopt`, the slice should replace the collapse shape op. 513 std::optional<SmallVector<ReassociationIndices>> newReassociationIndices; 514 }; 515 516 /// A collapsing reshape operation can sometimes be simplified or eliminated by 517 /// inserting a single rank-reducing slice operation between it and the source 518 /// tensor. The slice op will either take the place of the source, allowing for 519 /// a new, simpler reshape op to replace the original, or the reshape op will be 520 /// completely replaced by the slice result. 521 /// 522 /// This function returns the parameters required to implement this pattern. If 523 /// the pattern is not applicable, then failure is returned. 524 /// 525 /// ### Example: 526 /// ``` 527 /// %result = tensor.collapse_shape %0 [[0, 1], [2, 3]] 528 /// : tensor<?x1x30x10xf32> to tensor<?x300xf32> 529 /// ``` 530 /// can be transformed to 531 /// ``` 532 /// %tmp = tensor.extract_slice %0 [0, 0, 0, 0] 533 /// [0, %dim1, 30, 30] 534 /// [1, 1, 1 1] 535 /// : tensor<?x1x30x10xf32> to tensor<?x30x10xf32> 536 /// %result = tensor.collapse_shape %tmp [[0], [1, 2]] 537 /// : tensor<?x30x10xf32> to tensor<?x300xf32> 538 /// ``` 539 /// 540 /// ### Example: 541 /// ``` 542 /// %result = tensor.collapse_shape %1 [[0, 1], [2]] 543 /// : tensor<?x1x30xf32> to tensor<?x30xf32> 544 /// ``` 545 /// can be transformed to 546 /// ``` 547 /// %result = tensor.extract_slice %1 [0, 0, 0] 548 /// [%dim2, 1, 30] 549 /// [1, 1, 1] 550 /// : tensor<?x1x30xf32> to tensor<?x30xf32> 551 /// ``` 552 FailureOr<CollapseShapeRankReducingSliceSimplificationInfo> 553 getSimplifyCollapseShapeWithRankReducingSliceInfo( 554 RankedTensorType sourceType, 555 ArrayRef<ReassociationIndices> reassociationIndices); 556 557 struct PackingMetadata { 558 SmallVector<int64_t> insertPositions; 559 SmallVector<int64_t> outerPositions; 560 SmallVector<ReassociationIndices> reassociations; 561 }; 562 563 /// Given a vector of `positions` indices representing desired packing insertion 564 /// points into a target vector (i.e. pack/unpack.inner_dim_pos), compute the 565 /// final positions in the target shape as well as the reshape reassociations. 566 // Note: This should not be called with a large positions array (or the 567 // implementation needs to be updated to use an N.log N sort instead of 568 // repeated N^2 counts). 569 PackingMetadata computePackingMetadata(int64_t packedRank, 570 ArrayRef<int64_t> innerDimPos); 571 } // namespace mlir 572 573 #endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H 574