//===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/PatternMatch.h" namespace mlir { namespace tensor { namespace { /// Returns the number of shape sizes that is either dynamic or greater than 1. static int64_t getNumGtOneDims(ArrayRef shape) { return llvm::count_if( shape, [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; }); } /// Returns success() if there is only 1 dimension size in non-packed domain /// being greater than 1 and packing only happens on the dimension. /// Note: this method should only be used by pack/unpack to reshape conversion. /// It assumes that non-unit inner tile size must be used by the non-unit /// dimension. static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op, ArrayRef srcShape, ArrayRef innerPackTileSize) { if (getNumGtOneDims(srcShape) > 1) { return rewriter.notifyMatchFailure( op, "expects non-packed domain to have at most one non-unit dims"); } // Non-unit inner tile size must be used by the non-unit dimension. If not, it // will faill on getting reassociation maps. if (getNumGtOneDims(innerPackTileSize) > 1) { return rewriter.notifyMatchFailure( op, "expects at most one non-unit inner tiles"); } return success(); } // If the `linalgOp` represents a transpose, return the permutation vector for // the transpose. Otherwise, return failure. static FailureOr> getTransposeOpPermutation(linalg::LinalgOp linalgOp) { if (auto transposeOp = dyn_cast(linalgOp.getOperation())) return SmallVector(transposeOp.getPermutation()); if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) return failure(); if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) return failure(); auto mapRange = linalgOp.getIndexingMapsArray(); if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() || mapRange.front() == mapRange.back()) { return failure(); } if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations())) return failure(); AffineMap outMap = mapRange.back(); AffineMap inMap = mapRange.front(); // To get the permutation, look at each output index and find which // dimension in the input we're reading from for that index. return llvm::map_to_vector(outMap.getResults(), [&](AffineExpr expr) -> int64_t { return *inMap.getResultPosition(expr); }); } /// Packing one-dimensional tensor can be expressed as an expand shape op. struct SimplifyPackToExpandShape : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; FailureOr insertExpand(RewriterBase &rewriter, Location loc, Value operand, Type newOperandType, ArrayRef reassociation) const { if (operand.getType() == newOperandType) return operand; return rewriter .create(loc, newOperandType, operand, reassociation) .getResult(); } /// Returns success() if it is only packing on the innermost dimension. LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter, PackOp packOp) const { auto outerDimsPerm = packOp.getOuterDimsPerm(); if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) { return rewriter.notifyMatchFailure( packOp, "expects outer_dims_perm is empty or an identity permutation"); } int64_t srcRank = packOp.getSourceRank(); ArrayRef dimsPos = packOp.getInnerDimsPos(); if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) { return rewriter.notifyMatchFailure( packOp, "expects packing at the innermost dimension"); } return success(); } LogicalResult matchAndRewrite(PackOp packOp, PatternRewriter &rewriter) const override { if (packOp.getPaddingValue()) return rewriter.notifyMatchFailure(packOp, "expects no padding value"); RankedTensorType sourceType = packOp.getSourceType(); if (failed(isPackOnInnerMostDim(rewriter, packOp)) && failed(isPackOn1D(rewriter, packOp, sourceType.getShape(), packOp.getStaticTiles())) && !packOp.isLikePad()) { return failure(); } RankedTensorType destType = packOp.getDestType(); auto reassociation = getReassociationIndicesForReshape(sourceType, destType); if (!reassociation) return failure(); FailureOr expanded = insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType, *reassociation); if (failed(expanded)) { return rewriter.notifyMatchFailure( packOp, "unable to expand source of tensor.pack"); } rewriter.replaceOp(packOp, *expanded); return success(); } }; struct SimplifyUnPackToCollapseShape : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand, Type newOperandType, ArrayAttr reassociation) const { if (operand.getType() == newOperandType) return operand; return rewriter.create(loc, newOperandType, operand, reassociation); } /// Returns success() if it is unpacking on the innermost dimension. LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter, UnPackOp unpackOp) const { auto outerDimsPerm = unpackOp.getOuterDimsPerm(); if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) { return rewriter.notifyMatchFailure( unpackOp, "expects outer_dims_perm is empty or an identity permutation"); } RankedTensorType sourceType = unpackOp.getSourceType(); RankedTensorType destType = unpackOp.getDestType(); if (!sourceType.hasStaticShape() || !destType.hasStaticShape()) return rewriter.notifyMatchFailure(unpackOp, "expects static shapes"); ArrayRef dimsPos = unpackOp.getInnerDimsPos(); if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) { return rewriter.notifyMatchFailure( unpackOp, "expects unpacking on the innermost dimension"); } return success(); } LogicalResult matchAndRewrite(UnPackOp unpackOp, PatternRewriter &rewriter) const override { RankedTensorType destType = unpackOp.getDestType(); if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) && failed(isPackOn1D(rewriter, unpackOp, destType.getShape(), unpackOp.getStaticTiles())) && !unpackOp.isLikeUnPad()) { return failure(); } RankedTensorType sourceType = unpackOp.getSourceType(); auto reassociation = getReassociationIndicesForReshape(sourceType, destType); if (!reassociation) return failure(); Value collapsed = insertCollapse( rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType, getReassociationIndicesAttribute(rewriter, *reassociation)); rewriter.replaceOp(unpackOp, collapsed); return success(); } }; /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and /// the pad op has zero low paddings, or if `pack` has no padding values. struct FoldPadWithPackOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PackOp packOp, PatternRewriter &rewriter) const override { auto padOp = packOp.getSource().getDefiningOp(); if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad()) return failure(); Value constantPaddingValue = padOp.getConstantPaddingValue(); if (!constantPaddingValue) return failure(); if (auto paddingValue = packOp.getPaddingValue()) if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue)) return failure(); rewriter.replaceOpWithNewOp( packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(), packOp.getMixedTiles(), constantPaddingValue, packOp.getOuterDimsPerm()); return success(); } }; /// Fold a `unpack` -> `extract_slice` into the `unpack` since it already /// has extract_slice semantics. struct FoldUnpackWithExtractSliceOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override { auto unpackOp = sliceOp.getSource().getDefiningOp(); if (!unpackOp) return failure(); if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) { return rewriter.notifyMatchFailure( sliceOp, "rank-reduced folding is not supported"); } // Check all offsets are zeros, and all strides are ones. if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) || !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) { return rewriter.notifyMatchFailure( sliceOp, "expects offsets to be 0s and strides to be 1s"); } // Create a new empty output tensor. Type elementType = unpackOp.getDestType().getElementType(); Value output = rewriter.create( sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType); rewriter.replaceOpWithNewOp( sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(), unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm()); return success(); } }; // Applies 'permutation' on 'inVec' and stores the result in resVec. // 'inVec' may be empty, in that case it's one-to-one mapping with permutation. // `rank` sets the boundary for permutation i.e., the permutation dim can't be // greater than the rank specified. If it's so then return false. // For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in // permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is // not allowed since `3` exceeds the value of the rank in the given range. static bool checkAndPermute(ArrayRef permutation, ArrayRef inVec, SmallVectorImpl &resVec, int64_t rank) { for (unsigned int i = 0; i < rank; ++i) { int64_t remappedPosition = permutation[i]; if (remappedPosition >= rank) return false; if (!inVec.empty()) remappedPosition = inVec[remappedPosition]; resVec.push_back(remappedPosition); } return true; } /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose /// semantics. struct FoldProducerPackWithConsumerLinalgTransposeOp : public OpInterfaceRewritePattern { using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, PatternRewriter &rewriter) const override { auto packOp = linalgOp->getOperand(0).getDefiningOp(); if (!packOp) return failure(); FailureOr> maybePerm = getTransposeOpPermutation(linalgOp); if (failed(maybePerm)) return failure(); auto innerDimsPos = packOp.getInnerDimsPos(); auto mixedInnerTiles = packOp.getMixedTiles(); auto outerDimsPerm = packOp.getOuterDimsPerm(); auto transposePerm = maybePerm.value(); SmallVector newOuterDimsPermVec; SmallVector newInnerDimsPosVec; SmallVector newMixedInnerTilesVec; int64_t srcRank = packOp.getSourceRank(); if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec, srcRank)) return rewriter.notifyMatchFailure( linalgOp, "Cannot fold in tensor.pack if a tile dimension was transposed " "with a non-tile dimension in linalg.transpose."); // Process transpose operation for tiled inner dimensions for (unsigned int i = srcRank; i < transposePerm.size(); ++i) { int64_t remappedPosition = transposePerm[i] - srcRank; newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]); newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]); } Value output = packOp.createDestinationTensor( rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec); rewriter.replaceOpWithNewOp( linalgOp, packOp.getSource(), output, newInnerDimsPosVec, newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec); return success(); } }; /// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose /// semantics. struct FoldConsumerPackWithProducerLinalgTransposeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PackOp packOp, PatternRewriter &rewriter) const override { auto linalgOp = packOp.getSource().getDefiningOp(); if (!linalgOp) return failure(); FailureOr> maybePerm = getTransposeOpPermutation(linalgOp); if (failed(maybePerm)) return failure(); auto transposePermutation = maybePerm.value(); auto outerDimsPerm = packOp.getOuterDimsPerm(); auto innerDimsPos = packOp.getInnerDimsPos(); SmallVector newInnerDimsPosVec; SmallVector newOuterDimsPermVec = llvm::to_vector(transposePermutation); if (!outerDimsPerm.empty()) applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm); // Can't use applyPermutationToVector for newInnerDimsPosVec since input and // permutation rank won't necessarily be equal in all cases. for (auto dim : innerDimsPos) newInnerDimsPosVec.push_back(transposePermutation[dim]); Value output = packOp.createDestinationTensor( rewriter, packOp.getLoc(), linalgOp->getOperand(0), packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec); rewriter.replaceOpWithNewOp( packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec, packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec); return success(); } }; /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has /// transpose semantics. struct FoldProducerUnPackWithConsumerLinalgTransposeOp : public OpInterfaceRewritePattern { using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, PatternRewriter &rewriter) const override { auto unPackOp = linalgOp->getOperand(0).getDefiningOp(); if (!unPackOp) return failure(); FailureOr> maybePerm = getTransposeOpPermutation(linalgOp); if (failed(maybePerm)) return failure(); auto outerDimsPerm = unPackOp.getOuterDimsPerm(); auto innerDimsPos = unPackOp.getInnerDimsPos(); SmallVector newInnerDimsPosVec; SmallVector newOuterDimsPermVec = invertPermutationVector(maybePerm.value()); // Can't use applyPermutationToVector for newInnerDimsPosVec since input and // permutation rank won't necessarily be equal in all cases. for (auto dim : innerDimsPos) newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]); if (!outerDimsPerm.empty()) applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm); // Reuse the destination of the transpose op. rewriter.replaceOpWithNewOp( linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0], newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec); return success(); } }; /// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has /// transpose semantics. struct FoldConsumerUnPackWithProducerLinalgTransposeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(UnPackOp unPackOp, PatternRewriter &rewriter) const override { auto linalgOp = unPackOp.getSource().getDefiningOp(); if (!linalgOp) return failure(); FailureOr> maybePerm = getTransposeOpPermutation(linalgOp); if (failed(maybePerm)) return failure(); SmallVector> unpackOpResultDims; if (failed(reifyResultShapes(rewriter, unPackOp, unpackOpResultDims))) { return failure(); } SmallVector inverseTransposePerm = invertPermutationVector(maybePerm.value()); auto outerDimsPerm = unPackOp.getOuterDimsPerm(); auto innerDimsPos = unPackOp.getInnerDimsPos(); int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size(); auto mixedInnerTilesVec = unPackOp.getMixedTiles(); SmallVector newOuterDimsPermVec; SmallVector newInnerDimsPosVec; SmallVector newMixedInnerTilesVec; if (!checkAndPermute(inverseTransposePerm, outerDimsPerm, newOuterDimsPermVec, destRank)) return rewriter.notifyMatchFailure( unPackOp, "Cannot fold in tensor.unpack if a tile dimension was transposed " "with a non-tile dimension in linalg.transpose."); // Process transpose operation for tiled inner dimensions for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) { int64_t remappedPosition = inverseTransposePerm[i] - destRank; newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]); newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]); } auto elemType = cast(unPackOp->getResultTypes()[0]).getElementType(); Value output = rewriter.create( unPackOp->getLoc(), unpackOpResultDims[0], elemType); rewriter.replaceOpWithNewOp( unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec, newMixedInnerTilesVec, newOuterDimsPermVec); return success(); } }; } // namespace void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) { patterns.insert( patterns.getContext()); } void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) { patterns.add( patterns.getContext()); } } // namespace tensor } // namespace mlir