1 //===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/IR/Linalg.h" 10 #include "mlir/Dialect/Tensor/IR/Tensor.h" 11 #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 12 #include "mlir/Dialect/Utils/IndexingUtils.h" 13 #include "mlir/IR/PatternMatch.h" 14 15 namespace mlir { 16 namespace tensor { 17 namespace { 18 19 /// Returns the number of shape sizes that is either dynamic or greater than 1. 20 static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) { 21 return llvm::count_if( 22 shape, [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; }); 23 } 24 25 /// Returns success() if there is only 1 dimension size in non-packed domain 26 /// being greater than 1 and packing only happens on the dimension. 27 /// Note: this method should only be used by pack/unpack to reshape conversion. 28 /// It assumes that non-unit inner tile size must be used by the non-unit 29 /// dimension. 30 static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op, 31 ArrayRef<int64_t> srcShape, 32 ArrayRef<int64_t> innerPackTileSize) { 33 if (getNumGtOneDims(srcShape) > 1) { 34 return rewriter.notifyMatchFailure( 35 op, "expects non-packed domain to have at most one non-unit dims"); 36 } 37 // Non-unit inner tile size must be used by the non-unit dimension. If not, it 38 // will faill on getting reassociation maps. 39 if (getNumGtOneDims(innerPackTileSize) > 1) { 40 return rewriter.notifyMatchFailure( 41 op, "expects at most one non-unit inner tiles"); 42 } 43 return success(); 44 } 45 46 // If the `linalgOp` represents a transpose, return the permutation vector for 47 // the transpose. Otherwise, return failure. 48 static FailureOr<SmallVector<int64_t>> 49 getTransposeOpPermutation(linalg::LinalgOp linalgOp) { 50 if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation())) 51 return SmallVector<int64_t>(transposeOp.getPermutation()); 52 if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) 53 return failure(); 54 55 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) 56 return failure(); 57 auto mapRange = linalgOp.getIndexingMapsArray(); 58 if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() || 59 mapRange.front() == mapRange.back()) { 60 return failure(); 61 } 62 if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations())) 63 return failure(); 64 AffineMap outMap = mapRange.back(); 65 AffineMap inMap = mapRange.front(); 66 // To get the permutation, look at each output index and find which 67 // dimension in the input we're reading from for that index. 68 return llvm::map_to_vector(outMap.getResults(), 69 [&](AffineExpr expr) -> int64_t { 70 return *inMap.getResultPosition(expr); 71 }); 72 } 73 74 /// Packing one-dimensional tensor can be expressed as an expand shape op. 75 struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> { 76 using OpRewritePattern<PackOp>::OpRewritePattern; 77 78 FailureOr<Value> 79 insertExpand(RewriterBase &rewriter, Location loc, Value operand, 80 Type newOperandType, 81 ArrayRef<ReassociationIndices> reassociation) const { 82 if (operand.getType() == newOperandType) 83 return operand; 84 return rewriter 85 .create<tensor::ExpandShapeOp>(loc, newOperandType, operand, 86 reassociation) 87 .getResult(); 88 } 89 90 /// Returns success() if it is only packing on the innermost dimension. 91 LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter, 92 PackOp packOp) const { 93 auto outerDimsPerm = packOp.getOuterDimsPerm(); 94 if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) { 95 return rewriter.notifyMatchFailure( 96 packOp, 97 "expects outer_dims_perm is empty or an identity permutation"); 98 } 99 100 int64_t srcRank = packOp.getSourceRank(); 101 ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos(); 102 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) { 103 return rewriter.notifyMatchFailure( 104 packOp, "expects packing at the innermost dimension"); 105 } 106 return success(); 107 } 108 109 LogicalResult matchAndRewrite(PackOp packOp, 110 PatternRewriter &rewriter) const override { 111 if (packOp.getPaddingValue()) 112 return rewriter.notifyMatchFailure(packOp, "expects no padding value"); 113 114 RankedTensorType sourceType = packOp.getSourceType(); 115 if (failed(isPackOnInnerMostDim(rewriter, packOp)) && 116 failed(isPackOn1D(rewriter, packOp, sourceType.getShape(), 117 packOp.getStaticTiles())) && 118 !packOp.isLikePad()) { 119 return failure(); 120 } 121 122 RankedTensorType destType = packOp.getDestType(); 123 auto reassociation = 124 getReassociationIndicesForReshape(sourceType, destType); 125 if (!reassociation) 126 return failure(); 127 FailureOr<Value> expanded = 128 insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType, 129 *reassociation); 130 if (failed(expanded)) { 131 return rewriter.notifyMatchFailure( 132 packOp, "unable to expand source of tensor.pack"); 133 } 134 rewriter.replaceOp(packOp, *expanded); 135 return success(); 136 } 137 }; 138 139 struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> { 140 using OpRewritePattern<UnPackOp>::OpRewritePattern; 141 142 Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand, 143 Type newOperandType, ArrayAttr reassociation) const { 144 if (operand.getType() == newOperandType) 145 return operand; 146 return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType, 147 operand, reassociation); 148 } 149 150 /// Returns success() if it is unpacking on the innermost dimension. 151 LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter, 152 UnPackOp unpackOp) const { 153 auto outerDimsPerm = unpackOp.getOuterDimsPerm(); 154 if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) { 155 return rewriter.notifyMatchFailure( 156 unpackOp, 157 "expects outer_dims_perm is empty or an identity permutation"); 158 } 159 160 RankedTensorType sourceType = unpackOp.getSourceType(); 161 RankedTensorType destType = unpackOp.getDestType(); 162 if (!sourceType.hasStaticShape() || !destType.hasStaticShape()) 163 return rewriter.notifyMatchFailure(unpackOp, "expects static shapes"); 164 165 ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos(); 166 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) { 167 return rewriter.notifyMatchFailure( 168 unpackOp, "expects unpacking on the innermost dimension"); 169 } 170 171 return success(); 172 } 173 174 LogicalResult matchAndRewrite(UnPackOp unpackOp, 175 PatternRewriter &rewriter) const override { 176 RankedTensorType destType = unpackOp.getDestType(); 177 if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) && 178 failed(isPackOn1D(rewriter, unpackOp, destType.getShape(), 179 unpackOp.getStaticTiles())) && 180 !unpackOp.isLikeUnPad()) { 181 return failure(); 182 } 183 184 RankedTensorType sourceType = unpackOp.getSourceType(); 185 auto reassociation = 186 getReassociationIndicesForReshape(sourceType, destType); 187 if (!reassociation) 188 return failure(); 189 Value collapsed = insertCollapse( 190 rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType, 191 getReassociationIndicesAttribute(rewriter, *reassociation)); 192 rewriter.replaceOp(unpackOp, collapsed); 193 return success(); 194 } 195 }; 196 197 /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and 198 /// the pad op has zero low paddings, or if `pack` has no padding values. 199 struct FoldPadWithPackOp : public OpRewritePattern<PackOp> { 200 using OpRewritePattern<PackOp>::OpRewritePattern; 201 202 LogicalResult matchAndRewrite(PackOp packOp, 203 PatternRewriter &rewriter) const override { 204 auto padOp = packOp.getSource().getDefiningOp<PadOp>(); 205 206 if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad()) 207 return failure(); 208 209 Value constantPaddingValue = padOp.getConstantPaddingValue(); 210 if (!constantPaddingValue) 211 return failure(); 212 213 if (auto paddingValue = packOp.getPaddingValue()) 214 if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue)) 215 return failure(); 216 217 rewriter.replaceOpWithNewOp<PackOp>( 218 packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(), 219 packOp.getMixedTiles(), constantPaddingValue, 220 packOp.getOuterDimsPerm()); 221 return success(); 222 } 223 }; 224 225 /// Fold a `unpack` -> `extract_slice` into the `unpack` since it already 226 /// has extract_slice semantics. 227 struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> { 228 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern; 229 230 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, 231 PatternRewriter &rewriter) const override { 232 auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>(); 233 if (!unpackOp) 234 return failure(); 235 236 if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) { 237 return rewriter.notifyMatchFailure( 238 sliceOp, "rank-reduced folding is not supported"); 239 } 240 241 // Check all offsets are zeros, and all strides are ones. 242 if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) || 243 !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) { 244 return rewriter.notifyMatchFailure( 245 sliceOp, "expects offsets to be 0s and strides to be 1s"); 246 } 247 248 // Create a new empty output tensor. 249 Type elementType = unpackOp.getDestType().getElementType(); 250 Value output = rewriter.create<EmptyOp>( 251 sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType); 252 rewriter.replaceOpWithNewOp<UnPackOp>( 253 sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(), 254 unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm()); 255 return success(); 256 } 257 }; 258 259 // Applies 'permutation' on 'inVec' and stores the result in resVec. 260 // 'inVec' may be empty, in that case it's one-to-one mapping with permutation. 261 // `rank` sets the boundary for permutation i.e., the permutation dim can't be 262 // greater than the rank specified. If it's so then return false. 263 // For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in 264 // permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is 265 // not allowed since `3` exceeds the value of the rank in the given range. 266 static bool checkAndPermute(ArrayRef<int64_t> permutation, 267 ArrayRef<int64_t> inVec, 268 SmallVectorImpl<int64_t> &resVec, int64_t rank) { 269 270 for (unsigned int i = 0; i < rank; ++i) { 271 int64_t remappedPosition = permutation[i]; 272 if (remappedPosition >= rank) 273 return false; 274 if (!inVec.empty()) 275 remappedPosition = inVec[remappedPosition]; 276 resVec.push_back(remappedPosition); 277 } 278 279 return true; 280 } 281 282 /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose 283 /// semantics. 284 struct FoldProducerPackWithConsumerLinalgTransposeOp 285 : public OpInterfaceRewritePattern<linalg::LinalgOp> { 286 using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern; 287 288 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, 289 PatternRewriter &rewriter) const override { 290 auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>(); 291 292 if (!packOp) 293 return failure(); 294 295 FailureOr<SmallVector<int64_t>> maybePerm = 296 getTransposeOpPermutation(linalgOp); 297 if (failed(maybePerm)) 298 return failure(); 299 300 auto innerDimsPos = packOp.getInnerDimsPos(); 301 auto mixedInnerTiles = packOp.getMixedTiles(); 302 auto outerDimsPerm = packOp.getOuterDimsPerm(); 303 auto transposePerm = maybePerm.value(); 304 SmallVector<int64_t> newOuterDimsPermVec; 305 SmallVector<int64_t> newInnerDimsPosVec; 306 SmallVector<OpFoldResult> newMixedInnerTilesVec; 307 int64_t srcRank = packOp.getSourceRank(); 308 309 if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec, 310 srcRank)) 311 return rewriter.notifyMatchFailure( 312 linalgOp, 313 "Cannot fold in tensor.pack if a tile dimension was transposed " 314 "with a non-tile dimension in linalg.transpose."); 315 316 // Process transpose operation for tiled inner dimensions 317 for (unsigned int i = srcRank; i < transposePerm.size(); ++i) { 318 int64_t remappedPosition = transposePerm[i] - srcRank; 319 newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]); 320 newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]); 321 } 322 323 Value output = packOp.createDestinationTensor( 324 rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec, 325 newInnerDimsPosVec, newOuterDimsPermVec); 326 327 rewriter.replaceOpWithNewOp<PackOp>( 328 linalgOp, packOp.getSource(), output, newInnerDimsPosVec, 329 newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec); 330 331 return success(); 332 } 333 }; 334 335 /// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose 336 /// semantics. 337 struct FoldConsumerPackWithProducerLinalgTransposeOp 338 : public OpRewritePattern<PackOp> { 339 using OpRewritePattern<PackOp>::OpRewritePattern; 340 341 LogicalResult matchAndRewrite(PackOp packOp, 342 PatternRewriter &rewriter) const override { 343 auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>(); 344 if (!linalgOp) 345 return failure(); 346 347 FailureOr<SmallVector<int64_t>> maybePerm = 348 getTransposeOpPermutation(linalgOp); 349 if (failed(maybePerm)) 350 return failure(); 351 352 auto transposePermutation = maybePerm.value(); 353 auto outerDimsPerm = packOp.getOuterDimsPerm(); 354 auto innerDimsPos = packOp.getInnerDimsPos(); 355 SmallVector<int64_t> newInnerDimsPosVec; 356 SmallVector<int64_t> newOuterDimsPermVec = 357 llvm::to_vector(transposePermutation); 358 359 if (!outerDimsPerm.empty()) 360 applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm); 361 362 // Can't use applyPermutationToVector for newInnerDimsPosVec since input and 363 // permutation rank won't necessarily be equal in all cases. 364 for (auto dim : innerDimsPos) 365 newInnerDimsPosVec.push_back(transposePermutation[dim]); 366 367 Value output = packOp.createDestinationTensor( 368 rewriter, packOp.getLoc(), linalgOp->getOperand(0), 369 packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec); 370 371 rewriter.replaceOpWithNewOp<PackOp>( 372 packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec, 373 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec); 374 375 return success(); 376 } 377 }; 378 379 /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has 380 /// transpose semantics. 381 struct FoldProducerUnPackWithConsumerLinalgTransposeOp 382 : public OpInterfaceRewritePattern<linalg::LinalgOp> { 383 using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern; 384 385 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, 386 PatternRewriter &rewriter) const override { 387 auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>(); 388 389 if (!unPackOp) 390 return failure(); 391 392 FailureOr<SmallVector<int64_t>> maybePerm = 393 getTransposeOpPermutation(linalgOp); 394 if (failed(maybePerm)) 395 return failure(); 396 397 auto outerDimsPerm = unPackOp.getOuterDimsPerm(); 398 auto innerDimsPos = unPackOp.getInnerDimsPos(); 399 SmallVector<int64_t> newInnerDimsPosVec; 400 SmallVector<int64_t> newOuterDimsPermVec = 401 invertPermutationVector(maybePerm.value()); 402 403 // Can't use applyPermutationToVector for newInnerDimsPosVec since input and 404 // permutation rank won't necessarily be equal in all cases. 405 for (auto dim : innerDimsPos) 406 newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]); 407 408 if (!outerDimsPerm.empty()) 409 applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm); 410 411 // Reuse the destination of the transpose op. 412 rewriter.replaceOpWithNewOp<UnPackOp>( 413 linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0], 414 newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec); 415 416 return success(); 417 } 418 }; 419 420 /// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has 421 /// transpose semantics. 422 struct FoldConsumerUnPackWithProducerLinalgTransposeOp 423 : public OpRewritePattern<UnPackOp> { 424 using OpRewritePattern<UnPackOp>::OpRewritePattern; 425 426 LogicalResult matchAndRewrite(UnPackOp unPackOp, 427 PatternRewriter &rewriter) const override { 428 auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>(); 429 if (!linalgOp) 430 return failure(); 431 432 FailureOr<SmallVector<int64_t>> maybePerm = 433 getTransposeOpPermutation(linalgOp); 434 if (failed(maybePerm)) 435 return failure(); 436 437 SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims; 438 if (failed(reifyResultShapes(rewriter, unPackOp, unpackOpResultDims))) { 439 return failure(); 440 } 441 442 SmallVector<int64_t> inverseTransposePerm = 443 invertPermutationVector(maybePerm.value()); 444 auto outerDimsPerm = unPackOp.getOuterDimsPerm(); 445 auto innerDimsPos = unPackOp.getInnerDimsPos(); 446 int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size(); 447 auto mixedInnerTilesVec = unPackOp.getMixedTiles(); 448 SmallVector<int64_t> newOuterDimsPermVec; 449 SmallVector<int64_t> newInnerDimsPosVec; 450 SmallVector<OpFoldResult> newMixedInnerTilesVec; 451 if (!checkAndPermute(inverseTransposePerm, outerDimsPerm, 452 newOuterDimsPermVec, destRank)) 453 return rewriter.notifyMatchFailure( 454 unPackOp, 455 "Cannot fold in tensor.unpack if a tile dimension was transposed " 456 "with a non-tile dimension in linalg.transpose."); 457 458 // Process transpose operation for tiled inner dimensions 459 for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) { 460 int64_t remappedPosition = inverseTransposePerm[i] - destRank; 461 newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]); 462 newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]); 463 } 464 465 auto elemType = 466 cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType(); 467 Value output = rewriter.create<tensor::EmptyOp>( 468 unPackOp->getLoc(), unpackOpResultDims[0], elemType); 469 470 rewriter.replaceOpWithNewOp<UnPackOp>( 471 unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec, 472 newMixedInnerTilesVec, newOuterDimsPermVec); 473 474 return success(); 475 } 476 }; 477 } // namespace 478 479 void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) { 480 patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp, 481 FoldProducerPackWithConsumerLinalgTransposeOp, 482 FoldConsumerPackWithProducerLinalgTransposeOp, 483 FoldConsumerUnPackWithProducerLinalgTransposeOp, 484 FoldProducerUnPackWithConsumerLinalgTransposeOp>( 485 patterns.getContext()); 486 } 487 488 void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) { 489 patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>( 490 patterns.getContext()); 491 } 492 493 } // namespace tensor 494 } // namespace mlir 495