xref: /llvm-project/mlir/test/Dialect/Tensor/canonicalize.mlir (revision 9f6a1ddb43133328c90edfa29ccd4c714b289cb6)
1// RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s
2
3
4// CHECK-LABEL: expand_shape_identity_fold
5// CHECK-NEXT: return
6func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> {
7  %0 = tensor.expand_shape %arg0 [[0]] output_shape [5] : tensor<5xf32> into tensor<5xf32>
8  return %0 : tensor<5xf32>
9}
10
11// -----
12
13// CHECK-LABEL: expand_shape_rank0_identity_fold
14// CHECK-NEXT: return
15func.func @expand_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
16  %0 = tensor.expand_shape %arg0 [] output_shape [] : tensor<f32> into tensor<f32>
17  return %0 : tensor<f32>
18}
19
20// -----
21
22// CHECK-LABEL: collapse_shape_identity_fold
23// CHECK-NEXT: return
24func.func @collapse_shape_identity_fold(%arg0 : tensor<5x4xf32>) -> tensor<5x4xf32> {
25  %0 = tensor.collapse_shape %arg0 [[0], [1]] : tensor<5x4xf32> into tensor<5x4xf32>
26  return %0 : tensor<5x4xf32>
27}
28
29// -----
30
31// CHECK-LABEL: collapse_shape_rank0_identity_fold
32// CHECK-NEXT: return
33func.func @collapse_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
34  %0 = tensor.collapse_shape %arg0 [] : tensor<f32> into tensor<f32>
35  return %0 : tensor<f32>
36}
37
38// -----
39
40// CHECK-LABEL: @tensor_bitcast_chain_ok
41// CHECK-SAME: %[[IN:.*]]: tensor<2xi32>
42func.func @tensor_bitcast_chain_ok(%input: tensor<2xi32>) -> tensor<2xf32> {
43  // CHECK-NEXT: %[[RES:.*]] = tensor.bitcast %[[IN]] : tensor<2xi32> to tensor<2xf32>
44  %0 = tensor.bitcast %input : tensor<2xi32> to tensor<2xui32>
45  %1 = tensor.bitcast %0 : tensor<2xui32> to tensor<2xf32>
46  // CHECK-NEXT: return %[[RES]]
47  return %1 : tensor<2xf32>
48}
49
50// -----
51
52// CHECK-LABEL: @tensor_bitcast_chain_nop
53// CHECK-SAME: %[[IN:.*]]: tensor<4xi32>
54func.func @tensor_bitcast_chain_nop(%input: tensor<4xi32>) -> tensor<4xi32> {
55  %0 = tensor.bitcast %input : tensor<4xi32> to tensor<4xui32>
56  %1 = tensor.bitcast %0 : tensor<4xui32> to tensor<4xi32>
57  // CHECK-NEXT: return %[[IN]]
58  return %1 : tensor<4xi32>
59}
60
61// -----
62
63// Checks that NOP casts are removed.
64// CHECK-LABEL: cast_values
65func.func @cast_values(%arg0: tensor<*xi32>) -> tensor<2xi32> {
66  // NOP cast
67  %0 = tensor.cast %arg0 : tensor<*xi32> to tensor<*xi32>
68  // CHECK-NEXT: %[[RET:.*]] = tensor.cast %arg0 : tensor<*xi32> to tensor<2xi32>
69  %2 = tensor.cast %0 : tensor<*xi32> to tensor<2xi32>
70  // NOP cast
71  %4 = tensor.cast %2 : tensor<2xi32> to tensor<2xi32>
72  // CHECK-NEXT: return %[[RET]] : tensor<2xi32>
73  return %4 : tensor<2xi32>
74}
75
76// -----
77
78// CHECK-LABEL: @tensor.cast_chain_ok
79// CHECK-SAME: %[[IN:.*]]: tensor<*xi32>
80func.func @tensor.cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> {
81  // CHECK-NEXT: %[[RES:.*]] = tensor.cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32>
82  %0 = tensor.cast %input : tensor<*xi32> to tensor<4x?xi32>
83  %1 = tensor.cast %0 : tensor<4x?xi32> to tensor<4x8xi32>
84  // CHECK-NEXT: return %[[RES]]
85  return %1 : tensor<4x8xi32>
86}
87
88// -----
89
90// CHECK-LABEL: @tensor.cast_chain_regain
91// CHECK-SAME: %[[IN:.*]]: tensor<4xi32>
92func.func @tensor.cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> {
93  %0 = tensor.cast %input : tensor<4xi32> to tensor<?xi32>
94  %1 = tensor.cast %0 : tensor<?xi32> to tensor<4xi32>
95  // CHECK-NEXT: return %[[IN]]
96  return %1 : tensor<4xi32>
97}
98
99// -----
100
101// CHECK-LABEL: @tensor.cast_chain_keep
102// CHECK-SAME: %[[IN:.*]]: tensor<?x?xi32>
103func.func @tensor.cast_chain_keep(%input: tensor<?x?xi32>) -> tensor<?x8xi32> {
104  // CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]]
105  %0 = tensor.cast %input : tensor<?x?xi32> to tensor<4x?xi32>
106  // CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]]
107  %1 = tensor.cast %0 : tensor<4x?xi32> to tensor<?x8xi32>
108  // CHECK-NEXT: return %[[C2]]
109  return %1 : tensor<?x8xi32>
110}
111
112// -----
113
114// CHECK-LABEL: @tensor.cast_chain_invalid
115// CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32>
116func.func @tensor.cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> {
117  // CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]]
118  %0 = tensor.cast %input : tensor<4x8xi32> to tensor<?x?xi32>
119  // CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]]
120  %1 = tensor.cast %0 : tensor<?x?xi32> to tensor<8x4xi32>
121  // CHECK-NEXT: return %[[C2]]
122  return %1 : tensor<8x4xi32>
123}
124
125// -----
126
127// CHECK-LABEL: fold_concat
128// CHECK-SAME: %[[ARG0:.*]]: tensor<1x2x?xi32>
129func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1x2x?xi32>) {
130  %0 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x3xi32>
131  // CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<1x2x?xi32> to tensor<1x2x3xi32>
132  %1 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x?xi32>
133  // CHECK-NEXT: return %[[CAST]], %[[ARG0]] : tensor<1x2x3xi32>, tensor<1x2x?xi32>
134  return %0, %1 : tensor<1x2x3xi32>, tensor<1x2x?xi32>
135}
136
137// -----
138
139// CHECK-LABEL: func @fold_extract
140func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
141  %const_0 = arith.constant 0 : index
142  %const_1 = arith.constant 1 : index
143  %const_3 = arith.constant 3 : index
144  // CHECK-DAG: [[C64:%.+]] = arith.constant 64 : i32
145  // CHECK-DAG: [[C0:%.+]] = arith.constant 0.{{0*}}e+00 : f16
146  // CHECK-DAG: [[CM2:%.+]] = arith.constant -2.{{0*}}e+00 : f16
147
148  // Fold an extract into a splat.
149  // CHECK-DAG: [[C4:%.+]] = arith.constant 4.{{0*}}e+00 : f32
150  %0 = arith.constant dense<4.0> : tensor<4xf32>
151  %ext_1 = tensor.extract %0[%arg0] : tensor<4xf32>
152
153  // Fold an extract into a sparse with a sparse index.
154  %1 = arith.constant sparse<[[0, 0, 0], [1, 1, 1]],  [-5.0, -2.0]> : tensor<4x4x4xf16>
155  %ext_2 = tensor.extract %1[%const_1, %const_1, %const_1] : tensor<4x4x4xf16>
156
157  // Fold an extract into a sparse with a non sparse index.
158  %2 = arith.constant sparse<[[1, 1, 1]],  [-2.0]> : tensor<2x2x2xf16>
159  %ext_3 = tensor.extract %2[%const_0, %const_0, %const_0] : tensor<2x2x2xf16>
160
161  // Fold an extract into a dense tensor.
162  %3 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
163  %ext_4 = tensor.extract %3[%const_1, %const_0, %const_3] : tensor<2x1x4xi32>
164
165  // Fold an extract into a complex constant.
166  // CHECK-DAG: [[C5:%.+]] = complex.constant [1.200000e+00 : f32, 2.300000e+00 : f32] : complex<f32>
167  %4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
168  %ext_5 = tensor.extract %4[] : tensor<complex<f32>>
169
170  // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]]
171  return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32>
172}
173
174// -----
175
176// Ensure extract dense resource elements not crash.
177
178// CHECK-LABEL: func @extract_dense_resource_nofold
179func.func @extract_dense_resource_nofold() -> i64 {
180  // CHECK:      %[[EXT:.+]] = tensor.extract
181  // CHECK-NEXT:   return %[[EXT]]
182  %c0 = arith.constant 0 : index
183  %cst = arith.constant dense_resource<__elided__> : tensor<1xi64>
184  %extracted = tensor.extract %cst[%c0] : tensor<1xi64>
185  return %extracted : i64
186}
187
188// -----
189
190// CHECK-LABEL: func @fold_insert
191func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
192  // Fold an insert into a splat.
193  // CHECK-DAG: %[[C4:.+]] = arith.constant dense<4.{{0*}}e+00> : tensor<4xf32>
194  %0 = arith.constant dense<4.0> : tensor<4xf32>
195  %1 = arith.constant 4.0 : f32
196  %ins_1 = tensor.insert %1 into %0[%arg0] : tensor<4xf32>
197  // CHECK-NEXT: return %[[C4]]
198  return %ins_1 : tensor<4xf32>
199}
200
201// -----
202
203// CHECK-LABEL: func @extract_from_tensor.cast
204// CHECK-SAME: %[[TENSOR:.*]]: tensor<9xf32>
205func.func @extract_from_tensor.cast(%tensor: tensor<9xf32>) -> f32 {
206  // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
207  %c0 = arith.constant 0 : index
208  // CHECK-NOT: tensor.cast
209  %casted = tensor.cast %tensor : tensor<9xf32> to tensor<?xf32>
210  // CHECK-NEXT: tensor.extract %[[TENSOR]][%[[C0]]]
211  %result = tensor.extract %casted[%c0] : tensor<?xf32>
212  return %result : f32
213}
214
215// -----
216
217// CHECK-LABEL: func @extract_from_tensor.from_elements
218func.func @extract_from_tensor.from_elements(%element : index) -> index {
219  // CHECK-SAME: ([[ARG:%.*]]: index)
220  %c0 = arith.constant 0 : index
221  %tensor = tensor.from_elements %element : tensor<1xindex>
222  %extracted_element = tensor.extract %tensor[%c0] : tensor<1xindex>
223  // CHECK: [[ARG]] : index
224  return %extracted_element : index
225}
226
227// -----
228
229// CHECK-LABEL: func @extract_from_tensor.from_elements_0d
230func.func @extract_from_tensor.from_elements_0d(%element : index) -> index {
231  // CHECK-SAME: ([[ARG:%.*]]: index)
232  %c0 = arith.constant 0 : index
233  %tensor = tensor.from_elements %element : tensor<index>
234  %extracted_element = tensor.extract %tensor[] : tensor<index>
235  // CHECK: [[ARG]] : index
236  return %extracted_element : index
237}
238
239// -----
240
241// CHECK-LABEL: func @extract_from_tensor.from_elements_3d
242func.func @extract_from_tensor.from_elements_3d()
243    -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
244  %f0 = arith.constant 0.0 : f32
245  %f1 = arith.constant 1.0 : f32
246  %f2 = arith.constant 2.0 : f32
247  %f3 = arith.constant 3.0 : f32
248  %f4 = arith.constant 4.0 : f32
249  %f5 = arith.constant 5.0 : f32
250  %f6 = arith.constant 6.0 : f32
251  %f7 = arith.constant 7.0 : f32
252  %f8 = arith.constant 8.0 : f32
253  %f9 = arith.constant 9.0 : f32
254  %f10 = arith.constant 10.0 : f32
255  %f11 = arith.constant 11.0 : f32
256
257  %tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
258         : tensor<3x2x2xf32>
259  %c0 = arith.constant 0 : index
260  %c1 = arith.constant 1 : index
261  %c2 = arith.constant 2 : index
262
263  %r0 = tensor.extract %tensor[%c0, %c0, %c0] : tensor<3x2x2xf32>
264  %r1 = tensor.extract %tensor[%c0, %c0, %c1] : tensor<3x2x2xf32>
265  %r2 = tensor.extract %tensor[%c0, %c1, %c0] : tensor<3x2x2xf32>
266  %r3 = tensor.extract %tensor[%c0, %c1, %c1] : tensor<3x2x2xf32>
267  %r4 = tensor.extract %tensor[%c1, %c0, %c0] : tensor<3x2x2xf32>
268  %r5 = tensor.extract %tensor[%c1, %c0, %c1] : tensor<3x2x2xf32>
269  %r6 = tensor.extract %tensor[%c1, %c1, %c0] : tensor<3x2x2xf32>
270  %r7 = tensor.extract %tensor[%c1, %c1, %c1] : tensor<3x2x2xf32>
271  %r8 = tensor.extract %tensor[%c2, %c0, %c0] : tensor<3x2x2xf32>
272  %r9 = tensor.extract %tensor[%c2, %c0, %c1] : tensor<3x2x2xf32>
273  %r10 = tensor.extract %tensor[%c2, %c1, %c0] : tensor<3x2x2xf32>
274  %r11 = tensor.extract %tensor[%c2, %c1, %c1] : tensor<3x2x2xf32>
275  return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11
276         : f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
277}
278// CHECK-DAG: %[[F0:.*]] = arith.constant 0.0
279// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
280// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
281// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0
282// CHECK-DAG: %[[F4:.*]] = arith.constant 4.0
283// CHECK-DAG: %[[F5:.*]] = arith.constant 5.0
284// CHECK-DAG: %[[F6:.*]] = arith.constant 6.0
285// CHECK-DAG: %[[F7:.*]] = arith.constant 7.0
286// CHECK-DAG: %[[F8:.*]] = arith.constant 8.0
287// CHECK-DAG: %[[F9:.*]] = arith.constant 9.0
288// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01
289// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01
290
291// CHECK: return %[[F0]], %[[F1]], %[[F2]], %[[F3]], %[[F4]], %[[F5]],
292// CHECK-SAME:   %[[F6]], %[[F7]], %[[F8]], %[[F9]], %[[F10]], %[[F11]]
293
294// -----
295
296// CHECK-LABEL: func @extract_from_tensor.from_elements_variable_3d
297// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: f32
298// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: f32
299// CHECK-SAME: %[[ARG_2:[a-zA-Z0-9_]+]]: f32
300// CHECK-SAME: %[[ARG_3:[a-zA-Z0-9_]+]]: f32
301// CHECK-SAME: %[[ARG_4:[a-zA-Z0-9_]+]]: f32
302// CHECK-SAME: %[[ARG_5:[a-zA-Z0-9_]+]]: f32
303// CHECK-SAME: %[[ARG_6:[a-zA-Z0-9_]+]]: f32
304// CHECK-SAME: %[[ARG_7:[a-zA-Z0-9_]+]]: f32
305// CHECK-SAME: %[[ARG_8:[a-zA-Z0-9_]+]]: f32
306// CHECK-SAME: %[[ARG_9:[a-zA-Z0-9_]+]]: f32
307// CHECK-SAME: %[[ARG_10:[a-zA-Z0-9_]+]]: f32
308// CHECK-SAME: %[[ARG_11:[a-zA-Z0-9_]+]]: f32
309func.func @extract_from_tensor.from_elements_variable_3d(
310    %f0: f32, %f1: f32, %f2: f32, %f3: f32, %f4: f32, %f5: f32,
311    %f6: f32, %f7: f32, %f8: f32, %f9: f32, %f10: f32, %f11: f32)
312    -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
313
314  %tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
315         : tensor<3x2x2xf32>
316  %c0 = arith.constant 0 : index
317  %c1 = arith.constant 1 : index
318  %c2 = arith.constant 2 : index
319
320  %r0 = tensor.extract %tensor[%c0, %c0, %c0] : tensor<3x2x2xf32>
321  %r1 = tensor.extract %tensor[%c0, %c0, %c1] : tensor<3x2x2xf32>
322  %r2 = tensor.extract %tensor[%c0, %c1, %c0] : tensor<3x2x2xf32>
323  %r3 = tensor.extract %tensor[%c0, %c1, %c1] : tensor<3x2x2xf32>
324  %r4 = tensor.extract %tensor[%c1, %c0, %c0] : tensor<3x2x2xf32>
325  %r5 = tensor.extract %tensor[%c1, %c0, %c1] : tensor<3x2x2xf32>
326  %r6 = tensor.extract %tensor[%c1, %c1, %c0] : tensor<3x2x2xf32>
327  %r7 = tensor.extract %tensor[%c1, %c1, %c1] : tensor<3x2x2xf32>
328  %r8 = tensor.extract %tensor[%c2, %c0, %c0] : tensor<3x2x2xf32>
329  %r9 = tensor.extract %tensor[%c2, %c0, %c1] : tensor<3x2x2xf32>
330  %r10 = tensor.extract %tensor[%c2, %c1, %c0] : tensor<3x2x2xf32>
331  %r11 = tensor.extract %tensor[%c2, %c1, %c1] : tensor<3x2x2xf32>
332  return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11
333         : f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
334}
335// CHECK: return %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]],
336// CHECK-SAME: %[[ARG_6]], %[[ARG_7]], %[[ARG_8]], %[[ARG_9]], %[[ARG_10]], %[[ARG_11]]
337
338// -----
339
340// CHECK-LABEL: func.func @extract_from_elements_complex_i() -> tensor<3xcomplex<i32>> {
341// CHECK-NEXT:  %cst = arith.constant dense<[(1,2), (3,2), (1,2)]> : tensor<3xcomplex<i32>>
342// CHECK-NEXT:  return %cst : tensor<3xcomplex<i32>>
343func.func @extract_from_elements_complex_i() -> tensor<3xcomplex<i32>> {
344  %c1 = arith.constant dense<(1, 2)> : tensor<complex<i32>>
345  %complex1 = tensor.extract %c1[] : tensor<complex<i32>>
346  %c2 = arith.constant dense<(3, 2)> : tensor<complex<i32>>
347  %complex2 = tensor.extract %c2[] : tensor<complex<i32>>
348  %tensor = tensor.from_elements %complex1, %complex2, %complex1 : tensor<3xcomplex<i32>>
349  return %tensor : tensor<3xcomplex<i32>>
350}
351
352// -----
353
354// CHECK-LABEL:  func.func @extract_from_elements_complex_f() -> tensor<3xcomplex<f32>> {
355// CHECK-NEXT:   %cst = arith.constant dense<[(1.200000e+00,2.300000e+00), (3.200000e+00,2.100000e+00), (1.200000e+00,2.300000e+00)]> : tensor<3xcomplex<f32>>
356// CHECK-NEXT:   return %cst : tensor<3xcomplex<f32>>
357func.func @extract_from_elements_complex_f() -> tensor<3xcomplex<f32>> {
358  %c1 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
359  %complex1 = tensor.extract %c1[] : tensor<complex<f32>>
360  %c2 = arith.constant dense<(3.2, 2.1)> : tensor<complex<f32>>
361  %complex2 = tensor.extract %c2[] : tensor<complex<f32>>
362  %tensor = tensor.from_elements %complex1, %complex2, %complex1 : tensor<3xcomplex<f32>>
363  return %tensor : tensor<3xcomplex<f32>>
364}
365
366// -----
367
368// Ensure the optimization doesn't segfault from bad constants
369// CHECK-LABEL: func @extract_negative_from_tensor.from_elements
370func.func @extract_negative_from_tensor.from_elements(%element : index) -> index {
371  // CHECK-SAME: ([[ARG:%.*]]: index)
372  %c-1 = arith.constant -1 : index
373  %tensor = tensor.from_elements %element : tensor<1xindex>
374  %extracted_element = tensor.extract %tensor[%c-1] : tensor<1xindex>
375  // CHECK: tensor.from_elements
376  // CHECK: %[[RESULT:.*]] = tensor.extract
377  // CHECK: return %[[RESULT]]
378  return %extracted_element : index
379}
380
381// -----
382
383// Ensure the optimization doesn't segfault from bad constants
384// CHECK-LABEL: func @extract_oob_from_tensor.from_elements
385func.func @extract_oob_from_tensor.from_elements(%element : index) -> index {
386  // CHECK-SAME: ([[ARG:%.*]]: index)
387  %c1 = arith.constant 1 : index
388  %tensor = tensor.from_elements %element : tensor<1xindex>
389  %extracted_element = tensor.extract %tensor[%c1] : tensor<1xindex>
390  // CHECK: tensor.from_elements
391  // CHECK: %[[RESULT:.*]] = tensor.extract
392  // CHECK: return %[[RESULT]]
393  return %extracted_element : index
394}
395
396// -----
397
398// Ensure the optimization doesn't segfault from bad constants
399// CHECK-LABEL: func @extract_oob_from_tensor.from_elements
400func.func @extract_oob_from_tensor.from_elements(%element : index) -> index {
401  // CHECK-SAME: ([[ARG:%.*]]: index)
402  %c2 = arith.constant 2 : index
403  %tensor = tensor.from_elements %element : tensor<1xindex>
404  %extracted_element = tensor.extract %tensor[%c2] : tensor<1xindex>
405  // CHECK: tensor.from_elements
406  // CHECK: %[[RESULT:.*]] = tensor.extract
407  // CHECK: return %[[RESULT]]
408  return %extracted_element : index
409}
410
411// -----
412
413// CHECK-LABEL: func @extract_from_tensor.generate
414// CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
415func.func @extract_from_tensor.generate(%idx: index, %tensor: tensor<*xf32>) -> index {
416  %size = tensor.rank %tensor : tensor<*xf32>
417  // CHECK-NEXT: %[[RES:.*]] = tensor.dim %[[TENSOR]], %[[IDX]]
418  %0 = tensor.generate %size {
419    ^bb0(%arg0: index):
420    %1 = tensor.dim %tensor, %arg0 : tensor<*xf32>
421    tensor.yield %1 : index
422  } : tensor<?xindex>
423  %1 = tensor.extract %0[%idx] : tensor<?xindex>
424  // CHECK-NEXT: return %[[RES]]
425  return %1 : index
426}
427
428// -----
429
430// CHECK-LABEL: func @extract_from_tensor.generate_2d
431// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
432func.func @extract_from_tensor.generate_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index {
433  %size = tensor.rank %tensor : tensor<*xf32>
434  // CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[TENSOR]], %[[IDX0]]
435  // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[TENSOR]], %[[IDX1]]
436  // CHECK-NEXT: %[[RES:.*]] = arith.addi %[[DIM0]], %[[DIM1]]
437  %0 = tensor.generate %size, %size {
438    ^bb0(%arg0: index, %arg1: index):
439    %1 = tensor.dim %tensor, %arg0 : tensor<*xf32>
440    %2 = tensor.dim %tensor, %arg1 : tensor<*xf32>
441    %3 = arith.addi %1, %2 : index
442    tensor.yield %3 : index
443  } : tensor<?x?xindex>
444  %4 = tensor.extract %0[%idx0, %idx1] : tensor<?x?xindex>
445  // CHECK-NEXT: return %[[RES]]
446  return %4 : index
447}
448
449// -----
450
451// CHECK-LABEL: func @extract_from_tensor.generate_sideeffects
452// CHECK-SAME: %[[IDX:.*]]: index
453func.func @extract_from_tensor.generate_sideeffects(%idx: index, %tensor: tensor<*xf32>, %mem: memref<?xindex>) -> index {
454  %size = tensor.rank %tensor : tensor<*xf32>
455  // CHECK: %[[DTENSOR:.*]] = tensor.generate
456  %0 = tensor.generate %size {
457    ^bb0(%arg0: index):
458    %1 = tensor.dim %tensor, %arg0 : tensor<*xf32>
459    memref.store %1, %mem[%arg0] : memref<?xindex>
460    tensor.yield %1 : index
461  } : tensor<?xindex>
462  // CHECK: %[[RES:.*]] = tensor.extract %[[DTENSOR]][%[[IDX]]]
463  %1 = tensor.extract %0[%idx] : tensor<?xindex>
464  // CHECK-NEXT: return %[[RES]]
465  return %1 : index
466}
467
468// -----
469
470// CHECK-LABEL: @static_tensor.generate
471// CHECK-SAME: %[[SIZE1:.*]]: index, %[[SIZE4:.*]]: index)
472func.func @static_tensor.generate(%size1: index, %size4: index) -> tensor<3x?x?x7x?xindex> {
473  %c5 = arith.constant 5 : index
474  // CHECK: tensor.generate %[[SIZE1]], %[[SIZE4]]
475  %0 = tensor.generate %size1, %c5, %size4 {
476    ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index):
477    %1 = arith.constant 32 : index
478    tensor.yield %1 : index
479  // CHECK: : tensor<3x?x5x7x?xindex>
480  } : tensor<3x?x?x7x?xindex>
481  // CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
482  return %0 : tensor<3x?x?x7x?xindex>
483}
484
485// -----
486
487// CHECK-LABEL: @from_elements.constant
488func.func @from_elements.constant() -> tensor<3xindex> {
489  // CHECK: %[[CST:.*]] = arith.constant dense<[1, 2, 1]> : tensor<3xindex>
490  // CHECK: return %[[CST]]
491  %c1 = arith.constant 1 : index
492  %c2 = arith.constant 2 : index
493  %tensor = tensor.from_elements %c1, %c2, %c1 : tensor<3xindex>
494  return %tensor : tensor<3xindex>
495}
496
497// -----
498
499func.func @slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
500    %arg2 : index) -> tensor<?x?x?xf32>
501{
502  %c0 = arith.constant 0 : index
503  %c1 = arith.constant 1 : index
504  %c4 = arith.constant 4 : index
505  %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
506  return %0 : tensor<?x?x?xf32>
507}
508// CHECK-LABEL: func @slice_canonicalize
509//  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?xf32>
510//       CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
511//  CHECK-SAME:      [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
512//  CHECK-SAME:      : tensor<?x?x?xf32> to tensor<4x1x?xf32>
513//       CHECK:   %[[RESULT:.+]] = tensor.cast %[[SLICE]]
514//       CHECK:   return %[[RESULT]]
515
516// -----
517
518func.func @rank_reducing_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
519    %arg2 : index) -> tensor<?x?xf32>
520{
521  %c0 = arith.constant 0 : index
522  %c1 = arith.constant 1 : index
523  %c4 = arith.constant 4 : index
524  %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?xf32>
525  return %0 : tensor<?x?xf32>
526}
527// CHECK-LABEL: func @rank_reducing_slice_canonicalize
528//  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?xf32>
529//       CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
530//  CHECK-SAME:      [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
531//  CHECK-SAME:      : tensor<?x?x?xf32> to tensor<4x?xf32>
532//       CHECK:   %[[RESULT:.+]] = tensor.cast %[[SLICE]]
533//       CHECK:   return %[[RESULT]]
534
535// -----
536
537// CHECK-LABEL: func @trivial_slice
538//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
539//   CHECK-NOT:   tensor.extract_slice
540//       CHECK:   return %[[ARG0]] :  tensor<4x6x16x32xi8>
541func.func @trivial_slice(%arg0 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
542  %0 = tensor.extract_slice %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<4x6x16x32xi8>
543  return %0 : tensor<4x6x16x32xi8>
544}
545
546// -----
547
548// CHECK-LABEL: func @trivial_insert_slice
549//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
550//   CHECK-NOT:   tensor.extract_slice
551//       CHECK:   return %[[ARG0]] :  tensor<4x6x16x32xi8>
552func.func @trivial_insert_slice(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
553  %0 = tensor.insert_slice %arg0 into %arg1[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> into tensor<4x6x16x32xi8>
554  return %0 : tensor<4x6x16x32xi8>
555}
556
557// -----
558
559// CHECK-LABEL: func @empty_insert_slice
560//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<0x2xi8>
561//  CHECK-SAME:   %[[ARG1:.[a-z0-9A-Z_]+]]: tensor<3x3xi8>
562//   CHECK-NOT:   tensor.extract_slice
563//       CHECK:   return %[[ARG1]] :  tensor<3x3xi8>
564func.func @empty_insert_slice(%arg0 : tensor<0x2xi8>, %arg1 : tensor<3x3xi8>) -> tensor<3x3xi8> {
565  %0 = tensor.insert_slice %arg0 into %arg1[0, 0] [0, 2] [1, 1] : tensor<0x2xi8> into tensor<3x3xi8>
566  return %0 : tensor<3x3xi8>
567}
568
569// -----
570
571// CHECK-LABEL: func @rank_reducing_tensor_of_cast
572//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
573//       CHECK:   %[[S:.+]] = tensor.extract_slice %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8>
574// Tensor cast is moved after slice and then gets canonicalized away.
575//   CHECK-NOT:   tensor.cast
576//       CHECK:   return %[[S]] : tensor<16x32xi8>
577func.func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<16x32xi8> {
578  %0 = tensor.cast %arg : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
579  %1 = tensor.extract_slice %0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<?x?x16x32xi8> to tensor<16x32xi8>
580  return %1 : tensor<16x32xi8>
581}
582
583// -----
584
585// CHECK-LABEL: func @rank_reducing_insert_slice_of_cast
586//  CHECK-SAME:   %[[A:.[a-z0-9A-Z_]+]]: tensor<16x32xi8>
587//  CHECK-SAME:   %[[B:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
588//       CHECK:   %[[S:.+]] = tensor.insert_slice %[[A]] into %[[B]][0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<16x32xi8> into tensor<4x6x16x32xi8>
589// Tensor cast is folded away.
590//   CHECK-NOT:   tensor.cast
591//       CHECK:   return %[[S]] : tensor<4x6x16x32xi8>
592func.func @rank_reducing_insert_slice_of_cast(%a : tensor<16x32xi8>, %b : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
593  %c0 = arith.constant 0: index
594  %cast = tensor.cast %a : tensor<16x32xi8> to tensor<?x32xi8>
595  %sz = tensor.dim %cast, %c0: tensor<?x32xi8>
596  %res = tensor.insert_slice %cast into %b[0, 1, 0, 0] [1, 1, %sz, 32] [1, 1, 1, 1] : tensor<?x32xi8> into tensor<4x6x16x32xi8>
597  return %res : tensor<4x6x16x32xi8>
598}
599
600// -----
601
602func.func @insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
603    %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
604{
605  %c0 = arith.constant 0 : index
606  %c1 = arith.constant 1 : index
607  %c4 = arith.constant 4 : index
608  %0 = tensor.insert_slice %arg0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
609  return %0 : tensor<?x?x?xf32>
610}
611// CHECK-LABEL: func @insert_slice_canonicalize
612//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
613//       CHECK:   %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<4x1x?xf32>
614//       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[CAST]]
615//  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
616//  CHECK-SAME:      : tensor<4x1x?xf32> into tensor<?x?x?xf32>
617//       CHECK:   return %[[RESULT]]
618
619// -----
620
621// Do not insert a cast for the following example. The new source type wouldn't be "more static" than the old one.
622func.func @insert_slice_canonicalize_encoding(%arg0 : tensor<2x2xf32, "foo">,
623                                              %arg1 : tensor<4x4xf32, "foo">) -> tensor<4x4xf32, "foo">
624{
625  %0 = tensor.insert_slice %arg0 into %arg1[0, 0] [2, 2] [1, 1] : tensor<2x2xf32, "foo"> into tensor<4x4xf32, "foo">
626  return %0 : tensor<4x4xf32, "foo">
627}
628// CHECK-LABEL: func @insert_slice_canonicalize_encoding
629//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x2xf32, "foo">
630//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: tensor<4x4xf32, "foo">
631//       CHECK-NOT: tensor.cast
632//       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]] into %[[ARG1]]
633//  CHECK-SAME:      [0, 0] [2, 2] [1, 1]
634//  CHECK-SAME:      : tensor<2x2xf32, "foo"> into tensor<4x4xf32, "foo">
635//       CHECK:   return %[[RESULT]]
636
637// -----
638
639func.func @slice_to_insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
640    %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
641{
642  %c0 = arith.constant 0 : index
643  %c1 = arith.constant 1 : index
644  %c4 = arith.constant 4 : index
645  %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
646  %1 = tensor.insert_slice %0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
647  return %1 : tensor<?x?x?xf32>
648}
649// CHECK-LABEL: func @slice_to_insert_slice_canonicalize
650//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
651//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
652//       CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
653//  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}} [1, 1, 1]
654//  CHECK-SAME:      : tensor<?x?x?xf32> to tensor<4x1x?xf32>
655//       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[SLICE]]
656//  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
657//  CHECK-SAME:      : tensor<4x1x?xf32> into tensor<?x?x?xf32>
658//       CHECK:   return %[[RESULT]]
659
660// -----
661
662func.func @rank_reducing_insert_slice_canonicalize(%arg0 : tensor<?x?xf32>, %arg1 : index,
663    %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
664{
665  %c0 = arith.constant 0 : index
666  %c1 = arith.constant 1 : index
667  %c4 = arith.constant 4 : index
668  %0 = tensor.insert_slice %arg0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?xf32> into tensor<?x?x?xf32>
669  return %0 : tensor<?x?x?xf32>
670}
671// CHECK-LABEL: func @rank_reducing_insert_slice_canonicalize
672//  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?xf32>
673//       CHECK:   %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<?x?xf32> to tensor<4x?xf32>
674//       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[CAST]]
675//  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
676//  CHECK-SAME:      : tensor<4x?xf32> into tensor<?x?x?xf32>
677//       CHECK:   return %[[RESULT]]
678
679// -----
680
681func.func @rank_reducing_slice_to_insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
682    %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
683{
684  %c0 = arith.constant 0 : index
685  %c1 = arith.constant 1 : index
686  %c4 = arith.constant 4 : index
687  %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?xf32>
688  %1 = tensor.insert_slice %0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?xf32> into tensor<?x?x?xf32>
689  return %1 : tensor<?x?x?xf32>
690}
691// CHECK-LABEL: func @rank_reducing_slice_to_insert_slice_canonicalize
692//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
693//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
694//       CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
695//  CHECK-SAME:     [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
696//  CHECK-SAME:     : tensor<?x?x?xf32> to tensor<4x?xf32>
697//       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ARG3]]
698//  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
699//  CHECK-SAME:      : tensor<4x?xf32> into tensor<?x?x?xf32>
700//       CHECK:   return %[[RESULT]]
701
702// -----
703
704func.func @insert_slice_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>,
705    %arg2 : index, %arg3 : index) -> tensor<?x?xi32> {
706  %c0 = arith.constant 0 : index
707  %c1 = arith.constant 1 : index
708  %c2 = arith.constant 2 : index
709  %c8 = arith.constant 8 : index
710  %0 = tensor.dim %arg0, %c1 : tensor<2x?xi32>
711  %1 = tensor.extract %arg1[] : tensor<i32>
712  %2 = tensor.generate %arg2, %c8 {
713  ^bb0(%arg4: index, %arg5: index):
714    tensor.yield %1 : i32
715  } : tensor<?x?xi32>
716  %3 = tensor.insert_slice %arg0 into %2[0, %arg3] [2, %0] [1, 1] : tensor<2x?xi32> into tensor<?x?xi32>
717  return %3 : tensor<?x?xi32>
718}
719// CHECK-LABEL: func @insert_slice_propagate_dest_cast
720//       CHECK:   %[[UPDATED:.+]] = tensor.insert_slice %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1]
721//  CHECK-SAME:     tensor<2x?xi32> into tensor<?x8xi32>
722//       CHECK:   %[[CAST:.+]] = tensor.cast %[[UPDATED]]
723//       CHECK:   return %[[CAST]]
724
725// -----
726
727func.func @insert_slice_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : tensor<i32>) -> tensor<3x9xi32> {
728  %c9 = arith.constant 9 : index
729  %c3 = arith.constant 3 : index
730  %2 = tensor.extract %arg1[] : tensor<i32>
731  %4 = tensor.generate %c3, %c9 {
732  ^bb0(%arg2: index, %arg3: index):
733    tensor.yield %2 : i32
734  } : tensor<?x?xi32>
735  %5 = tensor.insert_slice %arg0 into %4[0, 1] [2, 3] [1, 1] : tensor<2x3xi32> into tensor<?x?xi32>
736  %6 = tensor.cast %5 : tensor<?x?xi32> to tensor<3x9xi32>
737  return %6 : tensor<3x9xi32>
738}
739// CHECK-LABEL: func @insert_slice_output_dest_canonicalize
740//  CHECK-SAME:   %[[ARG0:[a-zA-z0-9_]+]]: tensor<2x3xi32>
741//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<i32>
742//       CHECK:   %[[PAD:.+]] = tensor.extract %[[ARG1]]
743//       CHECK:   %[[GENERATE:.+]] = tensor.generate
744//       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]] into %[[GENERATE]]
745//       CHECK:   return %[[RESULT]]
746
747// -----
748
749// Test case: Folding of tensor.dim(tensor.generate %idx) -> %idx
750// CHECK-LABEL: func @dim_of_tensor.generate(
751//  CHECK-SAME:     %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
752//   CHECK-NOT:   tensor.dim
753//       CHECK:   return %[[IDX1]] : index
754func.func @dim_of_tensor.generate(%arg0: index, %arg1: index) -> index {
755  %c3 = arith.constant 3 : index
756  %0 = tensor.generate %arg0, %arg1 {
757  ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
758    tensor.yield %c3 : index
759  } : tensor<2x?x4x?x5xindex>
760  %1 = tensor.dim %0, %c3 : tensor<2x?x4x?x5xindex>
761  return %1 : index
762}
763
764// -----
765
766// Test case: Folding tensor.dim(tensor.cast %0, %idx) -> tensor.dim %0, %idx
767// CHECK-LABEL: func @fold_dim_of_tensor.cast
768//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
769//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
770//   CHECK-DAG:   %[[C4:.+]] = arith.constant 4 : index
771//       CHECK:   %[[T0:.+]] = tensor.dim %[[ARG0]], %[[C1]]
772//  CHECK-NEXT:   return %[[C4]], %[[T0]]
773func.func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
774  %c0 = arith.constant 0 : index
775  %c1 = arith.constant 1 : index
776  %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
777  %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
778  %2 = tensor.dim %0, %c1 : tensor<?x?xf32>
779  return %1, %2: index, index
780}
781
782// -----
783
784// CHECK-LABEL: func @insert_slice_cast
785func.func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
786  // CHECK-SAME: %[[ARG0:.*]]: tensor<1x?xf32>
787  %0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32>
788  // CHECK: %[[RES:.*]] = tensor.insert_slice %[[ARG0]]
789  // CHECK-SAME: [{{.*}}, {{.*}}] [1, {{.*}}] [{{.*}}, {{.*}}]
790  // CHECK-SAME: : tensor<1x?xf32> into tensor<?x?xf32>
791  %1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, %arg5] [%arg6, %arg7] : tensor<?x?xf32> into tensor<?x?xf32>
792  // CHECK: return %[[RES]] : tensor<?x?xf32>
793  return %1 : tensor<?x?xf32>
794}
795
796// -----
797
798// CHECK-LABEL: func @insert_slice_cast_no_fold
799func.func @insert_slice_cast_no_fold(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
800  %0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x5xf32>
801  // CHECK: %[[CAST:.*]] = tensor.cast
802  // CHECK: %[[RES:.*]] = tensor.insert_slice %[[CAST]]
803  // CHECK-SAME: [{{.*}}, {{.*}}] [{{.*}}, 5] [{{.*}}, {{.*}}]
804  // CHECK-SAME: : tensor<?x5xf32> into tensor<?x?xf32>
805  %1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, 5] [%arg6, %arg7] : tensor<?x5xf32> into tensor<?x?xf32>
806  // CHECK: return %[[RES]] : tensor<?x?xf32>
807  return %1 : tensor<?x?xf32>
808}
809
810// -----
811
812// CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src(
813// CHECK-SAME:      %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32>
814//      CHECK:    %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x64xf32>
815//      CHECK:    %[[r:.*]] =  tensor.insert_slice %[[cast]] into %[[arg1]][0, 1, 2] [64, 5, 64] [1, 1, 1] : tensor<64x5x64xf32> into tensor<?x?x?xf32>
816//      CHECK:    return %[[r]]
817func.func @insert_tensor_cast_on_insert_slice_src(
818    %arg0 : tensor<?x5x?xf32>,  %arg1 : tensor<?x?x?xf32>, %sz0: index, %sz2: index) -> tensor<?x?x?xf32> {
819  %c64 = arith.constant 64: index
820  %r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [%c64, 5, %c64] [1, 1, 1]
821    : tensor<?x5x?xf32> into tensor<?x?x?xf32>
822  return %r : tensor<?x?x?xf32>
823}
824
825// -----
826
827// CHECK-LABEL: func @fold_extract_insert
828//  CHECK-SAME: %{{.+}}: tensor<?x?x?xf32>, %[[SLICE:.+]]: tensor<4x?x8xf32>
829func.func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<4x?x8xf32>) {
830  %c0 = arith.constant 0: index
831  %c1 = arith.constant 1: index
832  %0 = tensor.insert_slice %slice into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
833  %1 = tensor.extract_slice %0[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<?x?x?xf32> to tensor<4x?x8xf32>
834  // CHECK: return %[[SLICE]]
835  return %1 : tensor<4x?x8xf32>
836}
837
838// -----
839
840// CHECK-LABEL: func @fold_gather_constant_splat
841//   CHECK-NOT: tensor.gather
842//       CHECK: arith.constant dense<1.000000e-01> : tensor<1x2x1x1x1xf32>
843func.func @fold_gather_constant_splat(%indices : tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> {
844  %cst = arith.constant dense<1.000000e-01> : tensor<4x4x4xf32>
845  %0 = tensor.gather %cst[%indices] gather_dims([0, 1, 2]) :
846    (tensor<4x4x4xf32>, tensor<1x2x 3xindex>) -> tensor<1x2x 1x1x1xf32>
847  return %0 : tensor<1x2x 1x1x1xf32>
848}
849
850// -----
851
852// CHECK-LABEL: func @fold_reshape_constant_splat
853//   CHECK-NOT: tensor.reshape
854//       CHECK: arith.constant dense<1.000000e-01> : tensor<4xf32>
855func.func @fold_reshape_constant_splat(%shape : tensor<1xi32>) -> tensor<4xf32> {
856  %cst = arith.constant dense<1.000000e-01> : tensor<4x1xf32>
857  %0 = tensor.reshape %cst(%shape)
858             : (tensor<4x1xf32>, tensor<1xi32>) -> tensor<4xf32>
859  return %0 : tensor<4xf32>
860}
861
862// -----
863
864// CHECK-LABEL: func @fold_reshape_chain
865//  CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<*xf32>
866//  CHECK-SAME: %[[SHAPE_0:[a-zA-Z0-9_]+]]: tensor<?xindex>
867//  CHECK-SAME: %[[SHAPE_1:[a-zA-Z0-9_]+]]: tensor<?xindex>
868//  CHECK-SAME: %[[SHAPE_2:[a-zA-Z0-9_]+]]: tensor<?xindex>
869//       CHECK: %[[RESULT:.*]] = tensor.reshape %[[INPUT]](%[[SHAPE_2]])
870//       CHECK: return %[[RESULT]]
871func.func @fold_reshape_chain(%input: tensor<*xf32>, %shape_0: tensor<?xindex>, %shape_1: tensor<?xindex>, %shape_2: tensor<?xindex>) -> tensor<*xf32> {
872  %0 = tensor.reshape %input(%shape_0) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
873  %1 = tensor.reshape %0(%shape_1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
874  %2 = tensor.reshape %1(%shape_2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
875  return %2 : tensor<*xf32>
876}
877
878// -----
879
880// CHECK-LABEL: func @fold_reshape_1d
881//  CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<?xf32>
882//  CHECK-SAME: %[[SHAPE:[a-zA-Z0-9_]+]]: tensor<1xindex>
883//       CHECK: return %[[INPUT]]
884func.func @fold_reshape_1d(%input: tensor<?xf32>, %shape: tensor<1xindex>) -> tensor<?xf32> {
885  %0 = tensor.reshape %input(%shape) : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
886  return %0 : tensor<?xf32>
887}
888
889// -----
890
891// CHECK-LABEL: func @fold_extract_constant_splat
892//   CHECK-NOT: tensor.extract_slice
893//       CHECK: arith.constant dense<42> : tensor<4x4xi32>
894func.func @fold_extract_constant_splat() -> (tensor<4x4xi32>) {
895  %cst = arith.constant dense<42> : tensor<1024x1024xi32>
896  %1 = tensor.extract_slice %cst[0,0] [4,4] [1, 1] : tensor<1024x1024xi32> to tensor<4x4xi32>
897  return %1 : tensor<4x4xi32>
898}
899
900// -----
901
902// CHECK-LABEL: func @fold_pack_constant_splat
903//   CHECK-NOT: tensor.pack
904//       CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
905func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
906  %cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
907  %0 = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
908    inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
909  return %0 : tensor<8x16x8x32xf32>
910}
911
912// -----
913
914// CHECK-LABEL: func @fold_padding_value_pack_constant_splat
915//   CHECK-NOT: tensor.pack
916//       CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
917func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
918  %pad = arith.constant 1.000000e-01 : f32
919  %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
920  %0 = tensor.pack %cst
921    padding_value(%pad : f32)
922    outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
923    inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
924  return %0 : tensor<8x16x8x32xf32>
925}
926
927
928// -----
929
930// CHECK-LABEL: func @nofold_padding_value_pack_constant_splat
931//       CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32>
932//       CHECK: tensor.pack
933func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
934  %pad = arith.constant 0.0 : f32
935  %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
936  %0 = tensor.pack %cst
937    padding_value(%pad : f32)
938    outer_dims_perm = [1, 0]
939    inner_dims_pos = [0, 1]
940    inner_tiles = [8, 32]
941    into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
942  return %0 : tensor<8x16x8x32xf32>
943}
944
945// -----
946
947func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<31250x1200x16x1xf32> {
948  %cst = arith.constant 0.000000e+00 : f32
949  %0 = tensor.empty() : tensor<31250x1200x16x1xf32>
950  %pack = tensor.pack %arg0
951    padding_value(%cst : f32)
952    outer_dims_perm = [1, 0]
953    inner_dims_pos = [1, 0]
954    inner_tiles = [16, 1]
955    into %0 : tensor<1200x500000xf32> -> tensor<31250x1200x16x1xf32>
956  return %pack : tensor<31250x1200x16x1xf32>
957}
958// CHECK-LABEL: func @fold_padding_value_pack
959// CHECK-NOT:     padding_value
960
961// -----
962
963func.func @infer_src_shape_pack(%src: tensor<?x?x?x?xf32>, %dest: tensor<10x20x30x40x16xf32>) -> tensor<10x20x30x40x16xf32> {
964  %cst = arith.constant 0.000000e+00 : f32
965   %pack = tensor.pack %src
966    padding_value(%cst : f32)
967    outer_dims_perm = [2, 1, 3, 0]
968    inner_dims_pos = [2]
969    inner_tiles = [16]
970    into %dest : tensor<?x?x?x?xf32> -> tensor<10x20x30x40x16xf32>
971  return %pack : tensor<10x20x30x40x16xf32>
972}
973// CHECK-LABEL: func.func @infer_src_shape_pack
974// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
975// CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
976// CHECK:         %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32>
977// CHECK:         %[[PACK:.+]] = tensor.pack %[[CAST_SRC]] {{.+}} into %[[DEST]]
978// CHECK:         return %[[PACK]]
979
980// -----
981
982func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?x?x?x16xf32>) -> tensor<?x?x?x?x16xf32> {
983  %cst = arith.constant 0.000000e+00 : f32
984   %pack = tensor.pack %src
985    padding_value(%cst : f32)
986    outer_dims_perm = [2, 1, 3, 0]
987    inner_dims_pos = [2]
988    inner_tiles = [16]
989    into %dest : tensor<30x20x?x10xf32> -> tensor<?x?x?x?x16xf32>
990  return %pack : tensor<?x?x?x?x16xf32>
991}
992// CHECK-LABEL: func.func @infer_dest_shape_pack
993// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
994// CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
995// CHECK:         %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32>
996// CHECK:         %[[PACK:.+]] = tensor.pack %[[SRC]] {{.+}} into %[[CAST_DEST]]
997// CHECK:         %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<?x20x10x30x16xf32> to tensor<?x?x?x?x16xf32>
998// CHECK:         return %[[CAST_PACK]]
999
1000// -----
1001
1002func.func @no_infer_pack_shape(%arg0: tensor<?x32x100xf32>, %arg1: index) -> tensor<32x7x?x16x1xf32> {
1003  %cst = arith.constant 0.000000e+00 : f32
1004  %0 = tensor.empty(%arg1) : tensor<32x7x?x16x1xf32>
1005  %pack = tensor.pack %arg0 padding_value(%cst : f32) outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [16, 1] into %0 : tensor<?x32x100xf32> -> tensor<32x7x?x16x1xf32>
1006  return %pack : tensor<32x7x?x16x1xf32>
1007}
1008// CHECK-LABEL: func.func @no_infer_pack_shape
1009// CHECK-NOT:     tensor.cast
1010
1011// -----
1012
1013func.func @fold_padding_value_pack_negative1(%arg0: tensor<1200x499999xf32>) -> tensor<31250x1200x16x1xf32> {
1014  %cst = arith.constant 0.000000e+00 : f32
1015  %0 = tensor.empty() : tensor<31250x1200x16x1xf32>
1016  %pack = tensor.pack %arg0
1017    padding_value(%cst : f32)
1018    outer_dims_perm = [1, 0]
1019    inner_dims_pos = [1, 0]
1020    inner_tiles = [16, 1]
1021    into %0 : tensor<1200x499999xf32> -> tensor<31250x1200x16x1xf32>
1022  return %pack : tensor<31250x1200x16x1xf32>
1023}
1024// CHECK-LABEL: func @fold_padding_value_pack_negative1
1025// CHECK:         tensor.pack
1026// CHECK-SAME:      padding_value
1027
1028// -----
1029
1030func.func @fold_padding_value_pack_negative2(%arg0: tensor<1200x?xf32>, %arg1: tensor<?x1200x16x1xf32>) -> tensor<?x1200x16x1xf32> {
1031  %cst = arith.constant 0.000000e+00 : f32
1032  %pack = tensor.pack %arg0
1033    padding_value(%cst : f32)
1034    outer_dims_perm = [1, 0]
1035    inner_dims_pos = [1, 0]
1036    inner_tiles = [16, 1]
1037    into %arg1 : tensor<1200x?xf32> -> tensor<?x1200x16x1xf32>
1038  return %pack : tensor<?x1200x16x1xf32>
1039}
1040// CHECK-LABEL: func @fold_padding_value_pack_negative2
1041// CHECK:         tensor.pack
1042// CHECK-SAME:      padding_value
1043
1044// -----
1045
1046func.func @fold_padding_value_pack_negative3(%arg0: tensor<1200x500000xf32>, %arg1: tensor<?x1200x?x1xf32>, %tile : index) -> tensor<?x1200x?x1xf32> {
1047  %cst = arith.constant 0.000000e+00 : f32
1048  %pack = tensor.pack %arg0
1049    padding_value(%cst : f32)
1050    outer_dims_perm = [1, 0]
1051    inner_dims_pos = [1, 0]
1052    inner_tiles = [%tile, 1]
1053    into %arg1 : tensor<1200x500000xf32> -> tensor<?x1200x?x1xf32>
1054  return %pack : tensor<?x1200x?x1xf32>
1055}
1056// CHECK-LABEL: func @fold_padding_value_pack_negative3
1057// CHECK:         tensor.pack
1058// CHECK-SAME:      padding_value
1059
1060// -----
1061
1062// CHECK-LABEL: func @fold_unpack_constant_splat
1063//   CHECK-NOT: tensor.unpack
1064//       CHECK: arith.constant dense<1.000000e-01> : tensor<128x256xf32>
1065func.func @fold_unpack_constant_splat(%dest : tensor<128x256xf32>) -> tensor<128x256xf32> {
1066  %cst = arith.constant dense<1.000000e-01> : tensor<16x8x8x32xf32>
1067  %0 = tensor.unpack %cst inner_dims_pos = [0, 1]
1068    inner_tiles = [8, 32] into %dest : tensor<16x8x8x32xf32> -> tensor<128x256xf32>
1069  return %0 : tensor<128x256xf32>
1070}
1071
1072// -----
1073
1074func.func @infer_dest_shape_unpack(%src: tensor<10x20x30x40x16xf32>, %dest: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
1075  %unpack = tensor.unpack %src
1076    outer_dims_perm = [2, 1, 3, 0]
1077    inner_dims_pos = [2]
1078    inner_tiles = [16]
1079    into %dest : tensor<10x20x30x40x16xf32> -> tensor<?x?x?x?xf32>
1080  return %unpack : tensor<?x?x?x?xf32>
1081}
1082// CHECK-LABEL: func.func @infer_dest_shape_unpack
1083// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
1084// CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
1085// CHECK:         %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32>
1086// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[SRC]] {{.+}} into %[[CAST_DEST]]
1087// CHECK:         %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<40x20x?x30xf32> to tensor<?x?x?x?xf32>
1088// CHECK:         return %[[CAST_UNPACK]]
1089
1090// -----
1091
1092func.func @infer_src_shape_unpack(%src: tensor<?x?x?x?x16xf32>, %dest: tensor<30x20x?x10xf32>) -> tensor<30x20x?x10xf32> {
1093  %unpack = tensor.unpack %src
1094    outer_dims_perm = [2, 1, 3, 0]
1095    inner_dims_pos = [2]
1096    inner_tiles = [16]
1097    into %dest : tensor<?x?x?x?x16xf32> -> tensor<30x20x?x10xf32>
1098  return %unpack : tensor<30x20x?x10xf32>
1099}
1100// CHECK-LABEL: func.func @infer_src_shape_unpack
1101// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
1102// CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
1103// CHECK:         %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32>
1104// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[CAST_SRC]]
1105// CHECK:         return %[[UNPACK]]
1106
1107// -----
1108
1109func.func @no_infer_unpack_shape(%arg1: tensor<32x7x?x16x1xf32>, %arg2: index) -> tensor<?x32x100xf32> {
1110  %cst = arith.constant 0.000000e+00 : f32
1111  %0 = tensor.empty(%arg2) : tensor<?x32x100xf32>
1112  %unpack = tensor.unpack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [16, 1] into %0 : tensor<32x7x?x16x1xf32> -> tensor<?x32x100xf32>
1113  return %unpack : tensor<?x32x100xf32>
1114}
1115// CHECK-LABEL: func.func @no_infer_unpack_shape
1116// CHECK-NOT:     tensor.cast
1117
1118// -----
1119
1120
1121// CHECK-LABEL: func @fold_overlapping_insert
1122//  CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
1123func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {
1124  %c0 = arith.constant 0: index
1125  %c1 = arith.constant 1: index
1126  %0 = tensor.insert_slice %slice1 into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
1127  // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE2]] into %[[INPUT]]
1128  %1 = tensor.insert_slice %slice2 into %0[0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
1129  // CHECK: return %[[INSERT]]
1130  return %1 : tensor<?x?x?xf32>
1131}
1132
1133// -----
1134
1135func.func @compose_expand_of_expand(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index)
1136    -> tensor<?x6x4x?x5xf32> {
1137  %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
1138      : tensor<?x?xf32> into tensor<?x4x?xf32>
1139  %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4]] output_shape [%arg3, 6, 4, %arg4, 5] : tensor<?x4x?xf32> into tensor<?x6x4x?x5xf32>
1140  return %1 : tensor<?x6x4x?x5xf32>
1141}
1142// CHECK-LABEL: compose_expand_of_expand
1143//       CHECK:   tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] output_shape [%arg3, 6, 4, %arg4, 5]
1144//   CHECK-NOT:   tensor.expand_shape
1145
1146// -----
1147
1148func.func @compose_expand_of_expand_of_zero_dim(%arg0 : tensor<f32>)
1149    -> tensor<1x1x1xf32> {
1150  %0 = tensor.expand_shape %arg0 [] output_shape [1] : tensor<f32> into tensor<1xf32>
1151  %1 = tensor.expand_shape %0 [[0, 1, 2]] output_shape [1, 1, 1]
1152      : tensor<1xf32> into tensor<1x1x1xf32>
1153  return %1 : tensor<1x1x1xf32>
1154}
1155// CHECK-LABEL: compose_expand_of_expand_of_zero_dim
1156//       CHECK:   tensor.expand_shape %{{.*}} [] output_shape [1, 1, 1]
1157//  CHECK-SAME:     tensor<f32> into tensor<1x1x1xf32>
1158
1159// -----
1160
1161// CHECK-LABEL: func.func @collapse_of_cast(
1162// CHECK-SAME:         %[[IN:.*]]: tensor<8x12x32xf32>) -> tensor<?x32xf32> {
1163// CHECK-NEXT:    %[[COLLAPSE:.*]] = tensor.collapse_shape %[[IN]] {{\[}}[0, 1], [2]] : tensor<8x12x32xf32> into tensor<96x32xf32>
1164// CHECK-NEXT:    %[[CAST:.*]] = tensor.cast %[[COLLAPSE]] : tensor<96x32xf32> to tensor<?x32xf32>
1165// CHECK-NEXT:    return %[[CAST]] : tensor<?x32xf32>
1166func.func @collapse_of_cast(%t: tensor<8x12x32xf32>) -> tensor<?x32xf32> {
1167  %0 = tensor.cast %t : tensor<8x12x32xf32> to tensor<?x?x?xf32>
1168  %1 = tensor.collapse_shape %0 [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
1169  %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<?x32xf32>
1170  return %2 : tensor<?x32xf32>
1171}
1172
1173// -----
1174
1175func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> {
1176  %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [3, 4, 4]
1177      : tensor<12x4xf32> into tensor<3x4x4xf32>
1178  %1 = tensor.collapse_shape %0 [[0, 1], [2]]
1179      : tensor<3x4x4xf32> into tensor<12x4xf32>
1180  return %1 : tensor<12x4xf32>
1181}
1182// CHECK-LABEL: @fold_collapse_of_expand
1183//   CHECK-NOT:   tensor.{{.*}}_shape
1184
1185// -----
1186
1187func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index)
1188    -> tensor<?x?xf32> {
1189  %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
1190      : tensor<?x?xf32> into tensor<?x4x?xf32>
1191  %1 = tensor.collapse_shape %0 [[0, 1], [2]]
1192      : tensor<?x4x?xf32> into tensor<?x?xf32>
1193  return %1 : tensor<?x?xf32>
1194}
1195// CHECK-LABEL: @fold_collapse_of_expand_dynamic
1196//   CHECK-NOT:   tensor.{{.*}}_shape
1197
1198// -----
1199
1200func.func @fold_collapse_of_expand_fully_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
1201    -> tensor<?x?xf32> {
1202  %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
1203      : tensor<?x?xf32> into tensor<?x?x?xf32>
1204  %1 = tensor.collapse_shape %0 [[0, 1], [2]]
1205      : tensor<?x?x?xf32> into tensor<?x?xf32>
1206  return %1 : tensor<?x?xf32>
1207}
1208// CHECK-LABEL: @fold_collapse_of_expand_fully_dynamic
1209//   CHECK-NOT:   tensor.{{.*}}_shape
1210
1211// -----
1212
1213func.func @no_fold_parallel_collapse_of_expand_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index)
1214    -> tensor<?x?x?xf32> {
1215  %0 = tensor.expand_shape %arg0 [[0, 1], [2], [3]] output_shape [%arg1, %arg2, %arg3, %arg4]
1216      : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
1217  %1 = tensor.collapse_shape %0 [[0], [1], [2, 3]]
1218      : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
1219  return %1 : tensor<?x?x?xf32>
1220}
1221// CHECK-LABEL: @no_fold_parallel_collapse_of_expand_dynamic
1222//       CHECK:   tensor.expand_shape
1223//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape
1224//       CHECK:   return %[[COLLAPSE]]
1225
1226// -----
1227
1228func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf32> {
1229  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
1230      : tensor<3x4x4xf32> into tensor<12x4xf32>
1231  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 4]
1232      : tensor<12x4xf32> into tensor<3x4x4xf32>
1233  return %1 : tensor<3x4x4xf32>
1234}
1235// CHECK-LABEL: @fold_expand_of_collapse
1236//   CHECK-NOT:   tensor.{{.*}}_shape
1237
1238// -----
1239
1240func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
1241    -> tensor<?x4x?xf32> {
1242  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
1243      : tensor<?x4x?xf32> into tensor<?x?xf32>
1244  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
1245      : tensor<?x?xf32> into tensor<?x4x?xf32>
1246  return %1 : tensor<?x4x?xf32>
1247}
1248// CHECK-LABEL: @fold_expand_of_collapse_dynamic
1249//   CHECK-NOT:   tensor.{{.*}}_shape
1250
1251// -----
1252
1253func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
1254    -> tensor<?x?x?xf32> {
1255  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
1256      : tensor<?x?x?xf32> into tensor<?x?xf32>
1257  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
1258      : tensor<?x?xf32> into tensor<?x?x?xf32>
1259  return %1 : tensor<?x?x?xf32>
1260}
1261// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
1262//       CHECK:   tensor.collapse_shape
1263//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape
1264//       CHECK:   return %[[EXPAND]]
1265
1266// -----
1267
1268func.func @compose_expand_of_collapse_last_two_dims(%arg0: tensor<?x64x1xf32>) -> tensor<?x384xf32> {
1269  %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor<?x64x1xf32> into tensor<?xf32>
1270  %c0 = arith.constant 0 : index
1271  %dim = tensor.dim %collapsed, %c0 : tensor<?xf32>
1272  %c384= arith.constant 384 : index
1273  %div = arith.divui %dim, %c384 : index
1274  %expanded = tensor.expand_shape %collapsed [[0, 1]] output_shape [%div, 384] : tensor<?xf32> into tensor<?x384xf32>
1275  return %expanded : tensor<?x384xf32>
1276}
1277//       CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 64)>
1278// CHECK-LABEL: @compose_expand_of_collapse_last_two_dims
1279//  CHECK-SAME: %[[ARG0:.+]]: tensor<?x64x1xf32>
1280//       CHECK: %[[CONSTANT0:.+]] = arith.constant 0 : index
1281//       CHECK: %[[CONSTANT384:.+]] = arith.constant 384 : index
1282//       CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]] : tensor<?x64x1xf32> into tensor<?xf32>
1283//       CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[CONSTANT0]] : tensor<?x64x1xf32>
1284//       CHECK: %[[AFFAPPLY:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
1285//       CHECK: %[[DIVUI:.+]] = arith.divui %[[AFFAPPLY]], %[[CONSTANT384]] : index
1286//       CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1]] output_shape [%[[DIVUI]], 384] : tensor<?xf32> into tensor<?x384xf32>
1287//       CHECK: return %[[RESULT]]
1288
1289// -----
1290
1291func.func @compose_expand_of_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
1292    -> tensor<24x5x42x8xf32> {
1293  %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]]
1294      : tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32>
1295  %1 = tensor.expand_shape %0 [[0, 1, 2, 3]] output_shape [24, 5, 42, 8]
1296      : tensor<40320xf32> into tensor<24x5x42x8xf32>
1297  return %1 : tensor<24x5x42x8xf32>
1298}
1299//      CHECK: func @compose_expand_of_collapse
1300// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8xf32>
1301//      CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
1302// CHECK-SAME:     [0, 1, 2], [3], [4, 5], [6]
1303//      CHECK:   return %[[RESULT]]
1304
1305// -----
1306
1307func.func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>)
1308    -> tensor<2x3x4x5x6x7x8xf32> {
1309  %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3]]
1310      : tensor<24x5x42x8xf32> into tensor<40320xf32>
1311  %1 = tensor.expand_shape %0 [[0, 1, 2, 3, 4, 5, 6]] output_shape [2, 3, 4, 5, 6, 7, 8]
1312      : tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32>
1313  return %1 : tensor<2x3x4x5x6x7x8xf32>
1314}
1315//      CHECK: func @compose_expand_of_collapse_7D
1316// CHECK-SAME:   %[[ARG0:.+]]: tensor<24x5x42x8xf32>
1317//      CHECK:   %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]]
1318// CHECK-SAME:     [0, 1, 2], [3], [4, 5], [6]
1319//      CHECK:   return %[[RESULT]]
1320
1321// -----
1322
1323func.func @compose_collapse_of_expand(%arg : tensor<?x?x?xi64>, %arg1: index, %arg2: index, %arg3: index)
1324    -> tensor<?x?xi64> {
1325  %0 = tensor.expand_shape %arg [[0], [1], [2, 3]] output_shape [%arg1, %arg2, %arg3, 1]
1326    : tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
1327  %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
1328    : tensor<?x?x?x1xi64> into tensor<?x?xi64>
1329  return %1 : tensor<?x?xi64>
1330}
1331// CHECK-LABEL: func @compose_collapse_of_expand
1332//       CHECK:   (%[[ARG:.*]]: tensor<?x?x?xi64>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
1333//  CHECK-NEXT: tensor.collapse_shape %[[ARG]]
1334//  CHECK-SAME:   [0, 1], [2]
1335//  CHECK-SAME:   : tensor<?x?x?xi64> into tensor<?x?xi64>
1336
1337// -----
1338
1339func.func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>)
1340    -> tensor<4x512xf32> {
1341  %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3]] output_shape [1, 4, 1, 512]
1342    : tensor<2048xf32> into tensor<1x4x1x512xf32>
1343  %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]]
1344    : tensor<1x4x1x512xf32> into tensor<4x512xf32>
1345  return %1 : tensor<4x512xf32>
1346}
1347//       CHECK: func @compose_collapse_of_expand_1D
1348//       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [4, 512]
1349//  CHECK-SAME:   tensor<2048xf32> into tensor<4x512xf32>
1350
1351// -----
1352
1353func.func @compose_expand_of_collapse_0_rank_to_expand(%arg0 : tensor<1x1x1xf32>)
1354    -> tensor<1x1x1x1xf32> {
1355  %0 = tensor.collapse_shape %arg0 []
1356      : tensor<1x1x1xf32> into tensor<f32>
1357  %1 = tensor.expand_shape %0 [] output_shape [1, 1, 1, 1]
1358      : tensor<f32> into tensor<1x1x1x1xf32>
1359  return %1 : tensor<1x1x1x1xf32>
1360}
1361//      CHECK: func @compose_expand_of_collapse_0_rank_to_expand
1362// CHECK-SAME:   %[[ARG0:.+]]: tensor<1x1x1xf32>
1363//      CHECK:   %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]]
1364// CHECK-SAME:     {{\[}}[0], [1], [2, 3]] output_shape [1, 1, 1, 1]
1365//      CHECK:   return %[[RESULT]]
1366
1367// -----
1368
1369func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1xf32>)
1370    -> tensor<1x1x1xf32> {
1371  %0 = tensor.collapse_shape %arg0 []
1372      : tensor<1x1x1x1xf32> into tensor<f32>
1373  %1 = tensor.expand_shape %0 [] output_shape [1, 1, 1]
1374      : tensor<f32> into tensor<1x1x1xf32>
1375  return %1 : tensor<1x1x1xf32>
1376}
1377//      CHECK: func @compose_expand_of_collapse_0_rank_to_collapse
1378// CHECK-SAME:   %[[ARG0:.+]]: tensor<1x1x1x1xf32>
1379//      CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
1380// CHECK-SAME:     [0], [1], [2, 3]
1381//      CHECK:   return %[[RESULT]]
1382
1383// -----
1384
1385func.func @compose_expand_of_collapse_static(%arg0 : tensor<4x32x10x64x2xf16>) -> tensor<4x32x10x128xf16> {
1386  %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x32x10x64x2xf16> into tensor<128x10x128xf16>
1387  %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 32, 10, 128] : tensor<128x10x128xf16> into tensor<4x32x10x128xf16>
1388  return %expanded : tensor<4x32x10x128xf16>
1389}
1390
1391// CHECK-LABEL: func @compose_expand_of_collapse_static
1392// CHECK-SAME:   %[[ARG0:.+]]: tensor<4x32x10x64x2xf16>
1393//      CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
1394// CHECK-SAME:     [0], [1], [2], [3, 4]
1395//      CHECK:   return %[[RESULT]]
1396
1397// -----
1398
1399func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %arg1 : index) -> tensor<4x?x10x128xf16> {
1400  %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x?x10x64x2xf16> into tensor<?x10x128xf16>
1401  %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, %arg1,  10, 128] : tensor<?x10x128xf16> into tensor<4x?x10x128xf16>
1402  return %expanded : tensor<4x?x10x128xf16>
1403}
1404
1405// CHECK-LABEL: func @compose_expand_of_collapse_dynamic
1406// CHECK-SAME:   %[[ARG0:.+]]: tensor<4x?x10x64x2xf16>
1407//      CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
1408// CHECK-SAME:     [0], [1], [2], [3, 4]
1409//      CHECK:   return %[[RESULT]]
1410
1411// -----
1412
1413// CHECK-LABEL: func @zero_rank_reshape_multi
1414func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
1415  // CHECK: return %arg0
1416  %0 = tensor.expand_shape %arg0 [] output_shape [1] : tensor<f32> into tensor<1xf32>
1417  %1 = tensor.expand_shape %0 [[0, 1]] output_shape [1, 1] : tensor<1xf32> into tensor<1x1xf32>
1418  %2 = tensor.collapse_shape %1 [] : tensor<1x1xf32> into tensor<f32>
1419  return %2 : tensor<f32>
1420}
1421
1422// -----
1423
1424func.func @compose_collapse_of_collapse(%arg0 : tensor<?x?x?x?x?xf32>)
1425    -> tensor<?x?xf32> {
1426  %0 = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]]
1427      : tensor<?x?x?x?x?xf32> into tensor<?x?x?xf32>
1428  %1 = tensor.collapse_shape %0 [[0, 1], [2]]
1429      : tensor<?x?x?xf32> into tensor<?x?xf32>
1430  return %1 : tensor<?x?xf32>
1431}
1432// CHECK-LABEL: func @compose_collapse_of_collapse
1433//       CHECK:   tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
1434//   CHECK-NOT:   tensor.collapse_shape
1435
1436// -----
1437
1438func.func @compose_collapse_of_collapse_zero_dim(%arg0 : tensor<1x1x1xf32>)
1439    -> tensor<f32> {
1440  %0 = tensor.collapse_shape %arg0 [[0, 1, 2]]
1441      : tensor<1x1x1xf32> into tensor<1xf32>
1442  %1 = tensor.collapse_shape %0 [] : tensor<1xf32> into tensor<f32>
1443  return %1 : tensor<f32>
1444}
1445// CHECK-LABEL: func @compose_collapse_of_collapse_zero_dim
1446//       CHECK:   tensor.collapse_shape %{{.*}} []
1447//  CHECK-SAME:     tensor<1x1x1xf32> into tensor<f32>
1448
1449// -----
1450
1451func.func @fold_collapse_of_expand_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> {
1452  %0 = tensor.expand_shape %arg0 [[0, 1, 2], [3]] output_shape [1, 4, 1, 512]
1453    : tensor<4x512xf32> into tensor<1x4x1x512xf32>
1454  %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]]
1455    : tensor<1x4x1x512xf32> into tensor<2048xf32>
1456  return %1 : tensor<2048xf32>
1457}
1458//       CHECK: func @fold_collapse_of_expand_1D
1459//       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1]]
1460//  CHECK-SAME:   tensor<4x512xf32> into tensor<2048xf32>
1461
1462// -----
1463
1464func.func @fold_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x1xf32>)
1465    -> tensor<4x512x1x1xf32> {
1466  %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3], [4], [5]] output_shape [1, 4, 1, 512, 1, 1] : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32>
1467  %1 = tensor.collapse_shape %0 [[0, 1, 2], [3], [4], [5]]
1468    : tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32>
1469  return %1 : tensor<4x512x1x1xf32>
1470}
1471//       CHECK: func @fold_collapse_of_expand_unit_dims
1472//       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]] output_shape [4, 512, 1, 1]
1473//  CHECK-SAME:   tensor<2048x1x1xf32> into tensor<4x512x1x1xf32>
1474
1475// -----
1476
1477func.func @compose_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x2048xf32>)
1478    -> tensor<4x512x1x512x4xf32> {
1479  %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]] output_shape [1, 4, 1, 512, 1, 1, 512, 1, 4] : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32>
1480  %1 = tensor.collapse_shape %0 [[0, 1, 2], [3, 4], [5], [6, 7], [8]]
1481    : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32>
1482  return %1 : tensor<4x512x1x512x4xf32>
1483}
1484//       CHECK: func @compose_collapse_of_expand_unit_dims
1485//       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]] output_shape [4, 512, 1, 512, 4]
1486//  CHECK-SAME:   tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32>
1487
1488// -----
1489
1490func.func @compose_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>)
1491    -> tensor<2x1xf32> {
1492  %0 = tensor.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 1, 1]
1493      : tensor<2xf32> into tensor<2x1x1xf32>
1494  %1 = tensor.collapse_shape %0 [[0], [1, 2]]
1495      : tensor<2x1x1xf32> into tensor<2x1xf32>
1496  return %1 : tensor<2x1xf32>
1497}
1498//       CHECK: func @compose_collapse_of_expand_trailing_unit_dims
1499//       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [2, 1]
1500//  CHECK-SAME:   tensor<2xf32> into tensor<2x1xf32>
1501
1502// -----
1503
1504func.func @compose_collapse_of_collapse_unit_dims_dynamic(
1505    %arg0 : tensor<?x1x?x1x1x?x?x1x1xf32>) -> tensor<?x?x?x?xf32> {
1506  %0 = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]]
1507    : tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x1x1x?x?xf32>
1508  %1 = tensor.collapse_shape %0 [[0], [1], [2, 3, 4], [5]]
1509    : tensor<?x?x1x1x?x?xf32> into tensor<?x?x?x?xf32>
1510  return %1 : tensor<?x?x?x?xf32>
1511}
1512//       CHECK: func @compose_collapse_of_collapse_unit_dims_dynamic
1513//       CHECK: tensor.collapse_shape
1514//  CHECK-SAME:   [0], [1, 2], [3, 4, 5], [6, 7, 8]
1515//  CHECK-SAME:   tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x?x?xf32>
1516
1517// -----
1518
1519func.func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>)
1520    -> tensor<2x1xf32> {
1521  %0 = tensor.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 1, 1] : tensor<2xf32> into tensor<2x1x1xf32>
1522  %1 = tensor.collapse_shape %0 [[0], [1, 2]]
1523      : tensor<2x1x1xf32> into tensor<2x1xf32>
1524  return %1 : tensor<2x1xf32>
1525}
1526//       CHECK: func @fold_collapse_of_expand_trailing_unit_dims
1527//       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [2, 1]
1528//  CHECK-SAME:   tensor<2xf32> into tensor<2x1xf32>
1529
1530// -----
1531
1532func.func @fold_collapse_of_collapse_trailing_unit_dims_dynamic(
1533    %arg0: tensor<1x1x?x1x1x1xf32>) -> tensor<?xf32> {
1534  %0 = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4], [5]]
1535      : tensor<1x1x?x1x1x1xf32> into tensor<?x1x1x1xf32>
1536  %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]]
1537      : tensor<?x1x1x1xf32> into tensor<?xf32>
1538  return %1 : tensor<?xf32>
1539}
1540//       CHECK: func @fold_collapse_of_collapse_trailing_unit_dims_dynamic
1541//       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]]
1542//  CHECK-SAME:   tensor<1x1x?x1x1x1xf32> into tensor<?xf32>
1543
1544// -----
1545
1546func.func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>)
1547    -> tensor<12x42xf32> {
1548  %0 = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [12, 42, 1, 1, 1] : tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32>
1549  %1 = tensor.collapse_shape %0 [[0], [1, 2, 3, 4]]
1550      : tensor<12x42x1x1x1xf32> into tensor<12x42xf32>
1551  return %1 : tensor<12x42xf32>
1552}
1553//       CHECK: func @fold_collapse_of_expand_trailing_unit_dims
1554//       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0], [1, 2, 3]]
1555//  CHECK-SAME:   tensor<12x42x1x1xf32> into tensor<12x42xf32>
1556
1557// -----
1558
1559func.func @fold_collapse_of_expand_unit_dims_in_middle(%arg0 : tensor<?x?x?xf32>, %sz0: index, %sz1: index, %sz2: index)
1560    -> tensor<?x?xf32> {
1561  %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [%sz0, %sz1, 1, %sz2]
1562      : tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
1563  %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]]
1564      : tensor<?x?x1x?xf32> into tensor<?x?xf32>
1565  return %1 : tensor<?x?xf32>
1566}
1567// CHECK-LABEL: func @fold_collapse_of_expand_unit_dims_in_middle
1568//  CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>
1569//       CHECK: tensor.collapse_shape %[[ARG]] {{\[}}[0], [1, 2]]
1570//  CHECK-SAME:   tensor<?x?x?xf32> into tensor<?x?xf32>
1571
1572// -----
1573
1574func.func @no_fold_collapse_of_expand_incompatible(%arg0 : tensor<4x6x8xf32>)
1575    -> tensor<2x6x16xf32> {
1576  %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]] output_shape [2, 2, 3, 2, 8]
1577      : tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32>
1578  %1 = tensor.collapse_shape %0 [[0], [1, 2], [3, 4]]
1579      : tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32>
1580  return %1 : tensor<2x6x16xf32>
1581}
1582// CHECK-LABEL: func @no_fold_collapse_of_expand_incompatible
1583//       CHECK:   tensor.expand_shape
1584//       CHECK:   tensor.collapse_shape
1585
1586// -----
1587
1588func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>)
1589    -> tensor<12x1xf32> {
1590  %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [3, 2, 2, 1]
1591      : tensor<3x2x2xf32> into tensor<3x2x2x1xf32>
1592  %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]]
1593      : tensor<3x2x2x1xf32> into tensor<12x1xf32>
1594  return %1 : tensor<12x1xf32>
1595}
1596//      CHECK: func @no_fold_collapse_of_expand_empty_expr
1597// CHECK-SAME:    %[[ARG0:.+]]: tensor<3x2x2xf32>
1598//      CHECK:    %[[RARG0:.+]] = tensor.expand_shape %[[ARG0]]
1599// CHECK-SAME:      {{\[}}[0], [1], [2, 3]] output_shape [3, 2, 2, 1]
1600//      CHECK:    %[[RES:.+]] = tensor.collapse_shape %[[RARG0]]
1601// CHECK-SAME:      [0, 1, 2], [3]
1602//      CHECK:    return %[[RES:.+]] : tensor<12x1xf32>
1603
1604// -----
1605
1606func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> {
1607  %c0 = arith.constant dense<42> : tensor<2x8xi32>
1608  %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
1609      : tensor<2x8xi32> into tensor<2x4x2xi32>
1610  return %0 : tensor<2x4x2xi32>
1611}
1612// CHECK-LABEL: @reshape_splat_constant_int32
1613//       CHECK:   %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xi32>
1614//   CHECK-NOT:   tensor.expand_shape
1615//       CHECK:   return %[[CST]]
1616// -----
1617func.func @expand_shape_splat(%arg : f32) -> tensor<2x2x2xf32> {
1618  %c0 = tensor.splat %arg : tensor<2x4xf32>
1619  %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 2, 2]
1620      : tensor<2x4xf32> into tensor<2x2x2xf32>
1621  return %0 : tensor<2x2x2xf32>
1622}
1623// CHECK-LABEL: @expand_shape_splat
1624// CHECK-SAME:    %[[ARG0:.+]]: f32
1625//       CHECK:   %[[CST:.*]] = tensor.splat %[[ARG0:.+]] : tensor<2x2x2xf32>
1626//   CHECK-NOT:   tensor.expand_shape
1627//       CHECK:   return %[[CST]]
1628
1629// -----
1630
1631// CHECK-LABEL: @expand_shape_splat_dynamic_no_fold
1632// CHECK-SAME: (%[[F:.+]]: f32, %[[M:.+]]: index, %[[SZ0:.+]]: index)
1633func.func @expand_shape_splat_dynamic_no_fold(%arg: f32, %m: index, %sz0: index) -> tensor<2x2x?xf32> {
1634  // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]] : tensor<2x?xf32>
1635  // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[SPLAT]]
1636  %c0 = tensor.splat %arg[%m] : tensor<2x?xf32>
1637  %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 2, %sz0] : tensor<2x?xf32> into tensor<2x2x?xf32>
1638  return %0 : tensor<2x2x?xf32>
1639}
1640
1641// -----
1642
1643func.func @collapse_shape_splat(%arg : f32) -> tensor<2x4xf32> {
1644  %c0 = tensor.splat %arg : tensor<2x2x2xf32>
1645  %0 = tensor.collapse_shape %c0 [[0], [1, 2]]
1646      : tensor<2x2x2xf32> into tensor<2x4xf32>
1647  return %0 : tensor<2x4xf32>
1648}
1649// CHECK-LABEL: @collapse_shape_splat
1650// CHECK-SAME:    %[[ARG0:.+]]: f32
1651//       CHECK:   %[[CST:.*]] = tensor.splat %[[ARG0:.+]] : tensor<2x4xf32>
1652//   CHECK-NOT:   tensor.collapse_shape
1653//       CHECK:   return %[[CST]]
1654
1655// -----
1656
1657// CHECK-LABEL: @collapse_shape_splat_dynamic_no_fold
1658// CHECK-SAME: %[[F:.+]]: f32
1659// CHECK-SAME: %[[M:.+]]: index
1660func.func @collapse_shape_splat_dynamic_no_fold(%f: f32, %m: index) -> tensor<2x?xf32> {
1661  // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]]
1662  // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[SPLAT]]
1663  %c0 = tensor.splat %f[%m] : tensor<2x2x?xf32>
1664  %0 = tensor.collapse_shape %c0 [[0], [1, 2]] : tensor<2x2x?xf32> into tensor<2x?xf32>
1665  return %0 : tensor<2x?xf32>
1666}
1667
1668// -----
1669
1670func.func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> {
1671  %c0 = arith.constant dense<42> : tensor<2x8xi16>
1672  %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
1673      : tensor<2x8xi16> into tensor<2x4x2xi16>
1674  return %0 : tensor<2x4x2xi16>
1675}
1676// CHECK-LABEL: @reshape_splat_constant_int16
1677//       CHECK:   %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xi16>
1678//   CHECK-NOT:   tensor.expand_shape
1679//       CHECK:   return %[[CST]]
1680
1681// -----
1682
1683func.func @reshape_splat_constant_float32() -> tensor<2x4x2xf32> {
1684  %c0 = arith.constant dense<42.0> : tensor<2x8xf32>
1685  %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
1686      : tensor<2x8xf32> into tensor<2x4x2xf32>
1687  return %0 : tensor<2x4x2xf32>
1688}
1689// CHECK-LABEL: @reshape_splat_constant_float32
1690//       CHECK:   %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xf32>
1691//   CHECK-NOT:   tensor.expand_shape
1692//       CHECK:   return %[[CST]]
1693
1694// -----
1695
1696func.func @reshape_splat_constant_float64() -> tensor<2x4x2xf64> {
1697  %c0 = arith.constant dense<42.0> : tensor<2x8xf64>
1698  %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
1699      : tensor<2x8xf64> into tensor<2x4x2xf64>
1700  return %0 : tensor<2x4x2xf64>
1701}
1702// CHECK-LABEL: @reshape_splat_constant_float64
1703//       CHECK:   %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xf64>
1704//   CHECK-NOT:   tensor.expand_shape
1705//       CHECK:   return %[[CST]]
1706
1707// -----
1708
1709// CHECK-LABEL: func @fold_rank
1710func.func @fold_rank() -> (index) {
1711  %const_0 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]>
1712    : tensor<2x1x4xi32>
1713
1714  // Fold a ank into a constant
1715  // CHECK-NEXT: [[C3:%.+]] = arith.constant 3 : index
1716  %rank_0 = tensor.rank %const_0 : tensor<2x1x4xi32>
1717
1718  // CHECK-NEXT: return [[C3]]
1719  return %rank_0 : index
1720}
1721
1722// -----
1723
1724// CHECK-LABEL: func @pad_same_static_shape(
1725//  CHECK-SAME:   %[[ARG0:.*]]: tensor<5x6xf32>
1726//   CHECK-NOT:   tensor.pad
1727//       CHECK:   return %[[ARG0]]
1728func.func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
1729    -> tensor<5x6xf32> {
1730  %cst = arith.constant 0.000000e+00 : f32
1731  %0 = tensor.pad %arg0 low[%a, 0] high[0, %a] {
1732        ^bb0(%arg1: index, %arg2: index):
1733          tensor.yield %cst : f32
1734  } : tensor<5x6xf32> to tensor<5x6xf32>
1735  return %0 : tensor<5x6xf32>
1736}
1737
1738// -----
1739
1740// CHECK-LABEL:   func @pad_fold_static(
1741// CHECK-SAME:      %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
1742// CHECK:           %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1743// CHECK-NOT:       arith.constant 4 : index
1744// CHECK:           %[[PADDED:.*]] = tensor.pad %[[INPUT]]
1745// CHECK-SAME:        low[0, 4, 1, 1] high[0, 4, 1, 1]  {
1746// CHECK:           ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
1747// CHECK:             tensor.yield %[[CST]] : f32
1748// CHECK:           } : tensor<?x64x?x?xf32> to tensor<?x72x?x?xf32>
1749// CHECK:           tensor.cast
1750func.func @pad_fold_static(%arg0: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
1751  %c0 = arith.constant 0 : index
1752  %cst = arith.constant 0.000000e+00 : f32
1753  %padding = arith.constant 4 : index
1754  %padded = tensor.pad %arg0 low[0, %padding, 1, 1] high[0, %padding, 1, 1]  {
1755    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
1756    tensor.yield %cst: f32
1757  } : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
1758  return %padded : tensor<?x?x?x?xf32>
1759}
1760
1761// -----
1762
1763// CHECK-LABEL: func @pad_nofold_same_static_shape(
1764//  CHECK-SAME:   %[[ARG0:.*]]: tensor<5x6xf32>
1765//       CHECK:   %[[PAD:.*]] = tensor.pad
1766//       CHECK:   return %[[PAD]]
1767func.func @pad_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
1768    -> tensor<5x6xf32> {
1769  %cst = arith.constant 0.000000e+00 : f32
1770  %0 = tensor.pad %arg0 nofold low[%a, 0] high[0, %a] {
1771        ^bb0(%arg1: index, %arg2: index):
1772          tensor.yield %cst : f32
1773  } : tensor<5x6xf32> to tensor<5x6xf32>
1774  return %0 : tensor<5x6xf32>
1775}
1776
1777// -----
1778
1779// CHECK-LABEL:   func @pad_after_cast_different_shape(
1780// CHECK-SAME:      %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
1781// CHECK:           %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1782// CHECK:           %[[PADDED:.*]] = tensor.pad %[[INPUT]]
1783// CHECK-SAME:        low[0, 0, 1, 1] high[0, 0, 1, 1]  {
1784// CHECK:           ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
1785// CHECK:             tensor.yield %[[CST]] : f32
1786// CHECK:           } : tensor<?x64x?x?xf32> to tensor<?x64x?x?xf32>
1787// CHECK:           %[[DYNAMIC:.*]] = tensor.cast %[[PADDED:.*]] :
1788// CHECK-SAME:         tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
1789// CHECK:           return %[[DYNAMIC]] : tensor<?x?x?x?xf32>
1790// CHECK:         }
1791func.func @pad_after_cast_different_shape(%arg0: tensor<?x64x?x?xf32>)
1792    -> tensor<?x?x?x?xf32> {
1793  %cst = arith.constant 0.000000e+00 : f32
1794  %dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
1795  %padded = tensor.pad %dynamic low[0, 0, 1, 1] high[0, 0, 1, 1]  {
1796    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
1797    tensor.yield %cst: f32
1798  } : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
1799  return %padded: tensor<?x?x?x?xf32>
1800}
1801
1802// -----
1803
1804// CHECK-LABEL:   func @pad_after_cast_same_shape(
1805// CHECK-SAME:      %[[INPUT:.*]]: tensor<?x64x?x?xf32>,
1806// CHECK-SAME:      %[[PADDING:.*]]: index) -> tensor<?x?x?x?xf32> {
1807// CHECK:           %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1808// CHECK:           %[[PADDED:.*]] = tensor.pad %[[INPUT]]
1809// CHECK-SAME:        low[0, %[[PADDING]], 1, 1] high[0, %[[PADDING]], 1, 1]  {
1810// CHECK:           ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
1811// CHECK:             tensor.yield %[[CST]] : f32
1812// CHECK:           } : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
1813// CHECK:           return %[[PADDED:.*]] : tensor<?x?x?x?xf32>
1814// CHECK:         }
1815func.func @pad_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : index)
1816    -> tensor<?x?x?x?xf32> {
1817  %cst = arith.constant 0.000000e+00 : f32
1818  %dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
1819  %padded = tensor.pad %dynamic low[0, %padding, 1, 1] high[0, %padding, 1, 1]  {
1820    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
1821    tensor.yield %cst: f32
1822  } : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
1823  return %padded: tensor<?x?x?x?xf32>
1824}
1825
1826// -----
1827
1828// CHECK-LABEL: func @pad_of_cast(
1829// CHECK-NOT:     tensor.cast
1830// CHECK:         tensor.pad
1831// CHECK:         tensor<8x?xf32> to tensor<8x32xf32>
1832func.func @pad_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> {
1833  %c0 = arith.constant 0 : index
1834  %cst = arith.constant 0.000000e+00 : f32
1835  %0 = tensor.cast %t : tensor<8x?xf32> to tensor<?x?xf32>
1836  %1 = tensor.pad %0 low[%c0, %c0] high[%c0, %s]  {
1837  ^bb0(%arg9: index, %arg10: index):
1838    tensor.yield %cst : f32
1839  } : tensor<?x?xf32> to tensor<8x32xf32>
1840  return %1 : tensor<8x32xf32>
1841}
1842
1843// -----
1844
1845// CHECK-LABEL: @cast_of_pad_more_static
1846func.func @cast_of_pad_more_static(%arg0: tensor<?x?xf32>, %padding: index) -> tensor<32x32xf32> {
1847  %cst = arith.constant 0.000000e+00 : f32
1848  // CHECK: %[[PAD:.*]] = tensor.pad
1849  // CHECK: tensor<?x?xf32> to tensor<32x32xf32>
1850  %padded = tensor.pad %arg0 low[%padding, %padding] high[0, 0] {
1851  ^bb0(%arg1: index, %arg2: index):
1852    tensor.yield %cst : f32
1853  } : tensor<?x?xf32> to tensor<?x?xf32>
1854  // CHECK-NOT: tensor.cast
1855  %casted = tensor.cast %padded : tensor<?x?xf32> to tensor<32x32xf32>
1856  // CHECK: return %[[PAD]]
1857  return %casted : tensor<32x32xf32>
1858}
1859
1860// -----
1861
1862// CHECK-LABEL: @cast_of_pad_less_static
1863func.func @cast_of_pad_less_static(%arg0: tensor<32x?x?xf32>, %padding: index) -> tensor<?x32x32xf32> {
1864  %cst = arith.constant 0.000000e+00 : f32
1865  // CHECK: tensor.pad
1866  %padded = tensor.pad %arg0 low[%padding, %padding, %padding] high[0, 0, 0] {
1867  ^bb0(%arg1: index, %arg2: index, %arg3: index):
1868    tensor.yield %cst : f32
1869  } : tensor<32x?x?xf32> to tensor<32x?x?xf32>
1870  // CHECK: %[[CAST:.*]] = tensor.cast
1871  %casted = tensor.cast %padded : tensor<32x?x?xf32> to tensor<?x32x32xf32>
1872  // CHECK: return %[[CAST]]
1873  return %casted : tensor<?x32x32xf32>
1874}
1875
1876// -----
1877
1878func.func @pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
1879  %c0 = arith.constant 0 : index
1880  %cst = arith.constant 0.0 : f32
1881  %0 = tensor.cast %arg0 : tensor<4x4xf32> to tensor<?x?xf32>
1882  %1 = tensor.pad %0 low[%c0, %c0] high[%c0, %c0]  {
1883    ^bb0(%arg1: index, %arg2: index):
1884      tensor.yield %cst : f32
1885  } : tensor<?x?xf32> to tensor<4x4xf32>
1886  return %1 : tensor<4x4xf32>
1887}
1888// CHECK-LABEL: @pad_cast
1889// CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32>
1890// CHECK: return %[[ARG0]]
1891
1892// -----
1893
1894// CHECK-LABEL: func @fold_pad_source_cast(
1895//  CHECK-SAME:                  %[[ARG0:.*]]: tensor<4x?xf32>
1896//   CHECK-NOT:   tensor.cast
1897//       CHECK:   %[[RESULT:.*]] = tensor.pad %[[ARG0]]
1898func.func @fold_pad_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> {
1899  %cst = arith.constant 0.0 : f32
1900  %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
1901  %1 = tensor.pad %0 low[0, 0] high[0, 1]  {
1902    ^bb0(%arg1: index, %arg2: index):
1903      tensor.yield %cst : f32
1904  } : tensor<?x?xf32> to tensor<4x4xf32>
1905  return %1 : tensor<4x4xf32>
1906}
1907
1908// -----
1909
1910// CHECK-LABEL: func @pad_static_zero_cast(
1911//  CHECK-SAME:                  %[[ARG0:.*]]: tensor<?x?x?xf32>
1912//   CHECK-NOT:   tensor.pad
1913//       CHECK:   %[[RESULT:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
1914//       CHECK:   return %[[RESULT]]
1915func.func @pad_static_zero_cast(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
1916  %c0 = arith.constant 0 : index
1917  %0 = tensor.pad %arg0 low[0, %c0, 0] high[0, 0, %c0] {
1918    ^bb0(%arg1: index, %arg2: index, %arg3: index):
1919      tensor.yield %pad_value : f32
1920    } : tensor<?x?x?xf32> to tensor<2x3x4xf32>
1921
1922  return %0 : tensor<2x3x4xf32>
1923}
1924
1925// -----
1926
1927// CHECK-LABEL: func @pad_nofold_static_zero(
1928//  CHECK-SAME:                  %[[ARG0:.*]]: tensor<?x?x?xf32>
1929//       CHECK:   %[[PAD:.*]] = tensor.pad
1930//       CHECK:   return %[[PAD]]
1931func.func @pad_nofold_static_zero(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
1932  %c0 = arith.constant 0 : index
1933  %0 = tensor.pad %arg0 nofold low[0, %c0, 0] high[0, 0, %c0] {
1934    ^bb0(%arg1: index, %arg2: index, %arg3: index):
1935      tensor.yield %pad_value : f32
1936    } : tensor<?x?x?xf32> to tensor<2x3x4xf32>
1937
1938  return %0 : tensor<2x3x4xf32>
1939}
1940
1941// -----
1942
1943// CHECK-LABEL: func @fold_orthogonal_pad_chains(
1944//  CHECK-SAME:   %[[ARG0:.*]]: tensor<64x64xf32>,
1945//  CHECK-SAME:   %[[SZ0:.*]]: index, %[[SZ1:.*]]: index, %[[PW0:.*]]: index, %[[PW1:.*]]: index
1946func.func @fold_orthogonal_pad_chains(%arg0: tensor<64x64xf32>,
1947                                      %sz0 : index, %sz1 : index,
1948                                      %pw0 : index, %pw1 : index) -> tensor<8x4xf32> {
1949  //       CHECK:   %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
1950  //  CHECK-SAME:                     [16, 4] [%[[SZ0]], %[[SZ1]]]
1951  //       CHECK:   %[[PAD:.*]] = tensor.pad %[[T0]] nofold
1952  //  CHECK-SAME:                     high[%[[PW0]], %[[PW1]]]
1953  //       CHECK:   return %[[PAD]]
1954  %pad_value = arith.constant 0.0 : f32
1955  %0 = tensor.extract_slice %arg0[16, 0] [%sz0, 64] [1, 1] : tensor<64x64xf32> to tensor<?x64xf32>
1956  %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] {
1957    ^bb0(%arg1: index, %arg2: index):
1958      tensor.yield %pad_value : f32
1959    } : tensor<?x64xf32> to tensor<8x64xf32>
1960  %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32>
1961  %3 = tensor.pad %2 nofold low[0, 0] high[0, %pw1] {
1962    ^bb0(%arg1: index, %arg2: index):
1963      tensor.yield %pad_value : f32
1964    } : tensor<8x?xf32> to tensor<8x4xf32>
1965  func.return %3 : tensor<8x4xf32>
1966}
1967
1968// -----
1969
1970// CHECK-LABEL: func @dont_fold_pad_chains(
1971//  CHECK-SAME:   %[[ARG0:.*]]: tensor<64x64xf32>,
1972//  CHECK-SAME:   %[[SZ0:.*]]: index, %[[SZ1:.*]]: index, %[[PW0:.*]]: index, %[[PW1:.*]]: index
1973func.func @dont_fold_pad_chains(%arg0: tensor<64x64xf32>,
1974                                %sz0 : index, %sz1 : index,
1975                                %pw0 : index, %pw1 : index) -> (tensor<8x4xf32>, tensor<4x64xf32>, tensor<8x4xf32>, tensor<6x4xf32>) {
1976  //       CHECK:   %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
1977  //       CHECK:   %[[T1:.*]] = tensor.pad %[[T0]]
1978  %pad_value = arith.constant 0.0 : f32
1979  %0 = tensor.extract_slice %arg0[16, 0] [%sz0, 64] [1, 1] : tensor<64x64xf32> to tensor<?x64xf32>
1980  %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] {
1981    ^bb0(%arg1: index, %arg2: index):
1982      tensor.yield %pad_value : f32
1983    } : tensor<?x64xf32> to tensor<8x64xf32>
1984
1985  // Don't fold if the padding values are different.
1986  //       CHECK:   %[[T2:.*]] = tensor.extract_slice %[[T1]]
1987  //  CHECK-SAME:                     [0, 4] [8, %[[SZ1]]]
1988  //       CHECK:   %[[PAD0:.*]] = tensor.pad %[[T2]]
1989  %different_value = arith.constant 1.0 : f32
1990  %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32>
1991  %3 = tensor.pad %2 nofold low[0, 0] high[0, %pw1] {
1992    ^bb0(%arg1: index, %arg2: index):
1993      tensor.yield %different_value : f32
1994    } : tensor<8x?xf32> to tensor<8x4xf32>
1995
1996  // Don't fold if the pad ops have common padding dimensions.
1997  //       CHECK:   %[[T3:.*]] = tensor.extract_slice %[[T1]]
1998  //  CHECK-SAME:                     [4, 0] [%[[SZ1]], 64]
1999  //       CHECK:   %[[PAD1:.*]] = tensor.pad %[[T3]]
2000  %4 = tensor.extract_slice %1[4, 0] [%sz1, 64] [1, 1] : tensor<8x64xf32> to tensor<?x64xf32>
2001  %5 = tensor.pad %4 nofold low[0, 0] high[%pw1, 0] {
2002    ^bb0(%arg1: index, %arg2: index):
2003      tensor.yield %pad_value : f32
2004    } : tensor<?x64xf32> to tensor<4x64xf32>
2005
2006  // Don't fold if padded source tensor dimension is accessed at an offset.
2007  //       CHECK:   %[[T4:.*]] = tensor.extract_slice %[[T1]]
2008  //  CHECK-SAME:                     [%[[SZ0]], 4] [8, %[[SZ1]]
2009  //       CHECK:   %[[PAD2:.*]] = tensor.pad %[[T4]]
2010  %6 = tensor.extract_slice %1[%sz0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32>
2011  %7 = tensor.pad %6 nofold low[0, 0] high[0, %pw1] {
2012    ^bb0(%arg1: index, %arg2: index):
2013      tensor.yield %pad_value : f32
2014    } : tensor<8x?xf32> to tensor<8x4xf32>
2015
2016  // Don't fold if a padded source tensor dimension is sliced.
2017  //       CHECK:   %[[T5:.*]] = tensor.extract_slice %[[T1]]
2018  //  CHECK-SAME:                     [0, 4] [6, %[[SZ1]]
2019  //       CHECK:   %[[PAD3:.*]] = tensor.pad %[[T5]]
2020  %8 = tensor.extract_slice %1[0, 4] [6, %sz1] [1, 1] : tensor<8x64xf32> to tensor<6x?xf32>
2021  %9 = tensor.pad %8 nofold low[0, 0] high[0, %pw1] {
2022    ^bb0(%arg1: index, %arg2: index):
2023      tensor.yield %pad_value : f32
2024    } : tensor<6x?xf32> to tensor<6x4xf32>
2025
2026  //       CHECK:   return %[[PAD0]], %[[PAD1]], %[[PAD2]], %[[PAD3]]
2027  func.return %3, %5, %7, %9 : tensor<8x4xf32>, tensor<4x64xf32>, tensor<8x4xf32>, tensor<6x4xf32>
2028}
2029
2030// -----
2031
2032// CHECK-LABEL: func @merge_constant_padding
2033//  CHECK-SAME:   %[[ARG0:[A-Za-z0-9]+]]: tensor<2x3xf32>
2034//  CHECK-SAME:   %[[PADVAL:[A-Za-z0-9]+]]: f32
2035//       CHECK:   %[[PAD:.+]] = tensor.pad %[[ARG0]] low[1, 3] high[4, 2]
2036//       CHECK:     tensor.yield %[[PADVAL]]
2037//       CHECK:   return %[[PAD]]
2038func.func @merge_constant_padding(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
2039  %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
2040    ^bb0(%b0: index, %b1 : index):
2041      tensor.yield %pad_value : f32
2042    } : tensor<2x3xf32> to tensor<4x4xf32>
2043  %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
2044    ^bb0(%b2: index, %b3 : index):
2045      tensor.yield %pad_value : f32
2046    } : tensor<4x4xf32> to tensor<7x8xf32>
2047  return %pad1 : tensor<7x8xf32>
2048}
2049
2050// -----
2051
2052//       CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 + 1)>
2053// CHECK-LABEL: func @merge_constant_padding_dynamic
2054//  CHECK-SAME:   %[[ARG0:[A-Za-z0-9]+]]: tensor<?x?xf32>
2055//  CHECK-SAME:   %[[IDX:[A-Za-z0-9]+]]: index
2056//  CHECK-SAME:   %[[PADVAL:[A-Za-z0-9]+]]: f32
2057//       CHECK:   %[[HIGH:.+]] = affine.apply #[[$MAP]]()[%[[IDX]]]
2058//       CHECK:   %[[PAD:.+]] = tensor.pad %[[ARG0]] low[%[[IDX]], 3] high[%[[HIGH]], 2]
2059//       CHECK:     tensor.yield %[[PADVAL]]
2060//       CHECK:   return %[[PAD]]
2061func.func @merge_constant_padding_dynamic(%arg0: tensor<?x?xf32>, %idx: index, %pad_value: f32) -> tensor<?x?xf32> {
2062  %pad0 = tensor.pad %arg0 low[%idx, 1] high[1, 0] {
2063    ^bb0(%b0: index, %b1 : index):
2064      tensor.yield %pad_value : f32
2065    } : tensor<?x?xf32> to tensor<?x?xf32>
2066  %pad1 = tensor.pad %pad0 low[0, 2] high[%idx, 2] {
2067    ^bb0(%b2: index, %b3 : index):
2068      tensor.yield %pad_value : f32
2069    } : tensor<?x?xf32> to tensor<?x?xf32>
2070  return %pad1 : tensor<?x?xf32>
2071}
2072
2073// -----
2074
2075// Verify that folding does not happen if it would drop a nofold attribute
2076// CHECK-LABEL: func @dont_merge_constant_padding_nofold
2077//       CHECK:   tensor.pad {{.*}} nofold
2078//       CHECK:   tensor.pad
2079func.func @dont_merge_constant_padding_nofold(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
2080  %pad0 = tensor.pad %arg0 nofold low[1, 1] high[1, 0] {
2081    ^bb0(%b0: index, %b1 : index):
2082      tensor.yield %pad_value : f32
2083    } : tensor<2x3xf32> to tensor<4x4xf32>
2084  %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
2085    ^bb0(%b2: index, %b3 : index):
2086      tensor.yield %pad_value : f32
2087    } : tensor<4x4xf32> to tensor<7x8xf32>
2088  return %pad1 : tensor<7x8xf32>
2089}
2090
2091// -----
2092
2093// Verify that folding does not happen if it would drop a nofold attribute
2094// CHECK-LABEL: func @dont_merge_constant_padding_different_vals
2095//       CHECK:   tensor.pad
2096//       CHECK:   tensor.pad
2097func.func @dont_merge_constant_padding_different_vals(
2098    %arg0: tensor<2x3xf32>,
2099    %pad_value0: f32,
2100    %pad_value1: f32) -> tensor<7x8xf32> {
2101  %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
2102    ^bb0(%b0: index, %b1 : index):
2103      tensor.yield %pad_value0 : f32
2104    } : tensor<2x3xf32> to tensor<4x4xf32>
2105  %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
2106    ^bb0(%b2: index, %b3 : index):
2107      tensor.yield %pad_value1 : f32
2108    } : tensor<4x4xf32> to tensor<7x8xf32>
2109  return %pad1 : tensor<7x8xf32>
2110}
2111
2112// -----
2113
2114// CHECK-LABEL: func @fold_collapse_shape_from_elements
2115func.func @fold_collapse_shape_from_elements(%arg0: i32) -> tensor<i32> {
2116  // CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<i32>
2117  // CHECK: return %[[FROM]] : tensor<i32>
2118  %0 = tensor.from_elements %arg0 : tensor<1xi32>
2119  %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
2120  return %1 : tensor<i32>
2121}
2122
2123// -----
2124
2125// CHECK-LABEL: func @fold_expand_shape_from_elements
2126func.func @fold_expand_shape_from_elements(%arg0: i32) -> tensor<1xi32> {
2127  // CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<1xi32>
2128  // CHECK: return %[[FROM]] : tensor<1xi32>
2129  %0 = tensor.from_elements %arg0 : tensor<i32>
2130  %1 = tensor.expand_shape %0 [] output_shape [1] : tensor<i32> into tensor<1xi32>
2131  return %1 : tensor<1xi32>
2132}
2133
2134// -----
2135
2136// CHECK-LABEL: func @propagate_index_cast
2137func.func @propagate_index_cast(%arg0: tensor<1xi32>) -> index {
2138  // CHECK: %[[IDX:.+]] = arith.constant 0
2139  // CHECK: %[[EXT:.+]] = tensor.extract %arg0[%[[IDX]]] : tensor<1xi32>
2140  // CHECK: %[[CAST:.+]] = arith.index_cast %[[EXT]]
2141  // CHECK: return %[[CAST]] : index
2142  %c0 = arith.constant 0 : index
2143  %0 = arith.index_cast %arg0 : tensor<1xi32> to tensor<1xindex>
2144  %1 = tensor.extract %0[%c0] : tensor<1xindex>
2145  return %1 : index
2146}
2147
2148// -----
2149
2150// CHECK-LABEL: func @splat_fold
2151func.func @splat_fold() -> tensor<4xf32> {
2152  %c = arith.constant 1.0 : f32
2153  %t = tensor.splat %c : tensor<4xf32>
2154  return %t : tensor<4xf32>
2155
2156  // CHECK-NEXT: [[T:%.*]] = arith.constant dense<1.000000e+00> : tensor<4xf32>
2157  // CHECK-NEXT: return [[T]] : tensor<4xf32>
2158}
2159
2160// -----
2161
2162// CHECK-LABEL: func @splat_dynamic_no_fold
2163// CHECK-SAME: %[[M:.+]]: index
2164func.func @splat_dynamic_no_fold(%m: index) -> tensor<4x?xf32> {
2165  // CHECK: %[[F:.+]] = arith.constant
2166  %f = arith.constant 1.0 : f32
2167
2168  // CHECK: tensor.splat %[[F]][%[[M]]] : tensor<4x?xf32>
2169  %t = tensor.splat %f[%m] : tensor<4x?xf32>
2170  return %t : tensor<4x?xf32>
2171}
2172
2173// -----
2174
2175// CHECK-LABEL: func @cast_extract_slice
2176func.func @cast_extract_slice(%arg0 : tensor<128x512xf32>, %s : index, %o : index)
2177    -> tensor<16x512xf32> {
2178// CHECK: %[[E:.*]] = tensor.extract_slice %{{.*}}[%{{.*}}, 0] [16, 512] [1, 1] : tensor<128x512xf32> to tensor<16x512xf32>
2179  %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] : tensor<128x512xf32> to tensor<?x512xf32>
2180  %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
2181// CHECK: return %[[E]] : tensor<16x512xf32>
2182  return %1 : tensor<16x512xf32>
2183}
2184
2185// -----
2186
2187// CHECK-LABEL: func @cast_extract_slice_rank_reduce
2188func.func @cast_extract_slice_rank_reduce(%arg0 : tensor<128x512xf32>, %s : index, %o : index)
2189    -> tensor<16xf32> {
2190// CHECK: %[[E:.*]]  = tensor.extract_slice %{{.*}}[%{{.*}}, 0] [16, 1] [1, 1] : tensor<128x512xf32> to tensor<16xf32>
2191  %0 = tensor.extract_slice %arg0[%o, 0] [%s, 1] [1, 1] : tensor<128x512xf32> to tensor<?xf32>
2192  %1 = tensor.cast %0 : tensor<?xf32> to tensor<16xf32>
2193// CHECK: return %[[E]] : tensor<16xf32>
2194  return %1 : tensor<16xf32>
2195}
2196
2197// -----
2198
2199// CHECK-LABEL: func.func @canonicalize_parallel_insert_slice_indices(
2200//  CHECK-SAME:     %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>,
2201//  CHECK-SAME:     %[[arg1:[0-9a-z]*]]: tensor<?x?xf32>,
2202//  CHECK-SAME:     %[[num_threads:[0-9a-z]*]]: index
2203func.func @canonicalize_parallel_insert_slice_indices(
2204    %arg0 : tensor<1x5xf32>, %arg1: tensor<?x?xf32>,
2205    %num_threads : index) -> tensor<?x?xf32>
2206{
2207  %cst = arith.constant 4.200000e+01 : f32
2208  %c0 = arith.constant 0 : index
2209  %c1 = arith.constant 1 : index
2210
2211  //  CHECK-NOT: tensor.cast
2212  //      CHECK: scf.forall (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) shared_outs(%[[o:.*]] = %[[arg1]]) -> (tensor<?x?xf32>) {
2213  // CHECK-NEXT:   scf.forall.in_parallel {
2214  // CHECK-NEXT:     tensor.parallel_insert_slice %[[arg0]] into %[[o]][%[[tidx]], 0] [1, 5] [1, 1]
2215  %2 = scf.forall (%tidx) in (%num_threads) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) {
2216    %3 = tensor.cast %arg0 : tensor<1x5xf32> to tensor<?x5xf32>
2217    scf.forall.in_parallel {
2218      tensor.parallel_insert_slice %3 into %o[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor<?x5xf32> into tensor<?x?xf32>
2219    }
2220  }
2221  return %2 : tensor<?x?xf32>
2222}
2223
2224// -----
2225
2226// CHECK-LABEL: func.func @fold_insert_slice_after_extract_slice
2227//  CHECK-SAME: (%[[INPUT:.+]]: tensor<1x2x2x4xf32>)
2228func.func @fold_insert_slice_after_extract_slice(%input: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> {
2229  %c0 = arith.constant 0 : index
2230  %0 = tensor.extract_slice %input[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32>
2231  %1 = tensor.insert_slice %0 into %input[%c0, 0, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32>
2232  // CHECK: return %[[INPUT]]
2233  return %1: tensor<1x2x2x4xf32>
2234}
2235
2236// -----
2237
2238// CHECK-LABEL: func.func @dont_fold_mismatched_source_dst
2239func.func @dont_fold_mismatched_source_dst(%input0: tensor<1x2x2x4xf32>, %input1: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> {
2240  %c0 = arith.constant 0 : index
2241  // CHECK: tensor.extract_slice
2242  %0 = tensor.extract_slice %input0[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32>
2243  // CHECK: tensor.insert_slice
2244  %1 = tensor.insert_slice %0 into %input1[%c0, 0, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32>
2245  return %1: tensor<1x2x2x4xf32>
2246}
2247
2248// -----
2249
2250// CHECK-LABEL: func.func @dont_fold_mismatched_parameters
2251func.func @dont_fold_mismatched_parameters(%input: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> {
2252  %c0 = arith.constant 0 : index
2253  // CHECK: tensor.extract_slice
2254  %0 = tensor.extract_slice %input[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32>
2255  // CHECK: tensor.insert_slice
2256  %1 = tensor.insert_slice %0 into %input[%c0, 1, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32>
2257  return %1: tensor<1x2x2x4xf32>
2258}
2259
2260// -----
2261
2262func.func @empty_canonicalize() -> (tensor<4x5x?xf32>) {
2263  %c6 = arith.constant 6 : index
2264  %0 = tensor.empty(%c6) : tensor<4x5x?xf32>
2265  return %0 : tensor<4x5x?xf32>
2266}
2267// CHECK: func @empty_canonicalize
2268// CHECK:   %[[T0:.+]] = tensor.empty() : tensor<4x5x6xf32>
2269// CHECK:   %[[T1:.+]] = tensor.cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32>
2270// CHECK:   return %[[T1]]
2271
2272// -----
2273
2274func.func @fold_empty_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> {
2275  %0 = tensor.empty(%arg0) : tensor<?x12xf32>
2276  %1 = tensor.cast %0 : tensor<?x12xf32> to tensor<1x12xf32>
2277  return %1 : tensor<1x12xf32>
2278}
2279//      CHECK: func @fold_empty_tensor_with_cast(%[[ARG0:.+]]: index)
2280//      CHECK:   %[[T0:.+]] = tensor.empty() : tensor<1x12xf32>
2281//      CHECK:   return %[[T0]] : tensor<1x12xf32>
2282
2283// -----
2284
2285func.func private @some_use(%i : index, %j : index)
2286
2287// CHECK-LABEL: func @empty_tensor_canonicalize
2288//  CHECK-SAME:   %[[I:.*]]: index
2289func.func @empty_tensor_canonicalize(%i : index) {
2290  %c0 = arith.constant 0 : index
2291  %c1 = arith.constant 1 : index
2292
2293  // CHECK-NOT: tensor.empty
2294  %0 = tensor.empty(%i) : tensor<?x42xf32>
2295
2296  // CHECK-NOT: tensor.dim
2297  %1 = tensor.dim %0, %c0: tensor<?x42xf32>
2298  %2 = tensor.dim %0, %c1: tensor<?x42xf32>
2299
2300  // CHECK: %[[c42:.*]] = arith.constant 42 : index
2301  // CHECK: call @some_use(%[[I]], %[[c42]])
2302  call @some_use(%1, %2) : (index, index) -> ()
2303
2304  return
2305}
2306
2307// -----
2308
2309//       CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)>
2310// CHECK-LABEL: func @dim_of_expand_shape(
2311//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>
2312//       CHECK:   %[[c1:.*]] = arith.constant 1 : index
2313//       CHECK:   %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor<?x?xf32>
2314//       CHECK:   %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]]
2315//       CHECK:   return %[[apply]]
2316func.func @dim_of_expand_shape(%t: tensor<?x?xf32>, %sz0: index, %sz1: index) -> index {
2317  %c2 = arith.constant 2 : index
2318  %0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] output_shape [%sz0, 1, %sz1, 5, 1, 8]
2319      : tensor<?x?xf32> into tensor<?x1x?x5x1x8xf32>
2320  %1 = tensor.dim %0, %c2 : tensor<?x1x?x5x1x8xf32>
2321  return %1 : index
2322}
2323
2324// -----
2325
2326//       CHECK: #[[$map:.*]] = affine_map<()[s0, s1, s2] -> (((s0 * s1) * s2) * 7)>
2327// CHECK-LABEL: func @dim_of_collapse_shape(
2328//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?x7x?xf32>
2329//   CHECK-DAG:   %[[c1:.*]] = arith.constant 1 : index
2330//   CHECK-DAG:   %[[c2:.*]] = arith.constant 2 : index
2331//   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
2332//   CHECK-DAG:   %[[dim1:.*]] = tensor.dim %[[t]], %[[c1]]
2333//   CHECK-DAG:   %[[dim2:.*]] = tensor.dim %[[t]], %[[c2]]
2334//   CHECK-DAG:   %[[dim4:.*]] = tensor.dim %[[t]], %[[c4]]
2335//       CHECK:   %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim1]], %[[dim2]], %[[dim4]]]
2336//       CHECK:   return %[[apply]]
2337func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
2338  %c1 = arith.constant 1 : index
2339  %0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]]
2340      : tensor<?x?x?x7x?xf32> into tensor<?x?xf32>
2341  %1 = tensor.dim %0, %c1 : tensor<?x?xf32>
2342  return %1 : index
2343}
2344
2345// -----
2346
2347// Can't fold when dim is out of bound.
2348// CHECK-LABEL: func @out_of_bound_dim_of_collapse_shape(
2349//       CHECK:   %[[DIM:.*]] = tensor.dim
2350//       CHECK:   return %[[DIM]]
2351func.func @out_of_bound_dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
2352  %c5 = arith.constant 5 : index
2353  %0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]]
2354      : tensor<?x?x?x7x?xf32> into tensor<?x?xf32>
2355  %1 = tensor.dim %0, %c5 : tensor<?x?xf32>
2356  return %1 : index
2357}
2358
2359// -----
2360
2361// CHECK-LABEL: func @collapse_expand_fold_to_cast(
2362//  CHECK-SAME:     %[[t:.*]]: tensor<?xf32>
2363//       CHECK:   return %[[t]]
2364func.func @collapse_expand_fold_to_cast(%t: tensor<?xf32>, %sz0: index) -> (tensor<?xf32>)
2365{
2366  %0 = tensor.expand_shape %t [[0, 1]] output_shape [1, %sz0] : tensor<?xf32> into tensor<1x?xf32>
2367  %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
2368  return %1 : tensor<?xf32>
2369}
2370
2371// -----
2372
2373// Chain: NC -> NCnc -> NCnc -> NC
2374// CHECK: func.func @unpack_pack(
2375// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>)
2376// CHECK: return %[[T]] : tensor<128x128xf32>
2377func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
2378  %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32>
2379  %packed = tensor.pack %t inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
2380  %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
2381  %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
2382  return %unpacked : tensor<128x128xf32>
2383}
2384
2385// -----
2386
2387// Chain: NC -> NCcn -> NCnc -> NC
2388// CHECK: func.func @unpack_pack(
2389// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>)
2390// CHECK-NOT: return %[[T]] : tensor<128x128xf32>
2391func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
2392  %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32>
2393  %packed = tensor.pack %t inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
2394  %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
2395  %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor
2396<128x128xf32>
2397  return %unpacked : tensor<128x128xf32>
2398}
2399
2400// -----
2401
2402// Chain: NC -> CNcn -> NCnc -> NC
2403// CHECK: func.func @unpack_pack(
2404// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>)
2405// CHECK-NOT: return %[[T]] : tensor<128x128xf32>
2406func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
2407  %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32>
2408  %packed = tensor.pack %t outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
2409  %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
2410  %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor
2411<128x128xf32>
2412  return %unpacked : tensor<128x128xf32>
2413}
2414
2415// -----
2416
2417// Chain: NC -> NCnc -> NCnc -> NC
2418// CHECK: func.func @unpack_pack(
2419// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>,
2420// CHECK: return %[[T]] : tensor<128x128xf32>
2421func.func @unpack_pack(%t: tensor<128x128xf32>, %tile1: index, %tile2: index) -> tensor<128x128xf32> {
2422  %tensor_empty = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
2423  %packed = tensor.pack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
2424  %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
2425  %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<16x16x?x?xf32> -> tensor
2426<128x128xf32>
2427  return %unpacked : tensor<128x128xf32>
2428}
2429
2430// -----
2431
2432// CHECK: func.func @unpack_pack_with_padding_no_canonicalization(
2433// CHECK:         tensor.pack
2434// CHECK:         tensor.unpack
2435func.func @unpack_pack_with_padding_no_canonicalization(%t: tensor<256x512xbf16>) -> tensor<224x512xbf16> {
2436  %tensor_empty = tensor.empty() : tensor<4x16x64x32xbf16>
2437  %tensor_empty1 = tensor.empty() : tensor<224x512xbf16>
2438  %packed = tensor.pack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty : tensor<256x512xbf16> -> tensor<4x16x64x32xbf16>
2439  %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty1 : tensor<4x16x64x32xbf16> -> tensor<224x512xbf16>
2440  return %unpacked : tensor<224x512xbf16>
2441}
2442
2443// -----
2444
2445// Chain NCnc -> NC -> NC -> NCnc
2446// CHECK: func.func @pack_unpack(
2447// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
2448// CHECK: return %[[T]] : tensor<16x16x?x?xf32>
2449func.func @pack_unpack(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
2450  %tensor_empty = tensor.empty() : tensor<128x128xf32>
2451  %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
2452  %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
2453  %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
2454  return %packed : tensor<16x16x?x?xf32>
2455}
2456
2457// -----
2458
2459// Chain NCnc -> NC -> NC -> NCnc
2460// CHECK: func.func @pack_unpack(
2461// CHECK-SAME: %[[T:.+]]: tensor<16x16x8x8xf32>
2462// CHECK: return %[[T]] : tensor<16x16x8x8xf32>
2463func.func @pack_unpack(%t: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
2464  %tensor_empty = tensor.empty() : tensor<128x128xf32>
2465  %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
2466  %tensor_empty1 = tensor.empty() : tensor<16x16x8x8xf32>
2467  %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
2468  return %packed : tensor<16x16x8x8xf32>
2469}
2470
2471// -----
2472
2473// CHECK: func.func @pack_unpack_same_tiles(
2474// CHECK-SAME:  %[[T:.+]]: tensor<?x?x?x?xf32>,
2475// CHECK: return %[[T]] : tensor<?x?x?x?xf32>
2476func.func @pack_unpack_same_tiles(%t: tensor<?x?x?x?xf32>, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index,
2477                       %tile1: index, %tile2: index) -> tensor<?x?x?x?xf32> {
2478  %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
2479  %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
2480  %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor<?x?x?x?xf32>
2481  %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
2482  return %packed : tensor<?x?x?x?xf32>
2483}
2484
2485// -----
2486
2487// CHECK: func.func @pack_unpack_different_tiles(
2488// CHECK-SAME:  %[[T:.+]]: tensor<?x?x?x?xf32>,
2489// CHECK-NOT: return %[[T]] : tensor<?x?x?x?xf32>
2490func.func @pack_unpack_different_tiles(%t: tensor<?x?x?x?xf32>, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index,
2491                       %tile1: index, %tile2: index) -> tensor<?x?x?x?xf32> {
2492  %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
2493  %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
2494  %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor<?x?x?x?xf32>
2495  %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile2, %tile1] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
2496  return %packed : tensor<?x?x?x?xf32>
2497}
2498
2499// -----
2500
2501// CHECK: func.func @pack_unpack_dynamic_with_padding(
2502// CHECK-SAME:  %[[T:.+]]: tensor<?x?x?x?xf32>,
2503// CHECK-NOT: return %[[T]] : tensor<?x?x?x?xf32>
2504func.func @pack_unpack_dynamic_with_padding(%t: tensor<?x?x?x?xf32>, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index,
2505                       %tile1: index, %tile2: index, %pad: f32) -> tensor<?x?x?x?xf32> {
2506  %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
2507  %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
2508  %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor<?x?x?x?xf32>
2509  %packed = tensor.pack %unpacked padding_value(%pad: f32) inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
2510  return %packed : tensor<?x?x?x?xf32>
2511}
2512
2513// -----
2514
2515// CHECK: func.func @pack_outer_dims_unpack_no_outer_dims(
2516// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
2517// CHECK: return %[[T]] : tensor<16x16x?x?xf32>
2518func.func @pack_outer_dims_unpack_no_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
2519  %tensor_empty = tensor.empty() : tensor<128x128xf32>
2520  %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
2521  %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
2522  %packed = tensor.pack %unpacked outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
2523  return %packed : tensor<16x16x?x?xf32>
2524}
2525
2526// -----
2527
2528// CHECK: func.func @pack_no_outer_dims_unpack_outer_dims(
2529// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
2530// CHECK: return %[[T]] : tensor<16x16x?x?xf32>
2531func.func @pack_no_outer_dims_unpack_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
2532  %tensor_empty = tensor.empty() : tensor<128x128xf32>
2533  %unpacked = tensor.unpack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
2534  %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
2535  %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
2536  return %packed : tensor<16x16x?x?xf32>
2537}
2538
2539// -----
2540
2541// CHECK: func.func @invalid_empty_negative_size
2542// CHECK: %[[IDX:.*]] = index.constant
2543// CHECK: %[[T:.*]] = tensor.empty(%[[IDX]]) : tensor<4x5x?xf32>
2544func.func @invalid_empty_negative_size() -> (tensor<4x5x?xf32>) {
2545  %c1 = arith.constant 1 : index
2546  %cn2 = arith.constant 2 : index
2547  %0 = index.sub %c1, %cn2
2548  %1 = tensor.empty(%0) : tensor<4x5x?xf32>
2549  return %1 : tensor<4x5x?xf32>
2550}
2551
2552// -----
2553
2554// Fold DstStyleOp -> tensor.unpack operations.
2555func.func @fold_dst_style_ops_into_unpack(%arg0 : tensor<?x?x16x64xf32>, %init : tensor<?x?xf32>) -> tensor<?x?xf32> {
2556  %cst = arith.constant 0.0 : f32
2557  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
2558  %unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %fill : tensor<?x?x16x64xf32> -> tensor<?x?xf32>
2559  return %unpack : tensor<?x?xf32>
2560}
2561// CHECK-LABEL: func @fold_dst_style_ops_into_unpack
2562//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x16x64xf32>
2563//  CHECK-SAME:     %[[INIT:.+]]: tensor<?x?xf32>
2564//       CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
2565//  CHECK-SAME:       into %[[INIT]]
2566//       CHECK:   return %[[UNPACK]]
2567
2568// -----
2569
2570// The IR in this test case in invalid. This test tests that the canonicalizer
2571// does not crash.
2572
2573// CHECK-LABEL: func @invalid_slice_ops(
2574//       CHECK:   %[[c:.*]] = arith.constant -5 : index
2575//       CHECK:   tensor.extract_slice {{.*}}%[[c]]
2576//       CHECK:   tensor.insert_slice {{.*}}%[[c]]
2577func.func @invalid_slice_ops(%t: tensor<?xf32>, %t2: tensor<?xf32>) -> tensor<?xf32> {
2578  %c = arith.constant -5 : index
2579  %0 = tensor.extract_slice %t[0][%c][1] : tensor<?xf32> to tensor<?xf32>
2580  %1 = tensor.insert_slice %0 into %t2[2][%c][1] : tensor<?xf32> into tensor<?xf32>
2581  return %1 : tensor<?xf32>
2582}
2583
2584// -----
2585
2586// CHECK-LABEL: func @generate_negative_size_verifies(
2587//       CHECK:   %[[c:.*]] = arith.constant -8 : index
2588//       CHECK:   tensor.generate %[[c]]
2589//       CHECK:   : tensor<?x8xi32>
2590func.func @generate_negative_size_verifies() -> tensor<?x8xi32> {
2591  %cst = arith.constant 0 : i32
2592  %c0 = arith.constant 0 : index
2593  %size = affine.max affine_map<(d0) -> (d0 mod 64 - 8)>(%c0)
2594  %tensor = tensor.generate %size {
2595  ^bb0(%arg0: index, %arg1: index):
2596    tensor.yield %cst : i32
2597  } : tensor<?x8xi32>
2598  return %tensor : tensor<?x8xi32>
2599}
2600
2601// -----
2602
2603func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> tensor<10x20x4x4xf32> {
2604  %dim1 = arith.constant 40 : index
2605  %dim2 = arith.constant 80 : index
2606  %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
2607  %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty : tensor<10x20x4x4xf32> -> tensor<?x?xf32>
2608  %cast = tensor.cast %unpacked : tensor<?x?xf32> to tensor<40x80xf32>
2609  %tensor_empty1 = tensor.empty() : tensor<10x20x4x4xf32>
2610  %packed = tensor.pack %cast inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty1 : tensor<40x80xf32> -> tensor<10x20x4x4xf32>
2611  return %packed : tensor<10x20x4x4xf32>
2612}
2613// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
2614// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
2615// CHECK:         return %[[SRC]]
2616
2617// -----
2618
2619// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
2620// CHECK-LABEL: func @dim_of_reshape(
2621//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: tensor<*xf32>,
2622//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: tensor<?xindex>
2623//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 3
2624//  CHECK-NEXT:   %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]]
2625//   CHECK-NOT:   tensor.store
2626//   CHECK-NOT:   tensor.dim
2627//   CHECK-NOT: tensor.reshape
2628//       CHECK:   return %[[DIM]] : index
2629func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>)
2630    -> index {
2631  %c3 = arith.constant 3 : index
2632  %0 = tensor.reshape %arg0(%arg1)
2633      : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
2634  // Update the shape to test that the load ends up in the right place.
2635  tensor.insert %c3 into %arg1[%c3] : tensor<?xindex>
2636  %1 = tensor.dim %0, %c3 : tensor<*xf32>
2637  return %1 : index
2638}
2639
2640// -----
2641
2642// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
2643// CHECK-LABEL: func @dim_of_reshape_i32(
2644//       CHECK:  tensor.extract
2645//  CHECK-NEXT:  %[[CAST:.*]] = arith.index_cast
2646//   CHECK-NOT:  tensor.dim
2647//   CHECK-NOT:  tensor.reshape
2648//       CHECK:  return %[[CAST]] : index
2649func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
2650    -> index {
2651    %c3 = arith.constant 3 : index
2652    %0 = tensor.reshape %arg0(%arg1)
2653        : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
2654    %1 = tensor.dim %0, %c3 : tensor<*xf32>
2655    return %1 : index
2656}
2657
2658// -----
2659
2660// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
2661// CHECK-LABEL: func @dim_of_reshape_for(
2662//       CHECK: scf.for
2663//  CHECK-NEXT: tensor.extract
2664//   CHECK-NOT: tensor.dim
2665//   CHECK-NOT: tensor.reshape
2666func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index {
2667    %c0 = arith.constant 0 : index
2668    %c1 = arith.constant 1 : index
2669    %c4 = arith.constant 4 : index
2670
2671    %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
2672
2673    %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
2674      %2 = tensor.dim %0, %arg2 : tensor<*xf32>
2675      %3 = arith.muli %arg3, %2 : index
2676      scf.yield %3 : index
2677    }
2678    return %1 : index
2679}
2680
2681// -----
2682
2683// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
2684// CHECK-LABEL: func @dim_of_reshape_undominated(
2685//       CHECK: arith.muli
2686//  CHECK-NEXT: tensor.extract
2687//   CHECK-NOT: tensor.dim
2688//   CHECK-NOT: tensor.reshape
2689func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index {
2690    %c4 = arith.constant 4 : index
2691    %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
2692    %0 = arith.muli %arg2, %c4 : index
2693    %dim = tensor.dim %reshape, %0 : tensor<*xf32>
2694    return %dim : index
2695  }
2696
2697// -----
2698
2699// CHECK-LABEL: @reshape_fold_2d
2700// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
2701func.func @reshape_fold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
2702  %c0 = arith.constant 0 : index
2703  %c1 = arith.constant 1 : index
2704  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
2705  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
2706  %ds = tensor.from_elements %d0, %d1 : tensor<2xindex>
2707  %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
2708  // CHECK: return %[[ARG0]]
2709  return %reshape : tensor<?x?xi32>
2710}
2711
2712// -----
2713
2714// CHECK-LABEL: @reshape_nofold_2d
2715// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
2716func.func @reshape_nofold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
2717  %c0 = arith.constant 0 : index
2718  %c1 = arith.constant 1 : index
2719  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
2720  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
2721  %ds = tensor.from_elements %d1, %d0 : tensor<2xindex>
2722  // CHECK: tensor.reshape
2723  %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
2724  return %reshape : tensor<?x?xi32>
2725}
2726
2727// -----
2728
2729// CHECK-LABEL: @reshape_nofold_2d_ins
2730func.func @reshape_nofold_2d_ins(%arg0 : tensor<?x?xi32>, %arg1: index, %arg2: index) -> tensor<?x?xi32> {
2731  %ds = tensor.from_elements %arg1, %arg2 : tensor<2xindex>
2732  // CHECK: tensor.reshape
2733  %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
2734  return %reshape : tensor<?x?xi32>
2735}
2736
2737// -----
2738
2739// CHECK-LABEL: @reshape_fold_3d_cst
2740// CHECK-SAME: %[[ARG0:.+]]: tensor<5x?x?xi32>
2741func.func @reshape_fold_3d_cst(%arg0 : tensor<5x?x?xi32>) -> tensor<5x?x?xi32> {
2742  %c1 = arith.constant 1 : index
2743  %c2 = arith.constant 2 : index
2744  %d0 = arith.constant 5 : index
2745  %d1 = tensor.dim %arg0, %c1 : tensor<5x?x?xi32>
2746  %d2 = tensor.dim %arg0, %c2 : tensor<5x?x?xi32>
2747  %ds = tensor.from_elements %d0, %d1, %d2 : tensor<3xindex>
2748  %reshape = tensor.reshape %arg0(%ds) : (tensor<5x?x?xi32>, tensor<3xindex>) -> tensor<5x?x?xi32>
2749  // CHECK: return %[[ARG0]]
2750  return %reshape : tensor<5x?x?xi32>
2751}
2752
2753// -----
2754
2755// Test case: This test fails to fold because the index of tensor.dim is out_of_bounds
2756// CHECK-LABEL: func @dim_out_of_bounds(
2757//       CHECK: %[[IDX:.*]] = index.constant 28
2758//  CHECK-NEXT: bufferization.alloc_tensor
2759//  CHECK-NEXT: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[IDX]]
2760//  CHECK-NEXT: memref.alloc
2761//  CHECK-NEXT: memref.cast
2762//  CHECK-NEXT: affine.vector_load %{{.*}}[{{.*}}, {{.*}}, symbol(%[[DIM]])]
2763//  CHECK-NEXT: return
2764func.func @dim_out_of_bounds() -> vector<7xi32> {
2765    %c1 = arith.constant 1 : index
2766    %idx28 = index.constant 28
2767    %c29 = arith.constant 29 : index
2768    %3 = bufferization.alloc_tensor(%c29) : tensor<?xi16>
2769    %dim = tensor.dim %3, %idx28 : tensor<?xi16>
2770    %alloc_21 = memref.alloc(%c29) : memref<?x26x2xi32>
2771    %16 = affine.vector_load %alloc_21[%c1, %c1, %dim] : memref<?x26x2xi32>, vector<7xi32>
2772    return %16 : vector<7xi32>
2773}
2774
2775// -----
2776
2777// CHECK-LABEL:   func.func @fold_cast_multiple_results(
2778// CHECK-SAME:         %[[ARG1:.*]]: tensor<2x2xf32>,
2779// CHECK-SAME:         %[[ARG2:.*]]: tensor<2x2xf32>) -> index {
2780// CHECK:           %[[RES:.*]]:2 = test.destination_style_op ins(%[[ARG1]] : tensor<2x2xf32>)
2781// CHECK-SAME:      outs(%[[ARG2]] : tensor<2x2xf32>) -> tensor<2x2xf32>, index
2782// CHECK:           return %[[RES]]#1 : index
2783func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index {
2784  %cast = tensor.cast %arg0 : tensor<2x2xf32> to tensor<?x2xf32>
2785  %cast_0 = tensor.cast %arg1 : tensor<2x2xf32> to tensor<?x2xf32>
2786  %0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
2787  return %0#1 : index
2788}
2789
2790// -----
2791
2792// CHECK-LABEL:   func.func @fold_cast_pack_dynamic_tile_size
2793// CHECK-SAME:      %[[DEST:.*]]: tensor<1x1x8x1xi32>,
2794// CHECK-SAME:      %[[SRC:.*]]: tensor<7x?xi32>,
2795// CHECK-SAME:      %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> {
2796// CHECK:           %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32)
2797// CHECK-SAME:        inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]]
2798// CHECK-SAME:        test_attr
2799// CHECK-SAME:        : tensor<7x?xi32> -> tensor<1x1x8x1xi32>
2800// CHECK:           return %[[PACK]] : tensor<1x1x8x1xi32>
2801func.func @fold_cast_pack_dynamic_tile_size(
2802  %dest: tensor<1x1x8x1xi32>,
2803  %src: tensor<7x?xi32>,
2804  %pad: i32) -> tensor<1x1x8x1xi32> {
2805
2806    %cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
2807    %c8 = arith.constant 8 : index
2808    %pack = tensor.pack %src padding_value(%pad : i32)
2809      inner_dims_pos = [0, 1]
2810      inner_tiles = [%c8, 1]
2811      into %cast {test_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
2812    %res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32>
2813    return %res : tensor<1x1x8x1xi32>
2814}
2815
2816// -----
2817
2818// CHECK-LABEL:   func.func @fold_cast_unpack_dynamic_tile_size(
2819// CHECK-SAME:      %[[SRC:.*]]: tensor<1x1x8x1xi32>,
2820// CHECK-SAME:      %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> {
2821// CHECK:           %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {test_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
2822// CHECK:           return %[[RES]] : tensor<7x?xi32>
2823func.func @fold_cast_unpack_dynamic_tile_size(
2824  %src: tensor<1x1x8x1xi32>,
2825  %res: tensor<7x?xi32>) -> tensor<7x?xi32> {
2826
2827    %cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
2828    %c8 = arith.constant 8 : index
2829    %unpack = tensor.unpack %cast
2830      inner_dims_pos = [0, 1]
2831      inner_tiles = [%c8, 1]
2832      into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
2833    return %unpack : tensor<7x?xi32>
2834}
2835
2836// -----
2837
2838// CHECK-LABEL:   func.func @pack_dont_drop_attributes(
2839// CHECK: tensor.pack {{.*}}  {test_attr}
2840func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> {
2841  %c32_i64 = arith.constant 32 : i64
2842  %cst = arith.constant 0.000000e+00 : f16
2843  %pack = tensor.pack %arg0 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %arg1 {test_attr} : tensor<?x?x?xf16> -> tensor<128x?x100x16x1xf16>
2844  return %pack : tensor<128x?x100x16x1xf16>
2845}
2846
2847// -----
2848
2849func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
2850    -> tensor<10x1x10xf32> {
2851  %c1 = arith.constant 1 : index
2852  %c10 = arith.constant 10 : index
2853  %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
2854  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
2855      : tensor<?x?xf32> into tensor<?x?x?xf32>
2856  %2 = tensor.cast %1 : tensor<?x?x?xf32> to tensor<10x1x10xf32>
2857  return %2 : tensor<10x1x10xf32>
2858}
2859// CHECK-LABEL:  func.func @fold_expand_of_cast
2860//       CHECK:   %[[RES:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10]
2861//       CHECK:   return %[[RES]]
2862
2863// -----
2864
2865func.func @sink_expand_of_cast(%arg0 : tensor<?x10xf32>)
2866    -> tensor<?x?x?xf32> {
2867  %c1 = arith.constant 1 : index
2868  %c10 = arith.constant 10 : index
2869  %0 = tensor.cast %arg0 : tensor<?x10xf32> to tensor<?x?xf32>
2870  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
2871      : tensor<?x?xf32> into tensor<?x?x?xf32>
2872  return %1 : tensor<?x?x?xf32>
2873}
2874// CHECK-LABEL:  func.func @sink_expand_of_cast
2875//   CHECK-DAG:   %[[C10:.*]] = arith.constant 10
2876//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1
2877//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]]
2878//  CHECK-SAME:     output_shape [%[[C10]], %[[C1]], 10]
2879//       CHECK:   %[[RES:.+]] = tensor.cast %[[EXPAND]]
2880//       CHECK:   return %[[RES]]
2881
2882// -----
2883
2884func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index, %arg2 : index)
2885    -> tensor<?x?x?xf32> {
2886  %c10 = arith.constant 10 : index
2887  %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
2888  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %c10]
2889      : tensor<?x?xf32> into tensor<?x?x?xf32>
2890  return %1 : tensor<?x?x?xf32>
2891}
2892// CHECK-LABEL:  func.func @partial_sink_expand_of_cast
2893//       CHECK:   %[[CAST:.+]] = tensor.cast
2894//  CHECK-SAME:     tensor<10x10xf32> to tensor<?x10xf32>
2895//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]]
2896//  CHECK-SAME:     output_shape [%{{.*}}, %{{.*}}, 10]
2897//       CHECK:   %[[RES:.+]] = tensor.cast %[[EXPAND]]
2898//  CHECK-SAME:     tensor<?x?x10xf32> to tensor<?x?x?xf32>
2899//       CHECK:   return %[[RES]]
2900