xref: /llvm-project/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir (revision c63febb1025564b078a5c8e52e6df638e8a1d808)
1// RUN: mlir-opt -split-input-file -convert-memref-to-spirv -canonicalize -verify-diagnostics %s -o - | FileCheck %s
2
3module attributes {
4  spirv.target_env = #spirv.target_env<
5    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
6  }
7{
8  func.func @alloc_dealloc_workgroup_mem_shader_f32(%arg0 : index, %arg1 : index) {
9    %0 = memref.alloc() : memref<4x5xf32, #spirv.storage_class<Workgroup>>
10    %1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, #spirv.storage_class<Workgroup>>
11    memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, #spirv.storage_class<Workgroup>>
12    memref.dealloc %0 : memref<4x5xf32, #spirv.storage_class<Workgroup>>
13    return
14  }
15}
16//       CHECK: spirv.GlobalVariable @[[$VAR:.+]] : !spirv.ptr<!spirv.struct<(!spirv.array<20 x f32>)>, Workgroup>
17// CHECK-LABEL: func @alloc_dealloc_workgroup_mem_shader_f32
18//   CHECK-NOT:   memref.alloc
19//       CHECK:   %[[PTR:.+]] = spirv.mlir.addressof @[[$VAR]]
20//       CHECK:   %[[LOADPTR:.+]] = spirv.AccessChain %[[PTR]]
21//       CHECK:   %[[VAL:.+]] = spirv.Load "Workgroup" %[[LOADPTR]] : f32
22//       CHECK:   %[[STOREPTR:.+]] = spirv.AccessChain %[[PTR]]
23//       CHECK:   spirv.Store "Workgroup" %[[STOREPTR]], %[[VAL]] : f32
24//   CHECK-NOT:   memref.dealloc
25
26// -----
27
28module attributes {
29  spirv.target_env = #spirv.target_env<
30    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
31  }
32{
33  func.func @alloc_dealloc_workgroup_mem_shader_i16(%arg0 : index, %arg1 : index) {
34    %0 = memref.alloc() : memref<4x5xi16, #spirv.storage_class<Workgroup>>
35    %1 = memref.load %0[%arg0, %arg1] : memref<4x5xi16, #spirv.storage_class<Workgroup>>
36    memref.store %1, %0[%arg0, %arg1] : memref<4x5xi16, #spirv.storage_class<Workgroup>>
37    memref.dealloc %0 : memref<4x5xi16, #spirv.storage_class<Workgroup>>
38    return
39  }
40}
41
42//       CHECK: spirv.GlobalVariable @__workgroup_mem__{{[0-9]+}}
43//  CHECK-SAME:   !spirv.ptr<!spirv.struct<(!spirv.array<10 x i32>)>, Workgroup>
44// CHECK-LABEL: func @alloc_dealloc_workgroup_mem_shader_i16
45//       CHECK:   %[[VAR:.+]] = spirv.mlir.addressof @__workgroup_mem__0
46//       CHECK:   %[[LOC:.+]] = spirv.SDiv
47//       CHECK:   %[[PTR:.+]] = spirv.AccessChain %[[VAR]][%{{.+}}, %[[LOC]]]
48//       CHECK:   %{{.+}} = spirv.Load "Workgroup" %[[PTR]] : i32
49//       CHECK:   %[[LOC:.+]] = spirv.SDiv
50//       CHECK:   %[[PTR:.+]] = spirv.AccessChain %[[VAR]][%{{.+}}, %[[LOC]]]
51//       CHECK:   %{{.+}} = spirv.AtomicAnd <Workgroup> <AcquireRelease> %[[PTR]], %{{.+}} : !spirv.ptr<i32, Workgroup>
52//       CHECK:   %{{.+}} = spirv.AtomicOr <Workgroup> <AcquireRelease> %[[PTR]], %{{.+}} : !spirv.ptr<i32, Workgroup>
53
54// -----
55
56module attributes {
57  spirv.target_env = #spirv.target_env<
58    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
59  }
60{
61  func.func @two_allocs() {
62    %0 = memref.alloc() : memref<4x5xf32, #spirv.storage_class<Workgroup>>
63    %1 = memref.alloc() : memref<2x3xi32, #spirv.storage_class<Workgroup>>
64    return
65  }
66}
67
68//   CHECK-DAG: spirv.GlobalVariable @__workgroup_mem__{{[0-9]+}}
69//  CHECK-SAME:   !spirv.ptr<!spirv.struct<(!spirv.array<6 x i32>)>, Workgroup>
70//   CHECK-DAG: spirv.GlobalVariable @__workgroup_mem__{{[0-9]+}}
71//  CHECK-SAME:   !spirv.ptr<!spirv.struct<(!spirv.array<20 x f32>)>, Workgroup>
72// CHECK-LABEL: func @two_allocs()
73
74// -----
75
76module attributes {
77  spirv.target_env = #spirv.target_env<
78    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
79  }
80{
81  func.func @two_allocs_vector() {
82    %0 = memref.alloc() : memref<4xvector<4xf32>, #spirv.storage_class<Workgroup>>
83    %1 = memref.alloc() : memref<2xvector<2xi32>, #spirv.storage_class<Workgroup>>
84    return
85  }
86}
87
88//   CHECK-DAG: spirv.GlobalVariable @__workgroup_mem__{{[0-9]+}}
89//  CHECK-SAME:   !spirv.ptr<!spirv.struct<(!spirv.array<2 x vector<2xi32>>)>, Workgroup>
90//   CHECK-DAG: spirv.GlobalVariable @__workgroup_mem__{{[0-9]+}}
91//  CHECK-SAME:   !spirv.ptr<!spirv.struct<(!spirv.array<4 x vector<4xf32>>)>, Workgroup>
92// CHECK-LABEL: func @two_allocs_vector()
93
94// -----
95
96module attributes {
97  spirv.target_env = #spirv.target_env<
98    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
99  }
100{
101  // CHECK-LABEL: func @alloc_dynamic_size
102  func.func @alloc_dynamic_size(%arg0 : index) -> f32 {
103    // CHECK: memref.alloc
104    %0 = memref.alloc(%arg0) : memref<4x?xf32, #spirv.storage_class<Workgroup>>
105    %1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, #spirv.storage_class<Workgroup>>
106    return %1: f32
107  }
108}
109
110// -----
111
112module attributes {
113  spirv.target_env = #spirv.target_env<
114    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
115  }
116{
117  // CHECK-LABEL: func @alloc_unsupported_memory_space
118  func.func @alloc_unsupported_memory_space(%arg0: index) -> f32 {
119    // CHECK: memref.alloc
120    %0 = memref.alloc() : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>
121    %1 = memref.load %0[%arg0, %arg0] : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>
122    return %1: f32
123  }
124}
125
126
127// -----
128
129module attributes {
130  spirv.target_env = #spirv.target_env<
131    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
132  }
133{
134  // CHECK-LABEL: func @dealloc_dynamic_size
135  func.func @dealloc_dynamic_size(%arg0 : memref<4x?xf32, #spirv.storage_class<Workgroup>>) {
136    // CHECK: memref.dealloc
137    memref.dealloc %arg0 : memref<4x?xf32, #spirv.storage_class<Workgroup>>
138    return
139  }
140}
141
142// -----
143
144module attributes {
145  spirv.target_env = #spirv.target_env<
146    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
147  }
148{
149  // CHECK-LABEL: func @dealloc_unsupported_memory_space
150  func.func @dealloc_unsupported_memory_space(%arg0 : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>) {
151    // CHECK: memref.dealloc
152    memref.dealloc %arg0 : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>
153    return
154  }
155}
156
157// -----
158module attributes {
159  spirv.target_env = #spirv.target_env<
160    #spirv.vce<v1.0, [Kernel], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
161  }
162{
163  func.func @alloc_dealloc_workgroup_mem_kernel(%arg0 : index, %arg1 : index) {
164    %0 = memref.alloc() : memref<4x5xf32, #spirv.storage_class<Workgroup>>
165    %1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, #spirv.storage_class<Workgroup>>
166    memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, #spirv.storage_class<Workgroup>>
167    memref.dealloc %0 : memref<4x5xf32, #spirv.storage_class<Workgroup>>
168    return
169  }
170}
171//       CHECK: spirv.GlobalVariable @[[$VAR:.+]] : !spirv.ptr<!spirv.array<20 x f32>, Workgroup>
172// CHECK-LABEL: func @alloc_dealloc_workgroup_mem_kernel
173//   CHECK-NOT:   memref.alloc
174//       CHECK:   %[[PTR:.+]] = spirv.mlir.addressof @[[$VAR]]
175//       CHECK:   %[[LOADPTR:.+]] = spirv.AccessChain %[[PTR]]
176//       CHECK:   %[[VAL:.+]] = spirv.Load "Workgroup" %[[LOADPTR]] : f32
177//       CHECK:   %[[STOREPTR:.+]] = spirv.AccessChain %[[PTR]]
178//       CHECK:   spirv.Store "Workgroup" %[[STOREPTR]], %[[VAL]] : f32
179//   CHECK-NOT:   memref.dealloc
180
181// -----
182
183module attributes {
184  spirv.target_env = #spirv.target_env<
185    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
186  }
187{
188  func.func @zero_size() {
189    %0 = memref.alloc() : memref<0xf32, #spirv.storage_class<Workgroup>>
190    %1 = memref.alloc() : memref<0xi1, #spirv.storage_class<Workgroup>>
191    %2 = memref.alloc() : memref<0xi4, #spirv.storage_class<Workgroup>>
192    return
193  }
194}
195
196// Zero-sized allocations are not handled yet. Just make sure we do not crash.
197// CHECK-LABEL: func @zero_size()
198