xref: /llvm-project/mlir/test/Examples/NVGPU/Ch4.py (revision f8ff9094711b74d3f695f7571f6390f8a481fc52)
14d330820SGuray Ozen# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
24d330820SGuray Ozen# RUN:   %PYTHON %s | FileCheck %s
34d330820SGuray Ozen
44d330820SGuray Ozen# ===----------------------------------------------------------------------===//
54d330820SGuray Ozen#  Chapter 4 : Multistage GEMM with Tensor Core
64d330820SGuray Ozen# ===----------------------------------------------------------------------===//
74d330820SGuray Ozen#
84d330820SGuray Ozen# This program exemplifies a GEMM operation for `f32+=f16*f16`, utilizing the
94d330820SGuray Ozen# Multistage method with a tile size of 128x128x64. The code completely
104d330820SGuray Ozen# parallelizes the two outermost loops into thread blocks. It launches one Warp
114d330820SGuray Ozen# Groups (128 threads in total) and allocates multiple slots/stage in the
124d330820SGuray Ozen# shared memory. The program consists of three main parts: prologue, mainloop,
134d330820SGuray Ozen# and epilogue. In the prologue, thread0 requests for TMA to load data into
144d330820SGuray Ozen# shared memory slots. The mainloop executes MMA while simultaneously loading
154d330820SGuray Ozen# TMA for the utilized slots. This overlap of TMA and MMA operations enhances
164d330820SGuray Ozen# performance by maximizing computational throughput.
174d330820SGuray Ozen#
184d330820SGuray Ozen# Loops illustration:
194d330820SGuray Ozen#
204d330820SGuray Ozen#  for s in range(num_stages):
214d330820SGuray Ozen#    TMA_128x64_64x128...
224d330820SGuray Ozen#  for ti in range(M//128):  # -> blockIdx.x
234d330820SGuray Ozen#   for tj in range(N//128): # -> blockIdx.y
244d330820SGuray Ozen#    for tk in range(K//64):
254d330820SGuray Ozen#      MMA_128x128x64...
264d330820SGuray Ozen#      TMA_128x64_64x128...
274d330820SGuray Ozen#  Epilogue...
284d330820SGuray Ozen#
294d330820SGuray Ozen# This chapter introduces demonstrates:
304d330820SGuray Ozen#  1. Partition shape based on block IDs
314d330820SGuray Ozen#  2. Prologue
324d330820SGuray Ozen#    2.1 Execute TMA Load for two input matrices for each stage
334d330820SGuray Ozen#  3. Main loop
344d330820SGuray Ozen#    3.1 Wait for completion of TMA load with mbarrier
354d330820SGuray Ozen#    3.2 Performs Tensor Core GEMM 64x128x64 by warpgroup
364d330820SGuray Ozen#    3.3 Load next stage if needed
374d330820SGuray Ozen#  4. Epilogue
384d330820SGuray Ozen#    4.1 Store fragmented registers to shared memory
394d330820SGuray Ozen#    4.2 Store shared memory to global
404d330820SGuray Ozen#
414d330820SGuray Ozen# ===----------------------------------------------------------------------===//
424d330820SGuray Ozen
434d330820SGuray Ozen
444d330820SGuray Ozenfrom mlir import ir
454d330820SGuray Ozenfrom mlir.dialects import gpu, scf, nvgpu, nvvm
464d330820SGuray Ozenfrom mlir.extras import types as T
474d330820SGuray Ozenfrom tools.nvdsl import *
484d330820SGuray Ozenimport numpy as np
494d330820SGuray Ozen
504d330820SGuray Ozen
514d330820SGuray Ozendef partition_shape():
524d330820SGuray Ozen    """
534d330820SGuray Ozen    Calculate the partition shape based on the block IDs.
544d330820SGuray Ozen
554d330820SGuray Ozen    It partitions the shape like below:
564d330820SGuray Ozen    for(.. i < M ...)   --> blockIdx.x
574d330820SGuray Ozen     for(.. j < N ...)  --> blockIdx.y
584d330820SGuray Ozen      for(.. k < K ...)
594d330820SGuray Ozen
604d330820SGuray Ozen    Returns:
614d330820SGuray Ozen        dimX (int): Dimension along the x-axis.
624d330820SGuray Ozen        dimY (int): Dimension along the y-axis.
634d330820SGuray Ozen    """
644d330820SGuray Ozen    bidx = gpu.block_id(gpu.Dimension.x)
654d330820SGuray Ozen    bidy = gpu.block_id(gpu.Dimension.y)
664d330820SGuray Ozen    dimX = bidx * TILE_M
674d330820SGuray Ozen    dimY = bidy * TILE_N
684d330820SGuray Ozen    return dimX, dimY
694d330820SGuray Ozen
704d330820SGuray Ozen
714d330820SGuray Ozendef tma_load(
724d330820SGuray Ozen    mbar_group: Mbarriers,
734d330820SGuray Ozen    a_tma: TMA,
744d330820SGuray Ozen    b_tma: TMA,
754d330820SGuray Ozen    slot,
764d330820SGuray Ozen    stage,
774d330820SGuray Ozen    num_stages,
784d330820SGuray Ozen    p=None,
794d330820SGuray Ozen):
804d330820SGuray Ozen    """
814d330820SGuray Ozen    TMA loads two input matrices from global memory to shared memory. It performs the following operations:
824d330820SGuray Ozen
834d330820SGuray Ozen       - tma.load a_shared_memory[off_x]  at coordinate [x, z]      (Loads 128x64)
844d330820SGuray Ozen       - tma.load b_shared_memory[off_y1] at coordinate [y, x]      (Loads 64x64)
854d330820SGuray Ozen       - tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64)
864d330820SGuray Ozen
874d330820SGuray Ozen       mbarrier.arrive ta_count = 128x64x2x4
884d330820SGuray Ozen    """
894d330820SGuray Ozen    dimX, dimY = partition_shape()
904d330820SGuray Ozen
914d330820SGuray Ozen    tidx = gpu.thread_id(gpu.Dimension.x)
924d330820SGuray Ozen    begin_b = num_stages * get_type_size(a_tma.tma_memref)
934d330820SGuray Ozen    size_tma_a = get_type_size(a_tma.tma_memref)
944d330820SGuray Ozen    size_tma_b = get_type_size(b_tma.tma_memref)
954d330820SGuray Ozen    ta_count = size_tma_a + (size_tma_b * 2)
964d330820SGuray Ozen    tidx = gpu.thread_id(gpu.Dimension.x)
974d330820SGuray Ozen
984d330820SGuray Ozen    p = tidx == 0 if p is None else p
994d330820SGuray Ozen
1004d330820SGuray Ozen    off_a = slot * size_tma_a
1014d330820SGuray Ozen    off_b = (slot * size_tma_a) + begin_b
1024d330820SGuray Ozen    off_b2 = off_b + size_tma_b
1034d330820SGuray Ozen    a_elem_ty = a_tma.tma_memref.element_type
1044d330820SGuray Ozen    b_elem_ty = b_tma.tma_memref.element_type
1054d330820SGuray Ozen    a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty, off_a)
1064d330820SGuray Ozen    b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
1074d330820SGuray Ozen    b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
1084d330820SGuray Ozen
1094d330820SGuray Ozen    mbar_group[slot].arrive(ta_count, predicate=p)
1104d330820SGuray Ozen
1114d330820SGuray Ozen    c1 = stage * 64
1124d330820SGuray Ozen    a_tma.load(a, mbar_group[slot], coords=[c1, dimX], predicate=p)
1134d330820SGuray Ozen    b_tma.load(b1, mbar_group[slot], coords=[dimY, c1], predicate=p)
1144d330820SGuray Ozen    b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
1154d330820SGuray Ozen
1164d330820SGuray Ozen
1174d330820SGuray Ozendef initialize(a_tma: TMA, b_tma: TMA, num_stages):
1184d330820SGuray Ozen    """
1194d330820SGuray Ozen    Initialize mbarriers and prefetch TMA descriptors.
1204d330820SGuray Ozen    """
1214d330820SGuray Ozen    tidx = gpu.thread_id(gpu.Dimension.x)
1224d330820SGuray Ozen    mbar_group = Mbarriers(number_of_barriers=num_stages)
1234d330820SGuray Ozen    isThread0 = tidx == const(0)
1244d330820SGuray Ozen    with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
1254d330820SGuray Ozen        for i in scf.for_(0, num_stages, 1):
1264d330820SGuray Ozen            mbar_group[i].init(1)
1274d330820SGuray Ozen            scf.yield_([])
1284d330820SGuray Ozen        a_tma.prefetch()
1294d330820SGuray Ozen        b_tma.prefetch()
1304d330820SGuray Ozen        scf.yield_([])
1314d330820SGuray Ozen
1324d330820SGuray Ozen    return mbar_group
1334d330820SGuray Ozen
1344d330820SGuray Ozen
1354d330820SGuray Ozendef prologue(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA, num_stages):
1364d330820SGuray Ozen    """
1374d330820SGuray Ozen    Prologue of the GEMM kernel. It loads 2 input matrices for each stage in loop like below:
1384d330820SGuray Ozen
1394d330820SGuray Ozen    for stage in range(NUM_STAGES):
1404d330820SGuray Ozen        tma_load x, y, stage
1414d330820SGuray Ozen
1424d330820SGuray Ozen    """
1434d330820SGuray Ozen    ns = num_stages if num_stages == 1 else num_stages - 1
1444d330820SGuray Ozen    for iv in scf.for_(0, ns, 1):
1454d330820SGuray Ozen        tma_load(mbar_group, a_tma, b_tma, iv, iv, num_stages)
1464d330820SGuray Ozen        scf.yield_([])
1474d330820SGuray Ozen
1484d330820SGuray Ozen
1494d330820SGuray Ozendef mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA, num_stages):
1504d330820SGuray Ozen    """
1514d330820SGuray Ozen    Main loop of the Multistage GEMM kernel. It iterates through
1524d330820SGuray Ozen    stages and performs matrix multiplication, loading data by TMA to shared memory. It like following
1534d330820SGuray Ozen
1544d330820SGuray Ozen    MatrixAccumulator D
1554d330820SGuray Ozen    for k in range(K // TILE_K):
1564d330820SGuray Ozen
1574d330820SGuray Ozen        try_wait(stage, ...)    # Wait TMA load
1584d330820SGuray Ozen
1594d330820SGuray Ozen        Matrix A(stage, ...)    # Find shared memory slot
1604d330820SGuray Ozen        Matrix B(stage, ...)    # Find shared memory slot
1614d330820SGuray Ozen        D += A @ B              # Multiply and accumulate
1624d330820SGuray Ozen
1634d330820SGuray Ozen        if(needLoad)            # Load next stage if needed
1644d330820SGuray Ozen            tma_load(x, y, nextSlot, nextStage)
1654d330820SGuray Ozen
1664d330820SGuray Ozen    """
1674d330820SGuray Ozen    ns = num_stages if num_stages == 1 else num_stages - 1
1684d330820SGuray Ozen
1694d330820SGuray Ozen    tidx = gpu.thread_id(gpu.Dimension.x)
1704d330820SGuray Ozen    begin_b = num_stages * get_type_size(a_tma.tma_memref)
1714d330820SGuray Ozen
1724d330820SGuray Ozen    size_a = TILE_M * TILE_K * get_type_size(T.f16())
1734d330820SGuray Ozen
1744d330820SGuray Ozen    # Initialize A and B (input matrices) and C (accumulator)
1754d330820SGuray Ozen    A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma)
1764d330820SGuray Ozen    B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma)
1774d330820SGuray Ozen    D = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32())
1784d330820SGuray Ozen
1794d330820SGuray Ozen    phase = const(False, ty=T.bool())
1804d330820SGuray Ozen
1814d330820SGuray Ozen    # Main Loop
1824d330820SGuray Ozen    for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [D.acc_op, phase])
1834d330820SGuray Ozen    with ir.InsertionPoint(for_op.body):
1844d330820SGuray Ozen        phase = for_op.inner_iter_args[1]
1854d330820SGuray Ozen        iv = for_op.induction_variable
1864d330820SGuray Ozen        stage = iv % num_stages
1874d330820SGuray Ozen
1884d330820SGuray Ozen        # Wait for current stage
1894d330820SGuray Ozen        mbar_group[stage].try_wait(phase=phase)
1904d330820SGuray Ozen
1914d330820SGuray Ozen        # Find shared memory slot
1924d330820SGuray Ozen        offset_a = stage * size_a
1934d330820SGuray Ozen        offset_b = offset_a + begin_b
1944d330820SGuray Ozen        a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a)
1954d330820SGuray Ozen        b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b)
1964d330820SGuray Ozen
1974d330820SGuray Ozen        # Iterate input matrices, update accumulator
1984d330820SGuray Ozen        A.update_smem(a_smem)
1994d330820SGuray Ozen        B.update_smem(b_smem)
2004d330820SGuray Ozen        D.update_accumulator(for_op.inner_iter_args[0])
2014d330820SGuray Ozen
2024d330820SGuray Ozen        # Matrix Multiply
2034d330820SGuray Ozen        D += A @ B
2044d330820SGuray Ozen
2054d330820SGuray Ozen        # Wait Tensor Core for single stage
2064d330820SGuray Ozen        if num_stages == 1:
2074d330820SGuray Ozen            nvvm.WgmmaWaitGroupSyncOp(0)
2084d330820SGuray Ozen
2094d330820SGuray Ozen        # Load next stage
2104d330820SGuray Ozen        pred = ((iv + ns) < const(K // TILE_K)) & (tidx == 0)
2114d330820SGuray Ozen        nextStage = iv + ns
2124d330820SGuray Ozen        nextSlot = nextStage % num_stages
2134d330820SGuray Ozen        tma_load(mbar_group, a_tma, b_tma, nextSlot, nextStage, num_stages, pred)
2144d330820SGuray Ozen
2154d330820SGuray Ozen        # Switch phase parity for the mbarrier
2164d330820SGuray Ozen        newPhase = arith.select(
2174d330820SGuray Ozen            stage == (num_stages - 1),
2184d330820SGuray Ozen            (phase ^ const(True, ty=T.bool())),
2194d330820SGuray Ozen            phase,
2204d330820SGuray Ozen        )
2214d330820SGuray Ozen        scf.yield_([D.acc_op, newPhase])
2224d330820SGuray Ozen
2234d330820SGuray Ozen    nvvm.WgmmaWaitGroupSyncOp(0)
2244d330820SGuray Ozen
2254d330820SGuray Ozen    D.update_accumulator(for_op.results[0])
2264d330820SGuray Ozen    return D
2274d330820SGuray Ozen
2284d330820SGuray Ozen
2294d330820SGuray Ozendef epilogue(D: WGMMAMatrix, d_dev):
2304d330820SGuray Ozen    """
2314d330820SGuray Ozen    Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
2324d330820SGuray Ozen
2334d330820SGuray Ozen    MatrixAccumulator D               # Fragmented results
2344d330820SGuray Ozen    store D -> Shared Memory          # Store Shared Memory
2354d330820SGuray Ozen    Shared Memory -> Z[dimX][dimY]    # Store Shared Memory to Global Memory
2364d330820SGuray Ozen
2374d330820SGuray Ozen    """
2384d330820SGuray Ozen    tidx = gpu.thread_id(gpu.Dimension.x)
2394d330820SGuray Ozen    dimX, dimY = partition_shape()
2404d330820SGuray Ozen
2414d330820SGuray Ozen    d_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
2424d330820SGuray Ozen    d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
2434d330820SGuray Ozen
2444d330820SGuray Ozen    # Store (registers -> shared memory)
2454d330820SGuray Ozen    D.store_accumulator(d_smem)
2464d330820SGuray Ozen    gpu.barrier()
2474d330820SGuray Ozen
2484d330820SGuray Ozen    # Store (shared memory --> global memory)
2494d330820SGuray Ozen    for i in scf.for_(0, TILE_M, 1):
2504d330820SGuray Ozen        val = memref.load(d_smem, [i, tidx])
2514d330820SGuray Ozen        memref.store(val, d_gmem, [i, tidx])
2524d330820SGuray Ozen        scf.yield_([])
2534d330820SGuray Ozen
2544d330820SGuray Ozen
2554d330820SGuray Ozen# The decorator generates
2564d330820SGuray Ozen#   a -> memref<MxKxf16>
2574d330820SGuray Ozen#   b -> memref<NxKf16>
2584d330820SGuray Ozen#   d -> memref<MxNxf32>
2594d330820SGuray Ozen@NVDSL.mlir_func
2604d330820SGuray Ozendef gemm_multistage(a, b, d, num_stages):
261*f8ff9094SGuray Ozen    token_ty = gpu.AsyncTokenType.get()
2624d330820SGuray Ozen    t1 = gpu.wait(token_ty, [])
2634d330820SGuray Ozen    a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
2644d330820SGuray Ozen    b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
2654d330820SGuray Ozen    d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
2664d330820SGuray Ozen    t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
2674d330820SGuray Ozen    t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
2684d330820SGuray Ozen    t7 = gpu.wait(token_ty, [t6])
2694d330820SGuray Ozen
2704d330820SGuray Ozen    sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
2714d330820SGuray Ozen    a_tma = TMA([128, 64], a.type, swizzle=sw)
2724d330820SGuray Ozen    b_tma = TMA([64, 64], b.type, swizzle=sw)
2734d330820SGuray Ozen    a_tma.create_descriptor(a_dev)
2744d330820SGuray Ozen    b_tma.create_descriptor(b_dev)
2754d330820SGuray Ozen
2764d330820SGuray Ozen    grid = [(M // TILE_M), (N // TILE_N), 1]
2774d330820SGuray Ozen    block = [128, 1, 1]
2784d330820SGuray Ozen
2794d330820SGuray Ozen    size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K
2804d330820SGuray Ozen    size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K
2814d330820SGuray Ozen    smem_size_in_bytes = (size_a + size_b) * num_stages
2824d330820SGuray Ozen
2834d330820SGuray Ozen    @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
2844d330820SGuray Ozen    def gemm_multistage_kernel():
2854d330820SGuray Ozen        # Initialize mbarriers and prefetch TMA descriptors
2864d330820SGuray Ozen        mbar_group = initialize(a_tma, b_tma, num_stages)
2874d330820SGuray Ozen
2884d330820SGuray Ozen        # Fill the pipeline stages
2894d330820SGuray Ozen        prologue(mbar_group, a_tma, b_tma, num_stages)
2904d330820SGuray Ozen
2914d330820SGuray Ozen        # Main loop
2924d330820SGuray Ozen        D = mainloop(mbar_group, a_tma, b_tma, num_stages)
2934d330820SGuray Ozen
2944d330820SGuray Ozen        # Store registers to global memory
2954d330820SGuray Ozen        epilogue(D, d_dev)
2964d330820SGuray Ozen
2974d330820SGuray Ozen    gemm_multistage_kernel()
2984d330820SGuray Ozen
2994d330820SGuray Ozen    t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
3004d330820SGuray Ozen    gpu.wait(None, [t8])
3014d330820SGuray Ozen
3024d330820SGuray Ozen
3034d330820SGuray Ozen# Python pass arguments to MLIR
3044d330820SGuray OzenN = 256
3054d330820SGuray OzenM = 512
3064d330820SGuray OzenK = 1024
3074d330820SGuray OzenTILE_M = 128
3084d330820SGuray OzenTILE_N = 128
3094d330820SGuray OzenTILE_K = 64
3104d330820SGuray Ozena = np.random.randn(M, K).astype(np.float16)
3114d330820SGuray Ozenb = np.random.randn(K, N).astype(np.float16)
3124d330820SGuray Ozend = np.zeros((M, N), np.float32)
3134d330820SGuray Ozen
3144d330820SGuray Ozengemm_multistage(a, b, d, num_stages=7)
3154d330820SGuray Ozen
3164d330820SGuray Ozen
3174d330820SGuray Ozen# Verify MLIR with reference computation
3184d330820SGuray Ozenref_d = a.astype(np.float16) @ b.astype(np.float16)
3194d330820SGuray Ozennp.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
3204d330820SGuray Ozen
3214d330820SGuray Ozen
3224d330820SGuray Ozenprint("PASS")
3234d330820SGuray Ozen# CHECK-NOT: Mismatched elements
324