1 //===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/MemRef/IR/MemRef.h"
11 #include "mlir/Dialect/Utils/IndexingUtils.h"
12 #include "mlir/Dialect/Vector/IR/VectorOps.h"
13 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 
17 using namespace mlir;
18 using namespace mlir::vector;
19 
20 // Helper that picks the proper sequence for inserting.
21 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
22                        Value into, int64_t offset) {
23   auto vectorType = into.getType().cast<VectorType>();
24   if (vectorType.getRank() > 1)
25     return rewriter.create<InsertOp>(loc, from, into, offset);
26   return rewriter.create<vector::InsertElementOp>(
27       loc, vectorType, from, into,
28       rewriter.create<arith::ConstantIndexOp>(loc, offset));
29 }
30 
31 // Helper that picks the proper sequence for extracting.
32 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
33                         int64_t offset) {
34   auto vectorType = vector.getType().cast<VectorType>();
35   if (vectorType.getRank() > 1)
36     return rewriter.create<ExtractOp>(loc, vector, offset);
37   return rewriter.create<vector::ExtractElementOp>(
38       loc, vectorType.getElementType(), vector,
39       rewriter.create<arith::ConstantIndexOp>(loc, offset));
40 }
41 
42 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
43 /// have different ranks.
44 ///
45 /// When ranks are different, InsertStridedSlice needs to extract a properly
46 /// ranked vector from the destination vector into which to insert. This pattern
47 /// only takes care of this extraction part and forwards the rest to
48 /// [ConvertSameRankInsertStridedSliceIntoShuffle].
49 ///
50 /// For a k-D source and n-D destination vector (k < n), we emit:
51 ///   1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
52 ///      insert the k-D source.
53 ///   2. k-D -> (n-1)-D InsertStridedSlice op
54 ///   3. InsertOp that is the reverse of 1.
55 class DecomposeDifferentRankInsertStridedSlice
56     : public OpRewritePattern<InsertStridedSliceOp> {
57 public:
58   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
59 
60   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
61                                 PatternRewriter &rewriter) const override {
62     auto srcType = op.getSourceVectorType();
63     auto dstType = op.getDestVectorType();
64 
65     if (op.getOffsets().getValue().empty())
66       return failure();
67 
68     auto loc = op.getLoc();
69     int64_t rankDiff = dstType.getRank() - srcType.getRank();
70     assert(rankDiff >= 0);
71     if (rankDiff == 0)
72       return failure();
73 
74     int64_t rankRest = dstType.getRank() - rankDiff;
75     // Extract / insert the subvector of matching rank and InsertStridedSlice
76     // on it.
77     Value extracted = rewriter.create<ExtractOp>(
78         loc, op.getDest(),
79         getI64SubArray(op.getOffsets(), /*dropFront=*/0,
80                        /*dropBack=*/rankRest));
81 
82     // A different pattern will kick in for InsertStridedSlice with matching
83     // ranks.
84     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
85         loc, op.getSource(), extracted,
86         getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff),
87         getI64SubArray(op.getStrides(), /*dropFront=*/0));
88 
89     rewriter.replaceOpWithNewOp<InsertOp>(
90         op, stridedSliceInnerOp.getResult(), op.getDest(),
91         getI64SubArray(op.getOffsets(), /*dropFront=*/0,
92                        /*dropBack=*/rankRest));
93     return success();
94   }
95 };
96 
97 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
98 /// have the same rank. For each outermost index in the slice:
99 ///   begin    end             stride
100 /// [offset : offset+size*stride : stride]
101 ///   1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
102 ///   2. InsertStridedSlice (k-1)-D into (n-1)-D
103 ///   3. the destination subvector is inserted back in the proper place
104 ///   3. InsertOp that is the reverse of 1.
105 class ConvertSameRankInsertStridedSliceIntoShuffle
106     : public OpRewritePattern<InsertStridedSliceOp> {
107 public:
108   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
109 
110   void initialize() {
111     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
112     // bounded as the rank is strictly decreasing.
113     setHasBoundedRewriteRecursion();
114   }
115 
116   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
117                                 PatternRewriter &rewriter) const override {
118     auto srcType = op.getSourceVectorType();
119     auto dstType = op.getDestVectorType();
120 
121     if (op.getOffsets().getValue().empty())
122       return failure();
123 
124     int64_t srcRank = srcType.getRank();
125     int64_t dstRank = dstType.getRank();
126     assert(dstRank >= srcRank);
127     if (dstRank != srcRank)
128       return failure();
129 
130     if (srcType == dstType) {
131       rewriter.replaceOp(op, op.getSource());
132       return success();
133     }
134 
135     int64_t offset =
136         op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
137     int64_t size = srcType.getShape().front();
138     int64_t stride =
139         op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
140 
141     auto loc = op.getLoc();
142     Value res = op.getDest();
143 
144     if (srcRank == 1) {
145       int nSrc = srcType.getShape().front();
146       int nDest = dstType.getShape().front();
147       // 1. Scale source to destType so we can shufflevector them together.
148       SmallVector<int64_t> offsets(nDest, 0);
149       for (int64_t i = 0; i < nSrc; ++i)
150         offsets[i] = i;
151       Value scaledSource = rewriter.create<ShuffleOp>(loc, op.getSource(),
152                                                       op.getSource(), offsets);
153 
154       // 2. Create a mask where we take the value from scaledSource of dest
155       // depending on the offset.
156       offsets.clear();
157       for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
158         if (i < offset || i >= e || (i - offset) % stride != 0)
159           offsets.push_back(nDest + i);
160         else
161           offsets.push_back((i - offset) / stride);
162       }
163 
164       // 3. Replace with a ShuffleOp.
165       rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.getDest(),
166                                              offsets);
167 
168       return success();
169     }
170 
171     // For each slice of the source vector along the most major dimension.
172     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
173          off += stride, ++idx) {
174       // 1. extract the proper subvector (or element) from source
175       Value extractedSource = extractOne(rewriter, loc, op.getSource(), idx);
176       if (extractedSource.getType().isa<VectorType>()) {
177         // 2. If we have a vector, extract the proper subvector from destination
178         // Otherwise we are at the element level and no need to recurse.
179         Value extractedDest = extractOne(rewriter, loc, op.getDest(), off);
180         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
181         // smaller rank.
182         extractedSource = rewriter.create<InsertStridedSliceOp>(
183             loc, extractedSource, extractedDest,
184             getI64SubArray(op.getOffsets(), /* dropFront=*/1),
185             getI64SubArray(op.getStrides(), /* dropFront=*/1));
186       }
187       // 4. Insert the extractedSource into the res vector.
188       res = insertOne(rewriter, loc, extractedSource, res, off);
189     }
190 
191     rewriter.replaceOp(op, res);
192     return success();
193   }
194 };
195 
196 /// RewritePattern for ExtractStridedSliceOp where source and destination
197 /// vectors are 1-D. For such cases, we can lower it to a ShuffleOp.
198 class Convert1DExtractStridedSliceIntoShuffle
199     : public OpRewritePattern<ExtractStridedSliceOp> {
200 public:
201   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
202 
203   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
204                                 PatternRewriter &rewriter) const override {
205     auto dstType = op.getType();
206 
207     assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
208 
209     int64_t offset =
210         op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
211     int64_t size =
212         op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
213     int64_t stride =
214         op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
215 
216     assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
217 
218     // Single offset can be more efficiently shuffled.
219     if (op.getOffsets().getValue().size() != 1)
220       return failure();
221 
222     SmallVector<int64_t, 4> offsets;
223     offsets.reserve(size);
224     for (int64_t off = offset, e = offset + size * stride; off < e;
225          off += stride)
226       offsets.push_back(off);
227     rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
228                                            op.getVector(),
229                                            rewriter.getI64ArrayAttr(offsets));
230     return success();
231   }
232 };
233 
234 /// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
235 /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
236 /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
237 class DecomposeNDExtractStridedSlice
238     : public OpRewritePattern<ExtractStridedSliceOp> {
239 public:
240   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
241 
242   void initialize() {
243     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
244     // is bounded as the rank is strictly decreasing.
245     setHasBoundedRewriteRecursion();
246   }
247 
248   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
249                                 PatternRewriter &rewriter) const override {
250     auto dstType = op.getType();
251 
252     assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
253 
254     int64_t offset =
255         op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
256     int64_t size =
257         op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
258     int64_t stride =
259         op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
260 
261     auto loc = op.getLoc();
262     auto elemType = dstType.getElementType();
263     assert(elemType.isSignlessIntOrIndexOrFloat());
264 
265     // Single offset can be more efficiently shuffled. It's handled in
266     // Convert1DExtractStridedSliceIntoShuffle.
267     if (op.getOffsets().getValue().size() == 1)
268       return failure();
269 
270     // Extract/insert on a lower ranked extract strided slice op.
271     Value zero = rewriter.create<arith::ConstantOp>(
272         loc, elemType, rewriter.getZeroAttr(elemType));
273     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
274     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
275          off += stride, ++idx) {
276       Value one = extractOne(rewriter, loc, op.getVector(), off);
277       Value extracted = rewriter.create<ExtractStridedSliceOp>(
278           loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
279           getI64SubArray(op.getSizes(), /* dropFront=*/1),
280           getI64SubArray(op.getStrides(), /* dropFront=*/1));
281       res = insertOne(rewriter, loc, extracted, res, idx);
282     }
283     rewriter.replaceOp(op, res);
284     return success();
285   }
286 };
287 
288 void mlir::vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
289     RewritePatternSet &patterns, PatternBenefit benefit) {
290   patterns.add<DecomposeDifferentRankInsertStridedSlice,
291                DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
292 }
293 
294 /// Populate the given list with patterns that convert from Vector to LLVM.
295 void mlir::vector::populateVectorInsertExtractStridedSliceTransforms(
296     RewritePatternSet &patterns, PatternBenefit benefit) {
297   populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns,
298                                                                benefit);
299   patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
300                Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext(),
301                                                         benefit);
302 }
303