xref: /llvm-project/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py (revision 13d6233e77982f2a596922a79365373e1466a968)
1import numpy as np
2from mlir import ir
3from mlir.dialects import arith
4from mlir.dialects import func
5from mlir.dialects import gpu
6from mlir.dialects import memref
7from mlir.dialects import nvgpu
8from mlir.dialects import nvvm
9from mlir.dialects import llvm
10from mlir.dialects import builtin
11from mlir.dialects import scf
12from mlir.dialects import vector
13from mlir.extras import types as T
14
15TMA_LAST_DIM_F16 = 64  # 128B flaot16
16WARP_SIZE = 32
17WARP_GROUP_SIZE = WARP_SIZE * 4
18
19PRODUCER_REGISTER_SIZE = 40
20CONSUMER_REGISTER_SIZE = 232
21
22PRODUCER_PRIMARY_THREAD = 128
23CONSUMER_PRIMARY_THREAD = 0
24
25# C++ uses this value to understand whether it's dynamic or not.
26MLIR_DYNAMIC = -9223372036854775808
27
28DEBUG = False
29
30
31class TmaDescriptorBuilder:
32    """A class that builds a TMA descriptor."""
33
34    def __init__(self, swizzle, l2promo, oob, interleave, tma_box_shape, memref_ty):
35        self.swizzle = swizzle  # mlir.nvgpu.TensorMapSwizzleKind
36        self.l2promo = l2promo  # mlir.nvgpu.TensorMapL2PromoKind
37        self.oob = oob  # mlir.nvgpu.TensorMapOOBKind
38        self.interleave = interleave  # mlir.nvgpu.TensorMapInterleaveKind
39        self.tma_box_shape = tma_box_shape
40        self.memref_ty = memref_ty  # MemRefType
41
42    @property
43    def tensormap_descriptor_ty(self):
44        """Returns a tensormap descriptor type."""
45        tensorMemrefType = ir.MemRefType.get(
46            self.tma_box_shape,
47            self.memref_ty.element_type,
48            memory_space=ir.Attribute.parse("3"),
49        )
50        return nvgpu.TensorMapDescriptorType.get(
51            tensorMemrefType,
52            self.swizzle,
53            self.l2promo,
54            self.oob,
55            self.interleave,
56        )
57
58    def tma_descriptor_op(self, device_ptr):
59        """Returns a tensormap descriptor op."""
60        tma_descriptor_ty = self.tensormap_descriptor_ty
61        device_unranked_memref = memref.CastOp(
62            ir.UnrankedMemRefType.get(
63                self.memref_ty.element_type, self.memref_ty.memory_space
64            ),
65            device_ptr,
66        )
67        tma_descriptor_op = nvgpu.TmaCreateDescriptorOp(
68            tma_descriptor_ty, device_unranked_memref, map(c, self.tma_box_shape)
69        )
70        return tma_descriptor_op.result
71
72
73def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
74    if not DEBUG and not forcePrint:
75        return
76    type_formats = []
77    for arg in args:
78        ty_format = None
79        if ir.IndexType.isinstance(arg.type):
80            ty_format = "%llu"
81        if ir.IntegerType.isinstance(arg.type):
82            width = ir.IntegerType(arg.type).width
83            if width == 64:
84                ty_format = "%llu"
85            elif width == 32:
86                ty_format = "%d"
87            elif width == 1:
88                ty_format = "%i"
89        if ir.F32Type.isinstance(arg.type):
90            ty_format = "%f"
91        if ty_format is None:
92            raise NotImplementedError(arg.type)
93        type_formats.append(ty_format)
94    if threadNumber != -1:
95        tidx = gpu.thread_id(gpu.Dimension.x)
96        predicate = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(threadNumber))
97        scf.yield_([])
98    if_op = scf.IfOp(predicate)
99    with ir.InsertionPoint(if_op.then_block):
100        gpu.printf(fmt.format(*type_formats) + "\n", args)
101        scf.yield_([])
102
103
104def get_type_size(ty):
105    if ir.FloatType.isinstance(ty):
106        return ir.FloatType(ty).width // 8
107    if ir.IntegerType.isinstance(ty):
108        return ir.IntegerType(ty).width // 8
109    raise NotImplementedError(ty)
110
111
112def get_mlir_ty(dtype):
113    if dtype == np.float16:
114        return T.f16()
115    if dtype == np.float32:
116        return T.f32()
117    if dtype == np.float64:
118        return T.f64()
119    if dtype == np.int32:
120        return T.i32()
121    if dtype == np.int64:
122        return T.i64()
123    raise NotImplementedError(dtype)
124
125
126def c(value, ty=None):
127    ty = T.index() if ty is None else ty
128    return arith.constant(ty, value)
129
130
131def make_kernel_name(
132    input_type=np.float16,
133    output_type=np.float32,
134    M=4096,
135    N=4096,
136    K=4096,
137    BLOCK_M=128,
138    BLOCK_N=128,
139    BLOCK_K=128,
140    num_stages=3,
141    use_warp_specialization=False,
142):
143    kernelName = "warpspecialized" if use_warp_specialization else "multistage"
144    return (
145        kernelName
146        + "_"
147        + str(M)
148        + "x"
149        + str(N)
150        + "x"
151        + str(K)
152        + "_"
153        + str(BLOCK_M)
154        + "x"
155        + str(BLOCK_N)
156        + "x"
157        + str(BLOCK_K)
158        + "_"
159        + str(num_stages)
160    )
161
162
163def generate_matmul_ws(
164    input_type=np.float16,
165    output_type=np.float32,
166    M=4096,
167    N=4096,
168    K=4096,
169    BLOCK_M=128,
170    BLOCK_N=128,
171    BLOCK_K=128,
172    num_stages=3,
173):
174    # Limitaitons for now
175    assert input_type == np.float16
176    assert output_type == np.float32
177    assert BLOCK_M == 128
178    assert BLOCK_N == 128
179    assert BLOCK_K == 64
180    assert M % BLOCK_M == 0
181    assert N % BLOCK_N == 0
182    assert K % BLOCK_K == 0
183
184    module = ir.Module.create()
185    token_ty = gpu.AsyncTokenType.get()
186    a_elem_ty = get_mlir_ty(input_type)
187    b_elem_ty = get_mlir_ty(input_type)
188    c_elem_ty = get_mlir_ty(output_type)
189    a_ty = ir.MemRefType.get([M, K], a_elem_ty)
190    b_ty = ir.MemRefType.get((K, N), b_elem_ty)
191    c_ty = ir.MemRefType.get((M, N), c_elem_ty)
192    a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
193    b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
194    b_tile_shape = (BLOCK_K, BLOCK_N)
195    txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + (
196        a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty)
197    )
198    smem_space_str = "#gpu.address_space<workgroup>"
199    smem_space = ir.Attribute.parse(smem_space_str)
200    mbar_ty = ir.Type.parse(
201        "!nvgpu.mbarrier.group<memorySpace = "
202        + str(smem_space)
203        + ", num_barriers = "
204        + str(num_stages)
205        + ">"
206    )
207    acc_ty = ir.Type.parse(
208        "!nvgpu.warpgroup.accumulator<fragmented=vector<"
209        + str(BLOCK_M)
210        + "x"
211        + str(BLOCK_N)
212        + "x"
213        + str(c_elem_ty)
214        + ">>"
215    )
216    a_wgmma_ty = ir.Type.parse(
217        "!nvgpu.warpgroup.descriptor<tensor=memref<"
218        + str(BLOCK_M)
219        + "x"
220        + str(BLOCK_K)
221        + "x"
222        + str(a_elem_ty)
223        + ", "
224        + smem_space_str
225        + ">>"
226    )
227    b_wgmma_ty = ir.Type.parse(
228        "!nvgpu.warpgroup.descriptor<tensor=memref<"
229        + str(BLOCK_K)
230        + "x"
231        + str(BLOCK_N)
232        + "x"
233        + str(a_elem_ty)
234        + ", "
235        + smem_space_str
236        + ">>"
237    )
238    kernelName = make_kernel_name(
239        input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_stages, True
240    )
241    with ir.InsertionPoint(module.body):
242        fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], []))
243        with ir.InsertionPoint(fop.add_entry_block()):
244            a_host = fop.arguments[0]
245            b_host = fop.arguments[1]
246            c_host = fop.arguments[2]
247            lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
248            rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
249            smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
250            smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty)
251            smem_size = max(smem_size_input, smem_size_output)
252
253            # Step 1. Allocate device memory and memcpy
254            t1 = gpu.wait(token_ty, [])
255            a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
256            b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
257            c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
258            t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
259            t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
260            t7 = gpu.wait(token_ty, [t6])
261
262            # Step 2. Create TMA Descriptors
263            a_tma_desc = TmaDescriptorBuilder(
264                nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
265                nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
266                nvgpu.TensorMapOOBKind.OOB_ZERO,
267                nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
268                a_tma_shape,
269                a_ty,
270            )
271
272            b_tma_desc = TmaDescriptorBuilder(
273                nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
274                nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
275                nvgpu.TensorMapOOBKind.OOB_ZERO,
276                nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
277                b_tma_shape,
278                b_ty,
279            )
280
281            a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device)
282            b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device)
283
284            # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
285            cta_m = M // BLOCK_M
286            cta_n = N // BLOCK_N
287            assert M % BLOCK_M == 0 and N % BLOCK_N == 0
288            grid = (cta_m, cta_n, 1)
289            block = (WARP_GROUP_SIZE * 2, 1, 1)
290            launch_op = gpu.LaunchOp(
291                token_ty,
292                [t7],
293                *map(c, grid),
294                *map(c, block),
295                dynamicSharedMemorySize=c(smem_size, ty=T.i32()),
296            )
297            launch_op.body.blocks.append(*([T.index()] * 12))
298            with ir.InsertionPoint(launch_op.body.blocks[0]):
299                # GPU Step 0. This is need for vectorized ld/st
300                memref.assume_alignment(c_device, 16)
301                dynamic_smem = gpu.dynamic_shared_memory(
302                    ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
303                )
304                ticks = c(10000000)
305
306                # GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc.
307                tidx = gpu.thread_id(gpu.Dimension.x)
308                wgPrimaryThread = arith.cmpi(
309                    arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0)
310                )
311                warp_id = arith.divui(tidx, c(32))
312                warpgroup_id = arith.divui(warp_id, c(4))
313                is_producer = arith.cmpi(
314                    arith.CmpIPredicate.eq,
315                    warpgroup_id,
316                    c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0),
317                )
318                is_consumer = arith.cmpi(
319                    arith.CmpIPredicate.eq,
320                    warpgroup_id,
321                    c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1),
322                )
323                producerPrimaryThread = arith.cmpi(
324                    arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD)
325                )
326                consumerPrimaryThread = arith.cmpi(
327                    arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD)
328                )
329                bidx = gpu.block_id(gpu.Dimension.x)
330                bidy = gpu.block_id(gpu.Dimension.y)
331                dimX = arith.muli(bidx, c(BLOCK_M))
332                dimY = arith.muli(bidy, c(BLOCK_N))
333
334                # GPU Step 2. Initialize mbarrier groups
335                mbarTMA = nvgpu.mbarrier_create(mbar_ty)
336                mbarDONE = nvgpu.mbarrier_create(mbar_ty)
337                for i in range(num_stages):
338                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread)
339                    nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread)
340                gpu.barrier()
341
342                # GPU Step 3. Prefetch TMA descriptors
343                nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=wgPrimaryThread)
344                nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=wgPrimaryThread)
345
346                ns = num_stages if num_stages == 1 else num_stages - 1
347                # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
348                with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
349                    # Step 5.1. Reduce register size
350                    nvvm.setmaxregister(
351                        PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease
352                    )
353
354                    # Step 5.2. TMA Main Loop
355                    for_op = scf.ForOp(
356                        c(0), c(K // BLOCK_K), c(1), [arith.constant(T.bool(), 1)]
357                    )
358                    with ir.InsertionPoint(for_op.body):
359                        phaseParity = for_op.inner_iter_args[0]
360                        iv = for_op.induction_variable
361                        stage = arith.remui(iv, c(num_stages))
362
363                        # Step 5.2.1. Wait mbarDONE
364                        debug_print(
365                            "[prod] iv={}  | mbarDONE[{}] try_wait  phase={}",
366                            iv,
367                            stage,
368                            phaseParity,
369                            predicate=producerPrimaryThread,
370                        )
371                        nvgpu.MBarrierTryWaitParityOp(
372                            mbarDONE, phaseParity, ticks, mbarId=stage
373                        )
374                        debug_print(
375                            "[prod] iv={}  | mbarDONE[{}] try_wait  phase={} [done]",
376                            iv,
377                            stage,
378                            phaseParity,
379                            predicate=producerPrimaryThread,
380                        )
381                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
382                        phaseParity = arith.select(
383                            p,
384                            arith.xori(phaseParity, arith.constant(T.bool(), 1)),
385                            phaseParity,
386                        )
387
388                        # Step 5.2.2. Load TMA
389                        a_offset = arith.muli(stage, c(lhs_tile_bytes))
390                        a_tma_slice = memref.view(
391                            ir.MemRefType.get(
392                                a_tma_shape, a_elem_ty, memory_space=smem_space
393                            ),
394                            dynamic_smem,
395                            a_offset,
396                            [],
397                        )
398                        b_offset = arith.addi(
399                            arith.muli(stage, c(rhs_tile_bytes)),
400                            c(lhs_tile_bytes * num_stages),
401                        )
402                        b_tma_slice_1 = memref.view(
403                            ir.MemRefType.get(
404                                b_tma_shape, b_elem_ty, memory_space=smem_space
405                            ),
406                            dynamic_smem,
407                            b_offset,
408                            [],
409                        )
410                        b_offset2 = arith.addi(
411                            b_offset,
412                            c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
413                        )
414                        b_tma_slice_2 = memref.view(
415                            ir.MemRefType.get(
416                                b_tma_shape, b_elem_ty, memory_space=smem_space
417                            ),
418                            dynamic_smem,
419                            b_offset2,
420                            [],
421                        )
422                        debug_print(
423                            "[prod] a_offset={} b_offset={} b_offset2={}",
424                            a_offset,
425                            b_offset,
426                            b_offset2,
427                            predicate=producerPrimaryThread,
428                        )
429                        coord = arith.muli(c(64), iv)
430                        nvgpu.TmaAsyncLoadOp(
431                            a_tma_slice,
432                            mbarTMA,
433                            a_tma_desc_op,
434                            coordinates=[coord, dimX],
435                            mbarId=stage,
436                            predicate=producerPrimaryThread,
437                        )
438                        nvgpu.TmaAsyncLoadOp(
439                            b_tma_slice_1,
440                            mbarTMA,
441                            b_tma_desc_op,
442                            coordinates=[dimY, coord],
443                            mbarId=stage,
444                            predicate=producerPrimaryThread,
445                        )
446                        dimY2 = arith.addi(dimY, c(64))
447                        nvgpu.TmaAsyncLoadOp(
448                            b_tma_slice_2,
449                            mbarTMA,
450                            b_tma_desc_op,
451                            coordinates=[dimY2, coord],
452                            mbarId=stage,
453                            predicate=producerPrimaryThread,
454                        )
455
456                        # Step 5.2.3. Arrive mbarTMA
457                        debug_print(
458                            "[prod] iv={}  | mbarTMA[{}] arrive",
459                            iv,
460                            stage,
461                            predicate=producerPrimaryThread,
462                        )
463                        nvgpu.mbarrier_arrive_expect_tx(
464                            mbarTMA, c(txcount), stage, predicate=producerPrimaryThread
465                        )
466                        debug_print(
467                            "[prod] iv={}  | mbarTMA[{}] arrive [done]",
468                            iv,
469                            stage,
470                            predicate=producerPrimaryThread,
471                        )
472                        scf.yield_([phaseParity])
473                    scf.yield_([])
474
475                # GPU Step 6. Consumer Warpgroup (MMA Warpgroup)
476                if_op = scf.IfOp(is_consumer)
477                with ir.InsertionPoint(if_op.then_block):
478                    # Step 6.1. Increase register size
479                    nvvm.setmaxregister(
480                        CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase
481                    )
482
483                    # GPU Step 6.2. Initialize MMA registers
484                    acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
485
486                    # Step 6.3. MMA Main Loop
487                    for_op = scf.ForOp(
488                        c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)]
489                    )
490                    with ir.InsertionPoint(for_op.body):
491                        # Step 6.3.1. Wait mbar1
492                        phaseParity = for_op.inner_iter_args[1]
493                        iv = for_op.induction_variable
494                        stage = arith.remui(iv, c(num_stages))
495                        debug_print(
496                            "[cons] iv={}  | mbarTMA[{}] try_wait   phase={}",
497                            iv,
498                            stage,
499                            phaseParity,
500                            predicate=consumerPrimaryThread,
501                        )
502                        nvgpu.MBarrierTryWaitParityOp(
503                            mbarTMA, phaseParity, ticks, mbarId=stage
504                        )
505                        debug_print(
506                            "[cons] iv={}  | mbarTMA[{}] try_wait   phase={} [done]",
507                            iv,
508                            stage,
509                            phaseParity,
510                            predicate=consumerPrimaryThread,
511                        )
512
513                        # Step 6.3.2. Create WGMMA Descriptors
514                        a_offset = arith.muli(stage, c(lhs_tile_bytes))
515                        a_tile_slice = memref.view(
516                            ir.MemRefType.get(
517                                a_tile_shape, a_elem_ty, memory_space=smem_space
518                            ),
519                            dynamic_smem,
520                            a_offset,
521                            [],
522                        )
523                        b_offset = arith.addi(
524                            arith.muli(stage, c(rhs_tile_bytes)),
525                            c(lhs_tile_bytes * num_stages),
526                        )
527                        b_tile_slice = memref.view(
528                            ir.MemRefType.get(
529                                b_tile_shape, b_elem_ty, memory_space=smem_space
530                            ),
531                            dynamic_smem,
532                            b_offset,
533                            [],
534                        )
535                        debug_print(
536                            "[cons] a_offset={} b_offset={}",
537                            a_offset,
538                            b_offset,
539                            predicate=consumerPrimaryThread,
540                        )
541                        da = nvgpu.WarpgroupGenerateDescriptorOp(
542                            a_wgmma_ty, a_tile_slice, a_tma_desc_op
543                        )
544                        db = nvgpu.WarpgroupGenerateDescriptorOp(
545                            b_wgmma_ty, b_tile_slice, b_tma_desc_op
546                        )
547
548                        # Step 6.3.3. MMA
549                        carry_acc = for_op.inner_iter_args[0]
550                        new_acc = nvgpu.WarpgroupMmaOp(
551                            acc.type, da, db, carry_acc, transposeB=True
552                        )
553
554                        # Step 6.3.4. Arrive mbarDONE
555                        if num_stages == 1:
556                            p_arrive = consumerPrimaryThread
557                        else:
558                            p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
559                            p_arrive = arith.andi(consumerPrimaryThread, p1)
560                        with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
561                            p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
562                            barId = arith.select(
563                                p, c(num_stages - 1), arith.subi(stage, c(1))
564                            )
565                            debug_print(
566                                "[cons] iv={}  | mbarDONE[{}] arrive ",
567                                iv,
568                                barId,
569                                predicate=consumerPrimaryThread,
570                            )
571                            nvgpu.mbarrier_arrive(mbarDONE, barId)
572                            debug_print(
573                                "[cons] iv={}  | mbarDONE[{}] arrive [done]",
574                                iv,
575                                barId,
576                                predicate=consumerPrimaryThread,
577                            )
578                            scf.yield_([])
579
580                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
581                        phaseParity = arith.select(
582                            p,
583                            arith.xori(phaseParity, arith.constant(T.bool(), 1)),
584                            phaseParity,
585                        )
586
587                        # Step 6.3.5. Yield
588                        scf.yield_([new_acc, phaseParity])
589
590                    with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block):
591                        barId = c((K // BLOCK_K) % num_stages)
592                        nvgpu.mbarrier_arrive(mbarDONE, barId)
593                        scf.yield_([])
594
595                    # Step 6.4. Epilogue (registers --> shared memory)
596                    acc_smem_ty = ir.MemRefType.get(
597                        (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
598                    )
599                    acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
600                    debug_print("[cons]  | Storing", predicate=consumerPrimaryThread)
601                    nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
602                    scf.yield_([])
603                gpu.barrier()
604
605                # GPU Step 9. Epilogue (shared memory --> global memory)
606                fd = ir.MemRefType.get(
607                    [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
608                )
609                collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
610                rty = ir.MemRefType.get(
611                    (BLOCK_M, BLOCK_N),
612                    c_elem_ty,
613                    ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
614                )
615                c_device_per_block = memref.SubViewOp(
616                    rty,
617                    c_device,
618                    [dimX, dimY],
619                    [],
620                    [],
621                    [MLIR_DYNAMIC, MLIR_DYNAMIC],
622                    [BLOCK_M, BLOCK_N],
623                    [1, 1],
624                )
625                vlen = 1
626                for_op = scf.ForOp(
627                    tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2)
628                )
629                with ir.InsertionPoint(for_op.body):
630                    x = arith.divui(for_op.induction_variable, c(BLOCK_M))
631                    y = arith.remui(for_op.induction_variable, c(BLOCK_N))
632                    vdata = vector.load(
633                        ir.VectorType.get((vlen,), c_elem_ty),
634                        collapsed_smem,
635                        [for_op.induction_variable],
636                    )
637                    vector.store(vdata, c_device_per_block, [x, y])
638                    scf.yield_([])
639
640                gpu.terminator()
641
642            # Step 4. Copy back to host
643            t8 = gpu.wait(token_ty, [launch_op])
644            t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
645            gpu.dealloc(token_ty, [t8], a_device)
646            gpu.dealloc(token_ty, [t8], b_device)
647            gpu.wait(token_ty, [t9])
648            gpu.dealloc(token_ty, [t8], c_device)
649            func.ReturnOp([])
650
651    fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
652    module.operation.verify()
653    return module
654
655
656def generate_matmul_multistage(
657    input_type=np.float16,
658    output_type=np.float32,
659    M=4096,
660    N=4096,
661    K=4096,
662    BLOCK_M=128,
663    BLOCK_N=128,
664    BLOCK_K=64,
665    num_stages=3,
666):
667    # Limitaitons for now
668    assert input_type == np.float16
669    assert output_type == np.float32
670    assert BLOCK_M == 128
671    assert BLOCK_N == 128
672    assert BLOCK_K == 64
673    assert M % BLOCK_M == 0
674    assert N % BLOCK_N == 0
675    assert K % BLOCK_K == 0
676
677    module = ir.Module.create()
678    token_ty = gpu.AsyncTokenType.get()
679    a_elem_ty = get_mlir_ty(input_type)
680    b_elem_ty = get_mlir_ty(input_type)
681    c_elem_ty = get_mlir_ty(output_type)
682    a_ty = ir.MemRefType.get([M, K], a_elem_ty)
683    b_ty = ir.MemRefType.get((K, N), b_elem_ty)
684    c_ty = ir.MemRefType.get((M, N), c_elem_ty)
685    a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
686    b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
687    b_tile_shape = (BLOCK_K, BLOCK_N)
688    txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + (
689        a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty)
690    )
691    smem_space_str = "#gpu.address_space<workgroup>"
692    smem_space = ir.Attribute.parse(smem_space_str)
693    mbar_ty = ir.Type.parse(
694        "!nvgpu.mbarrier.group<memorySpace = "
695        + str(smem_space)
696        + ", num_barriers = "
697        + str(num_stages)
698        + ">"
699    )
700    acc_ty = ir.Type.parse(
701        "!nvgpu.warpgroup.accumulator<fragmented=vector<"
702        + str(BLOCK_M)
703        + "x"
704        + str(BLOCK_N)
705        + "x"
706        + str(c_elem_ty)
707        + ">>"
708    )
709    a_wgmma_ty = ir.Type.parse(
710        "!nvgpu.warpgroup.descriptor<tensor=memref<"
711        + str(BLOCK_M)
712        + "x"
713        + str(BLOCK_K)
714        + "x"
715        + str(a_elem_ty)
716        + ", "
717        + smem_space_str
718        + ">>"
719    )
720    b_wgmma_ty = ir.Type.parse(
721        "!nvgpu.warpgroup.descriptor<tensor=memref<"
722        + str(BLOCK_K)
723        + "x"
724        + str(BLOCK_N)
725        + "x"
726        + str(a_elem_ty)
727        + ", "
728        + smem_space_str
729        + ">>"
730    )
731
732    with ir.InsertionPoint(module.body):
733        kernelName = make_kernel_name(
734            input_type,
735            output_type,
736            M,
737            N,
738            K,
739            BLOCK_M,
740            BLOCK_N,
741            BLOCK_K,
742            num_stages,
743            False,
744        )
745        fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], []))
746        with ir.InsertionPoint(fop.add_entry_block()):
747            a_host = fop.arguments[0]
748            b_host = fop.arguments[1]
749            c_host = fop.arguments[2]
750            lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
751            rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
752            smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
753            smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty)
754            smem_size = max(smem_size_input, smem_size_output)
755
756            # Step 1. Allocate device memory and memcpy
757            t1 = gpu.wait(token_ty, [])
758            a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
759            b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
760            c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
761            t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
762            t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
763            t7 = gpu.wait(token_ty, [t6])
764
765            # Step 2. Create TMA Descriptors
766            a_tma_desc = TmaDescriptorBuilder(
767                nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
768                nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
769                nvgpu.TensorMapOOBKind.OOB_ZERO,
770                nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
771                a_tma_shape,
772                a_ty,
773            )
774
775            b_tma_desc = TmaDescriptorBuilder(
776                nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
777                nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
778                nvgpu.TensorMapOOBKind.OOB_ZERO,
779                nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
780                b_tma_shape,
781                b_ty,
782            )
783
784            a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device)
785            b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device)
786
787            # Step 3. Launch Kernel with 1 Warpgroup
788            cta_m = M // BLOCK_M
789            cta_n = N // BLOCK_N
790            assert M % BLOCK_M == 0 and N % BLOCK_N == 0
791            grid = (cta_m, cta_n, 1)
792            block = (WARP_GROUP_SIZE, 1, 1)
793            launch_op = gpu.LaunchOp(
794                token_ty,
795                [t7],
796                *map(c, grid),
797                *map(c, block),
798                dynamicSharedMemorySize=c(smem_size, ty=T.i32()),
799            )
800            launch_op.body.blocks.append(*([T.index()] * 12))
801            with ir.InsertionPoint(launch_op.body.blocks[0]):
802                # GPU Step 0. Bootstrapping
803                memref.assume_alignment(c_device, 16)
804                dynamic_smem = gpu.dynamic_shared_memory(
805                    ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
806                )
807                ticks = c(10000000)
808                tidx = gpu.thread_id(gpu.Dimension.x)
809                primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0))
810                warpId = arith.divui(tidx, c(32))
811                bidx = gpu.block_id(gpu.Dimension.x)
812                bidy = gpu.block_id(gpu.Dimension.y)
813                dimX = arith.muli(bidx, c(BLOCK_M))
814                dimY = arith.muli(bidy, c(BLOCK_N))
815
816                # GPU Step 1. Initialize mbarrier groups
817                mbarTMA = nvgpu.mbarrier_create(mbar_ty)
818                for i in range(num_stages):
819                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread)
820                gpu.barrier()
821
822                # GPU Step 2. Prefetch TMA descriptors
823                nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=primaryThread)
824                nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=primaryThread)
825
826                # GPU Step 3. Prologue (global memory --> shared memory)
827                ns = num_stages if num_stages == 1 else num_stages - 1
828                for_op = scf.ForOp(c(0), c(ns), c(1))
829                with ir.InsertionPoint(for_op.body):
830                    iv = for_op.induction_variable
831
832                    # Step 3.1. Calculate offsets
833                    a_offset = arith.muli(iv, c(lhs_tile_bytes))
834                    a_tma_slice = memref.view(
835                        ir.MemRefType.get(
836                            a_tma_shape, a_elem_ty, memory_space=smem_space
837                        ),
838                        dynamic_smem,
839                        a_offset,
840                        [],
841                    )
842                    b_offset = arith.addi(
843                        arith.muli(iv, c(rhs_tile_bytes)),
844                        c(lhs_tile_bytes * num_stages),
845                    )
846                    b_tma_slice_1 = memref.view(
847                        ir.MemRefType.get(
848                            b_tma_shape, b_elem_ty, memory_space=smem_space
849                        ),
850                        dynamic_smem,
851                        b_offset,
852                        [],
853                    )
854                    b_offset2 = arith.addi(
855                        b_offset,
856                        c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
857                    )
858                    b_tma_slice_2 = memref.view(
859                        ir.MemRefType.get(
860                            b_tma_shape, b_elem_ty, memory_space=smem_space
861                        ),
862                        dynamic_smem,
863                        b_offset2,
864                        [],
865                    )
866
867                    # Step 3.2. TMA Load
868                    coord = arith.muli(c(64), iv)
869                    dimY2 = arith.addi(dimY, c(64))
870                    debug_print(
871                        "[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
872                        a_offset,
873                        b_offset,
874                        b_offset2,
875                        coord,
876                        dimX,
877                        dimY,
878                        coord,
879                        predicate=primaryThread,
880                    )
881                    nvgpu.TmaAsyncLoadOp(
882                        a_tma_slice,
883                        mbarTMA,
884                        a_tma_desc_op,
885                        coordinates=[coord, dimX],
886                        mbarId=iv,
887                        predicate=primaryThread,
888                    )
889                    nvgpu.TmaAsyncLoadOp(
890                        b_tma_slice_1,
891                        mbarTMA,
892                        b_tma_desc_op,
893                        coordinates=[dimY, coord],
894                        mbarId=iv,
895                        predicate=primaryThread,
896                    )
897                    nvgpu.TmaAsyncLoadOp(
898                        b_tma_slice_2,
899                        mbarTMA,
900                        b_tma_desc_op,
901                        coordinates=[dimY2, coord],
902                        mbarId=iv,
903                        predicate=primaryThread,
904                    )
905
906                    # Step 3.2. mbarTMA arrive
907                    debug_print(
908                        "[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread
909                    )
910                    nvgpu.mbarrier_arrive_expect_tx(
911                        mbarTMA, c(txcount), iv, predicate=primaryThread
912                    )
913                    debug_print(
914                        "[Prologue] mbarTMA[{}] arrive [done]",
915                        iv,
916                        predicate=primaryThread,
917                    )
918                    scf.yield_([])
919
920                # GPU Step 4. Main Loop
921                acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
922                for_op = scf.ForOp(
923                    c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)]
924                )
925                with ir.InsertionPoint(for_op.body):
926                    # Step 4.1. Wait mbarTMA
927                    phaseParity = for_op.inner_iter_args[1]
928                    iv = for_op.induction_variable
929                    stage = arith.remui(iv, c(num_stages))
930                    debug_print(
931                        "[MainLoop] mbarTMA[{}] try_wait   phase={}",
932                        stage,
933                        phaseParity,
934                        predicate=primaryThread,
935                    )
936                    nvgpu.MBarrierTryWaitParityOp(
937                        mbarTMA, phaseParity, ticks, mbarId=stage
938                    )
939                    debug_print(
940                        "[MainLoop] mbarTMA[{}] try_wait   phase={} [done]",
941                        stage,
942                        phaseParity,
943                        predicate=primaryThread,
944                    )
945
946                    # Step 4.2. Create WGMMA Descriptors
947                    a_offset = arith.muli(stage, c(lhs_tile_bytes))
948                    a_tile_slice = memref.view(
949                        ir.MemRefType.get(
950                            a_tile_shape, a_elem_ty, memory_space=smem_space
951                        ),
952                        dynamic_smem,
953                        a_offset,
954                        [],
955                    )
956                    b_offset = arith.addi(
957                        arith.muli(stage, c(rhs_tile_bytes)),
958                        c(lhs_tile_bytes * num_stages),
959                    )
960                    b_tile_slice = memref.view(
961                        ir.MemRefType.get(
962                            b_tile_shape, b_elem_ty, memory_space=smem_space
963                        ),
964                        dynamic_smem,
965                        b_offset,
966                        [],
967                    )
968                    debug_print(
969                        "[MainLoop] iv={} MMA a_offset={} b_offset={}",
970                        iv,
971                        a_offset,
972                        b_offset,
973                        predicate=primaryThread,
974                    )
975                    da = nvgpu.WarpgroupGenerateDescriptorOp(
976                        a_wgmma_ty, a_tile_slice, a_tma_desc_op
977                    )
978                    db = nvgpu.WarpgroupGenerateDescriptorOp(
979                        b_wgmma_ty, b_tile_slice, b_tma_desc_op
980                    )
981
982                    # Step 4.3. MMA
983                    carry_acc = for_op.inner_iter_args[0]
984                    new_acc = nvgpu.WarpgroupMmaOp(
985                        acc.type, da, db, carry_acc, transposeB=True
986                    )
987                    if num_stages == 1:
988                        nvvm.WgmmaWaitGroupSyncOp(0)
989
990                    # Step 4.4. Load TMA for next stage
991                    p1 = arith.cmpi(
992                        arith.CmpIPredicate.ult,
993                        arith.addi(iv, c(ns)),
994                        c(K // BLOCK_K),
995                    )
996                    p = arith.andi(primaryThread, p1)
997                    nextStage = arith.addi(iv, c(ns))
998                    nextSlot = arith.remui(nextStage, c(num_stages))
999                    a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
1000
1001                    debug_print(
1002                        "[MainLoop] mbarTMA[{}] arrive",
1003                        nextSlot,
1004                        predicate=p,
1005                    )
1006                    nvgpu.mbarrier_arrive_expect_tx(
1007                        mbarTMA, c(txcount), nextSlot, predicate=p
1008                    )
1009                    debug_print(
1010                        "[MainLoop] mbarTMA[{}] arrive [done]",
1011                        nextSlot,
1012                        predicate=p,
1013                    )
1014
1015                    a_tma_slice = memref.view(
1016                        ir.MemRefType.get(
1017                            a_tma_shape, a_elem_ty, memory_space=smem_space
1018                        ),
1019                        dynamic_smem,
1020                        a_offset,
1021                        [],
1022                    )
1023                    b_offset = arith.addi(
1024                        arith.muli(nextSlot, c(rhs_tile_bytes)),
1025                        c(lhs_tile_bytes * num_stages),
1026                    )
1027                    b_tma_slice_1 = memref.view(
1028                        ir.MemRefType.get(
1029                            b_tma_shape, b_elem_ty, memory_space=smem_space
1030                        ),
1031                        dynamic_smem,
1032                        b_offset,
1033                        [],
1034                    )
1035                    b_offset2 = arith.addi(
1036                        b_offset,
1037                        c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
1038                    )
1039                    b_tma_slice_2 = memref.view(
1040                        ir.MemRefType.get(
1041                            b_tma_shape, b_elem_ty, memory_space=smem_space
1042                        ),
1043                        dynamic_smem,
1044                        b_offset2,
1045                        [],
1046                    )
1047
1048                    coord = arith.muli(c(64), nextStage)
1049                    debug_print(
1050                        "[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
1051                        iv,
1052                        a_offset,
1053                        b_offset,
1054                        b_offset2,
1055                        coord,
1056                        dimX,
1057                        dimY,
1058                        coord,
1059                        predicate=p,
1060                    )
1061                    nvgpu.TmaAsyncLoadOp(
1062                        a_tma_slice,
1063                        mbarTMA,
1064                        a_tma_desc_op,
1065                        coordinates=[coord, dimX],
1066                        mbarId=nextSlot,
1067                        predicate=p,
1068                    )
1069                    nvgpu.TmaAsyncLoadOp(
1070                        b_tma_slice_1,
1071                        mbarTMA,
1072                        b_tma_desc_op,
1073                        coordinates=[dimY, coord],
1074                        mbarId=nextSlot,
1075                        predicate=p,
1076                    )
1077                    dimY2 = arith.addi(dimY, c(64))
1078                    nvgpu.TmaAsyncLoadOp(
1079                        b_tma_slice_2,
1080                        mbarTMA,
1081                        b_tma_desc_op,
1082                        coordinates=[dimY2, coord],
1083                        mbarId=nextSlot,
1084                        predicate=p,
1085                    )
1086                    # Step 4.5. Change the phaseParity
1087                    p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
1088                    phaseParity = arith.select(
1089                        p,
1090                        arith.xori(phaseParity, arith.constant(T.bool(), 1)),
1091                        phaseParity,
1092                    )
1093
1094                    # Step 4.5. Yield
1095                    scf.yield_([new_acc, phaseParity])
1096
1097                # Step 5. Wait All WGMMA groups
1098                nvvm.WgmmaWaitGroupSyncOp(0)
1099
1100                # Step 6. Epilogue (registers --> shared memory)
1101                acc_smem_ty = ir.MemRefType.get(
1102                    (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
1103                )
1104                acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
1105                debug_print("Storing", predicate=primaryThread)
1106                nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
1107                gpu.barrier()
1108
1109                # GPU Step 7. Epilogue (shared memory --> global memory)
1110                fd = ir.MemRefType.get(
1111                    [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
1112                )
1113                collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
1114                rty = ir.MemRefType.get(
1115                    (BLOCK_M, BLOCK_N),
1116                    c_elem_ty,
1117                    ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
1118                )
1119                c_device_per_block = memref.SubViewOp(
1120                    rty,
1121                    c_device,
1122                    [dimX, dimY],
1123                    [],
1124                    [],
1125                    [MLIR_DYNAMIC, MLIR_DYNAMIC],
1126                    [BLOCK_M, BLOCK_N],
1127                    [1, 1],
1128                )
1129                vlen = 1
1130                for_op = scf.ForOp(
1131                    tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE)
1132                )
1133                with ir.InsertionPoint(for_op.body):
1134                    x = arith.divui(for_op.induction_variable, c(BLOCK_M))
1135                    y = arith.remui(for_op.induction_variable, c(BLOCK_N))
1136                    vdata = vector.load(
1137                        ir.VectorType.get((vlen,), c_elem_ty),
1138                        collapsed_smem,
1139                        [for_op.induction_variable],
1140                    )
1141                    vector.store(vdata, c_device_per_block, [x, y])
1142                    scf.yield_([])
1143
1144                gpu.terminator()
1145
1146            # Step 4. Copy back to host
1147            t8 = gpu.wait(token_ty, [launch_op])
1148            t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
1149            gpu.dealloc(token_ty, [t8], a_device)
1150            gpu.dealloc(token_ty, [t8], b_device)
1151            gpu.wait(token_ty, [t9])
1152            gpu.dealloc(token_ty, [t8], c_device)
1153            func.ReturnOp([])
1154
1155    fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
1156    module.operation.verify()
1157    return module
1158