xref: /llvm-project/mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir (revision df40b056f1e956a25b8121174d0b42bf1b5c7732)
1// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
2
3///----------------------------------------------------------------------------------------
4/// [Pattern: PadOpVectorizationWithTransferReadPattern]
5///----------------------------------------------------------------------------------------
6// CHECK-LABEL: func @pad_and_transfer_read
7//  CHECK-SAME:     %[[ARG0:.*]]: tensor<5x6xf32>
8//   CHECK-NOT:   tensor.pad
9//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
10//   CHECK-DAG:   %[[C5:.*]] = arith.constant 5.0
11//       CHECK:   %[[RESULT:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32>
12//       CHECK:   return %[[RESULT]]
13func.func @pad_and_transfer_read(%arg0: tensor<5x6xf32>) -> vector<7x9xf32> {
14  %c0 = arith.constant 0 : index
15  %c5 = arith.constant 5.0 : f32
16  %c6 = arith.constant 6.0 : f32
17  %0 = tensor.pad %arg0 low[0, 0] high[5, 7] {
18    ^bb0(%arg1: index, %arg2: index):
19      tensor.yield %c5 : f32
20  } : tensor<5x6xf32> to tensor<10x13xf32>
21  %1 = vector.transfer_read %0[%c0, %c0], %c6
22      : tensor<10x13xf32>, vector<7x9xf32>
23  return %1 : vector<7x9xf32>
24}
25
26module attributes {transform.with_named_sequence} {
27  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
28    %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
29
30    transform.apply_patterns to %func_op {
31      transform.apply_patterns.linalg.pad_vectorization
32    } : !transform.op<"func.func">
33    transform.yield
34  }
35}
36
37// -----
38
39///----------------------------------------------------------------------------------------
40/// [Pattern: PadOpVectorizationWithTransferWritePattern]
41///----------------------------------------------------------------------------------------
42func.func private @make_vector() -> vector<7x9xf32>
43
44// CHECK-LABEL: func @pad_and_transfer_write_static_low_and_high
45//  CHECK-SAME:     %[[ARG0:.*]]: tensor<5x6xf32>
46//   CHECK-NOT:   tensor.pad
47//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
48//       CHECK:   %[[VEC0:.*]] = call @make_vector() : () -> vector<7x9xf32>
49//       CHECK:   %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[ARG0]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<5x6xf32>
50//       CHECK:   return %[[RESULT]]
51func.func @pad_and_transfer_write_static_low_and_high(
52    %arg0: tensor<5x6xf32>) -> tensor<5x6xf32> {
53  %c0 = arith.constant 0 : index
54  %c5 = arith.constant 5.0 : f32
55  %0 = tensor.pad %arg0 low[0, 0] high[5, 7] {
56    ^bb0(%arg2: index, %arg3: index):
57      tensor.yield %c5 : f32
58  } : tensor<5x6xf32> to tensor<10x13xf32>
59  %1 = call @make_vector() : () -> vector<7x9xf32>
60  %2 = vector.transfer_write %1, %0[%c0, %c0]
61      : vector<7x9xf32>, tensor<10x13xf32>
62  %3 = tensor.extract_slice %2[0, 0] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32>
63  return %3 : tensor<5x6xf32>
64}
65
66module attributes {transform.with_named_sequence} {
67  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
68    %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
69
70    transform.apply_patterns to %func_op {
71      transform.apply_patterns.linalg.pad_vectorization
72    } : !transform.op<"func.func">
73    transform.yield
74  }
75}
76
77// -----
78
79func.func private @make_vector() -> vector<7x9xf32>
80
81// CHECK-LABEL: func @pad_and_transfer_write_static_low_dynamic_high
82//  CHECK-SAME:     %[[ARG0:.*]]: tensor<?x?xf32>, %[[SIZE:.*]]: index, %[[PADDING:.*]]: index
83//   CHECK-NOT:   tensor.pad
84//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
85//       CHECK:   %[[SUB:.*]] = tensor.extract_slice %[[ARG0]][0, 0] [%[[SIZE]], 6] [1, 1] : tensor<?x?xf32> to tensor<?x6xf32>
86//       CHECK:   %[[VEC0:.*]] = call @make_vector() : () -> vector<7x9xf32>
87//       CHECK:   %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[SUB]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<?x6xf32>
88//       CHECK:   return %[[RESULT]]
89func.func @pad_and_transfer_write_static_low_dynamic_high(
90    %arg0: tensor<?x?xf32>, %size: index, %padding: index) -> tensor<?x6xf32> {
91  %c0 = arith.constant 0 : index
92  %c5 = arith.constant 5.0 : f32
93  %s = tensor.extract_slice %arg0[0, 0] [%size, 6] [1, 1]
94      : tensor<?x?xf32> to tensor<?x6xf32>
95  %0 = tensor.pad %s low[0, 0] high[%padding, 7] {
96    ^bb0(%arg2: index, %arg3: index):
97      tensor.yield %c5 : f32
98  } : tensor<?x6xf32> to tensor<?x13xf32>
99  %1 = call @make_vector() : () -> vector<7x9xf32>
100  %2 = vector.transfer_write %1, %0[%c0, %c0]
101      : vector<7x9xf32>, tensor<?x13xf32>
102  %3 = tensor.extract_slice %2[0, 0] [%size, 6] [1, 1] : tensor<?x13xf32> to tensor<?x6xf32>
103  return %3 : tensor<?x6xf32>
104}
105
106module attributes {transform.with_named_sequence} {
107  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
108    %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
109
110    transform.apply_patterns to %func_op {
111      transform.apply_patterns.linalg.pad_vectorization
112    } : !transform.op<"func.func">
113    transform.yield
114  }
115}
116
117// -----
118
119func.func private @make_vector() -> vector<7x9xf32>
120
121// Negative test - low pad is non-zero
122
123// CHECK-LABEL: func @pad_and_transfer_write_static_non_zero_low_pad
124//   CHECK:   tensor.pad
125func.func @pad_and_transfer_write_static_non_zero_low_pad(
126    %arg0: tensor<5x6xf32>) -> tensor<5x6xf32> {
127  %c0 = arith.constant 0 : index
128  %c5 = arith.constant 5.0 : f32
129  %0 = tensor.pad %arg0 low[0, 1] high[5, 6] {
130    ^bb0(%arg2: index, %arg3: index):
131      tensor.yield %c5 : f32
132  } : tensor<5x6xf32> to tensor<10x13xf32>
133  %1 = call @make_vector() : () -> vector<7x9xf32>
134  %2 = vector.transfer_write %1, %0[%c0, %c0]
135      : vector<7x9xf32>, tensor<10x13xf32>
136  %3 = tensor.extract_slice %2[0, 0] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32>
137  return %3 : tensor<5x6xf32>
138}
139
140module attributes {transform.with_named_sequence} {
141  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
142    %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
143
144    transform.apply_patterns to %func_op {
145      transform.apply_patterns.linalg.pad_vectorization
146    } : !transform.op<"func.func">
147    transform.yield
148  }
149}
150
151// -----
152
153// Negative test - TransferWriteOp result is not _directly_ consumed by an
154// ExtractSliceOp (noet the non-zero offset).
155
156func.func private @make_vector() -> vector<7x9xf32>
157
158// CHECK-LABEL: func @pad_and_transfer_write_static_non_zero_offset
159//   CHECK:   tensor.pad
160func.func @pad_and_transfer_write_static_non_zero_offset(
161    %arg0: tensor<5x6xf32>) -> tensor<5x6xf32> {
162  %c0 = arith.constant 0 : index
163  %c5 = arith.constant 5.0 : f32
164  %0 = tensor.pad %arg0 low[0, 0] high[5, 7] {
165    ^bb0(%arg2: index, %arg3: index):
166      tensor.yield %c5 : f32
167  } : tensor<5x6xf32> to tensor<10x13xf32>
168  %1 = call @make_vector() : () -> vector<7x9xf32>
169  %2 = vector.transfer_write %1, %0[%c0, %c0]
170      : vector<7x9xf32>, tensor<10x13xf32>
171  %3 = tensor.extract_slice %2[0, 1] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32>
172  return %3 : tensor<5x6xf32>
173}
174
175module attributes {transform.with_named_sequence} {
176  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
177    %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
178
179    transform.apply_patterns to %func_op {
180      transform.apply_patterns.linalg.pad_vectorization
181    } : !transform.op<"func.func">
182    transform.yield
183  }
184}
185
186// -----
187
188///----------------------------------------------------------------------------------------
189/// [Pattern: PadOpVectorizationWithInsertSlicePattern]
190///----------------------------------------------------------------------------------------
191
192func.func private @make_vector() -> tensor<12x13xf32>
193
194// CHECK-LABEL: func @pad_and_insert_slice_source
195//  CHECK-SAME:     %[[ARG0:.*]]: tensor<5x6xf32>
196//   CHECK-NOT:   tensor.pad
197//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
198//   CHECK-DAG:   %[[C5:.*]] = arith.constant 5.0
199//       CHECK:   %[[VEC0:.*]] = call @make_vector() : () -> tensor<12x13xf32>
200//       CHECK:   %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32>
201//       CHECK:   %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[VEC0]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<7x9xf32>, tensor<12x13xf32>
202//       CHECK:   return %[[WRITE]]
203func.func @pad_and_insert_slice_source(
204    %arg0: tensor<5x6xf32>) -> tensor<12x13xf32> {
205  %c0 = arith.constant 0 : index
206  %c5 = arith.constant 5.0 : f32
207  %0 = tensor.pad %arg0 low[0, 0] high[2, 3] {
208    ^bb0(%arg2: index, %arg3: index):
209      tensor.yield %c5 : f32
210  } : tensor<5x6xf32> to tensor<7x9xf32>
211  %1 = call @make_vector() : () -> tensor<12x13xf32>
212  %r = tensor.insert_slice %0 into %1[0, 0][7, 9][1, 1] : tensor<7x9xf32> into tensor<12x13xf32>
213  return %r : tensor<12x13xf32>
214}
215
216module attributes {transform.with_named_sequence} {
217  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
218    %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
219
220    transform.apply_patterns to %func_op {
221      transform.apply_patterns.linalg.pad_vectorization
222    } : !transform.op<"func.func">
223    transform.yield
224  }
225}
226
227
228// -----
229
230///----------------------------------------------------------------------------------------
231/// tensor::PadOp -> tensor::EmptyOp + linalg::FillOp/tensor::GenerateOp + tensor::InsertSliceOp
232/// [Pattern: GenericPadOpVectorizationPattern + InsertSliceVectorizePattern]
233/// TODO: Split the test into two, one for each pattern.
234///----------------------------------------------------------------------------------------
235
236func.func private @make_vector() -> tensor<12x13xf32>
237
238// Same as @pad_and_insert_slice_dest in vectorization-with-patterns.mlir, but
239// over here linalg::fill is not vectorized (patterns for linalg.fill are not
240// included here)
241// CHECK-LABEL:   func.func @pad_and_insert_slice_dest(
242// CHECK-SAME:      %[[ARG_0:.*]]: tensor<1x5x6xf32>) -> tensor<1x12x13xf32> {
243//  CHECK-NOT:     tensor.pad
244//  CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
245//  CHECK-DAG:     %[[PAD:.*]] = arith.constant 5.000000e+00 : f32
246//  CHECK-DAG:     %[[PAD_READ:.*]] = arith.constant 0.000000e+00 : f32
247//      CHECK:     %[[EMPTY:.*]] = tensor.empty() : tensor<1x12x13xf32>
248//      CHECK:     %[[FILL:.*]] = linalg.fill ins(%[[PAD]] : f32) outs(%[[EMPTY]] : tensor<1x12x13xf32>) -> tensor<1x12x13xf32>
249//      CHECK:     %[[READ_1:.*]] = vector.transfer_read %[[ARG_0]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x5x6xf32>, vector<1x5x6xf32>
250//      CHECK:     %[[WRITE_1:.*]] = vector.transfer_write %[[READ_1]], %[[FILL]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x5x6xf32>, tensor<1x12x13xf32>
251//      CHECK:     %[[VEC:.*]] = call @make_vector() : () -> tensor<12x13xf32>
252//      CHECK:     %[[READ_2:.*]] = vector.transfer_read %[[VEC]]{{\[}}%[[C0]], %[[C0]]], %[[PAD_READ]] {in_bounds = [true, true]} : tensor<12x13xf32>, vector<12x13xf32>
253//      CHECK:     %[[RES:.*]] = vector.transfer_write %[[READ_2]], %[[WRITE_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<12x13xf32>, tensor<1x12x13xf32>
254//      CHECK:     return %[[RES]] : tensor<1x12x13xf32>
255
256func.func @pad_and_insert_slice_dest(
257    %arg0: tensor<1x5x6xf32>) -> tensor<1x12x13xf32> {
258  %c5 = arith.constant 5.0 : f32
259  %0 = tensor.pad %arg0 low[0, 0, 0] high[0, 7, 7] {
260    ^bb0(%arg2: index, %arg3: index, %arg4: index):
261      tensor.yield %c5 : f32
262  } : tensor<1x5x6xf32> to tensor<1x12x13xf32>
263  %1 = call @make_vector() : () -> tensor<12x13xf32>
264  %r = tensor.insert_slice %1 into %0[0, 0, 0][1, 12, 13][1, 1, 1] : tensor<12x13xf32> into tensor<1x12x13xf32>
265  return %r : tensor<1x12x13xf32>
266}
267
268module attributes {transform.with_named_sequence} {
269  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
270    %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
271
272    transform.apply_patterns to %func_op {
273      // TODO: Split into two tests, one for each pattern
274      transform.apply_patterns.linalg.decompose_pad
275      transform.apply_patterns.linalg.pad_vectorization
276    } : !transform.op<"func.func">
277    transform.yield
278  }
279}
280