1import numpy as np 2from mlir import ir 3from mlir.dialects import arith 4from mlir.dialects import func 5from mlir.dialects import gpu 6from mlir.dialects import memref 7from mlir.dialects import nvgpu 8from mlir.dialects import nvvm 9from mlir.dialects import llvm 10from mlir.dialects import builtin 11from mlir.dialects import scf 12from mlir.dialects import vector 13from mlir.extras import types as T 14 15TMA_LAST_DIM_F16 = 64 # 128B flaot16 16WARP_SIZE = 32 17WARP_GROUP_SIZE = WARP_SIZE * 4 18 19PRODUCER_REGISTER_SIZE = 40 20CONSUMER_REGISTER_SIZE = 232 21 22PRODUCER_PRIMARY_THREAD = 128 23CONSUMER_PRIMARY_THREAD = 0 24 25# C++ uses this value to understand whether it's dynamic or not. 26MLIR_DYNAMIC = -9223372036854775808 27 28DEBUG = False 29 30 31class TmaDescriptorBuilder: 32 """A class that builds a TMA descriptor.""" 33 34 def __init__(self, swizzle, l2promo, oob, interleave, tma_box_shape, memref_ty): 35 self.swizzle = swizzle # mlir.nvgpu.TensorMapSwizzleKind 36 self.l2promo = l2promo # mlir.nvgpu.TensorMapL2PromoKind 37 self.oob = oob # mlir.nvgpu.TensorMapOOBKind 38 self.interleave = interleave # mlir.nvgpu.TensorMapInterleaveKind 39 self.tma_box_shape = tma_box_shape 40 self.memref_ty = memref_ty # MemRefType 41 42 @property 43 def tensormap_descriptor_ty(self): 44 """Returns a tensormap descriptor type.""" 45 tensorMemrefType = ir.MemRefType.get( 46 self.tma_box_shape, 47 self.memref_ty.element_type, 48 memory_space=ir.Attribute.parse("3"), 49 ) 50 return nvgpu.TensorMapDescriptorType.get( 51 tensorMemrefType, 52 self.swizzle, 53 self.l2promo, 54 self.oob, 55 self.interleave, 56 ) 57 58 def tma_descriptor_op(self, device_ptr): 59 """Returns a tensormap descriptor op.""" 60 tma_descriptor_ty = self.tensormap_descriptor_ty 61 device_unranked_memref = memref.CastOp( 62 ir.UnrankedMemRefType.get( 63 self.memref_ty.element_type, self.memref_ty.memory_space 64 ), 65 device_ptr, 66 ) 67 tma_descriptor_op = nvgpu.TmaCreateDescriptorOp( 68 tma_descriptor_ty, device_unranked_memref, map(c, self.tma_box_shape) 69 ) 70 return tma_descriptor_op.result 71 72 73def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False): 74 if not DEBUG and not forcePrint: 75 return 76 type_formats = [] 77 for arg in args: 78 ty_format = None 79 if ir.IndexType.isinstance(arg.type): 80 ty_format = "%llu" 81 if ir.IntegerType.isinstance(arg.type): 82 width = ir.IntegerType(arg.type).width 83 if width == 64: 84 ty_format = "%llu" 85 elif width == 32: 86 ty_format = "%d" 87 elif width == 1: 88 ty_format = "%i" 89 if ir.F32Type.isinstance(arg.type): 90 ty_format = "%f" 91 if ty_format is None: 92 raise NotImplementedError(arg.type) 93 type_formats.append(ty_format) 94 if threadNumber != -1: 95 tidx = gpu.thread_id(gpu.Dimension.x) 96 predicate = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(threadNumber)) 97 scf.yield_([]) 98 if_op = scf.IfOp(predicate) 99 with ir.InsertionPoint(if_op.then_block): 100 gpu.printf(fmt.format(*type_formats) + "\n", args) 101 scf.yield_([]) 102 103 104def get_type_size(ty): 105 if ir.FloatType.isinstance(ty): 106 return ir.FloatType(ty).width // 8 107 if ir.IntegerType.isinstance(ty): 108 return ir.IntegerType(ty).width // 8 109 raise NotImplementedError(ty) 110 111 112def get_mlir_ty(dtype): 113 if dtype == np.float16: 114 return T.f16() 115 if dtype == np.float32: 116 return T.f32() 117 if dtype == np.float64: 118 return T.f64() 119 if dtype == np.int32: 120 return T.i32() 121 if dtype == np.int64: 122 return T.i64() 123 raise NotImplementedError(dtype) 124 125 126def c(value, ty=None): 127 ty = T.index() if ty is None else ty 128 return arith.constant(ty, value) 129 130 131def make_kernel_name( 132 input_type=np.float16, 133 output_type=np.float32, 134 M=4096, 135 N=4096, 136 K=4096, 137 BLOCK_M=128, 138 BLOCK_N=128, 139 BLOCK_K=128, 140 num_stages=3, 141 use_warp_specialization=False, 142): 143 kernelName = "warpspecialized" if use_warp_specialization else "multistage" 144 return ( 145 kernelName 146 + "_" 147 + str(M) 148 + "x" 149 + str(N) 150 + "x" 151 + str(K) 152 + "_" 153 + str(BLOCK_M) 154 + "x" 155 + str(BLOCK_N) 156 + "x" 157 + str(BLOCK_K) 158 + "_" 159 + str(num_stages) 160 ) 161 162 163def generate_matmul_ws( 164 input_type=np.float16, 165 output_type=np.float32, 166 M=4096, 167 N=4096, 168 K=4096, 169 BLOCK_M=128, 170 BLOCK_N=128, 171 BLOCK_K=128, 172 num_stages=3, 173): 174 # Limitaitons for now 175 assert input_type == np.float16 176 assert output_type == np.float32 177 assert BLOCK_M == 128 178 assert BLOCK_N == 128 179 assert BLOCK_K == 64 180 assert M % BLOCK_M == 0 181 assert N % BLOCK_N == 0 182 assert K % BLOCK_K == 0 183 184 module = ir.Module.create() 185 token_ty = gpu.AsyncTokenType.get() 186 a_elem_ty = get_mlir_ty(input_type) 187 b_elem_ty = get_mlir_ty(input_type) 188 c_elem_ty = get_mlir_ty(output_type) 189 a_ty = ir.MemRefType.get([M, K], a_elem_ty) 190 b_ty = ir.MemRefType.get((K, N), b_elem_ty) 191 c_ty = ir.MemRefType.get((M, N), c_elem_ty) 192 a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16) 193 b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16) 194 b_tile_shape = (BLOCK_K, BLOCK_N) 195 txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + ( 196 a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty) 197 ) 198 smem_space_str = "#gpu.address_space<workgroup>" 199 smem_space = ir.Attribute.parse(smem_space_str) 200 mbar_ty = ir.Type.parse( 201 "!nvgpu.mbarrier.group<memorySpace = " 202 + str(smem_space) 203 + ", num_barriers = " 204 + str(num_stages) 205 + ">" 206 ) 207 acc_ty = ir.Type.parse( 208 "!nvgpu.warpgroup.accumulator<fragmented=vector<" 209 + str(BLOCK_M) 210 + "x" 211 + str(BLOCK_N) 212 + "x" 213 + str(c_elem_ty) 214 + ">>" 215 ) 216 a_wgmma_ty = ir.Type.parse( 217 "!nvgpu.warpgroup.descriptor<tensor=memref<" 218 + str(BLOCK_M) 219 + "x" 220 + str(BLOCK_K) 221 + "x" 222 + str(a_elem_ty) 223 + ", " 224 + smem_space_str 225 + ">>" 226 ) 227 b_wgmma_ty = ir.Type.parse( 228 "!nvgpu.warpgroup.descriptor<tensor=memref<" 229 + str(BLOCK_K) 230 + "x" 231 + str(BLOCK_N) 232 + "x" 233 + str(a_elem_ty) 234 + ", " 235 + smem_space_str 236 + ">>" 237 ) 238 kernelName = make_kernel_name( 239 input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_stages, True 240 ) 241 with ir.InsertionPoint(module.body): 242 fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], [])) 243 with ir.InsertionPoint(fop.add_entry_block()): 244 a_host = fop.arguments[0] 245 b_host = fop.arguments[1] 246 c_host = fop.arguments[2] 247 lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty) 248 rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty) 249 smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages 250 smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty) 251 smem_size = max(smem_size_input, smem_size_output) 252 253 # Step 1. Allocate device memory and memcpy 254 t1 = gpu.wait(token_ty, []) 255 a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], []) 256 b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], []) 257 c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], []) 258 t5 = gpu.memcpy(token_ty, [t4], a_device, a_host) 259 t6 = gpu.memcpy(token_ty, [t5], b_device, b_host) 260 t7 = gpu.wait(token_ty, [t6]) 261 262 # Step 2. Create TMA Descriptors 263 a_tma_desc = TmaDescriptorBuilder( 264 nvgpu.TensorMapSwizzleKind.SWIZZLE_128B, 265 nvgpu.TensorMapL2PromoKind.L2PROMO_NONE, 266 nvgpu.TensorMapOOBKind.OOB_ZERO, 267 nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE, 268 a_tma_shape, 269 a_ty, 270 ) 271 272 b_tma_desc = TmaDescriptorBuilder( 273 nvgpu.TensorMapSwizzleKind.SWIZZLE_128B, 274 nvgpu.TensorMapL2PromoKind.L2PROMO_NONE, 275 nvgpu.TensorMapOOBKind.OOB_ZERO, 276 nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE, 277 b_tma_shape, 278 b_ty, 279 ) 280 281 a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device) 282 b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device) 283 284 # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer 285 cta_m = M // BLOCK_M 286 cta_n = N // BLOCK_N 287 assert M % BLOCK_M == 0 and N % BLOCK_N == 0 288 grid = (cta_m, cta_n, 1) 289 block = (WARP_GROUP_SIZE * 2, 1, 1) 290 launch_op = gpu.LaunchOp( 291 token_ty, 292 [t7], 293 *map(c, grid), 294 *map(c, block), 295 dynamicSharedMemorySize=c(smem_size, ty=T.i32()), 296 ) 297 launch_op.body.blocks.append(*([T.index()] * 12)) 298 with ir.InsertionPoint(launch_op.body.blocks[0]): 299 # GPU Step 0. This is need for vectorized ld/st 300 memref.assume_alignment(c_device, 16) 301 dynamic_smem = gpu.dynamic_shared_memory( 302 ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space) 303 ) 304 ticks = c(10000000) 305 306 # GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc. 307 tidx = gpu.thread_id(gpu.Dimension.x) 308 wgPrimaryThread = arith.cmpi( 309 arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0) 310 ) 311 warp_id = arith.divui(tidx, c(32)) 312 warpgroup_id = arith.divui(warp_id, c(4)) 313 is_producer = arith.cmpi( 314 arith.CmpIPredicate.eq, 315 warpgroup_id, 316 c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0), 317 ) 318 is_consumer = arith.cmpi( 319 arith.CmpIPredicate.eq, 320 warpgroup_id, 321 c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1), 322 ) 323 producerPrimaryThread = arith.cmpi( 324 arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD) 325 ) 326 consumerPrimaryThread = arith.cmpi( 327 arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD) 328 ) 329 bidx = gpu.block_id(gpu.Dimension.x) 330 bidy = gpu.block_id(gpu.Dimension.y) 331 dimX = arith.muli(bidx, c(BLOCK_M)) 332 dimY = arith.muli(bidy, c(BLOCK_N)) 333 334 # GPU Step 2. Initialize mbarrier groups 335 mbarTMA = nvgpu.mbarrier_create(mbar_ty) 336 mbarDONE = nvgpu.mbarrier_create(mbar_ty) 337 for i in range(num_stages): 338 nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread) 339 nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread) 340 gpu.barrier() 341 342 # GPU Step 3. Prefetch TMA descriptors 343 nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=wgPrimaryThread) 344 nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=wgPrimaryThread) 345 346 ns = num_stages if num_stages == 1 else num_stages - 1 347 # GPU Step 5. Producer Warpgroup (TMA Warpgroup) 348 with ir.InsertionPoint(scf.IfOp(is_producer).then_block): 349 # Step 5.1. Reduce register size 350 nvvm.setmaxregister( 351 PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease 352 ) 353 354 # Step 5.2. TMA Main Loop 355 for_op = scf.ForOp( 356 c(0), c(K // BLOCK_K), c(1), [arith.constant(T.bool(), 1)] 357 ) 358 with ir.InsertionPoint(for_op.body): 359 phaseParity = for_op.inner_iter_args[0] 360 iv = for_op.induction_variable 361 stage = arith.remui(iv, c(num_stages)) 362 363 # Step 5.2.1. Wait mbarDONE 364 debug_print( 365 "[prod] iv={} | mbarDONE[{}] try_wait phase={}", 366 iv, 367 stage, 368 phaseParity, 369 predicate=producerPrimaryThread, 370 ) 371 nvgpu.MBarrierTryWaitParityOp( 372 mbarDONE, phaseParity, ticks, mbarId=stage 373 ) 374 debug_print( 375 "[prod] iv={} | mbarDONE[{}] try_wait phase={} [done]", 376 iv, 377 stage, 378 phaseParity, 379 predicate=producerPrimaryThread, 380 ) 381 p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1)) 382 phaseParity = arith.select( 383 p, 384 arith.xori(phaseParity, arith.constant(T.bool(), 1)), 385 phaseParity, 386 ) 387 388 # Step 5.2.2. Load TMA 389 a_offset = arith.muli(stage, c(lhs_tile_bytes)) 390 a_tma_slice = memref.view( 391 ir.MemRefType.get( 392 a_tma_shape, a_elem_ty, memory_space=smem_space 393 ), 394 dynamic_smem, 395 a_offset, 396 [], 397 ) 398 b_offset = arith.addi( 399 arith.muli(stage, c(rhs_tile_bytes)), 400 c(lhs_tile_bytes * num_stages), 401 ) 402 b_tma_slice_1 = memref.view( 403 ir.MemRefType.get( 404 b_tma_shape, b_elem_ty, memory_space=smem_space 405 ), 406 dynamic_smem, 407 b_offset, 408 [], 409 ) 410 b_offset2 = arith.addi( 411 b_offset, 412 c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)), 413 ) 414 b_tma_slice_2 = memref.view( 415 ir.MemRefType.get( 416 b_tma_shape, b_elem_ty, memory_space=smem_space 417 ), 418 dynamic_smem, 419 b_offset2, 420 [], 421 ) 422 debug_print( 423 "[prod] a_offset={} b_offset={} b_offset2={}", 424 a_offset, 425 b_offset, 426 b_offset2, 427 predicate=producerPrimaryThread, 428 ) 429 coord = arith.muli(c(64), iv) 430 nvgpu.TmaAsyncLoadOp( 431 a_tma_slice, 432 mbarTMA, 433 a_tma_desc_op, 434 coordinates=[coord, dimX], 435 mbarId=stage, 436 predicate=producerPrimaryThread, 437 ) 438 nvgpu.TmaAsyncLoadOp( 439 b_tma_slice_1, 440 mbarTMA, 441 b_tma_desc_op, 442 coordinates=[dimY, coord], 443 mbarId=stage, 444 predicate=producerPrimaryThread, 445 ) 446 dimY2 = arith.addi(dimY, c(64)) 447 nvgpu.TmaAsyncLoadOp( 448 b_tma_slice_2, 449 mbarTMA, 450 b_tma_desc_op, 451 coordinates=[dimY2, coord], 452 mbarId=stage, 453 predicate=producerPrimaryThread, 454 ) 455 456 # Step 5.2.3. Arrive mbarTMA 457 debug_print( 458 "[prod] iv={} | mbarTMA[{}] arrive", 459 iv, 460 stage, 461 predicate=producerPrimaryThread, 462 ) 463 nvgpu.mbarrier_arrive_expect_tx( 464 mbarTMA, c(txcount), stage, predicate=producerPrimaryThread 465 ) 466 debug_print( 467 "[prod] iv={} | mbarTMA[{}] arrive [done]", 468 iv, 469 stage, 470 predicate=producerPrimaryThread, 471 ) 472 scf.yield_([phaseParity]) 473 scf.yield_([]) 474 475 # GPU Step 6. Consumer Warpgroup (MMA Warpgroup) 476 if_op = scf.IfOp(is_consumer) 477 with ir.InsertionPoint(if_op.then_block): 478 # Step 6.1. Increase register size 479 nvvm.setmaxregister( 480 CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase 481 ) 482 483 # GPU Step 6.2. Initialize MMA registers 484 acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty) 485 486 # Step 6.3. MMA Main Loop 487 for_op = scf.ForOp( 488 c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)] 489 ) 490 with ir.InsertionPoint(for_op.body): 491 # Step 6.3.1. Wait mbar1 492 phaseParity = for_op.inner_iter_args[1] 493 iv = for_op.induction_variable 494 stage = arith.remui(iv, c(num_stages)) 495 debug_print( 496 "[cons] iv={} | mbarTMA[{}] try_wait phase={}", 497 iv, 498 stage, 499 phaseParity, 500 predicate=consumerPrimaryThread, 501 ) 502 nvgpu.MBarrierTryWaitParityOp( 503 mbarTMA, phaseParity, ticks, mbarId=stage 504 ) 505 debug_print( 506 "[cons] iv={} | mbarTMA[{}] try_wait phase={} [done]", 507 iv, 508 stage, 509 phaseParity, 510 predicate=consumerPrimaryThread, 511 ) 512 513 # Step 6.3.2. Create WGMMA Descriptors 514 a_offset = arith.muli(stage, c(lhs_tile_bytes)) 515 a_tile_slice = memref.view( 516 ir.MemRefType.get( 517 a_tile_shape, a_elem_ty, memory_space=smem_space 518 ), 519 dynamic_smem, 520 a_offset, 521 [], 522 ) 523 b_offset = arith.addi( 524 arith.muli(stage, c(rhs_tile_bytes)), 525 c(lhs_tile_bytes * num_stages), 526 ) 527 b_tile_slice = memref.view( 528 ir.MemRefType.get( 529 b_tile_shape, b_elem_ty, memory_space=smem_space 530 ), 531 dynamic_smem, 532 b_offset, 533 [], 534 ) 535 debug_print( 536 "[cons] a_offset={} b_offset={}", 537 a_offset, 538 b_offset, 539 predicate=consumerPrimaryThread, 540 ) 541 da = nvgpu.WarpgroupGenerateDescriptorOp( 542 a_wgmma_ty, a_tile_slice, a_tma_desc_op 543 ) 544 db = nvgpu.WarpgroupGenerateDescriptorOp( 545 b_wgmma_ty, b_tile_slice, b_tma_desc_op 546 ) 547 548 # Step 6.3.3. MMA 549 carry_acc = for_op.inner_iter_args[0] 550 new_acc = nvgpu.WarpgroupMmaOp( 551 acc.type, da, db, carry_acc, transposeB=True 552 ) 553 554 # Step 6.3.4. Arrive mbarDONE 555 if num_stages == 1: 556 p_arrive = consumerPrimaryThread 557 else: 558 p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0)) 559 p_arrive = arith.andi(consumerPrimaryThread, p1) 560 with ir.InsertionPoint(scf.IfOp(p_arrive).then_block): 561 p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0)) 562 barId = arith.select( 563 p, c(num_stages - 1), arith.subi(stage, c(1)) 564 ) 565 debug_print( 566 "[cons] iv={} | mbarDONE[{}] arrive ", 567 iv, 568 barId, 569 predicate=consumerPrimaryThread, 570 ) 571 nvgpu.mbarrier_arrive(mbarDONE, barId) 572 debug_print( 573 "[cons] iv={} | mbarDONE[{}] arrive [done]", 574 iv, 575 barId, 576 predicate=consumerPrimaryThread, 577 ) 578 scf.yield_([]) 579 580 p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1)) 581 phaseParity = arith.select( 582 p, 583 arith.xori(phaseParity, arith.constant(T.bool(), 1)), 584 phaseParity, 585 ) 586 587 # Step 6.3.5. Yield 588 scf.yield_([new_acc, phaseParity]) 589 590 with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block): 591 barId = c((K // BLOCK_K) % num_stages) 592 nvgpu.mbarrier_arrive(mbarDONE, barId) 593 scf.yield_([]) 594 595 # Step 6.4. Epilogue (registers --> shared memory) 596 acc_smem_ty = ir.MemRefType.get( 597 (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space 598 ) 599 acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), []) 600 debug_print("[cons] | Storing", predicate=consumerPrimaryThread) 601 nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem) 602 scf.yield_([]) 603 gpu.barrier() 604 605 # GPU Step 9. Epilogue (shared memory --> global memory) 606 fd = ir.MemRefType.get( 607 [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space 608 ) 609 collapsed_smem = memref.view(fd, dynamic_smem, c(0), []) 610 rty = ir.MemRefType.get( 611 (BLOCK_M, BLOCK_N), 612 c_elem_ty, 613 ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"), 614 ) 615 c_device_per_block = memref.SubViewOp( 616 rty, 617 c_device, 618 [dimX, dimY], 619 [], 620 [], 621 [MLIR_DYNAMIC, MLIR_DYNAMIC], 622 [BLOCK_M, BLOCK_N], 623 [1, 1], 624 ) 625 vlen = 1 626 for_op = scf.ForOp( 627 tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2) 628 ) 629 with ir.InsertionPoint(for_op.body): 630 x = arith.divui(for_op.induction_variable, c(BLOCK_M)) 631 y = arith.remui(for_op.induction_variable, c(BLOCK_N)) 632 vdata = vector.load( 633 ir.VectorType.get((vlen,), c_elem_ty), 634 collapsed_smem, 635 [for_op.induction_variable], 636 ) 637 vector.store(vdata, c_device_per_block, [x, y]) 638 scf.yield_([]) 639 640 gpu.terminator() 641 642 # Step 4. Copy back to host 643 t8 = gpu.wait(token_ty, [launch_op]) 644 t9 = gpu.memcpy(token_ty, [t8], c_host, c_device) 645 gpu.dealloc(token_ty, [t8], a_device) 646 gpu.dealloc(token_ty, [t8], b_device) 647 gpu.wait(token_ty, [t9]) 648 gpu.dealloc(token_ty, [t8], c_device) 649 func.ReturnOp([]) 650 651 fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() 652 module.operation.verify() 653 return module 654 655 656def generate_matmul_multistage( 657 input_type=np.float16, 658 output_type=np.float32, 659 M=4096, 660 N=4096, 661 K=4096, 662 BLOCK_M=128, 663 BLOCK_N=128, 664 BLOCK_K=64, 665 num_stages=3, 666): 667 # Limitaitons for now 668 assert input_type == np.float16 669 assert output_type == np.float32 670 assert BLOCK_M == 128 671 assert BLOCK_N == 128 672 assert BLOCK_K == 64 673 assert M % BLOCK_M == 0 674 assert N % BLOCK_N == 0 675 assert K % BLOCK_K == 0 676 677 module = ir.Module.create() 678 token_ty = gpu.AsyncTokenType.get() 679 a_elem_ty = get_mlir_ty(input_type) 680 b_elem_ty = get_mlir_ty(input_type) 681 c_elem_ty = get_mlir_ty(output_type) 682 a_ty = ir.MemRefType.get([M, K], a_elem_ty) 683 b_ty = ir.MemRefType.get((K, N), b_elem_ty) 684 c_ty = ir.MemRefType.get((M, N), c_elem_ty) 685 a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16) 686 b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16) 687 b_tile_shape = (BLOCK_K, BLOCK_N) 688 txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + ( 689 a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty) 690 ) 691 smem_space_str = "#gpu.address_space<workgroup>" 692 smem_space = ir.Attribute.parse(smem_space_str) 693 mbar_ty = ir.Type.parse( 694 "!nvgpu.mbarrier.group<memorySpace = " 695 + str(smem_space) 696 + ", num_barriers = " 697 + str(num_stages) 698 + ">" 699 ) 700 acc_ty = ir.Type.parse( 701 "!nvgpu.warpgroup.accumulator<fragmented=vector<" 702 + str(BLOCK_M) 703 + "x" 704 + str(BLOCK_N) 705 + "x" 706 + str(c_elem_ty) 707 + ">>" 708 ) 709 a_wgmma_ty = ir.Type.parse( 710 "!nvgpu.warpgroup.descriptor<tensor=memref<" 711 + str(BLOCK_M) 712 + "x" 713 + str(BLOCK_K) 714 + "x" 715 + str(a_elem_ty) 716 + ", " 717 + smem_space_str 718 + ">>" 719 ) 720 b_wgmma_ty = ir.Type.parse( 721 "!nvgpu.warpgroup.descriptor<tensor=memref<" 722 + str(BLOCK_K) 723 + "x" 724 + str(BLOCK_N) 725 + "x" 726 + str(a_elem_ty) 727 + ", " 728 + smem_space_str 729 + ">>" 730 ) 731 732 with ir.InsertionPoint(module.body): 733 kernelName = make_kernel_name( 734 input_type, 735 output_type, 736 M, 737 N, 738 K, 739 BLOCK_M, 740 BLOCK_N, 741 BLOCK_K, 742 num_stages, 743 False, 744 ) 745 fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], [])) 746 with ir.InsertionPoint(fop.add_entry_block()): 747 a_host = fop.arguments[0] 748 b_host = fop.arguments[1] 749 c_host = fop.arguments[2] 750 lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty) 751 rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty) 752 smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages 753 smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty) 754 smem_size = max(smem_size_input, smem_size_output) 755 756 # Step 1. Allocate device memory and memcpy 757 t1 = gpu.wait(token_ty, []) 758 a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], []) 759 b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], []) 760 c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], []) 761 t5 = gpu.memcpy(token_ty, [t4], a_device, a_host) 762 t6 = gpu.memcpy(token_ty, [t5], b_device, b_host) 763 t7 = gpu.wait(token_ty, [t6]) 764 765 # Step 2. Create TMA Descriptors 766 a_tma_desc = TmaDescriptorBuilder( 767 nvgpu.TensorMapSwizzleKind.SWIZZLE_128B, 768 nvgpu.TensorMapL2PromoKind.L2PROMO_NONE, 769 nvgpu.TensorMapOOBKind.OOB_ZERO, 770 nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE, 771 a_tma_shape, 772 a_ty, 773 ) 774 775 b_tma_desc = TmaDescriptorBuilder( 776 nvgpu.TensorMapSwizzleKind.SWIZZLE_128B, 777 nvgpu.TensorMapL2PromoKind.L2PROMO_NONE, 778 nvgpu.TensorMapOOBKind.OOB_ZERO, 779 nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE, 780 b_tma_shape, 781 b_ty, 782 ) 783 784 a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device) 785 b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device) 786 787 # Step 3. Launch Kernel with 1 Warpgroup 788 cta_m = M // BLOCK_M 789 cta_n = N // BLOCK_N 790 assert M % BLOCK_M == 0 and N % BLOCK_N == 0 791 grid = (cta_m, cta_n, 1) 792 block = (WARP_GROUP_SIZE, 1, 1) 793 launch_op = gpu.LaunchOp( 794 token_ty, 795 [t7], 796 *map(c, grid), 797 *map(c, block), 798 dynamicSharedMemorySize=c(smem_size, ty=T.i32()), 799 ) 800 launch_op.body.blocks.append(*([T.index()] * 12)) 801 with ir.InsertionPoint(launch_op.body.blocks[0]): 802 # GPU Step 0. Bootstrapping 803 memref.assume_alignment(c_device, 16) 804 dynamic_smem = gpu.dynamic_shared_memory( 805 ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space) 806 ) 807 ticks = c(10000000) 808 tidx = gpu.thread_id(gpu.Dimension.x) 809 primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0)) 810 warpId = arith.divui(tidx, c(32)) 811 bidx = gpu.block_id(gpu.Dimension.x) 812 bidy = gpu.block_id(gpu.Dimension.y) 813 dimX = arith.muli(bidx, c(BLOCK_M)) 814 dimY = arith.muli(bidy, c(BLOCK_N)) 815 816 # GPU Step 1. Initialize mbarrier groups 817 mbarTMA = nvgpu.mbarrier_create(mbar_ty) 818 for i in range(num_stages): 819 nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread) 820 gpu.barrier() 821 822 # GPU Step 2. Prefetch TMA descriptors 823 nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=primaryThread) 824 nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=primaryThread) 825 826 # GPU Step 3. Prologue (global memory --> shared memory) 827 ns = num_stages if num_stages == 1 else num_stages - 1 828 for_op = scf.ForOp(c(0), c(ns), c(1)) 829 with ir.InsertionPoint(for_op.body): 830 iv = for_op.induction_variable 831 832 # Step 3.1. Calculate offsets 833 a_offset = arith.muli(iv, c(lhs_tile_bytes)) 834 a_tma_slice = memref.view( 835 ir.MemRefType.get( 836 a_tma_shape, a_elem_ty, memory_space=smem_space 837 ), 838 dynamic_smem, 839 a_offset, 840 [], 841 ) 842 b_offset = arith.addi( 843 arith.muli(iv, c(rhs_tile_bytes)), 844 c(lhs_tile_bytes * num_stages), 845 ) 846 b_tma_slice_1 = memref.view( 847 ir.MemRefType.get( 848 b_tma_shape, b_elem_ty, memory_space=smem_space 849 ), 850 dynamic_smem, 851 b_offset, 852 [], 853 ) 854 b_offset2 = arith.addi( 855 b_offset, 856 c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)), 857 ) 858 b_tma_slice_2 = memref.view( 859 ir.MemRefType.get( 860 b_tma_shape, b_elem_ty, memory_space=smem_space 861 ), 862 dynamic_smem, 863 b_offset2, 864 [], 865 ) 866 867 # Step 3.2. TMA Load 868 coord = arith.muli(c(64), iv) 869 dimY2 = arith.addi(dimY, c(64)) 870 debug_print( 871 "[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})", 872 a_offset, 873 b_offset, 874 b_offset2, 875 coord, 876 dimX, 877 dimY, 878 coord, 879 predicate=primaryThread, 880 ) 881 nvgpu.TmaAsyncLoadOp( 882 a_tma_slice, 883 mbarTMA, 884 a_tma_desc_op, 885 coordinates=[coord, dimX], 886 mbarId=iv, 887 predicate=primaryThread, 888 ) 889 nvgpu.TmaAsyncLoadOp( 890 b_tma_slice_1, 891 mbarTMA, 892 b_tma_desc_op, 893 coordinates=[dimY, coord], 894 mbarId=iv, 895 predicate=primaryThread, 896 ) 897 nvgpu.TmaAsyncLoadOp( 898 b_tma_slice_2, 899 mbarTMA, 900 b_tma_desc_op, 901 coordinates=[dimY2, coord], 902 mbarId=iv, 903 predicate=primaryThread, 904 ) 905 906 # Step 3.2. mbarTMA arrive 907 debug_print( 908 "[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread 909 ) 910 nvgpu.mbarrier_arrive_expect_tx( 911 mbarTMA, c(txcount), iv, predicate=primaryThread 912 ) 913 debug_print( 914 "[Prologue] mbarTMA[{}] arrive [done]", 915 iv, 916 predicate=primaryThread, 917 ) 918 scf.yield_([]) 919 920 # GPU Step 4. Main Loop 921 acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty) 922 for_op = scf.ForOp( 923 c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)] 924 ) 925 with ir.InsertionPoint(for_op.body): 926 # Step 4.1. Wait mbarTMA 927 phaseParity = for_op.inner_iter_args[1] 928 iv = for_op.induction_variable 929 stage = arith.remui(iv, c(num_stages)) 930 debug_print( 931 "[MainLoop] mbarTMA[{}] try_wait phase={}", 932 stage, 933 phaseParity, 934 predicate=primaryThread, 935 ) 936 nvgpu.MBarrierTryWaitParityOp( 937 mbarTMA, phaseParity, ticks, mbarId=stage 938 ) 939 debug_print( 940 "[MainLoop] mbarTMA[{}] try_wait phase={} [done]", 941 stage, 942 phaseParity, 943 predicate=primaryThread, 944 ) 945 946 # Step 4.2. Create WGMMA Descriptors 947 a_offset = arith.muli(stage, c(lhs_tile_bytes)) 948 a_tile_slice = memref.view( 949 ir.MemRefType.get( 950 a_tile_shape, a_elem_ty, memory_space=smem_space 951 ), 952 dynamic_smem, 953 a_offset, 954 [], 955 ) 956 b_offset = arith.addi( 957 arith.muli(stage, c(rhs_tile_bytes)), 958 c(lhs_tile_bytes * num_stages), 959 ) 960 b_tile_slice = memref.view( 961 ir.MemRefType.get( 962 b_tile_shape, b_elem_ty, memory_space=smem_space 963 ), 964 dynamic_smem, 965 b_offset, 966 [], 967 ) 968 debug_print( 969 "[MainLoop] iv={} MMA a_offset={} b_offset={}", 970 iv, 971 a_offset, 972 b_offset, 973 predicate=primaryThread, 974 ) 975 da = nvgpu.WarpgroupGenerateDescriptorOp( 976 a_wgmma_ty, a_tile_slice, a_tma_desc_op 977 ) 978 db = nvgpu.WarpgroupGenerateDescriptorOp( 979 b_wgmma_ty, b_tile_slice, b_tma_desc_op 980 ) 981 982 # Step 4.3. MMA 983 carry_acc = for_op.inner_iter_args[0] 984 new_acc = nvgpu.WarpgroupMmaOp( 985 acc.type, da, db, carry_acc, transposeB=True 986 ) 987 if num_stages == 1: 988 nvvm.WgmmaWaitGroupSyncOp(0) 989 990 # Step 4.4. Load TMA for next stage 991 p1 = arith.cmpi( 992 arith.CmpIPredicate.ult, 993 arith.addi(iv, c(ns)), 994 c(K // BLOCK_K), 995 ) 996 p = arith.andi(primaryThread, p1) 997 nextStage = arith.addi(iv, c(ns)) 998 nextSlot = arith.remui(nextStage, c(num_stages)) 999 a_offset = arith.muli(nextSlot, c(lhs_tile_bytes)) 1000 1001 debug_print( 1002 "[MainLoop] mbarTMA[{}] arrive", 1003 nextSlot, 1004 predicate=p, 1005 ) 1006 nvgpu.mbarrier_arrive_expect_tx( 1007 mbarTMA, c(txcount), nextSlot, predicate=p 1008 ) 1009 debug_print( 1010 "[MainLoop] mbarTMA[{}] arrive [done]", 1011 nextSlot, 1012 predicate=p, 1013 ) 1014 1015 a_tma_slice = memref.view( 1016 ir.MemRefType.get( 1017 a_tma_shape, a_elem_ty, memory_space=smem_space 1018 ), 1019 dynamic_smem, 1020 a_offset, 1021 [], 1022 ) 1023 b_offset = arith.addi( 1024 arith.muli(nextSlot, c(rhs_tile_bytes)), 1025 c(lhs_tile_bytes * num_stages), 1026 ) 1027 b_tma_slice_1 = memref.view( 1028 ir.MemRefType.get( 1029 b_tma_shape, b_elem_ty, memory_space=smem_space 1030 ), 1031 dynamic_smem, 1032 b_offset, 1033 [], 1034 ) 1035 b_offset2 = arith.addi( 1036 b_offset, 1037 c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)), 1038 ) 1039 b_tma_slice_2 = memref.view( 1040 ir.MemRefType.get( 1041 b_tma_shape, b_elem_ty, memory_space=smem_space 1042 ), 1043 dynamic_smem, 1044 b_offset2, 1045 [], 1046 ) 1047 1048 coord = arith.muli(c(64), nextStage) 1049 debug_print( 1050 "[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})", 1051 iv, 1052 a_offset, 1053 b_offset, 1054 b_offset2, 1055 coord, 1056 dimX, 1057 dimY, 1058 coord, 1059 predicate=p, 1060 ) 1061 nvgpu.TmaAsyncLoadOp( 1062 a_tma_slice, 1063 mbarTMA, 1064 a_tma_desc_op, 1065 coordinates=[coord, dimX], 1066 mbarId=nextSlot, 1067 predicate=p, 1068 ) 1069 nvgpu.TmaAsyncLoadOp( 1070 b_tma_slice_1, 1071 mbarTMA, 1072 b_tma_desc_op, 1073 coordinates=[dimY, coord], 1074 mbarId=nextSlot, 1075 predicate=p, 1076 ) 1077 dimY2 = arith.addi(dimY, c(64)) 1078 nvgpu.TmaAsyncLoadOp( 1079 b_tma_slice_2, 1080 mbarTMA, 1081 b_tma_desc_op, 1082 coordinates=[dimY2, coord], 1083 mbarId=nextSlot, 1084 predicate=p, 1085 ) 1086 # Step 4.5. Change the phaseParity 1087 p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1)) 1088 phaseParity = arith.select( 1089 p, 1090 arith.xori(phaseParity, arith.constant(T.bool(), 1)), 1091 phaseParity, 1092 ) 1093 1094 # Step 4.5. Yield 1095 scf.yield_([new_acc, phaseParity]) 1096 1097 # Step 5. Wait All WGMMA groups 1098 nvvm.WgmmaWaitGroupSyncOp(0) 1099 1100 # Step 6. Epilogue (registers --> shared memory) 1101 acc_smem_ty = ir.MemRefType.get( 1102 (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space 1103 ) 1104 acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), []) 1105 debug_print("Storing", predicate=primaryThread) 1106 nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem) 1107 gpu.barrier() 1108 1109 # GPU Step 7. Epilogue (shared memory --> global memory) 1110 fd = ir.MemRefType.get( 1111 [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space 1112 ) 1113 collapsed_smem = memref.view(fd, dynamic_smem, c(0), []) 1114 rty = ir.MemRefType.get( 1115 (BLOCK_M, BLOCK_N), 1116 c_elem_ty, 1117 ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"), 1118 ) 1119 c_device_per_block = memref.SubViewOp( 1120 rty, 1121 c_device, 1122 [dimX, dimY], 1123 [], 1124 [], 1125 [MLIR_DYNAMIC, MLIR_DYNAMIC], 1126 [BLOCK_M, BLOCK_N], 1127 [1, 1], 1128 ) 1129 vlen = 1 1130 for_op = scf.ForOp( 1131 tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE) 1132 ) 1133 with ir.InsertionPoint(for_op.body): 1134 x = arith.divui(for_op.induction_variable, c(BLOCK_M)) 1135 y = arith.remui(for_op.induction_variable, c(BLOCK_N)) 1136 vdata = vector.load( 1137 ir.VectorType.get((vlen,), c_elem_ty), 1138 collapsed_smem, 1139 [for_op.induction_variable], 1140 ) 1141 vector.store(vdata, c_device_per_block, [x, y]) 1142 scf.yield_([]) 1143 1144 gpu.terminator() 1145 1146 # Step 4. Copy back to host 1147 t8 = gpu.wait(token_ty, [launch_op]) 1148 t9 = gpu.memcpy(token_ty, [t8], c_host, c_device) 1149 gpu.dealloc(token_ty, [t8], a_device) 1150 gpu.dealloc(token_ty, [t8], b_device) 1151 gpu.wait(token_ty, [t9]) 1152 gpu.dealloc(token_ty, [t8], c_device) 1153 func.ReturnOp([]) 1154 1155 fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() 1156 module.operation.verify() 1157 return module 1158