xref: /llvm-project/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp (revision 9cbc1f29cabc01c02a523c11d098c00650f6955c)
1 //===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements vector.transpose rewrites as AVX patterns for particular
10 // sizes of interest.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/Vector/IR/VectorOps.h"
17 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
18 #include "mlir/Dialect/X86Vector/Transforms.h"
19 #include "mlir/IR/ImplicitLocOpBuilder.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/Support/Format.h"
23 #include "llvm/Support/FormatVariadic.h"
24 
25 using namespace mlir;
26 using namespace mlir::vector;
27 using namespace mlir::x86vector;
28 using namespace mlir::x86vector::avx2;
29 using namespace mlir::x86vector::avx2::inline_asm;
30 using namespace mlir::x86vector::avx2::intrin;
31 
32 Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
33     ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
34   auto asmDialectAttr =
35       LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
36   const auto *asmTp = "vblendps $0, $1, $2, {0}";
37   const auto *asmCstr =
38       "=x,x,x"; // Careful: constraint parser is very brittle: no ws!
39   SmallVector<Value> asmVals{v1, v2};
40   auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str();
41   auto asmOp = b.create<LLVM::InlineAsmOp>(
42       v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr,
43       /*constraints=*/asmCstr, /*has_side_effects=*/false,
44       /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
45       /*operand_attrs=*/ArrayAttr());
46   return asmOp.getResult(0);
47 }
48 
49 Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b,
50                                                      Value v1, Value v2) {
51   return b.create<vector::ShuffleOp>(
52       v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
53 }
54 
55 Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b,
56                                                      Value v1, Value v2) {
57   return b.create<vector::ShuffleOp>(
58       v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
59 }
60 ///                            a  a   b   b  a  a   b   b
61 /// Takes an 8 bit mask, 2 bit for each position of a[0, 3)  **and** b[0, 4):
62 ///                                 0:127    |         128:255
63 ///                            b01  b23  C8  D8  |  b01+4 b23+4 C8+4 D8+4
64 Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b,
65                                                     Value v1, Value v2,
66                                                     uint8_t mask) {
67   uint8_t b01, b23, b45, b67;
68   MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
69   SmallVector<int64_t> shuffleMask = {
70       b01, b23, b45 + 8, b67 + 8, b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
71   return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
72 }
73 
74 // imm[0:1] out of imm[0:3] is:
75 //    0             1           2             3
76 // a[0:127] or a[128:255] or b[0:127] or b[128:255]    |
77 //          a[0:127] or a[128:255] or b[0:127] or b[128:255]
78 //             0             1           2             3
79 // imm[0:1] out of imm[4:7].
80 Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps(
81     ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
82   SmallVector<int64_t> shuffleMask;
83   auto appendToMask = [&](uint8_t control) {
84     if (control == 0)
85       llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3});
86     else if (control == 1)
87       llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7});
88     else if (control == 2)
89       llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11});
90     else if (control == 3)
91       llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15});
92     else
93       llvm_unreachable("control > 3 : overflow");
94   };
95   uint8_t b03, b47;
96   MaskHelper::extractPermute(mask, b03, b47);
97   appendToMask(b03);
98   appendToMask(b47);
99   return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
100 }
101 
102 /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
103 Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b,
104                                                   Value v1, Value v2,
105                                                   uint8_t mask) {
106   SmallVector<int64_t, 8> shuffleMask;
107   for (int i = 0; i < 8; ++i) {
108     bool isSet = mask & (1 << i);
109     shuffleMask.push_back(!isSet ? i : i + 8);
110   }
111   return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
112 }
113 
114 /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
115 void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
116                                              MutableArrayRef<Value> vs) {
117 #ifndef NDEBUG
118   auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
119   assert(vs.size() == 4 && "expects 4 vectors");
120   assert(llvm::all_of(ValueRange{vs}.getTypes(),
121                       [&](Type t) { return t == vt; }) &&
122          "expects all types to be vector<8xf32>");
123 #endif
124 
125   Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
126   Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
127   Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
128   Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
129   Value s0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 1, 0>());
130   Value s1 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<3, 2, 3, 2>());
131   Value s2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 1, 0>());
132   Value s3 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<3, 2, 3, 2>());
133   vs[0] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<2, 0>());
134   vs[1] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<2, 0>());
135   vs[2] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<3, 1>());
136   vs[3] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<3, 1>());
137 }
138 
139 /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
140 void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
141                                              MutableArrayRef<Value> vs) {
142   auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
143   (void)vt;
144   assert(vs.size() == 8 && "expects 8 vectors");
145   assert(llvm::all_of(ValueRange{vs}.getTypes(),
146                       [&](Type t) { return t == vt; }) &&
147          "expects all types to be vector<8xf32>");
148 
149   Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
150   Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
151   Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
152   Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
153   Value t4 = mm256UnpackLoPs(ib, vs[4], vs[5]);
154   Value t5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
155   Value t6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
156   Value t7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
157 
158   using inline_asm::mm256BlendPsAsm;
159   Value sh0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 3, 2>());
160   Value sh2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 3, 2>());
161   Value sh4 = mm256ShufflePs(ib, t4, t6, MaskHelper::shuffle<1, 0, 3, 2>());
162   Value sh6 = mm256ShufflePs(ib, t5, t7, MaskHelper::shuffle<1, 0, 3, 2>());
163 
164   Value s0 =
165       mm256BlendPsAsm(ib, t0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
166   Value s1 =
167       mm256BlendPsAsm(ib, t2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
168   Value s2 =
169       mm256BlendPsAsm(ib, t1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
170   Value s3 =
171       mm256BlendPsAsm(ib, t3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
172   Value s4 =
173       mm256BlendPsAsm(ib, t4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
174   Value s5 =
175       mm256BlendPsAsm(ib, t6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
176   Value s6 =
177       mm256BlendPsAsm(ib, t5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
178   Value s7 =
179       mm256BlendPsAsm(ib, t7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
180 
181   vs[0] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<2, 0>());
182   vs[1] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<2, 0>());
183   vs[2] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<2, 0>());
184   vs[3] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<2, 0>());
185   vs[4] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<3, 1>());
186   vs[5] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<3, 1>());
187   vs[6] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<3, 1>());
188   vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>());
189 }
190 
191 /// Rewrite AVX2-specific vector.transpose, for the supported cases and
192 /// depending on the `TransposeLoweringOptions`. The lowering supports 2-D
193 /// transpose cases and n-D cases that have been decomposed into 2-D
194 /// transposition slices. For example, a 3-D transpose:
195 ///
196 ///   %0 = vector.transpose %arg0, [2, 0, 1]
197 ///      : vector<1024x2048x4096xf32> to vector<4096x1024x2048xf32>
198 ///
199 /// could be sliced into 2-D transposes by tiling two of its dimensions to one
200 /// of the vector lengths supported by the AVX2 patterns (e.g., 4x8):
201 ///
202 ///   %0 = vector.transpose %arg0, [2, 0, 1]
203 ///      : vector<1x4x8xf32> to vector<8x1x4xf32>
204 ///
205 /// This lowering will analyze the n-D vector.transpose and determine if it's a
206 /// supported 2-D transposition slice where any of the AVX2 patterns can be
207 /// applied.
208 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
209 public:
210   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
211 
212   TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context,
213                       int benefit)
214       : OpRewritePattern<vector::TransposeOp>(context, benefit),
215         loweringOptions(loweringOptions) {}
216 
217   LogicalResult matchAndRewrite(vector::TransposeOp op,
218                                 PatternRewriter &rewriter) const override {
219     auto loc = op.getLoc();
220 
221     // Check if the source vector type is supported. AVX2 patterns can only be
222     // applied to f32 vector types with two dimensions greater than one.
223     VectorType srcType = op.getSourceVectorType();
224     if (!srcType.getElementType().isF32())
225       return rewriter.notifyMatchFailure(op, "Unsupported vector element type");
226 
227     auto srcGtOneDims = mlir::vector::isTranspose2DSlice(op);
228     if (failed(srcGtOneDims))
229       return rewriter.notifyMatchFailure(
230           op, "expected transposition on a 2D slice");
231 
232     // Retrieve the sizes of the two dimensions greater than one to be
233     // transposed.
234     int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
235     int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
236 
237     auto applyRewrite = [&]() {
238       ImplicitLocOpBuilder ib(loc, rewriter);
239       SmallVector<Value> vs;
240 
241       // Reshape the n-D input vector with only two dimensions greater than one
242       // to a 2-D vector.
243       auto flattenedType =
244           VectorType::get({n * m}, op.getSourceVectorType().getElementType());
245       auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
246       auto reshInput =
247           ib.create<vector::ShapeCastOp>(flattenedType, op.getVector());
248       reshInput = ib.create<vector::ShapeCastOp>(reshInputType, reshInput);
249 
250       // Extract 1-D vectors from the higher-order dimension of the input
251       // vector.
252       for (int64_t i = 0; i < m; ++i)
253         vs.push_back(ib.create<vector::ExtractOp>(reshInput, i));
254 
255       // Transpose set of 1-D vectors.
256       if (m == 4)
257         transpose4x8xf32(ib, vs);
258       if (m == 8)
259         transpose8x8xf32(ib, vs);
260 
261       // Insert transposed 1-D vectors into the higher-order dimension of the
262       // output vector.
263       Value res = ib.create<arith::ConstantOp>(reshInputType,
264                                                ib.getZeroAttr(reshInputType));
265       for (int64_t i = 0; i < m; ++i)
266         res = ib.create<vector::InsertOp>(vs[i], res, i);
267 
268       // The output vector still has the shape of the input vector (e.g., 4x8).
269       // We have to transpose their dimensions and retrieve its original rank
270       // (e.g., 1x8x1x4x1).
271       res = ib.create<vector::ShapeCastOp>(flattenedType, res);
272       res = ib.create<vector::ShapeCastOp>(op.getResultVectorType(), res);
273       rewriter.replaceOp(op, res);
274       return success();
275     };
276 
277     if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)
278       return applyRewrite();
279     if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)
280       return applyRewrite();
281     return failure();
282   }
283 
284 private:
285   LoweringOptions loweringOptions;
286 };
287 
288 void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
289     RewritePatternSet &patterns, LoweringOptions options, int benefit) {
290   patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit);
291 }
292