1// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(linalg-detensorize{aggressive-mode}))" | FileCheck %s -check-prefix=DET-ALL 2// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(linalg-detensorize))" | FileCheck %s -check-prefix=DET-CF 3 4#map0 = affine_map<() -> ()> 5#map1 = affine_map<(i) -> ()> 6#map2 = affine_map<(i) -> (i)> 7 8#attrs = { 9 indexing_maps = [#map0, #map0, #map0], 10 iterator_types = [] 11} 12 13#sum_reduction_attrs = { 14 indexing_maps = [#map2, #map1], 15 iterator_types = ["reduction"] 16} 17 18 19#broadcast_attrs = { 20 indexing_maps = [#map1, #map2], 21 iterator_types = ["parallel"] 22} 23 24func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attributes {} { 25 cf.br ^bb1(%farg0 : tensor<10xi32>) 26 27^bb1(%0: tensor<10xi32>): // 2 preds: ^bb0, ^bb2 28 %1 = tensor.empty() : tensor<i32> 29 %2 = linalg.generic #sum_reduction_attrs 30 ins(%0: tensor<10xi32>) 31 outs(%1: tensor<i32>) { 32 ^bb(%a: i32, %x: i32): 33 %b = arith.addi %x, %a : i32 34 linalg.yield %b : i32 35 } -> tensor<i32> 36 37 %3 = tensor.empty() : tensor<i1> 38 %4 = linalg.generic #attrs 39 ins(%2, %farg1 : tensor<i32>, tensor<i32>) 40 outs(%3 : tensor<i1>) { 41 ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): 42 %8 = arith.cmpi slt, %arg0, %arg1 : i32 43 linalg.yield %8 : i1 44 } -> tensor<i1> 45 %5 = tensor.extract %4[] : tensor<i1> 46 cf.cond_br %5, ^bb2(%2 : tensor<i32>), ^bb3(%2 : tensor<i32>) 47 48^bb2(%6: tensor<i32>): // pred: ^bb1 49 %7 = tensor.empty() : tensor<10xi32> 50 %9 = linalg.generic #broadcast_attrs 51 ins(%6: tensor<i32>) 52 outs(%7: tensor<10xi32>) { 53 ^bb(%a: i32, %b: i32) : 54 linalg.yield %a : i32 55 } -> tensor<10xi32> 56 57 cf.br ^bb1(%9 : tensor<10xi32>) 58 59^bb3(%10: tensor<i32>): // pred: ^bb1 60 return %10 : tensor<i32> 61} 62 63// Test aggresively detensoring all detensorable ops. 64// 65// DET-ALL-LABEL: func @main 66// DET-ALL-SAME: (%{{.*}}: tensor<10xi32>, %{{.*}}: tensor<i32>) 67// DET-ALL: cf.br ^[[bb1:.*]](%{{.*}} : tensor<10xi32>) 68// DET-ALL: ^[[bb1]](%{{.*}}: tensor<10xi32>) 69// DET-ALL: tensor.empty() : tensor<i32> 70// DET-ALL: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) { 71// DET-ALL: ^bb0(%{{.*}}: i32, %{{.*}}: i32): 72// DET-ALL: %{{.*}} = arith.addi %{{.*}}, %{{.*}} 73// DET-ALL: linalg.yield %{{.*}} : i32 74// DET-ALL: } -> tensor<i32> 75// DET-ALL: tensor.extract %{{.*}}[] : tensor<i32> 76// DET-ALL: cmpi slt, %{{.*}}, %{{.*}} : i32 77// DET-ALL: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^[[bb3:.*]] 78// DET-ALL: ^[[bb2]]: 79// DET-ALL: tensor.from_elements %{{.*}} : tensor<i32> 80// DET-ALL: tensor.empty() : tensor<10xi32> 81// DET-ALL: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) { 82// DET-ALL: ^bb0(%{{.*}}: i32, %{{.*}}: i32): 83// DET-ALL: linalg.yield %{{.*}} : i32 84// DET-ALL: } -> tensor<10xi32> 85// DET-ALL: cf.br ^[[bb1]](%{{.*}} : tensor<10xi32>) 86// DET-ALL: ^[[bb3]] 87// DET-ALL: tensor.from_elements %{{.*}} : tensor<i32> 88// DET-ALL: return %{{.*}} : tensor<i32> 89// DET-ALL: } 90 91// DET-CF-LABEL: func @main 92// DET-CF-SAME: (%{{.*}}: tensor<10xi32>, %{{.*}}: tensor<i32>) 93// DET-CF: cf.br ^[[bb1:.*]](%{{.*}} : tensor<10xi32>) 94// DET-CF: ^bb1(%{{.*}}: tensor<10xi32>) 95// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) { 96// DET-CF: tensor.extract %{{.*}}[] : tensor<i32> 97// DET-CF: cmpi slt, %{{.*}}, %{{.*}} : i32 98// DET-CF: cf.cond_br %{{.*}}, ^bb2, ^bb3 99// DET-CF: ^bb2: 100// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) { 101// DET-CF: cf.br ^bb1(%{{.*}} : tensor<10xi32>) 102// DET-CF: ^bb3: 103// DET-CF: return %{{.*}} : tensor<i32> 104// DET-CF: } 105