1*d95e6d02SGuray Ozen# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ 2*d95e6d02SGuray Ozen# RUN: %PYTHON %s | FileCheck %s 3*d95e6d02SGuray Ozen 4*d95e6d02SGuray Ozen 5*d95e6d02SGuray Ozen# ===--- GEMM Hopper Tensor Core Integration Test ---=== 6*d95e6d02SGuray Ozen# 7*d95e6d02SGuray Ozen# This test aims to validate the correctness of the supported GEMM kernels in 8*d95e6d02SGuray Ozen# NVGPU dialects, with current support for Multistage and Warp Specialization 9*d95e6d02SGuray Ozen# kernels. 10*d95e6d02SGuray Ozen# The test constructs and metaprograms IR using Python bindings, allowing 11*d95e6d02SGuray Ozen# generic IR building. This flexibility enables changes to the shape, 12*d95e6d02SGuray Ozen# tile size, or data type of the GEMM for testing purposes. 13*d95e6d02SGuray Ozen# The entry function is `matmul`, where one can specify GEMM shape, tile size, 14*d95e6d02SGuray Ozen# data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum 15*d95e6d02SGuray Ozen# number of stages. 16*d95e6d02SGuray Ozen# Verification is done via numpy's matmul operation. 17*d95e6d02SGuray Ozen# 18*d95e6d02SGuray Ozen# Example: 19*d95e6d02SGuray Ozen# matmul(input_type=np.float16, # input types 20*d95e6d02SGuray Ozen# output_type=np.float32, # output type 21*d95e6d02SGuray Ozen# M=4096, N=4096, K=4096, # Shape 22*d95e6d02SGuray Ozen# BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size 23*d95e6d02SGuray Ozen# use_warp_specialization=True, # Enable Warp Specialization 24*d95e6d02SGuray Ozen# max_num_stages=3) # Number of stages in shared memory 25*d95e6d02SGuray Ozen# 26*d95e6d02SGuray Ozen# ===--- Parallelism Across CTAs ---=== 27*d95e6d02SGuray Ozen# 28*d95e6d02SGuray Ozen# GEMM includes three loops defining the shape of the GEMM, specified in the 29*d95e6d02SGuray Ozen# `matmul` function. 30*d95e6d02SGuray Ozen# The program builds IR using the following loop structure, tiling the loops 31*d95e6d02SGuray Ozen# with the given tile size and parallelizing the two outermost loops into the 32*d95e6d02SGuray Ozen# first and second dimensions of CTAs. 33*d95e6d02SGuray Ozen# 34*d95e6d02SGuray Ozen# for(bi = 0; i < M; i += BLOCK_M) # parallelize across blockIdx.x 35*d95e6d02SGuray Ozen# for(bj = 0; j < N; j += BLOCK_N) # parallelize across blockIdx.y 36*d95e6d02SGuray Ozen# for(bk = 0; k < K; K += BLOCK_K) 37*d95e6d02SGuray Ozen# for(i = bi; i < (bi + BLOCK_M); ++i) 38*d95e6d02SGuray Ozen# for(j = bj; j < (bj + BLOCK_N); ++j) 39*d95e6d02SGuray Ozen# for(k = bk; k < (bk + BLOCK_K); ++k) 40*d95e6d02SGuray Ozen# 41*d95e6d02SGuray Ozen# ===--- Multistage Kernel ---=== 42*d95e6d02SGuray Ozen# 43*d95e6d02SGuray Ozen# This kernel launches a single warp group (128 threads). The primary thread 44*d95e6d02SGuray Ozen# (pthread) requests load from TMA. Threads collectively wait for the data and 45*d95e6d02SGuray Ozen# perform mma operations. After completing the shape, threads together store 46*d95e6d02SGuray Ozen# first fragmented registers to shared memory, then from shared memory to global 47*d95e6d02SGuray Ozen# memory; this part is called the epilogue. 48*d95e6d02SGuray Ozen# 49*d95e6d02SGuray Ozen# Execution Timeline of Multistage Kernel with 3 stages: 50*d95e6d02SGuray Ozen# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+ 51*d95e6d02SGuray Ozen# | |Prologue ----> |MainLoop ----> |Epilogue | 52*d95e6d02SGuray Ozen# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+ 53*d95e6d02SGuray Ozen# |pthread|[tma-0,1,2] |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]| 54*d95e6d02SGuray Ozen# |wgroup | ........ |[wait-0][mma] |[wait-1][mma] |[wait-2][mma] | ... | [mma-wait] |[epilogue]| 55*d95e6d02SGuray Ozen# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+ 56*d95e6d02SGuray Ozen# 57*d95e6d02SGuray Ozen# ===--- Warp Specialization Kernel ---=== 58*d95e6d02SGuray Ozen# 59*d95e6d02SGuray Ozen# This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one 60*d95e6d02SGuray Ozen# as `producer warp group` and another as `consumer warp group`. The 61*d95e6d02SGuray Ozen# `producer warp group` is responsible for requesting TMA load, while the 62*d95e6d02SGuray Ozen# `consumer warp group` performs the mma operation. The epilogue section is 63*d95e6d02SGuray Ozen# handled by the `consumer warp group` as its threads own the fragmented registers. 64*d95e6d02SGuray Ozen# 65*d95e6d02SGuray Ozen# Execution Timeline of Warp Specialization Kernel with 2 stages: 66*d95e6d02SGuray Ozen# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+ 67*d95e6d02SGuray Ozen# | |MainLoop ----> | 1st Epilogue | 2nd Epilogue | 68*d95e6d02SGuray Ozen# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+ 69*d95e6d02SGuray Ozen# |pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ........... | [shmem->global] | 70*d95e6d02SGuray Ozen# |wgroup1 | .......| | | | | | [shmem->global] | 71*d95e6d02SGuray Ozen# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+ 72*d95e6d02SGuray Ozen# |wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]| 73*d95e6d02SGuray Ozen# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+ 74*d95e6d02SGuray Ozen 75*d95e6d02SGuray Ozenimport errno 76*d95e6d02SGuray Ozenimport numpy as np 77*d95e6d02SGuray Ozenimport subprocess 78*d95e6d02SGuray Ozenimport ctypes 79*d95e6d02SGuray Ozenfrom tools import nvgpucompiler 80*d95e6d02SGuray Ozenfrom tools import matmulBuilder 81*d95e6d02SGuray Ozenimport contextlib 82*d95e6d02SGuray Ozenimport os 83*d95e6d02SGuray Ozenimport sys 84*d95e6d02SGuray Ozenimport pathlib 85*d95e6d02SGuray Ozenimport ctypes 86*d95e6d02SGuray Ozenfrom mlir import runtime as rt 87*d95e6d02SGuray Ozen 88*d95e6d02SGuray Ozen 89*d95e6d02SGuray Ozendef generate_matmul( 90*d95e6d02SGuray Ozen input_type=np.float16, 91*d95e6d02SGuray Ozen output_type=np.float32, 92*d95e6d02SGuray Ozen M=4096, 93*d95e6d02SGuray Ozen N=4096, 94*d95e6d02SGuray Ozen K=4096, 95*d95e6d02SGuray Ozen BLOCK_M=128, 96*d95e6d02SGuray Ozen BLOCK_N=128, 97*d95e6d02SGuray Ozen BLOCK_K=64, 98*d95e6d02SGuray Ozen use_warp_specialization=True, 99*d95e6d02SGuray Ozen saveIR=False, 100*d95e6d02SGuray Ozen max_num_stages=3, 101*d95e6d02SGuray Ozen options=f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3", 102*d95e6d02SGuray Ozen): 103*d95e6d02SGuray Ozen with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown(): 104*d95e6d02SGuray Ozen if use_warp_specialization: 105*d95e6d02SGuray Ozen mlir_nvgpu_module = matmulBuilder.generate_matmul_ws( 106*d95e6d02SGuray Ozen input_type, 107*d95e6d02SGuray Ozen output_type, 108*d95e6d02SGuray Ozen M, 109*d95e6d02SGuray Ozen N, 110*d95e6d02SGuray Ozen K, 111*d95e6d02SGuray Ozen BLOCK_M, 112*d95e6d02SGuray Ozen BLOCK_N, 113*d95e6d02SGuray Ozen BLOCK_K, 114*d95e6d02SGuray Ozen max_num_stages, 115*d95e6d02SGuray Ozen ) 116*d95e6d02SGuray Ozen else: 117*d95e6d02SGuray Ozen mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage( 118*d95e6d02SGuray Ozen input_type, 119*d95e6d02SGuray Ozen output_type, 120*d95e6d02SGuray Ozen M, 121*d95e6d02SGuray Ozen N, 122*d95e6d02SGuray Ozen K, 123*d95e6d02SGuray Ozen BLOCK_M, 124*d95e6d02SGuray Ozen BLOCK_N, 125*d95e6d02SGuray Ozen BLOCK_K, 126*d95e6d02SGuray Ozen max_num_stages, 127*d95e6d02SGuray Ozen ) 128*d95e6d02SGuray Ozen 129*d95e6d02SGuray Ozen mlir_nvgpu_module.operation.verify() 130*d95e6d02SGuray Ozen 131*d95e6d02SGuray Ozen # Save generated IR 132*d95e6d02SGuray Ozen if saveIR: 133*d95e6d02SGuray Ozen # print(mlir_nvgpu_module) 134*d95e6d02SGuray Ozen original_stdout = sys.stdout 135*d95e6d02SGuray Ozen with open("gemm.mlir", "w") as f: 136*d95e6d02SGuray Ozen sys.stdout = f 137*d95e6d02SGuray Ozen print(mlir_nvgpu_module) 138*d95e6d02SGuray Ozen sys.stdout = original_stdout 139*d95e6d02SGuray Ozen 140*d95e6d02SGuray Ozen # Get compiler 141*d95e6d02SGuray Ozen support_lib = os.getenv("SUPPORT_LIB") 142*d95e6d02SGuray Ozen if not os.path.exists(support_lib): 143*d95e6d02SGuray Ozen raise FileNotFoundError( 144*d95e6d02SGuray Ozen errno.ENOENT, os.strerror(errno.ENOENT), support_lib 145*d95e6d02SGuray Ozen ) 146*d95e6d02SGuray Ozen compiler = nvgpucompiler.NvgpuCompiler( 147*d95e6d02SGuray Ozen options, opt_level=3, shared_libs=[support_lib] 148*d95e6d02SGuray Ozen ) 149*d95e6d02SGuray Ozen 150*d95e6d02SGuray Ozen # Compile 151*d95e6d02SGuray Ozen engine = compiler.compile_and_jit(mlir_nvgpu_module) 152*d95e6d02SGuray Ozen return engine 153*d95e6d02SGuray Ozen 154*d95e6d02SGuray Ozen 155*d95e6d02SGuray Ozendef matmul( 156*d95e6d02SGuray Ozen input_type=np.float16, 157*d95e6d02SGuray Ozen output_type=np.float32, 158*d95e6d02SGuray Ozen M=128, 159*d95e6d02SGuray Ozen N=128, 160*d95e6d02SGuray Ozen K=128, 161*d95e6d02SGuray Ozen BLOCK_M=128, 162*d95e6d02SGuray Ozen BLOCK_N=128, 163*d95e6d02SGuray Ozen BLOCK_K=64, 164*d95e6d02SGuray Ozen use_warp_specialization=True, 165*d95e6d02SGuray Ozen saveIR=False, 166*d95e6d02SGuray Ozen max_num_stages=3, 167*d95e6d02SGuray Ozen print_results=False, 168*d95e6d02SGuray Ozen no_verify=False, 169*d95e6d02SGuray Ozen): 170*d95e6d02SGuray Ozen # Print the configuration 171*d95e6d02SGuray Ozen required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N) 172*d95e6d02SGuray Ozen num_stages = min(required_stages, max_num_stages) 173*d95e6d02SGuray Ozen ity = "f16" if input_type == np.float16 else "f32" 174*d95e6d02SGuray Ozen oty = "f16" if output_type == np.float16 else "f32" 175*d95e6d02SGuray Ozen gemmty = "Warp specialization" if use_warp_specialization else "Multistage" 176*d95e6d02SGuray Ozen print( 177*d95e6d02SGuray Ozen "===-- Running GEMM " 178*d95e6d02SGuray Ozen + gemmty 179*d95e6d02SGuray Ozen + " " 180*d95e6d02SGuray Ozen + oty 181*d95e6d02SGuray Ozen + " += " 182*d95e6d02SGuray Ozen + ity 183*d95e6d02SGuray Ozen + " * " 184*d95e6d02SGuray Ozen + ity 185*d95e6d02SGuray Ozen + ", Size " 186*d95e6d02SGuray Ozen + str(M) 187*d95e6d02SGuray Ozen + "x" 188*d95e6d02SGuray Ozen + str(N) 189*d95e6d02SGuray Ozen + "x" 190*d95e6d02SGuray Ozen + str(K) 191*d95e6d02SGuray Ozen + ", Tile " 192*d95e6d02SGuray Ozen + str(BLOCK_M) 193*d95e6d02SGuray Ozen + "x" 194*d95e6d02SGuray Ozen + str(BLOCK_N) 195*d95e6d02SGuray Ozen + "x" 196*d95e6d02SGuray Ozen + str(BLOCK_K) 197*d95e6d02SGuray Ozen + ", stages " 198*d95e6d02SGuray Ozen + str(num_stages) 199*d95e6d02SGuray Ozen + " --===" 200*d95e6d02SGuray Ozen ) 201*d95e6d02SGuray Ozen 202*d95e6d02SGuray Ozen # Build IR and compile 203*d95e6d02SGuray Ozen engine = generate_matmul( 204*d95e6d02SGuray Ozen input_type, 205*d95e6d02SGuray Ozen output_type, 206*d95e6d02SGuray Ozen M, 207*d95e6d02SGuray Ozen N, 208*d95e6d02SGuray Ozen K, 209*d95e6d02SGuray Ozen BLOCK_M, 210*d95e6d02SGuray Ozen BLOCK_N, 211*d95e6d02SGuray Ozen BLOCK_K, 212*d95e6d02SGuray Ozen use_warp_specialization, 213*d95e6d02SGuray Ozen saveIR, 214*d95e6d02SGuray Ozen num_stages, 215*d95e6d02SGuray Ozen ) 216*d95e6d02SGuray Ozen 217*d95e6d02SGuray Ozen # Allocate matrices and invoke the matmul 218*d95e6d02SGuray Ozen c = np.zeros((M, N), output_type) 219*d95e6d02SGuray Ozen a = np.random.randn(M, K).astype(input_type) 220*d95e6d02SGuray Ozen b = np.random.randn(K, N).astype(input_type) 221*d95e6d02SGuray Ozen mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a))) 222*d95e6d02SGuray Ozen mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b))) 223*d95e6d02SGuray Ozen mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c))) 224*d95e6d02SGuray Ozen kernelName = matmulBuilder.make_kernel_name( 225*d95e6d02SGuray Ozen input_type, 226*d95e6d02SGuray Ozen output_type, 227*d95e6d02SGuray Ozen M, 228*d95e6d02SGuray Ozen N, 229*d95e6d02SGuray Ozen K, 230*d95e6d02SGuray Ozen BLOCK_M, 231*d95e6d02SGuray Ozen BLOCK_N, 232*d95e6d02SGuray Ozen BLOCK_K, 233*d95e6d02SGuray Ozen num_stages, 234*d95e6d02SGuray Ozen use_warp_specialization, 235*d95e6d02SGuray Ozen ) 236*d95e6d02SGuray Ozen 237*d95e6d02SGuray Ozen # Launch the MLIR generated kernel 238*d95e6d02SGuray Ozen engine.invoke(kernelName, mem_a, mem_b, mem_c) 239*d95e6d02SGuray Ozen 240*d95e6d02SGuray Ozen float_formatter = "{:.2f}".format 241*d95e6d02SGuray Ozen np.set_printoptions(formatter={"float_kind": float_formatter}) 242*d95e6d02SGuray Ozen 243*d95e6d02SGuray Ozen if print_results: 244*d95e6d02SGuray Ozen print(c) 245*d95e6d02SGuray Ozen 246*d95e6d02SGuray Ozen # Verify the results 247*d95e6d02SGuray Ozen if not no_verify: 248*d95e6d02SGuray Ozen ref = a.astype(input_type) @ b.astype(input_type) 249*d95e6d02SGuray Ozen if print_results: 250*d95e6d02SGuray Ozen print(ref) 251*d95e6d02SGuray Ozen np.testing.assert_allclose(c, ref, rtol=5e-03, atol=1e-01) 252*d95e6d02SGuray Ozen 253*d95e6d02SGuray Ozen print("PASS ") 254*d95e6d02SGuray Ozen 255*d95e6d02SGuray Ozen 256*d95e6d02SGuray Ozen# Takes longer time to run 257*d95e6d02SGuray Ozendef test_long(): 258*d95e6d02SGuray Ozen for stages in range(1, 7): 259*d95e6d02SGuray Ozen for M in [128, 512, 1024, 4096, 8192]: 260*d95e6d02SGuray Ozen for N in [128, 512, 1024, 4096, 8192]: 261*d95e6d02SGuray Ozen for K in [64, 128, 512, 1024, 4096, 8192]: 262*d95e6d02SGuray Ozen matmul( 263*d95e6d02SGuray Ozen np.float16, 264*d95e6d02SGuray Ozen np.float32, 265*d95e6d02SGuray Ozen M, 266*d95e6d02SGuray Ozen N, 267*d95e6d02SGuray Ozen K, 268*d95e6d02SGuray Ozen max_num_stages=stages, 269*d95e6d02SGuray Ozen use_warp_specialization=False, 270*d95e6d02SGuray Ozen no_verify=True, 271*d95e6d02SGuray Ozen ) 272*d95e6d02SGuray Ozen matmul( 273*d95e6d02SGuray Ozen np.float16, 274*d95e6d02SGuray Ozen np.float32, 275*d95e6d02SGuray Ozen M, 276*d95e6d02SGuray Ozen N, 277*d95e6d02SGuray Ozen K, 278*d95e6d02SGuray Ozen max_num_stages=stages, 279*d95e6d02SGuray Ozen use_warp_specialization=True, 280*d95e6d02SGuray Ozen ) 281*d95e6d02SGuray Ozen 282*d95e6d02SGuray Ozen 283*d95e6d02SGuray Ozendef test_short(): 284*d95e6d02SGuray Ozen for stages in [1, 3]: 285*d95e6d02SGuray Ozen for M in [128, 512]: 286*d95e6d02SGuray Ozen for N in [128]: 287*d95e6d02SGuray Ozen for K in [64, 256]: 288*d95e6d02SGuray Ozen matmul( 289*d95e6d02SGuray Ozen np.float16, 290*d95e6d02SGuray Ozen np.float32, 291*d95e6d02SGuray Ozen M, 292*d95e6d02SGuray Ozen N, 293*d95e6d02SGuray Ozen K, 294*d95e6d02SGuray Ozen max_num_stages=stages, 295*d95e6d02SGuray Ozen use_warp_specialization=False, 296*d95e6d02SGuray Ozen ) 297*d95e6d02SGuray Ozen matmul( 298*d95e6d02SGuray Ozen np.float16, 299*d95e6d02SGuray Ozen np.float32, 300*d95e6d02SGuray Ozen M, 301*d95e6d02SGuray Ozen N, 302*d95e6d02SGuray Ozen K, 303*d95e6d02SGuray Ozen max_num_stages=stages, 304*d95e6d02SGuray Ozen use_warp_specialization=True, 305*d95e6d02SGuray Ozen ) 306*d95e6d02SGuray Ozen 307*d95e6d02SGuray Ozen 308*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --=== 309*d95e6d02SGuray Ozen# CHECK: PASS 310*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --=== 311*d95e6d02SGuray Ozen# CHECK: PASS 312*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --=== 313*d95e6d02SGuray Ozen# CHECK: PASS 314*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --=== 315*d95e6d02SGuray Ozen# CHECK: PASS 316*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --=== 317*d95e6d02SGuray Ozen# CHECK: PASS 318*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --=== 319*d95e6d02SGuray Ozen# CHECK: PASS 320*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --=== 321*d95e6d02SGuray Ozen# CHECK: PASS 322*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --=== 323*d95e6d02SGuray Ozen# CHECK: PASS 324*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --=== 325*d95e6d02SGuray Ozen# CHECK: PASS 326*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --=== 327*d95e6d02SGuray Ozen# CHECK: PASS 328*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --=== 329*d95e6d02SGuray Ozen# CHECK: PASS 330*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --=== 331*d95e6d02SGuray Ozen# CHECK: PASS 332*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --=== 333*d95e6d02SGuray Ozen# CHECK: PASS 334*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --=== 335*d95e6d02SGuray Ozen# CHECK: PASS 336*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --=== 337*d95e6d02SGuray Ozen# CHECK: PASS 338*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --=== 339*d95e6d02SGuray Ozen# CHECK: PASS 340*d95e6d02SGuray Ozen 341*d95e6d02SGuray Ozentest_short() 342