xref: /llvm-project/mlir/include/mlir/Dialect/X86Vector/X86Vector.td (revision 87782b216fd3e7a8f8b2de04d4af467b390e9a34)
1//===-- X86VectorOps.td - X86Vector dialect operation defs -*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the basic operations for the X86Vector dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef X86VECTOR_OPS
14#define X86VECTOR_OPS
15
16include "mlir/Interfaces/InferTypeOpInterface.td"
17include "mlir/Interfaces/SideEffectInterfaces.td"
18include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
19
20//===----------------------------------------------------------------------===//
21// X86Vector dialect definition
22//===----------------------------------------------------------------------===//
23
24def X86Vector_Dialect : Dialect {
25  let name = "x86vector";
26  let cppNamespace = "::mlir::x86vector";
27}
28
29//===----------------------------------------------------------------------===//
30// AVX512 op definitions
31//===----------------------------------------------------------------------===//
32
33// Operation that is part of the input dialect.
34class AVX512_Op<string mnemonic, list<Trait> traits = []> :
35  Op<X86Vector_Dialect, "avx512." # mnemonic, traits> {}
36
37// Intrinsic operation used during lowering to LLVM IR.
38class AVX512_IntrOp<string mnemonic, int numResults,
39                    list<Trait> traits = [],
40                    string extension = ""> :
41  LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic,
42                  !subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic),
43                  [], [], traits, numResults>;
44
45// Defined by first result overload. May have to be extended for other
46// instructions in the future.
47class AVX512_IntrOverloadedOp<string mnemonic,
48                              list<Trait> traits = [],
49                              string extension = ""> :
50  LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic,
51                  !subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic),
52                  /*list<int> overloadedResults=*/[0],
53                  /*list<int> overloadedOperands=*/[],
54                  traits, /*numResults=*/1>;
55
56//----------------------------------------------------------------------------//
57// MaskCompressOp
58//----------------------------------------------------------------------------//
59
60def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
61  // TODO: Support optional arguments in `AllTypesMatch`. "type($src)" could
62  // then be removed from assemblyFormat.
63  AllTypesMatch<["a", "dst"]>,
64  TypesMatchWith<"`k` has the same number of bits as elements in `dst`",
65                 "dst", "k",
66                 "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
67                 "IntegerType::get($_self.getContext(), 1))">]> {
68  let summary = "Masked compress op";
69  let description = [{
70  The mask.compress op is an AVX512 specific op that can lower to the
71  `llvm.mask.compress` instruction. Instead of `src`, a constant vector
72  vector attribute `constant_src` may be specified. If neither `src` nor
73  `constant_src` is specified, the remaining elements in the result vector are
74  set to zero.
75
76  #### From the Intel Intrinsics Guide:
77
78  Contiguously store the active integer/floating-point elements in `a` (those
79  with their respective bit set in writemask `k`) to `dst`, and pass through the
80  remaining elements from `src`.
81  }];
82  let arguments = (ins VectorOfLengthAndType<[16, 8],
83                                             [I1]>:$k,
84                   VectorOfLengthAndType<[16, 8],
85                                         [F32, I32, F64, I64]>:$a,
86                   Optional<VectorOfLengthAndType<[16, 8],
87                                                  [F32, I32, F64, I64]>>:$src,
88                   OptionalAttr<ElementsAttr>:$constant_src);
89  let results = (outs VectorOfLengthAndType<[16, 8],
90                                            [F32, I32, F64, I64]>:$dst);
91  let assemblyFormat = "$k `,` $a (`,` $src^)? attr-dict"
92                       " `:` type($dst) (`,` type($src)^)?";
93  let hasVerifier = 1;
94}
95
96def MaskCompressIntrOp : AVX512_IntrOverloadedOp<"mask.compress", [
97  Pure,
98  AllTypesMatch<["a", "src", "res"]>,
99  TypesMatchWith<"`k` has the same number of bits as elements in `res`",
100                 "res", "k",
101                 "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
102                 "IntegerType::get($_self.getContext(), 1))">]> {
103  let arguments = (ins VectorOfLengthAndType<[16, 8],
104                                             [F32, I32, F64, I64]>:$a,
105                   VectorOfLengthAndType<[16, 8],
106                                         [F32, I32, F64, I64]>:$src,
107                   VectorOfLengthAndType<[16, 8],
108                                         [I1]>:$k);
109}
110
111//----------------------------------------------------------------------------//
112// MaskRndScaleOp
113//----------------------------------------------------------------------------//
114
115def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure,
116  AllTypesMatch<["src", "a", "dst"]>,
117  TypesMatchWith<"imm has the same number of bits as elements in dst",
118                 "dst", "imm",
119                 "IntegerType::get($_self.getContext(), "
120                 "(::llvm::cast<VectorType>($_self).getShape()[0]))">]> {
121  let summary = "Masked roundscale op";
122  let description = [{
123    The mask.rndscale op is an AVX512 specific op that can lower to the proper
124    LLVMAVX512 operation: `llvm.mask.rndscale.ps.512` or
125    `llvm.mask.rndscale.pd.512` instruction depending on the type of vectors it
126    is applied to.
127
128    #### From the Intel Intrinsics Guide:
129
130    Round packed floating-point elements in `a` to the number of fraction bits
131    specified by `imm`, and store the results in `dst` using writemask `k`
132    (elements are copied from src when the corresponding mask bit is not set).
133  }];
134  // Supports vector<16xf32> and vector<8xf64>.
135  let arguments = (ins VectorOfLengthAndType<[16, 8], [F32, F64]>:$src,
136                   I32:$k,
137                   VectorOfLengthAndType<[16, 8], [F32, F64]>:$a,
138                   AnyTypeOf<[I16, I8]>:$imm,
139                   // TODO: figure rounding out (optional operand?).
140                   I32:$rounding
141            );
142  let results = (outs VectorOfLengthAndType<[16, 8], [F32, F64]>:$dst);
143  let assemblyFormat =
144    "$src `,` $k `,` $a `,` $imm `,` $rounding attr-dict `:` type($dst)";
145}
146
147def MaskRndScalePSIntrOp : AVX512_IntrOp<"mask.rndscale.ps.512", 1, [
148  Pure,
149  AllTypesMatch<["src", "a", "res"]>]> {
150  let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
151                   I32:$k,
152                   VectorOfLengthAndType<[16], [F32]>:$a,
153                   I16:$imm,
154                   LLVM_Type:$rounding);
155}
156
157def MaskRndScalePDIntrOp : AVX512_IntrOp<"mask.rndscale.pd.512", 1, [
158  Pure,
159  AllTypesMatch<["src", "a", "res"]>]> {
160  let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src,
161                   I32:$k,
162                   VectorOfLengthAndType<[8], [F64]>:$a,
163                   I8:$imm,
164                   LLVM_Type:$rounding);
165}
166
167//----------------------------------------------------------------------------//
168// MaskScaleFOp
169//----------------------------------------------------------------------------//
170
171def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure,
172  AllTypesMatch<["src", "a", "b", "dst"]>,
173  TypesMatchWith<"k has the same number of bits as elements in dst",
174                 "dst", "k",
175                 "IntegerType::get($_self.getContext(), "
176                 "(::llvm::cast<VectorType>($_self).getShape()[0]))">]> {
177  let summary = "ScaleF op";
178  let description = [{
179    The `mask.scalef` op is an AVX512 specific op that can lower to the proper
180    LLVMAVX512 operation: `llvm.mask.scalef.ps.512` or
181    `llvm.mask.scalef.pd.512` depending on the type of MLIR vectors it is
182    applied to.
183
184    #### From the Intel Intrinsics Guide:
185
186    Scale the packed floating-point elements in `a` using values from `b`, and
187    store the results in `dst` using writemask `k` (elements are copied from src
188    when the corresponding mask bit is not set).
189  }];
190  // Supports vector<16xf32> and vector<8xf64>.
191  let arguments = (ins VectorOfLengthAndType<[16, 8], [F32, F64]>:$src,
192                   VectorOfLengthAndType<[16, 8], [F32, F64]>:$a,
193                   VectorOfLengthAndType<[16, 8], [F32, F64]>:$b,
194                   AnyTypeOf<[I16, I8]>:$k,
195                   // TODO: figure rounding out (optional operand?).
196                   I32:$rounding
197            );
198  let results = (outs VectorOfLengthAndType<[16, 8], [F32, F64]>:$dst);
199  // Fully specified by traits.
200  let assemblyFormat =
201    "$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst)";
202}
203
204def MaskScaleFPSIntrOp : AVX512_IntrOp<"mask.scalef.ps.512", 1, [
205  Pure,
206  AllTypesMatch<["src", "a", "b", "res"]>]> {
207  let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
208                   VectorOfLengthAndType<[16], [F32]>:$a,
209                   VectorOfLengthAndType<[16], [F32]>:$b,
210                   I16:$k,
211                   LLVM_Type:$rounding);
212}
213
214def MaskScaleFPDIntrOp : AVX512_IntrOp<"mask.scalef.pd.512", 1, [
215  Pure,
216  AllTypesMatch<["src", "a", "b", "res"]>]> {
217  let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src,
218                   VectorOfLengthAndType<[8], [F64]>:$a,
219                   VectorOfLengthAndType<[8], [F64]>:$b,
220                   I8:$k,
221                   LLVM_Type:$rounding);
222}
223
224//----------------------------------------------------------------------------//
225// Vp2IntersectOp
226//----------------------------------------------------------------------------//
227
228def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure,
229  AllTypesMatch<["a", "b"]>,
230  TypesMatchWith<"k1 has the same number of bits as elements in a",
231                 "a", "k1",
232                 "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
233                 "IntegerType::get($_self.getContext(), 1))">,
234  TypesMatchWith<"k2 has the same number of bits as elements in b",
235                 // Should use `b` instead of `a`, but that would require
236                 // adding `type($b)` to assemblyFormat.
237                 "a", "k2",
238                 "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
239                 "IntegerType::get($_self.getContext(), 1))">]> {
240  let summary = "Vp2Intersect op";
241  let description = [{
242    The `vp2intersect` op is an AVX512 specific op that can lower to the proper
243    LLVMAVX512 operation: `llvm.vp2intersect.d.512` or
244    `llvm.vp2intersect.q.512` depending on the type of MLIR vectors it is
245    applied to.
246
247    #### From the Intel Intrinsics Guide:
248
249    Compute intersection of packed integer vectors `a` and `b`, and store
250    indication of match in the corresponding bit of two mask registers
251    specified by `k1` and `k2`. A match in corresponding elements of `a` and
252    `b` is indicated by a set bit in the corresponding bit of the mask
253    registers.
254  }];
255  let arguments = (ins VectorOfLengthAndType<[16, 8], [I32, I64]>:$a,
256                   VectorOfLengthAndType<[16, 8], [I32, I64]>:$b
257                   );
258  let results = (outs VectorOfLengthAndType<[16, 8], [I1]>:$k1,
259                 VectorOfLengthAndType<[16, 8], [I1]>:$k2
260                 );
261  let assemblyFormat =
262    "$a `,` $b attr-dict `:` type($a)";
263}
264
265def Vp2IntersectDIntrOp : AVX512_IntrOp<"vp2intersect.d.512", 2, [
266  Pure]> {
267  let arguments = (ins VectorOfLengthAndType<[16], [I32]>:$a,
268                   VectorOfLengthAndType<[16], [I32]>:$b);
269}
270
271def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [
272  Pure]> {
273  let arguments = (ins VectorOfLengthAndType<[8], [I64]>:$a,
274                   VectorOfLengthAndType<[8], [I64]>:$b);
275}
276
277//----------------------------------------------------------------------------//
278// Dot BF16
279//----------------------------------------------------------------------------//
280
281def DotBF16Op : AVX512_Op<"dot", [Pure,
282  AllTypesMatch<["a", "b"]>,
283  AllTypesMatch<["src", "dst"]>,
284  TypesMatchWith<"`a` has twice an many elements as `src`",
285                 "src", "a",
286                 "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] * 2}, "
287                 "BFloat16Type::get($_self.getContext()))">]> {
288  let summary = "Dot BF16 op";
289  let description = [{
290    The `dot` op is an AVX512-BF16 specific op that can lower to the proper
291    LLVMAVX512BF16 operation `llvm.dpbf16ps` depending on the width of MLIR
292    vectors it is applied to.
293
294    #### From the Intel Intrinsics Guide:
295
296    Compute dot-product of BF16 (16-bit) floating-point pairs in `a` and `b`,
297    accumulating the intermediate single-precision (32-bit) floating-point
298    elements with elements in `src`, and store the results in `dst`.
299
300    Example:
301    ```mlir
302    %0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
303    ```
304  }];
305  let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [F32]>:$src,
306                   VectorOfLengthAndType<[8, 16, 32], [BF16]>:$a,
307                   VectorOfLengthAndType<[8, 16, 32], [BF16]>:$b
308                   );
309  let results = (outs VectorOfLengthAndType<[4, 8, 16], [F32]>:$dst);
310  let assemblyFormat =
311    "$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)";
312}
313
314def DotBF16Ps128IntrOp : AVX512_IntrOp<"dpbf16ps.128", 1, [Pure,
315    AllTypesMatch<["a", "b"]>,
316    AllTypesMatch<["src", "res"]>],
317    /*extension=*/"bf16"> {
318  let arguments = (ins VectorOfLengthAndType<[4], [F32]>:$src,
319                       VectorOfLengthAndType<[8], [BF16]>:$a,
320                       VectorOfLengthAndType<[8], [BF16]>:$b);
321  let results = (outs VectorOfLengthAndType<[4], [F32]>:$res);
322}
323
324def DotBF16Ps256IntrOp : AVX512_IntrOp<"dpbf16ps.256", 1, [Pure,
325    AllTypesMatch<["a", "b"]>,
326    AllTypesMatch<["src", "res"]>],
327    /*extension=*/"bf16"> {
328  let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$src,
329                       VectorOfLengthAndType<[16], [BF16]>:$a,
330                       VectorOfLengthAndType<[16], [BF16]>:$b);
331  let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
332}
333
334def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure,
335    AllTypesMatch<["a", "b"]>,
336    AllTypesMatch<["src", "res"]>],
337    /*extension=*/"bf16"> {
338  let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
339                       VectorOfLengthAndType<[32], [BF16]>:$a,
340                       VectorOfLengthAndType<[32], [BF16]>:$b);
341  let results = (outs VectorOfLengthAndType<[16], [F32]>:$res);
342}
343
344//===----------------------------------------------------------------------===//
345// AVX op definitions
346//===----------------------------------------------------------------------===//
347
348// Operation that is part of the input dialect.
349class AVX_Op<string mnemonic, list<Trait> traits = []> :
350  Op<X86Vector_Dialect, "avx." # mnemonic, traits> {}
351
352// Operation that may be part of the input dialect, but whose
353// form is somewhere between the user view of the operation
354// and the actual lower level intrinsic in LLVM IR.
355class AVX_LowOp<string mnemonic, list<Trait> traits = []> :
356  Op<X86Vector_Dialect, "avx.intr." # mnemonic, traits> {}
357
358// Intrinsic operation used during lowering to LLVM IR.
359class AVX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
360  LLVM_IntrOpBase<X86Vector_Dialect, "avx.intr." # mnemonic,
361                  "x86_avx_" # !subst(".", "_", mnemonic),
362                  [], [], traits, numResults>;
363
364//----------------------------------------------------------------------------//
365// AVX Rsqrt
366//----------------------------------------------------------------------------//
367
368def RsqrtOp : AVX_Op<"rsqrt", [Pure, SameOperandsAndResultType]> {
369  let summary = "Rsqrt";
370  let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
371  let results = (outs VectorOfLengthAndType<[8], [F32]>:$b);
372  let assemblyFormat = "$a attr-dict `:` type($a)";
373}
374
375def RsqrtIntrOp : AVX_IntrOp<"rsqrt.ps.256", 1, [Pure,
376  SameOperandsAndResultType]> {
377  let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
378}
379
380//----------------------------------------------------------------------------//
381// AVX Dot
382//----------------------------------------------------------------------------//
383
384def DotOp : AVX_LowOp<"dot", [Pure, SameOperandsAndResultType]> {
385  let summary = "Dot";
386  let description = [{
387    Computes the 4-way dot products of the lower and higher parts of the source
388    vectors and broadcasts the two results to the lower and higher elements of
389    the destination vector, respectively. Adding one element of the lower part
390    to one element of the higher part in the destination vector yields the full
391    dot product of the two source vectors.
392
393    Example:
394
395    ```mlir
396    %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
397    %1 = vector.extractelement %0[%i0 : i32]: vector<8xf32>
398    %2 = vector.extractelement %0[%i4 : i32]: vector<8xf32>
399    %d = arith.addf %1, %2 : f32
400    ```
401  }];
402  let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a,
403                       VectorOfLengthAndType<[8], [F32]>:$b);
404  let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
405  let assemblyFormat = "$a `,` $b attr-dict `:` type($res)";
406}
407
408def DotIntrOp : AVX_IntrOp<"dp.ps.256", 1, [Pure,
409    AllTypesMatch<["a", "b", "res"]>]> {
410  let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a,
411                       VectorOfLengthAndType<[8], [F32]>:$b, I8:$c);
412  let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
413}
414
415#endif // X86VECTOR_OPS
416