xref: /llvm-project/mlir/test/Examples/NVGPU/Ch3.py (revision f8ff9094711b74d3f695f7571f6390f8a481fc52)
1# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2# RUN:   %PYTHON %s | FileCheck %s
3
4# ===----------------------------------------------------------------------===//
5#  Chapter 3 : GEMM 128x128x64 with Tensor Core
6# ===----------------------------------------------------------------------===//
7#
8# This program demonstrates a GEMM operation with 128x128x64 matrix multiplication
9#
10# This chapter introduces demonstrates:
11# 1. Execute TMA Load for two input matrices
12# 2. Performs Tensor Core GEMM 128x128x64 by warpgroup
13# 3. Stores fragmented registers to global memory by warpgroup
14#
15# ===----------------------------------------------------------------------===//
16
17
18from mlir import ir
19from mlir.dialects import nvgpu, scf, arith, memref, vector, gpu
20from tools.nvdsl import *
21from mlir.extras import types as T
22import numpy as np
23
24
25def tma_load(
26    mbar_group: Mbarriers,
27    a_tma: TMA,
28    b_tma: TMA,
29    p,
30):
31    """
32    TMA loads two input matrices from global memory to shared memory. It performs the following operations:
33
34       - tma.load a_shared_memory[0] at coordinate [0, 0]  (Loads 128x64)
35       - tma.load b_shared_memory[0] at coordinate [0, 0]  (Loads 64x64)
36       - tma.load b_shared_memory[0] at coordinate [64, 0] (Loads 64x64)
37
38       mbarrier.arrive ta_count = 128x64xf16 + 64x128xf16
39    """
40
41    size_tma_a = get_type_size(a_tma.tma_memref)
42    size_tma_b = get_type_size(b_tma.tma_memref)
43    ta_count = size_tma_a + (size_tma_b * 2)
44
45    off_b = size_tma_a
46    off_b2 = off_b + size_tma_b
47    a_elem_ty = a_tma.tma_memref.element_type
48    b_elem_ty = b_tma.tma_memref.element_type
49    a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty)
50    b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
51    b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
52
53    mbar_group[0].arrive(ta_count, predicate=p)
54
55    a_tma.load(a, mbar_group[0], coords=[0, 0], predicate=p)
56    b_tma.load(b1, mbar_group[0], coords=[0, 0], predicate=p)
57    b_tma.load(b2, mbar_group[0], coords=[64, 0], predicate=p)
58
59
60@NVDSL.mlir_func
61def gemm_128_128_64(a, b, d):
62    token_ty = gpu.AsyncTokenType.get()
63    t1 = gpu.wait(token_ty, [])
64    a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
65    b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
66    d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
67    t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
68    t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
69    t7 = gpu.wait(token_ty, [t6])
70
71    sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
72    a_tma = TMA([128, 64], a.type, swizzle=sw)
73    b_tma = TMA([64, 64], b.type, swizzle=sw)
74    a_tma.create_descriptor(a_dev)
75    b_tma.create_descriptor(b_dev)
76    a_size = get_type_size(a.type)
77    b_size = get_type_size(b.type)
78    smem_size_in_bytes = a_size + b_size
79
80    @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=smem_size_in_bytes)
81    def gemm_tma_kernel():
82        tidx = gpu.thread_id(gpu.Dimension.x)
83
84        mbar_group = Mbarriers(number_of_barriers=1)
85        isThread0 = tidx == 0
86
87        mbar_group[0].init(1, predicate=isThread0)
88        a_tma.prefetch(predicate=isThread0)
89        b_tma.prefetch(predicate=isThread0)
90
91        a_smem = get_dynamic_shared_memory((M, K), T.f16())
92        b_smem = get_dynamic_shared_memory((K, N), T.f16(), offset=a_size)
93
94        # 1. TMA Load for two input matrices
95        tma_load(mbar_group, a_tma, b_tma, isThread0)
96
97        # 2. All threads wait TMA load completion
98        mbar_group[0].try_wait()
99
100        # 3. Performs Tensor Core GEMM 128x128x64 by warpgroup
101        A = WGMMAMatrix(WGMMAType.Descriptor, [M, K], desc=a_tma, smem=a_smem)
102        B = WGMMAMatrix(WGMMAType.Descriptor, [K, N], desc=b_tma, smem=b_smem)
103        D = WGMMAMatrix(WGMMAType.Accumulator, shape=[M, N], ty=T.f32())
104
105        # Matrix Multiply
106        D += A @ B
107
108        # 4. Stores fragmented registers to global memory by warpgroup
109        D.store_accumulator(d_dev)
110
111    gemm_tma_kernel()
112
113    t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
114    gpu.wait(None, [t8])
115
116
117# Python pass arguments to MLIR
118M = 128
119N = 128
120K = 64
121a = np.random.randn(M, K).astype(np.float16)
122b = np.random.randn(K, N).astype(np.float16)
123d = np.zeros((M, N), np.float32)
124gemm_128_128_64(a, b, d)
125
126ref_d = a.astype(np.float16) @ b.astype(np.float16)
127np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
128print("PASS")
129# CHECK-NOT: Mismatched elements
130