1 //===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/MemRef/IR/MemRef.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Utils/IndexingUtils.h"
13 #include "mlir/Dialect/Vector/IR/VectorOps.h"
14 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
15 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 
18 using namespace mlir;
19 using namespace mlir::vector;
20 
21 // Helper that picks the proper sequence for inserting.
22 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
23                        Value into, int64_t offset) {
24   auto vectorType = into.getType().cast<VectorType>();
25   if (vectorType.getRank() > 1)
26     return rewriter.create<InsertOp>(loc, from, into, offset);
27   return rewriter.create<vector::InsertElementOp>(
28       loc, vectorType, from, into,
29       rewriter.create<arith::ConstantIndexOp>(loc, offset));
30 }
31 
32 // Helper that picks the proper sequence for extracting.
33 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
34                         int64_t offset) {
35   auto vectorType = vector.getType().cast<VectorType>();
36   if (vectorType.getRank() > 1)
37     return rewriter.create<ExtractOp>(loc, vector, offset);
38   return rewriter.create<vector::ExtractElementOp>(
39       loc, vectorType.getElementType(), vector,
40       rewriter.create<arith::ConstantIndexOp>(loc, offset));
41 }
42 
43 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
44 /// have different ranks.
45 ///
46 /// When ranks are different, InsertStridedSlice needs to extract a properly
47 /// ranked vector from the destination vector into which to insert. This pattern
48 /// only takes care of this extraction part and forwards the rest to
49 /// [VectorInsertStridedSliceOpSameRankRewritePattern].
50 ///
51 /// For a k-D source and n-D destination vector (k < n), we emit:
52 ///   1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
53 ///      insert the k-D source.
54 ///   2. k-D -> (n-1)-D InsertStridedSlice op
55 ///   3. InsertOp that is the reverse of 1.
56 class VectorInsertStridedSliceOpDifferentRankRewritePattern
57     : public OpRewritePattern<InsertStridedSliceOp> {
58 public:
59   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
60 
61   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
62                                 PatternRewriter &rewriter) const override {
63     auto srcType = op.getSourceVectorType();
64     auto dstType = op.getDestVectorType();
65 
66     if (op.offsets().getValue().empty())
67       return failure();
68 
69     auto loc = op.getLoc();
70     int64_t rankDiff = dstType.getRank() - srcType.getRank();
71     assert(rankDiff >= 0);
72     if (rankDiff == 0)
73       return failure();
74 
75     int64_t rankRest = dstType.getRank() - rankDiff;
76     // Extract / insert the subvector of matching rank and InsertStridedSlice
77     // on it.
78     Value extracted =
79         rewriter.create<ExtractOp>(loc, op.dest(),
80                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
81                                                   /*dropBack=*/rankRest));
82 
83     // A different pattern will kick in for InsertStridedSlice with matching
84     // ranks.
85     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
86         loc, op.source(), extracted,
87         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
88         getI64SubArray(op.strides(), /*dropFront=*/0));
89 
90     rewriter.replaceOpWithNewOp<InsertOp>(
91         op, stridedSliceInnerOp.getResult(), op.dest(),
92         getI64SubArray(op.offsets(), /*dropFront=*/0,
93                        /*dropBack=*/rankRest));
94     return success();
95   }
96 };
97 
98 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
99 /// have the same rank. For each outermost index in the slice:
100 ///   begin    end             stride
101 /// [offset : offset+size*stride : stride]
102 ///   1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
103 ///   2. InsertStridedSlice (k-1)-D into (n-1)-D
104 ///   3. the destination subvector is inserted back in the proper place
105 ///   3. InsertOp that is the reverse of 1.
106 class VectorInsertStridedSliceOpSameRankRewritePattern
107     : public OpRewritePattern<InsertStridedSliceOp> {
108 public:
109   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
110 
111   void initialize() {
112     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
113     // bounded as the rank is strictly decreasing.
114     setHasBoundedRewriteRecursion();
115   }
116 
117   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
118                                 PatternRewriter &rewriter) const override {
119     auto srcType = op.getSourceVectorType();
120     auto dstType = op.getDestVectorType();
121 
122     if (op.offsets().getValue().empty())
123       return failure();
124 
125     int64_t srcRank = srcType.getRank();
126     int64_t dstRank = dstType.getRank();
127     assert(dstRank >= srcRank);
128     if (dstRank != srcRank)
129       return failure();
130 
131     if (srcType == dstType) {
132       rewriter.replaceOp(op, op.source());
133       return success();
134     }
135 
136     int64_t offset =
137         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
138     int64_t size = srcType.getShape().front();
139     int64_t stride =
140         op.strides().getValue().front().cast<IntegerAttr>().getInt();
141 
142     auto loc = op.getLoc();
143     Value res = op.dest();
144 
145     if (srcRank == 1) {
146       int nSrc = srcType.getShape().front();
147       int nDest = dstType.getShape().front();
148       // 1. Scale source to destType so we can shufflevector them together.
149       SmallVector<int64_t> offsets(nDest, 0);
150       for (int64_t i = 0; i < nSrc; ++i)
151         offsets[i] = i;
152       Value scaledSource =
153           rewriter.create<ShuffleOp>(loc, op.source(), op.source(), offsets);
154 
155       // 2. Create a mask where we take the value from scaledSource of dest
156       // depending on the offset.
157       offsets.clear();
158       for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
159         if (i < offset || i >= e || (i - offset) % stride != 0)
160           offsets.push_back(nDest + i);
161         else
162           offsets.push_back((i - offset) / stride);
163       }
164 
165       // 3. Replace with a ShuffleOp.
166       rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.dest(),
167                                              offsets);
168 
169       return success();
170     }
171 
172     // For each slice of the source vector along the most major dimension.
173     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
174          off += stride, ++idx) {
175       // 1. extract the proper subvector (or element) from source
176       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
177       if (extractedSource.getType().isa<VectorType>()) {
178         // 2. If we have a vector, extract the proper subvector from destination
179         // Otherwise we are at the element level and no need to recurse.
180         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
181         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
182         // smaller rank.
183         extractedSource = rewriter.create<InsertStridedSliceOp>(
184             loc, extractedSource, extractedDest,
185             getI64SubArray(op.offsets(), /* dropFront=*/1),
186             getI64SubArray(op.strides(), /* dropFront=*/1));
187       }
188       // 4. Insert the extractedSource into the res vector.
189       res = insertOne(rewriter, loc, extractedSource, res, off);
190     }
191 
192     rewriter.replaceOp(op, res);
193     return success();
194   }
195 };
196 
197 /// Progressive lowering of ExtractStridedSliceOp to either:
198 ///   1. single offset extract as a direct vector::ShuffleOp.
199 ///   2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp +
200 ///      InsertOp/InsertElementOp for the n-D case.
201 class VectorExtractStridedSliceOpRewritePattern
202     : public OpRewritePattern<ExtractStridedSliceOp> {
203 public:
204   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
205 
206   void initialize() {
207     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
208     // is bounded as the rank is strictly decreasing.
209     setHasBoundedRewriteRecursion();
210   }
211 
212   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
213                                 PatternRewriter &rewriter) const override {
214     auto dstType = op.getType();
215 
216     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
217 
218     int64_t offset =
219         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
220     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
221     int64_t stride =
222         op.strides().getValue().front().cast<IntegerAttr>().getInt();
223 
224     auto loc = op.getLoc();
225     auto elemType = dstType.getElementType();
226     assert(elemType.isSignlessIntOrIndexOrFloat());
227 
228     // Single offset can be more efficiently shuffled.
229     if (op.offsets().getValue().size() == 1) {
230       SmallVector<int64_t, 4> offsets;
231       offsets.reserve(size);
232       for (int64_t off = offset, e = offset + size * stride; off < e;
233            off += stride)
234         offsets.push_back(off);
235       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
236                                              op.vector(),
237                                              rewriter.getI64ArrayAttr(offsets));
238       return success();
239     }
240 
241     // Extract/insert on a lower ranked extract strided slice op.
242     Value zero = rewriter.create<arith::ConstantOp>(
243         loc, elemType, rewriter.getZeroAttr(elemType));
244     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
245     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
246          off += stride, ++idx) {
247       Value one = extractOne(rewriter, loc, op.vector(), off);
248       Value extracted = rewriter.create<ExtractStridedSliceOp>(
249           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
250           getI64SubArray(op.sizes(), /* dropFront=*/1),
251           getI64SubArray(op.strides(), /* dropFront=*/1));
252       res = insertOne(rewriter, loc, extracted, res, idx);
253     }
254     rewriter.replaceOp(op, res);
255     return success();
256   }
257 };
258 
259 /// Populate the given list with patterns that convert from Vector to LLVM.
260 void mlir::vector::populateVectorInsertExtractStridedSliceTransforms(
261     RewritePatternSet &patterns) {
262   patterns.add<VectorInsertStridedSliceOpDifferentRankRewritePattern,
263                VectorInsertStridedSliceOpSameRankRewritePattern,
264                VectorExtractStridedSliceOpRewritePattern>(
265       patterns.getContext());
266 }
267