xref: /llvm-project/mlir/test/Examples/NVGPU/Ch2.py (revision f8ff9094711b74d3f695f7571f6390f8a481fc52)
1# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2# RUN:   %PYTHON %s | FileCheck %s
3
4# ===----------------------------------------------------------------------===//
5#  Chapter 2 : 2D Saxpy with TMA
6# ===----------------------------------------------------------------------===//
7#
8# This program demonstrates 2D Saxpy. It is same as Chapter 1,
9# but it loads data using TMA (Tensor Memory Accelerator)
10#
11# This chapter introduces demonstrates:
12#  1. Computes 2D SAXPY in the same way as Ch1.py but loads data using TMA
13#  2. Create and initialize 1 asynchronous transactional barrier (mbarrier)
14#  3. Thread-0 Load request data load from TMA for each thread block
15#  4. Each thread block loads <1x32xf32> for x and y.
16#  5. Wait for completion of TMA load with mbarrier
17#
18# ===----------------------------------------------------------------------===//
19
20from mlir import ir
21from mlir.dialects import nvgpu, scf, arith, memref, vector, gpu
22from tools.nvdsl import *
23from mlir import runtime as rt
24from mlir.extras import types as T
25import numpy as np
26
27
28@NVDSL.mlir_func
29def saxpy(x, y, alpha):
30    token_ty = gpu.AsyncTokenType.get()
31    t1 = gpu.wait(token_ty, [])
32    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
33    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
34    t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
35    t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
36    t6 = gpu.wait(token_ty, [t5])
37
38    x_tma = TMA([1, N], x.type)
39    y_tma = TMA([1, N], y.type)
40    x_tma.create_descriptor(x_dev)
41    y_tma.create_descriptor(y_dev)
42    sz_x = get_type_size(x_tma.tma_memref)
43    sz_y = get_type_size(x_tma.tma_memref)
44    sz = sz_x + sz_y
45
46    @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1), smem=sz)
47    def saxpy_tma_kernel():
48        bidx = gpu.block_id(gpu.Dimension.x)
49        tidx = gpu.thread_id(gpu.Dimension.x)
50        isThread0 = tidx == 0
51
52        # 1. Create and initialize asynchronous transactional barrier (mbarrier)
53        mbar_group = Mbarriers(number_of_barriers=1)
54        mbar_group[0].init(1, predicate=isThread0)
55
56        # 2. Execute Tensor Memory Accelerator (TMA) Load
57        x_smem = get_dynamic_shared_memory([1, N], T.f32())
58        y_smem = get_dynamic_shared_memory([1, N], T.f32(), offset=sz_x)
59        x_tma.load(x_smem, mbar_group[0], coords=[0, bidx], predicate=isThread0)
60        y_tma.load(y_smem, mbar_group[0], coords=[0, bidx], predicate=isThread0)
61        mbar_group[0].arrive(txcount=sz, predicate=isThread0)
62
63        # 3. Wait for completion of TMA load with mbarrier
64        mbar_group[0].try_wait()
65
66        x_val = memref.load(x_smem, [const(0), tidx])
67        y_val = memref.load(y_smem, [const(0), tidx])
68
69        # SAXPY: y[i] += a * x[i];
70        y_val += x_val * alpha
71
72        memref.store(y_val, y_dev, [bidx, tidx])
73
74    saxpy_tma_kernel()
75
76    t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
77    gpu.wait(token_ty, [t7])
78
79
80# 3. Pass numpy arrays to MLIR
81M = 256
82N = 32
83alpha = 2.0
84x = np.random.randn(M, N).astype(np.float32)
85y = np.ones((M, N), np.float32)
86saxpy(x, y, alpha)
87
88#  4. Verify MLIR with reference computation
89ref = np.ones((M, N), np.float32)
90ref += x * alpha
91np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
92print("PASS")
93# CHECK-NOT: Mismatched elements
94