xref: /llvm-project/mlir/test/Examples/NVGPU/Ch1.py (revision f8ff9094711b74d3f695f7571f6390f8a481fc52)
1# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2# RUN:   %PYTHON %s | FileCheck %s
3
4# ===----------------------------------------------------------------------===//
5#  Chapter 1 : 2D Saxpy
6# ===----------------------------------------------------------------------===//
7#
8# This program demonstrates 2D Saxpy:
9#  1. Use GPU dialect to allocate and copy memory host to gpu and vice versa
10#  2. Computes 2D SAXPY kernel using operator overloading
11#  3. Pass numpy arrays to MLIR as memref arguments
12#  4. Verify MLIR program with reference computation in python
13#
14# ===----------------------------------------------------------------------===//
15
16
17from mlir import ir
18from mlir.dialects import gpu, memref
19from tools.nvdsl import *
20import numpy as np
21
22
23@NVDSL.mlir_func
24def saxpy(x, y, alpha):
25    # 1. Use MLIR GPU dialect to allocate and copy memory
26    token_ty = gpu.AsyncTokenType.get()
27    t1 = gpu.wait(token_ty, [])
28    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
29    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
30    t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
31    t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
32    t6 = gpu.wait(token_ty, [t5])
33
34    # 2. Compute 2D SAXPY kernel
35    @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1))
36    def saxpy_kernel():
37        bidx = gpu.block_id(gpu.Dimension.x)
38        tidx = gpu.thread_id(gpu.Dimension.x)
39        x_val = memref.load(x_dev, [bidx, tidx])
40        y_val = memref.load(y_dev, [bidx, tidx])
41
42        # SAXPY: y[i] += a * x[i];
43        y_val += x_val * alpha
44
45        memref.store(y_val, y_dev, [bidx, tidx])
46
47    saxpy_kernel()
48
49    t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
50    gpu.wait(token_ty, [t7])
51
52
53# 3. Pass numpy arrays to MLIR
54M = 256
55N = 32
56alpha = 2.0
57x = np.random.randn(M, N).astype(np.float32)
58y = np.ones((M, N), np.float32)
59saxpy(x, y, alpha)
60
61#  4. Verify MLIR with reference computation
62ref = np.ones((M, N), np.float32)
63ref += x * alpha
64np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
65print("PASS")
66# CHECK-NOT: Mismatched elements
67