xref: /llvm-project/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir (revision 25ae1a266d50f24a8fffc57152d7f3c3fcb65517)
1// RUN: mlir-opt -test-convert-to-spirv="convert-gpu-modules=true run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
2
3module attributes {
4  gpu.container_module,
5  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>
6} {
7  // CHECK-LABEL: func.func @main
8  // CHECK:       %[[C1:.*]] = arith.constant 1 : index
9  // CHECK:       gpu.launch_func  @[[$KERNELS_1:.*]]::@[[$BUILTIN_WG_ID_X:.*]] blocks in (%[[C1]], %[[C1]], %[[C1]]) threads in (%[[C1]], %[[C1]], %[[C1]])
10  // CHECK:       gpu.launch_func  @[[$KERNELS_2:.*]]::@[[$BUILTIN_WG_ID_Y:.*]] blocks in (%[[C1]], %[[C1]], %[[C1]]) threads in (%[[C1]], %[[C1]], %[[C1]])
11  func.func @main() {
12    %c1 = arith.constant 1 : index
13    gpu.launch_func @kernels_1::@builtin_workgroup_id_x
14        blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
15    gpu.launch_func @KERNELS_2::@builtin_workgroup_id_y
16        blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
17    return
18  }
19
20  // CHECK-LABEL:  spirv.module @{{.*}} Logical GLSL450
21  // CHECK:        spirv.func @[[$BUILTIN_WG_ID_X]]
22  // CHECK:        spirv.mlir.addressof
23  // CHECK:        spirv.Load "Input"
24  // CHECK:        spirv.CompositeExtract
25  gpu.module @kernels_1 {
26    gpu.func @builtin_workgroup_id_x() kernel
27      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
28      %0 = gpu.block_id x
29      gpu.return
30    }
31  }
32  // CHECK:  gpu.module @[[$KERNELS_1]]
33  // CHECK:  gpu.func @[[$BUILTIN_WG_ID_X]]
34  // CHECK   gpu.block_id x
35  // CHECK:  gpu.return
36
37  // CHECK-LABEL:  spirv.module @{{.*}} Logical GLSL450
38  // CHECK:        spirv.func @[[$BUILTIN_WG_ID_Y]]
39  // CHECK:        spirv.mlir.addressof
40  // CHECK:        spirv.Load "Input"
41  // CHECK:        spirv.CompositeExtract
42  gpu.module @KERNELS_2 {
43    gpu.func @builtin_workgroup_id_y() kernel
44      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
45      %0 = gpu.block_id y
46      gpu.return
47    }
48  }
49  // CHECK:  gpu.module @[[$KERNELS_2]]
50  // CHECK:  gpu.func @[[$BUILTIN_WG_ID_Y]]
51  // CHECK   gpu.block_id y
52  // CHECK:  gpu.return
53}
54
55// -----
56
57module attributes {
58  gpu.container_module,
59  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
60} {
61  // CHECK-LABEL: func.func @main
62  // CHECK-SAME:  %[[ARG0:.*]]: memref<2xi32>, %[[ARG1:.*]]: memref<4xi32>
63  // CHECK:       %[[C1:.*]] = arith.constant 1 : index
64  // CHECK:       gpu.launch_func  @[[$KERNEL_MODULE:.*]]::@[[$KERNEL_FUNC:.*]] blocks in (%[[C1]], %[[C1]], %[[C1]]) threads in (%[[C1]], %[[C1]], %[[C1]]) args(%[[ARG0]] : memref<2xi32>, %[[ARG1]] : memref<4xi32>)
65  func.func @main(%arg0 : memref<2xi32>, %arg2 : memref<4xi32>) {
66    %c1 = arith.constant 1 : index
67    gpu.launch_func @kernels::@kernel_foo
68        blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
69        args(%arg0 : memref<2xi32>, %arg2 : memref<4xi32>)
70    return
71  }
72
73  // CHECK-LABEL: spirv.module @{{.*}} Logical GLSL450
74  // CHECK:       spirv.func @[[$KERNEL_FUNC]]
75  // CHECK-SAME:  %{{.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<2 x i32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
76  // CHECK-SAME:  %{{.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
77  gpu.module @kernels {
78    gpu.func @kernel_foo(%arg0 : memref<2xi32>, %arg1 : memref<4xi32>)
79      kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
80      // CHECK: spirv.Constant
81      // CHECK: spirv.Constant dense<0>
82      %idx0 = arith.constant 0 : index
83      %vec0 = arith.constant dense<[0, 0]> : vector<2xi32>
84      // CHECK: spirv.AccessChain
85      // CHECK: spirv.Load "StorageBuffer"
86      %val = memref.load %arg0[%idx0] : memref<2xi32>
87      // CHECK: spirv.CompositeInsert
88      %vec = vector.insertelement %val, %vec0[%idx0 : index] : vector<2xi32>
89      // CHECK: spirv.VectorShuffle
90      %shuffle = vector.shuffle %vec, %vec[3, 2, 1, 0] : vector<2xi32>, vector<2xi32>
91      // CHECK: spirv.CompositeExtract
92      %res = vector.extractelement %shuffle[%idx0 : index] : vector<4xi32>
93      // CHECK: spirv.AccessChain
94      // CHECK: spirv.Store "StorageBuffer"
95      memref.store %res, %arg1[%idx0]: memref<4xi32>
96      // CHECK: spirv.Return
97      gpu.return
98    }
99  }
100  // CHECK:      gpu.module @[[$KERNEL_MODULE]]
101  // CHECK:      gpu.func @[[$KERNEL_FUNC]]
102  // CHECK-SAME: %{{.*}}: memref<2xi32>, %{{.*}}: memref<4xi32>
103  // CHECK:      arith.constant
104  // CHECK:      memref.load
105  // CHECK:      vector.insertelement
106  // CHECK:      vector.shuffle
107  // CHECK:      vector.extractelement
108  // CHECK:      memref.store
109  // CHECK:      gpu.return
110}
111