18508a63bSEmilio Cota //=- Transforms.h - X86Vector Dialect Transformation Entrypoints -*- C++ -*-=// 28508a63bSEmilio Cota // 38508a63bSEmilio Cota // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 48508a63bSEmilio Cota // See https://llvm.org/LICENSE.txt for license information. 58508a63bSEmilio Cota // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 68508a63bSEmilio Cota // 78508a63bSEmilio Cota //===----------------------------------------------------------------------===// 88508a63bSEmilio Cota 98508a63bSEmilio Cota #ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMS_H 108508a63bSEmilio Cota #define MLIR_DIALECT_X86VECTOR_TRANSFORMS_H 118508a63bSEmilio Cota 1234ff8573SNicolas Vasilache #include "mlir/IR/Value.h" 1334ff8573SNicolas Vasilache 148508a63bSEmilio Cota namespace mlir { 158508a63bSEmilio Cota 1634ff8573SNicolas Vasilache class ImplicitLocOpBuilder; 178508a63bSEmilio Cota class LLVMConversionTarget; 188508a63bSEmilio Cota class LLVMTypeConverter; 198508a63bSEmilio Cota class RewritePatternSet; 208508a63bSEmilio Cota 2134ff8573SNicolas Vasilache namespace x86vector { 2234ff8573SNicolas Vasilache 2334ff8573SNicolas Vasilache /// Helper class to factor out the creation and extraction of masks from nibs. 2434ff8573SNicolas Vasilache struct MaskHelper { 25b2729fdaSNicolas Vasilache /// b0 captures the lowest bit, b7 captures the highest bit. 26b2729fdaSNicolas Vasilache /// Meant to be used with instructions such as mm256BlendPs. 27b2729fdaSNicolas Vasilache template <uint8_t b0, uint8_t b1, uint8_t b2, uint8_t b3, uint8_t b4, 28b2729fdaSNicolas Vasilache uint8_t b5, uint8_t b6, uint8_t b7> 29b2729fdaSNicolas Vasilache static uint8_t blend() { 30b2729fdaSNicolas Vasilache static_assert(b0 <= 1 && b1 <= 1 && b2 <= 1 && b3 <= 1, "overflow"); 31b2729fdaSNicolas Vasilache static_assert(b4 <= 1 && b5 <= 1 && b6 <= 1 && b7 <= 1, "overflow"); 32b2729fdaSNicolas Vasilache return static_cast<uint8_t>((b7 << 7) | (b6 << 6) | (b5 << 5) | (b4 << 4) | 33b2729fdaSNicolas Vasilache (b3 << 3) | (b2 << 2) | (b1 << 1) | b0); 34b2729fdaSNicolas Vasilache } 35b2729fdaSNicolas Vasilache /// b0 captures the lowest bit, b7 captures the highest bit. 36b2729fdaSNicolas Vasilache /// Meant to be used with instructions such as mm256BlendPs. 37b2729fdaSNicolas Vasilache static void extractBlend(uint8_t mask, uint8_t &b0, uint8_t &b1, uint8_t &b2, 38b2729fdaSNicolas Vasilache uint8_t &b3, uint8_t &b4, uint8_t &b5, uint8_t &b6, 39b2729fdaSNicolas Vasilache uint8_t &b7) { 40b2729fdaSNicolas Vasilache b7 = mask & (1 << 7); 41b2729fdaSNicolas Vasilache b6 = mask & (1 << 6); 42b2729fdaSNicolas Vasilache b5 = mask & (1 << 5); 43b2729fdaSNicolas Vasilache b4 = mask & (1 << 4); 44b2729fdaSNicolas Vasilache b3 = mask & (1 << 3); 45b2729fdaSNicolas Vasilache b2 = mask & (1 << 2); 46b2729fdaSNicolas Vasilache b1 = mask & (1 << 1); 47b2729fdaSNicolas Vasilache b0 = mask & 1; 48b2729fdaSNicolas Vasilache } 4934ff8573SNicolas Vasilache /// b01 captures the lower 2 bits, b67 captures the higher 2 bits. 5034ff8573SNicolas Vasilache /// Meant to be used with instructions such as mm256ShufflePs. 5134ff8573SNicolas Vasilache template <unsigned b67, unsigned b45, unsigned b23, unsigned b01> 52b2729fdaSNicolas Vasilache static uint8_t shuffle() { 5334ff8573SNicolas Vasilache static_assert(b01 <= 0x03, "overflow"); 5434ff8573SNicolas Vasilache static_assert(b23 <= 0x03, "overflow"); 5534ff8573SNicolas Vasilache static_assert(b45 <= 0x03, "overflow"); 5634ff8573SNicolas Vasilache static_assert(b67 <= 0x03, "overflow"); 57b2729fdaSNicolas Vasilache return static_cast<uint8_t>((b67 << 6) | (b45 << 4) | (b23 << 2) | b01); 5834ff8573SNicolas Vasilache } 5934ff8573SNicolas Vasilache /// b01 captures the lower 2 bits, b67 captures the higher 2 bits. 60b2729fdaSNicolas Vasilache static void extractShuffle(uint8_t mask, uint8_t &b01, uint8_t &b23, 61b2729fdaSNicolas Vasilache uint8_t &b45, uint8_t &b67) { 6234ff8573SNicolas Vasilache b67 = (mask & (0x03 << 6)) >> 6; 6334ff8573SNicolas Vasilache b45 = (mask & (0x03 << 4)) >> 4; 6434ff8573SNicolas Vasilache b23 = (mask & (0x03 << 2)) >> 2; 6534ff8573SNicolas Vasilache b01 = mask & 0x03; 6634ff8573SNicolas Vasilache } 6734ff8573SNicolas Vasilache /// b03 captures the lower 4 bits, b47 captures the higher 4 bits. 6834ff8573SNicolas Vasilache /// Meant to be used with instructions such as mm256Permute2f128Ps. 6934ff8573SNicolas Vasilache template <unsigned b47, unsigned b03> 70b2729fdaSNicolas Vasilache static uint8_t permute() { 7134ff8573SNicolas Vasilache static_assert(b03 <= 0x0f, "overflow"); 7234ff8573SNicolas Vasilache static_assert(b47 <= 0x0f, "overflow"); 73b2729fdaSNicolas Vasilache return static_cast<uint8_t>((b47 << 4) + b03); 7434ff8573SNicolas Vasilache } 7534ff8573SNicolas Vasilache /// b03 captures the lower 4 bits, b47 captures the higher 4 bits. 76b2729fdaSNicolas Vasilache static void extractPermute(uint8_t mask, uint8_t &b03, uint8_t &b47) { 7734ff8573SNicolas Vasilache b47 = (mask & (0x0f << 4)) >> 4; 7834ff8573SNicolas Vasilache b03 = mask & 0x0f; 7934ff8573SNicolas Vasilache } 8034ff8573SNicolas Vasilache }; 8134ff8573SNicolas Vasilache 8234ff8573SNicolas Vasilache //===----------------------------------------------------------------------===// 8334ff8573SNicolas Vasilache /// Helpers extracted from: 8434ff8573SNicolas Vasilache /// - clang/lib/Headers/avxintrin.h 8534ff8573SNicolas Vasilache /// - clang/test/CodeGen/X86/avx-builtins.c 8634ff8573SNicolas Vasilache /// - clang/test/CodeGen/X86/avx2-builtins.c 8734ff8573SNicolas Vasilache /// - clang/test/CodeGen/X86/avx-shuffle-builtins.c 8834ff8573SNicolas Vasilache /// as well as the Intel Intrinsics Guide 8934ff8573SNicolas Vasilache /// (https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html) 9034ff8573SNicolas Vasilache /// make it easier to just implement known good lowerings. 9134ff8573SNicolas Vasilache /// All intrinsics correspond 1-1 to the Intel definition. 9234ff8573SNicolas Vasilache //===----------------------------------------------------------------------===// 9334ff8573SNicolas Vasilache 9434ff8573SNicolas Vasilache namespace avx2 { 9534ff8573SNicolas Vasilache 96b2729fdaSNicolas Vasilache namespace inline_asm { 97b2729fdaSNicolas Vasilache //===----------------------------------------------------------------------===// 98b2729fdaSNicolas Vasilache /// Methods in the inline_asm namespace emit calls to LLVM::InlineAsmOp. 99b2729fdaSNicolas Vasilache //===----------------------------------------------------------------------===// 100b2729fdaSNicolas Vasilache /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2. 101b2729fdaSNicolas Vasilache Value mm256BlendPsAsm(ImplicitLocOpBuilder &b, Value v1, Value v2, 102b2729fdaSNicolas Vasilache uint8_t mask); 103b2729fdaSNicolas Vasilache 104b2729fdaSNicolas Vasilache } // namespace inline_asm 105b2729fdaSNicolas Vasilache 106b2729fdaSNicolas Vasilache namespace intrin { 107b2729fdaSNicolas Vasilache //===----------------------------------------------------------------------===// 108b2729fdaSNicolas Vasilache /// Methods in the intrin namespace emulate clang's impl. of X86 intrinsics. 109b2729fdaSNicolas Vasilache //===----------------------------------------------------------------------===// 11034ff8573SNicolas Vasilache /// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13]. 11134ff8573SNicolas Vasilache Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2); 11234ff8573SNicolas Vasilache 11334ff8573SNicolas Vasilache /// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13]. 11434ff8573SNicolas Vasilache Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2); 11534ff8573SNicolas Vasilache 11634ff8573SNicolas Vasilache /// a a b b a a b b 11734ff8573SNicolas Vasilache /// Take an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4): 11834ff8573SNicolas Vasilache /// 0:127 | 128:255 11934ff8573SNicolas Vasilache /// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4 120b2729fdaSNicolas Vasilache Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask); 12134ff8573SNicolas Vasilache 12234ff8573SNicolas Vasilache // imm[0:1] out of imm[0:3] is: 12334ff8573SNicolas Vasilache // 0 1 2 3 12434ff8573SNicolas Vasilache // a[0:127] or a[128:255] or b[0:127] or b[128:255] | 12534ff8573SNicolas Vasilache // a[0:127] or a[128:255] or b[0:127] or b[128:255] 12634ff8573SNicolas Vasilache // 0 1 2 3 12734ff8573SNicolas Vasilache // imm[0:1] out of imm[4:7]. 12834ff8573SNicolas Vasilache Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2, 129b2729fdaSNicolas Vasilache uint8_t mask); 13034ff8573SNicolas Vasilache 131b2729fdaSNicolas Vasilache /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2. 132b2729fdaSNicolas Vasilache Value mm256BlendPs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask); 133b2729fdaSNicolas Vasilache } // namespace intrin 134b2729fdaSNicolas Vasilache 135b2729fdaSNicolas Vasilache //===----------------------------------------------------------------------===// 136b2729fdaSNicolas Vasilache /// Generic lowerings may either use intrin or inline_asm depending on needs. 137b2729fdaSNicolas Vasilache //===----------------------------------------------------------------------===// 13834ff8573SNicolas Vasilache /// 4x8xf32-specific AVX2 transpose lowering. 13934ff8573SNicolas Vasilache void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs); 14034ff8573SNicolas Vasilache 14134ff8573SNicolas Vasilache /// 8x8xf32-specific AVX2 transpose lowering. 14234ff8573SNicolas Vasilache void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs); 14334ff8573SNicolas Vasilache 14434ff8573SNicolas Vasilache /// Structure to control the behavior of specialized AVX2 transpose lowering. 14534ff8573SNicolas Vasilache struct TransposeLoweringOptions { 14634ff8573SNicolas Vasilache bool lower4x8xf32_ = false; 14734ff8573SNicolas Vasilache TransposeLoweringOptions &lower4x8xf32(bool lower = true) { 14834ff8573SNicolas Vasilache lower4x8xf32_ = lower; 14934ff8573SNicolas Vasilache return *this; 15034ff8573SNicolas Vasilache } 15134ff8573SNicolas Vasilache bool lower8x8xf32_ = false; 15234ff8573SNicolas Vasilache TransposeLoweringOptions &lower8x8xf32(bool lower = true) { 15334ff8573SNicolas Vasilache lower8x8xf32_ = lower; 15434ff8573SNicolas Vasilache return *this; 15534ff8573SNicolas Vasilache } 15634ff8573SNicolas Vasilache }; 15734ff8573SNicolas Vasilache 15834ff8573SNicolas Vasilache /// Options for controlling specialized AVX2 lowerings. 15934ff8573SNicolas Vasilache struct LoweringOptions { 16034ff8573SNicolas Vasilache /// Configure specialized vector lowerings. 16134ff8573SNicolas Vasilache TransposeLoweringOptions transposeOptions; 16234ff8573SNicolas Vasilache LoweringOptions &setTransposeOptions(TransposeLoweringOptions options) { 16334ff8573SNicolas Vasilache transposeOptions = options; 16434ff8573SNicolas Vasilache return *this; 16534ff8573SNicolas Vasilache } 16634ff8573SNicolas Vasilache }; 16734ff8573SNicolas Vasilache 16834ff8573SNicolas Vasilache /// Insert specialized transpose lowering patterns. 16934ff8573SNicolas Vasilache void populateSpecializedTransposeLoweringPatterns( 17034ff8573SNicolas Vasilache RewritePatternSet &patterns, LoweringOptions options = LoweringOptions(), 17134ff8573SNicolas Vasilache int benefit = 10); 17234ff8573SNicolas Vasilache 17334ff8573SNicolas Vasilache } // namespace avx2 17434ff8573SNicolas Vasilache } // namespace x86vector 17534ff8573SNicolas Vasilache 1768508a63bSEmilio Cota /// Collect a set of patterns to lower X86Vector ops to ops that map to LLVM 1778508a63bSEmilio Cota /// intrinsics. 1788508a63bSEmilio Cota void populateX86VectorLegalizeForLLVMExportPatterns( 179*206fad0eSMatthias Springer const LLVMTypeConverter &converter, RewritePatternSet &patterns); 1808508a63bSEmilio Cota 1818508a63bSEmilio Cota /// Configure the target to support lowering X86Vector ops to ops that map to 1828508a63bSEmilio Cota /// LLVM intrinsics. 1838508a63bSEmilio Cota void configureX86VectorLegalizeForExportTarget(LLVMConversionTarget &target); 1848508a63bSEmilio Cota 1858508a63bSEmilio Cota } // namespace mlir 1868508a63bSEmilio Cota 1878508a63bSEmilio Cota #endif // MLIR_DIALECT_X86VECTOR_TRANSFORMS_H 188