xref: /llvm-project/mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-binary.mlir (revision 3efac5c68ac3117e8488a7fa247e45951e52936f)
1// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
2
3#map = affine_map<(d0, d1) -> (d0, d1)>
4func.func @specialize_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
5  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
6  ^bb0(%in: f32, %in_0: f32, %out: f32):
7    %1 = arith.addf %in, %in_0 : f32
8    linalg.yield %1 : f32
9  } -> tensor<?x?xf32>
10  return %0 : tensor<?x?xf32>
11}
12// CHECK-LABEL: specialize_add
13// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
14// CHECK-NOT: linalg.generic
15// CHECK: linalg.add ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
16
17func.func @specialize_sub(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
18  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
19  ^bb0(%in: f32, %in_0: f32, %out: f32):
20    %1 = arith.subf %in, %in_0 : f32
21    linalg.yield %1 : f32
22  } -> tensor<?x?xf32>
23  return %0 : tensor<?x?xf32>
24}
25// CHECK-LABEL: specialize_sub
26// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
27// CHECK-NOT: linalg.generic
28// CHECK: linalg.sub ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
29
30func.func @specialize_sub_swapped_operands(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
31  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
32  ^bb0(%in: f32, %in_0: f32, %out: f32):
33    %1 = arith.subf %in_0, %in : f32
34    linalg.yield %1 : f32
35  } -> tensor<?x?xf32>
36  return %0 : tensor<?x?xf32>
37}
38// CHECK-LABEL: specialize_sub
39// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
40// CHECK-NOT: linalg.generic
41// CHECK: linalg.sub ins(%[[ARG1]], %[[ARG0]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
42
43func.func @specialize_mul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
44  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
45  ^bb0(%in: f32, %in_0: f32, %out: f32):
46    %1 = arith.mulf %in, %in_0 : f32
47    linalg.yield %1 : f32
48  } -> tensor<?x?xf32>
49  return %0 : tensor<?x?xf32>
50}
51// CHECK-LABEL: specialize_mul
52// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
53// CHECK-NOT: linalg.generic
54// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
55
56func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
57  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
58  ^bb0(%in: f32, %in_0: f32, %out: f32):
59    %1 = arith.divf %in, %in_0 : f32
60    linalg.yield %1 : f32
61  } -> tensor<?x?xf32>
62  return %0 : tensor<?x?xf32>
63}
64// CHECK-LABEL: specialize_div
65// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
66// CHECK-NOT: linalg.generic
67// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
68
69
70module attributes {transform.with_named_sequence} {
71  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
72    %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
73    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
74    transform.yield
75  }
76}
77