xref: /llvm-project/mlir/include/mlir/Dialect/X86Vector/Transforms.h (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1 //=- Transforms.h - X86Vector Dialect Transformation Entrypoints -*- C++ -*-=//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
10 #define MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
11 
12 #include "mlir/IR/Value.h"
13 
14 namespace mlir {
15 
16 class ImplicitLocOpBuilder;
17 class LLVMConversionTarget;
18 class LLVMTypeConverter;
19 class RewritePatternSet;
20 
21 namespace x86vector {
22 
23 /// Helper class to factor out the creation and extraction of masks from nibs.
24 struct MaskHelper {
25   /// b0 captures the lowest bit, b7 captures the highest bit.
26   /// Meant to be used with instructions such as mm256BlendPs.
27   template <uint8_t b0, uint8_t b1, uint8_t b2, uint8_t b3, uint8_t b4,
28             uint8_t b5, uint8_t b6, uint8_t b7>
29   static uint8_t blend() {
30     static_assert(b0 <= 1 && b1 <= 1 && b2 <= 1 && b3 <= 1, "overflow");
31     static_assert(b4 <= 1 && b5 <= 1 && b6 <= 1 && b7 <= 1, "overflow");
32     return static_cast<uint8_t>((b7 << 7) | (b6 << 6) | (b5 << 5) | (b4 << 4) |
33                                 (b3 << 3) | (b2 << 2) | (b1 << 1) | b0);
34   }
35   /// b0 captures the lowest bit, b7 captures the highest bit.
36   /// Meant to be used with instructions such as mm256BlendPs.
37   static void extractBlend(uint8_t mask, uint8_t &b0, uint8_t &b1, uint8_t &b2,
38                            uint8_t &b3, uint8_t &b4, uint8_t &b5, uint8_t &b6,
39                            uint8_t &b7) {
40     b7 = mask & (1 << 7);
41     b6 = mask & (1 << 6);
42     b5 = mask & (1 << 5);
43     b4 = mask & (1 << 4);
44     b3 = mask & (1 << 3);
45     b2 = mask & (1 << 2);
46     b1 = mask & (1 << 1);
47     b0 = mask & 1;
48   }
49   /// b01 captures the lower 2 bits, b67 captures the higher 2 bits.
50   /// Meant to be used with instructions such as mm256ShufflePs.
51   template <unsigned b67, unsigned b45, unsigned b23, unsigned b01>
52   static uint8_t shuffle() {
53     static_assert(b01 <= 0x03, "overflow");
54     static_assert(b23 <= 0x03, "overflow");
55     static_assert(b45 <= 0x03, "overflow");
56     static_assert(b67 <= 0x03, "overflow");
57     return static_cast<uint8_t>((b67 << 6) | (b45 << 4) | (b23 << 2) | b01);
58   }
59   /// b01 captures the lower 2 bits, b67 captures the higher 2 bits.
60   static void extractShuffle(uint8_t mask, uint8_t &b01, uint8_t &b23,
61                              uint8_t &b45, uint8_t &b67) {
62     b67 = (mask & (0x03 << 6)) >> 6;
63     b45 = (mask & (0x03 << 4)) >> 4;
64     b23 = (mask & (0x03 << 2)) >> 2;
65     b01 = mask & 0x03;
66   }
67   /// b03 captures the lower 4 bits, b47 captures the higher 4 bits.
68   /// Meant to be used with instructions such as mm256Permute2f128Ps.
69   template <unsigned b47, unsigned b03>
70   static uint8_t permute() {
71     static_assert(b03 <= 0x0f, "overflow");
72     static_assert(b47 <= 0x0f, "overflow");
73     return static_cast<uint8_t>((b47 << 4) + b03);
74   }
75   /// b03 captures the lower 4 bits, b47 captures the higher 4 bits.
76   static void extractPermute(uint8_t mask, uint8_t &b03, uint8_t &b47) {
77     b47 = (mask & (0x0f << 4)) >> 4;
78     b03 = mask & 0x0f;
79   }
80 };
81 
82 //===----------------------------------------------------------------------===//
83 /// Helpers extracted from:
84 ///   - clang/lib/Headers/avxintrin.h
85 ///   - clang/test/CodeGen/X86/avx-builtins.c
86 ///   - clang/test/CodeGen/X86/avx2-builtins.c
87 ///   - clang/test/CodeGen/X86/avx-shuffle-builtins.c
88 /// as well as the Intel Intrinsics Guide
89 /// (https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html)
90 /// make it easier to just implement known good lowerings.
91 /// All intrinsics correspond 1-1 to the Intel definition.
92 //===----------------------------------------------------------------------===//
93 
94 namespace avx2 {
95 
96 namespace inline_asm {
97 //===----------------------------------------------------------------------===//
98 /// Methods in the inline_asm namespace  emit calls to LLVM::InlineAsmOp.
99 //===----------------------------------------------------------------------===//
100 /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
101 Value mm256BlendPsAsm(ImplicitLocOpBuilder &b, Value v1, Value v2,
102                       uint8_t mask);
103 
104 } // namespace inline_asm
105 
106 namespace intrin {
107 //===----------------------------------------------------------------------===//
108 /// Methods in the intrin namespace emulate clang's impl. of X86 intrinsics.
109 //===----------------------------------------------------------------------===//
110 /// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
111 Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2);
112 
113 /// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
114 Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2);
115 
116 ///                            a  a   b   b  a  a   b   b
117 /// Take an 8 bit mask, 2 bit for each position of a[0, 3)  **and** b[0, 4):
118 ///                                 0:127    |         128:255
119 ///                            b01  b23  C8  D8  |  b01+4 b23+4 C8+4 D8+4
120 Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask);
121 
122 // imm[0:1] out of imm[0:3] is:
123 //    0             1           2             3
124 // a[0:127] or a[128:255] or b[0:127] or b[128:255]    |
125 //          a[0:127] or a[128:255] or b[0:127] or b[128:255]
126 //             0             1           2             3
127 // imm[0:1] out of imm[4:7].
128 Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2,
129                           uint8_t mask);
130 
131 /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
132 Value mm256BlendPs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask);
133 } // namespace intrin
134 
135 //===----------------------------------------------------------------------===//
136 /// Generic lowerings may either use intrin or inline_asm depending on needs.
137 //===----------------------------------------------------------------------===//
138 /// 4x8xf32-specific AVX2 transpose lowering.
139 void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs);
140 
141 /// 8x8xf32-specific AVX2 transpose lowering.
142 void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs);
143 
144 /// Structure to control the behavior of specialized AVX2 transpose lowering.
145 struct TransposeLoweringOptions {
146   bool lower4x8xf32_ = false;
147   TransposeLoweringOptions &lower4x8xf32(bool lower = true) {
148     lower4x8xf32_ = lower;
149     return *this;
150   }
151   bool lower8x8xf32_ = false;
152   TransposeLoweringOptions &lower8x8xf32(bool lower = true) {
153     lower8x8xf32_ = lower;
154     return *this;
155   }
156 };
157 
158 /// Options for controlling specialized AVX2 lowerings.
159 struct LoweringOptions {
160   /// Configure specialized vector lowerings.
161   TransposeLoweringOptions transposeOptions;
162   LoweringOptions &setTransposeOptions(TransposeLoweringOptions options) {
163     transposeOptions = options;
164     return *this;
165   }
166 };
167 
168 /// Insert specialized transpose lowering patterns.
169 void populateSpecializedTransposeLoweringPatterns(
170     RewritePatternSet &patterns, LoweringOptions options = LoweringOptions(),
171     int benefit = 10);
172 
173 } // namespace avx2
174 } // namespace x86vector
175 
176 /// Collect a set of patterns to lower X86Vector ops to ops that map to LLVM
177 /// intrinsics.
178 void populateX86VectorLegalizeForLLVMExportPatterns(
179     const LLVMTypeConverter &converter, RewritePatternSet &patterns);
180 
181 /// Configure the target to support lowering X86Vector ops to ops that map to
182 /// LLVM intrinsics.
183 void configureX86VectorLegalizeForExportTarget(LLVMConversionTarget &target);
184 
185 } // namespace mlir
186 
187 #endif // MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
188