xref: /llvm-project/mlir/test/Dialect/Linalg/detensorize_br_operands.mlir (revision 0a1569a400491e264060b8a6ff7b7f64e1865496)
1// RUN: mlir-opt %s -split-input-file -allow-unregistered-dialect -pass-pipeline="builtin.module(func.func(linalg-detensorize))" | FileCheck %s
2
3// TODO: Detensoring breaks if %arg0 or %arg1 are passed directly as tensors. Fix that.
4func.func @if_true_test(%arg0: i1, %arg1: i32) -> tensor<i32> attributes {} {
5  %arg0_t = tensor.from_elements %arg0 : tensor<i1>
6  %arg1_t = tensor.from_elements %arg1 : tensor<i32>
7
8  %cst = arith.constant dense<10> : tensor<i32>
9  %2 = tensor.empty() : tensor<i8>
10  %3 = linalg.generic
11    {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []}
12    ins(%arg0_t : tensor<i1>)
13    outs(%2 : tensor<i8>) {
14  ^bb0(%arg2: i1, %arg3: i8):
15    %10 = arith.extui %arg2 : i1 to i8
16    linalg.yield %10 : i8
17  } -> tensor<i8>
18  %4 = tensor.extract %3[] : tensor<i8>
19  %5 = arith.trunci %4 : i8 to i1
20  cf.cond_br %5, ^bb1, ^bb2(%arg1_t : tensor<i32>)
21^bb1:
22  %6 = tensor.empty() : tensor<i32>
23  %7 = linalg.generic
24    {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []}
25    ins(%arg1_t, %cst : tensor<i32>, tensor<i32>)
26    outs(%6 : tensor<i32>) {
27  ^bb0(%arg2: i32, %arg3: i32, %arg4: i32):
28    %10 = arith.addi %arg2, %arg3 : i32
29    linalg.yield %10 : i32
30  } -> tensor<i32>
31  cf.br ^bb2(%7 : tensor<i32>)
32^bb2(%8: tensor<i32>):
33  return %8 : tensor<i32>
34}
35
36// CHECK-LABEL:  func @if_true_test
37// CHECK-SAME:     (%[[arg0:.*]]: i1, %[[arg1:.*]]: i32)
38// CHECK-NEXT:     arith.constant 10 : i32
39// CHECK-NEXT:     cf.cond_br %[[arg0]], ^[[bb1:.*]], ^[[bb2:.*]](%[[arg1]] : i32)
40// CHECK-NEXT:   ^[[bb1]]:
41// CHECK-NEXT:     %[[add_res:.*]] = arith.addi
42// CHECK-NEXT:     cf.br ^[[bb2]](%[[add_res]] : i32)
43// CHECK-NEXT:   ^[[bb2]]
44// CHECK-NEXT:     %[[func_res:.*]] = tensor.from_elements
45// CHECK-NEXT:     return %[[func_res]]
46