14d330820SGuray Ozen# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ 24d330820SGuray Ozen# RUN: %PYTHON %s | FileCheck %s 34d330820SGuray Ozen 44d330820SGuray Ozen# ===----------------------------------------------------------------------===// 54d330820SGuray Ozen# Chapter 3 : GEMM 128x128x64 with Tensor Core 64d330820SGuray Ozen# ===----------------------------------------------------------------------===// 74d330820SGuray Ozen# 84d330820SGuray Ozen# This program demonstrates a GEMM operation with 128x128x64 matrix multiplication 94d330820SGuray Ozen# 104d330820SGuray Ozen# This chapter introduces demonstrates: 114d330820SGuray Ozen# 1. Execute TMA Load for two input matrices 124d330820SGuray Ozen# 2. Performs Tensor Core GEMM 128x128x64 by warpgroup 134d330820SGuray Ozen# 3. Stores fragmented registers to global memory by warpgroup 144d330820SGuray Ozen# 154d330820SGuray Ozen# ===----------------------------------------------------------------------===// 164d330820SGuray Ozen 174d330820SGuray Ozen 184d330820SGuray Ozenfrom mlir import ir 194d330820SGuray Ozenfrom mlir.dialects import nvgpu, scf, arith, memref, vector, gpu 204d330820SGuray Ozenfrom tools.nvdsl import * 214d330820SGuray Ozenfrom mlir.extras import types as T 224d330820SGuray Ozenimport numpy as np 234d330820SGuray Ozen 244d330820SGuray Ozen 254d330820SGuray Ozendef tma_load( 264d330820SGuray Ozen mbar_group: Mbarriers, 274d330820SGuray Ozen a_tma: TMA, 284d330820SGuray Ozen b_tma: TMA, 294d330820SGuray Ozen p, 304d330820SGuray Ozen): 314d330820SGuray Ozen """ 324d330820SGuray Ozen TMA loads two input matrices from global memory to shared memory. It performs the following operations: 334d330820SGuray Ozen 344d330820SGuray Ozen - tma.load a_shared_memory[0] at coordinate [0, 0] (Loads 128x64) 354d330820SGuray Ozen - tma.load b_shared_memory[0] at coordinate [0, 0] (Loads 64x64) 364d330820SGuray Ozen - tma.load b_shared_memory[0] at coordinate [64, 0] (Loads 64x64) 374d330820SGuray Ozen 384d330820SGuray Ozen mbarrier.arrive ta_count = 128x64xf16 + 64x128xf16 394d330820SGuray Ozen """ 404d330820SGuray Ozen 414d330820SGuray Ozen size_tma_a = get_type_size(a_tma.tma_memref) 424d330820SGuray Ozen size_tma_b = get_type_size(b_tma.tma_memref) 434d330820SGuray Ozen ta_count = size_tma_a + (size_tma_b * 2) 444d330820SGuray Ozen 454d330820SGuray Ozen off_b = size_tma_a 464d330820SGuray Ozen off_b2 = off_b + size_tma_b 474d330820SGuray Ozen a_elem_ty = a_tma.tma_memref.element_type 484d330820SGuray Ozen b_elem_ty = b_tma.tma_memref.element_type 494d330820SGuray Ozen a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty) 504d330820SGuray Ozen b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b) 514d330820SGuray Ozen b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2) 524d330820SGuray Ozen 534d330820SGuray Ozen mbar_group[0].arrive(ta_count, predicate=p) 544d330820SGuray Ozen 554d330820SGuray Ozen a_tma.load(a, mbar_group[0], coords=[0, 0], predicate=p) 564d330820SGuray Ozen b_tma.load(b1, mbar_group[0], coords=[0, 0], predicate=p) 574d330820SGuray Ozen b_tma.load(b2, mbar_group[0], coords=[64, 0], predicate=p) 584d330820SGuray Ozen 594d330820SGuray Ozen 604d330820SGuray Ozen@NVDSL.mlir_func 614d330820SGuray Ozendef gemm_128_128_64(a, b, d): 62*f8ff9094SGuray Ozen token_ty = gpu.AsyncTokenType.get() 634d330820SGuray Ozen t1 = gpu.wait(token_ty, []) 644d330820SGuray Ozen a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], []) 654d330820SGuray Ozen b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], []) 664d330820SGuray Ozen d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], []) 674d330820SGuray Ozen t5 = gpu.memcpy(token_ty, [t4], a_dev, a) 684d330820SGuray Ozen t6 = gpu.memcpy(token_ty, [t5], b_dev, b) 694d330820SGuray Ozen t7 = gpu.wait(token_ty, [t6]) 704d330820SGuray Ozen 714d330820SGuray Ozen sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B 724d330820SGuray Ozen a_tma = TMA([128, 64], a.type, swizzle=sw) 734d330820SGuray Ozen b_tma = TMA([64, 64], b.type, swizzle=sw) 744d330820SGuray Ozen a_tma.create_descriptor(a_dev) 754d330820SGuray Ozen b_tma.create_descriptor(b_dev) 764d330820SGuray Ozen a_size = get_type_size(a.type) 774d330820SGuray Ozen b_size = get_type_size(b.type) 784d330820SGuray Ozen smem_size_in_bytes = a_size + b_size 794d330820SGuray Ozen 804d330820SGuray Ozen @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=smem_size_in_bytes) 814d330820SGuray Ozen def gemm_tma_kernel(): 824d330820SGuray Ozen tidx = gpu.thread_id(gpu.Dimension.x) 834d330820SGuray Ozen 844d330820SGuray Ozen mbar_group = Mbarriers(number_of_barriers=1) 854d330820SGuray Ozen isThread0 = tidx == 0 864d330820SGuray Ozen 874d330820SGuray Ozen mbar_group[0].init(1, predicate=isThread0) 884d330820SGuray Ozen a_tma.prefetch(predicate=isThread0) 894d330820SGuray Ozen b_tma.prefetch(predicate=isThread0) 904d330820SGuray Ozen 914d330820SGuray Ozen a_smem = get_dynamic_shared_memory((M, K), T.f16()) 924d330820SGuray Ozen b_smem = get_dynamic_shared_memory((K, N), T.f16(), offset=a_size) 934d330820SGuray Ozen 944d330820SGuray Ozen # 1. TMA Load for two input matrices 954d330820SGuray Ozen tma_load(mbar_group, a_tma, b_tma, isThread0) 964d330820SGuray Ozen 974d330820SGuray Ozen # 2. All threads wait TMA load completion 984d330820SGuray Ozen mbar_group[0].try_wait() 994d330820SGuray Ozen 1004d330820SGuray Ozen # 3. Performs Tensor Core GEMM 128x128x64 by warpgroup 1014d330820SGuray Ozen A = WGMMAMatrix(WGMMAType.Descriptor, [M, K], desc=a_tma, smem=a_smem) 1024d330820SGuray Ozen B = WGMMAMatrix(WGMMAType.Descriptor, [K, N], desc=b_tma, smem=b_smem) 1034d330820SGuray Ozen D = WGMMAMatrix(WGMMAType.Accumulator, shape=[M, N], ty=T.f32()) 1044d330820SGuray Ozen 1054d330820SGuray Ozen # Matrix Multiply 1064d330820SGuray Ozen D += A @ B 1074d330820SGuray Ozen 1084d330820SGuray Ozen # 4. Stores fragmented registers to global memory by warpgroup 1094d330820SGuray Ozen D.store_accumulator(d_dev) 1104d330820SGuray Ozen 1114d330820SGuray Ozen gemm_tma_kernel() 1124d330820SGuray Ozen 1134d330820SGuray Ozen t8 = gpu.memcpy(token_ty, [t7], d, d_dev) 1144d330820SGuray Ozen gpu.wait(None, [t8]) 1154d330820SGuray Ozen 1164d330820SGuray Ozen 1174d330820SGuray Ozen# Python pass arguments to MLIR 1184d330820SGuray OzenM = 128 1194d330820SGuray OzenN = 128 1204d330820SGuray OzenK = 64 1214d330820SGuray Ozena = np.random.randn(M, K).astype(np.float16) 1224d330820SGuray Ozenb = np.random.randn(K, N).astype(np.float16) 1234d330820SGuray Ozend = np.zeros((M, N), np.float32) 1244d330820SGuray Ozengemm_128_128_64(a, b, d) 1254d330820SGuray Ozen 1264d330820SGuray Ozenref_d = a.astype(np.float16) @ b.astype(np.float16) 1274d330820SGuray Ozennp.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01) 1284d330820SGuray Ozenprint("PASS") 1294d330820SGuray Ozen# CHECK-NOT: Mismatched elements 130