xref: /llvm-project/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir (revision 8f4d5a32ace7f858881b6a59663ff6596b162dbc)
1// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-reassociative-reshape-folding %s | FileCheck %s
2
3// CHECK-LABEL: func @expand_shape_of_rank_reducing_extract(
4//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?x?xf32>
5//   CHECK-DAG:   %[[extract1:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0]
6//   CHECK-SAME:    [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x1x1x5xf32>
7//   CHECK-DAG:   %[[extract2:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0]
8//   CHECK-SAME:    [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x1x1x5xf32>
9//       CHECK:   return %[[extract1]], %[[extract2]]
10func.func @expand_shape_of_rank_reducing_extract(
11    %t: tensor<?x?x?x?xf32>, %idx: index)
12  -> (tensor<?x1x1x5xf32>, tensor<?x1x1x5xf32>)
13{
14  %0 = tensor.extract_slice %t[0, 0, 0, 0][%idx, 1, 1, 5][1, 1, 1, 1]
15      : tensor<?x?x?x?xf32> to tensor<?x1x5xf32>
16  %c0 = arith.constant 0 : index
17  %sz0 = tensor.dim %0, %c0 : tensor<?x1x5xf32>
18  %1 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%sz0, 1, 1, 5]
19      : tensor<?x1x5xf32> into tensor<?x1x1x5xf32>
20  %2 = tensor.expand_shape %0 [[0, 1], [2], [3]] output_shape [%sz0, 1, 1, 5]
21      : tensor<?x1x5xf32> into tensor<?x1x1x5xf32>
22  return %1, %2 : tensor<?x1x1x5xf32>, tensor<?x1x1x5xf32>
23}
24
25// -----
26
27// CHECK-LABEL: func @unpadding_collapse_of_extract_slice(
28//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?x?xf32>
29//  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
30//  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
31//       CHECK:   %[[extract:.*]] = tensor.extract_slice %[[t]][%[[x]], %[[y]], 0, 0]
32//  CHECK-SAME:     [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?xf32>
33//       CHECK:   return %[[extract]]
34func.func @unpadding_collapse_of_extract_slice(
35    %t: tensor<?x?x?x?xf32>, %x: index, %y: index)
36  -> tensor<?x?xf32> {
37  %c1 = arith.constant 1 : index
38  %c3 = arith.constant 3 : index
39  %sz0 = tensor.dim %t, %c1 : tensor<?x?x?x?xf32>
40  %sz1 = tensor.dim %t, %c3 : tensor<?x?x?x?xf32>
41  %0 = tensor.extract_slice %t[%x, %y, 0, 0] [1, %sz0, 1, %sz1] [1, 1, 1, 1]
42      : tensor<?x?x?x?xf32> to tensor<1x?x1x?xf32>
43  %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
44      : tensor<1x?x1x?xf32> into tensor<?x?xf32>
45  return %1 : tensor<?x?xf32>
46}
47
48// -----
49
50// CHECK-LABEL: func @non_unpadding_collapse_of_extract_slice(
51//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?x?xf32>
52//  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
53//  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
54//  CHECK-SAME:     %[[sz:[a-zA-Z0-9_]+]]: index
55//       CHECK:   %[[extract:.*]] = tensor.extract_slice %[[t]][%[[x]], %[[y]], 0, 0]
56//  CHECK-SAME:     [%{{.*}}, %{{.*}}, %[[sz]], 1] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?xf32>
57//       CHECK:   %[[collapse:.*]] = tensor.collapse_shape %[[extract]] {{\[}}[0], [1, 2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
58//       CHECK:   return %[[collapse]]
59func.func @non_unpadding_collapse_of_extract_slice(
60    %t: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index)
61  -> tensor<?x?xf32> {
62  %c0 = arith.constant 0 : index
63  %c1 = arith.constant 1 : index
64  %sz0 = tensor.dim %t, %c0 : tensor<?x?x?x?xf32>
65  %sz1 = tensor.dim %t, %c1 : tensor<?x?x?x?xf32>
66  %0 = tensor.extract_slice %t[%x, %y, 0, 0] [%sz0, %sz1, %sz, 1] [1, 1, 1, 1]
67      : tensor<?x?x?x?xf32> to tensor<?x?x?xf32>
68  %1 = tensor.collapse_shape %0 [[0], [1, 2]]
69      : tensor<?x?x?xf32> into tensor<?x?xf32>
70  return %1 : tensor<?x?xf32>
71}
72
73// -----
74
75// CHECK-LABEL: func @unpadding_collapse_of_extract_slice_with_multiple_users(
76//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?x?xf32>
77//  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
78//  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
79//       CHECK:   %[[extract:.*]] = tensor.extract_slice %[[t]][%[[x]], %[[y]], 0, 0]
80//  CHECK-SAME:     [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<1x?x1x?xf32>
81//       CHECK:   %[[collapse:.*]] = tensor.collapse_shape %[[extract]] {{\[}}[0, 1], [2, 3]] : tensor<1x?x1x?xf32> into tensor<?x?xf32>
82//       CHECK:   return %[[extract]], %[[collapse]]
83func.func @unpadding_collapse_of_extract_slice_with_multiple_users(
84    %t: tensor<?x?x?x?xf32>, %x: index, %y: index)
85  -> (tensor<1x?x1x?xf32>, tensor<?x?xf32>) {
86  %c1 = arith.constant 1 : index
87  %c3 = arith.constant 3 : index
88  %sz0 = tensor.dim %t, %c1 : tensor<?x?x?x?xf32>
89  %sz1 = tensor.dim %t, %c3 : tensor<?x?x?x?xf32>
90  %0 = tensor.extract_slice %t[%x, %y, 0, 0] [1, %sz0, 1, %sz1] [1, 1, 1, 1]
91      : tensor<?x?x?x?xf32> to tensor<1x?x1x?xf32>
92  %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
93      : tensor<1x?x1x?xf32> into tensor<?x?xf32>
94  return %0, %1 : tensor<1x?x1x?xf32>, tensor<?x?xf32>
95}
96
97// -----
98
99// CHECK-LABEL: func @rank_reducing_insert_of_collapse_shape(
100//  CHECK-SAME:     %[[t:.*]]: tensor<?x1x1x5xf32>
101//       CHECK:   %[[insert:.*]] = tensor.insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0]
102//  CHECK-SAME:     [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x1x1x5xf32> into tensor<?x?x?x?xf32>
103//       CHECK:   return %[[insert]]
104func.func @rank_reducing_insert_of_collapse_shape(
105    %t: tensor<?x1x1x5xf32>, %d: tensor<?x?x?x?xf32>, %sz: index)
106  -> tensor<?x?x?x?xf32> {
107  %0 = tensor.collapse_shape %t [[0, 1], [2], [3]]
108      : tensor<?x1x1x5xf32> into tensor<?x1x5xf32>
109  %1 = tensor.insert_slice %0 into %d[0, 0, 0, 0][%sz, 1, 1, 5][1, 1, 1, 1]
110      : tensor<?x1x5xf32> into tensor<?x?x?x?xf32>
111  return %1 : tensor<?x?x?x?xf32>
112}
113
114// -----
115
116// CHECK-LABEL: func @rank_reducing_parallel_insert_of_collapse_shape(
117//  CHECK-SAME:     %[[t:.*]]: tensor<?x1x1x5xf32>
118//       CHECK:   tensor.parallel_insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0]
119//  CHECK-SAME:     [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x1x1x5xf32> into tensor<?x?x?x?xf32>
120func.func @rank_reducing_parallel_insert_of_collapse_shape(
121    %t: tensor<?x1x1x5xf32>, %d: tensor<?x?x?x?xf32>, %sz: index, %thr: index)
122  -> tensor<?x?x?x?xf32> {
123  %0 = tensor.collapse_shape %t [[0, 1], [2], [3]]
124      : tensor<?x1x1x5xf32> into tensor<?x1x5xf32>
125  %1 = scf.forall (%iv) in (%thr) shared_outs(%o = %d) -> (tensor<?x?x?x?xf32>) {
126    scf.forall.in_parallel {
127      tensor.parallel_insert_slice %0 into %o[0, 0, 0, 0][%sz, 1, 1, 5][1, 1, 1, 1]
128          : tensor<?x1x5xf32> into tensor<?x?x?x?xf32>
129    }
130  }
131  return %1 : tensor<?x?x?x?xf32>
132}
133
134// -----
135
136// CHECK-LABEL: func @insert_of_padding_expand_shape(
137//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>
138//  CHECK-SAME:     %[[d:.*]]: tensor<?x?x?x?xf32>
139//  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
140//  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
141//       CHECK:   %[[insert:.*]] = tensor.insert_slice %[[t]] into %[[d]][%[[x]], %[[y]], 0, 0]
142//  CHECK-SAME:     [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
143//       CHECK:   return %[[insert]]
144func.func @insert_of_padding_expand_shape(
145    %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index)
146  -> tensor<?x?x?x?xf32> {
147  %c0 = arith.constant 0 : index
148  %c1 = arith.constant 1 : index
149  %sz0 = tensor.dim %t, %c0 : tensor<?x?xf32>
150  %sz1 = tensor.dim %t, %c1 : tensor<?x?xf32>
151  %0 = tensor.expand_shape %t [[0, 1], [2, 3]] output_shape [1, %sz0, 1, %sz1]
152      : tensor<?x?xf32> into tensor<1x?x1x?xf32>
153  %1 = tensor.insert_slice %0 into %d[%x, %y, 0, 0][1, %sz0, 1, %sz1][1, 1, 1, 1]
154      : tensor<1x?x1x?xf32> into tensor<?x?x?x?xf32>
155  return %1 : tensor<?x?x?x?xf32>
156}
157
158// -----
159
160// CHECK-LABEL: func @insert_of_non_padding_expand_shape(
161//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>
162//  CHECK-SAME:     %[[d:.*]]: tensor<?x?x?x?xf32>
163//  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
164//  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
165//  CHECK-SAME:     %[[sz:[a-zA-Z0-9_]+]]: index
166//       CHECK:   %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]]
167//  CHECK-SAME:     output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
168//       CHECK:   %[[insert:.*]] = tensor.insert_slice %[[expand]] into %[[d]][%[[x]], %[[y]], 0, 0]
169//  CHECK-SAME:     [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
170//       CHECK:   return %[[insert]]
171func.func @insert_of_non_padding_expand_shape(
172    %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index)
173  -> tensor<?x?x?x?xf32> {
174  %c0 = arith.constant 0 : index
175  %c1 = arith.constant 1 : index
176  %sz0 = tensor.dim %t, %c0 : tensor<?x?xf32>
177  %sz1 = tensor.dim %t, %c1 : tensor<?x?xf32>
178  %0 = tensor.expand_shape %t [[0, 1], [2]] output_shape [%sz, %sz0, %sz1]
179      : tensor<?x?xf32> into tensor<?x?x?xf32>
180  %1 = tensor.insert_slice %0 into %d[%x, %y, 0, 0][%sz, 1, %sz0, %sz1][1, 1, 1, 1]
181      : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
182  return %1 : tensor<?x?x?x?xf32>
183}
184
185// -----
186
187// CHECK-LABEL: func @parallel_insert_of_padding_expand_shape(
188//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>
189//  CHECK-SAME:     %[[d:.*]]: tensor<?x?x?x?xf32>
190//  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
191//  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
192//       CHECK:   tensor.parallel_insert_slice %[[t]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0]
193//  CHECK-SAME:     [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
194func.func @parallel_insert_of_padding_expand_shape(
195    %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index)
196  -> tensor<?x?x?x?xf32> {
197  %c0 = arith.constant 0 : index
198  %c1 = arith.constant 1 : index
199  %sz0 = tensor.dim %t, %c0 : tensor<?x?xf32>
200  %sz1 = tensor.dim %t, %c1 : tensor<?x?xf32>
201  %0 = tensor.expand_shape %t [[0, 1], [2, 3]] output_shape [1, %sz0, 1, %sz1]
202      : tensor<?x?xf32> into tensor<1x?x1x?xf32>
203  %1 = scf.forall (%i, %j) in (%x, %y) shared_outs(%o = %d) -> (tensor<?x?x?x?xf32>) {
204    scf.forall.in_parallel {
205      tensor.parallel_insert_slice %0 into %o[%i, %j, 0, 0][1, %sz0, 1, %sz1][1, 1, 1, 1]
206          : tensor<1x?x1x?xf32> into tensor<?x?x?x?xf32>
207    }
208  }
209  return %1 : tensor<?x?x?x?xf32>
210}
211
212// -----
213
214// CHECK-LABEL: func @parallel_insert_of_non_padding_expand_shape(
215//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>
216//  CHECK-SAME:     %[[d:.*]]: tensor<?x?x?x?xf32>
217//  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
218//  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
219//  CHECK-SAME:     %[[sz:[a-zA-Z0-9_]+]]: index
220//       CHECK:   %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]]
221//  CHECK-SAME:     output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
222//       CHECK:   tensor.parallel_insert_slice %[[expand]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0]
223//  CHECK-SAME:     [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
224func.func @parallel_insert_of_non_padding_expand_shape(
225    %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index)
226  -> tensor<?x?x?x?xf32> {
227  %c0 = arith.constant 0 : index
228  %c1 = arith.constant 1 : index
229  %sz0 = tensor.dim %t, %c0 : tensor<?x?xf32>
230  %sz1 = tensor.dim %t, %c1 : tensor<?x?xf32>
231  %0 = tensor.expand_shape %t [[0, 1], [2]] output_shape [%sz, %sz0, %sz1]
232      : tensor<?x?xf32> into tensor<?x?x?xf32>
233  %1 = scf.forall (%i, %j) in (%x, %y) shared_outs(%o = %d) -> (tensor<?x?x?x?xf32>) {
234    scf.forall.in_parallel {
235      tensor.parallel_insert_slice %0 into %o[%i, %j, 0, 0][%sz, 1, %sz0, %sz1][1, 1, 1, 1]
236          : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
237    }
238  }
239  return %1 : tensor<?x?x?x?xf32>
240}
241