1// RUN: mlir-opt -split-input-file -convert-memref-to-spirv="bool-num-bits=8" -cse %s -o - | FileCheck %s 2// RUN: mlir-opt -split-input-file -convert-memref-to-spirv="bool-num-bits=8 use-64bit-index" -cse %s -o - | FileCheck %s --check-prefix=INDEX64 3 4// Check that access chain indices are properly adjusted if non-32-bit types are 5// emulated via 32-bit types. 6// TODO: Test i64 types. 7module attributes { 8 spirv.target_env = #spirv.target_env< 9 #spirv.vce<v1.0, [Shader, Int64], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>> 10} { 11 12// CHECK-LABEL: @load_i1 13func.func @load_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>) -> i1 { 14 // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 15 // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]] 16 // CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] 17 // CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32 18 // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32 19 // CHECK: %[[T2:.+]] = spirv.Constant 24 : i32 20 // CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 21 // CHECK: %[[T4:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 22 // Convert to i1 type. 23 // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 24 // CHECK: %[[RES:.+]] = spirv.IEqual %[[T4]], %[[ONE]] : i32 25 // CHECK: return %[[RES]] 26 %0 = memref.load %arg0[] : memref<i1, #spirv.storage_class<StorageBuffer>> 27 return %0 : i1 28} 29 30// CHECK-LABEL: @load_i8 31// INDEX64-LABEL: @load_i8 32func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 { 33 // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 34 // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]] 35 // CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] 36 // CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32 37 // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32 38 // CHECK: %[[T2:.+]] = spirv.Constant 24 : i32 39 // CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 40 // CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 41 // CHECK: builtin.unrealized_conversion_cast %[[SR]] 42 43 // INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64 44 // INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64 45 // INDEX64: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] : i32 46 // INDEX64: %[[MASK:.+]] = spirv.Constant 255 : i32 47 // INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32 48 // INDEX64: %[[T2:.+]] = spirv.Constant 24 : i32 49 // INDEX64: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 50 // INDEX64: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 51 // INDEX64: builtin.unrealized_conversion_cast %[[SR]] 52 %0 = memref.load %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>> 53 return %0 : i8 54} 55 56// CHECK-LABEL: @load_i16 57// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: index) 58func.func @load_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>, %index : index) -> i16 { 59 // CHECK: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32 60 // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 61 // CHECK: %[[TWO:.+]] = spirv.Constant 2 : i32 62 // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ARG1_CAST]], %[[TWO]] : i32 63 // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] 64 // CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] 65 // CHECK: %[[SIXTEEN:.+]] = spirv.Constant 16 : i32 66 // CHECK: %[[IDX:.+]] = spirv.UMod %[[ARG1_CAST]], %[[TWO]] : i32 67 // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[SIXTEEN]] : i32 68 // CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 69 // CHECK: %[[MASK:.+]] = spirv.Constant 65535 : i32 70 // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 71 // CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[SIXTEEN]] : i32, i32 72 // CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[SIXTEEN]] : i32, i32 73 // CHECK: builtin.unrealized_conversion_cast %[[SR]] 74 %0 = memref.load %arg0[%index] : memref<10xi16, #spirv.storage_class<StorageBuffer>> 75 return %0: i16 76} 77 78// CHECK-LABEL: @load_f32 79func.func @load_f32(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) { 80 // CHECK-NOT: spirv.SDiv 81 // CHECK: spirv.Load 82 // CHECK-NOT: spirv.ShiftRightArithmetic 83 %0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<StorageBuffer>> 84 return 85} 86 87// CHECK-LABEL: @store_i1 88// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i1) 89func.func @store_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>, %value: i1) { 90 // CHECK: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] 91 // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 92 // CHECK: %[[MASK:.+]] = spirv.Constant -256 : i32 93 // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 94 // CHECK: %[[CASTED_ARG1:.+]] = spirv.Select %[[ARG1]], %[[ONE]], %[[ZERO]] : i1, i32 95 // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]] 96 // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]] 97 // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CASTED_ARG1]] 98 memref.store %value, %arg0[] : memref<i1, #spirv.storage_class<StorageBuffer>> 99 return 100} 101 102// CHECK-LABEL: @store_i8 103// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8) 104// INDEX64-LABEL: @store_i8 105// INDEX64: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8) 106func.func @store_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>, %value: i8) { 107 // CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32 108 // CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] 109 // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 110 // CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32 111 // CHECK: %[[MASK2:.+]] = spirv.Constant -256 : i32 112 // CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32 113 // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]] 114 // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]] 115 // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CLAMPED_VAL]] 116 117 // INDEX64-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32 118 // INDEX64-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] 119 // INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64 120 // INDEX64: %[[MASK1:.+]] = spirv.Constant 255 : i32 121 // INDEX64: %[[MASK2:.+]] = spirv.Constant -256 : i32 122 // INDEX64: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32 123 // INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64 124 // INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]] 125 // INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CLAMPED_VAL]] 126 memref.store %value, %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>> 127 return 128} 129 130// CHECK-LABEL: @store_i16 131// CHECK: (%[[ARG0:.+]]: memref<10xi16, #spirv.storage_class<StorageBuffer>>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i16) 132func.func @store_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>, %index: index, %value: i16) { 133 // CHECK-DAG: %[[ARG2_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : i16 to i32 134 // CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] 135 // CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32 136 // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 137 // CHECK: %[[TWO:.+]] = spirv.Constant 2 : i32 138 // CHECK: %[[SIXTEEN:.+]] = spirv.Constant 16 : i32 139 // CHECK: %[[IDX:.+]] = spirv.UMod %[[ARG1_CAST]], %[[TWO]] : i32 140 // CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[SIXTEEN]] : i32 141 // CHECK: %[[MASK1:.+]] = spirv.Constant 65535 : i32 142 // CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 143 // CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32 144 // CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG2_CAST]], %[[MASK1]] : i32 145 // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32 146 // CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ARG1_CAST]], %[[TWO]] : i32 147 // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] 148 // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]] 149 // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]] 150 memref.store %value, %arg0[%index] : memref<10xi16, #spirv.storage_class<StorageBuffer>> 151 return 152} 153 154// CHECK-LABEL: @store_f32 155func.func @store_f32(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>, %value: f32) { 156 // CHECK: spirv.Store 157 // CHECK-NOT: spirv.AtomicAnd 158 // CHECK-NOT: spirv.AtomicOr 159 memref.store %value, %arg0[] : memref<f32, #spirv.storage_class<StorageBuffer>> 160 return 161} 162 163} // end module 164 165 166// ----- 167 168// Check that access chain indices are properly adjusted if sub-byte types are 169// emulated via 32-bit types. 170module attributes { 171 spirv.target_env = #spirv.target_env< 172 #spirv.vce<v1.0, [Shader, Int64], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>> 173} { 174 175// CHECK-LABEL: @load_i4 176func.func @load_i4(%arg0: memref<?xi4, #spirv.storage_class<StorageBuffer>>, %i: index) -> i4 { 177 // CHECK: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %{{.+}} : index to i32 178 // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 179 // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32 180 // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[INDEX]], %[[EIGHT]] : i32 181 // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] 182 // CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] : i32 183 // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32 184 // CHECK: %[[IDX:.+]] = spirv.UMod %[[INDEX]], %[[EIGHT]] : i32 185 // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[FOUR]] : i32 186 // CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 187 // CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32 188 // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 189 // CHECK: %[[C28:.+]] = spirv.Constant 28 : i32 190 // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[AND]], %[[C28]] : i32, i32 191 // CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[SL]], %[[C28]] : i32, i32 192 // CHECK: builtin.unrealized_conversion_cast %[[SR]] 193 %0 = memref.load %arg0[%i] : memref<?xi4, #spirv.storage_class<StorageBuffer>> 194 return %0 : i4 195} 196 197// CHECK-LABEL: @store_i4 198func.func @store_i4(%arg0: memref<?xi4, #spirv.storage_class<StorageBuffer>>, %value: i4, %i: index) { 199 // CHECK-DAG: %[[VAL:.+]] = builtin.unrealized_conversion_cast %{{.+}} : i4 to i32 200 // CHECK-DAG: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %{{.+}} : index to i32 201 // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 202 // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32 203 // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32 204 // CHECK: %[[IDX:.+]] = spirv.UMod %[[INDEX]], %[[EIGHT]] : i32 205 // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[FOUR]] : i32 206 // CHECK: %[[MASK1:.+]] = spirv.Constant 15 : i32 207 // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[BITS]] : i32, i32 208 // CHECK: %[[MASK2:.+]] = spirv.Not %[[SL]] : i32 209 // CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[VAL]], %[[MASK1]] : i32 210 // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[BITS]] : i32, i32 211 // CHECK: %[[ACCESS_INDEX:.+]] = spirv.SDiv %[[INDEX]], %[[EIGHT]] : i32 212 // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ACCESS_INDEX]]] 213 // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]] 214 // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]] 215 memref.store %value, %arg0[%i] : memref<?xi4, #spirv.storage_class<StorageBuffer>> 216 return 217} 218 219} // end module 220 221// ----- 222 223// Check that we can access i8 storage with i8 types available but without 224// 8-bit storage capabilities. 225module attributes { 226 spirv.target_env = #spirv.target_env< 227 #spirv.vce<v1.0, [Shader, Int64, Int8], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>> 228} { 229 230// CHECK-LABEL: @load_i8 231// INDEX64-LABEL: @load_i8 232func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 { 233 // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 234 // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]] 235 // CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] 236 // CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32 237 // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32 238 // CHECK: %[[T2:.+]] = spirv.Constant 24 : i32 239 // CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 240 // CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 241 // CHECK: %[[CAST:.+]] = spirv.UConvert %[[SR]] : i32 to i8 242 // CHECK: return %[[CAST]] : i8 243 244 // INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64 245 // INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64 246 // INDEX64: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] : i32 247 // INDEX64: %[[MASK:.+]] = spirv.Constant 255 : i32 248 // INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32 249 // INDEX64: %[[T2:.+]] = spirv.Constant 24 : i32 250 // INDEX64: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 251 // INDEX64: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 252 // INDEX64: %[[CAST:.+]] = spirv.UConvert %[[SR]] : i32 to i8 253 // INDEX64: return %[[CAST]] : i8 254 %0 = memref.load %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>> 255 return %0 : i8 256} 257 258// CHECK-LABEL: @store_i8 259// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8) 260// INDEX64-LABEL: @store_i8 261// INDEX64: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8) 262func.func @store_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>, %value: i8) { 263 // CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] 264 // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 265 // CHECK: %[[MASK1:.+]] = spirv.Constant -256 : i32 266 // CHECK: %[[ARG1_CAST:.+]] = spirv.UConvert %[[ARG1]] : i8 to i32 267 // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]] 268 // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK1]] 269 // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[ARG1_CAST]] 270 271 // INDEX64-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] 272 // INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64 273 // INDEX64: %[[MASK1:.+]] = spirv.Constant -256 : i32 274 // INDEX64: %[[ARG1_CAST:.+]] = spirv.UConvert %[[ARG1]] : i8 to i32 275 // INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64 276 // INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK1]] 277 // INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[ARG1_CAST]] 278 memref.store %value, %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>> 279 return 280} 281 282} // end module 283