xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- LowerVectorTranspose.cpp - Lower 'vector.transpose' operation ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.transpose' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Arith/Utils/Utils.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23 #include "mlir/Dialect/Vector/IR/VectorOps.h"
24 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
26 #include "mlir/IR/BuiltinAttributeInterfaces.h"
27 #include "mlir/IR/BuiltinTypes.h"
28 #include "mlir/IR/ImplicitLocOpBuilder.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
33 #include "mlir/Interfaces/VectorInterfaces.h"
34 
35 #define DEBUG_TYPE "lower-vector-transpose"
36 
37 using namespace mlir;
38 using namespace mlir::vector;
39 
40 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not
41 /// transposed.
pruneNonTransposedDims(ArrayRef<int64_t> transpose,SmallVectorImpl<int64_t> & result)42 static void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
43                                    SmallVectorImpl<int64_t> &result) {
44   size_t numTransposedDims = transpose.size();
45   for (size_t transpDim : llvm::reverse(transpose)) {
46     if (transpDim != numTransposedDims - 1)
47       break;
48     numTransposedDims--;
49   }
50 
51   result.append(transpose.begin(), transpose.begin() + numTransposedDims);
52 }
53 
54 /// Returns true if the lowering option is a vector shuffle based approach.
isShuffleLike(VectorTransposeLowering lowering)55 static bool isShuffleLike(VectorTransposeLowering lowering) {
56   return lowering == VectorTransposeLowering::Shuffle1D ||
57          lowering == VectorTransposeLowering::Shuffle16x16;
58 }
59 
60 /// Returns a shuffle mask that builds on `vals`. `vals` is the offset base of
61 /// shuffle ops, i.e., the unpack pattern. The method iterates with `vals` to
62 /// create the mask for `numBits` bits vector. The `numBits` have to be a
63 /// multiple of 128. For example, if `vals` is {0, 1, 16, 17} and `numBits` is
64 /// 512, there should be 16 elements in the final result. It constructs the
65 /// below mask to get the unpack elements.
66 ///   [0,    1,    16,    17,
67 ///    0+4,  1+4,  16+4,  17+4,
68 ///    0+8,  1+8,  16+8,  17+8,
69 ///    0+12, 1+12, 16+12, 17+12]
70 static SmallVector<int64_t>
getUnpackShufflePermFor128Lane(ArrayRef<int64_t> vals,int numBits)71 getUnpackShufflePermFor128Lane(ArrayRef<int64_t> vals, int numBits) {
72   assert(numBits % 128 == 0 && "expected numBits is a multiple of 128");
73   int numElem = numBits / 32;
74   SmallVector<int64_t> res;
75   for (int i = 0; i < numElem; i += 4)
76     for (int64_t v : vals)
77       res.push_back(v + i);
78   return res;
79 }
80 
81 /// Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask. For
82 /// example, if it is targeting 512 bit vector, returns
83 ///   vector.shuffle on v1, v2, [0,    1,    16,    17,
84 ///                              0+4,  1+4,  16+4,  17+4,
85 ///                              0+8,  1+8,  16+8,  17+8,
86 ///                              0+12, 1+12, 16+12, 17+12].
createUnpackLoPd(ImplicitLocOpBuilder & b,Value v1,Value v2,int numBits)87 static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2,
88                               int numBits) {
89   int numElem = numBits / 32;
90   return b.create<vector::ShuffleOp>(
91       v1, v2,
92       getUnpackShufflePermFor128Lane({0, 1, numElem, numElem + 1}, numBits));
93 }
94 
95 /// Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask. For
96 /// example, if it is targeting 512 bit vector, returns
97 ///   vector.shuffle, v1, v2, [2,    3,    18,    19,
98 ///                            2+4,  3+4,  18+4,  19+4,
99 ///                            2+8,  3+8,  18+8,  19+8,
100 ///                            2+12, 3+12, 18+12, 19+12].
createUnpackHiPd(ImplicitLocOpBuilder & b,Value v1,Value v2,int numBits)101 static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2,
102                               int numBits) {
103   int numElem = numBits / 32;
104   return b.create<vector::ShuffleOp>(
105       v1, v2,
106       getUnpackShufflePermFor128Lane({2, 3, numElem + 2, numElem + 3},
107                                      numBits));
108 }
109 
110 /// Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask. For
111 /// example, if it is targeting 512 bit vector, returns
112 ///   vector.shuffle, v1, v2, [0,    16,    1,    17,
113 ///                            0+4,  16+4,  1+4,  17+4,
114 ///                            0+8,  16+8,  1+8,  17+8,
115 ///                            0+12, 16+12, 1+12, 17+12].
createUnpackLoPs(ImplicitLocOpBuilder & b,Value v1,Value v2,int numBits)116 static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2,
117                               int numBits) {
118   int numElem = numBits / 32;
119   auto shuffle = b.create<vector::ShuffleOp>(
120       v1, v2,
121       getUnpackShufflePermFor128Lane({0, numElem, 1, numElem + 1}, numBits));
122   return shuffle;
123 }
124 
125 /// Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask. For
126 /// example, if it is targeting 512 bit vector, returns
127 ///   vector.shuffle, v1, v2, [2,    18,    3,    19,
128 ///                            2+4,  18+4,  3+4,  19+4,
129 ///                            2+8,  18+8,  3+8,  19+8,
130 ///                            2+12, 18+12, 3+12, 19+12].
createUnpackHiPs(ImplicitLocOpBuilder & b,Value v1,Value v2,int numBits)131 static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2,
132                               int numBits) {
133   int numElem = numBits / 32;
134   return b.create<vector::ShuffleOp>(
135       v1, v2,
136       getUnpackShufflePermFor128Lane({2, numElem + 2, 3, numElem + 3},
137                                      numBits));
138 }
139 
140 /// Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit
141 /// elements) selected by `mask` from `v1` and `v2`. I.e.,
142 ///
143 /// DEFINE SELECT4(src, control) {
144 ///	CASE(control[1:0]) OF
145 ///	0:	tmp[127:0] := src[127:0]
146 ///	1:	tmp[127:0] := src[255:128]
147 ///	2:	tmp[127:0] := src[383:256]
148 ///	3:	tmp[127:0] := src[511:384]
149 ///	ESAC
150 ///	RETURN tmp[127:0]
151 /// }
152 /// dst[127:0]   := SELECT4(v1[511:0], mask[1:0])
153 /// dst[255:128] := SELECT4(v1[511:0], mask[3:2])
154 /// dst[383:256] := SELECT4(v2[511:0], mask[5:4])
155 /// dst[511:384] := SELECT4(v2[511:0], mask[7:6])
create4x128BitSuffle(ImplicitLocOpBuilder & b,Value v1,Value v2,uint8_t mask)156 static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2,
157                                   uint8_t mask) {
158   assert(cast<VectorType>(v1.getType()).getShape()[0] == 16 &&
159          "expected a vector with length=16");
160   SmallVector<int64_t> shuffleMask;
161   auto appendToMask = [&](int64_t base, uint8_t control) {
162     switch (control) {
163     case 0:
164       llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 0, base + 1,
165                                                         base + 2, base + 3});
166       break;
167     case 1:
168       llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 4, base + 5,
169                                                         base + 6, base + 7});
170       break;
171     case 2:
172       llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 8, base + 9,
173                                                         base + 10, base + 11});
174       break;
175     case 3:
176       llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 12, base + 13,
177                                                         base + 14, base + 15});
178       break;
179     default:
180       llvm_unreachable("control > 3 : overflow");
181     }
182   };
183   uint8_t b01 = mask & 0x3;
184   uint8_t b23 = (mask >> 2) & 0x3;
185   uint8_t b45 = (mask >> 4) & 0x3;
186   uint8_t b67 = (mask >> 6) & 0x3;
187   appendToMask(0, b01);
188   appendToMask(0, b23);
189   appendToMask(16, b45);
190   appendToMask(16, b67);
191   return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
192 }
193 
194 /// Lowers the value to a vector.shuffle op. The `source` is expected to be a
195 /// 1-D vector and have `m`x`n` elements.
transposeToShuffle1D(OpBuilder & b,Value source,int m,int n)196 static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n) {
197   SmallVector<int64_t> mask;
198   mask.reserve(m * n);
199   for (int64_t j = 0; j < n; ++j)
200     for (int64_t i = 0; i < m; ++i)
201       mask.push_back(i * n + j);
202   return b.create<vector::ShuffleOp>(source.getLoc(), source, source, mask);
203 }
204 
205 /// Lowers the value to a sequence of vector.shuffle ops. The `source` is
206 /// expected to be a 16x16 vector.
transposeToShuffle16x16(OpBuilder & builder,Value source,int m,int n)207 static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
208                                      int n) {
209   ImplicitLocOpBuilder b(source.getLoc(), builder);
210   SmallVector<Value> vs;
211   for (int64_t i = 0; i < m; ++i)
212     vs.push_back(b.create<vector::ExtractOp>(source, i));
213 
214   // Interleave 32-bit lanes using
215   //   8x _mm512_unpacklo_epi32
216   //   8x _mm512_unpackhi_epi32
217   Value t0 = createUnpackLoPs(b, vs[0x0], vs[0x1], 512);
218   Value t1 = createUnpackHiPs(b, vs[0x0], vs[0x1], 512);
219   Value t2 = createUnpackLoPs(b, vs[0x2], vs[0x3], 512);
220   Value t3 = createUnpackHiPs(b, vs[0x2], vs[0x3], 512);
221   Value t4 = createUnpackLoPs(b, vs[0x4], vs[0x5], 512);
222   Value t5 = createUnpackHiPs(b, vs[0x4], vs[0x5], 512);
223   Value t6 = createUnpackLoPs(b, vs[0x6], vs[0x7], 512);
224   Value t7 = createUnpackHiPs(b, vs[0x6], vs[0x7], 512);
225   Value t8 = createUnpackLoPs(b, vs[0x8], vs[0x9], 512);
226   Value t9 = createUnpackHiPs(b, vs[0x8], vs[0x9], 512);
227   Value ta = createUnpackLoPs(b, vs[0xa], vs[0xb], 512);
228   Value tb = createUnpackHiPs(b, vs[0xa], vs[0xb], 512);
229   Value tc = createUnpackLoPs(b, vs[0xc], vs[0xd], 512);
230   Value td = createUnpackHiPs(b, vs[0xc], vs[0xd], 512);
231   Value te = createUnpackLoPs(b, vs[0xe], vs[0xf], 512);
232   Value tf = createUnpackHiPs(b, vs[0xe], vs[0xf], 512);
233 
234   // Interleave 64-bit lanes using
235   //   8x _mm512_unpacklo_epi64
236   //   8x _mm512_unpackhi_epi64
237   Value r0 = createUnpackLoPd(b, t0, t2, 512);
238   Value r1 = createUnpackHiPd(b, t0, t2, 512);
239   Value r2 = createUnpackLoPd(b, t1, t3, 512);
240   Value r3 = createUnpackHiPd(b, t1, t3, 512);
241   Value r4 = createUnpackLoPd(b, t4, t6, 512);
242   Value r5 = createUnpackHiPd(b, t4, t6, 512);
243   Value r6 = createUnpackLoPd(b, t5, t7, 512);
244   Value r7 = createUnpackHiPd(b, t5, t7, 512);
245   Value r8 = createUnpackLoPd(b, t8, ta, 512);
246   Value r9 = createUnpackHiPd(b, t8, ta, 512);
247   Value ra = createUnpackLoPd(b, t9, tb, 512);
248   Value rb = createUnpackHiPd(b, t9, tb, 512);
249   Value rc = createUnpackLoPd(b, tc, te, 512);
250   Value rd = createUnpackHiPd(b, tc, te, 512);
251   Value re = createUnpackLoPd(b, td, tf, 512);
252   Value rf = createUnpackHiPd(b, td, tf, 512);
253 
254   // Permute 128-bit lanes using
255   //   16x _mm512_shuffle_i32x4
256   t0 = create4x128BitSuffle(b, r0, r4, 0x88);
257   t1 = create4x128BitSuffle(b, r1, r5, 0x88);
258   t2 = create4x128BitSuffle(b, r2, r6, 0x88);
259   t3 = create4x128BitSuffle(b, r3, r7, 0x88);
260   t4 = create4x128BitSuffle(b, r0, r4, 0xdd);
261   t5 = create4x128BitSuffle(b, r1, r5, 0xdd);
262   t6 = create4x128BitSuffle(b, r2, r6, 0xdd);
263   t7 = create4x128BitSuffle(b, r3, r7, 0xdd);
264   t8 = create4x128BitSuffle(b, r8, rc, 0x88);
265   t9 = create4x128BitSuffle(b, r9, rd, 0x88);
266   ta = create4x128BitSuffle(b, ra, re, 0x88);
267   tb = create4x128BitSuffle(b, rb, rf, 0x88);
268   tc = create4x128BitSuffle(b, r8, rc, 0xdd);
269   td = create4x128BitSuffle(b, r9, rd, 0xdd);
270   te = create4x128BitSuffle(b, ra, re, 0xdd);
271   tf = create4x128BitSuffle(b, rb, rf, 0xdd);
272 
273   // Permute 256-bit lanes using again
274   //   16x _mm512_shuffle_i32x4
275   vs[0x0] = create4x128BitSuffle(b, t0, t8, 0x88);
276   vs[0x1] = create4x128BitSuffle(b, t1, t9, 0x88);
277   vs[0x2] = create4x128BitSuffle(b, t2, ta, 0x88);
278   vs[0x3] = create4x128BitSuffle(b, t3, tb, 0x88);
279   vs[0x4] = create4x128BitSuffle(b, t4, tc, 0x88);
280   vs[0x5] = create4x128BitSuffle(b, t5, td, 0x88);
281   vs[0x6] = create4x128BitSuffle(b, t6, te, 0x88);
282   vs[0x7] = create4x128BitSuffle(b, t7, tf, 0x88);
283   vs[0x8] = create4x128BitSuffle(b, t0, t8, 0xdd);
284   vs[0x9] = create4x128BitSuffle(b, t1, t9, 0xdd);
285   vs[0xa] = create4x128BitSuffle(b, t2, ta, 0xdd);
286   vs[0xb] = create4x128BitSuffle(b, t3, tb, 0xdd);
287   vs[0xc] = create4x128BitSuffle(b, t4, tc, 0xdd);
288   vs[0xd] = create4x128BitSuffle(b, t5, td, 0xdd);
289   vs[0xe] = create4x128BitSuffle(b, t6, te, 0xdd);
290   vs[0xf] = create4x128BitSuffle(b, t7, tf, 0xdd);
291 
292   auto reshInputType = VectorType::get(
293       {m, n}, cast<VectorType>(source.getType()).getElementType());
294   Value res =
295       b.create<arith::ConstantOp>(reshInputType, b.getZeroAttr(reshInputType));
296   for (int64_t i = 0; i < m; ++i)
297     res = b.create<vector::InsertOp>(vs[i], res, i);
298   return res;
299 }
300 
301 namespace {
302 /// Progressive lowering of TransposeOp.
303 /// One:
304 ///   %x = vector.transpose %y, [1, 0]
305 /// is replaced by:
306 ///   %z = arith.constant dense<0.000000e+00>
307 ///   %0 = vector.extract %y[0, 0]
308 ///   %1 = vector.insert %0, %z [0, 0]
309 ///   ..
310 ///   %x = vector.insert .., .. [.., ..]
311 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
312 public:
313   using OpRewritePattern::OpRewritePattern;
314 
TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,MLIRContext * context,PatternBenefit benefit=1)315   TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
316                       MLIRContext *context, PatternBenefit benefit = 1)
317       : OpRewritePattern<vector::TransposeOp>(context, benefit),
318         vectorTransformOptions(vectorTransformOptions) {}
319 
matchAndRewrite(vector::TransposeOp op,PatternRewriter & rewriter) const320   LogicalResult matchAndRewrite(vector::TransposeOp op,
321                                 PatternRewriter &rewriter) const override {
322     auto loc = op.getLoc();
323 
324     Value input = op.getVector();
325     VectorType inputType = op.getSourceVectorType();
326     VectorType resType = op.getResultVectorType();
327 
328     if (inputType.isScalable())
329       return rewriter.notifyMatchFailure(
330           op, "This lowering does not support scalable vectors");
331 
332     // Set up convenience transposition table.
333     ArrayRef<int64_t> transp = op.getPermutation();
334 
335     if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
336         succeeded(isTranspose2DSlice(op)))
337       return rewriter.notifyMatchFailure(
338           op, "Options specifies lowering to shuffle");
339 
340     // Handle a true 2-D matrix transpose differently when requested.
341     if (vectorTransformOptions.vectorTransposeLowering ==
342             vector::VectorTransposeLowering::Flat &&
343         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
344       Type flattenedType =
345           VectorType::get(resType.getNumElements(), resType.getElementType());
346       auto matrix =
347           rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
348       auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
349       auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
350       Value trans = rewriter.create<vector::FlatTransposeOp>(
351           loc, flattenedType, matrix, rows, columns);
352       rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
353       return success();
354     }
355 
356     // Generate unrolled extract/insert ops. We do not unroll the rightmost
357     // (i.e., highest-order) dimensions that are not transposed and leave them
358     // in vector form to improve performance. Therefore, we prune those
359     // dimensions from the shape/transpose data structures used to generate the
360     // extract/insert ops.
361     SmallVector<int64_t> prunedTransp;
362     pruneNonTransposedDims(transp, prunedTransp);
363     size_t numPrunedDims = transp.size() - prunedTransp.size();
364     auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
365     auto prunedInStrides = computeStrides(prunedInShape);
366 
367     // Generates the extract/insert operations for every scalar/vector element
368     // of the leftmost transposed dimensions. We traverse every transpose
369     // element using a linearized index that we delinearize to generate the
370     // appropriate indices for the extract/insert operations.
371     Value result = rewriter.create<arith::ConstantOp>(
372         loc, resType, rewriter.getZeroAttr(resType));
373     int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
374 
375     for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
376          ++linearIdx) {
377       auto extractIdxs = delinearize(linearIdx, prunedInStrides);
378       SmallVector<int64_t> insertIdxs(extractIdxs);
379       applyPermutationToVector(insertIdxs, prunedTransp);
380       Value extractOp =
381           rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
382       result =
383           rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
384     }
385 
386     rewriter.replaceOp(op, result);
387     return success();
388   }
389 
390 private:
391   /// Options to control the vector patterns.
392   vector::VectorTransformsOptions vectorTransformOptions;
393 };
394 
395 /// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
396 /// to 2D vectors with at least one unit dim. For example:
397 ///
398 /// Replace:
399 ///   vector.transpose %0, [1, 0] : vector<4x1xi32>> to
400 ///                                 vector<1x4xi32>
401 /// with:
402 ///   vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
403 ///
404 /// Source with leading unit dim (inverse) is also replaced. Unit dim must
405 /// be fixed. Non-unit dim can be scalable.
406 ///
407 /// TODO: This pattern was introduced specifically to help lower scalable
408 /// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
409 /// to cancel out) would be preferable:
410 ///
411 ///  BEFORE:
412 ///     %0 = some_op
413 ///     %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
414 ///     %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
415 ///  AFTER:
416 ///     %0 = some_op
417 ///     %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
418 ///
419 /// Given the context above, we may want to consider (re-)moving this pattern
420 /// at some later time. I am leaving it for now in case there are other users
421 /// that I am not aware of.
422 class Transpose2DWithUnitDimToShapeCast
423     : public OpRewritePattern<vector::TransposeOp> {
424 public:
425   using OpRewritePattern::OpRewritePattern;
426 
Transpose2DWithUnitDimToShapeCast(MLIRContext * context,PatternBenefit benefit=1)427   Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
428                                     PatternBenefit benefit = 1)
429       : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
430 
matchAndRewrite(vector::TransposeOp op,PatternRewriter & rewriter) const431   LogicalResult matchAndRewrite(vector::TransposeOp op,
432                                 PatternRewriter &rewriter) const override {
433     Value input = op.getVector();
434     VectorType resType = op.getResultVectorType();
435 
436     // Set up convenience transposition table.
437     ArrayRef<int64_t> transp = op.getPermutation();
438 
439     if (resType.getRank() == 2 &&
440         ((resType.getShape().front() == 1 &&
441           !resType.getScalableDims().front()) ||
442          (resType.getShape().back() == 1 &&
443           !resType.getScalableDims().back())) &&
444         transp == ArrayRef<int64_t>({1, 0})) {
445       rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
446       return success();
447     }
448 
449     return failure();
450   }
451 };
452 
453 /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
454 /// If the strategy is Shuffle1D, it will be lowered to:
455 ///   vector.shape_cast 2D -> 1D
456 ///   vector.shuffle
457 ///   vector.shape_cast 1D -> 2D
458 /// If the strategy is Shuffle16x16, it will be lowered to a sequence of shuffle
459 /// ops on 16xf32 vectors.
460 class TransposeOp2DToShuffleLowering
461     : public OpRewritePattern<vector::TransposeOp> {
462 public:
463   using OpRewritePattern::OpRewritePattern;
464 
TransposeOp2DToShuffleLowering(vector::VectorTransformsOptions vectorTransformOptions,MLIRContext * context,PatternBenefit benefit=1)465   TransposeOp2DToShuffleLowering(
466       vector::VectorTransformsOptions vectorTransformOptions,
467       MLIRContext *context, PatternBenefit benefit = 1)
468       : OpRewritePattern<vector::TransposeOp>(context, benefit),
469         vectorTransformOptions(vectorTransformOptions) {}
470 
matchAndRewrite(vector::TransposeOp op,PatternRewriter & rewriter) const471   LogicalResult matchAndRewrite(vector::TransposeOp op,
472                                 PatternRewriter &rewriter) const override {
473     if (!isShuffleLike(vectorTransformOptions.vectorTransposeLowering))
474       return rewriter.notifyMatchFailure(
475           op, "not using vector shuffle based lowering");
476 
477     if (op.getSourceVectorType().isScalable())
478       return rewriter.notifyMatchFailure(
479           op, "vector shuffle lowering not supported for scalable vectors");
480 
481     auto srcGtOneDims = isTranspose2DSlice(op);
482     if (failed(srcGtOneDims))
483       return rewriter.notifyMatchFailure(
484           op, "expected transposition on a 2D slice");
485 
486     VectorType srcType = op.getSourceVectorType();
487     int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
488     int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
489 
490     // Reshape the n-D input vector with only two dimensions greater than one
491     // to a 2-D vector.
492     Location loc = op.getLoc();
493     auto flattenedType = VectorType::get({n * m}, srcType.getElementType());
494     auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
495     auto reshInput = rewriter.create<vector::ShapeCastOp>(loc, flattenedType,
496                                                           op.getVector());
497 
498     Value res;
499     if (vectorTransformOptions.vectorTransposeLowering ==
500             VectorTransposeLowering::Shuffle16x16 &&
501         m == 16 && n == 16) {
502       reshInput =
503           rewriter.create<vector::ShapeCastOp>(loc, reshInputType, reshInput);
504       res = transposeToShuffle16x16(rewriter, reshInput, m, n);
505     } else {
506       // Fallback to shuffle on 1D approach.
507       res = transposeToShuffle1D(rewriter, reshInput, m, n);
508     }
509 
510     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
511         op, op.getResultVectorType(), res);
512 
513     return success();
514   }
515 
516 private:
517   /// Options to control the vector patterns.
518   vector::VectorTransformsOptions vectorTransformOptions;
519 };
520 } // namespace
521 
populateVectorTransposeLoweringPatterns(RewritePatternSet & patterns,VectorTransformsOptions options,PatternBenefit benefit)522 void mlir::vector::populateVectorTransposeLoweringPatterns(
523     RewritePatternSet &patterns, VectorTransformsOptions options,
524     PatternBenefit benefit) {
525   patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
526                                                   benefit);
527   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
528       options, patterns.getContext(), benefit);
529 }
530