xref: /llvm-project/mlir/test/Dialect/SCF/canonicalize.mlir (revision 9d8e634e85ca46fbec07733d3e69d34c0d7814ac)
1// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' -split-input-file | FileCheck %s
2
3func.func @single_iteration_some(%A: memref<?x?x?xi32>) {
4  %c0 = arith.constant 0 : index
5  %c1 = arith.constant 1 : index
6  %c2 = arith.constant 2 : index
7  %c3 = arith.constant 3 : index
8  %c6 = arith.constant 6 : index
9  %c7 = arith.constant 7 : index
10  %c10 = arith.constant 10 : index
11  scf.parallel (%i0, %i1, %i2) = (%c0, %c3, %c7) to (%c1, %c6, %c10) step (%c1, %c2, %c3) {
12    %c42 = arith.constant 42 : i32
13    memref.store %c42, %A[%i0, %i1, %i2] : memref<?x?x?xi32>
14    scf.reduce
15  }
16  return
17}
18
19// CHECK-LABEL:   func @single_iteration_some(
20// CHECK-SAME:                        [[ARG0:%.*]]: memref<?x?x?xi32>) {
21// CHECK-DAG:           [[C42:%.*]] = arith.constant 42 : i32
22// CHECK-DAG:           [[C7:%.*]] = arith.constant 7 : index
23// CHECK-DAG:           [[C6:%.*]] = arith.constant 6 : index
24// CHECK-DAG:           [[C3:%.*]] = arith.constant 3 : index
25// CHECK-DAG:           [[C2:%.*]] = arith.constant 2 : index
26// CHECK-DAG:           [[C0:%.*]] = arith.constant 0 : index
27// CHECK:           scf.parallel ([[V0:%.*]]) = ([[C3]]) to ([[C6]]) step ([[C2]]) {
28// CHECK:             memref.store [[C42]], [[ARG0]]{{\[}}[[C0]], [[V0]], [[C7]]] : memref<?x?x?xi32>
29// CHECK:             scf.reduce
30// CHECK:           }
31// CHECK:           return
32
33// -----
34
35func.func @single_iteration_all(%A: memref<?x?x?xi32>) {
36  %c0 = arith.constant 0 : index
37  %c1 = arith.constant 1 : index
38  %c3 = arith.constant 3 : index
39  %c6 = arith.constant 6 : index
40  %c7 = arith.constant 7 : index
41  %c10 = arith.constant 10 : index
42  scf.parallel (%i0, %i1, %i2) = (%c0, %c3, %c7) to (%c1, %c6, %c10) step (%c1, %c3, %c3) {
43    %c42 = arith.constant 42 : i32
44    memref.store %c42, %A[%i0, %i1, %i2] : memref<?x?x?xi32>
45    scf.reduce
46  }
47  return
48}
49
50// CHECK-LABEL:   func @single_iteration_all(
51// CHECK-SAME:                        [[ARG0:%.*]]: memref<?x?x?xi32>) {
52// CHECK-DAG:           [[C42:%.*]] = arith.constant 42 : i32
53// CHECK-DAG:           [[C7:%.*]] = arith.constant 7 : index
54// CHECK-DAG:           [[C3:%.*]] = arith.constant 3 : index
55// CHECK-DAG:           [[C0:%.*]] = arith.constant 0 : index
56// CHECK-NOT:           scf.parallel
57// CHECK:               memref.store [[C42]], [[ARG0]]{{\[}}[[C0]], [[C3]], [[C7]]] : memref<?x?x?xi32>
58// CHECK-NOT:           scf.reduce
59// CHECK:               return
60
61// -----
62
63func.func @single_iteration_reduce(%A: index, %B: index) -> (index, index) {
64  %c0 = arith.constant 0 : index
65  %c1 = arith.constant 1 : index
66  %c2 = arith.constant 2 : index
67  %c3 = arith.constant 3 : index
68  %c6 = arith.constant 6 : index
69  %0:2 = scf.parallel (%i0, %i1) = (%c1, %c3) to (%c2, %c6) step (%c1, %c3) init(%A, %B) -> (index, index) {
70    scf.reduce(%i0, %i1 : index, index)  {
71    ^bb0(%lhs: index, %rhs: index):
72      %1 = arith.addi %lhs, %rhs : index
73      scf.reduce.return %1 : index
74    }, {
75    ^bb0(%lhs: index, %rhs: index):
76      %2 = arith.muli %lhs, %rhs : index
77      scf.reduce.return %2 : index
78    }
79  }
80  return %0#0, %0#1 : index, index
81}
82
83// CHECK-LABEL:   func @single_iteration_reduce(
84// CHECK-SAME:                        [[ARG0:%.*]]: index, [[ARG1:%.*]]: index)
85// CHECK-DAG:           [[C3:%.*]] = arith.constant 3 : index
86// CHECK-DAG:           [[C1:%.*]] = arith.constant 1 : index
87// CHECK-NOT:           scf.parallel
88// CHECK-NOT:           scf.reduce
89// CHECK-NOT:           scf.reduce.return
90// CHECK-NOT:           scf.yield
91// CHECK:               [[V0:%.*]] = arith.addi [[ARG0]], [[C1]]
92// CHECK:               [[V1:%.*]] = arith.muli [[ARG1]], [[C3]]
93// CHECK:               return [[V0]], [[V1]]
94
95// -----
96
97func.func @nested_parallel(%0: memref<?x?x?xf64>) -> memref<?x?x?xf64> {
98  %c0 = arith.constant 0 : index
99  %c1 = arith.constant 1 : index
100  %c2 = arith.constant 2 : index
101  %1 = memref.dim %0, %c0 : memref<?x?x?xf64>
102  %2 = memref.dim %0, %c1 : memref<?x?x?xf64>
103  %3 = memref.dim %0, %c2 : memref<?x?x?xf64>
104  %4 = memref.alloc(%1, %2, %3) : memref<?x?x?xf64>
105  scf.parallel (%arg1) = (%c0) to (%1) step (%c1) {
106    scf.parallel (%arg2) = (%c0) to (%2) step (%c1) {
107      scf.parallel (%arg3) = (%c0) to (%3) step (%c1) {
108        %5 = memref.load %0[%arg1, %arg2, %arg3] : memref<?x?x?xf64>
109        memref.store %5, %4[%arg1, %arg2, %arg3] : memref<?x?x?xf64>
110        scf.reduce
111      }
112      scf.reduce
113    }
114    scf.reduce
115  }
116  return %4 : memref<?x?x?xf64>
117}
118
119// CHECK-LABEL:   func @nested_parallel(
120// CHECK-DAG:       [[C0:%.*]] = arith.constant 0 : index
121// CHECK-DAG:       [[C1:%.*]] = arith.constant 1 : index
122// CHECK-DAG:       [[C2:%.*]] = arith.constant 2 : index
123// CHECK:           [[B0:%.*]] = memref.dim {{.*}}, [[C0]]
124// CHECK:           [[B1:%.*]] = memref.dim {{.*}}, [[C1]]
125// CHECK:           [[B2:%.*]] = memref.dim {{.*}}, [[C2]]
126// CHECK:           scf.parallel ([[V0:%.*]], [[V1:%.*]], [[V2:%.*]]) = ([[C0]], [[C0]], [[C0]]) to ([[B0]], [[B1]], [[B2]]) step ([[C1]], [[C1]], [[C1]])
127// CHECK:           memref.load {{.*}}{{\[}}[[V0]], [[V1]], [[V2]]]
128// CHECK:           memref.store {{.*}}{{\[}}[[V0]], [[V1]], [[V2]]]
129
130// -----
131
132func.func private @side_effect()
133func.func @one_unused(%cond: i1) -> (index) {
134  %0, %1 = scf.if %cond -> (index, index) {
135    func.call @side_effect() : () -> ()
136    %c0 = "test.value0"() : () -> (index)
137    %c1 = "test.value1"() : () -> (index)
138    scf.yield %c0, %c1 : index, index
139  } else {
140    %c2 = "test.value2"() : () -> (index)
141    %c3 = "test.value3"() : () -> (index)
142    scf.yield %c2, %c3 : index, index
143  }
144  return %1 : index
145}
146
147// CHECK-LABEL:   func @one_unused
148// CHECK:           [[V0:%.*]] = scf.if %{{.*}} -> (index) {
149// CHECK:             call @side_effect() : () -> ()
150// CHECK:             [[C1:%.*]] = "test.value1"
151// CHECK:             scf.yield [[C1]] : index
152// CHECK:           } else
153// CHECK:             [[C3:%.*]] = "test.value3"
154// CHECK:             scf.yield [[C3]] : index
155// CHECK:           }
156// CHECK:           return [[V0]] : index
157
158// -----
159
160func.func private @side_effect()
161func.func @nested_unused(%cond1: i1, %cond2: i1) -> (index) {
162  %0, %1 = scf.if %cond1 -> (index, index) {
163    %2, %3 = scf.if %cond2 -> (index, index) {
164      func.call @side_effect() : () -> ()
165      %c0 = "test.value0"() : () -> (index)
166      %c1 = "test.value1"() : () -> (index)
167      scf.yield %c0, %c1 : index, index
168    } else {
169      %c2 = "test.value2"() : () -> (index)
170      %c3 = "test.value3"() : () -> (index)
171      scf.yield %c2, %c3 : index, index
172    }
173    scf.yield %2, %3 : index, index
174  } else {
175    %c0 = "test.value0_2"() : () -> (index)
176    %c1 = "test.value1_2"() : () -> (index)
177    scf.yield %c0, %c1 : index, index
178  }
179  return %1 : index
180}
181
182// CHECK-LABEL:   func @nested_unused
183// CHECK:           [[V0:%.*]] = scf.if {{.*}} -> (index) {
184// CHECK:             [[V1:%.*]] = scf.if {{.*}} -> (index) {
185// CHECK:               call @side_effect() : () -> ()
186// CHECK:               [[C1:%.*]] = "test.value1"
187// CHECK:               scf.yield [[C1]] : index
188// CHECK:             } else
189// CHECK:               [[C3:%.*]] = "test.value3"
190// CHECK:               scf.yield [[C3]] : index
191// CHECK:             }
192// CHECK:             scf.yield [[V1]] : index
193// CHECK:           } else
194// CHECK:             [[C1_2:%.*]] = "test.value1_2"
195// CHECK:             scf.yield [[C1_2]] : index
196// CHECK:           }
197// CHECK:           return [[V0]] : index
198
199// -----
200
201func.func private @side_effect()
202func.func @all_unused(%cond: i1) {
203  %c0 = arith.constant 0 : index
204  %c1 = arith.constant 1 : index
205  %0, %1 = scf.if %cond -> (index, index) {
206    func.call @side_effect() : () -> ()
207    scf.yield %c0, %c1 : index, index
208  } else {
209    func.call @side_effect() : () -> ()
210    scf.yield %c0, %c1 : index, index
211  }
212  return
213}
214
215// CHECK-LABEL:   func @all_unused
216// CHECK:           scf.if %{{.*}} {
217// CHECK:             call @side_effect() : () -> ()
218// CHECK:           } else
219// CHECK:             call @side_effect() : () -> ()
220// CHECK:           }
221// CHECK:           return
222
223// -----
224
225func.func @empty_if1(%cond: i1) {
226  scf.if %cond {
227    scf.yield
228  }
229  return
230}
231
232// CHECK-LABEL:   func @empty_if1
233// CHECK-NOT:       scf.if
234// CHECK:           return
235
236// -----
237
238func.func @empty_if2(%cond: i1) {
239  scf.if %cond {
240    scf.yield
241  } else {
242    scf.yield
243  }
244  return
245}
246
247// CHECK-LABEL:   func @empty_if2
248// CHECK-NOT:       scf.if
249// CHECK:           return
250
251// -----
252
253func.func @empty_else(%cond: i1, %v : memref<i1>) {
254  scf.if %cond {
255    memref.store %cond, %v[] : memref<i1>
256  } else {
257  }
258  return
259}
260
261// CHECK-LABEL: func @empty_else
262// CHECK:         scf.if
263// CHECK-NOT:     else
264
265// -----
266
267func.func @to_select1(%cond: i1) -> index {
268  %c0 = arith.constant 0 : index
269  %c1 = arith.constant 1 : index
270  %0 = scf.if %cond -> index {
271    scf.yield %c0 : index
272  } else {
273    scf.yield %c1 : index
274  }
275  return %0 : index
276}
277
278// CHECK-LABEL:   func @to_select1
279// CHECK-DAG:       [[C0:%.*]] = arith.constant 0 : index
280// CHECK-DAG:       [[C1:%.*]] = arith.constant 1 : index
281// CHECK:           [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]]
282// CHECK:           return [[V0]] : index
283
284// -----
285
286func.func @to_select_same_val(%cond: i1) -> (index, index) {
287  %c0 = arith.constant 0 : index
288  %c1 = arith.constant 1 : index
289  %0, %1 = scf.if %cond -> (index, index) {
290    scf.yield %c0, %c1 : index, index
291  } else {
292    scf.yield %c1, %c1 : index, index
293  }
294  return %0, %1 : index, index
295}
296
297// CHECK-LABEL:   func @to_select_same_val
298// CHECK-DAG:       [[C0:%.*]] = arith.constant 0 : index
299// CHECK-DAG:       [[C1:%.*]] = arith.constant 1 : index
300// CHECK:           [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]]
301// CHECK:           return [[V0]], [[C1]] : index, index
302
303// -----
304
305func.func @to_select_with_body(%cond: i1) -> index {
306  %c0 = arith.constant 0 : index
307  %c1 = arith.constant 1 : index
308  %0 = scf.if %cond -> index {
309    "test.op"() : () -> ()
310    scf.yield %c0 : index
311  } else {
312    scf.yield %c1 : index
313  }
314  return %0 : index
315}
316
317// CHECK-LABEL:   func @to_select_with_body
318// CHECK-DAG:       [[C0:%.*]] = arith.constant 0 : index
319// CHECK-DAG:       [[C1:%.*]] = arith.constant 1 : index
320// CHECK:           [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]]
321// CHECK:           scf.if {{.*}} {
322// CHECK:             "test.op"() : () -> ()
323// CHECK:           }
324// CHECK:           return [[V0]] : index
325
326// -----
327
328func.func @to_select2(%cond: i1) -> (index, index) {
329  %c0 = arith.constant 0 : index
330  %c1 = arith.constant 1 : index
331  %c2 = arith.constant 2 : index
332  %c3 = arith.constant 3 : index
333  %0, %1 = scf.if %cond -> (index, index) {
334    scf.yield %c0, %c1 : index, index
335  } else {
336    scf.yield %c2, %c3 : index, index
337  }
338  return %0, %1 : index, index
339}
340
341// CHECK-LABEL:   func @to_select2
342// CHECK-DAG:       [[C0:%.*]] = arith.constant 0 : index
343// CHECK-DAG:       [[C1:%.*]] = arith.constant 1 : index
344// CHECK-DAG:       [[C2:%.*]] = arith.constant 2 : index
345// CHECK-DAG:       [[C3:%.*]] = arith.constant 3 : index
346// CHECK:           [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C2]]
347// CHECK:           [[V1:%.*]] = arith.select {{.*}}, [[C1]], [[C3]]
348// CHECK:           return [[V0]], [[V1]] : index
349
350// -----
351
352func.func private @make_i32() -> i32
353
354func.func @for_yields_2(%lb : index, %ub : index, %step : index) -> i32 {
355  %a = call @make_i32() : () -> (i32)
356  %b = scf.for %i = %lb to %ub step %step iter_args(%0 = %a) -> i32 {
357    scf.yield %0 : i32
358  }
359  return %b : i32
360}
361
362// CHECK-LABEL:   func @for_yields_2
363//  CHECK-NEXT:     %[[R:.*]] = call @make_i32() : () -> i32
364//  CHECK-NEXT:     return %[[R]] : i32
365
366// -----
367
368func.func private @make_i32() -> i32
369
370func.func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i32) {
371  %a = call @make_i32() : () -> (i32)
372  %b = call @make_i32() : () -> (i32)
373  %r:3 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %a, %2 = %b) -> (i32, i32, i32) {
374    %c = func.call @make_i32() : () -> (i32)
375    scf.yield %0, %c, %2 : i32, i32, i32
376  } {some_attr}
377  return %r#0, %r#1, %r#2 : i32, i32, i32
378}
379
380// CHECK-LABEL:   func @for_yields_3
381//  CHECK-NEXT:     %[[a:.*]] = call @make_i32() : () -> i32
382//  CHECK-NEXT:     %[[b:.*]] = call @make_i32() : () -> i32
383//  CHECK-NEXT:     %[[r1:.*]] = scf.for {{.*}} iter_args(%arg4 = %[[a]]) -> (i32) {
384//  CHECK-NEXT:       %[[c:.*]] = func.call @make_i32() : () -> i32
385//  CHECK-NEXT:       scf.yield %[[c]] : i32
386//  CHECK-NEXT:     } {some_attr}
387//  CHECK-NEXT:     return %[[a]], %[[r1]], %[[b]] : i32, i32, i32
388
389// -----
390
391// Test that an empty loop which iterates at least once and only returns
392// values defined outside of the loop is folded away.
393func.func @for_yields_4() -> i32 {
394  %c0 = arith.constant 0 : index
395  %c1 = arith.constant 1 : index
396  %c2 = arith.constant 2 : index
397  %a = arith.constant 3 : i32
398  %b = arith.constant 4 : i32
399  %r = scf.for %i = %c0 to %c2 step %c1 iter_args(%0 = %a) -> i32 {
400    scf.yield %b : i32
401  }
402  return %r : i32
403}
404
405// CHECK-LABEL:   func @for_yields_4
406//  CHECK-NEXT:     %[[b:.*]] = arith.constant 4 : i32
407//  CHECK-NEXT:     return %[[b]] : i32
408
409// -----
410
411// CHECK-LABEL: @constant_iter_arg
412func.func @constant_iter_arg(%arg0: index, %arg1: index, %arg2: index) {
413  %c0_i32 = arith.constant 0 : i32
414  // CHECK: scf.for %arg3 = %arg0 to %arg1 step %arg2 {
415  %0 = scf.for %i = %arg0 to %arg1 step %arg2 iter_args(%arg3 = %c0_i32) -> i32 {
416    // CHECK-NEXT: "test.use"(%c0_i32)
417    "test.use"(%arg3) : (i32) -> ()
418    scf.yield %c0_i32 : i32
419  }
420  return
421}
422
423// -----
424
425// CHECK-LABEL: @replace_true_if
426func.func @replace_true_if() {
427  %true = arith.constant true
428  // CHECK-NOT: scf.if
429  // CHECK: "test.op"
430  scf.if %true {
431    "test.op"() : () -> ()
432    scf.yield
433  }
434  return
435}
436
437// -----
438
439// CHECK-LABEL: @remove_false_if
440func.func @remove_false_if() {
441  %false = arith.constant false
442  // CHECK-NOT: scf.if
443  // CHECK-NOT: "test.op"
444  scf.if %false {
445    "test.op"() : () -> ()
446    scf.yield
447  }
448  return
449}
450
451// -----
452
453// CHECK-LABEL: @replace_true_if_with_values
454func.func @replace_true_if_with_values() {
455  %true = arith.constant true
456  // CHECK-NOT: scf.if
457  // CHECK: %[[VAL:.*]] = "test.op"
458  %0 = scf.if %true -> (i32) {
459    %1 = "test.op"() : () -> i32
460    scf.yield %1 : i32
461  } else {
462    %2 = "test.other_op"() : () -> i32
463    scf.yield %2 : i32
464  }
465  // CHECK: "test.consume"(%[[VAL]])
466  "test.consume"(%0) : (i32) -> ()
467  return
468}
469
470// -----
471
472// CHECK-LABEL: @replace_false_if_with_values
473func.func @replace_false_if_with_values() {
474  %false = arith.constant false
475  // CHECK-NOT: scf.if
476  // CHECK: %[[VAL:.*]] = "test.other_op"
477  %0 = scf.if %false -> (i32) {
478    %1 = "test.op"() : () -> i32
479    scf.yield %1 : i32
480  } else {
481    %2 = "test.other_op"() : () -> i32
482    scf.yield %2 : i32
483  }
484  // CHECK: "test.consume"(%[[VAL]])
485  "test.consume"(%0) : (i32) -> ()
486  return
487}
488
489// -----
490
491// CHECK-LABEL: @merge_nested_if
492// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
493func.func @merge_nested_if(%arg0: i1, %arg1: i1) {
494// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
495// CHECK: scf.if %[[COND]] {
496// CHECK-NEXT: "test.op"()
497  scf.if %arg0 {
498    scf.if %arg1 {
499      "test.op"() : () -> ()
500      scf.yield
501    }
502    scf.yield
503  }
504  return
505}
506
507// -----
508
509// CHECK-LABEL: @merge_yielding_nested_if
510// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
511func.func @merge_yielding_nested_if(%arg0: i1, %arg1: i1) -> (i32, f32, i32, i8) {
512// CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32
513// CHECK: %[[PRE1:.*]] = "test.op1"() : () -> f32
514// CHECK: %[[PRE2:.*]] = "test.op2"() : () -> i32
515// CHECK: %[[PRE3:.*]] = "test.op3"() : () -> i8
516// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
517// CHECK: %[[RES:.*]]:2 = scf.if %[[COND]] -> (f32, i32)
518// CHECK:   %[[IN0:.*]] = "test.inop"() : () -> i32
519// CHECK:   %[[IN1:.*]] = "test.inop1"() : () -> f32
520// CHECK:   scf.yield %[[IN1]], %[[IN0]] : f32, i32
521// CHECK: } else {
522// CHECK:   scf.yield %[[PRE1]], %[[PRE2]] : f32, i32
523// CHECK: }
524// CHECK: return %[[PRE0]], %[[RES]]#0, %[[RES]]#1, %[[PRE3]] : i32, f32, i32, i8
525  %0 = "test.op"() : () -> (i32)
526  %1 = "test.op1"() : () -> (f32)
527  %2 = "test.op2"() : () -> (i32)
528  %3 = "test.op3"() : () -> (i8)
529  %r:4 = scf.if %arg0 -> (i32, f32, i32, i8) {
530    %a:2 = scf.if %arg1 -> (i32, f32) {
531      %i = "test.inop"() : () -> (i32)
532      %i1 = "test.inop1"() : () -> (f32)
533      scf.yield %i, %i1 : i32, f32
534    } else {
535      scf.yield %2, %1 : i32, f32
536    }
537    scf.yield %0, %a#1, %a#0, %3 : i32, f32, i32, i8
538  } else {
539    scf.yield %0, %1, %2, %3 : i32, f32, i32, i8
540  }
541  return %r#0, %r#1, %r#2, %r#3 : i32, f32, i32, i8
542}
543
544// -----
545
546// CHECK-LABEL: @merge_yielding_nested_if_nv1
547// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
548func.func @merge_yielding_nested_if_nv1(%arg0: i1, %arg1: i1) {
549// CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32
550// CHECK: %[[PRE1:.*]] = "test.op1"() : () -> f32
551// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
552// CHECK: scf.if %[[COND]]
553// CHECK:   %[[IN0:.*]] = "test.inop"() : () -> i32
554// CHECK:   %[[IN1:.*]] = "test.inop1"() : () -> f32
555// CHECK: }
556  %0 = "test.op"() : () -> (i32)
557  %1 = "test.op1"() : () -> (f32)
558  scf.if %arg0 {
559    %a:2 = scf.if %arg1 -> (i32, f32) {
560      %i = "test.inop"() : () -> (i32)
561      %i1 = "test.inop1"() : () -> (f32)
562      scf.yield %i, %i1 : i32, f32
563    } else {
564      scf.yield %0, %1 : i32, f32
565    }
566  }
567  return
568}
569
570// -----
571
572// CHECK-LABEL: @merge_yielding_nested_if_nv2
573// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
574func.func @merge_yielding_nested_if_nv2(%arg0: i1, %arg1: i1) -> i32 {
575// CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32
576// CHECK: %[[PRE1:.*]] = "test.op1"() : () -> i32
577// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
578// CHECK: %[[RES:.*]] = arith.select %[[ARG0]], %[[PRE0]], %[[PRE1]]
579// CHECK: scf.if %[[COND]]
580// CHECK:   "test.run"() : () -> ()
581// CHECK: }
582// CHECK: return %[[RES]]
583  %0 = "test.op"() : () -> (i32)
584  %1 = "test.op1"() : () -> (i32)
585  %r = scf.if %arg0 -> i32 {
586    scf.if %arg1 {
587      "test.run"() : () -> ()
588    }
589    scf.yield %0 : i32
590  } else {
591    scf.yield %1 : i32
592  }
593  return %r : i32
594}
595
596// -----
597
598// CHECK-LABEL: @merge_fail_yielding_nested_if
599// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
600func.func @merge_fail_yielding_nested_if(%arg0: i1, %arg1: i1) -> (i32, f32, i32, i8) {
601// CHECK-NOT: andi
602  %0 = "test.op"() : () -> (i32)
603  %1 = "test.op1"() : () -> (f32)
604  %2 = "test.op2"() : () -> (i32)
605  %3 = "test.op3"() : () -> (i8)
606  %r:4 = scf.if %arg0 -> (i32, f32, i32, i8) {
607    %a:2 = scf.if %arg1 -> (i32, f32) {
608      %i = "test.inop"() : () -> (i32)
609      %i1 = "test.inop1"() : () -> (f32)
610      scf.yield %i, %i1 : i32, f32
611    } else {
612      scf.yield %0, %1 : i32, f32
613    }
614    scf.yield %0, %a#1, %a#0, %3 : i32, f32, i32, i8
615  } else {
616    scf.yield %0, %1, %2, %3 : i32, f32, i32, i8
617  }
618  return %r#0, %r#1, %r#2, %r#3 : i32, f32, i32, i8
619}
620
621// -----
622
623// CHECK-LABEL:   func @if_condition_swap
624// CHECK-NEXT:     %{{.*}} = scf.if %arg0 -> (index) {
625// CHECK-NEXT:       %[[i1:.+]] = "test.origFalse"() : () -> index
626// CHECK-NEXT:       scf.yield %[[i1]] : index
627// CHECK-NEXT:     } else {
628// CHECK-NEXT:       %[[i2:.+]] = "test.origTrue"() : () -> index
629// CHECK-NEXT:       scf.yield %[[i2]] : index
630// CHECK-NEXT:     }
631func.func @if_condition_swap(%cond: i1) -> index {
632  %true = arith.constant true
633  %not = arith.xori %cond, %true : i1
634  %0 = scf.if %not -> (index) {
635    %1 = "test.origTrue"() : () -> index
636    scf.yield %1 : index
637  } else {
638    %1 = "test.origFalse"() : () -> index
639    scf.yield %1 : index
640  }
641  return %0 : index
642}
643
644// -----
645
646// CHECK-LABEL: @remove_zero_iteration_loop
647func.func @remove_zero_iteration_loop() {
648  %c42 = arith.constant 42 : index
649  %c1 = arith.constant 1 : index
650  // CHECK: %[[INIT:.*]] = "test.init"
651  %init = "test.init"() : () -> i32
652  // CHECK-NOT: scf.for
653  %0 = scf.for %i = %c42 to %c1 step %c1 iter_args(%arg = %init) -> (i32) {
654    %1 = "test.op"(%i, %arg) : (index, i32) -> i32
655    scf.yield %1 : i32
656  }
657  // CHECK: "test.consume"(%[[INIT]])
658  "test.consume"(%0) : (i32) -> ()
659  return
660}
661
662// -----
663
664// CHECK-LABEL: @remove_zero_iteration_loop_vals
665func.func @remove_zero_iteration_loop_vals(%arg0: index) {
666  %c2 = arith.constant 2 : index
667  // CHECK: %[[INIT:.*]] = "test.init"
668  %init = "test.init"() : () -> i32
669  // CHECK-NOT: scf.for
670  // CHECK-NOT: test.op
671  %0 = scf.for %i = %arg0 to %arg0 step %c2 iter_args(%arg = %init) -> (i32) {
672    %1 = "test.op"(%i, %arg) : (index, i32) -> i32
673    scf.yield %1 : i32
674  }
675  // CHECK: "test.consume"(%[[INIT]])
676  "test.consume"(%0) : (i32) -> ()
677  return
678}
679
680// -----
681
682// CHECK-LABEL: @replace_single_iteration_loop_1
683func.func @replace_single_iteration_loop_1() {
684  // CHECK: %[[LB:.*]] = arith.constant 42
685  %c42 = arith.constant 42 : index
686  %c43 = arith.constant 43 : index
687  %c1 = arith.constant 1 : index
688  // CHECK: %[[INIT:.*]] = "test.init"
689  %init = "test.init"() : () -> i32
690  // CHECK-NOT: scf.for
691  // CHECK: %[[VAL:.*]] = "test.op"(%[[LB]], %[[INIT]])
692  %0 = scf.for %i = %c42 to %c43 step %c1 iter_args(%arg = %init) -> (i32) {
693    %1 = "test.op"(%i, %arg) : (index, i32) -> i32
694    scf.yield %1 : i32
695  }
696  // CHECK: "test.consume"(%[[VAL]])
697  "test.consume"(%0) : (i32) -> ()
698  return
699}
700
701// -----
702
703// CHECK-LABEL: @replace_single_iteration_loop_2
704func.func @replace_single_iteration_loop_2() {
705  // CHECK: %[[LB:.*]] = arith.constant 5
706  %c5 = arith.constant 5 : index
707  %c6 = arith.constant 6 : index
708  %c11 = arith.constant 11 : index
709  // CHECK: %[[INIT:.*]] = "test.init"
710  %init = "test.init"() : () -> i32
711  // CHECK-NOT: scf.for
712  // CHECK: %[[VAL:.*]] = "test.op"(%[[LB]], %[[INIT]])
713  %0 = scf.for %i = %c5 to %c11 step %c6 iter_args(%arg = %init) -> (i32) {
714    %1 = "test.op"(%i, %arg) : (index, i32) -> i32
715    scf.yield %1 : i32
716  }
717  // CHECK: "test.consume"(%[[VAL]])
718  "test.consume"(%0) : (i32) -> ()
719  return
720}
721
722// -----
723
724// CHECK-LABEL: @replace_single_iteration_loop_non_unit_step
725func.func @replace_single_iteration_loop_non_unit_step() {
726  // CHECK: %[[LB:.*]] = arith.constant 42
727  %c42 = arith.constant 42 : index
728  %c47 = arith.constant 47 : index
729  %c5 = arith.constant 5 : index
730  // CHECK: %[[INIT:.*]] = "test.init"
731  %init = "test.init"() : () -> i32
732  // CHECK-NOT: scf.for
733  // CHECK: %[[VAL:.*]] = "test.op"(%[[LB]], %[[INIT]])
734  %0 = scf.for %i = %c42 to %c47 step %c5 iter_args(%arg = %init) -> (i32) {
735    %1 = "test.op"(%i, %arg) : (index, i32) -> i32
736    scf.yield %1 : i32
737  }
738  // CHECK: "test.consume"(%[[VAL]])
739  "test.consume"(%0) : (i32) -> ()
740  return
741}
742
743
744// -----
745
746// CHECK-LABEL: func @replace_single_iteration_const_diff(
747//  CHECK-SAME: %[[A0:.*]]: index)
748func.func @replace_single_iteration_const_diff(%arg0 : index) {
749  // CHECK-NEXT: %[[CST:.*]] = arith.constant 2
750  %c1 = arith.constant 1 : index
751  %c2 = arith.constant 2 : index
752  %5 = arith.addi %arg0, %c1 : index
753  // CHECK-NOT: scf.for
754  scf.for %arg2 = %arg0 to %5 step %c1 {
755    // CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[A0]], %[[CST]]
756    %7 = arith.muli %c2, %arg2 : index
757    // CHECK-NEXT: "test.consume"(%[[MUL]])
758    "test.consume"(%7) : (index) -> ()
759  }
760  return
761}
762
763// -----
764
765// CHECK-LABEL: @remove_empty_parallel_loop
766func.func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) {
767  // CHECK: %[[INIT:.*]] = "test.init"
768  %init = "test.init"() : () -> f32
769  // CHECK-NOT: scf.parallel
770  // CHECK-NOT: test.produce
771  // CHECK-NOT: test.transform
772  %0 = scf.parallel (%i, %j, %k) = (%lb, %ub, %lb) to (%ub, %ub, %ub) step (%s, %s, %s) init(%init) -> f32 {
773    %1 = "test.produce"() : () -> f32
774    scf.reduce(%1 : f32) {
775    ^bb0(%lhs: f32, %rhs: f32):
776      %2 = "test.transform"(%lhs, %rhs) : (f32, f32) -> f32
777      scf.reduce.return %2 : f32
778    }
779  }
780  // CHECK: "test.consume"(%[[INIT]])
781  "test.consume"(%0) : (f32) -> ()
782  return
783}
784
785// -----
786
787// CHECK-LABEL: fold_away_iter_with_no_use_and_yielded_input
788//  CHECK-SAME:   %[[A0:[0-9a-z]*]]: i32
789func.func @fold_away_iter_with_no_use_and_yielded_input(%arg0 : i32,
790                    %ub : index, %lb : index, %step : index) -> (i32, i32) {
791  // CHECK-NEXT: %[[C32:.*]] = arith.constant 32 : i32
792  %cst = arith.constant 32 : i32
793  // CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args({{.*}} = %[[A0]]) -> (i32) {
794  %0:2 = scf.for %arg1 = %lb to %ub step %step iter_args(%arg2 = %arg0, %arg3 = %cst)
795    -> (i32, i32) {
796    %1 = arith.addi %arg2, %cst : i32
797    scf.yield %1, %cst : i32, i32
798  }
799
800  // CHECK: return %[[FOR_RES]], %[[C32]] : i32, i32
801  return %0#0, %0#1 : i32, i32
802}
803
804// -----
805
806// CHECK-LABEL: fold_away_iter_and_result_with_no_use
807//  CHECK-SAME:   %[[A0:[0-9a-z]*]]: i32
808func.func @fold_away_iter_and_result_with_no_use(%arg0 : i32,
809                    %ub : index, %lb : index, %step : index) -> (i32) {
810  %cst = arith.constant 32 : i32
811  // CHECK: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args({{.*}} = %[[A0]]) -> (i32) {
812  %0:2 = scf.for %arg1 = %lb to %ub step %step iter_args(%arg2 = %arg0, %arg3 = %cst)
813    -> (i32, i32) {
814    %1 = arith.addi %arg2, %cst : i32
815    scf.yield %1, %1 : i32, i32
816  }
817
818  // CHECK: return %[[FOR_RES]] : i32
819  return %0#0 : i32
820}
821
822// -----
823
824func.func private @do(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
825
826func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> {
827  %c0 = arith.constant 0 : index
828  %c32 = arith.constant 32 : index
829  %c1024 = arith.constant 1024 : index
830  %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
831  %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) -> (tensor<?x?xf32>) {
832    %2 = func.call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
833    scf.yield %2 : tensor<?x?xf32>
834  } {some_attr}
835  return %1 : tensor<?x?xf32>
836}
837// CHECK-LABEL: matmul_on_tensors
838//  CHECK-SAME:   %[[T0:[0-9a-z]*]]: tensor<32x1024xf32>
839
840//   CHECK-NOT: tensor.cast
841//       CHECK: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args(%[[ITER_T0:.*]] = %[[T0]]) -> (tensor<32x1024xf32>) {
842//       CHECK:   %[[CAST:.*]] = tensor.cast %[[ITER_T0]] : tensor<32x1024xf32> to tensor<?x?xf32>
843//       CHECK:   %[[DONE:.*]] = func.call @do(%[[CAST]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
844//       CHECK:   %[[UNCAST:.*]] = tensor.cast %[[DONE]] : tensor<?x?xf32> to tensor<32x1024xf32>
845//       CHECK:   scf.yield %[[UNCAST]] : tensor<32x1024xf32>
846//       CHECK: } {some_attr}
847//       CHECK: %[[RES:.*]] = tensor.cast
848//       CHECK: return %[[RES]] : tensor<?x?xf32>
849
850// -----
851
852// CHECK-LABEL: @cond_prop
853func.func @cond_prop(%arg0 : i1) -> index {
854  %res = scf.if %arg0 -> index {
855    %res1 = scf.if %arg0 -> index {
856      %v1 = "test.get_some_value1"() : () -> index
857      scf.yield %v1 : index
858    } else {
859      %v2 = "test.get_some_value2"() : () -> index
860      scf.yield %v2 : index
861    }
862    scf.yield %res1 : index
863  } else {
864    %res2 = scf.if %arg0 -> index {
865      %v3 = "test.get_some_value3"() : () -> index
866      scf.yield %v3 : index
867    } else {
868      %v4 = "test.get_some_value4"() : () -> index
869      scf.yield %v4 : index
870    }
871    scf.yield %res2 : index
872  }
873  return %res : index
874}
875// CHECK-NEXT:  %[[if:.+]] = scf.if %arg0 -> (index) {
876// CHECK-NEXT:    %[[c1:.+]] = "test.get_some_value1"() : () -> index
877// CHECK-NEXT:    scf.yield %[[c1]] : index
878// CHECK-NEXT:  } else {
879// CHECK-NEXT:    %[[c4:.+]] = "test.get_some_value4"() : () -> index
880// CHECK-NEXT:    scf.yield %[[c4]] : index
881// CHECK-NEXT:  }
882// CHECK-NEXT:  return %[[if]] : index
883// CHECK-NEXT:}
884
885// -----
886
887// CHECK-LABEL: @replace_if_with_cond1
888func.func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
889  %true = arith.constant true
890  %false = arith.constant false
891  %res:2 = scf.if %arg0 -> (i32, i1) {
892    %v = "test.get_some_value"() : () -> i32
893    scf.yield %v, %true : i32, i1
894  } else {
895    %v2 = "test.get_some_value"() : () -> i32
896    scf.yield %v2, %false : i32, i1
897  }
898  return %res#0, %res#1 : i32, i1
899}
900// CHECK-NEXT:    %[[if:.+]] = scf.if %arg0 -> (i32) {
901// CHECK-NEXT:      %[[sv1:.+]] = "test.get_some_value"() : () -> i32
902// CHECK-NEXT:      scf.yield %[[sv1]] : i32
903// CHECK-NEXT:    } else {
904// CHECK-NEXT:      %[[sv2:.+]] = "test.get_some_value"() : () -> i32
905// CHECK-NEXT:      scf.yield %[[sv2]] : i32
906// CHECK-NEXT:    }
907// CHECK-NEXT:    return %[[if]], %arg0 : i32, i1
908
909// -----
910
911// CHECK-LABEL: @replace_if_with_cond2
912func.func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) {
913  %true = arith.constant true
914  %false = arith.constant false
915  %res:2 = scf.if %arg0 -> (i32, i1) {
916    %v = "test.get_some_value"() : () -> i32
917    scf.yield %v, %false : i32, i1
918  } else {
919    %v2 = "test.get_some_value"() : () -> i32
920    scf.yield %v2, %true : i32, i1
921  }
922  return %res#0, %res#1 : i32, i1
923}
924// CHECK-NEXT:     %true = arith.constant true
925// CHECK-NEXT:     %[[toret:.+]] = arith.xori %arg0, %true : i1
926// CHECK-NEXT:     %[[if:.+]] = scf.if %arg0 -> (i32) {
927// CHECK-NEXT:       %[[sv1:.+]] = "test.get_some_value"() : () -> i32
928// CHECK-NEXT:       scf.yield %[[sv1]] : i32
929// CHECK-NEXT:     } else {
930// CHECK-NEXT:       %[[sv2:.+]] = "test.get_some_value"() : () -> i32
931// CHECK-NEXT:       scf.yield %[[sv2]] : i32
932// CHECK-NEXT:     }
933// CHECK-NEXT:     return %[[if]], %[[toret]] : i32, i1
934
935// -----
936
937// CHECK-LABEL: @replace_if_with_cond3
938func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
939  %res:2 = scf.if %arg0 -> (i32, i64) {
940    %v = "test.get_some_value"() : () -> i32
941    scf.yield %v, %arg2 : i32, i64
942  } else {
943    %v2 = "test.get_some_value"() : () -> i32
944    scf.yield %v2, %arg2 : i32, i64
945  }
946  return %res#0, %res#1 : i32, i64
947}
948// CHECK-NEXT:     %[[if:.+]] = scf.if %arg0 -> (i32) {
949// CHECK-NEXT:       %[[sv1:.+]] = "test.get_some_value"() : () -> i32
950// CHECK-NEXT:       scf.yield %[[sv1]] : i32
951// CHECK-NEXT:     } else {
952// CHECK-NEXT:       %[[sv2:.+]] = "test.get_some_value"() : () -> i32
953// CHECK-NEXT:       scf.yield %[[sv2]] : i32
954// CHECK-NEXT:     }
955// CHECK-NEXT:     return %[[if]], %arg1 : i32, i64
956
957// -----
958
959// CHECK-LABEL: @while_cond_true
960func.func @while_cond_true() -> i1 {
961  %0 = scf.while () : () -> i1 {
962    %condition = "test.condition"() : () -> i1
963    scf.condition(%condition) %condition : i1
964  } do {
965  ^bb0(%arg0: i1):
966    "test.use"(%arg0) : (i1) -> ()
967    scf.yield
968  }
969  return %0 : i1
970}
971// CHECK-NEXT:         %[[true:.+]] = arith.constant true
972// CHECK-NEXT:         %{{.+}} = scf.while : () -> i1 {
973// CHECK-NEXT:           %[[cmp:.+]] = "test.condition"() : () -> i1
974// CHECK-NEXT:           scf.condition(%[[cmp]]) %[[cmp]] : i1
975// CHECK-NEXT:         } do {
976// CHECK-NEXT:         ^bb0(%arg0: i1):
977// CHECK-NEXT:           "test.use"(%[[true]]) : (i1) -> ()
978// CHECK-NEXT:           scf.yield
979// CHECK-NEXT:         }
980
981// -----
982
983// CHECK-LABEL: @invariant_loop_args_in_same_order
984// CHECK-SAME: (%[[FUNC_ARG0:.*]]: tensor<i32>)
985func.func @invariant_loop_args_in_same_order(%f_arg0: tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
986  %cst_0 = arith.constant dense<0> : tensor<i32>
987  %cst_1 = arith.constant dense<1> : tensor<i32>
988  %cst_42 = arith.constant dense<42> : tensor<i32>
989
990  %0:5 = scf.while (%arg0 = %cst_0, %arg1 = %f_arg0, %arg2 = %cst_1, %arg3 = %cst_1, %arg4 = %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
991    %1 = arith.cmpi slt, %arg0, %cst_42 : tensor<i32>
992    %2 = tensor.extract %1[] : tensor<i1>
993    scf.condition(%2) %arg0, %arg1, %arg2, %arg3, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
994  } do {
995  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors
996    // %arg1 here will get replaced by %cst_1
997    %1 = arith.addi %arg0, %arg1 : tensor<i32>
998    %2 = arith.addi %arg2, %arg3 : tensor<i32>
999    scf.yield %1, %arg1, %2, %2, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
1000  }
1001  return %0#0, %0#1, %0#2, %0#3, %0#4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
1002}
1003// CHECK:    %[[ZERO:.*]] = arith.constant dense<0>
1004// CHECK:    %[[ONE:.*]] = arith.constant dense<1>
1005// CHECK:    %[[CST42:.*]] = arith.constant dense<42>
1006// CHECK:    %[[WHILE:.*]]:3 = scf.while (%[[ARG0:.*]] = %[[ZERO]], %[[ARG2:.*]] = %[[ONE]], %[[ARG3:.*]] = %[[ONE]])
1007// CHECK:       arith.cmpi slt, %[[ARG0]], %{{.*}}
1008// CHECK:       tensor.extract %{{.*}}[]
1009// CHECK:       scf.condition(%{{.*}}) %[[ARG0]], %[[ARG2]], %[[ARG3]]
1010// CHECK:    } do {
1011// CHECK:     ^{{.*}}(%[[ARG0:.*]]: tensor<i32>, %[[ARG2:.*]]: tensor<i32>, %[[ARG3:.*]]: tensor<i32>):
1012// CHECK:       %[[VAL0:.*]] = arith.addi %[[ARG0]], %[[FUNC_ARG0]]
1013// CHECK:       %[[VAL1:.*]] = arith.addi %[[ARG2]], %[[ARG3]]
1014// CHECK:       scf.yield %[[VAL0]], %[[VAL1]], %[[VAL1]]
1015// CHECK:    }
1016// CHECK:    return %[[WHILE]]#0, %[[FUNC_ARG0]], %[[WHILE]]#1, %[[WHILE]]#2, %[[ZERO]]
1017
1018// CHECK-LABEL: @while_loop_invariant_argument_different_order
1019func.func @while_loop_invariant_argument_different_order(%arg : tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
1020  %cst_0 = arith.constant dense<0> : tensor<i32>
1021  %cst_1 = arith.constant dense<1> : tensor<i32>
1022  %cst_42 = arith.constant dense<42> : tensor<i32>
1023
1024  %0:6 = scf.while (%arg0 = %cst_0, %arg1 = %cst_1, %arg2 = %cst_1, %arg3 = %cst_1, %arg4 = %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
1025    %1 = arith.cmpi slt, %arg0, %arg : tensor<i32>
1026    %2 = tensor.extract %1[] : tensor<i1>
1027    scf.condition(%2) %arg1, %arg0, %arg2, %arg0, %arg3, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
1028  } do {
1029  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<i32>): // no predecessors
1030    %1 = arith.addi %arg0, %cst_1 : tensor<i32>
1031    %2 = arith.addi %arg2, %arg3 : tensor<i32>
1032    scf.yield %arg3, %arg1, %2, %2, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
1033  }
1034  return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
1035}
1036// CHECK-SAME: (%[[ARG:.+]]: tensor<i32>)
1037// CHECK:    %[[ZERO:.*]] = arith.constant dense<0>
1038// CHECK:    %[[ONE:.*]] = arith.constant dense<1>
1039// CHECK:    %[[WHILE:.*]]:2 = scf.while (%[[ARG1:.*]] = %[[ONE]], %[[ARG4:.*]] = %[[ZERO]])
1040// CHECK:       arith.cmpi sgt, %[[ARG]], %[[ZERO]]
1041// CHECK:       tensor.extract %{{.*}}[]
1042// CHECK:       scf.condition(%{{.*}}) %[[ARG1]], %[[ARG4]]
1043// CHECK:    } do {
1044// CHECK:     ^{{.*}}(%{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>):
1045// CHECK:       scf.yield %[[ZERO]], %[[ONE]]
1046// CHECK:    }
1047// CHECK:    return %[[WHILE]]#0, %[[ZERO]], %[[ONE]], %[[ZERO]], %[[ONE]], %[[WHILE]]#1
1048
1049// -----
1050
1051// CHECK-LABEL: @while_unused_result
1052func.func @while_unused_result() -> i32 {
1053  %0:2 = scf.while () : () -> (i32, i64) {
1054    %condition = "test.condition"() : () -> i1
1055    %v1 = "test.get_some_value"() : () -> i32
1056    %v2 = "test.get_some_value"() : () -> i64
1057    scf.condition(%condition) %v1, %v2 : i32, i64
1058  } do {
1059  ^bb0(%arg0: i32, %arg1: i64):
1060    "test.use"(%arg0) : (i32) -> ()
1061    scf.yield
1062  }
1063  return %0#0 : i32
1064}
1065// CHECK-NEXT:         %[[res:.*]] = scf.while : () -> i32 {
1066// CHECK-NEXT:           %[[cmp:.*]] = "test.condition"() : () -> i1
1067// CHECK-NEXT:           %[[val:.*]] = "test.get_some_value"() : () -> i32
1068// CHECK-NEXT:           %{{.*}} = "test.get_some_value"() : () -> i64
1069// CHECK-NEXT:           scf.condition(%[[cmp]]) %[[val]] : i32
1070// CHECK-NEXT:         } do {
1071// CHECK-NEXT:         ^bb0(%[[arg:.*]]: i32):
1072// CHECK-NEXT:           "test.use"(%[[arg]]) : (i32) -> ()
1073// CHECK-NEXT:           scf.yield
1074// CHECK-NEXT:         }
1075// CHECK-NEXT:         return %[[res]] : i32
1076
1077// -----
1078
1079// CHECK-LABEL: @while_cmp_lhs
1080func.func @while_cmp_lhs(%arg0 : i32) {
1081  %0 = scf.while () : () -> i32 {
1082    %val = "test.val"() : () -> i32
1083    %condition = arith.cmpi ne, %val, %arg0 : i32
1084    scf.condition(%condition) %val : i32
1085  } do {
1086  ^bb0(%val2: i32):
1087    %condition2 = arith.cmpi ne, %val2, %arg0 : i32
1088    %negcondition2 = arith.cmpi eq, %val2, %arg0 : i32
1089    "test.use"(%condition2, %negcondition2, %val2) : (i1, i1, i32) -> ()
1090    scf.yield
1091  }
1092  return
1093}
1094// CHECK-DAG:         %[[true:.+]] = arith.constant true
1095// CHECK-DAG:         %[[false:.+]] = arith.constant false
1096// CHECK-DAG:         %{{.+}} = scf.while : () -> i32 {
1097// CHECK-NEXT:         %[[val:.+]] = "test.val"
1098// CHECK-NEXT:         %[[cmp:.+]] = arith.cmpi ne, %[[val]], %arg0 : i32
1099// CHECK-NEXT:           scf.condition(%[[cmp]]) %[[val]] : i32
1100// CHECK-NEXT:         } do {
1101// CHECK-NEXT:         ^bb0(%arg1: i32):
1102// CHECK-NEXT:           "test.use"(%[[true]], %[[false]], %arg1) : (i1, i1, i32) -> ()
1103// CHECK-NEXT:           scf.yield
1104// CHECK-NEXT:         }
1105
1106// -----
1107
1108// CHECK-LABEL: @while_cmp_rhs
1109func.func @while_cmp_rhs(%arg0 : i32) {
1110  %0 = scf.while () : () -> i32 {
1111    %val = "test.val"() : () -> i32
1112    %condition = arith.cmpi ne, %arg0, %val : i32
1113    scf.condition(%condition) %val : i32
1114  } do {
1115  ^bb0(%val2: i32):
1116    %condition2 = arith.cmpi ne, %arg0, %val2 : i32
1117    %negcondition2 = arith.cmpi eq, %arg0, %val2 : i32
1118    "test.use"(%condition2, %negcondition2, %val2) : (i1, i1, i32) -> ()
1119    scf.yield
1120  }
1121  return
1122}
1123// CHECK-DAG:         %[[true:.+]] = arith.constant true
1124// CHECK-DAG:         %[[false:.+]] = arith.constant false
1125// CHECK-DAG:         %{{.+}} = scf.while : () -> i32 {
1126// CHECK-NEXT:         %[[val:.+]] = "test.val"
1127// CHECK-NEXT:         %[[cmp:.+]] = arith.cmpi ne, %arg0, %[[val]] : i32
1128// CHECK-NEXT:           scf.condition(%[[cmp]]) %[[val]] : i32
1129// CHECK-NEXT:         } do {
1130// CHECK-NEXT:         ^bb0(%arg1: i32):
1131// CHECK-NEXT:           "test.use"(%[[true]], %[[false]], %arg1) : (i1, i1, i32) -> ()
1132// CHECK-NEXT:           scf.yield
1133// CHECK-NEXT:         }
1134
1135// -----
1136
1137// CHECK-LABEL: @while_duplicated_res
1138func.func @while_duplicated_res() -> (i32, i32) {
1139  %0:2 = scf.while () : () -> (i32, i32) {
1140    %val = "test.val"() : () -> i32
1141    %condition = "test.condition"() : () -> i1
1142    scf.condition(%condition) %val, %val : i32, i32
1143  } do {
1144  ^bb0(%val2: i32, %val3: i32):
1145    "test.use"(%val2, %val3) : (i32, i32) -> ()
1146    scf.yield
1147  }
1148  return %0#0, %0#1: i32, i32
1149}
1150// CHECK:         %[[RES:.*]] = scf.while : () -> i32 {
1151// CHECK:         %[[VAL:.*]] = "test.val"() : () -> i32
1152// CHECK:         %[[COND:.*]] = "test.condition"() : () -> i1
1153// CHECK:           scf.condition(%[[COND]]) %[[VAL]] : i32
1154// CHECK:         } do {
1155// CHECK:         ^bb0(%[[ARG:.*]]: i32):
1156// CHECK:           "test.use"(%[[ARG]], %[[ARG]]) : (i32, i32) -> ()
1157// CHECK:           scf.yield
1158// CHECK:         }
1159// CHECK:         return %[[RES]], %[[RES]] : i32, i32
1160
1161
1162// -----
1163
1164// CHECK-LABEL: @while_unused_arg1
1165func.func @while_unused_arg1(%x : i32, %y : f64) -> i32 {
1166  %0 = scf.while (%arg1 = %x, %arg2 = %y) : (i32, f64) -> (i32) {
1167    %condition = "test.condition"(%arg1) : (i32) -> i1
1168    scf.condition(%condition) %arg1 : i32
1169  } do {
1170  ^bb0(%arg1: i32):
1171    %next = "test.use"(%arg1) : (i32) -> (i32)
1172    scf.yield %next, %y : i32, f64
1173  }
1174  return %0 : i32
1175}
1176// CHECK-NEXT:         %[[res:.*]] = scf.while (%[[arg2:.*]] = %{{.*}}) : (i32) -> i32 {
1177// CHECK-NEXT:           %[[cmp:.*]] = "test.condition"(%[[arg2]]) : (i32) -> i1
1178// CHECK-NEXT:           scf.condition(%[[cmp]]) %[[arg2]] : i32
1179// CHECK-NEXT:         } do {
1180// CHECK-NEXT:         ^bb0(%[[post:.*]]: i32):
1181// CHECK-NEXT:           %[[next:.*]] = "test.use"(%[[post]]) : (i32) -> i32
1182// CHECK-NEXT:           scf.yield %[[next]] : i32
1183// CHECK-NEXT:         }
1184// CHECK-NEXT:         return %[[res]] : i32
1185
1186
1187// -----
1188
1189// CHECK-LABEL: @while_unused_arg2
1190func.func @while_unused_arg2(%val0: i32) -> i32 {
1191  %0 = scf.while (%val1 = %val0) : (i32) -> i32 {
1192    %val = "test.val"() : () -> i32
1193    %condition = "test.condition"() : () -> i1
1194    scf.condition(%condition) %val: i32
1195  } do {
1196  ^bb0(%val2: i32):
1197    "test.use"(%val2) : (i32) -> ()
1198    %val1 = "test.val1"() : () -> i32
1199    scf.yield %val1 : i32
1200  }
1201  return %0 : i32
1202}
1203// CHECK:         %[[RES:.*]] = scf.while : () -> i32 {
1204// CHECK:         %[[VAL:.*]] = "test.val"() : () -> i32
1205// CHECK:         %[[COND:.*]] = "test.condition"() : () -> i1
1206// CHECK:           scf.condition(%[[COND]]) %[[VAL]] : i32
1207// CHECK:         } do {
1208// CHECK:         ^bb0(%[[ARG:.*]]: i32):
1209// CHECK:           "test.use"(%[[ARG]]) : (i32) -> ()
1210// CHECK:           scf.yield
1211// CHECK:         }
1212// CHECK:         return %[[RES]] : i32
1213
1214
1215// -----
1216
1217// CHECK-LABEL: func @test_align_args
1218//       CHECK:  %[[RES:.*]]:3 = scf.while (%[[ARG0:.*]] = %{{.*}}, %[[ARG1:.*]] = %{{.*}}, %[[ARG2:.*]] = %{{.*}}) : (f32, i32, i64) -> (f32, i32, i64) {
1219//       CHECK:  scf.condition(%{{.*}}) %[[ARG0]], %[[ARG1]], %[[ARG2]] : f32, i32, i64
1220//       CHECK:  ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i64):
1221//       CHECK:  %[[R1:.*]] = "test.test"(%[[ARG5]]) : (i64) -> f32
1222//       CHECK:  %[[R2:.*]] = "test.test"(%[[ARG3]]) : (f32) -> i32
1223//       CHECK:  %[[R3:.*]] = "test.test"(%[[ARG4]]) : (i32) -> i64
1224//       CHECK:  scf.yield %[[R1]], %[[R2]], %[[R3]] : f32, i32, i64
1225//       CHECK:  return %[[RES]]#2, %[[RES]]#0, %[[RES]]#1
1226func.func @test_align_args() -> (i64, f32, i32) {
1227  %0 = "test.test"() : () -> (f32)
1228  %1 = "test.test"() : () -> (i32)
1229  %2 = "test.test"() : () -> (i64)
1230  %3:3 = scf.while (%arg0 = %0, %arg1 = %1, %arg2 = %2) : (f32, i32, i64) -> (i64, f32, i32) {
1231    %cond = "test.test"() : () -> (i1)
1232    scf.condition(%cond) %arg2, %arg0, %arg1 : i64, f32, i32
1233  } do {
1234  ^bb0(%arg3: i64, %arg4: f32, %arg5: i32):
1235    %4 = "test.test"(%arg3) : (i64) -> (f32)
1236    %5 = "test.test"(%arg4) : (f32) -> (i32)
1237    %6 = "test.test"(%arg5) : (i32) -> (i64)
1238    scf.yield %4, %5, %6 : f32, i32, i64
1239  }
1240  return %3#0, %3#1, %3#2 : i64, f32, i32
1241}
1242
1243
1244// -----
1245
1246// CHECK-LABEL: @combineIfs
1247func.func @combineIfs(%arg0 : i1, %arg2: i64) -> (i32, i32) {
1248  %res = scf.if %arg0 -> i32 {
1249    %v = "test.firstCodeTrue"() : () -> i32
1250    scf.yield %v : i32
1251  } else {
1252    %v2 = "test.firstCodeFalse"() : () -> i32
1253    scf.yield %v2 : i32
1254  }
1255  %res2 = scf.if %arg0 -> i32 {
1256    %v = "test.secondCodeTrue"() : () -> i32
1257    scf.yield %v : i32
1258  } else {
1259    %v2 = "test.secondCodeFalse"() : () -> i32
1260    scf.yield %v2 : i32
1261  }
1262  return %res, %res2 : i32, i32
1263}
1264// CHECK-NEXT:     %[[res:.+]]:2 = scf.if %arg0 -> (i32, i32) {
1265// CHECK-NEXT:       %[[tval0:.+]] = "test.firstCodeTrue"() : () -> i32
1266// CHECK-NEXT:       %[[tval:.+]] = "test.secondCodeTrue"() : () -> i32
1267// CHECK-NEXT:       scf.yield %[[tval0]], %[[tval]] : i32, i32
1268// CHECK-NEXT:     } else {
1269// CHECK-NEXT:       %[[fval0:.+]] = "test.firstCodeFalse"() : () -> i32
1270// CHECK-NEXT:       %[[fval:.+]] = "test.secondCodeFalse"() : () -> i32
1271// CHECK-NEXT:       scf.yield %[[fval0]], %[[fval]] : i32, i32
1272// CHECK-NEXT:     }
1273// CHECK-NEXT:     return %[[res]]#0, %[[res]]#1 : i32, i32
1274
1275// -----
1276
1277// CHECK-LABEL: @combineIfs2
1278func.func @combineIfs2(%arg0 : i1, %arg2: i64) -> i32 {
1279  scf.if %arg0 {
1280    "test.firstCodeTrue"() : () -> ()
1281    scf.yield
1282  }
1283  %res = scf.if %arg0 -> i32 {
1284    %v = "test.secondCodeTrue"() : () -> i32
1285    scf.yield %v : i32
1286  } else {
1287    %v2 = "test.secondCodeFalse"() : () -> i32
1288    scf.yield %v2 : i32
1289  }
1290  return %res : i32
1291}
1292// CHECK-NEXT:     %[[res:.+]] = scf.if %arg0 -> (i32) {
1293// CHECK-NEXT:       "test.firstCodeTrue"() : () -> ()
1294// CHECK-NEXT:       %[[tval:.+]] = "test.secondCodeTrue"() : () -> i32
1295// CHECK-NEXT:       scf.yield %[[tval]] : i32
1296// CHECK-NEXT:     } else {
1297// CHECK-NEXT:       %[[fval:.+]] = "test.secondCodeFalse"() : () -> i32
1298// CHECK-NEXT:       scf.yield %[[fval]] : i32
1299// CHECK-NEXT:     }
1300// CHECK-NEXT:     return %[[res]] : i32
1301
1302// -----
1303
1304// CHECK-LABEL: @combineIfs3
1305func.func @combineIfs3(%arg0 : i1, %arg2: i64) -> i32 {
1306  %res = scf.if %arg0 -> i32 {
1307    %v = "test.firstCodeTrue"() : () -> i32
1308    scf.yield %v : i32
1309  } else {
1310    %v2 = "test.firstCodeFalse"() : () -> i32
1311    scf.yield %v2 : i32
1312  }
1313  scf.if %arg0 {
1314    "test.secondCodeTrue"() : () -> ()
1315    scf.yield
1316  }
1317  return %res : i32
1318}
1319// CHECK-NEXT:     %[[res:.+]] = scf.if %arg0 -> (i32) {
1320// CHECK-NEXT:       %[[tval:.+]] = "test.firstCodeTrue"() : () -> i32
1321// CHECK-NEXT:       "test.secondCodeTrue"() : () -> ()
1322// CHECK-NEXT:       scf.yield %[[tval]] : i32
1323// CHECK-NEXT:     } else {
1324// CHECK-NEXT:       %[[fval:.+]] = "test.firstCodeFalse"() : () -> i32
1325// CHECK-NEXT:       scf.yield %[[fval]] : i32
1326// CHECK-NEXT:     }
1327// CHECK-NEXT:     return %[[res]] : i32
1328
1329// -----
1330
1331// CHECK-LABEL: @combineIfs4
1332func.func @combineIfs4(%arg0 : i1, %arg2: i64) {
1333  scf.if %arg0 {
1334    "test.firstCodeTrue"() : () -> ()
1335    scf.yield
1336  }
1337  scf.if %arg0 {
1338    "test.secondCodeTrue"() : () -> ()
1339    scf.yield
1340  }
1341  return
1342}
1343
1344// CHECK-NEXT:     scf.if %arg0 {
1345// CHECK-NEXT:       "test.firstCodeTrue"() : () -> ()
1346// CHECK-NEXT:       "test.secondCodeTrue"() : () -> ()
1347// CHECK-NEXT:     }
1348
1349// -----
1350
1351// CHECK-LABEL: @combineIfsUsed
1352// CHECK-SAME: %[[arg0:.+]]: i1
1353func.func @combineIfsUsed(%arg0 : i1, %arg2: i64) -> (i32, i32) {
1354  %res = scf.if %arg0 -> i32 {
1355    %v = "test.firstCodeTrue"() : () -> i32
1356    scf.yield %v : i32
1357  } else {
1358    %v2 = "test.firstCodeFalse"() : () -> i32
1359    scf.yield %v2 : i32
1360  }
1361  %res2 = scf.if %arg0 -> i32 {
1362    %v = "test.secondCodeTrue"(%res) : (i32) -> i32
1363    scf.yield %v : i32
1364  } else {
1365    %v2 = "test.secondCodeFalse"(%res) : (i32) -> i32
1366    scf.yield %v2 : i32
1367  }
1368  return %res, %res2 : i32, i32
1369}
1370// CHECK-NEXT:     %[[res:.+]]:2 = scf.if %[[arg0]] -> (i32, i32) {
1371// CHECK-NEXT:       %[[tval0:.+]] = "test.firstCodeTrue"() : () -> i32
1372// CHECK-NEXT:       %[[tval:.+]] = "test.secondCodeTrue"(%[[tval0]]) : (i32) -> i32
1373// CHECK-NEXT:       scf.yield %[[tval0]], %[[tval]] : i32, i32
1374// CHECK-NEXT:     } else {
1375// CHECK-NEXT:       %[[fval0:.+]] = "test.firstCodeFalse"() : () -> i32
1376// CHECK-NEXT:       %[[fval:.+]] = "test.secondCodeFalse"(%[[fval0]]) : (i32) -> i32
1377// CHECK-NEXT:       scf.yield %[[fval0]], %[[fval]] : i32, i32
1378// CHECK-NEXT:     }
1379// CHECK-NEXT:     return %[[res]]#0, %[[res]]#1 : i32, i32
1380
1381// -----
1382
1383// CHECK-LABEL: @combineIfsNot
1384// CHECK-SAME: %[[arg0:.+]]: i1
1385func.func @combineIfsNot(%arg0 : i1, %arg2: i64) {
1386  %true = arith.constant true
1387  %not = arith.xori %arg0, %true : i1
1388  scf.if %arg0 {
1389    "test.firstCodeTrue"() : () -> ()
1390    scf.yield
1391  }
1392  scf.if %not {
1393    "test.secondCodeTrue"() : () -> ()
1394    scf.yield
1395  }
1396  return
1397}
1398
1399// CHECK-NEXT:     scf.if %[[arg0]] {
1400// CHECK-NEXT:       "test.firstCodeTrue"() : () -> ()
1401// CHECK-NEXT:     } else {
1402// CHECK-NEXT:       "test.secondCodeTrue"() : () -> ()
1403// CHECK-NEXT:     }
1404
1405// -----
1406
1407// CHECK-LABEL: @combineIfsNot2
1408// CHECK-SAME: %[[arg0:.+]]: i1
1409func.func @combineIfsNot2(%arg0 : i1, %arg2: i64) {
1410  %true = arith.constant true
1411  %not = arith.xori %arg0, %true : i1
1412  scf.if %not {
1413    "test.firstCodeTrue"() : () -> ()
1414    scf.yield
1415  }
1416  scf.if %arg0 {
1417    "test.secondCodeTrue"() : () -> ()
1418    scf.yield
1419  }
1420  return
1421}
1422
1423// CHECK-NEXT:     scf.if %[[arg0]] {
1424// CHECK-NEXT:       "test.secondCodeTrue"() : () -> ()
1425// CHECK-NEXT:     } else {
1426// CHECK-NEXT:       "test.firstCodeTrue"() : () -> ()
1427// CHECK-NEXT:     }
1428
1429// -----
1430
1431// CHECK-LABEL: func @propagate_into_execute_region
1432func.func @propagate_into_execute_region() {
1433  %cond = arith.constant 0 : i1
1434  affine.for %i = 0 to 100 {
1435    "test.foo"() : () -> ()
1436    %v = scf.execute_region -> i64 {
1437      cf.cond_br %cond, ^bb1, ^bb2
1438
1439    ^bb1:
1440      %c1 = arith.constant 1 : i64
1441      cf.br ^bb3(%c1 : i64)
1442
1443    ^bb2:
1444      %c2 = arith.constant 2 : i64
1445      cf.br ^bb3(%c2 : i64)
1446
1447    ^bb3(%x : i64):
1448      scf.yield %x : i64
1449    }
1450    "test.bar"(%v) : (i64) -> ()
1451    // CHECK:      %[[C2:.*]] = arith.constant 2 : i64
1452    // CHECK: "test.foo"
1453    // CHECK-NEXT: "test.bar"(%[[C2]]) : (i64) -> ()
1454  }
1455  return
1456}
1457
1458// -----
1459
1460// CHECK-LABEL: func @execute_region_elim
1461func.func @execute_region_elim() {
1462  affine.for %i = 0 to 100 {
1463    "test.foo"() : () -> ()
1464    %v = scf.execute_region -> i64 {
1465      %x = "test.val"() : () -> i64
1466      scf.yield %x : i64
1467    }
1468    "test.bar"(%v) : (i64) -> ()
1469  }
1470  return
1471}
1472
1473// CHECK-NEXT:     affine.for %arg0 = 0 to 100 {
1474// CHECK-NEXT:       "test.foo"() : () -> ()
1475// CHECK-NEXT:       %[[VAL:.*]] = "test.val"() : () -> i64
1476// CHECK-NEXT:       "test.bar"(%[[VAL]]) : (i64) -> ()
1477// CHECK-NEXT:     }
1478
1479// -----
1480
1481// CHECK-LABEL: func @func_execute_region_elim
1482func.func @func_execute_region_elim() {
1483    "test.foo"() : () -> ()
1484    %v = scf.execute_region -> i64 {
1485      %c = "test.cmp"() : () -> i1
1486      cf.cond_br %c, ^bb2, ^bb3
1487    ^bb2:
1488      %x = "test.val1"() : () -> i64
1489      cf.br ^bb4(%x : i64)
1490    ^bb3:
1491      %y = "test.val2"() : () -> i64
1492      cf.br ^bb4(%y : i64)
1493    ^bb4(%z : i64):
1494      scf.yield %z : i64
1495    }
1496    "test.bar"(%v) : (i64) -> ()
1497  return
1498}
1499
1500// CHECK-NOT: execute_region
1501// CHECK:     "test.foo"
1502// CHECK:     %[[cmp:.+]] = "test.cmp"
1503// CHECK:     cf.cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]]
1504// CHECK:   ^[[bb1]]:
1505// CHECK:     %[[x:.+]] = "test.val1"
1506// CHECK:     cf.br ^[[bb3:.+]](%[[x]] : i64)
1507// CHECK:   ^[[bb2]]:
1508// CHECK:     %[[y:.+]] = "test.val2"
1509// CHECK:     cf.br ^[[bb3]](%[[y:.+]] : i64)
1510// CHECK:   ^[[bb3]](%[[z:.+]]: i64):
1511// CHECK:     "test.bar"(%[[z]])
1512// CHECK:     return
1513
1514// -----
1515
1516// CHECK-LABEL: func @func_execute_region_elim_multi_yield
1517func.func @func_execute_region_elim_multi_yield() {
1518    "test.foo"() : () -> ()
1519    %v = scf.execute_region -> i64 {
1520      %c = "test.cmp"() : () -> i1
1521      cf.cond_br %c, ^bb2, ^bb3
1522    ^bb2:
1523      %x = "test.val1"() : () -> i64
1524      scf.yield %x : i64
1525    ^bb3:
1526      %y = "test.val2"() : () -> i64
1527      scf.yield %y : i64
1528    }
1529    "test.bar"(%v) : (i64) -> ()
1530  return
1531}
1532
1533// CHECK-NOT: execute_region
1534// CHECK:     "test.foo"
1535// CHECK:     %[[cmp:.+]] = "test.cmp"
1536// CHECK:     cf.cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]]
1537// CHECK:   ^[[bb1]]:
1538// CHECK:     %[[x:.+]] = "test.val1"
1539// CHECK:     cf.br ^[[bb3:.+]](%[[x]] : i64)
1540// CHECK:   ^[[bb2]]:
1541// CHECK:     %[[y:.+]] = "test.val2"
1542// CHECK:     cf.br ^[[bb3]](%[[y:.+]] : i64)
1543// CHECK:   ^[[bb3]](%[[z:.+]]: i64):
1544// CHECK:     "test.bar"(%[[z]])
1545// CHECK:     return
1546
1547// -----
1548
1549// CHECK-LABEL: func @canonicalize_parallel_insert_slice_indices(
1550//  CHECK-SAME:     %[[arg0:.*]]: tensor<1x5xf32>, %[[arg1:.*]]: tensor<?x?xf32>
1551func.func @canonicalize_parallel_insert_slice_indices(
1552    %arg0 : tensor<1x5xf32>, %arg1: tensor<?x?xf32>, %num_threads : index) -> index
1553{
1554  // CHECK: %[[c1:.*]] = arith.constant 1 : index
1555  %c1 = arith.constant 1 : index
1556
1557  %2 = scf.forall (%tidx) in (%num_threads) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) {
1558    scf.forall.in_parallel {
1559      tensor.parallel_insert_slice %arg0 into %o[%tidx, 0] [1, 5] [1, 1] : tensor<1x5xf32> into tensor<?x?xf32>
1560    }
1561  }
1562
1563  // CHECK: %[[dim:.*]] = tensor.dim %[[arg1]], %[[c1]]
1564  %dim = tensor.dim %2, %c1 : tensor<?x?xf32>
1565  // CHECK: return %[[dim]]
1566  return %dim : index
1567}
1568
1569// -----
1570
1571// CHECK-LABEL: func @forall_fold_control_operands
1572func.func @forall_fold_control_operands(
1573    %arg0 : tensor<?x10xf32>, %arg1: tensor<?x10xf32>) -> tensor<?x10xf32> {
1574  %c0 = arith.constant 0 : index
1575  %c1 = arith.constant 1 : index
1576  %dim0 = tensor.dim %arg0, %c0 : tensor<?x10xf32>
1577  %dim1 = tensor.dim %arg0, %c1 : tensor<?x10xf32>
1578
1579  %result = scf.forall (%i, %j) = (%c0, %c0) to (%dim0, %dim1)
1580      step (%c1, %c1) shared_outs(%o = %arg1) -> (tensor<?x10xf32>) {
1581    %slice = tensor.extract_slice %arg1[%i, %j] [1, 1] [1, 1]
1582      : tensor<?x10xf32> to tensor<1x1xf32>
1583
1584    scf.forall.in_parallel {
1585      tensor.parallel_insert_slice %slice into %o[%i, %j] [1, 1] [1, 1]
1586        : tensor<1x1xf32> into tensor<?x10xf32>
1587    }
1588  }
1589
1590  return %result : tensor<?x10xf32>
1591}
1592// CHECK: forall (%{{.*}}, %{{.*}}) in (%{{.*}}, 10)
1593
1594// -----
1595
1596func.func @inline_forall_loop(%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
1597  %c8 = arith.constant 8 : index
1598  %c0 = arith.constant 0 : index
1599  %c1 = arith.constant 1 : index
1600  %cst = arith.constant 0.000000e+00 : f32
1601  %0 = tensor.empty() : tensor<8x8xf32>
1602  %1 = scf.forall (%i, %j) = (%c0, %c0) to (%c1, %c1)
1603        step (%c8, %c8) shared_outs (%out_ = %0) -> (tensor<8x8xf32>) {
1604    %slice = tensor.extract_slice %out_[%i, %j] [2, 3] [1, 1]
1605      : tensor<8x8xf32> to tensor<2x3xf32>
1606    %fill = linalg.fill ins(%cst : f32) outs(%slice : tensor<2x3xf32>)
1607          -> tensor<2x3xf32>
1608    scf.forall.in_parallel {
1609      tensor.parallel_insert_slice %fill into %out_[%i, %j] [2, 3] [1, 1]
1610        : tensor<2x3xf32> into tensor<8x8xf32>
1611    }
1612  }
1613  return %1 : tensor<8x8xf32>
1614}
1615// CHECK-LABEL: @inline_forall_loop
1616// CHECK-NOT:     scf.forall
1617// CHECK:         %[[OUT:.*]] = tensor.empty
1618
1619// CHECK-NEXT:    %[[SLICE:.*]] = tensor.extract_slice %[[OUT]]
1620// CHECK-SAME:      : tensor<8x8xf32> to tensor<2x3xf32>
1621
1622// CHECK-NEXT:    %[[FILL:.*]] = linalg.fill
1623// CHECK-SAME:      outs(%[[SLICE]]
1624
1625// CHECK-NEXT:    tensor.insert_slice %[[FILL]]
1626// CHECK-SAME:      : tensor<2x3xf32> into tensor<8x8xf32>
1627
1628// -----
1629
1630func.func @do_not_inline_distributed_forall_loop(
1631    %in: tensor<8x8xf32>) -> tensor<8x8xf32> {
1632  %cst = arith.constant 0.000000e+00 : f32
1633  %0 = tensor.empty() : tensor<8x8xf32>
1634  %1 = scf.forall (%i, %j) = (0, 4) to (1, 5) step (8, 8)
1635      shared_outs (%out_ = %0) -> (tensor<8x8xf32>) {
1636    %slice = tensor.extract_slice %out_[%i, %j] [2, 3] [1, 1]
1637      : tensor<8x8xf32> to tensor<2x3xf32>
1638    %fill = linalg.fill ins(%cst : f32) outs(%slice : tensor<2x3xf32>)
1639          -> tensor<2x3xf32>
1640    scf.forall.in_parallel {
1641      tensor.parallel_insert_slice %fill into %out_[%i, %j] [2, 3] [1, 1]
1642        : tensor<2x3xf32> into tensor<8x8xf32>
1643    }
1644  }{ mapping = [#gpu.thread<y>, #gpu.thread<x>] }
1645  return %1 : tensor<8x8xf32>
1646}
1647// CHECK-LABEL: @do_not_inline_distributed_forall_loop
1648// CHECK: scf.forall
1649// CHECK:   tensor.extract_slice %{{.*}}[0, 4] [2, 3] [1, 1]
1650// CHECK:   tensor.parallel_insert_slice %{{.*}}[0, 4] [2, 3] [1, 1]
1651
1652// -----
1653
1654func.func @inline_empty_loop_with_empty_mapping(
1655    %in: tensor<16xf32>) -> tensor<16xf32> {
1656  %cst = arith.constant 0.000000e+00 : f32
1657  %0 = tensor.empty() : tensor<16xf32>
1658  %1 = scf.forall () in () shared_outs (%out_ = %0) -> (tensor<16xf32>) {
1659    %slice = tensor.extract_slice %out_[0] [16] [1]
1660      : tensor<16xf32> to tensor<16xf32>
1661    %generic = linalg.generic {
1662        indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
1663        iterator_types = ["parallel"]}
1664        ins(%slice : tensor<16xf32>) outs(%0 : tensor<16xf32>) {
1665      ^bb0(%b0 : f32, %b1 : f32):
1666        %2 = arith.addf %b0, %b0 : f32
1667        linalg.yield %2 : f32
1668    } -> tensor<16xf32>
1669    scf.forall.in_parallel {
1670      tensor.parallel_insert_slice %generic into %out_[0] [16] [1]
1671        : tensor<16xf32> into tensor<16xf32>
1672    }
1673  }{ mapping = [] }
1674  return %1 : tensor<16xf32>
1675}
1676// CHECK-LABEL: func @inline_empty_loop_with_empty_mapping
1677//   CHECK-NOT:   scf.forall
1678
1679// -----
1680
1681func.func @collapse_one_dim_parallel(%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
1682  %c8 = arith.constant 8 : index
1683  %c0 = arith.constant 0 : index
1684  %c1 = arith.constant 1 : index
1685  %c16 = arith.constant 16 : index
1686  %cst = arith.constant 0.000000e+00 : f32
1687  %0 = tensor.empty() : tensor<8x8xf32>
1688  %1 = scf.forall (%i, %j) = (0, %c0) to (1, %c16)
1689        step (8, %c8) shared_outs (%out_ = %0) -> (tensor<8x8xf32>) {
1690    %fill = linalg.fill ins(%cst : f32) outs(%out_ : tensor<8x8xf32>)
1691          -> tensor<8x8xf32>
1692    scf.forall.in_parallel {
1693      tensor.parallel_insert_slice %fill into %out_[%i, %j] [8, 8] [1, 1]
1694        : tensor<8x8xf32> into tensor<8x8xf32>
1695    }
1696  }
1697  return %1 : tensor<8x8xf32>
1698}
1699// CHECK-LABEL: @collapse_one_dim_parallel
1700// CHECK:         scf.forall (%[[ARG:.*]]) = (0) to (16) step (8)
1701// CHECK:           linalg.fill
1702// CHECK:           tensor.parallel_insert_slice
1703
1704// -----
1705
1706func.func @remove_empty_forall(%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
1707  %c8 = arith.constant 8 : index
1708  %c0 = arith.constant 0 : index
1709  %c1 = arith.constant 1 : index
1710  %c16 = arith.constant 16 : index
1711  %cst = arith.constant 0.000000e+00 : f32
1712  %0 = tensor.empty() : tensor<8x8xf32>
1713  %1 = scf.forall (%i, %j) = (%c0, %c16) to (%c1, %c16)
1714        step (%c8, %c8) shared_outs (%out_ = %0) -> (tensor<8x8xf32>) {
1715    %fill = linalg.fill ins(%cst : f32) outs(%out_ : tensor<8x8xf32>)
1716          -> tensor<8x8xf32>
1717    scf.forall.in_parallel {
1718      tensor.parallel_insert_slice %fill into %out_[%i, %j] [8, 8] [1, 1]
1719        : tensor<8x8xf32> into tensor<8x8xf32>
1720    }
1721  }
1722  return %1 : tensor<8x8xf32>
1723}
1724// CHECK-LABEL: @remove_empty_forall
1725// CHECK-NOT:   scf.forall
1726// CHECK:       %[[EMPTY:.*]] = tensor.empty
1727// CHECK:       return %[[EMPTY]]
1728
1729// -----
1730
1731func.func @fold_tensor_cast_into_forall(
1732    %in: tensor<2xi32>, %out: tensor<2xi32>) -> tensor<2xi32> {
1733  %cst = arith.constant dense<[100500]> : tensor<1xi32>
1734
1735
1736  %out_cast = tensor.cast %out : tensor<2xi32> to tensor<?xi32>
1737  %result = scf.forall (%i) = (0) to (2) step (1)
1738      shared_outs (%out_ = %out_cast) -> tensor<?xi32> {
1739
1740    scf.forall.in_parallel {
1741      tensor.parallel_insert_slice %cst into %out_[%i] [1] [1]
1742        : tensor<1xi32> into tensor<?xi32>
1743    }
1744  }
1745  %result_cast = tensor.cast %result : tensor<?xi32> to tensor<2xi32>
1746  func.return %result_cast : tensor<2xi32>
1747}
1748// CHECK-LABEL: @fold_tensor_cast_into_forall
1749// CHECK-NOT:     tensor.cast
1750// CHECK:         parallel_insert_slice
1751// CHECK-SAME:      : tensor<1xi32> into tensor<2xi32>
1752// CHECK-NOT:     tensor.cast
1753
1754// -----
1755
1756func.func @do_not_fold_tensor_cast_from_dynamic_to_static_type_into_forall(
1757    %in: tensor<?xi32>, %out: tensor<?xi32>) -> tensor<?xi32> {
1758  %cst = arith.constant dense<[100500]> : tensor<1xi32>
1759
1760
1761  %out_cast = tensor.cast %out : tensor<?xi32> to tensor<2xi32>
1762  %result = scf.forall (%i) = (0) to (2) step (1)
1763      shared_outs (%out_ = %out_cast) -> tensor<2xi32> {
1764
1765    scf.forall.in_parallel {
1766      tensor.parallel_insert_slice %cst into %out_[%i] [1] [1]
1767        : tensor<1xi32> into tensor<2xi32>
1768    }
1769  }
1770  %result_cast = tensor.cast %result : tensor<2xi32> to tensor<?xi32>
1771  func.return %result_cast : tensor<?xi32>
1772}
1773// CHECK-LABEL: @do_not_fold_tensor_cast_
1774// CHECK:         tensor.cast
1775// CHECK:         parallel_insert_slice
1776// CHECK-SAME:      : tensor<1xi32> into tensor<2xi32>
1777// CHECK:         tensor.cast
1778
1779// -----
1780
1781#map = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
1782#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
1783#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
1784module {
1785  func.func @fold_iter_args_not_being_modified_within_scfforall(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
1786    %c0 = arith.constant 0 : index
1787    %cst = arith.constant 4.200000e+01 : f32
1788    %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
1789    %dim = tensor.dim %arg1, %c0 : tensor<?xf32>
1790    %1 = affine.apply #map()[%dim, %arg0]
1791    %2:2 = scf.forall (%arg3) in (%1) shared_outs(%arg4 = %arg1, %arg5 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
1792      %3 = affine.apply #map1(%arg3)[%arg0]
1793      %4 = affine.min #map2(%arg3)[%dim, %arg0]
1794      %extracted_slice0 = tensor.extract_slice %arg4[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
1795      %extracted_slice1 = tensor.extract_slice %arg5[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
1796      %5 = linalg.elemwise_unary ins(%extracted_slice0 : tensor<?xf32>) outs(%extracted_slice1 : tensor<?xf32>) -> tensor<?xf32>
1797      scf.forall.in_parallel {
1798        tensor.parallel_insert_slice %5 into %arg5[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
1799      }
1800    }
1801    return %2#0, %2#1 : tensor<?xf32>, tensor<?xf32>
1802  }
1803}
1804// CHECK-LABEL: @fold_iter_args_not_being_modified_within_scfforall
1805//  CHECK-SAME:   (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
1806//       CHECK:    %[[RESULT:.*]] = scf.forall
1807//  CHECK-SAME:                       shared_outs(%[[ITER_ARG_5:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
1808//       CHECK:      %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
1809//       CHECK:      %[[OPERAND1:.*]] = tensor.extract_slice %[[ITER_ARG_5]]
1810//       CHECK:      %[[ELEM:.*]] = linalg.elemwise_unary ins(%[[OPERAND0]] : tensor<?xf32>) outs(%[[OPERAND1]] : tensor<?xf32>) -> tensor<?xf32>
1811//       CHECK:      scf.forall.in_parallel {
1812//  CHECK-NEXT:         tensor.parallel_insert_slice %[[ELEM]] into %[[ITER_ARG_5]]
1813//  CHECK-NEXT:      }
1814//  CHECK-NEXT:    }
1815//  CHECK-NEXT:    return %[[ARG1]], %[[RESULT]]
1816
1817// -----
1818
1819#map = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
1820#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
1821#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
1822module {
1823  func.func @fold_iter_args_with_no_use_of_result_scfforall(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>) -> tensor<?xf32> {
1824    %cst = arith.constant 4.200000e+01 : f32
1825    %c0 = arith.constant 0 : index
1826    %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
1827    %dim = tensor.dim %arg1, %c0 : tensor<?xf32>
1828    %1 = affine.apply #map()[%dim, %arg0]
1829    %2:3 = scf.forall (%arg4) in (%1) shared_outs(%arg5 = %arg1, %arg6 = %arg2, %arg7 = %arg3) -> (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
1830      %3 = affine.apply #map1(%arg4)[%arg0]
1831      %4 = affine.min #map2(%arg4)[%dim, %arg0]
1832      %extracted_slice = tensor.extract_slice %arg5[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
1833      %extracted_slice_0 = tensor.extract_slice %arg6[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
1834      %extracted_slice_1 = tensor.extract_slice %arg7[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
1835      %extracted_slice_2 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
1836      %5 = linalg.elemwise_unary ins(%extracted_slice : tensor<?xf32>) outs(%extracted_slice_1 : tensor<?xf32>) -> tensor<?xf32>
1837      scf.forall.in_parallel {
1838        tensor.parallel_insert_slice %5 into %arg6[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
1839        tensor.parallel_insert_slice %extracted_slice into %arg5[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
1840        tensor.parallel_insert_slice %extracted_slice_0 into %arg7[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
1841        tensor.parallel_insert_slice %5 into %arg7[%4] [%3] [1] : tensor<?xf32> into tensor<?xf32>
1842      }
1843    }
1844    return %2#1 : tensor<?xf32>
1845  }
1846}
1847// CHECK-LABEL: @fold_iter_args_with_no_use_of_result_scfforall
1848//  CHECK-SAME:   (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>, %[[ARG3:.*]]: tensor<?xf32>) -> tensor<?xf32> {
1849//       CHECK:    %[[RESULT:.*]] = scf.forall
1850//  CHECK-SAME:                       shared_outs(%[[ITER_ARG_6:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
1851//       CHECK:      %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
1852//       CHECK:      %[[OPERAND1:.*]] = tensor.extract_slice %[[ARG3]]
1853//       CHECK:      %[[ELEM:.*]] = linalg.elemwise_unary ins(%[[OPERAND0]] : tensor<?xf32>) outs(%[[OPERAND1]] : tensor<?xf32>) -> tensor<?xf32>
1854//       CHECK:      scf.forall.in_parallel {
1855//  CHECK-NEXT:         tensor.parallel_insert_slice %[[ELEM]] into %[[ITER_ARG_6]]
1856//  CHECK-NEXT:      }
1857//  CHECK-NEXT:    }
1858//  CHECK-NEXT:    return %[[RESULT]]
1859
1860// -----
1861
1862func.func @index_switch_fold() -> (f32, f32) {
1863  %switch_cst = arith.constant 1: index
1864  %0 = scf.index_switch %switch_cst -> f32
1865  case 1 {
1866    %y = arith.constant 1.0 : f32
1867    scf.yield %y : f32
1868  }
1869  default {
1870    %y = arith.constant 42.0 : f32
1871    scf.yield %y : f32
1872  }
1873
1874  %switch_cst_2 = arith.constant 2: index
1875  %1 = scf.index_switch %switch_cst_2 -> f32
1876  case 0 {
1877    %y = arith.constant 0.0 : f32
1878    scf.yield %y : f32
1879  }
1880  default {
1881    %y = arith.constant 42.0 : f32
1882    scf.yield %y : f32
1883  }
1884
1885  return %0, %1 : f32, f32
1886}
1887
1888// CHECK-LABEL: func.func @index_switch_fold()
1889//  CHECK-NEXT:   %[[c1:.*]] = arith.constant 1.000000e+00 : f32
1890//  CHECK-NEXT:   %[[c42:.*]] = arith.constant 4.200000e+01 : f32
1891//  CHECK-NEXT:   return %[[c1]], %[[c42]] : f32, f32
1892
1893// -----
1894
1895func.func @index_switch_fold_no_res() {
1896  %c1 = arith.constant 1 : index
1897  scf.index_switch %c1
1898  case 0 {
1899    scf.yield
1900  }
1901  default {
1902    "test.op"() : () -> ()
1903    scf.yield
1904  }
1905  return
1906}
1907
1908// CHECK-LABEL: func.func @index_switch_fold_no_res()
1909//  CHECK-NEXT: "test.op"() : () -> ()
1910