xref: /llvm-project/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir (revision 0981dca7779d4acfcbb92fbb29a7a1033e283b88)
1// RUN: mlir-opt --split-input-file --transform-interpreter %s | FileCheck %s
2
3func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
4  %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>)
5                    outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
6  return %0: tensor<16x32xf32>
7}
8
9//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
10//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)>
11//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
12//  CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
13//  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
14//  CHECK-LABEL: @matmul_split
15//  CHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32
16//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [16, 4, 64] : tensor<16x256xf32> into tensor<16x4x64xf32>
17//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [4, 64, 32] : tensor<256x32xf32> into tensor<4x64x32xf32>
18//  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<16x32x4xf32>
19//      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32>
20//      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
21// CHECK-SAME:   , iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
22// CHECK-SAME:   ins(%[[I1]], %[[I2]] : tensor<16x4x64xf32>, tensor<4x64x32xf32>) outs(%[[F]] : tensor<16x32x4xf32>) {
23//      CHECK:   arith.mulf
24//      CHECK:   arith.addf
25//      CHECK:   linalg.yield
26//      CHECK: } -> tensor<16x32x4xf32>
27//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]],
28// CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[G]] : tensor<16x32x4xf32>) outs(%{{.*}} : tensor<16x32xf32>) {
29//      CHECK:   arith.addf
30//      CHECK:   linalg.yield %{{.*}} : f32
31//      CHECK: } -> tensor<16x32xf32>
32//      CHECK: return %[[R]] : tensor<16x32xf32>
33
34module attributes {transform.with_named_sequence} {
35  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
36    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
37    %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2}
38      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
39      transform.yield
40  }
41}
42
43// -----
44
45func.func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor<f32>, %out: tensor<f32>) -> tensor<f32> {
46  %red = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
47                                          affine_map<(d0) -> ()>,
48                                          affine_map<(d0) -> ()>],
49   iterator_types = ["reduction"]}
50   ins(%arg0, %arg1 : tensor<32xf32>, tensor<f32>)
51   outs(%out : tensor<f32>) {
52    ^bb0(%arg7: f32, %arg8: f32, %arg9: f32):
53      %40 = arith.subf %arg7, %arg8 : f32
54      %41 = math.exp %40 : f32
55      %42 = arith.mulf %41, %arg9 : f32
56      linalg.yield %42 : f32
57    } -> tensor<f32>
58  return %red : tensor<f32>
59}
60
61//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
62//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()>
63//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
64//  CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0)>
65//  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()>
66//CHECK-LABEL: @generic_split_1d
67//  CHECK-DAG: %[[ID:.*]] = arith.constant 1.000000e+00 : f32
68//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [4, 8] : tensor<32xf32> into tensor<4x8xf32>
69//  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32>
70//      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
71//      CHECK: %[[G:.*]] = linalg.generic
72//      CHECK:   {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]],
73//      CHECK:   iterator_types = ["parallel", "reduction"]} ins(%[[I1]], %{{.*}} : tensor<4x8xf32>, tensor<f32>) outs(%[[F]] : tensor<4xf32>) {
74//      CHECK:   arith.subf
75//      CHECK:   math.exp
76//      CHECK:   arith.mulf
77//      CHECK:   linalg.yield
78//      CHECK: } -> tensor<4xf32>
79//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["reduction"]} ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor<f32>) {
80//      CHECK:   arith.mulf
81//      CHECK:   linalg.yield
82//      CHECK: } -> tensor<f32>
83//      CHECK: return %[[R]] : tensor<f32>
84
85module attributes {transform.with_named_sequence} {
86  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
87    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
88    %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 0}
89      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
90      transform.yield
91  }
92}
93
94// -----
95
96func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>, %output: tensor<5x2xf32>)
97  -> tensor<5x2xf32>
98{
99  %0 = linalg.generic {
100      indexing_maps = [
101        affine_map<(d0, d1, d2) -> (d1, d0)>,
102        affine_map<(d0, d1, d2) -> (d2, d1)>,
103        affine_map<(d0, d1, d2) -> (d2, d0)>
104      ],
105      iterator_types = ["parallel", "reduction", "parallel"]
106    } ins(%input, %input_2 : tensor<32x2xf32>, tensor<5x32xf32>) outs(%output : tensor<5x2xf32>) {
107    ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
108      %3 = arith.addf %arg0, %arg1 : f32
109      %4 = arith.maximumf %3, %arg2 : f32
110      linalg.yield %4 : f32
111    } -> tensor<5x2xf32>
112  return %0 : tensor<5x2xf32>
113}
114
115//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)>
116//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
117//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)>
118//  CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
119//  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
120// CHECK-LABEL:  func @generic_split_3d
121//  CHECK-DAG: %[[ID:.*]] = arith.constant 0xFF800000 : f32
122//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [4, 8, 2] : tensor<32x2xf32> into tensor<4x8x2xf32>
123//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [5, 4, 8] : tensor<5x32xf32> into tensor<5x4x8xf32>
124//  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32>
125//      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
126//      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
127// CHECK-SAME:   ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) {
128//      CHECK:   arith.addf
129//      CHECK:   arith.maximumf
130//      CHECK:   linalg.yield
131//      CHECK: } -> tensor<5x2x4xf32>
132//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]}
133// CHECK-SAME:   ins(%[[G]] : tensor<5x2x4xf32>) outs(%{{.*}} : tensor<5x2xf32>) {
134//      CHECK:   arith.maximumf
135//      CHECK:   linalg.yield
136//      CHECK:  } -> tensor<5x2xf32>
137//      CHECK: return %[[R]] : tensor<5x2xf32>
138
139module attributes {transform.with_named_sequence} {
140  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
141    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
142    %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2}
143      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
144      transform.yield
145  }
146}
147
148// -----
149
150// Check that we don't use -inf as the neutral element for maxf when maxf has
151// ninf. Instead check that we use the smallest finite floating point value.
152// Also check that the fastmath flags are set on the created maxf
153// instructions.
154func.func @generic_split_3d_ninf(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>, %output: tensor<5x2xf32>)
155  -> tensor<5x2xf32>
156{
157  %0 = linalg.generic {
158      indexing_maps = [
159        affine_map<(d0, d1, d2) -> (d1, d0)>,
160        affine_map<(d0, d1, d2) -> (d2, d1)>,
161        affine_map<(d0, d1, d2) -> (d2, d0)>
162      ],
163      iterator_types = ["parallel", "reduction", "parallel"]
164    } ins(%input, %input_2 : tensor<32x2xf32>, tensor<5x32xf32>) outs(%output : tensor<5x2xf32>) {
165    ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
166      %3 = arith.addf %arg0, %arg1 : f32
167      %4 = arith.maximumf %3, %arg2 fastmath<nnan,ninf> : f32
168      linalg.yield %4 : f32
169    } -> tensor<5x2xf32>
170  return %0 : tensor<5x2xf32>
171}
172
173//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)>
174//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
175//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)>
176//  CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
177//  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
178// CHECK-LABEL:  func @generic_split_3d_ninf
179//  CHECK-DAG: %[[ID:.*]] = arith.constant -3.40282347E+38 : f32
180//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [4, 8, 2] : tensor<32x2xf32> into tensor<4x8x2xf32>
181//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [5, 4, 8] : tensor<5x32xf32> into tensor<5x4x8xf32>
182//  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32>
183//      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
184//      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
185// CHECK-SAME:   ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) {
186//      CHECK:   arith.addf
187//      CHECK:   arith.maximumf {{.*}} fastmath<nnan,ninf>
188//      CHECK:   linalg.yield
189//      CHECK: } -> tensor<5x2x4xf32>
190//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]}
191// CHECK-SAME:   ins(%[[G]] : tensor<5x2x4xf32>) outs(%{{.*}} : tensor<5x2xf32>) {
192//      CHECK:   arith.maximumf {{.*}} fastmath<nnan,ninf>
193//      CHECK:   linalg.yield
194//      CHECK:  } -> tensor<5x2xf32>
195//      CHECK: return %[[R]] : tensor<5x2xf32>
196
197module attributes {transform.with_named_sequence} {
198  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
199    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
200    %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2}
201      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
202      transform.yield
203  }
204}
205
206// -----
207
208func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
209  %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>)
210                    outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
211  return %0: tensor<16x32xf32>
212}
213
214//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
215//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)>
216//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
217//  CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
218//  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
219//  CHECK-LABEL: @matmul_split
220//  CHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32
221//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [16, 64, 4] : tensor<16x256xf32> into tensor<16x64x4xf32>
222//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [64, 4, 32] : tensor<256x32xf32> into tensor<64x4x32xf32>
223//  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<16x32x4xf32>
224//      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32>
225//      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
226// CHECK-SAME:   , iterator_types = ["parallel", "parallel", "reduction", "parallel"]}
227// CHECK-SAME:   ins(%[[I1]], %[[I2]] : tensor<16x64x4xf32>, tensor<64x4x32xf32>) outs(%[[F]] : tensor<16x32x4xf32>) {
228//      CHECK:   arith.mulf
229//      CHECK:   arith.addf
230//      CHECK:   linalg.yield
231//      CHECK: } -> tensor<16x32x4xf32>
232//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]],
233// CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[G]] : tensor<16x32x4xf32>) outs(%{{.*}} : tensor<16x32xf32>) {
234//      CHECK:   arith.addf
235//      CHECK:   linalg.yield %{{.*}} : f32
236//      CHECK: } -> tensor<16x32xf32>
237//      CHECK: return %[[R]] : tensor<16x32xf32>
238
239module attributes {transform.with_named_sequence} {
240  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
241    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
242    %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2, inner_parallel}
243      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
244      transform.yield
245  }
246}
247
248// -----
249
250func.func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor<f32>, %out: tensor<f32>) -> tensor<f32> {
251  %red = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
252                                          affine_map<(d0) -> ()>,
253                                          affine_map<(d0) -> ()>],
254   iterator_types = ["reduction"]}
255   ins(%arg0, %arg1 : tensor<32xf32>, tensor<f32>)
256   outs(%out : tensor<f32>) {
257    ^bb0(%arg7: f32, %arg8: f32, %arg9: f32):
258      %40 = arith.subf %arg7, %arg8 : f32
259      %41 = math.exp %40 : f32
260      %42 = arith.mulf %41, %arg9 : f32
261      linalg.yield %42 : f32
262    } -> tensor<f32>
263  return %red : tensor<f32>
264}
265
266//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
267//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()>
268//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1)>
269//  CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0)>
270//  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()>
271//CHECK-LABEL: @generic_split_1d
272//  CHECK-DAG: %[[ID:.*]] = arith.constant 1.000000e+00 : f32
273//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [8, 4] : tensor<32xf32> into tensor<8x4xf32>
274//  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32>
275//      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
276//      CHECK: %[[G:.*]] = linalg.generic
277//      CHECK:   {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]],
278//      CHECK:   iterator_types = ["reduction", "parallel"]} ins(%[[I1]], %{{.*}} : tensor<8x4xf32>, tensor<f32>) outs(%[[F]] : tensor<4xf32>) {
279//      CHECK:   arith.subf
280//      CHECK:   math.exp
281//      CHECK:   arith.mulf
282//      CHECK:   linalg.yield
283//      CHECK: } -> tensor<4xf32>
284//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["reduction"]} ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor<f32>) {
285//      CHECK:   arith.mulf
286//      CHECK:   linalg.yield
287//      CHECK: } -> tensor<f32>
288//      CHECK: return %[[R]] : tensor<f32>
289
290module attributes {transform.with_named_sequence} {
291  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
292    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
293    %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 0, inner_parallel}
294      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
295      transform.yield
296  }
297}
298
299// -----
300
301func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>, %output: tensor<5x2xf32>)
302  -> tensor<5x2xf32>
303{
304  %0 = linalg.generic {
305      indexing_maps = [
306        affine_map<(d0, d1, d2) -> (d1, d0)>,
307        affine_map<(d0, d1, d2) -> (d2, d1)>,
308        affine_map<(d0, d1, d2) -> (d2, d0)>
309      ],
310      iterator_types = ["parallel", "reduction", "parallel"]
311    } ins(%input, %input_2 : tensor<32x2xf32>, tensor<5x32xf32>) outs(%output : tensor<5x2xf32>) {
312    ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
313      %3 = arith.addf %arg0, %arg1 : f32
314      %4 = arith.minimumf %3, %arg2 : f32
315      linalg.yield %4 : f32
316    } -> tensor<5x2xf32>
317  return %0 : tensor<5x2xf32>
318}
319
320//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d0)>
321//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>
322//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)>
323//  CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
324//  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
325// CHECK-LABEL:  func @generic_split_3d
326//  CHECK-DAG: %[[ID:.*]] = arith.constant 0x7F800000 : f32
327//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [8, 4, 2] : tensor<32x2xf32> into tensor<8x4x2xf32>
328//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [5, 8, 4] : tensor<5x32xf32> into tensor<5x8x4xf32>
329//  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32>
330//      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
331//      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
332// CHECK-SAME:   ins(%[[I1]], %[[I2]] : tensor<8x4x2xf32>, tensor<5x8x4xf32>) outs(%[[F]] : tensor<5x2x4xf32>) {
333//      CHECK:   arith.addf
334//      CHECK:   arith.minimumf
335//      CHECK:   linalg.yield
336//      CHECK: } -> tensor<5x2x4xf32>
337//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]}
338// CHECK-SAME:   ins(%[[G]] : tensor<5x2x4xf32>) outs(%{{.*}} : tensor<5x2xf32>) {
339//      CHECK:   arith.minimumf
340//      CHECK:   linalg.yield
341//      CHECK:  } -> tensor<5x2xf32>
342//      CHECK: return %[[R]] : tensor<5x2xf32>
343
344module attributes {transform.with_named_sequence} {
345  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
346    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
347    %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2, inner_parallel}
348      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
349      transform.yield
350  }
351}
352
353// -----
354
355// Check that we don't use +inf as the neutral element for minf when minf has
356// ninf. Instead check that we use the largest finite floating point value.
357// Also check that the fastmath flags are set on the created minf
358// instructions.
359func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>, %output: tensor<5x2xf32>)
360  -> tensor<5x2xf32>
361{
362  %0 = linalg.generic {
363      indexing_maps = [
364        affine_map<(d0, d1, d2) -> (d1, d0)>,
365        affine_map<(d0, d1, d2) -> (d2, d1)>,
366        affine_map<(d0, d1, d2) -> (d2, d0)>
367      ],
368      iterator_types = ["parallel", "reduction", "parallel"]
369    } ins(%input, %input_2 : tensor<32x2xf32>, tensor<5x32xf32>) outs(%output : tensor<5x2xf32>) {
370    ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
371      %3 = arith.addf %arg0, %arg1 : f32
372      %4 = arith.minimumf %3, %arg2 fastmath<ninf> : f32
373      linalg.yield %4 : f32
374    } -> tensor<5x2xf32>
375  return %0 : tensor<5x2xf32>
376}
377
378//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d0)>
379//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>
380//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)>
381//  CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
382//  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
383// CHECK-LABEL:  func @generic_split_3d
384//  CHECK-DAG: %[[ID:.*]] = arith.constant 3.40282347E+38 : f32
385//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [8, 4, 2] : tensor<32x2xf32> into tensor<8x4x2xf32>
386//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [5, 8, 4] : tensor<5x32xf32> into tensor<5x8x4xf32>
387//  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32>
388//      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
389//      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
390// CHECK-SAME:   ins(%[[I1]], %[[I2]] : tensor<8x4x2xf32>, tensor<5x8x4xf32>) outs(%[[F]] : tensor<5x2x4xf32>) {
391//      CHECK:   arith.addf
392//      CHECK:   arith.minimumf {{.*}} fastmath<ninf>
393//      CHECK:   linalg.yield
394//      CHECK: } -> tensor<5x2x4xf32>
395//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]}
396// CHECK-SAME:   ins(%[[G]] : tensor<5x2x4xf32>) outs(%{{.*}} : tensor<5x2xf32>) {
397//      CHECK:   arith.minimumf {{.*}} fastmath<ninf>
398//      CHECK:   linalg.yield
399//      CHECK:  } -> tensor<5x2xf32>
400//      CHECK: return %[[R]] : tensor<5x2xf32>
401
402module attributes {transform.with_named_sequence} {
403  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
404    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
405    %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2, inner_parallel}
406      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
407      transform.yield
408  }
409}
410
411// -----
412// Checks we use nan as the neutral element for maxnumf op.
413func.func @generic_split_maxnumf(%in: tensor<32xf32>, %out: tensor<f32>) -> tensor<f32> {
414  %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
415                                        affine_map<(d0) -> ()>],
416        iterator_types = ["reduction"]}
417  ins(%in : tensor<32xf32>)
418  outs(%out : tensor<f32>) {
419  ^bb0(%arg1: f32, %arg2: f32):
420    %y = arith.maxnumf %arg1, %arg2 : f32
421    linalg.yield %y : f32
422  } -> tensor<f32>
423  return %r : tensor<f32>
424}
425
426//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
427//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
428//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)>
429//  CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> ()>
430// CHECK-LABEL:  func @generic_split_maxnumf
431//  The float value 0xFFC00000 that is filled into the init tensor represents negative NaN.
432//  CHECK-DAG: %[[ID:.*]] = arith.constant 0xFFC00000 : f32
433//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [8, 4] : tensor<32xf32> into tensor<8x4xf32>
434//  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32>
435//      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
436//      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]}
437// CHECK-SAME:   ins(%[[I1]] : tensor<8x4xf32>) outs(%[[F]] : tensor<4xf32>) {
438//      CHECK:   arith.maxnumf
439//      CHECK:   linalg.yield
440//      CHECK: } -> tensor<4xf32>
441//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["reduction"]}
442// CHECK-SAME:   ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor<f32>) {
443//      CHECK:   arith.maxnumf {{.*}}
444//      CHECK:   linalg.yield
445//      CHECK:  } -> tensor<f32>
446//      CHECK: return %[[R]] : tensor<f32>
447
448module attributes {transform.with_named_sequence} {
449  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
450    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
451    %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 0, inner_parallel}
452      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
453      transform.yield
454  }
455}
456
457// -----
458// Checks we use nan as the neutral element for minnumf op.
459func.func @generic_split_minnumf(%in: tensor<32xf32>, %out: tensor<f32>) -> tensor<f32> {
460  %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
461                                        affine_map<(d0) -> ()>],
462        iterator_types = ["reduction"]}
463  ins(%in : tensor<32xf32>)
464  outs(%out : tensor<f32>) {
465  ^bb0(%arg1: f32, %arg2: f32):
466    %y = arith.minnumf %arg1, %arg2 : f32
467    linalg.yield %y : f32
468  } -> tensor<f32>
469  return %r : tensor<f32>
470}
471
472//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
473//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
474//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)>
475//  CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> ()>
476// CHECK-LABEL:  func @generic_split_minnumf
477//  The float value 0x7FC00000 that is filled into the init tensor represents positive NaN.
478//  CHECK-DAG: %[[ID:.*]] = arith.constant 0x7FC00000 : f32
479//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [8, 4] : tensor<32xf32> into tensor<8x4xf32>
480//  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32>
481//      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
482//      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]}
483// CHECK-SAME:   ins(%[[I1]] : tensor<8x4xf32>) outs(%[[F]] : tensor<4xf32>) {
484//      CHECK:   arith.minnumf
485//      CHECK:   linalg.yield
486//      CHECK: } -> tensor<4xf32>
487//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["reduction"]}
488// CHECK-SAME:   ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor<f32>) {
489//      CHECK:   arith.minnumf {{.*}}
490//      CHECK:   linalg.yield
491//      CHECK:  } -> tensor<f32>
492//      CHECK: return %[[R]] : tensor<f32>
493
494module attributes {transform.with_named_sequence} {
495  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
496    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
497    %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 0, inner_parallel}
498      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
499      transform.yield
500  }
501}
502