1bb4c53b7SLei Zhang //===- MergeConsecutiveInsertExtractSlicePatterns.cpp ---------------------===//
2bb4c53b7SLei Zhang //
3bb4c53b7SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4bb4c53b7SLei Zhang // See https://llvm.org/LICENSE.txt for license information.
5bb4c53b7SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6bb4c53b7SLei Zhang //
7bb4c53b7SLei Zhang //===----------------------------------------------------------------------===//
8bb4c53b7SLei Zhang
9465ec4e0SLei Zhang #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
10bb4c53b7SLei Zhang #include "mlir/Dialect/Tensor/IR/Tensor.h"
11bb4c53b7SLei Zhang #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
1226864d8fSMatthias Springer #include "mlir/Dialect/Tensor/Utils/Utils.h"
13bb4c53b7SLei Zhang #include "mlir/IR/BuiltinTypes.h"
14bb4c53b7SLei Zhang #include "mlir/IR/OpDefinition.h"
15bb4c53b7SLei Zhang #include "mlir/IR/PatternMatch.h"
16bb4c53b7SLei Zhang
17bb4c53b7SLei Zhang using namespace mlir;
18bb4c53b7SLei Zhang using namespace mlir::tensor;
19bb4c53b7SLei Zhang
20bb4c53b7SLei Zhang namespace {
21bb4c53b7SLei Zhang /// Merges consecutive tensor.extract_slice ops into one.
224dc72d47SNicolas Vasilache // TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
23bb4c53b7SLei Zhang struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
24bb4c53b7SLei Zhang using OpRewritePattern::OpRewritePattern;
25bb4c53b7SLei Zhang
matchAndRewrite__anonf469d7b20111::MergeConsecutiveExtractSlice26bb4c53b7SLei Zhang LogicalResult matchAndRewrite(ExtractSliceOp nextOp,
27bb4c53b7SLei Zhang PatternRewriter &rewriter) const override {
28bb4c53b7SLei Zhang auto prevOp = nextOp.getSource().getDefiningOp<ExtractSliceOp>();
29bb4c53b7SLei Zhang if (!prevOp)
30bb4c53b7SLei Zhang return failure();
31bb4c53b7SLei Zhang
32bd81524eSLei Zhang SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
334c48f016SMatthias Springer if (failed(affine::mergeOffsetsSizesAndStrides(
344c48f016SMatthias Springer rewriter, nextOp.getLoc(), prevOp, nextOp, prevOp.getDroppedDims(),
35bd81524eSLei Zhang newOffsets, newSizes, newStrides)))
36bb4c53b7SLei Zhang return failure();
37bb4c53b7SLei Zhang
38bd81524eSLei Zhang rewriter.replaceOpWithNewOp<ExtractSliceOp>(nextOp, nextOp.getType(),
39bd81524eSLei Zhang prevOp.getSource(), newOffsets,
40bd81524eSLei Zhang newSizes, newStrides);
41bb4c53b7SLei Zhang return success();
42bb4c53b7SLei Zhang }
43bb4c53b7SLei Zhang };
44bb4c53b7SLei Zhang
45bb4c53b7SLei Zhang /// Merges consecutive tensor.insert_slice ops into one.
464dc72d47SNicolas Vasilache // TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
476176d6a9SMatthias Springer template <typename OpTy>
486176d6a9SMatthias Springer struct MergeConsecutiveInsertSlice : public OpRewritePattern<OpTy> {
496176d6a9SMatthias Springer using OpRewritePattern<OpTy>::OpRewritePattern;
50bb4c53b7SLei Zhang
matchAndRewrite__anonf469d7b20111::MergeConsecutiveInsertSlice516176d6a9SMatthias Springer LogicalResult matchAndRewrite(OpTy nextOp,
52bb4c53b7SLei Zhang PatternRewriter &rewriter) const override {
536176d6a9SMatthias Springer auto prevOp = nextOp.getSource().template getDefiningOp<InsertSliceOp>();
54bb4c53b7SLei Zhang if (!prevOp)
55bb4c53b7SLei Zhang return failure();
56bb4c53b7SLei Zhang
57bb4c53b7SLei Zhang if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
58bb4c53b7SLei Zhang return failure();
59bb4c53b7SLei Zhang
60bb4c53b7SLei Zhang // The first insert_slice op should be rank reducing to make sure we cover
61bb4c53b7SLei Zhang // the full source tensor to be inserted in the second insert_slice op.
62bb4c53b7SLei Zhang SliceVerificationResult result =
63bb4c53b7SLei Zhang isRankReducedType(prevOp.getDestType(), prevOp.getSourceType());
64bb4c53b7SLei Zhang if (result != SliceVerificationResult::Success)
65bb4c53b7SLei Zhang return failure();
66bb4c53b7SLei Zhang
67bb4c53b7SLei Zhang // Dynamic dimensions can pass rank reducing check in the above, e.g,
68bb4c53b7SLei Zhang // inserting <?xf32> into <1x?x1xf32>. For such cases we cannot be certain
69bb4c53b7SLei Zhang // the dynamic size covers the full tensor.
70bb4c53b7SLei Zhang if (!prevOp.getSourceType().hasStaticShape() ||
71bb4c53b7SLei Zhang !prevOp.getDestType().hasStaticShape())
72bb4c53b7SLei Zhang return failure();
73bb4c53b7SLei Zhang
746176d6a9SMatthias Springer rewriter.replaceOpWithNewOp<OpTy>(
75bb4c53b7SLei Zhang nextOp, prevOp.getSource(), nextOp.getDest(), nextOp.getMixedOffsets(),
76bb4c53b7SLei Zhang nextOp.getMixedSizes(), nextOp.getMixedStrides());
77bb4c53b7SLei Zhang return success();
78bb4c53b7SLei Zhang }
79bb4c53b7SLei Zhang };
8026864d8fSMatthias Springer
81f566b079SJerry Wu /// Drop redundant rank expansion of insert_slice that are directly followed
82f566b079SJerry Wu /// by extract_slice. E.g.:
8326864d8fSMatthias Springer /// %0 = tensor.insert_slice ... : tensor<5x10xf32> into tensor<1x1x5x10xf32>
8426864d8fSMatthias Springer /// %1 = tensor.extract_slice %0[0, 0, 2, 3] [1, 1, 2, 2] [1, 1, 1, 1]
8526864d8fSMatthias Springer /// : tensor<1x1x5x10xf32> to tensor<2x2xf32>
86f566b079SJerry Wu struct DropRedundantRankExpansionOnExtractSliceOfInsertSlice
8726864d8fSMatthias Springer : public OpRewritePattern<ExtractSliceOp> {
8826864d8fSMatthias Springer using OpRewritePattern::OpRewritePattern;
8926864d8fSMatthias Springer
matchAndRewrite__anonf469d7b20111::DropRedundantRankExpansionOnExtractSliceOfInsertSlice9026864d8fSMatthias Springer LogicalResult matchAndRewrite(ExtractSliceOp extractSliceOp,
9126864d8fSMatthias Springer PatternRewriter &rewriter) const override {
9226864d8fSMatthias Springer // Nothing to do if no dims are dropped.
9326864d8fSMatthias Springer llvm::SmallBitVector droppedDims = extractSliceOp.getDroppedDims();
946343ee72SFelix Schneider if (droppedDims.none())
9526864d8fSMatthias Springer return failure();
9626864d8fSMatthias Springer
9726864d8fSMatthias Springer // Look for tensor.insert_slice op that has an inverse rank expansion.
9826864d8fSMatthias Springer auto insertSliceOp =
9926864d8fSMatthias Springer extractSliceOp.getSource().getDefiningOp<InsertSliceOp>();
10026864d8fSMatthias Springer if (!insertSliceOp)
10126864d8fSMatthias Springer return failure();
10226864d8fSMatthias Springer llvm::SmallBitVector expandedDims = insertSliceOp.getDroppedDims();
10326864d8fSMatthias Springer
10426864d8fSMatthias Springer // TODO: This could be extended to support cases where the dropped dims are
10526864d8fSMatthias Springer // a subset of the expanded dims.
10626864d8fSMatthias Springer if (expandedDims != droppedDims)
10726864d8fSMatthias Springer return failure();
10826864d8fSMatthias Springer
10926864d8fSMatthias Springer // The tensor.insert_slice may not be redundant if it has multiple users.
11026864d8fSMatthias Springer if (!insertSliceOp->hasOneUse())
11126864d8fSMatthias Springer return failure();
11226864d8fSMatthias Springer
11326864d8fSMatthias Springer // Only consider tensor.insert_slice ops that are pure rank-reductions.
11426864d8fSMatthias Springer // I.e., no elements are taken from the destination.
11526864d8fSMatthias Springer if (!isCastLikeInsertSliceOp(insertSliceOp))
11626864d8fSMatthias Springer return failure();
11726864d8fSMatthias Springer
11826864d8fSMatthias Springer // Extract directly from the source.
11926864d8fSMatthias Springer OpBuilder::InsertionGuard g(rewriter);
12026864d8fSMatthias Springer rewriter.setInsertionPoint(extractSliceOp);
12126864d8fSMatthias Springer SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
12226864d8fSMatthias Springer for (int64_t i = 0, e = extractSliceOp.getSourceType().getRank(); i < e;
12326864d8fSMatthias Springer ++i) {
12426864d8fSMatthias Springer if (droppedDims.test(i))
12526864d8fSMatthias Springer continue;
12626864d8fSMatthias Springer newOffsets.push_back(extractSliceOp.getMixedOffsets()[i]);
12726864d8fSMatthias Springer newSizes.push_back(extractSliceOp.getMixedSizes()[i]);
12826864d8fSMatthias Springer newStrides.push_back(extractSliceOp.getMixedStrides()[i]);
12926864d8fSMatthias Springer }
13026864d8fSMatthias Springer rewriter.replaceOpWithNewOp<ExtractSliceOp>(
13126864d8fSMatthias Springer extractSliceOp, /*source=*/insertSliceOp.getSource(), newOffsets,
13226864d8fSMatthias Springer newSizes, newStrides);
13326864d8fSMatthias Springer rewriter.eraseOp(insertSliceOp);
13426864d8fSMatthias Springer return success();
13526864d8fSMatthias Springer }
13626864d8fSMatthias Springer };
137f566b079SJerry Wu
138f566b079SJerry Wu /// Drop redundant rank expansion of insert_slice that direclty follows
139f566b079SJerry Wu /// extract_slice.
140f566b079SJerry Wu ///
141f566b079SJerry Wu /// This can be done when the insert_slice op purely expands ranks (adds unit
142f566b079SJerry Wu /// dims) and the extrace_slice drops corresponding unit dims. For example:
143f566b079SJerry Wu ///
144f566b079SJerry Wu /// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
145f566b079SJerry Wu /// : tensor<2x8xf32> to tensor<8xf32>
146f566b079SJerry Wu /// %inserted_slice = tensor.insert_slice %extracted_slice
147f566b079SJerry Wu /// into %dest[0, 0] [1, 8] [1, 1]
148f566b079SJerry Wu /// : tensor<8xf32> into tensor<1x8xf32>
149f566b079SJerry Wu ///
150f566b079SJerry Wu /// can be folded into:
151f566b079SJerry Wu ///
152f566b079SJerry Wu /// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
153f566b079SJerry Wu /// : tensor<2x8xf32> to tensor<1x8xf32>
154f566b079SJerry Wu struct DropRedundantRankExpansionOnInsertSliceOfExtractSlice final
155f566b079SJerry Wu : public OpRewritePattern<tensor::InsertSliceOp> {
156f566b079SJerry Wu using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
157f566b079SJerry Wu
matchAndRewrite__anonf469d7b20111::DropRedundantRankExpansionOnInsertSliceOfExtractSlice158f566b079SJerry Wu LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
159*706c1302SKazu Hirata PatternRewriter &rewriter) const override {
160f566b079SJerry Wu auto extractSliceOp =
161f566b079SJerry Wu insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
162f566b079SJerry Wu if (!extractSliceOp) {
163f566b079SJerry Wu return rewriter.notifyMatchFailure(insertSliceOp,
164f566b079SJerry Wu "source is not extract_slice");
165f566b079SJerry Wu }
166f566b079SJerry Wu
167f566b079SJerry Wu // Can't fold if the extract_slice op has other users.
168f566b079SJerry Wu if (!extractSliceOp->hasOneUse()) {
169f566b079SJerry Wu return rewriter.notifyMatchFailure(insertSliceOp,
170f566b079SJerry Wu "source has multi-uses");
171f566b079SJerry Wu }
172f566b079SJerry Wu
173f566b079SJerry Wu // Check if the insert_slice op purely expands ranks (add unit dims).
174f566b079SJerry Wu if (!isCastLikeInsertSliceOp(insertSliceOp)) {
175f566b079SJerry Wu return rewriter.notifyMatchFailure(insertSliceOp,
176f566b079SJerry Wu "insert_slice is not cast-like");
177f566b079SJerry Wu }
178f566b079SJerry Wu
179f566b079SJerry Wu llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
180f566b079SJerry Wu llvm::SmallBitVector insertDroppedDims = insertSliceOp.getDroppedDims();
181f566b079SJerry Wu // Can't fold if the insert_slice op expands to more dims.
182f566b079SJerry Wu if (extractDroppedDims.size() < insertDroppedDims.size()) {
183f566b079SJerry Wu return rewriter.notifyMatchFailure(insertSliceOp,
184f566b079SJerry Wu "insert_slice expands more dims");
185f566b079SJerry Wu }
186f566b079SJerry Wu
187f566b079SJerry Wu // Try to match the extract dropped dims to the insert dropped dims. This is
188f566b079SJerry Wu // done by scanning the dims of extract_slice and find the left-most one can
189f566b079SJerry Wu // match the dim of insert_slice. If a match is found, advance the dim of
190f566b079SJerry Wu // insert_slice to match the next one.
191f566b079SJerry Wu unsigned insertDimPos = 0;
192f566b079SJerry Wu for (unsigned extractDimPos = 0; extractDimPos < extractDroppedDims.size();
193f566b079SJerry Wu ++extractDimPos) {
194f566b079SJerry Wu // Matched all dims.
195f566b079SJerry Wu if (insertDimPos == insertDroppedDims.size())
196f566b079SJerry Wu break;
197f566b079SJerry Wu
198f566b079SJerry Wu bool isExtractDropped = extractDroppedDims[extractDimPos];
199f566b079SJerry Wu bool isInsertDropped = insertDroppedDims[insertDimPos];
200f566b079SJerry Wu // Match if both sides drop/keep the dim. Advance and match the next dim
201f566b079SJerry Wu // of insert_slice.
202f566b079SJerry Wu if (isExtractDropped == isInsertDropped) {
203f566b079SJerry Wu insertDimPos += 1;
204f566b079SJerry Wu } else if (!isExtractDropped && isInsertDropped) {
205f566b079SJerry Wu // Not enough extract dropped dims to match the insert dropped dims.
206f566b079SJerry Wu return rewriter.notifyMatchFailure(insertSliceOp,
207f566b079SJerry Wu "insert_slice drops more unit dims");
208f566b079SJerry Wu }
209f566b079SJerry Wu // If the dim is dropped by extract_slice and not by insert_slice, look
210f566b079SJerry Wu // the next dim of extract_slice to see if it can match the current dim of
211f566b079SJerry Wu // insert_slice.
212f566b079SJerry Wu }
213f566b079SJerry Wu // Can't match some insert dims.
214f566b079SJerry Wu if (insertDimPos != insertDroppedDims.size()) {
215f566b079SJerry Wu return rewriter.notifyMatchFailure(insertSliceOp,
216f566b079SJerry Wu "insert_slice has unmatched dims");
217f566b079SJerry Wu }
218f566b079SJerry Wu
219f566b079SJerry Wu rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
220f566b079SJerry Wu insertSliceOp, insertSliceOp.getType(), extractSliceOp.getSource(),
221f566b079SJerry Wu extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
222f566b079SJerry Wu extractSliceOp.getMixedStrides());
223f566b079SJerry Wu rewriter.eraseOp(extractSliceOp);
224f566b079SJerry Wu
225f566b079SJerry Wu return success();
226f566b079SJerry Wu }
227f566b079SJerry Wu };
228bb4c53b7SLei Zhang } // namespace
229bb4c53b7SLei Zhang
populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet & patterns)230bb4c53b7SLei Zhang void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
231bb4c53b7SLei Zhang RewritePatternSet &patterns) {
2326176d6a9SMatthias Springer patterns.add<MergeConsecutiveExtractSlice,
2336176d6a9SMatthias Springer MergeConsecutiveInsertSlice<InsertSliceOp>,
2346176d6a9SMatthias Springer MergeConsecutiveInsertSlice<ParallelInsertSliceOp>>(
235bb4c53b7SLei Zhang patterns.getContext());
236bb4c53b7SLei Zhang }
23726864d8fSMatthias Springer
populateDropRedundantInsertSliceRankExpansionPatterns(RewritePatternSet & patterns)23826864d8fSMatthias Springer void mlir::tensor::populateDropRedundantInsertSliceRankExpansionPatterns(
23926864d8fSMatthias Springer RewritePatternSet &patterns) {
240f566b079SJerry Wu patterns.add<DropRedundantRankExpansionOnExtractSliceOfInsertSlice,
241f566b079SJerry Wu DropRedundantRankExpansionOnInsertSliceOfExtractSlice>(
242f566b079SJerry Wu patterns.getContext());
24326864d8fSMatthias Springer }
244