xref: /llvm-project/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir (revision 441b672bbdc68ad88036f3e258759854c8283adb)
1// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(linalg-detensorize{aggressive-mode}))" | FileCheck %s -check-prefix=DET-ALL
2// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(linalg-detensorize))" | FileCheck %s -check-prefix=DET-CF
3
4#map0 = affine_map<() -> ()>
5#map1 = affine_map<(i) -> ()>
6#map2 = affine_map<(i) -> (i)>
7
8#attrs = {
9  indexing_maps = [#map0, #map0, #map0],
10  iterator_types = []
11}
12
13#sum_reduction_attrs = {
14  indexing_maps = [#map2, #map1],
15  iterator_types = ["reduction"]
16}
17
18
19#broadcast_attrs = {
20  indexing_maps = [#map1, #map2],
21  iterator_types = ["parallel"]
22}
23
24func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attributes {} {
25  cf.br ^bb1(%farg0 : tensor<10xi32>)
26
27^bb1(%0: tensor<10xi32>):  // 2 preds: ^bb0, ^bb2
28  %1 = tensor.empty() : tensor<i32>
29  %2 = linalg.generic #sum_reduction_attrs
30    ins(%0: tensor<10xi32>)
31    outs(%1: tensor<i32>) {
32      ^bb(%a: i32, %x: i32):
33        %b = arith.addi %x, %a : i32
34        linalg.yield %b : i32
35  } -> tensor<i32>
36
37  %3 = tensor.empty() : tensor<i1>
38  %4 = linalg.generic #attrs
39    ins(%2, %farg1 : tensor<i32>, tensor<i32>)
40    outs(%3 : tensor<i1>) {
41    ^bb0(%arg0: i32, %arg1: i32, %arg2: i1):
42      %8 = arith.cmpi slt, %arg0, %arg1 : i32
43      linalg.yield %8 : i1
44  } -> tensor<i1>
45  %5 = tensor.extract %4[] : tensor<i1>
46  cf.cond_br %5, ^bb2(%2 : tensor<i32>), ^bb3(%2 : tensor<i32>)
47
48^bb2(%6: tensor<i32>):  // pred: ^bb1
49  %7 = tensor.empty() : tensor<10xi32>
50  %9 = linalg.generic #broadcast_attrs
51       ins(%6: tensor<i32>)
52      outs(%7: tensor<10xi32>) {
53    ^bb(%a: i32, %b: i32) :
54      linalg.yield %a : i32
55  } -> tensor<10xi32>
56
57  cf.br ^bb1(%9 : tensor<10xi32>)
58
59^bb3(%10: tensor<i32>):  // pred: ^bb1
60  return %10 : tensor<i32>
61}
62
63// Test aggresively detensoring all detensorable ops.
64//
65// DET-ALL-LABEL: func @main
66// DET-ALL-SAME:    (%{{.*}}: tensor<10xi32>, %{{.*}}: tensor<i32>)
67// DET-ALL:         cf.br ^[[bb1:.*]](%{{.*}} : tensor<10xi32>)
68// DET-ALL:       ^[[bb1]](%{{.*}}: tensor<10xi32>)
69// DET-ALL:         tensor.empty() : tensor<i32>
70// DET-ALL:         linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) {
71// DET-ALL:         ^bb0(%{{.*}}: i32, %{{.*}}: i32):
72// DET-ALL:           %{{.*}} = arith.addi %{{.*}}, %{{.*}}
73// DET-ALL:           linalg.yield %{{.*}} : i32
74// DET-ALL:         } -> tensor<i32>
75// DET-ALL:         tensor.extract %{{.*}}[] : tensor<i32>
76// DET-ALL:         cmpi slt, %{{.*}}, %{{.*}} : i32
77// DET-ALL:         cf.cond_br %{{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
78// DET-ALL:       ^[[bb2]]:
79// DET-ALL:         tensor.from_elements %{{.*}} : tensor<i32>
80// DET-ALL:         tensor.empty() : tensor<10xi32>
81// DET-ALL:         linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
82// DET-ALL:         ^bb0(%{{.*}}: i32, %{{.*}}: i32):
83// DET-ALL:           linalg.yield %{{.*}} : i32
84// DET-ALL:         } -> tensor<10xi32>
85// DET-ALL:         cf.br ^[[bb1]](%{{.*}} : tensor<10xi32>)
86// DET-ALL:       ^[[bb3]]
87// DET-ALL:         tensor.from_elements %{{.*}} : tensor<i32>
88// DET-ALL:         return %{{.*}} : tensor<i32>
89// DET-ALL:       }
90
91// DET-CF-LABEL: func @main
92// DET-CF-SAME:    (%{{.*}}: tensor<10xi32>, %{{.*}}: tensor<i32>)
93// DET-CF:         cf.br ^[[bb1:.*]](%{{.*}} : tensor<10xi32>)
94// DET-CF:       ^bb1(%{{.*}}: tensor<10xi32>)
95// DET-CF:         %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) {
96// DET-CF:         tensor.extract %{{.*}}[] : tensor<i32>
97// DET-CF:         cmpi slt, %{{.*}}, %{{.*}} : i32
98// DET-CF:         cf.cond_br %{{.*}}, ^bb2, ^bb3
99// DET-CF:       ^bb2:
100// DET-CF:         %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
101// DET-CF:         cf.br ^bb1(%{{.*}} : tensor<10xi32>)
102// DET-CF:       ^bb3:
103// DET-CF:         return %{{.*}} : tensor<i32>
104// DET-CF:       }
105