xref: /llvm-project/mlir/test/Conversion/SCFToSPIRV/if.mlir (revision 80d5400d924e543c5420f4e924f5818313605e99)
1// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s
2
3module attributes {
4  spirv.target_env = #spirv.target_env<
5    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
6} {
7
8// CHECK-LABEL: @kernel_simple_selection
9func.func @kernel_simple_selection(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg3 : i1) {
10  %value = arith.constant 0.0 : f32
11  %i = arith.constant 0 : index
12
13  // CHECK:       spirv.mlir.selection {
14  // CHECK-NEXT:    spirv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[MERGE:\^.*]]
15  // CHECK-NEXT:  [[TRUE]]:
16  // CHECK:         spirv.Branch [[MERGE]]
17  // CHECK-NEXT:  [[MERGE]]:
18  // CHECK-NEXT:    spirv.mlir.merge
19  // CHECK-NEXT:  }
20  // CHECK-NEXT:  spirv.Return
21
22  scf.if %arg3 {
23    memref.store %value, %arg2[%i] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
24  }
25  return
26}
27
28// CHECK-LABEL: @kernel_nested_selection
29func.func @kernel_nested_selection(%arg3 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg4 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg5 : i1, %arg6 : i1) {
30  %i = arith.constant 0 : index
31  %j = arith.constant 9 : index
32
33  // CHECK:       spirv.mlir.selection {
34  // CHECK-NEXT:    spirv.BranchConditional {{%.*}}, [[TRUE_TOP:\^.*]], [[FALSE_TOP:\^.*]]
35  // CHECK-NEXT:  [[TRUE_TOP]]:
36  // CHECK-NEXT:    spirv.mlir.selection {
37  // CHECK-NEXT:      spirv.BranchConditional {{%.*}}, [[TRUE_NESTED_TRUE_PATH:\^.*]], [[FALSE_NESTED_TRUE_PATH:\^.*]]
38  // CHECK-NEXT:    [[TRUE_NESTED_TRUE_PATH]]:
39  // CHECK:           spirv.Branch [[MERGE_NESTED_TRUE_PATH:\^.*]]
40  // CHECK-NEXT:    [[FALSE_NESTED_TRUE_PATH]]:
41  // CHECK:           spirv.Branch [[MERGE_NESTED_TRUE_PATH]]
42  // CHECK-NEXT:    [[MERGE_NESTED_TRUE_PATH]]:
43  // CHECK-NEXT:      spirv.mlir.merge
44  // CHECK-NEXT:    }
45  // CHECK-NEXT:    spirv.Branch [[MERGE_TOP:\^.*]]
46  // CHECK-NEXT:  [[FALSE_TOP]]:
47  // CHECK-NEXT:    spirv.mlir.selection {
48  // CHECK-NEXT:      spirv.BranchConditional {{%.*}}, [[TRUE_NESTED_FALSE_PATH:\^.*]], [[FALSE_NESTED_FALSE_PATH:\^.*]]
49  // CHECK-NEXT:    [[TRUE_NESTED_FALSE_PATH]]:
50  // CHECK:           spirv.Branch [[MERGE_NESTED_FALSE_PATH:\^.*]]
51  // CHECK-NEXT:    [[FALSE_NESTED_FALSE_PATH]]:
52  // CHECK:           spirv.Branch [[MERGE_NESTED_FALSE_PATH]]
53  // CHECK:         [[MERGE_NESTED_FALSE_PATH]]:
54  // CHECK-NEXT:      spirv.mlir.merge
55  // CHECK-NEXT:    }
56  // CHECK-NEXT:    spirv.Branch [[MERGE_TOP]]
57  // CHECK-NEXT:  [[MERGE_TOP]]:
58  // CHECK-NEXT:    spirv.mlir.merge
59  // CHECK-NEXT:  }
60  // CHECK-NEXT:  spirv.Return
61
62  scf.if %arg5 {
63    scf.if %arg6 {
64      %value = memref.load %arg3[%i] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
65      memref.store %value, %arg4[%i] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
66    } else {
67      %value = memref.load %arg4[%i] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
68      memref.store %value, %arg3[%i] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
69    }
70  } else {
71    scf.if %arg6 {
72      %value = memref.load %arg3[%j] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
73      memref.store %value, %arg4[%j] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
74    } else {
75      %value = memref.load %arg4[%j] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
76      memref.store %value, %arg3[%j] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
77    }
78  }
79  return
80}
81
82// CHECK-LABEL: @simple_if_yield
83func.func @simple_if_yield(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg3 : i1) {
84  // CHECK: %[[VAR1:.*]] = spirv.Variable : !spirv.ptr<f32, Function>
85  // CHECK: %[[VAR2:.*]] = spirv.Variable : !spirv.ptr<f32, Function>
86  // CHECK:       spirv.mlir.selection {
87  // CHECK-NEXT:    spirv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]]
88  // CHECK-NEXT:  [[TRUE]]:
89  // CHECK:         %[[RET1TRUE:.*]] = spirv.Constant 0.000000e+00 : f32
90  // CHECK:         %[[RET2TRUE:.*]] = spirv.Constant 1.000000e+00 : f32
91  // CHECK-DAG:     spirv.Store "Function" %[[VAR1]], %[[RET1TRUE]] : f32
92  // CHECK-DAG:     spirv.Store "Function" %[[VAR2]], %[[RET2TRUE]] : f32
93  // CHECK:         spirv.Branch ^[[MERGE:.*]]
94  // CHECK-NEXT:  [[FALSE]]:
95  // CHECK:         %[[RET2FALSE:.*]] = spirv.Constant 2.000000e+00 : f32
96  // CHECK:         %[[RET1FALSE:.*]] = spirv.Constant 3.000000e+00 : f32
97  // CHECK-DAG:     spirv.Store "Function" %[[VAR1]], %[[RET1FALSE]] : f32
98  // CHECK-DAG:     spirv.Store "Function" %[[VAR2]], %[[RET2FALSE]] : f32
99  // CHECK:         spirv.Branch ^[[MERGE]]
100  // CHECK-NEXT:  ^[[MERGE]]:
101  // CHECK:         spirv.mlir.merge
102  // CHECK-NEXT:  }
103  // CHECK-DAG:   %[[OUT1:.*]] = spirv.Load "Function" %[[VAR1]] : f32
104  // CHECK-DAG:   %[[OUT2:.*]] = spirv.Load "Function" %[[VAR2]] : f32
105  // CHECK:       spirv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32
106  // CHECK:       spirv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32
107  // CHECK:       spirv.Return
108  %0:2 = scf.if %arg3 -> (f32, f32) {
109    %c0 = arith.constant 0.0 : f32
110    %c1 = arith.constant 1.0 : f32
111    scf.yield %c0, %c1 : f32, f32
112  } else {
113    %c0 = arith.constant 2.0 : f32
114    %c1 = arith.constant 3.0 : f32
115    scf.yield %c1, %c0 : f32, f32
116  }
117  %i = arith.constant 0 : index
118  %j = arith.constant 1 : index
119  memref.store %0#0, %arg2[%i] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
120  memref.store %0#1, %arg2[%j] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
121  return
122}
123
124// TODO: The transformation should only be legal if VariablePointer capability
125// is supported. This test is still useful to make sure we can handle scf op
126// result with type change.
127func.func @simple_if_yield_type_change(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg3 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg4 : i1) {
128  // CHECK-LABEL: @simple_if_yield_type_change
129  // CHECK:       %[[VAR:.*]] = spirv.Variable : !spirv.ptr<!spirv.ptr<!spirv.struct<(!spirv.array<10 x f32, stride=4> [0])>, StorageBuffer>, Function>
130  // CHECK:       spirv.mlir.selection {
131  // CHECK-NEXT:    spirv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]]
132  // CHECK-NEXT:  [[TRUE]]:
133  // CHECK:         spirv.Store "Function" %[[VAR]], {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.array<10 x f32, stride=4> [0])>, StorageBuffer>
134  // CHECK:         spirv.Branch ^[[MERGE:.*]]
135  // CHECK-NEXT:  [[FALSE]]:
136  // CHECK:         spirv.Store "Function" %[[VAR]], {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.array<10 x f32, stride=4> [0])>, StorageBuffer>
137  // CHECK:         spirv.Branch ^[[MERGE]]
138  // CHECK-NEXT:  ^[[MERGE]]:
139  // CHECK:         spirv.mlir.merge
140  // CHECK-NEXT:  }
141  // CHECK:       %[[OUT:.*]] = spirv.Load "Function" %[[VAR]] : !spirv.ptr<!spirv.struct<(!spirv.array<10 x f32, stride=4> [0])>, StorageBuffer>
142  // CHECK:       %[[ADD:.*]] = spirv.AccessChain %[[OUT]][{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<10 x f32, stride=4> [0])>, StorageBuffer>
143  // CHECK:       spirv.Store "StorageBuffer" %[[ADD]], {{%.*}} : f32
144  // CHECK:       spirv.Return
145  %i = arith.constant 0 : index
146  %value = arith.constant 0.0 : f32
147  %0 = scf.if %arg4 -> (memref<10xf32, #spirv.storage_class<StorageBuffer>>) {
148    scf.yield %arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>>
149  } else {
150    scf.yield %arg3 : memref<10xf32, #spirv.storage_class<StorageBuffer>>
151  }
152  memref.store %value, %0[%i] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
153  return
154}
155
156// Memrefs without a spirv storage class are not supported. The conversion
157// should preserve the `scf.if` and not crash.
158func.func @unsupported_yield_type(%arg0 : memref<8xi32>, %arg1 : memref<8xi32>, %c : i1) {
159// CHECK-LABEL: @unsupported_yield_type
160// CHECK-NEXT:    scf.if
161// CHECK:         spirv.Return
162  %r = scf.if %c -> (memref<8xi32>) {
163    scf.yield %arg0 : memref<8xi32>
164  } else {
165    scf.yield %arg1 : memref<8xi32>
166  }
167  return
168}
169
170} // end module
171