1// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s 2 3// Check that with proper compute and storage extensions, we don't need to 4// perform special tricks. 5 6module attributes { 7 spirv.target_env = #spirv.target_env< 8 #spirv.vce<v1.5, 9 [ 10 Shader, Int8, Int16, Int64, Float16, Float64, 11 StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, 12 StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, 13 PhysicalStorageBufferAddresses 14 ], 15 [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_physical_storage_buffer]>, 16 #spirv.resource_limits<>> 17} { 18 19// CHECK-LABEL: @load_store_zero_rank_float( 20// CHECK-SAME: %[[OARG0:.*]]: memref{{.*}}, %[[OARG1:.*]]: memref 21func.func @load_store_zero_rank_float(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>, %arg1: memref<f32, #spirv.storage_class<StorageBuffer>>) { 22 // CHECK-DAG: [[ARG0:%.*]] = builtin.unrealized_conversion_cast %[[OARG0]] : memref<f32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer> 23 // CHECK-DAG: [[ARG1:%.*]] = builtin.unrealized_conversion_cast %[[OARG1]] : memref<f32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer> 24 // CHECK: [[ZERO:%.*]] = spirv.Constant 0 : i32 25 // CHECK: spirv.AccessChain [[ARG0]][ 26 // CHECK-SAME: [[ZERO]], [[ZERO]] 27 // CHECK-SAME: ] : 28 // CHECK: spirv.Load "StorageBuffer" %{{.*}} : f32 29 %0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<StorageBuffer>> 30 // CHECK: spirv.AccessChain [[ARG1]][ 31 // CHECK-SAME: [[ZERO]], [[ZERO]] 32 // CHECK-SAME: ] : 33 // CHECK: spirv.Store "StorageBuffer" %{{.*}} : f32 34 memref.store %0, %arg1[] : memref<f32, #spirv.storage_class<StorageBuffer>> 35 return 36} 37 38// CHECK-LABEL: @load_store_zero_rank_int 39// CHECK-SAME: %[[OARG0:.*]]: memref{{.*}}, %[[OARG1:.*]]: memref 40func.func @load_store_zero_rank_int(%arg0: memref<i32, #spirv.storage_class<StorageBuffer>>, %arg1: memref<i32, #spirv.storage_class<StorageBuffer>>) { 41 // CHECK-DAG: [[ARG0:%.*]] = builtin.unrealized_conversion_cast %[[OARG0]] : memref<i32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer> 42 // CHECK-DAG: [[ARG1:%.*]] = builtin.unrealized_conversion_cast %[[OARG1]] : memref<i32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer> 43 // CHECK: [[ZERO:%.*]] = spirv.Constant 0 : i32 44 // CHECK: spirv.AccessChain [[ARG0]][ 45 // CHECK-SAME: [[ZERO]], [[ZERO]] 46 // CHECK-SAME: ] : 47 // CHECK: spirv.Load "StorageBuffer" %{{.*}} : i32 48 %0 = memref.load %arg0[] : memref<i32, #spirv.storage_class<StorageBuffer>> 49 // CHECK: spirv.AccessChain [[ARG1]][ 50 // CHECK-SAME: [[ZERO]], [[ZERO]] 51 // CHECK-SAME: ] : 52 // CHECK: spirv.Store "StorageBuffer" %{{.*}} : i32 53 memref.store %0, %arg1[] : memref<i32, #spirv.storage_class<StorageBuffer>> 54 return 55} 56 57// CHECK-LABEL: func @load_store_unknown_dim 58// CHECK-SAME: %[[OARG0:.*]]: index, %[[OARG1:.*]]: memref{{.*}}, %[[OARG2:.*]]: memref 59func.func @load_store_unknown_dim(%i: index, %source: memref<?xi32, #spirv.storage_class<StorageBuffer>>, %dest: memref<?xi32, #spirv.storage_class<StorageBuffer>>) { 60 // CHECK-DAG: %[[SRC:.+]] = builtin.unrealized_conversion_cast %[[OARG1]] : memref<?xi32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32, stride=4> [0])>, StorageBuffer> 61 // CHECK-DAG: %[[DST:.+]] = builtin.unrealized_conversion_cast %[[OARG2]] : memref<?xi32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32, stride=4> [0])>, StorageBuffer> 62 // CHECK: %[[AC0:.+]] = spirv.AccessChain %[[SRC]] 63 // CHECK: spirv.Load "StorageBuffer" %[[AC0]] 64 %0 = memref.load %source[%i] : memref<?xi32, #spirv.storage_class<StorageBuffer>> 65 // CHECK: %[[AC1:.+]] = spirv.AccessChain %[[DST]] 66 // CHECK: spirv.Store "StorageBuffer" %[[AC1]] 67 memref.store %0, %dest[%i]: memref<?xi32, #spirv.storage_class<StorageBuffer>> 68 return 69} 70 71// CHECK-LABEL: func @load_i1 72// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %[[IDX:.+]]: index) 73func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i : index) -> i1 { 74 // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i8, stride=1> [0])>, StorageBuffer> 75 // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] 76 // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32 77 // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]] 78 // CHECK: %[[VAL:.+]] = spirv.Load "StorageBuffer" %[[ADDR]] : i8 79 // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8 80 // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8 81 %0 = memref.load %src[%i] : memref<4xi1, #spirv.storage_class<StorageBuffer>> 82 // CHECK: return %[[BOOL]] 83 return %0: i1 84} 85 86// CHECK-LABEL: func @store_i1 87// CHECK-SAME: %[[DST:.+]]: memref<4xi1, #spirv.storage_class<StorageBuffer>>, 88// CHECK-SAME: %[[IDX:.+]]: index 89func.func @store_i1(%dst: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i: index) { 90 %true = arith.constant true 91 // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i8, stride=1> [0])>, StorageBuffer> 92 // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] 93 // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32 94 // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[DST_CAST]][%[[ZERO]], %[[IDX_CAST]]] 95 // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8 96 // CHECK: spirv.Store "StorageBuffer" %[[ADDR]], %[[ONE_I8]] : i8 97 memref.store %true, %dst[%i]: memref<4xi1, #spirv.storage_class<StorageBuffer>> 98 return 99} 100 101// CHECK-LABEL: @load_i16 102func.func @load_i16(%arg0: memref<i16, #spirv.storage_class<StorageBuffer>>) { 103 // CHECK-NOT: spirv.SDiv 104 // CHECK: spirv.Load 105 // CHECK-NOT: spirv.ShiftRightArithmetic 106 %0 = memref.load %arg0[] : memref<i16, #spirv.storage_class<StorageBuffer>> 107 return 108} 109 110// CHECK-LABEL: @store_i16 111func.func @store_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>, %index: index, %value: i16) { 112 // CHECK: spirv.Store 113 // CHECK-NOT: spirv.AtomicAnd 114 // CHECK-NOT: spirv.AtomicOr 115 memref.store %value, %arg0[%index] : memref<10xi16, #spirv.storage_class<StorageBuffer>> 116 return 117} 118 119// CHECK-LABEL: @load_store_i32_physical 120func.func @load_store_i32_physical(%arg0: memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>) { 121 // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : i32 122 // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : i32 123 %0 = memref.load %arg0[] : memref<i32, #spirv.storage_class<PhysicalStorageBuffer>> 124 memref.store %0, %arg0[] : memref<i32, #spirv.storage_class<PhysicalStorageBuffer>> 125 return 126} 127 128// CHECK-LABEL: @load_store_i8_physical 129func.func @load_store_i8_physical(%arg0: memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>) { 130 // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 1] : i8 131 // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 1] : i8 132 %0 = memref.load %arg0[] : memref<i8, #spirv.storage_class<PhysicalStorageBuffer>> 133 memref.store %0, %arg0[] : memref<i8, #spirv.storage_class<PhysicalStorageBuffer>> 134 return 135} 136 137// CHECK-LABEL: @load_store_i1_physical 138func.func @load_store_i1_physical(%arg0: memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>) { 139 // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 1] : i8 140 // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 1] : i8 141 %0 = memref.load %arg0[] : memref<i1, #spirv.storage_class<PhysicalStorageBuffer>> 142 memref.store %0, %arg0[] : memref<i1, #spirv.storage_class<PhysicalStorageBuffer>> 143 return 144} 145 146// CHECK-LABEL: @load_store_f32_physical 147func.func @load_store_f32_physical(%arg0: memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>) { 148 // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : f32 149 // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : f32 150 %0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>> 151 memref.store %0, %arg0[] : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>> 152 return 153} 154 155// CHECK-LABEL: @load_store_f16_physical 156func.func @load_store_f16_physical(%arg0: memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>) { 157 // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 2] : f16 158 // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 2] : f16 159 %0 = memref.load %arg0[] : memref<f16, #spirv.storage_class<PhysicalStorageBuffer>> 160 memref.store %0, %arg0[] : memref<f16, #spirv.storage_class<PhysicalStorageBuffer>> 161 return 162} 163 164} // end module 165 166// ----- 167 168// Check for Kernel capability, that with proper compute and storage extensions, we don't need to 169// perform special tricks. 170 171module attributes { 172 spirv.target_env = #spirv.target_env< 173 #spirv.vce<v1.0, 174 [ 175 Kernel, Addresses, Int8, Int16, Int64, Float16, Float64], []>, #spirv.resource_limits<>> 176} { 177 178// CHECK-LABEL: @load_store_zero_rank_float 179// CHECK-SAME: %[[OARG0:.*]]: memref{{.*}}, %[[OARG1:.*]]: memref 180func.func @load_store_zero_rank_float(%arg0: memref<f32, #spirv.storage_class<CrossWorkgroup>>, %arg1: memref<f32, #spirv.storage_class<CrossWorkgroup>>) { 181 // CHECK-DAG: [[ARG0:%.*]] = builtin.unrealized_conversion_cast %[[OARG0]] : memref<f32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x f32>, CrossWorkgroup> 182 // CHECK-DAG: [[ARG1:%.*]] = builtin.unrealized_conversion_cast %[[OARG1]] : memref<f32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x f32>, CrossWorkgroup> 183 // CHECK: [[ZERO:%.*]] = spirv.Constant 0 : i32 184 // CHECK: spirv.AccessChain [[ARG0]][ 185 // CHECK-SAME: [[ZERO]] 186 // CHECK-SAME: ] : 187 // CHECK: spirv.Load "CrossWorkgroup" %{{.*}} : f32 188 %0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<CrossWorkgroup>> 189 // CHECK: spirv.AccessChain [[ARG1]][ 190 // CHECK-SAME: [[ZERO]] 191 // CHECK-SAME: ] : 192 // CHECK: spirv.Store "CrossWorkgroup" %{{.*}} : f32 193 memref.store %0, %arg1[] : memref<f32, #spirv.storage_class<CrossWorkgroup>> 194 return 195} 196 197// CHECK-LABEL: @load_store_zero_rank_int 198// CHECK-SAME: %[[OARG0:.*]]: memref{{.*}}, %[[OARG1:.*]]: memref 199func.func @load_store_zero_rank_int(%arg0: memref<i32, #spirv.storage_class<CrossWorkgroup>>, %arg1: memref<i32, #spirv.storage_class<CrossWorkgroup>>) { 200 // CHECK-DAG: [[ARG0:%.*]] = builtin.unrealized_conversion_cast %[[OARG0]] : memref<i32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x i32>, CrossWorkgroup> 201 // CHECK-DAG: [[ARG1:%.*]] = builtin.unrealized_conversion_cast %[[OARG1]] : memref<i32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x i32>, CrossWorkgroup> 202 // CHECK: [[ZERO:%.*]] = spirv.Constant 0 : i32 203 // CHECK: spirv.AccessChain [[ARG0]][ 204 // CHECK-SAME: [[ZERO]] 205 // CHECK-SAME: ] : 206 // CHECK: spirv.Load "CrossWorkgroup" %{{.*}} : i32 207 %0 = memref.load %arg0[] : memref<i32, #spirv.storage_class<CrossWorkgroup>> 208 // CHECK: spirv.AccessChain [[ARG1]][ 209 // CHECK-SAME: [[ZERO]] 210 // CHECK-SAME: ] : 211 // CHECK: spirv.Store "CrossWorkgroup" %{{.*}} : i32 212 memref.store %0, %arg1[] : memref<i32, #spirv.storage_class<CrossWorkgroup>> 213 return 214} 215 216// CHECK-LABEL: func @load_store_unknown_dim 217// CHECK-SAME: %[[OARG0:.*]]: index, %[[OARG1:.*]]: memref{{.*}}, %[[OARG2:.*]]: memref 218func.func @load_store_unknown_dim(%i: index, %source: memref<?xi32, #spirv.storage_class<CrossWorkgroup>>, %dest: memref<?xi32, #spirv.storage_class<CrossWorkgroup>>) { 219 // CHECK-DAG: %[[SRC:.+]] = builtin.unrealized_conversion_cast %[[OARG1]] : memref<?xi32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<i32, CrossWorkgroup> 220 // CHECK-DAG: %[[DST:.+]] = builtin.unrealized_conversion_cast %[[OARG2]] : memref<?xi32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<i32, CrossWorkgroup> 221 // CHECK: %[[AC0:.+]] = spirv.PtrAccessChain %[[SRC]] 222 // CHECK: spirv.Load "CrossWorkgroup" %[[AC0]] 223 %0 = memref.load %source[%i] : memref<?xi32, #spirv.storage_class<CrossWorkgroup>> 224 // CHECK: %[[AC1:.+]] = spirv.PtrAccessChain %[[DST]] 225 // CHECK: spirv.Store "CrossWorkgroup" %[[AC1]] 226 memref.store %0, %dest[%i]: memref<?xi32, #spirv.storage_class<CrossWorkgroup>> 227 return 228} 229 230// CHECK-LABEL: func @load_i1 231// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %[[IDX:.+]]: index) 232func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i : index) -> i1 { 233 // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<4 x i8>, CrossWorkgroup> 234 // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] 235 // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[IDX_CAST]]] 236 // CHECK: %[[VAL:.+]] = spirv.Load "CrossWorkgroup" %[[ADDR]] : i8 237 // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8 238 // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8 239 %0 = memref.load %src[%i] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> 240 // CHECK: return %[[BOOL]] 241 return %0: i1 242} 243 244// CHECK-LABEL: func @store_i1 245// CHECK-SAME: %[[DST:.+]]: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, 246// CHECK-SAME: %[[IDX:.+]]: index 247func.func @store_i1(%dst: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i: index) { 248 %true = arith.constant true 249 // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<4 x i8>, CrossWorkgroup> 250 // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] 251 // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[DST_CAST]][%[[IDX_CAST]]] 252 // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8 253 // CHECK: spirv.Store "CrossWorkgroup" %[[ADDR]], %[[ONE_I8]] : i8 254 memref.store %true, %dst[%i]: memref<4xi1, #spirv.storage_class<CrossWorkgroup>> 255 return 256} 257 258} // end module 259 260// ----- 261 262// Check address space casts 263 264module attributes { 265 spirv.target_env = #spirv.target_env< 266 #spirv.vce<v1.0, 267 [ 268 Kernel, Addresses, GenericPointer], []>, #spirv.resource_limits<>> 269} { 270 271// CHECK-LABEL: func.func @memory_space_cast 272func.func @memory_space_cast(%arg: memref<4xf32, #spirv.storage_class<CrossWorkgroup>>) 273 -> memref<4xf32, #spirv.storage_class<Function>> { 274 // CHECK: %[[ARG_CAST:.+]] = builtin.unrealized_conversion_cast {{.*}} to !spirv.ptr<!spirv.array<4 x f32>, CrossWorkgroup> 275 // CHECK: %[[TO_GENERIC:.+]] = spirv.PtrCastToGeneric %[[ARG_CAST]] : !spirv.ptr<!spirv.array<4 x f32>, CrossWorkgroup> to !spirv.ptr<!spirv.array<4 x f32>, Generic> 276 // CHECK: %[[TO_PRIVATE:.+]] = spirv.GenericCastToPtr %[[TO_GENERIC]] : !spirv.ptr<!spirv.array<4 x f32>, Generic> to !spirv.ptr<!spirv.array<4 x f32>, Function> 277 // CHECK: %[[RET:.+]] = builtin.unrealized_conversion_cast %[[TO_PRIVATE]] 278 // CHECK: return %[[RET]] 279 %ret = memref.memory_space_cast %arg : memref<4xf32, #spirv.storage_class<CrossWorkgroup>> 280 to memref<4xf32, #spirv.storage_class<Function>> 281 return %ret : memref<4xf32, #spirv.storage_class<Function>> 282} 283 284} // end module 285 286// ----- 287 288// Check that casts are properly inserted if the corresponding **compute** 289// capability is allowed. 290module attributes { 291 spirv.target_env = #spirv.target_env< 292 #spirv.vce<v1.0, [Shader, Int8, Int16], [ 293 SPV_KHR_8bit_storage, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class 294 ]>, #spirv.resource_limits<>> 295} { 296 297// CHECK-LABEL: @load_i1 298func.func @load_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>) -> i1 { 299 // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 300 // CHECK: %[[RES:.+]] = spirv.IEqual %{{.+}}, %[[ONE]] : i32 301 // CHECK: return %[[RES]] 302 %0 = memref.load %arg0[] : memref<i1, #spirv.storage_class<StorageBuffer>> 303 return %0 : i1 304} 305 306// CHECK-LABEL: @load_i8 307func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 { 308 // CHECK: %[[RES:.+]] = spirv.UConvert %{{.+}} : i32 to i8 309 // CHECK: return %[[RES]] 310 %0 = memref.load %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>> 311 return %0 : i8 312} 313 314// CHECK-LABEL: @load_i16 315func.func @load_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>, %index : index) -> i16 { 316 // CHECK: %[[RES:.+]] = spirv.UConvert %{{.+}} : i32 to i16 317 // CHECK: return %[[RES]] 318 %0 = memref.load %arg0[%index] : memref<10xi16, #spirv.storage_class<StorageBuffer>> 319 return %0: i16 320} 321 322} // end module 323 324// ----- 325 326// Check reinterpret_casts 327 328module attributes { 329 spirv.target_env = #spirv.target_env< 330 #spirv.vce<v1.0, 331 [Kernel, Addresses, GenericPointer], []>, #spirv.resource_limits<>> 332} { 333 334// CHECK-LABEL: func.func @reinterpret_cast 335// CHECK-SAME: (%[[MEM:.*]]: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>, %[[OFF:.*]]: index) 336func.func @reinterpret_cast(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>, %arg1: index) -> memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> { 337// CHECK-DAG: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup> 338// CHECK-DAG: %[[OFF1:.*]] = builtin.unrealized_conversion_cast %[[OFF]] : index to i32 339// CHECK: %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF1]]] : !spirv.ptr<f32, CrossWorkgroup>, i32 340// CHECK: %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> 341// CHECK: return %[[RET1]] 342 %ret = memref.reinterpret_cast %arg to offset: [%arg1], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> 343 return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> 344} 345 346// CHECK-LABEL: func.func @reinterpret_cast_0 347// CHECK-SAME: (%[[MEM:.*]]: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) 348func.func @reinterpret_cast_0(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> { 349// CHECK-DAG: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup> 350// CHECK-DAG: %[[RET:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> 351// CHECK: return %[[RET]] 352 %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> 353 return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> 354} 355 356// CHECK-LABEL: func.func @reinterpret_cast_5 357// CHECK-SAME: (%[[MEM:.*]]: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) 358func.func @reinterpret_cast_5(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> { 359// CHECK: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup> 360// CHECK: %[[OFF:.*]] = spirv.Constant 5 : i32 361// CHECK: %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF]]] : !spirv.ptr<f32, CrossWorkgroup>, i32 362// CHECK: %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> 363// CHECK: return %[[RET1]] 364 %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> 365 return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> 366} 367 368} // end module 369 370 371// ----- 372 373// Check casts 374 375module attributes { 376 spirv.target_env = #spirv.target_env< 377 #spirv.vce<v1.0, 378 [Kernel, Addresses, GenericPointer], []>, #spirv.resource_limits<>> 379} { 380 381// CHECK-LABEL: func.func @cast 382// CHECK-SAME: (%[[MEM:.*]]: memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>>) 383func.func @cast(%arg: memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>> { 384// CHECK-DAG: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup> 385// CHECK-DAG: %[[MEM2:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>> 386// CHECK: return %[[MEM2]] 387 %ret = memref.cast %arg : memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>> 388 return %ret : memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>> 389} 390 391// TODO: Not supported yet 392// CHECK-LABEL: func.func @cast_from_static 393// CHECK-SAME: (%[[MEM:.*]]: memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>>) 394func.func @cast_from_static(%arg: memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>> { 395// CHECK: %[[MEM1:.*]] = memref.cast %[[MEM]] : memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>> 396// CHECK: return %[[MEM1]] 397 %ret = memref.cast %arg : memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>> 398 return %ret : memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>> 399} 400 401// TODO: Not supported yet 402// CHECK-LABEL: func.func @cast_to_static 403// CHECK-SAME: (%[[MEM:.*]]: memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>>) 404func.func @cast_to_static(%arg: memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>> { 405// CHECK: %[[MEM1:.*]] = memref.cast %[[MEM]] : memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>> 406// CHECK: return %[[MEM1]] 407 %ret = memref.cast %arg : memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>> 408 return %ret : memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>> 409} 410 411// TODO: Not supported yet 412// CHECK-LABEL: func.func @cast_to_static_zero_elems 413// CHECK-SAME: (%[[MEM:.*]]: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) 414func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<0xf32, #spirv.storage_class<CrossWorkgroup>> { 415// CHECK: %[[MEM1:.*]] = memref.cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<0xf32, #spirv.storage_class<CrossWorkgroup>> 416// CHECK: return %[[MEM1]] 417 %ret = memref.cast %arg : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<0xf32, #spirv.storage_class<CrossWorkgroup>> 418 return %ret : memref<0xf32, #spirv.storage_class<CrossWorkgroup>> 419} 420 421} 422 423// ----- 424 425// Check nontemporal attribute 426 427module attributes { 428 spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [ 429 Shader, 430 PhysicalStorageBufferAddresses 431 ], [ 432 SPV_KHR_storage_buffer_storage_class, 433 SPV_KHR_physical_storage_buffer 434 ]>, #spirv.resource_limits<>> 435} { 436 func.func @load_nontemporal(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) { 437 %0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>> 438// CHECK: spirv.Load "StorageBuffer" %{{.+}} ["Nontemporal"] : f32 439 memref.store %0, %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>> 440// CHECK: spirv.Store "StorageBuffer" %{{.+}}, %{{.+}} ["Nontemporal"] : f32 441 return 442 } 443 444 func.func @load_nontemporal_aligned(%arg0: memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>) { 445 %0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>> 446// CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned|Nontemporal", 4] : f32 447 memref.store %0, %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>> 448// CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned|Nontemporal", 4] : f32 449 return 450 } 451} 452