xref: /llvm-project/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir (revision f566b079f171f28366a66b8afa4a975bc4005529)
1// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-drop-redundant-insert-slice-rank-expansion %s | FileCheck %s
2
3// CHECK-LABEL: func @test_drop_rank_expansion(
4//  CHECK-SAME:     %[[src:.*]]: tensor<128x480xf32>,
5//       CHECK:   %[[extract:.*]] = tensor.extract_slice %[[src]][0, 0] [123, 456] [1, 1] : tensor<128x480xf32> to tensor<123x456xf32>
6//       CHECK:   return %[[extract]]
7func.func @test_drop_rank_expansion(%src: tensor<128x480xf32>, %dest: tensor<1x1x128x480xf32>) -> tensor<123x456xf32> {
8  %inserted_slice = tensor.insert_slice %src into %dest[0, 0, 0, 0] [1, 1, 128, 480] [1, 1, 1, 1] : tensor<128x480xf32> into tensor<1x1x128x480xf32>
9  %extracted_slice = tensor.extract_slice %inserted_slice[0, 0, 0, 0] [1, 1, 123, 456] [1, 1, 1, 1] : tensor<1x1x128x480xf32> to tensor<123x456xf32>
10  return %extracted_slice : tensor<123x456xf32>
11}
12
13// -----
14
15func.func @fold_casting_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<8x1x8xf32>) -> tensor<8x1x8xf32> {
16  %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
17  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [8, 1, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<8x1x8xf32>
18  return %inserted_slice : tensor<8x1x8xf32>
19}
20// CHECK-LABEL: func.func @fold_casting_insert_slice_of_extract_slice(
21// CHECK-SAME:      %[[ARG0:.*]]: tensor<?x8x2x8xf32>
22// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1]
23// CHECK-SAME:      : tensor<?x8x2x8xf32> to tensor<8x1x8xf32>
24// CHECK:         return %[[EXTRACTED_SLICE]] : tensor<8x1x8xf32>
25
26// -----
27
28func.func @fold_casting_insert_slice_of_strided_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x8xf32>) -> tensor<1x4x8xf32> {
29  %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1] : tensor<?x8x2x8xf32> to tensor<4x8xf32>
30  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 4, 8] [1, 1, 1] : tensor<4x8xf32> into tensor<1x4x8xf32>
31  return %inserted_slice : tensor<1x4x8xf32>
32}
33// CHECK-LABEL: func.func @fold_casting_insert_slice_of_strided_extract_slice(
34// CHECK-SAME:      %[[ARG0:.*]]: tensor<?x8x2x8xf32>
35// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1]
36// CHECK-SAME:      : tensor<?x8x2x8xf32> to tensor<1x4x8xf32>
37// CHECK:         return %[[EXTRACTED_SLICE]] : tensor<1x4x8xf32>
38
39// -----
40
41func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(%in : tensor<?x8x8xf32>, %dest : tensor<1x1x8x8xf32>) -> tensor<1x1x8x8xf32> {
42  %extracted_slice = tensor.extract_slice %in[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<?x8x8xf32> to tensor<8x8xf32>
43  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0, 0] [1, 1, 8, 8] [1, 1, 1, 1] : tensor<8x8xf32> into tensor<1x1x8x8xf32>
44  return %inserted_slice : tensor<1x1x8x8xf32>
45}
46// CHECK-LABEL: func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(
47// CHECK-SAME:      %[[ARG0:.*]]: tensor<?x8x8xf32>
48// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
49// CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
50// CHECK:         return %[[INSERTED_SLICE]] : tensor<1x1x8x8xf32>
51
52// -----
53
54func.func @no_fold_strided_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x4xf32>) -> tensor<1x4x4xf32> {
55  %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
56  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 2, 2] : tensor<8x8xf32> into tensor<1x4x4xf32>
57  return %inserted_slice : tensor<1x4x4xf32>
58}
59// CHECK-LABEL: func.func @no_fold_strided_insert_slice_of_extract_slice(
60// CHECK-SAME:      %[[ARG0:.*]]: tensor<?x8x2x8xf32>
61// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
62// CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
63// CHECK:         return %[[INSERTED_SLICE]] : tensor<1x4x4xf32>
64
65// -----
66
67func.func @no_fold_non_casting_insert_slice_of_extract_slice(%in : tensor<1x1x1x8x8xf32>, %dest : tensor<2x8x8xf32>) -> tensor<2x8x8xf32> {
68  %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0, 0] [1, 1, 1, 8, 8] [1, 1, 1, 1, 1] : tensor<1x1x1x8x8xf32> to tensor<8x8xf32>
69  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<2x8x8xf32>
70  return %inserted_slice : tensor<2x8x8xf32>
71}
72// CHECK-LABEL: func.func @no_fold_non_casting_insert_slice_of_extract_slice(
73// CHECK-SAME:      %[[ARG0:.*]]: tensor<1x1x1x8x8xf32>
74// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
75// CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
76// CHECK:         return %[[INSERTED_SLICE]] : tensor<2x8x8xf32>
77