xref: /llvm-project/mlir/test/Examples/NVGPU/Ch2.py (revision f8ff9094711b74d3f695f7571f6390f8a481fc52)
14d330820SGuray Ozen# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
24d330820SGuray Ozen# RUN:   %PYTHON %s | FileCheck %s
34d330820SGuray Ozen
44d330820SGuray Ozen# ===----------------------------------------------------------------------===//
54d330820SGuray Ozen#  Chapter 2 : 2D Saxpy with TMA
64d330820SGuray Ozen# ===----------------------------------------------------------------------===//
74d330820SGuray Ozen#
84d330820SGuray Ozen# This program demonstrates 2D Saxpy. It is same as Chapter 1,
94d330820SGuray Ozen# but it loads data using TMA (Tensor Memory Accelerator)
104d330820SGuray Ozen#
114d330820SGuray Ozen# This chapter introduces demonstrates:
124d330820SGuray Ozen#  1. Computes 2D SAXPY in the same way as Ch1.py but loads data using TMA
134d330820SGuray Ozen#  2. Create and initialize 1 asynchronous transactional barrier (mbarrier)
144d330820SGuray Ozen#  3. Thread-0 Load request data load from TMA for each thread block
154d330820SGuray Ozen#  4. Each thread block loads <1x32xf32> for x and y.
164d330820SGuray Ozen#  5. Wait for completion of TMA load with mbarrier
174d330820SGuray Ozen#
184d330820SGuray Ozen# ===----------------------------------------------------------------------===//
194d330820SGuray Ozen
204d330820SGuray Ozenfrom mlir import ir
214d330820SGuray Ozenfrom mlir.dialects import nvgpu, scf, arith, memref, vector, gpu
224d330820SGuray Ozenfrom tools.nvdsl import *
234d330820SGuray Ozenfrom mlir import runtime as rt
244d330820SGuray Ozenfrom mlir.extras import types as T
254d330820SGuray Ozenimport numpy as np
264d330820SGuray Ozen
274d330820SGuray Ozen
284d330820SGuray Ozen@NVDSL.mlir_func
294d330820SGuray Ozendef saxpy(x, y, alpha):
30*f8ff9094SGuray Ozen    token_ty = gpu.AsyncTokenType.get()
314d330820SGuray Ozen    t1 = gpu.wait(token_ty, [])
324d330820SGuray Ozen    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
334d330820SGuray Ozen    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
344d330820SGuray Ozen    t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
354d330820SGuray Ozen    t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
364d330820SGuray Ozen    t6 = gpu.wait(token_ty, [t5])
374d330820SGuray Ozen
384d330820SGuray Ozen    x_tma = TMA([1, N], x.type)
394d330820SGuray Ozen    y_tma = TMA([1, N], y.type)
404d330820SGuray Ozen    x_tma.create_descriptor(x_dev)
414d330820SGuray Ozen    y_tma.create_descriptor(y_dev)
424d330820SGuray Ozen    sz_x = get_type_size(x_tma.tma_memref)
434d330820SGuray Ozen    sz_y = get_type_size(x_tma.tma_memref)
444d330820SGuray Ozen    sz = sz_x + sz_y
454d330820SGuray Ozen
464d330820SGuray Ozen    @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1), smem=sz)
474d330820SGuray Ozen    def saxpy_tma_kernel():
484d330820SGuray Ozen        bidx = gpu.block_id(gpu.Dimension.x)
494d330820SGuray Ozen        tidx = gpu.thread_id(gpu.Dimension.x)
504d330820SGuray Ozen        isThread0 = tidx == 0
514d330820SGuray Ozen
524d330820SGuray Ozen        # 1. Create and initialize asynchronous transactional barrier (mbarrier)
534d330820SGuray Ozen        mbar_group = Mbarriers(number_of_barriers=1)
544d330820SGuray Ozen        mbar_group[0].init(1, predicate=isThread0)
554d330820SGuray Ozen
564d330820SGuray Ozen        # 2. Execute Tensor Memory Accelerator (TMA) Load
574d330820SGuray Ozen        x_smem = get_dynamic_shared_memory([1, N], T.f32())
584d330820SGuray Ozen        y_smem = get_dynamic_shared_memory([1, N], T.f32(), offset=sz_x)
594d330820SGuray Ozen        x_tma.load(x_smem, mbar_group[0], coords=[0, bidx], predicate=isThread0)
604d330820SGuray Ozen        y_tma.load(y_smem, mbar_group[0], coords=[0, bidx], predicate=isThread0)
614d330820SGuray Ozen        mbar_group[0].arrive(txcount=sz, predicate=isThread0)
624d330820SGuray Ozen
634d330820SGuray Ozen        # 3. Wait for completion of TMA load with mbarrier
644d330820SGuray Ozen        mbar_group[0].try_wait()
654d330820SGuray Ozen
664d330820SGuray Ozen        x_val = memref.load(x_smem, [const(0), tidx])
674d330820SGuray Ozen        y_val = memref.load(y_smem, [const(0), tidx])
684d330820SGuray Ozen
694d330820SGuray Ozen        # SAXPY: y[i] += a * x[i];
704d330820SGuray Ozen        y_val += x_val * alpha
714d330820SGuray Ozen
724d330820SGuray Ozen        memref.store(y_val, y_dev, [bidx, tidx])
734d330820SGuray Ozen
744d330820SGuray Ozen    saxpy_tma_kernel()
754d330820SGuray Ozen
764d330820SGuray Ozen    t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
774d330820SGuray Ozen    gpu.wait(token_ty, [t7])
784d330820SGuray Ozen
794d330820SGuray Ozen
804d330820SGuray Ozen# 3. Pass numpy arrays to MLIR
814d330820SGuray OzenM = 256
824d330820SGuray OzenN = 32
834d330820SGuray Ozenalpha = 2.0
844d330820SGuray Ozenx = np.random.randn(M, N).astype(np.float32)
854d330820SGuray Ozeny = np.ones((M, N), np.float32)
864d330820SGuray Ozensaxpy(x, y, alpha)
874d330820SGuray Ozen
884d330820SGuray Ozen#  4. Verify MLIR with reference computation
894d330820SGuray Ozenref = np.ones((M, N), np.float32)
904d330820SGuray Ozenref += x * alpha
914d330820SGuray Ozennp.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
924d330820SGuray Ozenprint("PASS")
934d330820SGuray Ozen# CHECK-NOT: Mismatched elements
94