xref: /llvm-project/mlir/test/Dialect/Linalg/transform-lower-pack.mlir (revision d6590c1bcb1b15b3b3f9f0ee6f0a6ff2b10b1e4f)
1// RUN: mlir-opt %s -transform-interpreter -cse -verify-diagnostics -split-input-file | FileCheck %s
2
3  // CHECK-LABEL: func.func @pack(
4func.func @pack(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {
5  %cst_0 = arith.constant 0.0 : f32
6
7  // tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
8  //      CHECK: tensor.pad {{.*}} low[0, 0, 0, 0]
9  //      CHECK:   : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
10  //      CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3], [4], [5]]
11  // CHECK-SAME:   : tensor<136x64x16x16xf32> into tensor<17x8x2x32x16x16xf32>
12  //      CHECK: linalg.transpose
13  // CHECK-SAME:   ins(%{{.*}} : tensor<17x8x2x32x16x16xf32>)
14  // CHECK-SAME:   outs(%{{.*}} : tensor<17x2x16x16x32x8xf32>)
15  // CHECK-SAME:   permutation = [0, 2, 4, 5, 3, 1]
16  %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
17    : tensor<129x47x16x16xf32> -> tensor<17x2x16x16x32x8xf32>
18  return %pack : tensor<17x2x16x16x32x8xf32>
19}
20
21module attributes {transform.with_named_sequence} {
22  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
23    %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
24      : (!transform.any_op) -> !transform.op<"tensor.pack">
25    transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
26      -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
27      transform.yield
28  }
29}
30
31// -----
32
33  // CHECK-LABEL: func.func @pack(
34func.func @pack(%arg0: tensor<128x8xf32>, %arg1: tensor<8x8x16x1xf32>) -> tensor<8x8x16x1xf32> {
35
36  // tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
37  //      CHECK: tensor.pad {{.*}} low[0, 0]
38  //      CHECK:   : tensor<128x8xf32> to tensor<128x8xf32>
39  //      CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3]]
40  // CHECK-SAME:   : tensor<128x8xf32> into tensor<8x16x8x1xf32>
41  //      CHECK: linalg.transpose
42  // CHECK-SAME:   ins(%{{.*}} : tensor<8x16x8x1xf32>)
43  // CHECK-SAME:   outs(%{{.*}} : tensor<8x8x16x1xf32>)
44  // CHECK-SAME:   permutation = [0, 2, 1, 3]
45
46  %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %arg1
47    : tensor<128x8xf32> -> tensor<8x8x16x1xf32>
48
49  return %pack : tensor<8x8x16x1xf32>
50}
51
52module attributes {transform.with_named_sequence} {
53  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
54    %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
55      : (!transform.any_op) -> !transform.op<"tensor.pack">
56    transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
57      -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
58      transform.yield
59  }
60}
61
62// -----
63
64// CHECK-LABEL: func.func @pack_as_pad(
65// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
66// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
67func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
68  %cst_0 = arith.constant 0.0 : f32
69
70  // tensor.pack is lowered to tensor.pad + tensor.insert_slice
71  //      CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0]
72  //      CHECK:   : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
73  //      CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]]
74  // offsets.
75  // CHECK-SAME:   [0, 0, 0, 0, 0, 0, 0, 0]
76  // sizes.
77  // CHECK-SAME:   [1, 1, 1, 1, 136, 64, 16, 16]
78  // strides multipliers.
79  // CHECK-SAME:   [1, 1, 1, 1, 1, 1, 1, 1]
80  // CHECK-SAME:   : tensor<136x64x16x16xf32> into tensor<1x1x1x1x136x64x16x16xf32>
81  //      CHECK: return %[[RES]]
82  %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
83    : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
84  return %pack :  tensor<1x1x1x1x136x64x16x16xf32>
85}
86
87module attributes {transform.with_named_sequence} {
88  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
89    %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
90      : (!transform.any_op) -> !transform.op<"tensor.pack">
91    transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
92      -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
93      transform.yield
94  }
95}
96
97// -----
98
99// This is same as pack_as_pad but since we explicitly added {lowerPadLikeWithInsertSlice = false}, it should not
100// be lowered to insert_slice.
101// CHECK-LABEL: func.func @pack_as_pad_disabled_insert_slice(
102func.func @pack_as_pad_disabled_insert_slice(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
103  %cst_0 = arith.constant 0.0 : f32
104  // tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
105  // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<129x47x16x16xf32>
106  //  CHECK-DAG: %[[PAD:.*]] = tensor.pad %[[ARG0]]
107  //  CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]]
108  //      CHECK: %[[PAD_EXPANDED:.*]] = tensor.expand_shape %[[PAD]]
109  //  CHECK-DAG: %[[RES:.*]] = linalg.transpose ins(%[[PAD_EXPANDED]]
110  %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
111    : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
112  return %pack :  tensor<1x1x1x1x136x64x16x16xf32>
113}
114
115module attributes {transform.with_named_sequence} {
116  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
117    %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
118      : (!transform.any_op) -> !transform.op<"tensor.pack">
119    transform.structured.lower_pack %pack {lowerPadLikeWithInsertSlice = false}: (!transform.op<"tensor.pack">)
120      -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
121      transform.yield
122  }
123}
124
125// -----
126
127// Check that we don't lower the following pack as a pad.
128// Although all the outer most dimensions in the resulting shape are 1s,
129// some of the original dimensions are not part of the inner_dims_pos, hence
130// some transpose needs to happen.
131// CHECK-LABEL: func.func @pack_not_a_pad(
132func.func @pack_not_a_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x16x16x136x64xf32>) -> tensor<1x1x16x16x136x64xf32> {
133  %cst_0 = arith.constant 0.0 : f32
134
135  //      CHECK: tensor.pad {{.*}} low[0, 0, 0, 0]
136  //      CHECK:   : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
137  //      CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3], [4], [5]]
138  // CHECK-SAME:   : tensor<136x64x16x16xf32> into tensor<1x136x1x64x16x16xf32>
139  //      CHECK: linalg.transpose
140  // CHECK-SAME:   ins(%{{.*}} : tensor<1x136x1x64x16x16xf32>)
141  // CHECK-SAME:   outs(%{{.*}} : tensor<1x1x16x16x136x64xf32>)
142  // CHECK-SAME:   permutation = [0, 2, 4, 5, 1, 3]
143
144  %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1] inner_tiles = [136, 64] into %arg1
145    : tensor<129x47x16x16xf32> -> tensor<1x1x16x16x136x64xf32>
146  return %pack :  tensor<1x1x16x16x136x64xf32>
147}
148
149module attributes {transform.with_named_sequence} {
150  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
151    %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
152      : (!transform.any_op) -> !transform.op<"tensor.pack">
153    transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
154      -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
155      transform.yield
156  }
157}
158
159// -----
160// CHECK-LABEL: func.func @unpack(
161func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
162  %cst_0 = arith.constant 0.0 : f32
163  // CHECK-SAME: %[[ARG0:.*]]: tensor<17x2x16x16x32x8xf32>, %[[ARG1:.*]]: tensor<129x47x16x16xf32>
164  //      CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<17x8x2x32x16x16xf32>
165  //      CHECK: %[[TRAN:.*]] = linalg.transpose
166  // CHECK-SAME:    ins(%[[ARG0]] : tensor<17x2x16x16x32x8xf32>)
167  // CHECK-SAME:   outs(%[[EMPTY]] : tensor<17x8x2x32x16x16xf32>)
168  // CHECK-SAME:   permutation = [0, 5, 1, 4, 2, 3]
169  //      CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3], [4], [5]]
170  // CHECK-SAME:   : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32>
171  //      CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]
172  // CHECK-SAME:   : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32>
173  //      CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>)
174  // CHECK-SAME:        outs(%[[ARG1]] : tensor<129x47x16x16xf32>)
175  %unpack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
176    : tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32>
177  return %unpack : tensor<129x47x16x16xf32>
178}
179
180module attributes {transform.with_named_sequence} {
181  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
182    %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
183      : (!transform.any_op) -> !transform.op<"tensor.unpack">
184    transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
185      -> (!transform.op<"tensor.empty">,
186          !transform.op<"linalg.transpose">,
187          !transform.op<"tensor.collapse_shape">,
188          !transform.op<"tensor.extract_slice">)
189          transform.yield
190  }
191}
192
193// -----
194
195// CHECK-LABEL: func.func @unpack_with_identity_outer_dims_perm(
196func.func @unpack_with_identity_outer_dims_perm(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
197  %cst_0 = arith.constant 0.0 : f32
198  // CHECK-SAME: %[[ARG0:.*]]: tensor<17x2x16x16x32x8xf32>, %[[ARG1:.*]]: tensor<129x47x16x16xf32>
199  //      CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<17x8x2x32x16x16xf32>
200  //      CHECK: %[[TRAN:.*]] = linalg.transpose
201  // CHECK-SAME:    ins(%[[ARG0]] : tensor<17x2x16x16x32x8xf32>)
202  // CHECK-SAME:   outs(%[[EMPTY]] : tensor<17x8x2x32x16x16xf32>)
203  // CHECK-SAME:   permutation = [0, 5, 1, 4, 2, 3]
204  //      CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3], [4], [5]]
205  // CHECK-SAME:   : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32>
206  //      CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]
207  // CHECK-SAME:   : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32>
208  //      CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>)
209  // CHECK-SAME:        outs(%[[ARG1]] : tensor<129x47x16x16xf32>)
210  %unpack = tensor.unpack %arg0 outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
211    : tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32>
212  return %unpack : tensor<129x47x16x16xf32>
213}
214
215module attributes {transform.with_named_sequence} {
216  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
217    %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
218      : (!transform.any_op) -> !transform.op<"tensor.unpack">
219    transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
220      -> (!transform.op<"tensor.empty">,
221          !transform.op<"linalg.transpose">,
222          !transform.op<"tensor.collapse_shape">,
223          !transform.op<"tensor.extract_slice">)
224          transform.yield
225  }
226}
227
228// -----
229
230// When an unpack is a plain 'unpad', lower it to a simple extract_slice.
231// CHECK-LABEL: func.func @unpack_as_pad(
232func.func @unpack_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
233  %cst_0 = arith.constant 0.0 : f32
234
235  // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
236  //      CHECK: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
237  // offsets.
238  // CHECK-SAME:   [0, 0, 0, 0, 0, 0, 0, 0]
239  // sizes.
240  // CHECK-SAME:   [1, 1, 1, 1, 129, 47, 16, 16]
241  // strides multiplers.
242  // CHECK-SAME:   [1, 1, 1, 1, 1, 1, 1, 1]
243  // CHECK-SAME:   : tensor<1x1x1x1x136x64x16x16xf32> to tensor<129x47x16x16xf32>
244  %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
245    : tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32>
246  return %pack : tensor<129x47x16x16xf32>
247}
248
249module attributes {transform.with_named_sequence} {
250  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
251    %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
252      : (!transform.any_op) -> !transform.op<"tensor.unpack">
253    transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
254      -> (!transform.op<"tensor.empty">,
255          !transform.op<"linalg.transpose">,
256          !transform.op<"tensor.collapse_shape">,
257          !transform.op<"tensor.extract_slice">)
258          transform.yield
259  }
260}
261
262// -----
263
264// This is same as upack_as_pad but since we explicitly added {lowerUnpadLikeWithExtractSlice = false}, it should not
265// be lowered to extract_slice.
266// CHECK-LABEL: func.func @unpack_as_pad_disabled_extract_slice(
267func.func @unpack_as_pad_disabled_extract_slice(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
268  %cst_0 = arith.constant 0.0 : f32
269
270  // tensor.unpack is lowered to tensor.extract_slice + linalg.transpose + tensor.collapse_shape
271  // CHECK-DAG: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
272  // CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
273  //     CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[ARG0]]
274  //     CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[TRANSPOSED]]
275  // CHECK-DAG: %[[RES:.*]] = tensor.extract_slice %[[COLLAPSED]]
276  %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
277    : tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32>
278  return %pack : tensor<129x47x16x16xf32>
279}
280
281module attributes {transform.with_named_sequence} {
282  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
283    %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
284      : (!transform.any_op) -> !transform.op<"tensor.unpack">
285    transform.structured.lower_unpack %unpack {lowerUnpadLikeWithExtractSlice = false}: (!transform.op<"tensor.unpack">)
286      -> (!transform.op<"tensor.empty">,
287          !transform.op<"linalg.transpose">,
288          !transform.op<"tensor.collapse_shape">,
289          !transform.op<"tensor.extract_slice">)
290          transform.yield
291  }
292}
293
294// -----
295
296// CHECK-LABEL: func.func @pack_with_outer_dims_perm(
297func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>,
298                                     %dest: tensor<200x4x16x100x16x32xi32>)
299    -> tensor<200x4x16x100x16x32xi32> {
300  //      CHECK: tensor.pad {{.*}} low[0, 0, 0, 0]
301  //      CHECK:   : tensor<100x200x128x256xi32> to tensor<100x200x128x256xi32>
302  //      CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0], [1], [2, 3], [4, 5]]
303  // CHECK-SAME:   : tensor<100x200x128x256xi32> into tensor<100x200x4x32x16x16xi32>
304  //      CHECK: linalg.transpose
305  // CHECK-SAME:   ins(%{{.*}} : tensor<100x200x4x32x16x16xi32>)
306  // CHECK-SAME:   outs(%{{.*}} : tensor<200x4x16x100x16x32xi32>)
307  // CHECK-SAME:   permutation = [1, 2, 4, 0, 5, 3]
308  %0 = tensor.pack %src
309    outer_dims_perm = [1, 2, 3, 0]
310    inner_dims_pos = [3, 2]
311    inner_tiles = [16, 32]
312    into %dest : tensor<100x200x128x256xi32> -> tensor<200x4x16x100x16x32xi32>
313  return %0 : tensor<200x4x16x100x16x32xi32>
314}
315
316module attributes {transform.with_named_sequence} {
317  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
318    %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
319      : (!transform.any_op) -> !transform.op<"tensor.pack">
320    transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
321      -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
322      transform.yield
323  }
324}
325
326// -----
327
328// CHECK-LABEL: func.func @pack_with_pad(
329func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x16x16x1xf32>)
330    -> tensor<265x16x16x1xf32> {
331  //      CHECK: tensor.pad {{.*}} low[0, 0]
332  //      CHECK:   : tensor<4225x12xf32> to tensor<4240x16xf32>
333  //      CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3]]
334  // CHECK-SAME:   : tensor<4240x16xf32> into tensor<265x16x16x1xf32>
335  //      CHECK: linalg.transpose
336  // CHECK-SAME:   ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
337  // CHECK-SAME:   outs(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
338  // CHECK-SAME:   permutation = [0, 2, 1, 3]
339  %cst = arith.constant 0.000000e+00 : f32
340  %0 = tensor.pack %src
341    padding_value(%cst : f32)
342    inner_dims_pos = [0, 1]
343    inner_tiles = [16, 1] into %dest
344    : tensor<4225x12xf32> -> tensor<265x16x16x1xf32>
345  return %0 : tensor<265x16x16x1xf32>
346}
347
348module attributes {transform.with_named_sequence} {
349  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
350    %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
351      : (!transform.any_op) -> !transform.op<"tensor.pack">
352    transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
353      -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
354      transform.yield
355  }
356}
357
358// -----
359
360// CHECK-LABEL: func.func @pack_with_pad_and_outer_dims_perm(
361func.func @pack_with_pad_and_outer_dims_perm(%src: tensor<100x200x127x255xi32>,
362                                             %dest: tensor<200x4x16x100x16x32xi32>)
363    -> tensor<200x4x16x100x16x32xi32> {
364  //      CHECK: tensor.pad {{.*}} low[0, 0, 0, 0]
365  //      CHECK:   : tensor<100x200x127x255xi32> to tensor<100x200x128x256xi32>
366  //      CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0], [1], [2, 3], [4, 5]]
367  // CHECK-SAME:   : tensor<100x200x128x256xi32> into tensor<100x200x4x32x16x16xi32>
368  //      CHECK: linalg.transpose
369  // CHECK-SAME:   ins(%{{.*}} : tensor<100x200x4x32x16x16xi32>)
370  // CHECK-SAME:   outs(%{{.*}} : tensor<200x4x16x100x16x32xi32>)
371  // CHECK-SAME:   permutation = [1, 2, 4, 0, 5, 3]
372  %cst_0 = arith.constant 0 : i32
373  %0 = tensor.pack %src
374    padding_value(%cst_0 : i32)
375    outer_dims_perm = [1, 2, 3, 0]
376    inner_dims_pos = [3, 2]
377    inner_tiles = [16, 32]
378    into %dest : tensor<100x200x127x255xi32> -> tensor<200x4x16x100x16x32xi32>
379  return %0 : tensor<200x4x16x100x16x32xi32>
380}
381
382module attributes {transform.with_named_sequence} {
383  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
384    %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
385      : (!transform.any_op) -> !transform.op<"tensor.pack">
386    transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
387      -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
388      transform.yield
389  }
390}
391
392// -----
393
394// CHECK-DAG:   #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * 16 - s1)>
395// CHECK-DAG:   #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 32 - s1)>
396// CHECK:       func.func @dynamic_pack_pad_transpose_inner_and_outer_dims(
397// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
398func.func @dynamic_pack_pad_transpose_inner_and_outer_dims(%source: tensor<?x?xf32>) -> tensor<?x?x16x32xf32> {
399  // CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
400  // CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
401  // CHECK-DAG:   %[[C16:.+]] = arith.constant 16 : index
402  // CHECK-DAG:   %[[C32:.+]] = arith.constant 32 : index
403  // CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[SRC]], %[[C0]]
404  // CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[SRC]], %[[C1]]
405  // CHECK-DAG:   %[[OUT_D0:.+]] = arith.ceildivui %[[D1]], %[[C16]] : index
406  // CHECK-DAG:   %[[OUT_D1:.+]] = arith.ceildivui %[[D0]], %[[C32]] : index
407  // CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty(%[[OUT_D0]], %[[OUT_D1]]) : tensor<?x?x16x32xf32>
408  // CHECK-DAG:   %[[DEST_D0:.+]] = tensor.dim %[[EMPTY]], %[[C0]]
409  // CHECK-DAG:   %[[DEST_D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]]
410  // CHECK-DAG:   %[[H1:.+]] = affine.apply #[[MAP0]]()[%[[DEST_D0]], %[[D1]]]
411  // CHECK-DAG:   %[[H0:.+]] = affine.apply #[[MAP1]]()[%[[DEST_D1]], %[[D0]]]
412  // CHECK:       %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high[%[[H0]], %[[H1]]]
413  // CHECK:         : tensor<?x?xf32> to tensor<?x?xf32>
414  // CHECK:       %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0, 1], [2, 3]]
415  // CHECK-SAME:   : tensor<?x?xf32> into tensor<?x32x?x16xf32>
416  // CHECK:       %[[TRANSP:.+]] = linalg.transpose
417  // CHECK-SAME:    ins(%[[EXPAND]] : tensor<?x32x?x16xf32>)
418  // CHECK-SAME:    outs(%[[EMPTY]] : tensor<?x?x16x32xf32>)
419  // CHECK-SAME:    permutation = [2, 0, 3, 1]
420  // CHECK:       return %[[TRANSP]]
421  %c0 = arith.constant 0 : index
422  %c1 = arith.constant 1 : index
423  %d0 = tensor.dim %source, %c0 : tensor<?x?xf32>
424  %d1 = tensor.dim %source, %c1 : tensor<?x?xf32>
425  %padding_value = arith.constant 0.0 : f32
426
427  %c16 = arith.constant 16 : index
428  %c32 = arith.constant 32 : index
429  %tiled_d0 = arith.ceildivui %d0, %c32 : index
430  %tiled_d1 = arith.ceildivui %d1, %c16 : index
431  %init_pack = tensor.empty(%tiled_d1, %tiled_d0) : tensor<?x?x16x32xf32>
432  %pack = tensor.pack %source padding_value(%padding_value : f32)
433      outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %init_pack
434      : tensor<?x?xf32> -> tensor<?x?x16x32xf32>
435  return %pack : tensor<?x?x16x32xf32>
436}
437
438module attributes {transform.with_named_sequence} {
439  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
440    %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
441      : (!transform.any_op) -> !transform.op<"tensor.pack">
442    transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
443      -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
444      transform.yield
445  }
446}
447
448// -----
449
450// CHECK-LABEL: func.func @pack_as_pad_with_outer_dims_perm(
451// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
452// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
453func.func @pack_as_pad_with_outer_dims_perm(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
454  %cst_0 = arith.constant 0.0 : f32
455
456  // tensor.pack is lowered to tensor.pad + tensor.insert_slice
457  //      CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0]
458  //      CHECK:   : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
459  //      CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]]
460  // offsets.
461  // CHECK-SAME:   [0, 0, 0, 0, 0, 0, 0, 0]
462  // sizes.
463  // CHECK-SAME:   [1, 1, 1, 1, 136, 64, 16, 16]
464  // strides multipliers.
465  // CHECK-SAME:   [1, 1, 1, 1, 1, 1, 1, 1]
466  // CHECK-SAME:   : tensor<136x64x16x16xf32> into tensor<1x1x1x1x136x64x16x16xf32>
467  //      CHECK: return %[[RES]]
468  %pack = tensor.pack %arg0
469    padding_value(%cst_0 : f32)
470    outer_dims_perm = [1, 2, 3, 0]
471    inner_dims_pos = [0, 1, 2, 3]
472    inner_tiles = [136, 64, 16, 16]
473    into %arg1 : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
474  return %pack :  tensor<1x1x1x1x136x64x16x16xf32>
475}
476
477module attributes {transform.with_named_sequence} {
478  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
479    %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
480      : (!transform.any_op) -> !transform.op<"tensor.pack">
481    transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
482      -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
483      transform.yield
484  }
485}
486
487// -----
488
489// CHECK-LABEL: func.func @pack_as_pad_with_unit_dims(
490// CHECK: %[[SRC:.+]]: tensor<3x1x1x1xf32>,
491// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x8x1xf32>)
492func.func @pack_as_pad_with_unit_dims(%arg0: tensor<3x1x1x1xf32>, %arg1: tensor<1x1x1x1x8x1xf32>) -> (tensor<1x1x1x1x8x1xf32>) {
493  %zero = arith.constant 0.0 : f32
494
495  // CHECK:      %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[5, 0, 0, 0] {
496  // CHECK:        : tensor<3x1x1x1xf32> to tensor<8x1x1x1xf32>
497  // CHECK:      %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] [{{.*}}[0, 1], [2, 3], [4], [5]]
498  // CHECK-SAME:   tensor<8x1x1x1xf32> into tensor<1x8x1x1x1x1xf32>
499  // CHECK:      %[[TRANSPOSED:.+]] = linalg.transpose
500  // CHECK-SAME:   ins(%[[EXPAND]] : tensor<1x8x1x1x1x1xf32>)
501  // CHECK-SAME:   outs(%[[OUT]] : tensor<1x1x1x1x8x1xf32>)
502  // CHECK-SAME:   permutation = [0, 2, 4, 5, 1, 3]
503  // CHECK:      return %[[TRANSPOSED]] : tensor<1x1x1x1x8x1xf32>
504  %pack = tensor.pack %arg0
505      padding_value(%zero : f32)
506      inner_dims_pos = [0, 1]
507      inner_tiles = [8, 1] into %arg1 : tensor<3x1x1x1xf32> -> tensor<1x1x1x1x8x1xf32>
508
509  return %pack : tensor<1x1x1x1x8x1xf32>
510}
511
512
513module attributes {transform.with_named_sequence} {
514  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
515    %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
516      : (!transform.any_op) -> !transform.op<"tensor.pack">
517    transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
518      -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
519      transform.yield
520  }
521}
522
523// -----
524
525// Check that we can lower unpack with dynamic dimensions in the destination.
526// CHECK-LABEL: func.func @unpack_with_dynamic_dest(
527// CHECK-SAME: %[[ARG0:.*]]: tensor<32x2x49x16x16xf32>, %[[ARG1:.*]]: tensor<32x?x?xf32>)
528//      CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<32x2x16x49x16xf32>
529//      CHECK: %[[TRAN:.*]] = linalg.transpose
530// CHECK-SAME:    ins(%[[ARG0]] : tensor<32x2x49x16x16xf32>)
531// CHECK-SAME:   outs(%[[EMPTY]] : tensor<32x2x16x49x16xf32>)
532// CHECK-SAME:   permutation = [0, 1, 3, 2, 4]
533//      CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0], [1, 2], [3, 4]]
534// CHECK-SAME:   : tensor<32x2x16x49x16xf32> into tensor<32x32x784xf32>
535//      CHECK:  %[[C1:.*]] = arith.constant 1 : index
536//      CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<32x?x?xf32>
537//      CHECK: %[[C2:.*]] = arith.constant 2 : index
538//      CHECK: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<32x?x?xf32>
539//      CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0] [32, %[[DIM1]], %[[DIM2]]] [1, 1, 1]
540// CHECK-SAME:   : tensor<32x32x784xf32> to tensor<32x?x?xf32>
541//      CHECK: linalg.copy ins(%[[SLICE]] : tensor<32x?x?xf32>)
542// CHECK-SAME:        outs(%[[ARG1]] : tensor<32x?x?xf32>)
543func.func @unpack_with_dynamic_dest(%arg0: tensor<32x2x49x16x16xf32>, %arg1: tensor<32x?x?xf32>) -> tensor<32x?x?xf32> {
544  %pack = tensor.unpack %arg0 inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %arg1
545    : tensor<32x2x49x16x16xf32> -> tensor<32x?x?xf32>
546  return %pack : tensor<32x?x?xf32>
547}
548
549module attributes {transform.with_named_sequence} {
550  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
551    %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
552      : (!transform.any_op) -> !transform.op<"tensor.unpack">
553    transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
554      -> (!transform.op<"tensor.empty">,
555          !transform.op<"linalg.transpose">,
556          !transform.op<"tensor.collapse_shape">,
557          !transform.op<"tensor.extract_slice">)
558          transform.yield
559  }
560}
561
562// -----
563
564// Check that we can lower unpack with dynamic dimensions in the input and destination.
565// CHECK-LABEL: func.func @unpack_with_dynamic_input_dest(
566// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x8x16xf32>, %[[ARG1:.*]]: tensor<?x?xf32>)
567//      CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
568//      CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
569//      CHECK-DAG: %[[DIM00:.*]] = tensor.dim %[[ARG0]], %[[C0]]
570//      CHECK-DAG: %[[DIM01:.*]] = tensor.dim %[[ARG0]], %[[C1]]
571//      CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM00]], %[[DIM01]]) : tensor<?x8x?x16xf32>
572//      CHECK: %[[TRAN:.*]] = linalg.transpose
573// CHECK-SAME:    ins(%[[ARG0]] : tensor<?x?x8x16xf32>)
574// CHECK-SAME:   outs(%[[EMPTY]] : tensor<?x8x?x16xf32>)
575// CHECK-SAME:   permutation = [0, 2, 1, 3]
576//      CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
577// CHECK-SAME:   : tensor<?x8x?x16xf32> into tensor<?x?xf32>
578//      CHECK: %[[DIM10:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
579//      CHECK: %[[DIM11:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
580//      CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [%[[DIM10]], %[[DIM11]]] [1, 1]
581// CHECK-SAME:   : tensor<?x?xf32> to tensor<?x?xf32>
582//      CHECK: linalg.copy ins(%[[SLICE]] : tensor<?x?xf32>)
583// CHECK-SAME:        outs(%[[ARG1]] : tensor<?x?xf32>)
584func.func @unpack_with_dynamic_input_dest(%arg0: tensor<?x?x8x16xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
585    %unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 16] into %arg1 : tensor<?x?x8x16xf32> -> tensor<?x?xf32>
586    return %unpack : tensor<?x?xf32>
587}
588
589module attributes {transform.with_named_sequence} {
590  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
591    %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
592      : (!transform.any_op) -> !transform.op<"tensor.unpack">
593    transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
594      -> (!transform.op<"tensor.empty">,
595          !transform.op<"linalg.transpose">,
596          !transform.op<"tensor.collapse_shape">,
597          !transform.op<"tensor.extract_slice">)
598          transform.yield
599  }
600}
601
602// -----
603
604// Check that we can lower unpack with dynamic dimensions in the input, destination, inner_tiles.
605// CHECK-LABEL: func.func @unpack_fully_dynamic(
606// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
607//      CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
608//      CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
609//      CHECK-DAG:  %[[C2:.*]] = arith.constant 2 : index
610//      CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
611//      CHECK-DAG: %[[DIM00:.*]] = tensor.dim %[[ARG0]], %[[C0]]
612//      CHECK-DAG: %[[DIM01:.*]] = tensor.dim %[[ARG0]], %[[C1]]
613//      CHECK-DAG: %[[DIM02:.*]] = tensor.dim %[[ARG0]], %[[C2]]
614//      CHECK-DAG: %[[DIM03:.*]] = tensor.dim %[[ARG0]], %[[C3]]
615//      CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM00]], %[[DIM02]], %[[DIM01]], %[[DIM03]]) : tensor<?x?x?x?xf32>
616//      CHECK: %[[TRAN:.*]] = linalg.transpose
617// CHECK-SAME:    ins(%[[ARG0]] : tensor<?x?x?x?xf32>)
618// CHECK-SAME:   outs(%[[EMPTY]] : tensor<?x?x?x?xf32>)
619// CHECK-SAME:   permutation = [0, 2, 1, 3]
620//      CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
621// CHECK-SAME:   : tensor<?x?x?x?xf32> into tensor<?x?xf32>
622//      CHECK: %[[DIM10:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
623//      CHECK: %[[DIM11:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
624//      CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [%[[DIM10]], %[[DIM11]]] [1, 1]
625// CHECK-SAME:   : tensor<?x?xf32> to tensor<?x?xf32>
626//      CHECK: linalg.copy ins(%[[SLICE]] : tensor<?x?xf32>)
627// CHECK-SAME:        outs(%[[ARG1]] : tensor<?x?xf32>)
628func.func @unpack_fully_dynamic(%source: tensor<?x?x?x?xf32>, %dest: tensor<?x?xf32>, %tile_n : index, %tile_m : index) -> tensor<?x?xf32> {
629  %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [%tile_n, %tile_m] into %dest : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
630  return %0 : tensor<?x?xf32>
631}
632module attributes {transform.with_named_sequence} {
633  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
634    %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
635      : (!transform.any_op) -> !transform.op<"tensor.unpack">
636    transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
637          -> (!transform.op<"tensor.empty">,
638          !transform.op<"linalg.transpose">,
639          !transform.op<"tensor.collapse_shape">,
640          !transform.op<"tensor.extract_slice">)
641      transform.yield
642  }
643}
644
645// -----
646
647// Check that we can lower unpack "as unpad" with dynamic dims.
648// CHECK-LABEL: func.func @unpack_as_pad_dynamic(
649// CHECK-SAME: %[[ARG0:.*]]: tensor<1x1x1x1x136x64x16x16xf32>, %[[ARG1:.*]]: tensor<?x?x?x?xf32>
650//      CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
651//      CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
652//      CHECK-DAG:  %[[C2:.*]] = arith.constant 2 : index
653//      CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
654//      CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG1]], %[[C0]]
655//      CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]]
656//      CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
657//      CHECK-DAG: %[[DIM3:.*]] = tensor.dim %[[ARG1]], %[[C3]]
658//      CHECK: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
659// offsets.
660// CHECK-SAME:   [0, 0, 0, 0, 0, 0, 0, 0]
661// sizes.
662// CHECK-SAME:   [1, 1, 1, 1, %[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]]]
663// strides multiplers.
664// CHECK-SAME:   [1, 1, 1, 1, 1, 1, 1, 1]
665// CHECK-SAME:   :  tensor<1x1x1x1x136x64x16x16xf32> to tensor<?x?x?x?xf32>
666func.func @unpack_as_pad_dynamic(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
667  %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
668    : tensor<1x1x1x1x136x64x16x16xf32> -> tensor<?x?x?x?xf32>
669  return %pack : tensor<?x?x?x?xf32>
670}
671
672module attributes {transform.with_named_sequence} {
673  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
674    %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
675      : (!transform.any_op) -> !transform.op<"tensor.unpack">
676    transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
677      -> (!transform.op<"tensor.empty">,
678          !transform.op<"linalg.transpose">,
679          !transform.op<"tensor.collapse_shape">,
680          !transform.op<"tensor.extract_slice">)
681          transform.yield
682  }
683}
684
685// -----
686
687// CHECK-LABEL: @unpack_with_outer_dims_perm
688//  CHECK-SAME: %[[ARG0:.*]]: tensor<32x64xf32>, %[[ARG1:.*]]: tensor<2x4x32x8xf32>
689//       CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<4x8x2x32xf32>
690//       CHECK: %[[TRAN:.*]] = linalg.transpose
691//  CHECK-SAME:   ins(%[[ARG1]] : tensor<2x4x32x8xf32>)
692//  CHECK-SAME:   outs(%[[EMPTY]] : tensor<4x8x2x32xf32>)
693//  CHECK-SAME:   permutation = [1, 3, 0, 2]
694//       CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
695//  CHECK-SAME:   : tensor<4x8x2x32xf32> into tensor<32x64xf32>
696//       CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [32, 64] [1, 1]
697//  CHECK-SAME:   : tensor<32x64xf32> to tensor<32x64xf32>
698//       CHECK: linalg.copy ins(%[[SLICE]]
699//  CHECK-SAME:   : tensor<32x64xf32>) outs(%[[ARG0]] : tensor<32x64xf32>) -> tensor<32x64xf32>
700func.func @unpack_with_outer_dims_perm(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
701  %unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0]
702    inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg0 : tensor<2x4x32x8xf32> -> tensor<32x64xf32>
703  return %unpack : tensor<32x64xf32>
704}
705
706module attributes {transform.with_named_sequence} {
707  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
708    %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
709      : (!transform.any_op) -> !transform.op<"tensor.unpack">
710    transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
711      -> (!transform.op<"tensor.empty">,
712          !transform.op<"linalg.transpose">,
713          !transform.op<"tensor.collapse_shape">,
714          !transform.op<"tensor.extract_slice">)
715          transform.yield
716  }
717}
718