xref: /llvm-project/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir (revision 27158edaa4e18a7d7275c77e8c483dd29145c3c4)
1// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s
2
3// Check that with proper compute and storage extensions, we don't need to
4// perform special tricks.
5
6module attributes {
7  spirv.target_env = #spirv.target_env<
8    #spirv.vce<v1.5,
9      [
10        Shader, Int8, Int16, Int64, Float16, Float64,
11        StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16,
12        StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8,
13        PhysicalStorageBufferAddresses
14      ],
15      [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_physical_storage_buffer]>,
16      #spirv.resource_limits<>>
17} {
18
19// CHECK-LABEL: @load_store_zero_rank_float(
20//  CHECK-SAME:     %[[OARG0:.*]]: memref{{.*}}, %[[OARG1:.*]]: memref
21func.func @load_store_zero_rank_float(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>, %arg1: memref<f32, #spirv.storage_class<StorageBuffer>>) {
22  //  CHECK-DAG: [[ARG0:%.*]] = builtin.unrealized_conversion_cast %[[OARG0]] : memref<f32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>
23  //  CHECK-DAG: [[ARG1:%.*]] = builtin.unrealized_conversion_cast %[[OARG1]] : memref<f32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>
24  //      CHECK: [[ZERO:%.*]] = spirv.Constant 0 : i32
25  //      CHECK: spirv.AccessChain [[ARG0]][
26  // CHECK-SAME: [[ZERO]], [[ZERO]]
27  // CHECK-SAME: ] :
28  //      CHECK: spirv.Load "StorageBuffer" %{{.*}} : f32
29  %0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<StorageBuffer>>
30  //      CHECK: spirv.AccessChain [[ARG1]][
31  // CHECK-SAME: [[ZERO]], [[ZERO]]
32  // CHECK-SAME: ] :
33  //      CHECK: spirv.Store "StorageBuffer" %{{.*}} : f32
34  memref.store %0, %arg1[] : memref<f32, #spirv.storage_class<StorageBuffer>>
35  return
36}
37
38// CHECK-LABEL: @load_store_zero_rank_int
39//  CHECK-SAME:     %[[OARG0:.*]]: memref{{.*}}, %[[OARG1:.*]]: memref
40func.func @load_store_zero_rank_int(%arg0: memref<i32, #spirv.storage_class<StorageBuffer>>, %arg1: memref<i32, #spirv.storage_class<StorageBuffer>>) {
41  //  CHECK-DAG: [[ARG0:%.*]] = builtin.unrealized_conversion_cast %[[OARG0]] : memref<i32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>
42  //  CHECK-DAG: [[ARG1:%.*]] = builtin.unrealized_conversion_cast %[[OARG1]] : memref<i32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>
43  //      CHECK: [[ZERO:%.*]] = spirv.Constant 0 : i32
44  //      CHECK: spirv.AccessChain [[ARG0]][
45  // CHECK-SAME: [[ZERO]], [[ZERO]]
46  // CHECK-SAME: ] :
47  //      CHECK: spirv.Load "StorageBuffer" %{{.*}} : i32
48  %0 = memref.load %arg0[] : memref<i32, #spirv.storage_class<StorageBuffer>>
49  //      CHECK: spirv.AccessChain [[ARG1]][
50  // CHECK-SAME: [[ZERO]], [[ZERO]]
51  // CHECK-SAME: ] :
52  //      CHECK: spirv.Store "StorageBuffer" %{{.*}} : i32
53  memref.store %0, %arg1[] : memref<i32, #spirv.storage_class<StorageBuffer>>
54  return
55}
56
57// CHECK-LABEL: func @load_store_unknown_dim
58//  CHECK-SAME:     %[[OARG0:.*]]: index, %[[OARG1:.*]]: memref{{.*}}, %[[OARG2:.*]]: memref
59func.func @load_store_unknown_dim(%i: index, %source: memref<?xi32, #spirv.storage_class<StorageBuffer>>, %dest: memref<?xi32, #spirv.storage_class<StorageBuffer>>) {
60  // CHECK-DAG: %[[SRC:.+]] = builtin.unrealized_conversion_cast %[[OARG1]] : memref<?xi32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32, stride=4> [0])>, StorageBuffer>
61  // CHECK-DAG: %[[DST:.+]] = builtin.unrealized_conversion_cast %[[OARG2]] : memref<?xi32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32, stride=4> [0])>, StorageBuffer>
62  // CHECK: %[[AC0:.+]] = spirv.AccessChain %[[SRC]]
63  // CHECK: spirv.Load "StorageBuffer" %[[AC0]]
64  %0 = memref.load %source[%i] : memref<?xi32, #spirv.storage_class<StorageBuffer>>
65  // CHECK: %[[AC1:.+]] = spirv.AccessChain %[[DST]]
66  // CHECK: spirv.Store "StorageBuffer" %[[AC1]]
67  memref.store %0, %dest[%i]: memref<?xi32, #spirv.storage_class<StorageBuffer>>
68  return
69}
70
71// CHECK-LABEL: func @load_i1
72//  CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %[[IDX:.+]]: index)
73func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i : index) -> i1 {
74  // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i8, stride=1> [0])>, StorageBuffer>
75  // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
76  // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
77  // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]]
78  // CHECK: %[[VAL:.+]] = spirv.Load "StorageBuffer" %[[ADDR]] : i8
79  // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
80  // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8
81  %0 = memref.load %src[%i] : memref<4xi1, #spirv.storage_class<StorageBuffer>>
82  // CHECK: return %[[BOOL]]
83  return %0: i1
84}
85
86// CHECK-LABEL: func @store_i1
87//  CHECK-SAME: %[[DST:.+]]: memref<4xi1, #spirv.storage_class<StorageBuffer>>,
88//  CHECK-SAME: %[[IDX:.+]]: index
89func.func @store_i1(%dst: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i: index) {
90  %true = arith.constant true
91  // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i8, stride=1> [0])>, StorageBuffer>
92  // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
93  // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
94  // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[DST_CAST]][%[[ZERO]], %[[IDX_CAST]]]
95  // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
96  // CHECK: spirv.Store "StorageBuffer" %[[ADDR]], %[[ONE_I8]] : i8
97  memref.store %true, %dst[%i]: memref<4xi1, #spirv.storage_class<StorageBuffer>>
98  return
99}
100
101// CHECK-LABEL: @load_i16
102func.func @load_i16(%arg0: memref<i16, #spirv.storage_class<StorageBuffer>>) {
103  // CHECK-NOT: spirv.SDiv
104  //     CHECK: spirv.Load
105  // CHECK-NOT: spirv.ShiftRightArithmetic
106  %0 = memref.load %arg0[] : memref<i16, #spirv.storage_class<StorageBuffer>>
107  return
108}
109
110// CHECK-LABEL: @store_i16
111func.func @store_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>, %index: index, %value: i16) {
112  //     CHECK: spirv.Store
113  // CHECK-NOT: spirv.AtomicAnd
114  // CHECK-NOT: spirv.AtomicOr
115  memref.store %value, %arg0[%index] : memref<10xi16, #spirv.storage_class<StorageBuffer>>
116  return
117}
118
119// CHECK-LABEL: @load_store_i32_physical
120func.func @load_store_i32_physical(%arg0: memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>) {
121  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : i32
122  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : i32
123  %0 = memref.load %arg0[] : memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>
124  memref.store %0, %arg0[] : memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>
125  return
126}
127
128// CHECK-LABEL: @load_store_i8_physical
129func.func @load_store_i8_physical(%arg0: memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>) {
130  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 1] : i8
131  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 1] : i8
132  %0 = memref.load %arg0[] : memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>
133  memref.store %0, %arg0[] : memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>
134  return
135}
136
137// CHECK-LABEL: @load_store_i1_physical
138func.func @load_store_i1_physical(%arg0: memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>) {
139  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 1] : i8
140  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 1] : i8
141  %0 = memref.load %arg0[] : memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>
142  memref.store %0, %arg0[] : memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>
143  return
144}
145
146// CHECK-LABEL: @load_store_f32_physical
147func.func @load_store_f32_physical(%arg0: memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>) {
148  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : f32
149  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : f32
150  %0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
151  memref.store %0, %arg0[] : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
152  return
153}
154
155// CHECK-LABEL: @load_store_f16_physical
156func.func @load_store_f16_physical(%arg0: memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>) {
157  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 2] : f16
158  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 2] : f16
159  %0 = memref.load %arg0[] : memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>
160  memref.store %0, %arg0[] : memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>
161  return
162}
163
164} // end module
165
166// -----
167
168// Check for Kernel capability, that with proper compute and storage extensions, we don't need to
169// perform special tricks.
170
171module attributes {
172  spirv.target_env = #spirv.target_env<
173    #spirv.vce<v1.0,
174      [
175        Kernel, Addresses, Int8, Int16, Int64, Float16, Float64], []>, #spirv.resource_limits<>>
176} {
177
178// CHECK-LABEL: @load_store_zero_rank_float
179//  CHECK-SAME:     %[[OARG0:.*]]: memref{{.*}}, %[[OARG1:.*]]: memref
180func.func @load_store_zero_rank_float(%arg0: memref<f32, #spirv.storage_class<CrossWorkgroup>>, %arg1: memref<f32, #spirv.storage_class<CrossWorkgroup>>) {
181  //  CHECK-DAG: [[ARG0:%.*]] = builtin.unrealized_conversion_cast %[[OARG0]] : memref<f32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x f32>, CrossWorkgroup>
182  //  CHECK-DAG: [[ARG1:%.*]] = builtin.unrealized_conversion_cast %[[OARG1]] : memref<f32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x f32>, CrossWorkgroup>
183  //      CHECK: [[ZERO:%.*]] = spirv.Constant 0 : i32
184  //      CHECK: spirv.AccessChain [[ARG0]][
185  // CHECK-SAME: [[ZERO]]
186  // CHECK-SAME: ] :
187  //      CHECK: spirv.Load "CrossWorkgroup" %{{.*}} : f32
188  %0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<CrossWorkgroup>>
189  //      CHECK: spirv.AccessChain [[ARG1]][
190  // CHECK-SAME: [[ZERO]]
191  // CHECK-SAME: ] :
192  //      CHECK: spirv.Store "CrossWorkgroup" %{{.*}} : f32
193  memref.store %0, %arg1[] : memref<f32, #spirv.storage_class<CrossWorkgroup>>
194  return
195}
196
197// CHECK-LABEL: @load_store_zero_rank_int
198//  CHECK-SAME:     %[[OARG0:.*]]: memref{{.*}}, %[[OARG1:.*]]: memref
199func.func @load_store_zero_rank_int(%arg0: memref<i32, #spirv.storage_class<CrossWorkgroup>>, %arg1: memref<i32, #spirv.storage_class<CrossWorkgroup>>) {
200  //  CHECK-DAG: [[ARG0:%.*]] = builtin.unrealized_conversion_cast %[[OARG0]] : memref<i32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x i32>, CrossWorkgroup>
201  //  CHECK-DAG: [[ARG1:%.*]] = builtin.unrealized_conversion_cast %[[OARG1]] : memref<i32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x i32>, CrossWorkgroup>
202  //      CHECK: [[ZERO:%.*]] = spirv.Constant 0 : i32
203  //      CHECK: spirv.AccessChain [[ARG0]][
204  // CHECK-SAME: [[ZERO]]
205  // CHECK-SAME: ] :
206  //      CHECK: spirv.Load "CrossWorkgroup" %{{.*}} : i32
207  %0 = memref.load %arg0[] : memref<i32, #spirv.storage_class<CrossWorkgroup>>
208  //      CHECK: spirv.AccessChain [[ARG1]][
209  // CHECK-SAME: [[ZERO]]
210  // CHECK-SAME: ] :
211  //      CHECK: spirv.Store "CrossWorkgroup" %{{.*}} : i32
212  memref.store %0, %arg1[] : memref<i32, #spirv.storage_class<CrossWorkgroup>>
213  return
214}
215
216// CHECK-LABEL: func @load_store_unknown_dim
217//  CHECK-SAME:     %[[OARG0:.*]]: index, %[[OARG1:.*]]: memref{{.*}}, %[[OARG2:.*]]: memref
218func.func @load_store_unknown_dim(%i: index, %source: memref<?xi32, #spirv.storage_class<CrossWorkgroup>>, %dest: memref<?xi32, #spirv.storage_class<CrossWorkgroup>>) {
219  // CHECK-DAG: %[[SRC:.+]] = builtin.unrealized_conversion_cast %[[OARG1]] : memref<?xi32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<i32, CrossWorkgroup>
220  // CHECK-DAG: %[[DST:.+]] = builtin.unrealized_conversion_cast %[[OARG2]] : memref<?xi32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<i32, CrossWorkgroup>
221  // CHECK: %[[AC0:.+]] = spirv.PtrAccessChain %[[SRC]]
222  // CHECK: spirv.Load "CrossWorkgroup" %[[AC0]]
223  %0 = memref.load %source[%i] : memref<?xi32, #spirv.storage_class<CrossWorkgroup>>
224  // CHECK: %[[AC1:.+]] = spirv.PtrAccessChain %[[DST]]
225  // CHECK: spirv.Store "CrossWorkgroup" %[[AC1]]
226  memref.store %0, %dest[%i]: memref<?xi32, #spirv.storage_class<CrossWorkgroup>>
227  return
228}
229
230// CHECK-LABEL: func @load_i1
231//  CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %[[IDX:.+]]: index)
232func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i : index) -> i1 {
233  // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<4 x i8>, CrossWorkgroup>
234  // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
235  // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[IDX_CAST]]]
236  // CHECK: %[[VAL:.+]] = spirv.Load "CrossWorkgroup" %[[ADDR]] : i8
237  // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
238  // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8
239  %0 = memref.load %src[%i] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>>
240  // CHECK: return %[[BOOL]]
241  return %0: i1
242}
243
244// CHECK-LABEL: func @store_i1
245//  CHECK-SAME: %[[DST:.+]]: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>,
246//  CHECK-SAME: %[[IDX:.+]]: index
247func.func @store_i1(%dst: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i: index) {
248  %true = arith.constant true
249  // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<4 x i8>, CrossWorkgroup>
250  // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
251  // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[DST_CAST]][%[[IDX_CAST]]]
252  // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
253  // CHECK: spirv.Store "CrossWorkgroup" %[[ADDR]], %[[ONE_I8]] : i8
254  memref.store %true, %dst[%i]: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>
255  return
256}
257
258} // end module
259
260// -----
261
262// Check address space casts
263
264module attributes {
265  spirv.target_env = #spirv.target_env<
266    #spirv.vce<v1.0,
267      [
268        Kernel, Addresses, GenericPointer], []>, #spirv.resource_limits<>>
269} {
270
271// CHECK-LABEL: func.func @memory_space_cast
272func.func @memory_space_cast(%arg: memref<4xf32, #spirv.storage_class<CrossWorkgroup>>)
273    -> memref<4xf32, #spirv.storage_class<Function>> {
274  // CHECK: %[[ARG_CAST:.+]] = builtin.unrealized_conversion_cast {{.*}} to !spirv.ptr<!spirv.array<4 x f32>, CrossWorkgroup>
275  // CHECK: %[[TO_GENERIC:.+]] = spirv.PtrCastToGeneric %[[ARG_CAST]] : !spirv.ptr<!spirv.array<4 x f32>, CrossWorkgroup> to !spirv.ptr<!spirv.array<4 x f32>, Generic>
276  // CHECK: %[[TO_PRIVATE:.+]] = spirv.GenericCastToPtr %[[TO_GENERIC]] : !spirv.ptr<!spirv.array<4 x f32>, Generic> to !spirv.ptr<!spirv.array<4 x f32>, Function>
277  // CHECK: %[[RET:.+]] = builtin.unrealized_conversion_cast %[[TO_PRIVATE]]
278  // CHECK: return %[[RET]]
279  %ret = memref.memory_space_cast %arg : memref<4xf32, #spirv.storage_class<CrossWorkgroup>>
280    to memref<4xf32, #spirv.storage_class<Function>>
281  return %ret : memref<4xf32, #spirv.storage_class<Function>>
282}
283
284} // end module
285
286// -----
287
288// Check that casts are properly inserted if the corresponding **compute**
289// capability is allowed.
290module attributes {
291  spirv.target_env = #spirv.target_env<
292    #spirv.vce<v1.0, [Shader, Int8, Int16], [
293      SPV_KHR_8bit_storage, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class
294      ]>, #spirv.resource_limits<>>
295} {
296
297// CHECK-LABEL: @load_i1
298func.func @load_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>) -> i1 {
299  //     CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
300  //     CHECK: %[[RES:.+]]  = spirv.IEqual %{{.+}}, %[[ONE]] : i32
301  //     CHECK: return %[[RES]]
302  %0 = memref.load %arg0[] : memref<i1, #spirv.storage_class<StorageBuffer>>
303  return %0 : i1
304}
305
306// CHECK-LABEL: @load_i8
307func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 {
308  //     CHECK: %[[RES:.+]] = spirv.UConvert %{{.+}} : i32 to i8
309  //     CHECK: return %[[RES]]
310  %0 = memref.load %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>>
311  return %0 : i8
312}
313
314// CHECK-LABEL: @load_i16
315func.func @load_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>, %index : index) -> i16 {
316  //     CHECK: %[[RES:.+]] = spirv.UConvert %{{.+}} : i32 to i16
317  //     CHECK: return %[[RES]]
318  %0 = memref.load %arg0[%index] : memref<10xi16, #spirv.storage_class<StorageBuffer>>
319  return %0: i16
320}
321
322} // end module
323
324// -----
325
326// Check reinterpret_casts
327
328module attributes {
329  spirv.target_env = #spirv.target_env<
330    #spirv.vce<v1.0,
331      [Kernel, Addresses, GenericPointer], []>, #spirv.resource_limits<>>
332} {
333
334// CHECK-LABEL: func.func @reinterpret_cast
335//  CHECK-SAME:  (%[[MEM:.*]]: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>, %[[OFF:.*]]: index)
336func.func @reinterpret_cast(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>, %arg1: index) -> memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> {
337//   CHECK-DAG:  %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
338//   CHECK-DAG:  %[[OFF1:.*]] = builtin.unrealized_conversion_cast %[[OFF]] : index to i32
339//       CHECK:  %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF1]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
340//       CHECK:  %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
341//       CHECK:  return %[[RET1]]
342  %ret = memref.reinterpret_cast %arg to offset: [%arg1], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
343  return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
344}
345
346// CHECK-LABEL: func.func @reinterpret_cast_0
347//  CHECK-SAME:  (%[[MEM:.*]]: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>)
348func.func @reinterpret_cast_0(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> {
349//   CHECK-DAG:  %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
350//   CHECK-DAG:  %[[RET:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
351//       CHECK:  return %[[RET]]
352  %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
353  return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
354}
355
356// CHECK-LABEL: func.func @reinterpret_cast_5
357//  CHECK-SAME:  (%[[MEM:.*]]: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>)
358func.func @reinterpret_cast_5(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> {
359//       CHECK:  %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
360//       CHECK:  %[[OFF:.*]] = spirv.Constant 5 : i32
361//       CHECK:  %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
362//       CHECK:  %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
363//       CHECK:  return %[[RET1]]
364  %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
365  return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
366}
367
368} // end module
369
370
371// -----
372
373// Check casts
374
375module attributes {
376  spirv.target_env = #spirv.target_env<
377    #spirv.vce<v1.0,
378      [Kernel, Addresses, GenericPointer], []>, #spirv.resource_limits<>>
379} {
380
381// CHECK-LABEL: func.func @cast
382//  CHECK-SAME:  (%[[MEM:.*]]: memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>>)
383func.func @cast(%arg: memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>> {
384//   CHECK-DAG:  %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
385//   CHECK-DAG:  %[[MEM2:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>>
386//       CHECK:  return %[[MEM2]]
387  %ret = memref.cast %arg : memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>>
388  return %ret : memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>>
389}
390
391// TODO: Not supported yet
392// CHECK-LABEL: func.func @cast_from_static
393//  CHECK-SAME:  (%[[MEM:.*]]: memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>>)
394func.func @cast_from_static(%arg: memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>> {
395//       CHECK:  %[[MEM1:.*]] =  memref.cast %[[MEM]] : memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>>
396//       CHECK:  return %[[MEM1]]
397  %ret = memref.cast %arg : memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>>
398  return %ret : memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>>
399}
400
401// TODO: Not supported yet
402// CHECK-LABEL: func.func @cast_to_static
403//  CHECK-SAME:  (%[[MEM:.*]]: memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>>)
404func.func @cast_to_static(%arg: memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>> {
405//       CHECK:  %[[MEM1:.*]] =  memref.cast %[[MEM]] : memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>>
406//       CHECK:  return %[[MEM1]]
407  %ret = memref.cast %arg : memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>>
408  return %ret : memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>>
409}
410
411// TODO: Not supported yet
412// CHECK-LABEL: func.func @cast_to_static_zero_elems
413//  CHECK-SAME:  (%[[MEM:.*]]: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>)
414func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<0xf32, #spirv.storage_class<CrossWorkgroup>> {
415//       CHECK:  %[[MEM1:.*]] =  memref.cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<0xf32, #spirv.storage_class<CrossWorkgroup>>
416//       CHECK:  return %[[MEM1]]
417  %ret = memref.cast %arg : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<0xf32, #spirv.storage_class<CrossWorkgroup>>
418  return %ret : memref<0xf32, #spirv.storage_class<CrossWorkgroup>>
419}
420
421}
422
423// -----
424
425// Check nontemporal attribute
426
427module attributes {
428  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [
429    Shader,
430    PhysicalStorageBufferAddresses
431  ], [
432    SPV_KHR_storage_buffer_storage_class,
433    SPV_KHR_physical_storage_buffer
434  ]>, #spirv.resource_limits<>>
435} {
436  func.func @load_nontemporal(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) {
437    %0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
438//       CHECK:  spirv.Load "StorageBuffer" %{{.+}} ["Nontemporal"] : f32
439    memref.store %0, %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
440//       CHECK:  spirv.Store "StorageBuffer" %{{.+}}, %{{.+}} ["Nontemporal"] : f32
441    return
442  }
443
444  func.func @load_nontemporal_aligned(%arg0: memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>) {
445    %0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
446//       CHECK:  spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned|Nontemporal", 4] : f32
447    memref.store %0, %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
448//       CHECK:  spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned|Nontemporal", 4] : f32
449    return
450  }
451}
452