xref: /llvm-project/mlir/test/Integration/Dialect/Transform/match_reduction.mlir (revision 2798b72ae7e5caad793169b77cbac47fe2362d0f)
1// RUN: mlir-opt %s --transform-interpreter --verify-diagnostics
2
3module attributes { transform.with_named_sequence } {
4  transform.named_sequence @_reduce_leading_trailing(%entry: !transform.any_op {transform.readonly})
5      -> (!transform.any_op) {
6    %c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
7
8    transform.match.structured %entry : !transform.any_op {
9    ^bb0(%struct: !transform.any_op):
10      transform.match.structured.dim %struct[all] {parallel} : !transform.any_op
11      transform.match.structured.input %struct[all] {projected_permutation} : !transform.any_op
12      transform.match.structured.init %struct[all] {permutation} : !transform.any_op
13      %ni = transform.match.structured.num_inits %struct : (!transform.any_op) -> !transform.param<i64>
14      transform.match.param.cmpi eq %ni, %c1 : !transform.param<i64>
15    }
16    transform.yield %entry : !transform.any_op
17  }
18
19  transform.named_sequence @fill_reduce_leading_trailing(%entry: !transform.any_op {transform.readonly})
20      -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op,
21          !transform.param<i64>, !transform.param<i64>, !transform.param<i64>) {
22    %c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
23    %c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
24    %c4 = transform.param.constant 4 : i64 -> !transform.param<i64>
25
26    %rk, %dms, %bw, %operand_o, %init_v, %trailing_o = transform.match.structured failures(propagate) %entry
27        : (!transform.any_op) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
28                                  !transform.any_op, !transform.any_value, !transform.any_op) {
29    ^bb0(%struct: !transform.any_op):
30      %rank = transform.match.structured.rank %struct : (!transform.any_op) -> !transform.param<i64>
31      transform.match.param.cmpi ge %rank, %c2 : !transform.param<i64>
32      transform.match.param.cmpi le %rank, %c4 : !transform.param<i64>
33
34      transform.match.structured.dim %struct[-1] {reduction} : !transform.any_op
35      transform.match.structured.dim %struct[except(-1)] {parallel} : !transform.any_op
36      %dims = transform.match.structured.dim %struct[all] : (!transform.any_op) -> !transform.param<i64>
37
38      %n_inputs = transform.match.structured.num_inputs %struct : (!transform.any_op) -> !transform.param<i64>
39      %n_outputs = transform.match.structured.num_inits %struct : (!transform.any_op) -> !transform.param<i64>
40      transform.match.param.cmpi eq %n_inputs, %c1 : !transform.param<i64>
41      transform.match.param.cmpi eq %n_outputs, %c1 : !transform.param<i64>
42
43      transform.match.structured.input %struct[0] {projected_permutation} : !transform.any_op
44      transform.match.structured.init %struct[0] {projected_permutation} : !transform.any_op
45      %init = transform.match.structured.init %struct[0] : (!transform.any_op) -> !transform.any_value
46
47      // This danse is necessary to create an empty handle if there is no single
48      // user without failing the entire match
49      %trailing_optional = transform.sequence %struct : (!transform.any_op) -> !transform.any_op failures(suppress) {
50      ^bb0(%struct_inner: !transform.any_op):
51        %result = transform.match.structured failures(propagate) %struct_inner : (!transform.any_op) -> !transform.any_op {
52        ^bb0(%struct_inner_inner: !transform.any_op):
53          %result_inner = transform.match.structured.result %struct_inner_inner[0] {single} : (!transform.any_op) -> !transform.any_op
54          %trailing = transform.include @_reduce_leading_trailing failures(propagate) (%result_inner) : (!transform.any_op) -> !transform.any_op
55          transform.match.structured.yield %trailing : !transform.any_op
56        }
57        transform.yield %result: !transform.any_op
58      }
59
60      // Suppress errors as a way to implement optionality. We cannot suppress them in
61      // the include because it keeps matching after "get_defining_op" fails, which
62      // breaks the single-op precondition of the following ops. We don't want to
63      // propagate that failure though.
64      //
65      // Additionally, we cannot put the sequence inside the call because its first
66      // operand must be an operation handle (the verifier asserts!) and there is
67      // no such handle available there.
68      //
69      // TODO: extend the structured matching to gracefully handle empty handles
70      // or provide the suppress-errors-but-stop failure mode for includes to
71      // implement optionality.
72      %operand_optional = transform.sequence %struct : (!transform.any_op) -> !transform.any_op failures(suppress) {
73      ^bb0(%struct_inner: !transform.any_op):
74        %operand3 = transform.match.structured failures(propagate) %struct_inner : (!transform.any_op) -> !transform.any_op {
75        ^bb1(%struct_inner_inner: !transform.any_op):
76          %operand = transform.match.structured.input %struct_inner_inner[0] : (!transform.any_op) -> !transform.any_op
77          %operand2 = transform.include @_reduce_leading_trailing failures(propagate) (%operand) : (!transform.any_op) -> !transform.any_op
78          transform.match.structured.yield %operand2 : !transform.any_op
79        }
80        transform.yield %operand3 : !transform.any_op
81      }
82
83      %bitwidth = transform.match.structured.elemental_bitwidth %init : (!transform.any_value) -> !transform.param<i64>
84
85      transform.match.structured.body %struct { reduction_position = 0 } : !transform.any_op
86      transform.match.structured.yield %rank, %dims, %bitwidth, %operand_optional, %init, %trailing_optional
87        : !transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
88          !transform.any_op, !transform.any_value, !transform.any_op
89    }
90
91    %init_o = transform.get_defining_op %init_v : (!transform.any_value) -> !transform.any_op
92    transform.match.operation_name %init_o ["linalg.fill"] : !transform.any_op
93
94    transform.yield %operand_o, %init_o, %entry, %trailing_o, %rk, %dms, %bw
95        : !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op,
96          !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
97  }
98
99  transform.named_sequence @print_reduce_leading_trailing(
100      %leading: !transform.any_op {transform.readonly},
101      %fill: !transform.any_op {transform.readonly},
102      %reduction: !transform.any_op {transform.readonly},
103      %trailing: !transform.any_op {transform.readonly},
104      %rank: !transform.param<i64> {transform.readonly},
105      %dims: !transform.param<i64> {transform.readonly},
106      %bitwidth: !transform.param<i64> {transform.readonly}) {
107    transform.debug.emit_remark_at %leading, "leading" : !transform.any_op
108    transform.debug.emit_remark_at %fill, "fill" : !transform.any_op
109    transform.debug.emit_remark_at %reduction, "reduction" : !transform.any_op
110    transform.debug.emit_remark_at %trailing, "trailing" : !transform.any_op
111    transform.debug.emit_param_as_remark %rank, "rank" at %reduction : !transform.param<i64>, !transform.any_op
112    transform.debug.emit_param_as_remark %dims, "dimensions" at %reduction : !transform.param<i64>, !transform.any_op
113    transform.debug.emit_param_as_remark %bitwidth, "bitwidth" at %reduction : !transform.param<i64>, !transform.any_op
114    transform.yield
115  }
116
117  transform.named_sequence @__transform_main(%root: !transform.any_op {transform.consumed}) {
118    transform.foreach_match in %root
119      @fill_reduce_leading_trailing -> @print_reduce_leading_trailing
120      : (!transform.any_op) -> !transform.any_op
121    transform.yield
122  }
123}
124
125!in_tensor_t = tensor<8x64xf32>
126!out_tensor_t = tensor<8xf32>
127
128func.func @eltwise_reduce(%arg : !in_tensor_t) -> (!out_tensor_t) {
129  %cst = arith.constant -0.000000e+00 : f32
130
131  %0 = tensor.empty() : !out_tensor_t
132  // expected-remark @below {{fill}}
133  %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) ->  !out_tensor_t
134  %2 = tensor.empty() : !in_tensor_t
135  // expected-remark @below {{leading}}
136  %3 = linalg.generic {
137    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
138                     affine_map<(d0, d1) -> (d0, d1)>],
139    iterator_types = ["parallel", "parallel"]}
140    ins(%arg : !in_tensor_t) outs(%2 : !in_tensor_t) {
141    ^bb0(%arg3: f32, %arg4: f32):
142      %4 = arith.addf %arg3, %arg3 : f32
143      %5 = arith.addf %4, %4 : f32
144      linalg.yield %5 : f32
145    } -> !in_tensor_t
146
147  // expected-remark @below {{reduction}}
148  // expected-remark @below {{rank 2}}
149  // expected-remark @below {{dimensions 8 : i64, 64 : i64}}
150  // expected-remark @below {{bitwidth 32 : i64}}
151  %6 = linalg.generic {
152    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
153                     affine_map<(d0, d1) -> (d0)>],
154    iterator_types = ["parallel", "reduction"]}
155    ins(%3 : !in_tensor_t) outs(%1 : !out_tensor_t) {
156      ^bb0(%arg3: f32, %arg4: f32):
157        %4 = arith.addf %arg3, %arg4 : f32
158        linalg.yield %4 : f32
159      } -> !out_tensor_t
160
161  return %6 : !out_tensor_t
162}
163
164func.func @reduce_eltwise(%arg : !in_tensor_t) -> (!out_tensor_t) {
165  %cst = arith.constant -0.000000e+00 : f32
166
167  %0 = tensor.empty() : !out_tensor_t
168  // expected-remark @below {{fill}}
169  %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) -> !out_tensor_t
170  // expected-remark @below {{reduction}}
171  // expected-remark @below {{rank 2}}
172  // expected-remark @below {{dimensions 8 : i64, 64 : i64}}
173  // expected-remark @below {{bitwidth 32 : i64}}
174  %5 = linalg.generic {
175    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
176                     affine_map<(d0, d1) -> (d0)>],
177    iterator_types = ["parallel", "reduction"]}
178    ins(%arg : !in_tensor_t) outs(%1 : !out_tensor_t) {
179      ^bb0(%arg3: f32, %arg4: f32):
180        %4 = arith.addf %arg3, %arg4 : f32
181        linalg.yield %4 : f32
182      } -> !out_tensor_t
183
184  %6 = tensor.empty() : !out_tensor_t
185  // expected-remark @below {{trailing}}
186  %7 = linalg.generic {
187    indexing_maps = [affine_map<(d0) -> (d0)>,
188                     affine_map<(d0) -> (d0)>],
189    iterator_types = ["parallel"]}
190    ins(%5 : !out_tensor_t) outs(%6 : !out_tensor_t) {
191    ^bb0(%arg3: f32, %arg4: f32):
192      %4 = math.sqrt %arg3 : f32
193      linalg.yield %4 : f32
194    } -> !out_tensor_t
195  return %7 : !out_tensor_t
196}
197
198func.func @eltwise_reduce_eltwise(%arg : !in_tensor_t) -> (!out_tensor_t) {
199  %cst = arith.constant -0.000000e+00 : f32
200
201  %0 = tensor.empty() : !out_tensor_t
202  // expected-remark @below {{fill}}
203  %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) ->  !out_tensor_t
204  %2 = tensor.empty() : !in_tensor_t
205  // expected-remark @below {{leading}}
206  %3 = linalg.generic {
207    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
208                     affine_map<(d0, d1) -> (d0, d1)>],
209    iterator_types = ["parallel", "parallel"]}
210    ins(%arg : !in_tensor_t) outs(%2 : !in_tensor_t) {
211    ^bb0(%arg3: f32, %arg4: f32):
212      %4 = arith.addf %arg3, %arg3 : f32
213      %5 = arith.addf %4, %4 : f32
214      linalg.yield %5 : f32
215    } -> !in_tensor_t
216
217  // expected-remark @below {{reduction}}
218  // expected-remark @below {{rank 2}}
219  // expected-remark @below {{dimensions 8 : i64, 64 : i64}}
220  // expected-remark @below {{bitwidth 32 : i64}}
221  %6 = linalg.generic {
222    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
223                     affine_map<(d0, d1) -> (d0)>],
224    iterator_types = ["parallel", "reduction"]}
225    ins(%3 : !in_tensor_t) outs(%1 : !out_tensor_t) {
226      ^bb0(%arg3: f32, %arg4: f32):
227        %4 = arith.addf %arg3, %arg4 : f32
228        linalg.yield %4 : f32
229      } -> !out_tensor_t
230
231  %7 = tensor.empty() : !out_tensor_t
232  // expected-remark @below {{trailing}}
233  %8 = linalg.generic {
234    indexing_maps = [affine_map<(d0) -> (d0)>,
235                     affine_map<(d0) -> (d0)>],
236    iterator_types = ["parallel"]}
237    ins(%6 : !out_tensor_t) outs(%7 : !out_tensor_t) {
238    ^bb0(%arg3: f32, %arg4: f32):
239      %4 = math.sqrt %arg3 : f32
240      linalg.yield %4 : f32
241    } -> !out_tensor_t
242
243
244  return %8 : !out_tensor_t
245}
246
247func.func @eltwise_reduce_eltwise_swapped(%arg : !in_tensor_t) -> (!out_tensor_t) {
248  %cst = arith.constant -0.000000e+00 : f32
249
250  %2 = tensor.empty() : !in_tensor_t
251  // expected-remark @below {{leading}}
252  %3 = linalg.generic {
253    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
254                     affine_map<(d0, d1) -> (d0, d1)>],
255    iterator_types = ["parallel", "parallel"]}
256    ins(%arg : !in_tensor_t) outs(%2 : !in_tensor_t) {
257    ^bb0(%arg3: f32, %arg4: f32):
258      %4 = arith.addf %arg3, %arg3 : f32
259      %5 = arith.addf %4, %4 : f32
260      linalg.yield %5 : f32
261    } -> !in_tensor_t
262
263  %0 = tensor.empty() : !out_tensor_t
264  // expected-remark @below {{fill}}
265  %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) ->  !out_tensor_t
266  // expected-remark @below {{reduction}}
267  // expected-remark @below {{rank 2}}
268  // expected-remark @below {{dimensions 8 : i64, 64 : i64}}
269  // expected-remark @below {{bitwidth 32 : i64}}
270  %6 = linalg.generic {
271    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
272                     affine_map<(d0, d1) -> (d0)>],
273    iterator_types = ["parallel", "reduction"]}
274    ins(%3 : !in_tensor_t) outs(%1 : !out_tensor_t) {
275      ^bb0(%arg3: f32, %arg4: f32):
276        %4 = arith.addf %arg3, %arg4 : f32
277        linalg.yield %4 : f32
278      } -> !out_tensor_t
279
280  %7 = tensor.empty() : !out_tensor_t
281  // expected-remark @below {{trailing}}
282  %8 = linalg.generic {
283    indexing_maps = [affine_map<(d0) -> (d0)>,
284                     affine_map<(d0) -> (d0)>],
285    iterator_types = ["parallel"]}
286    ins(%6 : !out_tensor_t) outs(%7 : !out_tensor_t) {
287    ^bb0(%arg3: f32, %arg4: f32):
288      %4 = math.sqrt %arg3 : f32
289      linalg.yield %4 : f32
290    } -> !out_tensor_t
291
292
293  return %8 : !out_tensor_t
294}
295
296func.func @reduction_with_extra_op_in_func(%arg0: tensor<8x479xf32>, %arg1: tensor<32x32xf32>) -> (tensor<8xf32>, tensor<32xf32>) {
297  %cst = arith.constant 0.0 : f32
298  %empty = tensor.empty() : tensor<8xf32>
299  // expected-remark @below {{fill}}
300  %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<8xf32>) -> tensor<8xf32>
301  // expected-remark @below {{reduction}}
302  // expected-remark @below {{rank 2}}
303  // expected-remark @below {{dimensions 8 : i64, 479 : i64}}
304  // expected-remark @below {{bitwidth 32 : i64}}
305  %result = linalg.generic {
306    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
307                     affine_map<(d0, d1) -> (d0)>],
308    iterator_types = ["parallel", "reduction"]}
309    ins(%arg0 : tensor<8x479xf32>)
310    outs(%fill : tensor<8xf32>) {
311  ^bb0(%in: f32, %out: f32):
312    %6 = arith.addf %in, %out : f32
313    linalg.yield %6 : f32
314  } -> tensor<8xf32>
315
316  %empty2 = tensor.empty() : tensor<32xf32>
317  %fill2 = linalg.fill ins(%cst : f32) outs(%empty2 : tensor<32xf32>) -> tensor<32xf32>
318  return %result, %fill2 : tensor<8xf32>, tensor<32xf32>
319}
320