xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp (revision 98e838a890191b9250ad33741a1c121a9591caa3)
14b14205bSHan-Chung Wang //===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===//
24b14205bSHan-Chung Wang //
34b14205bSHan-Chung Wang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44b14205bSHan-Chung Wang // See https://llvm.org/LICENSE.txt for license information.
54b14205bSHan-Chung Wang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64b14205bSHan-Chung Wang //
74b14205bSHan-Chung Wang //===----------------------------------------------------------------------===//
84b14205bSHan-Chung Wang 
94b14205bSHan-Chung Wang #include "mlir/Dialect/Linalg/IR/Linalg.h"
104b14205bSHan-Chung Wang #include "mlir/Dialect/Tensor/IR/Tensor.h"
114b14205bSHan-Chung Wang #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
12113bce0cSPrathamesh Tagore #include "mlir/Dialect/Utils/IndexingUtils.h"
134b14205bSHan-Chung Wang #include "mlir/IR/PatternMatch.h"
144b14205bSHan-Chung Wang 
154b14205bSHan-Chung Wang namespace mlir {
164b14205bSHan-Chung Wang namespace tensor {
174b14205bSHan-Chung Wang namespace {
184b14205bSHan-Chung Wang 
19f59eef65SHan-Chung Wang /// Returns the number of shape sizes that is either dynamic or greater than 1.
20f59eef65SHan-Chung Wang static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
21f59eef65SHan-Chung Wang   return llvm::count_if(
22f59eef65SHan-Chung Wang       shape, [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; });
23f59eef65SHan-Chung Wang }
24f59eef65SHan-Chung Wang 
25ad3cda7aSHan-Chung Wang /// Returns success() if there is only 1 dimension size in non-packed domain
26ad3cda7aSHan-Chung Wang /// being greater than 1 and packing only happens on the dimension.
27ad3cda7aSHan-Chung Wang /// Note: this method should only be used by pack/unpack to reshape conversion.
28ad3cda7aSHan-Chung Wang /// It assumes that non-unit inner tile size must be used by the non-unit
29ad3cda7aSHan-Chung Wang /// dimension.
30ad3cda7aSHan-Chung Wang static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
31ad3cda7aSHan-Chung Wang                                 ArrayRef<int64_t> srcShape,
32ad3cda7aSHan-Chung Wang                                 ArrayRef<int64_t> innerPackTileSize) {
33ad3cda7aSHan-Chung Wang   if (getNumGtOneDims(srcShape) > 1) {
34ad3cda7aSHan-Chung Wang     return rewriter.notifyMatchFailure(
35ad3cda7aSHan-Chung Wang         op, "expects non-packed domain to have at most one non-unit dims");
36ad3cda7aSHan-Chung Wang   }
37ad3cda7aSHan-Chung Wang   // Non-unit inner tile size must be used by the non-unit dimension. If not, it
38ad3cda7aSHan-Chung Wang   // will faill on getting reassociation maps.
39ad3cda7aSHan-Chung Wang   if (getNumGtOneDims(innerPackTileSize) > 1) {
40ad3cda7aSHan-Chung Wang     return rewriter.notifyMatchFailure(
41ad3cda7aSHan-Chung Wang         op, "expects at most one non-unit inner tiles");
42ad3cda7aSHan-Chung Wang   }
43ad3cda7aSHan-Chung Wang   return success();
44ad3cda7aSHan-Chung Wang }
45ad3cda7aSHan-Chung Wang 
467ef83f55SMax191 // If the `linalgOp` represents a transpose, return the permutation vector for
477ef83f55SMax191 // the transpose. Otherwise, return failure.
487ef83f55SMax191 static FailureOr<SmallVector<int64_t>>
497ef83f55SMax191 getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
507ef83f55SMax191   if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
517ef83f55SMax191     return SmallVector<int64_t>(transposeOp.getPermutation());
527ef83f55SMax191   if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
537ef83f55SMax191     return failure();
547ef83f55SMax191 
557ef83f55SMax191   if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
567ef83f55SMax191     return failure();
577ef83f55SMax191   auto mapRange = linalgOp.getIndexingMapsArray();
587ef83f55SMax191   if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() ||
597ef83f55SMax191       mapRange.front() == mapRange.back()) {
607ef83f55SMax191     return failure();
617ef83f55SMax191   }
627ef83f55SMax191   if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
637ef83f55SMax191     return failure();
647ef83f55SMax191   AffineMap outMap = mapRange.back();
657ef83f55SMax191   AffineMap inMap = mapRange.front();
667ef83f55SMax191   // To get the permutation, look at each output index and find which
677ef83f55SMax191   // dimension in the input we're reading from for that index.
687ef83f55SMax191   return llvm::map_to_vector(outMap.getResults(),
697ef83f55SMax191                              [&](AffineExpr expr) -> int64_t {
707ef83f55SMax191                                return *inMap.getResultPosition(expr);
717ef83f55SMax191                              });
727ef83f55SMax191 }
737ef83f55SMax191 
744b14205bSHan-Chung Wang /// Packing one-dimensional tensor can be expressed as an expand shape op.
754b14205bSHan-Chung Wang struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
764b14205bSHan-Chung Wang   using OpRewritePattern<PackOp>::OpRewritePattern;
774b14205bSHan-Chung Wang 
7897069a86SGaurav Shukla   FailureOr<Value>
7997069a86SGaurav Shukla   insertExpand(RewriterBase &rewriter, Location loc, Value operand,
8097069a86SGaurav Shukla                Type newOperandType,
8197069a86SGaurav Shukla                ArrayRef<ReassociationIndices> reassociation) const {
824b14205bSHan-Chung Wang     if (operand.getType() == newOperandType)
834b14205bSHan-Chung Wang       return operand;
8497069a86SGaurav Shukla     return rewriter
8597069a86SGaurav Shukla         .create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
8697069a86SGaurav Shukla                                        reassociation)
8797069a86SGaurav Shukla         .getResult();
884b14205bSHan-Chung Wang   }
894b14205bSHan-Chung Wang 
90f59eef65SHan-Chung Wang   /// Returns success() if it is only packing on the innermost dimension.
91f59eef65SHan-Chung Wang   LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
92f59eef65SHan-Chung Wang                                      PackOp packOp) const {
932472c45bSHan-Chung Wang     auto outerDimsPerm = packOp.getOuterDimsPerm();
942472c45bSHan-Chung Wang     if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
952472c45bSHan-Chung Wang       return rewriter.notifyMatchFailure(
962472c45bSHan-Chung Wang           packOp,
972472c45bSHan-Chung Wang           "expects outer_dims_perm is empty or an identity permutation");
982472c45bSHan-Chung Wang     }
9978348b69SHan-Chung Wang 
100f59eef65SHan-Chung Wang     int64_t srcRank = packOp.getSourceRank();
10178348b69SHan-Chung Wang     ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
102f59eef65SHan-Chung Wang     if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
10378348b69SHan-Chung Wang       return rewriter.notifyMatchFailure(
10478348b69SHan-Chung Wang           packOp, "expects packing at the innermost dimension");
10578348b69SHan-Chung Wang     }
106f59eef65SHan-Chung Wang     return success();
107f59eef65SHan-Chung Wang   }
10878348b69SHan-Chung Wang 
109f59eef65SHan-Chung Wang   LogicalResult matchAndRewrite(PackOp packOp,
110f59eef65SHan-Chung Wang                                 PatternRewriter &rewriter) const override {
111f59eef65SHan-Chung Wang     if (packOp.getPaddingValue())
112f59eef65SHan-Chung Wang       return rewriter.notifyMatchFailure(packOp, "expects no padding value");
113f59eef65SHan-Chung Wang 
114ad3cda7aSHan-Chung Wang     RankedTensorType sourceType = packOp.getSourceType();
115f59eef65SHan-Chung Wang     if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
116ad3cda7aSHan-Chung Wang         failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
117a79a0c52SAdam Siemieniuk                           packOp.getStaticTiles())) &&
118a79a0c52SAdam Siemieniuk         !packOp.isLikePad()) {
119f59eef65SHan-Chung Wang       return failure();
120f59eef65SHan-Chung Wang     }
121f59eef65SHan-Chung Wang 
122f59eef65SHan-Chung Wang     RankedTensorType destType = packOp.getDestType();
1234b14205bSHan-Chung Wang     auto reassociation =
1244b14205bSHan-Chung Wang         getReassociationIndicesForReshape(sourceType, destType);
1254b14205bSHan-Chung Wang     if (!reassociation)
1264b14205bSHan-Chung Wang       return failure();
12797069a86SGaurav Shukla     FailureOr<Value> expanded =
12897069a86SGaurav Shukla         insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType,
12997069a86SGaurav Shukla                      *reassociation);
13097069a86SGaurav Shukla     if (failed(expanded)) {
13197069a86SGaurav Shukla       return rewriter.notifyMatchFailure(
13297069a86SGaurav Shukla           packOp, "unable to expand source of tensor.pack");
13397069a86SGaurav Shukla     }
13497069a86SGaurav Shukla     rewriter.replaceOp(packOp, *expanded);
1354b14205bSHan-Chung Wang     return success();
1364b14205bSHan-Chung Wang   }
1374b14205bSHan-Chung Wang };
1384b14205bSHan-Chung Wang 
13976cb0bb7SHan-Chung Wang struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
14076cb0bb7SHan-Chung Wang   using OpRewritePattern<UnPackOp>::OpRewritePattern;
14176cb0bb7SHan-Chung Wang 
14276cb0bb7SHan-Chung Wang   Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
14376cb0bb7SHan-Chung Wang                        Type newOperandType, ArrayAttr reassociation) const {
14476cb0bb7SHan-Chung Wang     if (operand.getType() == newOperandType)
14576cb0bb7SHan-Chung Wang       return operand;
14676cb0bb7SHan-Chung Wang     return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
14776cb0bb7SHan-Chung Wang                                                     operand, reassociation);
14876cb0bb7SHan-Chung Wang   }
14976cb0bb7SHan-Chung Wang 
150ad3cda7aSHan-Chung Wang   /// Returns success() if it is unpacking on the innermost dimension.
151ad3cda7aSHan-Chung Wang   LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
152ad3cda7aSHan-Chung Wang                                        UnPackOp unpackOp) const {
1532472c45bSHan-Chung Wang     auto outerDimsPerm = unpackOp.getOuterDimsPerm();
1542472c45bSHan-Chung Wang     if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
1552472c45bSHan-Chung Wang       return rewriter.notifyMatchFailure(
1562472c45bSHan-Chung Wang           unpackOp,
1572472c45bSHan-Chung Wang           "expects outer_dims_perm is empty or an identity permutation");
15876cb0bb7SHan-Chung Wang     }
15976cb0bb7SHan-Chung Wang 
16076cb0bb7SHan-Chung Wang     RankedTensorType sourceType = unpackOp.getSourceType();
16176cb0bb7SHan-Chung Wang     RankedTensorType destType = unpackOp.getDestType();
16276cb0bb7SHan-Chung Wang     if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
16376cb0bb7SHan-Chung Wang       return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
16476cb0bb7SHan-Chung Wang 
16576cb0bb7SHan-Chung Wang     ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
16676cb0bb7SHan-Chung Wang     if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
16776cb0bb7SHan-Chung Wang       return rewriter.notifyMatchFailure(
168ad3cda7aSHan-Chung Wang           unpackOp, "expects unpacking on the innermost dimension");
16976cb0bb7SHan-Chung Wang     }
17076cb0bb7SHan-Chung Wang 
171ad3cda7aSHan-Chung Wang     return success();
172ad3cda7aSHan-Chung Wang   }
173ad3cda7aSHan-Chung Wang 
174ad3cda7aSHan-Chung Wang   LogicalResult matchAndRewrite(UnPackOp unpackOp,
175ad3cda7aSHan-Chung Wang                                 PatternRewriter &rewriter) const override {
176ad3cda7aSHan-Chung Wang     RankedTensorType destType = unpackOp.getDestType();
177ad3cda7aSHan-Chung Wang     if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
178ad3cda7aSHan-Chung Wang         failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
179a79a0c52SAdam Siemieniuk                           unpackOp.getStaticTiles())) &&
180a79a0c52SAdam Siemieniuk         !unpackOp.isLikeUnPad()) {
181ad3cda7aSHan-Chung Wang       return failure();
182ad3cda7aSHan-Chung Wang     }
183ad3cda7aSHan-Chung Wang 
184ad3cda7aSHan-Chung Wang     RankedTensorType sourceType = unpackOp.getSourceType();
18576cb0bb7SHan-Chung Wang     auto reassociation =
18676cb0bb7SHan-Chung Wang         getReassociationIndicesForReshape(sourceType, destType);
18776cb0bb7SHan-Chung Wang     if (!reassociation)
18876cb0bb7SHan-Chung Wang       return failure();
18976cb0bb7SHan-Chung Wang     Value collapsed = insertCollapse(
19076cb0bb7SHan-Chung Wang         rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
19176cb0bb7SHan-Chung Wang         getReassociationIndicesAttribute(rewriter, *reassociation));
19276cb0bb7SHan-Chung Wang     rewriter.replaceOp(unpackOp, collapsed);
19376cb0bb7SHan-Chung Wang     return success();
19476cb0bb7SHan-Chung Wang   }
19576cb0bb7SHan-Chung Wang };
19676cb0bb7SHan-Chung Wang 
1974b14205bSHan-Chung Wang /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
1984b14205bSHan-Chung Wang /// the pad op has zero low paddings, or if `pack` has no padding values.
1994b14205bSHan-Chung Wang struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
2004b14205bSHan-Chung Wang   using OpRewritePattern<PackOp>::OpRewritePattern;
2014b14205bSHan-Chung Wang 
2024b14205bSHan-Chung Wang   LogicalResult matchAndRewrite(PackOp packOp,
2034b14205bSHan-Chung Wang                                 PatternRewriter &rewriter) const override {
2044b14205bSHan-Chung Wang     auto padOp = packOp.getSource().getDefiningOp<PadOp>();
2054b14205bSHan-Chung Wang 
2064b14205bSHan-Chung Wang     if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
2074b14205bSHan-Chung Wang       return failure();
2084b14205bSHan-Chung Wang 
2094b14205bSHan-Chung Wang     Value constantPaddingValue = padOp.getConstantPaddingValue();
2104b14205bSHan-Chung Wang     if (!constantPaddingValue)
2114b14205bSHan-Chung Wang       return failure();
2124b14205bSHan-Chung Wang 
2134b14205bSHan-Chung Wang     if (auto paddingValue = packOp.getPaddingValue())
2144b14205bSHan-Chung Wang       if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
2154b14205bSHan-Chung Wang         return failure();
2164b14205bSHan-Chung Wang 
2174b14205bSHan-Chung Wang     rewriter.replaceOpWithNewOp<PackOp>(
2184b14205bSHan-Chung Wang         packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
2194b14205bSHan-Chung Wang         packOp.getMixedTiles(), constantPaddingValue,
2204b14205bSHan-Chung Wang         packOp.getOuterDimsPerm());
2214b14205bSHan-Chung Wang     return success();
2224b14205bSHan-Chung Wang   }
2234b14205bSHan-Chung Wang };
2244b14205bSHan-Chung Wang 
2254b14205bSHan-Chung Wang /// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
2264b14205bSHan-Chung Wang /// has extract_slice semantics.
2274b14205bSHan-Chung Wang struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
2284b14205bSHan-Chung Wang   using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2294b14205bSHan-Chung Wang 
2304b14205bSHan-Chung Wang   LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2314b14205bSHan-Chung Wang                                 PatternRewriter &rewriter) const override {
2324b14205bSHan-Chung Wang     auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
2334b14205bSHan-Chung Wang     if (!unpackOp)
2344b14205bSHan-Chung Wang       return failure();
2354b14205bSHan-Chung Wang 
2364b14205bSHan-Chung Wang     if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
2374b14205bSHan-Chung Wang       return rewriter.notifyMatchFailure(
2384b14205bSHan-Chung Wang           sliceOp, "rank-reduced folding is not supported");
2394b14205bSHan-Chung Wang     }
2404b14205bSHan-Chung Wang 
2414b14205bSHan-Chung Wang     // Check all offsets are zeros, and all strides are ones.
2424b14205bSHan-Chung Wang     if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
2434b14205bSHan-Chung Wang         !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
2444b14205bSHan-Chung Wang       return rewriter.notifyMatchFailure(
2454b14205bSHan-Chung Wang           sliceOp, "expects offsets to be 0s and strides to be 1s");
2464b14205bSHan-Chung Wang     }
2474b14205bSHan-Chung Wang 
2484b14205bSHan-Chung Wang     // Create a new empty output tensor.
2494b14205bSHan-Chung Wang     Type elementType = unpackOp.getDestType().getElementType();
2504b14205bSHan-Chung Wang     Value output = rewriter.create<EmptyOp>(
2514b14205bSHan-Chung Wang         sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
2524b14205bSHan-Chung Wang     rewriter.replaceOpWithNewOp<UnPackOp>(
2534b14205bSHan-Chung Wang         sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
2544b14205bSHan-Chung Wang         unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
2554b14205bSHan-Chung Wang     return success();
2564b14205bSHan-Chung Wang   }
2574b14205bSHan-Chung Wang };
2584b14205bSHan-Chung Wang 
259aa7ae1baSPrashant Kumar // Applies 'permutation' on 'inVec' and stores the result in resVec.
260aa7ae1baSPrashant Kumar // 'inVec' may be empty, in that case it's one-to-one mapping with permutation.
261aa7ae1baSPrashant Kumar // `rank` sets the boundary for permutation i.e., the permutation dim can't be
262aa7ae1baSPrashant Kumar // greater than the rank specified. If it's so then return false.
263aa7ae1baSPrashant Kumar // For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in
264aa7ae1baSPrashant Kumar // permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is
265aa7ae1baSPrashant Kumar // not allowed since `3` exceeds the value of the rank in the given range.
266aa7ae1baSPrashant Kumar static bool checkAndPermute(ArrayRef<int64_t> permutation,
267aa7ae1baSPrashant Kumar                             ArrayRef<int64_t> inVec,
268aa7ae1baSPrashant Kumar                             SmallVectorImpl<int64_t> &resVec, int64_t rank) {
269aa7ae1baSPrashant Kumar 
270aa7ae1baSPrashant Kumar   for (unsigned int i = 0; i < rank; ++i) {
271aa7ae1baSPrashant Kumar     int64_t remappedPosition = permutation[i];
2727ef83f55SMax191     if (remappedPosition >= rank)
273aa7ae1baSPrashant Kumar       return false;
2747ef83f55SMax191     if (!inVec.empty())
275aa7ae1baSPrashant Kumar       remappedPosition = inVec[remappedPosition];
276aa7ae1baSPrashant Kumar     resVec.push_back(remappedPosition);
277aa7ae1baSPrashant Kumar   }
278aa7ae1baSPrashant Kumar 
279aa7ae1baSPrashant Kumar   return true;
280aa7ae1baSPrashant Kumar }
281aa7ae1baSPrashant Kumar 
2824b14205bSHan-Chung Wang /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
2834b14205bSHan-Chung Wang /// semantics.
2844b14205bSHan-Chung Wang struct FoldProducerPackWithConsumerLinalgTransposeOp
2857ef83f55SMax191     : public OpInterfaceRewritePattern<linalg::LinalgOp> {
2867ef83f55SMax191   using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
2874b14205bSHan-Chung Wang 
2887ef83f55SMax191   LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
2894b14205bSHan-Chung Wang                                 PatternRewriter &rewriter) const override {
2907ef83f55SMax191     auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
2914b14205bSHan-Chung Wang 
2924b14205bSHan-Chung Wang     if (!packOp)
2934b14205bSHan-Chung Wang       return failure();
2944b14205bSHan-Chung Wang 
2957ef83f55SMax191     FailureOr<SmallVector<int64_t>> maybePerm =
2967ef83f55SMax191         getTransposeOpPermutation(linalgOp);
2977ef83f55SMax191     if (failed(maybePerm))
2987ef83f55SMax191       return failure();
2997ef83f55SMax191 
3004b14205bSHan-Chung Wang     auto innerDimsPos = packOp.getInnerDimsPos();
3014b14205bSHan-Chung Wang     auto mixedInnerTiles = packOp.getMixedTiles();
3024b14205bSHan-Chung Wang     auto outerDimsPerm = packOp.getOuterDimsPerm();
3037ef83f55SMax191     auto transposePerm = maybePerm.value();
3044b14205bSHan-Chung Wang     SmallVector<int64_t> newOuterDimsPermVec;
3054b14205bSHan-Chung Wang     SmallVector<int64_t> newInnerDimsPosVec;
3064b14205bSHan-Chung Wang     SmallVector<OpFoldResult> newMixedInnerTilesVec;
3074b14205bSHan-Chung Wang     int64_t srcRank = packOp.getSourceRank();
3084b14205bSHan-Chung Wang 
309aa7ae1baSPrashant Kumar     if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
310aa7ae1baSPrashant Kumar                          srcRank))
3114b14205bSHan-Chung Wang       return rewriter.notifyMatchFailure(
3127ef83f55SMax191           linalgOp,
3134b14205bSHan-Chung Wang           "Cannot fold in tensor.pack if a tile dimension was transposed "
3144b14205bSHan-Chung Wang           "with a non-tile dimension in linalg.transpose.");
3154b14205bSHan-Chung Wang 
3164b14205bSHan-Chung Wang     // Process transpose operation for tiled inner dimensions
3174b14205bSHan-Chung Wang     for (unsigned int i = srcRank; i < transposePerm.size(); ++i) {
3184b14205bSHan-Chung Wang       int64_t remappedPosition = transposePerm[i] - srcRank;
3194b14205bSHan-Chung Wang       newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
3204b14205bSHan-Chung Wang       newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
3214b14205bSHan-Chung Wang     }
3224b14205bSHan-Chung Wang 
3234b14205bSHan-Chung Wang     Value output = packOp.createDestinationTensor(
3247ef83f55SMax191         rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
3257ef83f55SMax191         newInnerDimsPosVec, newOuterDimsPermVec);
3264b14205bSHan-Chung Wang 
3274b14205bSHan-Chung Wang     rewriter.replaceOpWithNewOp<PackOp>(
3287ef83f55SMax191         linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
3294b14205bSHan-Chung Wang         newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
3304b14205bSHan-Chung Wang 
3314b14205bSHan-Chung Wang     return success();
3324b14205bSHan-Chung Wang   }
3334b14205bSHan-Chung Wang };
334113bce0cSPrathamesh Tagore 
335113bce0cSPrathamesh Tagore /// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
336113bce0cSPrathamesh Tagore /// semantics.
337113bce0cSPrathamesh Tagore struct FoldConsumerPackWithProducerLinalgTransposeOp
338113bce0cSPrathamesh Tagore     : public OpRewritePattern<PackOp> {
339113bce0cSPrathamesh Tagore   using OpRewritePattern<PackOp>::OpRewritePattern;
340113bce0cSPrathamesh Tagore 
341113bce0cSPrathamesh Tagore   LogicalResult matchAndRewrite(PackOp packOp,
342113bce0cSPrathamesh Tagore                                 PatternRewriter &rewriter) const override {
3437ef83f55SMax191     auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
3447ef83f55SMax191     if (!linalgOp)
345113bce0cSPrathamesh Tagore       return failure();
346113bce0cSPrathamesh Tagore 
3477ef83f55SMax191     FailureOr<SmallVector<int64_t>> maybePerm =
3487ef83f55SMax191         getTransposeOpPermutation(linalgOp);
3497ef83f55SMax191     if (failed(maybePerm))
3507ef83f55SMax191       return failure();
3517ef83f55SMax191 
3527ef83f55SMax191     auto transposePermutation = maybePerm.value();
353113bce0cSPrathamesh Tagore     auto outerDimsPerm = packOp.getOuterDimsPerm();
354113bce0cSPrathamesh Tagore     auto innerDimsPos = packOp.getInnerDimsPos();
355113bce0cSPrathamesh Tagore     SmallVector<int64_t> newInnerDimsPosVec;
356113bce0cSPrathamesh Tagore     SmallVector<int64_t> newOuterDimsPermVec =
357113bce0cSPrathamesh Tagore         llvm::to_vector(transposePermutation);
358113bce0cSPrathamesh Tagore 
359113bce0cSPrathamesh Tagore     if (!outerDimsPerm.empty())
360113bce0cSPrathamesh Tagore       applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
361113bce0cSPrathamesh Tagore 
362113bce0cSPrathamesh Tagore     // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
363113bce0cSPrathamesh Tagore     // permutation rank won't necessarily be equal in all cases.
364113bce0cSPrathamesh Tagore     for (auto dim : innerDimsPos)
365113bce0cSPrathamesh Tagore       newInnerDimsPosVec.push_back(transposePermutation[dim]);
366113bce0cSPrathamesh Tagore 
367113bce0cSPrathamesh Tagore     Value output = packOp.createDestinationTensor(
3687ef83f55SMax191         rewriter, packOp.getLoc(), linalgOp->getOperand(0),
369113bce0cSPrathamesh Tagore         packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
370113bce0cSPrathamesh Tagore 
371113bce0cSPrathamesh Tagore     rewriter.replaceOpWithNewOp<PackOp>(
3727ef83f55SMax191         packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
373113bce0cSPrathamesh Tagore         packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
374113bce0cSPrathamesh Tagore 
375113bce0cSPrathamesh Tagore     return success();
376113bce0cSPrathamesh Tagore   }
377113bce0cSPrathamesh Tagore };
378aa7ae1baSPrashant Kumar 
379aa7ae1baSPrashant Kumar /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
380aa7ae1baSPrashant Kumar /// transpose semantics.
381aa7ae1baSPrashant Kumar struct FoldProducerUnPackWithConsumerLinalgTransposeOp
3827ef83f55SMax191     : public OpInterfaceRewritePattern<linalg::LinalgOp> {
3837ef83f55SMax191   using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
384aa7ae1baSPrashant Kumar 
3857ef83f55SMax191   LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
386aa7ae1baSPrashant Kumar                                 PatternRewriter &rewriter) const override {
3877ef83f55SMax191     auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
388aa7ae1baSPrashant Kumar 
389aa7ae1baSPrashant Kumar     if (!unPackOp)
390aa7ae1baSPrashant Kumar       return failure();
391aa7ae1baSPrashant Kumar 
3927ef83f55SMax191     FailureOr<SmallVector<int64_t>> maybePerm =
3937ef83f55SMax191         getTransposeOpPermutation(linalgOp);
3947ef83f55SMax191     if (failed(maybePerm))
3957ef83f55SMax191       return failure();
3967ef83f55SMax191 
397aa7ae1baSPrashant Kumar     auto outerDimsPerm = unPackOp.getOuterDimsPerm();
398aa7ae1baSPrashant Kumar     auto innerDimsPos = unPackOp.getInnerDimsPos();
399aa7ae1baSPrashant Kumar     SmallVector<int64_t> newInnerDimsPosVec;
400aa7ae1baSPrashant Kumar     SmallVector<int64_t> newOuterDimsPermVec =
4017ef83f55SMax191         invertPermutationVector(maybePerm.value());
402aa7ae1baSPrashant Kumar 
403aa7ae1baSPrashant Kumar     // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
404aa7ae1baSPrashant Kumar     // permutation rank won't necessarily be equal in all cases.
405aa7ae1baSPrashant Kumar     for (auto dim : innerDimsPos)
4067ef83f55SMax191       newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
4077ef83f55SMax191 
4087ef83f55SMax191     if (!outerDimsPerm.empty())
4097ef83f55SMax191       applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
410aa7ae1baSPrashant Kumar 
41175f72954SQuinn Dawkins     // Reuse the destination of the transpose op.
412aa7ae1baSPrashant Kumar     rewriter.replaceOpWithNewOp<UnPackOp>(
4137ef83f55SMax191         linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
41475f72954SQuinn Dawkins         newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
415aa7ae1baSPrashant Kumar 
416aa7ae1baSPrashant Kumar     return success();
417aa7ae1baSPrashant Kumar   }
418aa7ae1baSPrashant Kumar };
419aa7ae1baSPrashant Kumar 
420aa7ae1baSPrashant Kumar /// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
421aa7ae1baSPrashant Kumar /// transpose semantics.
422aa7ae1baSPrashant Kumar struct FoldConsumerUnPackWithProducerLinalgTransposeOp
423aa7ae1baSPrashant Kumar     : public OpRewritePattern<UnPackOp> {
424aa7ae1baSPrashant Kumar   using OpRewritePattern<UnPackOp>::OpRewritePattern;
425aa7ae1baSPrashant Kumar 
426aa7ae1baSPrashant Kumar   LogicalResult matchAndRewrite(UnPackOp unPackOp,
427aa7ae1baSPrashant Kumar                                 PatternRewriter &rewriter) const override {
4287ef83f55SMax191     auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
4297ef83f55SMax191     if (!linalgOp)
430aa7ae1baSPrashant Kumar       return failure();
431aa7ae1baSPrashant Kumar 
4327ef83f55SMax191     FailureOr<SmallVector<int64_t>> maybePerm =
4337ef83f55SMax191         getTransposeOpPermutation(linalgOp);
4347ef83f55SMax191     if (failed(maybePerm))
4357ef83f55SMax191       return failure();
4367ef83f55SMax191 
437*c1667f90SBenoit Jacob     SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
438*c1667f90SBenoit Jacob     if (failed(reifyResultShapes(rewriter, unPackOp, unpackOpResultDims))) {
439*c1667f90SBenoit Jacob       return failure();
440*c1667f90SBenoit Jacob     }
441*c1667f90SBenoit Jacob 
4427ef83f55SMax191     SmallVector<int64_t> inverseTransposePerm =
4437ef83f55SMax191         invertPermutationVector(maybePerm.value());
444aa7ae1baSPrashant Kumar     auto outerDimsPerm = unPackOp.getOuterDimsPerm();
445aa7ae1baSPrashant Kumar     auto innerDimsPos = unPackOp.getInnerDimsPos();
446aa7ae1baSPrashant Kumar     int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
447aa7ae1baSPrashant Kumar     auto mixedInnerTilesVec = unPackOp.getMixedTiles();
448aa7ae1baSPrashant Kumar     SmallVector<int64_t> newOuterDimsPermVec;
449aa7ae1baSPrashant Kumar     SmallVector<int64_t> newInnerDimsPosVec;
450aa7ae1baSPrashant Kumar     SmallVector<OpFoldResult> newMixedInnerTilesVec;
4517ef83f55SMax191     if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
452aa7ae1baSPrashant Kumar                          newOuterDimsPermVec, destRank))
453aa7ae1baSPrashant Kumar       return rewriter.notifyMatchFailure(
454aa7ae1baSPrashant Kumar           unPackOp,
455aa7ae1baSPrashant Kumar           "Cannot fold in tensor.unpack if a tile dimension was transposed "
456aa7ae1baSPrashant Kumar           "with a non-tile dimension in linalg.transpose.");
457aa7ae1baSPrashant Kumar 
458aa7ae1baSPrashant Kumar     // Process transpose operation for tiled inner dimensions
4597ef83f55SMax191     for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
4607ef83f55SMax191       int64_t remappedPosition = inverseTransposePerm[i] - destRank;
461aa7ae1baSPrashant Kumar       newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
462aa7ae1baSPrashant Kumar       newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
463aa7ae1baSPrashant Kumar     }
464aa7ae1baSPrashant Kumar 
465*c1667f90SBenoit Jacob     auto elemType =
466*c1667f90SBenoit Jacob         cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType();
467*c1667f90SBenoit Jacob     Value output = rewriter.create<tensor::EmptyOp>(
468*c1667f90SBenoit Jacob         unPackOp->getLoc(), unpackOpResultDims[0], elemType);
469aa7ae1baSPrashant Kumar 
470aa7ae1baSPrashant Kumar     rewriter.replaceOpWithNewOp<UnPackOp>(
4717ef83f55SMax191         unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
472aa7ae1baSPrashant Kumar         newMixedInnerTilesVec, newOuterDimsPermVec);
473aa7ae1baSPrashant Kumar 
474aa7ae1baSPrashant Kumar     return success();
475aa7ae1baSPrashant Kumar   }
476aa7ae1baSPrashant Kumar };
4774b14205bSHan-Chung Wang } // namespace
4784b14205bSHan-Chung Wang 
4794b14205bSHan-Chung Wang void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
4804b14205bSHan-Chung Wang   patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
481113bce0cSPrathamesh Tagore                   FoldProducerPackWithConsumerLinalgTransposeOp,
482aa7ae1baSPrashant Kumar                   FoldConsumerPackWithProducerLinalgTransposeOp,
483aa7ae1baSPrashant Kumar                   FoldConsumerUnPackWithProducerLinalgTransposeOp,
484aa7ae1baSPrashant Kumar                   FoldProducerUnPackWithConsumerLinalgTransposeOp>(
4854b14205bSHan-Chung Wang       patterns.getContext());
4864b14205bSHan-Chung Wang }
4874b14205bSHan-Chung Wang 
4884b14205bSHan-Chung Wang void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
48976cb0bb7SHan-Chung Wang   patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
49076cb0bb7SHan-Chung Wang       patterns.getContext());
4914b14205bSHan-Chung Wang }
4924b14205bSHan-Chung Wang 
4934b14205bSHan-Chung Wang } // namespace tensor
4944b14205bSHan-Chung Wang } // namespace mlir
495