1d95e6d02SGuray Ozenimport numpy as np 2d95e6d02SGuray Ozenfrom mlir import ir 3d95e6d02SGuray Ozenfrom mlir.dialects import arith 4d95e6d02SGuray Ozenfrom mlir.dialects import func 5d95e6d02SGuray Ozenfrom mlir.dialects import gpu 6d95e6d02SGuray Ozenfrom mlir.dialects import memref 7d95e6d02SGuray Ozenfrom mlir.dialects import nvgpu 8d95e6d02SGuray Ozenfrom mlir.dialects import nvvm 9d95e6d02SGuray Ozenfrom mlir.dialects import llvm 10d95e6d02SGuray Ozenfrom mlir.dialects import builtin 11d95e6d02SGuray Ozenfrom mlir.dialects import scf 12d95e6d02SGuray Ozenfrom mlir.dialects import vector 13d95e6d02SGuray Ozenfrom mlir.extras import types as T 14d95e6d02SGuray Ozen 15d95e6d02SGuray OzenTMA_LAST_DIM_F16 = 64 # 128B flaot16 16d95e6d02SGuray OzenWARP_SIZE = 32 17d95e6d02SGuray OzenWARP_GROUP_SIZE = WARP_SIZE * 4 18d95e6d02SGuray Ozen 19d95e6d02SGuray OzenPRODUCER_REGISTER_SIZE = 40 20d95e6d02SGuray OzenCONSUMER_REGISTER_SIZE = 232 21d95e6d02SGuray Ozen 22d95e6d02SGuray OzenPRODUCER_PRIMARY_THREAD = 128 23d95e6d02SGuray OzenCONSUMER_PRIMARY_THREAD = 0 24d95e6d02SGuray Ozen 25d95e6d02SGuray Ozen# C++ uses this value to understand whether it's dynamic or not. 26d95e6d02SGuray OzenMLIR_DYNAMIC = -9223372036854775808 27d95e6d02SGuray Ozen 28d95e6d02SGuray OzenDEBUG = False 29d95e6d02SGuray Ozen 30d95e6d02SGuray Ozen 31c82f45f9SGuray Ozenclass TmaDescriptorBuilder: 32c82f45f9SGuray Ozen """A class that builds a TMA descriptor.""" 33c82f45f9SGuray Ozen 34c82f45f9SGuray Ozen def __init__(self, swizzle, l2promo, oob, interleave, tma_box_shape, memref_ty): 35c82f45f9SGuray Ozen self.swizzle = swizzle # mlir.nvgpu.TensorMapSwizzleKind 36c82f45f9SGuray Ozen self.l2promo = l2promo # mlir.nvgpu.TensorMapL2PromoKind 37c82f45f9SGuray Ozen self.oob = oob # mlir.nvgpu.TensorMapOOBKind 38c82f45f9SGuray Ozen self.interleave = interleave # mlir.nvgpu.TensorMapInterleaveKind 39c82f45f9SGuray Ozen self.tma_box_shape = tma_box_shape 40c82f45f9SGuray Ozen self.memref_ty = memref_ty # MemRefType 41c82f45f9SGuray Ozen 42c82f45f9SGuray Ozen @property 43c82f45f9SGuray Ozen def tensormap_descriptor_ty(self): 44c82f45f9SGuray Ozen """Returns a tensormap descriptor type.""" 45c82f45f9SGuray Ozen tensorMemrefType = ir.MemRefType.get( 46c82f45f9SGuray Ozen self.tma_box_shape, 47c82f45f9SGuray Ozen self.memref_ty.element_type, 48c82f45f9SGuray Ozen memory_space=ir.Attribute.parse("3"), 49c82f45f9SGuray Ozen ) 50c82f45f9SGuray Ozen return nvgpu.TensorMapDescriptorType.get( 51c82f45f9SGuray Ozen tensorMemrefType, 52c82f45f9SGuray Ozen self.swizzle, 53c82f45f9SGuray Ozen self.l2promo, 54c82f45f9SGuray Ozen self.oob, 55c82f45f9SGuray Ozen self.interleave, 56c82f45f9SGuray Ozen ) 57c82f45f9SGuray Ozen 58c82f45f9SGuray Ozen def tma_descriptor_op(self, device_ptr): 59c82f45f9SGuray Ozen """Returns a tensormap descriptor op.""" 60c82f45f9SGuray Ozen tma_descriptor_ty = self.tensormap_descriptor_ty 61c82f45f9SGuray Ozen device_unranked_memref = memref.CastOp( 62c82f45f9SGuray Ozen ir.UnrankedMemRefType.get( 63c82f45f9SGuray Ozen self.memref_ty.element_type, self.memref_ty.memory_space 64c82f45f9SGuray Ozen ), 65c82f45f9SGuray Ozen device_ptr, 66c82f45f9SGuray Ozen ) 67c82f45f9SGuray Ozen tma_descriptor_op = nvgpu.TmaCreateDescriptorOp( 68c82f45f9SGuray Ozen tma_descriptor_ty, device_unranked_memref, map(c, self.tma_box_shape) 69c82f45f9SGuray Ozen ) 70c82f45f9SGuray Ozen return tma_descriptor_op.result 71c82f45f9SGuray Ozen 72c82f45f9SGuray Ozen 73d95e6d02SGuray Ozendef debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False): 74d95e6d02SGuray Ozen if not DEBUG and not forcePrint: 75d95e6d02SGuray Ozen return 76d95e6d02SGuray Ozen type_formats = [] 77d95e6d02SGuray Ozen for arg in args: 78d95e6d02SGuray Ozen ty_format = None 79d95e6d02SGuray Ozen if ir.IndexType.isinstance(arg.type): 80d95e6d02SGuray Ozen ty_format = "%llu" 81d95e6d02SGuray Ozen if ir.IntegerType.isinstance(arg.type): 82d95e6d02SGuray Ozen width = ir.IntegerType(arg.type).width 83d95e6d02SGuray Ozen if width == 64: 84d95e6d02SGuray Ozen ty_format = "%llu" 85d95e6d02SGuray Ozen elif width == 32: 86d95e6d02SGuray Ozen ty_format = "%d" 87d95e6d02SGuray Ozen elif width == 1: 88d95e6d02SGuray Ozen ty_format = "%i" 89d95e6d02SGuray Ozen if ir.F32Type.isinstance(arg.type): 90d95e6d02SGuray Ozen ty_format = "%f" 91d95e6d02SGuray Ozen if ty_format is None: 92d95e6d02SGuray Ozen raise NotImplementedError(arg.type) 93d95e6d02SGuray Ozen type_formats.append(ty_format) 94d95e6d02SGuray Ozen if threadNumber != -1: 95d95e6d02SGuray Ozen tidx = gpu.thread_id(gpu.Dimension.x) 96d95e6d02SGuray Ozen predicate = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(threadNumber)) 97d95e6d02SGuray Ozen scf.yield_([]) 98d95e6d02SGuray Ozen if_op = scf.IfOp(predicate) 99d95e6d02SGuray Ozen with ir.InsertionPoint(if_op.then_block): 100d95e6d02SGuray Ozen gpu.printf(fmt.format(*type_formats) + "\n", args) 101d95e6d02SGuray Ozen scf.yield_([]) 102d95e6d02SGuray Ozen 103d95e6d02SGuray Ozen 104d95e6d02SGuray Ozendef get_type_size(ty): 105d95e6d02SGuray Ozen if ir.FloatType.isinstance(ty): 106d95e6d02SGuray Ozen return ir.FloatType(ty).width // 8 107d95e6d02SGuray Ozen if ir.IntegerType.isinstance(ty): 108d95e6d02SGuray Ozen return ir.IntegerType(ty).width // 8 109d95e6d02SGuray Ozen raise NotImplementedError(ty) 110d95e6d02SGuray Ozen 111d95e6d02SGuray Ozen 112d95e6d02SGuray Ozendef get_mlir_ty(dtype): 113d95e6d02SGuray Ozen if dtype == np.float16: 114d95e6d02SGuray Ozen return T.f16() 115d95e6d02SGuray Ozen if dtype == np.float32: 116d95e6d02SGuray Ozen return T.f32() 117d95e6d02SGuray Ozen if dtype == np.float64: 118d95e6d02SGuray Ozen return T.f64() 119d95e6d02SGuray Ozen if dtype == np.int32: 120d95e6d02SGuray Ozen return T.i32() 121d95e6d02SGuray Ozen if dtype == np.int64: 122d95e6d02SGuray Ozen return T.i64() 123d95e6d02SGuray Ozen raise NotImplementedError(dtype) 124d95e6d02SGuray Ozen 125d95e6d02SGuray Ozen 126d95e6d02SGuray Ozendef c(value, ty=None): 127d95e6d02SGuray Ozen ty = T.index() if ty is None else ty 128d95e6d02SGuray Ozen return arith.constant(ty, value) 129d95e6d02SGuray Ozen 130d95e6d02SGuray Ozen 131d95e6d02SGuray Ozendef make_kernel_name( 132d95e6d02SGuray Ozen input_type=np.float16, 133d95e6d02SGuray Ozen output_type=np.float32, 134d95e6d02SGuray Ozen M=4096, 135d95e6d02SGuray Ozen N=4096, 136d95e6d02SGuray Ozen K=4096, 137d95e6d02SGuray Ozen BLOCK_M=128, 138d95e6d02SGuray Ozen BLOCK_N=128, 139d95e6d02SGuray Ozen BLOCK_K=128, 140d95e6d02SGuray Ozen num_stages=3, 141d95e6d02SGuray Ozen use_warp_specialization=False, 142d95e6d02SGuray Ozen): 143d95e6d02SGuray Ozen kernelName = "warpspecialized" if use_warp_specialization else "multistage" 144d95e6d02SGuray Ozen return ( 145d95e6d02SGuray Ozen kernelName 146d95e6d02SGuray Ozen + "_" 147d95e6d02SGuray Ozen + str(M) 148d95e6d02SGuray Ozen + "x" 149d95e6d02SGuray Ozen + str(N) 150d95e6d02SGuray Ozen + "x" 151d95e6d02SGuray Ozen + str(K) 152d95e6d02SGuray Ozen + "_" 153d95e6d02SGuray Ozen + str(BLOCK_M) 154d95e6d02SGuray Ozen + "x" 155d95e6d02SGuray Ozen + str(BLOCK_N) 156d95e6d02SGuray Ozen + "x" 157d95e6d02SGuray Ozen + str(BLOCK_K) 158d95e6d02SGuray Ozen + "_" 159d95e6d02SGuray Ozen + str(num_stages) 160d95e6d02SGuray Ozen ) 161d95e6d02SGuray Ozen 162d95e6d02SGuray Ozen 163d95e6d02SGuray Ozendef generate_matmul_ws( 164d95e6d02SGuray Ozen input_type=np.float16, 165d95e6d02SGuray Ozen output_type=np.float32, 166d95e6d02SGuray Ozen M=4096, 167d95e6d02SGuray Ozen N=4096, 168d95e6d02SGuray Ozen K=4096, 169d95e6d02SGuray Ozen BLOCK_M=128, 170d95e6d02SGuray Ozen BLOCK_N=128, 171d95e6d02SGuray Ozen BLOCK_K=128, 172d95e6d02SGuray Ozen num_stages=3, 173d95e6d02SGuray Ozen): 174d95e6d02SGuray Ozen # Limitaitons for now 175d95e6d02SGuray Ozen assert input_type == np.float16 176d95e6d02SGuray Ozen assert output_type == np.float32 177d95e6d02SGuray Ozen assert BLOCK_M == 128 178d95e6d02SGuray Ozen assert BLOCK_N == 128 179d95e6d02SGuray Ozen assert BLOCK_K == 64 180d95e6d02SGuray Ozen assert M % BLOCK_M == 0 181d95e6d02SGuray Ozen assert N % BLOCK_N == 0 182d95e6d02SGuray Ozen assert K % BLOCK_K == 0 183d95e6d02SGuray Ozen 184d95e6d02SGuray Ozen module = ir.Module.create() 185f8ff9094SGuray Ozen token_ty = gpu.AsyncTokenType.get() 186d95e6d02SGuray Ozen a_elem_ty = get_mlir_ty(input_type) 187d95e6d02SGuray Ozen b_elem_ty = get_mlir_ty(input_type) 188d95e6d02SGuray Ozen c_elem_ty = get_mlir_ty(output_type) 189d95e6d02SGuray Ozen a_ty = ir.MemRefType.get([M, K], a_elem_ty) 190d95e6d02SGuray Ozen b_ty = ir.MemRefType.get((K, N), b_elem_ty) 191d95e6d02SGuray Ozen c_ty = ir.MemRefType.get((M, N), c_elem_ty) 192d95e6d02SGuray Ozen a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16) 193d95e6d02SGuray Ozen b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16) 194d95e6d02SGuray Ozen b_tile_shape = (BLOCK_K, BLOCK_N) 195d95e6d02SGuray Ozen txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + ( 196d95e6d02SGuray Ozen a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty) 197d95e6d02SGuray Ozen ) 198d95e6d02SGuray Ozen smem_space_str = "#gpu.address_space<workgroup>" 199d95e6d02SGuray Ozen smem_space = ir.Attribute.parse(smem_space_str) 200d95e6d02SGuray Ozen mbar_ty = ir.Type.parse( 201d95e6d02SGuray Ozen "!nvgpu.mbarrier.group<memorySpace = " 202d95e6d02SGuray Ozen + str(smem_space) 203d95e6d02SGuray Ozen + ", num_barriers = " 204d95e6d02SGuray Ozen + str(num_stages) 205d95e6d02SGuray Ozen + ">" 206d95e6d02SGuray Ozen ) 207d95e6d02SGuray Ozen acc_ty = ir.Type.parse( 208d95e6d02SGuray Ozen "!nvgpu.warpgroup.accumulator<fragmented=vector<" 209d95e6d02SGuray Ozen + str(BLOCK_M) 210d95e6d02SGuray Ozen + "x" 211d95e6d02SGuray Ozen + str(BLOCK_N) 212d95e6d02SGuray Ozen + "x" 213d95e6d02SGuray Ozen + str(c_elem_ty) 214d95e6d02SGuray Ozen + ">>" 215d95e6d02SGuray Ozen ) 216d95e6d02SGuray Ozen a_wgmma_ty = ir.Type.parse( 217d95e6d02SGuray Ozen "!nvgpu.warpgroup.descriptor<tensor=memref<" 218d95e6d02SGuray Ozen + str(BLOCK_M) 219d95e6d02SGuray Ozen + "x" 220d95e6d02SGuray Ozen + str(BLOCK_K) 221d95e6d02SGuray Ozen + "x" 222d95e6d02SGuray Ozen + str(a_elem_ty) 223d95e6d02SGuray Ozen + ", " 224d95e6d02SGuray Ozen + smem_space_str 225d95e6d02SGuray Ozen + ">>" 226d95e6d02SGuray Ozen ) 227d95e6d02SGuray Ozen b_wgmma_ty = ir.Type.parse( 228d95e6d02SGuray Ozen "!nvgpu.warpgroup.descriptor<tensor=memref<" 229d95e6d02SGuray Ozen + str(BLOCK_K) 230d95e6d02SGuray Ozen + "x" 231d95e6d02SGuray Ozen + str(BLOCK_N) 232d95e6d02SGuray Ozen + "x" 233d95e6d02SGuray Ozen + str(a_elem_ty) 234d95e6d02SGuray Ozen + ", " 235d95e6d02SGuray Ozen + smem_space_str 236d95e6d02SGuray Ozen + ">>" 237d95e6d02SGuray Ozen ) 238d95e6d02SGuray Ozen kernelName = make_kernel_name( 239d95e6d02SGuray Ozen input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_stages, True 240d95e6d02SGuray Ozen ) 241d95e6d02SGuray Ozen with ir.InsertionPoint(module.body): 242d95e6d02SGuray Ozen fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], [])) 243d95e6d02SGuray Ozen with ir.InsertionPoint(fop.add_entry_block()): 244d95e6d02SGuray Ozen a_host = fop.arguments[0] 245d95e6d02SGuray Ozen b_host = fop.arguments[1] 246d95e6d02SGuray Ozen c_host = fop.arguments[2] 247d95e6d02SGuray Ozen lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty) 248d95e6d02SGuray Ozen rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty) 249d95e6d02SGuray Ozen smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages 250d95e6d02SGuray Ozen smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty) 251d95e6d02SGuray Ozen smem_size = max(smem_size_input, smem_size_output) 252d95e6d02SGuray Ozen 253d95e6d02SGuray Ozen # Step 1. Allocate device memory and memcpy 254d95e6d02SGuray Ozen t1 = gpu.wait(token_ty, []) 255d95e6d02SGuray Ozen a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], []) 256d95e6d02SGuray Ozen b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], []) 257d95e6d02SGuray Ozen c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], []) 258d95e6d02SGuray Ozen t5 = gpu.memcpy(token_ty, [t4], a_device, a_host) 259d95e6d02SGuray Ozen t6 = gpu.memcpy(token_ty, [t5], b_device, b_host) 260d95e6d02SGuray Ozen t7 = gpu.wait(token_ty, [t6]) 261d95e6d02SGuray Ozen 262d95e6d02SGuray Ozen # Step 2. Create TMA Descriptors 263c82f45f9SGuray Ozen a_tma_desc = TmaDescriptorBuilder( 264c82f45f9SGuray Ozen nvgpu.TensorMapSwizzleKind.SWIZZLE_128B, 265c82f45f9SGuray Ozen nvgpu.TensorMapL2PromoKind.L2PROMO_NONE, 266c82f45f9SGuray Ozen nvgpu.TensorMapOOBKind.OOB_ZERO, 267c82f45f9SGuray Ozen nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE, 268c82f45f9SGuray Ozen a_tma_shape, 269c82f45f9SGuray Ozen a_ty, 270d95e6d02SGuray Ozen ) 271c82f45f9SGuray Ozen 272c82f45f9SGuray Ozen b_tma_desc = TmaDescriptorBuilder( 273c82f45f9SGuray Ozen nvgpu.TensorMapSwizzleKind.SWIZZLE_128B, 274c82f45f9SGuray Ozen nvgpu.TensorMapL2PromoKind.L2PROMO_NONE, 275c82f45f9SGuray Ozen nvgpu.TensorMapOOBKind.OOB_ZERO, 276c82f45f9SGuray Ozen nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE, 277c82f45f9SGuray Ozen b_tma_shape, 278c82f45f9SGuray Ozen b_ty, 279d95e6d02SGuray Ozen ) 280c82f45f9SGuray Ozen 281c82f45f9SGuray Ozen a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device) 282c82f45f9SGuray Ozen b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device) 283d95e6d02SGuray Ozen 284d95e6d02SGuray Ozen # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer 285d95e6d02SGuray Ozen cta_m = M // BLOCK_M 286d95e6d02SGuray Ozen cta_n = N // BLOCK_N 287d95e6d02SGuray Ozen assert M % BLOCK_M == 0 and N % BLOCK_N == 0 288d95e6d02SGuray Ozen grid = (cta_m, cta_n, 1) 289d95e6d02SGuray Ozen block = (WARP_GROUP_SIZE * 2, 1, 1) 290d95e6d02SGuray Ozen launch_op = gpu.LaunchOp( 291d95e6d02SGuray Ozen token_ty, 292d95e6d02SGuray Ozen [t7], 293d95e6d02SGuray Ozen *map(c, grid), 294d95e6d02SGuray Ozen *map(c, block), 295c82f45f9SGuray Ozen dynamicSharedMemorySize=c(smem_size, ty=T.i32()), 296d95e6d02SGuray Ozen ) 297d95e6d02SGuray Ozen launch_op.body.blocks.append(*([T.index()] * 12)) 298d95e6d02SGuray Ozen with ir.InsertionPoint(launch_op.body.blocks[0]): 299d95e6d02SGuray Ozen # GPU Step 0. This is need for vectorized ld/st 300d95e6d02SGuray Ozen memref.assume_alignment(c_device, 16) 301d95e6d02SGuray Ozen dynamic_smem = gpu.dynamic_shared_memory( 302d95e6d02SGuray Ozen ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space) 303d95e6d02SGuray Ozen ) 304d95e6d02SGuray Ozen ticks = c(10000000) 305d95e6d02SGuray Ozen 306d95e6d02SGuray Ozen # GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc. 307d95e6d02SGuray Ozen tidx = gpu.thread_id(gpu.Dimension.x) 308d95e6d02SGuray Ozen wgPrimaryThread = arith.cmpi( 309d95e6d02SGuray Ozen arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0) 310d95e6d02SGuray Ozen ) 311d95e6d02SGuray Ozen warp_id = arith.divui(tidx, c(32)) 312d95e6d02SGuray Ozen warpgroup_id = arith.divui(warp_id, c(4)) 313d95e6d02SGuray Ozen is_producer = arith.cmpi( 314d95e6d02SGuray Ozen arith.CmpIPredicate.eq, 315d95e6d02SGuray Ozen warpgroup_id, 316d95e6d02SGuray Ozen c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0), 317d95e6d02SGuray Ozen ) 318d95e6d02SGuray Ozen is_consumer = arith.cmpi( 319d95e6d02SGuray Ozen arith.CmpIPredicate.eq, 320d95e6d02SGuray Ozen warpgroup_id, 321d95e6d02SGuray Ozen c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1), 322d95e6d02SGuray Ozen ) 323d95e6d02SGuray Ozen producerPrimaryThread = arith.cmpi( 324d95e6d02SGuray Ozen arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD) 325d95e6d02SGuray Ozen ) 326d95e6d02SGuray Ozen consumerPrimaryThread = arith.cmpi( 327d95e6d02SGuray Ozen arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD) 328d95e6d02SGuray Ozen ) 329d95e6d02SGuray Ozen bidx = gpu.block_id(gpu.Dimension.x) 330d95e6d02SGuray Ozen bidy = gpu.block_id(gpu.Dimension.y) 331d95e6d02SGuray Ozen dimX = arith.muli(bidx, c(BLOCK_M)) 332d95e6d02SGuray Ozen dimY = arith.muli(bidy, c(BLOCK_N)) 333d95e6d02SGuray Ozen 334d95e6d02SGuray Ozen # GPU Step 2. Initialize mbarrier groups 335d95e6d02SGuray Ozen mbarTMA = nvgpu.mbarrier_create(mbar_ty) 336d95e6d02SGuray Ozen mbarDONE = nvgpu.mbarrier_create(mbar_ty) 337d95e6d02SGuray Ozen for i in range(num_stages): 338d95e6d02SGuray Ozen nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread) 339d95e6d02SGuray Ozen nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread) 340d95e6d02SGuray Ozen gpu.barrier() 341d95e6d02SGuray Ozen 342d95e6d02SGuray Ozen # GPU Step 3. Prefetch TMA descriptors 343c82f45f9SGuray Ozen nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=wgPrimaryThread) 344c82f45f9SGuray Ozen nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=wgPrimaryThread) 345d95e6d02SGuray Ozen 346d95e6d02SGuray Ozen ns = num_stages if num_stages == 1 else num_stages - 1 347d95e6d02SGuray Ozen # GPU Step 5. Producer Warpgroup (TMA Warpgroup) 348d95e6d02SGuray Ozen with ir.InsertionPoint(scf.IfOp(is_producer).then_block): 349d95e6d02SGuray Ozen # Step 5.1. Reduce register size 350d95e6d02SGuray Ozen nvvm.setmaxregister( 351d95e6d02SGuray Ozen PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease 352d95e6d02SGuray Ozen ) 353d95e6d02SGuray Ozen 354d95e6d02SGuray Ozen # Step 5.2. TMA Main Loop 355d95e6d02SGuray Ozen for_op = scf.ForOp( 356d95e6d02SGuray Ozen c(0), c(K // BLOCK_K), c(1), [arith.constant(T.bool(), 1)] 357d95e6d02SGuray Ozen ) 358d95e6d02SGuray Ozen with ir.InsertionPoint(for_op.body): 359d95e6d02SGuray Ozen phaseParity = for_op.inner_iter_args[0] 360d95e6d02SGuray Ozen iv = for_op.induction_variable 361d95e6d02SGuray Ozen stage = arith.remui(iv, c(num_stages)) 362d95e6d02SGuray Ozen 363d95e6d02SGuray Ozen # Step 5.2.1. Wait mbarDONE 364d95e6d02SGuray Ozen debug_print( 365d95e6d02SGuray Ozen "[prod] iv={} | mbarDONE[{}] try_wait phase={}", 366d95e6d02SGuray Ozen iv, 367d95e6d02SGuray Ozen stage, 368d95e6d02SGuray Ozen phaseParity, 369d95e6d02SGuray Ozen predicate=producerPrimaryThread, 370d95e6d02SGuray Ozen ) 371d95e6d02SGuray Ozen nvgpu.MBarrierTryWaitParityOp( 372d95e6d02SGuray Ozen mbarDONE, phaseParity, ticks, mbarId=stage 373d95e6d02SGuray Ozen ) 374d95e6d02SGuray Ozen debug_print( 375d95e6d02SGuray Ozen "[prod] iv={} | mbarDONE[{}] try_wait phase={} [done]", 376d95e6d02SGuray Ozen iv, 377d95e6d02SGuray Ozen stage, 378d95e6d02SGuray Ozen phaseParity, 379d95e6d02SGuray Ozen predicate=producerPrimaryThread, 380d95e6d02SGuray Ozen ) 381d95e6d02SGuray Ozen p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1)) 382d95e6d02SGuray Ozen phaseParity = arith.select( 383d95e6d02SGuray Ozen p, 384d95e6d02SGuray Ozen arith.xori(phaseParity, arith.constant(T.bool(), 1)), 385d95e6d02SGuray Ozen phaseParity, 386d95e6d02SGuray Ozen ) 387d95e6d02SGuray Ozen 388d95e6d02SGuray Ozen # Step 5.2.2. Load TMA 389d95e6d02SGuray Ozen a_offset = arith.muli(stage, c(lhs_tile_bytes)) 390d95e6d02SGuray Ozen a_tma_slice = memref.view( 391d95e6d02SGuray Ozen ir.MemRefType.get( 392d95e6d02SGuray Ozen a_tma_shape, a_elem_ty, memory_space=smem_space 393d95e6d02SGuray Ozen ), 394d95e6d02SGuray Ozen dynamic_smem, 395d95e6d02SGuray Ozen a_offset, 396d95e6d02SGuray Ozen [], 397d95e6d02SGuray Ozen ) 398d95e6d02SGuray Ozen b_offset = arith.addi( 399d95e6d02SGuray Ozen arith.muli(stage, c(rhs_tile_bytes)), 400d95e6d02SGuray Ozen c(lhs_tile_bytes * num_stages), 401d95e6d02SGuray Ozen ) 402d95e6d02SGuray Ozen b_tma_slice_1 = memref.view( 403d95e6d02SGuray Ozen ir.MemRefType.get( 404d95e6d02SGuray Ozen b_tma_shape, b_elem_ty, memory_space=smem_space 405d95e6d02SGuray Ozen ), 406d95e6d02SGuray Ozen dynamic_smem, 407d95e6d02SGuray Ozen b_offset, 408d95e6d02SGuray Ozen [], 409d95e6d02SGuray Ozen ) 410d95e6d02SGuray Ozen b_offset2 = arith.addi( 411d95e6d02SGuray Ozen b_offset, 412d95e6d02SGuray Ozen c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)), 413d95e6d02SGuray Ozen ) 414d95e6d02SGuray Ozen b_tma_slice_2 = memref.view( 415d95e6d02SGuray Ozen ir.MemRefType.get( 416d95e6d02SGuray Ozen b_tma_shape, b_elem_ty, memory_space=smem_space 417d95e6d02SGuray Ozen ), 418d95e6d02SGuray Ozen dynamic_smem, 419d95e6d02SGuray Ozen b_offset2, 420d95e6d02SGuray Ozen [], 421d95e6d02SGuray Ozen ) 422d95e6d02SGuray Ozen debug_print( 423d95e6d02SGuray Ozen "[prod] a_offset={} b_offset={} b_offset2={}", 424d95e6d02SGuray Ozen a_offset, 425d95e6d02SGuray Ozen b_offset, 426d95e6d02SGuray Ozen b_offset2, 427d95e6d02SGuray Ozen predicate=producerPrimaryThread, 428d95e6d02SGuray Ozen ) 429d95e6d02SGuray Ozen coord = arith.muli(c(64), iv) 430d95e6d02SGuray Ozen nvgpu.TmaAsyncLoadOp( 431d95e6d02SGuray Ozen a_tma_slice, 432d95e6d02SGuray Ozen mbarTMA, 433c82f45f9SGuray Ozen a_tma_desc_op, 434d95e6d02SGuray Ozen coordinates=[coord, dimX], 435d95e6d02SGuray Ozen mbarId=stage, 436d95e6d02SGuray Ozen predicate=producerPrimaryThread, 437d95e6d02SGuray Ozen ) 438d95e6d02SGuray Ozen nvgpu.TmaAsyncLoadOp( 439d95e6d02SGuray Ozen b_tma_slice_1, 440d95e6d02SGuray Ozen mbarTMA, 441c82f45f9SGuray Ozen b_tma_desc_op, 442d95e6d02SGuray Ozen coordinates=[dimY, coord], 443d95e6d02SGuray Ozen mbarId=stage, 444d95e6d02SGuray Ozen predicate=producerPrimaryThread, 445d95e6d02SGuray Ozen ) 446d95e6d02SGuray Ozen dimY2 = arith.addi(dimY, c(64)) 447d95e6d02SGuray Ozen nvgpu.TmaAsyncLoadOp( 448d95e6d02SGuray Ozen b_tma_slice_2, 449d95e6d02SGuray Ozen mbarTMA, 450c82f45f9SGuray Ozen b_tma_desc_op, 451d95e6d02SGuray Ozen coordinates=[dimY2, coord], 452d95e6d02SGuray Ozen mbarId=stage, 453d95e6d02SGuray Ozen predicate=producerPrimaryThread, 454d95e6d02SGuray Ozen ) 455d95e6d02SGuray Ozen 456d95e6d02SGuray Ozen # Step 5.2.3. Arrive mbarTMA 457d95e6d02SGuray Ozen debug_print( 458d95e6d02SGuray Ozen "[prod] iv={} | mbarTMA[{}] arrive", 459d95e6d02SGuray Ozen iv, 460d95e6d02SGuray Ozen stage, 461d95e6d02SGuray Ozen predicate=producerPrimaryThread, 462d95e6d02SGuray Ozen ) 463d95e6d02SGuray Ozen nvgpu.mbarrier_arrive_expect_tx( 464d95e6d02SGuray Ozen mbarTMA, c(txcount), stage, predicate=producerPrimaryThread 465d95e6d02SGuray Ozen ) 466d95e6d02SGuray Ozen debug_print( 467d95e6d02SGuray Ozen "[prod] iv={} | mbarTMA[{}] arrive [done]", 468d95e6d02SGuray Ozen iv, 469d95e6d02SGuray Ozen stage, 470d95e6d02SGuray Ozen predicate=producerPrimaryThread, 471d95e6d02SGuray Ozen ) 472d95e6d02SGuray Ozen scf.yield_([phaseParity]) 473d95e6d02SGuray Ozen scf.yield_([]) 474d95e6d02SGuray Ozen 475d95e6d02SGuray Ozen # GPU Step 6. Consumer Warpgroup (MMA Warpgroup) 476d95e6d02SGuray Ozen if_op = scf.IfOp(is_consumer) 477d95e6d02SGuray Ozen with ir.InsertionPoint(if_op.then_block): 478d95e6d02SGuray Ozen # Step 6.1. Increase register size 479d95e6d02SGuray Ozen nvvm.setmaxregister( 480d95e6d02SGuray Ozen CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase 481d95e6d02SGuray Ozen ) 482d95e6d02SGuray Ozen 483d95e6d02SGuray Ozen # GPU Step 6.2. Initialize MMA registers 484d95e6d02SGuray Ozen acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty) 485d95e6d02SGuray Ozen 486d95e6d02SGuray Ozen # Step 6.3. MMA Main Loop 487d95e6d02SGuray Ozen for_op = scf.ForOp( 488d95e6d02SGuray Ozen c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)] 489d95e6d02SGuray Ozen ) 490d95e6d02SGuray Ozen with ir.InsertionPoint(for_op.body): 491d95e6d02SGuray Ozen # Step 6.3.1. Wait mbar1 492d95e6d02SGuray Ozen phaseParity = for_op.inner_iter_args[1] 493d95e6d02SGuray Ozen iv = for_op.induction_variable 494d95e6d02SGuray Ozen stage = arith.remui(iv, c(num_stages)) 495d95e6d02SGuray Ozen debug_print( 496d95e6d02SGuray Ozen "[cons] iv={} | mbarTMA[{}] try_wait phase={}", 497d95e6d02SGuray Ozen iv, 498d95e6d02SGuray Ozen stage, 499d95e6d02SGuray Ozen phaseParity, 500d95e6d02SGuray Ozen predicate=consumerPrimaryThread, 501d95e6d02SGuray Ozen ) 502d95e6d02SGuray Ozen nvgpu.MBarrierTryWaitParityOp( 503d95e6d02SGuray Ozen mbarTMA, phaseParity, ticks, mbarId=stage 504d95e6d02SGuray Ozen ) 505d95e6d02SGuray Ozen debug_print( 506d95e6d02SGuray Ozen "[cons] iv={} | mbarTMA[{}] try_wait phase={} [done]", 507d95e6d02SGuray Ozen iv, 508d95e6d02SGuray Ozen stage, 509d95e6d02SGuray Ozen phaseParity, 510d95e6d02SGuray Ozen predicate=consumerPrimaryThread, 511d95e6d02SGuray Ozen ) 512d95e6d02SGuray Ozen 513d95e6d02SGuray Ozen # Step 6.3.2. Create WGMMA Descriptors 514d95e6d02SGuray Ozen a_offset = arith.muli(stage, c(lhs_tile_bytes)) 515d95e6d02SGuray Ozen a_tile_slice = memref.view( 516d95e6d02SGuray Ozen ir.MemRefType.get( 517d95e6d02SGuray Ozen a_tile_shape, a_elem_ty, memory_space=smem_space 518d95e6d02SGuray Ozen ), 519d95e6d02SGuray Ozen dynamic_smem, 520d95e6d02SGuray Ozen a_offset, 521d95e6d02SGuray Ozen [], 522d95e6d02SGuray Ozen ) 523d95e6d02SGuray Ozen b_offset = arith.addi( 524d95e6d02SGuray Ozen arith.muli(stage, c(rhs_tile_bytes)), 525d95e6d02SGuray Ozen c(lhs_tile_bytes * num_stages), 526d95e6d02SGuray Ozen ) 527d95e6d02SGuray Ozen b_tile_slice = memref.view( 528d95e6d02SGuray Ozen ir.MemRefType.get( 529d95e6d02SGuray Ozen b_tile_shape, b_elem_ty, memory_space=smem_space 530d95e6d02SGuray Ozen ), 531d95e6d02SGuray Ozen dynamic_smem, 532d95e6d02SGuray Ozen b_offset, 533d95e6d02SGuray Ozen [], 534d95e6d02SGuray Ozen ) 535d95e6d02SGuray Ozen debug_print( 536d95e6d02SGuray Ozen "[cons] a_offset={} b_offset={}", 537d95e6d02SGuray Ozen a_offset, 538d95e6d02SGuray Ozen b_offset, 539d95e6d02SGuray Ozen predicate=consumerPrimaryThread, 540d95e6d02SGuray Ozen ) 541d95e6d02SGuray Ozen da = nvgpu.WarpgroupGenerateDescriptorOp( 542c82f45f9SGuray Ozen a_wgmma_ty, a_tile_slice, a_tma_desc_op 543d95e6d02SGuray Ozen ) 544d95e6d02SGuray Ozen db = nvgpu.WarpgroupGenerateDescriptorOp( 545c82f45f9SGuray Ozen b_wgmma_ty, b_tile_slice, b_tma_desc_op 546d95e6d02SGuray Ozen ) 547d95e6d02SGuray Ozen 548d95e6d02SGuray Ozen # Step 6.3.3. MMA 549d95e6d02SGuray Ozen carry_acc = for_op.inner_iter_args[0] 550d95e6d02SGuray Ozen new_acc = nvgpu.WarpgroupMmaOp( 551d95e6d02SGuray Ozen acc.type, da, db, carry_acc, transposeB=True 552d95e6d02SGuray Ozen ) 553d95e6d02SGuray Ozen 554d95e6d02SGuray Ozen # Step 6.3.4. Arrive mbarDONE 555d95e6d02SGuray Ozen if num_stages == 1: 556d95e6d02SGuray Ozen p_arrive = consumerPrimaryThread 557d95e6d02SGuray Ozen else: 558d95e6d02SGuray Ozen p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0)) 559d95e6d02SGuray Ozen p_arrive = arith.andi(consumerPrimaryThread, p1) 560d95e6d02SGuray Ozen with ir.InsertionPoint(scf.IfOp(p_arrive).then_block): 561d95e6d02SGuray Ozen p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0)) 562d95e6d02SGuray Ozen barId = arith.select( 563d95e6d02SGuray Ozen p, c(num_stages - 1), arith.subi(stage, c(1)) 564d95e6d02SGuray Ozen ) 565d95e6d02SGuray Ozen debug_print( 566d95e6d02SGuray Ozen "[cons] iv={} | mbarDONE[{}] arrive ", 567d95e6d02SGuray Ozen iv, 568d95e6d02SGuray Ozen barId, 569d95e6d02SGuray Ozen predicate=consumerPrimaryThread, 570d95e6d02SGuray Ozen ) 571*13d6233eSDurgadoss R nvgpu.mbarrier_arrive(mbarDONE, barId) 572d95e6d02SGuray Ozen debug_print( 573d95e6d02SGuray Ozen "[cons] iv={} | mbarDONE[{}] arrive [done]", 574d95e6d02SGuray Ozen iv, 575d95e6d02SGuray Ozen barId, 576d95e6d02SGuray Ozen predicate=consumerPrimaryThread, 577d95e6d02SGuray Ozen ) 578d95e6d02SGuray Ozen scf.yield_([]) 579d95e6d02SGuray Ozen 580d95e6d02SGuray Ozen p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1)) 581d95e6d02SGuray Ozen phaseParity = arith.select( 582d95e6d02SGuray Ozen p, 583d95e6d02SGuray Ozen arith.xori(phaseParity, arith.constant(T.bool(), 1)), 584d95e6d02SGuray Ozen phaseParity, 585d95e6d02SGuray Ozen ) 586d95e6d02SGuray Ozen 587d95e6d02SGuray Ozen # Step 6.3.5. Yield 588d95e6d02SGuray Ozen scf.yield_([new_acc, phaseParity]) 589d95e6d02SGuray Ozen 590d95e6d02SGuray Ozen with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block): 591d95e6d02SGuray Ozen barId = c((K // BLOCK_K) % num_stages) 592*13d6233eSDurgadoss R nvgpu.mbarrier_arrive(mbarDONE, barId) 593d95e6d02SGuray Ozen scf.yield_([]) 594d95e6d02SGuray Ozen 595d95e6d02SGuray Ozen # Step 6.4. Epilogue (registers --> shared memory) 596d95e6d02SGuray Ozen acc_smem_ty = ir.MemRefType.get( 597d95e6d02SGuray Ozen (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space 598d95e6d02SGuray Ozen ) 599d95e6d02SGuray Ozen acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), []) 600d95e6d02SGuray Ozen debug_print("[cons] | Storing", predicate=consumerPrimaryThread) 601d95e6d02SGuray Ozen nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem) 602d95e6d02SGuray Ozen scf.yield_([]) 603d95e6d02SGuray Ozen gpu.barrier() 604d95e6d02SGuray Ozen 605d95e6d02SGuray Ozen # GPU Step 9. Epilogue (shared memory --> global memory) 606d95e6d02SGuray Ozen fd = ir.MemRefType.get( 607d95e6d02SGuray Ozen [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space 608d95e6d02SGuray Ozen ) 609d95e6d02SGuray Ozen collapsed_smem = memref.view(fd, dynamic_smem, c(0), []) 610d95e6d02SGuray Ozen rty = ir.MemRefType.get( 611d95e6d02SGuray Ozen (BLOCK_M, BLOCK_N), 612d95e6d02SGuray Ozen c_elem_ty, 613d95e6d02SGuray Ozen ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"), 614d95e6d02SGuray Ozen ) 615d95e6d02SGuray Ozen c_device_per_block = memref.SubViewOp( 616d95e6d02SGuray Ozen rty, 617d95e6d02SGuray Ozen c_device, 618d95e6d02SGuray Ozen [dimX, dimY], 619d95e6d02SGuray Ozen [], 620d95e6d02SGuray Ozen [], 621d95e6d02SGuray Ozen [MLIR_DYNAMIC, MLIR_DYNAMIC], 622d95e6d02SGuray Ozen [BLOCK_M, BLOCK_N], 623d95e6d02SGuray Ozen [1, 1], 624d95e6d02SGuray Ozen ) 625d95e6d02SGuray Ozen vlen = 1 626d95e6d02SGuray Ozen for_op = scf.ForOp( 627d95e6d02SGuray Ozen tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2) 628d95e6d02SGuray Ozen ) 629d95e6d02SGuray Ozen with ir.InsertionPoint(for_op.body): 630d95e6d02SGuray Ozen x = arith.divui(for_op.induction_variable, c(BLOCK_M)) 631d95e6d02SGuray Ozen y = arith.remui(for_op.induction_variable, c(BLOCK_N)) 632d95e6d02SGuray Ozen vdata = vector.load( 633d95e6d02SGuray Ozen ir.VectorType.get((vlen,), c_elem_ty), 634d95e6d02SGuray Ozen collapsed_smem, 635d95e6d02SGuray Ozen [for_op.induction_variable], 636d95e6d02SGuray Ozen ) 637d95e6d02SGuray Ozen vector.store(vdata, c_device_per_block, [x, y]) 638d95e6d02SGuray Ozen scf.yield_([]) 639d95e6d02SGuray Ozen 640d95e6d02SGuray Ozen gpu.terminator() 641d95e6d02SGuray Ozen 642d95e6d02SGuray Ozen # Step 4. Copy back to host 643d95e6d02SGuray Ozen t8 = gpu.wait(token_ty, [launch_op]) 644d95e6d02SGuray Ozen t9 = gpu.memcpy(token_ty, [t8], c_host, c_device) 645d95e6d02SGuray Ozen gpu.dealloc(token_ty, [t8], a_device) 646d95e6d02SGuray Ozen gpu.dealloc(token_ty, [t8], b_device) 647d95e6d02SGuray Ozen gpu.wait(token_ty, [t9]) 648d95e6d02SGuray Ozen gpu.dealloc(token_ty, [t8], c_device) 649d95e6d02SGuray Ozen func.ReturnOp([]) 650d95e6d02SGuray Ozen 651d95e6d02SGuray Ozen fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() 652d95e6d02SGuray Ozen module.operation.verify() 653d95e6d02SGuray Ozen return module 654d95e6d02SGuray Ozen 655d95e6d02SGuray Ozen 656d95e6d02SGuray Ozendef generate_matmul_multistage( 657d95e6d02SGuray Ozen input_type=np.float16, 658d95e6d02SGuray Ozen output_type=np.float32, 659d95e6d02SGuray Ozen M=4096, 660d95e6d02SGuray Ozen N=4096, 661d95e6d02SGuray Ozen K=4096, 662d95e6d02SGuray Ozen BLOCK_M=128, 663d95e6d02SGuray Ozen BLOCK_N=128, 664d95e6d02SGuray Ozen BLOCK_K=64, 665d95e6d02SGuray Ozen num_stages=3, 666d95e6d02SGuray Ozen): 667d95e6d02SGuray Ozen # Limitaitons for now 668d95e6d02SGuray Ozen assert input_type == np.float16 669d95e6d02SGuray Ozen assert output_type == np.float32 670d95e6d02SGuray Ozen assert BLOCK_M == 128 671d95e6d02SGuray Ozen assert BLOCK_N == 128 672d95e6d02SGuray Ozen assert BLOCK_K == 64 673d95e6d02SGuray Ozen assert M % BLOCK_M == 0 674d95e6d02SGuray Ozen assert N % BLOCK_N == 0 675d95e6d02SGuray Ozen assert K % BLOCK_K == 0 676d95e6d02SGuray Ozen 677d95e6d02SGuray Ozen module = ir.Module.create() 678f8ff9094SGuray Ozen token_ty = gpu.AsyncTokenType.get() 679d95e6d02SGuray Ozen a_elem_ty = get_mlir_ty(input_type) 680d95e6d02SGuray Ozen b_elem_ty = get_mlir_ty(input_type) 681d95e6d02SGuray Ozen c_elem_ty = get_mlir_ty(output_type) 682d95e6d02SGuray Ozen a_ty = ir.MemRefType.get([M, K], a_elem_ty) 683d95e6d02SGuray Ozen b_ty = ir.MemRefType.get((K, N), b_elem_ty) 684d95e6d02SGuray Ozen c_ty = ir.MemRefType.get((M, N), c_elem_ty) 685d95e6d02SGuray Ozen a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16) 686d95e6d02SGuray Ozen b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16) 687d95e6d02SGuray Ozen b_tile_shape = (BLOCK_K, BLOCK_N) 688d95e6d02SGuray Ozen txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + ( 689d95e6d02SGuray Ozen a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty) 690d95e6d02SGuray Ozen ) 691d95e6d02SGuray Ozen smem_space_str = "#gpu.address_space<workgroup>" 692d95e6d02SGuray Ozen smem_space = ir.Attribute.parse(smem_space_str) 693d95e6d02SGuray Ozen mbar_ty = ir.Type.parse( 694d95e6d02SGuray Ozen "!nvgpu.mbarrier.group<memorySpace = " 695d95e6d02SGuray Ozen + str(smem_space) 696d95e6d02SGuray Ozen + ", num_barriers = " 697d95e6d02SGuray Ozen + str(num_stages) 698d95e6d02SGuray Ozen + ">" 699d95e6d02SGuray Ozen ) 700d95e6d02SGuray Ozen acc_ty = ir.Type.parse( 701d95e6d02SGuray Ozen "!nvgpu.warpgroup.accumulator<fragmented=vector<" 702d95e6d02SGuray Ozen + str(BLOCK_M) 703d95e6d02SGuray Ozen + "x" 704d95e6d02SGuray Ozen + str(BLOCK_N) 705d95e6d02SGuray Ozen + "x" 706d95e6d02SGuray Ozen + str(c_elem_ty) 707d95e6d02SGuray Ozen + ">>" 708d95e6d02SGuray Ozen ) 709d95e6d02SGuray Ozen a_wgmma_ty = ir.Type.parse( 710d95e6d02SGuray Ozen "!nvgpu.warpgroup.descriptor<tensor=memref<" 711d95e6d02SGuray Ozen + str(BLOCK_M) 712d95e6d02SGuray Ozen + "x" 713d95e6d02SGuray Ozen + str(BLOCK_K) 714d95e6d02SGuray Ozen + "x" 715d95e6d02SGuray Ozen + str(a_elem_ty) 716d95e6d02SGuray Ozen + ", " 717d95e6d02SGuray Ozen + smem_space_str 718d95e6d02SGuray Ozen + ">>" 719d95e6d02SGuray Ozen ) 720d95e6d02SGuray Ozen b_wgmma_ty = ir.Type.parse( 721d95e6d02SGuray Ozen "!nvgpu.warpgroup.descriptor<tensor=memref<" 722d95e6d02SGuray Ozen + str(BLOCK_K) 723d95e6d02SGuray Ozen + "x" 724d95e6d02SGuray Ozen + str(BLOCK_N) 725d95e6d02SGuray Ozen + "x" 726d95e6d02SGuray Ozen + str(a_elem_ty) 727d95e6d02SGuray Ozen + ", " 728d95e6d02SGuray Ozen + smem_space_str 729d95e6d02SGuray Ozen + ">>" 730d95e6d02SGuray Ozen ) 731d95e6d02SGuray Ozen 732d95e6d02SGuray Ozen with ir.InsertionPoint(module.body): 733d95e6d02SGuray Ozen kernelName = make_kernel_name( 734d95e6d02SGuray Ozen input_type, 735d95e6d02SGuray Ozen output_type, 736d95e6d02SGuray Ozen M, 737d95e6d02SGuray Ozen N, 738d95e6d02SGuray Ozen K, 739d95e6d02SGuray Ozen BLOCK_M, 740d95e6d02SGuray Ozen BLOCK_N, 741d95e6d02SGuray Ozen BLOCK_K, 742d95e6d02SGuray Ozen num_stages, 743d95e6d02SGuray Ozen False, 744d95e6d02SGuray Ozen ) 745d95e6d02SGuray Ozen fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], [])) 746d95e6d02SGuray Ozen with ir.InsertionPoint(fop.add_entry_block()): 747d95e6d02SGuray Ozen a_host = fop.arguments[0] 748d95e6d02SGuray Ozen b_host = fop.arguments[1] 749d95e6d02SGuray Ozen c_host = fop.arguments[2] 750d95e6d02SGuray Ozen lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty) 751d95e6d02SGuray Ozen rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty) 752d95e6d02SGuray Ozen smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages 753d95e6d02SGuray Ozen smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty) 754d95e6d02SGuray Ozen smem_size = max(smem_size_input, smem_size_output) 755d95e6d02SGuray Ozen 756d95e6d02SGuray Ozen # Step 1. Allocate device memory and memcpy 757d95e6d02SGuray Ozen t1 = gpu.wait(token_ty, []) 758d95e6d02SGuray Ozen a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], []) 759d95e6d02SGuray Ozen b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], []) 760d95e6d02SGuray Ozen c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], []) 761d95e6d02SGuray Ozen t5 = gpu.memcpy(token_ty, [t4], a_device, a_host) 762d95e6d02SGuray Ozen t6 = gpu.memcpy(token_ty, [t5], b_device, b_host) 763d95e6d02SGuray Ozen t7 = gpu.wait(token_ty, [t6]) 764d95e6d02SGuray Ozen 765d95e6d02SGuray Ozen # Step 2. Create TMA Descriptors 766c82f45f9SGuray Ozen a_tma_desc = TmaDescriptorBuilder( 767c82f45f9SGuray Ozen nvgpu.TensorMapSwizzleKind.SWIZZLE_128B, 768c82f45f9SGuray Ozen nvgpu.TensorMapL2PromoKind.L2PROMO_NONE, 769c82f45f9SGuray Ozen nvgpu.TensorMapOOBKind.OOB_ZERO, 770c82f45f9SGuray Ozen nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE, 771c82f45f9SGuray Ozen a_tma_shape, 772c82f45f9SGuray Ozen a_ty, 773d95e6d02SGuray Ozen ) 774c82f45f9SGuray Ozen 775c82f45f9SGuray Ozen b_tma_desc = TmaDescriptorBuilder( 776c82f45f9SGuray Ozen nvgpu.TensorMapSwizzleKind.SWIZZLE_128B, 777c82f45f9SGuray Ozen nvgpu.TensorMapL2PromoKind.L2PROMO_NONE, 778c82f45f9SGuray Ozen nvgpu.TensorMapOOBKind.OOB_ZERO, 779c82f45f9SGuray Ozen nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE, 780c82f45f9SGuray Ozen b_tma_shape, 781c82f45f9SGuray Ozen b_ty, 782d95e6d02SGuray Ozen ) 783c82f45f9SGuray Ozen 784c82f45f9SGuray Ozen a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device) 785c82f45f9SGuray Ozen b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device) 786d95e6d02SGuray Ozen 787d95e6d02SGuray Ozen # Step 3. Launch Kernel with 1 Warpgroup 788d95e6d02SGuray Ozen cta_m = M // BLOCK_M 789d95e6d02SGuray Ozen cta_n = N // BLOCK_N 790d95e6d02SGuray Ozen assert M % BLOCK_M == 0 and N % BLOCK_N == 0 791d95e6d02SGuray Ozen grid = (cta_m, cta_n, 1) 792d95e6d02SGuray Ozen block = (WARP_GROUP_SIZE, 1, 1) 793d95e6d02SGuray Ozen launch_op = gpu.LaunchOp( 794d95e6d02SGuray Ozen token_ty, 795d95e6d02SGuray Ozen [t7], 796d95e6d02SGuray Ozen *map(c, grid), 797d95e6d02SGuray Ozen *map(c, block), 798c82f45f9SGuray Ozen dynamicSharedMemorySize=c(smem_size, ty=T.i32()), 799d95e6d02SGuray Ozen ) 800d95e6d02SGuray Ozen launch_op.body.blocks.append(*([T.index()] * 12)) 801d95e6d02SGuray Ozen with ir.InsertionPoint(launch_op.body.blocks[0]): 802d95e6d02SGuray Ozen # GPU Step 0. Bootstrapping 803d95e6d02SGuray Ozen memref.assume_alignment(c_device, 16) 804d95e6d02SGuray Ozen dynamic_smem = gpu.dynamic_shared_memory( 805d95e6d02SGuray Ozen ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space) 806d95e6d02SGuray Ozen ) 807d95e6d02SGuray Ozen ticks = c(10000000) 808d95e6d02SGuray Ozen tidx = gpu.thread_id(gpu.Dimension.x) 809d95e6d02SGuray Ozen primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0)) 810d95e6d02SGuray Ozen warpId = arith.divui(tidx, c(32)) 811d95e6d02SGuray Ozen bidx = gpu.block_id(gpu.Dimension.x) 812d95e6d02SGuray Ozen bidy = gpu.block_id(gpu.Dimension.y) 813d95e6d02SGuray Ozen dimX = arith.muli(bidx, c(BLOCK_M)) 814d95e6d02SGuray Ozen dimY = arith.muli(bidy, c(BLOCK_N)) 815d95e6d02SGuray Ozen 816d95e6d02SGuray Ozen # GPU Step 1. Initialize mbarrier groups 817d95e6d02SGuray Ozen mbarTMA = nvgpu.mbarrier_create(mbar_ty) 818d95e6d02SGuray Ozen for i in range(num_stages): 819d95e6d02SGuray Ozen nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread) 820d95e6d02SGuray Ozen gpu.barrier() 821d95e6d02SGuray Ozen 822d95e6d02SGuray Ozen # GPU Step 2. Prefetch TMA descriptors 823c82f45f9SGuray Ozen nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=primaryThread) 824c82f45f9SGuray Ozen nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=primaryThread) 825d95e6d02SGuray Ozen 826d95e6d02SGuray Ozen # GPU Step 3. Prologue (global memory --> shared memory) 827d95e6d02SGuray Ozen ns = num_stages if num_stages == 1 else num_stages - 1 828d95e6d02SGuray Ozen for_op = scf.ForOp(c(0), c(ns), c(1)) 829d95e6d02SGuray Ozen with ir.InsertionPoint(for_op.body): 830d95e6d02SGuray Ozen iv = for_op.induction_variable 831d95e6d02SGuray Ozen 832d95e6d02SGuray Ozen # Step 3.1. Calculate offsets 833d95e6d02SGuray Ozen a_offset = arith.muli(iv, c(lhs_tile_bytes)) 834d95e6d02SGuray Ozen a_tma_slice = memref.view( 835d95e6d02SGuray Ozen ir.MemRefType.get( 836d95e6d02SGuray Ozen a_tma_shape, a_elem_ty, memory_space=smem_space 837d95e6d02SGuray Ozen ), 838d95e6d02SGuray Ozen dynamic_smem, 839d95e6d02SGuray Ozen a_offset, 840d95e6d02SGuray Ozen [], 841d95e6d02SGuray Ozen ) 842d95e6d02SGuray Ozen b_offset = arith.addi( 843d95e6d02SGuray Ozen arith.muli(iv, c(rhs_tile_bytes)), 844d95e6d02SGuray Ozen c(lhs_tile_bytes * num_stages), 845d95e6d02SGuray Ozen ) 846d95e6d02SGuray Ozen b_tma_slice_1 = memref.view( 847d95e6d02SGuray Ozen ir.MemRefType.get( 848d95e6d02SGuray Ozen b_tma_shape, b_elem_ty, memory_space=smem_space 849d95e6d02SGuray Ozen ), 850d95e6d02SGuray Ozen dynamic_smem, 851d95e6d02SGuray Ozen b_offset, 852d95e6d02SGuray Ozen [], 853d95e6d02SGuray Ozen ) 854d95e6d02SGuray Ozen b_offset2 = arith.addi( 855d95e6d02SGuray Ozen b_offset, 856d95e6d02SGuray Ozen c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)), 857d95e6d02SGuray Ozen ) 858d95e6d02SGuray Ozen b_tma_slice_2 = memref.view( 859d95e6d02SGuray Ozen ir.MemRefType.get( 860d95e6d02SGuray Ozen b_tma_shape, b_elem_ty, memory_space=smem_space 861d95e6d02SGuray Ozen ), 862d95e6d02SGuray Ozen dynamic_smem, 863d95e6d02SGuray Ozen b_offset2, 864d95e6d02SGuray Ozen [], 865d95e6d02SGuray Ozen ) 866d95e6d02SGuray Ozen 867d95e6d02SGuray Ozen # Step 3.2. TMA Load 868d95e6d02SGuray Ozen coord = arith.muli(c(64), iv) 869d95e6d02SGuray Ozen dimY2 = arith.addi(dimY, c(64)) 870d95e6d02SGuray Ozen debug_print( 871d95e6d02SGuray Ozen "[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})", 872d95e6d02SGuray Ozen a_offset, 873d95e6d02SGuray Ozen b_offset, 874d95e6d02SGuray Ozen b_offset2, 875d95e6d02SGuray Ozen coord, 876d95e6d02SGuray Ozen dimX, 877d95e6d02SGuray Ozen dimY, 878d95e6d02SGuray Ozen coord, 879d95e6d02SGuray Ozen predicate=primaryThread, 880d95e6d02SGuray Ozen ) 881d95e6d02SGuray Ozen nvgpu.TmaAsyncLoadOp( 882d95e6d02SGuray Ozen a_tma_slice, 883d95e6d02SGuray Ozen mbarTMA, 884c82f45f9SGuray Ozen a_tma_desc_op, 885d95e6d02SGuray Ozen coordinates=[coord, dimX], 886d95e6d02SGuray Ozen mbarId=iv, 887d95e6d02SGuray Ozen predicate=primaryThread, 888d95e6d02SGuray Ozen ) 889d95e6d02SGuray Ozen nvgpu.TmaAsyncLoadOp( 890d95e6d02SGuray Ozen b_tma_slice_1, 891d95e6d02SGuray Ozen mbarTMA, 892c82f45f9SGuray Ozen b_tma_desc_op, 893d95e6d02SGuray Ozen coordinates=[dimY, coord], 894d95e6d02SGuray Ozen mbarId=iv, 895d95e6d02SGuray Ozen predicate=primaryThread, 896d95e6d02SGuray Ozen ) 897d95e6d02SGuray Ozen nvgpu.TmaAsyncLoadOp( 898d95e6d02SGuray Ozen b_tma_slice_2, 899d95e6d02SGuray Ozen mbarTMA, 900c82f45f9SGuray Ozen b_tma_desc_op, 901d95e6d02SGuray Ozen coordinates=[dimY2, coord], 902d95e6d02SGuray Ozen mbarId=iv, 903d95e6d02SGuray Ozen predicate=primaryThread, 904d95e6d02SGuray Ozen ) 905d95e6d02SGuray Ozen 906d95e6d02SGuray Ozen # Step 3.2. mbarTMA arrive 907d95e6d02SGuray Ozen debug_print( 908d95e6d02SGuray Ozen "[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread 909d95e6d02SGuray Ozen ) 910d95e6d02SGuray Ozen nvgpu.mbarrier_arrive_expect_tx( 911d95e6d02SGuray Ozen mbarTMA, c(txcount), iv, predicate=primaryThread 912d95e6d02SGuray Ozen ) 913d95e6d02SGuray Ozen debug_print( 914d95e6d02SGuray Ozen "[Prologue] mbarTMA[{}] arrive [done]", 915d95e6d02SGuray Ozen iv, 916d95e6d02SGuray Ozen predicate=primaryThread, 917d95e6d02SGuray Ozen ) 918d95e6d02SGuray Ozen scf.yield_([]) 919d95e6d02SGuray Ozen 920d95e6d02SGuray Ozen # GPU Step 4. Main Loop 921d95e6d02SGuray Ozen acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty) 922d95e6d02SGuray Ozen for_op = scf.ForOp( 923d95e6d02SGuray Ozen c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)] 924d95e6d02SGuray Ozen ) 925d95e6d02SGuray Ozen with ir.InsertionPoint(for_op.body): 926d95e6d02SGuray Ozen # Step 4.1. Wait mbarTMA 927d95e6d02SGuray Ozen phaseParity = for_op.inner_iter_args[1] 928d95e6d02SGuray Ozen iv = for_op.induction_variable 929d95e6d02SGuray Ozen stage = arith.remui(iv, c(num_stages)) 930d95e6d02SGuray Ozen debug_print( 931d95e6d02SGuray Ozen "[MainLoop] mbarTMA[{}] try_wait phase={}", 932d95e6d02SGuray Ozen stage, 933d95e6d02SGuray Ozen phaseParity, 934d95e6d02SGuray Ozen predicate=primaryThread, 935d95e6d02SGuray Ozen ) 936d95e6d02SGuray Ozen nvgpu.MBarrierTryWaitParityOp( 937d95e6d02SGuray Ozen mbarTMA, phaseParity, ticks, mbarId=stage 938d95e6d02SGuray Ozen ) 939d95e6d02SGuray Ozen debug_print( 940d95e6d02SGuray Ozen "[MainLoop] mbarTMA[{}] try_wait phase={} [done]", 941d95e6d02SGuray Ozen stage, 942d95e6d02SGuray Ozen phaseParity, 943d95e6d02SGuray Ozen predicate=primaryThread, 944d95e6d02SGuray Ozen ) 945d95e6d02SGuray Ozen 946d95e6d02SGuray Ozen # Step 4.2. Create WGMMA Descriptors 947d95e6d02SGuray Ozen a_offset = arith.muli(stage, c(lhs_tile_bytes)) 948d95e6d02SGuray Ozen a_tile_slice = memref.view( 949d95e6d02SGuray Ozen ir.MemRefType.get( 950d95e6d02SGuray Ozen a_tile_shape, a_elem_ty, memory_space=smem_space 951d95e6d02SGuray Ozen ), 952d95e6d02SGuray Ozen dynamic_smem, 953d95e6d02SGuray Ozen a_offset, 954d95e6d02SGuray Ozen [], 955d95e6d02SGuray Ozen ) 956d95e6d02SGuray Ozen b_offset = arith.addi( 957d95e6d02SGuray Ozen arith.muli(stage, c(rhs_tile_bytes)), 958d95e6d02SGuray Ozen c(lhs_tile_bytes * num_stages), 959d95e6d02SGuray Ozen ) 960d95e6d02SGuray Ozen b_tile_slice = memref.view( 961d95e6d02SGuray Ozen ir.MemRefType.get( 962d95e6d02SGuray Ozen b_tile_shape, b_elem_ty, memory_space=smem_space 963d95e6d02SGuray Ozen ), 964d95e6d02SGuray Ozen dynamic_smem, 965d95e6d02SGuray Ozen b_offset, 966d95e6d02SGuray Ozen [], 967d95e6d02SGuray Ozen ) 968d95e6d02SGuray Ozen debug_print( 969d95e6d02SGuray Ozen "[MainLoop] iv={} MMA a_offset={} b_offset={}", 970d95e6d02SGuray Ozen iv, 971d95e6d02SGuray Ozen a_offset, 972d95e6d02SGuray Ozen b_offset, 973d95e6d02SGuray Ozen predicate=primaryThread, 974d95e6d02SGuray Ozen ) 975d95e6d02SGuray Ozen da = nvgpu.WarpgroupGenerateDescriptorOp( 976c82f45f9SGuray Ozen a_wgmma_ty, a_tile_slice, a_tma_desc_op 977d95e6d02SGuray Ozen ) 978d95e6d02SGuray Ozen db = nvgpu.WarpgroupGenerateDescriptorOp( 979c82f45f9SGuray Ozen b_wgmma_ty, b_tile_slice, b_tma_desc_op 980d95e6d02SGuray Ozen ) 981d95e6d02SGuray Ozen 982d95e6d02SGuray Ozen # Step 4.3. MMA 983d95e6d02SGuray Ozen carry_acc = for_op.inner_iter_args[0] 984d95e6d02SGuray Ozen new_acc = nvgpu.WarpgroupMmaOp( 985d95e6d02SGuray Ozen acc.type, da, db, carry_acc, transposeB=True 986d95e6d02SGuray Ozen ) 987d95e6d02SGuray Ozen if num_stages == 1: 988d95e6d02SGuray Ozen nvvm.WgmmaWaitGroupSyncOp(0) 989d95e6d02SGuray Ozen 990d95e6d02SGuray Ozen # Step 4.4. Load TMA for next stage 991d95e6d02SGuray Ozen p1 = arith.cmpi( 992d95e6d02SGuray Ozen arith.CmpIPredicate.ult, 993d95e6d02SGuray Ozen arith.addi(iv, c(ns)), 994d95e6d02SGuray Ozen c(K // BLOCK_K), 995d95e6d02SGuray Ozen ) 996d95e6d02SGuray Ozen p = arith.andi(primaryThread, p1) 997d95e6d02SGuray Ozen nextStage = arith.addi(iv, c(ns)) 998d95e6d02SGuray Ozen nextSlot = arith.remui(nextStage, c(num_stages)) 999d95e6d02SGuray Ozen a_offset = arith.muli(nextSlot, c(lhs_tile_bytes)) 1000d95e6d02SGuray Ozen 1001d95e6d02SGuray Ozen debug_print( 1002d95e6d02SGuray Ozen "[MainLoop] mbarTMA[{}] arrive", 1003d95e6d02SGuray Ozen nextSlot, 1004d95e6d02SGuray Ozen predicate=p, 1005d95e6d02SGuray Ozen ) 1006d95e6d02SGuray Ozen nvgpu.mbarrier_arrive_expect_tx( 1007d95e6d02SGuray Ozen mbarTMA, c(txcount), nextSlot, predicate=p 1008d95e6d02SGuray Ozen ) 1009d95e6d02SGuray Ozen debug_print( 1010d95e6d02SGuray Ozen "[MainLoop] mbarTMA[{}] arrive [done]", 1011d95e6d02SGuray Ozen nextSlot, 1012d95e6d02SGuray Ozen predicate=p, 1013d95e6d02SGuray Ozen ) 1014d95e6d02SGuray Ozen 1015d95e6d02SGuray Ozen a_tma_slice = memref.view( 1016d95e6d02SGuray Ozen ir.MemRefType.get( 1017d95e6d02SGuray Ozen a_tma_shape, a_elem_ty, memory_space=smem_space 1018d95e6d02SGuray Ozen ), 1019d95e6d02SGuray Ozen dynamic_smem, 1020d95e6d02SGuray Ozen a_offset, 1021d95e6d02SGuray Ozen [], 1022d95e6d02SGuray Ozen ) 1023d95e6d02SGuray Ozen b_offset = arith.addi( 1024d95e6d02SGuray Ozen arith.muli(nextSlot, c(rhs_tile_bytes)), 1025d95e6d02SGuray Ozen c(lhs_tile_bytes * num_stages), 1026d95e6d02SGuray Ozen ) 1027d95e6d02SGuray Ozen b_tma_slice_1 = memref.view( 1028d95e6d02SGuray Ozen ir.MemRefType.get( 1029d95e6d02SGuray Ozen b_tma_shape, b_elem_ty, memory_space=smem_space 1030d95e6d02SGuray Ozen ), 1031d95e6d02SGuray Ozen dynamic_smem, 1032d95e6d02SGuray Ozen b_offset, 1033d95e6d02SGuray Ozen [], 1034d95e6d02SGuray Ozen ) 1035d95e6d02SGuray Ozen b_offset2 = arith.addi( 1036d95e6d02SGuray Ozen b_offset, 1037d95e6d02SGuray Ozen c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)), 1038d95e6d02SGuray Ozen ) 1039d95e6d02SGuray Ozen b_tma_slice_2 = memref.view( 1040d95e6d02SGuray Ozen ir.MemRefType.get( 1041d95e6d02SGuray Ozen b_tma_shape, b_elem_ty, memory_space=smem_space 1042d95e6d02SGuray Ozen ), 1043d95e6d02SGuray Ozen dynamic_smem, 1044d95e6d02SGuray Ozen b_offset2, 1045d95e6d02SGuray Ozen [], 1046d95e6d02SGuray Ozen ) 1047d95e6d02SGuray Ozen 1048d95e6d02SGuray Ozen coord = arith.muli(c(64), nextStage) 1049d95e6d02SGuray Ozen debug_print( 1050d95e6d02SGuray Ozen "[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})", 1051d95e6d02SGuray Ozen iv, 1052d95e6d02SGuray Ozen a_offset, 1053d95e6d02SGuray Ozen b_offset, 1054d95e6d02SGuray Ozen b_offset2, 1055d95e6d02SGuray Ozen coord, 1056d95e6d02SGuray Ozen dimX, 1057d95e6d02SGuray Ozen dimY, 1058d95e6d02SGuray Ozen coord, 1059d95e6d02SGuray Ozen predicate=p, 1060d95e6d02SGuray Ozen ) 1061d95e6d02SGuray Ozen nvgpu.TmaAsyncLoadOp( 1062d95e6d02SGuray Ozen a_tma_slice, 1063d95e6d02SGuray Ozen mbarTMA, 1064c82f45f9SGuray Ozen a_tma_desc_op, 1065d95e6d02SGuray Ozen coordinates=[coord, dimX], 1066d95e6d02SGuray Ozen mbarId=nextSlot, 1067d95e6d02SGuray Ozen predicate=p, 1068d95e6d02SGuray Ozen ) 1069d95e6d02SGuray Ozen nvgpu.TmaAsyncLoadOp( 1070d95e6d02SGuray Ozen b_tma_slice_1, 1071d95e6d02SGuray Ozen mbarTMA, 1072c82f45f9SGuray Ozen b_tma_desc_op, 1073d95e6d02SGuray Ozen coordinates=[dimY, coord], 1074d95e6d02SGuray Ozen mbarId=nextSlot, 1075d95e6d02SGuray Ozen predicate=p, 1076d95e6d02SGuray Ozen ) 1077d95e6d02SGuray Ozen dimY2 = arith.addi(dimY, c(64)) 1078d95e6d02SGuray Ozen nvgpu.TmaAsyncLoadOp( 1079d95e6d02SGuray Ozen b_tma_slice_2, 1080d95e6d02SGuray Ozen mbarTMA, 1081c82f45f9SGuray Ozen b_tma_desc_op, 1082d95e6d02SGuray Ozen coordinates=[dimY2, coord], 1083d95e6d02SGuray Ozen mbarId=nextSlot, 1084d95e6d02SGuray Ozen predicate=p, 1085d95e6d02SGuray Ozen ) 1086d95e6d02SGuray Ozen # Step 4.5. Change the phaseParity 1087d95e6d02SGuray Ozen p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1)) 1088d95e6d02SGuray Ozen phaseParity = arith.select( 1089d95e6d02SGuray Ozen p, 1090d95e6d02SGuray Ozen arith.xori(phaseParity, arith.constant(T.bool(), 1)), 1091d95e6d02SGuray Ozen phaseParity, 1092d95e6d02SGuray Ozen ) 1093d95e6d02SGuray Ozen 1094d95e6d02SGuray Ozen # Step 4.5. Yield 1095d95e6d02SGuray Ozen scf.yield_([new_acc, phaseParity]) 1096d95e6d02SGuray Ozen 1097d95e6d02SGuray Ozen # Step 5. Wait All WGMMA groups 1098d95e6d02SGuray Ozen nvvm.WgmmaWaitGroupSyncOp(0) 1099d95e6d02SGuray Ozen 1100d95e6d02SGuray Ozen # Step 6. Epilogue (registers --> shared memory) 1101d95e6d02SGuray Ozen acc_smem_ty = ir.MemRefType.get( 1102d95e6d02SGuray Ozen (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space 1103d95e6d02SGuray Ozen ) 1104d95e6d02SGuray Ozen acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), []) 1105d95e6d02SGuray Ozen debug_print("Storing", predicate=primaryThread) 1106d95e6d02SGuray Ozen nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem) 1107d95e6d02SGuray Ozen gpu.barrier() 1108d95e6d02SGuray Ozen 1109d95e6d02SGuray Ozen # GPU Step 7. Epilogue (shared memory --> global memory) 1110d95e6d02SGuray Ozen fd = ir.MemRefType.get( 1111d95e6d02SGuray Ozen [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space 1112d95e6d02SGuray Ozen ) 1113d95e6d02SGuray Ozen collapsed_smem = memref.view(fd, dynamic_smem, c(0), []) 1114d95e6d02SGuray Ozen rty = ir.MemRefType.get( 1115d95e6d02SGuray Ozen (BLOCK_M, BLOCK_N), 1116d95e6d02SGuray Ozen c_elem_ty, 1117d95e6d02SGuray Ozen ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"), 1118d95e6d02SGuray Ozen ) 1119d95e6d02SGuray Ozen c_device_per_block = memref.SubViewOp( 1120d95e6d02SGuray Ozen rty, 1121d95e6d02SGuray Ozen c_device, 1122d95e6d02SGuray Ozen [dimX, dimY], 1123d95e6d02SGuray Ozen [], 1124d95e6d02SGuray Ozen [], 1125d95e6d02SGuray Ozen [MLIR_DYNAMIC, MLIR_DYNAMIC], 1126d95e6d02SGuray Ozen [BLOCK_M, BLOCK_N], 1127d95e6d02SGuray Ozen [1, 1], 1128d95e6d02SGuray Ozen ) 1129d95e6d02SGuray Ozen vlen = 1 1130d95e6d02SGuray Ozen for_op = scf.ForOp( 1131d95e6d02SGuray Ozen tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE) 1132d95e6d02SGuray Ozen ) 1133d95e6d02SGuray Ozen with ir.InsertionPoint(for_op.body): 1134d95e6d02SGuray Ozen x = arith.divui(for_op.induction_variable, c(BLOCK_M)) 1135d95e6d02SGuray Ozen y = arith.remui(for_op.induction_variable, c(BLOCK_N)) 1136d95e6d02SGuray Ozen vdata = vector.load( 1137d95e6d02SGuray Ozen ir.VectorType.get((vlen,), c_elem_ty), 1138d95e6d02SGuray Ozen collapsed_smem, 1139d95e6d02SGuray Ozen [for_op.induction_variable], 1140d95e6d02SGuray Ozen ) 1141d95e6d02SGuray Ozen vector.store(vdata, c_device_per_block, [x, y]) 1142d95e6d02SGuray Ozen scf.yield_([]) 1143d95e6d02SGuray Ozen 1144d95e6d02SGuray Ozen gpu.terminator() 1145d95e6d02SGuray Ozen 1146d95e6d02SGuray Ozen # Step 4. Copy back to host 1147d95e6d02SGuray Ozen t8 = gpu.wait(token_ty, [launch_op]) 1148d95e6d02SGuray Ozen t9 = gpu.memcpy(token_ty, [t8], c_host, c_device) 1149d95e6d02SGuray Ozen gpu.dealloc(token_ty, [t8], a_device) 1150d95e6d02SGuray Ozen gpu.dealloc(token_ty, [t8], b_device) 1151d95e6d02SGuray Ozen gpu.wait(token_ty, [t9]) 1152d95e6d02SGuray Ozen gpu.dealloc(token_ty, [t8], c_device) 1153d95e6d02SGuray Ozen func.ReturnOp([]) 1154d95e6d02SGuray Ozen 1155d95e6d02SGuray Ozen fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() 1156d95e6d02SGuray Ozen module.operation.verify() 1157d95e6d02SGuray Ozen return module 1158