xref: /llvm-project/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir (revision 13bd41096286305ee603428f6adf161f52981827)
1// RUN: mlir-opt -pass-pipeline="builtin.module(func.func(convert-elementwise-to-linalg))" -split-input-file %s | FileCheck %s
2
3// In-depth checking of the linalg.generic op for a very trivial case.
4// CHECK: #[[$MAP:.*]] = affine_map<() -> ()>
5// CHECK-LABEL: func @addf_rank0
6//  CHECK-SAME:   %[[ARG0:[0-9a-zA-Z]*]]: tensor<f32>
7//  CHECK-SAME:   %[[ARG1:[0-9a-zA-Z]*]]: tensor<f32>
8func.func @addf_rank0(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
9  //      CHECK: %{{.*}} = linalg.generic
10  // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]]
11  // CHECK-SAME: iterator_types = []
12  // CHECK-SAME:  ins(%[[ARG0]], %[[ARG1]]
13  // CHECK-SAME: outs(%[[ARG0]]
14  //      CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32, %{{.*}}: f32):
15  //      CHECK:   %[[YIELD:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
16  //      CHECK:   linalg.yield %[[YIELD]] : f32
17  //      CHECK: } -> tensor<f32>
18  %0 = arith.addf %arg0, %arg1 : tensor<f32>
19  return %0 : tensor<f32>
20}
21
22// -----
23
24// Check indexing maps and iterator types for the rank > 0 case.
25// CHECK-LABEL: func @addf_rank1
26//  CHECK-SAME:   %[[ARG0:[0-9a-zA-Z]*]]: tensor<?xf32>
27//  CHECK-SAME:   %[[ARG1:[0-9a-zA-Z]*]]: tensor<?xf32>
28func.func @addf_rank1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
29  // CHECK: linalg.generic
30  // CHECK-SAME: iterator_types = ["parallel"]
31  // CHECK-SAME:  ins(%[[ARG0]], %[[ARG1]]
32  // CHECK-SAME: outs(%[[ARG0]]
33  %0 = arith.addf %arg0, %arg1 : tensor<?xf32>
34  return %0 : tensor<?xf32>
35}
36
37// -----
38
39// Check a unary op.
40// CHECK-LABEL: func @exp
41//  CHECK-SAME:   %[[ARG0:[0-9a-zA-Z]*]]: tensor<f32>
42func.func @exp(%arg0: tensor<f32>) -> tensor<f32> {
43  // CHECK: linalg.generic
44  // CHECK-SAME:  ins(%[[ARG0]]
45  // CHECK-SAME: outs(%[[ARG0]]
46  // CHECK: ^bb0(%[[SCALAR:.*]]: f32, %{{.*}}: f32):
47  // CHECK:   %[[YIELD:.*]] = math.exp %[[SCALAR]] : f32
48  // CHECK:   linalg.yield %[[YIELD]] : f32
49  %0 = math.exp %arg0 : tensor<f32>
50  return %0 : tensor<f32>
51}
52
53// -----
54
55// Check a case with varying operand types.
56// CHECK-LABEL: func @select
57//  CHECK-SAME:   %[[ARG0:[0-9a-zA-Z]*]]: tensor<i1>
58//  CHECK-SAME:   %[[ARG1:[0-9a-zA-Z]*]]: tensor<i32>
59//  CHECK-SAME:   %[[ARG2:[0-9a-zA-Z]*]]: tensor<i32>
60func.func @select(%arg0: tensor<i1>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<i32> {
61  // CHECK: linalg.generic
62  // CHECK-SAME:  ins(%[[ARG0]], %[[ARG1]], %[[ARG2]]
63  // CHECK-SAME: outs(%[[ARG1]]
64  // CHECK: ^bb0(%[[PRED:.*]]: i1, %[[TRUE_VAL:.*]]: i32, %[[FALSE_VAL:.*]]: i32, %{{.*}}: i32):
65  // CHECK:   arith.select %[[PRED]], %[[TRUE_VAL]], %[[FALSE_VAL]] : i32
66  %0 = arith.select %arg0, %arg1, %arg2 : tensor<i1>, tensor<i32>
67  return %0 : tensor<i32>
68}
69
70// -----
71
72// Spot-check an op that requires copying attributes properly to the created scalar op.
73// Also checks proper init_tensor usage.
74// CHECK-LABEL: func @cmpf(
75//  CHECK-SAME:   %[[ARG0:[0-9a-zA-Z]*]]: tensor<f32>
76//  CHECK-SAME:   %[[ARG1:[0-9a-zA-Z]*]]: tensor<f32>
77func.func @cmpf(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<i1> {
78  // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<i1>
79  // CHECK: linalg.generic
80  // CHECK-SAME:  ins(%[[ARG0]], %[[ARG1]]
81  // CHECK-SAME: outs(%[[INIT]]
82  // CHECK: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: i1):
83  // CHECK: arith.cmpf olt, %{{.*}}, %{{.*}} : f32
84  %0 = arith.cmpf olt, %arg0, %arg1 : tensor<f32>
85  return %0 : tensor<i1>
86}
87
88// -----
89
90// Check proper init_tensor usage in a mixed case.
91// CHECK-LABEL: func @cmpf(
92//  CHECK-SAME:   %[[ARG0:[0-9a-zA-Z]*]]: tensor<4x?x?x8x2x?xf32>
93//  CHECK-SAME:   %[[ARG1:[0-9a-zA-Z]*]]: tensor<4x?x?x8x2x?xf32>
94func.func @cmpf(%arg0: tensor<4x?x?x8x2x?xf32>, %arg1: tensor<4x?x?x8x2x?xf32>) -> tensor<4x?x?x8x2x?xi1> {
95  // CHECK: %[[C1:.*]] = arith.constant 1 : index
96  // CHECK: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<4x?x?x8x2x?xf32>
97  // CHECK: %[[C2:.*]] = arith.constant 2 : index
98  // CHECK: %[[D2:.*]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<4x?x?x8x2x?xf32>
99  // CHECK: %[[C5:.*]] = arith.constant 5 : index
100  // CHECK: %[[D5:.*]] = tensor.dim %[[ARG0]], %[[C5]] : tensor<4x?x?x8x2x?xf32>
101  // CHECK: %[[INIT:.*]] = tensor.empty(%[[D1]], %[[D2]], %[[D5]]) : tensor<4x?x?x8x2x?xi1>
102  // CHECK: linalg.generic
103  // CHECK-SAME:  ins(%[[ARG0]], %[[ARG1]]
104  // CHECK-SAME: outs(%[[INIT]]
105  // CHECK: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: i1):
106  // CHECK: arith.cmpf olt, %{{.*}}, %{{.*}} : f32
107  %0 = arith.cmpf olt, %arg0, %arg1 : tensor<4x?x?x8x2x?xf32>
108  return %0 : tensor<4x?x?x8x2x?xi1>
109}
110
111