Lines Matching full:tensor

6 func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> {
7 %0 = tensor.expand_shape %arg0 [[0]] output_shape [5] : tensor<5xf32> into tensor<5xf32>
8 return %0 : tensor<5xf32>
15 func.func @expand_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
16 %0 = tensor.expand_shape %arg0 [] output_shape [] : tensor<f32> into tensor<f32>
17 return %0 : tensor<f32>
24 func.func @collapse_shape_identity_fold(%arg0 : tensor<5x4xf32>) -> tensor<5x4xf32> {
25 %0 = tensor.collapse_shape %arg0 [[0], [1]] : tensor<5x4xf32> into tensor<5x4xf32>
26 return %0 : tensor<5x4xf32>
33 func.func @collapse_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
34 %0 = tensor.collapse_shape %arg0 [] : tensor<f32> into tensor<f32>
35 return %0 : tensor<f32>
41 // CHECK-SAME: %[[IN:.*]]: tensor<2xi32>
42 func.func @tensor_bitcast_chain_ok(%input: tensor<2xi32>) -> tensor<2xf32> {
43 // CHECK-NEXT: %[[RES:.*]] = tensor.bitcast %[[IN]] : tensor<2xi32> to tensor<2xf32>
44 %0 = tensor.bitcast %input : tensor<2xi32> to tensor<2xui32>
45 %1 = tensor.bitcast %0 : tensor<2xui32> to tensor<2xf32>
47 return %1 : tensor<2xf32>
53 // CHECK-SAME: %[[IN:.*]]: tensor<4xi32>
54 func.func @tensor_bitcast_chain_nop(%input: tensor<4xi32>) -> tensor<4xi32> {
55 %0 = tensor.bitcast %input : tensor<4xi32> to tensor<4xui32>
56 %1 = tensor.bitcast %0 : tensor<4xui32> to tensor<4xi32>
58 return %1 : tensor<4xi32>
65 func.func @cast_values(%arg0: tensor<*xi32>) -> tensor<2xi32> {
67 %0 = tensor.cast %arg0 : tensor<*xi32> to tensor<*xi32>
68 // CHECK-NEXT: %[[RET:.*]] = tensor.cast %arg0 : tensor<*xi32> to tensor<2xi32>
69 %2 = tensor.cast %0 : tensor<*xi32> to tensor<2xi32>
71 %4 = tensor.cast %2 : tensor<2xi32> to tensor<2xi32>
72 // CHECK-NEXT: return %[[RET]] : tensor<2xi32>
73 return %4 : tensor<2xi32>
78 // CHECK-LABEL: @tensor.cast_chain_ok
79 // CHECK-SAME: %[[IN:.*]]: tensor<*xi32>
80 func.func @tensor.cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> {
81 // CHECK-NEXT: %[[RES:.*]] = tensor.cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32>
82 %0 = tensor.cast %input : tensor<*xi32> to tensor<4x?xi32>
83 %1 = tensor.cast %0 : tensor<4x?xi32> to tensor<4x8xi32>
85 return %1 : tensor<4x8xi32>
90 // CHECK-LABEL: @tensor.cast_chain_regain
91 // CHECK-SAME: %[[IN:.*]]: tensor<4xi32>
92 func.func @tensor.cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> {
93 %0 = tensor.cast %input : tensor<4xi32> to tensor<?xi32>
94 %1 = tensor.cast %0 : tensor<?xi32> to tensor<4xi32>
96 return %1 : tensor<4xi32>
101 // CHECK-LABEL: @tensor.cast_chain_keep
102 // CHECK-SAME: %[[IN:.*]]: tensor<?x?xi32>
103 func.func @tensor.cast_chain_keep(%input: tensor<?x?xi32>) -> tensor<?x8xi32> {
104 // CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]]
105 %0 = tensor.cast %input : tensor<?x?xi32> to tensor<4x?xi32>
106 // CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]]
107 %1 = tensor.cast %0 : tensor<4x?xi32> to tensor<?x8xi32>
109 return %1 : tensor<?x8xi32>
114 // CHECK-LABEL: @tensor.cast_chain_invalid
115 // CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32>
116 func.func @tensor.cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> {
117 // CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]]
118 %0 = tensor.cast %input : tensor<4x8xi32> to tensor<?x?xi32>
119 // CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]]
120 %1 = tensor.cast %0 : tensor<?x?xi32> to tensor<8x4xi32>
122 return %1 : tensor<8x4xi32>
128 // CHECK-SAME: %[[ARG0:.*]]: tensor<1x2x?xi32>
129 func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1x2x?xi32>) {
130 %0 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x3xi32>
131 // CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<1x2x?xi32> to tensor<1x2x3xi32>
132 %1 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x?xi32>
133 // CHECK-NEXT: return %[[CAST]], %[[ARG0]] : tensor<1x2x3xi32>, tensor<1x2x?xi32>
134 return %0, %1 : tensor<1x2x3xi32>, tensor<1x2x?xi32>
150 %0 = arith.constant dense<4.0> : tensor<4xf32>
151 %ext_1 = tensor.extract %0[%arg0] : tensor<4xf32>
154 %1 = arith.constant sparse<[[0, 0, 0], [1, 1, 1]], [-5.0, -2.0]> : tensor<4x4x4xf16>
155 %ext_2 = tensor.extract %1[%const_1, %const_1, %const_1] : tensor<4x4x4xf16>
158 %2 = arith.constant sparse<[[1, 1, 1]], [-2.0]> : tensor<2x2x2xf16>
159 %ext_3 = tensor.extract %2[%const_0, %const_0, %const_0] : tensor<2x2x2xf16>
161 // Fold an extract into a dense tensor.
162 %3 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
163 %ext_4 = tensor.extract %3[%const_1, %const_0, %const_3] : tensor<2x1x4xi32>
167 %4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
168 %ext_5 = tensor.extract %4[] : tensor<complex<f32>>
180 // CHECK: %[[EXT:.+]] = tensor.extract
183 %cst = arith.constant dense_resource<__elided__> : tensor<1xi64>
184 %extracted = tensor.extract %cst[%c0] : tensor<1xi64>
191 func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
193 // CHECK-DAG: %[[C4:.+]] = arith.constant dense<4.{{0*}}e+00> : tensor<4xf32>
194 %0 = arith.constant dense<4.0> : tensor<4xf32>
196 %ins_1 = tensor.insert %1 into %0[%arg0] : tensor<4xf32>
198 return %ins_1 : tensor<4xf32>
204 // CHECK-SAME: %[[TENSOR:.*]]: tensor<9xf32>
205 func.func @extract_from_tensor.cast(%tensor: tensor<9xf32>) -> f32 {
208 // CHECK-NOT: tensor.cast
209 %casted = tensor.cast %tensor : tensor<9xf32> to tensor<?xf32>
210 // CHECK-NEXT: tensor.extract %[[TENSOR]][%[[C0]]]
211 %result = tensor.extract %casted[%c0] : tensor<?xf32>
221 %tensor = tensor.from_elements %element : tensor<1xindex>
222 %extracted_element = tensor.extract %tensor[%c0] : tensor<1xindex>
233 %tensor = tensor.from_elements %element : tensor<index>
234 %extracted_element = tensor.extract %tensor[] : tensor<index>
257 %tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
258 : tensor<3x2x2xf32>
263 %r0 = tensor.extract %tensor[%c0, %c0, %c0] : tensor<3x2x2xf32>
264 %r1 = tensor.extract %tensor[%c0, %c0, %c1] : tensor<3x2x2xf32>
265 %r2 = tensor.extract %tensor[%c0, %c1, %c0] : tensor<3x2x2xf32>
266 %r3 = tensor.extract %tensor[%c0, %c1, %c1] : tensor<3x2x2xf32>
267 %r4 = tensor.extract %tensor[%c1, %c0, %c0] : tensor<3x2x2xf32>
268 %r5 = tensor.extract %tensor[%c1, %c0, %c1] : tensor<3x2x2xf32>
269 %r6 = tensor.extract %tensor[%c1, %c1, %c0] : tensor<3x2x2xf32>
270 %r7 = tensor.extract %tensor[%c1, %c1, %c1] : tensor<3x2x2xf32>
271 %r8 = tensor.extract %tensor[%c2, %c0, %c0] : tensor<3x2x2xf32>
272 %r9 = tensor.extract %tensor[%c2, %c0, %c1] : tensor<3x2x2xf32>
273 %r10 = tensor.extract %tensor[%c2, %c1, %c0] : tensor<3x2x2xf32>
274 %r11 = tensor.extract %tensor[%c2, %c1, %c1] : tensor<3x2x2xf32>
314 %tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
315 : tensor<3x2x2xf32>
320 %r0 = tensor.extract %tensor[%c0, %c0, %c0] : tensor<3x2x2xf32>
321 %r1 = tensor.extract %tensor[%c0, %c0, %c1] : tensor<3x2x2xf32>
322 %r2 = tensor.extract %tensor[%c0, %c1, %c0] : tensor<3x2x2xf32>
323 %r3 = tensor.extract %tensor[%c0, %c1, %c1] : tensor<3x2x2xf32>
324 %r4 = tensor.extract %tensor[%c1, %c0, %c0] : tensor<3x2x2xf32>
325 %r5 = tensor.extract %tensor[%c1, %c0, %c1] : tensor<3x2x2xf32>
326 %r6 = tensor.extract %tensor[%c1, %c1, %c0] : tensor<3x2x2xf32>
327 %r7 = tensor.extract %tensor[%c1, %c1, %c1] : tensor<3x2x2xf32>
328 %r8 = tensor.extract %tensor[%c2, %c0, %c0] : tensor<3x2x2xf32>
329 %r9 = tensor.extract %tensor[%c2, %c0, %c1] : tensor<3x2x2xf32>
330 %r10 = tensor.extract %tensor[%c2, %c1, %c0] : tensor<3x2x2xf32>
331 %r11 = tensor.extract %tensor[%c2, %c1, %c1] : tensor<3x2x2xf32>
340 // CHECK-LABEL: func.func @extract_from_elements_complex_i() -> tensor<3xcomplex<i32>> {
341 // CHECK-NEXT: %cst = arith.constant dense<[(1,2), (3,2), (1,2)]> : tensor<3xcomplex<i32>>
342 // CHECK-NEXT: return %cst : tensor<3xcomplex<i32>>
343 func.func @extract_from_elements_complex_i() -> tensor<3xcomplex<i32>> {
344 %c1 = arith.constant dense<(1, 2)> : tensor<complex<i32>>
345 %complex1 = tensor.extract %c1[] : tensor<complex<i32>>
346 %c2 = arith.constant dense<(3, 2)> : tensor<complex<i32>>
347 %complex2 = tensor.extract %c2[] : tensor<complex<i32>>
348 %tensor = tensor.from_elements %complex1, %complex2, %complex1 : tensor<3xcomplex<i32>>
349 return %tensor : tensor<3xcomplex<i32>>
354 // CHECK-LABEL: func.func @extract_from_elements_complex_f() -> tensor<3xcomplex<f32>> {
355 // CHECK-NEXT: %cst = arith.constant dense<[(1.200000e+00,2.300000e+00), (3.200000e+00,2.100000e+00), (1.200000e+00,2.300000e+00)]> : tensor<3xcomplex<f32>>
356 // CHECK-NEXT: return %cst : tensor<3xcomplex<f32>>
357 func.func @extract_from_elements_complex_f() -> tensor<3xcomplex<f32>> {
358 %c1 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
359 %complex1 = tensor.extract %c1[] : tensor<complex<f32>>
360 %c2 = arith.constant dense<(3.2, 2.1)> : tensor<complex<f32>>
361 %complex2 = tensor.extract %c2[] : tensor<complex<f32>>
362 %tensor = tensor.from_elements %complex1, %complex2, %complex1 : tensor<3xcomplex<f32>>
363 return %tensor : tensor<3xcomplex<f32>>
373 %tensor = tensor.from_elements %element : tensor<1xindex>
374 %extracted_element = tensor.extract %tensor[%c-1] : tensor<1xindex>
375 // CHECK: tensor.from_elements
376 // CHECK: %[[RESULT:.*]] = tensor.extract
388 %tensor = tensor.from_elements %element : tensor<1xindex>
389 %extracted_element = tensor.extract %tensor[%c1] : tensor<1xindex>
390 // CHECK: tensor.from_elements
391 // CHECK: %[[RESULT:.*]] = tensor.extract
403 %tensor = tensor.from_elements %element : tensor<1xindex>
404 %extracted_element = tensor.extract %tensor[%c2] : tensor<1xindex>
405 // CHECK: tensor.from_elements
406 // CHECK: %[[RESULT:.*]] = tensor.extract
414 // CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
415 func.func @extract_from_tensor.generate(%idx: index, %tensor: tensor<*xf32>) -> index {
416 %size = tensor.rank %tensor : tensor<*xf32>
417 // CHECK-NEXT: %[[RES:.*]] = tensor.dim %[[TENSOR]], %[[IDX]]
418 %0 = tensor.generate %size {
420 %1 = tensor.dim %tensor, %arg0 : tensor<*xf32>
421 tensor.yield %1 : index
422 } : tensor<?xindex>
423 %1 = tensor.extract %0[%idx] : tensor<?xindex>
431 // CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
432 func.func @extract_from_tensor.generate_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index {
433 %size = tensor.rank %tensor : tensor<*xf32>
434 // CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[TENSOR]], %[[IDX0]]
435 // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[TENSOR]], %[[IDX1]]
437 %0 = tensor.generate %size, %size {
439 %1 = tensor.dim %tensor, %arg0 : tensor<*xf32>
440 %2 = tensor.dim %tensor, %arg1 : tensor<*xf32>
442 tensor.yield %3 : index
443 } : tensor<?x?xindex>
444 %4 = tensor.extract %0[%idx0, %idx1] : tensor<?x?xindex>
453 func.func @extract_from_tensor.generate_sideeffects(%idx: index, %tensor: tensor<*xf32>, %mem: memref<?xindex>) -> index {
454 %size = tensor.rank %tensor : tensor<*xf32>
455 // CHECK: %[[DTENSOR:.*]] = tensor.generate
456 %0 = tensor.generate %size {
458 %1 = tensor.dim %tensor, %arg0 : tensor<*xf32>
460 tensor.yield %1 : index
461 } : tensor<?xindex>
462 // CHECK: %[[RES:.*]] = tensor.extract %[[DTENSOR]][%[[IDX]]]
463 %1 = tensor.extract %0[%idx] : tensor<?xindex>
472 func.func @static_tensor.generate(%size1: index, %size4: index) -> tensor<3x?x?x7x?xindex> {
474 // CHECK: tensor.generate %[[SIZE1]], %[[SIZE4]]
475 %0 = tensor.generate %size1, %c5, %size4 {
478 tensor.yield %1 : index
479 // CHECK: : tensor<3x?x5x7x?xindex>
480 } : tensor<3x?x?x7x?xindex>
481 // CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
482 return %0 : tensor<3x?x?x7x?xindex>
488 func.func @from_elements.constant() -> tensor<3xindex> {
489 // CHECK: %[[CST:.*]] = arith.constant dense<[1, 2, 1]> : tensor<3xindex>
493 %tensor = tensor.from_elements %c1, %c2, %c1 : tensor<3xindex>
494 return %tensor : tensor<3xindex>
499 func.func @slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
500 %arg2 : index) -> tensor<?x?x?xf32>
505 %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
506 return %0 : tensor<?x?x?xf32>
509 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
510 // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
512 // CHECK-SAME: : tensor<?x?x?xf32> to tensor<4x1x?xf32>
513 // CHECK: %[[RESULT:.+]] = tensor.cast %[[SLICE]]
518 func.func @rank_reducing_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
519 %arg2 : index) -> tensor<?x?xf32>
524 %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?xf32>
525 return %0 : tensor<?x?xf32>
528 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
529 // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
531 // CHECK-SAME: : tensor<?x?x?xf32> to tensor<4x?xf32>
532 // CHECK: %[[RESULT:.+]] = tensor.cast %[[SLICE]]
538 // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
539 // CHECK-NOT: tensor.extract_slice
540 // CHECK: return %[[ARG0]] : tensor<4x6x16x32xi8>
541 func.func @trivial_slice(%arg0 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
542 %0 = tensor.extract_slice %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<4x6x16x32xi8>
543 return %0 : tensor<4x6x16x32xi8>
549 // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
550 // CHECK-NOT: tensor.extract_slice
551 // CHECK: return %[[ARG0]] : tensor<4x6x16x32xi8>
552 func.func @trivial_insert_slice(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
553 %0 = tensor.insert_slice %arg0 into %arg1[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> into tensor<4x6x16x32xi8>
554 return %0 : tensor<4x6x16x32xi8>
560 // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<0x2xi8>
561 // CHECK-SAME: %[[ARG1:.[a-z0-9A-Z_]+]]: tensor<3x3xi8>
562 // CHECK-NOT: tensor.extract_slice
563 // CHECK: return %[[ARG1]] : tensor<3x3xi8>
564 func.func @empty_insert_slice(%arg0 : tensor<0x2xi8>, %arg1 : tensor<3x3xi8>) -> tensor<3x3xi8> {
565 %0 = tensor.insert_slice %arg0 into %arg1[0, 0] [0, 2] [1, 1] : tensor<0x2xi8> into tensor<3x3xi8>
566 return %0 : tensor<3x3xi8>
572 // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
573 // CHECK: %[[S:.+]] = tensor.extract_slice %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8>
574 // Tensor cast is moved after slice and then gets canonicalized away.
575 // CHECK-NOT: tensor.cast
576 // CHECK: return %[[S]] : tensor<16x32xi8>
577 func.func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<16x32xi8> {
578 %0 = tensor.cast %arg : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
579 %1 = tensor.extract_slice %0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<?x?x16x32xi8> to tensor<16x32xi8>
580 return %1 : tensor<16x32xi8>
586 // CHECK-SAME: %[[A:.[a-z0-9A-Z_]+]]: tensor<16x32xi8>
587 // CHECK-SAME: %[[B:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
588 // CHECK: %[[S:.+]] = tensor.insert_slice %[[A]] into %[[B]][0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<16x32xi8> into tensor<4x6x16x32xi8>
589 // Tensor cast is folded away.
590 // CHECK-NOT: tensor.cast
591 // CHECK: return %[[S]] : tensor<4x6x16x32xi8>
592 func.func @rank_reducing_insert_slice_of_cast(%a : tensor<16x32xi8>, %b : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
594 %cast = tensor.cast %a : tensor<16x32xi8> to tensor<?x32xi8>
595 %sz = tensor.dim %cast, %c0: tensor<?x32xi8>
596 %res = tensor.insert_slice %cast into %b[0, 1, 0, 0] [1, 1, %sz, 32] [1, 1, 1, 1] : tensor<?x32xi8> into tensor<4x6x16x32xi8>
597 return %res : tensor<4x6x16x32xi8>
602 func.func @insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
603 %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
608 %0 = tensor.insert_slice %arg0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
609 return %0 : tensor<?x?x?xf32>
612 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
613 // CHECK: %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<4x1x?xf32>
614 // CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[CAST]]
616 // CHECK-SAME: : tensor<4x1x?xf32> into tensor<?x?x?xf32>
622 func.func @insert_slice_canonicalize_encoding(%arg0 : tensor<2x2xf32, "foo">,
623 %arg1 : tensor<4x4xf32, "foo">) -> tensor<4x4xf32, "foo">
625 %0 = tensor.insert_slice %arg0 into %arg1[0, 0] [2, 2] [1, 1] : tensor<2x2xf32, "foo"> into tensor<4x4xf32, "foo">
626 return %0 : tensor<4x4xf32, "foo">
629 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x2xf32, "foo">
630 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<4x4xf32, "foo">
631 // CHECK-NOT: tensor.cast
632 // CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]] into %[[ARG1]]
634 // CHECK-SAME: : tensor<2x2xf32, "foo"> into tensor<4x4xf32, "foo">
639 func.func @slice_to_insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
640 %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
645 %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
646 %1 = tensor.insert_slice %0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
647 return %1 : tensor<?x?x?xf32>
650 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
651 // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
652 // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
654 // CHECK-SAME: : tensor<?x?x?xf32> to tensor<4x1x?xf32>
655 // CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[SLICE]]
657 // CHECK-SAME: : tensor<4x1x?xf32> into tensor<?x?x?xf32>
662 func.func @rank_reducing_insert_slice_canonicalize(%arg0 : tensor<?x?xf32>, %arg1 : index,
663 %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
668 %0 = tensor.insert_slice %arg0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?xf32> into tensor<?x?x?xf32>
669 return %0 : tensor<?x?x?xf32>
672 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
673 // CHECK: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<?x?xf32> to tensor<4x?xf32>
674 // CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[CAST]]
676 // CHECK-SAME: : tensor<4x?xf32> into tensor<?x?x?xf32>
681 func.func @rank_reducing_slice_to_insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
682 %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
687 %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?xf32>
688 %1 = tensor.insert_slice %0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?xf32> into tensor<?x?x?xf32>
689 return %1 : tensor<?x?x?xf32>
692 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
693 // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
694 // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
696 // CHECK-SAME: : tensor<?x?x?xf32> to tensor<4x?xf32>
697 // CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ARG3]]
699 // CHECK-SAME: : tensor<4x?xf32> into tensor<?x?x?xf32>
704 func.func @insert_slice_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>,
705 %arg2 : index, %arg3 : index) -> tensor<?x?xi32> {
710 %0 = tensor.dim %arg0, %c1 : tensor<2x?xi32>
711 %1 = tensor.extract %arg1[] : tensor<i32>
712 %2 = tensor.generate %arg2, %c8 {
714 tensor.yield %1 : i32
715 } : tensor<?x?xi32>
716 %3 = tensor.insert_slice %arg0 into %2[0, %arg3] [2, %0] [1, 1] : tensor<2x?xi32> into tensor<?x?xi32>
717 return %3 : tensor<?x?xi32>
720 // CHECK: %[[UPDATED:.+]] = tensor.insert_slice %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1]
721 // CHECK-SAME: tensor<2x?xi32> into tensor<?x8xi32>
722 // CHECK: %[[CAST:.+]] = tensor.cast %[[UPDATED]]
727 func.func @insert_slice_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : tensor<i32>) -> tensor<3x9xi32> {
730 %2 = tensor.extract %arg1[] : tensor<i32>
731 %4 = tensor.generate %c3, %c9 {
733 tensor.yield %2 : i32
734 } : tensor<?x?xi32>
735 %5 = tensor.insert_slice %arg0 into %4[0, 1] [2, 3] [1, 1] : tensor<2x3xi32> into tensor<?x?xi32>
736 %6 = tensor.cast %5 : tensor<?x?xi32> to tensor<3x9xi32>
737 return %6 : tensor<3x9xi32>
740 // CHECK-SAME: %[[ARG0:[a-zA-z0-9_]+]]: tensor<2x3xi32>
741 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<i32>
742 // CHECK: %[[PAD:.+]] = tensor.extract %[[ARG1]]
743 // CHECK: %[[GENERATE:.+]] = tensor.generate
744 // CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]] into %[[GENERATE]]
749 // Test case: Folding of tensor.dim(tensor.generate %idx) -> %idx
752 // CHECK-NOT: tensor.dim
756 %0 = tensor.generate %arg0, %arg1 {
758 tensor.yield %c3 : index
759 } : tensor<2x?x4x?x5xindex>
760 %1 = tensor.dim %0, %c3 : tensor<2x?x4x?x5xindex>
766 // Test case: Folding tensor.dim(tensor.cast %0, %idx) -> tensor.dim %0, %idx
768 // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
771 // CHECK: %[[T0:.+]] = tensor.dim %[[ARG0]], %[[C1]]
773 func.func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
776 %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
777 %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
778 %2 = tensor.dim %0, %c1 : tensor<?x?xf32>
785 func.func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
786 // CHECK-SAME: %[[ARG0:.*]]: tensor<1x?xf32>
787 %0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32>
788 // CHECK: %[[RES:.*]] = tensor.insert_slice %[[ARG0]]
790 // CHECK-SAME: : tensor<1x?xf32> into tensor<?x?xf32>
791 %1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, %arg5] [%arg6, %arg7] : tensor<?x?xf32> into tensor<?x?xf32>
792 // CHECK: return %[[RES]] : tensor<?x?xf32>
793 return %1 : tensor<?x?xf32>
799 func.func @insert_slice_cast_no_fold(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
800 %0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x5xf32>
801 // CHECK: %[[CAST:.*]] = tensor.cast
802 // CHECK: %[[RES:.*]] = tensor.insert_slice %[[CAST]]
804 // CHECK-SAME: : tensor<?x5xf32> into tensor<?x?xf32>
805 %1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, 5] [%arg6, %arg7] : tensor<?x5xf32> into tensor<?x?xf32>
806 // CHECK: return %[[RES]] : tensor<?x?xf32>
807 return %1 : tensor<?x?xf32>
813 // CHECK-SAME: %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32>
814 // CHECK: %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x64xf32>
815 // CHECK: %[[r:.*]] = tensor.insert_slice %[[cast]] into %[[arg1]][0, 1, 2] [64, 5, 64] [1, 1, 1] : tensor<64x5x64xf32> into tensor<?x?x?xf32>
818 %arg0 : tensor<?x5x?xf32>, %arg1 : tensor<?x?x?xf32>, %sz0: index, %sz2: index) -> tensor<?x?x?xf32> {
820 %r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [%c64, 5, %c64] [1, 1, 1]
821 : tensor<?x5x?xf32> into tensor<?x?x?xf32>
822 return %r : tensor<?x?x?xf32>
828 // CHECK-SAME: %{{.+}}: tensor<?x?x?xf32>, %[[SLICE:.+]]: tensor<4x?x8xf32>
829 func.func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<4x?x8xf32>) {
832 %0 = tensor.insert_slice %slice into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
833 %1 = tensor.extract_slice %0[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<?x?x?xf32> to tensor<4x?x8xf32>
835 return %1 : tensor<4x?x8xf32>
841 // CHECK-NOT: tensor.gather
842 // CHECK: arith.constant dense<1.000000e-01> : tensor<1x2x1x1x1xf32>
843 func.func @fold_gather_constant_splat(%indices : tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> {
844 %cst = arith.constant dense<1.000000e-01> : tensor<4x4x4xf32>
845 %0 = tensor.gather %cst[%indices] gather_dims([0, 1, 2]) :
846 (tensor<4x4x4xf32>, tensor<1x2x 3xindex>) -> tensor<1x2x 1x1x1xf32>
847 return %0 : tensor<1x2x 1x1x1xf32>
853 // CHECK-NOT: tensor.reshape
854 // CHECK: arith.constant dense<1.000000e-01> : tensor<4xf32>
855 func.func @fold_reshape_constant_splat(%shape : tensor<1xi32>) -> tensor<4xf32> {
856 %cst = arith.constant dense<1.000000e-01> : tensor<4x1xf32>
857 %0 = tensor.reshape %cst(%shape)
858 : (tensor<4x1xf32>, tensor<1xi32>) -> tensor<4xf32>
859 return %0 : tensor<4xf32>
865 // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<*xf32>
866 // CHECK-SAME: %[[SHAPE_0:[a-zA-Z0-9_]+]]: tensor<?xindex>
867 // CHECK-SAME: %[[SHAPE_1:[a-zA-Z0-9_]+]]: tensor<?xindex>
868 // CHECK-SAME: %[[SHAPE_2:[a-zA-Z0-9_]+]]: tensor<?xindex>
869 // CHECK: %[[RESULT:.*]] = tensor.reshape %[[INPUT]](%[[SHAPE_2]])
871 func.func @fold_reshape_chain(%input: tensor<*xf32>, %shape_0: tensor<?xindex>, %shape_1: tensor<?xindex>, %shape_2: tensor<?xindex>) -> tensor<*xf32> {
872 %0 = tensor.reshape %input(%shape_0) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
873 %1 = tensor.reshape %0(%shape_1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
874 %2 = tensor.reshape %1(%shape_2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
875 return %2 : tensor<*xf32>
881 // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<?xf32>
882 // CHECK-SAME: %[[SHAPE:[a-zA-Z0-9_]+]]: tensor<1xindex>
884 func.func @fold_reshape_1d(%input: tensor<?xf32>, %shape: tensor<1xindex>) -> tensor<?xf32> {
885 %0 = tensor.reshape %input(%shape) : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
886 return %0 : tensor<?xf32>
892 // CHECK-NOT: tensor.extract_slice
893 // CHECK: arith.constant dense<42> : tensor<4x4xi32>
894 func.func @fold_extract_constant_splat() -> (tensor<4x4xi32>) {
895 %cst = arith.constant dense<42> : tensor<1024x1024xi32>
896 %1 = tensor.extract_slice %cst[0,0] [4,4] [1, 1] : tensor<1024x1024xi32> to tensor<4x4xi32>
897 return %1 : tensor<4x4xi32>
903 // CHECK-NOT: tensor.pack
904 // CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
905 func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
906 %cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
907 %0 = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
908 inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
909 return %0 : tensor<8x16x8x32xf32>
915 // CHECK-NOT: tensor.pack
916 // CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
917 func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
919 %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
920 %0 = tensor.pack %cst
923 inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
924 return %0 : tensor<8x16x8x32xf32>
931 // CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32>
932 // CHECK: tensor.pack
933 func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
935 %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
936 %0 = tensor.pack %cst
941 into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
942 return %0 : tensor<8x16x8x32xf32>
947 func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<31250x1200x16x1xf32> {
949 %0 = tensor.empty() : tensor<31250x1200x16x1xf32>
950 %pack = tensor.pack %arg0
955 into %0 : tensor<1200x500000xf32> -> tensor<31250x1200x16x1xf32>
956 return %pack : tensor<31250x1200x16x1xf32>
963 func.func @infer_src_shape_pack(%src: tensor<?x?x?x?xf32>, %dest: tensor<10x20x30x40x16xf32>) -> tensor<10x20x30x40x16xf32> {
965 %pack = tensor.pack %src
970 into %dest : tensor<?x?x?x?xf32> -> tensor<10x20x30x40x16xf32>
971 return %pack : tensor<10x20x30x40x16xf32>
976 // CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32>
977 // CHECK: %[[PACK:.+]] = tensor.pack %[[CAST_SRC]] {{.+}} into %[[DEST]]
982 func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?x?x?x16xf32>) -> tensor<?x?x?x?x16xf32> {
984 %pack = tensor.pack %src
989 into %dest : tensor<30x20x?x10xf32> -> tensor<?x?x?x?x16xf32>
990 return %pack : tensor<?x?x?x?x16xf32>
995 // CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32>
996 // CHECK: %[[PACK:.+]] = tensor.pack %[[SRC]] {{.+}} into %[[CAST_DEST]]
997 // CHECK: %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<?x20x10x30x16xf32> to tensor<?x?x?x?x16xf32>
1002 func.func @no_infer_pack_shape(%arg0: tensor<?x32x100xf32>, %arg1: index) -> tensor<32x7x?x16x1xf32> {
1004 %0 = tensor.empty(%arg1) : tensor<32x7x?x16x1xf32>
1005 %pack = tensor.pack %arg0 padding_value(%cst : f32) outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [16, 1] into %0 : tensor<?x32x100xf32> -> tensor<32x7x?x16x1xf32>
1006 return %pack : tensor<32x7x?x16x1xf32>
1009 // CHECK-NOT: tensor.cast
1013 func.func @fold_padding_value_pack_negative1(%arg0: tensor<1200x499999xf32>) -> tensor<31250x1200x16x1xf32> {
1015 %0 = tensor.empty() : tensor<31250x1200x16x1xf32>
1016 %pack = tensor.pack %arg0
1021 into %0 : tensor<1200x499999xf32> -> tensor<31250x1200x16x1xf32>
1022 return %pack : tensor<31250x1200x16x1xf32>
1025 // CHECK: tensor.pack
1030 func.func @fold_padding_value_pack_negative2(%arg0: tensor<1200x?xf32>, %arg1: tensor<?x1200x16x1xf32>) -> tensor<?x1200x16x1xf32> {
1032 %pack = tensor.pack %arg0
1037 into %arg1 : tensor<1200x?xf32> -> tensor<?x1200x16x1xf32>
1038 return %pack : tensor<?x1200x16x1xf32>
1041 // CHECK: tensor.pack
1046 func.func @fold_padding_value_pack_negative3(%arg0: tensor<1200x500000xf32>, %arg1: tensor<?x1200x?x1xf32>, %tile : index) -> tensor<?x1200x?x1xf32> {
1048 %pack = tensor.pack %arg0
1053 into %arg1 : tensor<1200x500000xf32> -> tensor<?x1200x?x1xf32>
1054 return %pack : tensor<?x1200x?x1xf32>
1057 // CHECK: tensor.pack
1063 // CHECK-NOT: tensor.unpack
1064 // CHECK: arith.constant dense<1.000000e-01> : tensor<128x256xf32>
1065 func.func @fold_unpack_constant_splat(%dest : tensor<128x256xf32>) -> tensor<128x256xf32> {
1066 %cst = arith.constant dense<1.000000e-01> : tensor<16x8x8x32xf32>
1067 %0 = tensor.unpack %cst inner_dims_pos = [0, 1]
1068 inner_tiles = [8, 32] into %dest : tensor<16x8x8x32xf32> -> tensor<128x256xf32>
1069 return %0 : tensor<128x256xf32>
1074 func.func @infer_dest_shape_unpack(%src: tensor<10x20x30x40x16xf32>, %dest: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
1075 %unpack = tensor.unpack %src
1079 into %dest : tensor<10x20x30x40x16xf32> -> tensor<?x?x?x?xf32>
1080 return %unpack : tensor<?x?x?x?xf32>
1085 // CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32>
1086 // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[SRC]] {{.+}} into %[[CAST_DEST]]
1087 // CHECK: %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<40x20x?x30xf32> to tensor<?x?x?x?xf32>
1092 func.func @infer_src_shape_unpack(%src: tensor<?x?x?x?x16xf32>, %dest: tensor<30x20x?x10xf32>) -> tensor<30x20x?x10xf32> {
1093 %unpack = tensor.unpack %src
1097 into %dest : tensor<?x?x?x?x16xf32> -> tensor<30x20x?x10xf32>
1098 return %unpack : tensor<30x20x?x10xf32>
1103 // CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32>
1104 // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[CAST_SRC]]
1109 func.func @no_infer_unpack_shape(%arg1: tensor<32x7x?x16x1xf32>, %arg2: index) -> tensor<?x32x100xf32> {
1111 %0 = tensor.empty(%arg2) : tensor<?x32x100xf32>
1112 %unpack = tensor.unpack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [16, 1] into %0 : tensor<32x7x?x16x1xf32> -> tensor<?x32x100xf32>
1113 return %unpack : tensor<?x32x100xf32>
1116 // CHECK-NOT: tensor.cast
1122 // CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
1123 func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {
1126 %0 = tensor.insert_slice %slice1 into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
1127 // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE2]] into %[[INPUT]]
1128 %1 = tensor.insert_slice %slice2 into %0[0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
1130 return %1 : tensor<?x?x?xf32>
1135 func.func @compose_expand_of_expand(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index)
1136 -> tensor<?x6x4x?x5xf32> {
1137 %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
1138 : tensor<?x?xf32> into tensor<?x4x?xf32>
1139 %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4]] output_shape [%arg3, 6, 4, %arg4, 5] : tensor<?x4x?xf32> into tensor<?x6x4x?x5xf32>
1140 return %1 : tensor<?x6x4x?x5xf32>
1143 // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] output_shape [%arg3, 6, 4, %arg4, 5]
1144 // CHECK-NOT: tensor.expand_shape
1148 func.func @compose_expand_of_expand_of_zero_dim(%arg0 : tensor<f32>)
1149 -> tensor<1x1x1xf32> {
1150 %0 = tensor.expand_shape %arg0 [] output_shape [1] : tensor<f32> into tensor<1xf32>
1151 %1 = tensor.expand_shape %0 [[0, 1, 2]] output_shape [1, 1, 1]
1152 : tensor<1xf32> into tensor<1x1x1xf32>
1153 return %1 : tensor<1x1x1xf32>
1156 // CHECK: tensor.expand_shape %{{.*}} [] output_shape [1, 1, 1]
1157 // CHECK-SAME: tensor<f32> into tensor<1x1x1xf32>
1162 // CHECK-SAME: %[[IN:.*]]: tensor<8x12x32xf32>) -> tensor<?x32xf32> {
1163 // CHECK-NEXT: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[IN]] {{\[}}[0, 1], [2]] : tensor<8x12x32xf32> into tensor<96x32xf32>
1164 // CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[COLLAPSE]] : tensor<96x32xf32> to tensor<?x32xf32>
1165 // CHECK-NEXT: return %[[CAST]] : tensor<?x32xf32>
1166 func.func @collapse_of_cast(%t: tensor<8x12x32xf32>) -> tensor<?x32xf32> {
1167 %0 = tensor.cast %t : tensor<8x12x32xf32> to tensor<?x?x?xf32>
1168 %1 = tensor.collapse_shape %0 [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
1169 %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<?x32xf32>
1170 return %2 : tensor<?x32xf32>
1175 func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> {
1176 %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [3, 4, 4]
1177 : tensor<12x4xf32> into tensor<3x4x4xf32>
1178 %1 = tensor.collapse_shape %0 [[0, 1], [2]]
1179 : tensor<3x4x4xf32> into tensor<12x4xf32>
1180 return %1 : tensor<12x4xf32>
1183 // CHECK-NOT: tensor.{{.*}}_shape
1187 func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index)
1188 -> tensor<?x?xf32> {
1189 %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
1190 : tensor<?x?xf32> into tensor<?x4x?xf32>
1191 %1 = tensor.collapse_shape %0 [[0, 1], [2]]
1192 : tensor<?x4x?xf32> into tensor<?x?xf32>
1193 return %1 : tensor<?x?xf32>
1196 // CHECK-NOT: tensor.{{.*}}_shape
1200 func.func @fold_collapse_of_expand_fully_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
1201 -> tensor<?x?xf32> {
1202 %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
1203 : tensor<?x?xf32> into tensor<?x?x?xf32>
1204 %1 = tensor.collapse_shape %0 [[0, 1], [2]]
1205 : tensor<?x?x?xf32> into tensor<?x?xf32>
1206 return %1 : tensor<?x?xf32>
1209 // CHECK-NOT: tensor.{{.*}}_shape
1213 func.func @no_fold_parallel_collapse_of_expand_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index)
1214 -> tensor<?x?x?xf32> {
1215 %0 = tensor.expand_shape %arg0 [[0, 1], [2], [3]] output_shape [%arg1, %arg2, %arg3, %arg4]
1216 : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
1217 %1 = tensor.collapse_shape %0 [[0], [1], [2, 3]]
1218 : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
1219 return %1 : tensor<?x?x?xf32>
1222 // CHECK: tensor.expand_shape
1223 // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
1228 func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf32> {
1229 %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
1230 : tensor<3x4x4xf32> into tensor<12x4xf32>
1231 %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 4]
1232 : tensor<12x4xf32> into tensor<3x4x4xf32>
1233 return %1 : tensor<3x4x4xf32>
1236 // CHECK-NOT: tensor.{{.*}}_shape
1240 func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
1241 -> tensor<?x4x?xf32> {
1242 %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
1243 : tensor<?x4x?xf32> into tensor<?x?xf32>
1244 %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
1245 : tensor<?x?xf32> into tensor<?x4x?xf32>
1246 return %1 : tensor<?x4x?xf32>
1249 // CHECK-NOT: tensor.{{.*}}_shape
1253 func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
1254 -> tensor<?x?x?xf32> {
1255 %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
1256 : tensor<?x?x?xf32> into tensor<?x?xf32>
1257 %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
1258 : tensor<?x?xf32> into tensor<?x?x?xf32>
1259 return %1 : tensor<?x?x?xf32>
1262 // CHECK: tensor.collapse_shape
1263 // CHECK: %[[EXPAND:.+]] = tensor.expand_shape
1268 func.func @compose_expand_of_collapse_last_two_dims(%arg0: tensor<?x64x1xf32>) -> tensor<?x384xf32> {
1269 %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor<?x64x1xf32> into tensor<?xf32>
1271 %dim = tensor.dim %collapsed, %c0 : tensor<?xf32>
1274 %expanded = tensor.expand_shape %collapsed [[0, 1]] output_shape [%div, 384] : tensor<?xf32> into tensor<?x384xf32>
1275 return %expanded : tensor<?x384xf32>
1279 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x64x1xf32>
1282 // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]] : tensor<?x64x1xf32> into tensor<?xf32>
1283 // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[CONSTANT0]] : tensor<?x64x1xf32>
1286 // CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1]] output_shape [%[[DIVUI]], 384] : tensor<?xf32> into tensor<?x384xf32>
1291 func.func @compose_expand_of_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
1292 -> tensor<24x5x42x8xf32> {
1293 %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]]
1294 : tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32>
1295 %1 = tensor.expand_shape %0 [[0, 1, 2, 3]] output_shape [24, 5, 42, 8]
1296 : tensor<40320xf32> into tensor<24x5x42x8xf32>
1297 return %1 : tensor<24x5x42x8xf32>
1300 // CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8xf32>
1301 // CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
1307 func.func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>)
1308 -> tensor<2x3x4x5x6x7x8xf32> {
1309 %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3]]
1310 : tensor<24x5x42x8xf32> into tensor<40320xf32>
1311 %1 = tensor.expand_shape %0 [[0, 1, 2, 3, 4, 5, 6]] output_shape [2, 3, 4, 5, 6, 7, 8]
1312 : tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32>
1313 return %1 : tensor<2x3x4x5x6x7x8xf32>
1316 // CHECK-SAME: %[[ARG0:.+]]: tensor<24x5x42x8xf32>
1317 // CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]]
1323 func.func @compose_collapse_of_expand(%arg : tensor<?x?x?xi64>, %arg1: index, %arg2: index, %arg3: index)
1324 -> tensor<?x?xi64> {
1325 %0 = tensor.expand_shape %arg [[0], [1], [2, 3]] output_shape [%arg1, %arg2, %arg3, 1]
1326 : tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
1327 %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
1328 : tensor<?x?x?x1xi64> into tensor<?x?xi64>
1329 return %1 : tensor<?x?xi64>
1332 // CHECK: (%[[ARG:.*]]: tensor<?x?x?xi64>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
1333 // CHECK-NEXT: tensor.collapse_shape %[[ARG]]
1335 // CHECK-SAME: : tensor<?x?x?xi64> into tensor<?x?xi64>
1339 func.func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>)
1340 -> tensor<4x512xf32> {
1341 %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3]] output_shape [1, 4, 1, 512]
1342 : tensor<2048xf32> into tensor<1x4x1x512xf32>
1343 %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]]
1344 : tensor<1x4x1x512xf32> into tensor<4x512xf32>
1345 return %1 : tensor<4x512xf32>
1348 // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [4, 512]
1349 // CHECK-SAME: tensor<2048xf32> into tensor<4x512xf32>
1353 func.func @compose_expand_of_collapse_0_rank_to_expand(%arg0 : tensor<1x1x1xf32>)
1354 -> tensor<1x1x1x1xf32> {
1355 %0 = tensor.collapse_shape %arg0 []
1356 : tensor<1x1x1xf32> into tensor<f32>
1357 %1 = tensor.expand_shape %0 [] output_shape [1, 1, 1, 1]
1358 : tensor<f32> into tensor<1x1x1x1xf32>
1359 return %1 : tensor<1x1x1x1xf32>
1362 // CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x1xf32>
1363 // CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]]
1369 func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1xf32>)
1370 -> tensor<1x1x1xf32> {
1371 %0 = tensor.collapse_shape %arg0 []
1372 : tensor<1x1x1x1xf32> into tensor<f32>
1373 %1 = tensor.expand_shape %0 [] output_shape [1, 1, 1]
1374 : tensor<f32> into tensor<1x1x1xf32>
1375 return %1 : tensor<1x1x1xf32>
1378 // CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x1x1xf32>
1379 // CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
1385 func.func @compose_expand_of_collapse_static(%arg0 : tensor<4x32x10x64x2xf16>) -> tensor<4x32x10x128xf16> {
1386 %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x32x10x64x2xf16> into tensor<128x10x128xf16>
1387 %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 32, 10, 128] : tensor<128x10x128xf16> into tensor<4x32x10x128xf16>
1388 return %expanded : tensor<4x32x10x128xf16>
1392 // CHECK-SAME: %[[ARG0:.+]]: tensor<4x32x10x64x2xf16>
1393 // CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
1399 func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %arg1 : index) -> tensor<4x?x10x128xf16> {
1400 %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x?x10x64x2xf16> into tensor<?x10x128xf16>
1401 %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, %arg1, 10, 128] : tensor<?x10x128xf16> into tensor<4x?x10x128xf16>
1402 return %expanded : tensor<4x?x10x128xf16>
1406 // CHECK-SAME: %[[ARG0:.+]]: tensor<4x?x10x64x2xf16>
1407 // CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
1414 func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
1416 %0 = tensor.expand_shape %arg0 [] output_shape [1] : tensor<f32> into tensor<1xf32>
1417 %1 = tensor.expand_shape %0 [[0, 1]] output_shape [1, 1] : tensor<1xf32> into tensor<1x1xf32>
1418 %2 = tensor.collapse_shape %1 [] : tensor<1x1xf32> into tensor<f32>
1419 return %2 : tensor<f32>
1424 func.func @compose_collapse_of_collapse(%arg0 : tensor<?x?x?x?x?xf32>)
1425 -> tensor<?x?xf32> {
1426 %0 = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]]
1427 : tensor<?x?x?x?x?xf32> into tensor<?x?x?xf32>
1428 %1 = tensor.collapse_shape %0 [[0, 1], [2]]
1429 : tensor<?x?x?xf32> into tensor<?x?xf32>
1430 return %1 : tensor<?x?xf32>
1433 // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
1434 // CHECK-NOT: tensor.collapse_shape
1438 func.func @compose_collapse_of_collapse_zero_dim(%arg0 : tensor<1x1x1xf32>)
1439 -> tensor<f32> {
1440 %0 = tensor.collapse_shape %arg0 [[0, 1, 2]]
1441 : tensor<1x1x1xf32> into tensor<1xf32>
1442 %1 = tensor.collapse_shape %0 [] : tensor<1xf32> into tensor<f32>
1443 return %1 : tensor<f32>
1446 // CHECK: tensor.collapse_shape %{{.*}} []
1447 // CHECK-SAME: tensor<1x1x1xf32> into tensor<f32>
1451 func.func @fold_collapse_of_expand_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> {
1452 %0 = tensor.expand_shape %arg0 [[0, 1, 2], [3]] output_shape [1, 4, 1, 512]
1453 : tensor<4x512xf32> into tensor<1x4x1x512xf32>
1454 %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]]
1455 : tensor<1x4x1x512xf32> into tensor<2048xf32>
1456 return %1 : tensor<2048xf32>
1459 // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1]]
1460 // CHECK-SAME: tensor<4x512xf32> into tensor<2048xf32>
1464 func.func @fold_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x1xf32>)
1465 -> tensor<4x512x1x1xf32> {
1466 %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3], [4], [5]] output_shape [1, 4, 1, 512, 1, 1] : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32>
1467 %1 = tensor.collapse_shape %0 [[0, 1, 2], [3], [4], [5]]
1468 : tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32>
1469 return %1 : tensor<4x512x1x1xf32>
1472 // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]] output_shape [4, 512, 1, 1]
1473 // CHECK-SAME: tensor<2048x1x1xf32> into tensor<4x512x1x1xf32>
1477 func.func @compose_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x2048xf32>)
1478 -> tensor<4x512x1x512x4xf32> {
1479 %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]] output_shape [1, 4, 1, 512, 1, 1, 512, 1, 4] : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32>
1480 %1 = tensor.collapse_shape %0 [[0, 1, 2], [3, 4], [5], [6, 7], [8]]
1481 : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32>
1482 return %1 : tensor<4x512x1x512x4xf32>
1485 // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]] output_shape [4, 512, 1, 512, 4]
1486 // CHECK-SAME: tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32>
1490 func.func @compose_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>)
1491 -> tensor<2x1xf32> {
1492 %0 = tensor.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 1, 1]
1493 : tensor<2xf32> into tensor<2x1x1xf32>
1494 %1 = tensor.collapse_shape %0 [[0], [1, 2]]
1495 : tensor<2x1x1xf32> into tensor<2x1xf32>
1496 return %1 : tensor<2x1xf32>
1499 // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [2, 1]
1500 // CHECK-SAME: tensor<2xf32> into tensor<2x1xf32>
1505 %arg0 : tensor<?x1x?x1x1x?x?x1x1xf32>) -> tensor<?x?x?x?xf32> {
1506 %0 = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]]
1507 : tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x1x1x?x?xf32>
1508 %1 = tensor.collapse_shape %0 [[0], [1], [2, 3, 4], [5]]
1509 : tensor<?x?x1x1x?x?xf32> into tensor<?x?x?x?xf32>
1510 return %1 : tensor<?x?x?x?xf32>
1513 // CHECK: tensor.collapse_shape
1515 // CHECK-SAME: tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x?x?xf32>
1519 func.func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>)
1520 -> tensor<2x1xf32> {
1521 %0 = tensor.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 1, 1] : tensor<2xf32> into tensor<2x1x1xf32>
1522 %1 = tensor.collapse_shape %0 [[0], [1, 2]]
1523 : tensor<2x1x1xf32> into tensor<2x1xf32>
1524 return %1 : tensor<2x1xf32>
1527 // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [2, 1]
1528 // CHECK-SAME: tensor<2xf32> into tensor<2x1xf32>
1533 %arg0: tensor<1x1x?x1x1x1xf32>) -> tensor<?xf32> {
1534 %0 = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4], [5]]
1535 : tensor<1x1x?x1x1x1xf32> into tensor<?x1x1x1xf32>
1536 %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]]
1537 : tensor<?x1x1x1xf32> into tensor<?xf32>
1538 return %1 : tensor<?xf32>
1541 // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]]
1542 // CHECK-SAME: tensor<1x1x?x1x1x1xf32> into tensor<?xf32>
1546 func.func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>)
1547 -> tensor<12x42xf32> {
1548 %0 = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [12, 42, 1, 1, 1] : tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32>
1549 %1 = tensor.collapse_shape %0 [[0], [1, 2, 3, 4]]
1550 : tensor<12x42x1x1x1xf32> into tensor<12x42xf32>
1551 return %1 : tensor<12x42xf32>
1554 // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0], [1, 2, 3]]
1555 // CHECK-SAME: tensor<12x42x1x1xf32> into tensor<12x42xf32>
1559 func.func @fold_collapse_of_expand_unit_dims_in_middle(%arg0 : tensor<?x?x?xf32>, %sz0: index, %sz1: index, %sz2: index)
1560 -> tensor<?x?xf32> {
1561 %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [%sz0, %sz1, 1, %sz2]
1562 : tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
1563 %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]]
1564 : tensor<?x?x1x?xf32> into tensor<?x?xf32>
1565 return %1 : tensor<?x?xf32>
1568 // CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>
1569 // CHECK: tensor.collapse_shape %[[ARG]] {{\[}}[0], [1, 2]]
1570 // CHECK-SAME: tensor<?x?x?xf32> into tensor<?x?xf32>
1574 func.func @no_fold_collapse_of_expand_incompatible(%arg0 : tensor<4x6x8xf32>)
1575 -> tensor<2x6x16xf32> {
1576 %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]] output_shape [2, 2, 3, 2, 8]
1577 : tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32>
1578 %1 = tensor.collapse_shape %0 [[0], [1, 2], [3, 4]]
1579 : tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32>
1580 return %1 : tensor<2x6x16xf32>
1583 // CHECK: tensor.expand_shape
1584 // CHECK: tensor.collapse_shape
1588 func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>)
1589 -> tensor<12x1xf32> {
1590 %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [3, 2, 2, 1]
1591 : tensor<3x2x2xf32> into tensor<3x2x2x1xf32>
1592 %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]]
1593 : tensor<3x2x2x1xf32> into tensor<12x1xf32>
1594 return %1 : tensor<12x1xf32>
1597 // CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x2xf32>
1598 // CHECK: %[[RARG0:.+]] = tensor.expand_shape %[[ARG0]]
1600 // CHECK: %[[RES:.+]] = tensor.collapse_shape %[[RARG0]]
1602 // CHECK: return %[[RES:.+]] : tensor<12x1xf32>
1606 func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> {
1607 %c0 = arith.constant dense<42> : tensor<2x8xi32>
1608 %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
1609 : tensor<2x8xi32> into tensor<2x4x2xi32>
1610 return %0 : tensor<2x4x2xi32>
1613 // CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xi32>
1614 // CHECK-NOT: tensor.expand_shape
1617 func.func @expand_shape_splat(%arg : f32) -> tensor<2x2x2xf32> {
1618 %c0 = tensor.splat %arg : tensor<2x4xf32>
1619 %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 2, 2]
1620 : tensor<2x4xf32> into tensor<2x2x2xf32>
1621 return %0 : tensor<2x2x2xf32>
1625 // CHECK: %[[CST:.*]] = tensor.splat %[[ARG0:.+]] : tensor<2x2x2xf32>
1626 // CHECK-NOT: tensor.expand_shape
1633 func.func @expand_shape_splat_dynamic_no_fold(%arg: f32, %m: index, %sz0: index) -> tensor<2x2x?xf32> {
1634 // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]] : tensor<2x?xf32>
1635 // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[SPLAT]]
1636 %c0 = tensor.splat %arg[%m] : tensor<2x?xf32>
1637 %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 2, %sz0] : tensor<2x?xf32> into tensor<2x2x?xf32>
1638 return %0 : tensor<2x2x?xf32>
1643 func.func @collapse_shape_splat(%arg : f32) -> tensor<2x4xf32> {
1644 %c0 = tensor.splat %arg : tensor<2x2x2xf32>
1645 %0 = tensor.collapse_shape %c0 [[0], [1, 2]]
1646 : tensor<2x2x2xf32> into tensor<2x4xf32>
1647 return %0 : tensor<2x4xf32>
1651 // CHECK: %[[CST:.*]] = tensor.splat %[[ARG0:.+]] : tensor<2x4xf32>
1652 // CHECK-NOT: tensor.collapse_shape
1660 func.func @collapse_shape_splat_dynamic_no_fold(%f: f32, %m: index) -> tensor<2x?xf32> {
1661 // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]]
1662 // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[SPLAT]]
1663 %c0 = tensor.splat %f[%m] : tensor<2x2x?xf32>
1664 %0 = tensor.collapse_shape %c0 [[0], [1, 2]] : tensor<2x2x?xf32> into tensor<2x?xf32>
1665 return %0 : tensor<2x?xf32>
1670 func.func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> {
1671 %c0 = arith.constant dense<42> : tensor<2x8xi16>
1672 %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
1673 : tensor<2x8xi16> into tensor<2x4x2xi16>
1674 return %0 : tensor<2x4x2xi16>
1677 // CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xi16>
1678 // CHECK-NOT: tensor.expand_shape
1683 func.func @reshape_splat_constant_float32() -> tensor<2x4x2xf32> {
1684 %c0 = arith.constant dense<42.0> : tensor<2x8xf32>
1685 %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
1686 : tensor<2x8xf32> into tensor<2x4x2xf32>
1687 return %0 : tensor<2x4x2xf32>
1690 // CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xf32>
1691 // CHECK-NOT: tensor.expand_shape
1696 func.func @reshape_splat_constant_float64() -> tensor<2x4x2xf64> {
1697 %c0 = arith.constant dense<42.0> : tensor<2x8xf64>
1698 %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
1699 : tensor<2x8xf64> into tensor<2x4x2xf64>
1700 return %0 : tensor<2x4x2xf64>
1703 // CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xf64>
1704 // CHECK-NOT: tensor.expand_shape
1712 : tensor<2x1x4xi32>
1716 %rank_0 = tensor.rank %const_0 : tensor<2x1x4xi32>
1725 // CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
1726 // CHECK-NOT: tensor.pad
1728 func.func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
1729 -> tensor<5x6xf32> {
1731 %0 = tensor.pad %arg0 low[%a, 0] high[0, %a] {
1733 tensor.yield %cst : f32
1734 } : tensor<5x6xf32> to tensor<5x6xf32>
1735 return %0 : tensor<5x6xf32>
1741 // CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
1744 // CHECK: %[[PADDED:.*]] = tensor.pad %[[INPUT]]
1747 // CHECK: tensor.yield %[[CST]] : f32
1748 // CHECK: } : tensor<?x64x?x?xf32> to tensor<?x72x?x?xf32>
1749 // CHECK: tensor.cast
1750 func.func @pad_fold_static(%arg0: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
1754 %padded = tensor.pad %arg0 low[0, %padding, 1, 1] high[0, %padding, 1, 1] {
1756 tensor.yield %cst: f32
1757 } : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
1758 return %padded : tensor<?x?x?x?xf32>
1764 // CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
1765 // CHECK: %[[PAD:.*]] = tensor.pad
1767 func.func @pad_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
1768 -> tensor<5x6xf32> {
1770 %0 = tensor.pad %arg0 nofold low[%a, 0] high[0, %a] {
1772 tensor.yield %cst : f32
1773 } : tensor<5x6xf32> to tensor<5x6xf32>
1774 return %0 : tensor<5x6xf32>
1780 // CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
1782 // CHECK: %[[PADDED:.*]] = tensor.pad %[[INPUT]]
1785 // CHECK: tensor.yield %[[CST]] : f32
1786 // CHECK: } : tensor<?x64x?x?xf32> to tensor<?x64x?x?xf32>
1787 // CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[PADDED:.*]] :
1788 // CHECK-SAME: tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
1789 // CHECK: return %[[DYNAMIC]] : tensor<?x?x?x?xf32>
1791 func.func @pad_after_cast_different_shape(%arg0: tensor<?x64x?x?xf32>)
1792 -> tensor<?x?x?x?xf32> {
1794 %dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
1795 %padded = tensor.pad %dynamic low[0, 0, 1, 1] high[0, 0, 1, 1] {
1797 tensor.yield %cst: f32
1798 } : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
1799 return %padded: tensor<?x?x?x?xf32>
1805 // CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>,
1806 // CHECK-SAME: %[[PADDING:.*]]: index) -> tensor<?x?x?x?xf32> {
1808 // CHECK: %[[PADDED:.*]] = tensor.pad %[[INPUT]]
1811 // CHECK: tensor.yield %[[CST]] : f32
1812 // CHECK: } : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
1813 // CHECK: return %[[PADDED:.*]] : tensor<?x?x?x?xf32>
1815 func.func @pad_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : index)
1816 -> tensor<?x?x?x?xf32> {
1818 %dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
1819 %padded = tensor.pad %dynamic low[0, %padding, 1, 1] high[0, %padding, 1, 1] {
1821 tensor.yield %cst: f32
1822 } : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
1823 return %padded: tensor<?x?x?x?xf32>
1829 // CHECK-NOT: tensor.cast
1830 // CHECK: tensor.pad
1831 // CHECK: tensor<8x?xf32> to tensor<8x32xf32>
1832 func.func @pad_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> {
1835 %0 = tensor.cast %t : tensor<8x?xf32> to tensor<?x?xf32>
1836 %1 = tensor.pad %0 low[%c0, %c0] high[%c0, %s] {
1838 tensor.yield %cst : f32
1839 } : tensor<?x?xf32> to tensor<8x32xf32>
1840 return %1 : tensor<8x32xf32>
1846 func.func @cast_of_pad_more_static(%arg0: tensor<?x?xf32>, %padding: index) -> tensor<32x32xf32> {
1848 // CHECK: %[[PAD:.*]] = tensor.pad
1849 // CHECK: tensor<?x?xf32> to tensor<32x32xf32>
1850 %padded = tensor.pad %arg0 low[%padding, %padding] high[0, 0] {
1852 tensor.yield %cst : f32
1853 } : tensor<?x?xf32> to tensor<?x?xf32>
1854 // CHECK-NOT: tensor.cast
1855 %casted = tensor.cast %padded : tensor<?x?xf32> to tensor<32x32xf32>
1857 return %casted : tensor<32x32xf32>
1863 func.func @cast_of_pad_less_static(%arg0: tensor<32x?x?xf32>, %padding: index) -> tensor<?x32x32xf32> {
1865 // CHECK: tensor.pad
1866 %padded = tensor.pad %arg0 low[%padding, %padding, %padding] high[0, 0, 0] {
1868 tensor.yield %cst : f32
1869 } : tensor<32x?x?xf32> to tensor<32x?x?xf32>
1870 // CHECK: %[[CAST:.*]] = tensor.cast
1871 %casted = tensor.cast %padded : tensor<32x?x?xf32> to tensor<?x32x32xf32>
1873 return %casted : tensor<?x32x32xf32>
1878 func.func @pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
1881 %0 = tensor.cast %arg0 : tensor<4x4xf32> to tensor<?x?xf32>
1882 %1 = tensor.pad %0 low[%c0, %c0] high[%c0, %c0] {
1884 tensor.yield %cst : f32
1885 } : tensor<?x?xf32> to tensor<4x4xf32>
1886 return %1 : tensor<4x4xf32>
1889 // CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32>
1895 // CHECK-SAME: %[[ARG0:.*]]: tensor<4x?xf32>
1896 // CHECK-NOT: tensor.cast
1897 // CHECK: %[[RESULT:.*]] = tensor.pad %[[ARG0]]
1898 func.func @fold_pad_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> {
1900 %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
1901 %1 = tensor.pad %0 low[0, 0] high[0, 1] {
1903 tensor.yield %cst : f32
1904 } : tensor<?x?xf32> to tensor<4x4xf32>
1905 return %1 : tensor<4x4xf32>
1911 // CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?xf32>
1912 // CHECK-NOT: tensor.pad
1913 // CHECK: %[[RESULT:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
1915 func.func @pad_static_zero_cast(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
1917 %0 = tensor.pad %arg0 low[0, %c0, 0] high[0, 0, %c0] {
1919 tensor.yield %pad_value : f32
1920 } : tensor<?x?x?xf32> to tensor<2x3x4xf32>
1922 return %0 : tensor<2x3x4xf32>
1928 // CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?xf32>
1929 // CHECK: %[[PAD:.*]] = tensor.pad
1931 func.func @pad_nofold_static_zero(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
1933 %0 = tensor.pad %arg0 nofold low[0, %c0, 0] high[0, 0, %c0] {
1935 tensor.yield %pad_value : f32
1936 } : tensor<?x?x?xf32> to tensor<2x3x4xf32>
1938 return %0 : tensor<2x3x4xf32>
1944 // CHECK-SAME: %[[ARG0:.*]]: tensor<64x64xf32>,
1946 func.func @fold_orthogonal_pad_chains(%arg0: tensor<64x64xf32>,
1948 %pw0 : index, %pw1 : index) -> tensor<8x4xf32> {
1949 // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
1951 // CHECK: %[[PAD:.*]] = tensor.pad %[[T0]] nofold
1955 %0 = tensor.extract_slice %arg0[16, 0] [%sz0, 64] [1, 1] : tensor<64x64xf32> to tensor<?x64xf32>
1956 %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] {
1958 tensor.yield %pad_value : f32
1959 } : tensor<?x64xf32> to tensor<8x64xf32>
1960 %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32>
1961 %3 = tensor.pad %2 nofold low[0, 0] high[0, %pw1] {
1963 tensor.yield %pad_value : f32
1964 } : tensor<8x?xf32> to tensor<8x4xf32>
1965 func.return %3 : tensor<8x4xf32>
1971 // CHECK-SAME: %[[ARG0:.*]]: tensor<64x64xf32>,
1973 func.func @dont_fold_pad_chains(%arg0: tensor<64x64xf32>,
1975 %pw0 : index, %pw1 : index) -> (tensor<8x4xf32>, tensor<4x64xf32>, tensor<8x4xf32>, tensor<6x4xf32>) {
1976 // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
1977 // CHECK: %[[T1:.*]] = tensor.pad %[[T0]]
1979 %0 = tensor.extract_slice %arg0[16, 0] [%sz0, 64] [1, 1] : tensor<64x64xf32> to tensor<?x64xf32>
1980 %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] {
1982 tensor.yield %pad_value : f32
1983 } : tensor<?x64xf32> to tensor<8x64xf32>
1986 // CHECK: %[[T2:.*]] = tensor.extract_slice %[[T1]]
1988 // CHECK: %[[PAD0:.*]] = tensor.pad %[[T2]]
1990 %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32>
1991 %3 = tensor.pad %2 nofold low[0, 0] high[0, %pw1] {
1993 tensor.yield %different_value : f32
1994 } : tensor<8x?xf32> to tensor<8x4xf32>
1997 // CHECK: %[[T3:.*]] = tensor.extract_slice %[[T1]]
1999 // CHECK: %[[PAD1:.*]] = tensor.pad %[[T3]]
2000 %4 = tensor.extract_slice %1[4, 0] [%sz1, 64] [1, 1] : tensor<8x64xf32> to tensor<?x64xf32>
2001 %5 = tensor.pad %4 nofold low[0, 0] high[%pw1, 0] {
2003 tensor.yield %pad_value : f32
2004 } : tensor<?x64xf32> to tensor<4x64xf32>
2006 // Don't fold if padded source tensor dimension is accessed at an offset.
2007 // CHECK: %[[T4:.*]] = tensor.extract_slice %[[T1]]
2009 // CHECK: %[[PAD2:.*]] = tensor.pad %[[T4]]
2010 %6 = tensor.extract_slice %1[%sz0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32>
2011 %7 = tensor.pad %6 nofold low[0, 0] high[0, %pw1] {
2013 tensor.yield %pad_value : f32
2014 } : tensor<8x?xf32> to tensor<8x4xf32>
2016 // Don't fold if a padded source tensor dimension is sliced.
2017 // CHECK: %[[T5:.*]] = tensor.extract_slice %[[T1]]
2019 // CHECK: %[[PAD3:.*]] = tensor.pad %[[T5]]
2020 %8 = tensor.extract_slice %1[0, 4] [6, %sz1] [1, 1] : tensor<8x64xf32> to tensor<6x?xf32>
2021 %9 = tensor.pad %8 nofold low[0, 0] high[0, %pw1] {
2023 tensor.yield %pad_value : f32
2024 } : tensor<6x?xf32> to tensor<6x4xf32>
2027 func.return %3, %5, %7, %9 : tensor<8x4xf32>, tensor<4x64xf32>, tensor<8x4xf32>, tensor<6x4xf32>
2033 // CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<2x3xf32>
2035 // CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] low[1, 3] high[4, 2]
2036 // CHECK: tensor.yield %[[PADVAL]]
2038 func.func @merge_constant_padding(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
2039 %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
2041 tensor.yield %pad_value : f32
2042 } : tensor<2x3xf32> to tensor<4x4xf32>
2043 %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
2045 tensor.yield %pad_value : f32
2046 } : tensor<4x4xf32> to tensor<7x8xf32>
2047 return %pad1 : tensor<7x8xf32>
2054 // CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<?x?xf32>
2058 // CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] low[%[[IDX]], 3] high[%[[HIGH]], 2]
2059 // CHECK: tensor.yield %[[PADVAL]]
2061 func.func @merge_constant_padding_dynamic(%arg0: tensor<?x?xf32>, %idx: index, %pad_value: f32) -> tensor<?x?xf32> {
2062 %pad0 = tensor.pad %arg0 low[%idx, 1] high[1, 0] {
2064 tensor.yield %pad_value : f32
2065 } : tensor<?x?xf32> to tensor<?x?xf32>
2066 %pad1 = tensor.pad %pad0 low[0, 2] high[%idx, 2] {
2068 tensor.yield %pad_value : f32
2069 } : tensor<?x?xf32> to tensor<?x?xf32>
2070 return %pad1 : tensor<?x?xf32>
2077 // CHECK: tensor.pad {{.*}} nofold
2078 // CHECK: tensor.pad
2079 func.func @dont_merge_constant_padding_nofold(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
2080 %pad0 = tensor.pad %arg0 nofold low[1, 1] high[1, 0] {
2082 tensor.yield %pad_value : f32
2083 } : tensor<2x3xf32> to tensor<4x4xf32>
2084 %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
2086 tensor.yield %pad_value : f32
2087 } : tensor<4x4xf32> to tensor<7x8xf32>
2088 return %pad1 : tensor<7x8xf32>
2095 // CHECK: tensor.pad
2096 // CHECK: tensor.pad
2098 %arg0: tensor<2x3xf32>,
2100 %pad_value1: f32) -> tensor<7x8xf32> {
2101 %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
2103 tensor.yield %pad_value0 : f32
2104 } : tensor<2x3xf32> to tensor<4x4xf32>
2105 %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
2107 tensor.yield %pad_value1 : f32
2108 } : tensor<4x4xf32> to tensor<7x8xf32>
2109 return %pad1 : tensor<7x8xf32>
2115 func.func @fold_collapse_shape_from_elements(%arg0: i32) -> tensor<i32> {
2116 // CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<i32>
2117 // CHECK: return %[[FROM]] : tensor<i32>
2118 %0 = tensor.from_elements %arg0 : tensor<1xi32>
2119 %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
2120 return %1 : tensor<i32>
2126 func.func @fold_expand_shape_from_elements(%arg0: i32) -> tensor<1xi32> {
2127 // CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<1xi32>
2128 // CHECK: return %[[FROM]] : tensor<1xi32>
2129 %0 = tensor.from_elements %arg0 : tensor<i32>
2130 %1 = tensor.expand_shape %0 [] output_shape [1] : tensor<i32> into tensor<1xi32>
2131 return %1 : tensor<1xi32>
2137 func.func @propagate_index_cast(%arg0: tensor<1xi32>) -> index {
2139 // CHECK: %[[EXT:.+]] = tensor.extract %arg0[%[[IDX]]] : tensor<1xi32>
2143 %0 = arith.index_cast %arg0 : tensor<1xi32> to tensor<1xindex>
2144 %1 = tensor.extract %0[%c0] : tensor<1xindex>
2151 func.func @splat_fold() -> tensor<4xf32> {
2153 %t = tensor.splat %c : tensor<4xf32>
2154 return %t : tensor<4xf32>
2156 // CHECK-NEXT: [[T:%.*]] = arith.constant dense<1.000000e+00> : tensor<4xf32>
2157 // CHECK-NEXT: return [[T]] : tensor<4xf32>
2164 func.func @splat_dynamic_no_fold(%m: index) -> tensor<4x?xf32> {
2168 // CHECK: tensor.splat %[[F]][%[[M]]] : tensor<4x?xf32>
2169 %t = tensor.splat %f[%m] : tensor<4x?xf32>
2170 return %t : tensor<4x?xf32>
2176 func.func @cast_extract_slice(%arg0 : tensor<128x512xf32>, %s : index, %o : index)
2177 -> tensor<16x512xf32> {
2178 // CHECK: %[[E:.*]] = tensor.extract_slice %{{.*}}[%{{.*}}, 0] [16, 512] [1, 1] : tensor<128x512xf32> to tensor<16x512xf32>
2179 %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] : tensor<128x512xf32> to tensor<?x512xf32>
2180 %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
2181 // CHECK: return %[[E]] : tensor<16x512xf32>
2182 return %1 : tensor<16x512xf32>
2188 func.func @cast_extract_slice_rank_reduce(%arg0 : tensor<128x512xf32>, %s : index, %o : index)
2189 -> tensor<16xf32> {
2190 // CHECK: %[[E:.*]] = tensor.extract_slice %{{.*}}[%{{.*}}, 0] [16, 1] [1, 1] : tensor<128x512xf32> to tensor<16xf32>
2191 %0 = tensor.extract_slice %arg0[%o, 0] [%s, 1] [1, 1] : tensor<128x512xf32> to tensor<?xf32>
2192 %1 = tensor.cast %0 : tensor<?xf32> to tensor<16xf32>
2193 // CHECK: return %[[E]] : tensor<16xf32>
2194 return %1 : tensor<16xf32>
2200 // CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>,
2201 // CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor<?x?xf32>,
2204 %arg0 : tensor<1x5xf32>, %arg1: tensor<?x?xf32>,
2205 %num_threads : index) -> tensor<?x?xf32>
2211 // CHECK-NOT: tensor.cast
2212 // CHECK: scf.forall (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) shared_outs(%[[o:.*]] = %[[arg1]]) -> (tensor<?x?xf32>) {
2214 // CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[o]][%[[tidx]], 0] [1, 5] [1, 1]
2215 %2 = scf.forall (%tidx) in (%num_threads) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) {
2216 %3 = tensor.cast %arg0 : tensor<1x5xf32> to tensor<?x5xf32>
2218 tensor.parallel_insert_slice %3 into %o[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor<?x5xf32> into tensor<?x?xf32>
2221 return %2 : tensor<?x?xf32>
2227 // CHECK-SAME: (%[[INPUT:.+]]: tensor<1x2x2x4xf32>)
2228 func.func @fold_insert_slice_after_extract_slice(%input: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> {
2230 %0 = tensor.extract_slice %input[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32>
2231 %1 = tensor.insert_slice %0 into %input[%c0, 0, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32>
2233 return %1: tensor<1x2x2x4xf32>
2239 func.func @dont_fold_mismatched_source_dst(%input0: tensor<1x2x2x4xf32>, %input1: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> {
2241 // CHECK: tensor.extract_slice
2242 %0 = tensor.extract_slice %input0[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32>
2243 // CHECK: tensor.insert_slice
2244 %1 = tensor.insert_slice %0 into %input1[%c0, 0, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32>
2245 return %1: tensor<1x2x2x4xf32>
2251 func.func @dont_fold_mismatched_parameters(%input: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> {
2253 // CHECK: tensor.extract_slice
2254 %0 = tensor.extract_slice %input[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32>
2255 // CHECK: tensor.insert_slice
2256 %1 = tensor.insert_slice %0 into %input[%c0, 1, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32>
2257 return %1: tensor<1x2x2x4xf32>
2262 func.func @empty_canonicalize() -> (tensor<4x5x?xf32>) {
2264 %0 = tensor.empty(%c6) : tensor<4x5x?xf32>
2265 return %0 : tensor<4x5x?xf32>
2268 // CHECK: %[[T0:.+]] = tensor.empty() : tensor<4x5x6xf32>
2269 // CHECK: %[[T1:.+]] = tensor.cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32>
2274 func.func @fold_empty_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> {
2275 %0 = tensor.empty(%arg0) : tensor<?x12xf32>
2276 %1 = tensor.cast %0 : tensor<?x12xf32> to tensor<1x12xf32>
2277 return %1 : tensor<1x12xf32>
2280 // CHECK: %[[T0:.+]] = tensor.empty() : tensor<1x12xf32>
2281 // CHECK: return %[[T0]] : tensor<1x12xf32>
2293 // CHECK-NOT: tensor.empty
2294 %0 = tensor.empty(%i) : tensor<?x42xf32>
2296 // CHECK-NOT: tensor.dim
2297 %1 = tensor.dim %0, %c0: tensor<?x42xf32>
2298 %2 = tensor.dim %0, %c1: tensor<?x42xf32>
2311 // CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
2313 // CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor<?x?xf32>
2316 func.func @dim_of_expand_shape(%t: tensor<?x?xf32>, %sz0: index, %sz1: index) -> index {
2318 %0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] output_shape [%sz0, 1, %sz1, 5, 1, 8]
2319 : tensor<?x?xf32> into tensor<?x1x?x5x1x8xf32>
2320 %1 = tensor.dim %0, %c2 : tensor<?x1x?x5x1x8xf32>
2328 // CHECK-SAME: %[[t:.*]]: tensor<?x?x?x7x?xf32>
2332 // CHECK-DAG: %[[dim1:.*]] = tensor.dim %[[t]], %[[c1]]
2333 // CHECK-DAG: %[[dim2:.*]] = tensor.dim %[[t]], %[[c2]]
2334 // CHECK-DAG: %[[dim4:.*]] = tensor.dim %[[t]], %[[c4]]
2337 func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
2339 %0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]]
2340 : tensor<?x?x?x7x?xf32> into tensor<?x?xf32>
2341 %1 = tensor.dim %0, %c1 : tensor<?x?xf32>
2349 // CHECK: %[[DIM:.*]] = tensor.dim
2351 func.func @out_of_bound_dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
2353 %0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]]
2354 : tensor<?x?x?x7x?xf32> into tensor<?x?xf32>
2355 %1 = tensor.dim %0, %c5 : tensor<?x?xf32>
2362 // CHECK-SAME: %[[t:.*]]: tensor<?xf32>
2364 func.func @collapse_expand_fold_to_cast(%t: tensor<?xf32>, %sz0: index) -> (tensor<?xf32>)
2366 %0 = tensor.expand_shape %t [[0, 1]] output_shape [1, %sz0] : tensor<?xf32> into tensor<1x?xf32>
2367 %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
2368 return %1 : tensor<?xf32>
2375 // CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>)
2376 // CHECK: return %[[T]] : tensor<128x128xf32>
2377 func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
2378 %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32>
2379 %packed = tensor.pack %t inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
2380 %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
2381 %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
2382 return %unpacked : tensor<128x128xf32>
2389 // CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>)
2390 // CHECK-NOT: return %[[T]] : tensor<128x128xf32>
2391 func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
2392 %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32>
2393 %packed = tensor.pack %t inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
2394 %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
2395 %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor
2397 return %unpacked : tensor<128x128xf32>
2404 // CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>)
2405 // CHECK-NOT: return %[[T]] : tensor<128x128xf32>
2406 func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
2407 %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32>
2408 %packed = tensor.pack %t outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
2409 %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
2410 %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor
2412 return %unpacked : tensor<128x128xf32>
2419 // CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>,
2420 // CHECK: return %[[T]] : tensor<128x128xf32>
2421 func.func @unpack_pack(%t: tensor<128x128xf32>, %tile1: index, %tile2: index) -> tensor<128x128xf32> {
2422 %tensor_empty = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
2423 %packed = tensor.pack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
2424 %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
2425 %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<16x16x?x?xf32> -> tensor
2427 return %unpacked : tensor<128x128xf32>
2433 // CHECK: tensor.pack
2434 // CHECK: tensor.unpack
2435 func.func @unpack_pack_with_padding_no_canonicalization(%t: tensor<256x512xbf16>) -> tensor<224x512xbf16> {
2436 %tensor_empty = tensor.empty() : tensor<4x16x64x32xbf16>
2437 %tensor_empty1 = tensor.empty() : tensor<224x512xbf16>
2438 %packed = tensor.pack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty : tensor<256x512xbf16> -> tensor<4x16x64x32xbf16>
2439 %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty1 : tensor<4x16x64x32xbf16> -> tensor<224x512xbf16>
2440 return %unpacked : tensor<224x512xbf16>
2447 // CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
2448 // CHECK: return %[[T]] : tensor<16x16x?x?xf32>
2449 func.func @pack_unpack(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
2450 %tensor_empty = tensor.empty() : tensor<128x128xf32>
2451 %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
2452 %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
2453 %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
2454 return %packed : tensor<16x16x?x?xf32>
2461 // CHECK-SAME: %[[T:.+]]: tensor<16x16x8x8xf32>
2462 // CHECK: return %[[T]] : tensor<16x16x8x8xf32>
2463 func.func @pack_unpack(%t: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
2464 %tensor_empty = tensor.empty() : tensor<128x128xf32>
2465 %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
2466 %tensor_empty1 = tensor.empty() : tensor<16x16x8x8xf32>
2467 %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
2468 return %packed : tensor<16x16x8x8xf32>
2474 // CHECK-SAME: %[[T:.+]]: tensor<?x?x?x?xf32>,
2475 // CHECK: return %[[T]] : tensor<?x?x?x?xf32>
2476 func.func @pack_unpack_same_tiles(%t: tensor<?x?x?x?xf32>, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index,
2477 %tile1: index, %tile2: index) -> tensor<?x?x?x?xf32> {
2478 %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
2479 %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
2480 %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor<?x?x?x?xf32>
2481 %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
2482 return %packed : tensor<?x?x?x?xf32>
2488 // CHECK-SAME: %[[T:.+]]: tensor<?x?x?x?xf32>,
2489 // CHECK-NOT: return %[[T]] : tensor<?x?x?x?xf32>
2490 func.func @pack_unpack_different_tiles(%t: tensor<?x?x?x?xf32>, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index,
2491 %tile1: index, %tile2: index) -> tensor<?x?x?x?xf32> {
2492 %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
2493 %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
2494 %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor<?x?x?x?xf32>
2495 %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile2, %tile1] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
2496 return %packed : tensor<?x?x?x?xf32>
2502 // CHECK-SAME: %[[T:.+]]: tensor<?x?x?x?xf32>,
2503 // CHECK-NOT: return %[[T]] : tensor<?x?x?x?xf32>
2504 func.func @pack_unpack_dynamic_with_padding(%t: tensor<?x?x?x?xf32>, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index,
2505 %tile1: index, %tile2: index, %pad: f32) -> tensor<?x?x?x?xf32> {
2506 %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
2507 %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
2508 %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor<?x?x?x?xf32>
2509 %packed = tensor.pack %unpacked padding_value(%pad: f32) inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
2510 return %packed : tensor<?x?x?x?xf32>
2516 // CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
2517 // CHECK: return %[[T]] : tensor<16x16x?x?xf32>
2518 func.func @pack_outer_dims_unpack_no_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
2519 %tensor_empty = tensor.empty() : tensor<128x128xf32>
2520 %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
2521 %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
2522 %packed = tensor.pack %unpacked outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
2523 return %packed : tensor<16x16x?x?xf32>
2529 // CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
2530 // CHECK: return %[[T]] : tensor<16x16x?x?xf32>
2531 func.func @pack_no_outer_dims_unpack_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
2532 %tensor_empty = tensor.empty() : tensor<128x128xf32>
2533 %unpacked = tensor.unpack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
2534 %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
2535 %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
2536 return %packed : tensor<16x16x?x?xf32>
2543 // CHECK: %[[T:.*]] = tensor.empty(%[[IDX]]) : tensor<4x5x?xf32>
2544 func.func @invalid_empty_negative_size() -> (tensor<4x5x?xf32>) {
2548 %1 = tensor.empty(%0) : tensor<4x5x?xf32>
2549 return %1 : tensor<4x5x?xf32>
2554 // Fold DstStyleOp -> tensor.unpack operations.
2555 func.func @fold_dst_style_ops_into_unpack(%arg0 : tensor<?x?x16x64xf32>, %init : tensor<?x?xf32>) -> tensor<?x?xf32> {
2557 %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
2558 %unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %fill : tensor<?x?x16x64xf32> -> tensor<?x?xf32>
2559 return %unpack : tensor<?x?xf32>
2562 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x16x64xf32>
2563 // CHECK-SAME: %[[INIT:.+]]: tensor<?x?xf32>
2564 // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
2575 // CHECK: tensor.extract_slice {{.*}}%[[c]]
2576 // CHECK: tensor.insert_slice {{.*}}%[[c]]
2577 func.func @invalid_slice_ops(%t: tensor<?xf32>, %t2: tensor<?xf32>) -> tensor<?xf32> {
2579 %0 = tensor.extract_slice %t[0][%c][1] : tensor<?xf32> to tensor<?xf32>
2580 %1 = tensor.insert_slice %0 into %t2[2][%c][1] : tensor<?xf32> into tensor<?xf32>
2581 return %1 : tensor<?xf32>
2588 // CHECK: tensor.generate %[[c]]
2589 // CHECK: : tensor<?x8xi32>
2590 func.func @generate_negative_size_verifies() -> tensor<?x8xi32> {
2594 %tensor = tensor.generate %size {
2596 tensor.yield %cst : i32
2597 } : tensor<?x8xi32>
2598 return %tensor : tensor<?x8xi32>
2603 func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> tensor<10x20x4x4xf32> {
2606 %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
2607 %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty : tensor<10x20x4x4xf32> -> tensor<?x?xf32>
2608 %cast = tensor.cast %unpacked : tensor<?x?xf32> to tensor<40x80xf32>
2609 %tensor_empty1 = tensor.empty() : tensor<10x20x4x4xf32>
2610 %packed = tensor.pack %cast inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty1 : tensor<40x80xf32> -> tensor<10x20x4x4xf32>
2611 return %packed : tensor<10x20x4x4xf32>
2619 // Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
2621 // CHECK-SAME: %[[MEM:[0-9a-z]+]]: tensor<*xf32>,
2622 // CHECK-SAME: %[[SHP:[0-9a-z]+]]: tensor<?xindex>
2624 // CHECK-NEXT: %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]]
2625 // CHECK-NOT: tensor.store
2626 // CHECK-NOT: tensor.dim
2627 // CHECK-NOT: tensor.reshape
2629 func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>)
2632 %0 = tensor.reshape %arg0(%arg1)
2633 : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
2635 tensor.insert %c3 into %arg1[%c3] : tensor<?xindex>
2636 %1 = tensor.dim %0, %c3 : tensor<*xf32>
2642 // Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
2644 // CHECK: tensor.extract
2646 // CHECK-NOT: tensor.dim
2647 // CHECK-NOT: tensor.reshape
2649 func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
2652 %0 = tensor.reshape %arg0(%arg1)
2653 : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
2654 %1 = tensor.dim %0, %c3 : tensor<*xf32>
2660 // Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
2663 // CHECK-NEXT: tensor.extract
2664 // CHECK-NOT: tensor.dim
2665 // CHECK-NOT: tensor.reshape
2666 func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index {
2671 %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
2674 %2 = tensor.dim %0, %arg2 : tensor<*xf32>
2683 // Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
2686 // CHECK-NEXT: tensor.extract
2687 // CHECK-NOT: tensor.dim
2688 // CHECK-NOT: tensor.reshape
2689 func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index {
2691 %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
2693 %dim = tensor.dim %reshape, %0 : tensor<*xf32>
2700 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
2701 func.func @reshape_fold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
2704 %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
2705 %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
2706 %ds = tensor.from_elements %d0, %d1 : tensor<2xindex>
2707 %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
2709 return %reshape : tensor<?x?xi32>
2715 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
2716 func.func @reshape_nofold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
2719 %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
2720 %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
2721 %ds = tensor.from_elements %d1, %d0 : tensor<2xindex>
2722 // CHECK: tensor.reshape
2723 %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
2724 return %reshape : tensor<?x?xi32>
2730 func.func @reshape_nofold_2d_ins(%arg0 : tensor<?x?xi32>, %arg1: index, %arg2: index) -> tensor<?x?xi32> {
2731 %ds = tensor.from_elements %arg1, %arg2 : tensor<2xindex>
2732 // CHECK: tensor.reshape
2733 %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
2734 return %reshape : tensor<?x?xi32>
2740 // CHECK-SAME: %[[ARG0:.+]]: tensor<5x?x?xi32>
2741 func.func @reshape_fold_3d_cst(%arg0 : tensor<5x?x?xi32>) -> tensor<5x?x?xi32> {
2745 %d1 = tensor.dim %arg0, %c1 : tensor<5x?x?xi32>
2746 %d2 = tensor.dim %arg0, %c2 : tensor<5x?x?xi32>
2747 %ds = tensor.from_elements %d0, %d1, %d2 : tensor<3xindex>
2748 %reshape = tensor.reshape %arg0(%ds) : (tensor<5x?x?xi32>, tensor<3xindex>) -> tensor<5x?x?xi32>
2750 return %reshape : tensor<5x?x?xi32>
2755 // Test case: This test fails to fold because the index of tensor.dim is out_of_bounds
2759 // CHECK-NEXT: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[IDX]]
2768 %3 = bufferization.alloc_tensor(%c29) : tensor<?xi16>
2769 %dim = tensor.dim %3, %idx28 : tensor<?xi16>
2778 // CHECK-SAME: %[[ARG1:.*]]: tensor<2x2xf32>,
2779 // CHECK-SAME: %[[ARG2:.*]]: tensor<2x2xf32>) -> index {
2780 // CHECK: %[[RES:.*]]:2 = test.destination_style_op ins(%[[ARG1]] : tensor<2x2xf32>)
2781 // CHECK-SAME: outs(%[[ARG2]] : tensor<2x2xf32>) -> tensor<2x2xf32>, index
2783 func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index {
2784 %cast = tensor.cast %arg0 : tensor<2x2xf32> to tensor<?x2xf32>
2785 %cast_0 = tensor.cast %arg1 : tensor<2x2xf32> to tensor<?x2xf32>
2786 %0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
2793 // CHECK-SAME: %[[DEST:.*]]: tensor<1x1x8x1xi32>,
2794 // CHECK-SAME: %[[SRC:.*]]: tensor<7x?xi32>,
2795 // CHECK-SAME: %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> {
2796 // CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32)
2799 // CHECK-SAME: : tensor<7x?xi32> -> tensor<1x1x8x1xi32>
2800 // CHECK: return %[[PACK]] : tensor<1x1x8x1xi32>
2802 %dest: tensor<1x1x8x1xi32>,
2803 %src: tensor<7x?xi32>,
2804 %pad: i32) -> tensor<1x1x8x1xi32> {
2806 %cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
2808 %pack = tensor.pack %src padding_value(%pad : i32)
2811 into %cast {test_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
2812 %res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32>
2813 return %res : tensor<1x1x8x1xi32>
2819 // CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>,
2820 // CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> {
2821 // CHECK: %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {test_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
2822 // CHECK: return %[[RES]] : tensor<7x?xi32>
2824 %src: tensor<1x1x8x1xi32>,
2825 %res: tensor<7x?xi32>) -> tensor<7x?xi32> {
2827 %cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
2829 %unpack = tensor.unpack %cast
2832 into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
2833 return %unpack : tensor<7x?xi32>
2839 // CHECK: tensor.pack {{.*}} {test_attr}
2840 func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> {
2843 %pack = tensor.pack %arg0 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %arg1 {test_attr} : tensor<?x?x?xf16> -> tensor<128x?x100x16x1xf16>
2844 return %pack : tensor<128x?x100x16x1xf16>
2849 func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
2850 -> tensor<10x1x10xf32> {
2853 %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
2854 %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
2855 : tensor<?x?xf32> into tensor<?x?x?xf32>
2856 %2 = tensor.cast %1 : tensor<?x?x?xf32> to tensor<10x1x10xf32>
2857 return %2 : tensor<10x1x10xf32>
2860 // CHECK: %[[RES:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10]
2865 func.func @sink_expand_of_cast(%arg0 : tensor<?x10xf32>)
2866 -> tensor<?x?x?xf32> {
2869 %0 = tensor.cast %arg0 : tensor<?x10xf32> to tensor<?x?xf32>
2870 %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
2871 : tensor<?x?xf32> into tensor<?x?x?xf32>
2872 return %1 : tensor<?x?x?xf32>
2877 // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]]
2879 // CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]]
2884 func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index, %arg2 : index)
2885 -> tensor<?x?x?xf32> {
2887 %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
2888 %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %c10]
2889 : tensor<?x?xf32> into tensor<?x?x?xf32>
2890 return %1 : tensor<?x?x?xf32>
2893 // CHECK: %[[CAST:.+]] = tensor.cast
2894 // CHECK-SAME: tensor<10x10xf32> to tensor<?x10xf32>
2895 // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]]
2897 // CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]]
2898 // CHECK-SAME: tensor<?x?x10xf32> to tensor<?x?x?xf32>