1# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ 2# RUN: %PYTHON %s | FileCheck %s 3 4# ===----------------------------------------------------------------------===// 5# Chapter 3 : GEMM 128x128x64 with Tensor Core 6# ===----------------------------------------------------------------------===// 7# 8# This program demonstrates a GEMM operation with 128x128x64 matrix multiplication 9# 10# This chapter introduces demonstrates: 11# 1. Execute TMA Load for two input matrices 12# 2. Performs Tensor Core GEMM 128x128x64 by warpgroup 13# 3. Stores fragmented registers to global memory by warpgroup 14# 15# ===----------------------------------------------------------------------===// 16 17 18from mlir import ir 19from mlir.dialects import nvgpu, scf, arith, memref, vector, gpu 20from tools.nvdsl import * 21from mlir.extras import types as T 22import numpy as np 23 24 25def tma_load( 26 mbar_group: Mbarriers, 27 a_tma: TMA, 28 b_tma: TMA, 29 p, 30): 31 """ 32 TMA loads two input matrices from global memory to shared memory. It performs the following operations: 33 34 - tma.load a_shared_memory[0] at coordinate [0, 0] (Loads 128x64) 35 - tma.load b_shared_memory[0] at coordinate [0, 0] (Loads 64x64) 36 - tma.load b_shared_memory[0] at coordinate [64, 0] (Loads 64x64) 37 38 mbarrier.arrive ta_count = 128x64xf16 + 64x128xf16 39 """ 40 41 size_tma_a = get_type_size(a_tma.tma_memref) 42 size_tma_b = get_type_size(b_tma.tma_memref) 43 ta_count = size_tma_a + (size_tma_b * 2) 44 45 off_b = size_tma_a 46 off_b2 = off_b + size_tma_b 47 a_elem_ty = a_tma.tma_memref.element_type 48 b_elem_ty = b_tma.tma_memref.element_type 49 a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty) 50 b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b) 51 b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2) 52 53 mbar_group[0].arrive(ta_count, predicate=p) 54 55 a_tma.load(a, mbar_group[0], coords=[0, 0], predicate=p) 56 b_tma.load(b1, mbar_group[0], coords=[0, 0], predicate=p) 57 b_tma.load(b2, mbar_group[0], coords=[64, 0], predicate=p) 58 59 60@NVDSL.mlir_func 61def gemm_128_128_64(a, b, d): 62 token_ty = gpu.AsyncTokenType.get() 63 t1 = gpu.wait(token_ty, []) 64 a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], []) 65 b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], []) 66 d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], []) 67 t5 = gpu.memcpy(token_ty, [t4], a_dev, a) 68 t6 = gpu.memcpy(token_ty, [t5], b_dev, b) 69 t7 = gpu.wait(token_ty, [t6]) 70 71 sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B 72 a_tma = TMA([128, 64], a.type, swizzle=sw) 73 b_tma = TMA([64, 64], b.type, swizzle=sw) 74 a_tma.create_descriptor(a_dev) 75 b_tma.create_descriptor(b_dev) 76 a_size = get_type_size(a.type) 77 b_size = get_type_size(b.type) 78 smem_size_in_bytes = a_size + b_size 79 80 @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=smem_size_in_bytes) 81 def gemm_tma_kernel(): 82 tidx = gpu.thread_id(gpu.Dimension.x) 83 84 mbar_group = Mbarriers(number_of_barriers=1) 85 isThread0 = tidx == 0 86 87 mbar_group[0].init(1, predicate=isThread0) 88 a_tma.prefetch(predicate=isThread0) 89 b_tma.prefetch(predicate=isThread0) 90 91 a_smem = get_dynamic_shared_memory((M, K), T.f16()) 92 b_smem = get_dynamic_shared_memory((K, N), T.f16(), offset=a_size) 93 94 # 1. TMA Load for two input matrices 95 tma_load(mbar_group, a_tma, b_tma, isThread0) 96 97 # 2. All threads wait TMA load completion 98 mbar_group[0].try_wait() 99 100 # 3. Performs Tensor Core GEMM 128x128x64 by warpgroup 101 A = WGMMAMatrix(WGMMAType.Descriptor, [M, K], desc=a_tma, smem=a_smem) 102 B = WGMMAMatrix(WGMMAType.Descriptor, [K, N], desc=b_tma, smem=b_smem) 103 D = WGMMAMatrix(WGMMAType.Accumulator, shape=[M, N], ty=T.f32()) 104 105 # Matrix Multiply 106 D += A @ B 107 108 # 4. Stores fragmented registers to global memory by warpgroup 109 D.store_accumulator(d_dev) 110 111 gemm_tma_kernel() 112 113 t8 = gpu.memcpy(token_ty, [t7], d, d_dev) 114 gpu.wait(None, [t8]) 115 116 117# Python pass arguments to MLIR 118M = 128 119N = 128 120K = 64 121a = np.random.randn(M, K).astype(np.float16) 122b = np.random.randn(K, N).astype(np.float16) 123d = np.zeros((M, N), np.float32) 124gemm_128_128_64(a, b, d) 125 126ref_d = a.astype(np.float16) @ b.astype(np.float16) 127np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01) 128print("PASS") 129# CHECK-NOT: Mismatched elements 130