xref: /llvm-project/mlir/test/Examples/NVGPU/Ch0.py (revision 4d3308202e52b213a05023c8b8b470b346151de6)
1*4d330820SGuray Ozen# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2*4d330820SGuray Ozen# RUN:   %PYTHON %s | FileCheck %s
3*4d330820SGuray Ozen
4*4d330820SGuray Ozen# ===----------------------------------------------------------------------===//
5*4d330820SGuray Ozen#  Chapter 0 : Hello World
6*4d330820SGuray Ozen# ===----------------------------------------------------------------------===//
7*4d330820SGuray Ozen#
8*4d330820SGuray Ozen# This program demonstrates Hello World:
9*4d330820SGuray Ozen#   1. Build MLIR function with arguments
10*4d330820SGuray Ozen#   2. Build MLIR GPU kernel
11*4d330820SGuray Ozen#   3. Print from a GPU thread
12*4d330820SGuray Ozen#   4. Pass arguments, JIT compile and run the MLIR function
13*4d330820SGuray Ozen#
14*4d330820SGuray Ozen# ===----------------------------------------------------------------------===//
15*4d330820SGuray Ozen
16*4d330820SGuray Ozen
17*4d330820SGuray Ozenfrom mlir.dialects import gpu
18*4d330820SGuray Ozenfrom tools.nvdsl import *
19*4d330820SGuray Ozen
20*4d330820SGuray Ozen
21*4d330820SGuray Ozen# 1. The decorator generates a MLIR func.func.
22*4d330820SGuray Ozen# Everything inside the Python function becomes the body of the func.
23*4d330820SGuray Ozen# The decorator also translates `alpha` to an `index` type.
24*4d330820SGuray Ozen@NVDSL.mlir_func
25*4d330820SGuray Ozendef main(alpha):
26*4d330820SGuray Ozen    # 2. The decorator generates a MLIR gpu.launch.
27*4d330820SGuray Ozen    # Everything inside the Python function becomes the body of the gpu.launch.
28*4d330820SGuray Ozen    # This allows for late outlining of the GPU kernel, enabling optimizations
29*4d330820SGuray Ozen    # like constant folding from host to device.
30*4d330820SGuray Ozen    @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(4, 1, 1))
31*4d330820SGuray Ozen    def kernel():
32*4d330820SGuray Ozen        tidx = gpu.thread_id(gpu.Dimension.x)
33*4d330820SGuray Ozen        # + operator generates arith.addi
34*4d330820SGuray Ozen        myValue = alpha + tidx
35*4d330820SGuray Ozen        # Print from a GPU thread
36*4d330820SGuray Ozen        gpu.printf("GPU thread %llu has %llu\n", [tidx, myValue])
37*4d330820SGuray Ozen
38*4d330820SGuray Ozen    # 3. Call the GPU kernel
39*4d330820SGuray Ozen    kernel()
40*4d330820SGuray Ozen
41*4d330820SGuray Ozen
42*4d330820SGuray Ozenalpha = 100
43*4d330820SGuray Ozen# 4. The `mlir_func` decorator JIT compiles the IR and executes the MLIR function.
44*4d330820SGuray Ozenmain(alpha)
45*4d330820SGuray Ozen
46*4d330820SGuray Ozen
47*4d330820SGuray Ozen# CHECK: GPU thread 0 has 100
48*4d330820SGuray Ozen# CHECK: GPU thread 1 has 101
49*4d330820SGuray Ozen# CHECK: GPU thread 2 has 102
50*4d330820SGuray Ozen# CHECK: GPU thread 3 has 103
51