xref: /llvm-project/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir (revision 91bbebc7e118cceae1fc0e349de08094a3cd2fe7)
1// RUN: mlir-opt %s -transform-interpreter -split-input-file -canonicalize -cse -verify-diagnostics | FileCheck %s
2
3func.func @reduction_tile(%arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
4  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
5                                          affine_map<(d0, d1) -> (d0)>],
6   iterator_types = ["parallel", "reduction"]}
7   ins(%arg0 : tensor<?x?xf32>)
8   outs(%out : tensor<?xf32>) {
9    ^bb0(%arg7: f32, %arg9: f32):
10      %1 = arith.mulf %arg7, %arg7 : f32
11      %2 = arith.addf %1, %arg9 : f32
12      linalg.yield %2 : f32
13    } -> tensor<?xf32>
14  return %red : tensor<?xf32>
15}
16
17module attributes {transform.with_named_sequence} {
18  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
19    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
20    %1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
21      by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
22      transform.yield
23  }
24}
25
26// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
27// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
28//     CHECK: func @reduction_tile(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?xf32>
29// CHECK-DAG:   %[[I:.*]] = arith.constant 0.000000e+00 : f32
30// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
31// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
32// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
33// CHECK-DAG:   %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
34// CHECK-DAG:   %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
35//     CHECK:   %[[E:.*]] = tensor.empty(%[[D0]]) : tensor<?x5xf32>
36//     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
37//     CHECK:   %[[L:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[D1]] step %[[C5]] iter_args(%[[ARG3:.*]] = %[[F]]) -> (tensor<?x5xf32>) {
38//     CHECK:     %[[PS:.*]] = affine.min #[[MAP0]](%[[K]])[%[[D1]]]
39//     CHECK:     %[[EXT2:.*]] = tensor.extract_slice %[[ARG0]][0, %[[K:.*]]] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
40//     CHECK:     %[[EXT:.*]] = tensor.extract_slice %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x5xf32> to tensor<?x?xf32>
41//     CHECK:     %[[PR:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[EXT2]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>) {
42//     CHECK:       arith.mulf
43//     CHECK:       arith.addf
44//     CHECK:       linalg.yield
45//     CHECK:     } -> tensor<?x?xf32>
46//     CHECK:     %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
47//     CHECK:     scf.yield %[[INS]] : tensor<?x5xf32>
48//     CHECK:   }
49//     CHECK:   %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
50//     CHECK:     arith.addf
51//     CHECK:     linalg.yield
52//     CHECK:   }
53//     CHECK:   return %[[R]] : tensor<?xf32>
54
55// -----
56
57func.func @reduction_tile_transpose(%arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
58  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
59                                          affine_map<(d0, d1) -> (d1)>],
60   iterator_types = ["reduction", "parallel"]}
61   ins(%arg0 : tensor<?x?xf32>)
62   outs(%out : tensor<?xf32>) {
63    ^bb0(%arg7: f32, %arg9: f32):
64      %42 = arith.addf %arg7, %arg9 : f32
65      linalg.yield %42 : f32
66    } -> tensor<?xf32>
67  return %red : tensor<?xf32>
68}
69
70module attributes {transform.with_named_sequence} {
71  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
72    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
73    %1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
74      by tile_sizes = [5, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
75      transform.yield
76  }
77}
78
79// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
80// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
81// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)>
82//     CHECK: func @reduction_tile_transpose
83//     CHECK:   tensor.empty(%{{.*}}) : tensor<?x5xf32>
84//     CHECK:   linalg.fill {{.*}} : tensor<?x5xf32>) -> tensor<?x5xf32>
85//     CHECK:   scf.for
86//     CHECK:     %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor<?x5xf32> to tensor<?x?xf32>
87//     CHECK:     %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>)
88//     CHECK:     %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
89//     CHECK:     scf.yield {{.*}} : tensor<?x5xf32>
90//     CHECK:   }
91//     CHECK:   linalg.reduce
92//     CHECK:   return
93
94// -----
95
96func.func @reduction_tile_parallel(
97  %arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
98  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
99                                          affine_map<(d0, d1) -> (d0)>],
100   iterator_types = ["parallel", "reduction"]}
101   ins(%arg0 : tensor<?x?xf32>)
102   outs(%out : tensor<?xf32>) {
103    ^bb0(%arg7: f32, %arg9: f32):
104      %1 = arith.mulf %arg7, %arg7 : f32
105      %2 = arith.addf %1, %arg9 : f32
106      linalg.yield %2 : f32
107    } -> tensor<?xf32>
108  return %red : tensor<?xf32>
109}
110
111module attributes {transform.with_named_sequence} {
112  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
113    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
114    %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0
115      by num_threads = [0, 5], tile_sizes = [] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
116      transform.yield
117  }
118}
119
120// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)>
121// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (0, d0)>
122// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 5))>
123// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, d1)>
124// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1) -> (d0)>
125//     CHECK: func @reduction_tile_parallel(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?xf32>
126// CHECK-DAG:   %[[I:.*]] = arith.constant 0.000000e+00 : f32
127// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
128// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
129// CHECK-DAG:   %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
130// CHECK-DAG:   %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
131//     CHECK:   %[[E:.*]] = tensor.empty(%[[D0]]) : tensor<?x5xf32>
132//     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
133//     CHECK:   %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
134// CHECK-DAG:     %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
135// CHECK-DAG:     %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
136// CHECK-DAG:     %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
137//     CHECK:     %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]]
138//     CHECK:     %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
139//     CHECK:     %[[TEMPEXT:.+]] = tensor.extract_slice %[[ET]][0] [%[[D0]]] [1] : tensor<?xf32> to tensor<?xf32>
140//     CHECK:     %[[PARTIAL:.+]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "reduction"]} ins(%[[INCHUNK]] : tensor<?x?xf32>) outs(%[[TEMPEXT]] : tensor<?xf32>) {
141//     CHECK:       arith.mulf
142//     CHECK:       arith.addf
143//     CHECK:       linalg.yield
144//     CHECK:     } -> tensor<?xf32>
145//     CHECK:     scf.forall.in_parallel {
146//     CHECK:       tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
147//     CHECK:     }
148//     CHECK:   }
149//     CHECK:   %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
150//     CHECK:   {
151//     CHECK:     arith.addf
152//     CHECK:     linalg.yield
153//     CHECK:   }
154//     CHECK:   return %[[R]] : tensor<?xf32>
155
156// -----
157
158func.func @matmul_tile_parallel(
159  %A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %out: tensor<?x?xf32>) -> tensor<?x?xf32> {
160  %matmul = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
161                     outs(%out: tensor<?x?xf32>) -> tensor<?x?xf32>
162  return %matmul : tensor<?x?xf32>
163}
164
165module attributes {transform.with_named_sequence} {
166  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
167    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
168    %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0
169      by num_threads = [0, 0, 5], tile_sizes = [] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
170      transform.yield
171  }
172}
173
174// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)>
175// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (0, d0)>
176// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 5))>
177//     CHECK: func @matmul_tile_parallel(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>
178// CHECK-DAG:   %[[I:.*]] = arith.constant 0.000000e+00 : f32
179// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
180// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
181// CHECK-DAG:   %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
182// CHECK-DAG:   %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
183// CHECK-DAG:   %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
184//     CHECK:   %[[E:.*]] = tensor.empty(%[[D0]], %[[D2]]) : tensor<?x?x5xf32>
185//     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x?x5xf32>) -> tensor<?x?x5xf32>
186//     CHECK:   %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x?x5xf32>) {
187// CHECK-DAG:     %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
188// CHECK-DAG:     %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
189// CHECK-DAG:     %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?x5xf32> to tensor<?x?xf32>
190//     CHECK:     %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]]
191//     CHECK:     %[[INCHUNKA:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
192//     CHECK:     %[[INCHUNKB:.+]] = tensor.extract_slice %[[ARG1]][%[[TINDEX]], 0] [%[[TS1]], %[[D2]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
193//     CHECK:     %[[TEMPEXT:.+]] = tensor.extract_slice %[[ET]][0, 0] [%[[D0]], %[[D2]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
194//     CHECK:     %[[PARTIAL:.+]] = linalg.matmul ins(%[[INCHUNKA]], %[[INCHUNKB]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[TEMPEXT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
195//     CHECK:     scf.forall.in_parallel {
196//     CHECK:       tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x5xf32>
197//     CHECK:     }
198//     CHECK:   }
199//     CHECK:   %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x?x5xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) dimensions = [2]
200//     CHECK:     arith.addf
201//     CHECK:     linalg.yield
202//     CHECK:   }
203//     CHECK:   return %[[R]] : tensor<?x?xf32>
204
205// -----
206
207func.func @reduction_tile_parallel_cyclic_dist(
208  %arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
209  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
210                                          affine_map<(d0, d1) -> (d0)>],
211   iterator_types = ["parallel", "reduction"]}
212   ins(%arg0 : tensor<?x?xf32>)
213   outs(%out : tensor<?xf32>) {
214    ^bb0(%arg7: f32, %arg9: f32):
215      %1 = arith.mulf %arg7, %arg7 : f32
216      %2 = arith.addf %1, %arg9 : f32
217      linalg.yield %2 : f32
218    } -> tensor<?xf32>
219  return %red : tensor<?xf32>
220}
221
222module attributes {transform.with_named_sequence} {
223  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
224    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
225    %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0
226      by num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread<x>] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
227      transform.yield
228  }
229}
230
231// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 3)>
232// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)>
233// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
234// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0)>
235
236//     CHECK: func @reduction_tile_parallel_cyclic_dist(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?xf32>
237// CHECK-DAG:   %[[I:.*]] = arith.constant 0.000000e+00 : f32
238// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
239// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
240// CHECK-DAG:   %[[C15:.*]] = arith.constant 15 : index
241// CHECK-DAG:   %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
242//     CHECK:   %[[E:.*]] = tensor.empty(%[[D0]]) : tensor<?x5xf32>
243//     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
244//     CHECK:   %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
245//     CHECK:     %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
246//     CHECK:     %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
247//     CHECK:     %[[LB:.+]] = affine.apply #[[MAP0]]()[%[[IV]]]
248//     CHECK:     %[[CARRY:.+]] = scf.for %[[IV1:.+]] = %[[LB]] to %[[D1]] step %[[C15]] iter_args(%[[ACC:.+]] = %[[ET]]) -> (tensor<?xf32>) {
249//     CHECK:       %[[TS0:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[D1]]]
250//     CHECK:       %[[D3:.+]] = tensor.dim %[[ACC]], %[[C0]] : tensor<?xf32>
251//     CHECK:       %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV1]]] [%[[D0]], %[[TS0]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
252//     CHECK:       %[[TEMPEXT:.+]] = tensor.extract_slice %[[ACC]][0] [%[[D3]]] [1] : tensor<?xf32> to tensor<?xf32>
253//     CHECK:       %[[PARTIAL:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]]], iterator_types = ["parallel", "reduction"]} ins(%[[INCHUNK]] : tensor<?x?xf32>) outs(%[[TEMPEXT]] : tensor<?xf32>) {
254//     CHECK:         arith.mulf
255//     CHECK:         arith.addf
256//     CHECK:         linalg.yield
257//     CHECK:       } -> tensor<?xf32>
258//     CHECK:       %[[INS:.+]] = tensor.insert_slice %[[PARTIAL]] into %[[ACC]][0] [%[[D3]]] [1] : tensor<?xf32> into tensor<?xf32>
259//     CHECK:       scf.yield %[[INS]] : tensor<?xf32>
260//     CHECK:     }
261//     CHECK:     scf.forall.in_parallel {
262//     CHECK:       tensor.parallel_insert_slice %[[CARRY]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
263//     CHECK:     }
264//     CHECK:   }
265//     CHECK:   %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
266//     CHECK:     arith.addf
267//     CHECK:     linalg.yield
268//     CHECK:   }
269//     CHECK:   return %[[R]] : tensor<?xf32>
270
271// -----
272
273func.func @reduction_tile_parallel_cyclic_dist(
274  %arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
275  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
276                                          affine_map<(d0, d1) -> (d0)>],
277   iterator_types = ["parallel", "reduction"]}
278   ins(%arg0 : tensor<?x?xf32>)
279   outs(%out : tensor<?xf32>) {
280    ^bb0(%arg7: f32, %arg9: f32):
281      %1 = arith.mulf %arg7, %arg7 : f32
282      %2 = arith.addf %1, %arg9 : f32
283      linalg.yield %2 : f32
284    } -> tensor<?xf32>
285  return %red : tensor<?xf32>
286}
287
288module attributes {transform.with_named_sequence} {
289  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
290    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
291    %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0
292      by num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread<x>] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
293
294    //      CHECK:     expecting fill
295    // CHECK-NEXT:     linalg.fill
296    transform.print %1 {name = "expecting fill"} : !transform.any_op
297    //      CHECK:     expecting parallel reduction
298    // CHECK-NEXT:     linalg.generic
299    //      CHECK:     iterator_types = ["parallel", "reduction"]
300    transform.print %2 {name = "expecting parallel reduction"} : !transform.any_op
301    //      CHECK:     expecting parallel reduction
302    // CHECK-NEXT:     linalg.reduce
303    //      CHECK:     iterator_types = ["parallel", "reduction"]
304    transform.print %3 {name = "expecting parallel reduction"} : !transform.any_op
305    transform.yield
306  }
307}
308
309// -----
310
311func.func @reduction_untiled_forall(
312  %arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
313  // expected-note @below {{target operation}}
314  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
315                                          affine_map<(d0, d1) -> (d0)>],
316   iterator_types = ["parallel", "reduction"]}
317   ins(%arg0 : tensor<?x?xf32>)
318   outs(%out : tensor<?xf32>) {
319    ^bb0(%arg7: f32, %arg9: f32):
320      %1 = arith.mulf %arg7, %arg7 : f32
321      %2 = arith.addf %1, %arg9 : f32
322      linalg.yield %2 : f32
323    } -> tensor<?xf32>
324  return %red : tensor<?xf32>
325}
326
327module attributes {transform.with_named_sequence} {
328  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
329    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
330    // expected-error @below {{could not tile reduction}}
331    %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0
332      by num_threads = [5], tile_sizes = [3], mapping = [#gpu.thread<x>] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
333
334      transform.yield
335  }
336}
337
338// -----
339
340#map = affine_map<(d0, d1) -> (d0, d1)>
341#map1 = affine_map<(d0, d1) -> (d0)>
342
343module {
344  func.func @fail_for_float_neutral(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
345    // expected-error @below {{'linalg.generic' op Failed to get an identity value for the reduction operation.}}
346    // expected-note @below {{when applied to this op}}
347    %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
348    ^bb0(%in: f32, %out: f32):
349      %1 = llvm.fmul %in, %in  : f32
350      %2 = llvm.fadd %1, %out  : f32
351      linalg.yield %2 : f32
352    } -> tensor<?xf32>
353    return %0 : tensor<?xf32>
354  }
355  module attributes {transform.with_named_sequence} {
356    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
357      %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
358      // expected-error @below {{transform.structured.tile_reduction_using_for failed to apply}}
359      %fill_op, %split_linalg_op, %combining_linalg_op, %for_op = transform.structured.tile_reduction_using_for %0 by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
360      transform.yield
361    }
362  }
363}
364
365// -----
366
367#map = affine_map<(d0, d1, d2) -> (d1, d2)>
368#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
369#map2 = affine_map<(d0, d1, d2) -> (d0)>
370module {
371  func.func @reduction_tile_multiple_reduction(%arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> {
372    %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) {
373    ^bb0(%in: f32, %in_0: f32, %out: f32):
374      %1 = arith.mulf %in, %in_0 : f32
375      %2 = arith.addf %1, %out : f32
376      linalg.yield %2 : f32
377    } -> tensor<4096xf32>
378    return %0 : tensor<4096xf32>
379  }
380  module attributes {transform.with_named_sequence} {
381    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
382      %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
383      %fill_op, %split_linalg_op, %combining_linalg_op, %for_op = transform.structured.tile_reduction_using_for %0 by tile_sizes = [0, 2, 64] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
384      transform.yield
385    }
386  }
387}
388
389// CHECK: func @reduction_tile_multiple_reduction(%[[ARG0:.+]]: tensor<86x128xf32>, %[[ARG1:.+]]: tensor<4096x86x128xf32>, %[[ARG2:.+]]: tensor<4096xf32>
390// CHECK:   %[[F:.*]] = linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<4096x2x64xf32>) -> tensor<4096x2x64xf32>
391// CHECK:   %[[L0:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG3:.*]] = %[[F]]) -> (tensor<4096x2x64xf32>)
392// CHECK:     %[[L1:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]]) -> (tensor<4096x2x64xf32>)
393// CHECK:       %[[OUT:.*]] = linalg.generic  {indexing_maps = [{{.*}}, {{.*}}, {{.*}}], iterator_types = ["parallel", "parallel", "parallel"]} ins(%{{.*}}, %{{.*}}: tensor<2x64xf32>, tensor<4096x2x64xf32>) outs(%{{.*}}: tensor<4096x2x64xf32>)
394// CHECK:       scf.yield %[[OUT]] : tensor<4096x2x64xf32>
395// CHECK:     scf.yield %[[L1]] : tensor<4096x2x64xf32>
396// CHECK:   %[[OUT2:.*]] = linalg.reduce ins(%{{.*}} : tensor<4096x2x64xf32>) outs(%{{.*}} : tensor<4096xf32>)
397// CHECK:  return %[[OUT2]] : tensor<4096xf32>
398
399// -----
400
401func.func @reduction_tile_multiple_results(%arg0: tensor<?x?xf32>, %out: tensor<?xf32>, %out2: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
402  %red:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
403                                            affine_map<(d0, d1) -> (d0)>,
404                                            affine_map<(d0, d1) -> (d0)>],
405   iterator_types = ["parallel", "reduction"]}
406   ins(%arg0 : tensor<?x?xf32>)
407   outs(%out, %out2 : tensor<?xf32>, tensor<?xf32>) {
408    ^bb0(%arg7: f32, %arg9: f32, %arg9_1: f32):
409      %1 = arith.mulf %arg7, %arg7 : f32
410      %2 = arith.addf %1, %arg9 : f32
411      %3 = arith.maximumf %1, %arg9_1 : f32
412      linalg.yield %2, %3 : f32, f32
413    } -> (tensor<?xf32>, tensor<?xf32>)
414  return %red#0, %red#1 : tensor<?xf32>, tensor<?xf32>
415}
416
417module attributes {transform.with_named_sequence} {
418  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
419    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
420    %1, %12, %2, %3, %4, %loop = transform.structured.tile_reduction_using_for %0
421      by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
422      transform.yield
423  }
424}
425
426// CHECK: func @reduction_tile_multiple_results
427// CHECK-DAG:   %[[SUM_ID:.+]] = arith.constant 0.000000e+00 : f32
428// CHECK-DAG:   %[[MAX_ID:.+]] = arith.constant 0xFF800000 : f32
429// CHECK-DAG:   %[[SUM_INIT:.+]] = linalg.fill ins(%[[SUM_ID]] : f32) outs(%{{.*}} : tensor<?x5xf32>) -> tensor<?x5xf32>
430// CHECK-DAG:   %[[MAX_INIT:.+]] = linalg.fill ins(%[[MAX_ID]] : f32) outs(%{{.*}} : tensor<?x5xf32>) -> tensor<?x5xf32>
431// CHECK:       %[[OUT:.+]]:2 = scf.for
432// CHECK-SAME:            iter_args(%[[SUM:.+]] = %[[SUM_INIT]], %[[MAX:.+]] = %[[MAX_INIT]])
433// CHECK:         %[[UPDATED:.*]]:2 = linalg.generic
434// CHECK:         arith.mulf
435// CHECK:         arith.addf
436// CHECK:         arith.maximumf
437// CHECK:       %[[INSERT1:.+]] = tensor.insert_slice %[[UPDATED]]#0 into %[[SUM]]
438// CHECK:       %[[INSERT2:.+]] = tensor.insert_slice %[[UPDATED]]#1 into %[[MAX]]
439// CHECK:       scf.yield %[[INSERT1]], %[[INSERT1]]
440// CHECK:       linalg.reduce
441// CHECK:         arith.addf
442// CHECK:       linalg.reduce
443// CHECK:         arith.maximumf
444
445// -----
446
447func.func @reduction_tile_multi_dim_transpose(%arg0: tensor<?x?x?xf32>, %out: tensor<?x?xf32>) -> tensor<?x?xf32> {
448  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
449                                          affine_map<(d0, d1, d2) -> (d2, d0)>],
450   iterator_types = ["parallel", "reduction", "parallel"]}
451   ins(%arg0 : tensor<?x?x?xf32>)
452   outs(%out : tensor<?x?xf32>) {
453    ^bb0(%arg7: f32, %arg9: f32):
454      %42 = arith.addf %arg7, %arg9 : f32
455      linalg.yield %42 : f32
456    } -> tensor<?x?xf32>
457  return %red : tensor<?x?xf32>
458}
459
460module attributes {transform.with_named_sequence} {
461  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
462    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
463    %1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
464      by tile_sizes = [0, 5, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
465      transform.yield
466  }
467}
468
469// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
470// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
471//     CHECK: func @reduction_tile_multi_dim_transpose
472//     CHECK:   tensor.empty(%{{.*}}) : tensor<?x?x5xf32>
473//     CHECK:   linalg.fill {{.*}} : tensor<?x?x5xf32>) -> tensor<?x?x5xf32>
474//     CHECK:   scf.for
475//     CHECK:     %[[K:.*]] = affine.min
476//     CHECK:     %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0, 0] [%[[D2:.*]], %[[D0:.*]], %[[K]]] [1, 1, 1] : tensor<?x?x5xf32> to tensor<?x?x?xf32>
477//     CHECK:     %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?x?xf32>) outs(%[[EXT]] : tensor<?x?x?xf32>)
478//     CHECK:     %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0, 0] [%[[D2]], %[[D0]], %[[K]]] [1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x5xf32>
479//     CHECK:     scf.yield {{.*}} : tensor<?x?x5xf32>
480//     CHECK:   }
481//     CHECK:   linalg.reduce
482//     CHECK:   return
483