xref: /llvm-project/mlir/include/mlir/Dialect/X86Vector/Transforms.h (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
18508a63bSEmilio Cota //=- Transforms.h - X86Vector Dialect Transformation Entrypoints -*- C++ -*-=//
28508a63bSEmilio Cota //
38508a63bSEmilio Cota // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48508a63bSEmilio Cota // See https://llvm.org/LICENSE.txt for license information.
58508a63bSEmilio Cota // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68508a63bSEmilio Cota //
78508a63bSEmilio Cota //===----------------------------------------------------------------------===//
88508a63bSEmilio Cota 
98508a63bSEmilio Cota #ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
108508a63bSEmilio Cota #define MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
118508a63bSEmilio Cota 
1234ff8573SNicolas Vasilache #include "mlir/IR/Value.h"
1334ff8573SNicolas Vasilache 
148508a63bSEmilio Cota namespace mlir {
158508a63bSEmilio Cota 
1634ff8573SNicolas Vasilache class ImplicitLocOpBuilder;
178508a63bSEmilio Cota class LLVMConversionTarget;
188508a63bSEmilio Cota class LLVMTypeConverter;
198508a63bSEmilio Cota class RewritePatternSet;
208508a63bSEmilio Cota 
2134ff8573SNicolas Vasilache namespace x86vector {
2234ff8573SNicolas Vasilache 
2334ff8573SNicolas Vasilache /// Helper class to factor out the creation and extraction of masks from nibs.
2434ff8573SNicolas Vasilache struct MaskHelper {
25b2729fdaSNicolas Vasilache   /// b0 captures the lowest bit, b7 captures the highest bit.
26b2729fdaSNicolas Vasilache   /// Meant to be used with instructions such as mm256BlendPs.
27b2729fdaSNicolas Vasilache   template <uint8_t b0, uint8_t b1, uint8_t b2, uint8_t b3, uint8_t b4,
28b2729fdaSNicolas Vasilache             uint8_t b5, uint8_t b6, uint8_t b7>
29b2729fdaSNicolas Vasilache   static uint8_t blend() {
30b2729fdaSNicolas Vasilache     static_assert(b0 <= 1 && b1 <= 1 && b2 <= 1 && b3 <= 1, "overflow");
31b2729fdaSNicolas Vasilache     static_assert(b4 <= 1 && b5 <= 1 && b6 <= 1 && b7 <= 1, "overflow");
32b2729fdaSNicolas Vasilache     return static_cast<uint8_t>((b7 << 7) | (b6 << 6) | (b5 << 5) | (b4 << 4) |
33b2729fdaSNicolas Vasilache                                 (b3 << 3) | (b2 << 2) | (b1 << 1) | b0);
34b2729fdaSNicolas Vasilache   }
35b2729fdaSNicolas Vasilache   /// b0 captures the lowest bit, b7 captures the highest bit.
36b2729fdaSNicolas Vasilache   /// Meant to be used with instructions such as mm256BlendPs.
37b2729fdaSNicolas Vasilache   static void extractBlend(uint8_t mask, uint8_t &b0, uint8_t &b1, uint8_t &b2,
38b2729fdaSNicolas Vasilache                            uint8_t &b3, uint8_t &b4, uint8_t &b5, uint8_t &b6,
39b2729fdaSNicolas Vasilache                            uint8_t &b7) {
40b2729fdaSNicolas Vasilache     b7 = mask & (1 << 7);
41b2729fdaSNicolas Vasilache     b6 = mask & (1 << 6);
42b2729fdaSNicolas Vasilache     b5 = mask & (1 << 5);
43b2729fdaSNicolas Vasilache     b4 = mask & (1 << 4);
44b2729fdaSNicolas Vasilache     b3 = mask & (1 << 3);
45b2729fdaSNicolas Vasilache     b2 = mask & (1 << 2);
46b2729fdaSNicolas Vasilache     b1 = mask & (1 << 1);
47b2729fdaSNicolas Vasilache     b0 = mask & 1;
48b2729fdaSNicolas Vasilache   }
4934ff8573SNicolas Vasilache   /// b01 captures the lower 2 bits, b67 captures the higher 2 bits.
5034ff8573SNicolas Vasilache   /// Meant to be used with instructions such as mm256ShufflePs.
5134ff8573SNicolas Vasilache   template <unsigned b67, unsigned b45, unsigned b23, unsigned b01>
52b2729fdaSNicolas Vasilache   static uint8_t shuffle() {
5334ff8573SNicolas Vasilache     static_assert(b01 <= 0x03, "overflow");
5434ff8573SNicolas Vasilache     static_assert(b23 <= 0x03, "overflow");
5534ff8573SNicolas Vasilache     static_assert(b45 <= 0x03, "overflow");
5634ff8573SNicolas Vasilache     static_assert(b67 <= 0x03, "overflow");
57b2729fdaSNicolas Vasilache     return static_cast<uint8_t>((b67 << 6) | (b45 << 4) | (b23 << 2) | b01);
5834ff8573SNicolas Vasilache   }
5934ff8573SNicolas Vasilache   /// b01 captures the lower 2 bits, b67 captures the higher 2 bits.
60b2729fdaSNicolas Vasilache   static void extractShuffle(uint8_t mask, uint8_t &b01, uint8_t &b23,
61b2729fdaSNicolas Vasilache                              uint8_t &b45, uint8_t &b67) {
6234ff8573SNicolas Vasilache     b67 = (mask & (0x03 << 6)) >> 6;
6334ff8573SNicolas Vasilache     b45 = (mask & (0x03 << 4)) >> 4;
6434ff8573SNicolas Vasilache     b23 = (mask & (0x03 << 2)) >> 2;
6534ff8573SNicolas Vasilache     b01 = mask & 0x03;
6634ff8573SNicolas Vasilache   }
6734ff8573SNicolas Vasilache   /// b03 captures the lower 4 bits, b47 captures the higher 4 bits.
6834ff8573SNicolas Vasilache   /// Meant to be used with instructions such as mm256Permute2f128Ps.
6934ff8573SNicolas Vasilache   template <unsigned b47, unsigned b03>
70b2729fdaSNicolas Vasilache   static uint8_t permute() {
7134ff8573SNicolas Vasilache     static_assert(b03 <= 0x0f, "overflow");
7234ff8573SNicolas Vasilache     static_assert(b47 <= 0x0f, "overflow");
73b2729fdaSNicolas Vasilache     return static_cast<uint8_t>((b47 << 4) + b03);
7434ff8573SNicolas Vasilache   }
7534ff8573SNicolas Vasilache   /// b03 captures the lower 4 bits, b47 captures the higher 4 bits.
76b2729fdaSNicolas Vasilache   static void extractPermute(uint8_t mask, uint8_t &b03, uint8_t &b47) {
7734ff8573SNicolas Vasilache     b47 = (mask & (0x0f << 4)) >> 4;
7834ff8573SNicolas Vasilache     b03 = mask & 0x0f;
7934ff8573SNicolas Vasilache   }
8034ff8573SNicolas Vasilache };
8134ff8573SNicolas Vasilache 
8234ff8573SNicolas Vasilache //===----------------------------------------------------------------------===//
8334ff8573SNicolas Vasilache /// Helpers extracted from:
8434ff8573SNicolas Vasilache ///   - clang/lib/Headers/avxintrin.h
8534ff8573SNicolas Vasilache ///   - clang/test/CodeGen/X86/avx-builtins.c
8634ff8573SNicolas Vasilache ///   - clang/test/CodeGen/X86/avx2-builtins.c
8734ff8573SNicolas Vasilache ///   - clang/test/CodeGen/X86/avx-shuffle-builtins.c
8834ff8573SNicolas Vasilache /// as well as the Intel Intrinsics Guide
8934ff8573SNicolas Vasilache /// (https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html)
9034ff8573SNicolas Vasilache /// make it easier to just implement known good lowerings.
9134ff8573SNicolas Vasilache /// All intrinsics correspond 1-1 to the Intel definition.
9234ff8573SNicolas Vasilache //===----------------------------------------------------------------------===//
9334ff8573SNicolas Vasilache 
9434ff8573SNicolas Vasilache namespace avx2 {
9534ff8573SNicolas Vasilache 
96b2729fdaSNicolas Vasilache namespace inline_asm {
97b2729fdaSNicolas Vasilache //===----------------------------------------------------------------------===//
98b2729fdaSNicolas Vasilache /// Methods in the inline_asm namespace  emit calls to LLVM::InlineAsmOp.
99b2729fdaSNicolas Vasilache //===----------------------------------------------------------------------===//
100b2729fdaSNicolas Vasilache /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
101b2729fdaSNicolas Vasilache Value mm256BlendPsAsm(ImplicitLocOpBuilder &b, Value v1, Value v2,
102b2729fdaSNicolas Vasilache                       uint8_t mask);
103b2729fdaSNicolas Vasilache 
104b2729fdaSNicolas Vasilache } // namespace inline_asm
105b2729fdaSNicolas Vasilache 
106b2729fdaSNicolas Vasilache namespace intrin {
107b2729fdaSNicolas Vasilache //===----------------------------------------------------------------------===//
108b2729fdaSNicolas Vasilache /// Methods in the intrin namespace emulate clang's impl. of X86 intrinsics.
109b2729fdaSNicolas Vasilache //===----------------------------------------------------------------------===//
11034ff8573SNicolas Vasilache /// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
11134ff8573SNicolas Vasilache Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2);
11234ff8573SNicolas Vasilache 
11334ff8573SNicolas Vasilache /// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
11434ff8573SNicolas Vasilache Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2);
11534ff8573SNicolas Vasilache 
11634ff8573SNicolas Vasilache ///                            a  a   b   b  a  a   b   b
11734ff8573SNicolas Vasilache /// Take an 8 bit mask, 2 bit for each position of a[0, 3)  **and** b[0, 4):
11834ff8573SNicolas Vasilache ///                                 0:127    |         128:255
11934ff8573SNicolas Vasilache ///                            b01  b23  C8  D8  |  b01+4 b23+4 C8+4 D8+4
120b2729fdaSNicolas Vasilache Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask);
12134ff8573SNicolas Vasilache 
12234ff8573SNicolas Vasilache // imm[0:1] out of imm[0:3] is:
12334ff8573SNicolas Vasilache //    0             1           2             3
12434ff8573SNicolas Vasilache // a[0:127] or a[128:255] or b[0:127] or b[128:255]    |
12534ff8573SNicolas Vasilache //          a[0:127] or a[128:255] or b[0:127] or b[128:255]
12634ff8573SNicolas Vasilache //             0             1           2             3
12734ff8573SNicolas Vasilache // imm[0:1] out of imm[4:7].
12834ff8573SNicolas Vasilache Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2,
129b2729fdaSNicolas Vasilache                           uint8_t mask);
13034ff8573SNicolas Vasilache 
131b2729fdaSNicolas Vasilache /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
132b2729fdaSNicolas Vasilache Value mm256BlendPs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask);
133b2729fdaSNicolas Vasilache } // namespace intrin
134b2729fdaSNicolas Vasilache 
135b2729fdaSNicolas Vasilache //===----------------------------------------------------------------------===//
136b2729fdaSNicolas Vasilache /// Generic lowerings may either use intrin or inline_asm depending on needs.
137b2729fdaSNicolas Vasilache //===----------------------------------------------------------------------===//
13834ff8573SNicolas Vasilache /// 4x8xf32-specific AVX2 transpose lowering.
13934ff8573SNicolas Vasilache void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs);
14034ff8573SNicolas Vasilache 
14134ff8573SNicolas Vasilache /// 8x8xf32-specific AVX2 transpose lowering.
14234ff8573SNicolas Vasilache void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs);
14334ff8573SNicolas Vasilache 
14434ff8573SNicolas Vasilache /// Structure to control the behavior of specialized AVX2 transpose lowering.
14534ff8573SNicolas Vasilache struct TransposeLoweringOptions {
14634ff8573SNicolas Vasilache   bool lower4x8xf32_ = false;
14734ff8573SNicolas Vasilache   TransposeLoweringOptions &lower4x8xf32(bool lower = true) {
14834ff8573SNicolas Vasilache     lower4x8xf32_ = lower;
14934ff8573SNicolas Vasilache     return *this;
15034ff8573SNicolas Vasilache   }
15134ff8573SNicolas Vasilache   bool lower8x8xf32_ = false;
15234ff8573SNicolas Vasilache   TransposeLoweringOptions &lower8x8xf32(bool lower = true) {
15334ff8573SNicolas Vasilache     lower8x8xf32_ = lower;
15434ff8573SNicolas Vasilache     return *this;
15534ff8573SNicolas Vasilache   }
15634ff8573SNicolas Vasilache };
15734ff8573SNicolas Vasilache 
15834ff8573SNicolas Vasilache /// Options for controlling specialized AVX2 lowerings.
15934ff8573SNicolas Vasilache struct LoweringOptions {
16034ff8573SNicolas Vasilache   /// Configure specialized vector lowerings.
16134ff8573SNicolas Vasilache   TransposeLoweringOptions transposeOptions;
16234ff8573SNicolas Vasilache   LoweringOptions &setTransposeOptions(TransposeLoweringOptions options) {
16334ff8573SNicolas Vasilache     transposeOptions = options;
16434ff8573SNicolas Vasilache     return *this;
16534ff8573SNicolas Vasilache   }
16634ff8573SNicolas Vasilache };
16734ff8573SNicolas Vasilache 
16834ff8573SNicolas Vasilache /// Insert specialized transpose lowering patterns.
16934ff8573SNicolas Vasilache void populateSpecializedTransposeLoweringPatterns(
17034ff8573SNicolas Vasilache     RewritePatternSet &patterns, LoweringOptions options = LoweringOptions(),
17134ff8573SNicolas Vasilache     int benefit = 10);
17234ff8573SNicolas Vasilache 
17334ff8573SNicolas Vasilache } // namespace avx2
17434ff8573SNicolas Vasilache } // namespace x86vector
17534ff8573SNicolas Vasilache 
1768508a63bSEmilio Cota /// Collect a set of patterns to lower X86Vector ops to ops that map to LLVM
1778508a63bSEmilio Cota /// intrinsics.
1788508a63bSEmilio Cota void populateX86VectorLegalizeForLLVMExportPatterns(
179*206fad0eSMatthias Springer     const LLVMTypeConverter &converter, RewritePatternSet &patterns);
1808508a63bSEmilio Cota 
1818508a63bSEmilio Cota /// Configure the target to support lowering X86Vector ops to ops that map to
1828508a63bSEmilio Cota /// LLVM intrinsics.
1838508a63bSEmilio Cota void configureX86VectorLegalizeForExportTarget(LLVMConversionTarget &target);
1848508a63bSEmilio Cota 
1858508a63bSEmilio Cota } // namespace mlir
1868508a63bSEmilio Cota 
1878508a63bSEmilio Cota #endif // MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
188