xref: /llvm-project/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir (revision 8900c09ebfd782bfd41bac63ac5266f80fe29602)
1// RUN: mlir-opt %s -split-input-file --pass-pipeline='builtin.module(func.func(nvgpu-optimize-shared-memory))' | FileCheck %s
2
3// CHECK: @optimize_128x32xf16_32x128xf16([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index)
4func.func @optimize_128x32xf16_32x128xf16(%arg0: memref<128x128xf16>,
5                               %ldRow: index, %ldCol: index,
6                               %stRow: index, %stCol: index,
7                               %fragRow: index, %fragCol :index)
8                                -> (vector<4x2xf16>, vector<4x2xf16>) {
9  // CHECK: [[shm:%.+]] = memref.alloc
10  // CHECK: [[shmB:%.+]] = memref.alloc
11  %shm = memref.alloc() : memref<128x32xf16, 3>
12  %shmB = memref.alloc() : memref<32x128xf16, 3>
13
14  // CHECK: [[c6:%.+]] = arith.constant 6 : index
15  // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c6]]
16  // CHECK: [[c2:%.+]] = arith.constant 2 : index
17  // CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c2]]
18  // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]]
19  // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stColPerm]]]
20  %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 8
21      : memref<128x128xf16> to memref<128x32xf16, 3>
22  %1 = nvgpu.device_async_create_group %0
23  nvgpu.device_async_wait %1 { numGroups = 1 : i32}
24
25  // CHECK: [[c6:%.+]] = arith.constant 6 : index
26  // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
27  // CHECK: [[c2:%.+]] = arith.constant 2 : index
28  // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
29  // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
30  // CHECK: nvgpu.ldmatrix [[shm]][[[fragRow]], [[fragColPerm]]]
31  %mat = nvgpu.ldmatrix %shm[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false}
32      : memref<128x32xf16, 3> -> vector<4x2xf16>
33
34  // CHECK: [[c15:%.+]] = arith.constant 15 : index
35  // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c15]]
36  // CHECK: [[c3:%.+]] = arith.constant 3 : index
37  // CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c3]]
38  // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]]
39  // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shmB]][[[stRow]], [[stColPerm]]]
40  %2 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shmB[%stRow, %stCol], 8
41      : memref<128x128xf16> to memref<32x128xf16, 3>
42  %3 = nvgpu.device_async_create_group %0
43  nvgpu.device_async_wait %1 { numGroups = 1 : i32}
44
45  // CHECK: [[c15:%.+]] = arith.constant 15 : index
46  // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c15]]
47  // CHECK: [[c3:%.+]] = arith.constant 3 : index
48  // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c3]]
49  // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
50  // CHECK: nvgpu.ldmatrix [[shmB]][[[fragRow]], [[fragColPerm]]]
51  %matB = nvgpu.ldmatrix %shmB[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false}
52      : memref<32x128xf16, 3> -> vector<4x2xf16>
53
54  return %mat, %matB: vector<4x2xf16>, vector<4x2xf16>
55}
56
57
58// -----
59
60// CHECK: @optimize_64x16xf32_16x64xf32([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index)
61func.func @optimize_64x16xf32_16x64xf32(%arg0: memref<128x128xf32>,
62                               %ldRow: index, %ldCol: index,
63                               %stRow: index, %stCol: index,
64                               %fragRow: index, %fragCol :index)
65                                -> (vector<4x1xf32>, vector<4x1xf32>, f32, vector<4xf32>, f32) {
66  // CHECK: [[shm:%.+]] = memref.alloc
67  // CHECK: [[shmB:%.+]] = memref.alloc
68  %shm = memref.alloc() : memref<64x16xf32, 3>
69  %shmB = memref.alloc() : memref<16x64xf32, 3>
70
71  // CHECK: [[c6:%.+]] = arith.constant 6 : index
72  // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c6]]
73  // CHECK: [[c1:%.+]] = arith.constant 1 : index
74  // CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c1]]
75  // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]]
76  // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stColPerm]]]
77  %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 4
78      : memref<128x128xf32> to memref<64x16xf32, 3>
79  %1 = nvgpu.device_async_create_group %0
80  nvgpu.device_async_wait %1 { numGroups = 1 : i32}
81
82  // CHECK: [[c6:%.+]] = arith.constant 6 : index
83  // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
84  // CHECK: [[c1:%.+]] = arith.constant 1 : index
85  // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]]
86  // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
87  // CHECK: nvgpu.ldmatrix [[shm]][[[fragRow]], [[fragColPerm]]]
88  %mat = nvgpu.ldmatrix %shm[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false}
89      : memref<64x16xf32, 3> -> vector<4x1xf32>
90
91  // CHECK: [[c6:%.+]] = arith.constant 6 : index
92  // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
93  // CHECK: [[c1:%.+]] = arith.constant 1 : index
94  // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]]
95  // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
96  // CHECK: memref.load [[shm]][[[fragRow]], [[fragColPerm]]]
97  %elem = memref.load %shm[%fragRow, %fragCol] : memref<64x16xf32, 3>
98
99  // Verify vector operations.
100
101  // CHECK: [[c6:%.+]] = arith.constant 6 : index
102  // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
103  // CHECK: [[c1:%.+]] = arith.constant 1 : index
104  // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]]
105  // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
106  // CHECK: vector.load [[shm]][[[fragRow]], [[fragColPerm]]]
107  %elem2 = vector.load %shm[%fragRow, %fragCol] : memref<64x16xf32, 3>, vector<4xf32>
108
109  // CHECK: [[c6:%.+]] = arith.constant 6 : index
110  // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
111  // CHECK: [[c1:%.+]] = arith.constant 1 : index
112  // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]]
113  // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
114  // CHECK: vector.store %{{.+}}, [[shm]][[[fragRow]], [[fragColPerm]]]
115  vector.store %elem2, %shm[%fragRow, %fragCol] : memref<64x16xf32, 3>, vector<4xf32>
116
117  // CHECK: [[c6:%.+]] = arith.constant 6 : index
118  // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
119  // CHECK: [[c1:%.+]] = arith.constant 1 : index
120  // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]]
121  // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
122  // CHECK: memref.store %{{.+}}, [[shm]][[[fragRow]], [[fragColPerm]]]
123  memref.store %elem, %shm[%fragRow, %fragCol] : memref<64x16xf32, 3>
124
125  // Verify 16x64xf32 memory size.
126
127  // CHECK: [[c15:%.+]] = arith.constant 15 : index
128  // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c15]]
129  // CHECK: [[c2:%.+]] = arith.constant 2 : index
130  // CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c2]]
131  // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]]
132  // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shmB]][[[stRow]], [[stColPerm]]]
133  %2 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shmB[%stRow, %stCol], 4
134      : memref<128x128xf32> to memref<16x64xf32, 3>
135  %3 = nvgpu.device_async_create_group %0
136  nvgpu.device_async_wait %1 { numGroups = 1 : i32}
137
138  // CHECK: [[c15:%.+]] = arith.constant 15 : index
139  // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c15]]
140  // CHECK: [[c2:%.+]] = arith.constant 2 : index
141  // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
142  // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
143  // CHECK: nvgpu.ldmatrix [[shmB]][[[fragRow]], [[fragColPerm]]]
144  %matB = nvgpu.ldmatrix %shmB[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false}
145      : memref<16x64xf32, 3> -> vector<4x1xf32>
146
147  // CHECK: [[c15:%.+]] = arith.constant 15 : index
148  // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c15]]
149  // CHECK: [[c2:%.+]] = arith.constant 2 : index
150  // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
151  // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
152  // CHECK: memref.load [[shmB]][[[fragRow]], [[fragColPerm]]]
153  %elemB = memref.load %shmB[%fragRow, %fragCol] : memref<16x64xf32, 3>
154
155  return %mat, %matB, %elem, %elem2, %elemB: vector<4x1xf32>, vector<4x1xf32>, f32, vector<4xf32>, f32
156}
157
158
159// -----
160
161// Small column edge cases
162
163// CHECK: @small_column_size_f64([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index)
164func.func @small_column_size_f64(%arg0: memref<32x32xf64>,
165                               %ldRow: index, %ldCol: index,
166                               %stRow: index, %stCol: index,
167                               %fragRow: index, %fragCol :index)
168                                -> f64 {
169  // CHECK: [[shm:%.+]] = memref.alloc
170  %shm = memref.alloc() : memref<32x4xf64, 3>
171
172  // CHECK: [[c4:%.+]] = arith.constant 4 : index
173  // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c4]]
174  // CHECK: [[c1:%.+]] = arith.constant 1 : index
175  // CHECK: [[xorBits:%.+]] = arith.shrui [[src_bits]], [[c1]]
176  // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]]
177  // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stColPerm]]]
178  %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 2
179      : memref<32x32xf64> to memref<32x4xf64, 3>
180  %1 = nvgpu.device_async_create_group %0
181  nvgpu.device_async_wait %1 { numGroups = 1 : i32}
182
183  // CHECK: [[c6:%.+]] = arith.constant 4 : index
184  // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
185  // CHECK: [[c1:%.+]] = arith.constant 1 : index
186  // CHECK: [[xorBits:%.+]] = arith.shrui [[srcBits]], [[c1]]
187  // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
188  // CHECK: memref.load [[shm]][[[fragRow]], [[fragColPerm]]]
189  %el = memref.load %shm[%fragRow, %fragCol] : memref<32x4xf64, 3>
190
191  return %el: f64
192}
193
194// CHECK: @too_small_column_size_f16([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index)
195func.func @too_small_column_size_f16(%arg0: memref<128x128xf16>,
196                               %ldRow: index, %ldCol: index,
197                               %stRow: index, %stCol: index,
198                               %fragRow: index, %fragCol :index)
199                                -> vector<1x2xf16> {
200  // CHECK: [[shm:%.+]] = memref.alloc
201  %shm = memref.alloc() : memref<128x8xf16, 3>
202
203  // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stCol]]]
204  %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 8
205      : memref<128x128xf16> to memref<128x8xf16, 3>
206  %1 = nvgpu.device_async_create_group %0
207  nvgpu.device_async_wait %1 { numGroups = 1 : i32}
208
209  // CHECK: nvgpu.ldmatrix [[shm]][[[fragRow]], [[fragCol]]]
210  %mat = nvgpu.ldmatrix %shm[%fragRow, %fragCol] {numTiles = 1 : i32, transpose = false}
211      : memref<128x8xf16, 3> -> vector<1x2xf16>
212
213  return %mat: vector<1x2xf16>
214}
215
216// -----
217
218// CHECK: @abort_if_subview([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index)
219func.func @abort_if_subview(%arg0: memref<128x128xf16>,
220                               %ldRow: index, %ldCol: index,
221                               %stRow: index, %stCol: index,
222                               %fragRow: index, %fragCol :index)
223                                -> vector<1x2xf16> {
224  // CHECK: [[shm:%.+]] = memref.alloc
225  %shm = memref.alloc() : memref<128x32xf16, 3>
226  // CHECK: [[shmView:%.+]] = memref.subview
227  %shmView = memref.subview %shm[0, 0][64, 32][1, 1] : memref<128x32xf16, 3> to memref<64x32xf16, 3>
228
229  // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stCol]]]
230  %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 8
231      : memref<128x128xf16> to memref<128x32xf16, 3>
232  %1 = nvgpu.device_async_create_group %0
233  nvgpu.device_async_wait %1 { numGroups = 1 : i32}
234
235  // CHECK: nvgpu.ldmatrix [[shmView]][[[fragRow]], [[fragCol]]]
236  %mat = nvgpu.ldmatrix %shmView[%fragRow, %fragCol] {numTiles = 1 : i32, transpose = false}
237      : memref<64x32xf16, 3> -> vector<1x2xf16>
238
239  return %mat: vector<1x2xf16>
240}
241
242// -----
243
244// Ensure this case not crash
245
246// CHECK-LABEL: func @test_0_d
247func.func @test_0_d() -> memref<i32, #gpu.address_space<workgroup>> {
248  %alloc = memref.alloc() : memref<i32, #gpu.address_space<workgroup>>
249  return %alloc : memref<i32, #gpu.address_space<workgroup>>
250}
251