xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp (revision 706c1302f99d79af21ddf22e23c53d33329f225a)
1 //===- MergeConsecutiveInsertExtractSlicePatterns.cpp ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
10 #include "mlir/Dialect/Tensor/IR/Tensor.h"
11 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
12 #include "mlir/Dialect/Tensor/Utils/Utils.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/OpDefinition.h"
15 #include "mlir/IR/PatternMatch.h"
16 
17 using namespace mlir;
18 using namespace mlir::tensor;
19 
20 namespace {
21 /// Merges consecutive tensor.extract_slice ops into one.
22 // TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
23 struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
24   using OpRewritePattern::OpRewritePattern;
25 
matchAndRewrite__anonf469d7b20111::MergeConsecutiveExtractSlice26   LogicalResult matchAndRewrite(ExtractSliceOp nextOp,
27                                 PatternRewriter &rewriter) const override {
28     auto prevOp = nextOp.getSource().getDefiningOp<ExtractSliceOp>();
29     if (!prevOp)
30       return failure();
31 
32     SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
33     if (failed(affine::mergeOffsetsSizesAndStrides(
34             rewriter, nextOp.getLoc(), prevOp, nextOp, prevOp.getDroppedDims(),
35             newOffsets, newSizes, newStrides)))
36       return failure();
37 
38     rewriter.replaceOpWithNewOp<ExtractSliceOp>(nextOp, nextOp.getType(),
39                                                 prevOp.getSource(), newOffsets,
40                                                 newSizes, newStrides);
41     return success();
42   }
43 };
44 
45 /// Merges consecutive tensor.insert_slice ops into one.
46 // TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
47 template <typename OpTy>
48 struct MergeConsecutiveInsertSlice : public OpRewritePattern<OpTy> {
49   using OpRewritePattern<OpTy>::OpRewritePattern;
50 
matchAndRewrite__anonf469d7b20111::MergeConsecutiveInsertSlice51   LogicalResult matchAndRewrite(OpTy nextOp,
52                                 PatternRewriter &rewriter) const override {
53     auto prevOp = nextOp.getSource().template getDefiningOp<InsertSliceOp>();
54     if (!prevOp)
55       return failure();
56 
57     if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
58       return failure();
59 
60     // The first insert_slice op should be rank reducing to make sure we cover
61     // the full source tensor to be inserted in the second insert_slice op.
62     SliceVerificationResult result =
63         isRankReducedType(prevOp.getDestType(), prevOp.getSourceType());
64     if (result != SliceVerificationResult::Success)
65       return failure();
66 
67     // Dynamic dimensions can pass rank reducing check in the above, e.g,
68     // inserting <?xf32> into <1x?x1xf32>. For such cases we cannot be certain
69     // the dynamic size covers the full tensor.
70     if (!prevOp.getSourceType().hasStaticShape() ||
71         !prevOp.getDestType().hasStaticShape())
72       return failure();
73 
74     rewriter.replaceOpWithNewOp<OpTy>(
75         nextOp, prevOp.getSource(), nextOp.getDest(), nextOp.getMixedOffsets(),
76         nextOp.getMixedSizes(), nextOp.getMixedStrides());
77     return success();
78   }
79 };
80 
81 /// Drop redundant rank expansion of insert_slice that are directly followed
82 /// by extract_slice. E.g.:
83 /// %0 = tensor.insert_slice ... : tensor<5x10xf32> into tensor<1x1x5x10xf32>
84 /// %1 = tensor.extract_slice %0[0, 0, 2, 3] [1, 1, 2, 2] [1, 1, 1, 1]
85 ///     : tensor<1x1x5x10xf32> to tensor<2x2xf32>
86 struct DropRedundantRankExpansionOnExtractSliceOfInsertSlice
87     : public OpRewritePattern<ExtractSliceOp> {
88   using OpRewritePattern::OpRewritePattern;
89 
matchAndRewrite__anonf469d7b20111::DropRedundantRankExpansionOnExtractSliceOfInsertSlice90   LogicalResult matchAndRewrite(ExtractSliceOp extractSliceOp,
91                                 PatternRewriter &rewriter) const override {
92     // Nothing to do if no dims are dropped.
93     llvm::SmallBitVector droppedDims = extractSliceOp.getDroppedDims();
94     if (droppedDims.none())
95       return failure();
96 
97     // Look for tensor.insert_slice op that has an inverse rank expansion.
98     auto insertSliceOp =
99         extractSliceOp.getSource().getDefiningOp<InsertSliceOp>();
100     if (!insertSliceOp)
101       return failure();
102     llvm::SmallBitVector expandedDims = insertSliceOp.getDroppedDims();
103 
104     // TODO: This could be extended to support cases where the dropped dims are
105     // a subset of the expanded dims.
106     if (expandedDims != droppedDims)
107       return failure();
108 
109     // The tensor.insert_slice may not be redundant if it has multiple users.
110     if (!insertSliceOp->hasOneUse())
111       return failure();
112 
113     // Only consider tensor.insert_slice ops that are pure rank-reductions.
114     // I.e., no elements are taken from the destination.
115     if (!isCastLikeInsertSliceOp(insertSliceOp))
116       return failure();
117 
118     // Extract directly from the source.
119     OpBuilder::InsertionGuard g(rewriter);
120     rewriter.setInsertionPoint(extractSliceOp);
121     SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
122     for (int64_t i = 0, e = extractSliceOp.getSourceType().getRank(); i < e;
123          ++i) {
124       if (droppedDims.test(i))
125         continue;
126       newOffsets.push_back(extractSliceOp.getMixedOffsets()[i]);
127       newSizes.push_back(extractSliceOp.getMixedSizes()[i]);
128       newStrides.push_back(extractSliceOp.getMixedStrides()[i]);
129     }
130     rewriter.replaceOpWithNewOp<ExtractSliceOp>(
131         extractSliceOp, /*source=*/insertSliceOp.getSource(), newOffsets,
132         newSizes, newStrides);
133     rewriter.eraseOp(insertSliceOp);
134     return success();
135   }
136 };
137 
138 /// Drop redundant rank expansion of insert_slice that direclty follows
139 /// extract_slice.
140 ///
141 /// This can be done when the insert_slice op purely expands ranks (adds unit
142 /// dims) and the extrace_slice drops corresponding unit dims. For example:
143 ///
144 /// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
145 ///     : tensor<2x8xf32> to tensor<8xf32>
146 /// %inserted_slice = tensor.insert_slice %extracted_slice
147 ///     into %dest[0, 0] [1, 8] [1, 1]
148 ///     : tensor<8xf32> into tensor<1x8xf32>
149 ///
150 /// can be folded into:
151 ///
152 /// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
153 ///     : tensor<2x8xf32> to tensor<1x8xf32>
154 struct DropRedundantRankExpansionOnInsertSliceOfExtractSlice final
155     : public OpRewritePattern<tensor::InsertSliceOp> {
156   using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
157 
matchAndRewrite__anonf469d7b20111::DropRedundantRankExpansionOnInsertSliceOfExtractSlice158   LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
159                                 PatternRewriter &rewriter) const override {
160     auto extractSliceOp =
161         insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
162     if (!extractSliceOp) {
163       return rewriter.notifyMatchFailure(insertSliceOp,
164                                          "source is not extract_slice");
165     }
166 
167     // Can't fold if the extract_slice op has other users.
168     if (!extractSliceOp->hasOneUse()) {
169       return rewriter.notifyMatchFailure(insertSliceOp,
170                                          "source has multi-uses");
171     }
172 
173     // Check if the insert_slice op purely expands ranks (add unit dims).
174     if (!isCastLikeInsertSliceOp(insertSliceOp)) {
175       return rewriter.notifyMatchFailure(insertSliceOp,
176                                          "insert_slice is not cast-like");
177     }
178 
179     llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
180     llvm::SmallBitVector insertDroppedDims = insertSliceOp.getDroppedDims();
181     // Can't fold if the insert_slice op expands to more dims.
182     if (extractDroppedDims.size() < insertDroppedDims.size()) {
183       return rewriter.notifyMatchFailure(insertSliceOp,
184                                          "insert_slice expands more dims");
185     }
186 
187     // Try to match the extract dropped dims to the insert dropped dims. This is
188     // done by scanning the dims of extract_slice and find the left-most one can
189     // match the dim of insert_slice. If a match is found, advance the dim of
190     // insert_slice to match the next one.
191     unsigned insertDimPos = 0;
192     for (unsigned extractDimPos = 0; extractDimPos < extractDroppedDims.size();
193          ++extractDimPos) {
194       // Matched all dims.
195       if (insertDimPos == insertDroppedDims.size())
196         break;
197 
198       bool isExtractDropped = extractDroppedDims[extractDimPos];
199       bool isInsertDropped = insertDroppedDims[insertDimPos];
200       // Match if both sides drop/keep the dim. Advance and match the next dim
201       // of insert_slice.
202       if (isExtractDropped == isInsertDropped) {
203         insertDimPos += 1;
204       } else if (!isExtractDropped && isInsertDropped) {
205         // Not enough extract dropped dims to match the insert dropped dims.
206         return rewriter.notifyMatchFailure(insertSliceOp,
207                                            "insert_slice drops more unit dims");
208       }
209       // If the dim is dropped by extract_slice and not by insert_slice, look
210       // the next dim of extract_slice to see if it can match the current dim of
211       // insert_slice.
212     }
213     // Can't match some insert dims.
214     if (insertDimPos != insertDroppedDims.size()) {
215       return rewriter.notifyMatchFailure(insertSliceOp,
216                                          "insert_slice has unmatched dims");
217     }
218 
219     rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
220         insertSliceOp, insertSliceOp.getType(), extractSliceOp.getSource(),
221         extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
222         extractSliceOp.getMixedStrides());
223     rewriter.eraseOp(extractSliceOp);
224 
225     return success();
226   }
227 };
228 } // namespace
229 
populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet & patterns)230 void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
231     RewritePatternSet &patterns) {
232   patterns.add<MergeConsecutiveExtractSlice,
233                MergeConsecutiveInsertSlice<InsertSliceOp>,
234                MergeConsecutiveInsertSlice<ParallelInsertSliceOp>>(
235       patterns.getContext());
236 }
237 
populateDropRedundantInsertSliceRankExpansionPatterns(RewritePatternSet & patterns)238 void mlir::tensor::populateDropRedundantInsertSliceRankExpansionPatterns(
239     RewritePatternSet &patterns) {
240   patterns.add<DropRedundantRankExpansionOnExtractSliceOfInsertSlice,
241                DropRedundantRankExpansionOnInsertSliceOfExtractSlice>(
242       patterns.getContext());
243 }
244