1*4d330820SGuray Ozen# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ 2*4d330820SGuray Ozen# RUN: %PYTHON %s | FileCheck %s 3*4d330820SGuray Ozen 4*4d330820SGuray Ozen# ===----------------------------------------------------------------------===// 5*4d330820SGuray Ozen# Chapter 0 : Hello World 6*4d330820SGuray Ozen# ===----------------------------------------------------------------------===// 7*4d330820SGuray Ozen# 8*4d330820SGuray Ozen# This program demonstrates Hello World: 9*4d330820SGuray Ozen# 1. Build MLIR function with arguments 10*4d330820SGuray Ozen# 2. Build MLIR GPU kernel 11*4d330820SGuray Ozen# 3. Print from a GPU thread 12*4d330820SGuray Ozen# 4. Pass arguments, JIT compile and run the MLIR function 13*4d330820SGuray Ozen# 14*4d330820SGuray Ozen# ===----------------------------------------------------------------------===// 15*4d330820SGuray Ozen 16*4d330820SGuray Ozen 17*4d330820SGuray Ozenfrom mlir.dialects import gpu 18*4d330820SGuray Ozenfrom tools.nvdsl import * 19*4d330820SGuray Ozen 20*4d330820SGuray Ozen 21*4d330820SGuray Ozen# 1. The decorator generates a MLIR func.func. 22*4d330820SGuray Ozen# Everything inside the Python function becomes the body of the func. 23*4d330820SGuray Ozen# The decorator also translates `alpha` to an `index` type. 24*4d330820SGuray Ozen@NVDSL.mlir_func 25*4d330820SGuray Ozendef main(alpha): 26*4d330820SGuray Ozen # 2. The decorator generates a MLIR gpu.launch. 27*4d330820SGuray Ozen # Everything inside the Python function becomes the body of the gpu.launch. 28*4d330820SGuray Ozen # This allows for late outlining of the GPU kernel, enabling optimizations 29*4d330820SGuray Ozen # like constant folding from host to device. 30*4d330820SGuray Ozen @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(4, 1, 1)) 31*4d330820SGuray Ozen def kernel(): 32*4d330820SGuray Ozen tidx = gpu.thread_id(gpu.Dimension.x) 33*4d330820SGuray Ozen # + operator generates arith.addi 34*4d330820SGuray Ozen myValue = alpha + tidx 35*4d330820SGuray Ozen # Print from a GPU thread 36*4d330820SGuray Ozen gpu.printf("GPU thread %llu has %llu\n", [tidx, myValue]) 37*4d330820SGuray Ozen 38*4d330820SGuray Ozen # 3. Call the GPU kernel 39*4d330820SGuray Ozen kernel() 40*4d330820SGuray Ozen 41*4d330820SGuray Ozen 42*4d330820SGuray Ozenalpha = 100 43*4d330820SGuray Ozen# 4. The `mlir_func` decorator JIT compiles the IR and executes the MLIR function. 44*4d330820SGuray Ozenmain(alpha) 45*4d330820SGuray Ozen 46*4d330820SGuray Ozen 47*4d330820SGuray Ozen# CHECK: GPU thread 0 has 100 48*4d330820SGuray Ozen# CHECK: GPU thread 1 has 101 49*4d330820SGuray Ozen# CHECK: GPU thread 2 has 102 50*4d330820SGuray Ozen# CHECK: GPU thread 3 has 103 51