1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir import ir 4from mlir.dialects.transform import interpreter as interp 5 6 7def test_in_context(f): 8 with ir.Context(), ir.Location.unknown(): 9 f() 10 return f 11 12 13print_root_module = """ 14module attributes {transform.with_named_sequence} { 15 transform.named_sequence @__transform_main(%root: !transform.any_op) { 16 transform.print %root { name = \"from interpreter\" }: !transform.any_op 17 transform.yield 18 } 19}""" 20 21 22@test_in_context 23def print_self(): 24 m = ir.Module.parse(print_root_module.replace("from interpreter", "print_self")) 25 interp.apply_named_sequence(m, m.body.operations[0], m) 26 27 28# CHECK-LABEL: print_self 29# CHECK: transform.named_sequence @__transform_main 30# CHECK: transform.print 31# CHECK: transform.yield 32 33 34@test_in_context 35def print_other(): 36 transform = ir.Module.parse( 37 print_root_module.replace("from interpreter", "print_other") 38 ) 39 payload = ir.Module.parse("module attributes { this.is.payload } {}") 40 interp.apply_named_sequence(payload, transform.body.operations[0], transform) 41 42 43# CHECK-LABEL: print_other 44# CHECK-NOT: transform 45# CHECK: this.is.payload 46 47 48@test_in_context 49def failed(): 50 payload = ir.Module.parse("module attributes { this.is.payload } {}") 51 try: 52 interp.apply_named_sequence(payload, payload, payload) 53 except ValueError as e: 54 assert ( 55 "must implement TransformOpInterface to be used as transform root" in str(e) 56 ) 57 58 59print_root_via_include_module = """ 60module @print_root_via_include_module attributes {transform.with_named_sequence} { 61 transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly}) 62 transform.named_sequence private @callee2(%root: !transform.any_op {transform.readonly}) 63 transform.named_sequence @__transform_main(%root: !transform.any_op) { 64 transform.include @callee2 failures(propagate) 65 (%root) : (!transform.any_op) -> () 66 transform.yield 67 } 68}""" 69 70callee2_definition = """ 71module attributes {transform.with_named_sequence} { 72 transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly}) 73 transform.named_sequence @callee2(%root: !transform.any_op {transform.readonly}) { 74 transform.include @callee1 failures(propagate) 75 (%root) : (!transform.any_op) -> () 76 transform.yield 77 } 78} 79""" 80 81callee1_definition = """ 82module attributes {transform.with_named_sequence} { 83 transform.named_sequence @callee1(%root: !transform.any_op {transform.readonly}) { 84 transform.print %root { name = \"from interpreter\" }: !transform.any_op 85 transform.yield 86 } 87} 88""" 89 90 91@test_in_context 92def include(): 93 main = ir.Module.parse(print_root_via_include_module) 94 callee1 = ir.Module.parse(callee1_definition) 95 callee2 = ir.Module.parse(callee2_definition) 96 interp.copy_symbols_and_merge_into(main, callee1) 97 interp.copy_symbols_and_merge_into(main, callee2) 98 99 # CHECK: @print_root_via_include_module 100 # CHECK: transform.named_sequence @__transform_main 101 # CHECK: transform.include @callee2 102 # 103 # CHECK: transform.named_sequence @callee1 104 # CHECK: transform.print 105 # 106 # CHECK: transform.named_sequence @callee2 107 # CHECK: transform.include @callee1 108 interp.apply_named_sequence(main, main.body.operations[0], main) 109 110 111@test_in_context 112def partial_include(): 113 main = ir.Module.parse(print_root_via_include_module) 114 callee2 = ir.Module.parse(callee2_definition) 115 interp.copy_symbols_and_merge_into(main, callee2) 116 117 try: 118 interp.apply_named_sequence(main, main.body.operations[0], main) 119 except ValueError as e: 120 assert "Failed to apply" in str(e) 121 122 123@test_in_context 124def repeated_include(): 125 main = ir.Module.parse(print_root_via_include_module) 126 callee2 = ir.Module.parse(callee2_definition) 127 interp.copy_symbols_and_merge_into(main, callee2) 128 129 try: 130 interp.copy_symbols_and_merge_into(main, callee2) 131 except ValueError as e: 132 assert "doubly defined symbol @callee2" in str(e) 133