xref: /llvm-project/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir (revision 1f5335c1db5d54b4465677c224b48e0ffc78e6d9)
1// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-with-reshape-by-collapsing -split-input-file | FileCheck %s
2
3// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
4// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
5
6// CHECK-LABEL: func @reshape
7// CHECK-SAME: (%[[A:.*]]: tensor<?x16xf32>, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor<?x112x16xf32>, %[[SZ0:.*]]: index)
8//      CHECK: %[[C112:.*]] = arith.constant 112 : index
9//      CHECK: %[[C0:.*]] = arith.constant 0 : index
10//      CHECK: %[[RI:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] : tensor<?x112x16xf32> into tensor<?x16xf32>
11//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
12// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
13// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<?x16xf32>)
14//      CHECK: %[[DIM:.*]] = tensor.dim %[[R]], %[[C0]] : tensor<?x16xf32>
15//      CHECK: %[[VAL_1:.*]] = arith.divsi %[[DIM]], %[[C112]] : index
16//      CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[\[}}0, 1], [2]] output_shape [%[[VAL_1]], 112, 16] : tensor<?x16xf32> into tensor<?x112x16xf32>
17//      CHECK: return %[[RR]] : tensor<?x112x16xf32>
18func.func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112x16xf32>, %sz0: index) -> tensor<?x112x16xf32> {
19  %0 = tensor.expand_shape %A [[0, 1], [2]] output_shape [%sz0, 112, 16]
20      : tensor<?x16xf32> into tensor<?x112x16xf32>
21  %2 = linalg.generic {indexing_maps = [
22    affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>,
23    affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
24    iterator_types = ["parallel", "parallel", "parallel"]}
25  ins(%0, %B : tensor<?x112x16xf32>, tensor<16xf32>)
26  outs(%init : tensor<?x112x16xf32>) {
27  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
28    %s = arith.subf %arg1, %arg2 : f32
29    linalg.yield %s : f32
30  } -> tensor<?x112x16xf32>
31  return %2 : tensor<?x112x16xf32>
32}
33
34// -----
35
36// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
37// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
38
39// CHECK-LABEL: func @reshape_multiple
40// CHECK-SAME: (%[[A:.*]]: tensor<12544x16xf32>, %[[B:.*]]: tensor<12544x16xf32>, %[[C:.*]]: tensor<16xf32>)
41//      CHECK: %[[I:.*]] = tensor.empty() : tensor<112x112x16xf32>
42//      CHECK: %[[RI:.*]] = tensor.collapse_shape %[[I]] {{\[}}[0, 1], [2]] : tensor<112x112x16xf32> into tensor<12544x16xf32>
43//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
44// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
45// CHECK-SAME: ins(%[[A]], %[[B]], %[[C]] : tensor<12544x16xf32>, tensor<12544x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<12544x16xf32>)
46//      CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[}}[0, 1], [2]] output_shape [112, 112, 16] : tensor<12544x16xf32> into tensor<112x112x16xf32>
47//      CHECK: return %[[RR]] : tensor<112x112x16xf32>
48func.func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>,
49  %C: tensor<16xf32>) -> tensor<112x112x16xf32> {
50  %0 = tensor.expand_shape %A [[0, 1], [2]] output_shape [112, 112, 16]
51      : tensor<12544x16xf32> into tensor<112x112x16xf32>
52  %1 = tensor.expand_shape %B [[0, 1], [2]] output_shape [112, 112, 16]
53      : tensor<12544x16xf32> into tensor<112x112x16xf32>
54  %2 = tensor.empty() : tensor<112x112x16xf32>
55  %3 = linalg.generic {indexing_maps = [
56    affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
57    affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
58    affine_map<(d0, d1, d2) -> (d2)>,
59    affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
60    iterator_types = ["parallel", "parallel", "parallel"]}
61  ins(%0, %1, %C : tensor<112x112x16xf32>, tensor<112x112x16xf32>, tensor<16xf32>)
62  outs(%2 : tensor<112x112x16xf32>) {
63  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32):
64    %s = arith.subf %arg1, %arg2 : f32
65    %m = arith.mulf %s, %arg3 : f32
66    linalg.yield %m : f32
67  } -> tensor<112x112x16xf32>
68  return %3 : tensor<112x112x16xf32>
69}
70
71// -----
72
73// Negative test, since the second source is broadcasted from d1 we cannot merge
74// d0 and d1 dimensions
75// CHECK-LABEL: func @reshape_negative
76// CHECK: tensor.expand_shape {{.*}} {{\[\[}}0, 1], [2]] output_shape [112, 112, 16] : tensor<12544x16xf32> into tensor<112x112x16xf32>
77// CHECK: linalg.generic
78// CHECK: } -> tensor<112x112x16xf32>
79func.func @reshape_negative(%A: tensor<12544x16xf32>, %B: tensor<112xf32>) -> tensor<112x112x16xf32> {
80  %20 = tensor.expand_shape %A [[0, 1], [2]] output_shape [112, 112, 16]
81      : tensor<12544x16xf32> into tensor<112x112x16xf32>
82  %21 = tensor.empty() : tensor<112x112x16xf32>
83  %22 = linalg.generic {indexing_maps = [
84    affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1)>,
85    affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
86    iterator_types = ["parallel", "parallel", "parallel"]}
87  ins(%20, %B : tensor<112x112x16xf32>, tensor<112xf32>)
88  outs(%21 : tensor<112x112x16xf32>) {
89  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
90    %s = arith.subf %arg1, %arg2 : f32
91    linalg.yield %s : f32
92  } -> tensor<112x112x16xf32>
93  return %22 : tensor<112x112x16xf32>
94}
95
96// -----
97
98func.func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>,
99    %arg2 : tensor<5xf32>) -> tensor<2x3x5xf32> {
100  %cst_6 = arith.constant 1.000000e+00 : f32
101  %cst_7 = arith.constant 7.000000e+00 : f32
102  %cst_8 = arith.constant 1.1920929E-7 : f32
103  %25 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 3, 5]
104      : tensor<6x5xi32> into tensor<2x3x5xi32>
105  %26 = tensor.empty() : tensor<2x3x5xf32>
106  %28 = linalg.generic {
107      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
108                       affine_map<(d0, d1, d2) -> (d2)>,
109                       affine_map<(d0, d1, d2) -> (d2)>,
110                       affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
111      iterator_types = ["parallel", "parallel", "parallel"]}
112      ins(%25, %arg1, %arg2 : tensor<2x3x5xi32>, tensor<5xf32>, tensor<5xf32>)
113      outs(%26 : tensor<2x3x5xf32>) {
114      ^bb0(%arg6: i32, %arg7: f32, %arg8: f32, %arg9: f32):
115        %29 = arith.sitofp %arg6 : i32 to f32
116        %30 = arith.addf %arg7, %cst_8 : f32
117        %31 = arith.divf %cst_7, %30 : f32
118        %32 = arith.divf %cst_6, %31 : f32
119        %33 = arith.mulf %29, %32 : f32
120        %34 = arith.addf %33, %arg8 : f32
121        linalg.yield %34 : f32
122      } -> tensor<2x3x5xf32>
123  return %28 : tensor<2x3x5xf32>
124}
125// CHECK-LABEL: func @type_correctness
126//       CHECK:   %[[OP:.+]] = linalg.generic
127//  CHECK-SAME:   ins(%{{.+}}, %{{.+}}, %{{.+}} : tensor<6x5xi32>, tensor<5xf32>, tensor<5xf32>)
128//  CHECK-SAME:   outs(%{{.+}} : tensor<6x5xf32>)
129//       CHECK:   tensor.expand_shape %[[OP]]
130//  CHECK-SAME:   tensor<6x5xf32> into tensor<2x3x5xf32>
131