1// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s 2 3//===----------------------------------------------------------------------===// 4// spirv.CompositeConstruct 5//===----------------------------------------------------------------------===// 6 7// CHECK-LABEL: func @composite_construct_vector 8func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> { 9 // CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32) -> vector<3xf32> 10 %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xf32> 11 return %0: vector<3xf32> 12} 13 14// CHECK-LABEL: func @composite_construct_struct 15func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spirv.array<4xf32>, %arg2 : !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> { 16 // CHECK: spirv.CompositeConstruct 17 %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> 18 return %0: !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> 19} 20 21// CHECK-LABEL: func @composite_construct_mixed_scalar_vector 22func.func @composite_construct_mixed_scalar_vector(%arg0: f32, %arg1: f32, %arg2 : vector<2xf32>) -> vector<4xf32> { 23 // CHECK: spirv.CompositeConstruct %{{.+}}, %{{.+}}, %{{.+}} : (f32, vector<2xf32>, f32) -> vector<4xf32> 24 %0 = spirv.CompositeConstruct %arg0, %arg2, %arg1 : (f32, vector<2xf32>, f32) -> vector<4xf32> 25 return %0: vector<4xf32> 26} 27 28// CHECK-LABEL: func @composite_construct_coopmatrix_khr 29func.func @composite_construct_coopmatrix_khr(%arg0 : f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> { 30 // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> 31 %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> 32 return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> 33} 34 35// ----- 36 37func.func @composite_construct_invalid_result_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> { 38 // expected-error @+1 {{has incorrect number of operands: expected 3, but provided 2}} 39 %0 = spirv.CompositeConstruct %arg0, %arg2 : (f32, f32) -> vector<3xf32> 40 return %0: vector<3xf32> 41} 42 43// ----- 44 45func.func @composite_construct_invalid_operand_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xi32> { 46 // expected-error @+1 {{operand type mismatch: expected operand type 'i32', but provided 'f32'}} 47 %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xi32> 48 return %0: vector<3xi32> 49} 50 51// ----- 52 53func.func @composite_construct_khr_coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> 54 !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> { 55 // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}} 56 %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> 57 return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> 58} 59 60// ----- 61 62func.func @composite_construct_khr_coopmatrix_incorrect_element_type(%arg0 : i32) -> 63 !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB> { 64 // expected-error @+1 {{operand type mismatch: expected operand type 'f32', but provided 'i32'}} 65 %0 = spirv.CompositeConstruct %arg0 : (i32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB> 66 return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB> 67} 68 69// ----- 70 71func.func @composite_construct_array(%arg0: f32) -> !spirv.array<4xf32> { 72 // expected-error @+1 {{expected to return a vector or cooperative matrix when the number of constituents is less than what the result needs}} 73 %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.array<4xf32> 74 return %0: !spirv.array<4xf32> 75} 76 77// ----- 78 79func.func @composite_construct_vector_wrong_element_type(%arg0: f32, %arg1: f32, %arg2 : vector<2xi32>) -> vector<4xf32> { 80 // expected-error @+1 {{operand element type mismatch: expected to be 'f32', but provided 'i32'}} 81 %0 = spirv.CompositeConstruct %arg0, %arg2, %arg1 : (f32, vector<2xi32>, f32) -> vector<4xf32> 82 return %0: vector<4xf32> 83} 84 85// ----- 86 87func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2 : vector<2xf32>) -> vector<4xf32> { 88 // expected-error @+1 {{op has incorrect number of operands: expected 4, but provided 3}} 89 %0 = spirv.CompositeConstruct %arg0, %arg2 : (f32, vector<2xf32>) -> vector<4xf32> 90 return %0: vector<4xf32> 91} 92 93// ----- 94 95//===----------------------------------------------------------------------===// 96// spirv.CompositeExtractOp 97//===----------------------------------------------------------------------===// 98 99func.func @composite_extract_array(%arg0: !spirv.array<4xf32>) -> f32 { 100 // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[1 : i32] : !spirv.array<4 x f32> 101 %0 = spirv.CompositeExtract %arg0[1 : i32] : !spirv.array<4xf32> 102 return %0: f32 103} 104 105// ----- 106 107func.func @composite_extract_struct(%arg0 : !spirv.struct<(f32, !spirv.array<4xf32>)>) -> f32 { 108 // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[1 : i32, 2 : i32] : !spirv.struct<(f32, !spirv.array<4 x f32>)> 109 %0 = spirv.CompositeExtract %arg0[1 : i32, 2 : i32] : !spirv.struct<(f32, !spirv.array<4xf32>)> 110 return %0 : f32 111} 112 113// ----- 114 115func.func @composite_extract_vector(%arg0 : vector<4xf32>) -> f32 { 116 // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[1 : i32] : vector<4xf32> 117 %0 = spirv.CompositeExtract %arg0[1 : i32] : vector<4xf32> 118 return %0 : f32 119} 120 121// ----- 122 123func.func @composite_extract_no_ssa_operand() -> () { 124 // expected-error @+1 {{expected SSA operand}} 125 %0 = spirv.CompositeExtract [4 : i32, 1 : i32] : !spirv.array<4x!spirv.array<4xf32>> 126 return 127} 128 129// ----- 130 131func.func @composite_extract_invalid_index_type_1() -> () { 132 %0 = spirv.Constant 10 : i32 133 %1 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function> 134 %2 = spirv.Load "Function" %1 ["Volatile"] : !spirv.array<4x!spirv.array<4xf32>> 135 // expected-error @+1 {{expected attribute value}} 136 %3 = spirv.CompositeExtract %2[%0] : !spirv.array<4x!spirv.array<4xf32>> 137 return 138} 139 140// ----- 141 142func.func @composite_extract_invalid_index_type_2(%arg0 : !spirv.array<4x!spirv.array<4xf32>>) -> () { 143 // expected-error @+1 {{attribute 'indices' failed to satisfy constraint: 32-bit integer array attribute}} 144 %0 = spirv.CompositeExtract %arg0[1] : !spirv.array<4x!spirv.array<4xf32>> 145 return 146} 147 148// ----- 149 150func.func @composite_extract_invalid_index_identifier(%arg0 : !spirv.array<4x!spirv.array<4xf32>>) -> () { 151 // expected-error @+1 {{expected attribute value}} 152 %0 = spirv.CompositeExtract %arg0 ]1 : i32) : !spirv.array<4x!spirv.array<4xf32>> 153 return 154} 155 156// ----- 157 158func.func @composite_extract_2D_array_out_of_bounds_access_1(%arg0: !spirv.array<4x!spirv.array<4xf32>>) -> () { 159 // expected-error @+1 {{index 4 out of bounds for '!spirv.array<4 x !spirv.array<4 x f32>>'}} 160 %0 = spirv.CompositeExtract %arg0[4 : i32, 1 : i32] : !spirv.array<4x!spirv.array<4xf32>> 161 return 162} 163 164// ----- 165 166func.func @composite_extract_2D_array_out_of_bounds_access_2(%arg0: !spirv.array<4x!spirv.array<4xf32>> 167) -> () { 168 // expected-error @+1 {{index 4 out of bounds for '!spirv.array<4 x f32>'}} 169 %0 = spirv.CompositeExtract %arg0[1 : i32, 4 : i32] : !spirv.array<4x!spirv.array<4xf32>> 170 return 171} 172 173// ----- 174 175func.func @composite_extract_struct_element_out_of_bounds_access(%arg0 : !spirv.struct<(f32, !spirv.array<4xf32>)>) -> () { 176 // expected-error @+1 {{index 2 out of bounds for '!spirv.struct<(f32, !spirv.array<4 x f32>)>'}} 177 %0 = spirv.CompositeExtract %arg0[2 : i32, 0 : i32] : !spirv.struct<(f32, !spirv.array<4xf32>)> 178 return 179} 180 181// ----- 182 183func.func @composite_extract_vector_out_of_bounds_access(%arg0: vector<4xf32>) -> () { 184 // expected-error @+1 {{index 4 out of bounds for 'vector<4xf32>'}} 185 %0 = spirv.CompositeExtract %arg0[4 : i32] : vector<4xf32> 186 return 187} 188 189// ----- 190 191func.func @composite_extract_invalid_types_1(%arg0: !spirv.array<4x!spirv.array<4xf32>>) -> () { 192 // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 3}} 193 %0 = spirv.CompositeExtract %arg0[1 : i32, 2 : i32, 3 : i32] : !spirv.array<4x!spirv.array<4xf32>> 194 return 195} 196 197// ----- 198 199func.func @composite_extract_invalid_types_2(%arg0: f32) -> () { 200 // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 1}} 201 %0 = spirv.CompositeExtract %arg0[1 : i32] : f32 202 return 203} 204 205// ----- 206 207func.func @composite_extract_invalid_extracted_type(%arg0: !spirv.array<4x!spirv.array<4xf32>>) -> () { 208 // expected-error @+1 {{expected at least one index for spirv.CompositeExtract}} 209 %0 = spirv.CompositeExtract %arg0[] : !spirv.array<4x!spirv.array<4xf32>> 210 return 211} 212 213// ----- 214 215func.func @composite_extract_result_type_mismatch(%arg0: !spirv.array<4xf32>) -> i32 { 216 // expected-error @+1 {{invalid result type: expected 'f32' but provided 'i32'}} 217 %0 = "spirv.CompositeExtract"(%arg0) {indices = [2: i32]} : (!spirv.array<4xf32>) -> (i32) 218 return %0: i32 219} 220 221// ----- 222 223//===----------------------------------------------------------------------===// 224// spirv.CompositeInsert 225//===----------------------------------------------------------------------===// 226 227func.func @composite_insert_array(%arg0: !spirv.array<4xf32>, %arg1: f32) -> !spirv.array<4xf32> { 228 // CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32] : f32 into !spirv.array<4 x f32> 229 %0 = spirv.CompositeInsert %arg1, %arg0[1 : i32] : f32 into !spirv.array<4xf32> 230 return %0: !spirv.array<4xf32> 231} 232 233// ----- 234 235func.func @composite_insert_struct(%arg0: !spirv.struct<(!spirv.array<4xf32>, f32)>, %arg1: !spirv.array<4xf32>) -> !spirv.struct<(!spirv.array<4xf32>, f32)> { 236 // CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : !spirv.array<4 x f32> into !spirv.struct<(!spirv.array<4 x f32>, f32)> 237 %0 = spirv.CompositeInsert %arg1, %arg0[0 : i32] : !spirv.array<4xf32> into !spirv.struct<(!spirv.array<4xf32>, f32)> 238 return %0: !spirv.struct<(!spirv.array<4xf32>, f32)> 239} 240 241// ----- 242 243func.func @composite_insert_no_indices(%arg0: !spirv.array<4xf32>, %arg1: f32) -> !spirv.array<4xf32> { 244 // expected-error @+1 {{expected at least one index}} 245 %0 = spirv.CompositeInsert %arg1, %arg0[] : f32 into !spirv.array<4xf32> 246 return %0: !spirv.array<4xf32> 247} 248 249// ----- 250 251func.func @composite_insert_out_of_bounds(%arg0: !spirv.array<4xf32>, %arg1: f32) -> !spirv.array<4xf32> { 252 // expected-error @+1 {{index 4 out of bounds}} 253 %0 = spirv.CompositeInsert %arg1, %arg0[4 : i32] : f32 into !spirv.array<4xf32> 254 return %0: !spirv.array<4xf32> 255} 256 257// ----- 258 259func.func @composite_insert_invalid_object_type(%arg0: !spirv.array<4xf32>, %arg1: f64) -> !spirv.array<4xf32> { 260 // expected-error @+1 {{object operand type should be 'f32', but found 'f64'}} 261 %0 = spirv.CompositeInsert %arg1, %arg0[3 : i32] : f64 into !spirv.array<4xf32> 262 return %0: !spirv.array<4xf32> 263} 264 265// ----- 266 267func.func @composite_insert_invalid_result_type(%arg0: !spirv.array<4xf32>, %arg1 : f32) -> !spirv.array<4xf64> { 268 // expected-error @+1 {{result type should be the same as the composite type, but found '!spirv.array<4 x f32>' vs '!spirv.array<4 x f64>'}} 269 %0 = "spirv.CompositeInsert"(%arg1, %arg0) {indices = [0: i32]} : (f32, !spirv.array<4xf32>) -> !spirv.array<4xf64> 270 return %0: !spirv.array<4xf64> 271} 272 273// ----- 274 275//===----------------------------------------------------------------------===// 276// spirv.VectorExtractDynamic 277//===----------------------------------------------------------------------===// 278 279func.func @vector_dynamic_extract(%vec: vector<4xf32>, %id : i32) -> f32 { 280 // CHECK: spirv.VectorExtractDynamic %{{.*}}[%{{.*}}] : vector<4xf32>, i32 281 %0 = spirv.VectorExtractDynamic %vec[%id] : vector<4xf32>, i32 282 return %0 : f32 283} 284 285//===----------------------------------------------------------------------===// 286// spirv.VectorInsertDynamic 287//===----------------------------------------------------------------------===// 288 289func.func @vector_dynamic_insert(%val: f32, %vec: vector<4xf32>, %id : i32) -> vector<4xf32> { 290 // CHECK: spirv.VectorInsertDynamic %{{.*}}, %{{.*}}[%{{.*}}] : vector<4xf32>, i32 291 %0 = spirv.VectorInsertDynamic %val, %vec[%id] : vector<4xf32>, i32 292 return %0 : vector<4xf32> 293} 294 295// ----- 296 297//===----------------------------------------------------------------------===// 298// spirv.VectorShuffle 299//===----------------------------------------------------------------------===// 300 301func.func @vector_shuffle(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> { 302 // CHECK: %{{.+}} = spirv.VectorShuffle [1 : i32, 3 : i32, -1 : i32] %{{.+}}, %arg1 : vector<4xf32>, vector<2xf32> -> vector<3xf32> 303 %0 = spirv.VectorShuffle [1: i32, 3: i32, 0xffffffff: i32] %vector1, %vector2 : vector<4xf32>, vector<2xf32> -> vector<3xf32> 304 return %0: vector<3xf32> 305} 306 307// ----- 308 309func.func @vector_shuffle_extra_selector(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> { 310 // expected-error @+1 {{result type element count (3) mismatch with the number of component selectors (4)}} 311 %0 = spirv.VectorShuffle [1: i32, 3: i32, 5: i32, 2: i32] %vector1, %vector2 : vector<4xf32>, vector<2xf32> -> vector<3xf32> 312 return %0: vector<3xf32> 313} 314 315// ----- 316 317func.func @vector_shuffle_extra_selector(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> { 318 // expected-error @+1 {{component selector 7 out of range: expected to be in [0, 6) or 0xffffffff}} 319 %0 = spirv.VectorShuffle [1: i32, 7: i32, 5: i32] %vector1, %vector2 : vector<4xf32>, vector<2xf32> -> vector<3xf32> 320 return %0: vector<3xf32> 321} 322