xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp (revision 706c1302f99d79af21ddf22e23c53d33329f225a)
1bb4c53b7SLei Zhang //===- MergeConsecutiveInsertExtractSlicePatterns.cpp ---------------------===//
2bb4c53b7SLei Zhang //
3bb4c53b7SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4bb4c53b7SLei Zhang // See https://llvm.org/LICENSE.txt for license information.
5bb4c53b7SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6bb4c53b7SLei Zhang //
7bb4c53b7SLei Zhang //===----------------------------------------------------------------------===//
8bb4c53b7SLei Zhang 
9465ec4e0SLei Zhang #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
10bb4c53b7SLei Zhang #include "mlir/Dialect/Tensor/IR/Tensor.h"
11bb4c53b7SLei Zhang #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
1226864d8fSMatthias Springer #include "mlir/Dialect/Tensor/Utils/Utils.h"
13bb4c53b7SLei Zhang #include "mlir/IR/BuiltinTypes.h"
14bb4c53b7SLei Zhang #include "mlir/IR/OpDefinition.h"
15bb4c53b7SLei Zhang #include "mlir/IR/PatternMatch.h"
16bb4c53b7SLei Zhang 
17bb4c53b7SLei Zhang using namespace mlir;
18bb4c53b7SLei Zhang using namespace mlir::tensor;
19bb4c53b7SLei Zhang 
20bb4c53b7SLei Zhang namespace {
21bb4c53b7SLei Zhang /// Merges consecutive tensor.extract_slice ops into one.
224dc72d47SNicolas Vasilache // TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
23bb4c53b7SLei Zhang struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
24bb4c53b7SLei Zhang   using OpRewritePattern::OpRewritePattern;
25bb4c53b7SLei Zhang 
matchAndRewrite__anonf469d7b20111::MergeConsecutiveExtractSlice26bb4c53b7SLei Zhang   LogicalResult matchAndRewrite(ExtractSliceOp nextOp,
27bb4c53b7SLei Zhang                                 PatternRewriter &rewriter) const override {
28bb4c53b7SLei Zhang     auto prevOp = nextOp.getSource().getDefiningOp<ExtractSliceOp>();
29bb4c53b7SLei Zhang     if (!prevOp)
30bb4c53b7SLei Zhang       return failure();
31bb4c53b7SLei Zhang 
32bd81524eSLei Zhang     SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
334c48f016SMatthias Springer     if (failed(affine::mergeOffsetsSizesAndStrides(
344c48f016SMatthias Springer             rewriter, nextOp.getLoc(), prevOp, nextOp, prevOp.getDroppedDims(),
35bd81524eSLei Zhang             newOffsets, newSizes, newStrides)))
36bb4c53b7SLei Zhang       return failure();
37bb4c53b7SLei Zhang 
38bd81524eSLei Zhang     rewriter.replaceOpWithNewOp<ExtractSliceOp>(nextOp, nextOp.getType(),
39bd81524eSLei Zhang                                                 prevOp.getSource(), newOffsets,
40bd81524eSLei Zhang                                                 newSizes, newStrides);
41bb4c53b7SLei Zhang     return success();
42bb4c53b7SLei Zhang   }
43bb4c53b7SLei Zhang };
44bb4c53b7SLei Zhang 
45bb4c53b7SLei Zhang /// Merges consecutive tensor.insert_slice ops into one.
464dc72d47SNicolas Vasilache // TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
476176d6a9SMatthias Springer template <typename OpTy>
486176d6a9SMatthias Springer struct MergeConsecutiveInsertSlice : public OpRewritePattern<OpTy> {
496176d6a9SMatthias Springer   using OpRewritePattern<OpTy>::OpRewritePattern;
50bb4c53b7SLei Zhang 
matchAndRewrite__anonf469d7b20111::MergeConsecutiveInsertSlice516176d6a9SMatthias Springer   LogicalResult matchAndRewrite(OpTy nextOp,
52bb4c53b7SLei Zhang                                 PatternRewriter &rewriter) const override {
536176d6a9SMatthias Springer     auto prevOp = nextOp.getSource().template getDefiningOp<InsertSliceOp>();
54bb4c53b7SLei Zhang     if (!prevOp)
55bb4c53b7SLei Zhang       return failure();
56bb4c53b7SLei Zhang 
57bb4c53b7SLei Zhang     if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
58bb4c53b7SLei Zhang       return failure();
59bb4c53b7SLei Zhang 
60bb4c53b7SLei Zhang     // The first insert_slice op should be rank reducing to make sure we cover
61bb4c53b7SLei Zhang     // the full source tensor to be inserted in the second insert_slice op.
62bb4c53b7SLei Zhang     SliceVerificationResult result =
63bb4c53b7SLei Zhang         isRankReducedType(prevOp.getDestType(), prevOp.getSourceType());
64bb4c53b7SLei Zhang     if (result != SliceVerificationResult::Success)
65bb4c53b7SLei Zhang       return failure();
66bb4c53b7SLei Zhang 
67bb4c53b7SLei Zhang     // Dynamic dimensions can pass rank reducing check in the above, e.g,
68bb4c53b7SLei Zhang     // inserting <?xf32> into <1x?x1xf32>. For such cases we cannot be certain
69bb4c53b7SLei Zhang     // the dynamic size covers the full tensor.
70bb4c53b7SLei Zhang     if (!prevOp.getSourceType().hasStaticShape() ||
71bb4c53b7SLei Zhang         !prevOp.getDestType().hasStaticShape())
72bb4c53b7SLei Zhang       return failure();
73bb4c53b7SLei Zhang 
746176d6a9SMatthias Springer     rewriter.replaceOpWithNewOp<OpTy>(
75bb4c53b7SLei Zhang         nextOp, prevOp.getSource(), nextOp.getDest(), nextOp.getMixedOffsets(),
76bb4c53b7SLei Zhang         nextOp.getMixedSizes(), nextOp.getMixedStrides());
77bb4c53b7SLei Zhang     return success();
78bb4c53b7SLei Zhang   }
79bb4c53b7SLei Zhang };
8026864d8fSMatthias Springer 
81f566b079SJerry Wu /// Drop redundant rank expansion of insert_slice that are directly followed
82f566b079SJerry Wu /// by extract_slice. E.g.:
8326864d8fSMatthias Springer /// %0 = tensor.insert_slice ... : tensor<5x10xf32> into tensor<1x1x5x10xf32>
8426864d8fSMatthias Springer /// %1 = tensor.extract_slice %0[0, 0, 2, 3] [1, 1, 2, 2] [1, 1, 1, 1]
8526864d8fSMatthias Springer ///     : tensor<1x1x5x10xf32> to tensor<2x2xf32>
86f566b079SJerry Wu struct DropRedundantRankExpansionOnExtractSliceOfInsertSlice
8726864d8fSMatthias Springer     : public OpRewritePattern<ExtractSliceOp> {
8826864d8fSMatthias Springer   using OpRewritePattern::OpRewritePattern;
8926864d8fSMatthias Springer 
matchAndRewrite__anonf469d7b20111::DropRedundantRankExpansionOnExtractSliceOfInsertSlice9026864d8fSMatthias Springer   LogicalResult matchAndRewrite(ExtractSliceOp extractSliceOp,
9126864d8fSMatthias Springer                                 PatternRewriter &rewriter) const override {
9226864d8fSMatthias Springer     // Nothing to do if no dims are dropped.
9326864d8fSMatthias Springer     llvm::SmallBitVector droppedDims = extractSliceOp.getDroppedDims();
946343ee72SFelix Schneider     if (droppedDims.none())
9526864d8fSMatthias Springer       return failure();
9626864d8fSMatthias Springer 
9726864d8fSMatthias Springer     // Look for tensor.insert_slice op that has an inverse rank expansion.
9826864d8fSMatthias Springer     auto insertSliceOp =
9926864d8fSMatthias Springer         extractSliceOp.getSource().getDefiningOp<InsertSliceOp>();
10026864d8fSMatthias Springer     if (!insertSliceOp)
10126864d8fSMatthias Springer       return failure();
10226864d8fSMatthias Springer     llvm::SmallBitVector expandedDims = insertSliceOp.getDroppedDims();
10326864d8fSMatthias Springer 
10426864d8fSMatthias Springer     // TODO: This could be extended to support cases where the dropped dims are
10526864d8fSMatthias Springer     // a subset of the expanded dims.
10626864d8fSMatthias Springer     if (expandedDims != droppedDims)
10726864d8fSMatthias Springer       return failure();
10826864d8fSMatthias Springer 
10926864d8fSMatthias Springer     // The tensor.insert_slice may not be redundant if it has multiple users.
11026864d8fSMatthias Springer     if (!insertSliceOp->hasOneUse())
11126864d8fSMatthias Springer       return failure();
11226864d8fSMatthias Springer 
11326864d8fSMatthias Springer     // Only consider tensor.insert_slice ops that are pure rank-reductions.
11426864d8fSMatthias Springer     // I.e., no elements are taken from the destination.
11526864d8fSMatthias Springer     if (!isCastLikeInsertSliceOp(insertSliceOp))
11626864d8fSMatthias Springer       return failure();
11726864d8fSMatthias Springer 
11826864d8fSMatthias Springer     // Extract directly from the source.
11926864d8fSMatthias Springer     OpBuilder::InsertionGuard g(rewriter);
12026864d8fSMatthias Springer     rewriter.setInsertionPoint(extractSliceOp);
12126864d8fSMatthias Springer     SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
12226864d8fSMatthias Springer     for (int64_t i = 0, e = extractSliceOp.getSourceType().getRank(); i < e;
12326864d8fSMatthias Springer          ++i) {
12426864d8fSMatthias Springer       if (droppedDims.test(i))
12526864d8fSMatthias Springer         continue;
12626864d8fSMatthias Springer       newOffsets.push_back(extractSliceOp.getMixedOffsets()[i]);
12726864d8fSMatthias Springer       newSizes.push_back(extractSliceOp.getMixedSizes()[i]);
12826864d8fSMatthias Springer       newStrides.push_back(extractSliceOp.getMixedStrides()[i]);
12926864d8fSMatthias Springer     }
13026864d8fSMatthias Springer     rewriter.replaceOpWithNewOp<ExtractSliceOp>(
13126864d8fSMatthias Springer         extractSliceOp, /*source=*/insertSliceOp.getSource(), newOffsets,
13226864d8fSMatthias Springer         newSizes, newStrides);
13326864d8fSMatthias Springer     rewriter.eraseOp(insertSliceOp);
13426864d8fSMatthias Springer     return success();
13526864d8fSMatthias Springer   }
13626864d8fSMatthias Springer };
137f566b079SJerry Wu 
138f566b079SJerry Wu /// Drop redundant rank expansion of insert_slice that direclty follows
139f566b079SJerry Wu /// extract_slice.
140f566b079SJerry Wu ///
141f566b079SJerry Wu /// This can be done when the insert_slice op purely expands ranks (adds unit
142f566b079SJerry Wu /// dims) and the extrace_slice drops corresponding unit dims. For example:
143f566b079SJerry Wu ///
144f566b079SJerry Wu /// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
145f566b079SJerry Wu ///     : tensor<2x8xf32> to tensor<8xf32>
146f566b079SJerry Wu /// %inserted_slice = tensor.insert_slice %extracted_slice
147f566b079SJerry Wu ///     into %dest[0, 0] [1, 8] [1, 1]
148f566b079SJerry Wu ///     : tensor<8xf32> into tensor<1x8xf32>
149f566b079SJerry Wu ///
150f566b079SJerry Wu /// can be folded into:
151f566b079SJerry Wu ///
152f566b079SJerry Wu /// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
153f566b079SJerry Wu ///     : tensor<2x8xf32> to tensor<1x8xf32>
154f566b079SJerry Wu struct DropRedundantRankExpansionOnInsertSliceOfExtractSlice final
155f566b079SJerry Wu     : public OpRewritePattern<tensor::InsertSliceOp> {
156f566b079SJerry Wu   using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
157f566b079SJerry Wu 
matchAndRewrite__anonf469d7b20111::DropRedundantRankExpansionOnInsertSliceOfExtractSlice158f566b079SJerry Wu   LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
159*706c1302SKazu Hirata                                 PatternRewriter &rewriter) const override {
160f566b079SJerry Wu     auto extractSliceOp =
161f566b079SJerry Wu         insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
162f566b079SJerry Wu     if (!extractSliceOp) {
163f566b079SJerry Wu       return rewriter.notifyMatchFailure(insertSliceOp,
164f566b079SJerry Wu                                          "source is not extract_slice");
165f566b079SJerry Wu     }
166f566b079SJerry Wu 
167f566b079SJerry Wu     // Can't fold if the extract_slice op has other users.
168f566b079SJerry Wu     if (!extractSliceOp->hasOneUse()) {
169f566b079SJerry Wu       return rewriter.notifyMatchFailure(insertSliceOp,
170f566b079SJerry Wu                                          "source has multi-uses");
171f566b079SJerry Wu     }
172f566b079SJerry Wu 
173f566b079SJerry Wu     // Check if the insert_slice op purely expands ranks (add unit dims).
174f566b079SJerry Wu     if (!isCastLikeInsertSliceOp(insertSliceOp)) {
175f566b079SJerry Wu       return rewriter.notifyMatchFailure(insertSliceOp,
176f566b079SJerry Wu                                          "insert_slice is not cast-like");
177f566b079SJerry Wu     }
178f566b079SJerry Wu 
179f566b079SJerry Wu     llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
180f566b079SJerry Wu     llvm::SmallBitVector insertDroppedDims = insertSliceOp.getDroppedDims();
181f566b079SJerry Wu     // Can't fold if the insert_slice op expands to more dims.
182f566b079SJerry Wu     if (extractDroppedDims.size() < insertDroppedDims.size()) {
183f566b079SJerry Wu       return rewriter.notifyMatchFailure(insertSliceOp,
184f566b079SJerry Wu                                          "insert_slice expands more dims");
185f566b079SJerry Wu     }
186f566b079SJerry Wu 
187f566b079SJerry Wu     // Try to match the extract dropped dims to the insert dropped dims. This is
188f566b079SJerry Wu     // done by scanning the dims of extract_slice and find the left-most one can
189f566b079SJerry Wu     // match the dim of insert_slice. If a match is found, advance the dim of
190f566b079SJerry Wu     // insert_slice to match the next one.
191f566b079SJerry Wu     unsigned insertDimPos = 0;
192f566b079SJerry Wu     for (unsigned extractDimPos = 0; extractDimPos < extractDroppedDims.size();
193f566b079SJerry Wu          ++extractDimPos) {
194f566b079SJerry Wu       // Matched all dims.
195f566b079SJerry Wu       if (insertDimPos == insertDroppedDims.size())
196f566b079SJerry Wu         break;
197f566b079SJerry Wu 
198f566b079SJerry Wu       bool isExtractDropped = extractDroppedDims[extractDimPos];
199f566b079SJerry Wu       bool isInsertDropped = insertDroppedDims[insertDimPos];
200f566b079SJerry Wu       // Match if both sides drop/keep the dim. Advance and match the next dim
201f566b079SJerry Wu       // of insert_slice.
202f566b079SJerry Wu       if (isExtractDropped == isInsertDropped) {
203f566b079SJerry Wu         insertDimPos += 1;
204f566b079SJerry Wu       } else if (!isExtractDropped && isInsertDropped) {
205f566b079SJerry Wu         // Not enough extract dropped dims to match the insert dropped dims.
206f566b079SJerry Wu         return rewriter.notifyMatchFailure(insertSliceOp,
207f566b079SJerry Wu                                            "insert_slice drops more unit dims");
208f566b079SJerry Wu       }
209f566b079SJerry Wu       // If the dim is dropped by extract_slice and not by insert_slice, look
210f566b079SJerry Wu       // the next dim of extract_slice to see if it can match the current dim of
211f566b079SJerry Wu       // insert_slice.
212f566b079SJerry Wu     }
213f566b079SJerry Wu     // Can't match some insert dims.
214f566b079SJerry Wu     if (insertDimPos != insertDroppedDims.size()) {
215f566b079SJerry Wu       return rewriter.notifyMatchFailure(insertSliceOp,
216f566b079SJerry Wu                                          "insert_slice has unmatched dims");
217f566b079SJerry Wu     }
218f566b079SJerry Wu 
219f566b079SJerry Wu     rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
220f566b079SJerry Wu         insertSliceOp, insertSliceOp.getType(), extractSliceOp.getSource(),
221f566b079SJerry Wu         extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
222f566b079SJerry Wu         extractSliceOp.getMixedStrides());
223f566b079SJerry Wu     rewriter.eraseOp(extractSliceOp);
224f566b079SJerry Wu 
225f566b079SJerry Wu     return success();
226f566b079SJerry Wu   }
227f566b079SJerry Wu };
228bb4c53b7SLei Zhang } // namespace
229bb4c53b7SLei Zhang 
populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet & patterns)230bb4c53b7SLei Zhang void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
231bb4c53b7SLei Zhang     RewritePatternSet &patterns) {
2326176d6a9SMatthias Springer   patterns.add<MergeConsecutiveExtractSlice,
2336176d6a9SMatthias Springer                MergeConsecutiveInsertSlice<InsertSliceOp>,
2346176d6a9SMatthias Springer                MergeConsecutiveInsertSlice<ParallelInsertSliceOp>>(
235bb4c53b7SLei Zhang       patterns.getContext());
236bb4c53b7SLei Zhang }
23726864d8fSMatthias Springer 
populateDropRedundantInsertSliceRankExpansionPatterns(RewritePatternSet & patterns)23826864d8fSMatthias Springer void mlir::tensor::populateDropRedundantInsertSliceRankExpansionPatterns(
23926864d8fSMatthias Springer     RewritePatternSet &patterns) {
240f566b079SJerry Wu   patterns.add<DropRedundantRankExpansionOnExtractSliceOfInsertSlice,
241f566b079SJerry Wu                DropRedundantRankExpansionOnInsertSliceOfExtractSlice>(
242f566b079SJerry Wu       patterns.getContext());
24326864d8fSMatthias Springer }
244