xref: /llvm-project/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir (revision a79a0c52885c3a60d6afdda3b125866b8ed75fce)
1// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-simplify-pack-unpack-patterns" %s | FileCheck %s
2
3// CHECK-LABEL: func.func @single_dim_packing(
4// CHECK-SAME:    %[[ARG0:.+]]: tensor<256xf32>)
5// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [8, 32] : tensor<256xf32> into tensor<8x32xf32>
6// CHECK:         return %[[EXPANDED]] : tensor<8x32xf32>
7func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> {
8  %empty = tensor.empty() : tensor<8x32xf32>
9  %0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256xf32> -> tensor<8x32xf32>
10  return %0 : tensor<8x32xf32>
11}
12
13// -----
14
15// CHECK-LABEL: func.func @single_dim_packing_with_padding(
16// CHECK-SAME:    %[[ARG0:.+]]: tensor<255xf32>)
17// CHECK-NOT:     tensor.expand_shape
18// CHECK:         tensor.pack
19func.func @single_dim_packing_with_padding(%arg0: tensor<255xf32>) -> tensor<8x32xf32> {
20  %empty = tensor.empty() : tensor<8x32xf32>
21  %cst = arith.constant 0.000000e+00 : f32
22  %0 = tensor.pack %arg0 padding_value(%cst : f32) inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<255xf32> -> tensor<8x32xf32>
23  return %0 : tensor<8x32xf32>
24}
25
26// -----
27
28// CHECK-LABEL: func.func @single_last_inner_dim_packing(
29// CHECK-SAME:    %[[ARG0:.+]]: tensor<5x256xf32>)
30// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] output_shape [5, 8, 32] : tensor<5x256xf32> into tensor<5x8x32xf32>
31// CHECK:         return %[[EXPANDED]] : tensor<5x8x32xf32>
32func.func @single_last_inner_dim_packing(%arg0: tensor<5x256xf32>) -> tensor<5x8x32xf32> {
33  %empty = tensor.empty() : tensor<5x8x32xf32>
34  %0 = tensor.pack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x256xf32> -> tensor<5x8x32xf32>
35  return %0 : tensor<5x8x32xf32>
36}
37
38// -----
39
40// CHECK-LABEL: func.func @pack_1d_with_outer_dims_perm(
41// CHECK-SAME:    %[[ARG0:.+]]: tensor<64xf32>)
42// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [2, 32] : tensor<64xf32> into tensor<2x32xf32>
43// CHECK:         return %[[EXPANDED]] : tensor<2x32xf32>
44func.func @pack_1d_with_outer_dims_perm(%arg0: tensor<64xf32>) -> tensor<2x32xf32> {
45  %empty = tensor.empty() :  tensor<2x32xf32>
46  %pack = tensor.pack %arg0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<64xf32> -> tensor<2x32xf32>
47  return %pack : tensor<2x32xf32>
48}
49
50// -----
51
52// CHECK-LABEL: func.func @single_last_inner_dim_packing_with_identity_outer_dims_perm(
53// CHECK-SAME:    %[[ARG0:.+]]: tensor<5x256xf32>)
54// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] output_shape [5, 8, 32] : tensor<5x256xf32> into tensor<5x8x32xf32>
55// CHECK:         return %[[EXPANDED]] : tensor<5x8x32xf32>
56func.func @single_last_inner_dim_packing_with_identity_outer_dims_perm(%arg0: tensor<5x256xf32>) -> tensor<5x8x32xf32> {
57  %empty = tensor.empty() : tensor<5x8x32xf32>
58  %0 = tensor.pack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x256xf32> -> tensor<5x8x32xf32>
59  return %0 : tensor<5x8x32xf32>
60}
61
62// -----
63
64// CHECK-LABEL: func.func @packing_with_outer_dims_perm(
65// CHECK-NOT:     tensor.expand_shape
66// CHECK:         tensor.pack
67func.func @packing_with_outer_dims_perm(%arg0: tensor<5x256xf32>) -> tensor<8x5x32xf32> {
68  %empty = tensor.empty() : tensor<8x5x32xf32>
69  %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x256xf32> -> tensor<8x5x32xf32>
70  return %0 : tensor<8x5x32xf32>
71}
72
73// -----
74
75// CHECK-LABEL: func.func @single_first_inner_dim_packing(
76// CHECK-NOT:     tensor.expand_shape
77// CHECK:         tensor.pack
78func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x5x32xf32> {
79  %empty = tensor.empty() : tensor<8x5x32xf32>
80  %0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256x5xf32> -> tensor<8x5x32xf32>
81  return %0 : tensor<8x5x32xf32>
82}
83
84// -----
85
86// CHECK-LABEL: func.func @pack_1x32_to_1x32x1x1
87// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]+]]
88// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]] output_shape [1, 32, 1, 1]
89// CHECK:         return %[[EXPANDED]]
90func.func @pack_1x32_to_1x32x1x1(%arg0 : tensor<1x32xf32>) -> tensor<1x32x1x1xf32> {
91  %empty = tensor.empty() : tensor<1x32x1x1xf32>
92  %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %empty
93    : tensor<1x32xf32> -> tensor<1x32x1x1xf32>
94  return %pack : tensor<1x32x1x1xf32>
95}
96
97// -----
98
99// CHECK-LABEL: func.func @pack_1x32_to_1x16x1x2
100// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]+]]
101// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]] output_shape [1, 16, 1, 2]
102// CHECK:         return %[[EXPANDED]]
103func.func @pack_1x32_to_1x16x1x2(%arg0 : tensor<1x32xf32>) -> tensor<1x16x1x2xf32> {
104  %empty = tensor.empty() : tensor<1x16x1x2xf32>
105  %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [1, 2] into %empty
106    : tensor<1x32xf32> -> tensor<1x16x1x2xf32>
107  return %pack : tensor<1x16x1x2xf32>
108}
109
110// -----
111
112// CHECK-LABEL: func.func @pack_32x1_to_16x1x2x1
113// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]+]]
114// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] output_shape [1, 16, 2, 1]
115// CHECK:         return %[[EXPANDED]]
116func.func @pack_32x1_to_16x1x2x1(%arg0 : tensor<32x1xf32>) -> tensor<1x16x2x1xf32> {
117  %empty = tensor.empty() : tensor<1x16x2x1xf32>
118  %pack = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 1] into %empty
119    : tensor<32x1xf32> -> tensor<1x16x2x1xf32>
120  return %pack : tensor<1x16x2x1xf32>
121}
122
123// -----
124
125// CHECK-LABEL: func.func @pack_32x1_to_16x1x1x2
126// CHECK-NOT:     tensor.expand_shape
127// CHECK:         tensor.pack
128func.func @pack_32x1_to_16x1x1x2(%arg0 : tensor<32x1xf32>) -> tensor<16x1x1x2xf32> {
129  %empty = tensor.empty() : tensor<16x1x1x2xf32>
130  %pack = tensor.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [1, 2] into %empty
131    : tensor<32x1xf32> -> tensor<16x1x1x2xf32>
132  return %pack : tensor<16x1x1x2xf32>
133}
134
135// -----
136
137// CHECK-LABEL: func.func @unpack_1d_to_collapse
138// CHECK-SAME:    %[[ARG0:.+]]: tensor<8x32xf32>)
139// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<8x32xf32> into tensor<256xf32>
140// CHECK:         return %[[COLLAPSED]]
141func.func @unpack_1d_to_collapse(%arg0: tensor<8x32xf32>) -> tensor<256xf32> {
142  %empty = tensor.empty() : tensor<256xf32>
143  %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<256xf32>
144  return %0 : tensor<256xf32>
145}
146
147// -----
148
149// CHECK-LABEL: func.func @unpack_to_partial_slice
150// CHECK-NOT:     tensor.collapse
151// CHECK:         tensor.unpack
152func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
153  %empty = tensor.empty() : tensor<255xf32>
154  %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<255xf32>
155  return %0 : tensor<255xf32>
156}
157
158// -----
159
160// CHECK-LABEL: func.func @unpack_dynamic
161// CHECK-NOT:     tensor.collapse
162// CHECK:         tensor.unpack
163func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
164  %c32 = arith.constant 32 : index
165  %c0 = arith.constant 0 : index
166  %d0 = tensor.dim %arg0, %c0 : tensor<?x32xf32>
167  %size = arith.muli %d0, %c32 : index
168  %empty = tensor.empty(%size) : tensor<?xf32>
169  %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<?x32xf32> -> tensor<?xf32>
170  return %0 : tensor<?xf32>
171}
172
173// -----
174
175// CHECK-LABEL: func.func @single_last_inner_dim_unpacking(
176// CHECK-SAME:    %[[ARG0:.+]]: tensor<5x8x32xf32>)
177// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x8x32xf32> into tensor<5x256xf32>
178// CHECK:         return %[[COLLAPSED]] : tensor<5x256xf32>
179func.func @single_last_inner_dim_unpacking(%arg0: tensor<5x8x32xf32>) -> tensor<5x256xf32> {
180  %empty = tensor.empty() : tensor<5x256xf32>
181  %0 = tensor.unpack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x8x32xf32> -> tensor<5x256xf32>
182  return %0 : tensor<5x256xf32>
183}
184
185// -----
186
187// CHECK-LABEL: func.func @single_last_inner_dim_unpacking_with_identity_outer_dims_perm(
188// CHECK-SAME:    %[[ARG0:.+]]: tensor<5x8x32xf32>)
189// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x8x32xf32> into tensor<5x256xf32>
190// CHECK:         return %[[COLLAPSED]] : tensor<5x256xf32>
191func.func @single_last_inner_dim_unpacking_with_identity_outer_dims_perm(%arg0: tensor<5x8x32xf32>) -> tensor<5x256xf32> {
192  %empty = tensor.empty() : tensor<5x256xf32>
193  %0 = tensor.unpack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x8x32xf32> -> tensor<5x256xf32>
194  return %0 : tensor<5x256xf32>
195}
196
197// -----
198
199// CHECK-LABEL: func.func @unpacking_with_outer_dims_perm(
200// CHECK-NOT:     tensor.collpase_shape
201// CHECK:         tensor.unpack
202func.func @unpacking_with_outer_dims_perm(%arg0: tensor<8x5x32xf32>) -> tensor<5x256xf32> {
203  %empty = tensor.empty() : tensor<5x256xf32>
204  %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<5x256xf32>
205  return %0 : tensor<5x256xf32>
206}
207
208// -----
209
210// CHECK-LABEL: func.func @single_first_inner_dim_unpacking(
211// CHECK-NOT:     tensor.collapse_shape
212// CHECK:         tensor.unpack
213func.func @single_first_inner_dim_unpacking(%arg0: tensor<8x5x32xf32>) -> tensor<256x5xf32> {
214  %empty = tensor.empty() : tensor<256x5xf32>
215  %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<256x5xf32>
216  return %0 : tensor<256x5xf32>
217}
218
219// -----
220
221// CHECK-LABEL: func.func @unpack_1x32x1x1_to_1x32
222// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]+]]
223// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]]
224// CHECK:         return %[[COLLAPSED]]
225func.func @unpack_1x32x1x1_to_1x32(%arg0 : tensor<1x32x1x1xf32>) -> tensor<1x32xf32> {
226  %empty = tensor.empty() : tensor<1x32xf32>
227  %unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %empty
228    : tensor<1x32x1x1xf32> -> tensor<1x32xf32>
229  return %unpack : tensor<1x32xf32>
230}
231
232// -----
233
234// CHECK-LABEL: func.func @unpack_1x2x1x16_to_1x32
235// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]+]]
236// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]]
237// CHECK:         return %[[COLLAPSED]]
238func.func @unpack_1x2x1x16_to_1x32(%arg0 : tensor<1x2x1x16xf32>) -> tensor<1x32xf32> {
239  %empty = tensor.empty() : tensor<1x32xf32>
240  %unpack = tensor.unpack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [1, 16] into %empty
241    : tensor<1x2x1x16xf32> -> tensor<1x32xf32>
242  return %unpack : tensor<1x32xf32>
243}
244
245// -----
246
247// CHECK-LABEL: func.func @unpack_16x1x2x1_to_32x1
248// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]+]]
249// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
250// CHECK:         return %[[COLLAPSED]]
251func.func @unpack_16x1x2x1_to_32x1(%arg0 : tensor<1x16x2x1xf32>) -> tensor<32x1xf32> {
252  %empty = tensor.empty() : tensor<32x1xf32>
253  %unpack = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 1] into %empty
254    : tensor<1x16x2x1xf32> -> tensor<32x1xf32>
255  return %unpack : tensor<32x1xf32>
256}
257
258// -----
259
260// CHECK-LABEL: func.func @unpack_16x1x1x2_to_32x1
261// CHECK-NOT:     tensor.collapse_shape
262// CHECK:         tensor.unpack
263func.func @unpack_16x1x1x2_to_32x1(%arg0 : tensor<16x1x1x2xf32>) -> tensor<32x1xf32> {
264  %empty = tensor.empty() : tensor<32x1xf32>
265  %unpack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [1, 2] into %empty
266    : tensor<16x1x1x2xf32> -> tensor<32x1xf32>
267  return %unpack : tensor<32x1xf32>
268}
269
270// -----
271
272// CHECK-LABEL: func.func @pad_like_pack(
273// CHECK-SAME:    %[[ARG0:.+]]: tensor<32x64xf32>)
274// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] output_shape [1, 1, 32, 64] : tensor<32x64xf32> into tensor<1x1x32x64xf32>
275// CHECK:         return %[[EXPANDED]] : tensor<1x1x32x64xf32>
276func.func @pad_like_pack(%arg0: tensor<32x64xf32>) -> tensor<1x1x32x64xf32> {
277  %empty = tensor.empty() : tensor<1x1x32x64xf32>
278  %0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<32x64xf32> -> tensor<1x1x32x64xf32>
279  return %0 : tensor<1x1x32x64xf32>
280}
281
282// -----
283
284// CHECK-LABEL: func.func @pad_like_pack_with_outer_dims_perm(
285// CHECK-SAME:    %[[ARG0:.+]]: tensor<32x64xf32>)
286// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] output_shape [1, 1, 32, 64] : tensor<32x64xf32> into tensor<1x1x32x64xf32>
287// CHECK:         return %[[EXPANDED]] : tensor<1x1x32x64xf32>
288func.func @pad_like_pack_with_outer_dims_perm(%arg0: tensor<32x64xf32>) -> tensor<1x1x32x64xf32> {
289  %empty = tensor.empty() : tensor<1x1x32x64xf32>
290  %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<32x64xf32> -> tensor<1x1x32x64xf32>
291  return %0 : tensor<1x1x32x64xf32>
292}
293
294// -----
295
296// CHECK-LABEL: func.func @inner_pad_like_pack(
297// CHECK-SAME:    %[[ARG0:.+]]: tensor<32x64xf32>)
298// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] output_shape [32, 1, 64] : tensor<32x64xf32> into tensor<32x1x64xf32>
299// CHECK:         return %[[EXPANDED]] : tensor<32x1x64xf32>
300func.func @inner_pad_like_pack(%arg0: tensor<32x64xf32>) -> tensor<32x1x64xf32> {
301  %empty = tensor.empty() : tensor<32x1x64xf32>
302  %0 = tensor.pack %arg0 inner_dims_pos = [1] inner_tiles = [64] into %empty : tensor<32x64xf32> -> tensor<32x1x64xf32>
303  return %0 : tensor<32x1x64xf32>
304}
305
306// -----
307
308// Do not simplify pack with inner dimension shuffling.
309// CHECK-LABEL: func.func @pad_and_inner_dim_shuffle_pack(
310// CHECK-SAME:    %[[ARG0:.+]]: tensor<32x64xf32>)
311// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x64x32xf32>
312// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1, 0] inner_tiles = [64, 32] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<1x1x64x32xf32>
313// CHECK:         return %[[PACK]] : tensor<1x1x64x32xf32>
314func.func @pad_and_inner_dim_shuffle_pack(%arg0: tensor<32x64xf32>) -> tensor<1x1x64x32xf32> {
315  %empty = tensor.empty() : tensor<1x1x64x32xf32>
316  %0 = tensor.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [64, 32] into %empty : tensor<32x64xf32> -> tensor<1x1x64x32xf32>
317  return %0 : tensor<1x1x64x32xf32>
318}
319
320// -----
321
322// Do not simplify pack with inner dimension transpose.
323// CHECK-LABEL: func.func @pad_like_pack_with_transpose(
324// CHECK-SAME:    %[[ARG0:.+]]: tensor<32x64x16xf32>)
325// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<32x1x16x64xf32>
326// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [64] into %[[EMPTY]] : tensor<32x64x16xf32> -> tensor<32x1x16x64xf32>
327// CHECK:         return %[[PACK]] : tensor<32x1x16x64xf32>
328func.func @pad_like_pack_with_transpose(%arg0: tensor<32x64x16xf32>) -> tensor<32x1x16x64xf32> {
329  %empty = tensor.empty() : tensor<32x1x16x64xf32>
330  %0 = tensor.pack %arg0 inner_dims_pos = [1] inner_tiles = [64] into %empty : tensor<32x64x16xf32> -> tensor<32x1x16x64xf32>
331  return %0 : tensor<32x1x16x64xf32>
332}
333
334// -----
335
336// CHECK-LABEL: func.func @unpad_like_unpack(
337// CHECK-SAME:    %[[ARG0:.+]]: tensor<1x1x32x64xf32>)
338// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x32x64xf32> into tensor<32x64xf32>
339// CHECK:         return %[[COLLAPSED]] : tensor<32x64xf32>
340func.func @unpad_like_unpack(%arg0: tensor<1x1x32x64xf32>) -> tensor<32x64xf32> {
341  %empty = tensor.empty() : tensor<32x64xf32>
342  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<1x1x32x64xf32> -> tensor<32x64xf32>
343  return %0 : tensor<32x64xf32>
344}
345
346// -----
347
348// CHECK-LABEL: func.func @unpad_like_unpack_with_outer_dims_perm(
349// CHECK-SAME:    %[[ARG0:.+]]: tensor<1x1x32x64xf32>)
350// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x32x64xf32> into tensor<32x64xf32>
351// CHECK:         return %[[COLLAPSED]] : tensor<32x64xf32>
352func.func @unpad_like_unpack_with_outer_dims_perm(%arg0: tensor<1x1x32x64xf32>) -> tensor<32x64xf32> {
353  %empty = tensor.empty() : tensor<32x64xf32>
354  %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<1x1x32x64xf32> -> tensor<32x64xf32>
355  return %0 : tensor<32x64xf32>
356}
357
358// -----
359
360// CHECK-LABEL: func.func @inner_unpad_like_unpack(
361// CHECK-SAME:    %[[ARG0:.+]]: tensor<32x1x64xf32>)
362// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<32x1x64xf32> into tensor<32x64xf32>
363// CHECK:         return %[[COLLAPSED]] : tensor<32x64xf32>
364func.func @inner_unpad_like_unpack(%arg0: tensor<32x1x64xf32>) -> tensor<32x64xf32> {
365  %empty = tensor.empty() : tensor<32x64xf32>
366  %0 = tensor.unpack %arg0 inner_dims_pos = [1] inner_tiles = [64] into %empty : tensor<32x1x64xf32> -> tensor<32x64xf32>
367  return %0 : tensor<32x64xf32>
368}
369
370// -----
371
372// Do not simplify unpack with inner dimension shuffling.
373// CHECK-LABEL: func.func @unpad_and_inner_dim_shuffle_pack(
374// CHECK-SAME:    %[[ARG0:.+]]: tensor<1x1x32x64xf32>)
375// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<64x32xf32>
376// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] inner_dims_pos = [1, 0] inner_tiles = [32, 64] into %[[EMPTY]] : tensor<1x1x32x64xf32> -> tensor<64x32xf32>
377// CHECK:         return %[[UNPACK]] : tensor<64x32xf32>
378func.func @unpad_and_inner_dim_shuffle_pack(%arg0: tensor<1x1x32x64xf32>) -> tensor<64x32xf32> {
379  %empty = tensor.empty() : tensor<64x32xf32>
380  %0 = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 64] into %empty : tensor<1x1x32x64xf32> -> tensor<64x32xf32>
381  return %0 : tensor<64x32xf32>
382}
383
384// -----
385
386// Do not simplify unpack with inner dimension transpose.
387// CHECK-LABEL: func.func @unpad_like_unpack_with_transpose(
388// CHECK-SAME:    %[[ARG0:.+]]: tensor<32x1x16x64xf32>)
389// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<32x64x16xf32>
390// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [64] into %[[EMPTY]] : tensor<32x1x16x64xf32> -> tensor<32x64x16xf32>
391// CHECK:         return %[[UNPACK]] : tensor<32x64x16xf32>
392func.func @unpad_like_unpack_with_transpose(%arg0: tensor<32x1x16x64xf32>) -> tensor<32x64x16xf32> {
393  %empty = tensor.empty() : tensor<32x64x16xf32>
394  %0 = tensor.unpack %arg0 inner_dims_pos = [1] inner_tiles = [64] into %empty : tensor<32x1x16x64xf32> -> tensor<32x64x16xf32>
395  return %0 : tensor<32x64x16xf32>
396}
397