xref: /llvm-project/mlir/test/Dialect/SCF/ops.mlir (revision 10056c821a56a19cef732129e4e0c5883ae1ee49)
1// RUN: mlir-opt %s | FileCheck %s
2// Verify the printed output can be parsed.
3// RUN: mlir-opt %s | mlir-opt | FileCheck %s
4// Verify the generic form can be parsed.
5// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
6
7func.func @std_for(%arg0 : index, %arg1 : index, %arg2 : index) {
8  scf.for %i0 = %arg0 to %arg1 step %arg2 {
9    scf.for %i1 = %arg0 to %arg1 step %arg2 {
10      %min_cmp = arith.cmpi slt, %i0, %i1 : index
11      %min = arith.select %min_cmp, %i0, %i1 : index
12      %max_cmp = arith.cmpi sge, %i0, %i1 : index
13      %max = arith.select %max_cmp, %i0, %i1 : index
14      scf.for %i2 = %min to %max step %i1 {
15      }
16    }
17  }
18  return
19}
20// CHECK-LABEL: func @std_for(
21//  CHECK-NEXT:   scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
22//  CHECK-NEXT:     scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
23//  CHECK-NEXT:       %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : index
24//  CHECK-NEXT:       %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : index
25//  CHECK-NEXT:       %{{.*}} = arith.cmpi sge, %{{.*}}, %{{.*}} : index
26//  CHECK-NEXT:       %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : index
27//  CHECK-NEXT:       scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
28
29func.func @std_for_i32(%arg0 : i32, %arg1 : i32, %arg2 : i32) {
30  scf.for %i0 = %arg0 to %arg1 step %arg2 : i32 {
31    scf.for %i1 = %arg0 to %arg1 step %arg2 : i32 {
32    }
33  }
34  return
35}
36// CHECK-LABEL: func @std_for_i32(
37//  CHECK-NEXT:   scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} : i32 {
38//  CHECK-NEXT:     scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} : i32 {
39
40func.func @scf_for_i64_iter(%arg1: i64, %arg2: i64) {
41  %c1_i64 = arith.constant 1 : i64
42  %c0_i64 = arith.constant 0 : i64
43  %0 = scf.for %arg3 = %arg1 to %arg2 step %c1_i64 iter_args(%arg4 = %c0_i64) -> (i64) : i64 {
44    %1 = arith.addi %arg4, %arg3 : i64
45    scf.yield %1 : i64
46  }
47  return
48}
49// CHECK-LABEL: scf_for_i64_iter
50// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}} -> (i64) : i64 {
51
52func.func @std_if(%arg0: i1, %arg1: f32) {
53  scf.if %arg0 {
54    %0 = arith.addf %arg1, %arg1 : f32
55  }
56  return
57}
58// CHECK-LABEL: func @std_if(
59//  CHECK-NEXT:   scf.if %{{.*}} {
60//  CHECK-NEXT:     %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f32
61
62func.func @std_if_else(%arg0: i1, %arg1: f32) {
63  scf.if %arg0 {
64    %0 = arith.addf %arg1, %arg1 : f32
65  } else {
66    %1 = arith.addf %arg1, %arg1 : f32
67  }
68  return
69}
70// CHECK-LABEL: func @std_if_else(
71//  CHECK-NEXT:   scf.if %{{.*}} {
72//  CHECK-NEXT:     %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f32
73//  CHECK-NEXT:   } else {
74//  CHECK-NEXT:     %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f32
75
76func.func @std_parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
77                        %arg3 : index, %arg4 : index) {
78  %step = arith.constant 1 : index
79  scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
80                                          step (%arg4, %step) {
81    %min_cmp = arith.cmpi slt, %i0, %i1 : index
82    %min = arith.select %min_cmp, %i0, %i1 : index
83    %max_cmp = arith.cmpi sge, %i0, %i1 : index
84    %max = arith.select %max_cmp, %i0, %i1 : index
85    %zero = arith.constant 0.0 : f32
86    %int_zero = arith.constant 0 : i32
87    %red:2 = scf.parallel (%i2) = (%min) to (%max) step (%i1)
88                                      init (%zero, %int_zero) -> (f32, i32) {
89      %one = arith.constant 1.0 : f32
90      %int_one = arith.constant 1 : i32
91      scf.reduce(%one, %int_one : f32, i32)  {
92        ^bb0(%lhs : f32, %rhs: f32):
93          %res = arith.addf %lhs, %rhs : f32
94          scf.reduce.return %res : f32
95      }, {
96        ^bb0(%lhs : i32, %rhs: i32):
97          %res = arith.muli %lhs, %rhs : i32
98          scf.reduce.return %res : i32
99      }
100    }
101    scf.reduce
102  }
103  return
104}
105// CHECK-LABEL: func @std_parallel_loop(
106//  CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]:
107//  CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]:
108//  CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]:
109//  CHECK-SAME: %[[ARG3:[A-Za-z0-9]+]]:
110//  CHECK-SAME: %[[ARG4:[A-Za-z0-9]+]]:
111//       CHECK:   %[[STEP:.*]] = arith.constant 1 : index
112//  CHECK-NEXT:   scf.parallel (%[[I0:.*]], %[[I1:.*]]) = (%[[ARG0]], %[[ARG1]]) to
113//       CHECK:   (%[[ARG2]], %[[ARG3]]) step (%[[ARG4]], %[[STEP]]) {
114//  CHECK-NEXT:     %[[MIN_CMP:.*]] = arith.cmpi slt, %[[I0]], %[[I1]] : index
115//  CHECK-NEXT:     %[[MIN:.*]] = arith.select %[[MIN_CMP]], %[[I0]], %[[I1]] : index
116//  CHECK-NEXT:     %[[MAX_CMP:.*]] = arith.cmpi sge, %[[I0]], %[[I1]] : index
117//  CHECK-NEXT:     %[[MAX:.*]] = arith.select %[[MAX_CMP]], %[[I0]], %[[I1]] : index
118//  CHECK-NEXT:     %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
119//  CHECK-NEXT:     %[[INT_ZERO:.*]] = arith.constant 0 : i32
120//  CHECK-NEXT:     scf.parallel (%{{.*}}) = (%[[MIN]]) to (%[[MAX]])
121//  CHECK-SAME:          step (%[[I1]])
122//  CHECK-SAME:          init (%[[ZERO]], %[[INT_ZERO]]) -> (f32, i32) {
123//  CHECK-NEXT:       %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
124//  CHECK-NEXT:       %[[INT_ONE:.*]] = arith.constant 1 : i32
125//  CHECK-NEXT:       scf.reduce(%[[ONE]], %[[INT_ONE]] : f32, i32) {
126//  CHECK-NEXT:       ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
127//  CHECK-NEXT:         %[[RES:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
128//  CHECK-NEXT:         scf.reduce.return %[[RES]] : f32
129//  CHECK-NEXT:       }, {
130//  CHECK-NEXT:       ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
131//  CHECK-NEXT:         %[[RES:.*]] = arith.muli %[[LHS]], %[[RHS]] : i32
132//  CHECK-NEXT:         scf.reduce.return %[[RES]] : i32
133//  CHECK-NEXT:       }
134//  CHECK-NEXT:     }
135//  CHECK-NEXT:     scf.reduce
136
137func.func @parallel_explicit_yield(
138    %arg0: index, %arg1: index, %arg2: index) {
139  scf.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) {
140    scf.reduce
141  }
142  return
143}
144
145// CHECK-LABEL: func @parallel_explicit_yield(
146//  CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]:
147//  CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]:
148//  CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]:
149//  CHECK-NEXT: scf.parallel (%{{.*}}) = (%[[ARG0]]) to (%[[ARG1]]) step (%[[ARG2]])
150//  CHECK-NEXT: scf.reduce
151//  CHECK-NEXT: }
152//  CHECK-NEXT: return
153//  CHECK-NEXT: }
154
155func.func @std_if_yield(%arg0: i1, %arg1: f32)
156{
157  %x, %y = scf.if %arg0 -> (f32, f32) {
158    %0 = arith.addf %arg1, %arg1 : f32
159    %1 = arith.subf %arg1, %arg1 : f32
160    scf.yield %0, %1 : f32, f32
161  } else {
162    %0 = arith.subf %arg1, %arg1 : f32
163    %1 = arith.addf %arg1, %arg1 : f32
164    scf.yield %0, %1 : f32, f32
165  }
166  return
167}
168// CHECK-LABEL: func @std_if_yield(
169//  CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]:
170//  CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]:
171//  CHECK-NEXT: %{{.*}}:2 = scf.if %[[ARG0]] -> (f32, f32) {
172//  CHECK-NEXT: %[[T1:.*]] = arith.addf %[[ARG1]], %[[ARG1]]
173//  CHECK-NEXT: %[[T2:.*]] = arith.subf %[[ARG1]], %[[ARG1]]
174//  CHECK-NEXT: scf.yield %[[T1]], %[[T2]] : f32, f32
175//  CHECK-NEXT: } else {
176//  CHECK-NEXT: %[[T3:.*]] = arith.subf %[[ARG1]], %[[ARG1]]
177//  CHECK-NEXT: %[[T4:.*]] = arith.addf %[[ARG1]], %[[ARG1]]
178//  CHECK-NEXT: scf.yield %[[T3]], %[[T4]] : f32, f32
179//  CHECK-NEXT: }
180
181func.func @std_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) {
182  %s0 = arith.constant 0.0 : f32
183  %result = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0) -> (f32) {
184    %sn = arith.addf %si, %si : f32
185    scf.yield %sn : f32
186  }
187  return
188}
189// CHECK-LABEL: func @std_for_yield(
190// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]:
191// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]:
192// CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]:
193// CHECK-NEXT: %[[INIT:.*]] = arith.constant
194// CHECK-NEXT: %{{.*}} = scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
195// CHECK-SAME: iter_args(%[[ITER:.*]] = %[[INIT]]) -> (f32) {
196// CHECK-NEXT: %[[NEXT:.*]] = arith.addf %[[ITER]], %[[ITER]] : f32
197// CHECK-NEXT: scf.yield %[[NEXT]] : f32
198// CHECK-NEXT: }
199
200
201func.func @std_for_yield_multi(%arg0 : index, %arg1 : index, %arg2 : index) {
202  %s0 = arith.constant 0.0 : f32
203  %t0 = arith.constant 1 : i32
204  %u0 = arith.constant 1.0 : f32
205  %result1:3 = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0, %ti = %t0, %ui = %u0) -> (f32, i32, f32) {
206    %sn = arith.addf %si, %si : f32
207    %tn = arith.addi %ti, %ti : i32
208    %un = arith.subf %ui, %ui : f32
209    scf.yield %sn, %tn, %un : f32, i32, f32
210  }
211  return
212}
213// CHECK-LABEL: func @std_for_yield_multi(
214// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]:
215// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]:
216// CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]:
217// CHECK-NEXT: %[[INIT1:.*]] = arith.constant
218// CHECK-NEXT: %[[INIT2:.*]] = arith.constant
219// CHECK-NEXT: %[[INIT3:.*]] = arith.constant
220// CHECK-NEXT: %{{.*}}:3 = scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
221// CHECK-SAME: iter_args(%[[ITER1:.*]] = %[[INIT1]], %[[ITER2:.*]] = %[[INIT2]], %[[ITER3:.*]] = %[[INIT3]]) -> (f32, i32, f32) {
222// CHECK-NEXT: %[[NEXT1:.*]] = arith.addf %[[ITER1]], %[[ITER1]] : f32
223// CHECK-NEXT: %[[NEXT2:.*]] = arith.addi %[[ITER2]], %[[ITER2]] : i32
224// CHECK-NEXT: %[[NEXT3:.*]] = arith.subf %[[ITER3]], %[[ITER3]] : f32
225// CHECK-NEXT: scf.yield %[[NEXT1]], %[[NEXT2]], %[[NEXT3]] : f32, i32, f32
226
227
228func.func @conditional_reduce(%buffer: memref<1024xf32>, %lb: index, %ub: index, %step: index) -> (f32) {
229  %sum_0 = arith.constant 0.0 : f32
230  %c0 = arith.constant 0.0 : f32
231  %sum = scf.for %iv = %lb to %ub step %step iter_args(%sum_iter = %sum_0) -> (f32) {
232	  %t = memref.load %buffer[%iv] : memref<1024xf32>
233	  %cond = arith.cmpf ugt, %t, %c0 : f32
234	  %sum_next = scf.if %cond -> (f32) {
235	    %new_sum = arith.addf %sum_iter, %t : f32
236      scf.yield %new_sum : f32
237	  } else {
238  		scf.yield %sum_iter : f32
239	  }
240    scf.yield %sum_next : f32
241  }
242  return %sum : f32
243}
244// CHECK-LABEL: func @conditional_reduce(
245//  CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]
246//  CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]
247//  CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]
248//  CHECK-SAME: %[[ARG3:[A-Za-z0-9]+]]
249//  CHECK-NEXT: %[[INIT:.*]] = arith.constant
250//  CHECK-NEXT: %[[ZERO:.*]] = arith.constant
251//  CHECK-NEXT: %[[RESULT:.*]] = scf.for %[[IV:.*]] = %[[ARG1]] to %[[ARG2]] step %[[ARG3]]
252//  CHECK-SAME: iter_args(%[[ITER:.*]] = %[[INIT]]) -> (f32) {
253//  CHECK-NEXT: %[[T:.*]] = memref.load %[[ARG0]][%[[IV]]]
254//  CHECK-NEXT: %[[COND:.*]] = arith.cmpf ugt, %[[T]], %[[ZERO]]
255//  CHECK-NEXT: %[[IFRES:.*]] = scf.if %[[COND]] -> (f32) {
256//  CHECK-NEXT: %[[THENRES:.*]] = arith.addf %[[ITER]], %[[T]]
257//  CHECK-NEXT: scf.yield %[[THENRES]] : f32
258//  CHECK-NEXT: } else {
259//  CHECK-NEXT: scf.yield %[[ITER]] : f32
260//  CHECK-NEXT: }
261//  CHECK-NEXT: scf.yield %[[IFRES]] : f32
262//  CHECK-NEXT: }
263//  CHECK-NEXT: return %[[RESULT]]
264
265// CHECK-LABEL: @while
266func.func @while() {
267  %0 = "test.get_some_value"() : () -> i32
268  %1 = "test.get_some_value"() : () -> f32
269
270  // CHECK: = scf.while (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) : (i32, f32) -> (i64, f64) {
271  %2:2 = scf.while (%arg0 = %0, %arg1 = %1) : (i32, f32) -> (i64, f64) {
272    %3:2 = "test.some_operation"(%arg0, %arg1) : (i32, f32) -> (i64, f64)
273    %4 = "test.some_condition"(%arg0, %arg1) : (i32, f32) -> i1
274    // CHECK: scf.condition(%{{.*}}) %{{.*}}, %{{.*}} : i64, f64
275    scf.condition(%4) %3#0, %3#1 : i64, f64
276  // CHECK: } do {
277  } do {
278  // CHECK: ^{{.*}}(%{{.*}}: i64, %{{.*}}: f64):
279  ^bb0(%arg2: i64, %arg3: f64):
280    %5:2 = "test.some_operation"(%arg2, %arg3): (i64, f64) -> (i32, f32)
281    // CHECK: scf.yield %{{.*}}, %{{.*}} : i32, f32
282    scf.yield %5#0, %5#1 : i32, f32
283  // CHECK: attributes {foo = "bar"}
284  } attributes {foo="bar"}
285  return
286}
287
288// CHECK-LABEL: @infinite_while
289func.func @infinite_while() {
290  %true = arith.constant true
291
292  // CHECK: scf.while  : () -> () {
293  scf.while : () -> () {
294    // CHECK: scf.condition(%{{.*}})
295    scf.condition(%true)
296  // CHECK: } do {
297  } do {
298    // CHECK: scf.yield
299    scf.yield
300  }
301  return
302}
303
304// CHECK-LABEL: func @execute_region
305func.func @execute_region() -> i64 {
306  // CHECK:      scf.execute_region -> i64 {
307  // CHECK-NEXT:   arith.constant
308  // CHECK-NEXT:   scf.yield
309  // CHECK-NEXT: }
310  %res = scf.execute_region -> i64 {
311    %c1 = arith.constant 1 : i64
312    scf.yield %c1 : i64
313  }
314
315  // CHECK:      scf.execute_region -> (i64, i64) {
316  %res2:2 = scf.execute_region -> (i64, i64) {
317    %c1 = arith.constant 1 : i64
318    scf.yield %c1, %c1 : i64, i64
319  }
320
321  // CHECK:       scf.execute_region {
322  // CHECK-NEXT:    cf.br ^bb1
323  // CHECK-NEXT:  ^bb1:
324  // CHECK-NEXT:    scf.yield
325  // CHECK-NEXT:  }
326  "scf.execute_region"() ({
327  ^bb0:
328    cf.br ^bb1
329  ^bb1:
330    scf.yield
331  }) : () -> ()
332  return %res : i64
333}
334
335// CHECK-LABEL: func.func @normalized_forall
336func.func @normalized_forall(%in: tensor<100xf32>, %out: tensor<100xf32>) {
337  %c1 = arith.constant 1 : index
338  %num_threads = arith.constant 100 : index
339
340  //      CHECK:    scf.forall
341  // CHECK-NEXT:  tensor.extract_slice
342  // CHECK-NEXT:  scf.forall.in_parallel
343  // CHECK-NEXT:  tensor.parallel_insert_slice
344  // CHECK-NEXT:  }
345  // CHECK-NEXT:  }
346  // CHECK-NEXT:  return
347  %result = scf.forall (%thread_idx) in (%num_threads) shared_outs(%o = %out) -> tensor<100xf32> {
348      %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
349      scf.forall.in_parallel {
350        tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] :
351          tensor<1xf32> into tensor<100xf32>
352      }
353  }
354  return
355}
356
357// CHECK-LABEL: func.func @explicit_loop_bounds_forall
358func.func @explicit_loop_bounds_forall(%in: tensor<100xf32>,
359    %out: tensor<100xf32>) {
360  %c0 = arith.constant 0 : index
361  %c1 = arith.constant 1 : index
362  %num_threads = arith.constant 100 : index
363
364  //      CHECK:    scf.forall
365  // CHECK-NEXT:  tensor.extract_slice
366  // CHECK-NEXT:  scf.forall.in_parallel
367  // CHECK-NEXT:  tensor.parallel_insert_slice
368  // CHECK-NEXT:  }
369  // CHECK-NEXT:  }
370  // CHECK-NEXT:  return
371  %result = scf.forall (%thread_idx) =  (%c0) to (%num_threads) step (%c1) shared_outs(%o = %out) -> tensor<100xf32> {
372      %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
373      scf.forall.in_parallel {
374        tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] :
375          tensor<1xf32> into tensor<100xf32>
376      }
377  }
378  return
379}
380
381// CHECK-LABEL: func.func @normalized_forall_elide_terminator
382func.func @normalized_forall_elide_terminator() -> () {
383  %num_threads = arith.constant 100 : index
384
385  //      CHECK:    scf.forall
386  // CHECK-NEXT:  } {mapping = [#gpu.thread<x>]}
387  // CHECK-NEXT:  return
388  scf.forall (%thread_idx) in (%num_threads) {
389    scf.forall.in_parallel {
390    }
391  } {mapping = [#gpu.thread<x>]}
392  return
393
394}
395
396// CHECK-LABEL: func.func @explicit_loop_bounds_forall_elide_terminator
397func.func @explicit_loop_bounds_forall_elide_terminator() -> () {
398  %c0 = arith.constant 0 : index
399  %c1 = arith.constant 1 : index
400  %num_threads = arith.constant 100 : index
401
402  //      CHECK:    scf.forall
403  // CHECK-NEXT:  } {mapping = [#gpu.thread<x>]}
404  // CHECK-NEXT:  return
405  scf.forall (%thread_idx) = (%c0) to (%num_threads) step (%c1) {
406    scf.forall.in_parallel {
407    }
408  } {mapping = [#gpu.thread<x>]}
409  return
410}
411
412// CHECK-LABEL: @switch
413func.func @switch(%arg0: index) -> i32 {
414  // CHECK: %{{.*}} = scf.index_switch %arg0 -> i32
415  %0 = scf.index_switch %arg0 -> i32
416  // CHECK-NEXT: case 2 {
417  case 2 {
418    // CHECK-NEXT: arith.constant
419    %c10_i32 = arith.constant 10 : i32
420    // CHECK-NEXT: scf.yield %{{.*}} : i32
421    scf.yield %c10_i32 : i32
422    // CHECK-NEXT: }
423  }
424  // CHECK-NEXT: case 5 {
425  case 5 {
426    %c20_i32 = arith.constant 20 : i32
427    scf.yield %c20_i32 : i32
428  }
429  // CHECK: default {
430  default {
431    %c30_i32 = arith.constant 30 : i32
432    scf.yield %c30_i32 : i32
433  }
434
435  // CHECK: scf.index_switch %arg0
436  scf.index_switch %arg0
437  // CHECK-NEXT: default {
438  default {
439    scf.yield
440  }
441
442  return %0 : i32
443}
444