xref: /llvm-project/mlir/test/Examples/NVGPU/Ch3.py (revision f8ff9094711b74d3f695f7571f6390f8a481fc52)
14d330820SGuray Ozen# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
24d330820SGuray Ozen# RUN:   %PYTHON %s | FileCheck %s
34d330820SGuray Ozen
44d330820SGuray Ozen# ===----------------------------------------------------------------------===//
54d330820SGuray Ozen#  Chapter 3 : GEMM 128x128x64 with Tensor Core
64d330820SGuray Ozen# ===----------------------------------------------------------------------===//
74d330820SGuray Ozen#
84d330820SGuray Ozen# This program demonstrates a GEMM operation with 128x128x64 matrix multiplication
94d330820SGuray Ozen#
104d330820SGuray Ozen# This chapter introduces demonstrates:
114d330820SGuray Ozen# 1. Execute TMA Load for two input matrices
124d330820SGuray Ozen# 2. Performs Tensor Core GEMM 128x128x64 by warpgroup
134d330820SGuray Ozen# 3. Stores fragmented registers to global memory by warpgroup
144d330820SGuray Ozen#
154d330820SGuray Ozen# ===----------------------------------------------------------------------===//
164d330820SGuray Ozen
174d330820SGuray Ozen
184d330820SGuray Ozenfrom mlir import ir
194d330820SGuray Ozenfrom mlir.dialects import nvgpu, scf, arith, memref, vector, gpu
204d330820SGuray Ozenfrom tools.nvdsl import *
214d330820SGuray Ozenfrom mlir.extras import types as T
224d330820SGuray Ozenimport numpy as np
234d330820SGuray Ozen
244d330820SGuray Ozen
254d330820SGuray Ozendef tma_load(
264d330820SGuray Ozen    mbar_group: Mbarriers,
274d330820SGuray Ozen    a_tma: TMA,
284d330820SGuray Ozen    b_tma: TMA,
294d330820SGuray Ozen    p,
304d330820SGuray Ozen):
314d330820SGuray Ozen    """
324d330820SGuray Ozen    TMA loads two input matrices from global memory to shared memory. It performs the following operations:
334d330820SGuray Ozen
344d330820SGuray Ozen       - tma.load a_shared_memory[0] at coordinate [0, 0]  (Loads 128x64)
354d330820SGuray Ozen       - tma.load b_shared_memory[0] at coordinate [0, 0]  (Loads 64x64)
364d330820SGuray Ozen       - tma.load b_shared_memory[0] at coordinate [64, 0] (Loads 64x64)
374d330820SGuray Ozen
384d330820SGuray Ozen       mbarrier.arrive ta_count = 128x64xf16 + 64x128xf16
394d330820SGuray Ozen    """
404d330820SGuray Ozen
414d330820SGuray Ozen    size_tma_a = get_type_size(a_tma.tma_memref)
424d330820SGuray Ozen    size_tma_b = get_type_size(b_tma.tma_memref)
434d330820SGuray Ozen    ta_count = size_tma_a + (size_tma_b * 2)
444d330820SGuray Ozen
454d330820SGuray Ozen    off_b = size_tma_a
464d330820SGuray Ozen    off_b2 = off_b + size_tma_b
474d330820SGuray Ozen    a_elem_ty = a_tma.tma_memref.element_type
484d330820SGuray Ozen    b_elem_ty = b_tma.tma_memref.element_type
494d330820SGuray Ozen    a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty)
504d330820SGuray Ozen    b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
514d330820SGuray Ozen    b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
524d330820SGuray Ozen
534d330820SGuray Ozen    mbar_group[0].arrive(ta_count, predicate=p)
544d330820SGuray Ozen
554d330820SGuray Ozen    a_tma.load(a, mbar_group[0], coords=[0, 0], predicate=p)
564d330820SGuray Ozen    b_tma.load(b1, mbar_group[0], coords=[0, 0], predicate=p)
574d330820SGuray Ozen    b_tma.load(b2, mbar_group[0], coords=[64, 0], predicate=p)
584d330820SGuray Ozen
594d330820SGuray Ozen
604d330820SGuray Ozen@NVDSL.mlir_func
614d330820SGuray Ozendef gemm_128_128_64(a, b, d):
62*f8ff9094SGuray Ozen    token_ty = gpu.AsyncTokenType.get()
634d330820SGuray Ozen    t1 = gpu.wait(token_ty, [])
644d330820SGuray Ozen    a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
654d330820SGuray Ozen    b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
664d330820SGuray Ozen    d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
674d330820SGuray Ozen    t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
684d330820SGuray Ozen    t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
694d330820SGuray Ozen    t7 = gpu.wait(token_ty, [t6])
704d330820SGuray Ozen
714d330820SGuray Ozen    sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
724d330820SGuray Ozen    a_tma = TMA([128, 64], a.type, swizzle=sw)
734d330820SGuray Ozen    b_tma = TMA([64, 64], b.type, swizzle=sw)
744d330820SGuray Ozen    a_tma.create_descriptor(a_dev)
754d330820SGuray Ozen    b_tma.create_descriptor(b_dev)
764d330820SGuray Ozen    a_size = get_type_size(a.type)
774d330820SGuray Ozen    b_size = get_type_size(b.type)
784d330820SGuray Ozen    smem_size_in_bytes = a_size + b_size
794d330820SGuray Ozen
804d330820SGuray Ozen    @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=smem_size_in_bytes)
814d330820SGuray Ozen    def gemm_tma_kernel():
824d330820SGuray Ozen        tidx = gpu.thread_id(gpu.Dimension.x)
834d330820SGuray Ozen
844d330820SGuray Ozen        mbar_group = Mbarriers(number_of_barriers=1)
854d330820SGuray Ozen        isThread0 = tidx == 0
864d330820SGuray Ozen
874d330820SGuray Ozen        mbar_group[0].init(1, predicate=isThread0)
884d330820SGuray Ozen        a_tma.prefetch(predicate=isThread0)
894d330820SGuray Ozen        b_tma.prefetch(predicate=isThread0)
904d330820SGuray Ozen
914d330820SGuray Ozen        a_smem = get_dynamic_shared_memory((M, K), T.f16())
924d330820SGuray Ozen        b_smem = get_dynamic_shared_memory((K, N), T.f16(), offset=a_size)
934d330820SGuray Ozen
944d330820SGuray Ozen        # 1. TMA Load for two input matrices
954d330820SGuray Ozen        tma_load(mbar_group, a_tma, b_tma, isThread0)
964d330820SGuray Ozen
974d330820SGuray Ozen        # 2. All threads wait TMA load completion
984d330820SGuray Ozen        mbar_group[0].try_wait()
994d330820SGuray Ozen
1004d330820SGuray Ozen        # 3. Performs Tensor Core GEMM 128x128x64 by warpgroup
1014d330820SGuray Ozen        A = WGMMAMatrix(WGMMAType.Descriptor, [M, K], desc=a_tma, smem=a_smem)
1024d330820SGuray Ozen        B = WGMMAMatrix(WGMMAType.Descriptor, [K, N], desc=b_tma, smem=b_smem)
1034d330820SGuray Ozen        D = WGMMAMatrix(WGMMAType.Accumulator, shape=[M, N], ty=T.f32())
1044d330820SGuray Ozen
1054d330820SGuray Ozen        # Matrix Multiply
1064d330820SGuray Ozen        D += A @ B
1074d330820SGuray Ozen
1084d330820SGuray Ozen        # 4. Stores fragmented registers to global memory by warpgroup
1094d330820SGuray Ozen        D.store_accumulator(d_dev)
1104d330820SGuray Ozen
1114d330820SGuray Ozen    gemm_tma_kernel()
1124d330820SGuray Ozen
1134d330820SGuray Ozen    t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
1144d330820SGuray Ozen    gpu.wait(None, [t8])
1154d330820SGuray Ozen
1164d330820SGuray Ozen
1174d330820SGuray Ozen# Python pass arguments to MLIR
1184d330820SGuray OzenM = 128
1194d330820SGuray OzenN = 128
1204d330820SGuray OzenK = 64
1214d330820SGuray Ozena = np.random.randn(M, K).astype(np.float16)
1224d330820SGuray Ozenb = np.random.randn(K, N).astype(np.float16)
1234d330820SGuray Ozend = np.zeros((M, N), np.float32)
1244d330820SGuray Ozengemm_128_128_64(a, b, d)
1254d330820SGuray Ozen
1264d330820SGuray Ozenref_d = a.astype(np.float16) @ b.astype(np.float16)
1274d330820SGuray Ozennp.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
1284d330820SGuray Ozenprint("PASS")
1294d330820SGuray Ozen# CHECK-NOT: Mismatched elements
130