1// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s 2 3// This test covers the Integer Dot Product ops defined in the 4// SPV_KHR_integer_dot_product extension. 5 6//===----------------------------------------------------------------------===// 7// spirv.SDot 8//===----------------------------------------------------------------------===// 9 10// CHECK: @sdot_scalar_i32 11func.func @sdot_scalar_i32(%a: i32, %b: i32) -> i32 { 12 // CHECK-NEXT: spirv.SDot 13 %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i32 14 return %r : i32 15} 16 17// CHECK: @sdot_scalar_i64 18func.func @sdot_scalar_i64(%a: i32, %b: i32) -> i64 { 19 // CHECK-NEXT: spirv.SDot 20 %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i64 21 return %r : i64 22} 23 24// CHECK: @sdot_vector_4xi8 25func.func @sdot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 { 26 // CHECK-NEXT: spirv.SDot 27 %r = spirv.SDot %a, %b : vector<4xi8> -> i32 28 return %r : i32 29} 30 31// CHECK: @sdot_vector_4xi16 32func.func @sdot_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>) -> i64 { 33 // CHECK-NEXT: spirv.SDot 34 %r = spirv.SDot %a, %b : vector<4xi16> -> i64 35 return %r : i64 36} 37 38// CHECK: @sdot_vector_8xi8 39func.func @sdot_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>) -> i64 { 40 // CHECK-NEXT: spirv.SDot 41 %r = spirv.SDot %a, %b : vector<8xi8> -> i64 42 return %r : i64 43} 44 45// ----- 46 47// expected-note @+1 {{prior use here}} 48func.func @sdot_scalar_bad_types(%a: i32, %b: i64) -> i32 { 49 // expected-error @+1 {{use of value '%b' expects different type than prior uses: 'i32' vs 'i64'}} 50 %r = spirv.SDot %a, %b : i32 -> i32 51 return %r : i32 52} 53// ----- 54 55func.func @sdot_vector_4xi8_bad_attr(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 { 56 // expected-error @+1 {{op with invalid format attribute for vector operands of type 'vector<4xi8>'}} 57 %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : vector<4xi8> -> i32 58 return %r : i32 59} 60 61// ----- 62 63func.func @sdot_scalar_bad_types(%a: i32, %b: i32) -> i16 { 64 // expected-error @+1 {{op result type has insufficient bit-width (16 bits) for the specified vector operand type (32 bits)}} 65 %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i16 66 return %r : i16 67} 68 69// ----- 70 71func.func @sdot_scalar_bad_types(%a: i64, %b: i64) -> i64 { 72 // expected-error @+1 {{op with specified Packed Vector Format (PackedVectorFormat4x8Bit) requires integer vector operands to be 32-bits wide}} 73 %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i64 -> i64 74 return %r : i64 75} 76 77// ----- 78 79//===----------------------------------------------------------------------===// 80// spirv.SUDot 81//===----------------------------------------------------------------------===// 82 83// CHECK: @sudot_scalar_i32 84func.func @sudot_scalar_i32(%a: i32, %b: i32) -> i32 { 85 // CHECK-NEXT: spirv.SUDot 86 %r = spirv.SUDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i32 87 return %r : i32 88} 89 90// CHECK: @sudot_scalar_i64 91func.func @sudot_scalar_i64(%a: i32, %b: i32) -> i64 { 92 // CHECK-NEXT: spirv.SUDot 93 %r = spirv.SUDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i64 94 return %r : i64 95} 96 97// CHECK: @sudot_vector_4xi8 98func.func @sudot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 { 99 // CHECK-NEXT: spirv.SUDot 100 %r = spirv.SUDot %a, %b : vector<4xi8> -> i32 101 return %r : i32 102} 103 104// CHECK: @sudot_vector_4xi16 105func.func @sudot_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>) -> i64 { 106 // CHECK-NEXT: spirv.SUDot 107 %r = spirv.SUDot %a, %b : vector<4xi16> -> i64 108 return %r : i64 109} 110 111// CHECK: @sudot_vector_8xi8 112func.func @sudot_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>) -> i64 { 113 // CHECK-NEXT: spirv.SUDot 114 %r = spirv.SUDot %a, %b : vector<8xi8> -> i64 115 return %r : i64 116} 117 118// ----- 119 120//===----------------------------------------------------------------------===// 121// spirv.UDot 122//===----------------------------------------------------------------------===// 123 124// CHECK: @udot_scalar_i32 125func.func @udot_scalar_i32(%a: i32, %b: i32) -> i32 { 126 // CHECK-NEXT: spirv.UDot 127 %r = spirv.UDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i32 128 return %r : i32 129} 130 131// CHECK: @udot_scalar_i64 132func.func @udot_scalar_i64(%a: i32, %b: i32) -> i64 { 133 // CHECK-NEXT: spirv.UDot 134 %r = spirv.UDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i64 135 return %r : i64 136} 137 138// CHECK: @udot_vector_4xi8 139func.func @udot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 { 140 // CHECK-NEXT: spirv.UDot 141 %r = spirv.UDot %a, %b : vector<4xi8> -> i32 142 return %r : i32 143} 144 145// ----- 146 147//===----------------------------------------------------------------------===// 148// spirv.SDotAccSat 149//===----------------------------------------------------------------------===// 150 151// CHECK: @sdot_acc_sat_scalar_i32 152func.func @sdot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc : i32) -> i32 { 153 // CHECK-NEXT: spirv.SDotAccSat 154 %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32 155 return %r : i32 156} 157 158// CHECK: @sdot_acc_sat_scalar_i64 159func.func @sdot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc : i64) -> i64 { 160 // CHECK-NEXT: spirv.SDotAccSat 161 %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64 162 return %r : i64 163} 164 165// CHECK: @sdot_acc_sat_vector_4xi8 166func.func @sdot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc : i32) -> i32 { 167 // CHECK-NEXT: spirv.SDotAccSat 168 %r = spirv.SDotAccSat %a, %b, %acc : vector<4xi8> -> i32 169 return %r : i32 170} 171 172// CHECK: @sdot_acc_sat_vector_4xi16 173func.func @sdot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc : i64) -> i64 { 174 // CHECK-NEXT: spirv.SDotAccSat 175 %r = spirv.SDotAccSat %a, %b, %acc : vector<4xi16> -> i64 176 return %r : i64 177} 178 179// CHECK: @sdot_acc_sat_vector_8xi8 180func.func @sdot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc : i64) -> i64 { 181 // CHECK-NEXT: spirv.SDotAccSat 182 %r = spirv.SDotAccSat %a, %b, %acc : vector<8xi8> -> i64 183 return %r : i64 184} 185 186// ----- 187 188// expected-note @+1 {{prior use here}} 189func.func @sdot_acc_sat_scalar_bad_types(%a: i32, %b: i64, %acc : i32) -> i32 { 190 // expected-error @+1 {{use of value '%b' expects different type than prior uses: 'i32' vs 'i64'}} 191 %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32 192 return %r : i32 193} 194 195// ----- 196 197func.func @sdot_acc_sat_scalar_bad_types(%a: i32, %b: i32, %acc : i16) -> i16 { 198 // expected-error @+1 {{op result type has insufficient bit-width (16 bits) for the specified vector operand type (32 bits)}} 199 %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i16 200 return %r : i16 201} 202 203// ----- 204 205func.func @sdot_acc_sat_scalar_bad_types(%a: i64, %b: i64, %acc : i64) -> i64 { 206 // expected-error @+1 {{op with specified Packed Vector Format (PackedVectorFormat4x8Bit) requires integer vector operands to be 32-bits wide}} 207 %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i64 -> i64 208 return %r : i64 209} 210 211// ----- 212 213// expected-note @+1 {{prior use here}} 214func.func @sdot_acc_sat_scalar_bad_accumulator(%a: i32, %b: i32, %acc : i32) -> i64 { 215 // expected-error @+1 {{use of value '%acc' expects different type than prior uses: 'i64' vs 'i32'}} 216 %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64 217 return %r : i64 218} 219 220// ----- 221 222//===----------------------------------------------------------------------===// 223// spirv.SUDotAccSat 224//===----------------------------------------------------------------------===// 225 226// CHECK: @sudot_acc_sat_scalar_i32 227func.func @sudot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc : i32) -> i32 { 228 // CHECK-NEXT: spirv.SUDotAccSat 229 %r = spirv.SUDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32 230 return %r : i32 231} 232 233// CHECK: @sudot_acc_sat_scalar_i64 234func.func @sudot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc : i64) -> i64 { 235 // CHECK-NEXT: spirv.SUDotAccSat 236 %r = spirv.SUDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64 237 return %r : i64 238} 239 240// CHECK: @sudot_acc_sat_vector_4xi8 241func.func @sudot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc : i32) -> i32 { 242 // CHECK-NEXT: spirv.SUDotAccSat 243 %r = spirv.SUDotAccSat %a, %b, %acc : vector<4xi8> -> i32 244 return %r : i32 245} 246 247// CHECK: @sudot_acc_sat_vector_4xi16 248func.func @sudot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc : i64) -> i64 { 249 // CHECK-NEXT: spirv.SUDotAccSat 250 %r = spirv.SUDotAccSat %a, %b, %acc : vector<4xi16> -> i64 251 return %r : i64 252} 253 254// CHECK: @sudot_acc_sat_vector_8xi8 255func.func @sudot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc : i64) -> i64 { 256 // CHECK-NEXT: spirv.SUDotAccSat 257 %r = spirv.SUDotAccSat %a, %b, %acc : vector<8xi8> -> i64 258 return %r : i64 259} 260 261// ----- 262 263//===----------------------------------------------------------------------===// 264// spirv.UDotAccSat 265//===----------------------------------------------------------------------===// 266 267// CHECK: @udot_acc_sat_scalar_i32 268func.func @udot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc : i32) -> i32 { 269 // CHECK-NEXT: spirv.UDotAccSat 270 %r = spirv.UDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32 271 return %r : i32 272} 273 274// CHECK: @udot_acc_sat_scalar_i64 275func.func @udot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc : i64) -> i64 { 276 // CHECK-NEXT: spirv.UDotAccSat 277 %r = spirv.UDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64 278 return %r : i64 279} 280 281// CHECK: @udot_acc_sat_vector_4xi8 282func.func @udot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc : i32) -> i32 { 283 // CHECK-NEXT: spirv.UDotAccSat 284 %r = spirv.UDotAccSat %a, %b, %acc : vector<4xi8> -> i32 285 return %r : i32 286} 287 288// CHECK: @udot_acc_sat_vector_4xi16 289func.func @udot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc : i64) -> i64 { 290 // CHECK-NEXT: spirv.UDotAccSat 291 %r = spirv.UDotAccSat %a, %b, %acc : vector<4xi16> -> i64 292 return %r : i64 293} 294 295// CHECK: @udot_acc_sat_vector_8xi8 296func.func @udot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc : i64) -> i64 { 297 // CHECK-NEXT: spirv.UDotAccSat 298 %r = spirv.UDotAccSat %a, %b, %acc : vector<8xi8> -> i64 299 return %r : i64 300} 301