1 //===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Dialect/MemRef/IR/MemRef.h"
11 #include "mlir/Dialect/Utils/IndexingUtils.h"
12 #include "mlir/Dialect/Vector/IR/VectorOps.h"
13 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace mlir::vector;
20 
21 // Helper that picks the proper sequence for inserting.
22 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
23                        Value into, int64_t offset) {
24   auto vectorType = cast<VectorType>(into.getType());
25   if (vectorType.getRank() > 1)
26     return rewriter.create<InsertOp>(loc, from, into, offset);
27   return rewriter.create<vector::InsertElementOp>(
28       loc, vectorType, from, into,
29       rewriter.create<arith::ConstantIndexOp>(loc, offset));
30 }
31 
32 // Helper that picks the proper sequence for extracting.
33 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
34                         int64_t offset) {
35   auto vectorType = cast<VectorType>(vector.getType());
36   if (vectorType.getRank() > 1)
37     return rewriter.create<ExtractOp>(loc, vector, offset);
38   return rewriter.create<vector::ExtractElementOp>(
39       loc, vectorType.getElementType(), vector,
40       rewriter.create<arith::ConstantIndexOp>(loc, offset));
41 }
42 
43 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
44 /// have different ranks.
45 ///
46 /// When ranks are different, InsertStridedSlice needs to extract a properly
47 /// ranked vector from the destination vector into which to insert. This pattern
48 /// only takes care of this extraction part and forwards the rest to
49 /// [ConvertSameRankInsertStridedSliceIntoShuffle].
50 ///
51 /// For a k-D source and n-D destination vector (k < n), we emit:
52 ///   1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
53 ///      insert the k-D source.
54 ///   2. k-D -> (n-1)-D InsertStridedSlice op
55 ///   3. InsertOp that is the reverse of 1.
56 class DecomposeDifferentRankInsertStridedSlice
57     : public OpRewritePattern<InsertStridedSliceOp> {
58 public:
59   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
60 
61   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
62                                 PatternRewriter &rewriter) const override {
63     auto srcType = op.getSourceVectorType();
64     auto dstType = op.getDestVectorType();
65 
66     if (op.getOffsets().getValue().empty())
67       return failure();
68 
69     auto loc = op.getLoc();
70     int64_t rankDiff = dstType.getRank() - srcType.getRank();
71     assert(rankDiff >= 0);
72     if (rankDiff == 0)
73       return failure();
74 
75     int64_t rankRest = dstType.getRank() - rankDiff;
76     // Extract / insert the subvector of matching rank and InsertStridedSlice
77     // on it.
78     Value extracted = rewriter.create<ExtractOp>(
79         loc, op.getDest(),
80         getI64SubArray(op.getOffsets(), /*dropFront=*/0,
81                        /*dropBack=*/rankRest));
82 
83     // A different pattern will kick in for InsertStridedSlice with matching
84     // ranks.
85     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
86         loc, op.getSource(), extracted,
87         getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff),
88         getI64SubArray(op.getStrides(), /*dropFront=*/0));
89 
90     rewriter.replaceOpWithNewOp<InsertOp>(
91         op, stridedSliceInnerOp.getResult(), op.getDest(),
92         getI64SubArray(op.getOffsets(), /*dropFront=*/0,
93                        /*dropBack=*/rankRest));
94     return success();
95   }
96 };
97 
98 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
99 /// have the same rank. For each outermost index in the slice:
100 ///   begin    end             stride
101 /// [offset : offset+size*stride : stride]
102 ///   1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
103 ///   2. InsertStridedSlice (k-1)-D into (n-1)-D
104 ///   3. the destination subvector is inserted back in the proper place
105 ///   3. InsertOp that is the reverse of 1.
106 class ConvertSameRankInsertStridedSliceIntoShuffle
107     : public OpRewritePattern<InsertStridedSliceOp> {
108 public:
109   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
110 
111   void initialize() {
112     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
113     // bounded as the rank is strictly decreasing.
114     setHasBoundedRewriteRecursion();
115   }
116 
117   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
118                                 PatternRewriter &rewriter) const override {
119     auto srcType = op.getSourceVectorType();
120     auto dstType = op.getDestVectorType();
121 
122     if (op.getOffsets().getValue().empty())
123       return failure();
124 
125     int64_t srcRank = srcType.getRank();
126     int64_t dstRank = dstType.getRank();
127     assert(dstRank >= srcRank);
128     if (dstRank != srcRank)
129       return failure();
130 
131     if (srcType == dstType) {
132       rewriter.replaceOp(op, op.getSource());
133       return success();
134     }
135 
136     int64_t offset =
137         cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
138     int64_t size = srcType.getShape().front();
139     int64_t stride =
140         cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
141 
142     auto loc = op.getLoc();
143     Value res = op.getDest();
144 
145     if (srcRank == 1) {
146       int nSrc = srcType.getShape().front();
147       int nDest = dstType.getShape().front();
148       // 1. Scale source to destType so we can shufflevector them together.
149       SmallVector<int64_t> offsets(nDest, 0);
150       for (int64_t i = 0; i < nSrc; ++i)
151         offsets[i] = i;
152       Value scaledSource = rewriter.create<ShuffleOp>(loc, op.getSource(),
153                                                       op.getSource(), offsets);
154 
155       // 2. Create a mask where we take the value from scaledSource of dest
156       // depending on the offset.
157       offsets.clear();
158       for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
159         if (i < offset || i >= e || (i - offset) % stride != 0)
160           offsets.push_back(nDest + i);
161         else
162           offsets.push_back((i - offset) / stride);
163       }
164 
165       // 3. Replace with a ShuffleOp.
166       rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.getDest(),
167                                              offsets);
168 
169       return success();
170     }
171 
172     // For each slice of the source vector along the most major dimension.
173     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
174          off += stride, ++idx) {
175       // 1. extract the proper subvector (or element) from source
176       Value extractedSource = extractOne(rewriter, loc, op.getSource(), idx);
177       if (isa<VectorType>(extractedSource.getType())) {
178         // 2. If we have a vector, extract the proper subvector from destination
179         // Otherwise we are at the element level and no need to recurse.
180         Value extractedDest = extractOne(rewriter, loc, op.getDest(), off);
181         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
182         // smaller rank.
183         extractedSource = rewriter.create<InsertStridedSliceOp>(
184             loc, extractedSource, extractedDest,
185             getI64SubArray(op.getOffsets(), /* dropFront=*/1),
186             getI64SubArray(op.getStrides(), /* dropFront=*/1));
187       }
188       // 4. Insert the extractedSource into the res vector.
189       res = insertOne(rewriter, loc, extractedSource, res, off);
190     }
191 
192     rewriter.replaceOp(op, res);
193     return success();
194   }
195 };
196 
197 /// RewritePattern for ExtractStridedSliceOp where source and destination
198 /// vectors are 1-D. For such cases, we can lower it to a ShuffleOp.
199 class Convert1DExtractStridedSliceIntoShuffle
200     : public OpRewritePattern<ExtractStridedSliceOp> {
201 public:
202   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
203 
204   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
205                                 PatternRewriter &rewriter) const override {
206     auto dstType = op.getType();
207 
208     assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
209 
210     int64_t offset =
211         cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
212     int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
213     int64_t stride =
214         cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
215 
216     assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
217 
218     // Single offset can be more efficiently shuffled.
219     if (op.getOffsets().getValue().size() != 1)
220       return failure();
221 
222     SmallVector<int64_t, 4> offsets;
223     offsets.reserve(size);
224     for (int64_t off = offset, e = offset + size * stride; off < e;
225          off += stride)
226       offsets.push_back(off);
227     rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
228                                            op.getVector(), offsets);
229     return success();
230   }
231 };
232 
233 /// For a 1-D ExtractStridedSlice, breaks it down into a chain of Extract ops
234 /// to extract each element from the source, and then a chain of Insert ops
235 /// to insert to the target vector.
236 class Convert1DExtractStridedSliceIntoExtractInsertChain final
237     : public OpRewritePattern<ExtractStridedSliceOp> {
238 public:
239   Convert1DExtractStridedSliceIntoExtractInsertChain(
240       MLIRContext *context,
241       std::function<bool(ExtractStridedSliceOp)> controlFn,
242       PatternBenefit benefit)
243       : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
244 
245   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
246                                 PatternRewriter &rewriter) const override {
247     if (controlFn && !controlFn(op))
248       return failure();
249 
250     // Only handle 1-D cases.
251     if (op.getOffsets().getValue().size() != 1)
252       return failure();
253 
254     int64_t offset =
255         cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
256     int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
257     int64_t stride =
258         cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
259 
260     Location loc = op.getLoc();
261     SmallVector<Value> elements;
262     elements.reserve(size);
263     for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
264       elements.push_back(rewriter.create<ExtractOp>(loc, op.getVector(), i));
265 
266     Value result = rewriter.create<arith::ConstantOp>(
267         loc, rewriter.getZeroAttr(op.getType()));
268     for (int64_t i = 0; i < size; ++i)
269       result = rewriter.create<InsertOp>(loc, elements[i], result, i);
270 
271     rewriter.replaceOp(op, result);
272     return success();
273   }
274 
275 private:
276   std::function<bool(ExtractStridedSliceOp)> controlFn;
277 };
278 
279 /// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
280 /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
281 /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
282 class DecomposeNDExtractStridedSlice
283     : public OpRewritePattern<ExtractStridedSliceOp> {
284 public:
285   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
286 
287   void initialize() {
288     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
289     // is bounded as the rank is strictly decreasing.
290     setHasBoundedRewriteRecursion();
291   }
292 
293   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
294                                 PatternRewriter &rewriter) const override {
295     auto dstType = op.getType();
296 
297     assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
298 
299     int64_t offset =
300         cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
301     int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
302     int64_t stride =
303         cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
304 
305     auto loc = op.getLoc();
306     auto elemType = dstType.getElementType();
307     assert(elemType.isSignlessIntOrIndexOrFloat());
308 
309     // Single offset can be more efficiently shuffled. It's handled in
310     // Convert1DExtractStridedSliceIntoShuffle.
311     if (op.getOffsets().getValue().size() == 1)
312       return failure();
313 
314     // Extract/insert on a lower ranked extract strided slice op.
315     Value zero = rewriter.create<arith::ConstantOp>(
316         loc, elemType, rewriter.getZeroAttr(elemType));
317     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
318     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
319          off += stride, ++idx) {
320       Value one = extractOne(rewriter, loc, op.getVector(), off);
321       Value extracted = rewriter.create<ExtractStridedSliceOp>(
322           loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
323           getI64SubArray(op.getSizes(), /* dropFront=*/1),
324           getI64SubArray(op.getStrides(), /* dropFront=*/1));
325       res = insertOne(rewriter, loc, extracted, res, idx);
326     }
327     rewriter.replaceOp(op, res);
328     return success();
329   }
330 };
331 
332 /// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
333 /// slice is contiguous, into extract and shape_cast.
334 class ContiguousExtractStridedSliceToExtract final
335     : public OpRewritePattern<ExtractStridedSliceOp> {
336 public:
337   using OpRewritePattern::OpRewritePattern;
338 
339   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
340                                 PatternRewriter &rewriter) const override {
341     if (op.hasNonUnitStrides()) {
342       return failure();
343     }
344     Value source = op.getOperand();
345     auto sourceType = cast<VectorType>(source.getType());
346     if (sourceType.isScalable()) {
347       return failure();
348     }
349 
350     // Compute the number of offsets to pass to ExtractOp::build. That is the
351     // difference between the source rank and the desired slice rank. We walk
352     // the dimensions from innermost out, and stop when the next slice dimension
353     // is not full-size.
354     SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
355     int numOffsets;
356     for (numOffsets = sourceType.getRank(); numOffsets > 0; --numOffsets) {
357       if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
358         break;
359       }
360     }
361 
362     // If not even the inner-most dimension is full-size, this op can't be
363     // rewritten as an ExtractOp.
364     if (numOffsets == sourceType.getRank()) {
365       return failure();
366     }
367 
368     // Avoid generating slices that have unit outer dimensions. The shape_cast
369     // op that we create below would take bad generic fallback patterns
370     // (ShapeCastOpRewritePattern).
371     while (sizes[numOffsets] == 1 && numOffsets < sourceType.getRank() - 1) {
372       ++numOffsets;
373     }
374 
375     SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
376     auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
377     Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
378                                                        extractOffsets);
379     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
380     return success();
381   }
382 };
383 
384 void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
385     RewritePatternSet &patterns, PatternBenefit benefit) {
386   patterns.add<DecomposeDifferentRankInsertStridedSlice,
387                DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
388 }
389 
390 void vector::populateVectorContiguousExtractStridedSliceToExtractPatterns(
391     RewritePatternSet &patterns, PatternBenefit benefit) {
392   patterns.add<ContiguousExtractStridedSliceToExtract>(patterns.getContext(),
393                                                        benefit);
394 }
395 
396 void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
397     RewritePatternSet &patterns,
398     std::function<bool(ExtractStridedSliceOp)> controlFn,
399     PatternBenefit benefit) {
400   patterns.add<Convert1DExtractStridedSliceIntoExtractInsertChain>(
401       patterns.getContext(), std::move(controlFn), benefit);
402 }
403 
404 /// Populate the given list with patterns that convert from Vector to LLVM.
405 void vector::populateVectorInsertExtractStridedSliceTransforms(
406     RewritePatternSet &patterns, PatternBenefit benefit) {
407   populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns,
408                                                                benefit);
409   patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
410                Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext(),
411                                                         benefit);
412 }
413