xref: /llvm-project/mlir/test/Examples/NVGPU/Ch5.py (revision f8ff9094711b74d3f695f7571f6390f8a481fc52)
1# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2# RUN:   %PYTHON %s | FileCheck %s
3
4# ===----------------------------------------------------------------------===//
5#  Chapter 5 : Warp Specialized GEMM with Tensor Core
6# ===----------------------------------------------------------------------===//
7#
8# This program demonstrates a GEMM operation for `f32+=f16*f16`, utilizing the
9# Warp Specialized method with a tile size of 128x128x64. The code completely
10# parallelizes the two outermost loops into thread blocks. It launches two Warp
11# Groups (256 threads in total): one for the producer and the other for the consumer.
12# Each group takes a different control-flow. The producer thread group is responsible
13# for loading data into shared memory, while the consumer group executes the Tensor
14# Core GEMM operation and epilogue.
15#
16#  for ti in range(M//128):  # -> blockIdx.x
17#   for tj in range(N//128): # -> blockIdx.y
18#    with wg_producer:
19#     for tk in range(K//64):
20#        TMA_128x64_64x128...
21#    with wg_consumer:
22#     for tk in range(K//64):
23#        MMA_128x128x64...
24#     Epilogue..
25#
26# This chapter demonstrates:
27#  2 WG (warpgroups)
28#    Producer:
29#       2.1.1 Wait MMA Barrier
30#       2.1.1 Load TMA with TMA barrier
31#       2.1.1 Arrive TMA barrier with txcount
32#    Consumer:
33#       Loop
34#           Wait TMA barrier
35#           Performs Tensor Core GEMM 64x128x64 by warpgroup
36#           Arrive MMA Barrier
37#       Epilogue
38#           Store fragmented registers to shared memory
39#           Store shared memory to global
40#
41# ===----------------------------------------------------------------------===//
42
43
44from mlir import ir
45from mlir.dialects import gpu, scf, nvgpu, nvvm
46from mlir.extras import types as T
47from tools.nvdsl import *
48import numpy as np
49
50
51def partition_shape():
52    """
53    Calculate the partition shape based on the block IDs.
54
55    It parallelizes the two outermost loops into thread blocks.
56    for ti in range(M//128):    # -> blockIdx.x
57     for tj in range(N//128):   # -> blockIdx.y
58      D = 0
59      for tk in range(K//64):
60       for i in range(128):
61        for j in range(128):
62         for k in range(64):
63           FMA
64
65    Returns:
66        dimX (int): Dimension along the x-axis.
67        dimY (int): Dimension along the y-axis.
68    """
69    bidx = gpu.block_id(gpu.Dimension.x)
70    bidy = gpu.block_id(gpu.Dimension.y)
71    dimX = bidx * TILE_M
72    dimY = bidy * TILE_N
73    return dimX, dimY
74
75
76def tma_load(
77    mbar_group: Mbarriers,
78    a_tma: TMA,
79    b_tma: TMA,
80    slot,
81    stage,
82    num_stages,
83    p=None,
84):
85    """
86    TMA loads two input matrices from global memory to shared memory. It performs the following operations:
87
88       - tma.load a_shared_memory[off_x]  at coordinate [x, z]      (Loads 128x64)
89       - tma.load b_shared_memory[off_y1] at coordinate [y, x]      (Loads 64x64)
90       - tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64)
91
92       mbarrier.arrive ta_count = 128x64x2x4
93    """
94    dimX, dimY = partition_shape()
95
96    tidx = gpu.thread_id(gpu.Dimension.x)
97    begin_b = num_stages * get_type_size(a_tma.tma_memref)
98    size_tma_a = get_type_size(a_tma.tma_memref)
99    size_tma_b = get_type_size(b_tma.tma_memref)
100    ta_count = size_tma_a + (size_tma_b * 2)
101
102    off_a = slot * size_tma_a
103    off_b = (slot * size_tma_a) + begin_b
104    off_b2 = off_b + size_tma_b
105    a_elem_ty = a_tma.tma_memref.element_type
106    b_elem_ty = b_tma.tma_memref.element_type
107    a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty, off_a)
108    b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
109    b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
110
111    mbar_group[slot].arrive(ta_count, predicate=p)
112    p = (tidx % WARP_GROUP_SIZE) == 0
113    c1 = stage * 64
114    a_tma.load(a, mbar_group[slot], coords=[c1, dimX], predicate=p)
115    b_tma.load(b1, mbar_group[slot], coords=[dimY, c1], predicate=p)
116    b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
117
118
119def initialize(a_tma: TMA, b_tma: TMA, num_stages):
120    """
121    Initialize mbarriers and prefetch TMA descriptors.
122    """
123    tidx = gpu.thread_id(gpu.Dimension.x)
124    mbar_group_tma = Mbarriers(number_of_barriers=num_stages)
125    mbar_group_mma = Mbarriers(number_of_barriers=num_stages)
126    isThread0 = tidx == const(0)
127    with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
128        for i in scf.for_(0, num_stages, 1):
129            mbar_group_tma[i].init(1)
130            mbar_group_mma[i].init(1)
131            scf.yield_([])
132        a_tma.prefetch()
133        b_tma.prefetch()
134        scf.yield_([])
135
136    return mbar_group_tma, mbar_group_mma
137
138
139def switch_phase(stage, phase, num_stages):
140    p = stage == (num_stages - 1)
141    phase = arith.select(
142        p,
143        (phase ^ const(True, ty=T.bool())),
144        phase,
145    )
146    return phase
147
148
149def producer_loop(
150    mbar_tma: Mbarriers,
151    mbar_mma: Mbarriers,
152    a_tma: TMA,
153    b_tma: TMA,
154    wg_me: Warpgroup,
155    num_stages,
156):
157    phase = const(True, ty=T.bool())
158
159    for iv, phase in scf.for_(0, (K // TILE_K), 1, [phase]):
160        stage = iv % num_stages
161        # Wait MMA to be done
162        mbar_mma[stage].try_wait(phase)
163        # New phase for mbarrier
164        phase = switch_phase(stage, phase, num_stages)
165        # TMA Load
166        tma_load(mbar_tma, a_tma, b_tma, stage, iv, num_stages, wg_me.is_wg_primary)
167        scf.yield_([phase])
168
169
170def consumer_loop(
171    mbar_tma: Mbarriers,
172    mbar_mma: Mbarriers,
173    a_tma: TMA,
174    b_tma: TMA,
175    wg_me: Warpgroup,
176    num_stages,
177):
178    begin_b = num_stages * get_type_size(a_tma.tma_memref)
179
180    size_a = TILE_M * TILE_K * get_type_size(T.f16())
181
182    phase = const(False, ty=T.bool())
183    A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma)
184    B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma)
185    D = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32())
186
187    for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [D.acc_op, phase])
188    with ir.InsertionPoint(for_op.body):
189        phase = for_op.inner_iter_args[1]
190        iv = for_op.induction_variable
191        stage = iv % num_stages
192
193        # Wait TMA for current stage
194        mbar_tma[stage].try_wait(phase)
195
196        # Find shared memory slot
197        offset_a = stage * size_a
198        offset_b = offset_a + begin_b
199        a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a)
200        b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b)
201
202        # Iterate input matrices, update accumulator
203        A.update_smem(a_smem)
204        B.update_smem(b_smem)
205        D.update_accumulator(for_op.inner_iter_args[0])
206
207        # Matrix Multiply
208        D += A @ B
209
210        # MMA Barrier Arrive
211        p_arrive = (iv > 0) & wg_me.is_wg_primary
212        with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
213            barId = arith.select((stage == 0), const(num_stages - 1), (stage - 1))
214            mbar_mma[barId].arrive()
215            scf.yield_([])
216
217        phase = switch_phase(stage, phase, num_stages)
218        scf.yield_([D.acc_op, phase])
219
220    nvvm.WgmmaWaitGroupSyncOp(0)
221    D.update_accumulator(for_op.results[0])
222    return D
223
224
225def epilogue(D: WGMMAMatrix, d_dev):
226    """
227    Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
228
229    MatrixAccumulator D               # Fragmented results
230    store D -> Shared Memory          # Store Shared Memory
231    Shared Memory -> Z[dimX][dimY]    # Store Shared Memory to Global Memory
232
233    """
234    tidx = gpu.thread_id(gpu.Dimension.x)
235    dimX, dimY = partition_shape()
236    # s = tidx - WARP_GROUP_SIZE
237    # debug_print("[Epilogue] store to global memory @ s={}", s)
238
239    d_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
240    d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
241
242    # Store (registers -> shared memory)
243    D.store_accumulator(d_smem)
244    gpu.barrier()
245
246    # Store (shared memory --> global memory)
247    for i in scf.for_(0, TILE_M, 1):
248        val = memref.load(d_smem, [i, tidx])
249        memref.store(val, d_gmem, [i, tidx])
250        scf.yield_([])
251
252
253@NVDSL.mlir_func
254def gemm_warp_specialized(a, b, d, num_stages):
255    token_ty = gpu.AsyncTokenType.get()
256    t1 = gpu.wait(token_ty, [])
257    a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
258    b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
259    d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
260    t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
261    t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
262    t7 = gpu.wait(token_ty, [t6])
263
264    sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
265    a_tma = TMA([128, 64], a.type, swizzle=sw)
266    b_tma = TMA([64, 64], b.type, swizzle=sw)
267    a_tma.create_descriptor(a_dev)
268    b_tma.create_descriptor(b_dev)
269
270    grid = [(M // TILE_M), (N // TILE_N), 1]
271    block = [256, 1, 1]
272
273    size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K
274    size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K
275    smem_size_in_bytes = (size_a + size_b) * num_stages
276
277    @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
278    def gemm_warp_specialized_kernel():
279        # Init Warpgroups
280        wg_producer = Warpgroup(primary_thread=128, register_size=40)
281        wg_consumer = Warpgroup(primary_thread=0, register_size=232)
282
283        # Initialize mbarriers and prefetch TMA descriptors
284        mbar_mma, mbar_tma = initialize(a_tma, b_tma, num_stages)
285
286        # Producer performs TMA
287        with wg_producer:
288            producer_loop(mbar_tma, mbar_mma, a_tma, b_tma, wg_producer, num_stages)
289
290        # Consumer performs MMA/Tensor Core
291        with wg_consumer:
292            D = consumer_loop(mbar_tma, mbar_mma, a_tma, b_tma, wg_consumer, num_stages)
293            epilogue(D, d_dev)
294
295    gemm_warp_specialized_kernel()
296
297    t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
298    gpu.wait(None, [t8])
299
300
301# Python pass arguments to MLIR
302N = 256
303M = 512
304K = 1024
305TILE_M = 128
306TILE_N = 128
307TILE_K = 64
308a = np.random.randn(M, K).astype(np.float16)
309b = np.random.randn(K, N).astype(np.float16)
310d = np.zeros((M, N), np.float32)
311
312gemm_warp_specialized(a, b, d, num_stages=7)
313
314
315# Verify MLIR with reference computation
316ref_d = a.astype(np.float16) @ b.astype(np.float16)
317np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
318
319
320print("PASS")
321# CHECK-NOT: Mismatched elements
322