xref: /llvm-project/mlir/test/Conversion/SCFToSPIRV/while.mlir (revision 0c21dfdf1263bad26e09fe4232fcea01a69693ab)
1// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-spirv %s -o - | FileCheck %s
2
3module attributes {
4  spirv.target_env = #spirv.target_env<
5    #spirv.vce<v1.0, [Shader, Int64], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
6} {
7
8// CHECK-LABEL: @while_loop1
9func.func @while_loop1(%arg0: i32, %arg1: i32) -> i32 {
10  // CHECK-SAME: (%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
11  // CHECK: %[[INITVAR:.*]] = spirv.Constant 2 : i32
12  // CHECK: %[[VAR1:.*]] = spirv.Variable : !spirv.ptr<i32, Function>
13  // CHECK: spirv.mlir.loop {
14  // CHECK:   spirv.Branch ^[[HEADER:.*]](%[[ARG1]] : i32)
15  // CHECK: ^[[HEADER]](%[[INDVAR1:.*]]: i32):
16  // CHECK:   %[[CMP:.*]] = spirv.SLessThan %[[INDVAR1]], %[[ARG2]] : i32
17  // CHECK:   spirv.Store "Function" %[[VAR1]], %[[INDVAR1]] : i32
18  // CHECK:   spirv.BranchConditional %[[CMP]], ^[[BODY:.*]](%[[INDVAR1]] : i32), ^[[MERGE:.*]]
19  // CHECK: ^[[BODY]](%[[INDVAR2:.*]]: i32):
20  // CHECK:   %[[UPDATED:.*]] = spirv.IMul %[[INDVAR2]], %[[INITVAR]] : i32
21  // CHECK: spirv.Branch ^[[HEADER]](%[[UPDATED]] : i32)
22  // CHECK: ^[[MERGE]]:
23  // CHECK:   spirv.mlir.merge
24  // CHECK: }
25  %c2_i32 = arith.constant 2 : i32
26  %0 = scf.while (%arg3 = %arg0) : (i32) -> (i32) {
27    %1 = arith.cmpi slt, %arg3, %arg1 : i32
28    scf.condition(%1) %arg3 : i32
29  } do {
30  ^bb0(%arg5: i32):
31    %1 = arith.muli %arg5, %c2_i32 : i32
32    scf.yield %1 : i32
33  }
34  // CHECK: %[[OUT:.*]] = spirv.Load "Function" %[[VAR1]] : i32
35  // CHECK: spirv.ReturnValue %[[OUT]] : i32
36  return %0 : i32
37}
38
39// -----
40
41// CHECK-LABEL: @while_loop2
42func.func @while_loop2(%arg0: f32) -> i64 {
43  // CHECK-SAME: (%[[ARG:.*]]: f32)
44  // CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<i64, Function>
45  // CHECK: spirv.mlir.loop {
46  // CHECK:   spirv.Branch ^[[HEADER:.*]](%[[ARG]] : f32)
47  // CHECK: ^[[HEADER]](%[[INDVAR1:.*]]: f32):
48  // CHECK:   %[[SHARED:.*]] = "foo.shared_compute"(%[[INDVAR1]]) : (f32) -> i64
49  // CHECK:   %[[CMP:.*]] = "foo.evaluate_condition"(%[[INDVAR1]], %[[SHARED]]) : (f32, i64) -> i1
50  // CHECK:   spirv.Store "Function" %[[VAR]], %[[SHARED]] : i64
51  // CHECK:   spirv.BranchConditional %[[CMP]], ^[[BODY:.*]](%[[SHARED]] : i64), ^[[MERGE:.*]]
52  // CHECK: ^[[BODY]](%[[INDVAR2:.*]]: i64):
53  // CHECK:   %[[UPDATED:.*]] = "foo.payload"(%[[INDVAR2]]) : (i64) -> f32
54  // CHECK: spirv.Branch ^[[HEADER]](%[[UPDATED]] : f32)
55  // CHECK: ^[[MERGE]]:
56  // CHECK:   spirv.mlir.merge
57  // CHECK: }
58  %res = scf.while (%arg1 = %arg0) : (f32) -> i64 {
59    %shared = "foo.shared_compute"(%arg1) : (f32) -> i64
60    %condition = "foo.evaluate_condition"(%arg1, %shared) : (f32, i64) -> i1
61    scf.condition(%condition) %shared : i64
62  } do {
63  ^bb0(%arg2: i64):
64    %res = "foo.payload"(%arg2) : (i64) -> f32
65    scf.yield %res : f32
66  }
67  // CHECK: %[[OUT:.*]] = spirv.Load "Function" %[[VAR]] : i64
68  // CHECK: spirv.ReturnValue %[[OUT]] : i64
69  return %res : i64
70}
71
72// -----
73
74// CHECK-LABEL: @while_loop_before_typeconv
75func.func @while_loop_before_typeconv(%arg0: index) -> i64 {
76  // CHECK-SAME: (%[[ARG:.*]]: i32)
77  // CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<i64, Function>
78  // CHECK: spirv.mlir.loop {
79  // CHECK:   spirv.Branch ^[[HEADER:.*]](%[[ARG]] : i32)
80  // CHECK: ^[[HEADER]](%[[INDVAR1:.*]]: i32):
81  // CHECK:   spirv.BranchConditional %{{.*}}, ^[[BODY:.*]](%{{.*}} : i64), ^[[MERGE:.*]]
82  // CHECK: ^[[BODY]](%[[INDVAR2:.*]]: i64):
83  // CHECK: spirv.Branch ^[[HEADER]](%{{.*}} : i32)
84  // CHECK: ^[[MERGE]]:
85  // CHECK:   spirv.mlir.merge
86  // CHECK: }
87  %res = scf.while (%arg1 = %arg0) : (index) -> i64 {
88    %shared = "foo.shared_compute"(%arg1) : (index) -> i64
89    %condition = "foo.evaluate_condition"(%arg1, %shared) : (index, i64) -> i1
90    scf.condition(%condition) %shared : i64
91  } do {
92  ^bb0(%arg2: i64):
93    %res = "foo.payload"(%arg2) : (i64) -> index
94    scf.yield %res : index
95  }
96  // CHECK: %[[OUT:.*]] = spirv.Load "Function" %[[VAR]] : i64
97  // CHECK: spirv.ReturnValue %[[OUT]] : i64
98  return %res : i64
99}
100
101// -----
102
103// CHECK-LABEL: @while_loop_after_typeconv
104func.func @while_loop_after_typeconv(%arg0: f32) -> index {
105  // CHECK-SAME: (%[[ARG:.*]]: f32)
106  // CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<i32, Function>
107  // CHECK: spirv.mlir.loop {
108  // CHECK:   spirv.Branch ^[[HEADER:.*]](%[[ARG]] : f32)
109  // CHECK: ^[[HEADER]](%[[INDVAR1:.*]]: f32):
110  // CHECK:   spirv.BranchConditional %{{.*}}, ^[[BODY:.*]](%{{.*}} : i32), ^[[MERGE:.*]]
111  // CHECK: ^[[BODY]](%[[INDVAR2:.*]]: i32):
112  // CHECK: spirv.Branch ^[[HEADER]](%{{.*}} : f32)
113  // CHECK: ^[[MERGE]]:
114  // CHECK:   spirv.mlir.merge
115  // CHECK: }
116  %res = scf.while (%arg1 = %arg0) : (f32) -> index {
117    %shared = "foo.shared_compute"(%arg1) : (f32) -> index
118    %condition = "foo.evaluate_condition"(%arg1, %shared) : (f32, index) -> i1
119    scf.condition(%condition) %shared : index
120  } do {
121  ^bb0(%arg2: index):
122    %res = "foo.payload"(%arg2) : (index) -> f32
123    scf.yield %res : f32
124  }
125  // CHECK: %[[OUT:.*]] = spirv.Load "Function" %[[VAR]] : i32
126  // CHECK: spirv.ReturnValue %[[OUT]] : i32
127  return %res : index
128}
129
130} // end module
131