xref: /llvm-project/mlir/test/Examples/NVGPU/Ch4.py (revision f8ff9094711b74d3f695f7571f6390f8a481fc52)
1# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2# RUN:   %PYTHON %s | FileCheck %s
3
4# ===----------------------------------------------------------------------===//
5#  Chapter 4 : Multistage GEMM with Tensor Core
6# ===----------------------------------------------------------------------===//
7#
8# This program exemplifies a GEMM operation for `f32+=f16*f16`, utilizing the
9# Multistage method with a tile size of 128x128x64. The code completely
10# parallelizes the two outermost loops into thread blocks. It launches one Warp
11# Groups (128 threads in total) and allocates multiple slots/stage in the
12# shared memory. The program consists of three main parts: prologue, mainloop,
13# and epilogue. In the prologue, thread0 requests for TMA to load data into
14# shared memory slots. The mainloop executes MMA while simultaneously loading
15# TMA for the utilized slots. This overlap of TMA and MMA operations enhances
16# performance by maximizing computational throughput.
17#
18# Loops illustration:
19#
20#  for s in range(num_stages):
21#    TMA_128x64_64x128...
22#  for ti in range(M//128):  # -> blockIdx.x
23#   for tj in range(N//128): # -> blockIdx.y
24#    for tk in range(K//64):
25#      MMA_128x128x64...
26#      TMA_128x64_64x128...
27#  Epilogue...
28#
29# This chapter introduces demonstrates:
30#  1. Partition shape based on block IDs
31#  2. Prologue
32#    2.1 Execute TMA Load for two input matrices for each stage
33#  3. Main loop
34#    3.1 Wait for completion of TMA load with mbarrier
35#    3.2 Performs Tensor Core GEMM 64x128x64 by warpgroup
36#    3.3 Load next stage if needed
37#  4. Epilogue
38#    4.1 Store fragmented registers to shared memory
39#    4.2 Store shared memory to global
40#
41# ===----------------------------------------------------------------------===//
42
43
44from mlir import ir
45from mlir.dialects import gpu, scf, nvgpu, nvvm
46from mlir.extras import types as T
47from tools.nvdsl import *
48import numpy as np
49
50
51def partition_shape():
52    """
53    Calculate the partition shape based on the block IDs.
54
55    It partitions the shape like below:
56    for(.. i < M ...)   --> blockIdx.x
57     for(.. j < N ...)  --> blockIdx.y
58      for(.. k < K ...)
59
60    Returns:
61        dimX (int): Dimension along the x-axis.
62        dimY (int): Dimension along the y-axis.
63    """
64    bidx = gpu.block_id(gpu.Dimension.x)
65    bidy = gpu.block_id(gpu.Dimension.y)
66    dimX = bidx * TILE_M
67    dimY = bidy * TILE_N
68    return dimX, dimY
69
70
71def tma_load(
72    mbar_group: Mbarriers,
73    a_tma: TMA,
74    b_tma: TMA,
75    slot,
76    stage,
77    num_stages,
78    p=None,
79):
80    """
81    TMA loads two input matrices from global memory to shared memory. It performs the following operations:
82
83       - tma.load a_shared_memory[off_x]  at coordinate [x, z]      (Loads 128x64)
84       - tma.load b_shared_memory[off_y1] at coordinate [y, x]      (Loads 64x64)
85       - tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64)
86
87       mbarrier.arrive ta_count = 128x64x2x4
88    """
89    dimX, dimY = partition_shape()
90
91    tidx = gpu.thread_id(gpu.Dimension.x)
92    begin_b = num_stages * get_type_size(a_tma.tma_memref)
93    size_tma_a = get_type_size(a_tma.tma_memref)
94    size_tma_b = get_type_size(b_tma.tma_memref)
95    ta_count = size_tma_a + (size_tma_b * 2)
96    tidx = gpu.thread_id(gpu.Dimension.x)
97
98    p = tidx == 0 if p is None else p
99
100    off_a = slot * size_tma_a
101    off_b = (slot * size_tma_a) + begin_b
102    off_b2 = off_b + size_tma_b
103    a_elem_ty = a_tma.tma_memref.element_type
104    b_elem_ty = b_tma.tma_memref.element_type
105    a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty, off_a)
106    b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
107    b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
108
109    mbar_group[slot].arrive(ta_count, predicate=p)
110
111    c1 = stage * 64
112    a_tma.load(a, mbar_group[slot], coords=[c1, dimX], predicate=p)
113    b_tma.load(b1, mbar_group[slot], coords=[dimY, c1], predicate=p)
114    b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
115
116
117def initialize(a_tma: TMA, b_tma: TMA, num_stages):
118    """
119    Initialize mbarriers and prefetch TMA descriptors.
120    """
121    tidx = gpu.thread_id(gpu.Dimension.x)
122    mbar_group = Mbarriers(number_of_barriers=num_stages)
123    isThread0 = tidx == const(0)
124    with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
125        for i in scf.for_(0, num_stages, 1):
126            mbar_group[i].init(1)
127            scf.yield_([])
128        a_tma.prefetch()
129        b_tma.prefetch()
130        scf.yield_([])
131
132    return mbar_group
133
134
135def prologue(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA, num_stages):
136    """
137    Prologue of the GEMM kernel. It loads 2 input matrices for each stage in loop like below:
138
139    for stage in range(NUM_STAGES):
140        tma_load x, y, stage
141
142    """
143    ns = num_stages if num_stages == 1 else num_stages - 1
144    for iv in scf.for_(0, ns, 1):
145        tma_load(mbar_group, a_tma, b_tma, iv, iv, num_stages)
146        scf.yield_([])
147
148
149def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA, num_stages):
150    """
151    Main loop of the Multistage GEMM kernel. It iterates through
152    stages and performs matrix multiplication, loading data by TMA to shared memory. It like following
153
154    MatrixAccumulator D
155    for k in range(K // TILE_K):
156
157        try_wait(stage, ...)    # Wait TMA load
158
159        Matrix A(stage, ...)    # Find shared memory slot
160        Matrix B(stage, ...)    # Find shared memory slot
161        D += A @ B              # Multiply and accumulate
162
163        if(needLoad)            # Load next stage if needed
164            tma_load(x, y, nextSlot, nextStage)
165
166    """
167    ns = num_stages if num_stages == 1 else num_stages - 1
168
169    tidx = gpu.thread_id(gpu.Dimension.x)
170    begin_b = num_stages * get_type_size(a_tma.tma_memref)
171
172    size_a = TILE_M * TILE_K * get_type_size(T.f16())
173
174    # Initialize A and B (input matrices) and C (accumulator)
175    A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma)
176    B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma)
177    D = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32())
178
179    phase = const(False, ty=T.bool())
180
181    # Main Loop
182    for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [D.acc_op, phase])
183    with ir.InsertionPoint(for_op.body):
184        phase = for_op.inner_iter_args[1]
185        iv = for_op.induction_variable
186        stage = iv % num_stages
187
188        # Wait for current stage
189        mbar_group[stage].try_wait(phase=phase)
190
191        # Find shared memory slot
192        offset_a = stage * size_a
193        offset_b = offset_a + begin_b
194        a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a)
195        b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b)
196
197        # Iterate input matrices, update accumulator
198        A.update_smem(a_smem)
199        B.update_smem(b_smem)
200        D.update_accumulator(for_op.inner_iter_args[0])
201
202        # Matrix Multiply
203        D += A @ B
204
205        # Wait Tensor Core for single stage
206        if num_stages == 1:
207            nvvm.WgmmaWaitGroupSyncOp(0)
208
209        # Load next stage
210        pred = ((iv + ns) < const(K // TILE_K)) & (tidx == 0)
211        nextStage = iv + ns
212        nextSlot = nextStage % num_stages
213        tma_load(mbar_group, a_tma, b_tma, nextSlot, nextStage, num_stages, pred)
214
215        # Switch phase parity for the mbarrier
216        newPhase = arith.select(
217            stage == (num_stages - 1),
218            (phase ^ const(True, ty=T.bool())),
219            phase,
220        )
221        scf.yield_([D.acc_op, newPhase])
222
223    nvvm.WgmmaWaitGroupSyncOp(0)
224
225    D.update_accumulator(for_op.results[0])
226    return D
227
228
229def epilogue(D: WGMMAMatrix, d_dev):
230    """
231    Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
232
233    MatrixAccumulator D               # Fragmented results
234    store D -> Shared Memory          # Store Shared Memory
235    Shared Memory -> Z[dimX][dimY]    # Store Shared Memory to Global Memory
236
237    """
238    tidx = gpu.thread_id(gpu.Dimension.x)
239    dimX, dimY = partition_shape()
240
241    d_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
242    d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
243
244    # Store (registers -> shared memory)
245    D.store_accumulator(d_smem)
246    gpu.barrier()
247
248    # Store (shared memory --> global memory)
249    for i in scf.for_(0, TILE_M, 1):
250        val = memref.load(d_smem, [i, tidx])
251        memref.store(val, d_gmem, [i, tidx])
252        scf.yield_([])
253
254
255# The decorator generates
256#   a -> memref<MxKxf16>
257#   b -> memref<NxKf16>
258#   d -> memref<MxNxf32>
259@NVDSL.mlir_func
260def gemm_multistage(a, b, d, num_stages):
261    token_ty = gpu.AsyncTokenType.get()
262    t1 = gpu.wait(token_ty, [])
263    a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
264    b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
265    d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
266    t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
267    t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
268    t7 = gpu.wait(token_ty, [t6])
269
270    sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
271    a_tma = TMA([128, 64], a.type, swizzle=sw)
272    b_tma = TMA([64, 64], b.type, swizzle=sw)
273    a_tma.create_descriptor(a_dev)
274    b_tma.create_descriptor(b_dev)
275
276    grid = [(M // TILE_M), (N // TILE_N), 1]
277    block = [128, 1, 1]
278
279    size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K
280    size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K
281    smem_size_in_bytes = (size_a + size_b) * num_stages
282
283    @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
284    def gemm_multistage_kernel():
285        # Initialize mbarriers and prefetch TMA descriptors
286        mbar_group = initialize(a_tma, b_tma, num_stages)
287
288        # Fill the pipeline stages
289        prologue(mbar_group, a_tma, b_tma, num_stages)
290
291        # Main loop
292        D = mainloop(mbar_group, a_tma, b_tma, num_stages)
293
294        # Store registers to global memory
295        epilogue(D, d_dev)
296
297    gemm_multistage_kernel()
298
299    t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
300    gpu.wait(None, [t8])
301
302
303# Python pass arguments to MLIR
304N = 256
305M = 512
306K = 1024
307TILE_M = 128
308TILE_N = 128
309TILE_K = 64
310a = np.random.randn(M, K).astype(np.float16)
311b = np.random.randn(K, N).astype(np.float16)
312d = np.zeros((M, N), np.float32)
313
314gemm_multistage(a, b, d, num_stages=7)
315
316
317# Verify MLIR with reference computation
318ref_d = a.astype(np.float16) @ b.astype(np.float16)
319np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
320
321
322print("PASS")
323# CHECK-NOT: Mismatched elements
324