1# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ 2# RUN: %PYTHON %s | FileCheck %s 3 4# ===----------------------------------------------------------------------===// 5# Chapter 4 : Multistage GEMM with Tensor Core 6# ===----------------------------------------------------------------------===// 7# 8# This program exemplifies a GEMM operation for `f32+=f16*f16`, utilizing the 9# Multistage method with a tile size of 128x128x64. The code completely 10# parallelizes the two outermost loops into thread blocks. It launches one Warp 11# Groups (128 threads in total) and allocates multiple slots/stage in the 12# shared memory. The program consists of three main parts: prologue, mainloop, 13# and epilogue. In the prologue, thread0 requests for TMA to load data into 14# shared memory slots. The mainloop executes MMA while simultaneously loading 15# TMA for the utilized slots. This overlap of TMA and MMA operations enhances 16# performance by maximizing computational throughput. 17# 18# Loops illustration: 19# 20# for s in range(num_stages): 21# TMA_128x64_64x128... 22# for ti in range(M//128): # -> blockIdx.x 23# for tj in range(N//128): # -> blockIdx.y 24# for tk in range(K//64): 25# MMA_128x128x64... 26# TMA_128x64_64x128... 27# Epilogue... 28# 29# This chapter introduces demonstrates: 30# 1. Partition shape based on block IDs 31# 2. Prologue 32# 2.1 Execute TMA Load for two input matrices for each stage 33# 3. Main loop 34# 3.1 Wait for completion of TMA load with mbarrier 35# 3.2 Performs Tensor Core GEMM 64x128x64 by warpgroup 36# 3.3 Load next stage if needed 37# 4. Epilogue 38# 4.1 Store fragmented registers to shared memory 39# 4.2 Store shared memory to global 40# 41# ===----------------------------------------------------------------------===// 42 43 44from mlir import ir 45from mlir.dialects import gpu, scf, nvgpu, nvvm 46from mlir.extras import types as T 47from tools.nvdsl import * 48import numpy as np 49 50 51def partition_shape(): 52 """ 53 Calculate the partition shape based on the block IDs. 54 55 It partitions the shape like below: 56 for(.. i < M ...) --> blockIdx.x 57 for(.. j < N ...) --> blockIdx.y 58 for(.. k < K ...) 59 60 Returns: 61 dimX (int): Dimension along the x-axis. 62 dimY (int): Dimension along the y-axis. 63 """ 64 bidx = gpu.block_id(gpu.Dimension.x) 65 bidy = gpu.block_id(gpu.Dimension.y) 66 dimX = bidx * TILE_M 67 dimY = bidy * TILE_N 68 return dimX, dimY 69 70 71def tma_load( 72 mbar_group: Mbarriers, 73 a_tma: TMA, 74 b_tma: TMA, 75 slot, 76 stage, 77 num_stages, 78 p=None, 79): 80 """ 81 TMA loads two input matrices from global memory to shared memory. It performs the following operations: 82 83 - tma.load a_shared_memory[off_x] at coordinate [x, z] (Loads 128x64) 84 - tma.load b_shared_memory[off_y1] at coordinate [y, x] (Loads 64x64) 85 - tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64) 86 87 mbarrier.arrive ta_count = 128x64x2x4 88 """ 89 dimX, dimY = partition_shape() 90 91 tidx = gpu.thread_id(gpu.Dimension.x) 92 begin_b = num_stages * get_type_size(a_tma.tma_memref) 93 size_tma_a = get_type_size(a_tma.tma_memref) 94 size_tma_b = get_type_size(b_tma.tma_memref) 95 ta_count = size_tma_a + (size_tma_b * 2) 96 tidx = gpu.thread_id(gpu.Dimension.x) 97 98 p = tidx == 0 if p is None else p 99 100 off_a = slot * size_tma_a 101 off_b = (slot * size_tma_a) + begin_b 102 off_b2 = off_b + size_tma_b 103 a_elem_ty = a_tma.tma_memref.element_type 104 b_elem_ty = b_tma.tma_memref.element_type 105 a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty, off_a) 106 b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b) 107 b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2) 108 109 mbar_group[slot].arrive(ta_count, predicate=p) 110 111 c1 = stage * 64 112 a_tma.load(a, mbar_group[slot], coords=[c1, dimX], predicate=p) 113 b_tma.load(b1, mbar_group[slot], coords=[dimY, c1], predicate=p) 114 b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p) 115 116 117def initialize(a_tma: TMA, b_tma: TMA, num_stages): 118 """ 119 Initialize mbarriers and prefetch TMA descriptors. 120 """ 121 tidx = gpu.thread_id(gpu.Dimension.x) 122 mbar_group = Mbarriers(number_of_barriers=num_stages) 123 isThread0 = tidx == const(0) 124 with ir.InsertionPoint(scf.IfOp(isThread0).then_block): 125 for i in scf.for_(0, num_stages, 1): 126 mbar_group[i].init(1) 127 scf.yield_([]) 128 a_tma.prefetch() 129 b_tma.prefetch() 130 scf.yield_([]) 131 132 return mbar_group 133 134 135def prologue(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA, num_stages): 136 """ 137 Prologue of the GEMM kernel. It loads 2 input matrices for each stage in loop like below: 138 139 for stage in range(NUM_STAGES): 140 tma_load x, y, stage 141 142 """ 143 ns = num_stages if num_stages == 1 else num_stages - 1 144 for iv in scf.for_(0, ns, 1): 145 tma_load(mbar_group, a_tma, b_tma, iv, iv, num_stages) 146 scf.yield_([]) 147 148 149def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA, num_stages): 150 """ 151 Main loop of the Multistage GEMM kernel. It iterates through 152 stages and performs matrix multiplication, loading data by TMA to shared memory. It like following 153 154 MatrixAccumulator D 155 for k in range(K // TILE_K): 156 157 try_wait(stage, ...) # Wait TMA load 158 159 Matrix A(stage, ...) # Find shared memory slot 160 Matrix B(stage, ...) # Find shared memory slot 161 D += A @ B # Multiply and accumulate 162 163 if(needLoad) # Load next stage if needed 164 tma_load(x, y, nextSlot, nextStage) 165 166 """ 167 ns = num_stages if num_stages == 1 else num_stages - 1 168 169 tidx = gpu.thread_id(gpu.Dimension.x) 170 begin_b = num_stages * get_type_size(a_tma.tma_memref) 171 172 size_a = TILE_M * TILE_K * get_type_size(T.f16()) 173 174 # Initialize A and B (input matrices) and C (accumulator) 175 A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma) 176 B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma) 177 D = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32()) 178 179 phase = const(False, ty=T.bool()) 180 181 # Main Loop 182 for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [D.acc_op, phase]) 183 with ir.InsertionPoint(for_op.body): 184 phase = for_op.inner_iter_args[1] 185 iv = for_op.induction_variable 186 stage = iv % num_stages 187 188 # Wait for current stage 189 mbar_group[stage].try_wait(phase=phase) 190 191 # Find shared memory slot 192 offset_a = stage * size_a 193 offset_b = offset_a + begin_b 194 a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a) 195 b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b) 196 197 # Iterate input matrices, update accumulator 198 A.update_smem(a_smem) 199 B.update_smem(b_smem) 200 D.update_accumulator(for_op.inner_iter_args[0]) 201 202 # Matrix Multiply 203 D += A @ B 204 205 # Wait Tensor Core for single stage 206 if num_stages == 1: 207 nvvm.WgmmaWaitGroupSyncOp(0) 208 209 # Load next stage 210 pred = ((iv + ns) < const(K // TILE_K)) & (tidx == 0) 211 nextStage = iv + ns 212 nextSlot = nextStage % num_stages 213 tma_load(mbar_group, a_tma, b_tma, nextSlot, nextStage, num_stages, pred) 214 215 # Switch phase parity for the mbarrier 216 newPhase = arith.select( 217 stage == (num_stages - 1), 218 (phase ^ const(True, ty=T.bool())), 219 phase, 220 ) 221 scf.yield_([D.acc_op, newPhase]) 222 223 nvvm.WgmmaWaitGroupSyncOp(0) 224 225 D.update_accumulator(for_op.results[0]) 226 return D 227 228 229def epilogue(D: WGMMAMatrix, d_dev): 230 """ 231 Epilogue of the GEMM kernel. It stores the fragmented registers to global memory. 232 233 MatrixAccumulator D # Fragmented results 234 store D -> Shared Memory # Store Shared Memory 235 Shared Memory -> Z[dimX][dimY] # Store Shared Memory to Global Memory 236 237 """ 238 tidx = gpu.thread_id(gpu.Dimension.x) 239 dimX, dimY = partition_shape() 240 241 d_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32()) 242 d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1]) 243 244 # Store (registers -> shared memory) 245 D.store_accumulator(d_smem) 246 gpu.barrier() 247 248 # Store (shared memory --> global memory) 249 for i in scf.for_(0, TILE_M, 1): 250 val = memref.load(d_smem, [i, tidx]) 251 memref.store(val, d_gmem, [i, tidx]) 252 scf.yield_([]) 253 254 255# The decorator generates 256# a -> memref<MxKxf16> 257# b -> memref<NxKf16> 258# d -> memref<MxNxf32> 259@NVDSL.mlir_func 260def gemm_multistage(a, b, d, num_stages): 261 token_ty = gpu.AsyncTokenType.get() 262 t1 = gpu.wait(token_ty, []) 263 a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], []) 264 b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], []) 265 d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], []) 266 t5 = gpu.memcpy(token_ty, [t4], a_dev, a) 267 t6 = gpu.memcpy(token_ty, [t5], b_dev, b) 268 t7 = gpu.wait(token_ty, [t6]) 269 270 sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B 271 a_tma = TMA([128, 64], a.type, swizzle=sw) 272 b_tma = TMA([64, 64], b.type, swizzle=sw) 273 a_tma.create_descriptor(a_dev) 274 b_tma.create_descriptor(b_dev) 275 276 grid = [(M // TILE_M), (N // TILE_N), 1] 277 block = [128, 1, 1] 278 279 size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K 280 size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K 281 smem_size_in_bytes = (size_a + size_b) * num_stages 282 283 @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes) 284 def gemm_multistage_kernel(): 285 # Initialize mbarriers and prefetch TMA descriptors 286 mbar_group = initialize(a_tma, b_tma, num_stages) 287 288 # Fill the pipeline stages 289 prologue(mbar_group, a_tma, b_tma, num_stages) 290 291 # Main loop 292 D = mainloop(mbar_group, a_tma, b_tma, num_stages) 293 294 # Store registers to global memory 295 epilogue(D, d_dev) 296 297 gemm_multistage_kernel() 298 299 t8 = gpu.memcpy(token_ty, [t7], d, d_dev) 300 gpu.wait(None, [t8]) 301 302 303# Python pass arguments to MLIR 304N = 256 305M = 512 306K = 1024 307TILE_M = 128 308TILE_N = 128 309TILE_K = 64 310a = np.random.randn(M, K).astype(np.float16) 311b = np.random.randn(K, N).astype(np.float16) 312d = np.zeros((M, N), np.float32) 313 314gemm_multistage(a, b, d, num_stages=7) 315 316 317# Verify MLIR with reference computation 318ref_d = a.astype(np.float16) @ b.astype(np.float16) 319np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01) 320 321 322print("PASS") 323# CHECK-NOT: Mismatched elements 324