1# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ 2# RUN: %PYTHON %s | FileCheck %s 3 4 5# ===--- GEMM Hopper Tensor Core Integration Test ---=== 6# 7# This test aims to validate the correctness of the supported GEMM kernels in 8# NVGPU dialects, with current support for Multistage and Warp Specialization 9# kernels. 10# The test constructs and metaprograms IR using Python bindings, allowing 11# generic IR building. This flexibility enables changes to the shape, 12# tile size, or data type of the GEMM for testing purposes. 13# The entry function is `matmul`, where one can specify GEMM shape, tile size, 14# data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum 15# number of stages. 16# Verification is done via numpy's matmul operation. 17# 18# Example: 19# matmul(input_type=np.float16, # input types 20# output_type=np.float32, # output type 21# M=4096, N=4096, K=4096, # Shape 22# BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size 23# use_warp_specialization=True, # Enable Warp Specialization 24# max_num_stages=3) # Number of stages in shared memory 25# 26# ===--- Parallelism Across CTAs ---=== 27# 28# GEMM includes three loops defining the shape of the GEMM, specified in the 29# `matmul` function. 30# The program builds IR using the following loop structure, tiling the loops 31# with the given tile size and parallelizing the two outermost loops into the 32# first and second dimensions of CTAs. 33# 34# for(bi = 0; i < M; i += BLOCK_M) # parallelize across blockIdx.x 35# for(bj = 0; j < N; j += BLOCK_N) # parallelize across blockIdx.y 36# for(bk = 0; k < K; K += BLOCK_K) 37# for(i = bi; i < (bi + BLOCK_M); ++i) 38# for(j = bj; j < (bj + BLOCK_N); ++j) 39# for(k = bk; k < (bk + BLOCK_K); ++k) 40# 41# ===--- Multistage Kernel ---=== 42# 43# This kernel launches a single warp group (128 threads). The primary thread 44# (pthread) requests load from TMA. Threads collectively wait for the data and 45# perform mma operations. After completing the shape, threads together store 46# first fragmented registers to shared memory, then from shared memory to global 47# memory; this part is called the epilogue. 48# 49# Execution Timeline of Multistage Kernel with 3 stages: 50# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+ 51# | |Prologue ----> |MainLoop ----> |Epilogue | 52# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+ 53# |pthread|[tma-0,1,2] |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]| 54# |wgroup | ........ |[wait-0][mma] |[wait-1][mma] |[wait-2][mma] | ... | [mma-wait] |[epilogue]| 55# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+ 56# 57# ===--- Warp Specialization Kernel ---=== 58# 59# This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one 60# as `producer warp group` and another as `consumer warp group`. The 61# `producer warp group` is responsible for requesting TMA load, while the 62# `consumer warp group` performs the mma operation. The epilogue section is 63# handled by the `consumer warp group` as its threads own the fragmented registers. 64# 65# Execution Timeline of Warp Specialization Kernel with 2 stages: 66# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+ 67# | |MainLoop ----> | 1st Epilogue | 2nd Epilogue | 68# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+ 69# |pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ........... | [shmem->global] | 70# |wgroup1 | .......| | | | | | [shmem->global] | 71# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+ 72# |wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]| 73# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+ 74 75import errno 76import numpy as np 77import subprocess 78import ctypes 79from tools import nvgpucompiler 80from tools import matmulBuilder 81import contextlib 82import os 83import sys 84import pathlib 85import ctypes 86from mlir import runtime as rt 87 88 89def generate_matmul( 90 input_type=np.float16, 91 output_type=np.float32, 92 M=4096, 93 N=4096, 94 K=4096, 95 BLOCK_M=128, 96 BLOCK_N=128, 97 BLOCK_K=64, 98 use_warp_specialization=True, 99 saveIR=False, 100 max_num_stages=3, 101 options=f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3", 102): 103 with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown(): 104 if use_warp_specialization: 105 mlir_nvgpu_module = matmulBuilder.generate_matmul_ws( 106 input_type, 107 output_type, 108 M, 109 N, 110 K, 111 BLOCK_M, 112 BLOCK_N, 113 BLOCK_K, 114 max_num_stages, 115 ) 116 else: 117 mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage( 118 input_type, 119 output_type, 120 M, 121 N, 122 K, 123 BLOCK_M, 124 BLOCK_N, 125 BLOCK_K, 126 max_num_stages, 127 ) 128 129 mlir_nvgpu_module.operation.verify() 130 131 # Save generated IR 132 if saveIR: 133 # print(mlir_nvgpu_module) 134 original_stdout = sys.stdout 135 with open("gemm.mlir", "w") as f: 136 sys.stdout = f 137 print(mlir_nvgpu_module) 138 sys.stdout = original_stdout 139 140 # Get compiler 141 support_lib = os.getenv("SUPPORT_LIB") 142 if not os.path.exists(support_lib): 143 raise FileNotFoundError( 144 errno.ENOENT, os.strerror(errno.ENOENT), support_lib 145 ) 146 compiler = nvgpucompiler.NvgpuCompiler( 147 options, opt_level=3, shared_libs=[support_lib] 148 ) 149 150 # Compile 151 engine = compiler.compile_and_jit(mlir_nvgpu_module) 152 return engine 153 154 155def matmul( 156 input_type=np.float16, 157 output_type=np.float32, 158 M=128, 159 N=128, 160 K=128, 161 BLOCK_M=128, 162 BLOCK_N=128, 163 BLOCK_K=64, 164 use_warp_specialization=True, 165 saveIR=False, 166 max_num_stages=3, 167 print_results=False, 168 no_verify=False, 169): 170 # Print the configuration 171 required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N) 172 num_stages = min(required_stages, max_num_stages) 173 ity = "f16" if input_type == np.float16 else "f32" 174 oty = "f16" if output_type == np.float16 else "f32" 175 gemmty = "Warp specialization" if use_warp_specialization else "Multistage" 176 print( 177 "===-- Running GEMM " 178 + gemmty 179 + " " 180 + oty 181 + " += " 182 + ity 183 + " * " 184 + ity 185 + ", Size " 186 + str(M) 187 + "x" 188 + str(N) 189 + "x" 190 + str(K) 191 + ", Tile " 192 + str(BLOCK_M) 193 + "x" 194 + str(BLOCK_N) 195 + "x" 196 + str(BLOCK_K) 197 + ", stages " 198 + str(num_stages) 199 + " --===" 200 ) 201 202 # Build IR and compile 203 engine = generate_matmul( 204 input_type, 205 output_type, 206 M, 207 N, 208 K, 209 BLOCK_M, 210 BLOCK_N, 211 BLOCK_K, 212 use_warp_specialization, 213 saveIR, 214 num_stages, 215 ) 216 217 # Allocate matrices and invoke the matmul 218 c = np.zeros((M, N), output_type) 219 a = np.random.randn(M, K).astype(input_type) 220 b = np.random.randn(K, N).astype(input_type) 221 mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a))) 222 mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b))) 223 mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c))) 224 kernelName = matmulBuilder.make_kernel_name( 225 input_type, 226 output_type, 227 M, 228 N, 229 K, 230 BLOCK_M, 231 BLOCK_N, 232 BLOCK_K, 233 num_stages, 234 use_warp_specialization, 235 ) 236 237 # Launch the MLIR generated kernel 238 engine.invoke(kernelName, mem_a, mem_b, mem_c) 239 240 float_formatter = "{:.2f}".format 241 np.set_printoptions(formatter={"float_kind": float_formatter}) 242 243 if print_results: 244 print(c) 245 246 # Verify the results 247 if not no_verify: 248 ref = a.astype(input_type) @ b.astype(input_type) 249 if print_results: 250 print(ref) 251 np.testing.assert_allclose(c, ref, rtol=5e-03, atol=1e-01) 252 253 print("PASS ") 254 255 256# Takes longer time to run 257def test_long(): 258 for stages in range(1, 7): 259 for M in [128, 512, 1024, 4096, 8192]: 260 for N in [128, 512, 1024, 4096, 8192]: 261 for K in [64, 128, 512, 1024, 4096, 8192]: 262 matmul( 263 np.float16, 264 np.float32, 265 M, 266 N, 267 K, 268 max_num_stages=stages, 269 use_warp_specialization=False, 270 no_verify=True, 271 ) 272 matmul( 273 np.float16, 274 np.float32, 275 M, 276 N, 277 K, 278 max_num_stages=stages, 279 use_warp_specialization=True, 280 ) 281 282 283def test_short(): 284 for stages in [1, 3]: 285 for M in [128, 512]: 286 for N in [128]: 287 for K in [64, 256]: 288 matmul( 289 np.float16, 290 np.float32, 291 M, 292 N, 293 K, 294 max_num_stages=stages, 295 use_warp_specialization=False, 296 ) 297 matmul( 298 np.float16, 299 np.float32, 300 M, 301 N, 302 K, 303 max_num_stages=stages, 304 use_warp_specialization=True, 305 ) 306 307 308# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --=== 309# CHECK: PASS 310# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --=== 311# CHECK: PASS 312# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --=== 313# CHECK: PASS 314# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --=== 315# CHECK: PASS 316# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --=== 317# CHECK: PASS 318# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --=== 319# CHECK: PASS 320# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --=== 321# CHECK: PASS 322# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --=== 323# CHECK: PASS 324# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --=== 325# CHECK: PASS 326# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --=== 327# CHECK: PASS 328# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --=== 329# CHECK: PASS 330# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --=== 331# CHECK: PASS 332# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --=== 333# CHECK: PASS 334# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --=== 335# CHECK: PASS 336# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --=== 337# CHECK: PASS 338# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --=== 339# CHECK: PASS 340 341test_short() 342