import numpy as np from mlir import ir from mlir.dialects import arith from mlir.dialects import func from mlir.dialects import gpu from mlir.dialects import memref from mlir.dialects import nvgpu from mlir.dialects import nvvm from mlir.dialects import llvm from mlir.dialects import builtin from mlir.dialects import scf from mlir.dialects import vector from mlir.extras import types as T TMA_LAST_DIM_F16 = 64 # 128B flaot16 WARP_SIZE = 32 WARP_GROUP_SIZE = WARP_SIZE * 4 PRODUCER_REGISTER_SIZE = 40 CONSUMER_REGISTER_SIZE = 232 PRODUCER_PRIMARY_THREAD = 128 CONSUMER_PRIMARY_THREAD = 0 # C++ uses this value to understand whether it's dynamic or not. MLIR_DYNAMIC = -9223372036854775808 DEBUG = False class TmaDescriptorBuilder: """A class that builds a TMA descriptor.""" def __init__(self, swizzle, l2promo, oob, interleave, tma_box_shape, memref_ty): self.swizzle = swizzle # mlir.nvgpu.TensorMapSwizzleKind self.l2promo = l2promo # mlir.nvgpu.TensorMapL2PromoKind self.oob = oob # mlir.nvgpu.TensorMapOOBKind self.interleave = interleave # mlir.nvgpu.TensorMapInterleaveKind self.tma_box_shape = tma_box_shape self.memref_ty = memref_ty # MemRefType @property def tensormap_descriptor_ty(self): """Returns a tensormap descriptor type.""" tensorMemrefType = ir.MemRefType.get( self.tma_box_shape, self.memref_ty.element_type, memory_space=ir.Attribute.parse("3"), ) return nvgpu.TensorMapDescriptorType.get( tensorMemrefType, self.swizzle, self.l2promo, self.oob, self.interleave, ) def tma_descriptor_op(self, device_ptr): """Returns a tensormap descriptor op.""" tma_descriptor_ty = self.tensormap_descriptor_ty device_unranked_memref = memref.CastOp( ir.UnrankedMemRefType.get( self.memref_ty.element_type, self.memref_ty.memory_space ), device_ptr, ) tma_descriptor_op = nvgpu.TmaCreateDescriptorOp( tma_descriptor_ty, device_unranked_memref, map(c, self.tma_box_shape) ) return tma_descriptor_op.result def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False): if not DEBUG and not forcePrint: return type_formats = [] for arg in args: ty_format = None if ir.IndexType.isinstance(arg.type): ty_format = "%llu" if ir.IntegerType.isinstance(arg.type): width = ir.IntegerType(arg.type).width if width == 64: ty_format = "%llu" elif width == 32: ty_format = "%d" elif width == 1: ty_format = "%i" if ir.F32Type.isinstance(arg.type): ty_format = "%f" if ty_format is None: raise NotImplementedError(arg.type) type_formats.append(ty_format) if threadNumber != -1: tidx = gpu.thread_id(gpu.Dimension.x) predicate = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(threadNumber)) scf.yield_([]) if_op = scf.IfOp(predicate) with ir.InsertionPoint(if_op.then_block): gpu.printf(fmt.format(*type_formats) + "\n", args) scf.yield_([]) def get_type_size(ty): if ir.FloatType.isinstance(ty): return ir.FloatType(ty).width // 8 if ir.IntegerType.isinstance(ty): return ir.IntegerType(ty).width // 8 raise NotImplementedError(ty) def get_mlir_ty(dtype): if dtype == np.float16: return T.f16() if dtype == np.float32: return T.f32() if dtype == np.float64: return T.f64() if dtype == np.int32: return T.i32() if dtype == np.int64: return T.i64() raise NotImplementedError(dtype) def c(value, ty=None): ty = T.index() if ty is None else ty return arith.constant(ty, value) def make_kernel_name( input_type=np.float16, output_type=np.float32, M=4096, N=4096, K=4096, BLOCK_M=128, BLOCK_N=128, BLOCK_K=128, num_stages=3, use_warp_specialization=False, ): kernelName = "warpspecialized" if use_warp_specialization else "multistage" return ( kernelName + "_" + str(M) + "x" + str(N) + "x" + str(K) + "_" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) + "_" + str(num_stages) ) def generate_matmul_ws( input_type=np.float16, output_type=np.float32, M=4096, N=4096, K=4096, BLOCK_M=128, BLOCK_N=128, BLOCK_K=128, num_stages=3, ): # Limitaitons for now assert input_type == np.float16 assert output_type == np.float32 assert BLOCK_M == 128 assert BLOCK_N == 128 assert BLOCK_K == 64 assert M % BLOCK_M == 0 assert N % BLOCK_N == 0 assert K % BLOCK_K == 0 module = ir.Module.create() token_ty = gpu.AsyncTokenType.get() a_elem_ty = get_mlir_ty(input_type) b_elem_ty = get_mlir_ty(input_type) c_elem_ty = get_mlir_ty(output_type) a_ty = ir.MemRefType.get([M, K], a_elem_ty) b_ty = ir.MemRefType.get((K, N), b_elem_ty) c_ty = ir.MemRefType.get((M, N), c_elem_ty) a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16) b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16) b_tile_shape = (BLOCK_K, BLOCK_N) txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + ( a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty) ) smem_space_str = "#gpu.address_space" smem_space = ir.Attribute.parse(smem_space_str) mbar_ty = ir.Type.parse( "!nvgpu.mbarrier.group" ) acc_ty = ir.Type.parse( "!nvgpu.warpgroup.accumulator>" ) a_wgmma_ty = ir.Type.parse( "!nvgpu.warpgroup.descriptor>" ) b_wgmma_ty = ir.Type.parse( "!nvgpu.warpgroup.descriptor>" ) kernelName = make_kernel_name( input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_stages, True ) with ir.InsertionPoint(module.body): fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], [])) with ir.InsertionPoint(fop.add_entry_block()): a_host = fop.arguments[0] b_host = fop.arguments[1] c_host = fop.arguments[2] lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty) rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty) smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty) smem_size = max(smem_size_input, smem_size_output) # Step 1. Allocate device memory and memcpy t1 = gpu.wait(token_ty, []) a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], []) b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], []) c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], []) t5 = gpu.memcpy(token_ty, [t4], a_device, a_host) t6 = gpu.memcpy(token_ty, [t5], b_device, b_host) t7 = gpu.wait(token_ty, [t6]) # Step 2. Create TMA Descriptors a_tma_desc = TmaDescriptorBuilder( nvgpu.TensorMapSwizzleKind.SWIZZLE_128B, nvgpu.TensorMapL2PromoKind.L2PROMO_NONE, nvgpu.TensorMapOOBKind.OOB_ZERO, nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE, a_tma_shape, a_ty, ) b_tma_desc = TmaDescriptorBuilder( nvgpu.TensorMapSwizzleKind.SWIZZLE_128B, nvgpu.TensorMapL2PromoKind.L2PROMO_NONE, nvgpu.TensorMapOOBKind.OOB_ZERO, nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE, b_tma_shape, b_ty, ) a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device) b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device) # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer cta_m = M // BLOCK_M cta_n = N // BLOCK_N assert M % BLOCK_M == 0 and N % BLOCK_N == 0 grid = (cta_m, cta_n, 1) block = (WARP_GROUP_SIZE * 2, 1, 1) launch_op = gpu.LaunchOp( token_ty, [t7], *map(c, grid), *map(c, block), dynamicSharedMemorySize=c(smem_size, ty=T.i32()), ) launch_op.body.blocks.append(*([T.index()] * 12)) with ir.InsertionPoint(launch_op.body.blocks[0]): # GPU Step 0. This is need for vectorized ld/st memref.assume_alignment(c_device, 16) dynamic_smem = gpu.dynamic_shared_memory( ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space) ) ticks = c(10000000) # GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc. tidx = gpu.thread_id(gpu.Dimension.x) wgPrimaryThread = arith.cmpi( arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0) ) warp_id = arith.divui(tidx, c(32)) warpgroup_id = arith.divui(warp_id, c(4)) is_producer = arith.cmpi( arith.CmpIPredicate.eq, warpgroup_id, c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0), ) is_consumer = arith.cmpi( arith.CmpIPredicate.eq, warpgroup_id, c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1), ) producerPrimaryThread = arith.cmpi( arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD) ) consumerPrimaryThread = arith.cmpi( arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD) ) bidx = gpu.block_id(gpu.Dimension.x) bidy = gpu.block_id(gpu.Dimension.y) dimX = arith.muli(bidx, c(BLOCK_M)) dimY = arith.muli(bidy, c(BLOCK_N)) # GPU Step 2. Initialize mbarrier groups mbarTMA = nvgpu.mbarrier_create(mbar_ty) mbarDONE = nvgpu.mbarrier_create(mbar_ty) for i in range(num_stages): nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread) nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread) gpu.barrier() # GPU Step 3. Prefetch TMA descriptors nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=wgPrimaryThread) nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=wgPrimaryThread) ns = num_stages if num_stages == 1 else num_stages - 1 # GPU Step 5. Producer Warpgroup (TMA Warpgroup) with ir.InsertionPoint(scf.IfOp(is_producer).then_block): # Step 5.1. Reduce register size nvvm.setmaxregister( PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease ) # Step 5.2. TMA Main Loop for_op = scf.ForOp( c(0), c(K // BLOCK_K), c(1), [arith.constant(T.bool(), 1)] ) with ir.InsertionPoint(for_op.body): phaseParity = for_op.inner_iter_args[0] iv = for_op.induction_variable stage = arith.remui(iv, c(num_stages)) # Step 5.2.1. Wait mbarDONE debug_print( "[prod] iv={} | mbarDONE[{}] try_wait phase={}", iv, stage, phaseParity, predicate=producerPrimaryThread, ) nvgpu.MBarrierTryWaitParityOp( mbarDONE, phaseParity, ticks, mbarId=stage ) debug_print( "[prod] iv={} | mbarDONE[{}] try_wait phase={} [done]", iv, stage, phaseParity, predicate=producerPrimaryThread, ) p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1)) phaseParity = arith.select( p, arith.xori(phaseParity, arith.constant(T.bool(), 1)), phaseParity, ) # Step 5.2.2. Load TMA a_offset = arith.muli(stage, c(lhs_tile_bytes)) a_tma_slice = memref.view( ir.MemRefType.get( a_tma_shape, a_elem_ty, memory_space=smem_space ), dynamic_smem, a_offset, [], ) b_offset = arith.addi( arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages), ) b_tma_slice_1 = memref.view( ir.MemRefType.get( b_tma_shape, b_elem_ty, memory_space=smem_space ), dynamic_smem, b_offset, [], ) b_offset2 = arith.addi( b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)), ) b_tma_slice_2 = memref.view( ir.MemRefType.get( b_tma_shape, b_elem_ty, memory_space=smem_space ), dynamic_smem, b_offset2, [], ) debug_print( "[prod] a_offset={} b_offset={} b_offset2={}", a_offset, b_offset, b_offset2, predicate=producerPrimaryThread, ) coord = arith.muli(c(64), iv) nvgpu.TmaAsyncLoadOp( a_tma_slice, mbarTMA, a_tma_desc_op, coordinates=[coord, dimX], mbarId=stage, predicate=producerPrimaryThread, ) nvgpu.TmaAsyncLoadOp( b_tma_slice_1, mbarTMA, b_tma_desc_op, coordinates=[dimY, coord], mbarId=stage, predicate=producerPrimaryThread, ) dimY2 = arith.addi(dimY, c(64)) nvgpu.TmaAsyncLoadOp( b_tma_slice_2, mbarTMA, b_tma_desc_op, coordinates=[dimY2, coord], mbarId=stage, predicate=producerPrimaryThread, ) # Step 5.2.3. Arrive mbarTMA debug_print( "[prod] iv={} | mbarTMA[{}] arrive", iv, stage, predicate=producerPrimaryThread, ) nvgpu.mbarrier_arrive_expect_tx( mbarTMA, c(txcount), stage, predicate=producerPrimaryThread ) debug_print( "[prod] iv={} | mbarTMA[{}] arrive [done]", iv, stage, predicate=producerPrimaryThread, ) scf.yield_([phaseParity]) scf.yield_([]) # GPU Step 6. Consumer Warpgroup (MMA Warpgroup) if_op = scf.IfOp(is_consumer) with ir.InsertionPoint(if_op.then_block): # Step 6.1. Increase register size nvvm.setmaxregister( CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase ) # GPU Step 6.2. Initialize MMA registers acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty) # Step 6.3. MMA Main Loop for_op = scf.ForOp( c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)] ) with ir.InsertionPoint(for_op.body): # Step 6.3.1. Wait mbar1 phaseParity = for_op.inner_iter_args[1] iv = for_op.induction_variable stage = arith.remui(iv, c(num_stages)) debug_print( "[cons] iv={} | mbarTMA[{}] try_wait phase={}", iv, stage, phaseParity, predicate=consumerPrimaryThread, ) nvgpu.MBarrierTryWaitParityOp( mbarTMA, phaseParity, ticks, mbarId=stage ) debug_print( "[cons] iv={} | mbarTMA[{}] try_wait phase={} [done]", iv, stage, phaseParity, predicate=consumerPrimaryThread, ) # Step 6.3.2. Create WGMMA Descriptors a_offset = arith.muli(stage, c(lhs_tile_bytes)) a_tile_slice = memref.view( ir.MemRefType.get( a_tile_shape, a_elem_ty, memory_space=smem_space ), dynamic_smem, a_offset, [], ) b_offset = arith.addi( arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages), ) b_tile_slice = memref.view( ir.MemRefType.get( b_tile_shape, b_elem_ty, memory_space=smem_space ), dynamic_smem, b_offset, [], ) debug_print( "[cons] a_offset={} b_offset={}", a_offset, b_offset, predicate=consumerPrimaryThread, ) da = nvgpu.WarpgroupGenerateDescriptorOp( a_wgmma_ty, a_tile_slice, a_tma_desc_op ) db = nvgpu.WarpgroupGenerateDescriptorOp( b_wgmma_ty, b_tile_slice, b_tma_desc_op ) # Step 6.3.3. MMA carry_acc = for_op.inner_iter_args[0] new_acc = nvgpu.WarpgroupMmaOp( acc.type, da, db, carry_acc, transposeB=True ) # Step 6.3.4. Arrive mbarDONE if num_stages == 1: p_arrive = consumerPrimaryThread else: p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0)) p_arrive = arith.andi(consumerPrimaryThread, p1) with ir.InsertionPoint(scf.IfOp(p_arrive).then_block): p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0)) barId = arith.select( p, c(num_stages - 1), arith.subi(stage, c(1)) ) debug_print( "[cons] iv={} | mbarDONE[{}] arrive ", iv, barId, predicate=consumerPrimaryThread, ) nvgpu.mbarrier_arrive(mbarDONE, barId) debug_print( "[cons] iv={} | mbarDONE[{}] arrive [done]", iv, barId, predicate=consumerPrimaryThread, ) scf.yield_([]) p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1)) phaseParity = arith.select( p, arith.xori(phaseParity, arith.constant(T.bool(), 1)), phaseParity, ) # Step 6.3.5. Yield scf.yield_([new_acc, phaseParity]) with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block): barId = c((K // BLOCK_K) % num_stages) nvgpu.mbarrier_arrive(mbarDONE, barId) scf.yield_([]) # Step 6.4. Epilogue (registers --> shared memory) acc_smem_ty = ir.MemRefType.get( (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space ) acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), []) debug_print("[cons] | Storing", predicate=consumerPrimaryThread) nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem) scf.yield_([]) gpu.barrier() # GPU Step 9. Epilogue (shared memory --> global memory) fd = ir.MemRefType.get( [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space ) collapsed_smem = memref.view(fd, dynamic_smem, c(0), []) rty = ir.MemRefType.get( (BLOCK_M, BLOCK_N), c_elem_ty, ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"), ) c_device_per_block = memref.SubViewOp( rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC], [BLOCK_M, BLOCK_N], [1, 1], ) vlen = 1 for_op = scf.ForOp( tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2) ) with ir.InsertionPoint(for_op.body): x = arith.divui(for_op.induction_variable, c(BLOCK_M)) y = arith.remui(for_op.induction_variable, c(BLOCK_N)) vdata = vector.load( ir.VectorType.get((vlen,), c_elem_ty), collapsed_smem, [for_op.induction_variable], ) vector.store(vdata, c_device_per_block, [x, y]) scf.yield_([]) gpu.terminator() # Step 4. Copy back to host t8 = gpu.wait(token_ty, [launch_op]) t9 = gpu.memcpy(token_ty, [t8], c_host, c_device) gpu.dealloc(token_ty, [t8], a_device) gpu.dealloc(token_ty, [t8], b_device) gpu.wait(token_ty, [t9]) gpu.dealloc(token_ty, [t8], c_device) func.ReturnOp([]) fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() module.operation.verify() return module def generate_matmul_multistage( input_type=np.float16, output_type=np.float32, M=4096, N=4096, K=4096, BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, num_stages=3, ): # Limitaitons for now assert input_type == np.float16 assert output_type == np.float32 assert BLOCK_M == 128 assert BLOCK_N == 128 assert BLOCK_K == 64 assert M % BLOCK_M == 0 assert N % BLOCK_N == 0 assert K % BLOCK_K == 0 module = ir.Module.create() token_ty = gpu.AsyncTokenType.get() a_elem_ty = get_mlir_ty(input_type) b_elem_ty = get_mlir_ty(input_type) c_elem_ty = get_mlir_ty(output_type) a_ty = ir.MemRefType.get([M, K], a_elem_ty) b_ty = ir.MemRefType.get((K, N), b_elem_ty) c_ty = ir.MemRefType.get((M, N), c_elem_ty) a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16) b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16) b_tile_shape = (BLOCK_K, BLOCK_N) txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + ( a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty) ) smem_space_str = "#gpu.address_space" smem_space = ir.Attribute.parse(smem_space_str) mbar_ty = ir.Type.parse( "!nvgpu.mbarrier.group" ) acc_ty = ir.Type.parse( "!nvgpu.warpgroup.accumulator>" ) a_wgmma_ty = ir.Type.parse( "!nvgpu.warpgroup.descriptor>" ) b_wgmma_ty = ir.Type.parse( "!nvgpu.warpgroup.descriptor>" ) with ir.InsertionPoint(module.body): kernelName = make_kernel_name( input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_stages, False, ) fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], [])) with ir.InsertionPoint(fop.add_entry_block()): a_host = fop.arguments[0] b_host = fop.arguments[1] c_host = fop.arguments[2] lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty) rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty) smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty) smem_size = max(smem_size_input, smem_size_output) # Step 1. Allocate device memory and memcpy t1 = gpu.wait(token_ty, []) a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], []) b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], []) c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], []) t5 = gpu.memcpy(token_ty, [t4], a_device, a_host) t6 = gpu.memcpy(token_ty, [t5], b_device, b_host) t7 = gpu.wait(token_ty, [t6]) # Step 2. Create TMA Descriptors a_tma_desc = TmaDescriptorBuilder( nvgpu.TensorMapSwizzleKind.SWIZZLE_128B, nvgpu.TensorMapL2PromoKind.L2PROMO_NONE, nvgpu.TensorMapOOBKind.OOB_ZERO, nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE, a_tma_shape, a_ty, ) b_tma_desc = TmaDescriptorBuilder( nvgpu.TensorMapSwizzleKind.SWIZZLE_128B, nvgpu.TensorMapL2PromoKind.L2PROMO_NONE, nvgpu.TensorMapOOBKind.OOB_ZERO, nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE, b_tma_shape, b_ty, ) a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device) b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device) # Step 3. Launch Kernel with 1 Warpgroup cta_m = M // BLOCK_M cta_n = N // BLOCK_N assert M % BLOCK_M == 0 and N % BLOCK_N == 0 grid = (cta_m, cta_n, 1) block = (WARP_GROUP_SIZE, 1, 1) launch_op = gpu.LaunchOp( token_ty, [t7], *map(c, grid), *map(c, block), dynamicSharedMemorySize=c(smem_size, ty=T.i32()), ) launch_op.body.blocks.append(*([T.index()] * 12)) with ir.InsertionPoint(launch_op.body.blocks[0]): # GPU Step 0. Bootstrapping memref.assume_alignment(c_device, 16) dynamic_smem = gpu.dynamic_shared_memory( ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space) ) ticks = c(10000000) tidx = gpu.thread_id(gpu.Dimension.x) primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0)) warpId = arith.divui(tidx, c(32)) bidx = gpu.block_id(gpu.Dimension.x) bidy = gpu.block_id(gpu.Dimension.y) dimX = arith.muli(bidx, c(BLOCK_M)) dimY = arith.muli(bidy, c(BLOCK_N)) # GPU Step 1. Initialize mbarrier groups mbarTMA = nvgpu.mbarrier_create(mbar_ty) for i in range(num_stages): nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread) gpu.barrier() # GPU Step 2. Prefetch TMA descriptors nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=primaryThread) nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=primaryThread) # GPU Step 3. Prologue (global memory --> shared memory) ns = num_stages if num_stages == 1 else num_stages - 1 for_op = scf.ForOp(c(0), c(ns), c(1)) with ir.InsertionPoint(for_op.body): iv = for_op.induction_variable # Step 3.1. Calculate offsets a_offset = arith.muli(iv, c(lhs_tile_bytes)) a_tma_slice = memref.view( ir.MemRefType.get( a_tma_shape, a_elem_ty, memory_space=smem_space ), dynamic_smem, a_offset, [], ) b_offset = arith.addi( arith.muli(iv, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages), ) b_tma_slice_1 = memref.view( ir.MemRefType.get( b_tma_shape, b_elem_ty, memory_space=smem_space ), dynamic_smem, b_offset, [], ) b_offset2 = arith.addi( b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)), ) b_tma_slice_2 = memref.view( ir.MemRefType.get( b_tma_shape, b_elem_ty, memory_space=smem_space ), dynamic_smem, b_offset2, [], ) # Step 3.2. TMA Load coord = arith.muli(c(64), iv) dimY2 = arith.addi(dimY, c(64)) debug_print( "[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})", a_offset, b_offset, b_offset2, coord, dimX, dimY, coord, predicate=primaryThread, ) nvgpu.TmaAsyncLoadOp( a_tma_slice, mbarTMA, a_tma_desc_op, coordinates=[coord, dimX], mbarId=iv, predicate=primaryThread, ) nvgpu.TmaAsyncLoadOp( b_tma_slice_1, mbarTMA, b_tma_desc_op, coordinates=[dimY, coord], mbarId=iv, predicate=primaryThread, ) nvgpu.TmaAsyncLoadOp( b_tma_slice_2, mbarTMA, b_tma_desc_op, coordinates=[dimY2, coord], mbarId=iv, predicate=primaryThread, ) # Step 3.2. mbarTMA arrive debug_print( "[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread ) nvgpu.mbarrier_arrive_expect_tx( mbarTMA, c(txcount), iv, predicate=primaryThread ) debug_print( "[Prologue] mbarTMA[{}] arrive [done]", iv, predicate=primaryThread, ) scf.yield_([]) # GPU Step 4. Main Loop acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty) for_op = scf.ForOp( c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)] ) with ir.InsertionPoint(for_op.body): # Step 4.1. Wait mbarTMA phaseParity = for_op.inner_iter_args[1] iv = for_op.induction_variable stage = arith.remui(iv, c(num_stages)) debug_print( "[MainLoop] mbarTMA[{}] try_wait phase={}", stage, phaseParity, predicate=primaryThread, ) nvgpu.MBarrierTryWaitParityOp( mbarTMA, phaseParity, ticks, mbarId=stage ) debug_print( "[MainLoop] mbarTMA[{}] try_wait phase={} [done]", stage, phaseParity, predicate=primaryThread, ) # Step 4.2. Create WGMMA Descriptors a_offset = arith.muli(stage, c(lhs_tile_bytes)) a_tile_slice = memref.view( ir.MemRefType.get( a_tile_shape, a_elem_ty, memory_space=smem_space ), dynamic_smem, a_offset, [], ) b_offset = arith.addi( arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages), ) b_tile_slice = memref.view( ir.MemRefType.get( b_tile_shape, b_elem_ty, memory_space=smem_space ), dynamic_smem, b_offset, [], ) debug_print( "[MainLoop] iv={} MMA a_offset={} b_offset={}", iv, a_offset, b_offset, predicate=primaryThread, ) da = nvgpu.WarpgroupGenerateDescriptorOp( a_wgmma_ty, a_tile_slice, a_tma_desc_op ) db = nvgpu.WarpgroupGenerateDescriptorOp( b_wgmma_ty, b_tile_slice, b_tma_desc_op ) # Step 4.3. MMA carry_acc = for_op.inner_iter_args[0] new_acc = nvgpu.WarpgroupMmaOp( acc.type, da, db, carry_acc, transposeB=True ) if num_stages == 1: nvvm.WgmmaWaitGroupSyncOp(0) # Step 4.4. Load TMA for next stage p1 = arith.cmpi( arith.CmpIPredicate.ult, arith.addi(iv, c(ns)), c(K // BLOCK_K), ) p = arith.andi(primaryThread, p1) nextStage = arith.addi(iv, c(ns)) nextSlot = arith.remui(nextStage, c(num_stages)) a_offset = arith.muli(nextSlot, c(lhs_tile_bytes)) debug_print( "[MainLoop] mbarTMA[{}] arrive", nextSlot, predicate=p, ) nvgpu.mbarrier_arrive_expect_tx( mbarTMA, c(txcount), nextSlot, predicate=p ) debug_print( "[MainLoop] mbarTMA[{}] arrive [done]", nextSlot, predicate=p, ) a_tma_slice = memref.view( ir.MemRefType.get( a_tma_shape, a_elem_ty, memory_space=smem_space ), dynamic_smem, a_offset, [], ) b_offset = arith.addi( arith.muli(nextSlot, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages), ) b_tma_slice_1 = memref.view( ir.MemRefType.get( b_tma_shape, b_elem_ty, memory_space=smem_space ), dynamic_smem, b_offset, [], ) b_offset2 = arith.addi( b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)), ) b_tma_slice_2 = memref.view( ir.MemRefType.get( b_tma_shape, b_elem_ty, memory_space=smem_space ), dynamic_smem, b_offset2, [], ) coord = arith.muli(c(64), nextStage) debug_print( "[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})", iv, a_offset, b_offset, b_offset2, coord, dimX, dimY, coord, predicate=p, ) nvgpu.TmaAsyncLoadOp( a_tma_slice, mbarTMA, a_tma_desc_op, coordinates=[coord, dimX], mbarId=nextSlot, predicate=p, ) nvgpu.TmaAsyncLoadOp( b_tma_slice_1, mbarTMA, b_tma_desc_op, coordinates=[dimY, coord], mbarId=nextSlot, predicate=p, ) dimY2 = arith.addi(dimY, c(64)) nvgpu.TmaAsyncLoadOp( b_tma_slice_2, mbarTMA, b_tma_desc_op, coordinates=[dimY2, coord], mbarId=nextSlot, predicate=p, ) # Step 4.5. Change the phaseParity p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1)) phaseParity = arith.select( p, arith.xori(phaseParity, arith.constant(T.bool(), 1)), phaseParity, ) # Step 4.5. Yield scf.yield_([new_acc, phaseParity]) # Step 5. Wait All WGMMA groups nvvm.WgmmaWaitGroupSyncOp(0) # Step 6. Epilogue (registers --> shared memory) acc_smem_ty = ir.MemRefType.get( (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space ) acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), []) debug_print("Storing", predicate=primaryThread) nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem) gpu.barrier() # GPU Step 7. Epilogue (shared memory --> global memory) fd = ir.MemRefType.get( [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space ) collapsed_smem = memref.view(fd, dynamic_smem, c(0), []) rty = ir.MemRefType.get( (BLOCK_M, BLOCK_N), c_elem_ty, ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"), ) c_device_per_block = memref.SubViewOp( rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC], [BLOCK_M, BLOCK_N], [1, 1], ) vlen = 1 for_op = scf.ForOp( tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE) ) with ir.InsertionPoint(for_op.body): x = arith.divui(for_op.induction_variable, c(BLOCK_M)) y = arith.remui(for_op.induction_variable, c(BLOCK_N)) vdata = vector.load( ir.VectorType.get((vlen,), c_elem_ty), collapsed_smem, [for_op.induction_variable], ) vector.store(vdata, c_device_per_block, [x, y]) scf.yield_([]) gpu.terminator() # Step 4. Copy back to host t8 = gpu.wait(token_ty, [launch_op]) t9 = gpu.memcpy(token_ty, [t8], c_host, c_device) gpu.dealloc(token_ty, [t8], a_device) gpu.dealloc(token_ty, [t8], b_device) gpu.wait(token_ty, [t9]) gpu.dealloc(token_ty, [t8], c_device) func.ReturnOp([]) fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() module.operation.verify() return module