xref: /llvm-project/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir (revision 897141449e306ba56a8b214c0799f00e5a1289cc)
1// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
2
3//===----------------------------------------------------------------------===//
4// spirv.FAdd
5//===----------------------------------------------------------------------===//
6
7func.func @fadd_scalar(%arg: f32) -> f32 {
8  // CHECK: spirv.FAdd
9  %0 = spirv.FAdd %arg, %arg : f32
10  return %0 : f32
11}
12
13// -----
14
15//===----------------------------------------------------------------------===//
16// spirv.FDiv
17//===----------------------------------------------------------------------===//
18
19func.func @fdiv_scalar(%arg: f32) -> f32 {
20  // CHECK: spirv.FDiv
21  %0 = spirv.FDiv %arg, %arg : f32
22  return %0 : f32
23}
24
25// -----
26
27//===----------------------------------------------------------------------===//
28// spirv.FMod
29//===----------------------------------------------------------------------===//
30
31func.func @fmod_scalar(%arg: f32) -> f32 {
32  // CHECK: spirv.FMod
33  %0 = spirv.FMod %arg, %arg : f32
34  return %0 : f32
35}
36
37// -----
38
39//===----------------------------------------------------------------------===//
40// spirv.FMul
41//===----------------------------------------------------------------------===//
42
43func.func @fmul_scalar(%arg: f32) -> f32 {
44  // CHECK: spirv.FMul
45  %0 = spirv.FMul %arg, %arg : f32
46  return %0 : f32
47}
48
49func.func @fmul_vector(%arg: vector<4xf32>) -> vector<4xf32> {
50  // CHECK: spirv.FMul
51  %0 = spirv.FMul %arg, %arg : vector<4xf32>
52  return %0 : vector<4xf32>
53}
54
55// -----
56
57func.func @fmul_i32(%arg: i32) -> i32 {
58  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
59  %0 = spirv.FMul %arg, %arg : i32
60  return %0 : i32
61}
62
63// -----
64
65func.func @fmul_bf16(%arg: bf16) -> bf16 {
66  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
67  %0 = spirv.FMul %arg, %arg : bf16
68  return %0 : bf16
69}
70
71// -----
72
73func.func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
74  // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
75  %0 = spirv.FMul %arg, %arg : tensor<4xf32>
76  return %0 : tensor<4xf32>
77}
78
79// -----
80
81//===----------------------------------------------------------------------===//
82// spirv.FNegate
83//===----------------------------------------------------------------------===//
84
85func.func @fnegate_scalar(%arg: f32) -> f32 {
86  // CHECK: spirv.FNegate
87  %0 = spirv.FNegate %arg : f32
88  return %0 : f32
89}
90
91// -----
92
93//===----------------------------------------------------------------------===//
94// spirv.FRem
95//===----------------------------------------------------------------------===//
96
97func.func @frem_scalar(%arg: f32) -> f32 {
98  // CHECK: spirv.FRem
99  %0 = spirv.FRem %arg, %arg : f32
100  return %0 : f32
101}
102
103// -----
104
105//===----------------------------------------------------------------------===//
106// spirv.FSub
107//===----------------------------------------------------------------------===//
108
109func.func @fsub_scalar(%arg: f32) -> f32 {
110  // CHECK: spirv.FSub
111  %0 = spirv.FSub %arg, %arg : f32
112  return %0 : f32
113}
114
115// -----
116
117//===----------------------------------------------------------------------===//
118// spirv.IAdd
119//===----------------------------------------------------------------------===//
120
121func.func @iadd_scalar(%arg: i32) -> i32 {
122  // CHECK: spirv.IAdd
123  %0 = spirv.IAdd %arg, %arg : i32
124  return %0 : i32
125}
126
127// -----
128
129//===----------------------------------------------------------------------===//
130// spirv.IMul
131//===----------------------------------------------------------------------===//
132
133func.func @imul_scalar(%arg: i32) -> i32 {
134  // CHECK: spirv.IMul
135  %0 = spirv.IMul %arg, %arg : i32
136  return %0 : i32
137}
138
139// -----
140
141//===----------------------------------------------------------------------===//
142// spirv.ISub
143//===----------------------------------------------------------------------===//
144
145func.func @isub_scalar(%arg: i32) -> i32 {
146  // CHECK: spirv.ISub
147  %0 = spirv.ISub %arg, %arg : i32
148  return %0 : i32
149}
150
151// -----
152
153//===----------------------------------------------------------------------===//
154// spirv.IAddCarry
155//===----------------------------------------------------------------------===//
156
157// CHECK-LABEL: @iadd_carry_scalar
158func.func @iadd_carry_scalar(%arg: i32) -> !spirv.struct<(i32, i32)> {
159  // CHECK: spirv.IAddCarry %{{.+}}, %{{.+}} : !spirv.struct<(i32, i32)>
160  %0 = spirv.IAddCarry %arg, %arg : !spirv.struct<(i32, i32)>
161  return %0 : !spirv.struct<(i32, i32)>
162}
163
164// CHECK-LABEL: @iadd_carry_vector
165func.func @iadd_carry_vector(%arg: vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
166  // CHECK: spirv.IAddCarry %{{.+}}, %{{.+}} : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
167  %0 = spirv.IAddCarry %arg, %arg : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
168  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
169}
170
171// -----
172
173func.func @iadd_carry(%arg: i32) -> !spirv.struct<(i32, i32, i32)> {
174  // expected-error @+1 {{expected spirv.struct type with two members}}
175  %0 = spirv.IAddCarry %arg, %arg : !spirv.struct<(i32, i32, i32)>
176  return %0 : !spirv.struct<(i32, i32, i32)>
177}
178
179// -----
180
181func.func @iadd_carry(%arg: i32) -> !spirv.struct<(i32)> {
182  // expected-error @+1 {{expected result struct type containing two members}}
183  %0 = "spirv.IAddCarry"(%arg, %arg): (i32, i32) -> !spirv.struct<(i32)>
184  return %0 : !spirv.struct<(i32)>
185}
186
187// -----
188
189func.func @iadd_carry(%arg: i32) -> !spirv.struct<(i32, i64)> {
190  // expected-error @+1 {{expected all operand types and struct member types are the same}}
191  %0 = "spirv.IAddCarry"(%arg, %arg): (i32, i32) -> !spirv.struct<(i32, i64)>
192  return %0 : !spirv.struct<(i32, i64)>
193}
194
195// -----
196
197func.func @iadd_carry(%arg: i64) -> !spirv.struct<(i32, i32)> {
198  // expected-error @+1 {{expected all operand types and struct member types are the same}}
199  %0 = "spirv.IAddCarry"(%arg, %arg): (i64, i64) -> !spirv.struct<(i32, i32)>
200  return %0 : !spirv.struct<(i32, i32)>
201}
202
203// -----
204
205//===----------------------------------------------------------------------===//
206// spirv.ISubBorrow
207//===----------------------------------------------------------------------===//
208
209// CHECK-LABEL: @isub_borrow_scalar
210func.func @isub_borrow_scalar(%arg: i32) -> !spirv.struct<(i32, i32)> {
211  // CHECK: spirv.ISubBorrow %{{.+}}, %{{.+}} : !spirv.struct<(i32, i32)>
212  %0 = spirv.ISubBorrow %arg, %arg : !spirv.struct<(i32, i32)>
213  return %0 : !spirv.struct<(i32, i32)>
214}
215
216// CHECK-LABEL: @isub_borrow_vector
217func.func @isub_borrow_vector(%arg: vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
218  // CHECK: spirv.ISubBorrow %{{.+}}, %{{.+}} : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
219  %0 = spirv.ISubBorrow %arg, %arg : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
220  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
221}
222
223// -----
224
225func.func @isub_borrow(%arg: i32) -> !spirv.struct<(i32, i32, i32)> {
226  // expected-error @+1 {{expected spirv.struct type with two members}}
227  %0 = spirv.ISubBorrow %arg, %arg : !spirv.struct<(i32, i32, i32)>
228  return %0 : !spirv.struct<(i32, i32, i32)>
229}
230
231// -----
232
233func.func @isub_borrow(%arg: i32) -> !spirv.struct<(i32)> {
234  // expected-error @+1 {{expected result struct type containing two members}}
235  %0 = "spirv.ISubBorrow"(%arg, %arg): (i32, i32) -> !spirv.struct<(i32)>
236  return %0 : !spirv.struct<(i32)>
237}
238
239// -----
240
241func.func @isub_borrow(%arg: i32) -> !spirv.struct<(i32, i64)> {
242  // expected-error @+1 {{expected all operand types and struct member types are the same}}
243  %0 = "spirv.ISubBorrow"(%arg, %arg): (i32, i32) -> !spirv.struct<(i32, i64)>
244  return %0 : !spirv.struct<(i32, i64)>
245}
246
247// -----
248
249func.func @isub_borrow(%arg: i64) -> !spirv.struct<(i32, i32)> {
250  // expected-error @+1 {{expected all operand types and struct member types are the same}}
251  %0 = "spirv.ISubBorrow"(%arg, %arg): (i64, i64) -> !spirv.struct<(i32, i32)>
252  return %0 : !spirv.struct<(i32, i32)>
253}
254
255// -----
256
257//===----------------------------------------------------------------------===//
258// spirv.Dot
259//===----------------------------------------------------------------------===//
260
261func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
262  %0 = spirv.Dot %arg0, %arg1 : vector<4xf32> -> f32
263  return %0 : f32
264}
265
266// -----
267
268// expected-note @+1 {{prior use here}}
269func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 {
270  // expected-error @+1 {{use of value '%arg1' expects different type than prior uses}}
271  %0 = spirv.Dot %arg0, %arg1 : vector<4xf32> -> f32
272  return %0 : f32
273}
274
275// -----
276
277func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
278  // expected-error @+1 {{'spirv.Dot' op failed to verify that all of {vector1, result} have same element type}}
279  %0 = spirv.Dot %arg0, %arg1 : vector<4xf32> -> f16
280  return %0 : f16
281}
282
283// -----
284
285func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
286  // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
287  %0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
288  return %0 : i32
289}
290
291// -----
292
293//===----------------------------------------------------------------------===//
294// spirv.SMulExtended
295//===----------------------------------------------------------------------===//
296
297// CHECK-LABEL: @smul_extended_scalar
298func.func @smul_extended_scalar(%arg: i32) -> !spirv.struct<(i32, i32)> {
299  // CHECK: spirv.SMulExtended %{{.+}}, %{{.+}} : !spirv.struct<(i32, i32)>
300  %0 = spirv.SMulExtended %arg, %arg : !spirv.struct<(i32, i32)>
301  return %0 : !spirv.struct<(i32, i32)>
302}
303
304// CHECK-LABEL: @smul_extended_vector
305func.func @smul_extended_vector(%arg: vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
306  // CHECK: spirv.SMulExtended %{{.+}}, %{{.+}} : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
307  %0 = spirv.SMulExtended %arg, %arg : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
308  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
309}
310
311// -----
312
313func.func @smul_extended(%arg: i32) -> !spirv.struct<(i32, i32, i32)> {
314  // expected-error @+1 {{expected spirv.struct type with two members}}
315  %0 = spirv.SMulExtended %arg, %arg : !spirv.struct<(i32, i32, i32)>
316  return %0 : !spirv.struct<(i32, i32, i32)>
317}
318
319// -----
320
321func.func @smul_extended(%arg: i32) -> !spirv.struct<(i32)> {
322  // expected-error @+1 {{expected result struct type containing two members}}
323  %0 = "spirv.SMulExtended"(%arg, %arg): (i32, i32) -> !spirv.struct<(i32)>
324  return %0 : !spirv.struct<(i32)>
325}
326
327// -----
328
329func.func @smul_extended(%arg: i32) -> !spirv.struct<(i32, i64)> {
330  // expected-error @+1 {{expected all operand types and struct member types are the same}}
331  %0 = "spirv.SMulExtended"(%arg, %arg): (i32, i32) -> !spirv.struct<(i32, i64)>
332  return %0 : !spirv.struct<(i32, i64)>
333}
334
335// -----
336
337func.func @smul_extended(%arg: i64) -> !spirv.struct<(i32, i32)> {
338  // expected-error @+1 {{expected all operand types and struct member types are the same}}
339  %0 = "spirv.SMulExtended"(%arg, %arg): (i64, i64) -> !spirv.struct<(i32, i32)>
340  return %0 : !spirv.struct<(i32, i32)>
341}
342
343// -----
344
345//===----------------------------------------------------------------------===//
346// spirv.UMulExtended
347//===----------------------------------------------------------------------===//
348
349// CHECK-LABEL: @umul_extended_scalar
350func.func @umul_extended_scalar(%arg: i32) -> !spirv.struct<(i32, i32)> {
351  // CHECK: spirv.UMulExtended %{{.+}}, %{{.+}} : !spirv.struct<(i32, i32)>
352  %0 = spirv.UMulExtended %arg, %arg : !spirv.struct<(i32, i32)>
353  return %0 : !spirv.struct<(i32, i32)>
354}
355
356// CHECK-LABEL: @umul_extended_vector
357func.func @umul_extended_vector(%arg: vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
358  // CHECK: spirv.UMulExtended %{{.+}}, %{{.+}} : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
359  %0 = spirv.UMulExtended %arg, %arg : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
360  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
361}
362
363// -----
364
365func.func @umul_extended(%arg: i32) -> !spirv.struct<(i32, i32, i32)> {
366  // expected-error @+1 {{expected spirv.struct type with two members}}
367  %0 = spirv.UMulExtended %arg, %arg : !spirv.struct<(i32, i32, i32)>
368  return %0 : !spirv.struct<(i32, i32, i32)>
369}
370
371// -----
372
373func.func @umul_extended(%arg: i32) -> !spirv.struct<(i32)> {
374  // expected-error @+1 {{expected result struct type containing two members}}
375  %0 = "spirv.UMulExtended"(%arg, %arg): (i32, i32) -> !spirv.struct<(i32)>
376  return %0 : !spirv.struct<(i32)>
377}
378
379// -----
380
381func.func @umul_extended(%arg: i32) -> !spirv.struct<(i32, i64)> {
382  // expected-error @+1 {{expected all operand types and struct member types are the same}}
383  %0 = "spirv.UMulExtended"(%arg, %arg): (i32, i32) -> !spirv.struct<(i32, i64)>
384  return %0 : !spirv.struct<(i32, i64)>
385}
386
387// -----
388
389func.func @umul_extended(%arg: i64) -> !spirv.struct<(i32, i32)> {
390  // expected-error @+1 {{expected all operand types and struct member types are the same}}
391  %0 = "spirv.UMulExtended"(%arg, %arg): (i64, i64) -> !spirv.struct<(i32, i32)>
392  return %0 : !spirv.struct<(i32, i32)>
393}
394
395// -----
396
397//===----------------------------------------------------------------------===//
398// spirv.SDiv
399//===----------------------------------------------------------------------===//
400
401func.func @sdiv_scalar(%arg: i32) -> i32 {
402  // CHECK: spirv.SDiv
403  %0 = spirv.SDiv %arg, %arg : i32
404  return %0 : i32
405}
406
407// -----
408
409//===----------------------------------------------------------------------===//
410// spirv.SMod
411//===----------------------------------------------------------------------===//
412
413func.func @smod_scalar(%arg: i32) -> i32 {
414  // CHECK: spirv.SMod
415  %0 = spirv.SMod %arg, %arg : i32
416  return %0 : i32
417}
418
419// -----
420
421//===----------------------------------------------------------------------===//
422// spirv.SNegate
423//===----------------------------------------------------------------------===//
424
425func.func @snegate_scalar(%arg: i32) -> i32 {
426  // CHECK: spirv.SNegate
427  %0 = spirv.SNegate %arg : i32
428  return %0 : i32
429}
430
431// -----
432//===----------------------------------------------------------------------===//
433// spirv.SRem
434//===----------------------------------------------------------------------===//
435
436func.func @srem_scalar(%arg: i32) -> i32 {
437  // CHECK: spirv.SRem
438  %0 = spirv.SRem %arg, %arg : i32
439  return %0 : i32
440}
441
442// -----
443
444//===----------------------------------------------------------------------===//
445// spirv.UDiv
446//===----------------------------------------------------------------------===//
447
448func.func @udiv_scalar(%arg: i32) -> i32 {
449  // CHECK: spirv.UDiv
450  %0 = spirv.UDiv %arg, %arg : i32
451  return %0 : i32
452}
453
454// -----
455
456//===----------------------------------------------------------------------===//
457// spirv.UMod
458//===----------------------------------------------------------------------===//
459
460func.func @umod_scalar(%arg: i32) -> i32 {
461  // CHECK: spirv.UMod
462  %0 = spirv.UMod %arg, %arg : i32
463  return %0 : i32
464}
465
466// -----
467//===----------------------------------------------------------------------===//
468// spirv.VectorTimesScalar
469//===----------------------------------------------------------------------===//
470
471func.func @vector_times_scalar(%vector: vector<4xf32>, %scalar: f32) -> vector<4xf32> {
472  // CHECK: spirv.VectorTimesScalar %{{.+}}, %{{.+}} : (vector<4xf32>, f32) -> vector<4xf32>
473  %0 = spirv.VectorTimesScalar %vector, %scalar : (vector<4xf32>, f32) -> vector<4xf32>
474  return %0 : vector<4xf32>
475}
476
477// -----
478
479func.func @vector_times_scalar(%vector: vector<4xf32>, %scalar: f16) -> vector<4xf32> {
480  // expected-error @+1 {{scalar operand and result element type match}}
481  %0 = spirv.VectorTimesScalar %vector, %scalar : (vector<4xf32>, f16) -> vector<4xf32>
482  return %0 : vector<4xf32>
483}
484
485// -----
486
487func.func @vector_times_scalar(%vector: vector<4xf32>, %scalar: f32) -> vector<3xf32> {
488  // expected-error @+1 {{vector operand and result type mismatch}}
489  %0 = spirv.VectorTimesScalar %vector, %scalar : (vector<4xf32>, f32) -> vector<3xf32>
490  return %0 : vector<3xf32>
491}
492