xref: /llvm-project/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir (revision f6a756f35a4d0719a96b4e214905369d565d87da)
1// RUN: mlir-opt %s --pass-pipeline="builtin.module(transform-interpreter{debug-payload-root-tag=start_here})" --split-input-file --verify-diagnostics
2
3module attributes { transform.with_named_sequence } {
4  transform.named_sequence @print_structured(%arg0: !transform.any_op {transform.readonly}) {
5    transform.debug.emit_remark_at %arg0, "structured" : !transform.any_op
6    transform.yield
7  }
8
9  transform.named_sequence @match_structured_empty(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
10    %0 = transform.match.structured %arg0 : (!transform.any_op) -> !transform.any_op {
11    ^bb0(%arg1: !transform.any_op):
12          transform.match.structured.yield %arg1 : !transform.any_op
13    }
14    transform.yield %0 : !transform.any_op
15  }
16
17  // Entry point. Match any structured operation and emit at remark.
18  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
19    transform.foreach_match in %arg0
20        @match_structured_empty -> @print_structured
21        : (!transform.any_op) -> !transform.any_op
22    transform.yield
23  }
24
25  func.func @payload() attributes { transform.target_tag = "start_here" } {
26    %preA = tensor.empty() : tensor<2x3xf32>
27    %cA = arith.constant 1.0 : f32
28    // expected-remark @below {{structured}}
29    %A = linalg.fill ins(%cA : f32) outs(%preA : tensor<2x3xf32>) -> tensor<2x3xf32>
30
31    %B = arith.constant dense<1.0> : tensor<3x4xf32>
32    %C = arith.constant dense<1000.0> : tensor<2x4xf32>
33    // expected-remark @below {{structured}}
34    %D = linalg.matmul ins(%A, %B: tensor<2x3xf32>, tensor<3x4xf32>)
35                       outs(%C: tensor<2x4xf32>) -> tensor<2x4xf32>
36
37    %E = arith.constant dense<2.0> : tensor<2x4xf32>
38    // expected-remark @below {{structured}}
39    linalg.generic {
40      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
41      iterator_types = ["parallel", "parallel"]
42    } ins(%D : tensor<2x4xf32>) outs(%E : tensor<2x4xf32>) {
43    ^bb0(%arg0: f32, %arg1: f32):
44      linalg.yield %arg0 : f32
45    } -> tensor<2x4xf32>
46
47    return
48  }
49}
50
51// -----
52
53module attributes { transform.with_named_sequence } {
54  transform.named_sequence @do_nothing(%arg0: !transform.any_op {transform.readonly}) {
55    transform.yield
56  }
57
58  transform.named_sequence @print_in_matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
59    transform.print %arg0 : !transform.any_op
60    transform.yield %arg0 : !transform.any_op
61  }
62
63  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
64    transform.foreach_match in %arg0
65        @print_in_matcher -> @do_nothing
66        : (!transform.any_op) -> !transform.any_op
67    transform.yield
68  }
69
70  func.func @payload() attributes { transform.target_tag = "start_here" } {
71    // CHECK: [[ IR Printer ]]
72    // CHECK: test.print_me
73    %0 = "test.print_me"() : () -> (i1)
74    return
75  }
76}
77
78// -----
79
80
81module attributes { transform.with_named_sequence } {
82  transform.named_sequence @do_nothing(%arg0: !transform.any_op {transform.readonly}) {
83    transform.yield
84  }
85
86  // Entry point. Match any structured operation and emit a remark. Also emit
87  // a different remark at all considered operations. When it fails, the
88  // failure is suppressed and the resulting handle is assocaited with an empty
89  // list, hence nothing is printed. Both remark printing operations happen
90  // after the check in the sequence, so they only apply if the check operation
91  // produced success (due to failure suppression or not).
92  transform.named_sequence @match_structured_suppress(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
93    %0 = transform.match.structured failures(suppress) %arg0 : (!transform.any_op) -> !transform.any_op {
94    ^bb0(%arg1: !transform.any_op):
95      transform.match.structured.yield %arg1 : !transform.any_op
96    }
97    transform.debug.emit_remark_at %0, "structured" : !transform.any_op
98    transform.debug.emit_remark_at %arg0, "other" : !transform.any_op
99    transform.yield %0 : !transform.any_op
100  }
101
102  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
103    transform.foreach_match restrict_root in %arg0
104        @match_structured_suppress -> @do_nothing
105        : (!transform.any_op) -> !transform.any_op
106    transform.yield
107  }
108
109  // expected-remark @below {{other}}
110  func.func @payload() attributes { transform.target_tag = "start_here" } {
111    // expected-remark @below {{other}}
112    %D = arith.constant dense<1.0> : tensor<2x4xf32>
113    // expected-remark @below {{other}}
114    %E = arith.constant dense<2.0> : tensor<2x4xf32>
115    // expected-remark @below {{structured}}
116    // expected-remark @below {{other}}
117    linalg.generic {
118      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
119      iterator_types = ["parallel", "parallel"]
120    } ins(%D : tensor<2x4xf32>) outs(%E : tensor<2x4xf32>) {
121    ^bb0(%arg0: f32, %arg1: f32):
122      // expected-remark @below {{other}}
123      linalg.yield %arg0 : f32
124    } -> tensor<2x4xf32>
125
126    // expected-remark @below {{other}}
127    return
128  }
129}
130
131// -----
132
133module attributes { transform.with_named_sequence } {
134  transform.named_sequence @print_passthrough(%arg0: !transform.any_op {transform.readonly}) {
135    transform.debug.emit_remark_at %arg0, "passthrough" : !transform.any_op
136    transform.yield
137  }
138
139  transform.named_sequence @match_structured_body_passthrough(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
140    %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
141    ^bb0(%arg1: !transform.any_op):
142      transform.match.structured.body %arg1 { passthrough } : !transform.any_op
143      transform.match.structured.yield %arg1 : !transform.any_op
144    }
145    transform.yield %0 : !transform.any_op
146  }
147
148  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
149    transform.foreach_match in %arg0
150        @match_structured_body_passthrough -> @print_passthrough
151        : (!transform.any_op) -> !transform.any_op
152    transform.yield
153  }
154
155  func.func @payload(%in: tensor<2xf32>, %out: tensor<2xf32>) attributes { transform.target_tag = "start_here" } {
156    // expected-remark @below {{passthrough}}
157    linalg.generic {
158      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
159      iterator_types = ["parallel"]
160    } ins(%in : tensor<2xf32>) outs(%out : tensor<2xf32>) {
161    ^bb0(%arg0: f32, %arg1: f32):
162      linalg.yield %arg0 : f32
163    } -> tensor<2xf32>
164
165    linalg.generic {
166      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
167      iterator_types = ["parallel"]
168    } ins(%in : tensor<2xf32>) outs(%out : tensor<2xf32>) {
169    ^bb0(%arg0: f32, %arg1: f32):
170      %0 = arith.mulf %arg0, %arg1 : f32
171      linalg.yield %0 : f32
172    } -> tensor<2xf32>
173
174    // expected-remark @below {{passthrough}}
175    linalg.copy ins(%in : tensor<2xf32>) outs(%out : tensor<2xf32>) -> tensor<2xf32>
176
177    return
178  }
179}
180
181// -----
182
183module attributes { transform.with_named_sequence } {
184  transform.named_sequence @print_elementwise(%arg0: !transform.any_op {transform.readonly}) {
185    transform.debug.emit_remark_at %arg0, "elementwise" : !transform.any_op
186    transform.yield
187  }
188
189  transform.named_sequence @match_structured_body_elementwise(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
190    %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
191    ^bb0(%arg1: !transform.any_op):
192      transform.match.structured.body %arg1 { elementwise } : !transform.any_op
193      transform.match.structured.yield %arg1 : !transform.any_op
194    }
195    transform.yield %0 : !transform.any_op
196  }
197
198  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
199    transform.foreach_match in %arg0
200        @match_structured_body_elementwise -> @print_elementwise
201        : (!transform.any_op) -> !transform.any_op
202    transform.yield
203  }
204
205  func.func @payload(%in1: tensor<2xf32>, %in2: tensor<2xf32>, %in3: tensor<2x3xf32>, %out: tensor<2xf32>, %out2: tensor<2x3xf32>) -> (tensor<2xf32>, tensor<2x3xf32>, tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } {
206    %cst0 = arith.constant 0.0 : f32
207    %c0 = arith.constant 0 : index
208    %c1 = arith.constant 1 : index
209    // expected-remark @below {{elementwise}}
210    %fill = linalg.fill ins(%cst0: f32) outs(%out: tensor<2xf32>) -> tensor<2xf32>
211    // expected-remark @below {{elementwise}}
212    %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<2xf32>, tensor<2xf32>) outs(%fill: tensor<2xf32>)
213    %non_elementwise = linalg.generic
214      {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
215       iterator_types = ["parallel", "parallel"]}
216      ins(%in1, %in3: tensor<2xf32>, tensor<2x3xf32>) outs(%out2: tensor<2x3xf32>) {
217        ^bb0(%arg0: f32, %arg1: f32, %arg3: f32):
218          %0 = arith.addf %arg0, %arg1 : f32
219          %1 = tensor.dim %add, %c0 : tensor<2xf32>
220          %2 = arith.subi %1, %c1 : index
221          %3 = tensor.extract %add[%2] : tensor<2xf32>
222          %4 = arith.mulf %0, %3 : f32
223          linalg.yield %4 : f32
224      } -> tensor<2x3xf32>
225    // expected-remark @below {{elementwise}}
226    %add_bcast = linalg.generic
227      {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
228       iterator_types = ["parallel", "parallel"]}
229      ins(%in1, %in3: tensor<2xf32>, tensor<2x3xf32>) outs(%out2: tensor<2x3xf32>) {
230        ^bb0(%arg0: f32, %arg1: f32, %arg3: f32):
231          %0 = arith.addf %arg0, %arg1 : f32
232          linalg.yield %0 : f32
233      } -> tensor<2x3xf32>
234    return %add, %add_bcast, %non_elementwise : tensor<2xf32>, tensor<2x3xf32>, tensor<2x3xf32>
235  }
236}
237
238// -----
239
240module attributes { transform.with_named_sequence } {
241  transform.named_sequence @print_reduction(%arg0: !transform.any_op {transform.readonly}) {
242    transform.debug.emit_remark_at %arg0, "reduction" : !transform.any_op
243    transform.yield
244  }
245
246  transform.named_sequence @match_structured_body_reduction(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
247    %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
248    ^bb0(%arg1: !transform.any_op):
249      transform.match.structured.body %arg1 { reduction_position = 0 } : !transform.any_op
250      transform.match.structured.yield %arg1 : !transform.any_op
251    }
252    transform.yield %0 : !transform.any_op
253  }
254
255  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
256    transform.foreach_match in %arg0
257        @match_structured_body_reduction -> @print_reduction
258        : (!transform.any_op) -> !transform.any_op
259    transform.yield
260  }
261
262  func.func @payload(%lhs: tensor<2x4xf32>, %rhs: tensor<4x3xf32>, %out: tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } {
263    // expected-remark @below {{reduction}}
264    linalg.generic {
265      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
266      iterator_types = ["parallel", "parallel", "reduction"]
267    } ins(%lhs, %rhs: tensor<2x4xf32>, tensor<4x3xf32>) outs(%out: tensor<2x3xf32>) {
268    ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
269      %0 = arith.mulf %arg0, %arg1 : f32
270      %1 = arith.addf %0, %arg2 : f32
271      linalg.yield %1 : f32
272    } -> tensor<2x3xf32>
273
274    %r = tensor.empty() : tensor<2x3xf32>
275    linalg.generic {
276      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>,
277                       affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
278      iterator_types = ["parallel", "parallel", "reduction"]
279    } ins(%lhs, %rhs: tensor<2x4xf32>, tensor<4x3xf32>) outs(%out, %r: tensor<2x3xf32>, tensor<2x3xf32>) {
280    ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32):
281      %0 = arith.mulf %arg0, %arg1 : f32
282      %1 = arith.cmpf olt, %0, %arg2 : f32
283      %2 = arith.select %1, %0, %arg2 : f32
284      %3 = arith.select %1, %arg3, %0 : f32
285      linalg.yield %2, %3 : f32, f32
286    } -> (tensor<2x3xf32>, tensor<2x3xf32>)
287
288    // expected-remark @below {{reduction}}
289    linalg.matmul ins(%lhs, %rhs: tensor<2x4xf32>, tensor<4x3xf32>) outs(%out: tensor<2x3xf32>) -> tensor<2x3xf32>
290
291    %e = tensor.empty() : tensor<2x4xf32>
292    linalg.generic {
293      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
294      iterator_types = ["parallel", "parallel"]
295    } ins(%lhs: tensor<2x4xf32>) outs(%e: tensor<2x4xf32>) {
296    ^bb0(%arg0: f32, %arg1: f32):
297      linalg.yield %arg0 : f32
298    } -> tensor<2x4xf32>
299
300    return
301  }
302}
303
304
305// -----
306
307module attributes { transform.with_named_sequence } {
308  transform.named_sequence @do_nothing(%arg0: !transform.any_op {transform.readonly}) {
309    transform.yield
310  }
311
312  transform.named_sequence @print_dimension_size_match(%arg0: !transform.any_op {transform.readonly}) {
313    transform.debug.emit_remark_at %arg0, "matched sizes" : !transform.any_op
314    transform.yield
315  }
316
317  transform.named_sequence @match_dimension_capture(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
318    // Capture multiple dimension values. Suppress failures so we can print them anyway after the capture.
319    %0:9 = transform.match.structured failures(suppress) %arg0
320      : (!transform.any_op) -> (!transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
321            !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>) {
322    ^bb0(%arg1: !transform.any_op):
323      // This also tests the positional specification used by other ops, which may not test it again.
324      %1 = transform.match.structured.dim %arg1[all] : (!transform.any_op) -> !transform.param<i64>
325      %2 = transform.match.structured.dim %arg1[0] : (!transform.any_op) -> !transform.param<i64>
326      %3 = transform.match.structured.dim %arg1[-1] : (!transform.any_op) -> !transform.param<i64>
327      %4 = transform.match.structured.dim %arg1[0, 2] : (!transform.any_op) -> !transform.param<i64>
328      %5 = transform.match.structured.dim %arg1[0, -1] : (!transform.any_op) -> !transform.param<i64>
329      %6 = transform.match.structured.dim %arg1[except(-1)] : (!transform.any_op) -> !transform.param<i64>
330      %7 = transform.match.structured.dim %arg1[except(0, -2)] : (!transform.any_op) -> !transform.param<i64>
331      %8 = transform.match.structured.dim %arg1[0, -3] : (!transform.any_op) -> !transform.param<i64>
332      transform.match.structured.yield %arg1, %1, %2, %3, %4, %5, %6, %7, %8
333          : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
334            !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
335    }
336    transform.debug.emit_param_as_remark %0#1, "dimensions all:" at %0#0 : !transform.param<i64>, !transform.any_op
337    transform.debug.emit_param_as_remark %0#2, "dimension 0:" at %0#0 : !transform.param<i64>, !transform.any_op
338    transform.debug.emit_param_as_remark %0#3, "dimension -1:" at %0#0 : !transform.param<i64>, !transform.any_op
339    transform.debug.emit_param_as_remark %0#4, "dimensions 0, 2:" at %0#0 : !transform.param<i64>, !transform.any_op
340    transform.debug.emit_param_as_remark %0#5, "dimensions 0, -1:" at %0#0 : !transform.param<i64>, !transform.any_op
341    transform.debug.emit_param_as_remark %0#6, "dimensions except -1:" at %0#0 : !transform.param<i64>, !transform.any_op
342    transform.debug.emit_param_as_remark %0#7, "dimensions except 0, -2:" at %0#0 : !transform.param<i64>, !transform.any_op
343    transform.debug.emit_param_as_remark %0#8, "dimensions 0, -3:" at %0#0 : !transform.param<i64>, !transform.any_op
344    transform.yield %0#0 : !transform.any_op
345  }
346
347  transform.named_sequence @match_dimension_sizes(%arg0: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
348    %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
349    ^bb0(%arg1: !transform.any_op):
350      %1 = transform.match.structured.dim %arg1[all] : (!transform.any_op) -> !transform.param<i64>
351      %c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
352      %c3 = transform.param.constant 3 : i64 -> !transform.param<i64>
353      %c4 = transform.param.constant 4 : i64 -> !transform.param<i64>
354      %2 = transform.merge_handles %c2, %c3, %c4 : !transform.param<i64>
355      transform.match.param.cmpi eq %1, %2 : !transform.param<i64>
356
357      transform.match.structured.yield %arg1 : !transform.any_op
358    }
359    transform.yield %0 : !transform.any_op
360  }
361
362  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
363    %0 = transform.foreach_match in %arg0 @match_dimension_capture -> @do_nothing : (!transform.any_op) -> !transform.any_op
364    %1 = transform.foreach_match in %0 @match_dimension_sizes -> @print_dimension_size_match : (!transform.any_op) -> !transform.any_op
365    transform.yield
366  }
367
368  func.func @payload(%lhs: tensor<2x4xf32>, %rhs: tensor<4x3xf32>, %out: tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } {
369    // The last does not emit anything because it fails to match
370    // due to 0 and -3 being the same dimension in the 3D case.
371    // expected-remark @below {{dimensions all: 2 : i64, 3 : i64, 4 : i64}}
372    // expected-remark @below {{dimension 0: 2 : i64}}
373    // expected-remark @below {{dimension -1: 4 : i64}}
374    // expected-remark @below {{dimensions 0, 2: 2 : i64, 4 : i64}}
375    // expected-remark @below {{dimensions 0, -1: 2 : i64, 4 : i64}}
376    // expected-remark @below {{dimensions except -1: 2 : i64, 3 : i64}}
377    // expected-remark @below {{dimensions except 0, -2: 4 : i64}}
378    // expected-remark @below {{dimensions 0, -3:}}
379    // expected-remark @below {{matched sizes}}
380    linalg.generic {
381      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
382      iterator_types = ["parallel", "parallel", "reduction"]
383    } ins(%lhs, %rhs: tensor<2x4xf32>, tensor<4x3xf32>) outs(%out: tensor<2x3xf32>) {
384    ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
385      %0 = arith.mulf %arg0, %arg1 : f32
386      %1 = arith.addf %0, %arg2 : f32
387      linalg.yield %1 : f32
388    } -> tensor<2x3xf32>
389
390    return
391  }
392}
393
394// -----
395
396module attributes { transform.with_named_sequence } {
397  transform.named_sequence @print_all_reduction(%arg0: !transform.any_op {transform.readonly}) {
398    transform.debug.emit_remark_at %arg0, "all reduction" : !transform.any_op
399    transform.yield
400  }
401  transform.named_sequence @print_all_parallel(%arg0: !transform.any_op {transform.readonly}) {
402    transform.debug.emit_remark_at %arg0, "all parallel" : !transform.any_op
403    transform.yield
404  }
405  transform.named_sequence @print_last_reduction(%arg0: !transform.any_op {transform.readonly}) {
406    transform.debug.emit_remark_at %arg0, "last reduction" : !transform.any_op
407    transform.yield
408  }
409  transform.named_sequence @print_parallel_except_last(%arg0: !transform.any_op {transform.readonly}) {
410    transform.debug.emit_remark_at %arg0, "parallel except last" : !transform.any_op
411    transform.yield
412  }
413
414  transform.named_sequence @match_all_reduction(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
415    transform.match.structured failures(propagate) %arg0 : !transform.any_op {
416    ^bb0(%arg1: !transform.any_op):
417      transform.match.structured.dim %arg1[all] { reduction } : !transform.any_op
418      transform.match.structured.yield
419    }
420    transform.yield %arg0 : !transform.any_op
421  }
422  transform.named_sequence @match_all_parallel(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
423    transform.match.structured failures(propagate) %arg0 : !transform.any_op {
424    ^bb0(%arg1: !transform.any_op):
425      transform.match.structured.dim %arg1[all] { parallel } : !transform.any_op
426      transform.match.structured.yield
427    }
428    transform.yield %arg0 : !transform.any_op
429  }
430  transform.named_sequence @match_last_reduction(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
431    transform.match.structured failures(propagate) %arg0 : !transform.any_op {
432    ^bb0(%arg1: !transform.any_op):
433      transform.match.structured.dim %arg1[-1] { reduction } : !transform.any_op
434      transform.match.structured.yield
435    }
436    transform.yield %arg0 : !transform.any_op
437  }
438  transform.named_sequence @match_parallel_except_last(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
439    transform.match.structured failures(propagate) %arg0 : !transform.any_op {
440    ^bb0(%arg1: !transform.any_op):
441      transform.match.structured.dim %arg1[except(-1)] { parallel } : !transform.any_op
442      transform.match.structured.yield
443    }
444    transform.yield %arg0 : !transform.any_op
445  }
446
447  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
448    %0 = transform.foreach_match in %arg0 @match_all_reduction -> @print_all_reduction : (!transform.any_op) -> !transform.any_op
449    %1 = transform.foreach_match in %0 @match_all_parallel -> @print_all_parallel : (!transform.any_op) -> !transform.any_op
450    %2 = transform.foreach_match in %1 @match_last_reduction -> @print_last_reduction : (!transform.any_op) -> !transform.any_op
451    %3 = transform.foreach_match in %2 @match_parallel_except_last -> @print_parallel_except_last : (!transform.any_op) -> !transform.any_op
452    transform.yield
453  }
454
455  func.func @payload(%lhs: tensor<2x4xf32>, %rhs: tensor<4x3xf32>, %out: tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } {
456    // expected-remark @below {{last reduction}}
457    // expected-remark @below {{parallel except last}}
458    linalg.generic {
459      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
460      iterator_types = ["parallel", "parallel", "reduction"]
461    } ins(%lhs, %rhs: tensor<2x4xf32>, tensor<4x3xf32>) outs(%out: tensor<2x3xf32>) {
462    ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
463      %0 = arith.mulf %arg0, %arg1 : f32
464      %1 = arith.addf %0, %arg2 : f32
465      linalg.yield %1 : f32
466    } -> tensor<2x3xf32>
467
468    // expected-remark @below {{last reduction}}
469    // expected-remark @below {{parallel except last}}
470    linalg.matmul ins(%lhs, %rhs : tensor<2x4xf32>, tensor<4x3xf32>) outs(%out : tensor<2x3xf32>) -> tensor<2x3xf32>
471
472    %cst = arith.constant 1.0 : f32
473    // expected-remark @below {{all parallel}}
474    // expected-remark @below {{parallel except last}}
475    linalg.fill ins(%cst : f32) outs(%out: tensor<2x3xf32>) -> tensor<2x3xf32>
476
477    return
478  }
479}
480
481// -----
482
483module attributes { transform.with_named_sequence } {
484  transform.named_sequence @match_bitwidth(%arg0: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.param<i64>) {
485    %bw = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.param<i64> {
486    ^bb0(%arg1: !transform.any_op):
487      %0 = transform.match.structured.init %arg1 [0] : (!transform.any_op) -> !transform.any_value
488      %1 = transform.match.structured.elemental_bitwidth %0 : (!transform.any_value) -> !transform.param<i64>
489      transform.match.structured.yield %1 : !transform.param<i64>
490    }
491    transform.yield %arg0, %bw : !transform.any_op, !transform.param<i64>
492  }
493
494  transform.named_sequence @print_bitwidth(%arg0: !transform.any_op {transform.readonly}, %arg1: !transform.param<i64> {transform.readonly}) {
495    transform.debug.emit_param_as_remark %arg1, "bitwidth:" at %arg0 : !transform.param<i64>, !transform.any_op
496    transform.yield
497  }
498
499  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
500    transform.foreach_match in %arg0 @match_bitwidth -> @print_bitwidth : (!transform.any_op) -> !transform.any_op
501    transform.yield
502  }
503
504  func.func @payload(%f32: f32, %tf32: tensor<?xf32>,
505                     %index: index, %tindex: tensor<?xindex>)
506            attributes { transform.target_tag = "start_here" }  {
507    // expected-remark @below {{bitwidth: 32}}
508    linalg.fill ins(%f32: f32) outs(%tf32: tensor<?xf32>) -> tensor<?xf32>
509    linalg.fill ins(%index: index) outs(%tindex: tensor<?xindex>) -> tensor<?xindex>
510    return
511  }
512}
513
514// -----
515
516module attributes { transform.with_named_sequence } {
517  transform.named_sequence @match_init(%arg0: !transform.any_op {transform.readonly})
518      -> (!transform.any_op, !transform.any_value, !transform.any_value, !transform.any_op) {
519    %outs:3 = transform.match.structured failures(suppress) %arg0
520      : (!transform.any_op) -> (!transform.any_value, !transform.any_value, !transform.any_op) {
521    ^bb0(%arg1: !transform.any_op):
522      %0 = transform.match.structured.init %arg1 [0] : (!transform.any_op) -> !transform.any_value
523      %1 = transform.match.structured.init %arg1 [all] : (!transform.any_op) -> !transform.any_value
524      %2 = transform.match.structured.init %arg1 [0] : (!transform.any_op) -> !transform.any_op
525      transform.match.structured.yield %0, %1, %2 : !transform.any_value, !transform.any_value, !transform.any_op
526    }
527    transform.yield %arg0, %outs#0, %outs#1, %outs#2 : !transform.any_op, !transform.any_value, !transform.any_value, !transform.any_op
528  }
529
530  transform.named_sequence @print_init(%arg0: !transform.any_op {transform.readonly},
531                                         %arg1: !transform.any_value {transform.readonly},
532                                         %arg2: !transform.any_value {transform.readonly},
533                                         %arg3: !transform.any_op {transform.readonly}) {
534    transform.debug.emit_remark_at %arg1, "output 0" : !transform.any_value
535    transform.debug.emit_remark_at %arg3, "output producer" : !transform.any_op
536    transform.debug.emit_remark_at %arg2, "all output" : !transform.any_value
537    transform.yield
538  }
539
540  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
541    transform.foreach_match in %arg0 @match_init -> @print_init : (!transform.any_op) -> !transform.any_op
542    transform.yield
543  }
544
545
546  func.func @payload(%f32: f32,
547            // expected-remark @below {{output 0}}
548            // expected-remark @below {{all output}}
549            // expected-note @below {{value handle points to a block argument #1 in block #0 in region #0}}
550            %tf32: tensor<?xf32>,
551            // expected-remark @below {{all output}}
552            // expected-note @below {{value handle points to a block argument #2 in block #0 in region #0}}
553            %tf32_2: tensor<?xf32>)
554            attributes { transform.target_tag = "start_here" }  {
555    // expected-remark @below {{output 0}}
556    // expected-remark @below {{output producer}}
557    // expected-remark @below {{all output}}
558    // expected-note @below {{value handle points to an op result #0}}
559    %0 = linalg.fill ins(%f32: f32) outs(%tf32: tensor<?xf32>) -> tensor<?xf32>
560
561    linalg.generic {
562      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
563      iterator_types = ["parallel"]
564    } ins(%tf32: tensor<?xf32>) outs(%0, %tf32_2: tensor<?xf32>, tensor<?xf32>) {
565    ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
566      linalg.yield %arg0, %arg0 : f32, f32
567    } -> (tensor<?xf32>, tensor<?xf32>)
568    return
569  }
570}
571
572// -----
573
574module attributes { transform.with_named_sequence } {
575  transform.named_sequence @match_init_0_permutation(%arg0: !transform.any_op {transform.readonly})
576      -> !transform.any_op {
577    %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
578    ^bb0(%arg1: !transform.any_op):
579      transform.match.structured.init %arg1[0] { permutation }: !transform.any_op
580      transform.match.structured.yield %arg1 : !transform.any_op
581    }
582    transform.yield %0 : !transform.any_op
583  }
584  transform.named_sequence @match_init_1_permutation(%arg0: !transform.any_op {transform.readonly})
585      -> !transform.any_op {
586    %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
587    ^bb0(%arg1: !transform.any_op):
588      transform.match.structured.init %arg1[1] { permutation }: !transform.any_op
589      transform.match.structured.yield %arg1 : !transform.any_op
590    }
591    transform.yield %0 : !transform.any_op
592  }
593  transform.named_sequence @match_init_2_projected_permutation(%arg0: !transform.any_op {transform.readonly})
594      -> !transform.any_op {
595    %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
596    ^bb0(%arg1: !transform.any_op):
597      transform.match.structured.init %arg1[2] { projected_permutation }: !transform.any_op
598      transform.match.structured.yield %arg1 : !transform.any_op
599    }
600    transform.yield %0 : !transform.any_op
601  }
602
603  transform.named_sequence @print_init_0_permutation(%arg0: !transform.any_op {transform.readonly}) {
604    transform.debug.emit_remark_at %arg0, "matched output 0 permutation" : !transform.any_op
605    transform.yield
606  }
607  transform.named_sequence @print_init_1_permutation(%arg0: !transform.any_op {transform.readonly}) {
608    transform.debug.emit_remark_at %arg0, "matched output 1 permutation" : !transform.any_op
609    transform.yield
610  }
611  transform.named_sequence @print_init_2_projected_permutation(%arg0: !transform.any_op {transform.readonly}) {
612    transform.debug.emit_remark_at %arg0, "matched output 2 projected permutation" : !transform.any_op
613    transform.yield
614  }
615
616  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
617    %0 = transform.foreach_match in %arg0 @match_init_0_permutation -> @print_init_0_permutation : (!transform.any_op) -> !transform.any_op
618    %1 = transform.foreach_match in %0 @match_init_1_permutation -> @print_init_1_permutation : (!transform.any_op) -> !transform.any_op
619    %2 = transform.foreach_match in %1 @match_init_2_projected_permutation -> @print_init_2_projected_permutation : (!transform.any_op) -> !transform.any_op
620    transform.yield
621  }
622
623  func.func @payload(%f32: f32,
624            %oned: tensor<?xf32>,
625            %oned2: tensor<?xf32>,
626            %twod: tensor<?x?xf32>)
627            attributes { transform.target_tag = "start_here" }  {
628    // expected-remark @below {{matched output 2 projected permutation}}
629    linalg.generic {
630      indexing_maps = [affine_map<(d0, d1) -> (d0)>,
631                       affine_map<(d0, d1) -> (d0 + d1)>,
632                       affine_map<(d0, d1) -> (d1)>,
633                       affine_map<(d0, d1) -> (d1, d0)>],
634      iterator_types = ["parallel", "parallel"]
635    } ins(%oned: tensor<?xf32>) outs(%oned, %oned2, %twod: tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>) {
636    ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32):
637      linalg.yield %arg0, %arg0, %arg0 : f32, f32, f32
638    } -> (tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>)
639
640    // expected-remark @below {{matched output 2 projected permutation}}
641    // expected-remark @below {{matched output 1 permutation}}
642    linalg.generic {
643      indexing_maps = [affine_map<(d0, d1) -> (d0)>,
644                       affine_map<(d0, d1) -> (d0 + d1)>,
645                       affine_map<(d0, d1) -> (d1, d0)>,
646                       affine_map<(d0, d1) -> (d1)>],
647      iterator_types = ["parallel", "parallel"]
648    } ins(%oned: tensor<?xf32>) outs(%oned, %twod, %oned2: tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>) {
649    ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32):
650      linalg.yield %arg0, %arg0, %arg0 : f32, f32, f32
651    } -> (tensor<?xf32>,  tensor<?x?xf32>, tensor<?xf32>)
652    return
653  }
654}
655
656// -----
657
658
659
660module attributes { transform.with_named_sequence } {
661  transform.named_sequence @match_num_io(%arg0: !transform.any_op {transform.readonly})
662      -> (!transform.param<i64>, !transform.param<i64>, !transform.any_op) {
663    %0:3 = transform.match.structured failures(propagate) %arg0
664         : (!transform.any_op) -> (!transform.param<i64>, !transform.param<i64>, !transform.any_op) {
665    ^bb0(%arg1: !transform.any_op):
666      %1 = transform.match.structured.num_inputs %arg1 : (!transform.any_op) -> !transform.param<i64>
667      %2 = transform.match.structured.num_inits %arg1 : (!transform.any_op) -> !transform.param<i64>
668      transform.match.structured.yield %1, %2, %arg1 : !transform.param<i64>, !transform.param<i64>, !transform.any_op
669    }
670    transform.yield %0#0, %0#1, %0#2 : !transform.param<i64>, !transform.param<i64>, !transform.any_op
671  }
672
673
674  transform.named_sequence @print_num_io(
675      %arg0: !transform.param<i64> {transform.readonly},
676      %arg1: !transform.param<i64> {transform.readonly},
677      %arg2: !transform.any_op {transform.readonly}) {
678    transform.debug.emit_param_as_remark %arg0, "inputs" at %arg2 : !transform.param<i64>, !transform.any_op
679    transform.debug.emit_param_as_remark %arg1, "outputs" at %arg2 : !transform.param<i64>, !transform.any_op
680    transform.yield
681  }
682
683
684  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
685    %0 = transform.foreach_match in %arg0 @match_num_io -> @print_num_io : (!transform.any_op) -> !transform.any_op
686    transform.yield
687  }
688
689  func.func @payload(%f32: f32,
690            %oned: tensor<?xf32>,
691            %oned2: tensor<?xf32>,
692            %twod: tensor<?x?xf32>)
693            attributes { transform.target_tag = "start_here" }  {
694    // expected-remark @below {{inputs 1}}
695    // expected-remark @below {{outputs 3}}
696    linalg.generic {
697      indexing_maps = [affine_map<(d0, d1) -> (d0)>,
698                       affine_map<(d0, d1) -> (d0 + d1)>,
699                       affine_map<(d0, d1) -> (d1)>,
700                       affine_map<(d0, d1) -> (d1, d0)>],
701      iterator_types = ["parallel", "parallel"]
702    } ins(%oned: tensor<?xf32>) outs(%oned, %oned2, %twod: tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>) {
703    ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32):
704      linalg.yield %arg0, %arg0, %arg0 : f32, f32, f32
705    } -> (tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>)
706
707    // expected-remark @below {{inputs 2}}
708    // expected-remark @below {{outputs 2}}
709    linalg.generic {
710      indexing_maps = [affine_map<(d0, d1) -> (d0)>,
711                       affine_map<(d0, d1) -> (d1, d0)>,
712                       affine_map<(d0, d1) -> (d0 + d1)>,
713                       affine_map<(d0, d1) -> (d1)>],
714      iterator_types = ["parallel", "parallel"]
715    } ins(%oned, %twod: tensor<?xf32>, tensor<?x?xf32>) outs(%oned, %oned2: tensor<?xf32>, tensor<?xf32>) {
716    ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32):
717      linalg.yield %arg0, %arg0 : f32, f32
718    } -> (tensor<?xf32>, tensor<?xf32>)
719    return
720  }
721}
722
723// -----
724
725module attributes { transform.with_named_sequence } {
726  transform.named_sequence @match_rank(%arg0: !transform.any_op {transform.readonly})
727      -> (!transform.param<i64>, !transform.any_op) {
728    %0:2 = transform.match.structured failures(propagate) %arg0
729         : (!transform.any_op) -> (!transform.param<i64>, !transform.any_op) {
730    ^bb0(%arg1: !transform.any_op):
731      %1 = transform.match.structured.rank %arg1 : (!transform.any_op) -> !transform.param<i64>
732      transform.match.structured.yield %1, %arg1 : !transform.param<i64>, !transform.any_op
733    }
734    transform.yield %0#0, %0#1 : !transform.param<i64>, !transform.any_op
735  }
736
737
738  transform.named_sequence @print_rank(%arg0: !transform.param<i64> {transform.readonly},
739                                       %arg2: !transform.any_op {transform.readonly}) {
740    transform.debug.emit_param_as_remark %arg0, "rank" at %arg2 : !transform.param<i64>, !transform.any_op
741    transform.yield
742  }
743
744  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
745    %0 = transform.foreach_match in %arg0 @match_rank -> @print_rank : (!transform.any_op) -> !transform.any_op
746    transform.yield
747  }
748
749  func.func @payload(%f32: f32,
750            %twod: tensor<42x42xf32>)
751            attributes { transform.target_tag = "start_here" } {
752    %0 = tensor.empty() : tensor<42x42xf32>
753    // expected-remark @below {{rank 2}}
754    %1 = linalg.fill ins(%f32 : f32) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
755    // expected-remark @below {{rank 3}}
756    linalg.matmul ins(%twod, %twod : tensor<42x42xf32>, tensor<42x42xf32>)
757                  outs(%1 : tensor<42x42xf32>) -> tensor<42x42xf32>
758    return
759  }
760}
761
762// -----
763
764module attributes { transform.with_named_sequence } {
765  transform.named_sequence @match_single_result(%arg0: !transform.any_op {transform.readonly})
766      -> (!transform.any_op, !transform.any_op) {
767    %0:2 = transform.match.structured failures(propagate) %arg0
768         : (!transform.any_op) -> (!transform.any_op, !transform.any_op) {
769    ^bb0(%arg1: !transform.any_op):
770      %1 = transform.match.structured.result %arg1[0] { single } : (!transform.any_op) -> !transform.any_op
771      transform.match.structured.yield %1, %arg1 : !transform.any_op, !transform.any_op
772    }
773    transform.yield %0#0, %0#1 : !transform.any_op, !transform.any_op
774  }
775  transform.named_sequence @match_result_value(%arg0: !transform.any_op {transform.readonly})
776      -> (!transform.any_value, !transform.any_op) {
777    %0:2 = transform.match.structured failures(propagate) %arg0
778         : (!transform.any_op) -> (!transform.any_value, !transform.any_op) {
779    ^bb0(%arg1: !transform.any_op):
780      %1 = transform.match.structured.result %arg1[0] : (!transform.any_op) -> !transform.any_value
781      transform.match.structured.yield %1, %arg1 : !transform.any_value, !transform.any_op
782    }
783    transform.yield %0#0, %0#1 : !transform.any_value, !transform.any_op
784  }
785  transform.named_sequence @match_any_result(%arg0: !transform.any_op {transform.readonly})
786      -> (!transform.any_op) {
787    %0 = transform.match.structured failures(propagate) %arg0
788         : (!transform.any_op) -> !transform.any_op {
789    ^bb0(%arg1: !transform.any_op):
790      %1 = transform.match.structured.result %arg1[-1] { any } : (!transform.any_op) -> !transform.any_op
791      transform.match.structured.yield %arg1 : !transform.any_op
792    }
793    transform.yield %0 : !transform.any_op
794  }
795
796  transform.named_sequence @print_single_result(%arg0: !transform.any_op {transform.readonly},
797                                                %arg2: !transform.any_op {transform.readonly}) {
798    transform.debug.emit_remark_at %arg2, "matched single result" : !transform.any_op
799    transform.debug.emit_remark_at %arg0, "single user" : !transform.any_op
800    transform.yield
801  }
802  transform.named_sequence @print_result_value(%arg0: !transform.any_value {transform.readonly},
803                                               %arg1: !transform.any_op {transform.readonly}) {
804    transform.debug.emit_remark_at %arg1, "matched result value" : !transform.any_op
805    transform.debug.emit_remark_at %arg0, "op result" : !transform.any_value
806    transform.yield
807  }
808  transform.named_sequence @print_any_result(%arg0: !transform.any_op {transform.readonly}) {
809    transform.debug.emit_remark_at %arg0, "matched any result" : !transform.any_op
810    transform.yield
811  }
812
813  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
814    %0 = transform.foreach_match in %arg0 @match_single_result -> @print_single_result : (!transform.any_op) -> !transform.any_op
815    %1 = transform.foreach_match in %0 @match_result_value -> @print_result_value : (!transform.any_op) -> !transform.any_op
816    %2 = transform.foreach_match in %1 @match_any_result -> @print_any_result : (!transform.any_op) -> !transform.any_op
817    transform.yield
818  }
819
820  func.func @payload(%f32: f32, %f322: f32, %f323: f32,
821            %twod: tensor<42x42xf32>)
822            attributes { transform.target_tag = "start_here" } {
823    %0 = tensor.empty() : tensor<42x42xf32>
824
825    // expected-remark @below {{matched result value}}
826    // expected-remark @below {{op result}}
827    // expected-note @below {{value handle points to an op result #0}}
828    %1 = linalg.fill ins(%f32 : f32) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
829    // expected-remark @below {{matched result value}}
830    // expected-remark @below {{op result}}
831    // expected-note @below {{value handle points to an op result #0}}
832    // expected-remark @below {{matched single result}}
833    // expected-remark @below {{matched any result}}
834    %2 = linalg.fill ins(%f322 : f32) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
835    // expected-remark @below {{matched result value}}
836    // expected-remark @below {{op result}}
837    // expected-note @below {{value handle points to an op result #0}}
838    // expected-remark @below {{matched any result}}
839    %3 = linalg.fill ins(%f323 : f32) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
840
841    // expected-remark @below {{matched result value}}
842    // expected-remark @below {{op result}}
843    // expected-note @below {{value handle points to an op result #0}}
844    // expected-remark @below {{single user}}
845    linalg.elemwise_unary {fun = #linalg.unary_fn<negf>} ins(%2 : tensor<42x42xf32>) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
846    // expected-remark @below {{matched result value}}
847    // expected-remark @below {{op result}}
848    // expected-note @below {{value handle points to an op result #0}}
849    linalg.elemwise_unary {fun = #linalg.unary_fn<exp>} ins(%3 : tensor<42x42xf32>) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
850    // expected-remark @below {{matched result value}}
851    // expected-remark @below {{op result}}
852    // expected-note @below {{value handle points to an op result #0}}
853    linalg.elemwise_unary {fun = #linalg.unary_fn<exp>} ins(%3 : tensor<42x42xf32>) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
854    return
855  }
856}
857
858// -----
859
860module attributes { transform.with_named_sequence } {
861  transform.named_sequence @match_input_indexing_map(%arg0: !transform.any_op {transform.readonly})
862      -> (!transform.affine_map, !transform.any_op) {
863    %0 = transform.match.structured failures(propagate) %arg0
864         : (!transform.any_op) -> !transform.affine_map {
865    ^bb0(%arg1: !transform.any_op):
866      %1 = transform.match.structured.input %arg1[0]  : (!transform.any_op) -> !transform.affine_map
867      transform.match.structured.yield %1 : !transform.affine_map
868    }
869    transform.yield %0, %arg0 : !transform.affine_map, !transform.any_op
870  }
871  transform.named_sequence @match_init_indexing_map(%arg0: !transform.any_op {transform.readonly})
872      -> (!transform.affine_map, !transform.any_op) {
873    %0 = transform.match.structured failures(propagate) %arg0
874         : (!transform.any_op) -> !transform.affine_map {
875    ^bb0(%arg1: !transform.any_op):
876      %1 = transform.match.structured.init %arg1[0]  : (!transform.any_op) -> !transform.affine_map
877      transform.match.structured.yield %1 : !transform.affine_map
878    }
879    transform.yield %0, %arg0 : !transform.affine_map, !transform.any_op
880  }
881
882  transform.named_sequence @print_indexing_map_1(%arg0: !transform.affine_map {transform.readonly},
883                                               %arg1: !transform.any_op {transform.readonly}) {
884    transform.debug.emit_param_as_remark %arg0, "indexing map 1" at %arg1 : !transform.affine_map, !transform.any_op
885    transform.yield
886  }
887  transform.named_sequence @print_indexing_map_2(%arg0: !transform.affine_map {transform.readonly},
888                                               %arg1: !transform.any_op {transform.readonly}) {
889    transform.debug.emit_param_as_remark %arg0, "indexing map 2" at %arg1 : !transform.affine_map, !transform.any_op
890    transform.yield
891  }
892
893  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
894    %3 = transform.foreach_match in %arg0 @match_input_indexing_map -> @print_indexing_map_1 : (!transform.any_op) -> !transform.any_op
895    %4 = transform.foreach_match in %3 @match_init_indexing_map -> @print_indexing_map_2 : (!transform.any_op) -> !transform.any_op
896    transform.yield
897  }
898
899  func.func @payload(%lhs: tensor<32x32xf32>, %rhs: tensor<32x32xf32>)
900            attributes { transform.target_tag = "start_here" } {
901    %out = tensor.empty() : tensor<32x32xf32>
902    %cst = arith.constant 1.0 : f32
903    // expected-remark @below {{indexing map 1 affine_map<(d0, d1) -> ()>}}
904    // expected-remark @below {{indexing map 2 affine_map<(d0, d1) -> (d0, d1)>}}
905    %res = linalg.fill ins(%cst : f32) outs(%out : tensor<32x32xf32>) -> tensor<32x32xf32>
906    // expected-remark @below {{indexing map 1 affine_map<(d0, d1, d2) -> (d0, d2)>}}
907    // expected-remark @below {{indexing map 2 affine_map<(d0, d1, d2) -> (d0, d1)>}}
908    linalg.matmul ins(%lhs, %rhs : tensor<32x32xf32>, tensor<32x32xf32>) outs(%res : tensor<32x32xf32>) -> tensor<32x32xf32>
909    return
910  }
911}
912
913// -----
914
915module attributes { transform.with_named_sequence } {
916  transform.named_sequence @match_contraction(%arg0: !transform.any_op {transform.readonly})
917    -> (!transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>) {
918    %1:4 = transform.match.structured %arg0 : (!transform.any_op) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>) {
919    ^bb0(%struct: !transform.any_op):
920      transform.match.structured.body %struct { contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op
921      %0:4 = transform.match.structured.classify_contraction_dims %struct
922        : (!transform.any_op) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>)
923      transform.match.structured.yield %0#0, %0#1, %0#2, %0#3
924        : !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
925    }
926    transform.yield %arg0, %1#0, %1#1, %1#2, %1#3 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
927  }
928
929  transform.named_sequence @print_contraction(
930      %op: !transform.any_op {transform.readonly},
931      %batch: !transform.param<i64> {transform.readonly},
932      %m: !transform.param<i64> {transform.readonly},
933      %n: !transform.param<i64> {transform.readonly},
934      %k: !transform.param<i64> {transform.readonly}) {
935    transform.debug.emit_remark_at %op, "contraction" : !transform.any_op
936    transform.debug.emit_param_as_remark %batch, "batch dims" at %op : !transform.param<i64>, !transform.any_op
937    transform.debug.emit_param_as_remark %m, "m dims" at %op : !transform.param<i64>, !transform.any_op
938    transform.debug.emit_param_as_remark %n, "n dims" at %op : !transform.param<i64>, !transform.any_op
939    transform.debug.emit_param_as_remark %k, "k dims" at %op : !transform.param<i64>, !transform.any_op
940    transform.yield
941  }
942
943  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
944    %3 = transform.foreach_match in %arg0 @match_contraction -> @print_contraction : (!transform.any_op) -> !transform.any_op
945    transform.yield
946  }
947}
948
949module attributes { transform.target_tag = "start_here" } {
950  func.func @matmul_simple(%lhs: tensor<10x20xf32>, %rhs: tensor<20x15xf32>) -> tensor<10x15xf64> {
951    %cst = arith.constant 0.0 : f64
952    %empty = tensor.empty() : tensor<10x15xf64>
953    %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf64>) -> tensor<10x15xf64>
954    // expected-remark @below {{contraction}}
955    // expected-remark @below {{batch dims}}
956    // expected-remark @below {{m dims 0}}
957    // expected-remark @below {{n dims 1}}
958    // expected-remark @below {{k dims 2}}
959    %result = linalg.matmul ins(%lhs, %rhs: tensor<10x20xf32>, tensor<20x15xf32>) outs(%fill: tensor<10x15xf64>) -> tensor<10x15xf64>
960    return %result : tensor<10x15xf64>
961  }
962
963  func.func @vecmat_simple(%lhs: tensor<20xf32>, %rhs: tensor<20x15xf32>) -> tensor<15xf64> {
964    %cst = arith.constant 0.0 : f64
965    %empty = tensor.empty() : tensor<15xf64>
966    %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<15xf64>) -> tensor<15xf64>
967    // expected-remark @below {{contraction}}
968    // expected-remark @below {{batch dims}}
969    // expected-remark @below {{m dims}}
970    // expected-remark @below {{n dims 0}}
971    // expected-remark @below {{k dims 1}}
972    %result = linalg.vecmat ins(%lhs, %rhs: tensor<20xf32>, tensor<20x15xf32>) outs(%fill: tensor<15xf64>) -> tensor<15xf64>
973    return %result : tensor<15xf64>
974  }
975
976  func.func @double_batch(%lhs: tensor<40x10x50x20xf32>, %rhs: tensor<40x20x50x15xf32>) -> tensor<40x10x50x15xf32> {
977    %cst = arith.constant 0.0 : f32
978    %empty = tensor.empty() : tensor<40x10x50x15xf32>
979    %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<40x10x50x15xf32>) -> tensor<40x10x50x15xf32>
980    // expected-remark @below {{contraction}}
981    // expected-remark @below {{batch dims 0 : i64, 2 : i64}}
982    // expected-remark @below {{m dims 1}}
983    // expected-remark @below {{n dims 3}}
984    // expected-remark @below {{k dims 4}}
985    %result = linalg.generic {
986      indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>,
987                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2, d3)>,
988                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>],
989      iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]
990    } ins(%lhs, %rhs : tensor<40x10x50x20xf32>, tensor<40x20x50x15xf32>)
991      outs(%fill : tensor<40x10x50x15xf32>) {
992    ^bb(%arg0: f32, %arg1: f32, %arg2: f32):
993      %0 = arith.mulf %arg0, %arg1 : f32
994      %1 = arith.addf %arg2, %0 : f32
995      linalg.yield %1 : f32
996    } -> tensor<40x10x50x15xf32>
997    return %result : tensor<40x10x50x15xf32>
998  }
999
1000  func.func @generic_min(%arg0: tensor<1x7x4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<1x1x4xf32>) {
1001    linalg.generic {
1002      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1 * 2 + d3 * 2, d2)>,
1003      affine_map<(d0, d1, d2, d3) -> (d3)>,
1004      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
1005      iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
1006      ins(%arg0, %arg1 : tensor<1x7x4xf32>, tensor<4xf32>)
1007      outs(%arg2 : tensor<1x1x4xf32>) {
1008    ^bb0(%in: f32, %in_1: f32, %out: f32):
1009      %5 = arith.minimumf %out, %in : f32
1010      linalg.yield %5 : f32
1011    } -> tensor<1x1x4xf32>
1012    return
1013  }
1014}
1015
1016// -----
1017
1018module attributes { transform.with_named_sequence } {
1019  transform.named_sequence @match_convolution(%arg0: !transform.any_op {transform.readonly})
1020    -> (!transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>) {
1021    %1:8 = transform.match.structured %arg0 : (!transform.any_op) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>) {
1022    ^bb0(%struct: !transform.any_op):
1023      transform.match.structured.body %struct { contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op
1024      %0:8 = transform.match.structured.classify_convolution_dims %struct
1025        : (!transform.any_op) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>)
1026      transform.match.structured.yield %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7
1027        : !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
1028    }
1029    transform.yield %arg0, %1#0, %1#1, %1#2, %1#3, %1#4, %1#5, %1#6, %1#7 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
1030  }
1031
1032  transform.named_sequence @print_convolution(
1033      %op: !transform.any_op {transform.readonly},
1034      %batch: !transform.param<i64> {transform.readonly},
1035      %oi: !transform.param<i64> {transform.readonly},
1036      %oc: !transform.param<i64> {transform.readonly},
1037      %fl: !transform.param<i64> {transform.readonly},
1038      %ic: !transform.param<i64> {transform.readonly},
1039      %depth: !transform.param<i64> {transform.readonly},
1040      %strides: !transform.param<i64> {transform.readonly},
1041      %dilations: !transform.param<i64> {transform.readonly}) {
1042    transform.debug.emit_remark_at %op, "convolution" : !transform.any_op
1043    transform.debug.emit_param_as_remark %batch, "batch dims" at %op : !transform.param<i64>, !transform.any_op
1044    transform.debug.emit_param_as_remark %oi, "output image dims" at %op : !transform.param<i64>, !transform.any_op
1045    transform.debug.emit_param_as_remark %oc, "output channel dims" at %op : !transform.param<i64>, !transform.any_op
1046    transform.debug.emit_param_as_remark %fl, "filter loop dims" at %op : !transform.param<i64>, !transform.any_op
1047    transform.debug.emit_param_as_remark %ic, "input channel dims" at %op : !transform.param<i64>, !transform.any_op
1048    transform.debug.emit_param_as_remark %depth, "depth dims" at %op : !transform.param<i64>, !transform.any_op
1049    transform.debug.emit_param_as_remark %strides, "strides" at %op : !transform.param<i64>, !transform.any_op
1050    transform.debug.emit_param_as_remark %dilations, "dilations" at %op : !transform.param<i64>, !transform.any_op
1051    transform.yield
1052  }
1053
1054  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
1055    %3 = transform.foreach_match in %arg0 @match_convolution -> @print_convolution : (!transform.any_op) -> !transform.any_op
1056    transform.yield
1057  }
1058}
1059
1060module attributes { transform.target_tag = "start_here" } {
1061  func.func @convolution_simple(%input: tensor<10x20x30xf32>, %filter: tensor<3x30x15xf32>) -> tensor<10x18x15xf64> {
1062    %cst = arith.constant 0.0 : f64
1063    %empty = tensor.empty() : tensor<10x18x15xf64>
1064    %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x18x15xf64>) -> tensor<10x18x15xf64>
1065    // expected-remark @below {{convolution}}
1066    // expected-remark @below {{batch dims 0}}
1067    // expected-remark @below {{output image dims 1}}
1068    // expected-remark @below {{output channel dims 2}}
1069    // expected-remark @below {{filter loop dims 3}}
1070    // expected-remark @below {{input channel dims 4}}
1071    // expected-remark @below {{depth dims}}
1072    // expected-remark @below {{strides 1}}
1073    // expected-remark @below {{dilations 1}}
1074    %result = linalg.conv_1d_nwc_wcf {dilations = dense<1> : tensor<1xi64>,
1075                                      strides = dense<1> : tensor<1xi64>}
1076       ins(%input, %filter: tensor<10x20x30xf32>, tensor<3x30x15xf32>) outs(%fill: tensor<10x18x15xf64>) -> tensor<10x18x15xf64>
1077    return %result : tensor<10x18x15xf64>
1078  }
1079
1080  func.func @convolution_depthwise(%input: tensor<1x10x196x48xf32>, %filter: tensor<1x4x48xf32>) -> tensor<1x10x191x48xf32> {
1081    %cst = arith.constant 0.0 : f32
1082    %empty = tensor.empty() : tensor<1x10x191x48xf32>
1083    %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x10x191x48xf32>) -> tensor<1x10x191x48xf32>
1084    // expected-remark @below {{convolution}}
1085    // expected-remark @below {{batch dims 0}}
1086    // expected-remark @below {{output image dims 1 : i64, 2 : i64}}
1087    // expected-remark @below {{output channel dims}}
1088    // expected-remark @below {{filter loop dims 4 : i64, 5 : i64}}
1089    // expected-remark @below {{input channel dims}}
1090    // expected-remark @below {{depth dims 3}}
1091    // expected-remark @below {{strides 1 : i64, 1 : i64}}
1092    // expected-remark @below {{dilations 1 : i64, 1 : i64}}
1093    %result = linalg.depthwise_conv_2d_nhwc_hwc {
1094      dilations = dense<1> : tensor<2xi64>,
1095      strides = dense<1> : tensor<2xi64>}
1096      ins(%input, %filter : tensor<1x10x196x48xf32>, tensor<1x4x48xf32>)
1097      outs(%fill : tensor<1x10x191x48xf32>) -> tensor<1x10x191x48xf32>
1098
1099    return %result : tensor<1x10x191x48xf32>
1100  }
1101
1102  func.func @convolution_multi_channel(%input: tensor<2x34x68x16xf32>, %filter: tensor<8x2x3x5x16x16xf32>) -> tensor<8x32x32x16xf32> {
1103    %cst = arith.constant 0.0 : f32
1104    %empty = tensor.empty() : tensor<8x32x32x16xf32>
1105    %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<8x32x32x16xf32>) -> tensor<8x32x32x16xf32>
1106    // expected-remark @below {{convolution}}
1107    // expected-remark @below {{batch dims}}
1108    // expected-remark @below {{output image dims 1 : i64, 2 : i64}}
1109    // expected-remark @below {{output channel dims 0 : i64, 3 : i64}}
1110    // expected-remark @below {{filter loop dims 5 : i64, 6 : i64}}
1111    // expected-remark @below {{input channel dims 4 : i64, 7 : i64}}
1112    // expected-remark @below {{depth dims}}
1113    // expected-remark @below {{strides 1 : i64, 2 : i64}}
1114    // expected-remark @below {{dilations 1 : i64, 1 : i64}}
1115    %result = linalg.generic {
1116        indexing_maps = [
1117            affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d1 + d5, 2 * d2 + d6, d7)>,
1118            affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d4, d5, d6, d7, d3)>,
1119            affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3)>],
1120        iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction", "reduction"]}
1121        ins(%input, %filter : tensor<2x34x68x16xf32>, tensor<8x2x3x5x16x16xf32>) outs(%fill : tensor<8x32x32x16xf32>) {
1122          ^bb0(%in: f32, %in_0: f32, %out: f32):
1123            %mul = arith.mulf %in, %in_0 : f32
1124            %add = arith.addf %mul, %out : f32
1125            linalg.yield %add : f32
1126          } -> tensor<8x32x32x16xf32>
1127    return %result : tensor<8x32x32x16xf32>
1128  }
1129}
1130