14b14205bSHan-Chung Wang //===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===// 24b14205bSHan-Chung Wang // 34b14205bSHan-Chung Wang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 44b14205bSHan-Chung Wang // See https://llvm.org/LICENSE.txt for license information. 54b14205bSHan-Chung Wang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 64b14205bSHan-Chung Wang // 74b14205bSHan-Chung Wang //===----------------------------------------------------------------------===// 84b14205bSHan-Chung Wang 94b14205bSHan-Chung Wang #include "mlir/Dialect/Linalg/IR/Linalg.h" 104b14205bSHan-Chung Wang #include "mlir/Dialect/Tensor/IR/Tensor.h" 114b14205bSHan-Chung Wang #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 12113bce0cSPrathamesh Tagore #include "mlir/Dialect/Utils/IndexingUtils.h" 134b14205bSHan-Chung Wang #include "mlir/IR/PatternMatch.h" 144b14205bSHan-Chung Wang 154b14205bSHan-Chung Wang namespace mlir { 164b14205bSHan-Chung Wang namespace tensor { 174b14205bSHan-Chung Wang namespace { 184b14205bSHan-Chung Wang 19f59eef65SHan-Chung Wang /// Returns the number of shape sizes that is either dynamic or greater than 1. 20f59eef65SHan-Chung Wang static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) { 21f59eef65SHan-Chung Wang return llvm::count_if( 22f59eef65SHan-Chung Wang shape, [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; }); 23f59eef65SHan-Chung Wang } 24f59eef65SHan-Chung Wang 25ad3cda7aSHan-Chung Wang /// Returns success() if there is only 1 dimension size in non-packed domain 26ad3cda7aSHan-Chung Wang /// being greater than 1 and packing only happens on the dimension. 27ad3cda7aSHan-Chung Wang /// Note: this method should only be used by pack/unpack to reshape conversion. 28ad3cda7aSHan-Chung Wang /// It assumes that non-unit inner tile size must be used by the non-unit 29ad3cda7aSHan-Chung Wang /// dimension. 30ad3cda7aSHan-Chung Wang static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op, 31ad3cda7aSHan-Chung Wang ArrayRef<int64_t> srcShape, 32ad3cda7aSHan-Chung Wang ArrayRef<int64_t> innerPackTileSize) { 33ad3cda7aSHan-Chung Wang if (getNumGtOneDims(srcShape) > 1) { 34ad3cda7aSHan-Chung Wang return rewriter.notifyMatchFailure( 35ad3cda7aSHan-Chung Wang op, "expects non-packed domain to have at most one non-unit dims"); 36ad3cda7aSHan-Chung Wang } 37ad3cda7aSHan-Chung Wang // Non-unit inner tile size must be used by the non-unit dimension. If not, it 38ad3cda7aSHan-Chung Wang // will faill on getting reassociation maps. 39ad3cda7aSHan-Chung Wang if (getNumGtOneDims(innerPackTileSize) > 1) { 40ad3cda7aSHan-Chung Wang return rewriter.notifyMatchFailure( 41ad3cda7aSHan-Chung Wang op, "expects at most one non-unit inner tiles"); 42ad3cda7aSHan-Chung Wang } 43ad3cda7aSHan-Chung Wang return success(); 44ad3cda7aSHan-Chung Wang } 45ad3cda7aSHan-Chung Wang 467ef83f55SMax191 // If the `linalgOp` represents a transpose, return the permutation vector for 477ef83f55SMax191 // the transpose. Otherwise, return failure. 487ef83f55SMax191 static FailureOr<SmallVector<int64_t>> 497ef83f55SMax191 getTransposeOpPermutation(linalg::LinalgOp linalgOp) { 507ef83f55SMax191 if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation())) 517ef83f55SMax191 return SmallVector<int64_t>(transposeOp.getPermutation()); 527ef83f55SMax191 if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) 537ef83f55SMax191 return failure(); 547ef83f55SMax191 557ef83f55SMax191 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) 567ef83f55SMax191 return failure(); 577ef83f55SMax191 auto mapRange = linalgOp.getIndexingMapsArray(); 587ef83f55SMax191 if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() || 597ef83f55SMax191 mapRange.front() == mapRange.back()) { 607ef83f55SMax191 return failure(); 617ef83f55SMax191 } 627ef83f55SMax191 if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations())) 637ef83f55SMax191 return failure(); 647ef83f55SMax191 AffineMap outMap = mapRange.back(); 657ef83f55SMax191 AffineMap inMap = mapRange.front(); 667ef83f55SMax191 // To get the permutation, look at each output index and find which 677ef83f55SMax191 // dimension in the input we're reading from for that index. 687ef83f55SMax191 return llvm::map_to_vector(outMap.getResults(), 697ef83f55SMax191 [&](AffineExpr expr) -> int64_t { 707ef83f55SMax191 return *inMap.getResultPosition(expr); 717ef83f55SMax191 }); 727ef83f55SMax191 } 737ef83f55SMax191 744b14205bSHan-Chung Wang /// Packing one-dimensional tensor can be expressed as an expand shape op. 754b14205bSHan-Chung Wang struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> { 764b14205bSHan-Chung Wang using OpRewritePattern<PackOp>::OpRewritePattern; 774b14205bSHan-Chung Wang 7897069a86SGaurav Shukla FailureOr<Value> 7997069a86SGaurav Shukla insertExpand(RewriterBase &rewriter, Location loc, Value operand, 8097069a86SGaurav Shukla Type newOperandType, 8197069a86SGaurav Shukla ArrayRef<ReassociationIndices> reassociation) const { 824b14205bSHan-Chung Wang if (operand.getType() == newOperandType) 834b14205bSHan-Chung Wang return operand; 8497069a86SGaurav Shukla return rewriter 8597069a86SGaurav Shukla .create<tensor::ExpandShapeOp>(loc, newOperandType, operand, 8697069a86SGaurav Shukla reassociation) 8797069a86SGaurav Shukla .getResult(); 884b14205bSHan-Chung Wang } 894b14205bSHan-Chung Wang 90f59eef65SHan-Chung Wang /// Returns success() if it is only packing on the innermost dimension. 91f59eef65SHan-Chung Wang LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter, 92f59eef65SHan-Chung Wang PackOp packOp) const { 932472c45bSHan-Chung Wang auto outerDimsPerm = packOp.getOuterDimsPerm(); 942472c45bSHan-Chung Wang if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) { 952472c45bSHan-Chung Wang return rewriter.notifyMatchFailure( 962472c45bSHan-Chung Wang packOp, 972472c45bSHan-Chung Wang "expects outer_dims_perm is empty or an identity permutation"); 982472c45bSHan-Chung Wang } 9978348b69SHan-Chung Wang 100f59eef65SHan-Chung Wang int64_t srcRank = packOp.getSourceRank(); 10178348b69SHan-Chung Wang ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos(); 102f59eef65SHan-Chung Wang if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) { 10378348b69SHan-Chung Wang return rewriter.notifyMatchFailure( 10478348b69SHan-Chung Wang packOp, "expects packing at the innermost dimension"); 10578348b69SHan-Chung Wang } 106f59eef65SHan-Chung Wang return success(); 107f59eef65SHan-Chung Wang } 10878348b69SHan-Chung Wang 109f59eef65SHan-Chung Wang LogicalResult matchAndRewrite(PackOp packOp, 110f59eef65SHan-Chung Wang PatternRewriter &rewriter) const override { 111f59eef65SHan-Chung Wang if (packOp.getPaddingValue()) 112f59eef65SHan-Chung Wang return rewriter.notifyMatchFailure(packOp, "expects no padding value"); 113f59eef65SHan-Chung Wang 114ad3cda7aSHan-Chung Wang RankedTensorType sourceType = packOp.getSourceType(); 115f59eef65SHan-Chung Wang if (failed(isPackOnInnerMostDim(rewriter, packOp)) && 116ad3cda7aSHan-Chung Wang failed(isPackOn1D(rewriter, packOp, sourceType.getShape(), 117a79a0c52SAdam Siemieniuk packOp.getStaticTiles())) && 118a79a0c52SAdam Siemieniuk !packOp.isLikePad()) { 119f59eef65SHan-Chung Wang return failure(); 120f59eef65SHan-Chung Wang } 121f59eef65SHan-Chung Wang 122f59eef65SHan-Chung Wang RankedTensorType destType = packOp.getDestType(); 1234b14205bSHan-Chung Wang auto reassociation = 1244b14205bSHan-Chung Wang getReassociationIndicesForReshape(sourceType, destType); 1254b14205bSHan-Chung Wang if (!reassociation) 1264b14205bSHan-Chung Wang return failure(); 12797069a86SGaurav Shukla FailureOr<Value> expanded = 12897069a86SGaurav Shukla insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType, 12997069a86SGaurav Shukla *reassociation); 13097069a86SGaurav Shukla if (failed(expanded)) { 13197069a86SGaurav Shukla return rewriter.notifyMatchFailure( 13297069a86SGaurav Shukla packOp, "unable to expand source of tensor.pack"); 13397069a86SGaurav Shukla } 13497069a86SGaurav Shukla rewriter.replaceOp(packOp, *expanded); 1354b14205bSHan-Chung Wang return success(); 1364b14205bSHan-Chung Wang } 1374b14205bSHan-Chung Wang }; 1384b14205bSHan-Chung Wang 13976cb0bb7SHan-Chung Wang struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> { 14076cb0bb7SHan-Chung Wang using OpRewritePattern<UnPackOp>::OpRewritePattern; 14176cb0bb7SHan-Chung Wang 14276cb0bb7SHan-Chung Wang Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand, 14376cb0bb7SHan-Chung Wang Type newOperandType, ArrayAttr reassociation) const { 14476cb0bb7SHan-Chung Wang if (operand.getType() == newOperandType) 14576cb0bb7SHan-Chung Wang return operand; 14676cb0bb7SHan-Chung Wang return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType, 14776cb0bb7SHan-Chung Wang operand, reassociation); 14876cb0bb7SHan-Chung Wang } 14976cb0bb7SHan-Chung Wang 150ad3cda7aSHan-Chung Wang /// Returns success() if it is unpacking on the innermost dimension. 151ad3cda7aSHan-Chung Wang LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter, 152ad3cda7aSHan-Chung Wang UnPackOp unpackOp) const { 1532472c45bSHan-Chung Wang auto outerDimsPerm = unpackOp.getOuterDimsPerm(); 1542472c45bSHan-Chung Wang if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) { 1552472c45bSHan-Chung Wang return rewriter.notifyMatchFailure( 1562472c45bSHan-Chung Wang unpackOp, 1572472c45bSHan-Chung Wang "expects outer_dims_perm is empty or an identity permutation"); 15876cb0bb7SHan-Chung Wang } 15976cb0bb7SHan-Chung Wang 16076cb0bb7SHan-Chung Wang RankedTensorType sourceType = unpackOp.getSourceType(); 16176cb0bb7SHan-Chung Wang RankedTensorType destType = unpackOp.getDestType(); 16276cb0bb7SHan-Chung Wang if (!sourceType.hasStaticShape() || !destType.hasStaticShape()) 16376cb0bb7SHan-Chung Wang return rewriter.notifyMatchFailure(unpackOp, "expects static shapes"); 16476cb0bb7SHan-Chung Wang 16576cb0bb7SHan-Chung Wang ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos(); 16676cb0bb7SHan-Chung Wang if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) { 16776cb0bb7SHan-Chung Wang return rewriter.notifyMatchFailure( 168ad3cda7aSHan-Chung Wang unpackOp, "expects unpacking on the innermost dimension"); 16976cb0bb7SHan-Chung Wang } 17076cb0bb7SHan-Chung Wang 171ad3cda7aSHan-Chung Wang return success(); 172ad3cda7aSHan-Chung Wang } 173ad3cda7aSHan-Chung Wang 174ad3cda7aSHan-Chung Wang LogicalResult matchAndRewrite(UnPackOp unpackOp, 175ad3cda7aSHan-Chung Wang PatternRewriter &rewriter) const override { 176ad3cda7aSHan-Chung Wang RankedTensorType destType = unpackOp.getDestType(); 177ad3cda7aSHan-Chung Wang if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) && 178ad3cda7aSHan-Chung Wang failed(isPackOn1D(rewriter, unpackOp, destType.getShape(), 179a79a0c52SAdam Siemieniuk unpackOp.getStaticTiles())) && 180a79a0c52SAdam Siemieniuk !unpackOp.isLikeUnPad()) { 181ad3cda7aSHan-Chung Wang return failure(); 182ad3cda7aSHan-Chung Wang } 183ad3cda7aSHan-Chung Wang 184ad3cda7aSHan-Chung Wang RankedTensorType sourceType = unpackOp.getSourceType(); 18576cb0bb7SHan-Chung Wang auto reassociation = 18676cb0bb7SHan-Chung Wang getReassociationIndicesForReshape(sourceType, destType); 18776cb0bb7SHan-Chung Wang if (!reassociation) 18876cb0bb7SHan-Chung Wang return failure(); 18976cb0bb7SHan-Chung Wang Value collapsed = insertCollapse( 19076cb0bb7SHan-Chung Wang rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType, 19176cb0bb7SHan-Chung Wang getReassociationIndicesAttribute(rewriter, *reassociation)); 19276cb0bb7SHan-Chung Wang rewriter.replaceOp(unpackOp, collapsed); 19376cb0bb7SHan-Chung Wang return success(); 19476cb0bb7SHan-Chung Wang } 19576cb0bb7SHan-Chung Wang }; 19676cb0bb7SHan-Chung Wang 1974b14205bSHan-Chung Wang /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and 1984b14205bSHan-Chung Wang /// the pad op has zero low paddings, or if `pack` has no padding values. 1994b14205bSHan-Chung Wang struct FoldPadWithPackOp : public OpRewritePattern<PackOp> { 2004b14205bSHan-Chung Wang using OpRewritePattern<PackOp>::OpRewritePattern; 2014b14205bSHan-Chung Wang 2024b14205bSHan-Chung Wang LogicalResult matchAndRewrite(PackOp packOp, 2034b14205bSHan-Chung Wang PatternRewriter &rewriter) const override { 2044b14205bSHan-Chung Wang auto padOp = packOp.getSource().getDefiningOp<PadOp>(); 2054b14205bSHan-Chung Wang 2064b14205bSHan-Chung Wang if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad()) 2074b14205bSHan-Chung Wang return failure(); 2084b14205bSHan-Chung Wang 2094b14205bSHan-Chung Wang Value constantPaddingValue = padOp.getConstantPaddingValue(); 2104b14205bSHan-Chung Wang if (!constantPaddingValue) 2114b14205bSHan-Chung Wang return failure(); 2124b14205bSHan-Chung Wang 2134b14205bSHan-Chung Wang if (auto paddingValue = packOp.getPaddingValue()) 2144b14205bSHan-Chung Wang if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue)) 2154b14205bSHan-Chung Wang return failure(); 2164b14205bSHan-Chung Wang 2174b14205bSHan-Chung Wang rewriter.replaceOpWithNewOp<PackOp>( 2184b14205bSHan-Chung Wang packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(), 2194b14205bSHan-Chung Wang packOp.getMixedTiles(), constantPaddingValue, 2204b14205bSHan-Chung Wang packOp.getOuterDimsPerm()); 2214b14205bSHan-Chung Wang return success(); 2224b14205bSHan-Chung Wang } 2234b14205bSHan-Chung Wang }; 2244b14205bSHan-Chung Wang 2254b14205bSHan-Chung Wang /// Fold a `unpack` -> `extract_slice` into the `unpack` since it already 2264b14205bSHan-Chung Wang /// has extract_slice semantics. 2274b14205bSHan-Chung Wang struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> { 2284b14205bSHan-Chung Wang using OpRewritePattern<ExtractSliceOp>::OpRewritePattern; 2294b14205bSHan-Chung Wang 2304b14205bSHan-Chung Wang LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, 2314b14205bSHan-Chung Wang PatternRewriter &rewriter) const override { 2324b14205bSHan-Chung Wang auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>(); 2334b14205bSHan-Chung Wang if (!unpackOp) 2344b14205bSHan-Chung Wang return failure(); 2354b14205bSHan-Chung Wang 2364b14205bSHan-Chung Wang if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) { 2374b14205bSHan-Chung Wang return rewriter.notifyMatchFailure( 2384b14205bSHan-Chung Wang sliceOp, "rank-reduced folding is not supported"); 2394b14205bSHan-Chung Wang } 2404b14205bSHan-Chung Wang 2414b14205bSHan-Chung Wang // Check all offsets are zeros, and all strides are ones. 2424b14205bSHan-Chung Wang if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) || 2434b14205bSHan-Chung Wang !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) { 2444b14205bSHan-Chung Wang return rewriter.notifyMatchFailure( 2454b14205bSHan-Chung Wang sliceOp, "expects offsets to be 0s and strides to be 1s"); 2464b14205bSHan-Chung Wang } 2474b14205bSHan-Chung Wang 2484b14205bSHan-Chung Wang // Create a new empty output tensor. 2494b14205bSHan-Chung Wang Type elementType = unpackOp.getDestType().getElementType(); 2504b14205bSHan-Chung Wang Value output = rewriter.create<EmptyOp>( 2514b14205bSHan-Chung Wang sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType); 2524b14205bSHan-Chung Wang rewriter.replaceOpWithNewOp<UnPackOp>( 2534b14205bSHan-Chung Wang sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(), 2544b14205bSHan-Chung Wang unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm()); 2554b14205bSHan-Chung Wang return success(); 2564b14205bSHan-Chung Wang } 2574b14205bSHan-Chung Wang }; 2584b14205bSHan-Chung Wang 259aa7ae1baSPrashant Kumar // Applies 'permutation' on 'inVec' and stores the result in resVec. 260aa7ae1baSPrashant Kumar // 'inVec' may be empty, in that case it's one-to-one mapping with permutation. 261aa7ae1baSPrashant Kumar // `rank` sets the boundary for permutation i.e., the permutation dim can't be 262aa7ae1baSPrashant Kumar // greater than the rank specified. If it's so then return false. 263aa7ae1baSPrashant Kumar // For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in 264aa7ae1baSPrashant Kumar // permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is 265aa7ae1baSPrashant Kumar // not allowed since `3` exceeds the value of the rank in the given range. 266aa7ae1baSPrashant Kumar static bool checkAndPermute(ArrayRef<int64_t> permutation, 267aa7ae1baSPrashant Kumar ArrayRef<int64_t> inVec, 268aa7ae1baSPrashant Kumar SmallVectorImpl<int64_t> &resVec, int64_t rank) { 269aa7ae1baSPrashant Kumar 270aa7ae1baSPrashant Kumar for (unsigned int i = 0; i < rank; ++i) { 271aa7ae1baSPrashant Kumar int64_t remappedPosition = permutation[i]; 2727ef83f55SMax191 if (remappedPosition >= rank) 273aa7ae1baSPrashant Kumar return false; 2747ef83f55SMax191 if (!inVec.empty()) 275aa7ae1baSPrashant Kumar remappedPosition = inVec[remappedPosition]; 276aa7ae1baSPrashant Kumar resVec.push_back(remappedPosition); 277aa7ae1baSPrashant Kumar } 278aa7ae1baSPrashant Kumar 279aa7ae1baSPrashant Kumar return true; 280aa7ae1baSPrashant Kumar } 281aa7ae1baSPrashant Kumar 2824b14205bSHan-Chung Wang /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose 2834b14205bSHan-Chung Wang /// semantics. 2844b14205bSHan-Chung Wang struct FoldProducerPackWithConsumerLinalgTransposeOp 2857ef83f55SMax191 : public OpInterfaceRewritePattern<linalg::LinalgOp> { 2867ef83f55SMax191 using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern; 2874b14205bSHan-Chung Wang 2887ef83f55SMax191 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, 2894b14205bSHan-Chung Wang PatternRewriter &rewriter) const override { 2907ef83f55SMax191 auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>(); 2914b14205bSHan-Chung Wang 2924b14205bSHan-Chung Wang if (!packOp) 2934b14205bSHan-Chung Wang return failure(); 2944b14205bSHan-Chung Wang 2957ef83f55SMax191 FailureOr<SmallVector<int64_t>> maybePerm = 2967ef83f55SMax191 getTransposeOpPermutation(linalgOp); 2977ef83f55SMax191 if (failed(maybePerm)) 2987ef83f55SMax191 return failure(); 2997ef83f55SMax191 3004b14205bSHan-Chung Wang auto innerDimsPos = packOp.getInnerDimsPos(); 3014b14205bSHan-Chung Wang auto mixedInnerTiles = packOp.getMixedTiles(); 3024b14205bSHan-Chung Wang auto outerDimsPerm = packOp.getOuterDimsPerm(); 3037ef83f55SMax191 auto transposePerm = maybePerm.value(); 3044b14205bSHan-Chung Wang SmallVector<int64_t> newOuterDimsPermVec; 3054b14205bSHan-Chung Wang SmallVector<int64_t> newInnerDimsPosVec; 3064b14205bSHan-Chung Wang SmallVector<OpFoldResult> newMixedInnerTilesVec; 3074b14205bSHan-Chung Wang int64_t srcRank = packOp.getSourceRank(); 3084b14205bSHan-Chung Wang 309aa7ae1baSPrashant Kumar if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec, 310aa7ae1baSPrashant Kumar srcRank)) 3114b14205bSHan-Chung Wang return rewriter.notifyMatchFailure( 3127ef83f55SMax191 linalgOp, 3134b14205bSHan-Chung Wang "Cannot fold in tensor.pack if a tile dimension was transposed " 3144b14205bSHan-Chung Wang "with a non-tile dimension in linalg.transpose."); 3154b14205bSHan-Chung Wang 3164b14205bSHan-Chung Wang // Process transpose operation for tiled inner dimensions 3174b14205bSHan-Chung Wang for (unsigned int i = srcRank; i < transposePerm.size(); ++i) { 3184b14205bSHan-Chung Wang int64_t remappedPosition = transposePerm[i] - srcRank; 3194b14205bSHan-Chung Wang newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]); 3204b14205bSHan-Chung Wang newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]); 3214b14205bSHan-Chung Wang } 3224b14205bSHan-Chung Wang 3234b14205bSHan-Chung Wang Value output = packOp.createDestinationTensor( 3247ef83f55SMax191 rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec, 3257ef83f55SMax191 newInnerDimsPosVec, newOuterDimsPermVec); 3264b14205bSHan-Chung Wang 3274b14205bSHan-Chung Wang rewriter.replaceOpWithNewOp<PackOp>( 3287ef83f55SMax191 linalgOp, packOp.getSource(), output, newInnerDimsPosVec, 3294b14205bSHan-Chung Wang newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec); 3304b14205bSHan-Chung Wang 3314b14205bSHan-Chung Wang return success(); 3324b14205bSHan-Chung Wang } 3334b14205bSHan-Chung Wang }; 334113bce0cSPrathamesh Tagore 335113bce0cSPrathamesh Tagore /// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose 336113bce0cSPrathamesh Tagore /// semantics. 337113bce0cSPrathamesh Tagore struct FoldConsumerPackWithProducerLinalgTransposeOp 338113bce0cSPrathamesh Tagore : public OpRewritePattern<PackOp> { 339113bce0cSPrathamesh Tagore using OpRewritePattern<PackOp>::OpRewritePattern; 340113bce0cSPrathamesh Tagore 341113bce0cSPrathamesh Tagore LogicalResult matchAndRewrite(PackOp packOp, 342113bce0cSPrathamesh Tagore PatternRewriter &rewriter) const override { 3437ef83f55SMax191 auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>(); 3447ef83f55SMax191 if (!linalgOp) 345113bce0cSPrathamesh Tagore return failure(); 346113bce0cSPrathamesh Tagore 3477ef83f55SMax191 FailureOr<SmallVector<int64_t>> maybePerm = 3487ef83f55SMax191 getTransposeOpPermutation(linalgOp); 3497ef83f55SMax191 if (failed(maybePerm)) 3507ef83f55SMax191 return failure(); 3517ef83f55SMax191 3527ef83f55SMax191 auto transposePermutation = maybePerm.value(); 353113bce0cSPrathamesh Tagore auto outerDimsPerm = packOp.getOuterDimsPerm(); 354113bce0cSPrathamesh Tagore auto innerDimsPos = packOp.getInnerDimsPos(); 355113bce0cSPrathamesh Tagore SmallVector<int64_t> newInnerDimsPosVec; 356113bce0cSPrathamesh Tagore SmallVector<int64_t> newOuterDimsPermVec = 357113bce0cSPrathamesh Tagore llvm::to_vector(transposePermutation); 358113bce0cSPrathamesh Tagore 359113bce0cSPrathamesh Tagore if (!outerDimsPerm.empty()) 360113bce0cSPrathamesh Tagore applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm); 361113bce0cSPrathamesh Tagore 362113bce0cSPrathamesh Tagore // Can't use applyPermutationToVector for newInnerDimsPosVec since input and 363113bce0cSPrathamesh Tagore // permutation rank won't necessarily be equal in all cases. 364113bce0cSPrathamesh Tagore for (auto dim : innerDimsPos) 365113bce0cSPrathamesh Tagore newInnerDimsPosVec.push_back(transposePermutation[dim]); 366113bce0cSPrathamesh Tagore 367113bce0cSPrathamesh Tagore Value output = packOp.createDestinationTensor( 3687ef83f55SMax191 rewriter, packOp.getLoc(), linalgOp->getOperand(0), 369113bce0cSPrathamesh Tagore packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec); 370113bce0cSPrathamesh Tagore 371113bce0cSPrathamesh Tagore rewriter.replaceOpWithNewOp<PackOp>( 3727ef83f55SMax191 packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec, 373113bce0cSPrathamesh Tagore packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec); 374113bce0cSPrathamesh Tagore 375113bce0cSPrathamesh Tagore return success(); 376113bce0cSPrathamesh Tagore } 377113bce0cSPrathamesh Tagore }; 378aa7ae1baSPrashant Kumar 379aa7ae1baSPrashant Kumar /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has 380aa7ae1baSPrashant Kumar /// transpose semantics. 381aa7ae1baSPrashant Kumar struct FoldProducerUnPackWithConsumerLinalgTransposeOp 3827ef83f55SMax191 : public OpInterfaceRewritePattern<linalg::LinalgOp> { 3837ef83f55SMax191 using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern; 384aa7ae1baSPrashant Kumar 3857ef83f55SMax191 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, 386aa7ae1baSPrashant Kumar PatternRewriter &rewriter) const override { 3877ef83f55SMax191 auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>(); 388aa7ae1baSPrashant Kumar 389aa7ae1baSPrashant Kumar if (!unPackOp) 390aa7ae1baSPrashant Kumar return failure(); 391aa7ae1baSPrashant Kumar 3927ef83f55SMax191 FailureOr<SmallVector<int64_t>> maybePerm = 3937ef83f55SMax191 getTransposeOpPermutation(linalgOp); 3947ef83f55SMax191 if (failed(maybePerm)) 3957ef83f55SMax191 return failure(); 3967ef83f55SMax191 397aa7ae1baSPrashant Kumar auto outerDimsPerm = unPackOp.getOuterDimsPerm(); 398aa7ae1baSPrashant Kumar auto innerDimsPos = unPackOp.getInnerDimsPos(); 399aa7ae1baSPrashant Kumar SmallVector<int64_t> newInnerDimsPosVec; 400aa7ae1baSPrashant Kumar SmallVector<int64_t> newOuterDimsPermVec = 4017ef83f55SMax191 invertPermutationVector(maybePerm.value()); 402aa7ae1baSPrashant Kumar 403aa7ae1baSPrashant Kumar // Can't use applyPermutationToVector for newInnerDimsPosVec since input and 404aa7ae1baSPrashant Kumar // permutation rank won't necessarily be equal in all cases. 405aa7ae1baSPrashant Kumar for (auto dim : innerDimsPos) 4067ef83f55SMax191 newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]); 4077ef83f55SMax191 4087ef83f55SMax191 if (!outerDimsPerm.empty()) 4097ef83f55SMax191 applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm); 410aa7ae1baSPrashant Kumar 41175f72954SQuinn Dawkins // Reuse the destination of the transpose op. 412aa7ae1baSPrashant Kumar rewriter.replaceOpWithNewOp<UnPackOp>( 4137ef83f55SMax191 linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0], 41475f72954SQuinn Dawkins newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec); 415aa7ae1baSPrashant Kumar 416aa7ae1baSPrashant Kumar return success(); 417aa7ae1baSPrashant Kumar } 418aa7ae1baSPrashant Kumar }; 419aa7ae1baSPrashant Kumar 420aa7ae1baSPrashant Kumar /// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has 421aa7ae1baSPrashant Kumar /// transpose semantics. 422aa7ae1baSPrashant Kumar struct FoldConsumerUnPackWithProducerLinalgTransposeOp 423aa7ae1baSPrashant Kumar : public OpRewritePattern<UnPackOp> { 424aa7ae1baSPrashant Kumar using OpRewritePattern<UnPackOp>::OpRewritePattern; 425aa7ae1baSPrashant Kumar 426aa7ae1baSPrashant Kumar LogicalResult matchAndRewrite(UnPackOp unPackOp, 427aa7ae1baSPrashant Kumar PatternRewriter &rewriter) const override { 4287ef83f55SMax191 auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>(); 4297ef83f55SMax191 if (!linalgOp) 430aa7ae1baSPrashant Kumar return failure(); 431aa7ae1baSPrashant Kumar 4327ef83f55SMax191 FailureOr<SmallVector<int64_t>> maybePerm = 4337ef83f55SMax191 getTransposeOpPermutation(linalgOp); 4347ef83f55SMax191 if (failed(maybePerm)) 4357ef83f55SMax191 return failure(); 4367ef83f55SMax191 437*c1667f90SBenoit Jacob SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims; 438*c1667f90SBenoit Jacob if (failed(reifyResultShapes(rewriter, unPackOp, unpackOpResultDims))) { 439*c1667f90SBenoit Jacob return failure(); 440*c1667f90SBenoit Jacob } 441*c1667f90SBenoit Jacob 4427ef83f55SMax191 SmallVector<int64_t> inverseTransposePerm = 4437ef83f55SMax191 invertPermutationVector(maybePerm.value()); 444aa7ae1baSPrashant Kumar auto outerDimsPerm = unPackOp.getOuterDimsPerm(); 445aa7ae1baSPrashant Kumar auto innerDimsPos = unPackOp.getInnerDimsPos(); 446aa7ae1baSPrashant Kumar int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size(); 447aa7ae1baSPrashant Kumar auto mixedInnerTilesVec = unPackOp.getMixedTiles(); 448aa7ae1baSPrashant Kumar SmallVector<int64_t> newOuterDimsPermVec; 449aa7ae1baSPrashant Kumar SmallVector<int64_t> newInnerDimsPosVec; 450aa7ae1baSPrashant Kumar SmallVector<OpFoldResult> newMixedInnerTilesVec; 4517ef83f55SMax191 if (!checkAndPermute(inverseTransposePerm, outerDimsPerm, 452aa7ae1baSPrashant Kumar newOuterDimsPermVec, destRank)) 453aa7ae1baSPrashant Kumar return rewriter.notifyMatchFailure( 454aa7ae1baSPrashant Kumar unPackOp, 455aa7ae1baSPrashant Kumar "Cannot fold in tensor.unpack if a tile dimension was transposed " 456aa7ae1baSPrashant Kumar "with a non-tile dimension in linalg.transpose."); 457aa7ae1baSPrashant Kumar 458aa7ae1baSPrashant Kumar // Process transpose operation for tiled inner dimensions 4597ef83f55SMax191 for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) { 4607ef83f55SMax191 int64_t remappedPosition = inverseTransposePerm[i] - destRank; 461aa7ae1baSPrashant Kumar newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]); 462aa7ae1baSPrashant Kumar newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]); 463aa7ae1baSPrashant Kumar } 464aa7ae1baSPrashant Kumar 465*c1667f90SBenoit Jacob auto elemType = 466*c1667f90SBenoit Jacob cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType(); 467*c1667f90SBenoit Jacob Value output = rewriter.create<tensor::EmptyOp>( 468*c1667f90SBenoit Jacob unPackOp->getLoc(), unpackOpResultDims[0], elemType); 469aa7ae1baSPrashant Kumar 470aa7ae1baSPrashant Kumar rewriter.replaceOpWithNewOp<UnPackOp>( 4717ef83f55SMax191 unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec, 472aa7ae1baSPrashant Kumar newMixedInnerTilesVec, newOuterDimsPermVec); 473aa7ae1baSPrashant Kumar 474aa7ae1baSPrashant Kumar return success(); 475aa7ae1baSPrashant Kumar } 476aa7ae1baSPrashant Kumar }; 4774b14205bSHan-Chung Wang } // namespace 4784b14205bSHan-Chung Wang 4794b14205bSHan-Chung Wang void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) { 4804b14205bSHan-Chung Wang patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp, 481113bce0cSPrathamesh Tagore FoldProducerPackWithConsumerLinalgTransposeOp, 482aa7ae1baSPrashant Kumar FoldConsumerPackWithProducerLinalgTransposeOp, 483aa7ae1baSPrashant Kumar FoldConsumerUnPackWithProducerLinalgTransposeOp, 484aa7ae1baSPrashant Kumar FoldProducerUnPackWithConsumerLinalgTransposeOp>( 4854b14205bSHan-Chung Wang patterns.getContext()); 4864b14205bSHan-Chung Wang } 4874b14205bSHan-Chung Wang 4884b14205bSHan-Chung Wang void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) { 48976cb0bb7SHan-Chung Wang patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>( 49076cb0bb7SHan-Chung Wang patterns.getContext()); 4914b14205bSHan-Chung Wang } 4924b14205bSHan-Chung Wang 4934b14205bSHan-Chung Wang } // namespace tensor 4944b14205bSHan-Chung Wang } // namespace mlir 495