1//===-- X86VectorOps.td - X86Vector dialect operation defs -*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8// 9// This file defines the basic operations for the X86Vector dialect. 10// 11//===----------------------------------------------------------------------===// 12 13#ifndef X86VECTOR_OPS 14#define X86VECTOR_OPS 15 16include "mlir/Interfaces/InferTypeOpInterface.td" 17include "mlir/Interfaces/SideEffectInterfaces.td" 18include "mlir/Dialect/LLVMIR/LLVMOpBase.td" 19 20//===----------------------------------------------------------------------===// 21// X86Vector dialect definition 22//===----------------------------------------------------------------------===// 23 24def X86Vector_Dialect : Dialect { 25 let name = "x86vector"; 26 let cppNamespace = "::mlir::x86vector"; 27} 28 29//===----------------------------------------------------------------------===// 30// AVX512 op definitions 31//===----------------------------------------------------------------------===// 32 33// Operation that is part of the input dialect. 34class AVX512_Op<string mnemonic, list<Trait> traits = []> : 35 Op<X86Vector_Dialect, "avx512." # mnemonic, traits> {} 36 37// Intrinsic operation used during lowering to LLVM IR. 38class AVX512_IntrOp<string mnemonic, int numResults, 39 list<Trait> traits = [], 40 string extension = ""> : 41 LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic, 42 !subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic), 43 [], [], traits, numResults>; 44 45// Defined by first result overload. May have to be extended for other 46// instructions in the future. 47class AVX512_IntrOverloadedOp<string mnemonic, 48 list<Trait> traits = [], 49 string extension = ""> : 50 LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic, 51 !subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic), 52 /*list<int> overloadedResults=*/[0], 53 /*list<int> overloadedOperands=*/[], 54 traits, /*numResults=*/1>; 55 56//----------------------------------------------------------------------------// 57// MaskCompressOp 58//----------------------------------------------------------------------------// 59 60def MaskCompressOp : AVX512_Op<"mask.compress", [Pure, 61 // TODO: Support optional arguments in `AllTypesMatch`. "type($src)" could 62 // then be removed from assemblyFormat. 63 AllTypesMatch<["a", "dst"]>, 64 TypesMatchWith<"`k` has the same number of bits as elements in `dst`", 65 "dst", "k", 66 "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, " 67 "IntegerType::get($_self.getContext(), 1))">]> { 68 let summary = "Masked compress op"; 69 let description = [{ 70 The mask.compress op is an AVX512 specific op that can lower to the 71 `llvm.mask.compress` instruction. Instead of `src`, a constant vector 72 vector attribute `constant_src` may be specified. If neither `src` nor 73 `constant_src` is specified, the remaining elements in the result vector are 74 set to zero. 75 76 #### From the Intel Intrinsics Guide: 77 78 Contiguously store the active integer/floating-point elements in `a` (those 79 with their respective bit set in writemask `k`) to `dst`, and pass through the 80 remaining elements from `src`. 81 }]; 82 let arguments = (ins VectorOfLengthAndType<[16, 8], 83 [I1]>:$k, 84 VectorOfLengthAndType<[16, 8], 85 [F32, I32, F64, I64]>:$a, 86 Optional<VectorOfLengthAndType<[16, 8], 87 [F32, I32, F64, I64]>>:$src, 88 OptionalAttr<ElementsAttr>:$constant_src); 89 let results = (outs VectorOfLengthAndType<[16, 8], 90 [F32, I32, F64, I64]>:$dst); 91 let assemblyFormat = "$k `,` $a (`,` $src^)? attr-dict" 92 " `:` type($dst) (`,` type($src)^)?"; 93 let hasVerifier = 1; 94} 95 96def MaskCompressIntrOp : AVX512_IntrOverloadedOp<"mask.compress", [ 97 Pure, 98 AllTypesMatch<["a", "src", "res"]>, 99 TypesMatchWith<"`k` has the same number of bits as elements in `res`", 100 "res", "k", 101 "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, " 102 "IntegerType::get($_self.getContext(), 1))">]> { 103 let arguments = (ins VectorOfLengthAndType<[16, 8], 104 [F32, I32, F64, I64]>:$a, 105 VectorOfLengthAndType<[16, 8], 106 [F32, I32, F64, I64]>:$src, 107 VectorOfLengthAndType<[16, 8], 108 [I1]>:$k); 109} 110 111//----------------------------------------------------------------------------// 112// MaskRndScaleOp 113//----------------------------------------------------------------------------// 114 115def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure, 116 AllTypesMatch<["src", "a", "dst"]>, 117 TypesMatchWith<"imm has the same number of bits as elements in dst", 118 "dst", "imm", 119 "IntegerType::get($_self.getContext(), " 120 "(::llvm::cast<VectorType>($_self).getShape()[0]))">]> { 121 let summary = "Masked roundscale op"; 122 let description = [{ 123 The mask.rndscale op is an AVX512 specific op that can lower to the proper 124 LLVMAVX512 operation: `llvm.mask.rndscale.ps.512` or 125 `llvm.mask.rndscale.pd.512` instruction depending on the type of vectors it 126 is applied to. 127 128 #### From the Intel Intrinsics Guide: 129 130 Round packed floating-point elements in `a` to the number of fraction bits 131 specified by `imm`, and store the results in `dst` using writemask `k` 132 (elements are copied from src when the corresponding mask bit is not set). 133 }]; 134 // Supports vector<16xf32> and vector<8xf64>. 135 let arguments = (ins VectorOfLengthAndType<[16, 8], [F32, F64]>:$src, 136 I32:$k, 137 VectorOfLengthAndType<[16, 8], [F32, F64]>:$a, 138 AnyTypeOf<[I16, I8]>:$imm, 139 // TODO: figure rounding out (optional operand?). 140 I32:$rounding 141 ); 142 let results = (outs VectorOfLengthAndType<[16, 8], [F32, F64]>:$dst); 143 let assemblyFormat = 144 "$src `,` $k `,` $a `,` $imm `,` $rounding attr-dict `:` type($dst)"; 145} 146 147def MaskRndScalePSIntrOp : AVX512_IntrOp<"mask.rndscale.ps.512", 1, [ 148 Pure, 149 AllTypesMatch<["src", "a", "res"]>]> { 150 let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src, 151 I32:$k, 152 VectorOfLengthAndType<[16], [F32]>:$a, 153 I16:$imm, 154 LLVM_Type:$rounding); 155} 156 157def MaskRndScalePDIntrOp : AVX512_IntrOp<"mask.rndscale.pd.512", 1, [ 158 Pure, 159 AllTypesMatch<["src", "a", "res"]>]> { 160 let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src, 161 I32:$k, 162 VectorOfLengthAndType<[8], [F64]>:$a, 163 I8:$imm, 164 LLVM_Type:$rounding); 165} 166 167//----------------------------------------------------------------------------// 168// MaskScaleFOp 169//----------------------------------------------------------------------------// 170 171def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure, 172 AllTypesMatch<["src", "a", "b", "dst"]>, 173 TypesMatchWith<"k has the same number of bits as elements in dst", 174 "dst", "k", 175 "IntegerType::get($_self.getContext(), " 176 "(::llvm::cast<VectorType>($_self).getShape()[0]))">]> { 177 let summary = "ScaleF op"; 178 let description = [{ 179 The `mask.scalef` op is an AVX512 specific op that can lower to the proper 180 LLVMAVX512 operation: `llvm.mask.scalef.ps.512` or 181 `llvm.mask.scalef.pd.512` depending on the type of MLIR vectors it is 182 applied to. 183 184 #### From the Intel Intrinsics Guide: 185 186 Scale the packed floating-point elements in `a` using values from `b`, and 187 store the results in `dst` using writemask `k` (elements are copied from src 188 when the corresponding mask bit is not set). 189 }]; 190 // Supports vector<16xf32> and vector<8xf64>. 191 let arguments = (ins VectorOfLengthAndType<[16, 8], [F32, F64]>:$src, 192 VectorOfLengthAndType<[16, 8], [F32, F64]>:$a, 193 VectorOfLengthAndType<[16, 8], [F32, F64]>:$b, 194 AnyTypeOf<[I16, I8]>:$k, 195 // TODO: figure rounding out (optional operand?). 196 I32:$rounding 197 ); 198 let results = (outs VectorOfLengthAndType<[16, 8], [F32, F64]>:$dst); 199 // Fully specified by traits. 200 let assemblyFormat = 201 "$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst)"; 202} 203 204def MaskScaleFPSIntrOp : AVX512_IntrOp<"mask.scalef.ps.512", 1, [ 205 Pure, 206 AllTypesMatch<["src", "a", "b", "res"]>]> { 207 let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src, 208 VectorOfLengthAndType<[16], [F32]>:$a, 209 VectorOfLengthAndType<[16], [F32]>:$b, 210 I16:$k, 211 LLVM_Type:$rounding); 212} 213 214def MaskScaleFPDIntrOp : AVX512_IntrOp<"mask.scalef.pd.512", 1, [ 215 Pure, 216 AllTypesMatch<["src", "a", "b", "res"]>]> { 217 let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src, 218 VectorOfLengthAndType<[8], [F64]>:$a, 219 VectorOfLengthAndType<[8], [F64]>:$b, 220 I8:$k, 221 LLVM_Type:$rounding); 222} 223 224//----------------------------------------------------------------------------// 225// Vp2IntersectOp 226//----------------------------------------------------------------------------// 227 228def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure, 229 AllTypesMatch<["a", "b"]>, 230 TypesMatchWith<"k1 has the same number of bits as elements in a", 231 "a", "k1", 232 "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, " 233 "IntegerType::get($_self.getContext(), 1))">, 234 TypesMatchWith<"k2 has the same number of bits as elements in b", 235 // Should use `b` instead of `a`, but that would require 236 // adding `type($b)` to assemblyFormat. 237 "a", "k2", 238 "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, " 239 "IntegerType::get($_self.getContext(), 1))">]> { 240 let summary = "Vp2Intersect op"; 241 let description = [{ 242 The `vp2intersect` op is an AVX512 specific op that can lower to the proper 243 LLVMAVX512 operation: `llvm.vp2intersect.d.512` or 244 `llvm.vp2intersect.q.512` depending on the type of MLIR vectors it is 245 applied to. 246 247 #### From the Intel Intrinsics Guide: 248 249 Compute intersection of packed integer vectors `a` and `b`, and store 250 indication of match in the corresponding bit of two mask registers 251 specified by `k1` and `k2`. A match in corresponding elements of `a` and 252 `b` is indicated by a set bit in the corresponding bit of the mask 253 registers. 254 }]; 255 let arguments = (ins VectorOfLengthAndType<[16, 8], [I32, I64]>:$a, 256 VectorOfLengthAndType<[16, 8], [I32, I64]>:$b 257 ); 258 let results = (outs VectorOfLengthAndType<[16, 8], [I1]>:$k1, 259 VectorOfLengthAndType<[16, 8], [I1]>:$k2 260 ); 261 let assemblyFormat = 262 "$a `,` $b attr-dict `:` type($a)"; 263} 264 265def Vp2IntersectDIntrOp : AVX512_IntrOp<"vp2intersect.d.512", 2, [ 266 Pure]> { 267 let arguments = (ins VectorOfLengthAndType<[16], [I32]>:$a, 268 VectorOfLengthAndType<[16], [I32]>:$b); 269} 270 271def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [ 272 Pure]> { 273 let arguments = (ins VectorOfLengthAndType<[8], [I64]>:$a, 274 VectorOfLengthAndType<[8], [I64]>:$b); 275} 276 277//----------------------------------------------------------------------------// 278// Dot BF16 279//----------------------------------------------------------------------------// 280 281def DotBF16Op : AVX512_Op<"dot", [Pure, 282 AllTypesMatch<["a", "b"]>, 283 AllTypesMatch<["src", "dst"]>, 284 TypesMatchWith<"`a` has twice an many elements as `src`", 285 "src", "a", 286 "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] * 2}, " 287 "BFloat16Type::get($_self.getContext()))">]> { 288 let summary = "Dot BF16 op"; 289 let description = [{ 290 The `dot` op is an AVX512-BF16 specific op that can lower to the proper 291 LLVMAVX512BF16 operation `llvm.dpbf16ps` depending on the width of MLIR 292 vectors it is applied to. 293 294 #### From the Intel Intrinsics Guide: 295 296 Compute dot-product of BF16 (16-bit) floating-point pairs in `a` and `b`, 297 accumulating the intermediate single-precision (32-bit) floating-point 298 elements with elements in `src`, and store the results in `dst`. 299 300 Example: 301 ```mlir 302 %0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32> 303 ``` 304 }]; 305 let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [F32]>:$src, 306 VectorOfLengthAndType<[8, 16, 32], [BF16]>:$a, 307 VectorOfLengthAndType<[8, 16, 32], [BF16]>:$b 308 ); 309 let results = (outs VectorOfLengthAndType<[4, 8, 16], [F32]>:$dst); 310 let assemblyFormat = 311 "$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)"; 312} 313 314def DotBF16Ps128IntrOp : AVX512_IntrOp<"dpbf16ps.128", 1, [Pure, 315 AllTypesMatch<["a", "b"]>, 316 AllTypesMatch<["src", "res"]>], 317 /*extension=*/"bf16"> { 318 let arguments = (ins VectorOfLengthAndType<[4], [F32]>:$src, 319 VectorOfLengthAndType<[8], [BF16]>:$a, 320 VectorOfLengthAndType<[8], [BF16]>:$b); 321 let results = (outs VectorOfLengthAndType<[4], [F32]>:$res); 322} 323 324def DotBF16Ps256IntrOp : AVX512_IntrOp<"dpbf16ps.256", 1, [Pure, 325 AllTypesMatch<["a", "b"]>, 326 AllTypesMatch<["src", "res"]>], 327 /*extension=*/"bf16"> { 328 let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$src, 329 VectorOfLengthAndType<[16], [BF16]>:$a, 330 VectorOfLengthAndType<[16], [BF16]>:$b); 331 let results = (outs VectorOfLengthAndType<[8], [F32]>:$res); 332} 333 334def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure, 335 AllTypesMatch<["a", "b"]>, 336 AllTypesMatch<["src", "res"]>], 337 /*extension=*/"bf16"> { 338 let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src, 339 VectorOfLengthAndType<[32], [BF16]>:$a, 340 VectorOfLengthAndType<[32], [BF16]>:$b); 341 let results = (outs VectorOfLengthAndType<[16], [F32]>:$res); 342} 343 344//===----------------------------------------------------------------------===// 345// AVX op definitions 346//===----------------------------------------------------------------------===// 347 348// Operation that is part of the input dialect. 349class AVX_Op<string mnemonic, list<Trait> traits = []> : 350 Op<X86Vector_Dialect, "avx." # mnemonic, traits> {} 351 352// Operation that may be part of the input dialect, but whose 353// form is somewhere between the user view of the operation 354// and the actual lower level intrinsic in LLVM IR. 355class AVX_LowOp<string mnemonic, list<Trait> traits = []> : 356 Op<X86Vector_Dialect, "avx.intr." # mnemonic, traits> {} 357 358// Intrinsic operation used during lowering to LLVM IR. 359class AVX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> : 360 LLVM_IntrOpBase<X86Vector_Dialect, "avx.intr." # mnemonic, 361 "x86_avx_" # !subst(".", "_", mnemonic), 362 [], [], traits, numResults>; 363 364//----------------------------------------------------------------------------// 365// AVX Rsqrt 366//----------------------------------------------------------------------------// 367 368def RsqrtOp : AVX_Op<"rsqrt", [Pure, SameOperandsAndResultType]> { 369 let summary = "Rsqrt"; 370 let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a); 371 let results = (outs VectorOfLengthAndType<[8], [F32]>:$b); 372 let assemblyFormat = "$a attr-dict `:` type($a)"; 373} 374 375def RsqrtIntrOp : AVX_IntrOp<"rsqrt.ps.256", 1, [Pure, 376 SameOperandsAndResultType]> { 377 let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a); 378} 379 380//----------------------------------------------------------------------------// 381// AVX Dot 382//----------------------------------------------------------------------------// 383 384def DotOp : AVX_LowOp<"dot", [Pure, SameOperandsAndResultType]> { 385 let summary = "Dot"; 386 let description = [{ 387 Computes the 4-way dot products of the lower and higher parts of the source 388 vectors and broadcasts the two results to the lower and higher elements of 389 the destination vector, respectively. Adding one element of the lower part 390 to one element of the higher part in the destination vector yields the full 391 dot product of the two source vectors. 392 393 Example: 394 395 ```mlir 396 %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> 397 %1 = vector.extractelement %0[%i0 : i32]: vector<8xf32> 398 %2 = vector.extractelement %0[%i4 : i32]: vector<8xf32> 399 %d = arith.addf %1, %2 : f32 400 ``` 401 }]; 402 let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a, 403 VectorOfLengthAndType<[8], [F32]>:$b); 404 let results = (outs VectorOfLengthAndType<[8], [F32]>:$res); 405 let assemblyFormat = "$a `,` $b attr-dict `:` type($res)"; 406} 407 408def DotIntrOp : AVX_IntrOp<"dp.ps.256", 1, [Pure, 409 AllTypesMatch<["a", "b", "res"]>]> { 410 let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a, 411 VectorOfLengthAndType<[8], [F32]>:$b, I8:$c); 412 let results = (outs VectorOfLengthAndType<[8], [F32]>:$res); 413} 414 415#endif // X86VECTOR_OPS 416