1 //===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Dialect/MemRef/IR/MemRef.h"
11 #include "mlir/Dialect/Utils/IndexingUtils.h"
12 #include "mlir/Dialect/Vector/IR/VectorOps.h"
13 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace mlir::vector;
20 
21 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
22 /// have different ranks.
23 ///
24 /// When ranks are different, InsertStridedSlice needs to extract a properly
25 /// ranked vector from the destination vector into which to insert. This pattern
26 /// only takes care of this extraction part and forwards the rest to
27 /// [ConvertSameRankInsertStridedSliceIntoShuffle].
28 ///
29 /// For a k-D source and n-D destination vector (k < n), we emit:
30 ///   1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
31 ///      insert the k-D source.
32 ///   2. k-D -> (n-1)-D InsertStridedSlice op
33 ///   3. InsertOp that is the reverse of 1.
34 class DecomposeDifferentRankInsertStridedSlice
35     : public OpRewritePattern<InsertStridedSliceOp> {
36 public:
37   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
38 
39   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
40                                 PatternRewriter &rewriter) const override {
41     auto srcType = op.getSourceVectorType();
42     auto dstType = op.getDestVectorType();
43 
44     if (op.getOffsets().getValue().empty())
45       return failure();
46 
47     auto loc = op.getLoc();
48     int64_t rankDiff = dstType.getRank() - srcType.getRank();
49     assert(rankDiff >= 0);
50     if (rankDiff == 0)
51       return failure();
52 
53     int64_t rankRest = dstType.getRank() - rankDiff;
54     // Extract / insert the subvector of matching rank and InsertStridedSlice
55     // on it.
56     Value extracted = rewriter.create<ExtractOp>(
57         loc, op.getDest(),
58         getI64SubArray(op.getOffsets(), /*dropFront=*/0,
59                        /*dropBack=*/rankRest));
60 
61     // A different pattern will kick in for InsertStridedSlice with matching
62     // ranks.
63     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
64         loc, op.getSource(), extracted,
65         getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff),
66         getI64SubArray(op.getStrides(), /*dropFront=*/0));
67 
68     rewriter.replaceOpWithNewOp<InsertOp>(
69         op, stridedSliceInnerOp.getResult(), op.getDest(),
70         getI64SubArray(op.getOffsets(), /*dropFront=*/0,
71                        /*dropBack=*/rankRest));
72     return success();
73   }
74 };
75 
76 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
77 /// have the same rank. For each outermost index in the slice:
78 ///   begin    end             stride
79 /// [offset : offset+size*stride : stride]
80 ///   1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
81 ///   2. InsertStridedSlice (k-1)-D into (n-1)-D
82 ///   3. the destination subvector is inserted back in the proper place
83 ///   3. InsertOp that is the reverse of 1.
84 class ConvertSameRankInsertStridedSliceIntoShuffle
85     : public OpRewritePattern<InsertStridedSliceOp> {
86 public:
87   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
88 
89   void initialize() {
90     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
91     // bounded as the rank is strictly decreasing.
92     setHasBoundedRewriteRecursion();
93   }
94 
95   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
96                                 PatternRewriter &rewriter) const override {
97     auto srcType = op.getSourceVectorType();
98     auto dstType = op.getDestVectorType();
99 
100     if (op.getOffsets().getValue().empty())
101       return failure();
102 
103     int64_t srcRank = srcType.getRank();
104     int64_t dstRank = dstType.getRank();
105     assert(dstRank >= srcRank);
106     if (dstRank != srcRank)
107       return failure();
108 
109     if (srcType == dstType) {
110       rewriter.replaceOp(op, op.getSource());
111       return success();
112     }
113 
114     int64_t offset =
115         cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
116     int64_t size = srcType.getShape().front();
117     int64_t stride =
118         cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
119 
120     auto loc = op.getLoc();
121     Value res = op.getDest();
122 
123     if (srcRank == 1) {
124       int nSrc = srcType.getShape().front();
125       int nDest = dstType.getShape().front();
126       // 1. Scale source to destType so we can shufflevector them together.
127       SmallVector<int64_t> offsets(nDest, 0);
128       for (int64_t i = 0; i < nSrc; ++i)
129         offsets[i] = i;
130       Value scaledSource = rewriter.create<ShuffleOp>(loc, op.getSource(),
131                                                       op.getSource(), offsets);
132 
133       // 2. Create a mask where we take the value from scaledSource of dest
134       // depending on the offset.
135       offsets.clear();
136       for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
137         if (i < offset || i >= e || (i - offset) % stride != 0)
138           offsets.push_back(nDest + i);
139         else
140           offsets.push_back((i - offset) / stride);
141       }
142 
143       // 3. Replace with a ShuffleOp.
144       rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.getDest(),
145                                              offsets);
146 
147       return success();
148     }
149 
150     // For each slice of the source vector along the most major dimension.
151     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
152          off += stride, ++idx) {
153       // 1. extract the proper subvector (or element) from source
154       Value extractedSource =
155           rewriter.create<ExtractOp>(loc, op.getSource(), idx);
156       if (isa<VectorType>(extractedSource.getType())) {
157         // 2. If we have a vector, extract the proper subvector from destination
158         // Otherwise we are at the element level and no need to recurse.
159         Value extractedDest =
160             rewriter.create<ExtractOp>(loc, op.getDest(), off);
161         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
162         // smaller rank.
163         extractedSource = rewriter.create<InsertStridedSliceOp>(
164             loc, extractedSource, extractedDest,
165             getI64SubArray(op.getOffsets(), /* dropFront=*/1),
166             getI64SubArray(op.getStrides(), /* dropFront=*/1));
167       }
168       // 4. Insert the extractedSource into the res vector.
169       res = rewriter.create<InsertOp>(loc, extractedSource, res, off);
170     }
171 
172     rewriter.replaceOp(op, res);
173     return success();
174   }
175 };
176 
177 /// RewritePattern for ExtractStridedSliceOp where source and destination
178 /// vectors are 1-D. For such cases, we can lower it to a ShuffleOp.
179 class Convert1DExtractStridedSliceIntoShuffle
180     : public OpRewritePattern<ExtractStridedSliceOp> {
181 public:
182   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
183 
184   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
185                                 PatternRewriter &rewriter) const override {
186     auto dstType = op.getType();
187 
188     assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
189 
190     int64_t offset =
191         cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
192     int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
193     int64_t stride =
194         cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
195 
196     assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
197 
198     // Single offset can be more efficiently shuffled.
199     if (op.getOffsets().getValue().size() != 1)
200       return failure();
201 
202     SmallVector<int64_t, 4> offsets;
203     offsets.reserve(size);
204     for (int64_t off = offset, e = offset + size * stride; off < e;
205          off += stride)
206       offsets.push_back(off);
207     rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
208                                            op.getVector(), offsets);
209     return success();
210   }
211 };
212 
213 /// For a 1-D ExtractStridedSlice, breaks it down into a chain of Extract ops
214 /// to extract each element from the source, and then a chain of Insert ops
215 /// to insert to the target vector.
216 class Convert1DExtractStridedSliceIntoExtractInsertChain final
217     : public OpRewritePattern<ExtractStridedSliceOp> {
218 public:
219   Convert1DExtractStridedSliceIntoExtractInsertChain(
220       MLIRContext *context,
221       std::function<bool(ExtractStridedSliceOp)> controlFn,
222       PatternBenefit benefit)
223       : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
224 
225   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
226                                 PatternRewriter &rewriter) const override {
227     if (controlFn && !controlFn(op))
228       return failure();
229 
230     // Only handle 1-D cases.
231     if (op.getOffsets().getValue().size() != 1)
232       return failure();
233 
234     int64_t offset =
235         cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
236     int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
237     int64_t stride =
238         cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
239 
240     Location loc = op.getLoc();
241     SmallVector<Value> elements;
242     elements.reserve(size);
243     for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
244       elements.push_back(rewriter.create<ExtractOp>(loc, op.getVector(), i));
245 
246     Value result = rewriter.create<arith::ConstantOp>(
247         loc, rewriter.getZeroAttr(op.getType()));
248     for (int64_t i = 0; i < size; ++i)
249       result = rewriter.create<InsertOp>(loc, elements[i], result, i);
250 
251     rewriter.replaceOp(op, result);
252     return success();
253   }
254 
255 private:
256   std::function<bool(ExtractStridedSliceOp)> controlFn;
257 };
258 
259 /// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
260 /// For such cases, we can rewrite it to ExtractOp + lower rank
261 /// ExtractStridedSliceOp + InsertOp for the n-D case.
262 class DecomposeNDExtractStridedSlice
263     : public OpRewritePattern<ExtractStridedSliceOp> {
264 public:
265   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
266 
267   void initialize() {
268     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
269     // is bounded as the rank is strictly decreasing.
270     setHasBoundedRewriteRecursion();
271   }
272 
273   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
274                                 PatternRewriter &rewriter) const override {
275     auto dstType = op.getType();
276 
277     assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
278 
279     int64_t offset =
280         cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
281     int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
282     int64_t stride =
283         cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
284 
285     auto loc = op.getLoc();
286     auto elemType = dstType.getElementType();
287     assert(elemType.isSignlessIntOrIndexOrFloat());
288 
289     // Single offset can be more efficiently shuffled. It's handled in
290     // Convert1DExtractStridedSliceIntoShuffle.
291     if (op.getOffsets().getValue().size() == 1)
292       return failure();
293 
294     // Extract/insert on a lower ranked extract strided slice op.
295     Value zero = rewriter.create<arith::ConstantOp>(
296         loc, elemType, rewriter.getZeroAttr(elemType));
297     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
298     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
299          off += stride, ++idx) {
300       Value one = rewriter.create<ExtractOp>(loc, op.getVector(), off);
301       Value extracted = rewriter.create<ExtractStridedSliceOp>(
302           loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
303           getI64SubArray(op.getSizes(), /* dropFront=*/1),
304           getI64SubArray(op.getStrides(), /* dropFront=*/1));
305       res = rewriter.create<InsertOp>(loc, extracted, res, idx);
306     }
307     rewriter.replaceOp(op, res);
308     return success();
309   }
310 };
311 
312 void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
313     RewritePatternSet &patterns, PatternBenefit benefit) {
314   patterns.add<DecomposeDifferentRankInsertStridedSlice,
315                DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
316 }
317 
318 void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
319     RewritePatternSet &patterns,
320     std::function<bool(ExtractStridedSliceOp)> controlFn,
321     PatternBenefit benefit) {
322   patterns.add<Convert1DExtractStridedSliceIntoExtractInsertChain>(
323       patterns.getContext(), std::move(controlFn), benefit);
324 }
325 
326 /// Populate the given list with patterns that convert from Vector to LLVM.
327 void vector::populateVectorInsertExtractStridedSliceTransforms(
328     RewritePatternSet &patterns, PatternBenefit benefit) {
329   populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns,
330                                                                benefit);
331   patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
332                Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext(),
333                                                         benefit);
334 }
335