xref: /llvm-project/mlir/test/Dialect/Transform/test-pattern-application.mlir (revision 37b26bf48b9894ed0c13fd1aede23472660fb75e)
1// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s
2
3// CHECK-LABEL: func @update_tracked_op_mapping()
4//       CHECK:   "test.container"() ({
5//       CHECK:     %0 = "test.foo"() {annotated} : () -> i32
6//       CHECK:   }) : () -> ()
7func.func @update_tracked_op_mapping() {
8  "test.container"() ({
9    %0 = "test.foo"() {replace_with_new_op = "test.foo"} : () -> (i32)
10  }) : () -> ()
11  return
12}
13
14module attributes {transform.with_named_sequence} {
15  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
16    %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
17    %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
18    transform.apply_patterns to %0 {
19      transform.apply_patterns.transform.test_patterns
20    } : !transform.any_op
21    // Add an attribute to %1, which is now mapped to a new op.
22    transform.annotate %1 "annotated" : !transform.any_op
23    transform.yield
24  }
25}
26
27// -----
28
29// CHECK-LABEL: @limited_updates
30func.func @limited_updates() {
31  "test.container"() ({
32    // Only one is replaced.
33    // CHECK: "test.foo"() {replace_with_new_op = "test.foo"}
34    // CHECK: "test.foo"() : ()
35    %0 = "test.foo"() {replace_with_new_op = "test.foo"} : () -> (i32)
36    %1 = "test.foo"() {replace_with_new_op = "test.foo"} : () -> (i32)
37  }) : () -> ()
38  return
39}
40
41module attributes {transform.with_named_sequence} {
42  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
43    // Pattern application will fail because of the upper limit, wrap in
44    // sequence to suppress the error message.
45    transform.sequence %arg0 : !transform.any_op failures(suppress) {
46    ^bb0(%arg1: !transform.any_op):
47      %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
48      %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
49      transform.apply_patterns to %0 {
50        transform.apply_patterns.transform.test_patterns
51      }  {max_num_rewrites = 1} : !transform.any_op
52    }
53    transform.yield
54  }
55}
56
57// -----
58
59func.func @replacement_op_not_found() {
60  "test.container"() ({
61    // expected-note @below {{[0] replaced op}}
62    // expected-note @below {{[0] replacement value 0}}
63    %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32)
64  }) : () -> ()
65  return
66}
67
68module attributes {transform.with_named_sequence} {
69  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
70    %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
71    // expected-note @below {{replacement is required because this handle must be updated}}
72    %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
73    // expected-error @below {{tracking listener failed to find replacement op during application of this transform op}}
74    // expected-note @below {{ran out of suitable replacement values}}
75    transform.apply_patterns to %0 {
76      transform.apply_patterns.transform.test_patterns
77    } : !transform.any_op
78    // %1 must be used in some way. If no replacement payload op could be found,
79    // an error is thrown only if the handle is not dead.
80    transform.annotate %1 "annotated" : !transform.any_op
81    transform.yield
82  }
83}
84
85// -----
86
87// CHECK-LABEL: func @replacement_op_for_dead_handle_not_found()
88//       CHECK:   "test.container"() ({
89//       CHECK:     %0 = "test.bar"() : () -> i32
90//       CHECK:   }) : () -> ()
91func.func @replacement_op_for_dead_handle_not_found() {
92  "test.container"() ({
93    %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32)
94  }) : () -> ()
95  return
96}
97
98module attributes {transform.with_named_sequence} {
99  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
100    %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
101    %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
102    // No error because %1 is dead.
103    transform.apply_patterns to %0 {
104      transform.apply_patterns.transform.test_patterns
105    } : !transform.any_op
106    transform.yield
107  }
108}
109
110// -----
111
112// CHECK-LABEL: func @replacement_op_not_found_silenced()
113//       CHECK:   "test.container"() ({
114//       CHECK:     %0 = "test.bar"() : () -> i32
115//       CHECK:   }) : () -> ()
116func.func @replacement_op_not_found_silenced() {
117  "test.container"() ({
118    %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32)
119  }) : () -> ()
120  return
121}
122
123module attributes {transform.with_named_sequence} {
124  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
125    %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
126    %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
127    transform.apply_patterns to %0 {
128      transform.apply_patterns.transform.test_patterns
129    } {transform.silence_tracking_failures} : !transform.any_op
130    transform.annotate %1 "annotated" : !transform.any_op
131    transform.yield
132  }
133}
134
135// -----
136
137// CHECK-LABEL: func @patterns_apply_only_to_target_body()
138//       CHECK:   %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> i32
139func.func @patterns_apply_only_to_target_body() {
140  %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32)
141  return
142}
143
144module attributes {transform.with_named_sequence} {
145  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
146  %0 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
147    transform.apply_patterns to %0 {
148      transform.apply_patterns.transform.test_patterns
149    } : !transform.any_op
150    transform.yield
151  }
152}
153
154// -----
155
156// CHECK-LABEL: func @erase_tracked_op()
157//       CHECK:   "test.container"() ({
158//  CHECK-NEXT:   ^bb0:
159//  CHECK-NEXT:   }) : () -> ()
160func.func @erase_tracked_op() {
161  "test.container"() ({
162    // expected-remark @below {{matched op}}
163    %0 = "test.erase_op"() {replace_with_new_op = "test.foo"} : () -> (i32)
164  }) : () -> ()
165  return
166}
167
168module attributes {transform.with_named_sequence} {
169  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
170    %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
171    %1 = transform.structured.match ops{["test.erase_op"]} in %arg1 : (!transform.any_op) -> !transform.any_op
172    transform.debug.emit_remark_at %1, "matched op" : !transform.any_op
173    transform.apply_patterns to %0 {
174      transform.apply_patterns.transform.test_patterns
175    } : !transform.any_op
176    // No marker should be printed.
177    transform.debug.emit_remark_at %1, "op was deleted" : !transform.any_op
178    transform.yield
179  }
180}
181
182// -----
183
184// CHECK-LABEL: func @erase_tracked_op_in_named_sequence()
185//       CHECK:   "test.container"() ({
186//  CHECK-NEXT:   ^bb0:
187//  CHECK-NEXT:   }) : () -> ()
188module attributes {transform.with_named_sequence} {
189  func.func @erase_tracked_op_in_named_sequence() {
190    "test.container"() ({
191      // expected-remark @below {{matched op}}
192      %0 = "test.erase_op"() {replace_with_new_op = "test.foo"} : () -> (i32)
193    }) : () -> ()
194    return
195  }
196
197  transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> () {
198    transform.apply_patterns to %arg0 {
199      transform.apply_patterns.transform.test_patterns
200    } : !transform.any_op
201    transform.yield
202  }
203
204  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
205    %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
206    %1 = transform.structured.match ops{["test.erase_op"]} in %arg1 : (!transform.any_op) -> !transform.any_op
207    transform.debug.emit_remark_at %1, "matched op" : !transform.any_op
208    transform.include @foo failures(propagate) (%0) : (!transform.any_op) -> ()
209    // No marker should be printed.
210    transform.debug.emit_remark_at %1, "op was deleted" : !transform.any_op
211    transform.yield
212  }
213}
214
215// -----
216
217// CHECK-LABEL: func @canonicalization(
218//       CHECK:   %[[c5:.*]] = arith.constant 5 : index
219//       CHECK:   return %[[c5]]
220func.func @canonicalization(%t: tensor<5xf32>) -> index {
221  %c0 = arith.constant 0 : index
222  %dim = tensor.dim %t, %c0 : tensor<5xf32>
223  return %dim : index
224}
225
226module attributes {transform.with_named_sequence} {
227  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
228    %0 = transform.structured.match ops{["tensor.dim"]} in %arg1 : (!transform.any_op) -> !transform.any_op
229    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
230    transform.apply_patterns to %1 {
231      transform.apply_patterns.canonicalization
232    } : !transform.any_op
233    transform.yield
234  }
235}
236
237// -----
238
239// expected-note @below{{target payload op}}
240module {
241  func.func @invalid_pattern_application_to_transform_ir() {
242    return
243  }
244
245  module attributes {transform.with_named_sequence} {
246    transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
247      // expected-error @below {{cannot apply transform to itself (or one of its ancestors)}}
248      transform.apply_patterns to %arg1 {
249        transform.apply_patterns.canonicalization
250      } : !transform.any_op
251      transform.yield
252    }
253  }
254}
255
256// -----
257
258// CHECK-LABEL: func @canonicalization_and_cse(
259//   CHECK-NOT:   memref.subview
260//   CHECK-NOT:   memref.copy
261func.func @canonicalization_and_cse(%m: memref<5xf32>) {
262  %c2 = arith.constant 2 : index
263  %s0 = memref.subview %m[1] [2] [1] : memref<5xf32> to memref<2xf32, strided<[1], offset: 1>>
264  %s1 = memref.subview %m[1] [%c2] [1] : memref<5xf32> to memref<?xf32, strided<[1], offset: 1>>
265  memref.copy %s0, %s1 : memref<2xf32, strided<[1], offset: 1>> to memref<?xf32, strided<[1], offset: 1>>
266  return
267}
268
269module attributes {transform.with_named_sequence} {
270  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
271    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
272    transform.apply_patterns to %1 {
273      transform.apply_patterns.canonicalization
274    } {apply_cse} : !transform.any_op
275    transform.yield
276  }
277}
278
279// -----
280
281// CHECK-LABEL: func @full_dialect_conversion
282//  CHECK-NEXT:   %[[m:.*]] = "test.new_op"() : () -> memref<5xf32>
283//  CHECK-NEXT:   %[[cast:.*]] = builtin.unrealized_conversion_cast %0 : memref<5xf32> to tensor<5xf32>
284//  CHECK-NEXT:   return %[[cast]]
285func.func @full_dialect_conversion() -> tensor<5xf32> {
286  %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (tensor<5xf32>)
287  return %0 : tensor<5xf32>
288}
289
290module attributes {transform.with_named_sequence} {
291  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
292    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
293    transform.apply_conversion_patterns to %0 {
294      transform.apply_conversion_patterns.transform.test_conversion_patterns
295    } with type_converter {
296      transform.apply_conversion_patterns.transform.test_type_converter
297    } {legal_ops = ["func.func", "func.return", "test.new_op"]}
298        : !transform.any_op
299    transform.yield
300  }
301}
302
303// -----
304
305// Full dialect conversion fails because test.bar is not replaced and not legal.
306
307// expected-note @below{{target op}}
308func.func @full_dialect_conversion_failed() -> tensor<5xf32> {
309  %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (tensor<5xf32>)
310  // expected-error @below{{failed to legalize operation 'test.bar'}}
311  "test.bar"() : () -> ()
312  return %0 : tensor<5xf32>
313}
314
315module attributes {transform.with_named_sequence} {
316  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
317    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
318    // expected-error @below{{dialect conversion failed}}
319    transform.apply_conversion_patterns to %0 {
320      transform.apply_conversion_patterns.transform.test_conversion_patterns
321    } with type_converter {
322      transform.apply_conversion_patterns.transform.test_type_converter
323    } {legal_ops = ["func.func", "func.return", "test.new_op"]}
324        : !transform.any_op
325    transform.yield
326  }
327}
328
329// -----
330
331// Partial dialect conversion succeeds because test.bar is not explicitly
332// illegal.
333
334// CHECK-LABEL: func @partial_dialect_conversion
335//  CHECK-NEXT:   %[[m:.*]] = "test.new_op"() : () -> memref<5xf32>
336//  CHECK-NEXT:   %[[cast:.*]] = builtin.unrealized_conversion_cast %0 : memref<5xf32> to tensor<5xf32>
337//  CHECK-NEXT:   "test.bar"
338//  CHECK-NEXT:   return %[[cast]]
339func.func @partial_dialect_conversion() -> tensor<5xf32> {
340  %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (tensor<5xf32>)
341  "test.bar"() : () -> ()
342  return %0 : tensor<5xf32>
343}
344
345module attributes {transform.with_named_sequence} {
346  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
347    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
348    transform.apply_conversion_patterns to %0 {
349      transform.apply_conversion_patterns.transform.test_conversion_patterns
350    } with type_converter {
351      transform.apply_conversion_patterns.transform.test_type_converter
352    } {legal_ops = ["func.func", "func.return", "test.new_op"],
353       partial_conversion} : !transform.any_op
354    transform.yield
355  }
356}
357
358// -----
359
360module attributes {transform.with_named_sequence} {
361  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
362    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
363    // expected-error @below{{pattern descriptor does not specify type converter and apply_conversion_patterns op has no default type converter}}
364    transform.apply_conversion_patterns to %0 {
365      // expected-note @below{{pattern descriptor op}}
366      transform.apply_conversion_patterns.transform.test_conversion_patterns
367    } {illegal_ops = ["test.foo"]} : !transform.any_op
368    transform.yield
369  }
370}
371
372// -----
373
374module attributes {transform.with_named_sequence} {
375  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
376    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
377    transform.apply_conversion_patterns to %0 {
378      // expected-error @below{{expected LLVMTypeConverter}}
379      transform.apply_conversion_patterns.dialect_to_llvm "test"
380    } with type_converter {
381      transform.apply_conversion_patterns.transform.test_type_converter
382    } {illegal_ops = ["test.foo"],
383       legal_ops = ["func.func", "func.return", "test.new_op"]}
384        : !transform.any_op
385    transform.yield
386  }
387}
388
389// -----
390
391module attributes {transform.with_named_sequence} {
392  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
393    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
394    transform.apply_conversion_patterns to %0 {
395      // expected-error @below{{unknown dialect or dialect not loaded: this_dialect_does_not_exist}}
396      transform.apply_conversion_patterns.dialect_to_llvm "this_dialect_does_not_exist"
397    } with type_converter {
398      transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
399    } {illegal_ops = ["test.foo"],
400       legal_ops = ["func.func", "func.return", "test.new_op"]}
401        : !transform.any_op
402    transform.yield
403  }
404}
405
406// -----
407
408module attributes {transform.with_named_sequence} {
409  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
410    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
411    transform.apply_conversion_patterns to %0 {
412      // expected-error @below{{dialect does not implement ConvertToLLVMPatternInterface or extension was not loaded: transform}}
413      transform.apply_conversion_patterns.dialect_to_llvm "transform"
414    } with type_converter {
415      transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
416    } {illegal_ops = ["test.foo"],
417       legal_ops = ["func.func", "func.return", "test.new_op"]}
418        : !transform.any_op
419    transform.yield
420  }
421}
422
423// -----
424
425module attributes { transform.with_named_sequence } {
426  func.func @replacement_op_not_found() {
427    // No op replacement can be found, but there are no handles that must be
428    // updated. No error should be reported.
429    "test.container"() ({
430      %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32)
431    }) : () -> ()
432    return
433  }
434
435  transform.named_sequence @patterns(%container: !transform.any_op {transform.readonly}) {
436    transform.apply_patterns to %container {
437      transform.apply_patterns.transform.test_patterns
438    } : !transform.any_op
439    transform.yield
440  }
441
442  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
443    %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
444    %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
445    transform.annotate %1 "annotated" : !transform.any_op
446    transform.include @patterns failures(propagate) (%0) : (!transform.any_op) -> ()
447    transform.yield
448  }
449}
450
451// -----
452
453// "test.foo" is tracked and replaced with "test.new_op" during a dialect
454// conversion. Make sure that the handle is updated accordingly.
455
456// CHECK-LABEL: func @dialect_conversion_tracking
457//  CHECK-NEXT:   %[[m:.*]] = "test.new_op"() {annotated} : () -> memref<5xf32>
458//  CHECK-NEXT:   %[[cast:.*]] = builtin.unrealized_conversion_cast %0 : memref<5xf32> to tensor<5xf32>
459//  CHECK-NEXT:   return %[[cast]]
460func.func @dialect_conversion_tracking() -> tensor<5xf32> {
461  %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (tensor<5xf32>)
462  return %0 : tensor<5xf32>
463}
464
465module attributes {transform.with_named_sequence} {
466  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
467    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
468    %1 = transform.structured.match ops{["test.foo"]} in %0 : (!transform.any_op) -> !transform.any_op
469    transform.apply_conversion_patterns to %0 {
470      transform.apply_conversion_patterns.transform.test_conversion_patterns
471    } with type_converter {
472      transform.apply_conversion_patterns.transform.test_type_converter
473    } {legal_ops = ["func.func", "func.return", "test.new_op"], preserve_handles}
474        : !transform.any_op
475    // Add an attribute to %1, which is now mapped to a new op.
476    transform.annotate %1 "annotated" : !transform.any_op
477    transform.yield
478  }
479}
480