xref: /llvm-project/mlir/test/Dialect/Linalg/data-layout-propagation.mlir (revision 1f5335c1db5d54b4465677c224b48e0ffc78e6d9)
1// RUN: mlir-opt %s -test-linalg-data-layout-propagation -split-input-file | FileCheck %s
2
3#map0 = affine_map<(d0, d1) -> (d0, d1)>
4func.func @dynamic_elem_pack(%arg0: tensor<?x?xf32>, %dest: tensor<?x?x8x2xf32>) -> tensor<?x?x8x2xf32>
5{
6  %c0 = arith.constant 0 : index
7  %c1 = arith.constant 1 : index
8  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
9  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
10  %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
11  %3 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
12      ins(%arg0 : tensor<?x?xf32>)
13      outs(%2 : tensor<?x?xf32>) {
14    ^bb0(%arg3: f32, %arg4: f32):
15      %4 = arith.addf %arg3, %arg3 : f32
16      linalg.yield %4 : f32
17  } -> tensor<?x?xf32>
18  %4 = tensor.pack %3
19    inner_dims_pos = [0, 1]
20    inner_tiles = [8, 2]
21    into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
22  return %4 : tensor<?x?x8x2xf32>
23}
24// CHECK-DAG:  #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
25// CHECK-DAG:  #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
26// CHECK-DAG:  #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
27// CHECK-LABEL:  func.func @dynamic_elem_pack
28// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
29// CHECK-SAME:     %[[DEST:[a-zA-Z0-9]+]]
30// CHECK-DAG:      %[[C0:.+]] = arith.constant 0 : index
31// CHECK-DAG:      %[[C1:.+]] = arith.constant 1 : index
32// CHECK-DAG:      %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
33// CHECK-DAG:      %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
34// CHECK-DAG:      %[[OUTER_D0:.+]] = affine.apply #[[$MAP0]]()[%[[D0]]]
35// CHECK-DAG:      %[[OUTER_D1:.+]] = affine.apply #[[$MAP1]]()[%[[D1]]]
36// CHECK:          %[[ARG0_EMPTY:.+]] = tensor.empty(%[[OUTER_D0]], %[[OUTER_D1]]) : tensor<?x?x8x2xf32>
37// CHECK:          %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
38// CHECK-SAME:       inner_dims_pos = [0, 1] inner_tiles = [8, 2]
39// CHECK-SAME:       into %[[ARG0_EMPTY]]
40// CHECK:          %[[ELEM:.+]] = linalg.generic
41// CHECK-SAME:       indexing_maps = [#[[$MAP2]], #[[$MAP2]]]
42// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel"]
43// CHECK-SAME:       ins(%[[PACK_ARG0]]
44// CHECK-SAME:       outs(%[[DEST]]
45// CHECK:          return %[[ELEM]] : tensor<?x?x8x2xf32>
46
47// -----
48
49#map0 = affine_map<(d0, d1) -> (d0, d1)>
50func.func @elem_pack_transpose_inner_dims(%arg0: tensor<128x256xi32>, %dest: tensor<4x16x16x32xi32>) -> tensor<4x16x16x32xi32>{
51  %init = tensor.empty() : tensor<128x256xi32>
52  %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
53      ins(%arg0 : tensor<128x256xi32>)
54      outs(%init : tensor<128x256xi32>) {
55    ^bb0(%arg3: i32, %arg4: i32):
56      %4 = arith.addi %arg3, %arg3 : i32
57      linalg.yield %4 : i32
58  } -> tensor<128x256xi32>
59  %pack = tensor.pack %elem
60    inner_dims_pos = [1, 0]
61    inner_tiles = [16, 32]
62    into %dest : tensor<128x256xi32> -> tensor<4x16x16x32xi32>
63  return %pack : tensor<4x16x16x32xi32>
64}
65// CHECK-DAG:  #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
66// CHECK-LABEL: func.func @elem_pack_transpose_inner_dims
67// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
68// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
69// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32>
70// CHECK:         %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
71// CHECK-SAME:      inner_dims_pos = [1, 0] inner_tiles = [16, 32]
72// CHECK-SAME:      into %[[ARG0_EMPTY]]
73// CHECK:         %[[ELEM:.+]] = linalg.generic
74// CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
75// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel"]
76// CHECK-SAME:      ins(%[[PACK_ARG0]]
77// CHECK-SAME:      outs(%[[DEST]]
78// CHECK:         return %[[ELEM]] : tensor<4x16x16x32xi32>
79
80// -----
81
82#map0 = affine_map<(d0, d1) -> (d0, d1)>
83func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %dest: tensor<16x4x32x16xi32>) -> tensor<16x4x32x16xi32>{
84  %init = tensor.empty() : tensor<128x256xi32>
85  %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
86      ins(%arg0 : tensor<128x256xi32>)
87      outs(%init : tensor<128x256xi32>) {
88    ^bb0(%arg3: i32, %arg4: i32):
89      %4 = arith.addi %arg3, %arg3 : i32
90      linalg.yield %4 : i32
91  } -> tensor<128x256xi32>
92  %pack = tensor.pack %elem
93    outer_dims_perm = [1, 0]
94    inner_dims_pos = [0, 1]
95    inner_tiles = [32, 16]
96    into %dest : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
97  return %pack : tensor<16x4x32x16xi32>
98}
99// CHECK-DAG:  #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
100// CHECK-LABEL: func.func @elem_pack_transpose_outer_dims
101// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
102// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
103// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
104// CHECK:         %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
105// CHECK-SAME:      outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
106// CHECK-SAME:      into %[[ARG0_EMPTY]] : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
107// CHECK:         %[[ELEM:.+]] = linalg.generic
108// CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
109// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel"]
110// CHECK-SAME:      ins(%[[PACK_ARG0]]
111// CHECK-SAME:      outs(%[[DEST]]
112// CHECK:         return %[[ELEM]] : tensor<16x4x32x16xi32>
113
114// -----
115
116#map0 = affine_map<(d0, d1) -> (d0, d1)>
117func.func @elem_pack_transpose_inner_and_outer_dims(%arg0: tensor<128x256xi32>, %dest: tensor<16x4x16x32xi32>) -> tensor<16x4x16x32xi32>{
118  %init = tensor.empty() : tensor<128x256xi32>
119  %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
120      ins(%arg0 : tensor<128x256xi32>)
121      outs(%init : tensor<128x256xi32>) {
122    ^bb0(%arg3: i32, %arg4: i32):
123      %4 = arith.addi %arg3, %arg3 : i32
124      linalg.yield %4 : i32
125  } -> tensor<128x256xi32>
126  %pack = tensor.pack %elem
127    outer_dims_perm = [1, 0]
128    inner_dims_pos = [1, 0]
129    inner_tiles = [16, 32]
130    into %dest : tensor<128x256xi32> -> tensor<16x4x16x32xi32>
131  return %pack : tensor<16x4x16x32xi32>
132}
133// CHECK-DAG:  #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
134// CHECK-LABEL: func.func @elem_pack_transpose_inner_and_outer_dims
135// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
136// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
137// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x16x32xi32>
138// CHECK:         %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
139// CHECK-SAME:      outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32]
140// CHECK-SAME:      into %[[ARG0_EMPTY]]
141// CHECK:         %[[ELEM:.+]] = linalg.generic
142// CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
143// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel"]
144// CHECK-SAME:      ins(%[[PACK_ARG0]]
145// CHECK-SAME:      outs(%[[DEST]]
146// CHECK:         return %[[ELEM]] : tensor<16x4x16x32xi32>
147
148// -----
149
150#map0 = affine_map<(d0, d1) -> (d0, d1)>
151#map1 = affine_map<(d0, d1) -> (d0)>
152#map2 = affine_map<(d0, d1) -> (d1)>
153func.func @dynamic_broadcast_pack(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %dest: tensor<?x?x8x2xf32>) -> tensor<?x?x8x2xf32>
154{
155  %c0 = arith.constant 0 : index
156  %0 = tensor.dim %arg0, %c0 : tensor<?xf32>
157  %1 = tensor.dim %arg1, %c0 : tensor<?xf32>
158  %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
159  %3 = linalg.generic {indexing_maps = [#map1, #map2, #map0], iterator_types = ["parallel", "parallel"]}
160      ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
161      outs(%2 : tensor<?x?xf32>) {
162    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
163      %4 = arith.addf %arg3, %arg4 : f32
164      linalg.yield %4 : f32
165  } -> tensor<?x?xf32>
166  %4 = tensor.pack %3
167    inner_dims_pos = [0, 1]
168    inner_tiles = [8, 2]
169    into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
170  return %4 : tensor<?x?x8x2xf32>
171}
172// CHECK-DAG:  #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
173// CHECK-DAG:  #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
174// CHECK-DAG:  #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
175// CHECK-DAG:  #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
176// CHECK-DAG:  #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
177// CHECK-LABEL: func.func @dynamic_broadcast_pack
178// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
179// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
180// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
181// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
182// CHECK-DAG:     %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
183// CHECK-DAG:     %[[OUTER_D0:.+]] = affine.apply #[[$MAP0]]()[%[[D0]]]
184// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty(%[[OUTER_D0]]) : tensor<?x8xf32>
185// CHECK:         %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
186// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [8]
187// CHECK-SAME:      into %[[ARG0_EMPTY]]
188// CHECK-DAG:     %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C0]]
189// CHECK-DAG:     %[[OUTER_D1:.+]] = affine.apply #[[$MAP1]]()[%[[D1]]]
190// CHECK:         %[[ARG1_EMPTY:.+]] = tensor.empty(%[[OUTER_D1]]) : tensor<?x2xf32>
191// CHECK:         %[[PACK_ARG1:.+]] = tensor.pack %[[ARG1]]
192// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [2]
193// CHECK-SAME:      into %[[ARG1_EMPTY]]
194// CHECK:         %[[ELEM:.+]] = linalg.generic
195// CHECK-SAME:      indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP4]]]
196// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel"]
197// CHECK-SAME:      ins(%[[PACK_ARG0]], %[[PACK_ARG0]]
198// CHECK-SAME:      outs(%[[DEST]]
199// CHECK:         return %[[ELEM]] : tensor<?x?x8x2xf32>
200
201// -----
202
203#map = affine_map<(d0, d1, d2, d3) -> (d3)>
204#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
205func.func @elem_pack_transpose_inner_and_outer_dims2(%arg0: tensor<64xf32>, %dest: tensor<1x2x56x57x32xf32>) -> tensor<1x2x56x57x32xf32> {
206  %0 = tensor.empty() : tensor<1x56x57x64xf32>
207  %1 = linalg.generic {
208      indexing_maps = [#map, #map1],
209      iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
210    ins(%arg0 : tensor<64xf32>)
211    outs(%0 : tensor<1x56x57x64xf32>) {
212    ^bb0(%in: f32, %out: f32):
213      linalg.yield %in : f32
214  } -> tensor<1x56x57x64xf32>
215  %2 = tensor.pack %1 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %dest : tensor<1x56x57x64xf32> -> tensor<1x2x56x57x32xf32>
216  return %2 : tensor<1x2x56x57x32xf32>
217}
218// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d4)>
219// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
220// CHECK-LABEL: func.func @elem_pack_transpose_inner_and_outer_dims2
221// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
222// CHECK-SAME:     %[[DEST:[a-zA-Z0-9]+]]
223// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<2x32xf32>
224// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
225// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [32]
226// CHECK-SAME:    into %[[ARG0_EMPTY]]
227// CHECK:         %[[RES:.+]] = linalg.generic
228// CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
229// CHECK-SAME:      ins(%[[PACKED_ARG0]]
230// CHECK-SAME:      outs(%[[DEST]]
231
232// -----
233
234func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<100x200x4x16x16x32xi32>) -> tensor<100x200x4x16x16x32xi32>
235{
236  %init_transpose = tensor.empty() : tensor<100x200x128x256xi32>
237  %transpose = linalg.generic {
238      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
239                       affine_map<(d0, d1, d2, d3) -> (d0)>,
240                       affine_map<(d0, d1, d2, d3) -> (d1)>,
241                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>],
242      iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
243      ins(%arg0, %arg1, %arg2 : tensor<100x128x200x256xi32>, tensor<100xi32>, tensor<128xi32>)
244      outs(%init_transpose : tensor<100x200x128x256xi32>) {
245    ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
246      %0 = arith.addi %b0, %b1 : i32
247      %1 = arith.addi %0, %b2 : i32
248      linalg.yield %1 : i32
249    } -> tensor<100x200x128x256xi32>
250  %4 = tensor.pack %transpose
251    inner_dims_pos = [3, 2]
252    inner_tiles = [16, 32]
253    into %dest : tensor<100x200x128x256xi32> -> tensor<100x200x4x16x16x32xi32>
254  return %4 : tensor<100x200x4x16x16x32xi32>
255}
256// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
257// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)>
258// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d5)>
259// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d3, d4, d5)>
260// CHECK-LABEL: func.func @transpose_pack
261// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
262// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
263// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]
264// CHECK-SAME:     %[[DEST:[a-zA-Z0-9]+]]
265// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<100x4x200x16x16x32xi32>
266// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
267// CHECK-SAME:      inner_dims_pos = [3, 1] inner_tiles = [16, 32]
268// CHECK-SAME:    into %[[ARG0_EMPTY]]
269// CHECK:         %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32>
270// CHECK:         %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]]
271// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [32]
272// CHECK-SAME:    into %[[ARG2_EMPTY]]
273// CHECK:         %[[RES:.+]] = linalg.generic
274// CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
275// CHECK-SAME:      ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]]
276// CHECK-SAME:      outs(%[[DEST]]
277
278// -----
279
280func.func @affine_constant_expr_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100x1x1x1xi32>, %arg2: tensor<1x128x1x1xi32>, %dest: tensor<100x200x4x16x16x32xi32>) -> tensor<100x200x4x16x16x32xi32>
281{
282  %init_transpose = tensor.empty() : tensor<100x200x128x256xi32>
283  %transpose = linalg.generic {
284      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
285                       affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, 0)>,
286                       affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>,
287                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>],
288      iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
289      ins(%arg0, %arg1, %arg2 : tensor<100x128x200x256xi32>, tensor<100x1x1x1xi32>, tensor<1x128x1x1xi32>)
290      outs(%init_transpose : tensor<100x200x128x256xi32>) {
291    ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
292      %0 = arith.addi %b0, %b1 : i32
293      %1 = arith.addi %0, %b2 : i32
294      linalg.yield %1 : i32
295    } -> tensor<100x200x128x256xi32>
296  %4 = tensor.pack %transpose
297    inner_dims_pos = [3, 2]
298    inner_tiles = [16, 32]
299    into %dest : tensor<100x200x128x256xi32> -> tensor<100x200x4x16x16x32xi32>
300  return %4 : tensor<100x200x4x16x16x32xi32>
301}
302// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
303// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, 0, 0, 0)>
304// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (0, d1, 0, 0, d5)>
305// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d3, d4, d5)>
306// CHECK-LABEL: func.func @affine_constant_expr_pack
307// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
308// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
309// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]
310// CHECK-SAME:     %[[DEST:[a-zA-Z0-9]+]]
311// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<100x4x200x16x16x32xi32>
312// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
313// CHECK-SAME:      inner_dims_pos = [3, 1] inner_tiles = [16, 32]
314// CHECK-SAME:    into %[[ARG0_EMPTY]]
315// CHECK:         %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<1x4x1x1x32xi32>
316// CHECK:         %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]]
317// CHECK-SAME:      inner_dims_pos = [1] inner_tiles = [32]
318// CHECK-SAME:    into %[[ARG2_EMPTY]]
319// CHECK:         %[[RES:.+]] = linalg.generic
320// CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
321// CHECK-SAME:      ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]]
322// CHECK-SAME:      outs(%[[DEST]]
323
324// -----
325
326func.func @transpose_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<200x4x16x100x16x32xi32>) -> tensor<200x4x16x100x16x32xi32>
327{
328  %init_transpose = tensor.empty() : tensor<100x200x128x256xi32>
329  %transpose = linalg.generic {
330      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
331                       affine_map<(d0, d1, d2, d3) -> (d0)>,
332                       affine_map<(d0, d1, d2, d3) -> (d1)>,
333                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>],
334      iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
335      ins(%arg0, %arg1, %arg2 : tensor<100x128x200x256xi32>, tensor<100xi32>, tensor<128xi32>)
336      outs(%init_transpose : tensor<100x200x128x256xi32>) {
337    ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
338      %0 = arith.addi %b0, %b1 : i32
339      %1 = arith.addi %0, %b2 : i32
340      linalg.yield %1 : i32
341    } -> tensor<100x200x128x256xi32>
342  %4 = tensor.pack %transpose
343    outer_dims_perm = [1, 2, 3, 0]
344    inner_dims_pos = [3, 2]
345    inner_tiles = [16, 32]
346    into %dest : tensor<100x200x128x256xi32> -> tensor<200x4x16x100x16x32xi32>
347  return %4 : tensor<200x4x16x100x16x32xi32>
348}
349
350// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
351// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>
352// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d5)>
353// CHECK-LABEL: func.func @transpose_pack_with_outer_dims
354// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
355// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
356// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]
357// CHECK-SAME:     %[[DEST:[a-zA-Z0-9]+]]
358// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<200x4x16x100x16x32xi32>
359// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
360// CHECK-SAME:      outer_dims_perm = [2, 1, 3, 0] inner_dims_pos = [3, 1] inner_tiles = [16, 32]
361// CHECK-SAME:      into %[[ARG0_EMPTY]]
362// CHECK:         %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32>
363// CHECK:         %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]]
364// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [32]
365// CHECK-SAME:      into %[[ARG2_EMPTY]]
366// CHECK:         %[[RES:.+]] = linalg.generic
367// CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP]]]
368// CHECK-SAME:      ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]]
369// CHECK-SAME:      outs(%[[DEST]]
370
371// -----
372
373#map0 = affine_map<(d0, d1) -> (d0, d1)>
374func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %init: tensor<128x256xi32>) -> tensor<16x4x32x16xi32>{
375  %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
376      ins(%arg0 : tensor<128x256xi32>)
377      outs(%init : tensor<128x256xi32>) {
378    ^bb0(%arg3: i32, %arg4: i32):
379      %4 = arith.addi %arg3, %arg4 : i32
380      linalg.yield %4 : i32
381  } -> tensor<128x256xi32>
382  %empty = tensor.empty() : tensor<16x4x32x16xi32>
383  %pack = tensor.pack %elem
384    outer_dims_perm = [1, 0]
385    inner_dims_pos = [0, 1]
386    inner_tiles = [32, 16]
387    into %empty : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
388  return %pack : tensor<16x4x32x16xi32>
389}
390
391// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
392// CHECK-LABEL: func.func @elem_pack_transpose_outer_dims
393// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
394// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
395// CHECK:         %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
396// CHECK:         %[[PACKED_ARG1:.+]] = tensor.pack %[[ARG1]]
397// CHECK-SAME:      outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
398// CHECK-SAME:      into %[[ARG1_EMPTY]]
399// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
400// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
401// CHECK-SAME:      outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
402// CHECK-SAME:      into %[[ARG0_EMPTY]]
403// CHECK:         %[[RES:.+]] = linalg.generic
404// CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
405// CHECK-SAME:      ins(%[[PACKED_ARG0]]
406// CHECK-SAME:      outs(%[[PACKED_ARG1]]
407
408// -----
409
410#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
411
412func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> {
413  %0 = tensor.empty() : tensor<12x56x56x64xf32>
414  %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
415  %2 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%1 : tensor<12x56x56x64xf32>) {
416    ^bb0(%out: f32):
417      %3 = arith.addf %out, %out : f32
418      linalg.yield %3 : f32
419  } -> tensor<12x56x56x64xf32>
420  return %2 : tensor<12x56x56x64xf32>
421}
422
423// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
424// CHECK-LABEL: func.func @unpack_on_output
425// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
426// CHECK:         %[[ARG0_EMPTY_UNPACK:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
427// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
428// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
429// CHECK-SAME:      into %[[ARG0_EMPTY_UNPACK]]
430// CHECK:         %[[ARG0_EMPTY_PACK:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
431// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
432// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
433// CHECK-SAME:      into %[[ARG0_EMPTY_PACK]]
434// CHECK:         %[[RES:.+]] = linalg.generic
435// CHECK-SAME:      indexing_maps = [#[[$MAP]]]
436// CHECK-SAME:      outs(%[[PACKED_ARG0]]
437// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[RES]]
438// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
439// CHECK-SAME:      into %[[UNPACKED_ARG0]]
440
441// -----
442
443#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
444
445func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf32>) -> tensor<12x56x56x64xf32> {
446  %0 = tensor.empty() : tensor<12x56x56x64xf32>
447  %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
448  %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf32>) {
449    ^bb0(%in: f32, %out: f32):
450      %3 = arith.addf %in, %out : f32
451      linalg.yield %3 : f32
452  } -> tensor<12x56x56x64xf32>
453  return %2 : tensor<12x56x56x64xf32>
454}
455
456// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
457// CHECK-LABEL: func.func @unpack_on_input
458// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
459// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
460// CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
461// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
462// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
463// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
464// CHECK:         %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
465// CHECK:         %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
466// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
467// CHECK-SAME:      into %[[ARG1_PACK_EMPTY]]
468// CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
469// CHECK:         %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]]
470// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
471// CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
472// CHECK:         %[[RES:.+]] = linalg.generic
473// CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
474// CHECK-SAME:      ins(%[[ARG0_PACK]]
475// CHECK-SAME:      outs(%[[ARG1_PACK]]
476// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[RES]]
477// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
478// CHECK-SAME:      into %[[ARG1]]
479
480// -----
481
482#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
483
484func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> {
485  %0 = tensor.empty() : tensor<12x56x56x64xf32>
486  %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
487  %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf16>) {
488    ^bb0(%in: f32, %out: f16):
489      %3 = arith.truncf %in : f32 to f16
490      linalg.yield %3 : f16
491  } -> tensor<12x56x56x64xf16>
492  return %2 : tensor<12x56x56x64xf16>
493}
494
495// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
496// CHECK-LABEL: func.func @unpack_element_type_change
497// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
498// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
499// CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
500// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
501// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
502// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
503// CHECK:         %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
504// CHECK:         %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
505// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
506// CHECK-SAME:      into %[[ARG1_PACK_EMPTY]]
507// CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
508// CHECK:         %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]]
509// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
510// CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
511// CHECK:         %[[RES:.+]] = linalg.generic
512// CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
513// CHECK-SAME:      ins(%[[ARG0_PACK]]
514// CHECK-SAME:      outs(%[[ARG1_PACK]]
515// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[RES]]
516// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
517// CHECK-SAME:      into %[[ARG1]]
518
519// -----
520
521#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
522
523func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> {
524  %init = tensor.empty() : tensor<12x56x56x64xf32>
525  %0 = tensor.empty() : tensor<12x56x56x64xf32>
526  %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
527  %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf32>) {
528    ^bb0(%in: f32, %out: f32):
529      %3 = arith.addf %in, %in : f32
530      linalg.yield %3 : f32
531  } -> tensor<12x56x56x64xf32>
532  return %2 : tensor<12x56x56x64xf32>
533}
534
535// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
536// CHECK-LABEL: func.func @forward_tensor_empty
537// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
538// CHECK:         %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
539// CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
540// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
541// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
542// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
543// CHECK:         %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
544// CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
545// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
546// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
547// CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
548// CHECK:         %[[RES:.+]] = linalg.generic
549// CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
550// CHECK-SAME:      ins(%[[PACKED_ARG0]]
551// CHECK-SAME:      outs(%[[DEST]]
552// CHECK:         %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
553// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
554// CHECK-SAME:      into %[[FINAL_RES]]
555
556// -----
557
558func.func @pad_valid_unpack_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x58x58x64xf32> {
559  %cst = arith.constant 0.000000e+00 : f32
560  %0 = tensor.empty() : tensor<1x56x56x64xf32>
561  %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
562  %padded = tensor.pad %1 low[0, 1, 1, 0] high[0, 1, 1, 0] {
563    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
564    tensor.yield %cst : f32
565  } : tensor<1x56x56x64xf32> to tensor<1x58x58x64xf32>
566  return %padded : tensor<1x58x58x64xf32>
567}
568
569// CHECK-LABEL: func.func @pad_valid_unpack_propagation(
570// CHECK-SAME:     %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
571// CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
572// CHECK:         %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
573// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x58x58x64xf32>
574// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
575// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
576// CHECK-SAME:      into %[[EMPTY]] : tensor<1x2x58x58x32xf32> -> tensor<1x58x58x64xf32>
577
578// -----
579
580func.func @pad_valid_unpack_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<2x58x58x64xf32> {
581  %cst = arith.constant 0.000000e+00 : f32
582  %0 = tensor.empty() : tensor<1x56x56x64xf32>
583  %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
584  %padded = tensor.pad %1 low[1, 1, 1, 0] high[0, 1, 1, 0] {
585    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
586    tensor.yield %cst : f32
587  } : tensor<1x56x56x64xf32> to tensor<2x58x58x64xf32>
588  return %padded : tensor<2x58x58x64xf32>
589}
590
591// CHECK-LABEL: func.func @pad_valid_unpack_propagation(
592// CHECK-SAME:     %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
593// CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
594// CHECK:         %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[1, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
595// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<2x58x58x64xf32>
596// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
597// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
598// CHECK-SAME:      into %[[EMPTY]] : tensor<2x2x58x58x32xf32> -> tensor<2x58x58x64xf32>
599
600// -----
601
602func.func @pad_along_unpacked_dim(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x58x58x66xf32> {
603  %cst = arith.constant 0.000000e+00 : f32
604  %0 = tensor.empty() : tensor<1x56x56x64xf32>
605  %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
606  %padded = tensor.pad %1 low[0, 1, 1, 1] high[0, 1, 1, 1] {
607    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
608    tensor.yield %cst : f32
609  } : tensor<1x56x56x64xf32> to tensor<1x58x58x66xf32>
610  return %padded : tensor<1x58x58x66xf32>
611}
612
613// CHECK-LABEL: func.func @pad_along_unpacked_dim(
614// CHECK:         %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
615// CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
616// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x64xf32>
617// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
618// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
619// CHECK-SAME:      into %[[EMPTY]] : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
620// CHECK:         %[[PADDED:.+]] = tensor.pad %[[UNPACK]] low[0, 1, 1, 1] high[0, 1, 1, 1]
621
622// -----
623
624func.func @pad_valid_pack_propagation(%arg0: tensor<1x64x56x56xf32>) -> tensor<1x2x58x58x32xf32> {
625  %cst = arith.constant 0.000000e+00 : f32
626  %padded = tensor.pad %arg0 low[0, 0, 1, 1] high[0, 0, 1, 1] {
627    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
628    tensor.yield %cst : f32
629  } : tensor<1x64x56x56xf32> to tensor<1x64x58x58xf32>
630  %0 = tensor.empty() : tensor<1x2x58x58x32xf32>
631  %1 = tensor.pack %padded inner_dims_pos = [1] inner_tiles = [32] into %0 : tensor<1x64x58x58xf32> -> tensor<1x2x58x58x32xf32>
632  return %1 : tensor<1x2x58x58x32xf32>
633}
634
635// CHECK-LABEL: func.func @pad_valid_pack_propagation(
636// CHECK-SAME:     %[[ARG0:.+]]: tensor<1x64x56x56xf32>)
637// CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
638// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x2x56x56x32xf32>
639// CHECK:         %[[PACKED:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [32]
640// CHECK-SAME:      into %[[EMPTY]] : tensor<1x64x56x56xf32> -> tensor<1x2x56x56x32xf32>
641// CHECK:         %[[PADDED:.+]] = tensor.pad %[[PACKED]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
642// CHECK:         return %[[PADDED]]
643
644// -----
645
646func.func @pad_valid_outer_dims_pack_propagation(%arg0: tensor<1x64x56x56xf32>) -> tensor<1x58x58x2x32xf32> {
647  %cst = arith.constant 0.000000e+00 : f32
648  %padded = tensor.pad %arg0 low[0, 0, 1, 1] high[0, 0, 1, 1] {
649    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
650    tensor.yield %cst : f32
651  } : tensor<1x64x56x56xf32> to tensor<1x64x58x58xf32>
652  %0 = tensor.empty() : tensor<1x58x58x2x32xf32>
653  %1 = tensor.pack %padded outer_dims_perm = [0, 3, 2, 1] inner_dims_pos = [1] inner_tiles = [32] into %0 : tensor<1x64x58x58xf32> -> tensor<1x58x58x2x32xf32>
654  return %1 : tensor<1x58x58x2x32xf32>
655}
656
657// CHECK-LABEL: func.func @pad_valid_outer_dims_pack_propagation(
658// CHECK-SAME:     %[[ARG0:.+]]: tensor<1x64x56x56xf32>)
659// CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
660// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x2x32xf32>
661// CHECK:         %[[PACKED:.+]] = tensor.pack %[[ARG0]]
662// CHECK-SAME:      outer_dims_perm = [0, 3, 2, 1] inner_dims_pos = [1] inner_tiles = [32]
663// CHECK-SAME:      into %[[EMPTY]] : tensor<1x64x56x56xf32> -> tensor<1x56x56x2x32xf32>
664// CHECK:         %[[PADDED:.+]] = tensor.pad %[[PACKED]] low[0, 1, 1, 0, 0] high[0, 1, 1, 0, 0]
665// CHECK:         return %[[PADDED]]
666
667// -----
668
669func.func @pad_along_packed_dim(%arg0: tensor<1x60x56x56xf32>) -> tensor<1x2x58x58x32xf32> {
670  %cst = arith.constant 0.000000e+00 : f32
671  %padded = tensor.pad %arg0 low[0, 2, 1, 1] high[0, 2, 1, 1] {
672    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
673    tensor.yield %cst : f32
674  } : tensor<1x60x56x56xf32> to tensor<1x64x58x58xf32>
675  %0 = tensor.empty() : tensor<1x2x58x58x32xf32>
676  %1 = tensor.pack %padded inner_dims_pos = [1] inner_tiles = [32] into %0 : tensor<1x64x58x58xf32> -> tensor<1x2x58x58x32xf32>
677  return %1 : tensor<1x2x58x58x32xf32>
678}
679
680// CHECK-LABEL: func.func @pad_along_packed_dim(
681// CHECK-SAME:     %[[ARG0:.+]]: tensor<1x60x56x56xf32>)
682// CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
683// CHECK:         %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 2, 1, 1] high[0, 2, 1, 1]
684// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x2x58x58x32xf32>
685// CHECK:         tensor.pack %[[PADDED]] inner_dims_pos = [1] inner_tiles = [32]
686// CHECK-SAME:      into %[[EMPTY]] : tensor<1x64x58x58xf32> -> tensor<1x2x58x58x32xf32>
687
688// -----
689
690func.func @multi_use_pad_pack_propagation(%arg0: tensor<1x64x56x56xf32>) -> (tensor<1x64x58x58xf32>, tensor<1x2x58x58x32xf32>) {
691  %cst = arith.constant 0.000000e+00 : f32
692  %padded = tensor.pad %arg0 low[0, 0, 1, 1] high[0, 0, 1, 1] {
693    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
694    tensor.yield %cst : f32
695  } : tensor<1x64x56x56xf32> to tensor<1x64x58x58xf32>
696  %0 = tensor.empty() : tensor<1x2x58x58x32xf32>
697  %1 = tensor.pack %padded inner_dims_pos = [1] inner_tiles = [32] into %0 : tensor<1x64x58x58xf32> -> tensor<1x2x58x58x32xf32>
698  return %padded, %1 : tensor<1x64x58x58xf32>, tensor<1x2x58x58x32xf32>
699}
700
701// CHECK-LABEL: func.func @multi_use_pad_pack_propagation(
702// CHECK-SAME:     %[[ARG0:.+]]: tensor<1x64x56x56xf32>)
703// CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
704// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x2x56x56x32xf32>
705// CHECK:         %[[PACKED:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [32]
706// CHECK-SAME:      into %[[EMPTY]] : tensor<1x64x56x56xf32> -> tensor<1x2x56x56x32xf32>
707// CHECK:         %[[PADDED:.+]] = tensor.pad %[[PACKED]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
708// CHECK:         %[[UNPACKED:.+]] = tensor.unpack %[[PADDED]] inner_dims_pos = [1] inner_tiles = [32]
709// CHECK:         return %[[UNPACKED]], %[[PADDED]]
710
711// -----
712
713#map0 = affine_map<(d0, d1) -> (d0, d1)>
714func.func @would_break_dominance(%arg0: tensor<128x256xi32>) -> tensor<4x16x16x32xi32>{
715  %init = tensor.empty() : tensor<128x256xi32>
716  %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
717      ins(%arg0 : tensor<128x256xi32>)
718      outs(%init : tensor<128x256xi32>) {
719    ^bb0(%arg3: i32, %arg4: i32):
720      %4 = arith.addi %arg3, %arg3 : i32
721      linalg.yield %4 : i32
722  } -> tensor<128x256xi32>
723  %dest = bufferization.alloc_tensor() : tensor<4x16x16x32xi32>
724  %pack = tensor.pack %elem
725    inner_dims_pos = [1, 0]
726    inner_tiles = [16, 32]
727    into %dest : tensor<128x256xi32> -> tensor<4x16x16x32xi32>
728  return %pack : tensor<4x16x16x32xi32>
729}
730
731// CHECK-LABEL: func.func @would_break_dominance(
732// CHECK-SAME:     %[[ARG0:.+]]: tensor<128x256xi32>)
733// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<128x256xi32>
734// CHECK-NEXT:    %[[GEN:.+]] = linalg.generic
735// CHECK-SAME:      ins(%[[ARG0]]
736// CHECK-SAME:      outs(%[[EMPTY]]
737// CHECK:         %[[ALLOC:.+]] = bufferization.alloc_tensor() : tensor<4x16x16x32xi32>
738// CHECK-NEXT:    %{{.+}} = tensor.pack %[[GEN]]
739// CHECK-SAME:      inner_dims_pos = [1, 0] inner_tiles = [16, 32]
740// CHECK-SAME:      into %[[ALLOC]]
741
742// -----
743
744#map0 = affine_map<(d0, d1, d2, d3) -> ()>
745#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
746
747func.func @scalar_tensor(%arg0 : tensor<f32>) -> tensor<1x32x7x7x32xf32> {
748  %empty_gen = tensor.empty() : tensor<1x7x7x1024xf32>
749  %gen = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<f32>) outs(%empty_gen : tensor<1x7x7x1024xf32>) {
750  ^bb0(%in: f32, %out: f32):
751    linalg.yield %in : f32
752  } -> tensor<1x7x7x1024xf32>
753  %empty_pack = tensor.empty() : tensor<1x32x7x7x32xf32>
754  %pack = tensor.pack %gen outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %empty_pack : tensor<1x7x7x1024xf32> -> tensor<1x32x7x7x32xf32>
755  return %pack : tensor<1x32x7x7x32xf32>
756}
757
758// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
759// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
760// CHECK-LABEL: func.func @scalar_tensor
761// CHECK-SAME:     %[[ARG0:.+]]: tensor<f32>)
762// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x32x7x7x32xf32>
763// CHECK:         linalg.generic
764// CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP1]]]
765// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
766// CHECK-SAME:      ins(%[[ARG0]]
767// CHECK-SAME:      outs(%[[EMPTY]]
768
769// -----
770
771#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
772func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x56x56x64xf32> {
773  %init = tensor.empty() : tensor<12x56x56x64xf32>
774  %0 = tensor.empty() : tensor<12x56x56x64xf32>
775  %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] into %0 : tensor<12x64x56x56xf32> -> tensor<12x56x56x64xf32>
776  %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf32>) {
777    ^bb0(%in: f32, %out: f32):
778      %3 = arith.addf %in, %in : f32
779      linalg.yield %3 : f32
780  } -> tensor<12x56x56x64xf32>
781  return %2 : tensor<12x56x56x64xf32>
782}
783
784// CHECK-LABEL: func.func @unpack_empty_inner_dims
785// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack
786// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
787// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
788// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
789// CHECK:         %[[RES:.+]] = linalg.generic
790// CHECK-SAME:      ins(%[[PACKED_ARG0]]
791// CHECK:         %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
792// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
793
794// -----
795
796#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
797#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
798func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>,
799      %arg1: tensor<128x256xi32>) -> tensor<4x16x16x32xi32>{
800  %elem = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "reduction"]}
801      ins(%arg0 : tensor<128x256x32xi32>)
802      outs(%arg1 : tensor<128x256xi32>) {
803    ^bb0(%arg3: i32, %arg4: i32):
804      %4 = arith.addi %arg3, %arg4 : i32
805      linalg.yield %4 : i32
806  } -> tensor<128x256xi32>
807  %dest = tensor.empty() : tensor<4x16x16x32xi32>
808  %pack = tensor.pack %elem
809    inner_dims_pos = [1, 0]
810    inner_tiles = [16, 32]
811    into %dest : tensor<128x256xi32> -> tensor<4x16x16x32xi32>
812  return %pack : tensor<4x16x16x32xi32>
813}
814// CHECK-DAG:  #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
815// CHECK-DAG:  #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
816// CHECK-LABEL: func.func @reduction_pack_transpose_inner_dims
817// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
818// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
819// CHECK:         %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32>
820// CHECK:         %[[PACK_ARG1:.+]] = tensor.pack %[[ARG1]]
821// CHECK-SAME:     inner_dims_pos = [1, 0] inner_tiles = [16, 32]
822// CHECK-SAME:     into %[[ARG1_EMPTY]]
823// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x32x16x32xi32>
824// CHECK:         %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
825// CHECK-SAME:      inner_dims_pos = [1, 0] inner_tiles = [16, 32]
826// CHECK-SAME:      into %[[ARG0_EMPTY]]
827// CHECK:         %[[RED:.+]] = linalg.generic
828// CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
829// CHECK-SAME:      iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel"]
830// CHECK-SAME:      ins(%[[PACK_ARG0]]
831// CHECK-SAME:      outs(%[[PACK_ARG1]]
832// CHECK:         return %[[RED]] : tensor<4x16x16x32xi32>
833
834// -----
835
836func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>,
837  %arg2: tensor<128xi32>, %init_reduction: tensor<100x128x256xi32>) -> tensor<4x16x100x16x32xi32>
838{
839  %reduction = linalg.generic {
840      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
841                       affine_map<(d0, d1, d2, d3) -> (d0)>,
842                       affine_map<(d0, d1, d2, d3) -> (d1)>,
843                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>],
844      iterator_types = ["parallel", "parallel", "reduction", "parallel"]}
845      ins(%arg0, %arg1, %arg2 : tensor<100x128x200x256xi32>, tensor<100xi32>, tensor<128xi32>)
846      outs(%init_reduction : tensor<100x128x256xi32>) {
847    ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
848      %0 = arith.addi %b0, %b1 : i32
849      %1 = arith.addi %0, %b2 : i32
850      %2 = arith.addi %1, %b3 : i32
851      linalg.yield %2 : i32
852    } -> tensor<100x128x256xi32>
853  %init_pack = tensor.empty() : tensor<4x16x100x16x32xi32>
854  %4 = tensor.pack %reduction
855    outer_dims_perm = [1, 2, 0]
856    inner_dims_pos = [2, 1]
857    inner_tiles = [16, 32]
858    into %init_pack : tensor<100x128x256xi32> -> tensor<4x16x100x16x32xi32>
859  return %4 : tensor<4x16x100x16x32xi32>
860}
861
862// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
863// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>
864// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d5)>
865// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d5)>
866// CHECK-LABEL: func.func @reduction_pack_with_outer_dims
867// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
868// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
869// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]
870// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]
871// CHECK:         %[[ARG3_EMPTY:.+]] = tensor.empty() : tensor<4x16x100x16x32xi32>
872// CHECK:         %[[PACKED_ARG3:.+]] = tensor.pack %[[ARG3]]
873// CHECK-SAME:      outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 32]
874// CHECK-SAME:      into %[[ARG3_EMPTY]]
875// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x200x100x16x32xi32>
876// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
877// CHECK-SAME:      outer_dims_perm = [1, 3, 2, 0] inner_dims_pos = [3, 1] inner_tiles = [16, 32]
878// CHECK-SAME:      into %[[ARG0_EMPTY]]
879// CHECK:         %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32>
880// CHECK:         %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]]
881// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [32]
882// CHECK-SAME:      into %[[ARG2_EMPTY]]
883// CHECK:         %[[RES:.+]] = linalg.generic
884// CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
885// CHECK-SAME:      ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]]
886// CHECK-SAME:      outs(%[[PACKED_ARG3]]
887
888// -----
889
890#map0 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>
891#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
892#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d3)>
893func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>,
894    %filter: tensor<2x2xi32>) -> tensor<16x540x960xi32>{
895  %init = tensor.empty() : tensor<16x540x960xi32>
896  %empty = tensor.empty() : tensor<1x16x1080x1920xi32>
897  %unpack = tensor.unpack %arg0
898      inner_dims_pos = [1]
899      inner_tiles = [16]
900      into %empty : tensor<1x1x1080x1920x16xi32> -> tensor<1x16x1080x1920xi32>
901  %pool = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
902      ins(%unpack, %filter : tensor<1x16x1080x1920xi32>, tensor<2x2xi32>)
903      outs(%init : tensor<16x540x960xi32>) {
904    ^bb0(%in: i32, %in_1: i32, %out: i32):
905      %max = arith.maxui %in, %in_1 : i32
906      linalg.yield %max : i32
907  } -> tensor<16x540x960xi32>
908  return %pool : tensor<16x540x960xi32>
909}
910// CHECK-DAG:  #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5, d6)>
911// CHECK-DAG:  #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5)>
912// CHECK-DAG:  #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d3, d6)>
913// CHECK-LABEL: func.func @unpack_different_destination_shape
914// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
915// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
916// CHECK:         %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32>
917// CHECK:         %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
918// CHECK:         %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
919// CHECK:         %[[PACK_ARG0:.+]] = tensor.pack
920// CHECK-SAME:      inner_dims_pos = [1] inner_tiles = [16]
921// CHECK-SAME:      into %[[PACK_EMPTY]]
922// CHECK:         %[[POOL:.+]] = linalg.generic
923// CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
924// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
925// CHECK-SAME:      ins(%[[PACK_ARG0]], %[[ARG1]]
926// CHECK-SAME:      outs(%[[INIT]]
927// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[POOL]]
928// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [16]
929// CHECK-SAME:      into %[[FINAL_RES]]
930// CHECK:         return %[[UNPACK]] : tensor<16x540x960xi32>
931
932// -----
933
934func.func @bubble_up_pack_through_collapse(%1: tensor<?x16x4xf32>, %dim : index) -> tensor<?x4x8x1xf32> {
935  %collapsed = tensor.collapse_shape %1 [[0, 1], [2]] : tensor<?x16x4xf32> into tensor<?x4xf32>
936  %2 = tensor.empty(%dim) : tensor<?x4x8x1xf32>
937  %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
938  func.return %pack : tensor<?x4x8x1xf32>
939}
940// CHECK-LABEL: func.func @bubble_up_pack_through_collapse
941// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
942// CHECK-SAME:      %[[ARG1:[a-zA-Z0-9]+]]
943// CHECK:         %[[C0:.+]] = arith.constant 0 : index
944// CHECK:         %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x16x4xf32>
945// CHECK:         %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x2x4x8x1xf32>
946// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
947// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1xf32>
948// CHECK:         return %[[COLLAPSED]] : tensor<?x4x8x1xf32>
949
950// -----
951
952func.func @bubble_up_pack_through_collapse_empty_outer_dims_perm(%1: tensor<?x16x4xf32>, %dim : index) -> tensor<?x4x8x1xf32> {
953  %collapsed = tensor.collapse_shape %1 [[0, 1], [2]] : tensor<?x16x4xf32> into tensor<?x4xf32>
954  %2 = tensor.empty(%dim) : tensor<?x4x8x1xf32>
955  %pack = tensor.pack %collapsed inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
956  func.return %pack : tensor<?x4x8x1xf32>
957}
958// CHECK-LABEL: func.func @bubble_up_pack_through_collapse_empty_outer_dims_perm
959// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
960// CHECK-SAME:      %[[ARG1:[a-zA-Z0-9]+]]
961// CHECK:         %[[C0:.+]] = arith.constant 0 : index
962// CHECK:         %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x16x4xf32>
963// CHECK:         %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x2x4x8x1xf32>
964// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
965// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1xf32>
966// CHECK:         return %[[COLLAPSED]] : tensor<?x4x8x1xf32>
967
968// -----
969
970func.func @bubble_up_permuted_pack_through_collapse(%1: tensor<4x192x16x256xf32>) -> tensor<4x32x3072x8x1xf32> {
971  %collapsed = tensor.collapse_shape %1 [[0], [1, 2], [3]] : tensor<4x192x16x256xf32> into tensor<4x3072x256xf32>
972  %2 = tensor.empty() : tensor<4x32x3072x8x1xf32>
973  %pack = tensor.pack %collapsed outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 1] into %2 : tensor<4x3072x256xf32> -> tensor<4x32x3072x8x1xf32>
974  func.return %pack : tensor<4x32x3072x8x1xf32>
975}
976// CHECK-LABEL: func.func @bubble_up_permuted_pack_through_collapse
977// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
978// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<4x32x192x16x8x1xf32>
979// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<4x192x16x256xf32> -> tensor<4x32x192x16x8x1xf32>
980// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %pack {{\[}}[0], [1], [2, 3], [4], [5]] : tensor<4x32x192x16x8x1xf32> into tensor<4x32x3072x8x1xf32>
981// CHECK:         return %[[COLLAPSED]] : tensor<4x32x3072x8x1xf32>
982
983// -----
984
985func.func @bubble_up_pack_through_unit_collapse(%1: tensor<1x64x1x4xf32>) -> tensor<8x4x8x1xf32> {
986  %collapsed = tensor.collapse_shape %1 [[0, 1, 2], [3]] : tensor<1x64x1x4xf32> into tensor<64x4xf32>
987  %2 = tensor.empty() : tensor<8x4x8x1xf32>
988  %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<64x4xf32> -> tensor<8x4x8x1xf32>
989  func.return %pack : tensor<8x4x8x1xf32>
990}
991// CHECK-LABEL: func.func @bubble_up_pack_through_unit_collapse
992// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
993// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x8x1x4x8x1xf32>
994// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [1, 3] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<1x64x1x4xf32> -> tensor<1x8x1x4x8x1xf32>
995// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1, 2], [3], [4], [5]] : tensor<1x8x1x4x8x1xf32> into tensor<8x4x8x1xf32>
996// CHECK:         return %[[COLLAPSED]] : tensor<8x4x8x1xf32>
997
998// -----
999
1000func.func @bubble_up_pack_through_collapse_on_outer_dims(%1: tensor<?x16x4xf32>, %dim : index) -> tensor<?x1x4xf32> {
1001  %collapsed = tensor.collapse_shape %1 [[0, 1], [2]] : tensor<?x16x4xf32> into tensor<?x4xf32>
1002  %2 = tensor.empty(%dim) : tensor<?x1x4xf32>
1003  %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [1] inner_tiles = [4] into %2 : tensor<?x4xf32> -> tensor<?x1x4xf32>
1004  func.return %pack : tensor<?x1x4xf32>
1005}
1006// CHECK-LABEL: func.func @bubble_up_pack_through_collapse_on_outer_dims
1007// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1008// CHECK-SAME:      %[[ARG1:[a-zA-Z0-9]+]]
1009// CHECK:         %[[C0:.+]] = arith.constant 0 : index
1010// CHECK:         %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x16x4xf32>
1011// CHECK:         %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x16x1x4xf32>
1012// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [2] inner_tiles = [4] into %[[EMPTY]] : tensor<?x16x4xf32> -> tensor<?x16x1x4xf32>
1013// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1], [2], [3]] : tensor<?x16x1x4xf32> into tensor<?x1x4xf32>
1014// CHECK:         return %[[COLLAPSED]] : tensor<?x1x4xf32>
1015
1016// -----
1017
1018func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4xf32>) -> tensor<384x32x8x8xf32> {
1019  %collapsed = tensor.collapse_shape %1 [[0], [1, 2]] : tensor<3072x64x4xf32> into tensor<3072x256xf32>
1020  %2 = tensor.empty() : tensor<384x32x8x8xf32>
1021  %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %2 : tensor<3072x256xf32> -> tensor<384x32x8x8xf32>
1022  func.return %pack : tensor<384x32x8x8xf32>
1023}
1024// CHECK-LABEL: func.func @no_bubble_up_pack_through_non_divisible_collapse
1025// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1026// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<3072x64x4xf32> into tensor<3072x256xf32>
1027// CHECK:         %[[PACK:.+]] = tensor.pack %[[COLLAPSED]]
1028// CHECK:         return %[[PACK]] : tensor<384x32x8x8xf32>
1029
1030// -----
1031
1032func.func @bubble_up_pack_outer_expanded_through_expand(%arg0: tensor<32x64xf32>) -> tensor<4x2x64x4xf32> {
1033  %empty = tensor.empty() : tensor<4x2x64x4xf32>
1034  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
1035  %pack = tensor.pack %expanded inner_dims_pos = [1] inner_tiles = [4] into %empty : tensor<4x8x64xf32> -> tensor<4x2x64x4xf32>
1036  return %pack : tensor<4x2x64x4xf32>
1037}
1038// CHECK-LABEL: func.func @bubble_up_pack_outer_expanded_through_expand(
1039// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1040// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x64x4xf32>
1041// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]]
1042// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<8x64x4xf32>
1043// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3]]
1044// CHECK-SAME:      output_shape [4, 2, 64, 4] : tensor<8x64x4xf32> into tensor<4x2x64x4xf32>
1045// CHECK:         return %[[EXPANDED]] : tensor<4x2x64x4xf32>
1046
1047// -----
1048
1049func.func @bubble_up_pack_inner_expanded_through_expand(%arg0: tensor<32x64xf32>) -> tensor<32x4x4x4xf32> {
1050  %empty = tensor.empty() : tensor<32x4x4x4xf32>
1051  %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
1052  %pack = tensor.pack %expanded inner_dims_pos = [2] inner_tiles = [4] into %empty : tensor<32x4x16xf32> -> tensor<32x4x4x4xf32>
1053  return %pack : tensor<32x4x4x4xf32>
1054}
1055// CHECK-LABEL: func.func @bubble_up_pack_inner_expanded_through_expand(
1056// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1057// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<32x16x4xf32>
1058// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]]
1059// CHECK-SAME:      inner_dims_pos = [1] inner_tiles = [4] into %[[EMPTY]]
1060// CHECK-SAME:      : tensor<32x64xf32> -> tensor<32x16x4xf32>
1061// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3]]
1062// CHECK-SAME:      output_shape [32, 4, 4, 4] : tensor<32x16x4xf32> into tensor<32x4x4x4xf32>
1063// CHECK:         return %[[EXPANDED]] : tensor<32x4x4x4xf32>
1064
1065// -----
1066
1067func.func @bubble_up_pack_non_expanded_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<8x2x32x16x4xf32> {
1068  %empty = tensor.empty() : tensor<8x2x32x16x4xf32>
1069  %expanded = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [32, 2, 32, 16] : tensor<32x64x16xf32> into tensor<32x2x32x16xf32>
1070  %pack = tensor.pack %expanded inner_dims_pos = [0] inner_tiles = [4] into %empty : tensor<32x2x32x16xf32> -> tensor<8x2x32x16x4xf32>
1071  return %pack : tensor<8x2x32x16x4xf32>
1072}
1073// CHECK-LABEL: func.func @bubble_up_pack_non_expanded_dims_through_expand(
1074// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1075// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x64x16x4xf32>
1076// CHECK:         %[[PACK:.+]] = tensor.pack
1077// CHECK-SAME:      %[[ARG0]] inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]]
1078// CHECK-SAME:      : tensor<32x64x16xf32> -> tensor<8x64x16x4xf32>
1079// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3], [4]]
1080// CHECK-SAME:      output_shape [8, 2, 32, 16, 4] : tensor<8x64x16x4xf32> into tensor<8x2x32x16x4xf32>
1081// CHECK:         return %[[EXPANDED]] : tensor<8x2x32x16x4xf32>
1082
1083// -----
1084
1085func.func @bubble_up_pack_through_expand_dynamic(%arg0: tensor<?x64xf32>) -> tensor<?x4x2x8xf32> {
1086  %c0 = arith.constant 0 : index
1087  %dim = tensor.dim %arg0, %c0 : tensor<?x64xf32>
1088  %empty = tensor.empty(%dim) : tensor<?x4x2x8xf32>
1089  %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [%dim, 4, 16] : tensor<?x64xf32> into tensor<?x4x16xf32>
1090  %pack = tensor.pack %expanded inner_dims_pos = [2] inner_tiles = [8] into %empty : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32>
1091  return %pack : tensor<?x4x2x8xf32>
1092}
1093// CHECK-LABEL: func.func @bubble_up_pack_through_expand_dynamic(
1094// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1095// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
1096// CHECK:         %[[DIM_INPUT:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x64xf32>
1097// CHECK:         %[[EMPTY:.+]] = tensor.empty(%[[DIM_INPUT]]) : tensor<?x8x8xf32>
1098// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]]
1099// CHECK-SAME:      inner_dims_pos = [1] inner_tiles = [8] into %[[EMPTY]]
1100// CHECK-SAME:      : tensor<?x64xf32> -> tensor<?x8x8xf32>
1101// CHECK:         %[[DIM_PACK:.+]] = tensor.dim %[[PACK]], %[[C0]] : tensor<?x8x8xf32>
1102// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3]]
1103// CHECK-SAME:      output_shape [%[[DIM_PACK]], 4, 2, 8] : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
1104// CHECK:         return %[[EXPANDED]] : tensor<?x4x2x8xf32>
1105
1106// -----
1107
1108func.func @bubble_up_pack_non_expanded_padding_through_expand(%arg0: tensor<32x60xf32>) -> tensor<4x2x8x4x8xf32> {
1109  %cst = arith.constant 3.000000e+00 : f32
1110  %empty = tensor.empty() : tensor<4x2x8x4x8xf32>
1111  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x60xf32> into tensor<4x8x60xf32>
1112  %pack = tensor.pack %expanded padding_value(%cst : f32) inner_dims_pos = [1, 2] inner_tiles = [4, 8] into %empty : tensor<4x8x60xf32> -> tensor<4x2x8x4x8xf32>
1113  return %pack : tensor<4x2x8x4x8xf32>
1114}
1115// CHECK-LABEL: func.func @bubble_up_pack_non_expanded_padding_through_expand(
1116// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1117// CHECK-DAG:     %[[CST:.+]] = arith.constant 3.000000e+00 : f32
1118// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32>
1119// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] padding_value(%[[CST]] : f32)
1120// CHECK-SAME:      inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %[[EMPTY]]
1121// CHECK-SAME:      : tensor<32x60xf32> -> tensor<8x8x4x8xf32>
1122// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]]
1123// CHECK-SAME:      output_shape [4, 2, 8, 4, 8] : tensor<8x8x4x8xf32> into tensor<4x2x8x4x8xf32>
1124// CHECK:         return %[[EXPANDED]] : tensor<4x2x8x4x8xf32>
1125
1126// -----
1127
1128func.func @bubble_up_pack_outer_dims_perm_identity_through_expand(%arg0: tensor<32x64xf32>) -> tensor<4x2x32x4x2xf32> {
1129  %empty = tensor.empty() : tensor<4x2x32x4x2xf32>
1130  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
1131  %pack = tensor.pack %expanded outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %empty : tensor<4x8x64xf32> -> tensor<4x2x32x4x2xf32>
1132  return %pack : tensor<4x2x32x4x2xf32>
1133}
1134// CHECK-LABEL: func.func @bubble_up_pack_outer_dims_perm_identity_through_expand(
1135// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1136// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x32x4x2xf32>
1137// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]]
1138// CHECK-SAME:      inner_dims_pos = [0, 1] inner_tiles = [4, 2] into %[[EMPTY]]
1139// CHECK-SAME:      : tensor<32x64xf32> -> tensor<8x32x4x2xf32>
1140// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]]
1141// CHECK-SAME:      output_shape [4, 2, 32, 4, 2] : tensor<8x32x4x2xf32> into tensor<4x2x32x4x2xf32>
1142// CHECK:         return %[[EXPANDED]] : tensor<4x2x32x4x2xf32>
1143
1144// -----
1145
1146func.func @bubble_up_pack_multiple_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<8x2x4x8x4x8x2xf32> {
1147  %empty = tensor.empty() : tensor<8x2x4x8x4x8x2xf32>
1148  %expanded = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [32, 2, 32, 16] : tensor<32x64x16xf32> into tensor<32x2x32x16xf32>
1149  %pack = tensor.pack %expanded inner_dims_pos = [0, 2, 3] inner_tiles = [4, 8, 2] into %empty : tensor<32x2x32x16xf32> -> tensor<8x2x4x8x4x8x2xf32>
1150  return %pack : tensor<8x2x4x8x4x8x2xf32>
1151}
1152// CHECK-LABEL: func.func @bubble_up_pack_multiple_dims_through_expand(
1153// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1154// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x8x4x8x2xf32>
1155// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]]
1156// CHECK-SAME:      inner_dims_pos = [0, 1, 2] inner_tiles = [4, 8, 2] into %[[EMPTY]]
1157// CHECK-SAME:      : tensor<32x64x16xf32> -> tensor<8x8x8x4x8x2xf32>
1158// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3], [4], [5], [6]]
1159// CHECK-SAME:      output_shape [8, 2, 4, 8, 4, 8, 2] : tensor<8x8x8x4x8x2xf32> into tensor<8x2x4x8x4x8x2xf32>
1160// CHECK:         return %[[EXPANDED]] : tensor<8x2x4x8x4x8x2xf32>
1161
1162// -----
1163
1164func.func @bubble_up_pack_inner_dims_reorder_through_expand(%arg0: tensor<32x64xf32>) -> tensor<4x2x4x16x4xf32> {
1165  %empty = tensor.empty() : tensor<4x2x4x16x4xf32>
1166  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
1167  %pack = tensor.pack %expanded inner_dims_pos = [2, 1] inner_tiles = [16, 4] into %empty : tensor<4x8x64xf32> -> tensor<4x2x4x16x4xf32>
1168  return %pack : tensor<4x2x4x16x4xf32>
1169}
1170// CHECK-LABEL: func.func @bubble_up_pack_inner_dims_reorder_through_expand(
1171// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1172// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x4xf32>
1173// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]]
1174// CHECK-SAME:      inner_dims_pos = [1, 0] inner_tiles = [16, 4] into %[[EMPTY]]
1175// CHECK-SAME:      : tensor<32x64xf32> -> tensor<8x4x16x4xf32>
1176// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]]
1177// CHECK-SAME:      output_shape [4, 2, 4, 16, 4] : tensor<8x4x16x4xf32> into tensor<4x2x4x16x4xf32>
1178// CHECK:         return %[[EXPANDED]] : tensor<4x2x4x16x4xf32>
1179
1180// -----
1181
1182func.func @bubble_up_pack_multiple_different_expanded_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<4x2x2x8x16x4x4xf32> {
1183  %empty = tensor.empty() : tensor<4x2x2x8x16x4x4xf32>
1184  %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]] output_shape [4, 8, 2, 32, 16] : tensor<32x64x16xf32> into tensor<4x8x2x32x16xf32>
1185  %pack = tensor.pack %expanded inner_dims_pos = [1, 3] inner_tiles = [4, 4] into %empty : tensor<4x8x2x32x16xf32> -> tensor<4x2x2x8x16x4x4xf32>
1186  return %pack : tensor<4x2x2x8x16x4x4xf32>
1187}
1188// CHECK-LABEL: func.func @bubble_up_pack_multiple_different_expanded_dims_through_expand(
1189// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1190// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x16x16x4x4xf32>
1191// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]]
1192// CHECK-SAME:      inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %[[EMPTY]]
1193// CHECK-SAME:      : tensor<32x64x16xf32> -> tensor<8x16x16x4x4xf32>
1194// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2, 3], [4], [5], [6]]
1195// CHECK-SAME:      output_shape [4, 2, 2, 8, 16, 4, 4] : tensor<8x16x16x4x4xf32> into tensor<4x2x2x8x16x4x4xf32>
1196// CHECK:         return %[[EXPANDED]] : tensor<4x2x2x8x16x4x4xf32>
1197
1198// -----
1199
1200func.func @no_bubble_up_pack_outer_dims_permutation_through_expand(%arg0: tensor<32x64xf32>) -> tensor<32x4x2x4x2xf32> {
1201  %empty = tensor.empty() : tensor<32x4x2x4x2xf32>
1202  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
1203  %pack = tensor.pack %expanded outer_dims_perm = [2, 0, 1] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %empty : tensor<4x8x64xf32> -> tensor<32x4x2x4x2xf32>
1204  return %pack : tensor<32x4x2x4x2xf32>
1205}
1206// CHECK-LABEL: func.func @no_bubble_up_pack_outer_dims_permutation_through_expand(
1207// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1208// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<32x4x2x4x2xf32>
1209// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]]
1210// CHECK-SAME:      output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
1211// CHECK:         %[[PACK:.+]] = tensor.pack %[[EXPANDED]]
1212// CHECK-SAME:      outer_dims_perm = [2, 0, 1] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %[[EMPTY]]
1213// CHECK-SAME:      : tensor<4x8x64xf32> -> tensor<32x4x2x4x2xf32>
1214// CHECK:         return %[[PACK]] : tensor<32x4x2x4x2xf32>
1215
1216// -----
1217
1218func.func @no_bubble_up_pack_multiple_same_expanded_dim_through_expand(%arg0: tensor<32x64xf32>) -> tensor<2x2x64x2x4xf32> {
1219  %empty = tensor.empty() : tensor<2x2x64x2x4xf32>
1220  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
1221  %pack = tensor.pack %expanded inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %empty : tensor<4x8x64xf32> -> tensor<2x2x64x2x4xf32>
1222  return %pack : tensor<2x2x64x2x4xf32>
1223}
1224// CHECK-LABEL: func.func @no_bubble_up_pack_multiple_same_expanded_dim_through_expand(
1225// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1226// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<2x2x64x2x4xf32>
1227// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]]
1228// CHECK-SAME:      output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
1229// CHECK:         %[[PACK:.+]] = tensor.pack %[[EXPANDED]]
1230// CHECK-SAME:      inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %[[EMPTY]]
1231// CHECK-SAME:      : tensor<4x8x64xf32> -> tensor<2x2x64x2x4xf32>
1232// CHECK:         return %[[PACK]] : tensor<2x2x64x2x4xf32>
1233
1234// -----
1235
1236func.func @no_bubble_up_pack_non_innermost_expanded_dim_through_expand(%arg0: tensor<32x64xf32>) -> tensor<2x8x64x2xf32> {
1237  %empty = tensor.empty() : tensor<2x8x64x2xf32>
1238  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
1239  %pack = tensor.pack %expanded inner_dims_pos = [0] inner_tiles = [2] into %empty : tensor<4x8x64xf32> -> tensor<2x8x64x2xf32>
1240  return %pack : tensor<2x8x64x2xf32>
1241}
1242// CHECK-LABEL: func.func @no_bubble_up_pack_non_innermost_expanded_dim_through_expand(
1243// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1244// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<2x8x64x2xf32>
1245// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]]
1246// CHECK-SAME:      output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
1247// CHECK:         %[[PACK:.+]] = tensor.pack %[[EXPANDED]]
1248// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [2] into %[[EMPTY]]
1249// CHECK-SAME:      : tensor<4x8x64xf32> -> tensor<2x8x64x2xf32>
1250// CHECK:         return %[[PACK]] : tensor<2x8x64x2xf32>
1251
1252// -----
1253
1254func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(%arg0: tensor<30x60xf32>) -> tensor<3x2x60x8xf32> {
1255  %cst = arith.constant 3.000000e+00 : f32
1256  %empty = tensor.empty() : tensor<3x2x60x8xf32>
1257  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [3, 10, 60] : tensor<30x60xf32> into tensor<3x10x60xf32>
1258  %pack = tensor.pack %expanded padding_value(%cst : f32) inner_dims_pos = [1] inner_tiles = [8] into %empty : tensor<3x10x60xf32> -> tensor<3x2x60x8xf32>
1259  return %pack : tensor<3x2x60x8xf32>
1260}
1261// CHECK-LABEL: func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(
1262// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1263// CHECK-DAG:     %[[CST:.+]] = arith.constant 3.000000e+00 : f32
1264// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<3x2x60x8xf32>
1265// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]]
1266// CHECK-SAME:      output_shape [3, 10, 60] : tensor<30x60xf32> into tensor<3x10x60xf32>
1267// CHECK:         %[[PACK:.+]] = tensor.pack %[[EXPANDED]] padding_value(%[[CST]] : f32)
1268// CHECK-SAME:      inner_dims_pos = [1] inner_tiles = [8] into %[[EMPTY]]
1269// CHECK-SAME:      : tensor<3x10x60xf32> -> tensor<3x2x60x8xf32>
1270// CHECK:         return %[[PACK]] : tensor<3x2x60x8xf32>
1271
1272// -----
1273
1274func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(%arg0: tensor<32x64xf32>) -> tensor<8x4x16x8xf32> {
1275  %empty = tensor.empty() : tensor<8x4x16x8xf32>
1276  %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
1277  %pack = tensor.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
1278  return %pack : tensor<8x4x16x8xf32>
1279}
1280// CHECK-LABEL: func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(
1281// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1282// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x8xf32>
1283// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]]
1284// CHECK-SAME:      output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
1285// CHECK:         %[[PACK:.+]] = tensor.pack %[[EXPANDED]]
1286// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]]
1287// CHECK-SAME:      : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
1288// CHECK:         return %[[PACK]] : tensor<8x4x16x8xf32>
1289
1290// -----
1291
1292func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
1293  %6 = tensor.empty(%dim) : tensor<?x256xf32>
1294  %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
1295  %expanded = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [%sz0, 256, 256] : tensor<?x256xf32> into tensor<?x256x256xf32>
1296  func.return %expanded : tensor<?x256x256xf32>
1297}
1298// CHECK-LABEL: func.func @push_down_unpack_through_expand
1299// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1300// CHECK-SAME:      %[[ARG1:[a-zA-Z0-9]+]]
1301// CHECK:         %[[C32:.+]] = arith.constant 32 : index
1302// CHECK:         %[[C0:.+]] = arith.constant 0 : index
1303// CHECK:         %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x32x8x8xf32>
1304// CHECK:         %[[SZ0:.+]] = arith.divsi %[[DIM0]], %[[C32]] : index
1305// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] output_shape [%[[SZ0]], 32, 32, 8, 8] : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
1306// CHECK:         %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x32x32x8x8xf32>
1307// CHECK:         %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
1308// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
1309// CHECK:         return %[[UNPACK]] : tensor<?x256x256xf32>
1310
1311// -----
1312
1313func.func @push_down_unpack_through_expand_empty_outer_dims_perm(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
1314  %6 = tensor.empty(%dim) : tensor<?x256xf32>
1315  %unpack = tensor.unpack %5 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
1316  %expanded = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [%sz0, 256, 256] : tensor<?x256xf32> into tensor<?x256x256xf32>
1317  func.return %expanded : tensor<?x256x256xf32>
1318}
1319// CHECK-LABEL: func.func @push_down_unpack_through_expand_empty_outer_dims_perm
1320// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1321// CHECK-SAME:      %[[ARG1:[a-zA-Z0-9]+]]
1322// CHECK:         %[[C32:.+]] = arith.constant 32 : index
1323// CHECK:         %[[C0:.+]] = arith.constant 0 : index
1324// CHECK:         %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x32x8x8xf32>
1325// CHECK:         %[[SZ0:.+]] = arith.divsi %[[DIM0]], %[[C32]] : index
1326// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] output_shape [%[[SZ0]], 32, 32, 8, 8] : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
1327// CHECK:         %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x32x32x8x8xf32>
1328// CHECK:         %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
1329// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
1330// CHECK:         return %[[UNPACK]] : tensor<?x256x256xf32>
1331
1332// -----
1333
1334func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>) -> tensor<4x12x256x256xf32> {
1335  %6 = tensor.empty() : tensor<4x3072x256xf32>
1336  %unpack = tensor.unpack %5 outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 8] into %6 : tensor<4x32x384x8x8xf32> -> tensor<4x3072x256xf32>
1337  %expanded = tensor.expand_shape %unpack [[0], [1, 2], [3]] output_shape [4, 12, 256, 256] : tensor<4x3072x256xf32> into tensor<4x12x256x256xf32>
1338  func.return %expanded : tensor<4x12x256x256xf32>
1339}
1340// CHECK-LABEL: @push_down_permuted_unpack_through_expand
1341// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1342// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3], [4], [5]] output_shape [4, 32, 12, 32, 8, 8] : tensor<4x32x384x8x8xf32> into tensor<4x32x12x32x8x8xf32>
1343// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<4x12x256x256xf32>
1344// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<4x32x12x32x8x8xf32> -> tensor<4x12x256x256xf32>
1345// CHECK:         return %[[UNPACK]] : tensor<4x12x256x256xf32>
1346
1347// -----
1348
1349func.func @push_down_unpack_through_unit_expand(%5: tensor<6x32x8x8xf32>) -> tensor<3x16x1x256xf32> {
1350  %6 = tensor.empty() : tensor<48x256xf32>
1351  %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<6x32x8x8xf32> -> tensor<48x256xf32>
1352  %expanded = tensor.expand_shape %unpack [[0, 1, 2], [3]] output_shape [3, 16, 1, 256] : tensor<48x256xf32> into tensor<3x16x1x256xf32>
1353  func.return %expanded : tensor<3x16x1x256xf32>
1354}
1355// CHECK-LABEL: func.func @push_down_unpack_through_unit_expand
1356// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1357// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4], [5]] output_shape [3, 2, 1, 32, 8, 8] : tensor<6x32x8x8xf32> into tensor<3x2x1x32x8x8xf32>
1358// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<3x16x1x256xf32>
1359// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [1, 3] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<3x2x1x32x8x8xf32> -> tensor<3x16x1x256xf32>
1360// CHECK:         return %[[UNPACK]] : tensor<3x16x1x256xf32>
1361
1362// -----
1363
1364func.func @push_down_unpack_through_expand_on_outer_dims(%5: tensor<?x32x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
1365  %6 = tensor.empty(%dim) : tensor<?x256xf32>
1366  %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [1] inner_tiles = [8] into %6 : tensor<?x32x8xf32> -> tensor<?x256xf32>
1367  %expanded = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [%sz0, 256, 256] : tensor<?x256xf32> into tensor<?x256x256xf32>
1368  func.return %expanded : tensor<?x256x256xf32>
1369}
1370// CHECK-LABEL: func.func @push_down_unpack_through_expand_on_outer_dims
1371// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1372// CHECK-SAME:      %[[ARG1:[a-zA-Z0-9]+]]
1373// CHECK:         %[[C256:.+]] = arith.constant 256 : index
1374// CHECK:         %[[C0:.+]] = arith.constant 0 : index
1375// CHECK:         %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x32x8xf32>
1376// CHECK:         %[[SZ0:.+]] = arith.divsi %[[DIM0]], %[[C256]] : index
1377// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]] output_shape [%[[SZ0]], 256, 32, 8] : tensor<?x32x8xf32> into tensor<?x256x32x8xf32>
1378// CHECK:         %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x256x32x8xf32>
1379// CHECK:         %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
1380// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [2] inner_tiles = [8] into %[[EMPTY]] : tensor<?x256x32x8xf32> -> tensor<?x256x256xf32>
1381// CHECK:         return %[[UNPACK]] : tensor<?x256x256xf32>
1382
1383// -----
1384
1385func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x8xf32>) -> tensor<256x12x256xf32> {
1386  %6 = tensor.empty() : tensor<3072x256xf32>
1387  %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<384x32x8x8xf32> -> tensor<3072x256xf32>
1388  %expanded = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
1389  func.return %expanded : tensor<256x12x256xf32>
1390}
1391// CHECK-LABEL: func.func @no_push_down_unpack_through_non_divisible_expand
1392// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
1393// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
1394// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
1395// CHECK:         return %[[EXPANDED]] : tensor<256x12x256xf32>
1396