xref: /llvm-project/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py (revision d95e6d027486876559f1a2a96c33b8ad93cc0ae4)
1# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2# RUN:   %PYTHON %s | FileCheck %s
3
4
5# ===--- GEMM Hopper Tensor Core Integration Test ---===
6#
7# This test aims to validate the correctness of the supported GEMM kernels in
8# NVGPU dialects, with current support for Multistage and Warp Specialization
9# kernels.
10# The test constructs and metaprograms IR using Python bindings, allowing
11# generic IR building. This flexibility enables changes to the shape,
12# tile size, or data type of the GEMM for testing purposes.
13# The entry function is `matmul`, where one can specify GEMM shape, tile size,
14# data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum
15# number of stages.
16# Verification is done via numpy's matmul operation.
17#
18# Example:
19# matmul(input_type=np.float16,                # input types
20#        output_type=np.float32,               # output type
21#        M=4096, N=4096, K=4096,               # Shape
22#        BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
23#        use_warp_specialization=True,         # Enable Warp Specialization
24#        max_num_stages=3)                     # Number of stages in shared memory
25#
26# ===--- Parallelism Across CTAs  ---===
27#
28# GEMM includes three loops defining the shape of the GEMM, specified in the
29# `matmul` function.
30# The program builds IR using the following loop structure, tiling the loops
31# with the given tile size and parallelizing the two outermost loops into the
32# first and second dimensions of CTAs.
33#
34# for(bi = 0; i < M; i += BLOCK_M)          # parallelize across blockIdx.x
35#     for(bj = 0; j < N; j += BLOCK_N)      # parallelize across blockIdx.y
36#         for(bk = 0; k < K; K += BLOCK_K)
37#             for(i = bi; i < (bi + BLOCK_M); ++i)
38#                 for(j = bj; j < (bj + BLOCK_N); ++j)
39#                     for(k = bk; k < (bk + BLOCK_K); ++k)
40#
41# ===--- Multistage Kernel ---===
42#
43# This kernel launches a single warp group (128 threads). The primary thread
44# (pthread) requests load from TMA. Threads collectively wait for the data and
45# perform mma operations. After completing the shape, threads together store
46# first fragmented registers to shared memory, then from shared memory to global
47# memory; this part is called the epilogue.
48#
49# Execution Timeline of Multistage Kernel with 3 stages:
50# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
51# |       |Prologue ---->   |MainLoop ---->                                                                  |Epilogue  |
52# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
53# |pthread|[tma-0,1,2]     |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
54# |wgroup | ........       |[wait-0][mma]       |[wait-1][mma]       |[wait-2][mma]       | ... | [mma-wait] |[epilogue]|
55# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
56#
57# ===--- Warp Specialization Kernel  ---===
58#
59# This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one
60# as `producer warp group` and another as `consumer warp group`. The
61# `producer warp group` is responsible for requesting TMA load, while the
62# `consumer warp group` performs the mma operation. The epilogue section is
63# handled by the `consumer warp group` as its threads own the fragmented registers.
64#
65# Execution Timeline of Warp Specialization Kernel with 2 stages:
66# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
67# |        |MainLoop ---->                                                    | 1st Epilogue | 2nd Epilogue    |
68# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
69# |pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ...........  | [shmem->global] |
70# |wgroup1 | .......|         |         |         |                           |              | [shmem->global] |
71# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
72# |wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
73# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
74
75import errno
76import numpy as np
77import subprocess
78import ctypes
79from tools import nvgpucompiler
80from tools import matmulBuilder
81import contextlib
82import os
83import sys
84import pathlib
85import ctypes
86from mlir import runtime as rt
87
88
89def generate_matmul(
90    input_type=np.float16,
91    output_type=np.float32,
92    M=4096,
93    N=4096,
94    K=4096,
95    BLOCK_M=128,
96    BLOCK_N=128,
97    BLOCK_K=64,
98    use_warp_specialization=True,
99    saveIR=False,
100    max_num_stages=3,
101    options=f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3",
102):
103    with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
104        if use_warp_specialization:
105            mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(
106                input_type,
107                output_type,
108                M,
109                N,
110                K,
111                BLOCK_M,
112                BLOCK_N,
113                BLOCK_K,
114                max_num_stages,
115            )
116        else:
117            mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(
118                input_type,
119                output_type,
120                M,
121                N,
122                K,
123                BLOCK_M,
124                BLOCK_N,
125                BLOCK_K,
126                max_num_stages,
127            )
128
129        mlir_nvgpu_module.operation.verify()
130
131        # Save generated IR
132        if saveIR:
133            # print(mlir_nvgpu_module)
134            original_stdout = sys.stdout
135            with open("gemm.mlir", "w") as f:
136                sys.stdout = f
137                print(mlir_nvgpu_module)
138                sys.stdout = original_stdout
139
140        # Get compiler
141        support_lib = os.getenv("SUPPORT_LIB")
142        if not os.path.exists(support_lib):
143            raise FileNotFoundError(
144                errno.ENOENT, os.strerror(errno.ENOENT), support_lib
145            )
146        compiler = nvgpucompiler.NvgpuCompiler(
147            options, opt_level=3, shared_libs=[support_lib]
148        )
149
150        # Compile
151        engine = compiler.compile_and_jit(mlir_nvgpu_module)
152        return engine
153
154
155def matmul(
156    input_type=np.float16,
157    output_type=np.float32,
158    M=128,
159    N=128,
160    K=128,
161    BLOCK_M=128,
162    BLOCK_N=128,
163    BLOCK_K=64,
164    use_warp_specialization=True,
165    saveIR=False,
166    max_num_stages=3,
167    print_results=False,
168    no_verify=False,
169):
170    # Print the configuration
171    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
172    num_stages = min(required_stages, max_num_stages)
173    ity = "f16" if input_type == np.float16 else "f32"
174    oty = "f16" if output_type == np.float16 else "f32"
175    gemmty = "Warp specialization" if use_warp_specialization else "Multistage"
176    print(
177        "===-- Running GEMM "
178        + gemmty
179        + " "
180        + oty
181        + " += "
182        + ity
183        + " * "
184        + ity
185        + ", Size "
186        + str(M)
187        + "x"
188        + str(N)
189        + "x"
190        + str(K)
191        + ", Tile "
192        + str(BLOCK_M)
193        + "x"
194        + str(BLOCK_N)
195        + "x"
196        + str(BLOCK_K)
197        + ", stages "
198        + str(num_stages)
199        + " --==="
200    )
201
202    # Build IR and compile
203    engine = generate_matmul(
204        input_type,
205        output_type,
206        M,
207        N,
208        K,
209        BLOCK_M,
210        BLOCK_N,
211        BLOCK_K,
212        use_warp_specialization,
213        saveIR,
214        num_stages,
215    )
216
217    # Allocate matrices and invoke the matmul
218    c = np.zeros((M, N), output_type)
219    a = np.random.randn(M, K).astype(input_type)
220    b = np.random.randn(K, N).astype(input_type)
221    mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
222    mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
223    mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
224    kernelName = matmulBuilder.make_kernel_name(
225        input_type,
226        output_type,
227        M,
228        N,
229        K,
230        BLOCK_M,
231        BLOCK_N,
232        BLOCK_K,
233        num_stages,
234        use_warp_specialization,
235    )
236
237    # Launch the MLIR generated kernel
238    engine.invoke(kernelName, mem_a, mem_b, mem_c)
239
240    float_formatter = "{:.2f}".format
241    np.set_printoptions(formatter={"float_kind": float_formatter})
242
243    if print_results:
244        print(c)
245
246    # Verify the results
247    if not no_verify:
248        ref = a.astype(input_type) @ b.astype(input_type)
249        if print_results:
250            print(ref)
251        np.testing.assert_allclose(c, ref, rtol=5e-03, atol=1e-01)
252
253    print("PASS ")
254
255
256# Takes longer time to run
257def test_long():
258    for stages in range(1, 7):
259        for M in [128, 512, 1024, 4096, 8192]:
260            for N in [128, 512, 1024, 4096, 8192]:
261                for K in [64, 128, 512, 1024, 4096, 8192]:
262                    matmul(
263                        np.float16,
264                        np.float32,
265                        M,
266                        N,
267                        K,
268                        max_num_stages=stages,
269                        use_warp_specialization=False,
270                        no_verify=True,
271                    )
272                    matmul(
273                        np.float16,
274                        np.float32,
275                        M,
276                        N,
277                        K,
278                        max_num_stages=stages,
279                        use_warp_specialization=True,
280                    )
281
282
283def test_short():
284    for stages in [1, 3]:
285        for M in [128, 512]:
286            for N in [128]:
287                for K in [64, 256]:
288                    matmul(
289                        np.float16,
290                        np.float32,
291                        M,
292                        N,
293                        K,
294                        max_num_stages=stages,
295                        use_warp_specialization=False,
296                    )
297                    matmul(
298                        np.float16,
299                        np.float32,
300                        M,
301                        N,
302                        K,
303                        max_num_stages=stages,
304                        use_warp_specialization=True,
305                    )
306
307
308# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
309# CHECK: PASS
310# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
311# CHECK: PASS
312# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --===
313# CHECK: PASS
314# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --===
315# CHECK: PASS
316# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
317# CHECK: PASS
318# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
319# CHECK: PASS
320# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --===
321# CHECK: PASS
322# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --===
323# CHECK: PASS
324# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
325# CHECK: PASS
326# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
327# CHECK: PASS
328# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --===
329# CHECK: PASS
330# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --===
331# CHECK: PASS
332# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
333# CHECK: PASS
334# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
335# CHECK: PASS
336# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --===
337# CHECK: PASS
338# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --===
339# CHECK: PASS
340
341test_short()
342