xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp (revision a95ad2da36b6a996b05c79df6b385cd98bac286d)
150a2bb95SMatthias Springer //===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
250a2bb95SMatthias Springer //
350a2bb95SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
450a2bb95SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
550a2bb95SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
650a2bb95SMatthias Springer //
750a2bb95SMatthias Springer //===----------------------------------------------------------------------===//
850a2bb95SMatthias Springer 
950a2bb95SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
1050a2bb95SMatthias Springer #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
1150a2bb95SMatthias Springer #include "mlir/IR/PatternMatch.h"
1250a2bb95SMatthias Springer #include "llvm/Support/Debug.h"
1350a2bb95SMatthias Springer 
1450a2bb95SMatthias Springer using namespace mlir;
1550a2bb95SMatthias Springer using namespace mlir::tensor;
1650a2bb95SMatthias Springer 
1750a2bb95SMatthias Springer namespace {
1850a2bb95SMatthias Springer /// Fold expand_shape(extract_slice) ops that cancel itself out.
1950a2bb95SMatthias Springer struct FoldExpandOfRankReducingExtract
2050a2bb95SMatthias Springer     : public OpRewritePattern<ExpandShapeOp> {
2150a2bb95SMatthias Springer   using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2250a2bb95SMatthias Springer 
2350a2bb95SMatthias Springer   LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
2450a2bb95SMatthias Springer                                 PatternRewriter &rewriter) const override {
2550a2bb95SMatthias Springer     RankedTensorType resultType = expandShapeOp.getResultType();
2650a2bb95SMatthias Springer     auto extractSliceOp =
2750a2bb95SMatthias Springer         expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
2850a2bb95SMatthias Springer     if (!extractSliceOp)
2950a2bb95SMatthias Springer       return failure();
3050a2bb95SMatthias Springer     RankedTensorType srcType = extractSliceOp.getSourceType();
3150a2bb95SMatthias Springer 
3250a2bb95SMatthias Springer     // Only cases where the ExpandShapeOp can be folded away entirely are
3350a2bb95SMatthias Springer     // supported. Moreover, only simple cases where the resulting ExtractSliceOp
3450a2bb95SMatthias Springer     // has no rank-reduction anymore are supported at the moment.
3550a2bb95SMatthias Springer     RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
3650a2bb95SMatthias Springer         srcType, extractSliceOp.getStaticOffsets(),
3750a2bb95SMatthias Springer         extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
3850a2bb95SMatthias Springer     if (nonReducingExtractType != resultType)
3950a2bb95SMatthias Springer       return failure();
4050a2bb95SMatthias Springer 
4150a2bb95SMatthias Springer     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
4250a2bb95SMatthias Springer     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
4350a2bb95SMatthias Springer     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
4450a2bb95SMatthias Springer     rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
4550a2bb95SMatthias Springer         expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
4650a2bb95SMatthias Springer         mixedStrides);
4750a2bb95SMatthias Springer     return success();
4850a2bb95SMatthias Springer   }
4950a2bb95SMatthias Springer };
5014030737SMatthias Springer 
518f4d5a32SAdam Siemieniuk /// Fold collapse_shape which only removes static dimensions of size `1`
528f4d5a32SAdam Siemieniuk /// into extract_slice.
538f4d5a32SAdam Siemieniuk struct FoldUnPaddingCollapseIntoExtract
548f4d5a32SAdam Siemieniuk     : public OpRewritePattern<tensor::CollapseShapeOp> {
558f4d5a32SAdam Siemieniuk   using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
568f4d5a32SAdam Siemieniuk 
578f4d5a32SAdam Siemieniuk   LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
588f4d5a32SAdam Siemieniuk                                 PatternRewriter &rewriter) const override {
598f4d5a32SAdam Siemieniuk     auto extractSliceOp =
608f4d5a32SAdam Siemieniuk         collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
618f4d5a32SAdam Siemieniuk     // Collapse cannot be folded away with multiple users of the extract slice
628f4d5a32SAdam Siemieniuk     // and it is not necessarily beneficial to only convert the collapse into
638f4d5a32SAdam Siemieniuk     // another extract slice.
648f4d5a32SAdam Siemieniuk     if (!extractSliceOp || !extractSliceOp->hasOneUse())
658f4d5a32SAdam Siemieniuk       return failure();
668f4d5a32SAdam Siemieniuk 
678f4d5a32SAdam Siemieniuk     // Only fold away simple collapse where all removed dimensions have static
688f4d5a32SAdam Siemieniuk     // size `1`.
698f4d5a32SAdam Siemieniuk     SliceVerificationResult res = isRankReducedType(
708f4d5a32SAdam Siemieniuk         collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
718f4d5a32SAdam Siemieniuk     if (res != SliceVerificationResult::Success)
728f4d5a32SAdam Siemieniuk       return rewriter.notifyMatchFailure(collapseShapeOp,
738f4d5a32SAdam Siemieniuk                                          "expected unpadding collapse");
748f4d5a32SAdam Siemieniuk 
758f4d5a32SAdam Siemieniuk     Value unPaddedExtractSlice = rewriter.create<tensor::ExtractSliceOp>(
768f4d5a32SAdam Siemieniuk         extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
778f4d5a32SAdam Siemieniuk         extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
788f4d5a32SAdam Siemieniuk         extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
798f4d5a32SAdam Siemieniuk     rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice);
808f4d5a32SAdam Siemieniuk     return success();
818f4d5a32SAdam Siemieniuk   }
828f4d5a32SAdam Siemieniuk };
838f4d5a32SAdam Siemieniuk 
8414030737SMatthias Springer /// Fold insert_slice(collapse_shape) ops that cancel itself out.
859cdf6b64SMatthias Springer template <typename OpTy>
869cdf6b64SMatthias Springer struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
879cdf6b64SMatthias Springer   using OpRewritePattern<OpTy>::OpRewritePattern;
8814030737SMatthias Springer 
899cdf6b64SMatthias Springer   LogicalResult matchAndRewrite(OpTy insertSliceOp,
9014030737SMatthias Springer                                 PatternRewriter &rewriter) const override {
9114030737SMatthias Springer     auto collapseShapeOp =
929cdf6b64SMatthias Springer         insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
9314030737SMatthias Springer     if (!collapseShapeOp)
9414030737SMatthias Springer       return failure();
9514030737SMatthias Springer     RankedTensorType srcType = collapseShapeOp.getSrcType();
9614030737SMatthias Springer 
9714030737SMatthias Springer     // Only cases where the CollapseShapeOp can be folded away entirely are
9814030737SMatthias Springer     // supported. Moreover, only simple cases where the resulting InsertSliceOp
9914030737SMatthias Springer     // has no rank-reduction anymore are supported at the moment.
10014030737SMatthias Springer     RankedTensorType nonReducingInsertType =
10114030737SMatthias Springer         RankedTensorType::get(insertSliceOp.getStaticSizes(),
1029cdf6b64SMatthias Springer                               insertSliceOp.getDestType().getElementType());
10314030737SMatthias Springer     if (nonReducingInsertType != srcType)
10414030737SMatthias Springer       return failure();
10514030737SMatthias Springer 
10614030737SMatthias Springer     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
10714030737SMatthias Springer     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
10814030737SMatthias Springer     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
1099cdf6b64SMatthias Springer     rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
1109cdf6b64SMatthias Springer                                       insertSliceOp.getDest(), mixedOffsets,
1119cdf6b64SMatthias Springer                                       mixedSizes, mixedStrides);
11214030737SMatthias Springer     return success();
11314030737SMatthias Springer   }
11414030737SMatthias Springer };
115d6541fc7SAdam Siemieniuk 
116d6541fc7SAdam Siemieniuk /// Fold expand_shape which only adds static dimensions of size `1`
117d6541fc7SAdam Siemieniuk /// into insert_slice.
118d6541fc7SAdam Siemieniuk template <typename OpTy>
119d6541fc7SAdam Siemieniuk struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
120d6541fc7SAdam Siemieniuk   using OpRewritePattern<OpTy>::OpRewritePattern;
121d6541fc7SAdam Siemieniuk 
122d6541fc7SAdam Siemieniuk   LogicalResult matchAndRewrite(OpTy insertSliceOp,
123d6541fc7SAdam Siemieniuk                                 PatternRewriter &rewriter) const override {
124d6541fc7SAdam Siemieniuk     auto expandShapeOp = insertSliceOp.getSource()
125d6541fc7SAdam Siemieniuk                              .template getDefiningOp<tensor::ExpandShapeOp>();
126d6541fc7SAdam Siemieniuk     if (!expandShapeOp)
127d6541fc7SAdam Siemieniuk       return failure();
128d6541fc7SAdam Siemieniuk 
129d6541fc7SAdam Siemieniuk     // Only fold away simple expansion where all added dimensions have static
130d6541fc7SAdam Siemieniuk     // size `1`.
131d6541fc7SAdam Siemieniuk     SliceVerificationResult res = isRankReducedType(
132d6541fc7SAdam Siemieniuk         expandShapeOp.getResultType(), expandShapeOp.getSrcType());
133d6541fc7SAdam Siemieniuk     if (res != SliceVerificationResult::Success)
134d6541fc7SAdam Siemieniuk       return rewriter.notifyMatchFailure(insertSliceOp,
135d6541fc7SAdam Siemieniuk                                          "expected rank increasing expansion");
136d6541fc7SAdam Siemieniuk 
137d6541fc7SAdam Siemieniuk     rewriter.modifyOpInPlace(insertSliceOp, [&]() {
138d6541fc7SAdam Siemieniuk       insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
139d6541fc7SAdam Siemieniuk     });
140d6541fc7SAdam Siemieniuk     return success();
141d6541fc7SAdam Siemieniuk   }
142d6541fc7SAdam Siemieniuk };
143*a95ad2daSIan Wood 
144*a95ad2daSIan Wood /// Pattern to bubble up a tensor.expand_shape op through a producer
145*a95ad2daSIan Wood /// tensor.collapse_shape op that has non intersecting reassociations.
146*a95ad2daSIan Wood struct BubbleUpExpandThroughParallelCollapse
147*a95ad2daSIan Wood     : public OpRewritePattern<tensor::ExpandShapeOp> {
148*a95ad2daSIan Wood   using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
149*a95ad2daSIan Wood 
150*a95ad2daSIan Wood   LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
151*a95ad2daSIan Wood                                 PatternRewriter &rewriter) const override {
152*a95ad2daSIan Wood     auto collapseOp =
153*a95ad2daSIan Wood         expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
154*a95ad2daSIan Wood     if (!collapseOp)
155*a95ad2daSIan Wood       return failure();
156*a95ad2daSIan Wood     auto expandReInds = expandOp.getReassociationIndices();
157*a95ad2daSIan Wood     auto collapseReInds = collapseOp.getReassociationIndices();
158*a95ad2daSIan Wood 
159*a95ad2daSIan Wood     // Reshapes are parallel to each other if none of the reassociation indices
160*a95ad2daSIan Wood     // have greater than 1 index for both reshapes.
161*a95ad2daSIan Wood     for (auto [expandReassociation, collapseReassociation] :
162*a95ad2daSIan Wood          llvm::zip_equal(expandReInds, collapseReInds)) {
163*a95ad2daSIan Wood       if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
164*a95ad2daSIan Wood         return failure();
165*a95ad2daSIan Wood     }
166*a95ad2daSIan Wood 
167*a95ad2daSIan Wood     // Compute new reassociation indices and expanded/collaped shapes.
168*a95ad2daSIan Wood     SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
169*a95ad2daSIan Wood     Location loc = expandOp->getLoc();
170*a95ad2daSIan Wood     SmallVector<OpFoldResult> collapseSizes =
171*a95ad2daSIan Wood         tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
172*a95ad2daSIan Wood     SmallVector<OpFoldResult> expandSizes(getMixedValues(
173*a95ad2daSIan Wood         expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
174*a95ad2daSIan Wood     SmallVector<OpFoldResult> newExpandSizes;
175*a95ad2daSIan Wood     int64_t index = 0, expandIndex = 0, collapseIndex = 0;
176*a95ad2daSIan Wood     for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
177*a95ad2daSIan Wood       if (collapseReassociation.size() != 1) {
178*a95ad2daSIan Wood         ReassociationIndices newCollapseReassociation;
179*a95ad2daSIan Wood         for (size_t i = 0; i < collapseReassociation.size(); ++i) {
180*a95ad2daSIan Wood           newCollapseReassociation.push_back(index);
181*a95ad2daSIan Wood           newExpandReInds.push_back({index++});
182*a95ad2daSIan Wood           newExpandSizes.push_back(collapseSizes[collapseIndex++]);
183*a95ad2daSIan Wood         }
184*a95ad2daSIan Wood         newCollapseReInds.push_back(newCollapseReassociation);
185*a95ad2daSIan Wood         expandIndex++;
186*a95ad2daSIan Wood         continue;
187*a95ad2daSIan Wood       }
188*a95ad2daSIan Wood       ReassociationIndices newExpandReassociation;
189*a95ad2daSIan Wood       auto expandReassociation = expandReInds[idx];
190*a95ad2daSIan Wood       for (size_t i = 0; i < expandReassociation.size(); ++i) {
191*a95ad2daSIan Wood         newExpandReassociation.push_back(index);
192*a95ad2daSIan Wood         newCollapseReInds.push_back({index++});
193*a95ad2daSIan Wood         newExpandSizes.push_back(expandSizes[expandIndex++]);
194*a95ad2daSIan Wood       }
195*a95ad2daSIan Wood       newExpandReInds.push_back(newExpandReassociation);
196*a95ad2daSIan Wood       collapseIndex++;
197*a95ad2daSIan Wood     }
198*a95ad2daSIan Wood 
199*a95ad2daSIan Wood     // Swap reshape order.
200*a95ad2daSIan Wood     SmallVector<Value> dynamicSizes;
201*a95ad2daSIan Wood     SmallVector<int64_t> staticSizes;
202*a95ad2daSIan Wood     dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
203*a95ad2daSIan Wood     auto expandResultType = expandOp.getResultType().clone(staticSizes);
204*a95ad2daSIan Wood     auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
205*a95ad2daSIan Wood         loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
206*a95ad2daSIan Wood         newExpandSizes);
207*a95ad2daSIan Wood     rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
208*a95ad2daSIan Wood         expandOp, newExpand.getResult(), newCollapseReInds);
209*a95ad2daSIan Wood     return success();
210*a95ad2daSIan Wood   }
211*a95ad2daSIan Wood };
212*a95ad2daSIan Wood 
21350a2bb95SMatthias Springer } // namespace
21450a2bb95SMatthias Springer 
21550a2bb95SMatthias Springer void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
21650a2bb95SMatthias Springer     RewritePatternSet &patterns) {
2178f4d5a32SAdam Siemieniuk   patterns
2188f4d5a32SAdam Siemieniuk       .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
2199cdf6b64SMatthias Springer            FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
220d6541fc7SAdam Siemieniuk            FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
221d6541fc7SAdam Siemieniuk            FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
222d6541fc7SAdam Siemieniuk            FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
22314030737SMatthias Springer           patterns.getContext());
22450a2bb95SMatthias Springer }
225*a95ad2daSIan Wood 
226*a95ad2daSIan Wood void mlir::tensor::populateBubbleUpExpandShapePatterns(
227*a95ad2daSIan Wood     RewritePatternSet &patterns) {
228*a95ad2daSIan Wood   patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
229*a95ad2daSIan Wood }
230