1# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ 2# RUN: %PYTHON %s | FileCheck %s 3 4# ===----------------------------------------------------------------------===// 5# Chapter 1 : 2D Saxpy 6# ===----------------------------------------------------------------------===// 7# 8# This program demonstrates 2D Saxpy: 9# 1. Use GPU dialect to allocate and copy memory host to gpu and vice versa 10# 2. Computes 2D SAXPY kernel using operator overloading 11# 3. Pass numpy arrays to MLIR as memref arguments 12# 4. Verify MLIR program with reference computation in python 13# 14# ===----------------------------------------------------------------------===// 15 16 17from mlir import ir 18from mlir.dialects import gpu, memref 19from tools.nvdsl import * 20import numpy as np 21 22 23@NVDSL.mlir_func 24def saxpy(x, y, alpha): 25 # 1. Use MLIR GPU dialect to allocate and copy memory 26 token_ty = gpu.AsyncTokenType.get() 27 t1 = gpu.wait(token_ty, []) 28 x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], []) 29 y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], []) 30 t4 = gpu.memcpy(token_ty, [t3], x_dev, x) 31 t5 = gpu.memcpy(token_ty, [t4], y_dev, y) 32 t6 = gpu.wait(token_ty, [t5]) 33 34 # 2. Compute 2D SAXPY kernel 35 @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1)) 36 def saxpy_kernel(): 37 bidx = gpu.block_id(gpu.Dimension.x) 38 tidx = gpu.thread_id(gpu.Dimension.x) 39 x_val = memref.load(x_dev, [bidx, tidx]) 40 y_val = memref.load(y_dev, [bidx, tidx]) 41 42 # SAXPY: y[i] += a * x[i]; 43 y_val += x_val * alpha 44 45 memref.store(y_val, y_dev, [bidx, tidx]) 46 47 saxpy_kernel() 48 49 t7 = gpu.memcpy(token_ty, [t6], y, y_dev) 50 gpu.wait(token_ty, [t7]) 51 52 53# 3. Pass numpy arrays to MLIR 54M = 256 55N = 32 56alpha = 2.0 57x = np.random.randn(M, N).astype(np.float32) 58y = np.ones((M, N), np.float32) 59saxpy(x, y, alpha) 60 61# 4. Verify MLIR with reference computation 62ref = np.ones((M, N), np.float32) 63ref += x * alpha 64np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01) 65print("PASS") 66# CHECK-NOT: Mismatched elements 67