xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.shape_cast' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Arith/Utils/Utils.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23 #include "mlir/Dialect/Vector/IR/VectorOps.h"
24 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
25 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
26 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
27 #include "mlir/IR/BuiltinAttributeInterfaces.h"
28 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/ImplicitLocOpBuilder.h"
30 #include "mlir/IR/Location.h"
31 #include "mlir/IR/Matchers.h"
32 #include "mlir/IR/PatternMatch.h"
33 #include "mlir/IR/TypeUtilities.h"
34 #include "mlir/Interfaces/VectorInterfaces.h"
35 
36 #define DEBUG_TYPE "vector-shape-cast-lowering"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 namespace {
42 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
43 /// vectors progressively on the way to target llvm.matrix intrinsics.
44 /// This iterates over the most major dimension of the 2-D vector and performs
45 /// rewrites into:
46 ///   vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
47 class ShapeCastOp2DDownCastRewritePattern
48     : public OpRewritePattern<vector::ShapeCastOp> {
49 public:
50   using OpRewritePattern::OpRewritePattern;
51 
52   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
53                                 PatternRewriter &rewriter) const override {
54     auto sourceVectorType = op.getSourceVectorType();
55     auto resultVectorType = op.getResultVectorType();
56 
57     if (sourceVectorType.isScalable() || resultVectorType.isScalable())
58       return failure();
59 
60     if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
61       return failure();
62 
63     auto loc = op.getLoc();
64     Value desc = rewriter.create<arith::ConstantOp>(
65         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
66     unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
67     for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
68       Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
69       desc = rewriter.create<vector::InsertStridedSliceOp>(
70           loc, vec, desc,
71           /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
72     }
73     rewriter.replaceOp(op, desc);
74     return success();
75   }
76 };
77 
78 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
79 /// vectors progressively.
80 /// This iterates over the most major dimension of the 2-D vector and performs
81 /// rewrites into:
82 ///   vector.extract_strided_slice from 1-D + vector.insert into 2-D
83 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
84 class ShapeCastOp2DUpCastRewritePattern
85     : public OpRewritePattern<vector::ShapeCastOp> {
86 public:
87   using OpRewritePattern::OpRewritePattern;
88 
89   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
90                                 PatternRewriter &rewriter) const override {
91     auto sourceVectorType = op.getSourceVectorType();
92     auto resultVectorType = op.getResultVectorType();
93 
94     if (sourceVectorType.isScalable() || resultVectorType.isScalable())
95       return failure();
96 
97     if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
98       return failure();
99 
100     auto loc = op.getLoc();
101     Value desc = rewriter.create<arith::ConstantOp>(
102         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
103     unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
104     for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
105       Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
106           loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
107           /*sizes=*/mostMinorVectorSize,
108           /*strides=*/1);
109       desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
110     }
111     rewriter.replaceOp(op, desc);
112     return success();
113   }
114 };
115 
116 static void incIdx(llvm::MutableArrayRef<int64_t> idx, VectorType tp,
117                    int dimIdx, int initialStep = 1) {
118   int step = initialStep;
119   for (int d = dimIdx; d >= 0; d--) {
120     idx[d] += step;
121     if (idx[d] >= tp.getDimSize(d)) {
122       idx[d] = 0;
123       step = 1;
124     } else {
125       break;
126     }
127   }
128 }
129 
130 // We typically should not lower general shape cast operations into data
131 // movement instructions, since the assumption is that these casts are
132 // optimized away during progressive lowering. For completeness, however,
133 // we fall back to a reference implementation that moves all elements
134 // into the right place if we get here.
135 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
136 public:
137   using OpRewritePattern::OpRewritePattern;
138 
139   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
140                                 PatternRewriter &rewriter) const override {
141     Location loc = op.getLoc();
142     auto sourceVectorType = op.getSourceVectorType();
143     auto resultVectorType = op.getResultVectorType();
144 
145     if (sourceVectorType.isScalable() || resultVectorType.isScalable())
146       return failure();
147 
148     // Special case 2D / 1D lowerings with better implementations.
149     // TODO: make is ND / 1D to allow generic ND -> 1D -> MD.
150     int64_t srcRank = sourceVectorType.getRank();
151     int64_t resRank = resultVectorType.getRank();
152     if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
153       return failure();
154 
155     // Generic ShapeCast lowering path goes all the way down to unrolled scalar
156     // extract/insert chains.
157     // TODO: consider evolving the semantics to only allow 1D source or dest and
158     // drop this potentially very expensive lowering.
159     // Compute number of elements involved in the reshape.
160     int64_t numElts = 1;
161     for (int64_t r = 0; r < srcRank; r++)
162       numElts *= sourceVectorType.getDimSize(r);
163     // Replace with data movement operations:
164     //    x[0,0,0] = y[0,0]
165     //    x[0,0,1] = y[0,1]
166     //    x[0,1,0] = y[0,2]
167     // etc., incrementing the two index vectors "row-major"
168     // within the source and result shape.
169     SmallVector<int64_t> srcIdx(srcRank);
170     SmallVector<int64_t> resIdx(resRank);
171     Value result = rewriter.create<arith::ConstantOp>(
172         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
173     for (int64_t i = 0; i < numElts; i++) {
174       if (i != 0) {
175         incIdx(srcIdx, sourceVectorType, srcRank - 1);
176         incIdx(resIdx, resultVectorType, resRank - 1);
177       }
178 
179       Value extract;
180       if (srcRank == 0) {
181         // 0-D vector special case
182         assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
183         extract = rewriter.create<vector::ExtractElementOp>(
184             loc, op.getSourceVectorType().getElementType(), op.getSource());
185       } else {
186         extract =
187             rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
188       }
189 
190       if (resRank == 0) {
191         // 0-D vector special case
192         assert(resIdx.empty() && "Unexpected indices for 0-D vector");
193         result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
194       } else {
195         result =
196             rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
197       }
198     }
199     rewriter.replaceOp(op, result);
200     return success();
201   }
202 };
203 
204 /// A shape_cast lowering for scalable vectors with a single trailing scalable
205 /// dimension. This is similar to the general shape_cast lowering but makes use
206 /// of vector.scalable.insert and vector.scalable.extract to move elements a
207 /// subvector at a time.
208 ///
209 /// E.g.:
210 /// ```
211 /// // Flatten scalable vector
212 /// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32>
213 /// ```
214 /// is rewritten to:
215 /// ```
216 /// // Flatten scalable vector
217 /// %c = arith.constant dense<0> : vector<[8]xi32>
218 /// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
219 /// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32>
220 /// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
221 /// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32>
222 /// ```
223 /// or:
224 /// ```
225 /// // Un-flatten scalable vector
226 /// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>
227 /// ```
228 /// is rewritten to:
229 /// ```
230 /// // Un-flatten scalable vector
231 /// %c = arith.constant dense<0> : vector<2x1x[4]xi32>
232 /// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32>
233 /// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
234 /// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32>
235 /// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
236 /// ```
237 class ScalableShapeCastOpRewritePattern
238     : public OpRewritePattern<vector::ShapeCastOp> {
239 public:
240   using OpRewritePattern::OpRewritePattern;
241 
242   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
243                                 PatternRewriter &rewriter) const override {
244 
245     Location loc = op.getLoc();
246     auto sourceVectorType = op.getSourceVectorType();
247     auto resultVectorType = op.getResultVectorType();
248     auto srcRank = sourceVectorType.getRank();
249     auto resRank = resultVectorType.getRank();
250 
251     // This can only lower shape_casts where both the source and result types
252     // have a single trailing scalable dimension. This is because there are no
253     // legal representation of other scalable types in LLVM (and likely won't be
254     // soon). There are also (currently) no operations that can index or extract
255     // from >= 2D scalable vectors or scalable vectors of fixed vectors.
256     if (!isTrailingDimScalable(sourceVectorType) ||
257         !isTrailingDimScalable(resultVectorType)) {
258       return failure();
259     }
260 
261     // The sizes of the trailing dimension of the source and result vectors, the
262     // size of subvector to move, and the number of elements in the vectors.
263     // These are "min" sizes as they are the size when vscale == 1.
264     auto minSourceTrailingSize = sourceVectorType.getShape().back();
265     auto minResultTrailingSize = resultVectorType.getShape().back();
266     auto minExtractionSize =
267         std::min(minSourceTrailingSize, minResultTrailingSize);
268     int64_t minNumElts = 1;
269     for (auto size : sourceVectorType.getShape())
270       minNumElts *= size;
271 
272     // The subvector type to move from the source to the result. Note that this
273     // is a scalable vector. This rewrite will generate code in terms of the
274     // "min" size (vscale == 1 case), that scales to any vscale.
275     auto extractionVectorType = VectorType::get(
276         {minExtractionSize}, sourceVectorType.getElementType(), {true});
277 
278     Value result = rewriter.create<arith::ConstantOp>(
279         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
280 
281     SmallVector<int64_t> srcIdx(srcRank);
282     SmallVector<int64_t> resIdx(resRank);
283 
284     // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
285     // once D150000 lands.
286     Value currentResultScalableVector;
287     Value currentSourceScalableVector;
288     for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
289       // 1. Extract a scalable subvector from the source vector.
290       if (!currentSourceScalableVector) {
291         if (srcRank != 1) {
292           currentSourceScalableVector = rewriter.create<vector::ExtractOp>(
293               loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back());
294         } else {
295           currentSourceScalableVector = op.getSource();
296         }
297       }
298       Value sourceSubVector = currentSourceScalableVector;
299       if (minExtractionSize < minSourceTrailingSize) {
300         sourceSubVector = rewriter.create<vector::ScalableExtractOp>(
301             loc, extractionVectorType, sourceSubVector, srcIdx.back());
302       }
303 
304       // 2. Insert the scalable subvector into the result vector.
305       if (!currentResultScalableVector) {
306         if (minExtractionSize == minResultTrailingSize) {
307           currentResultScalableVector = sourceSubVector;
308         } else if (resRank != 1) {
309           currentResultScalableVector = rewriter.create<vector::ExtractOp>(
310               loc, result, llvm::ArrayRef(resIdx).drop_back());
311         } else {
312           currentResultScalableVector = result;
313         }
314       }
315       if (minExtractionSize < minResultTrailingSize) {
316         currentResultScalableVector = rewriter.create<vector::ScalableInsertOp>(
317             loc, sourceSubVector, currentResultScalableVector, resIdx.back());
318       }
319 
320       // 3. Update the source and result scalable vectors if needed.
321       if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
322           currentResultScalableVector != result) {
323         // Finished row of result. Insert complete scalable vector into result
324         // (n-D) vector.
325         result = rewriter.create<vector::InsertOp>(
326             loc, currentResultScalableVector, result,
327             llvm::ArrayRef(resIdx).drop_back());
328         currentResultScalableVector = {};
329       }
330       if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
331         // Finished row of source.
332         currentSourceScalableVector = {};
333       }
334 
335       // 4. Increment the insert/extract indices, stepping by minExtractionSize
336       // for the trailing dimensions.
337       incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize);
338       incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize);
339     }
340 
341     rewriter.replaceOp(op, result);
342     return success();
343   }
344 
345   static bool isTrailingDimScalable(VectorType type) {
346     return type.getRank() >= 1 && type.getScalableDims().back() &&
347            !llvm::is_contained(type.getScalableDims().drop_back(), true);
348   }
349 };
350 
351 } // namespace
352 
353 void mlir::vector::populateVectorShapeCastLoweringPatterns(
354     RewritePatternSet &patterns, PatternBenefit benefit) {
355   patterns.add<ShapeCastOp2DDownCastRewritePattern,
356                ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern,
357                ScalableShapeCastOpRewritePattern>(patterns.getContext(),
358                                                   benefit);
359 }
360