1// RUN: mlir-opt %s -split-input-file --pass-pipeline='builtin.module(func.func(nvgpu-optimize-shared-memory))' | FileCheck %s 2 3// CHECK: @optimize_128x32xf16_32x128xf16([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index) 4func.func @optimize_128x32xf16_32x128xf16(%arg0: memref<128x128xf16>, 5 %ldRow: index, %ldCol: index, 6 %stRow: index, %stCol: index, 7 %fragRow: index, %fragCol :index) 8 -> (vector<4x2xf16>, vector<4x2xf16>) { 9 // CHECK: [[shm:%.+]] = memref.alloc 10 // CHECK: [[shmB:%.+]] = memref.alloc 11 %shm = memref.alloc() : memref<128x32xf16, 3> 12 %shmB = memref.alloc() : memref<32x128xf16, 3> 13 14 // CHECK: [[c6:%.+]] = arith.constant 6 : index 15 // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c6]] 16 // CHECK: [[c2:%.+]] = arith.constant 2 : index 17 // CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c2]] 18 // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]] 19 // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stColPerm]]] 20 %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 8 21 : memref<128x128xf16> to memref<128x32xf16, 3> 22 %1 = nvgpu.device_async_create_group %0 23 nvgpu.device_async_wait %1 { numGroups = 1 : i32} 24 25 // CHECK: [[c6:%.+]] = arith.constant 6 : index 26 // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] 27 // CHECK: [[c2:%.+]] = arith.constant 2 : index 28 // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] 29 // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] 30 // CHECK: nvgpu.ldmatrix [[shm]][[[fragRow]], [[fragColPerm]]] 31 %mat = nvgpu.ldmatrix %shm[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false} 32 : memref<128x32xf16, 3> -> vector<4x2xf16> 33 34 // CHECK: [[c15:%.+]] = arith.constant 15 : index 35 // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c15]] 36 // CHECK: [[c3:%.+]] = arith.constant 3 : index 37 // CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c3]] 38 // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]] 39 // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shmB]][[[stRow]], [[stColPerm]]] 40 %2 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shmB[%stRow, %stCol], 8 41 : memref<128x128xf16> to memref<32x128xf16, 3> 42 %3 = nvgpu.device_async_create_group %0 43 nvgpu.device_async_wait %1 { numGroups = 1 : i32} 44 45 // CHECK: [[c15:%.+]] = arith.constant 15 : index 46 // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c15]] 47 // CHECK: [[c3:%.+]] = arith.constant 3 : index 48 // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c3]] 49 // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] 50 // CHECK: nvgpu.ldmatrix [[shmB]][[[fragRow]], [[fragColPerm]]] 51 %matB = nvgpu.ldmatrix %shmB[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false} 52 : memref<32x128xf16, 3> -> vector<4x2xf16> 53 54 return %mat, %matB: vector<4x2xf16>, vector<4x2xf16> 55} 56 57 58// ----- 59 60// CHECK: @optimize_64x16xf32_16x64xf32([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index) 61func.func @optimize_64x16xf32_16x64xf32(%arg0: memref<128x128xf32>, 62 %ldRow: index, %ldCol: index, 63 %stRow: index, %stCol: index, 64 %fragRow: index, %fragCol :index) 65 -> (vector<4x1xf32>, vector<4x1xf32>, f32, vector<4xf32>, f32) { 66 // CHECK: [[shm:%.+]] = memref.alloc 67 // CHECK: [[shmB:%.+]] = memref.alloc 68 %shm = memref.alloc() : memref<64x16xf32, 3> 69 %shmB = memref.alloc() : memref<16x64xf32, 3> 70 71 // CHECK: [[c6:%.+]] = arith.constant 6 : index 72 // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c6]] 73 // CHECK: [[c1:%.+]] = arith.constant 1 : index 74 // CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c1]] 75 // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]] 76 // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stColPerm]]] 77 %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 4 78 : memref<128x128xf32> to memref<64x16xf32, 3> 79 %1 = nvgpu.device_async_create_group %0 80 nvgpu.device_async_wait %1 { numGroups = 1 : i32} 81 82 // CHECK: [[c6:%.+]] = arith.constant 6 : index 83 // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] 84 // CHECK: [[c1:%.+]] = arith.constant 1 : index 85 // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]] 86 // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] 87 // CHECK: nvgpu.ldmatrix [[shm]][[[fragRow]], [[fragColPerm]]] 88 %mat = nvgpu.ldmatrix %shm[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false} 89 : memref<64x16xf32, 3> -> vector<4x1xf32> 90 91 // CHECK: [[c6:%.+]] = arith.constant 6 : index 92 // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] 93 // CHECK: [[c1:%.+]] = arith.constant 1 : index 94 // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]] 95 // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] 96 // CHECK: memref.load [[shm]][[[fragRow]], [[fragColPerm]]] 97 %elem = memref.load %shm[%fragRow, %fragCol] : memref<64x16xf32, 3> 98 99 // Verify vector operations. 100 101 // CHECK: [[c6:%.+]] = arith.constant 6 : index 102 // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] 103 // CHECK: [[c1:%.+]] = arith.constant 1 : index 104 // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]] 105 // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] 106 // CHECK: vector.load [[shm]][[[fragRow]], [[fragColPerm]]] 107 %elem2 = vector.load %shm[%fragRow, %fragCol] : memref<64x16xf32, 3>, vector<4xf32> 108 109 // CHECK: [[c6:%.+]] = arith.constant 6 : index 110 // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] 111 // CHECK: [[c1:%.+]] = arith.constant 1 : index 112 // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]] 113 // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] 114 // CHECK: vector.store %{{.+}}, [[shm]][[[fragRow]], [[fragColPerm]]] 115 vector.store %elem2, %shm[%fragRow, %fragCol] : memref<64x16xf32, 3>, vector<4xf32> 116 117 // CHECK: [[c6:%.+]] = arith.constant 6 : index 118 // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] 119 // CHECK: [[c1:%.+]] = arith.constant 1 : index 120 // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]] 121 // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] 122 // CHECK: memref.store %{{.+}}, [[shm]][[[fragRow]], [[fragColPerm]]] 123 memref.store %elem, %shm[%fragRow, %fragCol] : memref<64x16xf32, 3> 124 125 // Verify 16x64xf32 memory size. 126 127 // CHECK: [[c15:%.+]] = arith.constant 15 : index 128 // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c15]] 129 // CHECK: [[c2:%.+]] = arith.constant 2 : index 130 // CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c2]] 131 // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]] 132 // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shmB]][[[stRow]], [[stColPerm]]] 133 %2 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shmB[%stRow, %stCol], 4 134 : memref<128x128xf32> to memref<16x64xf32, 3> 135 %3 = nvgpu.device_async_create_group %0 136 nvgpu.device_async_wait %1 { numGroups = 1 : i32} 137 138 // CHECK: [[c15:%.+]] = arith.constant 15 : index 139 // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c15]] 140 // CHECK: [[c2:%.+]] = arith.constant 2 : index 141 // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] 142 // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] 143 // CHECK: nvgpu.ldmatrix [[shmB]][[[fragRow]], [[fragColPerm]]] 144 %matB = nvgpu.ldmatrix %shmB[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false} 145 : memref<16x64xf32, 3> -> vector<4x1xf32> 146 147 // CHECK: [[c15:%.+]] = arith.constant 15 : index 148 // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c15]] 149 // CHECK: [[c2:%.+]] = arith.constant 2 : index 150 // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] 151 // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] 152 // CHECK: memref.load [[shmB]][[[fragRow]], [[fragColPerm]]] 153 %elemB = memref.load %shmB[%fragRow, %fragCol] : memref<16x64xf32, 3> 154 155 return %mat, %matB, %elem, %elem2, %elemB: vector<4x1xf32>, vector<4x1xf32>, f32, vector<4xf32>, f32 156} 157 158 159// ----- 160 161// Small column edge cases 162 163// CHECK: @small_column_size_f64([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index) 164func.func @small_column_size_f64(%arg0: memref<32x32xf64>, 165 %ldRow: index, %ldCol: index, 166 %stRow: index, %stCol: index, 167 %fragRow: index, %fragCol :index) 168 -> f64 { 169 // CHECK: [[shm:%.+]] = memref.alloc 170 %shm = memref.alloc() : memref<32x4xf64, 3> 171 172 // CHECK: [[c4:%.+]] = arith.constant 4 : index 173 // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c4]] 174 // CHECK: [[c1:%.+]] = arith.constant 1 : index 175 // CHECK: [[xorBits:%.+]] = arith.shrui [[src_bits]], [[c1]] 176 // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]] 177 // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stColPerm]]] 178 %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 2 179 : memref<32x32xf64> to memref<32x4xf64, 3> 180 %1 = nvgpu.device_async_create_group %0 181 nvgpu.device_async_wait %1 { numGroups = 1 : i32} 182 183 // CHECK: [[c6:%.+]] = arith.constant 4 : index 184 // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] 185 // CHECK: [[c1:%.+]] = arith.constant 1 : index 186 // CHECK: [[xorBits:%.+]] = arith.shrui [[srcBits]], [[c1]] 187 // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] 188 // CHECK: memref.load [[shm]][[[fragRow]], [[fragColPerm]]] 189 %el = memref.load %shm[%fragRow, %fragCol] : memref<32x4xf64, 3> 190 191 return %el: f64 192} 193 194// CHECK: @too_small_column_size_f16([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index) 195func.func @too_small_column_size_f16(%arg0: memref<128x128xf16>, 196 %ldRow: index, %ldCol: index, 197 %stRow: index, %stCol: index, 198 %fragRow: index, %fragCol :index) 199 -> vector<1x2xf16> { 200 // CHECK: [[shm:%.+]] = memref.alloc 201 %shm = memref.alloc() : memref<128x8xf16, 3> 202 203 // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stCol]]] 204 %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 8 205 : memref<128x128xf16> to memref<128x8xf16, 3> 206 %1 = nvgpu.device_async_create_group %0 207 nvgpu.device_async_wait %1 { numGroups = 1 : i32} 208 209 // CHECK: nvgpu.ldmatrix [[shm]][[[fragRow]], [[fragCol]]] 210 %mat = nvgpu.ldmatrix %shm[%fragRow, %fragCol] {numTiles = 1 : i32, transpose = false} 211 : memref<128x8xf16, 3> -> vector<1x2xf16> 212 213 return %mat: vector<1x2xf16> 214} 215 216// ----- 217 218// CHECK: @abort_if_subview([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index) 219func.func @abort_if_subview(%arg0: memref<128x128xf16>, 220 %ldRow: index, %ldCol: index, 221 %stRow: index, %stCol: index, 222 %fragRow: index, %fragCol :index) 223 -> vector<1x2xf16> { 224 // CHECK: [[shm:%.+]] = memref.alloc 225 %shm = memref.alloc() : memref<128x32xf16, 3> 226 // CHECK: [[shmView:%.+]] = memref.subview 227 %shmView = memref.subview %shm[0, 0][64, 32][1, 1] : memref<128x32xf16, 3> to memref<64x32xf16, 3> 228 229 // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stCol]]] 230 %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 8 231 : memref<128x128xf16> to memref<128x32xf16, 3> 232 %1 = nvgpu.device_async_create_group %0 233 nvgpu.device_async_wait %1 { numGroups = 1 : i32} 234 235 // CHECK: nvgpu.ldmatrix [[shmView]][[[fragRow]], [[fragCol]]] 236 %mat = nvgpu.ldmatrix %shmView[%fragRow, %fragCol] {numTiles = 1 : i32, transpose = false} 237 : memref<64x32xf16, 3> -> vector<1x2xf16> 238 239 return %mat: vector<1x2xf16> 240} 241 242// ----- 243 244// Ensure this case not crash 245 246// CHECK-LABEL: func @test_0_d 247func.func @test_0_d() -> memref<i32, #gpu.address_space<workgroup>> { 248 %alloc = memref.alloc() : memref<i32, #gpu.address_space<workgroup>> 249 return %alloc : memref<i32, #gpu.address_space<workgroup>> 250} 251