xref: /llvm-project/mlir/test/Examples/NVGPU/Ch0.py (revision 4d3308202e52b213a05023c8b8b470b346151de6)
1# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2# RUN:   %PYTHON %s | FileCheck %s
3
4# ===----------------------------------------------------------------------===//
5#  Chapter 0 : Hello World
6# ===----------------------------------------------------------------------===//
7#
8# This program demonstrates Hello World:
9#   1. Build MLIR function with arguments
10#   2. Build MLIR GPU kernel
11#   3. Print from a GPU thread
12#   4. Pass arguments, JIT compile and run the MLIR function
13#
14# ===----------------------------------------------------------------------===//
15
16
17from mlir.dialects import gpu
18from tools.nvdsl import *
19
20
21# 1. The decorator generates a MLIR func.func.
22# Everything inside the Python function becomes the body of the func.
23# The decorator also translates `alpha` to an `index` type.
24@NVDSL.mlir_func
25def main(alpha):
26    # 2. The decorator generates a MLIR gpu.launch.
27    # Everything inside the Python function becomes the body of the gpu.launch.
28    # This allows for late outlining of the GPU kernel, enabling optimizations
29    # like constant folding from host to device.
30    @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(4, 1, 1))
31    def kernel():
32        tidx = gpu.thread_id(gpu.Dimension.x)
33        # + operator generates arith.addi
34        myValue = alpha + tidx
35        # Print from a GPU thread
36        gpu.printf("GPU thread %llu has %llu\n", [tidx, myValue])
37
38    # 3. Call the GPU kernel
39    kernel()
40
41
42alpha = 100
43# 4. The `mlir_func` decorator JIT compiles the IR and executes the MLIR function.
44main(alpha)
45
46
47# CHECK: GPU thread 0 has 100
48# CHECK: GPU thread 1 has 101
49# CHECK: GPU thread 2 has 102
50# CHECK: GPU thread 3 has 103
51