xref: /llvm-project/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir (revision 13d983e730297ad454d53a0a97e1f72499b489f1)
1// RUN: mlir-opt -split-input-file -convert-memref-to-spirv="bool-num-bits=8" -cse %s -o - | FileCheck %s
2// RUN: mlir-opt -split-input-file -convert-memref-to-spirv="bool-num-bits=8 use-64bit-index" -cse %s -o - | FileCheck %s --check-prefix=INDEX64
3
4// Check that access chain indices are properly adjusted if non-32-bit types are
5// emulated via 32-bit types.
6// TODO: Test i64 types.
7module attributes {
8  spirv.target_env = #spirv.target_env<
9    #spirv.vce<v1.0, [Shader, Int64], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
10} {
11
12// CHECK-LABEL: @load_i1
13func.func @load_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>) -> i1 {
14  //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
15  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]]
16  //     CHECK: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]]
17  //     CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32
18  //     CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
19  //     CHECK: %[[T2:.+]] = spirv.Constant 24 : i32
20  //     CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
21  //     CHECK: %[[T4:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
22  // Convert to i1 type.
23  //     CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
24  //     CHECK: %[[RES:.+]]  = spirv.IEqual %[[T4]], %[[ONE]] : i32
25  //     CHECK: return %[[RES]]
26  %0 = memref.load %arg0[] : memref<i1, #spirv.storage_class<StorageBuffer>>
27  return %0 : i1
28}
29
30// CHECK-LABEL: @load_i8
31// INDEX64-LABEL: @load_i8
32func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 {
33  //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
34  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]]
35  //     CHECK: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]]
36  //     CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32
37  //     CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
38  //     CHECK: %[[T2:.+]] = spirv.Constant 24 : i32
39  //     CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
40  //     CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
41  //     CHECK: builtin.unrealized_conversion_cast %[[SR]]
42
43  //   INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
44  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64
45  //   INDEX64: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]] : i32
46  //   INDEX64: %[[MASK:.+]] = spirv.Constant 255 : i32
47  //   INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
48  //   INDEX64: %[[T2:.+]] = spirv.Constant 24 : i32
49  //   INDEX64: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
50  //   INDEX64: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
51  //   INDEX64: builtin.unrealized_conversion_cast %[[SR]]
52  %0 = memref.load %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>>
53  return %0 : i8
54}
55
56// CHECK-LABEL: @load_i16
57//       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: index)
58func.func @load_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>, %index : index) -> i16 {
59  //     CHECK: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
60  //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
61  //     CHECK: %[[TWO:.+]] = spirv.Constant 2 : i32
62  //     CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ARG1_CAST]], %[[TWO]] : i32
63  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
64  //     CHECK: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]]
65  //     CHECK: %[[SIXTEEN:.+]] = spirv.Constant 16 : i32
66  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ARG1_CAST]], %[[TWO]] : i32
67  //     CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[SIXTEEN]] : i32
68  //     CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
69  //     CHECK: %[[MASK:.+]] = spirv.Constant 65535 : i32
70  //     CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
71  //     CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[SIXTEEN]] : i32, i32
72  //     CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[SIXTEEN]] : i32, i32
73  //     CHECK: builtin.unrealized_conversion_cast %[[SR]]
74  %0 = memref.load %arg0[%index] : memref<10xi16, #spirv.storage_class<StorageBuffer>>
75  return %0: i16
76}
77
78// CHECK-LABEL: @load_f32
79func.func @load_f32(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) {
80  // CHECK-NOT: spirv.SDiv
81  //     CHECK: spirv.Load
82  // CHECK-NOT: spirv.ShiftRightArithmetic
83  %0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<StorageBuffer>>
84  return
85}
86
87// CHECK-LABEL: @store_i1
88//       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i1)
89func.func @store_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>, %value: i1) {
90  //     CHECK: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
91  //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
92  //     CHECK: %[[MASK:.+]] = spirv.Constant -256 : i32
93  //     CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
94  //     CHECK: %[[CASTED_ARG1:.+]] = spirv.Select %[[ARG1]], %[[ONE]], %[[ZERO]] : i1, i32
95  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]]
96  //     CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
97  //     CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CASTED_ARG1]]
98  memref.store %value, %arg0[] : memref<i1, #spirv.storage_class<StorageBuffer>>
99  return
100}
101
102// CHECK-LABEL: @store_i8
103//       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8)
104// INDEX64-LABEL: @store_i8
105//       INDEX64: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8)
106func.func @store_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>, %value: i8) {
107  //     CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
108  //     CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
109  //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
110  //     CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32
111  //     CHECK: %[[MASK2:.+]] = spirv.Constant -256 : i32
112  //     CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
113  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]]
114  //     CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]]
115  //     CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CLAMPED_VAL]]
116
117  //   INDEX64-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
118  //   INDEX64-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
119  //   INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
120  //   INDEX64: %[[MASK1:.+]] = spirv.Constant 255 : i32
121  //   INDEX64: %[[MASK2:.+]] = spirv.Constant -256 : i32
122  //   INDEX64: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
123  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64
124  //   INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]]
125  //   INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CLAMPED_VAL]]
126  memref.store %value, %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>>
127  return
128}
129
130// CHECK-LABEL: @store_i16
131//       CHECK: (%[[ARG0:.+]]: memref<10xi16, #spirv.storage_class<StorageBuffer>>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i16)
132func.func @store_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>, %index: index, %value: i16) {
133  //     CHECK-DAG: %[[ARG2_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : i16 to i32
134  //     CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
135  //     CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
136  //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
137  //     CHECK: %[[TWO:.+]] = spirv.Constant 2 : i32
138  //     CHECK: %[[SIXTEEN:.+]] = spirv.Constant 16 : i32
139  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ARG1_CAST]], %[[TWO]] : i32
140  //     CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[SIXTEEN]] : i32
141  //     CHECK: %[[MASK1:.+]] = spirv.Constant 65535 : i32
142  //     CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
143  //     CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
144  //     CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG2_CAST]], %[[MASK1]] : i32
145  //     CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
146  //     CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ARG1_CAST]], %[[TWO]] : i32
147  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
148  //     CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
149  //     CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
150  memref.store %value, %arg0[%index] : memref<10xi16, #spirv.storage_class<StorageBuffer>>
151  return
152}
153
154// CHECK-LABEL: @store_f32
155func.func @store_f32(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>, %value: f32) {
156  //     CHECK: spirv.Store
157  // CHECK-NOT: spirv.AtomicAnd
158  // CHECK-NOT: spirv.AtomicOr
159  memref.store %value, %arg0[] : memref<f32, #spirv.storage_class<StorageBuffer>>
160  return
161}
162
163} // end module
164
165
166// -----
167
168// Check that access chain indices are properly adjusted if sub-byte types are
169// emulated via 32-bit types.
170module attributes {
171  spirv.target_env = #spirv.target_env<
172    #spirv.vce<v1.0, [Shader, Int64], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
173} {
174
175// CHECK-LABEL: @load_i4
176func.func @load_i4(%arg0: memref<?xi4, #spirv.storage_class<StorageBuffer>>, %i: index) -> i4 {
177  // CHECK: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %{{.+}} : index to i32
178  // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
179  // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
180  // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[INDEX]], %[[EIGHT]] : i32
181  // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
182  // CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] : i32
183  // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
184  // CHECK: %[[IDX:.+]] = spirv.UMod %[[INDEX]], %[[EIGHT]] : i32
185  // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[FOUR]] : i32
186  // CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
187  // CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32
188  // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
189  // CHECK: %[[C28:.+]] = spirv.Constant 28 : i32
190  // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[AND]], %[[C28]] : i32, i32
191  // CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[SL]], %[[C28]] : i32, i32
192  // CHECK: builtin.unrealized_conversion_cast %[[SR]]
193  %0 = memref.load %arg0[%i] : memref<?xi4, #spirv.storage_class<StorageBuffer>>
194  return %0 : i4
195}
196
197// CHECK-LABEL: @store_i4
198func.func @store_i4(%arg0: memref<?xi4, #spirv.storage_class<StorageBuffer>>, %value: i4, %i: index) {
199  // CHECK-DAG: %[[VAL:.+]] = builtin.unrealized_conversion_cast %{{.+}} : i4 to i32
200  // CHECK-DAG: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %{{.+}} : index to i32
201  // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
202  // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
203  // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
204  // CHECK: %[[IDX:.+]] = spirv.UMod %[[INDEX]], %[[EIGHT]] : i32
205  // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[FOUR]] : i32
206  // CHECK: %[[MASK1:.+]] = spirv.Constant 15 : i32
207  // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[BITS]] : i32, i32
208  // CHECK: %[[MASK2:.+]] = spirv.Not %[[SL]] : i32
209  // CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[VAL]], %[[MASK1]] : i32
210  // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[BITS]] : i32, i32
211  // CHECK: %[[ACCESS_INDEX:.+]] = spirv.SDiv %[[INDEX]], %[[EIGHT]] : i32
212  // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ACCESS_INDEX]]]
213  // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]]
214  // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
215  memref.store %value, %arg0[%i] : memref<?xi4, #spirv.storage_class<StorageBuffer>>
216  return
217}
218
219} // end module
220
221// -----
222
223// Check that we can access i8 storage with i8 types available but without
224// 8-bit storage capabilities.
225module attributes {
226  spirv.target_env = #spirv.target_env<
227    #spirv.vce<v1.0, [Shader, Int64, Int8], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
228} {
229
230// CHECK-LABEL: @load_i8
231// INDEX64-LABEL: @load_i8
232func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 {
233  //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
234  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]]
235  //     CHECK: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]]
236  //     CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32
237  //     CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
238  //     CHECK: %[[T2:.+]] = spirv.Constant 24 : i32
239  //     CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
240  //     CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
241  //     CHECK: %[[CAST:.+]] = spirv.UConvert %[[SR]] : i32 to i8
242  //     CHECK: return %[[CAST]] : i8
243
244  //   INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
245  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64
246  //   INDEX64: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]] : i32
247  //   INDEX64: %[[MASK:.+]] = spirv.Constant 255 : i32
248  //   INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
249  //   INDEX64: %[[T2:.+]] = spirv.Constant 24 : i32
250  //   INDEX64: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
251  //   INDEX64: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
252  //   INDEX64: %[[CAST:.+]] = spirv.UConvert %[[SR]] : i32 to i8
253  //   INDEX64: return %[[CAST]] : i8
254  %0 = memref.load %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>>
255  return %0 : i8
256}
257
258// CHECK-LABEL: @store_i8
259//       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8)
260// INDEX64-LABEL: @store_i8
261//       INDEX64: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8)
262func.func @store_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>, %value: i8) {
263  //     CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
264  //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
265  //     CHECK: %[[MASK1:.+]] = spirv.Constant -256 : i32
266  //     CHECK: %[[ARG1_CAST:.+]] = spirv.UConvert %[[ARG1]] : i8 to i32
267  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]]
268  //     CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK1]]
269  //     CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[ARG1_CAST]]
270
271  //   INDEX64-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
272  //   INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
273  //   INDEX64: %[[MASK1:.+]] = spirv.Constant -256 : i32
274  //   INDEX64: %[[ARG1_CAST:.+]] = spirv.UConvert %[[ARG1]] : i8 to i32
275  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64
276  //   INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK1]]
277  //   INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[ARG1_CAST]]
278  memref.store %value, %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>>
279  return
280}
281
282} // end module
283