xref: /llvm-project/mlir/test/Dialect/Linalg/transform-op-split.mlir (revision d28a4f1fc02dc34a87fa22af0a053e8f1e7f6cea)
1// RUN: mlir-opt %s --transform-interpreter --split-input-file -verify-diagnostics | FileCheck %s
2
3module attributes {transform.with_named_sequence} {
4  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
5    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
6    %1 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op
7    transform.yield
8  }
9}
10
11func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
12
13// CHECK: #[[$ADD_42_MAP:.+]] = affine_map<()[s0] -> (s0 + 42)>
14
15// CHECK-LABEL: @one_d_static
16// CHECK-SAME:  %[[IN:.+]]: tensor<100xf32>, %[[OUT:.+]]: tensor<100xf32>
17func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
18  // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN]][0] [42] [1] : tensor<100xf32> to tensor<42xf32>
19  // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT]][0] [42] [1] : tensor<100xf32> to tensor<42xf32>
20  // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic
21  // CHECK:   ins(%[[IN_SLICE_LOW]]
22  // CHECK:   outs(%[[OUT_SLICE_LOW]]
23  // CHECK:   linalg.index 0
24  // CHECK:   func.call @elem
25  // CHECK: %[[RES_PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [42] [1]
26  //
27  // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN]][42] [58] [1] : tensor<100xf32> to tensor<58xf32>
28  // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[RES_PARTIAL]][42] [58] [1] : tensor<100xf32> to tensor<58xf32>
29  // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic
30  // CHECK:   ins(%[[IN_SLICE_HIGH]]
31  // CHECK:   outs(%[[OUT_SLICE_HIGH]]
32  // CHECK:   %[[IDX:.+]] = linalg.index 0
33  // CHECK:   affine.apply #[[$ADD_42_MAP]]()[%[[IDX]]]
34  // CHECK:   func.call @elem
35  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[RES_PARTIAL]][42] [58] [1]
36  %0 = linalg.generic {
37    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
38    iterator_types = ["parallel"]
39  }
40  ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
41  ^bb0(%0: f32, %1: f32):
42    %i = linalg.index 0 : index
43    %call_res = func.call @elem(%0, %i, %i) : (f32, index, index) -> f32
44    linalg.yield %call_res : f32
45  } -> tensor<100xf32>
46
47  // CHECK: return %[[RES]]
48  return %0 : tensor<100xf32>
49}
50
51// -----
52
53module attributes {transform.with_named_sequence} {
54  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
55    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
56    %1 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op
57    transform.yield
58  }
59}
60
61func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
62
63// CHECK-LABEL: @one_d_static_overflow
64// CHECK-SAME:  %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32>
65func.func @one_d_static_overflow(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> {
66  // Folding is sufficiently powerful to detect the static overflow and avoid
67  // the splitting altogether.
68  // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic
69  // CHECK:   ins(%[[IN]]
70  // CHECK:   outs(%[[OUT]]
71  // CHECK:   linalg.index 0
72  // CHECK:   func.call @elem
73  %0 = linalg.generic {
74    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
75    iterator_types = ["parallel"]
76  }
77  ins(%arg0: tensor<10xf32>) outs(%arg1: tensor<10xf32>) {
78  ^bb0(%0: f32, %1: f32):
79    %i = linalg.index 0 : index
80    %call_res = func.call @elem(%0, %i, %i) : (f32, index, index) -> f32
81    linalg.yield %call_res : f32
82  } -> tensor<10xf32>
83  return %0 : tensor<10xf32>
84}
85
86// -----
87
88module attributes {transform.with_named_sequence} {
89  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
90    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
91    %1 = transform.structured.match ops{["func.call"]} in %arg1 : (!transform.any_op) -> !transform.any_op
92    transform.structured.split %0 after %1 { dimension = 0 } : !transform.any_op, !transform.any_op
93    transform.yield
94  }
95}
96
97func.func private @get_size() -> index
98
99// CHECK: #[[$MAP_MIN_100:.+]] = affine_map<()[s0] -> (s0, 100)>
100// CHECK: #[[$MAP_S_MINUS_100:.+]] = affine_map<()[s0] -> (-s0 + 100)>
101
102// CHECK-LABEL: @dynamic
103func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
104  // CHECK: %[[SPLIT:.+]] = call @get_size
105  // CHECK: %[[SPLIT_LOW:.+]] = affine.min #[[$MAP_MIN_100]]()[%[[SPLIT]]
106  // CHECK: %[[SPLIT_HIGH_1:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
107  // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor<?xf32>
108  // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor<?xf32>
109  // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic
110  // CHECK:   ins(%[[IN_SLICE_LOW]]
111  // CHECK:   outs(%[[OUT_SLICE_LOW]]
112  // CHECK: %[[PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [%[[SPLIT_LOW]]] [1]
113  //
114  // CHECK: %[[SPLIT_HIGH_2:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
115  // CHECK: %[[SPLIT_HIGH_3:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
116  // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_2]]] [1] : tensor<100xf32> to tensor<?xf32>
117  // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[PARTIAL:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_3]]] [1] : tensor<100xf32> to tensor<?xf32>
118  // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic
119  // CHECK:   ins(%[[IN_SLICE_HIGH]]
120  // CHECK:   outs(%[[OUT_SLICE_HIGH]]
121  // CHECK: %[[SPLIT_HIGH_4:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
122  // CHECK: tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[PARTIAL]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_4]]] [1]
123  %0 = func.call @get_size() : () -> index
124  %1 = linalg.generic {
125    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
126    iterator_types = ["parallel"]
127  }
128  ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
129  ^bb0(%3: f32, %4: f32):
130    %5 = arith.addf %3, %4 : f32
131    linalg.yield %5 : f32
132  } -> tensor<100xf32>
133  return %1 : tensor<100xf32>
134}
135
136// -----
137
138module attributes {transform.with_named_sequence} {
139  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
140    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
141    %t = transform.structured.split %0 after 4 { dimension = 0 } : !transform.any_op
142    %1:2 = transform.split_handle %t : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
143    %2 = transform.structured.split %1#1 after 16 { dimension = 1 } : !transform.any_op
144    transform.yield
145  }
146}
147
148func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
149
150// CHECK-LABEL: @two_d
151func.func @two_d(%arg0: tensor<10x34xf32>,
152                 %arg1: tensor<10x34xf32>) -> tensor<10x34xf32> {
153  // Check the overall structure: split along the dimension 0, and then split
154  // the second half only along the dimension 1.
155  // CHECK:      %[[IN_1:.+]] = tensor.extract_slice %[[IN:.+]][0, 0]
156  // CHECK:      %[[OUT_1:.+]] = tensor.extract_slice %[[OUT:.+]][0, 0]
157  // CHECK:      %[[RES_1:.+]] = linalg.generic
158  // CHECK-SAME:   ins(%[[IN_1]] : tensor<4x34xf32>)
159  // CHECK-SAME:   outs(%[[OUT_1]] : tensor<4x34xf32>)
160  // CHECK:      %[[PARTIAL_1:.+]] = tensor.insert_slice %[[RES_1]] into %[[OUT]]
161  //
162  // CHECK:      %[[IN_2:.+]] = tensor.extract_slice %[[IN]]
163  // CHECK:      %[[OUT_2:.+]] = tensor.extract_slice %[[PARTIAL_1]]
164  // Note that `extract_slice` taking a slice from another `extract_slice` result
165  // is folded to use the operand of the first `extract_slice`.
166  // CHECK:      %[[IN_21:.+]] = tensor.extract_slice %[[IN_2]]
167  // CHECK:      %[[OUT_21:.+]] = tensor.extract_slice %[[OUT_2]]
168  // CHECK:      %[[RES_21:.+]] = linalg.generic
169  // CHECK-SAME:   ins(%[[IN_21]] : tensor<6x16xf32>)
170  // CHECK-SAME:   outs(%[[OUT_21]] : tensor<6x16xf32>)
171  // CHECK:      %[[PARTIAL_21:.+]] = tensor.insert_slice %[[RES_21]] into %[[OUT_2]]
172  //
173  // CHECK:      %[[IN_22:.+]] = tensor.extract_slice %[[IN_2]]
174  // CHECK:      %[[OUT_22:.+]] = tensor.extract_slice %[[PARTIAL_21]]
175  // CHECK:      %[[RES_22:.+]] = linalg.generic
176  // CHECK-SAME:   ins(%[[IN_22]] : tensor<6x18xf32>)
177  // CHECK-SAME:   outs(%[[OUT_22]] : tensor<6x18xf32>)
178  // CHECK:      %[[PARTIAL_22:.+]] = tensor.insert_slice %[[RES_22]] into %[[PARTIAL_21]]
179  // CHECK:      %[[PARTIAL_2:.+]] = tensor.insert_slice %[[PARTIAL_22]] into %[[PARTIAL_1]]
180  %0 = linalg.generic {
181    indexing_maps = [affine_map<(i, j) -> (i, j)>,
182                     affine_map<(i, j) -> (i, j)>],
183    iterator_types = ["parallel", "parallel"]
184  }
185  ins(%arg0: tensor<10x34xf32>)
186  outs(%arg1: tensor<10x34xf32>) {
187  ^bb0(%0: f32, %1: f32):
188    %i = linalg.index 0 : index
189    %j = linalg.index 1 : index
190    %call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32
191    linalg.yield %call_res : f32
192  } -> tensor<10x34xf32>
193  return %0 : tensor<10x34xf32>
194}
195
196// -----
197
198module attributes {transform.with_named_sequence} {
199  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) {
200    // expected-error @below {{expects either a dynamic or a static split point to be provided}}
201    %0 = "transform.structured.split"(%arg1) { dimension = 1, static_chunk_sizes = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op)
202    transform.yield
203  }
204}
205
206// -----
207
208module attributes {transform.with_named_sequence} {
209  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
210    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
211    %1 = transform.structured.match ops{["func.call"]} in %arg1 : (!transform.any_op) -> !transform.any_op
212    // expected-error @below {{expected dynamic split point handle to point to a single-result index-typed op}}
213    transform.structured.split %0 after %1 { dimension = 0 } : !transform.any_op, !transform.any_op
214    transform.yield
215  }
216}
217
218func.func private @get_size() -> i64
219
220func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
221  // expected-note @below {{dynamic split point}}
222  %0 = func.call @get_size() : () -> i64
223  %1 = linalg.generic {
224    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
225    iterator_types = ["parallel"]
226  }
227  ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
228  ^bb0(%3: f32, %4: f32):
229    linalg.yield %3 : f32
230  } -> tensor<100xf32>
231  return %1 : tensor<100xf32>
232}
233
234// -----
235
236module attributes {transform.with_named_sequence} {
237  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
238    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
239    %1 = transform.structured.match ops{["func.call"]} in %arg1 : (!transform.any_op) -> !transform.any_op
240    // expected-error @below {{expected the dynamic split point handle to point to as many operations (0) as the target handle (1)}}
241    transform.structured.split %0 after %1 { dimension = 0 } : !transform.any_op, !transform.any_op
242    transform.yield
243  }
244}
245
246func.func private @get_size() -> i64
247
248func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
249  %1 = linalg.generic {
250    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
251    iterator_types = ["parallel"]
252  }
253  ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
254  ^bb0(%3: f32, %4: f32):
255    linalg.yield %3 : f32
256  } -> tensor<100xf32>
257  return %1 : tensor<100xf32>
258}
259
260// -----
261
262module attributes {transform.with_named_sequence} {
263  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
264    %0 = transform.structured.match ops{["func.return"]} in %arg1 : (!transform.any_op) -> !transform.any_op
265    // expected-error @below {{only applies to structured ops}}
266    transform.structured.split %0 after 16 { dimension = 1 } : !transform.any_op
267    transform.yield
268  }
269}
270
271func.func @noop(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
272  // expected-note @below {{target op}}
273  return %arg0 : tensor<100xf32>
274}
275
276// -----
277
278module attributes {transform.with_named_sequence} {
279  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
280    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
281    // expected-error @below {{dimension 1 does not exist in target op}}
282    transform.structured.split %0 after 16 { dimension = 1 } : !transform.any_op
283    transform.yield
284  }
285}
286
287func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
288  // expected-note @below {{target op}}
289  %0 = linalg.generic {
290    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
291    iterator_types = ["parallel"]
292  }
293  ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
294  ^bb0(%0: f32, %1: f32):
295    linalg.yield %0 : f32
296  } -> tensor<100xf32>
297  return %0 : tensor<100xf32>
298}
299
300// -----
301
302module attributes {transform.with_named_sequence} {
303  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
304    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
305    // expected-error @below {{splitting does not produce the second part for a subset of targets}}
306    // expected-note @below {{expected splitting to produce the second part of all or none of the targets}}
307    %1 = transform.structured.split %0 after 142 { dimension = 0 } : !transform.any_op
308    transform.yield
309  }
310}
311
312func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
313
314func.func @split_one_but_not_other(
315    %arg0: tensor<100xf32>, %arg1: tensor<100xf32>,
316    %arg2: tensor<200xf32>, %arg3: tensor<200xf32>)
317    -> (tensor<100xf32>, tensor<200xf32>) {
318  // expected-note @below {{first target with no second part}}
319  %0 = linalg.generic {
320    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
321    iterator_types = ["parallel"]
322  }
323  ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
324  ^bb0(%arg4: f32, %arg5: f32):
325    %i = linalg.index 0 : index
326    %call_res = func.call @elem(%arg4, %i, %i) : (f32, index, index) -> f32
327    linalg.yield %call_res : f32
328  } -> tensor<100xf32>
329
330  %1 = linalg.generic {
331    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
332    iterator_types = ["parallel"]
333  }
334  ins(%arg2: tensor<200xf32>) outs(%arg3: tensor<200xf32>) {
335  ^bb0(%arg4: f32, %arg5: f32):
336    %i = linalg.index 0 : index
337    %call_res = func.call @elem(%arg4, %i, %i) : (f32, index, index) -> f32
338    linalg.yield %call_res : f32
339  } -> tensor<200xf32>
340
341  return %0, %1 : tensor<100xf32>, tensor<200xf32>
342}
343
344