xref: /llvm-project/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir (revision 3a8f161a3401edeb58e018e2d389dd2413a6417f)
1// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-consecutive-insert-extract-slice -canonicalize -mlir-print-local-scope %s | FileCheck %s
2
3func.func @extract_slice_same_rank(
4    %src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<8x16x32x?xf32> {
5  %0 = tensor.extract_slice %src[0, 1, 2, %offset0] [128, 128, 128, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x128x128x?xf32>
6  %1 = tensor.extract_slice %0[7, 8, 9, %offset1] [8, 16, 32, %size1] [1, 1, 1, 1] : tensor<128x128x128x?xf32> to tensor<8x16x32x?xf32>
7  return %1: tensor<8x16x32x?xf32>
8}
9
10// CHECK-LABEL: func.func @extract_slice_same_rank
11//  CHECK-SAME: (%[[SOURCE:.+]]: tensor<?x?x?x?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index)
12//       CHECK:   %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET1]], %[[OFFSET0]]]
13//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][7, 9, 11, %[[OFFSET]]] [8, 16, 32, %[[SIZE1]]] [1, 1, 1, 1]
14//       CHECK:   return %[[EXTRACT]] : tensor<8x16x32x?xf32>
15
16// -----
17
18func.func @extract_slice_rank_reducing_consumer(
19    %src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<16x?xf32> {
20  %0 = tensor.extract_slice %src[0, 1, 2, %offset0] [128, 128, 128, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x128x128x?xf32>
21  %1 = tensor.extract_slice %0[7, 8, 9, %offset1] [1, 16, 1, %size1] [1, 1, 1, 1] : tensor<128x128x128x?xf32> to tensor<16x?xf32>
22  return %1: tensor<16x?xf32>
23}
24
25// CHECK-LABEL: func.func @extract_slice_rank_reducing_consumer
26//       CHECK:   tensor.extract_slice %{{.+}}[7, 9, 11, %{{.+}}] [1, 16, 1, %{{.+}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<16x?xf32>
27
28// -----
29
30func.func @extract_slice_rank_reducing_producer(
31    %src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<8x?xf32> {
32  %0 = tensor.extract_slice %src[0, 1, 2, %offset0] [1, 128, 1, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x?xf32>
33  %1 = tensor.extract_slice %0[7, %offset1] [8, %size1] [1, 1] : tensor<128x?xf32> to tensor<8x?xf32>
34  return %1: tensor<8x?xf32>
35}
36
37// CHECK-LABEL: func.func @extract_slice_rank_reducing_producer
38//  CHECK-SAME: (%[[SRC:.+]]: tensor<?x?x?x?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index)
39//       CHECK:   %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET1]], %[[OFFSET0]]]
40//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice %[[SRC]][0, 8, 2, %[[OFFSET]]] [1, 8, 1, %[[SIZE1]]] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<8x?xf32>
41//       CHECK:   return %[[EXTRACT]] : tensor<8x?xf32>
42
43// -----
44
45func.func @extract_slice_non_one_stride(
46    %src: tensor<?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index, %stride0: index, %stride1: index) -> tensor<?xf32> {
47  %0 = tensor.extract_slice %src[%offset0] [%size0] [%stride0] : tensor<?xf32> to tensor<?xf32>
48  %1 = tensor.extract_slice %0[%offset1] [%size1] [%stride1] : tensor<?xf32> to tensor<?xf32>
49  return %1: tensor<?xf32>
50}
51
52// CHECK-LABEL: func.func @extract_slice_non_one_stride
53//  CHECK-SAME: (%[[SRC:.+]]: tensor<?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index, %[[STRIDE0:.+]]: index, %[[STRIDE1:.+]]: index)
54//       CHECK:   %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>()[%[[OFFSET1]], %[[STRIDE0]], %[[OFFSET0]]]
55//       CHECK:   %[[STRIDE:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%[[STRIDE1]], %[[STRIDE0]]]
56//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice %[[SRC]][%[[OFFSET]]] [%[[SIZE1]]] [%[[STRIDE]]] : tensor<?xf32> to tensor<?xf32>
57//       CHECK:   return %[[EXTRACT]] : tensor<?xf32>
58
59// -----
60
61func.func @insert_slice_rank_reducing(
62    %dst: tensor<128x128x128x128xf32>, %mid: tensor<1x16x1xf32>, %src: tensor<16xf32>, %offset: index) -> tensor<128x128x128x128xf32> {
63  %0 = tensor.insert_slice %src into %mid[0, 0, 0] [1, 16, 1] [1, 1, 1] : tensor<16xf32> into tensor<1x16x1xf32>
64  %1 = tensor.insert_slice %0 into %dst[6, 7, 8, %offset] [1, 1, 16, 1] [1, 1, 1, 1] : tensor<1x16x1xf32> into tensor<128x128x128x128xf32>
65  return %1: tensor<128x128x128x128xf32>
66}
67
68// CHECK-LABEL: func.func @insert_slice_rank_reducing
69//  CHECK-SAME: (%[[DST:.+]]: tensor<128x128x128x128xf32>, %{{.+}}: tensor<1x16x1xf32>, %[[SRC:.+]]: tensor<16xf32>, %[[IDX:.+]]: index)
70//       CHECK:  %[[INSERT:.+]] = tensor.insert_slice %[[SRC]] into %[[DST]][6, 7, 8, %[[IDX]]] [1, 1, 16, 1] [1, 1, 1, 1]
71//       CHECK:  return %[[INSERT]]
72
73// -----
74
75func.func @insert_slice_rank_reducing_dynamic_shape(
76    %dst: tensor<128x128x128x128xf32>, %mid: tensor<1x?x1xf32>, %src: tensor<?xf32>, %offset: index, %size: index) -> tensor<128x128x128x128xf32> {
77  %0 = tensor.insert_slice %src into %mid[0, 0, 0] [1, %size, 1] [1, 1, 1] : tensor<?xf32> into tensor<1x?x1xf32>
78  %1 = tensor.insert_slice %0 into %dst[6, 7, 8, %offset] [1, 1, %size, 1] [1, 1, 1, 1] : tensor<1x?x1xf32> into tensor<128x128x128x128xf32>
79  return %1: tensor<128x128x128x128xf32>
80}
81
82//   CHECK-LABEL: func.func @insert_slice_rank_reducing_dynamic_shape
83// CHECK-COUNT-2:   tensor.insert_slice
84
85// -----
86
87// CHECK-LABEL: func.func @parallel_insert_slice
88//   CHECK-NOT:   tensor.insert_slice
89//       CHECK:   tensor.parallel_insert_slice %{{.*}} into %{{.*}}[0, %{{.*}}] [1, 1] [1, 1] : tensor<f32> into tensor<1x2xf32>
90func.func @parallel_insert_slice(%t0: tensor<1x2xf32>, %t1: tensor<f32>, %t2: tensor<1x1xf32>) -> tensor<1x2xf32> {
91  %c1 = arith.constant 1 : index
92  %c2 = arith.constant 2 : index
93  %r = scf.forall (%arg2, %arg3) in (%c1, %c2) shared_outs(%arg4 = %t0) -> (tensor<1x2xf32>) {
94    %inserted_slice = tensor.insert_slice %t1 into %t2[0, 0] [1, 1] [1, 1] : tensor<f32> into tensor<1x1xf32>
95    scf.forall.in_parallel {
96      tensor.parallel_insert_slice %inserted_slice into %arg4[%arg2, %arg3] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<1x2xf32>
97    }
98  }
99  return %r : tensor<1x2xf32>
100}
101