xref: /llvm-project/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir (revision 92d38adb83f4e4e8257d092adeffba9132aa4830)
1// RUN: mlir-opt %s -test-linalg-transform-patterns=test-swap-subtensor-padtensor -canonicalize  -split-input-file | FileCheck %s
2
3// CHECK-LABEL: @static_data_only(
4//  CHECK-SAME:     %[[ARG0:.*]]: tensor<4x5xf32>
5//       CHECK:   %[[RESULT:.*]] = tensor.extract_slice %[[ARG0]][1, 2] [2, 1] [1, 1] : tensor<4x5xf32> to tensor<2x1xf32>
6//       CHECK:   return %[[RESULT]]
7func.func @static_data_only(%arg0 : tensor<4x5xf32>, %pad : f32)
8    -> tensor<2x1xf32> {
9  %0 = tensor.pad %arg0 low[0, 0] high[7, 8] {
10    ^bb0(%arg1: index, %arg2: index):
11      tensor.yield %pad : f32
12    } : tensor<4x5xf32> to tensor<11x13xf32>
13  %1 = tensor.extract_slice %0[1, 2] [2, 1] [1, 1] : tensor<11x13xf32> to tensor<2x1xf32>
14  return %1 : tensor<2x1xf32>
15}
16
17// -----
18
19// CHECK-LABEL: @static_high_pad_only
20//  CHECK-SAME:   %[[ARG0:.*]]: tensor<4x5xf32>, %[[PAD:.*]]: f32
21//   CHECK-NOT:   tensor.pad
22//   CHECK-NOT:   tensor.extract_slice
23//       CHECK:   %[[RESULT:.*]] = tensor.generate
24//       CHECK:     tensor.yield %[[PAD]]
25//       CHECK:   return %[[RESULT]] : tensor<2x4xf32>
26func.func @static_high_pad_only(%arg0 : tensor<4x5xf32>, %pad : f32)
27    -> tensor<2x4xf32> {
28  %0 = tensor.pad %arg0 low[0, 0] high[7, 8] {
29    ^bb0(%arg1: index, %arg2: index):
30      tensor.yield %pad : f32
31    } : tensor<4x5xf32> to tensor<11x13xf32>
32  %1 = tensor.extract_slice %0[4, 5] [2, 4] [1, 1] : tensor<11x13xf32> to tensor<2x4xf32>
33  return %1 : tensor<2x4xf32>
34}
35
36// -----
37
38// CHECK-LABEL: @static_low_pad_only
39//  CHECK-SAME:   %[[ARG0:.*]]: tensor<4x5xf32>, %[[PAD:.*]]: f32
40//   CHECK-NOT:   tensor.pad
41//   CHECK-NOT:   tensor.extract_slice
42//       CHECK:   %[[RESULT:.*]] = tensor.generate
43//       CHECK:     tensor.yield %[[PAD]]
44//       CHECK:   return %[[RESULT]] : tensor<2x3xf32>
45func.func @static_low_pad_only(%arg0 : tensor<4x5xf32>, %pad : f32)
46    -> tensor<2x3xf32> {
47  %0 = tensor.pad %arg0 low[3, 7] high[7, 8] {
48    ^bb0(%arg1: index, %arg2: index):
49      tensor.yield %pad : f32
50    } : tensor<4x5xf32> to tensor<14x20xf32>
51  %1 = tensor.extract_slice %0[1, 3] [2, 3] [1, 1] : tensor<14x20xf32> to tensor<2x3xf32>
52  return %1 : tensor<2x3xf32>
53}
54
55// -----
56
57// CHECK-LABEL: @static_low_pad_only_2
58//  CHECK-SAME:   %[[ARG0:.*]]: tensor<4x5xf32>, %[[PAD:.*]]: f32
59//   CHECK-NOT:   tensor.pad
60//   CHECK-NOT:   tensor.extract_slice
61//       CHECK:   %[[RESULT:.*]] = tensor.generate
62//       CHECK:     tensor.yield %[[PAD]]
63//       CHECK:   return %[[RESULT]] : tensor<1x3xf32>
64func.func @static_low_pad_only_2(%arg0 : tensor<4x5xf32>, %pad : f32)
65    -> tensor<1x3xf32> {
66  %0 = tensor.pad %arg0 low[3, 7] high[7, 8] {
67    ^bb0(%arg1: index, %arg2: index):
68      tensor.yield %pad : f32
69    } : tensor<4x5xf32> to tensor<14x20xf32>
70  %1 = tensor.extract_slice %0[1, 3] [1, 3] [1, 1] : tensor<14x20xf32> to tensor<1x3xf32>
71  return %1 : tensor<1x3xf32>
72}
73
74// -----
75
76// CHECK-LABEL: @static_mixed_data_high_pad
77//  CHECK-SAME:   %[[ARG0:.*]]: tensor<4x5xf32>, %[[PAD:.*]]: f32
78//   CHECK-NOT:   tensor.pad
79//       CHECK:   %[[SUBTENSOR:.*]] = tensor.extract_slice %[[ARG0]][2, 4] [2, 1] [1, 1] : tensor<4x5xf32> to tensor<2x1xf32>
80//       CHECK:   %[[RESULT:.*]] = tensor.pad %[[SUBTENSOR]] low[0, 0] high[1, 3]
81//       CHECK:     tensor.yield %[[PAD]]
82//       CHECK:   return %[[RESULT]] : tensor<3x4xf32>
83func.func @static_mixed_data_high_pad(%arg0 : tensor<4x5xf32>, %pad : f32)
84    -> tensor<3x4xf32> {
85  %0 = tensor.pad %arg0 low[0, 0] high[7, 8] {
86    ^bb0(%arg1: index, %arg2: index):
87      tensor.yield %pad : f32
88    } : tensor<4x5xf32> to tensor<11x13xf32>
89  %1 = tensor.extract_slice %0[2, 4] [3, 4] [1, 1] : tensor<11x13xf32> to tensor<3x4xf32>
90  return %1 : tensor<3x4xf32>
91}
92
93// -----
94
95// CHECK-LABEL: @static_mixed_data_low_pad
96//  CHECK-SAME:   %[[ARG0:.*]]: tensor<4x5xf32>, %[[PAD:.*]]: f32
97//   CHECK-NOT:   tensor.pad
98//       CHECK:   %[[SUBTENSOR:.*]] = tensor.extract_slice %[[ARG0]][0, 0] [2, 1] [1, 1] : tensor<4x5xf32> to tensor<2x1xf32>
99//       CHECK:   %[[RESULT:.*]] = tensor.pad %[[SUBTENSOR]] low[1, 3] high[0, 0]
100//       CHECK:     tensor.yield %[[PAD]]
101//       CHECK:   return %[[RESULT]] : tensor<3x4xf32>
102func.func @static_mixed_data_low_pad(%arg0 : tensor<4x5xf32>, %pad : f32)
103    -> tensor<3x4xf32> {
104  %0 = tensor.pad %arg0 low[3, 7] high[7, 8] {
105    ^bb0(%arg1: index, %arg2: index):
106      tensor.yield %pad : f32
107    } : tensor<4x5xf32> to tensor<14x20xf32>
108  %1 = tensor.extract_slice %0[2, 4] [3, 4] [1, 1] : tensor<14x20xf32> to tensor<3x4xf32>
109  return %1 : tensor<3x4xf32>
110}
111
112// -----
113
114// CHECK-LABEL: @static_mixed_data_low_high_pad
115//  CHECK-SAME:   %[[ARG0:.*]]: tensor<4x5xf32>, %[[PAD:.*]]: f32
116//   CHECK-NOT:   tensor.pad
117//       CHECK:   %[[RESULT:.*]] = tensor.pad %[[ARG0]] low[1, 1] high[2, 3]
118//       CHECK:     tensor.yield %[[PAD]]
119//       CHECK:   return %[[RESULT]] : tensor<7x9xf32>
120func.func @static_mixed_data_low_high_pad(%arg0 : tensor<4x5xf32>, %pad : f32)
121    -> tensor<7x9xf32> {
122  %0 = tensor.pad %arg0 low[2, 3] high[7, 8] {
123    ^bb0(%arg1: index, %arg2: index):
124      tensor.yield %pad : f32
125    } : tensor<4x5xf32> to tensor<13x16xf32>
126  %1 = tensor.extract_slice %0[1, 2] [7, 9] [1, 1] : tensor<13x16xf32> to tensor<7x9xf32>
127  return %1 : tensor<7x9xf32>
128}
129
130// -----
131
132// CHECK-LABEL: @dynamic_high_pad
133//  CHECK-SAME:     %[[ARG0:.*]]: tensor<?x5xf32>
134//   CHECK-NOT:   tensor.pad
135//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
136//       CHECK:   tensor.dim %[[ARG0]], %[[C0]]
137//       CHECK:   %[[RESULT:.*]] = scf.if %{{.*}} -> (tensor<3x4xf32>) {
138//       CHECK:     %[[GEN:.*]] = tensor.generate
139//       CHECK:     scf.yield %[[GEN]]
140//       CHECK:   } else {
141//       CHECK:     %[[SUBTENSOR:.*]] = tensor.extract_slice %[[ARG0]][%{{.*}}, 4] [%{{.*}}, 1] [1, 1] : tensor<?x5xf32> to tensor<?x1xf32>
142//       CHECK:     %[[PADTENSOR:.*]] = tensor.pad %[[SUBTENSOR]] low[0, 0] high[%{{.*}}, 3]
143//       CHECK:     scf.yield %[[PADTENSOR]]
144//       CHECK:   }
145//       CHECK:   return %[[RESULT]]
146func.func @dynamic_high_pad(%arg0 : tensor<?x5xf32>, %h1: index, %pad : f32) -> tensor<3x4xf32> {
147  %0 = tensor.pad %arg0 low[0, 0] high[%h1, 8] {
148    ^bb0(%arg1: index, %arg2: index):
149      tensor.yield %pad : f32
150    } : tensor<?x5xf32> to tensor<?x13xf32>
151  %1 = tensor.extract_slice %0[2, 4] [3, 4] [1, 1] : tensor<?x13xf32> to tensor<3x4xf32>
152  return %1 : tensor<3x4xf32>
153}
154
155// -----
156
157// CHECK-LABEL: @dynamic_extract_size
158//  CHECK-SAME:     %[[ARG0:.*]]: tensor<?x5xf32>, %[[ARG1:.*]]: index
159//   CHECK-NOT:   tensor.pad
160//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
161//       CHECK:   tensor.dim %[[ARG0]], %[[C0]]
162//       CHECK:   %[[RESULT:.*]] = scf.if %{{.*}} -> (tensor<?x4xf32>) {
163//       CHECK:     %[[GEN:.*]] = tensor.generate %[[ARG1]]
164//       CHECK:     scf.yield %[[GEN]]
165//       CHECK:   } else {
166//       CHECK:     %[[SUBTENSOR:.*]] = tensor.extract_slice %[[ARG0]][%{{.*}}, 4] [%{{.*}}, 1] [1, 1] : tensor<?x5xf32> to tensor<?x1xf32>
167//       CHECK:     %[[PADTENSOR:.*]] = tensor.pad %[[SUBTENSOR]] low[0, 0] high[%{{.*}}, 3]
168//       CHECK:     scf.yield %[[PADTENSOR]]
169//       CHECK:   }
170//       CHECK:   return %[[RESULT]]
171func.func @dynamic_extract_size(%arg0 : tensor<?x5xf32>, %s1: index, %pad : f32) -> tensor<?x4xf32> {
172  %0 = tensor.pad %arg0 low[0, 0] high[7, 8] {
173    ^bb0(%arg1: index, %arg2: index):
174      tensor.yield %pad : f32
175    } : tensor<?x5xf32> to tensor<?x13xf32>
176  %1 = tensor.extract_slice %0[2, 4] [%s1, 4] [1, 1] : tensor<?x13xf32> to tensor<?x4xf32>
177  return %1 : tensor<?x4xf32>
178}
179
180// -----
181
182// CHECK-LABEL: @dynamic_zero_low_padding
183//       CHECK:   scf.if
184//       CHECK:     tensor.generate
185//       CHECK:   else
186//       CHECK:     %[[SLICE:.*]] = tensor.extract_slice
187//       CHECK:     tensor.pad %[[SLICE]] low[0, 0]
188func.func @dynamic_zero_low_padding(%arg0 : tensor<?x?xf32>, %pad : f32,
189                               %o1 : index, %o2 : index,
190                               %s1 : index, %s2 : index)
191    -> tensor<?x?xf32> {
192  %0 = tensor.pad %arg0 low[0, 0] high[7, 8] {
193    ^bb0(%arg1: index, %arg2: index):
194      tensor.yield %pad : f32
195    } : tensor<?x?xf32> to tensor<?x?xf32>
196  %1 = tensor.extract_slice %0[%o1, %o2] [%s1, %s2] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
197  return %1 : tensor<?x?xf32>
198}
199
200// -----
201
202// CHECK-LABEL: @dynamic_zero_high_padding
203//       CHECK:   scf.if
204//       CHECK:     tensor.generate
205//       CHECK:   else
206//       CHECK:     %[[SLICE:.*]] = tensor.extract_slice
207//       CHECK:     tensor.pad %[[SLICE]] low[%{{.*}}, %{{.*}}] high[0, 0]
208func.func @dynamic_zero_high_padding(%arg0 : tensor<?x?xf32>, %pad : f32,
209                                %o1 : index, %o2 : index,
210                                %s1 : index, %s2 : index)
211    -> tensor<?x?xf32> {
212  %0 = tensor.pad %arg0 low[7, 8] high[0, 0] {
213    ^bb0(%arg1: index, %arg2: index):
214      tensor.yield %pad : f32
215    } : tensor<?x?xf32> to tensor<?x?xf32>
216  %1 = tensor.extract_slice %0[%o1, %o2] [%s1, %s2] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
217  return %1 : tensor<?x?xf32>
218}
219