xref: /llvm-project/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py (revision 13d6233e77982f2a596922a79365373e1466a968)
1d95e6d02SGuray Ozenimport numpy as np
2d95e6d02SGuray Ozenfrom mlir import ir
3d95e6d02SGuray Ozenfrom mlir.dialects import arith
4d95e6d02SGuray Ozenfrom mlir.dialects import func
5d95e6d02SGuray Ozenfrom mlir.dialects import gpu
6d95e6d02SGuray Ozenfrom mlir.dialects import memref
7d95e6d02SGuray Ozenfrom mlir.dialects import nvgpu
8d95e6d02SGuray Ozenfrom mlir.dialects import nvvm
9d95e6d02SGuray Ozenfrom mlir.dialects import llvm
10d95e6d02SGuray Ozenfrom mlir.dialects import builtin
11d95e6d02SGuray Ozenfrom mlir.dialects import scf
12d95e6d02SGuray Ozenfrom mlir.dialects import vector
13d95e6d02SGuray Ozenfrom mlir.extras import types as T
14d95e6d02SGuray Ozen
15d95e6d02SGuray OzenTMA_LAST_DIM_F16 = 64  # 128B flaot16
16d95e6d02SGuray OzenWARP_SIZE = 32
17d95e6d02SGuray OzenWARP_GROUP_SIZE = WARP_SIZE * 4
18d95e6d02SGuray Ozen
19d95e6d02SGuray OzenPRODUCER_REGISTER_SIZE = 40
20d95e6d02SGuray OzenCONSUMER_REGISTER_SIZE = 232
21d95e6d02SGuray Ozen
22d95e6d02SGuray OzenPRODUCER_PRIMARY_THREAD = 128
23d95e6d02SGuray OzenCONSUMER_PRIMARY_THREAD = 0
24d95e6d02SGuray Ozen
25d95e6d02SGuray Ozen# C++ uses this value to understand whether it's dynamic or not.
26d95e6d02SGuray OzenMLIR_DYNAMIC = -9223372036854775808
27d95e6d02SGuray Ozen
28d95e6d02SGuray OzenDEBUG = False
29d95e6d02SGuray Ozen
30d95e6d02SGuray Ozen
31c82f45f9SGuray Ozenclass TmaDescriptorBuilder:
32c82f45f9SGuray Ozen    """A class that builds a TMA descriptor."""
33c82f45f9SGuray Ozen
34c82f45f9SGuray Ozen    def __init__(self, swizzle, l2promo, oob, interleave, tma_box_shape, memref_ty):
35c82f45f9SGuray Ozen        self.swizzle = swizzle  # mlir.nvgpu.TensorMapSwizzleKind
36c82f45f9SGuray Ozen        self.l2promo = l2promo  # mlir.nvgpu.TensorMapL2PromoKind
37c82f45f9SGuray Ozen        self.oob = oob  # mlir.nvgpu.TensorMapOOBKind
38c82f45f9SGuray Ozen        self.interleave = interleave  # mlir.nvgpu.TensorMapInterleaveKind
39c82f45f9SGuray Ozen        self.tma_box_shape = tma_box_shape
40c82f45f9SGuray Ozen        self.memref_ty = memref_ty  # MemRefType
41c82f45f9SGuray Ozen
42c82f45f9SGuray Ozen    @property
43c82f45f9SGuray Ozen    def tensormap_descriptor_ty(self):
44c82f45f9SGuray Ozen        """Returns a tensormap descriptor type."""
45c82f45f9SGuray Ozen        tensorMemrefType = ir.MemRefType.get(
46c82f45f9SGuray Ozen            self.tma_box_shape,
47c82f45f9SGuray Ozen            self.memref_ty.element_type,
48c82f45f9SGuray Ozen            memory_space=ir.Attribute.parse("3"),
49c82f45f9SGuray Ozen        )
50c82f45f9SGuray Ozen        return nvgpu.TensorMapDescriptorType.get(
51c82f45f9SGuray Ozen            tensorMemrefType,
52c82f45f9SGuray Ozen            self.swizzle,
53c82f45f9SGuray Ozen            self.l2promo,
54c82f45f9SGuray Ozen            self.oob,
55c82f45f9SGuray Ozen            self.interleave,
56c82f45f9SGuray Ozen        )
57c82f45f9SGuray Ozen
58c82f45f9SGuray Ozen    def tma_descriptor_op(self, device_ptr):
59c82f45f9SGuray Ozen        """Returns a tensormap descriptor op."""
60c82f45f9SGuray Ozen        tma_descriptor_ty = self.tensormap_descriptor_ty
61c82f45f9SGuray Ozen        device_unranked_memref = memref.CastOp(
62c82f45f9SGuray Ozen            ir.UnrankedMemRefType.get(
63c82f45f9SGuray Ozen                self.memref_ty.element_type, self.memref_ty.memory_space
64c82f45f9SGuray Ozen            ),
65c82f45f9SGuray Ozen            device_ptr,
66c82f45f9SGuray Ozen        )
67c82f45f9SGuray Ozen        tma_descriptor_op = nvgpu.TmaCreateDescriptorOp(
68c82f45f9SGuray Ozen            tma_descriptor_ty, device_unranked_memref, map(c, self.tma_box_shape)
69c82f45f9SGuray Ozen        )
70c82f45f9SGuray Ozen        return tma_descriptor_op.result
71c82f45f9SGuray Ozen
72c82f45f9SGuray Ozen
73d95e6d02SGuray Ozendef debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
74d95e6d02SGuray Ozen    if not DEBUG and not forcePrint:
75d95e6d02SGuray Ozen        return
76d95e6d02SGuray Ozen    type_formats = []
77d95e6d02SGuray Ozen    for arg in args:
78d95e6d02SGuray Ozen        ty_format = None
79d95e6d02SGuray Ozen        if ir.IndexType.isinstance(arg.type):
80d95e6d02SGuray Ozen            ty_format = "%llu"
81d95e6d02SGuray Ozen        if ir.IntegerType.isinstance(arg.type):
82d95e6d02SGuray Ozen            width = ir.IntegerType(arg.type).width
83d95e6d02SGuray Ozen            if width == 64:
84d95e6d02SGuray Ozen                ty_format = "%llu"
85d95e6d02SGuray Ozen            elif width == 32:
86d95e6d02SGuray Ozen                ty_format = "%d"
87d95e6d02SGuray Ozen            elif width == 1:
88d95e6d02SGuray Ozen                ty_format = "%i"
89d95e6d02SGuray Ozen        if ir.F32Type.isinstance(arg.type):
90d95e6d02SGuray Ozen            ty_format = "%f"
91d95e6d02SGuray Ozen        if ty_format is None:
92d95e6d02SGuray Ozen            raise NotImplementedError(arg.type)
93d95e6d02SGuray Ozen        type_formats.append(ty_format)
94d95e6d02SGuray Ozen    if threadNumber != -1:
95d95e6d02SGuray Ozen        tidx = gpu.thread_id(gpu.Dimension.x)
96d95e6d02SGuray Ozen        predicate = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(threadNumber))
97d95e6d02SGuray Ozen        scf.yield_([])
98d95e6d02SGuray Ozen    if_op = scf.IfOp(predicate)
99d95e6d02SGuray Ozen    with ir.InsertionPoint(if_op.then_block):
100d95e6d02SGuray Ozen        gpu.printf(fmt.format(*type_formats) + "\n", args)
101d95e6d02SGuray Ozen        scf.yield_([])
102d95e6d02SGuray Ozen
103d95e6d02SGuray Ozen
104d95e6d02SGuray Ozendef get_type_size(ty):
105d95e6d02SGuray Ozen    if ir.FloatType.isinstance(ty):
106d95e6d02SGuray Ozen        return ir.FloatType(ty).width // 8
107d95e6d02SGuray Ozen    if ir.IntegerType.isinstance(ty):
108d95e6d02SGuray Ozen        return ir.IntegerType(ty).width // 8
109d95e6d02SGuray Ozen    raise NotImplementedError(ty)
110d95e6d02SGuray Ozen
111d95e6d02SGuray Ozen
112d95e6d02SGuray Ozendef get_mlir_ty(dtype):
113d95e6d02SGuray Ozen    if dtype == np.float16:
114d95e6d02SGuray Ozen        return T.f16()
115d95e6d02SGuray Ozen    if dtype == np.float32:
116d95e6d02SGuray Ozen        return T.f32()
117d95e6d02SGuray Ozen    if dtype == np.float64:
118d95e6d02SGuray Ozen        return T.f64()
119d95e6d02SGuray Ozen    if dtype == np.int32:
120d95e6d02SGuray Ozen        return T.i32()
121d95e6d02SGuray Ozen    if dtype == np.int64:
122d95e6d02SGuray Ozen        return T.i64()
123d95e6d02SGuray Ozen    raise NotImplementedError(dtype)
124d95e6d02SGuray Ozen
125d95e6d02SGuray Ozen
126d95e6d02SGuray Ozendef c(value, ty=None):
127d95e6d02SGuray Ozen    ty = T.index() if ty is None else ty
128d95e6d02SGuray Ozen    return arith.constant(ty, value)
129d95e6d02SGuray Ozen
130d95e6d02SGuray Ozen
131d95e6d02SGuray Ozendef make_kernel_name(
132d95e6d02SGuray Ozen    input_type=np.float16,
133d95e6d02SGuray Ozen    output_type=np.float32,
134d95e6d02SGuray Ozen    M=4096,
135d95e6d02SGuray Ozen    N=4096,
136d95e6d02SGuray Ozen    K=4096,
137d95e6d02SGuray Ozen    BLOCK_M=128,
138d95e6d02SGuray Ozen    BLOCK_N=128,
139d95e6d02SGuray Ozen    BLOCK_K=128,
140d95e6d02SGuray Ozen    num_stages=3,
141d95e6d02SGuray Ozen    use_warp_specialization=False,
142d95e6d02SGuray Ozen):
143d95e6d02SGuray Ozen    kernelName = "warpspecialized" if use_warp_specialization else "multistage"
144d95e6d02SGuray Ozen    return (
145d95e6d02SGuray Ozen        kernelName
146d95e6d02SGuray Ozen        + "_"
147d95e6d02SGuray Ozen        + str(M)
148d95e6d02SGuray Ozen        + "x"
149d95e6d02SGuray Ozen        + str(N)
150d95e6d02SGuray Ozen        + "x"
151d95e6d02SGuray Ozen        + str(K)
152d95e6d02SGuray Ozen        + "_"
153d95e6d02SGuray Ozen        + str(BLOCK_M)
154d95e6d02SGuray Ozen        + "x"
155d95e6d02SGuray Ozen        + str(BLOCK_N)
156d95e6d02SGuray Ozen        + "x"
157d95e6d02SGuray Ozen        + str(BLOCK_K)
158d95e6d02SGuray Ozen        + "_"
159d95e6d02SGuray Ozen        + str(num_stages)
160d95e6d02SGuray Ozen    )
161d95e6d02SGuray Ozen
162d95e6d02SGuray Ozen
163d95e6d02SGuray Ozendef generate_matmul_ws(
164d95e6d02SGuray Ozen    input_type=np.float16,
165d95e6d02SGuray Ozen    output_type=np.float32,
166d95e6d02SGuray Ozen    M=4096,
167d95e6d02SGuray Ozen    N=4096,
168d95e6d02SGuray Ozen    K=4096,
169d95e6d02SGuray Ozen    BLOCK_M=128,
170d95e6d02SGuray Ozen    BLOCK_N=128,
171d95e6d02SGuray Ozen    BLOCK_K=128,
172d95e6d02SGuray Ozen    num_stages=3,
173d95e6d02SGuray Ozen):
174d95e6d02SGuray Ozen    # Limitaitons for now
175d95e6d02SGuray Ozen    assert input_type == np.float16
176d95e6d02SGuray Ozen    assert output_type == np.float32
177d95e6d02SGuray Ozen    assert BLOCK_M == 128
178d95e6d02SGuray Ozen    assert BLOCK_N == 128
179d95e6d02SGuray Ozen    assert BLOCK_K == 64
180d95e6d02SGuray Ozen    assert M % BLOCK_M == 0
181d95e6d02SGuray Ozen    assert N % BLOCK_N == 0
182d95e6d02SGuray Ozen    assert K % BLOCK_K == 0
183d95e6d02SGuray Ozen
184d95e6d02SGuray Ozen    module = ir.Module.create()
185f8ff9094SGuray Ozen    token_ty = gpu.AsyncTokenType.get()
186d95e6d02SGuray Ozen    a_elem_ty = get_mlir_ty(input_type)
187d95e6d02SGuray Ozen    b_elem_ty = get_mlir_ty(input_type)
188d95e6d02SGuray Ozen    c_elem_ty = get_mlir_ty(output_type)
189d95e6d02SGuray Ozen    a_ty = ir.MemRefType.get([M, K], a_elem_ty)
190d95e6d02SGuray Ozen    b_ty = ir.MemRefType.get((K, N), b_elem_ty)
191d95e6d02SGuray Ozen    c_ty = ir.MemRefType.get((M, N), c_elem_ty)
192d95e6d02SGuray Ozen    a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
193d95e6d02SGuray Ozen    b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
194d95e6d02SGuray Ozen    b_tile_shape = (BLOCK_K, BLOCK_N)
195d95e6d02SGuray Ozen    txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + (
196d95e6d02SGuray Ozen        a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty)
197d95e6d02SGuray Ozen    )
198d95e6d02SGuray Ozen    smem_space_str = "#gpu.address_space<workgroup>"
199d95e6d02SGuray Ozen    smem_space = ir.Attribute.parse(smem_space_str)
200d95e6d02SGuray Ozen    mbar_ty = ir.Type.parse(
201d95e6d02SGuray Ozen        "!nvgpu.mbarrier.group<memorySpace = "
202d95e6d02SGuray Ozen        + str(smem_space)
203d95e6d02SGuray Ozen        + ", num_barriers = "
204d95e6d02SGuray Ozen        + str(num_stages)
205d95e6d02SGuray Ozen        + ">"
206d95e6d02SGuray Ozen    )
207d95e6d02SGuray Ozen    acc_ty = ir.Type.parse(
208d95e6d02SGuray Ozen        "!nvgpu.warpgroup.accumulator<fragmented=vector<"
209d95e6d02SGuray Ozen        + str(BLOCK_M)
210d95e6d02SGuray Ozen        + "x"
211d95e6d02SGuray Ozen        + str(BLOCK_N)
212d95e6d02SGuray Ozen        + "x"
213d95e6d02SGuray Ozen        + str(c_elem_ty)
214d95e6d02SGuray Ozen        + ">>"
215d95e6d02SGuray Ozen    )
216d95e6d02SGuray Ozen    a_wgmma_ty = ir.Type.parse(
217d95e6d02SGuray Ozen        "!nvgpu.warpgroup.descriptor<tensor=memref<"
218d95e6d02SGuray Ozen        + str(BLOCK_M)
219d95e6d02SGuray Ozen        + "x"
220d95e6d02SGuray Ozen        + str(BLOCK_K)
221d95e6d02SGuray Ozen        + "x"
222d95e6d02SGuray Ozen        + str(a_elem_ty)
223d95e6d02SGuray Ozen        + ", "
224d95e6d02SGuray Ozen        + smem_space_str
225d95e6d02SGuray Ozen        + ">>"
226d95e6d02SGuray Ozen    )
227d95e6d02SGuray Ozen    b_wgmma_ty = ir.Type.parse(
228d95e6d02SGuray Ozen        "!nvgpu.warpgroup.descriptor<tensor=memref<"
229d95e6d02SGuray Ozen        + str(BLOCK_K)
230d95e6d02SGuray Ozen        + "x"
231d95e6d02SGuray Ozen        + str(BLOCK_N)
232d95e6d02SGuray Ozen        + "x"
233d95e6d02SGuray Ozen        + str(a_elem_ty)
234d95e6d02SGuray Ozen        + ", "
235d95e6d02SGuray Ozen        + smem_space_str
236d95e6d02SGuray Ozen        + ">>"
237d95e6d02SGuray Ozen    )
238d95e6d02SGuray Ozen    kernelName = make_kernel_name(
239d95e6d02SGuray Ozen        input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_stages, True
240d95e6d02SGuray Ozen    )
241d95e6d02SGuray Ozen    with ir.InsertionPoint(module.body):
242d95e6d02SGuray Ozen        fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], []))
243d95e6d02SGuray Ozen        with ir.InsertionPoint(fop.add_entry_block()):
244d95e6d02SGuray Ozen            a_host = fop.arguments[0]
245d95e6d02SGuray Ozen            b_host = fop.arguments[1]
246d95e6d02SGuray Ozen            c_host = fop.arguments[2]
247d95e6d02SGuray Ozen            lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
248d95e6d02SGuray Ozen            rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
249d95e6d02SGuray Ozen            smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
250d95e6d02SGuray Ozen            smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty)
251d95e6d02SGuray Ozen            smem_size = max(smem_size_input, smem_size_output)
252d95e6d02SGuray Ozen
253d95e6d02SGuray Ozen            # Step 1. Allocate device memory and memcpy
254d95e6d02SGuray Ozen            t1 = gpu.wait(token_ty, [])
255d95e6d02SGuray Ozen            a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
256d95e6d02SGuray Ozen            b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
257d95e6d02SGuray Ozen            c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
258d95e6d02SGuray Ozen            t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
259d95e6d02SGuray Ozen            t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
260d95e6d02SGuray Ozen            t7 = gpu.wait(token_ty, [t6])
261d95e6d02SGuray Ozen
262d95e6d02SGuray Ozen            # Step 2. Create TMA Descriptors
263c82f45f9SGuray Ozen            a_tma_desc = TmaDescriptorBuilder(
264c82f45f9SGuray Ozen                nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
265c82f45f9SGuray Ozen                nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
266c82f45f9SGuray Ozen                nvgpu.TensorMapOOBKind.OOB_ZERO,
267c82f45f9SGuray Ozen                nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
268c82f45f9SGuray Ozen                a_tma_shape,
269c82f45f9SGuray Ozen                a_ty,
270d95e6d02SGuray Ozen            )
271c82f45f9SGuray Ozen
272c82f45f9SGuray Ozen            b_tma_desc = TmaDescriptorBuilder(
273c82f45f9SGuray Ozen                nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
274c82f45f9SGuray Ozen                nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
275c82f45f9SGuray Ozen                nvgpu.TensorMapOOBKind.OOB_ZERO,
276c82f45f9SGuray Ozen                nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
277c82f45f9SGuray Ozen                b_tma_shape,
278c82f45f9SGuray Ozen                b_ty,
279d95e6d02SGuray Ozen            )
280c82f45f9SGuray Ozen
281c82f45f9SGuray Ozen            a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device)
282c82f45f9SGuray Ozen            b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device)
283d95e6d02SGuray Ozen
284d95e6d02SGuray Ozen            # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
285d95e6d02SGuray Ozen            cta_m = M // BLOCK_M
286d95e6d02SGuray Ozen            cta_n = N // BLOCK_N
287d95e6d02SGuray Ozen            assert M % BLOCK_M == 0 and N % BLOCK_N == 0
288d95e6d02SGuray Ozen            grid = (cta_m, cta_n, 1)
289d95e6d02SGuray Ozen            block = (WARP_GROUP_SIZE * 2, 1, 1)
290d95e6d02SGuray Ozen            launch_op = gpu.LaunchOp(
291d95e6d02SGuray Ozen                token_ty,
292d95e6d02SGuray Ozen                [t7],
293d95e6d02SGuray Ozen                *map(c, grid),
294d95e6d02SGuray Ozen                *map(c, block),
295c82f45f9SGuray Ozen                dynamicSharedMemorySize=c(smem_size, ty=T.i32()),
296d95e6d02SGuray Ozen            )
297d95e6d02SGuray Ozen            launch_op.body.blocks.append(*([T.index()] * 12))
298d95e6d02SGuray Ozen            with ir.InsertionPoint(launch_op.body.blocks[0]):
299d95e6d02SGuray Ozen                # GPU Step 0. This is need for vectorized ld/st
300d95e6d02SGuray Ozen                memref.assume_alignment(c_device, 16)
301d95e6d02SGuray Ozen                dynamic_smem = gpu.dynamic_shared_memory(
302d95e6d02SGuray Ozen                    ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
303d95e6d02SGuray Ozen                )
304d95e6d02SGuray Ozen                ticks = c(10000000)
305d95e6d02SGuray Ozen
306d95e6d02SGuray Ozen                # GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc.
307d95e6d02SGuray Ozen                tidx = gpu.thread_id(gpu.Dimension.x)
308d95e6d02SGuray Ozen                wgPrimaryThread = arith.cmpi(
309d95e6d02SGuray Ozen                    arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0)
310d95e6d02SGuray Ozen                )
311d95e6d02SGuray Ozen                warp_id = arith.divui(tidx, c(32))
312d95e6d02SGuray Ozen                warpgroup_id = arith.divui(warp_id, c(4))
313d95e6d02SGuray Ozen                is_producer = arith.cmpi(
314d95e6d02SGuray Ozen                    arith.CmpIPredicate.eq,
315d95e6d02SGuray Ozen                    warpgroup_id,
316d95e6d02SGuray Ozen                    c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0),
317d95e6d02SGuray Ozen                )
318d95e6d02SGuray Ozen                is_consumer = arith.cmpi(
319d95e6d02SGuray Ozen                    arith.CmpIPredicate.eq,
320d95e6d02SGuray Ozen                    warpgroup_id,
321d95e6d02SGuray Ozen                    c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1),
322d95e6d02SGuray Ozen                )
323d95e6d02SGuray Ozen                producerPrimaryThread = arith.cmpi(
324d95e6d02SGuray Ozen                    arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD)
325d95e6d02SGuray Ozen                )
326d95e6d02SGuray Ozen                consumerPrimaryThread = arith.cmpi(
327d95e6d02SGuray Ozen                    arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD)
328d95e6d02SGuray Ozen                )
329d95e6d02SGuray Ozen                bidx = gpu.block_id(gpu.Dimension.x)
330d95e6d02SGuray Ozen                bidy = gpu.block_id(gpu.Dimension.y)
331d95e6d02SGuray Ozen                dimX = arith.muli(bidx, c(BLOCK_M))
332d95e6d02SGuray Ozen                dimY = arith.muli(bidy, c(BLOCK_N))
333d95e6d02SGuray Ozen
334d95e6d02SGuray Ozen                # GPU Step 2. Initialize mbarrier groups
335d95e6d02SGuray Ozen                mbarTMA = nvgpu.mbarrier_create(mbar_ty)
336d95e6d02SGuray Ozen                mbarDONE = nvgpu.mbarrier_create(mbar_ty)
337d95e6d02SGuray Ozen                for i in range(num_stages):
338d95e6d02SGuray Ozen                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread)
339d95e6d02SGuray Ozen                    nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread)
340d95e6d02SGuray Ozen                gpu.barrier()
341d95e6d02SGuray Ozen
342d95e6d02SGuray Ozen                # GPU Step 3. Prefetch TMA descriptors
343c82f45f9SGuray Ozen                nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=wgPrimaryThread)
344c82f45f9SGuray Ozen                nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=wgPrimaryThread)
345d95e6d02SGuray Ozen
346d95e6d02SGuray Ozen                ns = num_stages if num_stages == 1 else num_stages - 1
347d95e6d02SGuray Ozen                # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
348d95e6d02SGuray Ozen                with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
349d95e6d02SGuray Ozen                    # Step 5.1. Reduce register size
350d95e6d02SGuray Ozen                    nvvm.setmaxregister(
351d95e6d02SGuray Ozen                        PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease
352d95e6d02SGuray Ozen                    )
353d95e6d02SGuray Ozen
354d95e6d02SGuray Ozen                    # Step 5.2. TMA Main Loop
355d95e6d02SGuray Ozen                    for_op = scf.ForOp(
356d95e6d02SGuray Ozen                        c(0), c(K // BLOCK_K), c(1), [arith.constant(T.bool(), 1)]
357d95e6d02SGuray Ozen                    )
358d95e6d02SGuray Ozen                    with ir.InsertionPoint(for_op.body):
359d95e6d02SGuray Ozen                        phaseParity = for_op.inner_iter_args[0]
360d95e6d02SGuray Ozen                        iv = for_op.induction_variable
361d95e6d02SGuray Ozen                        stage = arith.remui(iv, c(num_stages))
362d95e6d02SGuray Ozen
363d95e6d02SGuray Ozen                        # Step 5.2.1. Wait mbarDONE
364d95e6d02SGuray Ozen                        debug_print(
365d95e6d02SGuray Ozen                            "[prod] iv={}  | mbarDONE[{}] try_wait  phase={}",
366d95e6d02SGuray Ozen                            iv,
367d95e6d02SGuray Ozen                            stage,
368d95e6d02SGuray Ozen                            phaseParity,
369d95e6d02SGuray Ozen                            predicate=producerPrimaryThread,
370d95e6d02SGuray Ozen                        )
371d95e6d02SGuray Ozen                        nvgpu.MBarrierTryWaitParityOp(
372d95e6d02SGuray Ozen                            mbarDONE, phaseParity, ticks, mbarId=stage
373d95e6d02SGuray Ozen                        )
374d95e6d02SGuray Ozen                        debug_print(
375d95e6d02SGuray Ozen                            "[prod] iv={}  | mbarDONE[{}] try_wait  phase={} [done]",
376d95e6d02SGuray Ozen                            iv,
377d95e6d02SGuray Ozen                            stage,
378d95e6d02SGuray Ozen                            phaseParity,
379d95e6d02SGuray Ozen                            predicate=producerPrimaryThread,
380d95e6d02SGuray Ozen                        )
381d95e6d02SGuray Ozen                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
382d95e6d02SGuray Ozen                        phaseParity = arith.select(
383d95e6d02SGuray Ozen                            p,
384d95e6d02SGuray Ozen                            arith.xori(phaseParity, arith.constant(T.bool(), 1)),
385d95e6d02SGuray Ozen                            phaseParity,
386d95e6d02SGuray Ozen                        )
387d95e6d02SGuray Ozen
388d95e6d02SGuray Ozen                        # Step 5.2.2. Load TMA
389d95e6d02SGuray Ozen                        a_offset = arith.muli(stage, c(lhs_tile_bytes))
390d95e6d02SGuray Ozen                        a_tma_slice = memref.view(
391d95e6d02SGuray Ozen                            ir.MemRefType.get(
392d95e6d02SGuray Ozen                                a_tma_shape, a_elem_ty, memory_space=smem_space
393d95e6d02SGuray Ozen                            ),
394d95e6d02SGuray Ozen                            dynamic_smem,
395d95e6d02SGuray Ozen                            a_offset,
396d95e6d02SGuray Ozen                            [],
397d95e6d02SGuray Ozen                        )
398d95e6d02SGuray Ozen                        b_offset = arith.addi(
399d95e6d02SGuray Ozen                            arith.muli(stage, c(rhs_tile_bytes)),
400d95e6d02SGuray Ozen                            c(lhs_tile_bytes * num_stages),
401d95e6d02SGuray Ozen                        )
402d95e6d02SGuray Ozen                        b_tma_slice_1 = memref.view(
403d95e6d02SGuray Ozen                            ir.MemRefType.get(
404d95e6d02SGuray Ozen                                b_tma_shape, b_elem_ty, memory_space=smem_space
405d95e6d02SGuray Ozen                            ),
406d95e6d02SGuray Ozen                            dynamic_smem,
407d95e6d02SGuray Ozen                            b_offset,
408d95e6d02SGuray Ozen                            [],
409d95e6d02SGuray Ozen                        )
410d95e6d02SGuray Ozen                        b_offset2 = arith.addi(
411d95e6d02SGuray Ozen                            b_offset,
412d95e6d02SGuray Ozen                            c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
413d95e6d02SGuray Ozen                        )
414d95e6d02SGuray Ozen                        b_tma_slice_2 = memref.view(
415d95e6d02SGuray Ozen                            ir.MemRefType.get(
416d95e6d02SGuray Ozen                                b_tma_shape, b_elem_ty, memory_space=smem_space
417d95e6d02SGuray Ozen                            ),
418d95e6d02SGuray Ozen                            dynamic_smem,
419d95e6d02SGuray Ozen                            b_offset2,
420d95e6d02SGuray Ozen                            [],
421d95e6d02SGuray Ozen                        )
422d95e6d02SGuray Ozen                        debug_print(
423d95e6d02SGuray Ozen                            "[prod] a_offset={} b_offset={} b_offset2={}",
424d95e6d02SGuray Ozen                            a_offset,
425d95e6d02SGuray Ozen                            b_offset,
426d95e6d02SGuray Ozen                            b_offset2,
427d95e6d02SGuray Ozen                            predicate=producerPrimaryThread,
428d95e6d02SGuray Ozen                        )
429d95e6d02SGuray Ozen                        coord = arith.muli(c(64), iv)
430d95e6d02SGuray Ozen                        nvgpu.TmaAsyncLoadOp(
431d95e6d02SGuray Ozen                            a_tma_slice,
432d95e6d02SGuray Ozen                            mbarTMA,
433c82f45f9SGuray Ozen                            a_tma_desc_op,
434d95e6d02SGuray Ozen                            coordinates=[coord, dimX],
435d95e6d02SGuray Ozen                            mbarId=stage,
436d95e6d02SGuray Ozen                            predicate=producerPrimaryThread,
437d95e6d02SGuray Ozen                        )
438d95e6d02SGuray Ozen                        nvgpu.TmaAsyncLoadOp(
439d95e6d02SGuray Ozen                            b_tma_slice_1,
440d95e6d02SGuray Ozen                            mbarTMA,
441c82f45f9SGuray Ozen                            b_tma_desc_op,
442d95e6d02SGuray Ozen                            coordinates=[dimY, coord],
443d95e6d02SGuray Ozen                            mbarId=stage,
444d95e6d02SGuray Ozen                            predicate=producerPrimaryThread,
445d95e6d02SGuray Ozen                        )
446d95e6d02SGuray Ozen                        dimY2 = arith.addi(dimY, c(64))
447d95e6d02SGuray Ozen                        nvgpu.TmaAsyncLoadOp(
448d95e6d02SGuray Ozen                            b_tma_slice_2,
449d95e6d02SGuray Ozen                            mbarTMA,
450c82f45f9SGuray Ozen                            b_tma_desc_op,
451d95e6d02SGuray Ozen                            coordinates=[dimY2, coord],
452d95e6d02SGuray Ozen                            mbarId=stage,
453d95e6d02SGuray Ozen                            predicate=producerPrimaryThread,
454d95e6d02SGuray Ozen                        )
455d95e6d02SGuray Ozen
456d95e6d02SGuray Ozen                        # Step 5.2.3. Arrive mbarTMA
457d95e6d02SGuray Ozen                        debug_print(
458d95e6d02SGuray Ozen                            "[prod] iv={}  | mbarTMA[{}] arrive",
459d95e6d02SGuray Ozen                            iv,
460d95e6d02SGuray Ozen                            stage,
461d95e6d02SGuray Ozen                            predicate=producerPrimaryThread,
462d95e6d02SGuray Ozen                        )
463d95e6d02SGuray Ozen                        nvgpu.mbarrier_arrive_expect_tx(
464d95e6d02SGuray Ozen                            mbarTMA, c(txcount), stage, predicate=producerPrimaryThread
465d95e6d02SGuray Ozen                        )
466d95e6d02SGuray Ozen                        debug_print(
467d95e6d02SGuray Ozen                            "[prod] iv={}  | mbarTMA[{}] arrive [done]",
468d95e6d02SGuray Ozen                            iv,
469d95e6d02SGuray Ozen                            stage,
470d95e6d02SGuray Ozen                            predicate=producerPrimaryThread,
471d95e6d02SGuray Ozen                        )
472d95e6d02SGuray Ozen                        scf.yield_([phaseParity])
473d95e6d02SGuray Ozen                    scf.yield_([])
474d95e6d02SGuray Ozen
475d95e6d02SGuray Ozen                # GPU Step 6. Consumer Warpgroup (MMA Warpgroup)
476d95e6d02SGuray Ozen                if_op = scf.IfOp(is_consumer)
477d95e6d02SGuray Ozen                with ir.InsertionPoint(if_op.then_block):
478d95e6d02SGuray Ozen                    # Step 6.1. Increase register size
479d95e6d02SGuray Ozen                    nvvm.setmaxregister(
480d95e6d02SGuray Ozen                        CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase
481d95e6d02SGuray Ozen                    )
482d95e6d02SGuray Ozen
483d95e6d02SGuray Ozen                    # GPU Step 6.2. Initialize MMA registers
484d95e6d02SGuray Ozen                    acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
485d95e6d02SGuray Ozen
486d95e6d02SGuray Ozen                    # Step 6.3. MMA Main Loop
487d95e6d02SGuray Ozen                    for_op = scf.ForOp(
488d95e6d02SGuray Ozen                        c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)]
489d95e6d02SGuray Ozen                    )
490d95e6d02SGuray Ozen                    with ir.InsertionPoint(for_op.body):
491d95e6d02SGuray Ozen                        # Step 6.3.1. Wait mbar1
492d95e6d02SGuray Ozen                        phaseParity = for_op.inner_iter_args[1]
493d95e6d02SGuray Ozen                        iv = for_op.induction_variable
494d95e6d02SGuray Ozen                        stage = arith.remui(iv, c(num_stages))
495d95e6d02SGuray Ozen                        debug_print(
496d95e6d02SGuray Ozen                            "[cons] iv={}  | mbarTMA[{}] try_wait   phase={}",
497d95e6d02SGuray Ozen                            iv,
498d95e6d02SGuray Ozen                            stage,
499d95e6d02SGuray Ozen                            phaseParity,
500d95e6d02SGuray Ozen                            predicate=consumerPrimaryThread,
501d95e6d02SGuray Ozen                        )
502d95e6d02SGuray Ozen                        nvgpu.MBarrierTryWaitParityOp(
503d95e6d02SGuray Ozen                            mbarTMA, phaseParity, ticks, mbarId=stage
504d95e6d02SGuray Ozen                        )
505d95e6d02SGuray Ozen                        debug_print(
506d95e6d02SGuray Ozen                            "[cons] iv={}  | mbarTMA[{}] try_wait   phase={} [done]",
507d95e6d02SGuray Ozen                            iv,
508d95e6d02SGuray Ozen                            stage,
509d95e6d02SGuray Ozen                            phaseParity,
510d95e6d02SGuray Ozen                            predicate=consumerPrimaryThread,
511d95e6d02SGuray Ozen                        )
512d95e6d02SGuray Ozen
513d95e6d02SGuray Ozen                        # Step 6.3.2. Create WGMMA Descriptors
514d95e6d02SGuray Ozen                        a_offset = arith.muli(stage, c(lhs_tile_bytes))
515d95e6d02SGuray Ozen                        a_tile_slice = memref.view(
516d95e6d02SGuray Ozen                            ir.MemRefType.get(
517d95e6d02SGuray Ozen                                a_tile_shape, a_elem_ty, memory_space=smem_space
518d95e6d02SGuray Ozen                            ),
519d95e6d02SGuray Ozen                            dynamic_smem,
520d95e6d02SGuray Ozen                            a_offset,
521d95e6d02SGuray Ozen                            [],
522d95e6d02SGuray Ozen                        )
523d95e6d02SGuray Ozen                        b_offset = arith.addi(
524d95e6d02SGuray Ozen                            arith.muli(stage, c(rhs_tile_bytes)),
525d95e6d02SGuray Ozen                            c(lhs_tile_bytes * num_stages),
526d95e6d02SGuray Ozen                        )
527d95e6d02SGuray Ozen                        b_tile_slice = memref.view(
528d95e6d02SGuray Ozen                            ir.MemRefType.get(
529d95e6d02SGuray Ozen                                b_tile_shape, b_elem_ty, memory_space=smem_space
530d95e6d02SGuray Ozen                            ),
531d95e6d02SGuray Ozen                            dynamic_smem,
532d95e6d02SGuray Ozen                            b_offset,
533d95e6d02SGuray Ozen                            [],
534d95e6d02SGuray Ozen                        )
535d95e6d02SGuray Ozen                        debug_print(
536d95e6d02SGuray Ozen                            "[cons] a_offset={} b_offset={}",
537d95e6d02SGuray Ozen                            a_offset,
538d95e6d02SGuray Ozen                            b_offset,
539d95e6d02SGuray Ozen                            predicate=consumerPrimaryThread,
540d95e6d02SGuray Ozen                        )
541d95e6d02SGuray Ozen                        da = nvgpu.WarpgroupGenerateDescriptorOp(
542c82f45f9SGuray Ozen                            a_wgmma_ty, a_tile_slice, a_tma_desc_op
543d95e6d02SGuray Ozen                        )
544d95e6d02SGuray Ozen                        db = nvgpu.WarpgroupGenerateDescriptorOp(
545c82f45f9SGuray Ozen                            b_wgmma_ty, b_tile_slice, b_tma_desc_op
546d95e6d02SGuray Ozen                        )
547d95e6d02SGuray Ozen
548d95e6d02SGuray Ozen                        # Step 6.3.3. MMA
549d95e6d02SGuray Ozen                        carry_acc = for_op.inner_iter_args[0]
550d95e6d02SGuray Ozen                        new_acc = nvgpu.WarpgroupMmaOp(
551d95e6d02SGuray Ozen                            acc.type, da, db, carry_acc, transposeB=True
552d95e6d02SGuray Ozen                        )
553d95e6d02SGuray Ozen
554d95e6d02SGuray Ozen                        # Step 6.3.4. Arrive mbarDONE
555d95e6d02SGuray Ozen                        if num_stages == 1:
556d95e6d02SGuray Ozen                            p_arrive = consumerPrimaryThread
557d95e6d02SGuray Ozen                        else:
558d95e6d02SGuray Ozen                            p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
559d95e6d02SGuray Ozen                            p_arrive = arith.andi(consumerPrimaryThread, p1)
560d95e6d02SGuray Ozen                        with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
561d95e6d02SGuray Ozen                            p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
562d95e6d02SGuray Ozen                            barId = arith.select(
563d95e6d02SGuray Ozen                                p, c(num_stages - 1), arith.subi(stage, c(1))
564d95e6d02SGuray Ozen                            )
565d95e6d02SGuray Ozen                            debug_print(
566d95e6d02SGuray Ozen                                "[cons] iv={}  | mbarDONE[{}] arrive ",
567d95e6d02SGuray Ozen                                iv,
568d95e6d02SGuray Ozen                                barId,
569d95e6d02SGuray Ozen                                predicate=consumerPrimaryThread,
570d95e6d02SGuray Ozen                            )
571*13d6233eSDurgadoss R                            nvgpu.mbarrier_arrive(mbarDONE, barId)
572d95e6d02SGuray Ozen                            debug_print(
573d95e6d02SGuray Ozen                                "[cons] iv={}  | mbarDONE[{}] arrive [done]",
574d95e6d02SGuray Ozen                                iv,
575d95e6d02SGuray Ozen                                barId,
576d95e6d02SGuray Ozen                                predicate=consumerPrimaryThread,
577d95e6d02SGuray Ozen                            )
578d95e6d02SGuray Ozen                            scf.yield_([])
579d95e6d02SGuray Ozen
580d95e6d02SGuray Ozen                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
581d95e6d02SGuray Ozen                        phaseParity = arith.select(
582d95e6d02SGuray Ozen                            p,
583d95e6d02SGuray Ozen                            arith.xori(phaseParity, arith.constant(T.bool(), 1)),
584d95e6d02SGuray Ozen                            phaseParity,
585d95e6d02SGuray Ozen                        )
586d95e6d02SGuray Ozen
587d95e6d02SGuray Ozen                        # Step 6.3.5. Yield
588d95e6d02SGuray Ozen                        scf.yield_([new_acc, phaseParity])
589d95e6d02SGuray Ozen
590d95e6d02SGuray Ozen                    with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block):
591d95e6d02SGuray Ozen                        barId = c((K // BLOCK_K) % num_stages)
592*13d6233eSDurgadoss R                        nvgpu.mbarrier_arrive(mbarDONE, barId)
593d95e6d02SGuray Ozen                        scf.yield_([])
594d95e6d02SGuray Ozen
595d95e6d02SGuray Ozen                    # Step 6.4. Epilogue (registers --> shared memory)
596d95e6d02SGuray Ozen                    acc_smem_ty = ir.MemRefType.get(
597d95e6d02SGuray Ozen                        (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
598d95e6d02SGuray Ozen                    )
599d95e6d02SGuray Ozen                    acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
600d95e6d02SGuray Ozen                    debug_print("[cons]  | Storing", predicate=consumerPrimaryThread)
601d95e6d02SGuray Ozen                    nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
602d95e6d02SGuray Ozen                    scf.yield_([])
603d95e6d02SGuray Ozen                gpu.barrier()
604d95e6d02SGuray Ozen
605d95e6d02SGuray Ozen                # GPU Step 9. Epilogue (shared memory --> global memory)
606d95e6d02SGuray Ozen                fd = ir.MemRefType.get(
607d95e6d02SGuray Ozen                    [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
608d95e6d02SGuray Ozen                )
609d95e6d02SGuray Ozen                collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
610d95e6d02SGuray Ozen                rty = ir.MemRefType.get(
611d95e6d02SGuray Ozen                    (BLOCK_M, BLOCK_N),
612d95e6d02SGuray Ozen                    c_elem_ty,
613d95e6d02SGuray Ozen                    ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
614d95e6d02SGuray Ozen                )
615d95e6d02SGuray Ozen                c_device_per_block = memref.SubViewOp(
616d95e6d02SGuray Ozen                    rty,
617d95e6d02SGuray Ozen                    c_device,
618d95e6d02SGuray Ozen                    [dimX, dimY],
619d95e6d02SGuray Ozen                    [],
620d95e6d02SGuray Ozen                    [],
621d95e6d02SGuray Ozen                    [MLIR_DYNAMIC, MLIR_DYNAMIC],
622d95e6d02SGuray Ozen                    [BLOCK_M, BLOCK_N],
623d95e6d02SGuray Ozen                    [1, 1],
624d95e6d02SGuray Ozen                )
625d95e6d02SGuray Ozen                vlen = 1
626d95e6d02SGuray Ozen                for_op = scf.ForOp(
627d95e6d02SGuray Ozen                    tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2)
628d95e6d02SGuray Ozen                )
629d95e6d02SGuray Ozen                with ir.InsertionPoint(for_op.body):
630d95e6d02SGuray Ozen                    x = arith.divui(for_op.induction_variable, c(BLOCK_M))
631d95e6d02SGuray Ozen                    y = arith.remui(for_op.induction_variable, c(BLOCK_N))
632d95e6d02SGuray Ozen                    vdata = vector.load(
633d95e6d02SGuray Ozen                        ir.VectorType.get((vlen,), c_elem_ty),
634d95e6d02SGuray Ozen                        collapsed_smem,
635d95e6d02SGuray Ozen                        [for_op.induction_variable],
636d95e6d02SGuray Ozen                    )
637d95e6d02SGuray Ozen                    vector.store(vdata, c_device_per_block, [x, y])
638d95e6d02SGuray Ozen                    scf.yield_([])
639d95e6d02SGuray Ozen
640d95e6d02SGuray Ozen                gpu.terminator()
641d95e6d02SGuray Ozen
642d95e6d02SGuray Ozen            # Step 4. Copy back to host
643d95e6d02SGuray Ozen            t8 = gpu.wait(token_ty, [launch_op])
644d95e6d02SGuray Ozen            t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
645d95e6d02SGuray Ozen            gpu.dealloc(token_ty, [t8], a_device)
646d95e6d02SGuray Ozen            gpu.dealloc(token_ty, [t8], b_device)
647d95e6d02SGuray Ozen            gpu.wait(token_ty, [t9])
648d95e6d02SGuray Ozen            gpu.dealloc(token_ty, [t8], c_device)
649d95e6d02SGuray Ozen            func.ReturnOp([])
650d95e6d02SGuray Ozen
651d95e6d02SGuray Ozen    fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
652d95e6d02SGuray Ozen    module.operation.verify()
653d95e6d02SGuray Ozen    return module
654d95e6d02SGuray Ozen
655d95e6d02SGuray Ozen
656d95e6d02SGuray Ozendef generate_matmul_multistage(
657d95e6d02SGuray Ozen    input_type=np.float16,
658d95e6d02SGuray Ozen    output_type=np.float32,
659d95e6d02SGuray Ozen    M=4096,
660d95e6d02SGuray Ozen    N=4096,
661d95e6d02SGuray Ozen    K=4096,
662d95e6d02SGuray Ozen    BLOCK_M=128,
663d95e6d02SGuray Ozen    BLOCK_N=128,
664d95e6d02SGuray Ozen    BLOCK_K=64,
665d95e6d02SGuray Ozen    num_stages=3,
666d95e6d02SGuray Ozen):
667d95e6d02SGuray Ozen    # Limitaitons for now
668d95e6d02SGuray Ozen    assert input_type == np.float16
669d95e6d02SGuray Ozen    assert output_type == np.float32
670d95e6d02SGuray Ozen    assert BLOCK_M == 128
671d95e6d02SGuray Ozen    assert BLOCK_N == 128
672d95e6d02SGuray Ozen    assert BLOCK_K == 64
673d95e6d02SGuray Ozen    assert M % BLOCK_M == 0
674d95e6d02SGuray Ozen    assert N % BLOCK_N == 0
675d95e6d02SGuray Ozen    assert K % BLOCK_K == 0
676d95e6d02SGuray Ozen
677d95e6d02SGuray Ozen    module = ir.Module.create()
678f8ff9094SGuray Ozen    token_ty = gpu.AsyncTokenType.get()
679d95e6d02SGuray Ozen    a_elem_ty = get_mlir_ty(input_type)
680d95e6d02SGuray Ozen    b_elem_ty = get_mlir_ty(input_type)
681d95e6d02SGuray Ozen    c_elem_ty = get_mlir_ty(output_type)
682d95e6d02SGuray Ozen    a_ty = ir.MemRefType.get([M, K], a_elem_ty)
683d95e6d02SGuray Ozen    b_ty = ir.MemRefType.get((K, N), b_elem_ty)
684d95e6d02SGuray Ozen    c_ty = ir.MemRefType.get((M, N), c_elem_ty)
685d95e6d02SGuray Ozen    a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
686d95e6d02SGuray Ozen    b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
687d95e6d02SGuray Ozen    b_tile_shape = (BLOCK_K, BLOCK_N)
688d95e6d02SGuray Ozen    txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + (
689d95e6d02SGuray Ozen        a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty)
690d95e6d02SGuray Ozen    )
691d95e6d02SGuray Ozen    smem_space_str = "#gpu.address_space<workgroup>"
692d95e6d02SGuray Ozen    smem_space = ir.Attribute.parse(smem_space_str)
693d95e6d02SGuray Ozen    mbar_ty = ir.Type.parse(
694d95e6d02SGuray Ozen        "!nvgpu.mbarrier.group<memorySpace = "
695d95e6d02SGuray Ozen        + str(smem_space)
696d95e6d02SGuray Ozen        + ", num_barriers = "
697d95e6d02SGuray Ozen        + str(num_stages)
698d95e6d02SGuray Ozen        + ">"
699d95e6d02SGuray Ozen    )
700d95e6d02SGuray Ozen    acc_ty = ir.Type.parse(
701d95e6d02SGuray Ozen        "!nvgpu.warpgroup.accumulator<fragmented=vector<"
702d95e6d02SGuray Ozen        + str(BLOCK_M)
703d95e6d02SGuray Ozen        + "x"
704d95e6d02SGuray Ozen        + str(BLOCK_N)
705d95e6d02SGuray Ozen        + "x"
706d95e6d02SGuray Ozen        + str(c_elem_ty)
707d95e6d02SGuray Ozen        + ">>"
708d95e6d02SGuray Ozen    )
709d95e6d02SGuray Ozen    a_wgmma_ty = ir.Type.parse(
710d95e6d02SGuray Ozen        "!nvgpu.warpgroup.descriptor<tensor=memref<"
711d95e6d02SGuray Ozen        + str(BLOCK_M)
712d95e6d02SGuray Ozen        + "x"
713d95e6d02SGuray Ozen        + str(BLOCK_K)
714d95e6d02SGuray Ozen        + "x"
715d95e6d02SGuray Ozen        + str(a_elem_ty)
716d95e6d02SGuray Ozen        + ", "
717d95e6d02SGuray Ozen        + smem_space_str
718d95e6d02SGuray Ozen        + ">>"
719d95e6d02SGuray Ozen    )
720d95e6d02SGuray Ozen    b_wgmma_ty = ir.Type.parse(
721d95e6d02SGuray Ozen        "!nvgpu.warpgroup.descriptor<tensor=memref<"
722d95e6d02SGuray Ozen        + str(BLOCK_K)
723d95e6d02SGuray Ozen        + "x"
724d95e6d02SGuray Ozen        + str(BLOCK_N)
725d95e6d02SGuray Ozen        + "x"
726d95e6d02SGuray Ozen        + str(a_elem_ty)
727d95e6d02SGuray Ozen        + ", "
728d95e6d02SGuray Ozen        + smem_space_str
729d95e6d02SGuray Ozen        + ">>"
730d95e6d02SGuray Ozen    )
731d95e6d02SGuray Ozen
732d95e6d02SGuray Ozen    with ir.InsertionPoint(module.body):
733d95e6d02SGuray Ozen        kernelName = make_kernel_name(
734d95e6d02SGuray Ozen            input_type,
735d95e6d02SGuray Ozen            output_type,
736d95e6d02SGuray Ozen            M,
737d95e6d02SGuray Ozen            N,
738d95e6d02SGuray Ozen            K,
739d95e6d02SGuray Ozen            BLOCK_M,
740d95e6d02SGuray Ozen            BLOCK_N,
741d95e6d02SGuray Ozen            BLOCK_K,
742d95e6d02SGuray Ozen            num_stages,
743d95e6d02SGuray Ozen            False,
744d95e6d02SGuray Ozen        )
745d95e6d02SGuray Ozen        fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], []))
746d95e6d02SGuray Ozen        with ir.InsertionPoint(fop.add_entry_block()):
747d95e6d02SGuray Ozen            a_host = fop.arguments[0]
748d95e6d02SGuray Ozen            b_host = fop.arguments[1]
749d95e6d02SGuray Ozen            c_host = fop.arguments[2]
750d95e6d02SGuray Ozen            lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
751d95e6d02SGuray Ozen            rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
752d95e6d02SGuray Ozen            smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
753d95e6d02SGuray Ozen            smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty)
754d95e6d02SGuray Ozen            smem_size = max(smem_size_input, smem_size_output)
755d95e6d02SGuray Ozen
756d95e6d02SGuray Ozen            # Step 1. Allocate device memory and memcpy
757d95e6d02SGuray Ozen            t1 = gpu.wait(token_ty, [])
758d95e6d02SGuray Ozen            a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
759d95e6d02SGuray Ozen            b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
760d95e6d02SGuray Ozen            c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
761d95e6d02SGuray Ozen            t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
762d95e6d02SGuray Ozen            t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
763d95e6d02SGuray Ozen            t7 = gpu.wait(token_ty, [t6])
764d95e6d02SGuray Ozen
765d95e6d02SGuray Ozen            # Step 2. Create TMA Descriptors
766c82f45f9SGuray Ozen            a_tma_desc = TmaDescriptorBuilder(
767c82f45f9SGuray Ozen                nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
768c82f45f9SGuray Ozen                nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
769c82f45f9SGuray Ozen                nvgpu.TensorMapOOBKind.OOB_ZERO,
770c82f45f9SGuray Ozen                nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
771c82f45f9SGuray Ozen                a_tma_shape,
772c82f45f9SGuray Ozen                a_ty,
773d95e6d02SGuray Ozen            )
774c82f45f9SGuray Ozen
775c82f45f9SGuray Ozen            b_tma_desc = TmaDescriptorBuilder(
776c82f45f9SGuray Ozen                nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
777c82f45f9SGuray Ozen                nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
778c82f45f9SGuray Ozen                nvgpu.TensorMapOOBKind.OOB_ZERO,
779c82f45f9SGuray Ozen                nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
780c82f45f9SGuray Ozen                b_tma_shape,
781c82f45f9SGuray Ozen                b_ty,
782d95e6d02SGuray Ozen            )
783c82f45f9SGuray Ozen
784c82f45f9SGuray Ozen            a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device)
785c82f45f9SGuray Ozen            b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device)
786d95e6d02SGuray Ozen
787d95e6d02SGuray Ozen            # Step 3. Launch Kernel with 1 Warpgroup
788d95e6d02SGuray Ozen            cta_m = M // BLOCK_M
789d95e6d02SGuray Ozen            cta_n = N // BLOCK_N
790d95e6d02SGuray Ozen            assert M % BLOCK_M == 0 and N % BLOCK_N == 0
791d95e6d02SGuray Ozen            grid = (cta_m, cta_n, 1)
792d95e6d02SGuray Ozen            block = (WARP_GROUP_SIZE, 1, 1)
793d95e6d02SGuray Ozen            launch_op = gpu.LaunchOp(
794d95e6d02SGuray Ozen                token_ty,
795d95e6d02SGuray Ozen                [t7],
796d95e6d02SGuray Ozen                *map(c, grid),
797d95e6d02SGuray Ozen                *map(c, block),
798c82f45f9SGuray Ozen                dynamicSharedMemorySize=c(smem_size, ty=T.i32()),
799d95e6d02SGuray Ozen            )
800d95e6d02SGuray Ozen            launch_op.body.blocks.append(*([T.index()] * 12))
801d95e6d02SGuray Ozen            with ir.InsertionPoint(launch_op.body.blocks[0]):
802d95e6d02SGuray Ozen                # GPU Step 0. Bootstrapping
803d95e6d02SGuray Ozen                memref.assume_alignment(c_device, 16)
804d95e6d02SGuray Ozen                dynamic_smem = gpu.dynamic_shared_memory(
805d95e6d02SGuray Ozen                    ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
806d95e6d02SGuray Ozen                )
807d95e6d02SGuray Ozen                ticks = c(10000000)
808d95e6d02SGuray Ozen                tidx = gpu.thread_id(gpu.Dimension.x)
809d95e6d02SGuray Ozen                primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0))
810d95e6d02SGuray Ozen                warpId = arith.divui(tidx, c(32))
811d95e6d02SGuray Ozen                bidx = gpu.block_id(gpu.Dimension.x)
812d95e6d02SGuray Ozen                bidy = gpu.block_id(gpu.Dimension.y)
813d95e6d02SGuray Ozen                dimX = arith.muli(bidx, c(BLOCK_M))
814d95e6d02SGuray Ozen                dimY = arith.muli(bidy, c(BLOCK_N))
815d95e6d02SGuray Ozen
816d95e6d02SGuray Ozen                # GPU Step 1. Initialize mbarrier groups
817d95e6d02SGuray Ozen                mbarTMA = nvgpu.mbarrier_create(mbar_ty)
818d95e6d02SGuray Ozen                for i in range(num_stages):
819d95e6d02SGuray Ozen                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread)
820d95e6d02SGuray Ozen                gpu.barrier()
821d95e6d02SGuray Ozen
822d95e6d02SGuray Ozen                # GPU Step 2. Prefetch TMA descriptors
823c82f45f9SGuray Ozen                nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=primaryThread)
824c82f45f9SGuray Ozen                nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=primaryThread)
825d95e6d02SGuray Ozen
826d95e6d02SGuray Ozen                # GPU Step 3. Prologue (global memory --> shared memory)
827d95e6d02SGuray Ozen                ns = num_stages if num_stages == 1 else num_stages - 1
828d95e6d02SGuray Ozen                for_op = scf.ForOp(c(0), c(ns), c(1))
829d95e6d02SGuray Ozen                with ir.InsertionPoint(for_op.body):
830d95e6d02SGuray Ozen                    iv = for_op.induction_variable
831d95e6d02SGuray Ozen
832d95e6d02SGuray Ozen                    # Step 3.1. Calculate offsets
833d95e6d02SGuray Ozen                    a_offset = arith.muli(iv, c(lhs_tile_bytes))
834d95e6d02SGuray Ozen                    a_tma_slice = memref.view(
835d95e6d02SGuray Ozen                        ir.MemRefType.get(
836d95e6d02SGuray Ozen                            a_tma_shape, a_elem_ty, memory_space=smem_space
837d95e6d02SGuray Ozen                        ),
838d95e6d02SGuray Ozen                        dynamic_smem,
839d95e6d02SGuray Ozen                        a_offset,
840d95e6d02SGuray Ozen                        [],
841d95e6d02SGuray Ozen                    )
842d95e6d02SGuray Ozen                    b_offset = arith.addi(
843d95e6d02SGuray Ozen                        arith.muli(iv, c(rhs_tile_bytes)),
844d95e6d02SGuray Ozen                        c(lhs_tile_bytes * num_stages),
845d95e6d02SGuray Ozen                    )
846d95e6d02SGuray Ozen                    b_tma_slice_1 = memref.view(
847d95e6d02SGuray Ozen                        ir.MemRefType.get(
848d95e6d02SGuray Ozen                            b_tma_shape, b_elem_ty, memory_space=smem_space
849d95e6d02SGuray Ozen                        ),
850d95e6d02SGuray Ozen                        dynamic_smem,
851d95e6d02SGuray Ozen                        b_offset,
852d95e6d02SGuray Ozen                        [],
853d95e6d02SGuray Ozen                    )
854d95e6d02SGuray Ozen                    b_offset2 = arith.addi(
855d95e6d02SGuray Ozen                        b_offset,
856d95e6d02SGuray Ozen                        c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
857d95e6d02SGuray Ozen                    )
858d95e6d02SGuray Ozen                    b_tma_slice_2 = memref.view(
859d95e6d02SGuray Ozen                        ir.MemRefType.get(
860d95e6d02SGuray Ozen                            b_tma_shape, b_elem_ty, memory_space=smem_space
861d95e6d02SGuray Ozen                        ),
862d95e6d02SGuray Ozen                        dynamic_smem,
863d95e6d02SGuray Ozen                        b_offset2,
864d95e6d02SGuray Ozen                        [],
865d95e6d02SGuray Ozen                    )
866d95e6d02SGuray Ozen
867d95e6d02SGuray Ozen                    # Step 3.2. TMA Load
868d95e6d02SGuray Ozen                    coord = arith.muli(c(64), iv)
869d95e6d02SGuray Ozen                    dimY2 = arith.addi(dimY, c(64))
870d95e6d02SGuray Ozen                    debug_print(
871d95e6d02SGuray Ozen                        "[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
872d95e6d02SGuray Ozen                        a_offset,
873d95e6d02SGuray Ozen                        b_offset,
874d95e6d02SGuray Ozen                        b_offset2,
875d95e6d02SGuray Ozen                        coord,
876d95e6d02SGuray Ozen                        dimX,
877d95e6d02SGuray Ozen                        dimY,
878d95e6d02SGuray Ozen                        coord,
879d95e6d02SGuray Ozen                        predicate=primaryThread,
880d95e6d02SGuray Ozen                    )
881d95e6d02SGuray Ozen                    nvgpu.TmaAsyncLoadOp(
882d95e6d02SGuray Ozen                        a_tma_slice,
883d95e6d02SGuray Ozen                        mbarTMA,
884c82f45f9SGuray Ozen                        a_tma_desc_op,
885d95e6d02SGuray Ozen                        coordinates=[coord, dimX],
886d95e6d02SGuray Ozen                        mbarId=iv,
887d95e6d02SGuray Ozen                        predicate=primaryThread,
888d95e6d02SGuray Ozen                    )
889d95e6d02SGuray Ozen                    nvgpu.TmaAsyncLoadOp(
890d95e6d02SGuray Ozen                        b_tma_slice_1,
891d95e6d02SGuray Ozen                        mbarTMA,
892c82f45f9SGuray Ozen                        b_tma_desc_op,
893d95e6d02SGuray Ozen                        coordinates=[dimY, coord],
894d95e6d02SGuray Ozen                        mbarId=iv,
895d95e6d02SGuray Ozen                        predicate=primaryThread,
896d95e6d02SGuray Ozen                    )
897d95e6d02SGuray Ozen                    nvgpu.TmaAsyncLoadOp(
898d95e6d02SGuray Ozen                        b_tma_slice_2,
899d95e6d02SGuray Ozen                        mbarTMA,
900c82f45f9SGuray Ozen                        b_tma_desc_op,
901d95e6d02SGuray Ozen                        coordinates=[dimY2, coord],
902d95e6d02SGuray Ozen                        mbarId=iv,
903d95e6d02SGuray Ozen                        predicate=primaryThread,
904d95e6d02SGuray Ozen                    )
905d95e6d02SGuray Ozen
906d95e6d02SGuray Ozen                    # Step 3.2. mbarTMA arrive
907d95e6d02SGuray Ozen                    debug_print(
908d95e6d02SGuray Ozen                        "[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread
909d95e6d02SGuray Ozen                    )
910d95e6d02SGuray Ozen                    nvgpu.mbarrier_arrive_expect_tx(
911d95e6d02SGuray Ozen                        mbarTMA, c(txcount), iv, predicate=primaryThread
912d95e6d02SGuray Ozen                    )
913d95e6d02SGuray Ozen                    debug_print(
914d95e6d02SGuray Ozen                        "[Prologue] mbarTMA[{}] arrive [done]",
915d95e6d02SGuray Ozen                        iv,
916d95e6d02SGuray Ozen                        predicate=primaryThread,
917d95e6d02SGuray Ozen                    )
918d95e6d02SGuray Ozen                    scf.yield_([])
919d95e6d02SGuray Ozen
920d95e6d02SGuray Ozen                # GPU Step 4. Main Loop
921d95e6d02SGuray Ozen                acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
922d95e6d02SGuray Ozen                for_op = scf.ForOp(
923d95e6d02SGuray Ozen                    c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)]
924d95e6d02SGuray Ozen                )
925d95e6d02SGuray Ozen                with ir.InsertionPoint(for_op.body):
926d95e6d02SGuray Ozen                    # Step 4.1. Wait mbarTMA
927d95e6d02SGuray Ozen                    phaseParity = for_op.inner_iter_args[1]
928d95e6d02SGuray Ozen                    iv = for_op.induction_variable
929d95e6d02SGuray Ozen                    stage = arith.remui(iv, c(num_stages))
930d95e6d02SGuray Ozen                    debug_print(
931d95e6d02SGuray Ozen                        "[MainLoop] mbarTMA[{}] try_wait   phase={}",
932d95e6d02SGuray Ozen                        stage,
933d95e6d02SGuray Ozen                        phaseParity,
934d95e6d02SGuray Ozen                        predicate=primaryThread,
935d95e6d02SGuray Ozen                    )
936d95e6d02SGuray Ozen                    nvgpu.MBarrierTryWaitParityOp(
937d95e6d02SGuray Ozen                        mbarTMA, phaseParity, ticks, mbarId=stage
938d95e6d02SGuray Ozen                    )
939d95e6d02SGuray Ozen                    debug_print(
940d95e6d02SGuray Ozen                        "[MainLoop] mbarTMA[{}] try_wait   phase={} [done]",
941d95e6d02SGuray Ozen                        stage,
942d95e6d02SGuray Ozen                        phaseParity,
943d95e6d02SGuray Ozen                        predicate=primaryThread,
944d95e6d02SGuray Ozen                    )
945d95e6d02SGuray Ozen
946d95e6d02SGuray Ozen                    # Step 4.2. Create WGMMA Descriptors
947d95e6d02SGuray Ozen                    a_offset = arith.muli(stage, c(lhs_tile_bytes))
948d95e6d02SGuray Ozen                    a_tile_slice = memref.view(
949d95e6d02SGuray Ozen                        ir.MemRefType.get(
950d95e6d02SGuray Ozen                            a_tile_shape, a_elem_ty, memory_space=smem_space
951d95e6d02SGuray Ozen                        ),
952d95e6d02SGuray Ozen                        dynamic_smem,
953d95e6d02SGuray Ozen                        a_offset,
954d95e6d02SGuray Ozen                        [],
955d95e6d02SGuray Ozen                    )
956d95e6d02SGuray Ozen                    b_offset = arith.addi(
957d95e6d02SGuray Ozen                        arith.muli(stage, c(rhs_tile_bytes)),
958d95e6d02SGuray Ozen                        c(lhs_tile_bytes * num_stages),
959d95e6d02SGuray Ozen                    )
960d95e6d02SGuray Ozen                    b_tile_slice = memref.view(
961d95e6d02SGuray Ozen                        ir.MemRefType.get(
962d95e6d02SGuray Ozen                            b_tile_shape, b_elem_ty, memory_space=smem_space
963d95e6d02SGuray Ozen                        ),
964d95e6d02SGuray Ozen                        dynamic_smem,
965d95e6d02SGuray Ozen                        b_offset,
966d95e6d02SGuray Ozen                        [],
967d95e6d02SGuray Ozen                    )
968d95e6d02SGuray Ozen                    debug_print(
969d95e6d02SGuray Ozen                        "[MainLoop] iv={} MMA a_offset={} b_offset={}",
970d95e6d02SGuray Ozen                        iv,
971d95e6d02SGuray Ozen                        a_offset,
972d95e6d02SGuray Ozen                        b_offset,
973d95e6d02SGuray Ozen                        predicate=primaryThread,
974d95e6d02SGuray Ozen                    )
975d95e6d02SGuray Ozen                    da = nvgpu.WarpgroupGenerateDescriptorOp(
976c82f45f9SGuray Ozen                        a_wgmma_ty, a_tile_slice, a_tma_desc_op
977d95e6d02SGuray Ozen                    )
978d95e6d02SGuray Ozen                    db = nvgpu.WarpgroupGenerateDescriptorOp(
979c82f45f9SGuray Ozen                        b_wgmma_ty, b_tile_slice, b_tma_desc_op
980d95e6d02SGuray Ozen                    )
981d95e6d02SGuray Ozen
982d95e6d02SGuray Ozen                    # Step 4.3. MMA
983d95e6d02SGuray Ozen                    carry_acc = for_op.inner_iter_args[0]
984d95e6d02SGuray Ozen                    new_acc = nvgpu.WarpgroupMmaOp(
985d95e6d02SGuray Ozen                        acc.type, da, db, carry_acc, transposeB=True
986d95e6d02SGuray Ozen                    )
987d95e6d02SGuray Ozen                    if num_stages == 1:
988d95e6d02SGuray Ozen                        nvvm.WgmmaWaitGroupSyncOp(0)
989d95e6d02SGuray Ozen
990d95e6d02SGuray Ozen                    # Step 4.4. Load TMA for next stage
991d95e6d02SGuray Ozen                    p1 = arith.cmpi(
992d95e6d02SGuray Ozen                        arith.CmpIPredicate.ult,
993d95e6d02SGuray Ozen                        arith.addi(iv, c(ns)),
994d95e6d02SGuray Ozen                        c(K // BLOCK_K),
995d95e6d02SGuray Ozen                    )
996d95e6d02SGuray Ozen                    p = arith.andi(primaryThread, p1)
997d95e6d02SGuray Ozen                    nextStage = arith.addi(iv, c(ns))
998d95e6d02SGuray Ozen                    nextSlot = arith.remui(nextStage, c(num_stages))
999d95e6d02SGuray Ozen                    a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
1000d95e6d02SGuray Ozen
1001d95e6d02SGuray Ozen                    debug_print(
1002d95e6d02SGuray Ozen                        "[MainLoop] mbarTMA[{}] arrive",
1003d95e6d02SGuray Ozen                        nextSlot,
1004d95e6d02SGuray Ozen                        predicate=p,
1005d95e6d02SGuray Ozen                    )
1006d95e6d02SGuray Ozen                    nvgpu.mbarrier_arrive_expect_tx(
1007d95e6d02SGuray Ozen                        mbarTMA, c(txcount), nextSlot, predicate=p
1008d95e6d02SGuray Ozen                    )
1009d95e6d02SGuray Ozen                    debug_print(
1010d95e6d02SGuray Ozen                        "[MainLoop] mbarTMA[{}] arrive [done]",
1011d95e6d02SGuray Ozen                        nextSlot,
1012d95e6d02SGuray Ozen                        predicate=p,
1013d95e6d02SGuray Ozen                    )
1014d95e6d02SGuray Ozen
1015d95e6d02SGuray Ozen                    a_tma_slice = memref.view(
1016d95e6d02SGuray Ozen                        ir.MemRefType.get(
1017d95e6d02SGuray Ozen                            a_tma_shape, a_elem_ty, memory_space=smem_space
1018d95e6d02SGuray Ozen                        ),
1019d95e6d02SGuray Ozen                        dynamic_smem,
1020d95e6d02SGuray Ozen                        a_offset,
1021d95e6d02SGuray Ozen                        [],
1022d95e6d02SGuray Ozen                    )
1023d95e6d02SGuray Ozen                    b_offset = arith.addi(
1024d95e6d02SGuray Ozen                        arith.muli(nextSlot, c(rhs_tile_bytes)),
1025d95e6d02SGuray Ozen                        c(lhs_tile_bytes * num_stages),
1026d95e6d02SGuray Ozen                    )
1027d95e6d02SGuray Ozen                    b_tma_slice_1 = memref.view(
1028d95e6d02SGuray Ozen                        ir.MemRefType.get(
1029d95e6d02SGuray Ozen                            b_tma_shape, b_elem_ty, memory_space=smem_space
1030d95e6d02SGuray Ozen                        ),
1031d95e6d02SGuray Ozen                        dynamic_smem,
1032d95e6d02SGuray Ozen                        b_offset,
1033d95e6d02SGuray Ozen                        [],
1034d95e6d02SGuray Ozen                    )
1035d95e6d02SGuray Ozen                    b_offset2 = arith.addi(
1036d95e6d02SGuray Ozen                        b_offset,
1037d95e6d02SGuray Ozen                        c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
1038d95e6d02SGuray Ozen                    )
1039d95e6d02SGuray Ozen                    b_tma_slice_2 = memref.view(
1040d95e6d02SGuray Ozen                        ir.MemRefType.get(
1041d95e6d02SGuray Ozen                            b_tma_shape, b_elem_ty, memory_space=smem_space
1042d95e6d02SGuray Ozen                        ),
1043d95e6d02SGuray Ozen                        dynamic_smem,
1044d95e6d02SGuray Ozen                        b_offset2,
1045d95e6d02SGuray Ozen                        [],
1046d95e6d02SGuray Ozen                    )
1047d95e6d02SGuray Ozen
1048d95e6d02SGuray Ozen                    coord = arith.muli(c(64), nextStage)
1049d95e6d02SGuray Ozen                    debug_print(
1050d95e6d02SGuray Ozen                        "[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
1051d95e6d02SGuray Ozen                        iv,
1052d95e6d02SGuray Ozen                        a_offset,
1053d95e6d02SGuray Ozen                        b_offset,
1054d95e6d02SGuray Ozen                        b_offset2,
1055d95e6d02SGuray Ozen                        coord,
1056d95e6d02SGuray Ozen                        dimX,
1057d95e6d02SGuray Ozen                        dimY,
1058d95e6d02SGuray Ozen                        coord,
1059d95e6d02SGuray Ozen                        predicate=p,
1060d95e6d02SGuray Ozen                    )
1061d95e6d02SGuray Ozen                    nvgpu.TmaAsyncLoadOp(
1062d95e6d02SGuray Ozen                        a_tma_slice,
1063d95e6d02SGuray Ozen                        mbarTMA,
1064c82f45f9SGuray Ozen                        a_tma_desc_op,
1065d95e6d02SGuray Ozen                        coordinates=[coord, dimX],
1066d95e6d02SGuray Ozen                        mbarId=nextSlot,
1067d95e6d02SGuray Ozen                        predicate=p,
1068d95e6d02SGuray Ozen                    )
1069d95e6d02SGuray Ozen                    nvgpu.TmaAsyncLoadOp(
1070d95e6d02SGuray Ozen                        b_tma_slice_1,
1071d95e6d02SGuray Ozen                        mbarTMA,
1072c82f45f9SGuray Ozen                        b_tma_desc_op,
1073d95e6d02SGuray Ozen                        coordinates=[dimY, coord],
1074d95e6d02SGuray Ozen                        mbarId=nextSlot,
1075d95e6d02SGuray Ozen                        predicate=p,
1076d95e6d02SGuray Ozen                    )
1077d95e6d02SGuray Ozen                    dimY2 = arith.addi(dimY, c(64))
1078d95e6d02SGuray Ozen                    nvgpu.TmaAsyncLoadOp(
1079d95e6d02SGuray Ozen                        b_tma_slice_2,
1080d95e6d02SGuray Ozen                        mbarTMA,
1081c82f45f9SGuray Ozen                        b_tma_desc_op,
1082d95e6d02SGuray Ozen                        coordinates=[dimY2, coord],
1083d95e6d02SGuray Ozen                        mbarId=nextSlot,
1084d95e6d02SGuray Ozen                        predicate=p,
1085d95e6d02SGuray Ozen                    )
1086d95e6d02SGuray Ozen                    # Step 4.5. Change the phaseParity
1087d95e6d02SGuray Ozen                    p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
1088d95e6d02SGuray Ozen                    phaseParity = arith.select(
1089d95e6d02SGuray Ozen                        p,
1090d95e6d02SGuray Ozen                        arith.xori(phaseParity, arith.constant(T.bool(), 1)),
1091d95e6d02SGuray Ozen                        phaseParity,
1092d95e6d02SGuray Ozen                    )
1093d95e6d02SGuray Ozen
1094d95e6d02SGuray Ozen                    # Step 4.5. Yield
1095d95e6d02SGuray Ozen                    scf.yield_([new_acc, phaseParity])
1096d95e6d02SGuray Ozen
1097d95e6d02SGuray Ozen                # Step 5. Wait All WGMMA groups
1098d95e6d02SGuray Ozen                nvvm.WgmmaWaitGroupSyncOp(0)
1099d95e6d02SGuray Ozen
1100d95e6d02SGuray Ozen                # Step 6. Epilogue (registers --> shared memory)
1101d95e6d02SGuray Ozen                acc_smem_ty = ir.MemRefType.get(
1102d95e6d02SGuray Ozen                    (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
1103d95e6d02SGuray Ozen                )
1104d95e6d02SGuray Ozen                acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
1105d95e6d02SGuray Ozen                debug_print("Storing", predicate=primaryThread)
1106d95e6d02SGuray Ozen                nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
1107d95e6d02SGuray Ozen                gpu.barrier()
1108d95e6d02SGuray Ozen
1109d95e6d02SGuray Ozen                # GPU Step 7. Epilogue (shared memory --> global memory)
1110d95e6d02SGuray Ozen                fd = ir.MemRefType.get(
1111d95e6d02SGuray Ozen                    [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
1112d95e6d02SGuray Ozen                )
1113d95e6d02SGuray Ozen                collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
1114d95e6d02SGuray Ozen                rty = ir.MemRefType.get(
1115d95e6d02SGuray Ozen                    (BLOCK_M, BLOCK_N),
1116d95e6d02SGuray Ozen                    c_elem_ty,
1117d95e6d02SGuray Ozen                    ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
1118d95e6d02SGuray Ozen                )
1119d95e6d02SGuray Ozen                c_device_per_block = memref.SubViewOp(
1120d95e6d02SGuray Ozen                    rty,
1121d95e6d02SGuray Ozen                    c_device,
1122d95e6d02SGuray Ozen                    [dimX, dimY],
1123d95e6d02SGuray Ozen                    [],
1124d95e6d02SGuray Ozen                    [],
1125d95e6d02SGuray Ozen                    [MLIR_DYNAMIC, MLIR_DYNAMIC],
1126d95e6d02SGuray Ozen                    [BLOCK_M, BLOCK_N],
1127d95e6d02SGuray Ozen                    [1, 1],
1128d95e6d02SGuray Ozen                )
1129d95e6d02SGuray Ozen                vlen = 1
1130d95e6d02SGuray Ozen                for_op = scf.ForOp(
1131d95e6d02SGuray Ozen                    tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE)
1132d95e6d02SGuray Ozen                )
1133d95e6d02SGuray Ozen                with ir.InsertionPoint(for_op.body):
1134d95e6d02SGuray Ozen                    x = arith.divui(for_op.induction_variable, c(BLOCK_M))
1135d95e6d02SGuray Ozen                    y = arith.remui(for_op.induction_variable, c(BLOCK_N))
1136d95e6d02SGuray Ozen                    vdata = vector.load(
1137d95e6d02SGuray Ozen                        ir.VectorType.get((vlen,), c_elem_ty),
1138d95e6d02SGuray Ozen                        collapsed_smem,
1139d95e6d02SGuray Ozen                        [for_op.induction_variable],
1140d95e6d02SGuray Ozen                    )
1141d95e6d02SGuray Ozen                    vector.store(vdata, c_device_per_block, [x, y])
1142d95e6d02SGuray Ozen                    scf.yield_([])
1143d95e6d02SGuray Ozen
1144d95e6d02SGuray Ozen                gpu.terminator()
1145d95e6d02SGuray Ozen
1146d95e6d02SGuray Ozen            # Step 4. Copy back to host
1147d95e6d02SGuray Ozen            t8 = gpu.wait(token_ty, [launch_op])
1148d95e6d02SGuray Ozen            t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
1149d95e6d02SGuray Ozen            gpu.dealloc(token_ty, [t8], a_device)
1150d95e6d02SGuray Ozen            gpu.dealloc(token_ty, [t8], b_device)
1151d95e6d02SGuray Ozen            gpu.wait(token_ty, [t9])
1152d95e6d02SGuray Ozen            gpu.dealloc(token_ty, [t8], c_device)
1153d95e6d02SGuray Ozen            func.ReturnOp([])
1154d95e6d02SGuray Ozen
1155d95e6d02SGuray Ozen    fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
1156d95e6d02SGuray Ozen    module.operation.verify()
1157d95e6d02SGuray Ozen    return module
1158