xref: /llvm-project/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py (revision d95e6d027486876559f1a2a96c33b8ad93cc0ae4)
1*d95e6d02SGuray Ozen# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2*d95e6d02SGuray Ozen# RUN:   %PYTHON %s | FileCheck %s
3*d95e6d02SGuray Ozen
4*d95e6d02SGuray Ozen
5*d95e6d02SGuray Ozen# ===--- GEMM Hopper Tensor Core Integration Test ---===
6*d95e6d02SGuray Ozen#
7*d95e6d02SGuray Ozen# This test aims to validate the correctness of the supported GEMM kernels in
8*d95e6d02SGuray Ozen# NVGPU dialects, with current support for Multistage and Warp Specialization
9*d95e6d02SGuray Ozen# kernels.
10*d95e6d02SGuray Ozen# The test constructs and metaprograms IR using Python bindings, allowing
11*d95e6d02SGuray Ozen# generic IR building. This flexibility enables changes to the shape,
12*d95e6d02SGuray Ozen# tile size, or data type of the GEMM for testing purposes.
13*d95e6d02SGuray Ozen# The entry function is `matmul`, where one can specify GEMM shape, tile size,
14*d95e6d02SGuray Ozen# data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum
15*d95e6d02SGuray Ozen# number of stages.
16*d95e6d02SGuray Ozen# Verification is done via numpy's matmul operation.
17*d95e6d02SGuray Ozen#
18*d95e6d02SGuray Ozen# Example:
19*d95e6d02SGuray Ozen# matmul(input_type=np.float16,                # input types
20*d95e6d02SGuray Ozen#        output_type=np.float32,               # output type
21*d95e6d02SGuray Ozen#        M=4096, N=4096, K=4096,               # Shape
22*d95e6d02SGuray Ozen#        BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
23*d95e6d02SGuray Ozen#        use_warp_specialization=True,         # Enable Warp Specialization
24*d95e6d02SGuray Ozen#        max_num_stages=3)                     # Number of stages in shared memory
25*d95e6d02SGuray Ozen#
26*d95e6d02SGuray Ozen# ===--- Parallelism Across CTAs  ---===
27*d95e6d02SGuray Ozen#
28*d95e6d02SGuray Ozen# GEMM includes three loops defining the shape of the GEMM, specified in the
29*d95e6d02SGuray Ozen# `matmul` function.
30*d95e6d02SGuray Ozen# The program builds IR using the following loop structure, tiling the loops
31*d95e6d02SGuray Ozen# with the given tile size and parallelizing the two outermost loops into the
32*d95e6d02SGuray Ozen# first and second dimensions of CTAs.
33*d95e6d02SGuray Ozen#
34*d95e6d02SGuray Ozen# for(bi = 0; i < M; i += BLOCK_M)          # parallelize across blockIdx.x
35*d95e6d02SGuray Ozen#     for(bj = 0; j < N; j += BLOCK_N)      # parallelize across blockIdx.y
36*d95e6d02SGuray Ozen#         for(bk = 0; k < K; K += BLOCK_K)
37*d95e6d02SGuray Ozen#             for(i = bi; i < (bi + BLOCK_M); ++i)
38*d95e6d02SGuray Ozen#                 for(j = bj; j < (bj + BLOCK_N); ++j)
39*d95e6d02SGuray Ozen#                     for(k = bk; k < (bk + BLOCK_K); ++k)
40*d95e6d02SGuray Ozen#
41*d95e6d02SGuray Ozen# ===--- Multistage Kernel ---===
42*d95e6d02SGuray Ozen#
43*d95e6d02SGuray Ozen# This kernel launches a single warp group (128 threads). The primary thread
44*d95e6d02SGuray Ozen# (pthread) requests load from TMA. Threads collectively wait for the data and
45*d95e6d02SGuray Ozen# perform mma operations. After completing the shape, threads together store
46*d95e6d02SGuray Ozen# first fragmented registers to shared memory, then from shared memory to global
47*d95e6d02SGuray Ozen# memory; this part is called the epilogue.
48*d95e6d02SGuray Ozen#
49*d95e6d02SGuray Ozen# Execution Timeline of Multistage Kernel with 3 stages:
50*d95e6d02SGuray Ozen# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
51*d95e6d02SGuray Ozen# |       |Prologue ---->   |MainLoop ---->                                                                  |Epilogue  |
52*d95e6d02SGuray Ozen# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
53*d95e6d02SGuray Ozen# |pthread|[tma-0,1,2]     |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
54*d95e6d02SGuray Ozen# |wgroup | ........       |[wait-0][mma]       |[wait-1][mma]       |[wait-2][mma]       | ... | [mma-wait] |[epilogue]|
55*d95e6d02SGuray Ozen# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
56*d95e6d02SGuray Ozen#
57*d95e6d02SGuray Ozen# ===--- Warp Specialization Kernel  ---===
58*d95e6d02SGuray Ozen#
59*d95e6d02SGuray Ozen# This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one
60*d95e6d02SGuray Ozen# as `producer warp group` and another as `consumer warp group`. The
61*d95e6d02SGuray Ozen# `producer warp group` is responsible for requesting TMA load, while the
62*d95e6d02SGuray Ozen# `consumer warp group` performs the mma operation. The epilogue section is
63*d95e6d02SGuray Ozen# handled by the `consumer warp group` as its threads own the fragmented registers.
64*d95e6d02SGuray Ozen#
65*d95e6d02SGuray Ozen# Execution Timeline of Warp Specialization Kernel with 2 stages:
66*d95e6d02SGuray Ozen# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
67*d95e6d02SGuray Ozen# |        |MainLoop ---->                                                    | 1st Epilogue | 2nd Epilogue    |
68*d95e6d02SGuray Ozen# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
69*d95e6d02SGuray Ozen# |pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ...........  | [shmem->global] |
70*d95e6d02SGuray Ozen# |wgroup1 | .......|         |         |         |                           |              | [shmem->global] |
71*d95e6d02SGuray Ozen# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
72*d95e6d02SGuray Ozen# |wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
73*d95e6d02SGuray Ozen# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
74*d95e6d02SGuray Ozen
75*d95e6d02SGuray Ozenimport errno
76*d95e6d02SGuray Ozenimport numpy as np
77*d95e6d02SGuray Ozenimport subprocess
78*d95e6d02SGuray Ozenimport ctypes
79*d95e6d02SGuray Ozenfrom tools import nvgpucompiler
80*d95e6d02SGuray Ozenfrom tools import matmulBuilder
81*d95e6d02SGuray Ozenimport contextlib
82*d95e6d02SGuray Ozenimport os
83*d95e6d02SGuray Ozenimport sys
84*d95e6d02SGuray Ozenimport pathlib
85*d95e6d02SGuray Ozenimport ctypes
86*d95e6d02SGuray Ozenfrom mlir import runtime as rt
87*d95e6d02SGuray Ozen
88*d95e6d02SGuray Ozen
89*d95e6d02SGuray Ozendef generate_matmul(
90*d95e6d02SGuray Ozen    input_type=np.float16,
91*d95e6d02SGuray Ozen    output_type=np.float32,
92*d95e6d02SGuray Ozen    M=4096,
93*d95e6d02SGuray Ozen    N=4096,
94*d95e6d02SGuray Ozen    K=4096,
95*d95e6d02SGuray Ozen    BLOCK_M=128,
96*d95e6d02SGuray Ozen    BLOCK_N=128,
97*d95e6d02SGuray Ozen    BLOCK_K=64,
98*d95e6d02SGuray Ozen    use_warp_specialization=True,
99*d95e6d02SGuray Ozen    saveIR=False,
100*d95e6d02SGuray Ozen    max_num_stages=3,
101*d95e6d02SGuray Ozen    options=f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3",
102*d95e6d02SGuray Ozen):
103*d95e6d02SGuray Ozen    with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
104*d95e6d02SGuray Ozen        if use_warp_specialization:
105*d95e6d02SGuray Ozen            mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(
106*d95e6d02SGuray Ozen                input_type,
107*d95e6d02SGuray Ozen                output_type,
108*d95e6d02SGuray Ozen                M,
109*d95e6d02SGuray Ozen                N,
110*d95e6d02SGuray Ozen                K,
111*d95e6d02SGuray Ozen                BLOCK_M,
112*d95e6d02SGuray Ozen                BLOCK_N,
113*d95e6d02SGuray Ozen                BLOCK_K,
114*d95e6d02SGuray Ozen                max_num_stages,
115*d95e6d02SGuray Ozen            )
116*d95e6d02SGuray Ozen        else:
117*d95e6d02SGuray Ozen            mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(
118*d95e6d02SGuray Ozen                input_type,
119*d95e6d02SGuray Ozen                output_type,
120*d95e6d02SGuray Ozen                M,
121*d95e6d02SGuray Ozen                N,
122*d95e6d02SGuray Ozen                K,
123*d95e6d02SGuray Ozen                BLOCK_M,
124*d95e6d02SGuray Ozen                BLOCK_N,
125*d95e6d02SGuray Ozen                BLOCK_K,
126*d95e6d02SGuray Ozen                max_num_stages,
127*d95e6d02SGuray Ozen            )
128*d95e6d02SGuray Ozen
129*d95e6d02SGuray Ozen        mlir_nvgpu_module.operation.verify()
130*d95e6d02SGuray Ozen
131*d95e6d02SGuray Ozen        # Save generated IR
132*d95e6d02SGuray Ozen        if saveIR:
133*d95e6d02SGuray Ozen            # print(mlir_nvgpu_module)
134*d95e6d02SGuray Ozen            original_stdout = sys.stdout
135*d95e6d02SGuray Ozen            with open("gemm.mlir", "w") as f:
136*d95e6d02SGuray Ozen                sys.stdout = f
137*d95e6d02SGuray Ozen                print(mlir_nvgpu_module)
138*d95e6d02SGuray Ozen                sys.stdout = original_stdout
139*d95e6d02SGuray Ozen
140*d95e6d02SGuray Ozen        # Get compiler
141*d95e6d02SGuray Ozen        support_lib = os.getenv("SUPPORT_LIB")
142*d95e6d02SGuray Ozen        if not os.path.exists(support_lib):
143*d95e6d02SGuray Ozen            raise FileNotFoundError(
144*d95e6d02SGuray Ozen                errno.ENOENT, os.strerror(errno.ENOENT), support_lib
145*d95e6d02SGuray Ozen            )
146*d95e6d02SGuray Ozen        compiler = nvgpucompiler.NvgpuCompiler(
147*d95e6d02SGuray Ozen            options, opt_level=3, shared_libs=[support_lib]
148*d95e6d02SGuray Ozen        )
149*d95e6d02SGuray Ozen
150*d95e6d02SGuray Ozen        # Compile
151*d95e6d02SGuray Ozen        engine = compiler.compile_and_jit(mlir_nvgpu_module)
152*d95e6d02SGuray Ozen        return engine
153*d95e6d02SGuray Ozen
154*d95e6d02SGuray Ozen
155*d95e6d02SGuray Ozendef matmul(
156*d95e6d02SGuray Ozen    input_type=np.float16,
157*d95e6d02SGuray Ozen    output_type=np.float32,
158*d95e6d02SGuray Ozen    M=128,
159*d95e6d02SGuray Ozen    N=128,
160*d95e6d02SGuray Ozen    K=128,
161*d95e6d02SGuray Ozen    BLOCK_M=128,
162*d95e6d02SGuray Ozen    BLOCK_N=128,
163*d95e6d02SGuray Ozen    BLOCK_K=64,
164*d95e6d02SGuray Ozen    use_warp_specialization=True,
165*d95e6d02SGuray Ozen    saveIR=False,
166*d95e6d02SGuray Ozen    max_num_stages=3,
167*d95e6d02SGuray Ozen    print_results=False,
168*d95e6d02SGuray Ozen    no_verify=False,
169*d95e6d02SGuray Ozen):
170*d95e6d02SGuray Ozen    # Print the configuration
171*d95e6d02SGuray Ozen    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
172*d95e6d02SGuray Ozen    num_stages = min(required_stages, max_num_stages)
173*d95e6d02SGuray Ozen    ity = "f16" if input_type == np.float16 else "f32"
174*d95e6d02SGuray Ozen    oty = "f16" if output_type == np.float16 else "f32"
175*d95e6d02SGuray Ozen    gemmty = "Warp specialization" if use_warp_specialization else "Multistage"
176*d95e6d02SGuray Ozen    print(
177*d95e6d02SGuray Ozen        "===-- Running GEMM "
178*d95e6d02SGuray Ozen        + gemmty
179*d95e6d02SGuray Ozen        + " "
180*d95e6d02SGuray Ozen        + oty
181*d95e6d02SGuray Ozen        + " += "
182*d95e6d02SGuray Ozen        + ity
183*d95e6d02SGuray Ozen        + " * "
184*d95e6d02SGuray Ozen        + ity
185*d95e6d02SGuray Ozen        + ", Size "
186*d95e6d02SGuray Ozen        + str(M)
187*d95e6d02SGuray Ozen        + "x"
188*d95e6d02SGuray Ozen        + str(N)
189*d95e6d02SGuray Ozen        + "x"
190*d95e6d02SGuray Ozen        + str(K)
191*d95e6d02SGuray Ozen        + ", Tile "
192*d95e6d02SGuray Ozen        + str(BLOCK_M)
193*d95e6d02SGuray Ozen        + "x"
194*d95e6d02SGuray Ozen        + str(BLOCK_N)
195*d95e6d02SGuray Ozen        + "x"
196*d95e6d02SGuray Ozen        + str(BLOCK_K)
197*d95e6d02SGuray Ozen        + ", stages "
198*d95e6d02SGuray Ozen        + str(num_stages)
199*d95e6d02SGuray Ozen        + " --==="
200*d95e6d02SGuray Ozen    )
201*d95e6d02SGuray Ozen
202*d95e6d02SGuray Ozen    # Build IR and compile
203*d95e6d02SGuray Ozen    engine = generate_matmul(
204*d95e6d02SGuray Ozen        input_type,
205*d95e6d02SGuray Ozen        output_type,
206*d95e6d02SGuray Ozen        M,
207*d95e6d02SGuray Ozen        N,
208*d95e6d02SGuray Ozen        K,
209*d95e6d02SGuray Ozen        BLOCK_M,
210*d95e6d02SGuray Ozen        BLOCK_N,
211*d95e6d02SGuray Ozen        BLOCK_K,
212*d95e6d02SGuray Ozen        use_warp_specialization,
213*d95e6d02SGuray Ozen        saveIR,
214*d95e6d02SGuray Ozen        num_stages,
215*d95e6d02SGuray Ozen    )
216*d95e6d02SGuray Ozen
217*d95e6d02SGuray Ozen    # Allocate matrices and invoke the matmul
218*d95e6d02SGuray Ozen    c = np.zeros((M, N), output_type)
219*d95e6d02SGuray Ozen    a = np.random.randn(M, K).astype(input_type)
220*d95e6d02SGuray Ozen    b = np.random.randn(K, N).astype(input_type)
221*d95e6d02SGuray Ozen    mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
222*d95e6d02SGuray Ozen    mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
223*d95e6d02SGuray Ozen    mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
224*d95e6d02SGuray Ozen    kernelName = matmulBuilder.make_kernel_name(
225*d95e6d02SGuray Ozen        input_type,
226*d95e6d02SGuray Ozen        output_type,
227*d95e6d02SGuray Ozen        M,
228*d95e6d02SGuray Ozen        N,
229*d95e6d02SGuray Ozen        K,
230*d95e6d02SGuray Ozen        BLOCK_M,
231*d95e6d02SGuray Ozen        BLOCK_N,
232*d95e6d02SGuray Ozen        BLOCK_K,
233*d95e6d02SGuray Ozen        num_stages,
234*d95e6d02SGuray Ozen        use_warp_specialization,
235*d95e6d02SGuray Ozen    )
236*d95e6d02SGuray Ozen
237*d95e6d02SGuray Ozen    # Launch the MLIR generated kernel
238*d95e6d02SGuray Ozen    engine.invoke(kernelName, mem_a, mem_b, mem_c)
239*d95e6d02SGuray Ozen
240*d95e6d02SGuray Ozen    float_formatter = "{:.2f}".format
241*d95e6d02SGuray Ozen    np.set_printoptions(formatter={"float_kind": float_formatter})
242*d95e6d02SGuray Ozen
243*d95e6d02SGuray Ozen    if print_results:
244*d95e6d02SGuray Ozen        print(c)
245*d95e6d02SGuray Ozen
246*d95e6d02SGuray Ozen    # Verify the results
247*d95e6d02SGuray Ozen    if not no_verify:
248*d95e6d02SGuray Ozen        ref = a.astype(input_type) @ b.astype(input_type)
249*d95e6d02SGuray Ozen        if print_results:
250*d95e6d02SGuray Ozen            print(ref)
251*d95e6d02SGuray Ozen        np.testing.assert_allclose(c, ref, rtol=5e-03, atol=1e-01)
252*d95e6d02SGuray Ozen
253*d95e6d02SGuray Ozen    print("PASS ")
254*d95e6d02SGuray Ozen
255*d95e6d02SGuray Ozen
256*d95e6d02SGuray Ozen# Takes longer time to run
257*d95e6d02SGuray Ozendef test_long():
258*d95e6d02SGuray Ozen    for stages in range(1, 7):
259*d95e6d02SGuray Ozen        for M in [128, 512, 1024, 4096, 8192]:
260*d95e6d02SGuray Ozen            for N in [128, 512, 1024, 4096, 8192]:
261*d95e6d02SGuray Ozen                for K in [64, 128, 512, 1024, 4096, 8192]:
262*d95e6d02SGuray Ozen                    matmul(
263*d95e6d02SGuray Ozen                        np.float16,
264*d95e6d02SGuray Ozen                        np.float32,
265*d95e6d02SGuray Ozen                        M,
266*d95e6d02SGuray Ozen                        N,
267*d95e6d02SGuray Ozen                        K,
268*d95e6d02SGuray Ozen                        max_num_stages=stages,
269*d95e6d02SGuray Ozen                        use_warp_specialization=False,
270*d95e6d02SGuray Ozen                        no_verify=True,
271*d95e6d02SGuray Ozen                    )
272*d95e6d02SGuray Ozen                    matmul(
273*d95e6d02SGuray Ozen                        np.float16,
274*d95e6d02SGuray Ozen                        np.float32,
275*d95e6d02SGuray Ozen                        M,
276*d95e6d02SGuray Ozen                        N,
277*d95e6d02SGuray Ozen                        K,
278*d95e6d02SGuray Ozen                        max_num_stages=stages,
279*d95e6d02SGuray Ozen                        use_warp_specialization=True,
280*d95e6d02SGuray Ozen                    )
281*d95e6d02SGuray Ozen
282*d95e6d02SGuray Ozen
283*d95e6d02SGuray Ozendef test_short():
284*d95e6d02SGuray Ozen    for stages in [1, 3]:
285*d95e6d02SGuray Ozen        for M in [128, 512]:
286*d95e6d02SGuray Ozen            for N in [128]:
287*d95e6d02SGuray Ozen                for K in [64, 256]:
288*d95e6d02SGuray Ozen                    matmul(
289*d95e6d02SGuray Ozen                        np.float16,
290*d95e6d02SGuray Ozen                        np.float32,
291*d95e6d02SGuray Ozen                        M,
292*d95e6d02SGuray Ozen                        N,
293*d95e6d02SGuray Ozen                        K,
294*d95e6d02SGuray Ozen                        max_num_stages=stages,
295*d95e6d02SGuray Ozen                        use_warp_specialization=False,
296*d95e6d02SGuray Ozen                    )
297*d95e6d02SGuray Ozen                    matmul(
298*d95e6d02SGuray Ozen                        np.float16,
299*d95e6d02SGuray Ozen                        np.float32,
300*d95e6d02SGuray Ozen                        M,
301*d95e6d02SGuray Ozen                        N,
302*d95e6d02SGuray Ozen                        K,
303*d95e6d02SGuray Ozen                        max_num_stages=stages,
304*d95e6d02SGuray Ozen                        use_warp_specialization=True,
305*d95e6d02SGuray Ozen                    )
306*d95e6d02SGuray Ozen
307*d95e6d02SGuray Ozen
308*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
309*d95e6d02SGuray Ozen# CHECK: PASS
310*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
311*d95e6d02SGuray Ozen# CHECK: PASS
312*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --===
313*d95e6d02SGuray Ozen# CHECK: PASS
314*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --===
315*d95e6d02SGuray Ozen# CHECK: PASS
316*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
317*d95e6d02SGuray Ozen# CHECK: PASS
318*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
319*d95e6d02SGuray Ozen# CHECK: PASS
320*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --===
321*d95e6d02SGuray Ozen# CHECK: PASS
322*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --===
323*d95e6d02SGuray Ozen# CHECK: PASS
324*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
325*d95e6d02SGuray Ozen# CHECK: PASS
326*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
327*d95e6d02SGuray Ozen# CHECK: PASS
328*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --===
329*d95e6d02SGuray Ozen# CHECK: PASS
330*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --===
331*d95e6d02SGuray Ozen# CHECK: PASS
332*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
333*d95e6d02SGuray Ozen# CHECK: PASS
334*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
335*d95e6d02SGuray Ozen# CHECK: PASS
336*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --===
337*d95e6d02SGuray Ozen# CHECK: PASS
338*d95e6d02SGuray Ozen# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --===
339*d95e6d02SGuray Ozen# CHECK: PASS
340*d95e6d02SGuray Ozen
341*d95e6d02SGuray Ozentest_short()
342