1# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ 2# RUN: %PYTHON %s | FileCheck %s 3 4# ===----------------------------------------------------------------------===// 5# Chapter 5 : Warp Specialized GEMM with Tensor Core 6# ===----------------------------------------------------------------------===// 7# 8# This program demonstrates a GEMM operation for `f32+=f16*f16`, utilizing the 9# Warp Specialized method with a tile size of 128x128x64. The code completely 10# parallelizes the two outermost loops into thread blocks. It launches two Warp 11# Groups (256 threads in total): one for the producer and the other for the consumer. 12# Each group takes a different control-flow. The producer thread group is responsible 13# for loading data into shared memory, while the consumer group executes the Tensor 14# Core GEMM operation and epilogue. 15# 16# for ti in range(M//128): # -> blockIdx.x 17# for tj in range(N//128): # -> blockIdx.y 18# with wg_producer: 19# for tk in range(K//64): 20# TMA_128x64_64x128... 21# with wg_consumer: 22# for tk in range(K//64): 23# MMA_128x128x64... 24# Epilogue.. 25# 26# This chapter demonstrates: 27# 2 WG (warpgroups) 28# Producer: 29# 2.1.1 Wait MMA Barrier 30# 2.1.1 Load TMA with TMA barrier 31# 2.1.1 Arrive TMA barrier with txcount 32# Consumer: 33# Loop 34# Wait TMA barrier 35# Performs Tensor Core GEMM 64x128x64 by warpgroup 36# Arrive MMA Barrier 37# Epilogue 38# Store fragmented registers to shared memory 39# Store shared memory to global 40# 41# ===----------------------------------------------------------------------===// 42 43 44from mlir import ir 45from mlir.dialects import gpu, scf, nvgpu, nvvm 46from mlir.extras import types as T 47from tools.nvdsl import * 48import numpy as np 49 50 51def partition_shape(): 52 """ 53 Calculate the partition shape based on the block IDs. 54 55 It parallelizes the two outermost loops into thread blocks. 56 for ti in range(M//128): # -> blockIdx.x 57 for tj in range(N//128): # -> blockIdx.y 58 D = 0 59 for tk in range(K//64): 60 for i in range(128): 61 for j in range(128): 62 for k in range(64): 63 FMA 64 65 Returns: 66 dimX (int): Dimension along the x-axis. 67 dimY (int): Dimension along the y-axis. 68 """ 69 bidx = gpu.block_id(gpu.Dimension.x) 70 bidy = gpu.block_id(gpu.Dimension.y) 71 dimX = bidx * TILE_M 72 dimY = bidy * TILE_N 73 return dimX, dimY 74 75 76def tma_load( 77 mbar_group: Mbarriers, 78 a_tma: TMA, 79 b_tma: TMA, 80 slot, 81 stage, 82 num_stages, 83 p=None, 84): 85 """ 86 TMA loads two input matrices from global memory to shared memory. It performs the following operations: 87 88 - tma.load a_shared_memory[off_x] at coordinate [x, z] (Loads 128x64) 89 - tma.load b_shared_memory[off_y1] at coordinate [y, x] (Loads 64x64) 90 - tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64) 91 92 mbarrier.arrive ta_count = 128x64x2x4 93 """ 94 dimX, dimY = partition_shape() 95 96 tidx = gpu.thread_id(gpu.Dimension.x) 97 begin_b = num_stages * get_type_size(a_tma.tma_memref) 98 size_tma_a = get_type_size(a_tma.tma_memref) 99 size_tma_b = get_type_size(b_tma.tma_memref) 100 ta_count = size_tma_a + (size_tma_b * 2) 101 102 off_a = slot * size_tma_a 103 off_b = (slot * size_tma_a) + begin_b 104 off_b2 = off_b + size_tma_b 105 a_elem_ty = a_tma.tma_memref.element_type 106 b_elem_ty = b_tma.tma_memref.element_type 107 a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty, off_a) 108 b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b) 109 b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2) 110 111 mbar_group[slot].arrive(ta_count, predicate=p) 112 p = (tidx % WARP_GROUP_SIZE) == 0 113 c1 = stage * 64 114 a_tma.load(a, mbar_group[slot], coords=[c1, dimX], predicate=p) 115 b_tma.load(b1, mbar_group[slot], coords=[dimY, c1], predicate=p) 116 b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p) 117 118 119def initialize(a_tma: TMA, b_tma: TMA, num_stages): 120 """ 121 Initialize mbarriers and prefetch TMA descriptors. 122 """ 123 tidx = gpu.thread_id(gpu.Dimension.x) 124 mbar_group_tma = Mbarriers(number_of_barriers=num_stages) 125 mbar_group_mma = Mbarriers(number_of_barriers=num_stages) 126 isThread0 = tidx == const(0) 127 with ir.InsertionPoint(scf.IfOp(isThread0).then_block): 128 for i in scf.for_(0, num_stages, 1): 129 mbar_group_tma[i].init(1) 130 mbar_group_mma[i].init(1) 131 scf.yield_([]) 132 a_tma.prefetch() 133 b_tma.prefetch() 134 scf.yield_([]) 135 136 return mbar_group_tma, mbar_group_mma 137 138 139def switch_phase(stage, phase, num_stages): 140 p = stage == (num_stages - 1) 141 phase = arith.select( 142 p, 143 (phase ^ const(True, ty=T.bool())), 144 phase, 145 ) 146 return phase 147 148 149def producer_loop( 150 mbar_tma: Mbarriers, 151 mbar_mma: Mbarriers, 152 a_tma: TMA, 153 b_tma: TMA, 154 wg_me: Warpgroup, 155 num_stages, 156): 157 phase = const(True, ty=T.bool()) 158 159 for iv, phase in scf.for_(0, (K // TILE_K), 1, [phase]): 160 stage = iv % num_stages 161 # Wait MMA to be done 162 mbar_mma[stage].try_wait(phase) 163 # New phase for mbarrier 164 phase = switch_phase(stage, phase, num_stages) 165 # TMA Load 166 tma_load(mbar_tma, a_tma, b_tma, stage, iv, num_stages, wg_me.is_wg_primary) 167 scf.yield_([phase]) 168 169 170def consumer_loop( 171 mbar_tma: Mbarriers, 172 mbar_mma: Mbarriers, 173 a_tma: TMA, 174 b_tma: TMA, 175 wg_me: Warpgroup, 176 num_stages, 177): 178 begin_b = num_stages * get_type_size(a_tma.tma_memref) 179 180 size_a = TILE_M * TILE_K * get_type_size(T.f16()) 181 182 phase = const(False, ty=T.bool()) 183 A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma) 184 B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma) 185 D = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32()) 186 187 for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [D.acc_op, phase]) 188 with ir.InsertionPoint(for_op.body): 189 phase = for_op.inner_iter_args[1] 190 iv = for_op.induction_variable 191 stage = iv % num_stages 192 193 # Wait TMA for current stage 194 mbar_tma[stage].try_wait(phase) 195 196 # Find shared memory slot 197 offset_a = stage * size_a 198 offset_b = offset_a + begin_b 199 a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a) 200 b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b) 201 202 # Iterate input matrices, update accumulator 203 A.update_smem(a_smem) 204 B.update_smem(b_smem) 205 D.update_accumulator(for_op.inner_iter_args[0]) 206 207 # Matrix Multiply 208 D += A @ B 209 210 # MMA Barrier Arrive 211 p_arrive = (iv > 0) & wg_me.is_wg_primary 212 with ir.InsertionPoint(scf.IfOp(p_arrive).then_block): 213 barId = arith.select((stage == 0), const(num_stages - 1), (stage - 1)) 214 mbar_mma[barId].arrive() 215 scf.yield_([]) 216 217 phase = switch_phase(stage, phase, num_stages) 218 scf.yield_([D.acc_op, phase]) 219 220 nvvm.WgmmaWaitGroupSyncOp(0) 221 D.update_accumulator(for_op.results[0]) 222 return D 223 224 225def epilogue(D: WGMMAMatrix, d_dev): 226 """ 227 Epilogue of the GEMM kernel. It stores the fragmented registers to global memory. 228 229 MatrixAccumulator D # Fragmented results 230 store D -> Shared Memory # Store Shared Memory 231 Shared Memory -> Z[dimX][dimY] # Store Shared Memory to Global Memory 232 233 """ 234 tidx = gpu.thread_id(gpu.Dimension.x) 235 dimX, dimY = partition_shape() 236 # s = tidx - WARP_GROUP_SIZE 237 # debug_print("[Epilogue] store to global memory @ s={}", s) 238 239 d_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32()) 240 d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1]) 241 242 # Store (registers -> shared memory) 243 D.store_accumulator(d_smem) 244 gpu.barrier() 245 246 # Store (shared memory --> global memory) 247 for i in scf.for_(0, TILE_M, 1): 248 val = memref.load(d_smem, [i, tidx]) 249 memref.store(val, d_gmem, [i, tidx]) 250 scf.yield_([]) 251 252 253@NVDSL.mlir_func 254def gemm_warp_specialized(a, b, d, num_stages): 255 token_ty = gpu.AsyncTokenType.get() 256 t1 = gpu.wait(token_ty, []) 257 a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], []) 258 b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], []) 259 d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], []) 260 t5 = gpu.memcpy(token_ty, [t4], a_dev, a) 261 t6 = gpu.memcpy(token_ty, [t5], b_dev, b) 262 t7 = gpu.wait(token_ty, [t6]) 263 264 sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B 265 a_tma = TMA([128, 64], a.type, swizzle=sw) 266 b_tma = TMA([64, 64], b.type, swizzle=sw) 267 a_tma.create_descriptor(a_dev) 268 b_tma.create_descriptor(b_dev) 269 270 grid = [(M // TILE_M), (N // TILE_N), 1] 271 block = [256, 1, 1] 272 273 size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K 274 size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K 275 smem_size_in_bytes = (size_a + size_b) * num_stages 276 277 @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes) 278 def gemm_warp_specialized_kernel(): 279 # Init Warpgroups 280 wg_producer = Warpgroup(primary_thread=128, register_size=40) 281 wg_consumer = Warpgroup(primary_thread=0, register_size=232) 282 283 # Initialize mbarriers and prefetch TMA descriptors 284 mbar_mma, mbar_tma = initialize(a_tma, b_tma, num_stages) 285 286 # Producer performs TMA 287 with wg_producer: 288 producer_loop(mbar_tma, mbar_mma, a_tma, b_tma, wg_producer, num_stages) 289 290 # Consumer performs MMA/Tensor Core 291 with wg_consumer: 292 D = consumer_loop(mbar_tma, mbar_mma, a_tma, b_tma, wg_consumer, num_stages) 293 epilogue(D, d_dev) 294 295 gemm_warp_specialized_kernel() 296 297 t8 = gpu.memcpy(token_ty, [t7], d, d_dev) 298 gpu.wait(None, [t8]) 299 300 301# Python pass arguments to MLIR 302N = 256 303M = 512 304K = 1024 305TILE_M = 128 306TILE_N = 128 307TILE_K = 64 308a = np.random.randn(M, K).astype(np.float16) 309b = np.random.randn(K, N).astype(np.float16) 310d = np.zeros((M, N), np.float32) 311 312gemm_warp_specialized(a, b, d, num_stages=7) 313 314 315# Verify MLIR with reference computation 316ref_d = a.astype(np.float16) @ b.astype(np.float16) 317np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01) 318 319 320print("PASS") 321# CHECK-NOT: Mismatched elements 322