1// RUN: mlir-opt %s \ 2// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ 3// RUN: | mlir-runner \ 4// RUN: --shared-libs=%mlir_cuda_runtime \ 5// RUN: --shared-libs=%mlir_runner_utils \ 6// RUN: --entry-point-result=void \ 7// RUN: | FileCheck %s 8 9func.func @main() { 10 %data = memref.alloc() : memref<2x6xi32> 11 %sum = memref.alloc() : memref<2xi32> 12 %cst0 = arith.constant 0 : i32 13 %cst1 = arith.constant 1 : i32 14 %cst2 = arith.constant 2 : i32 15 %cst4 = arith.constant 4 : i32 16 %cst8 = arith.constant 8 : i32 17 %cst16 = arith.constant 16 : i32 18 19 %cst3 = arith.constant 3 : i32 20 %cst6 = arith.constant 6 : i32 21 %cst7 = arith.constant 7 : i32 22 %cst10 = arith.constant 10 : i32 23 %cst11 = arith.constant 11 : i32 24 25 %c0 = arith.constant 0 : index 26 %c1 = arith.constant 1 : index 27 %c2 = arith.constant 2 : index 28 %c3 = arith.constant 3 : index 29 %c4 = arith.constant 4 : index 30 %c5 = arith.constant 5 : index 31 %c6 = arith.constant 6 : index 32 33 %cast_data = memref.cast %data : memref<2x6xi32> to memref<*xi32> 34 gpu.host_register %cast_data : memref<*xi32> 35 %cast_sum = memref.cast %sum : memref<2xi32> to memref<*xi32> 36 gpu.host_register %cast_sum : memref<*xi32> 37 38 memref.store %cst0, %data[%c0, %c0] : memref<2x6xi32> 39 memref.store %cst1, %data[%c0, %c1] : memref<2x6xi32> 40 memref.store %cst2, %data[%c0, %c2] : memref<2x6xi32> 41 memref.store %cst4, %data[%c0, %c3] : memref<2x6xi32> 42 memref.store %cst8, %data[%c0, %c4] : memref<2x6xi32> 43 memref.store %cst16, %data[%c0, %c5] : memref<2x6xi32> 44 45 memref.store %cst2, %data[%c1, %c0] : memref<2x6xi32> 46 memref.store %cst3, %data[%c1, %c1] : memref<2x6xi32> 47 memref.store %cst6, %data[%c1, %c2] : memref<2x6xi32> 48 memref.store %cst7, %data[%c1, %c3] : memref<2x6xi32> 49 memref.store %cst10, %data[%c1, %c4] : memref<2x6xi32> 50 memref.store %cst11, %data[%c1, %c5] : memref<2x6xi32> 51 52 // MAX 53 gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c2, %grid_y = %c1, %grid_z = %c1) 54 threads(%tx, %ty, %tz) in (%block_x = %c6, %block_y = %c1, %block_z = %c1) { 55 %val = memref.load %data[%bx, %tx] : memref<2x6xi32> 56 %reduced = gpu.all_reduce maxsi %val uniform {} : (i32) -> (i32) 57 memref.store %reduced, %sum[%bx] : memref<2xi32> 58 gpu.terminator 59 } 60 61 call @printMemrefI32(%cast_sum) : (memref<*xi32>) -> () 62 // CHECK: [16, 11] 63 64 return 65} 66 67func.func private @printMemrefI32(memref<*xi32>) 68 69