xref: /llvm-project/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir (revision 6e90f13cc9bc9dbc5c2c248d95c6e18a5fb021b4)
1// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
2
3//===----------------------------------------------------------------------===//
4// spirv.CompositeConstruct
5//===----------------------------------------------------------------------===//
6
7// CHECK-LABEL: func @composite_construct_vector
8func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
9  // CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32) -> vector<3xf32>
10  %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xf32>
11  return %0: vector<3xf32>
12}
13
14// CHECK-LABEL: func @composite_construct_struct
15func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spirv.array<4xf32>, %arg2 : !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> {
16  // CHECK: spirv.CompositeConstruct
17  %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)>
18  return %0: !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)>
19}
20
21// CHECK-LABEL: func @composite_construct_mixed_scalar_vector
22func.func @composite_construct_mixed_scalar_vector(%arg0: f32, %arg1: f32, %arg2 : vector<2xf32>) -> vector<4xf32> {
23  // CHECK: spirv.CompositeConstruct %{{.+}}, %{{.+}}, %{{.+}} : (f32, vector<2xf32>, f32) -> vector<4xf32>
24  %0 = spirv.CompositeConstruct %arg0, %arg2, %arg1 : (f32, vector<2xf32>, f32) -> vector<4xf32>
25  return %0: vector<4xf32>
26}
27
28// CHECK-LABEL: func @composite_construct_coopmatrix_khr
29func.func @composite_construct_coopmatrix_khr(%arg0 : f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> {
30  // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
31  %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
32  return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
33}
34
35// -----
36
37func.func @composite_construct_invalid_result_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
38  // expected-error @+1 {{has incorrect number of operands: expected 3, but provided 2}}
39  %0 = spirv.CompositeConstruct %arg0, %arg2 : (f32, f32) -> vector<3xf32>
40  return %0: vector<3xf32>
41}
42
43// -----
44
45func.func @composite_construct_invalid_operand_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xi32> {
46  // expected-error @+1 {{operand type mismatch: expected operand type 'i32', but provided 'f32'}}
47  %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xi32>
48  return %0: vector<3xi32>
49}
50
51// -----
52
53func.func @composite_construct_khr_coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) ->
54  !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> {
55  // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
56  %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
57  return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
58}
59
60// -----
61
62func.func @composite_construct_khr_coopmatrix_incorrect_element_type(%arg0 : i32) ->
63  !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB> {
64  // expected-error @+1 {{operand type mismatch: expected operand type 'f32', but provided 'i32'}}
65  %0 = spirv.CompositeConstruct %arg0 : (i32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB>
66  return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB>
67}
68
69// -----
70
71func.func @composite_construct_array(%arg0: f32) -> !spirv.array<4xf32> {
72  // expected-error @+1 {{expected to return a vector or cooperative matrix when the number of constituents is less than what the result needs}}
73  %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.array<4xf32>
74  return %0: !spirv.array<4xf32>
75}
76
77// -----
78
79func.func @composite_construct_vector_wrong_element_type(%arg0: f32, %arg1: f32, %arg2 : vector<2xi32>) -> vector<4xf32> {
80  // expected-error @+1 {{operand element type mismatch: expected to be 'f32', but provided 'i32'}}
81  %0 = spirv.CompositeConstruct %arg0, %arg2, %arg1 : (f32, vector<2xi32>, f32) -> vector<4xf32>
82  return %0: vector<4xf32>
83}
84
85// -----
86
87func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2 : vector<2xf32>) -> vector<4xf32> {
88  // expected-error @+1 {{op has incorrect number of operands: expected 4, but provided 3}}
89  %0 = spirv.CompositeConstruct %arg0, %arg2 : (f32, vector<2xf32>) -> vector<4xf32>
90  return %0: vector<4xf32>
91}
92
93// -----
94
95//===----------------------------------------------------------------------===//
96// spirv.CompositeExtractOp
97//===----------------------------------------------------------------------===//
98
99func.func @composite_extract_array(%arg0: !spirv.array<4xf32>) -> f32 {
100  // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[1 : i32] : !spirv.array<4 x f32>
101  %0 = spirv.CompositeExtract %arg0[1 : i32] : !spirv.array<4xf32>
102  return %0: f32
103}
104
105// -----
106
107func.func @composite_extract_struct(%arg0 : !spirv.struct<(f32, !spirv.array<4xf32>)>) -> f32 {
108  // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[1 : i32, 2 : i32] : !spirv.struct<(f32, !spirv.array<4 x f32>)>
109  %0 = spirv.CompositeExtract %arg0[1 : i32, 2 : i32] : !spirv.struct<(f32, !spirv.array<4xf32>)>
110  return %0 : f32
111}
112
113// -----
114
115func.func @composite_extract_vector(%arg0 : vector<4xf32>) -> f32 {
116  // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[1 : i32] : vector<4xf32>
117  %0 = spirv.CompositeExtract %arg0[1 : i32] : vector<4xf32>
118  return %0 : f32
119}
120
121// -----
122
123func.func @composite_extract_no_ssa_operand() -> () {
124  // expected-error @+1 {{expected SSA operand}}
125  %0 = spirv.CompositeExtract [4 : i32, 1 : i32] : !spirv.array<4x!spirv.array<4xf32>>
126  return
127}
128
129// -----
130
131func.func @composite_extract_invalid_index_type_1() -> () {
132  %0 = spirv.Constant 10 : i32
133  %1 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
134  %2 = spirv.Load "Function" %1 ["Volatile"] : !spirv.array<4x!spirv.array<4xf32>>
135  // expected-error @+1 {{expected attribute value}}
136  %3 = spirv.CompositeExtract %2[%0] : !spirv.array<4x!spirv.array<4xf32>>
137  return
138}
139
140// -----
141
142func.func @composite_extract_invalid_index_type_2(%arg0 : !spirv.array<4x!spirv.array<4xf32>>) -> () {
143  // expected-error @+1 {{attribute 'indices' failed to satisfy constraint: 32-bit integer array attribute}}
144  %0 = spirv.CompositeExtract %arg0[1] : !spirv.array<4x!spirv.array<4xf32>>
145  return
146}
147
148// -----
149
150func.func @composite_extract_invalid_index_identifier(%arg0 : !spirv.array<4x!spirv.array<4xf32>>) -> () {
151  // expected-error @+1 {{expected attribute value}}
152  %0 = spirv.CompositeExtract %arg0 ]1 : i32) : !spirv.array<4x!spirv.array<4xf32>>
153  return
154}
155
156// -----
157
158func.func @composite_extract_2D_array_out_of_bounds_access_1(%arg0: !spirv.array<4x!spirv.array<4xf32>>) -> () {
159  // expected-error @+1 {{index 4 out of bounds for '!spirv.array<4 x !spirv.array<4 x f32>>'}}
160  %0 = spirv.CompositeExtract %arg0[4 : i32, 1 : i32] : !spirv.array<4x!spirv.array<4xf32>>
161  return
162}
163
164// -----
165
166func.func @composite_extract_2D_array_out_of_bounds_access_2(%arg0: !spirv.array<4x!spirv.array<4xf32>>
167) -> () {
168  // expected-error @+1 {{index 4 out of bounds for '!spirv.array<4 x f32>'}}
169  %0 = spirv.CompositeExtract %arg0[1 : i32, 4 : i32] : !spirv.array<4x!spirv.array<4xf32>>
170  return
171}
172
173// -----
174
175func.func @composite_extract_struct_element_out_of_bounds_access(%arg0 : !spirv.struct<(f32, !spirv.array<4xf32>)>) -> () {
176  // expected-error @+1 {{index 2 out of bounds for '!spirv.struct<(f32, !spirv.array<4 x f32>)>'}}
177  %0 = spirv.CompositeExtract %arg0[2 : i32, 0 : i32] : !spirv.struct<(f32, !spirv.array<4xf32>)>
178  return
179}
180
181// -----
182
183func.func @composite_extract_vector_out_of_bounds_access(%arg0: vector<4xf32>) -> () {
184  // expected-error @+1 {{index 4 out of bounds for 'vector<4xf32>'}}
185  %0 = spirv.CompositeExtract %arg0[4 : i32] : vector<4xf32>
186  return
187}
188
189// -----
190
191func.func @composite_extract_invalid_types_1(%arg0: !spirv.array<4x!spirv.array<4xf32>>) -> () {
192  // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 3}}
193  %0 = spirv.CompositeExtract %arg0[1 : i32, 2 : i32, 3 : i32] : !spirv.array<4x!spirv.array<4xf32>>
194  return
195}
196
197// -----
198
199func.func @composite_extract_invalid_types_2(%arg0: f32) -> () {
200  // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 1}}
201  %0 = spirv.CompositeExtract %arg0[1 : i32] : f32
202  return
203}
204
205// -----
206
207func.func @composite_extract_invalid_extracted_type(%arg0: !spirv.array<4x!spirv.array<4xf32>>) -> () {
208  // expected-error @+1 {{expected at least one index for spirv.CompositeExtract}}
209  %0 = spirv.CompositeExtract %arg0[] : !spirv.array<4x!spirv.array<4xf32>>
210  return
211}
212
213// -----
214
215func.func @composite_extract_result_type_mismatch(%arg0: !spirv.array<4xf32>) -> i32 {
216  // expected-error @+1 {{invalid result type: expected 'f32' but provided 'i32'}}
217  %0 = "spirv.CompositeExtract"(%arg0) {indices = [2: i32]} : (!spirv.array<4xf32>) -> (i32)
218  return %0: i32
219}
220
221// -----
222
223//===----------------------------------------------------------------------===//
224// spirv.CompositeInsert
225//===----------------------------------------------------------------------===//
226
227func.func @composite_insert_array(%arg0: !spirv.array<4xf32>, %arg1: f32) -> !spirv.array<4xf32> {
228  // CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32] : f32 into !spirv.array<4 x f32>
229  %0 = spirv.CompositeInsert %arg1, %arg0[1 : i32] : f32 into !spirv.array<4xf32>
230  return %0: !spirv.array<4xf32>
231}
232
233// -----
234
235func.func @composite_insert_struct(%arg0: !spirv.struct<(!spirv.array<4xf32>, f32)>, %arg1: !spirv.array<4xf32>) -> !spirv.struct<(!spirv.array<4xf32>, f32)> {
236  // CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : !spirv.array<4 x f32> into !spirv.struct<(!spirv.array<4 x f32>, f32)>
237  %0 = spirv.CompositeInsert %arg1, %arg0[0 : i32] : !spirv.array<4xf32> into !spirv.struct<(!spirv.array<4xf32>, f32)>
238  return %0: !spirv.struct<(!spirv.array<4xf32>, f32)>
239}
240
241// -----
242
243func.func @composite_insert_no_indices(%arg0: !spirv.array<4xf32>, %arg1: f32) -> !spirv.array<4xf32> {
244  // expected-error @+1 {{expected at least one index}}
245  %0 = spirv.CompositeInsert %arg1, %arg0[] : f32 into !spirv.array<4xf32>
246  return %0: !spirv.array<4xf32>
247}
248
249// -----
250
251func.func @composite_insert_out_of_bounds(%arg0: !spirv.array<4xf32>, %arg1: f32) -> !spirv.array<4xf32> {
252  // expected-error @+1 {{index 4 out of bounds}}
253  %0 = spirv.CompositeInsert %arg1, %arg0[4 : i32] : f32 into !spirv.array<4xf32>
254  return %0: !spirv.array<4xf32>
255}
256
257// -----
258
259func.func @composite_insert_invalid_object_type(%arg0: !spirv.array<4xf32>, %arg1: f64) -> !spirv.array<4xf32> {
260  // expected-error @+1 {{object operand type should be 'f32', but found 'f64'}}
261  %0 = spirv.CompositeInsert %arg1, %arg0[3 : i32] : f64 into !spirv.array<4xf32>
262  return %0: !spirv.array<4xf32>
263}
264
265// -----
266
267func.func @composite_insert_invalid_result_type(%arg0: !spirv.array<4xf32>, %arg1 : f32) -> !spirv.array<4xf64> {
268  // expected-error @+1 {{result type should be the same as the composite type, but found '!spirv.array<4 x f32>' vs '!spirv.array<4 x f64>'}}
269  %0 = "spirv.CompositeInsert"(%arg1, %arg0) {indices = [0: i32]} : (f32, !spirv.array<4xf32>) -> !spirv.array<4xf64>
270  return %0: !spirv.array<4xf64>
271}
272
273// -----
274
275//===----------------------------------------------------------------------===//
276// spirv.VectorExtractDynamic
277//===----------------------------------------------------------------------===//
278
279func.func @vector_dynamic_extract(%vec: vector<4xf32>, %id : i32) -> f32 {
280  // CHECK: spirv.VectorExtractDynamic %{{.*}}[%{{.*}}] : vector<4xf32>, i32
281  %0 = spirv.VectorExtractDynamic %vec[%id] : vector<4xf32>, i32
282  return %0 : f32
283}
284
285//===----------------------------------------------------------------------===//
286// spirv.VectorInsertDynamic
287//===----------------------------------------------------------------------===//
288
289func.func @vector_dynamic_insert(%val: f32, %vec: vector<4xf32>, %id : i32) -> vector<4xf32> {
290  // CHECK: spirv.VectorInsertDynamic %{{.*}}, %{{.*}}[%{{.*}}] : vector<4xf32>, i32
291  %0 = spirv.VectorInsertDynamic %val, %vec[%id] : vector<4xf32>, i32
292  return %0 : vector<4xf32>
293}
294
295// -----
296
297//===----------------------------------------------------------------------===//
298// spirv.VectorShuffle
299//===----------------------------------------------------------------------===//
300
301func.func @vector_shuffle(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> {
302  // CHECK: %{{.+}} = spirv.VectorShuffle [1 : i32, 3 : i32, -1 : i32] %{{.+}}, %arg1 : vector<4xf32>, vector<2xf32> -> vector<3xf32>
303  %0 = spirv.VectorShuffle [1: i32, 3: i32, 0xffffffff: i32] %vector1, %vector2 : vector<4xf32>, vector<2xf32> -> vector<3xf32>
304  return %0: vector<3xf32>
305}
306
307// -----
308
309func.func @vector_shuffle_extra_selector(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> {
310  // expected-error @+1 {{result type element count (3) mismatch with the number of component selectors (4)}}
311  %0 = spirv.VectorShuffle [1: i32, 3: i32, 5: i32, 2: i32] %vector1, %vector2 : vector<4xf32>, vector<2xf32> -> vector<3xf32>
312  return %0: vector<3xf32>
313}
314
315// -----
316
317func.func @vector_shuffle_extra_selector(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> {
318  // expected-error @+1 {{component selector 7 out of range: expected to be in [0, 6) or 0xffffffff}}
319  %0 = spirv.VectorShuffle [1: i32, 7: i32, 5: i32] %vector1, %vector2 : vector<4xf32>, vector<2xf32> -> vector<3xf32>
320  return %0: vector<3xf32>
321}
322