14d330820SGuray Ozen# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ 24d330820SGuray Ozen# RUN: %PYTHON %s | FileCheck %s 34d330820SGuray Ozen 44d330820SGuray Ozen# ===----------------------------------------------------------------------===// 54d330820SGuray Ozen# Chapter 2 : 2D Saxpy with TMA 64d330820SGuray Ozen# ===----------------------------------------------------------------------===// 74d330820SGuray Ozen# 84d330820SGuray Ozen# This program demonstrates 2D Saxpy. It is same as Chapter 1, 94d330820SGuray Ozen# but it loads data using TMA (Tensor Memory Accelerator) 104d330820SGuray Ozen# 114d330820SGuray Ozen# This chapter introduces demonstrates: 124d330820SGuray Ozen# 1. Computes 2D SAXPY in the same way as Ch1.py but loads data using TMA 134d330820SGuray Ozen# 2. Create and initialize 1 asynchronous transactional barrier (mbarrier) 144d330820SGuray Ozen# 3. Thread-0 Load request data load from TMA for each thread block 154d330820SGuray Ozen# 4. Each thread block loads <1x32xf32> for x and y. 164d330820SGuray Ozen# 5. Wait for completion of TMA load with mbarrier 174d330820SGuray Ozen# 184d330820SGuray Ozen# ===----------------------------------------------------------------------===// 194d330820SGuray Ozen 204d330820SGuray Ozenfrom mlir import ir 214d330820SGuray Ozenfrom mlir.dialects import nvgpu, scf, arith, memref, vector, gpu 224d330820SGuray Ozenfrom tools.nvdsl import * 234d330820SGuray Ozenfrom mlir import runtime as rt 244d330820SGuray Ozenfrom mlir.extras import types as T 254d330820SGuray Ozenimport numpy as np 264d330820SGuray Ozen 274d330820SGuray Ozen 284d330820SGuray Ozen@NVDSL.mlir_func 294d330820SGuray Ozendef saxpy(x, y, alpha): 30*f8ff9094SGuray Ozen token_ty = gpu.AsyncTokenType.get() 314d330820SGuray Ozen t1 = gpu.wait(token_ty, []) 324d330820SGuray Ozen x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], []) 334d330820SGuray Ozen y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], []) 344d330820SGuray Ozen t4 = gpu.memcpy(token_ty, [t3], x_dev, x) 354d330820SGuray Ozen t5 = gpu.memcpy(token_ty, [t4], y_dev, y) 364d330820SGuray Ozen t6 = gpu.wait(token_ty, [t5]) 374d330820SGuray Ozen 384d330820SGuray Ozen x_tma = TMA([1, N], x.type) 394d330820SGuray Ozen y_tma = TMA([1, N], y.type) 404d330820SGuray Ozen x_tma.create_descriptor(x_dev) 414d330820SGuray Ozen y_tma.create_descriptor(y_dev) 424d330820SGuray Ozen sz_x = get_type_size(x_tma.tma_memref) 434d330820SGuray Ozen sz_y = get_type_size(x_tma.tma_memref) 444d330820SGuray Ozen sz = sz_x + sz_y 454d330820SGuray Ozen 464d330820SGuray Ozen @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1), smem=sz) 474d330820SGuray Ozen def saxpy_tma_kernel(): 484d330820SGuray Ozen bidx = gpu.block_id(gpu.Dimension.x) 494d330820SGuray Ozen tidx = gpu.thread_id(gpu.Dimension.x) 504d330820SGuray Ozen isThread0 = tidx == 0 514d330820SGuray Ozen 524d330820SGuray Ozen # 1. Create and initialize asynchronous transactional barrier (mbarrier) 534d330820SGuray Ozen mbar_group = Mbarriers(number_of_barriers=1) 544d330820SGuray Ozen mbar_group[0].init(1, predicate=isThread0) 554d330820SGuray Ozen 564d330820SGuray Ozen # 2. Execute Tensor Memory Accelerator (TMA) Load 574d330820SGuray Ozen x_smem = get_dynamic_shared_memory([1, N], T.f32()) 584d330820SGuray Ozen y_smem = get_dynamic_shared_memory([1, N], T.f32(), offset=sz_x) 594d330820SGuray Ozen x_tma.load(x_smem, mbar_group[0], coords=[0, bidx], predicate=isThread0) 604d330820SGuray Ozen y_tma.load(y_smem, mbar_group[0], coords=[0, bidx], predicate=isThread0) 614d330820SGuray Ozen mbar_group[0].arrive(txcount=sz, predicate=isThread0) 624d330820SGuray Ozen 634d330820SGuray Ozen # 3. Wait for completion of TMA load with mbarrier 644d330820SGuray Ozen mbar_group[0].try_wait() 654d330820SGuray Ozen 664d330820SGuray Ozen x_val = memref.load(x_smem, [const(0), tidx]) 674d330820SGuray Ozen y_val = memref.load(y_smem, [const(0), tidx]) 684d330820SGuray Ozen 694d330820SGuray Ozen # SAXPY: y[i] += a * x[i]; 704d330820SGuray Ozen y_val += x_val * alpha 714d330820SGuray Ozen 724d330820SGuray Ozen memref.store(y_val, y_dev, [bidx, tidx]) 734d330820SGuray Ozen 744d330820SGuray Ozen saxpy_tma_kernel() 754d330820SGuray Ozen 764d330820SGuray Ozen t7 = gpu.memcpy(token_ty, [t6], y, y_dev) 774d330820SGuray Ozen gpu.wait(token_ty, [t7]) 784d330820SGuray Ozen 794d330820SGuray Ozen 804d330820SGuray Ozen# 3. Pass numpy arrays to MLIR 814d330820SGuray OzenM = 256 824d330820SGuray OzenN = 32 834d330820SGuray Ozenalpha = 2.0 844d330820SGuray Ozenx = np.random.randn(M, N).astype(np.float32) 854d330820SGuray Ozeny = np.ones((M, N), np.float32) 864d330820SGuray Ozensaxpy(x, y, alpha) 874d330820SGuray Ozen 884d330820SGuray Ozen# 4. Verify MLIR with reference computation 894d330820SGuray Ozenref = np.ones((M, N), np.float32) 904d330820SGuray Ozenref += x * alpha 914d330820SGuray Ozennp.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01) 924d330820SGuray Ozenprint("PASS") 934d330820SGuray Ozen# CHECK-NOT: Mismatched elements 94