1 //=- Transforms.h - X86Vector Dialect Transformation Entrypoints -*- C++ -*-=// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMS_H 10 #define MLIR_DIALECT_X86VECTOR_TRANSFORMS_H 11 12 #include "mlir/IR/Value.h" 13 14 namespace mlir { 15 16 class ImplicitLocOpBuilder; 17 class LLVMConversionTarget; 18 class LLVMTypeConverter; 19 class RewritePatternSet; 20 21 namespace x86vector { 22 23 /// Helper class to factor out the creation and extraction of masks from nibs. 24 struct MaskHelper { 25 /// b0 captures the lowest bit, b7 captures the highest bit. 26 /// Meant to be used with instructions such as mm256BlendPs. 27 template <uint8_t b0, uint8_t b1, uint8_t b2, uint8_t b3, uint8_t b4, 28 uint8_t b5, uint8_t b6, uint8_t b7> 29 static uint8_t blend() { 30 static_assert(b0 <= 1 && b1 <= 1 && b2 <= 1 && b3 <= 1, "overflow"); 31 static_assert(b4 <= 1 && b5 <= 1 && b6 <= 1 && b7 <= 1, "overflow"); 32 return static_cast<uint8_t>((b7 << 7) | (b6 << 6) | (b5 << 5) | (b4 << 4) | 33 (b3 << 3) | (b2 << 2) | (b1 << 1) | b0); 34 } 35 /// b0 captures the lowest bit, b7 captures the highest bit. 36 /// Meant to be used with instructions such as mm256BlendPs. 37 static void extractBlend(uint8_t mask, uint8_t &b0, uint8_t &b1, uint8_t &b2, 38 uint8_t &b3, uint8_t &b4, uint8_t &b5, uint8_t &b6, 39 uint8_t &b7) { 40 b7 = mask & (1 << 7); 41 b6 = mask & (1 << 6); 42 b5 = mask & (1 << 5); 43 b4 = mask & (1 << 4); 44 b3 = mask & (1 << 3); 45 b2 = mask & (1 << 2); 46 b1 = mask & (1 << 1); 47 b0 = mask & 1; 48 } 49 /// b01 captures the lower 2 bits, b67 captures the higher 2 bits. 50 /// Meant to be used with instructions such as mm256ShufflePs. 51 template <unsigned b67, unsigned b45, unsigned b23, unsigned b01> 52 static uint8_t shuffle() { 53 static_assert(b01 <= 0x03, "overflow"); 54 static_assert(b23 <= 0x03, "overflow"); 55 static_assert(b45 <= 0x03, "overflow"); 56 static_assert(b67 <= 0x03, "overflow"); 57 return static_cast<uint8_t>((b67 << 6) | (b45 << 4) | (b23 << 2) | b01); 58 } 59 /// b01 captures the lower 2 bits, b67 captures the higher 2 bits. 60 static void extractShuffle(uint8_t mask, uint8_t &b01, uint8_t &b23, 61 uint8_t &b45, uint8_t &b67) { 62 b67 = (mask & (0x03 << 6)) >> 6; 63 b45 = (mask & (0x03 << 4)) >> 4; 64 b23 = (mask & (0x03 << 2)) >> 2; 65 b01 = mask & 0x03; 66 } 67 /// b03 captures the lower 4 bits, b47 captures the higher 4 bits. 68 /// Meant to be used with instructions such as mm256Permute2f128Ps. 69 template <unsigned b47, unsigned b03> 70 static uint8_t permute() { 71 static_assert(b03 <= 0x0f, "overflow"); 72 static_assert(b47 <= 0x0f, "overflow"); 73 return static_cast<uint8_t>((b47 << 4) + b03); 74 } 75 /// b03 captures the lower 4 bits, b47 captures the higher 4 bits. 76 static void extractPermute(uint8_t mask, uint8_t &b03, uint8_t &b47) { 77 b47 = (mask & (0x0f << 4)) >> 4; 78 b03 = mask & 0x0f; 79 } 80 }; 81 82 //===----------------------------------------------------------------------===// 83 /// Helpers extracted from: 84 /// - clang/lib/Headers/avxintrin.h 85 /// - clang/test/CodeGen/X86/avx-builtins.c 86 /// - clang/test/CodeGen/X86/avx2-builtins.c 87 /// - clang/test/CodeGen/X86/avx-shuffle-builtins.c 88 /// as well as the Intel Intrinsics Guide 89 /// (https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html) 90 /// make it easier to just implement known good lowerings. 91 /// All intrinsics correspond 1-1 to the Intel definition. 92 //===----------------------------------------------------------------------===// 93 94 namespace avx2 { 95 96 namespace inline_asm { 97 //===----------------------------------------------------------------------===// 98 /// Methods in the inline_asm namespace emit calls to LLVM::InlineAsmOp. 99 //===----------------------------------------------------------------------===// 100 /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2. 101 Value mm256BlendPsAsm(ImplicitLocOpBuilder &b, Value v1, Value v2, 102 uint8_t mask); 103 104 } // namespace inline_asm 105 106 namespace intrin { 107 //===----------------------------------------------------------------------===// 108 /// Methods in the intrin namespace emulate clang's impl. of X86 intrinsics. 109 //===----------------------------------------------------------------------===// 110 /// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13]. 111 Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2); 112 113 /// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13]. 114 Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2); 115 116 /// a a b b a a b b 117 /// Take an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4): 118 /// 0:127 | 128:255 119 /// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4 120 Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask); 121 122 // imm[0:1] out of imm[0:3] is: 123 // 0 1 2 3 124 // a[0:127] or a[128:255] or b[0:127] or b[128:255] | 125 // a[0:127] or a[128:255] or b[0:127] or b[128:255] 126 // 0 1 2 3 127 // imm[0:1] out of imm[4:7]. 128 Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2, 129 uint8_t mask); 130 131 /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2. 132 Value mm256BlendPs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask); 133 } // namespace intrin 134 135 //===----------------------------------------------------------------------===// 136 /// Generic lowerings may either use intrin or inline_asm depending on needs. 137 //===----------------------------------------------------------------------===// 138 /// 4x8xf32-specific AVX2 transpose lowering. 139 void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs); 140 141 /// 8x8xf32-specific AVX2 transpose lowering. 142 void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs); 143 144 /// Structure to control the behavior of specialized AVX2 transpose lowering. 145 struct TransposeLoweringOptions { 146 bool lower4x8xf32_ = false; 147 TransposeLoweringOptions &lower4x8xf32(bool lower = true) { 148 lower4x8xf32_ = lower; 149 return *this; 150 } 151 bool lower8x8xf32_ = false; 152 TransposeLoweringOptions &lower8x8xf32(bool lower = true) { 153 lower8x8xf32_ = lower; 154 return *this; 155 } 156 }; 157 158 /// Options for controlling specialized AVX2 lowerings. 159 struct LoweringOptions { 160 /// Configure specialized vector lowerings. 161 TransposeLoweringOptions transposeOptions; 162 LoweringOptions &setTransposeOptions(TransposeLoweringOptions options) { 163 transposeOptions = options; 164 return *this; 165 } 166 }; 167 168 /// Insert specialized transpose lowering patterns. 169 void populateSpecializedTransposeLoweringPatterns( 170 RewritePatternSet &patterns, LoweringOptions options = LoweringOptions(), 171 int benefit = 10); 172 173 } // namespace avx2 174 } // namespace x86vector 175 176 /// Collect a set of patterns to lower X86Vector ops to ops that map to LLVM 177 /// intrinsics. 178 void populateX86VectorLegalizeForLLVMExportPatterns( 179 const LLVMTypeConverter &converter, RewritePatternSet &patterns); 180 181 /// Configure the target to support lowering X86Vector ops to ops that map to 182 /// LLVM intrinsics. 183 void configureX86VectorLegalizeForExportTarget(LLVMConversionTarget &target); 184 185 } // namespace mlir 186 187 #endif // MLIR_DIALECT_X86VECTOR_TRANSFORMS_H 188