xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp (revision a7a4c16c672bdd8e245af533a1f170522e26e42a)
1 //===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.shape_cast' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/Vector/IR/VectorOps.h"
17 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
18 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
19 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 
25 #define DEBUG_TYPE "vector-shape-cast-lowering"
26 
27 using namespace mlir;
28 using namespace mlir::vector;
29 
30 /// Increments n-D `indices` by `step` starting from the innermost dimension.
31 static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
32                    int step = 1) {
33   for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
34     assert(indices[dim] < vecType.getDimSize(dim) &&
35            "Indices are out of bound");
36     indices[dim] += step;
37     if (indices[dim] < vecType.getDimSize(dim))
38       break;
39 
40     indices[dim] = 0;
41     step = 1;
42   }
43 }
44 
45 namespace {
46 /// ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D
47 /// vectors progressively. This iterates over the n-1 major dimensions of the
48 /// n-D vector and performs rewrites into:
49 ///   vector.extract from n-D + vector.insert_strided_slice offset into 1-D
50 class ShapeCastOpNDDownCastRewritePattern
51     : public OpRewritePattern<vector::ShapeCastOp> {
52 public:
53   using OpRewritePattern::OpRewritePattern;
54 
55   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
56                                 PatternRewriter &rewriter) const override {
57     auto sourceVectorType = op.getSourceVectorType();
58     auto resultVectorType = op.getResultVectorType();
59     if (sourceVectorType.isScalable() || resultVectorType.isScalable())
60       return failure();
61 
62     int64_t srcRank = sourceVectorType.getRank();
63     int64_t resRank = resultVectorType.getRank();
64     if (srcRank < 2 || resRank != 1)
65       return failure();
66 
67     // Compute the number of 1-D vector elements involved in the reshape.
68     int64_t numElts = 1;
69     for (int64_t dim = 0; dim < srcRank - 1; ++dim)
70       numElts *= sourceVectorType.getDimSize(dim);
71 
72     auto loc = op.getLoc();
73     SmallVector<int64_t> srcIdx(srcRank - 1, 0);
74     SmallVector<int64_t> resIdx(resRank, 0);
75     int64_t extractSize = sourceVectorType.getShape().back();
76     Value result = rewriter.create<arith::ConstantOp>(
77         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
78 
79     // Compute the indices of each 1-D vector element of the source extraction
80     // and destination slice insertion and generate such instructions.
81     for (int64_t i = 0; i < numElts; ++i) {
82       if (i != 0) {
83         incIdx(srcIdx, sourceVectorType, /*step=*/1);
84         incIdx(resIdx, resultVectorType, /*step=*/extractSize);
85       }
86 
87       Value extract =
88           rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
89       result = rewriter.create<vector::InsertStridedSliceOp>(
90           loc, extract, result,
91           /*offsets=*/resIdx, /*strides=*/1);
92     }
93 
94     rewriter.replaceOp(op, result);
95     return success();
96   }
97 };
98 
99 /// ShapeOp 1-D -> n-D upcast serves the purpose of unflattening n-D from 1-D
100 /// vectors progressively. This iterates over the n-1 major dimension of the n-D
101 /// vector and performs rewrites into:
102 ///   vector.extract_strided_slice from 1-D + vector.insert into n-D
103 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
104 class ShapeCastOpNDUpCastRewritePattern
105     : public OpRewritePattern<vector::ShapeCastOp> {
106 public:
107   using OpRewritePattern::OpRewritePattern;
108 
109   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
110                                 PatternRewriter &rewriter) const override {
111     auto sourceVectorType = op.getSourceVectorType();
112     auto resultVectorType = op.getResultVectorType();
113     if (sourceVectorType.isScalable() || resultVectorType.isScalable())
114       return failure();
115 
116     int64_t srcRank = sourceVectorType.getRank();
117     int64_t resRank = resultVectorType.getRank();
118     if (srcRank != 1 || resRank < 2)
119       return failure();
120 
121     // Compute the number of 1-D vector elements involved in the reshape.
122     int64_t numElts = 1;
123     for (int64_t dim = 0; dim < resRank - 1; ++dim)
124       numElts *= resultVectorType.getDimSize(dim);
125 
126     // Compute the indices of each 1-D vector element of the source slice
127     // extraction and destination insertion and generate such instructions.
128     auto loc = op.getLoc();
129     SmallVector<int64_t> srcIdx(srcRank, 0);
130     SmallVector<int64_t> resIdx(resRank - 1, 0);
131     int64_t extractSize = resultVectorType.getShape().back();
132     Value result = rewriter.create<arith::ConstantOp>(
133         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
134     for (int64_t i = 0; i < numElts; ++i) {
135       if (i != 0) {
136         incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
137         incIdx(resIdx, resultVectorType, /*step=*/1);
138       }
139 
140       Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
141           loc, op.getSource(), /*offsets=*/srcIdx, /*sizes=*/extractSize,
142           /*strides=*/1);
143       result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
144     }
145     rewriter.replaceOp(op, result);
146     return success();
147   }
148 };
149 
150 // We typically should not lower general shape cast operations into data
151 // movement instructions, since the assumption is that these casts are
152 // optimized away during progressive lowering. For completeness, however,
153 // we fall back to a reference implementation that moves all elements
154 // into the right place if we get here.
155 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
156 public:
157   using OpRewritePattern::OpRewritePattern;
158 
159   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
160                                 PatternRewriter &rewriter) const override {
161     Location loc = op.getLoc();
162     auto sourceVectorType = op.getSourceVectorType();
163     auto resultVectorType = op.getResultVectorType();
164 
165     if (sourceVectorType.isScalable() || resultVectorType.isScalable())
166       return failure();
167 
168     // Special case for n-D / 1-D lowerings with better implementations.
169     int64_t srcRank = sourceVectorType.getRank();
170     int64_t resRank = resultVectorType.getRank();
171     if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
172       return failure();
173 
174     // Generic ShapeCast lowering path goes all the way down to unrolled scalar
175     // extract/insert chains.
176     int64_t numElts = 1;
177     for (int64_t r = 0; r < srcRank; r++)
178       numElts *= sourceVectorType.getDimSize(r);
179     // Replace with data movement operations:
180     //    x[0,0,0] = y[0,0]
181     //    x[0,0,1] = y[0,1]
182     //    x[0,1,0] = y[0,2]
183     // etc., incrementing the two index vectors "row-major"
184     // within the source and result shape.
185     SmallVector<int64_t> srcIdx(srcRank, 0);
186     SmallVector<int64_t> resIdx(resRank, 0);
187     Value result = rewriter.create<arith::ConstantOp>(
188         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
189     for (int64_t i = 0; i < numElts; i++) {
190       if (i != 0) {
191         incIdx(srcIdx, sourceVectorType);
192         incIdx(resIdx, resultVectorType);
193       }
194 
195       Value extract;
196       if (srcRank == 0) {
197         // 0-D vector special case
198         assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
199         extract = rewriter.create<vector::ExtractElementOp>(
200             loc, op.getSourceVectorType().getElementType(), op.getSource());
201       } else {
202         extract =
203             rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
204       }
205 
206       if (resRank == 0) {
207         // 0-D vector special case
208         assert(resIdx.empty() && "Unexpected indices for 0-D vector");
209         result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
210       } else {
211         result =
212             rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
213       }
214     }
215     rewriter.replaceOp(op, result);
216     return success();
217   }
218 };
219 
220 /// A shape_cast lowering for scalable vectors with a single trailing scalable
221 /// dimension. This is similar to the general shape_cast lowering but makes use
222 /// of vector.scalable.insert and vector.scalable.extract to move elements a
223 /// subvector at a time.
224 ///
225 /// E.g.:
226 /// ```
227 /// // Flatten scalable vector
228 /// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32>
229 /// ```
230 /// is rewritten to:
231 /// ```
232 /// // Flatten scalable vector
233 /// %c = arith.constant dense<0> : vector<[8]xi32>
234 /// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
235 /// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32>
236 /// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
237 /// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32>
238 /// ```
239 /// or:
240 /// ```
241 /// // Un-flatten scalable vector
242 /// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>
243 /// ```
244 /// is rewritten to:
245 /// ```
246 /// // Un-flatten scalable vector
247 /// %c = arith.constant dense<0> : vector<2x1x[4]xi32>
248 /// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32>
249 /// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
250 /// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32>
251 /// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
252 /// ```
253 class ScalableShapeCastOpRewritePattern
254     : public OpRewritePattern<vector::ShapeCastOp> {
255 public:
256   using OpRewritePattern::OpRewritePattern;
257 
258   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
259                                 PatternRewriter &rewriter) const override {
260 
261     Location loc = op.getLoc();
262     auto sourceVectorType = op.getSourceVectorType();
263     auto resultVectorType = op.getResultVectorType();
264     auto srcRank = sourceVectorType.getRank();
265     auto resRank = resultVectorType.getRank();
266 
267     // This can only lower shape_casts where both the source and result types
268     // have a single trailing scalable dimension. This is because there are no
269     // legal representation of other scalable types in LLVM (and likely won't be
270     // soon). There are also (currently) no operations that can index or extract
271     // from >= 2-D scalable vectors or scalable vectors of fixed vectors.
272     if (!isTrailingDimScalable(sourceVectorType) ||
273         !isTrailingDimScalable(resultVectorType)) {
274       return failure();
275     }
276 
277     // The sizes of the trailing dimension of the source and result vectors, the
278     // size of subvector to move, and the number of elements in the vectors.
279     // These are "min" sizes as they are the size when vscale == 1.
280     auto minSourceTrailingSize = sourceVectorType.getShape().back();
281     auto minResultTrailingSize = resultVectorType.getShape().back();
282     auto minExtractionSize =
283         std::min(minSourceTrailingSize, minResultTrailingSize);
284     int64_t minNumElts = 1;
285     for (auto size : sourceVectorType.getShape())
286       minNumElts *= size;
287 
288     // The subvector type to move from the source to the result. Note that this
289     // is a scalable vector. This rewrite will generate code in terms of the
290     // "min" size (vscale == 1 case), that scales to any vscale.
291     auto extractionVectorType = VectorType::get(
292         {minExtractionSize}, sourceVectorType.getElementType(), {true});
293 
294     Value result = rewriter.create<arith::ConstantOp>(
295         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
296 
297     SmallVector<int64_t> srcIdx(srcRank, 0);
298     SmallVector<int64_t> resIdx(resRank, 0);
299 
300     // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
301     // once D150000 lands.
302     Value currentResultScalableVector;
303     Value currentSourceScalableVector;
304     for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
305       // 1. Extract a scalable subvector from the source vector.
306       if (!currentSourceScalableVector) {
307         if (srcRank != 1) {
308           currentSourceScalableVector = rewriter.create<vector::ExtractOp>(
309               loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back());
310         } else {
311           currentSourceScalableVector = op.getSource();
312         }
313       }
314       Value sourceSubVector = currentSourceScalableVector;
315       if (minExtractionSize < minSourceTrailingSize) {
316         sourceSubVector = rewriter.create<vector::ScalableExtractOp>(
317             loc, extractionVectorType, sourceSubVector, srcIdx.back());
318       }
319 
320       // 2. Insert the scalable subvector into the result vector.
321       if (!currentResultScalableVector) {
322         if (minExtractionSize == minResultTrailingSize) {
323           currentResultScalableVector = sourceSubVector;
324         } else if (resRank != 1) {
325           currentResultScalableVector = rewriter.create<vector::ExtractOp>(
326               loc, result, llvm::ArrayRef(resIdx).drop_back());
327         } else {
328           currentResultScalableVector = result;
329         }
330       }
331       if (minExtractionSize < minResultTrailingSize) {
332         currentResultScalableVector = rewriter.create<vector::ScalableInsertOp>(
333             loc, sourceSubVector, currentResultScalableVector, resIdx.back());
334       }
335 
336       // 3. Update the source and result scalable vectors if needed.
337       if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
338           currentResultScalableVector != result) {
339         // Finished row of result. Insert complete scalable vector into result
340         // (n-D) vector.
341         result = rewriter.create<vector::InsertOp>(
342             loc, currentResultScalableVector, result,
343             llvm::ArrayRef(resIdx).drop_back());
344         currentResultScalableVector = {};
345       }
346       if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
347         // Finished row of source.
348         currentSourceScalableVector = {};
349       }
350 
351       // 4. Increment the insert/extract indices, stepping by minExtractionSize
352       // for the trailing dimensions.
353       incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize);
354       incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize);
355     }
356 
357     rewriter.replaceOp(op, result);
358     return success();
359   }
360 
361   static bool isTrailingDimScalable(VectorType type) {
362     return type.getRank() >= 1 && type.getScalableDims().back() &&
363            !llvm::is_contained(type.getScalableDims().drop_back(), true);
364   }
365 };
366 
367 } // namespace
368 
369 void mlir::vector::populateVectorShapeCastLoweringPatterns(
370     RewritePatternSet &patterns, PatternBenefit benefit) {
371   patterns.add<ShapeCastOpNDDownCastRewritePattern,
372                ShapeCastOpNDUpCastRewritePattern, ShapeCastOpRewritePattern,
373                ScalableShapeCastOpRewritePattern>(patterns.getContext(),
374                                                   benefit);
375 }
376