xref: /llvm-project/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir (revision 2798b72ae7e5caad793169b77cbac47fe2362d0f)
1// RUN: mlir-opt --transform-interpreter --split-input-file %s -verify-diagnostics | FileCheck %s
2
3#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
4#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
5#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
6
7module {
8  // CHECK-LABEL: func.func @fuse_tileable_op
9  //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index
10  //  CHECK-SAME:   %[[IN:[0-9a-z]+]]: tensor<?xf32>
11  //  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<?xf32>
12  func.func @fuse_tileable_op(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
13    %cst = arith.constant 4.200000e+01 : f32
14    %c0 = arith.constant 0 : index
15    %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
16    %d0 = tensor.dim %arg1, %c0 : tensor<?xf32>
17    %1 = affine.apply #map0()[%d0, %arg0]
18
19    // CHECK: scf.forall {{.*}} {
20    %2 = scf.forall (%arg3) in (%1) shared_outs(%o = %arg2) -> (tensor<?xf32>) {
21      %3 = affine.apply #map1(%arg3)[%arg0]
22      %4 = affine.min #map2(%arg3)[%d0, %arg0]
23      %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
24
25      // CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%{{.*}}] [%{{.*}}] [{{.*}}]
26      // CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]]
27      %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
28
29      // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]]
30      %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
31      scf.forall.in_parallel {
32        tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
33      }
34    }
35    // CHECK: }
36    func.return %2 : tensor<?xf32>
37  }
38
39  // Check no failure when nothing happens.
40  func.func @dummy1() { return }
41  func.func @dummy2() { return }
42  func.func @dummy3() { return }
43
44  module attributes {transform.with_named_sequence} {
45    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
46      %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.fill">
47      %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
48
49      // linalg.fill is tileable. The op is tiled and fused.
50      transform.structured.fuse_into_containing_op %0 into %1
51        : (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
52        transform.yield
53    }
54  }
55}
56
57// -----
58
59#map0 = affine_map<()[s0] -> (64 ceildiv s0)>
60#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
61#map2 = affine_map<(d0)[s0] -> (-(d0 * s0) + 64, s0)>
62
63module {
64  // CHECK-LABEL: func.func @fuse_untileable_op
65  //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index
66  //  CHECK-SAME:   %[[IN:[0-9a-z]+]]: tensor<64xf32>
67  //  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<64xf32>
68  func.func @fuse_untileable_op(%arg0: index, %arg1: tensor<64xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
69    %0 = tensor.empty(%arg0) : tensor<?xf32>
70    %1 = affine.apply #map0()[%arg0]
71
72    // CHECK: scf.forall {{.*}} {
73    %2 = scf.forall (%arg3) in (%1) shared_outs(%o = %arg2) -> (tensor<64xf32>) {
74      // CHECK: %[[INIT_TENSOR:.*]] = tensor.empty
75      %3 = affine.apply #map1(%arg3)[%arg0]
76      %4 = affine.min #map2(%arg3)[%arg0]
77      %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<64xf32> to tensor<?xf32>
78
79      // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[INIT_TENSOR]]
80      %7 = linalg.elemwise_unary ins(%0 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
81      scf.forall.in_parallel {
82        tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<64xf32>
83      }
84    }
85    // CHECK: }
86
87    func.return %2 : tensor<64xf32>
88  }
89
90  module attributes {transform.with_named_sequence} {
91    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
92      %0 = transform.structured.match ops{["tensor.empty"]} in %arg1 : (!transform.any_op) -> !transform.op<"tensor.empty">
93      %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
94
95      // tensor.empty is not tileable. The op is cloned and fused.
96      transform.structured.fuse_into_containing_op %0 into %1
97        : (!transform.op<"tensor.empty">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
98        transform.yield
99    }
100  }
101}
102
103// -----
104
105module {
106  func.func @foo(%0: tensor<f32>) -> tensor<f32> {
107    return %0: tensor<f32>
108  }
109
110  // CHECK-LABEL: func.func @fuse_tileable_op_rank_reducing
111  //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index
112  //  CHECK-SAME:   %[[IN:[0-9a-z]+]]: tensor<?xf32>
113  //  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<?xf32>
114  func.func @fuse_tileable_op_rank_reducing(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
115    %cst = arith.constant 4.200000e+01 : f32
116    %c0 = arith.constant 0 : index
117    %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<?xf32>) -> tensor<?xf32>
118    %d0 = tensor.dim %arg1, %c0 : tensor<?xf32>
119
120    // CHECK: scf.forall {{.*}} -> (tensor<?xf32>) {
121    %2 = scf.forall (%arg3) in (%d0) shared_outs(%o = %0) -> (tensor<?xf32>) {
122      %5 = tensor.extract_slice %o[%arg3] [1] [1] : tensor<?xf32> to tensor<f32>
123
124      // CHECK: tensor.extract_slice %{{.*}}[%{{.*}}] [1] [1] : tensor<?xf32> to tensor<1xf32>
125      // CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<1xf32>) -> tensor<1xf32>
126      // CHECK: tensor.extract_slice %{{.*}}[0] [1] [1] : tensor<1xf32> to tensor<f32>
127      // CHECK: func.call @foo(%{{.*}}) : (tensor<f32>) -> tensor<f32>
128      %7 = func.call @foo(%5) : (tensor<f32>) -> tensor<f32>
129
130      scf.forall.in_parallel {
131      // CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [1] [1] : tensor<f32> into tensor<?xf32>
132        tensor.parallel_insert_slice %7 into %o[%arg3] [1] [1] : tensor<f32> into tensor<?xf32>
133      }
134    }
135    // CHECK: }
136    func.return %2 : tensor<?xf32>
137  }
138
139  module attributes {transform.with_named_sequence} {
140    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
141      %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.fill">
142      %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
143
144      // linalg.fill is tileable. The op is tiled and fused.
145      transform.structured.fuse_into_containing_op %0 into %1
146        : (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
147        transform.yield
148    }
149  }
150}
151
152// -----
153
154#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
155#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
156#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
157
158module {
159  // CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg
160  //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index
161  //  CHECK-SAME:   %[[IN:[0-9a-z]+]]: tensor<?xf32>
162  //  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<?xf32>
163  func.func @fuse_tileable_op_through_bbarg(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
164    %cst = arith.constant 4.200000e+01 : f32
165    %c0 = arith.constant 0 : index
166    %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<?xf32>) -> tensor<?xf32>
167    %d0 = tensor.dim %arg1, %c0 : tensor<?xf32>
168    %1 = affine.apply #map0()[%d0, %arg0]
169
170    // CHECK: scf.forall {{.*}} shared_outs(%[[BBARGOUT:.*]] = %[[OUT]]) -> (tensor<?xf32>) {
171    %2 = scf.forall (%arg3) in (%1) shared_outs(%o = %0) -> (tensor<?xf32>) {
172      %3 = affine.apply #map1(%arg3)[%arg0]
173      %4 = affine.min #map2(%arg3)[%d0, %arg0]
174      %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
175
176      // CHECK: %[[T0:.*]] = tensor.extract_slice %[[BBARGOUT]][%{{.*}}] [%{{.*}}] [{{.*}}]
177      // CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]]
178      %6 = tensor.extract_slice %arg1[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
179
180      // CHECK: %[[T2:.*]] = linalg.elemwise_unary {{.*}} outs(%[[T1]]
181      %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
182      scf.forall.in_parallel {
183        tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
184      }
185    }
186    // CHECK: }
187    func.return %2 : tensor<?xf32>
188  }
189
190  module attributes {transform.with_named_sequence} {
191    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
192      %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
193      %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
194
195      // linalg.fill is tileable. The op is tiled and fused.
196      transform.structured.fuse_into_containing_op %0 into %1
197        : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
198        transform.yield
199    }
200  }
201}
202
203// -----
204
205#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
206#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
207#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
208
209module {
210  // CHECK-LABEL: func.func @fuse_tileable_multi_output_op
211  //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index
212  //  CHECK-SAME:   %[[IN:[0-9a-z]+]]: tensor<?xf32>
213  //  CHECK-SAME:   %[[OUT_1:[0-9a-z]+]]: tensor<?xf32>
214  //  CHECK-SAME:   %[[OUT_2:[0-9a-z]+]]: tensor<?xf32>
215  //  CHECK-SAME:   %[[OUT_3:[0-9a-z]+]]: tensor<?xf32>
216  func.func @fuse_tileable_multi_output_op(%idx: index, %in: tensor<?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>) -> tensor<?xf32> {
217    %cst = arith.constant 4.200000e+01 : f32
218    %c0 = arith.constant 0 : index
219
220    %0:2 = linalg.generic {
221      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
222      iterator_types = ["parallel"]
223    } ins(%in : tensor<?xf32>) outs(%out_1, %out_3 : tensor<?xf32>, tensor<?xf32>) {
224      ^bb0(%a: f32, %b: f32, %c: f32):
225        %d = arith.addf %a, %b : f32
226        %e = arith.addf %d, %c : f32
227        linalg.yield %d, %e : f32, f32
228    } -> (tensor<?xf32>, tensor<?xf32>)
229    %d0 = tensor.dim %out_1, %c0 : tensor<?xf32>
230
231    %1 = affine.apply #map0()[%d0, %idx]
232
233    // CHECK: scf.forall {{.*}} {
234    %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
235      %3 = affine.apply #map1(%i)[%idx]
236      %4 = affine.min #map2(%i)[%d0, %idx]
237      %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
238
239      // CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%{{.*}}] [%{{.*}}] [{{.*}}]
240      // CHECK: %[[T1:.*]]:2 = linalg.generic {{.*}} ins(%[[T0]]
241      %6 = tensor.extract_slice %0#0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
242
243      // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]]#0
244      %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
245      scf.forall.in_parallel {
246        tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
247      }
248    }
249    // CHECK: }
250    func.return %2 : tensor<?xf32>
251  }
252
253  module attributes {transform.with_named_sequence} {
254    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
255      %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic">
256      %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
257
258      // linalg.generic is tileable. The op is tiled and fused.
259      transform.structured.fuse_into_containing_op %0 into %1
260        : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
261        transform.yield
262    }
263  }
264}
265
266// -----
267
268module {
269  // CHECK-LABEL: func.func @fuse_repeated
270  func.func @fuse_repeated(%fill: tensor<2xf32>, %output: tensor<2xf32>) -> tensor<2xf32> {
271    %c0 = arith.constant 0.0 : f32
272    %0 = linalg.fill ins(%c0 : f32) outs(%fill : tensor<2xf32>) -> tensor<2xf32>
273
274    // CHECK: scf.forall
275    %1 = scf.forall (%i) in (2) shared_outs(%arg1 = %output) -> (tensor<2xf32>) {
276      %2 = tensor.extract_slice %0[%i][1][1] : tensor<2xf32> to tensor<1xf32>
277      %3 = tensor.extract_slice %arg1[%i][1][1] : tensor<2xf32> to tensor<1xf32>
278      // CHECK: %[[FUSED:.+]] = linalg.fill
279      // CHECK: elemwise_unary ins(%[[FUSED]]
280      %4 = linalg.elemwise_unary ins(%2 : tensor<1xf32>) outs(%3 : tensor<1xf32>) -> tensor<1xf32>
281      scf.forall.in_parallel {
282        tensor.parallel_insert_slice %4 into %arg1[%i][1][1] : tensor<1xf32> into tensor<2xf32>
283      }
284    }
285
286    return %1 : tensor<2xf32>
287  }
288
289  module attributes {transform.with_named_sequence} {
290    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
291      %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
292      %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
293
294      // Create a new handle that points to `linalg.fill` twice.
295      %2 = transform.merge_handles %0, %0 : !transform.any_op
296
297      // It shouldn't be a problem to fuse this handle.
298      transform.structured.fuse_into_containing_op %2 into %1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
299      transform.yield
300    }
301  }
302}
303
304// -----
305
306#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
307#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
308#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
309
310module {
311  // CHECK-LABEL: func.func @fuse_tileable_multi_output_op_multi_use
312  //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index
313  //  CHECK-SAME:   %[[IN:[0-9a-z]+]]: tensor<?xf32>
314  //  CHECK-SAME:   %[[OUT_1:[0-9a-z]+]]: tensor<?xf32>
315  //  CHECK-SAME:   %[[OUT_2:[0-9a-z]+]]: tensor<?xf32>
316  //  CHECK-SAME:   %[[OUT_3:[0-9a-z]+]]: tensor<?xf32>
317  func.func @fuse_tileable_multi_output_op_multi_use(%idx: index, %in: tensor<?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>)
318    -> (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
319    %cst = arith.constant 4.200000e+01 : f32
320    %c0 = arith.constant 0 : index
321
322    // CHECK: %[[G0:.*]]:2 = linalg.generic
323    %0:2 = linalg.generic {
324      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
325      iterator_types = ["parallel"]
326    } ins(%in : tensor<?xf32>) outs(%out_1, %out_3 : tensor<?xf32>, tensor<?xf32>) {
327      ^bb0(%a: f32, %b: f32, %c: f32):
328        %d = arith.addf %a, %b : f32
329        %e = arith.addf %d, %c : f32
330        linalg.yield %d, %e : f32, f32
331    } -> (tensor<?xf32>, tensor<?xf32>)
332    %d0 = tensor.dim %out_1, %c0 : tensor<?xf32>
333
334    %1 = affine.apply #map0()[%d0, %idx]
335
336    // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]])
337    // CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) {
338    // expected-remark @below{{new containing op}}
339    %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
340      // CHECK: %[[I0:.*]] = affine.apply {{.*}}
341      %3 = affine.apply #map1(%i)[%idx]
342      // CHECK: %[[I1:.*]] = affine.min {{.*}}
343      %4 = affine.min #map2(%i)[%d0, %idx]
344      %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
345
346      // CHECK: %[[T1:.*]]:2 = linalg.generic {{.*}}
347      %6 = tensor.extract_slice %0#0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
348
349      %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
350      scf.forall.in_parallel {
351        // CHECK: tensor.parallel_insert_slice %[[T1]]#0 into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
352        tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
353      }
354    }
355    // CHECK: return %[[R0]]#0, %[[R0]]#1, %[[G0]]#1
356    func.return %2, %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>
357    // CHECK: }
358  }
359
360  module attributes {transform.with_named_sequence} {
361    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
362      %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic">
363      %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
364
365      // linalg.generic is tileable. The op is tiled and fused.
366      %fused, %containing = transform.structured.fuse_into_containing_op %0 into %1
367        : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
368      transform.debug.emit_remark_at %containing, "new containing op" : !transform.any_op
369      transform.yield
370    }
371  }
372}
373
374// -----
375
376#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
377#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
378#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
379
380module {
381  // CHECK-LABEL: func.func @fuse_tileable_mixed_dominating_uses
382  //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index
383  //  CHECK-SAME:   %[[IN:[0-9a-z]+]]: tensor<?xf32>
384  //  CHECK-SAME:   %[[OUT_1:[0-9a-z]+]]: tensor<?xf32>
385  //  CHECK-SAME:   %[[OUT_2:[0-9a-z]+]]: tensor<?xf32>
386  //  CHECK-SAME:   %[[OUT_3:[0-9a-z]+]]: tensor<?xf32>
387  func.func @fuse_tileable_mixed_dominating_uses(%idx: index, %in: tensor<?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>)
388    -> (tensor<?xf32>, tensor<?xf32>) {
389    %cst = arith.constant 4.200000e+01 : f32
390    %c0 = arith.constant 0 : index
391
392    // CHECK: %[[G0:.*]] = linalg.generic
393    %0 = linalg.generic {
394      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
395      iterator_types = ["parallel"]
396    } ins(%in : tensor<?xf32>) outs(%out_1 : tensor<?xf32>) {
397      ^bb0(%a: f32, %b: f32):
398        %d = arith.addf %a, %b : f32
399        linalg.yield %d : f32
400    } -> tensor<?xf32>
401    // CHECK: %[[D0:.*]] = tensor.dim %[[G0]]
402    %d0 = tensor.dim %0, %c0 : tensor<?xf32>
403
404    %1 = affine.apply #map0()[%d0, %idx]
405
406    // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]])
407    // CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) {
408    %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
409      // CHECK: %[[I0:.*]] = affine.apply {{.*}}
410      %3 = affine.apply #map1(%i)[%idx]
411      // CHECK: %[[I1:.*]] = affine.min {{.*}}
412      %4 = affine.min #map2(%i)[%d0, %idx]
413      %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
414
415      // CHECK: %[[T1:.*]] = linalg.generic {{.*}}
416      %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
417
418      %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
419      scf.forall.in_parallel {
420        // CHECK: tensor.parallel_insert_slice %[[T1]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
421        tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
422      }
423    }
424    // CHECK: return %[[R0]]#0, %[[R0]]#1
425    func.return %2, %0 : tensor<?xf32>, tensor<?xf32>
426    // CHECK: }
427  }
428
429  module attributes {transform.with_named_sequence} {
430    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
431      %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic">
432      %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
433
434      // linalg.generic is tileable. The op is tiled and fused.
435      transform.structured.fuse_into_containing_op %0 into %1
436        : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
437        transform.yield
438    }
439  }
440}
441
442// -----
443
444#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
445#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
446#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
447#map3 = affine_map<(d0, d1) -> (d0, d1)>
448#map4 = affine_map<(d0, d1) -> (d0)>
449
450module {
451  // CHECK-LABEL: func.func @fuse_tileable_reductions
452  //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index
453  //  CHECK-SAME:   %[[IN:[0-9a-z]+]]: tensor<?x?xf32>
454  //  CHECK-SAME:   %[[OUT_1:[0-9a-z]+]]: tensor<?xf32>
455  //  CHECK-SAME:   %[[OUT_2:[0-9a-z]+]]: tensor<?xf32>
456  //  CHECK-SAME:   %[[OUT_3:[0-9a-z]+]]: tensor<?xf32>
457  func.func @fuse_tileable_reductions(%idx: index, %in: tensor<?x?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>)
458    -> (tensor<?xf32>, tensor<?xf32>) {
459    %cst = arith.constant 4.200000e+01 : f32
460    %c0 = arith.constant 0 : index
461
462    %0 = linalg.generic {
463      indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]
464      } ins(%in : tensor<?x?xf32>) outs(%out_1 : tensor<?xf32>) {
465        ^bb0(%a: f32, %b: f32):
466          %d = arith.maximumf %a, %b : f32
467          linalg.yield %d : f32
468        } -> tensor<?xf32>
469    %d0 = tensor.dim %out_1, %c0 : tensor<?xf32>
470
471    %1 = affine.apply #map0()[%d0, %idx]
472
473    // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]])
474    // CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) {
475    %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
476      // CHECK: %[[I0:.*]] = affine.apply {{.*}}
477      %3 = affine.apply #map1(%i)[%idx]
478      // CHECK: %[[I1:.*]] = affine.min {{.*}}
479      %4 = affine.min #map2(%i)[%d0, %idx]
480      %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
481
482      // CHECK: %[[T1:.*]] = linalg.generic {{.*}}
483      %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
484
485      %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
486      scf.forall.in_parallel {
487        // CHECK: tensor.parallel_insert_slice %[[T1]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
488        tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
489      }
490    }
491    // CHECK: return %[[R0]]#0, %[[R0]]#1
492    func.return %2, %0 : tensor<?xf32>, tensor<?xf32>
493    // CHECK: }
494  }
495
496  module attributes {transform.with_named_sequence} {
497    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
498      %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic">
499      %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
500
501      // linalg.generic is tileable. The op is tiled and fused.
502      transform.structured.fuse_into_containing_op %0 into %1
503        : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
504        transform.yield
505    }
506  }
507}
508
509// -----
510
511#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
512#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
513#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
514#map3 = affine_map<(d0) -> (d0)>
515
516module {
517  // CHECK-LABEL: func.func @fuse_tileable_using_new_handle
518  //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index
519  //  CHECK-SAME:   %[[IN:[0-9a-z]+]]: tensor<?xf32>
520  //  CHECK-SAME:   %[[OUT_1:[0-9a-z]+]]: tensor<?xf32>
521  //  CHECK-SAME:   %[[OUT_2:[0-9a-z]+]]: tensor<?xf32>
522  //  CHECK-SAME:   %[[OUT_3:[0-9a-z]+]]: tensor<?xf32>
523  func.func @fuse_tileable_using_new_handle(%idx: index, %in: tensor<?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>)
524    -> (tensor<?xf32>, tensor<?xf32>) {
525    %cst = arith.constant 4.200000e+01 : f32
526    %c0 = arith.constant 0 : index
527
528    %0 = linalg.generic {
529      indexing_maps = [#map3, #map3], iterator_types = ["parallel"]
530      } ins(%in : tensor<?xf32>) outs(%out_1 : tensor<?xf32>) {
531        ^bb0(%a: f32, %b: f32):
532          %d = arith.addf %a, %b : f32
533          linalg.yield %d : f32
534        } -> tensor<?xf32>
535
536    %1 = linalg.generic {
537      indexing_maps = [#map3, #map3], iterator_types = ["parallel"]
538      } ins(%0 : tensor<?xf32>) outs(%out_1 : tensor<?xf32>) {
539        ^bb0(%a: f32, %b: f32):
540          %d = arith.mulf %a, %b : f32
541          linalg.yield %d : f32
542        } -> tensor<?xf32>
543    %d0 = tensor.dim %out_1, %c0 : tensor<?xf32>
544
545    %2 = affine.apply #map0()[%d0, %idx]
546
547    // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]])
548    // CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) {
549    %3 = scf.forall (%i) in (%2) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
550      // CHECK: %[[I0:.*]] = affine.apply {{.*}}
551      %4 = affine.apply #map1(%i)[%idx]
552      // CHECK: %[[I1:.*]] = affine.min {{.*}}
553      %5 = affine.min #map2(%i)[%d0, %idx]
554      %6 = tensor.extract_slice %o[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
555
556      // CHECK: %[[T1:.*]] = linalg.generic {{.*}}
557      // CHECK: %[[T2:.*]] = linalg.generic {{.*}}
558      %7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
559
560      %8 = linalg.elemwise_unary ins(%7 : tensor<?xf32>) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
561      scf.forall.in_parallel {
562        // CHECK: tensor.parallel_insert_slice %[[T2]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
563        tensor.parallel_insert_slice %8 into %o[%2] [%5] [1] : tensor<?xf32> into tensor<?xf32>
564      }
565    }
566    // CHECK: return %[[R0]]#0, %[[R0]]#1
567    func.return %3, %1 : tensor<?xf32>, tensor<?xf32>
568    // CHECK: }
569  }
570
571  module attributes {transform.with_named_sequence} {
572    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
573      %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic">
574      %add, %reduce = transform.split_handle %0 : (!transform.op<"linalg.generic">) -> (!transform.op<"linalg.generic">, !transform.op<"linalg.generic">)
575      %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
576
577      %fused_ops, %new_forall = transform.structured.fuse_into_containing_op %reduce into %1
578        : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.op<"scf.forall">)
579      %fused_ops_2, %new_forall_2 = transform.structured.fuse_into_containing_op %add into %new_forall
580        : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.op<"scf.forall">)
581        transform.yield
582    }
583  }
584}
585
586// -----
587
588// This is a regression test. Make sure that the transform succeeds and valid
589// IR is generated.
590
591module {
592  // CHECK-LABEL: func.func @softmax_dispatch_0_generic_16x128x128_f32
593  func.func @softmax_dispatch_0_generic_16x128x128_f32() -> tensor<16x128x128xf32> {
594    %c0 = arith.constant 0 : index
595    %cst = arith.constant dense<5.000000e+00> : tensor<16x128x128xf32>
596    %cst_1 = arith.constant 5.000000e+00 : f32
597    %1 = tensor.empty() : tensor<16x128xf32>
598    %2 = tensor.empty() : tensor<16x128x128xf32>
599    %3 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<16x128xf32>) -> tensor<16x128xf32>
600    %4 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<16x128xf32>) -> tensor<16x128xf32>
601    %5 = linalg.generic {producer, indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%cst : tensor<16x128x128xf32>) outs(%4 : tensor<16x128xf32>) {
602    ^bb0(%in: f32, %out: f32):
603      %8 = arith.maximumf %in, %out : f32
604      linalg.yield %8 : f32
605    } -> tensor<16x128xf32>
606    %c16 = arith.constant 16 : index
607    %c32 = arith.constant 32 : index
608    %7 = scf.forall (%arg0, %arg1) in (16, 32) shared_outs(%arg2 = %2) -> (tensor<16x128x128xf32>) {
609      %11 = affine.apply affine_map<(d0) -> (d0 * 4)>(%arg1)
610      %extracted_slice = tensor.extract_slice %5[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32>
611      %extracted_slice_3 = tensor.extract_slice %2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<16x128x128xf32> to tensor<1x4x128xf32>
612      %extracted_slice_4 = tensor.extract_slice %3[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32>
613      %15:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%extracted_slice : tensor<1x4xf32>) outs(%extracted_slice_3, %extracted_slice_4 : tensor<1x4x128xf32>, tensor<1x4xf32>) {
614      ^bb0(%in: f32, %out: f32, %out_9: f32):
615        %22 = arith.subf %cst_1, %in : f32
616        %23 = math.exp %22 : f32
617        %24 = arith.addf %23, %out_9 : f32
618        linalg.yield %23, %24 : f32, f32
619      } -> (tensor<1x4x128xf32>, tensor<1x4xf32>)
620      %extracted_slice_5 = tensor.extract_slice %5[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32>
621      %extracted_slice_6 = tensor.extract_slice %2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<16x128x128xf32> to tensor<1x4x128xf32>
622      %extracted_slice_7 = tensor.extract_slice %3[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32>
623      %19:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%extracted_slice_5 : tensor<1x4xf32>) outs(%extracted_slice_6, %extracted_slice_7 : tensor<1x4x128xf32>, tensor<1x4xf32>) {
624      ^bb0(%in: f32, %out: f32, %out_9: f32):
625        %22 = arith.subf %cst_1, %in : f32
626        %23 = math.exp %22 : f32
627        %24 = arith.addf %23, %out_9 : f32
628        linalg.yield %23, %24 : f32, f32
629      } -> (tensor<1x4x128xf32>, tensor<1x4xf32>)
630      %extracted_slice_8 = tensor.extract_slice %arg2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<16x128x128xf32> to tensor<1x4x128xf32>
631      %20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15#0, %19#1 : tensor<1x4x128xf32>, tensor<1x4xf32>) outs(%extracted_slice_8 : tensor<1x4x128xf32>) {
632      ^bb0(%in: f32, %in_9: f32, %out: f32):
633        %22 = arith.divf %in, %in_9 : f32
634        linalg.yield %22 : f32
635      } -> tensor<1x4x128xf32>
636      scf.forall.in_parallel {
637        tensor.parallel_insert_slice %20 into %arg2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<1x4x128xf32> into tensor<16x128x128xf32>
638      }
639    }
640    return %7 : tensor<16x128x128xf32>
641  }
642
643  module attributes {transform.with_named_sequence} {
644    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
645      %0 = transform.structured.match attributes{producer} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic">
646      %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
647      transform.structured.fuse_into_containing_op %0 into %1
648        : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
649        transform.yield
650    }
651  }
652}
653
654
655////////////////////////////////////////////////////////////////////////////////
656// Tests below are expected to fail.
657////////////////////////////////////////////////////////////////////////////////
658
659// -----
660
661// NO-CHECK-LABEL-ON-EXPECTED-ERROR
662func.func @copy_1d_1024xf16(%arg0: tensor<123x456xf32>, %arg1: tensor<456x789xf32>, %arg2 : tensor<123x789xf32>) -> tensor<123x789xf32> {
663  %0 = arith.constant 0.000000e+00 : f32
664  %1 = linalg.fill ins(%0 : f32) outs(%arg2 : tensor<123x789xf32>) -> tensor<123x789xf32>
665  // expected-note @below {{containing op}}
666  %2 = linalg.matmul ins(%arg0, %arg1 : tensor<123x456xf32>, tensor<456x789xf32>) outs(%1 : tensor<123x789xf32>) -> tensor<123x789xf32>
667  return %2 : tensor<123x789xf32>
668}
669
670module attributes {transform.with_named_sequence} {
671  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
672    %0 = transform.structured.match ops{["linalg.fill"]} in %arg1
673      : (!transform.any_op) -> !transform.any_op
674    %1 = transform.structured.match ops{["linalg.matmul"]} in %arg1
675      : (!transform.any_op) -> !transform.any_op
676    %tiled_op, %forall_op = transform.structured.tile_using_forall %1
677      num_threads [] tile_sizes [50, 16]
678      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
679    // Note that we pass in %tiled_op, which isn't a container op.
680    // expected-error @+2 {{could not find next producer to fuse into container}}
681    %fused_op, %new_containing_op =
682      transform.structured.fuse_into_containing_op %0 into %tiled_op
683        : (!transform.any_op, !transform.any_op)
684          -> (!transform.any_op, !transform.any_op)
685          transform.yield
686  }
687}
688