xref: /llvm-project/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir (revision 97069a86193a617a9e4cf742a29db6116b2bf449)
1// RUN: mlir-opt -test-linalg-elementwise-fusion-patterns=control-fusion-by-expansion %s -split-input-file | FileCheck %s
2
3func.func @control_producer_reshape_fusion(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?xf32>) -> tensor<?x?xf32> {
4  %c0 = arith.constant 0 : index
5  %c1 = arith.constant 1 : index
6  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
7  %d0 = tensor.dim %0, %c0 : tensor<?x?xf32>
8  %d1 = tensor.dim %0, %c1 : tensor<?x?xf32>
9  %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
10  %1 = linalg.generic {
11      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
12      iterator_types = ["parallel", "parallel"]}
13      ins(%0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
14      outs(%init : tensor<?x?xf32>) {
15      ^bb0(%arg2 : f32, %arg3:f32, %arg4 : f32):
16        %2 = arith.addf %arg2, %arg3 : f32
17        linalg.yield %2 : f32
18      } -> tensor<?x?xf32>
19  return %1 : tensor<?x?xf32>
20}
21//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
22//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1)>
23//      CHECK: func @control_producer_reshape_fusion
24// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
25// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?xf32>
26//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
27//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
28//      CHECK:   %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
29// CHECK-SAME:       {{\[}}[0, 1], [2]{{\]}} : tensor<?x?x?xf32> into tensor<?x?xf32>
30//      CHECK:   %[[RESULT:.+]] = linalg.generic
31// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP0]]]
32// CHECK-SAME:       ins(%[[RESHAPE]], %[[ARG1]] : tensor<?x?xf32>, tensor<?xf32>)
33//      CHECK:   return %[[RESULT]]
34
35// -----
36
37func.func @control_consumer_reshape_fusion(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32> {
38  %c1 = arith.constant 1 : index
39  %c2 = arith.constant 2 : index
40  %cst = arith.constant 0.0 : f32
41  %d0 = tensor.dim %arg0, %c1 : tensor<1x?x?xf32>
42  %d1 = tensor.dim %arg1, %c2 : tensor<1x?x?xf32>
43  %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
44  %fill = linalg.generic {
45      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
46      iterator_types = ["parallel", "parallel"]}
47      outs(%init : tensor<?x?xf32>) {
48      ^bb0(%arg2: f32):
49        linalg.yield %cst : f32
50      } -> tensor<?x?xf32>
51  %0 = tensor.expand_shape %fill [[0, 1], [2]] output_shape [1, %d0, %d1] : tensor<?x?xf32> into tensor<1x?x?xf32>
52  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>)
53      outs(%0 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32>
54  return %1 : tensor<1x?x?xf32>
55}
56//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)
57//      CHECK: func @control_consumer_reshape_fusion
58//      CHECK:   %[[FILL:.+]] = linalg.generic
59// CHECK-SAME:       indexing_maps = [#[[MAP]]]
60// CHECK-SAME:       outs(%{{.+}} : tensor<1x?x?xf32>)
61//      CHECK:   linalg.batch_matmul
62// CHECK-SAME:       outs(%[[FILL]] : tensor<1x?x?xf32>)
63