xref: /llvm-project/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir (revision 1b70587ca13bc1d372ddd3928818b7d774e21595)
1// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
2
3// This test covers the Integer Dot Product ops defined in the
4// SPV_KHR_integer_dot_product extension.
5
6//===----------------------------------------------------------------------===//
7// spirv.SDot
8//===----------------------------------------------------------------------===//
9
10// CHECK: @sdot_scalar_i32
11func.func @sdot_scalar_i32(%a: i32, %b: i32) -> i32 {
12  // CHECK-NEXT: spirv.SDot
13  %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i32
14  return %r : i32
15}
16
17// CHECK: @sdot_scalar_i64
18func.func @sdot_scalar_i64(%a: i32, %b: i32) -> i64 {
19  // CHECK-NEXT: spirv.SDot
20  %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i64
21  return %r : i64
22}
23
24// CHECK: @sdot_vector_4xi8
25func.func @sdot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
26  // CHECK-NEXT: spirv.SDot
27  %r = spirv.SDot %a, %b : vector<4xi8> -> i32
28  return %r : i32
29}
30
31// CHECK: @sdot_vector_4xi16
32func.func @sdot_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>) -> i64 {
33  // CHECK-NEXT: spirv.SDot
34  %r = spirv.SDot %a, %b : vector<4xi16> -> i64
35  return %r : i64
36}
37
38// CHECK: @sdot_vector_8xi8
39func.func @sdot_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>) -> i64 {
40  // CHECK-NEXT: spirv.SDot
41  %r = spirv.SDot %a, %b : vector<8xi8> -> i64
42  return %r : i64
43}
44
45// -----
46
47// expected-note @+1 {{prior use here}}
48func.func @sdot_scalar_bad_types(%a: i32, %b: i64) -> i32 {
49  // expected-error @+1 {{use of value '%b' expects different type than prior uses: 'i32' vs 'i64'}}
50  %r = spirv.SDot %a, %b : i32 -> i32
51  return %r : i32
52}
53// -----
54
55func.func @sdot_vector_4xi8_bad_attr(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
56  // expected-error @+1 {{op with invalid format attribute for vector operands of type 'vector<4xi8>'}}
57  %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : vector<4xi8> -> i32
58  return %r : i32
59}
60
61// -----
62
63func.func @sdot_scalar_bad_types(%a: i32, %b: i32) -> i16 {
64  // expected-error @+1 {{op result type has insufficient bit-width (16 bits) for the specified vector operand type (32 bits)}}
65  %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i16
66  return %r : i16
67}
68
69// -----
70
71func.func @sdot_scalar_bad_types(%a: i64, %b: i64) -> i64 {
72  // expected-error @+1 {{op with specified Packed Vector Format (PackedVectorFormat4x8Bit) requires integer vector operands to be 32-bits wide}}
73  %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i64 -> i64
74  return %r : i64
75}
76
77// -----
78
79//===----------------------------------------------------------------------===//
80// spirv.SUDot
81//===----------------------------------------------------------------------===//
82
83// CHECK: @sudot_scalar_i32
84func.func @sudot_scalar_i32(%a: i32, %b: i32) -> i32 {
85  // CHECK-NEXT: spirv.SUDot
86  %r = spirv.SUDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i32
87  return %r : i32
88}
89
90// CHECK: @sudot_scalar_i64
91func.func @sudot_scalar_i64(%a: i32, %b: i32) -> i64 {
92  // CHECK-NEXT: spirv.SUDot
93  %r = spirv.SUDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i64
94  return %r : i64
95}
96
97// CHECK: @sudot_vector_4xi8
98func.func @sudot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
99  // CHECK-NEXT: spirv.SUDot
100  %r = spirv.SUDot %a, %b : vector<4xi8> -> i32
101  return %r : i32
102}
103
104// CHECK: @sudot_vector_4xi16
105func.func @sudot_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>) -> i64 {
106  // CHECK-NEXT: spirv.SUDot
107  %r = spirv.SUDot %a, %b : vector<4xi16> -> i64
108  return %r : i64
109}
110
111// CHECK: @sudot_vector_8xi8
112func.func @sudot_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>) -> i64 {
113  // CHECK-NEXT: spirv.SUDot
114  %r = spirv.SUDot %a, %b : vector<8xi8> -> i64
115  return %r : i64
116}
117
118// -----
119
120//===----------------------------------------------------------------------===//
121// spirv.UDot
122//===----------------------------------------------------------------------===//
123
124// CHECK: @udot_scalar_i32
125func.func @udot_scalar_i32(%a: i32, %b: i32) -> i32 {
126  // CHECK-NEXT: spirv.UDot
127  %r = spirv.UDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i32
128  return %r : i32
129}
130
131// CHECK: @udot_scalar_i64
132func.func @udot_scalar_i64(%a: i32, %b: i32) -> i64 {
133  // CHECK-NEXT: spirv.UDot
134  %r = spirv.UDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i64
135  return %r : i64
136}
137
138// CHECK: @udot_vector_4xi8
139func.func @udot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
140  // CHECK-NEXT: spirv.UDot
141  %r = spirv.UDot %a, %b : vector<4xi8> -> i32
142  return %r : i32
143}
144
145// -----
146
147//===----------------------------------------------------------------------===//
148// spirv.SDotAccSat
149//===----------------------------------------------------------------------===//
150
151// CHECK: @sdot_acc_sat_scalar_i32
152func.func @sdot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc : i32) -> i32 {
153  // CHECK-NEXT: spirv.SDotAccSat
154  %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32
155  return %r : i32
156}
157
158// CHECK: @sdot_acc_sat_scalar_i64
159func.func @sdot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc : i64) -> i64 {
160  // CHECK-NEXT: spirv.SDotAccSat
161  %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64
162  return %r : i64
163}
164
165// CHECK: @sdot_acc_sat_vector_4xi8
166func.func @sdot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc : i32) -> i32 {
167  // CHECK-NEXT: spirv.SDotAccSat
168  %r = spirv.SDotAccSat %a, %b, %acc : vector<4xi8> -> i32
169  return %r : i32
170}
171
172// CHECK: @sdot_acc_sat_vector_4xi16
173func.func @sdot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc : i64) -> i64 {
174  // CHECK-NEXT: spirv.SDotAccSat
175  %r = spirv.SDotAccSat %a, %b, %acc : vector<4xi16> -> i64
176  return %r : i64
177}
178
179// CHECK: @sdot_acc_sat_vector_8xi8
180func.func @sdot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc : i64) -> i64 {
181  // CHECK-NEXT: spirv.SDotAccSat
182  %r = spirv.SDotAccSat %a, %b, %acc : vector<8xi8> -> i64
183  return %r : i64
184}
185
186// -----
187
188// expected-note @+1 {{prior use here}}
189func.func @sdot_acc_sat_scalar_bad_types(%a: i32, %b: i64, %acc : i32) -> i32 {
190  // expected-error @+1 {{use of value '%b' expects different type than prior uses: 'i32' vs 'i64'}}
191  %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32
192  return %r : i32
193}
194
195// -----
196
197func.func @sdot_acc_sat_scalar_bad_types(%a: i32, %b: i32, %acc : i16) -> i16 {
198  // expected-error @+1 {{op result type has insufficient bit-width (16 bits) for the specified vector operand type (32 bits)}}
199  %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i16
200  return %r : i16
201}
202
203// -----
204
205func.func @sdot_acc_sat_scalar_bad_types(%a: i64, %b: i64, %acc : i64) -> i64 {
206  // expected-error @+1 {{op with specified Packed Vector Format (PackedVectorFormat4x8Bit) requires integer vector operands to be 32-bits wide}}
207  %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i64 -> i64
208  return %r : i64
209}
210
211// -----
212
213// expected-note @+1 {{prior use here}}
214func.func @sdot_acc_sat_scalar_bad_accumulator(%a: i32, %b: i32, %acc : i32) -> i64 {
215  // expected-error @+1 {{use of value '%acc' expects different type than prior uses: 'i64' vs 'i32'}}
216  %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64
217  return %r : i64
218}
219
220// -----
221
222//===----------------------------------------------------------------------===//
223// spirv.SUDotAccSat
224//===----------------------------------------------------------------------===//
225
226// CHECK: @sudot_acc_sat_scalar_i32
227func.func @sudot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc : i32) -> i32 {
228  // CHECK-NEXT: spirv.SUDotAccSat
229  %r = spirv.SUDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32
230  return %r : i32
231}
232
233// CHECK: @sudot_acc_sat_scalar_i64
234func.func @sudot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc : i64) -> i64 {
235  // CHECK-NEXT: spirv.SUDotAccSat
236  %r = spirv.SUDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64
237  return %r : i64
238}
239
240// CHECK: @sudot_acc_sat_vector_4xi8
241func.func @sudot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc : i32) -> i32 {
242  // CHECK-NEXT: spirv.SUDotAccSat
243  %r = spirv.SUDotAccSat %a, %b, %acc : vector<4xi8> -> i32
244  return %r : i32
245}
246
247// CHECK: @sudot_acc_sat_vector_4xi16
248func.func @sudot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc : i64) -> i64 {
249  // CHECK-NEXT: spirv.SUDotAccSat
250  %r = spirv.SUDotAccSat %a, %b, %acc : vector<4xi16> -> i64
251  return %r : i64
252}
253
254// CHECK: @sudot_acc_sat_vector_8xi8
255func.func @sudot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc : i64) -> i64 {
256  // CHECK-NEXT: spirv.SUDotAccSat
257  %r = spirv.SUDotAccSat %a, %b, %acc : vector<8xi8> -> i64
258  return %r : i64
259}
260
261// -----
262
263//===----------------------------------------------------------------------===//
264// spirv.UDotAccSat
265//===----------------------------------------------------------------------===//
266
267// CHECK: @udot_acc_sat_scalar_i32
268func.func @udot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc : i32) -> i32 {
269  // CHECK-NEXT: spirv.UDotAccSat
270  %r = spirv.UDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32
271  return %r : i32
272}
273
274// CHECK: @udot_acc_sat_scalar_i64
275func.func @udot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc : i64) -> i64 {
276  // CHECK-NEXT: spirv.UDotAccSat
277  %r = spirv.UDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64
278  return %r : i64
279}
280
281// CHECK: @udot_acc_sat_vector_4xi8
282func.func @udot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc : i32) -> i32 {
283  // CHECK-NEXT: spirv.UDotAccSat
284  %r = spirv.UDotAccSat %a, %b, %acc : vector<4xi8> -> i32
285  return %r : i32
286}
287
288// CHECK: @udot_acc_sat_vector_4xi16
289func.func @udot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc : i64) -> i64 {
290  // CHECK-NEXT: spirv.UDotAccSat
291  %r = spirv.UDotAccSat %a, %b, %acc : vector<4xi16> -> i64
292  return %r : i64
293}
294
295// CHECK: @udot_acc_sat_vector_8xi8
296func.func @udot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc : i64) -> i64 {
297  // CHECK-NEXT: spirv.UDotAccSat
298  %r = spirv.UDotAccSat %a, %b, %acc : vector<8xi8> -> i64
299  return %r : i64
300}
301