1// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s 2 3module attributes { 4 spirv.target_env = #spirv.target_env< 5 #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>> 6} { 7 8// CHECK-LABEL: @kernel_simple_selection 9func.func @kernel_simple_selection(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg3 : i1) { 10 %value = arith.constant 0.0 : f32 11 %i = arith.constant 0 : index 12 13 // CHECK: spirv.mlir.selection { 14 // CHECK-NEXT: spirv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[MERGE:\^.*]] 15 // CHECK-NEXT: [[TRUE]]: 16 // CHECK: spirv.Branch [[MERGE]] 17 // CHECK-NEXT: [[MERGE]]: 18 // CHECK-NEXT: spirv.mlir.merge 19 // CHECK-NEXT: } 20 // CHECK-NEXT: spirv.Return 21 22 scf.if %arg3 { 23 memref.store %value, %arg2[%i] : memref<10xf32, #spirv.storage_class<StorageBuffer>> 24 } 25 return 26} 27 28// CHECK-LABEL: @kernel_nested_selection 29func.func @kernel_nested_selection(%arg3 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg4 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg5 : i1, %arg6 : i1) { 30 %i = arith.constant 0 : index 31 %j = arith.constant 9 : index 32 33 // CHECK: spirv.mlir.selection { 34 // CHECK-NEXT: spirv.BranchConditional {{%.*}}, [[TRUE_TOP:\^.*]], [[FALSE_TOP:\^.*]] 35 // CHECK-NEXT: [[TRUE_TOP]]: 36 // CHECK-NEXT: spirv.mlir.selection { 37 // CHECK-NEXT: spirv.BranchConditional {{%.*}}, [[TRUE_NESTED_TRUE_PATH:\^.*]], [[FALSE_NESTED_TRUE_PATH:\^.*]] 38 // CHECK-NEXT: [[TRUE_NESTED_TRUE_PATH]]: 39 // CHECK: spirv.Branch [[MERGE_NESTED_TRUE_PATH:\^.*]] 40 // CHECK-NEXT: [[FALSE_NESTED_TRUE_PATH]]: 41 // CHECK: spirv.Branch [[MERGE_NESTED_TRUE_PATH]] 42 // CHECK-NEXT: [[MERGE_NESTED_TRUE_PATH]]: 43 // CHECK-NEXT: spirv.mlir.merge 44 // CHECK-NEXT: } 45 // CHECK-NEXT: spirv.Branch [[MERGE_TOP:\^.*]] 46 // CHECK-NEXT: [[FALSE_TOP]]: 47 // CHECK-NEXT: spirv.mlir.selection { 48 // CHECK-NEXT: spirv.BranchConditional {{%.*}}, [[TRUE_NESTED_FALSE_PATH:\^.*]], [[FALSE_NESTED_FALSE_PATH:\^.*]] 49 // CHECK-NEXT: [[TRUE_NESTED_FALSE_PATH]]: 50 // CHECK: spirv.Branch [[MERGE_NESTED_FALSE_PATH:\^.*]] 51 // CHECK-NEXT: [[FALSE_NESTED_FALSE_PATH]]: 52 // CHECK: spirv.Branch [[MERGE_NESTED_FALSE_PATH]] 53 // CHECK: [[MERGE_NESTED_FALSE_PATH]]: 54 // CHECK-NEXT: spirv.mlir.merge 55 // CHECK-NEXT: } 56 // CHECK-NEXT: spirv.Branch [[MERGE_TOP]] 57 // CHECK-NEXT: [[MERGE_TOP]]: 58 // CHECK-NEXT: spirv.mlir.merge 59 // CHECK-NEXT: } 60 // CHECK-NEXT: spirv.Return 61 62 scf.if %arg5 { 63 scf.if %arg6 { 64 %value = memref.load %arg3[%i] : memref<10xf32, #spirv.storage_class<StorageBuffer>> 65 memref.store %value, %arg4[%i] : memref<10xf32, #spirv.storage_class<StorageBuffer>> 66 } else { 67 %value = memref.load %arg4[%i] : memref<10xf32, #spirv.storage_class<StorageBuffer>> 68 memref.store %value, %arg3[%i] : memref<10xf32, #spirv.storage_class<StorageBuffer>> 69 } 70 } else { 71 scf.if %arg6 { 72 %value = memref.load %arg3[%j] : memref<10xf32, #spirv.storage_class<StorageBuffer>> 73 memref.store %value, %arg4[%j] : memref<10xf32, #spirv.storage_class<StorageBuffer>> 74 } else { 75 %value = memref.load %arg4[%j] : memref<10xf32, #spirv.storage_class<StorageBuffer>> 76 memref.store %value, %arg3[%j] : memref<10xf32, #spirv.storage_class<StorageBuffer>> 77 } 78 } 79 return 80} 81 82// CHECK-LABEL: @simple_if_yield 83func.func @simple_if_yield(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg3 : i1) { 84 // CHECK: %[[VAR1:.*]] = spirv.Variable : !spirv.ptr<f32, Function> 85 // CHECK: %[[VAR2:.*]] = spirv.Variable : !spirv.ptr<f32, Function> 86 // CHECK: spirv.mlir.selection { 87 // CHECK-NEXT: spirv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]] 88 // CHECK-NEXT: [[TRUE]]: 89 // CHECK: %[[RET1TRUE:.*]] = spirv.Constant 0.000000e+00 : f32 90 // CHECK: %[[RET2TRUE:.*]] = spirv.Constant 1.000000e+00 : f32 91 // CHECK-DAG: spirv.Store "Function" %[[VAR1]], %[[RET1TRUE]] : f32 92 // CHECK-DAG: spirv.Store "Function" %[[VAR2]], %[[RET2TRUE]] : f32 93 // CHECK: spirv.Branch ^[[MERGE:.*]] 94 // CHECK-NEXT: [[FALSE]]: 95 // CHECK: %[[RET2FALSE:.*]] = spirv.Constant 2.000000e+00 : f32 96 // CHECK: %[[RET1FALSE:.*]] = spirv.Constant 3.000000e+00 : f32 97 // CHECK-DAG: spirv.Store "Function" %[[VAR1]], %[[RET1FALSE]] : f32 98 // CHECK-DAG: spirv.Store "Function" %[[VAR2]], %[[RET2FALSE]] : f32 99 // CHECK: spirv.Branch ^[[MERGE]] 100 // CHECK-NEXT: ^[[MERGE]]: 101 // CHECK: spirv.mlir.merge 102 // CHECK-NEXT: } 103 // CHECK-DAG: %[[OUT1:.*]] = spirv.Load "Function" %[[VAR1]] : f32 104 // CHECK-DAG: %[[OUT2:.*]] = spirv.Load "Function" %[[VAR2]] : f32 105 // CHECK: spirv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32 106 // CHECK: spirv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32 107 // CHECK: spirv.Return 108 %0:2 = scf.if %arg3 -> (f32, f32) { 109 %c0 = arith.constant 0.0 : f32 110 %c1 = arith.constant 1.0 : f32 111 scf.yield %c0, %c1 : f32, f32 112 } else { 113 %c0 = arith.constant 2.0 : f32 114 %c1 = arith.constant 3.0 : f32 115 scf.yield %c1, %c0 : f32, f32 116 } 117 %i = arith.constant 0 : index 118 %j = arith.constant 1 : index 119 memref.store %0#0, %arg2[%i] : memref<10xf32, #spirv.storage_class<StorageBuffer>> 120 memref.store %0#1, %arg2[%j] : memref<10xf32, #spirv.storage_class<StorageBuffer>> 121 return 122} 123 124// TODO: The transformation should only be legal if VariablePointer capability 125// is supported. This test is still useful to make sure we can handle scf op 126// result with type change. 127func.func @simple_if_yield_type_change(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg3 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg4 : i1) { 128 // CHECK-LABEL: @simple_if_yield_type_change 129 // CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<!spirv.ptr<!spirv.struct<(!spirv.array<10 x f32, stride=4> [0])>, StorageBuffer>, Function> 130 // CHECK: spirv.mlir.selection { 131 // CHECK-NEXT: spirv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]] 132 // CHECK-NEXT: [[TRUE]]: 133 // CHECK: spirv.Store "Function" %[[VAR]], {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.array<10 x f32, stride=4> [0])>, StorageBuffer> 134 // CHECK: spirv.Branch ^[[MERGE:.*]] 135 // CHECK-NEXT: [[FALSE]]: 136 // CHECK: spirv.Store "Function" %[[VAR]], {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.array<10 x f32, stride=4> [0])>, StorageBuffer> 137 // CHECK: spirv.Branch ^[[MERGE]] 138 // CHECK-NEXT: ^[[MERGE]]: 139 // CHECK: spirv.mlir.merge 140 // CHECK-NEXT: } 141 // CHECK: %[[OUT:.*]] = spirv.Load "Function" %[[VAR]] : !spirv.ptr<!spirv.struct<(!spirv.array<10 x f32, stride=4> [0])>, StorageBuffer> 142 // CHECK: %[[ADD:.*]] = spirv.AccessChain %[[OUT]][{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<10 x f32, stride=4> [0])>, StorageBuffer> 143 // CHECK: spirv.Store "StorageBuffer" %[[ADD]], {{%.*}} : f32 144 // CHECK: spirv.Return 145 %i = arith.constant 0 : index 146 %value = arith.constant 0.0 : f32 147 %0 = scf.if %arg4 -> (memref<10xf32, #spirv.storage_class<StorageBuffer>>) { 148 scf.yield %arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>> 149 } else { 150 scf.yield %arg3 : memref<10xf32, #spirv.storage_class<StorageBuffer>> 151 } 152 memref.store %value, %0[%i] : memref<10xf32, #spirv.storage_class<StorageBuffer>> 153 return 154} 155 156// Memrefs without a spirv storage class are not supported. The conversion 157// should preserve the `scf.if` and not crash. 158func.func @unsupported_yield_type(%arg0 : memref<8xi32>, %arg1 : memref<8xi32>, %c : i1) { 159// CHECK-LABEL: @unsupported_yield_type 160// CHECK-NEXT: scf.if 161// CHECK: spirv.Return 162 %r = scf.if %c -> (memref<8xi32>) { 163 scf.yield %arg0 : memref<8xi32> 164 } else { 165 scf.yield %arg1 : memref<8xi32> 166 } 167 return 168} 169 170} // end module 171