1 //===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements vector.transpose rewrites as AVX patterns for particular 10 // sizes of interest. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/Vector/IR/VectorOps.h" 17 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 18 #include "mlir/Dialect/X86Vector/Transforms.h" 19 #include "mlir/IR/ImplicitLocOpBuilder.h" 20 #include "mlir/IR/Matchers.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "llvm/Support/Format.h" 23 #include "llvm/Support/FormatVariadic.h" 24 25 using namespace mlir; 26 using namespace mlir::vector; 27 using namespace mlir::x86vector; 28 using namespace mlir::x86vector::avx2; 29 using namespace mlir::x86vector::avx2::inline_asm; 30 using namespace mlir::x86vector::avx2::intrin; 31 32 Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm( 33 ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { 34 auto asmDialectAttr = 35 LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel); 36 const auto *asmTp = "vblendps $0, $1, $2, {0}"; 37 const auto *asmCstr = 38 "=x,x,x"; // Careful: constraint parser is very brittle: no ws! 39 SmallVector<Value> asmVals{v1, v2}; 40 auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str(); 41 auto asmOp = b.create<LLVM::InlineAsmOp>( 42 v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr, 43 /*constraints=*/asmCstr, /*has_side_effects=*/false, 44 /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, 45 /*operand_attrs=*/ArrayAttr()); 46 return asmOp.getResult(0); 47 } 48 49 Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b, 50 Value v1, Value v2) { 51 return b.create<vector::ShuffleOp>( 52 v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13}); 53 } 54 55 Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b, 56 Value v1, Value v2) { 57 return b.create<vector::ShuffleOp>( 58 v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15}); 59 } 60 /// a a b b a a b b 61 /// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4): 62 /// 0:127 | 128:255 63 /// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4 64 Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b, 65 Value v1, Value v2, 66 uint8_t mask) { 67 uint8_t b01, b23, b45, b67; 68 MaskHelper::extractShuffle(mask, b01, b23, b45, b67); 69 SmallVector<int64_t> shuffleMask = { 70 b01, b23, b45 + 8, b67 + 8, b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4}; 71 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); 72 } 73 74 // imm[0:1] out of imm[0:3] is: 75 // 0 1 2 3 76 // a[0:127] or a[128:255] or b[0:127] or b[128:255] | 77 // a[0:127] or a[128:255] or b[0:127] or b[128:255] 78 // 0 1 2 3 79 // imm[0:1] out of imm[4:7]. 80 Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps( 81 ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { 82 SmallVector<int64_t> shuffleMask; 83 auto appendToMask = [&](uint8_t control) { 84 if (control == 0) 85 llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3}); 86 else if (control == 1) 87 llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7}); 88 else if (control == 2) 89 llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11}); 90 else if (control == 3) 91 llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15}); 92 else 93 llvm_unreachable("control > 3 : overflow"); 94 }; 95 uint8_t b03, b47; 96 MaskHelper::extractPermute(mask, b03, b47); 97 appendToMask(b03); 98 appendToMask(b47); 99 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); 100 } 101 102 /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2. 103 Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b, 104 Value v1, Value v2, 105 uint8_t mask) { 106 SmallVector<int64_t, 8> shuffleMask; 107 for (int i = 0; i < 8; ++i) { 108 bool isSet = mask & (1 << i); 109 shuffleMask.push_back(!isSet ? i : i + 8); 110 } 111 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); 112 } 113 114 /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model. 115 void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib, 116 MutableArrayRef<Value> vs) { 117 #ifndef NDEBUG 118 auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); 119 assert(vs.size() == 4 && "expects 4 vectors"); 120 assert(llvm::all_of(ValueRange{vs}.getTypes(), 121 [&](Type t) { return t == vt; }) && 122 "expects all types to be vector<8xf32>"); 123 #endif 124 125 Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]); 126 Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]); 127 Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]); 128 Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]); 129 Value s0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 1, 0>()); 130 Value s1 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<3, 2, 3, 2>()); 131 Value s2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 1, 0>()); 132 Value s3 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<3, 2, 3, 2>()); 133 vs[0] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<2, 0>()); 134 vs[1] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<2, 0>()); 135 vs[2] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<3, 1>()); 136 vs[3] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<3, 1>()); 137 } 138 139 /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model. 140 void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib, 141 MutableArrayRef<Value> vs) { 142 auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); 143 (void)vt; 144 assert(vs.size() == 8 && "expects 8 vectors"); 145 assert(llvm::all_of(ValueRange{vs}.getTypes(), 146 [&](Type t) { return t == vt; }) && 147 "expects all types to be vector<8xf32>"); 148 149 Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]); 150 Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]); 151 Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]); 152 Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]); 153 Value t4 = mm256UnpackLoPs(ib, vs[4], vs[5]); 154 Value t5 = mm256UnpackHiPs(ib, vs[4], vs[5]); 155 Value t6 = mm256UnpackLoPs(ib, vs[6], vs[7]); 156 Value t7 = mm256UnpackHiPs(ib, vs[6], vs[7]); 157 158 using inline_asm::mm256BlendPsAsm; 159 Value sh0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 3, 2>()); 160 Value sh2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 3, 2>()); 161 Value sh4 = mm256ShufflePs(ib, t4, t6, MaskHelper::shuffle<1, 0, 3, 2>()); 162 Value sh6 = mm256ShufflePs(ib, t5, t7, MaskHelper::shuffle<1, 0, 3, 2>()); 163 164 Value s0 = 165 mm256BlendPsAsm(ib, t0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 166 Value s1 = 167 mm256BlendPsAsm(ib, t2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 168 Value s2 = 169 mm256BlendPsAsm(ib, t1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 170 Value s3 = 171 mm256BlendPsAsm(ib, t3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 172 Value s4 = 173 mm256BlendPsAsm(ib, t4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 174 Value s5 = 175 mm256BlendPsAsm(ib, t6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 176 Value s6 = 177 mm256BlendPsAsm(ib, t5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 178 Value s7 = 179 mm256BlendPsAsm(ib, t7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 180 181 vs[0] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<2, 0>()); 182 vs[1] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<2, 0>()); 183 vs[2] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<2, 0>()); 184 vs[3] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<2, 0>()); 185 vs[4] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<3, 1>()); 186 vs[5] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<3, 1>()); 187 vs[6] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<3, 1>()); 188 vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>()); 189 } 190 191 /// Rewrite AVX2-specific vector.transpose, for the supported cases and 192 /// depending on the `TransposeLoweringOptions`. The lowering supports 2-D 193 /// transpose cases and n-D cases that have been decomposed into 2-D 194 /// transposition slices. For example, a 3-D transpose: 195 /// 196 /// %0 = vector.transpose %arg0, [2, 0, 1] 197 /// : vector<1024x2048x4096xf32> to vector<4096x1024x2048xf32> 198 /// 199 /// could be sliced into 2-D transposes by tiling two of its dimensions to one 200 /// of the vector lengths supported by the AVX2 patterns (e.g., 4x8): 201 /// 202 /// %0 = vector.transpose %arg0, [2, 0, 1] 203 /// : vector<1x4x8xf32> to vector<8x1x4xf32> 204 /// 205 /// This lowering will analyze the n-D vector.transpose and determine if it's a 206 /// supported 2-D transposition slice where any of the AVX2 patterns can be 207 /// applied. 208 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { 209 public: 210 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 211 212 TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context, 213 int benefit) 214 : OpRewritePattern<vector::TransposeOp>(context, benefit), 215 loweringOptions(loweringOptions) {} 216 217 LogicalResult matchAndRewrite(vector::TransposeOp op, 218 PatternRewriter &rewriter) const override { 219 auto loc = op.getLoc(); 220 221 // Check if the source vector type is supported. AVX2 patterns can only be 222 // applied to f32 vector types with two dimensions greater than one. 223 VectorType srcType = op.getSourceVectorType(); 224 if (!srcType.getElementType().isF32()) 225 return rewriter.notifyMatchFailure(op, "Unsupported vector element type"); 226 227 auto srcGtOneDims = mlir::vector::isTranspose2DSlice(op); 228 if (failed(srcGtOneDims)) 229 return rewriter.notifyMatchFailure( 230 op, "expected transposition on a 2D slice"); 231 232 // Retrieve the sizes of the two dimensions greater than one to be 233 // transposed. 234 int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value())); 235 int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value())); 236 237 auto applyRewrite = [&]() { 238 ImplicitLocOpBuilder ib(loc, rewriter); 239 SmallVector<Value> vs; 240 241 // Reshape the n-D input vector with only two dimensions greater than one 242 // to a 2-D vector. 243 auto flattenedType = 244 VectorType::get({n * m}, op.getSourceVectorType().getElementType()); 245 auto reshInputType = VectorType::get({m, n}, srcType.getElementType()); 246 auto reshInput = 247 ib.create<vector::ShapeCastOp>(flattenedType, op.getVector()); 248 reshInput = ib.create<vector::ShapeCastOp>(reshInputType, reshInput); 249 250 // Extract 1-D vectors from the higher-order dimension of the input 251 // vector. 252 for (int64_t i = 0; i < m; ++i) 253 vs.push_back(ib.create<vector::ExtractOp>(reshInput, i)); 254 255 // Transpose set of 1-D vectors. 256 if (m == 4) 257 transpose4x8xf32(ib, vs); 258 if (m == 8) 259 transpose8x8xf32(ib, vs); 260 261 // Insert transposed 1-D vectors into the higher-order dimension of the 262 // output vector. 263 Value res = ib.create<arith::ConstantOp>(reshInputType, 264 ib.getZeroAttr(reshInputType)); 265 for (int64_t i = 0; i < m; ++i) 266 res = ib.create<vector::InsertOp>(vs[i], res, i); 267 268 // The output vector still has the shape of the input vector (e.g., 4x8). 269 // We have to transpose their dimensions and retrieve its original rank 270 // (e.g., 1x8x1x4x1). 271 res = ib.create<vector::ShapeCastOp>(flattenedType, res); 272 res = ib.create<vector::ShapeCastOp>(op.getResultVectorType(), res); 273 rewriter.replaceOp(op, res); 274 return success(); 275 }; 276 277 if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8) 278 return applyRewrite(); 279 if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8) 280 return applyRewrite(); 281 return failure(); 282 } 283 284 private: 285 LoweringOptions loweringOptions; 286 }; 287 288 void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 289 RewritePatternSet &patterns, LoweringOptions options, int benefit) { 290 patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit); 291 } 292