1# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ 2# RUN: %PYTHON %s | FileCheck %s 3 4# ===----------------------------------------------------------------------===// 5# Chapter 2 : 2D Saxpy with TMA 6# ===----------------------------------------------------------------------===// 7# 8# This program demonstrates 2D Saxpy. It is same as Chapter 1, 9# but it loads data using TMA (Tensor Memory Accelerator) 10# 11# This chapter introduces demonstrates: 12# 1. Computes 2D SAXPY in the same way as Ch1.py but loads data using TMA 13# 2. Create and initialize 1 asynchronous transactional barrier (mbarrier) 14# 3. Thread-0 Load request data load from TMA for each thread block 15# 4. Each thread block loads <1x32xf32> for x and y. 16# 5. Wait for completion of TMA load with mbarrier 17# 18# ===----------------------------------------------------------------------===// 19 20from mlir import ir 21from mlir.dialects import nvgpu, scf, arith, memref, vector, gpu 22from tools.nvdsl import * 23from mlir import runtime as rt 24from mlir.extras import types as T 25import numpy as np 26 27 28@NVDSL.mlir_func 29def saxpy(x, y, alpha): 30 token_ty = gpu.AsyncTokenType.get() 31 t1 = gpu.wait(token_ty, []) 32 x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], []) 33 y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], []) 34 t4 = gpu.memcpy(token_ty, [t3], x_dev, x) 35 t5 = gpu.memcpy(token_ty, [t4], y_dev, y) 36 t6 = gpu.wait(token_ty, [t5]) 37 38 x_tma = TMA([1, N], x.type) 39 y_tma = TMA([1, N], y.type) 40 x_tma.create_descriptor(x_dev) 41 y_tma.create_descriptor(y_dev) 42 sz_x = get_type_size(x_tma.tma_memref) 43 sz_y = get_type_size(x_tma.tma_memref) 44 sz = sz_x + sz_y 45 46 @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1), smem=sz) 47 def saxpy_tma_kernel(): 48 bidx = gpu.block_id(gpu.Dimension.x) 49 tidx = gpu.thread_id(gpu.Dimension.x) 50 isThread0 = tidx == 0 51 52 # 1. Create and initialize asynchronous transactional barrier (mbarrier) 53 mbar_group = Mbarriers(number_of_barriers=1) 54 mbar_group[0].init(1, predicate=isThread0) 55 56 # 2. Execute Tensor Memory Accelerator (TMA) Load 57 x_smem = get_dynamic_shared_memory([1, N], T.f32()) 58 y_smem = get_dynamic_shared_memory([1, N], T.f32(), offset=sz_x) 59 x_tma.load(x_smem, mbar_group[0], coords=[0, bidx], predicate=isThread0) 60 y_tma.load(y_smem, mbar_group[0], coords=[0, bidx], predicate=isThread0) 61 mbar_group[0].arrive(txcount=sz, predicate=isThread0) 62 63 # 3. Wait for completion of TMA load with mbarrier 64 mbar_group[0].try_wait() 65 66 x_val = memref.load(x_smem, [const(0), tidx]) 67 y_val = memref.load(y_smem, [const(0), tidx]) 68 69 # SAXPY: y[i] += a * x[i]; 70 y_val += x_val * alpha 71 72 memref.store(y_val, y_dev, [bidx, tidx]) 73 74 saxpy_tma_kernel() 75 76 t7 = gpu.memcpy(token_ty, [t6], y, y_dev) 77 gpu.wait(token_ty, [t7]) 78 79 80# 3. Pass numpy arrays to MLIR 81M = 256 82N = 32 83alpha = 2.0 84x = np.random.randn(M, N).astype(np.float32) 85y = np.ones((M, N), np.float32) 86saxpy(x, y, alpha) 87 88# 4. Verify MLIR with reference computation 89ref = np.ones((M, N), np.float32) 90ref += x * alpha 91np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01) 92print("PASS") 93# CHECK-NOT: Mismatched elements 94