14d330820SGuray Ozen# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ 24d330820SGuray Ozen# RUN: %PYTHON %s | FileCheck %s 34d330820SGuray Ozen 44d330820SGuray Ozen# ===----------------------------------------------------------------------===// 54d330820SGuray Ozen# Chapter 5 : Warp Specialized GEMM with Tensor Core 64d330820SGuray Ozen# ===----------------------------------------------------------------------===// 74d330820SGuray Ozen# 84d330820SGuray Ozen# This program demonstrates a GEMM operation for `f32+=f16*f16`, utilizing the 94d330820SGuray Ozen# Warp Specialized method with a tile size of 128x128x64. The code completely 104d330820SGuray Ozen# parallelizes the two outermost loops into thread blocks. It launches two Warp 114d330820SGuray Ozen# Groups (256 threads in total): one for the producer and the other for the consumer. 124d330820SGuray Ozen# Each group takes a different control-flow. The producer thread group is responsible 134d330820SGuray Ozen# for loading data into shared memory, while the consumer group executes the Tensor 144d330820SGuray Ozen# Core GEMM operation and epilogue. 154d330820SGuray Ozen# 164d330820SGuray Ozen# for ti in range(M//128): # -> blockIdx.x 174d330820SGuray Ozen# for tj in range(N//128): # -> blockIdx.y 184d330820SGuray Ozen# with wg_producer: 194d330820SGuray Ozen# for tk in range(K//64): 204d330820SGuray Ozen# TMA_128x64_64x128... 214d330820SGuray Ozen# with wg_consumer: 224d330820SGuray Ozen# for tk in range(K//64): 234d330820SGuray Ozen# MMA_128x128x64... 244d330820SGuray Ozen# Epilogue.. 254d330820SGuray Ozen# 264d330820SGuray Ozen# This chapter demonstrates: 274d330820SGuray Ozen# 2 WG (warpgroups) 284d330820SGuray Ozen# Producer: 294d330820SGuray Ozen# 2.1.1 Wait MMA Barrier 304d330820SGuray Ozen# 2.1.1 Load TMA with TMA barrier 314d330820SGuray Ozen# 2.1.1 Arrive TMA barrier with txcount 324d330820SGuray Ozen# Consumer: 334d330820SGuray Ozen# Loop 344d330820SGuray Ozen# Wait TMA barrier 354d330820SGuray Ozen# Performs Tensor Core GEMM 64x128x64 by warpgroup 364d330820SGuray Ozen# Arrive MMA Barrier 374d330820SGuray Ozen# Epilogue 384d330820SGuray Ozen# Store fragmented registers to shared memory 394d330820SGuray Ozen# Store shared memory to global 404d330820SGuray Ozen# 414d330820SGuray Ozen# ===----------------------------------------------------------------------===// 424d330820SGuray Ozen 434d330820SGuray Ozen 444d330820SGuray Ozenfrom mlir import ir 454d330820SGuray Ozenfrom mlir.dialects import gpu, scf, nvgpu, nvvm 464d330820SGuray Ozenfrom mlir.extras import types as T 474d330820SGuray Ozenfrom tools.nvdsl import * 484d330820SGuray Ozenimport numpy as np 494d330820SGuray Ozen 504d330820SGuray Ozen 514d330820SGuray Ozendef partition_shape(): 524d330820SGuray Ozen """ 534d330820SGuray Ozen Calculate the partition shape based on the block IDs. 544d330820SGuray Ozen 554d330820SGuray Ozen It parallelizes the two outermost loops into thread blocks. 564d330820SGuray Ozen for ti in range(M//128): # -> blockIdx.x 574d330820SGuray Ozen for tj in range(N//128): # -> blockIdx.y 584d330820SGuray Ozen D = 0 594d330820SGuray Ozen for tk in range(K//64): 604d330820SGuray Ozen for i in range(128): 614d330820SGuray Ozen for j in range(128): 624d330820SGuray Ozen for k in range(64): 634d330820SGuray Ozen FMA 644d330820SGuray Ozen 654d330820SGuray Ozen Returns: 664d330820SGuray Ozen dimX (int): Dimension along the x-axis. 674d330820SGuray Ozen dimY (int): Dimension along the y-axis. 684d330820SGuray Ozen """ 694d330820SGuray Ozen bidx = gpu.block_id(gpu.Dimension.x) 704d330820SGuray Ozen bidy = gpu.block_id(gpu.Dimension.y) 714d330820SGuray Ozen dimX = bidx * TILE_M 724d330820SGuray Ozen dimY = bidy * TILE_N 734d330820SGuray Ozen return dimX, dimY 744d330820SGuray Ozen 754d330820SGuray Ozen 764d330820SGuray Ozendef tma_load( 774d330820SGuray Ozen mbar_group: Mbarriers, 784d330820SGuray Ozen a_tma: TMA, 794d330820SGuray Ozen b_tma: TMA, 804d330820SGuray Ozen slot, 814d330820SGuray Ozen stage, 824d330820SGuray Ozen num_stages, 834d330820SGuray Ozen p=None, 844d330820SGuray Ozen): 854d330820SGuray Ozen """ 864d330820SGuray Ozen TMA loads two input matrices from global memory to shared memory. It performs the following operations: 874d330820SGuray Ozen 884d330820SGuray Ozen - tma.load a_shared_memory[off_x] at coordinate [x, z] (Loads 128x64) 894d330820SGuray Ozen - tma.load b_shared_memory[off_y1] at coordinate [y, x] (Loads 64x64) 904d330820SGuray Ozen - tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64) 914d330820SGuray Ozen 924d330820SGuray Ozen mbarrier.arrive ta_count = 128x64x2x4 934d330820SGuray Ozen """ 944d330820SGuray Ozen dimX, dimY = partition_shape() 954d330820SGuray Ozen 964d330820SGuray Ozen tidx = gpu.thread_id(gpu.Dimension.x) 974d330820SGuray Ozen begin_b = num_stages * get_type_size(a_tma.tma_memref) 984d330820SGuray Ozen size_tma_a = get_type_size(a_tma.tma_memref) 994d330820SGuray Ozen size_tma_b = get_type_size(b_tma.tma_memref) 1004d330820SGuray Ozen ta_count = size_tma_a + (size_tma_b * 2) 1014d330820SGuray Ozen 1024d330820SGuray Ozen off_a = slot * size_tma_a 1034d330820SGuray Ozen off_b = (slot * size_tma_a) + begin_b 1044d330820SGuray Ozen off_b2 = off_b + size_tma_b 1054d330820SGuray Ozen a_elem_ty = a_tma.tma_memref.element_type 1064d330820SGuray Ozen b_elem_ty = b_tma.tma_memref.element_type 1074d330820SGuray Ozen a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty, off_a) 1084d330820SGuray Ozen b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b) 1094d330820SGuray Ozen b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2) 1104d330820SGuray Ozen 1114d330820SGuray Ozen mbar_group[slot].arrive(ta_count, predicate=p) 1124d330820SGuray Ozen p = (tidx % WARP_GROUP_SIZE) == 0 1134d330820SGuray Ozen c1 = stage * 64 1144d330820SGuray Ozen a_tma.load(a, mbar_group[slot], coords=[c1, dimX], predicate=p) 1154d330820SGuray Ozen b_tma.load(b1, mbar_group[slot], coords=[dimY, c1], predicate=p) 1164d330820SGuray Ozen b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p) 1174d330820SGuray Ozen 1184d330820SGuray Ozen 1194d330820SGuray Ozendef initialize(a_tma: TMA, b_tma: TMA, num_stages): 1204d330820SGuray Ozen """ 1214d330820SGuray Ozen Initialize mbarriers and prefetch TMA descriptors. 1224d330820SGuray Ozen """ 1234d330820SGuray Ozen tidx = gpu.thread_id(gpu.Dimension.x) 1244d330820SGuray Ozen mbar_group_tma = Mbarriers(number_of_barriers=num_stages) 1254d330820SGuray Ozen mbar_group_mma = Mbarriers(number_of_barriers=num_stages) 1264d330820SGuray Ozen isThread0 = tidx == const(0) 1274d330820SGuray Ozen with ir.InsertionPoint(scf.IfOp(isThread0).then_block): 1284d330820SGuray Ozen for i in scf.for_(0, num_stages, 1): 1294d330820SGuray Ozen mbar_group_tma[i].init(1) 1304d330820SGuray Ozen mbar_group_mma[i].init(1) 1314d330820SGuray Ozen scf.yield_([]) 1324d330820SGuray Ozen a_tma.prefetch() 1334d330820SGuray Ozen b_tma.prefetch() 1344d330820SGuray Ozen scf.yield_([]) 1354d330820SGuray Ozen 1364d330820SGuray Ozen return mbar_group_tma, mbar_group_mma 1374d330820SGuray Ozen 1384d330820SGuray Ozen 1394d330820SGuray Ozendef switch_phase(stage, phase, num_stages): 1404d330820SGuray Ozen p = stage == (num_stages - 1) 1414d330820SGuray Ozen phase = arith.select( 1424d330820SGuray Ozen p, 1434d330820SGuray Ozen (phase ^ const(True, ty=T.bool())), 1444d330820SGuray Ozen phase, 1454d330820SGuray Ozen ) 1464d330820SGuray Ozen return phase 1474d330820SGuray Ozen 1484d330820SGuray Ozen 1494d330820SGuray Ozendef producer_loop( 1504d330820SGuray Ozen mbar_tma: Mbarriers, 1514d330820SGuray Ozen mbar_mma: Mbarriers, 1524d330820SGuray Ozen a_tma: TMA, 1534d330820SGuray Ozen b_tma: TMA, 1544d330820SGuray Ozen wg_me: Warpgroup, 1554d330820SGuray Ozen num_stages, 1564d330820SGuray Ozen): 1574d330820SGuray Ozen phase = const(True, ty=T.bool()) 1584d330820SGuray Ozen 1594d330820SGuray Ozen for iv, phase in scf.for_(0, (K // TILE_K), 1, [phase]): 1604d330820SGuray Ozen stage = iv % num_stages 1614d330820SGuray Ozen # Wait MMA to be done 1624d330820SGuray Ozen mbar_mma[stage].try_wait(phase) 1634d330820SGuray Ozen # New phase for mbarrier 1644d330820SGuray Ozen phase = switch_phase(stage, phase, num_stages) 1654d330820SGuray Ozen # TMA Load 1664d330820SGuray Ozen tma_load(mbar_tma, a_tma, b_tma, stage, iv, num_stages, wg_me.is_wg_primary) 1674d330820SGuray Ozen scf.yield_([phase]) 1684d330820SGuray Ozen 1694d330820SGuray Ozen 1704d330820SGuray Ozendef consumer_loop( 1714d330820SGuray Ozen mbar_tma: Mbarriers, 1724d330820SGuray Ozen mbar_mma: Mbarriers, 1734d330820SGuray Ozen a_tma: TMA, 1744d330820SGuray Ozen b_tma: TMA, 1754d330820SGuray Ozen wg_me: Warpgroup, 1764d330820SGuray Ozen num_stages, 1774d330820SGuray Ozen): 1784d330820SGuray Ozen begin_b = num_stages * get_type_size(a_tma.tma_memref) 1794d330820SGuray Ozen 1804d330820SGuray Ozen size_a = TILE_M * TILE_K * get_type_size(T.f16()) 1814d330820SGuray Ozen 1824d330820SGuray Ozen phase = const(False, ty=T.bool()) 1834d330820SGuray Ozen A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma) 1844d330820SGuray Ozen B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma) 1854d330820SGuray Ozen D = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32()) 1864d330820SGuray Ozen 1874d330820SGuray Ozen for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [D.acc_op, phase]) 1884d330820SGuray Ozen with ir.InsertionPoint(for_op.body): 1894d330820SGuray Ozen phase = for_op.inner_iter_args[1] 1904d330820SGuray Ozen iv = for_op.induction_variable 1914d330820SGuray Ozen stage = iv % num_stages 1924d330820SGuray Ozen 1934d330820SGuray Ozen # Wait TMA for current stage 1944d330820SGuray Ozen mbar_tma[stage].try_wait(phase) 1954d330820SGuray Ozen 1964d330820SGuray Ozen # Find shared memory slot 1974d330820SGuray Ozen offset_a = stage * size_a 1984d330820SGuray Ozen offset_b = offset_a + begin_b 1994d330820SGuray Ozen a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a) 2004d330820SGuray Ozen b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b) 2014d330820SGuray Ozen 2024d330820SGuray Ozen # Iterate input matrices, update accumulator 2034d330820SGuray Ozen A.update_smem(a_smem) 2044d330820SGuray Ozen B.update_smem(b_smem) 2054d330820SGuray Ozen D.update_accumulator(for_op.inner_iter_args[0]) 2064d330820SGuray Ozen 2074d330820SGuray Ozen # Matrix Multiply 2084d330820SGuray Ozen D += A @ B 2094d330820SGuray Ozen 2104d330820SGuray Ozen # MMA Barrier Arrive 2114d330820SGuray Ozen p_arrive = (iv > 0) & wg_me.is_wg_primary 2124d330820SGuray Ozen with ir.InsertionPoint(scf.IfOp(p_arrive).then_block): 2134d330820SGuray Ozen barId = arith.select((stage == 0), const(num_stages - 1), (stage - 1)) 2144d330820SGuray Ozen mbar_mma[barId].arrive() 2154d330820SGuray Ozen scf.yield_([]) 2164d330820SGuray Ozen 2174d330820SGuray Ozen phase = switch_phase(stage, phase, num_stages) 2184d330820SGuray Ozen scf.yield_([D.acc_op, phase]) 2194d330820SGuray Ozen 2204d330820SGuray Ozen nvvm.WgmmaWaitGroupSyncOp(0) 2214d330820SGuray Ozen D.update_accumulator(for_op.results[0]) 2224d330820SGuray Ozen return D 2234d330820SGuray Ozen 2244d330820SGuray Ozen 2254d330820SGuray Ozendef epilogue(D: WGMMAMatrix, d_dev): 2264d330820SGuray Ozen """ 2274d330820SGuray Ozen Epilogue of the GEMM kernel. It stores the fragmented registers to global memory. 2284d330820SGuray Ozen 2294d330820SGuray Ozen MatrixAccumulator D # Fragmented results 2304d330820SGuray Ozen store D -> Shared Memory # Store Shared Memory 2314d330820SGuray Ozen Shared Memory -> Z[dimX][dimY] # Store Shared Memory to Global Memory 2324d330820SGuray Ozen 2334d330820SGuray Ozen """ 2344d330820SGuray Ozen tidx = gpu.thread_id(gpu.Dimension.x) 2354d330820SGuray Ozen dimX, dimY = partition_shape() 2364d330820SGuray Ozen # s = tidx - WARP_GROUP_SIZE 2374d330820SGuray Ozen # debug_print("[Epilogue] store to global memory @ s={}", s) 2384d330820SGuray Ozen 2394d330820SGuray Ozen d_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32()) 2404d330820SGuray Ozen d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1]) 2414d330820SGuray Ozen 2424d330820SGuray Ozen # Store (registers -> shared memory) 2434d330820SGuray Ozen D.store_accumulator(d_smem) 2444d330820SGuray Ozen gpu.barrier() 2454d330820SGuray Ozen 2464d330820SGuray Ozen # Store (shared memory --> global memory) 2474d330820SGuray Ozen for i in scf.for_(0, TILE_M, 1): 2484d330820SGuray Ozen val = memref.load(d_smem, [i, tidx]) 2494d330820SGuray Ozen memref.store(val, d_gmem, [i, tidx]) 2504d330820SGuray Ozen scf.yield_([]) 2514d330820SGuray Ozen 2524d330820SGuray Ozen 2534d330820SGuray Ozen@NVDSL.mlir_func 2544d330820SGuray Ozendef gemm_warp_specialized(a, b, d, num_stages): 255*f8ff9094SGuray Ozen token_ty = gpu.AsyncTokenType.get() 2564d330820SGuray Ozen t1 = gpu.wait(token_ty, []) 2574d330820SGuray Ozen a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], []) 2584d330820SGuray Ozen b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], []) 2594d330820SGuray Ozen d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], []) 2604d330820SGuray Ozen t5 = gpu.memcpy(token_ty, [t4], a_dev, a) 2614d330820SGuray Ozen t6 = gpu.memcpy(token_ty, [t5], b_dev, b) 2624d330820SGuray Ozen t7 = gpu.wait(token_ty, [t6]) 2634d330820SGuray Ozen 2644d330820SGuray Ozen sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B 2654d330820SGuray Ozen a_tma = TMA([128, 64], a.type, swizzle=sw) 2664d330820SGuray Ozen b_tma = TMA([64, 64], b.type, swizzle=sw) 2674d330820SGuray Ozen a_tma.create_descriptor(a_dev) 2684d330820SGuray Ozen b_tma.create_descriptor(b_dev) 2694d330820SGuray Ozen 2704d330820SGuray Ozen grid = [(M // TILE_M), (N // TILE_N), 1] 2714d330820SGuray Ozen block = [256, 1, 1] 2724d330820SGuray Ozen 2734d330820SGuray Ozen size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K 2744d330820SGuray Ozen size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K 2754d330820SGuray Ozen smem_size_in_bytes = (size_a + size_b) * num_stages 2764d330820SGuray Ozen 2774d330820SGuray Ozen @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes) 2784d330820SGuray Ozen def gemm_warp_specialized_kernel(): 2794d330820SGuray Ozen # Init Warpgroups 2804d330820SGuray Ozen wg_producer = Warpgroup(primary_thread=128, register_size=40) 2814d330820SGuray Ozen wg_consumer = Warpgroup(primary_thread=0, register_size=232) 2824d330820SGuray Ozen 2834d330820SGuray Ozen # Initialize mbarriers and prefetch TMA descriptors 2844d330820SGuray Ozen mbar_mma, mbar_tma = initialize(a_tma, b_tma, num_stages) 2854d330820SGuray Ozen 2864d330820SGuray Ozen # Producer performs TMA 2874d330820SGuray Ozen with wg_producer: 2884d330820SGuray Ozen producer_loop(mbar_tma, mbar_mma, a_tma, b_tma, wg_producer, num_stages) 2894d330820SGuray Ozen 2904d330820SGuray Ozen # Consumer performs MMA/Tensor Core 2914d330820SGuray Ozen with wg_consumer: 2924d330820SGuray Ozen D = consumer_loop(mbar_tma, mbar_mma, a_tma, b_tma, wg_consumer, num_stages) 2934d330820SGuray Ozen epilogue(D, d_dev) 2944d330820SGuray Ozen 2954d330820SGuray Ozen gemm_warp_specialized_kernel() 2964d330820SGuray Ozen 2974d330820SGuray Ozen t8 = gpu.memcpy(token_ty, [t7], d, d_dev) 2984d330820SGuray Ozen gpu.wait(None, [t8]) 2994d330820SGuray Ozen 3004d330820SGuray Ozen 3014d330820SGuray Ozen# Python pass arguments to MLIR 3024d330820SGuray OzenN = 256 3034d330820SGuray OzenM = 512 3044d330820SGuray OzenK = 1024 3054d330820SGuray OzenTILE_M = 128 3064d330820SGuray OzenTILE_N = 128 3074d330820SGuray OzenTILE_K = 64 3084d330820SGuray Ozena = np.random.randn(M, K).astype(np.float16) 3094d330820SGuray Ozenb = np.random.randn(K, N).astype(np.float16) 3104d330820SGuray Ozend = np.zeros((M, N), np.float32) 3114d330820SGuray Ozen 3124d330820SGuray Ozengemm_warp_specialized(a, b, d, num_stages=7) 3134d330820SGuray Ozen 3144d330820SGuray Ozen 3154d330820SGuray Ozen# Verify MLIR with reference computation 3164d330820SGuray Ozenref_d = a.astype(np.float16) @ b.astype(np.float16) 3174d330820SGuray Ozennp.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01) 3184d330820SGuray Ozen 3194d330820SGuray Ozen 3204d330820SGuray Ozenprint("PASS") 3214d330820SGuray Ozen# CHECK-NOT: Mismatched elements 322