1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir import ir 4from mlir.dialects.transform import interpreter as interp 5 6 7def test_in_context(f): 8 with ir.Context(), ir.Location.unknown(): 9 f() 10 return f 11 12 13print_root_module = """ 14module attributes {transform.with_named_sequence} { 15 transform.named_sequence @__transform_main(%root: !transform.any_op) { 16 transform.print %root { name = \"from interpreter\" }: !transform.any_op 17 transform.yield 18 } 19}""" 20 21 22@test_in_context 23def print_self(): 24 m = ir.Module.parse(print_root_module.replace("from interpreter", "print_self")) 25 interp.apply_named_sequence(m, m.body.operations[0], m) 26 27 28# CHECK-LABEL: print_self 29# CHECK: transform.named_sequence @__transform_main 30# CHECK: transform.print 31# CHECK: transform.yield 32 33 34@test_in_context 35def print_other(): 36 transform = ir.Module.parse( 37 print_root_module.replace("from interpreter", "print_other") 38 ) 39 payload = ir.Module.parse("module attributes { this.is.payload } {}") 40 interp.apply_named_sequence(payload, transform.body.operations[0], transform) 41 42 43# CHECK-LABEL: print_other 44# CHECK-NOT: transform 45# CHECK: this.is.payload 46 47 48@test_in_context 49def transform_options(): 50 options = interp.TransformOptions() 51 options.expensive_checks = False 52 options.enforce_single_top_level_transform_op = True 53 m = ir.Module.parse( 54 print_root_module.replace("from interpreter", "transform_options") 55 ) 56 payload = ir.Module.parse("module attributes { this.is.payload } {}") 57 interp.apply_named_sequence(payload, m.body.operations[0], m, options) 58 59 60# CHECK-LABEL: transform_options 61 62 63@test_in_context 64def failed(): 65 payload = ir.Module.parse("module attributes { this.is.payload } {}") 66 try: 67 interp.apply_named_sequence(payload, payload, payload) 68 except ValueError as e: 69 assert ( 70 "must implement TransformOpInterface to be used as transform root" in str(e) 71 ) 72 73 74print_root_via_include_module = """ 75module @print_root_via_include_module attributes {transform.with_named_sequence} { 76 transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly}) 77 transform.named_sequence private @callee2(%root: !transform.any_op {transform.readonly}) 78 transform.named_sequence @__transform_main(%root: !transform.any_op) { 79 transform.include @callee2 failures(propagate) 80 (%root) : (!transform.any_op) -> () 81 transform.yield 82 } 83}""" 84 85callee2_definition = """ 86module attributes {transform.with_named_sequence} { 87 transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly}) 88 transform.named_sequence @callee2(%root: !transform.any_op {transform.readonly}) { 89 transform.include @callee1 failures(propagate) 90 (%root) : (!transform.any_op) -> () 91 transform.yield 92 } 93} 94""" 95 96callee1_definition = """ 97module attributes {transform.with_named_sequence} { 98 transform.named_sequence @callee1(%root: !transform.any_op {transform.readonly}) { 99 transform.print %root { name = \"from interpreter\" }: !transform.any_op 100 transform.yield 101 } 102} 103""" 104 105 106@test_in_context 107def include(): 108 main = ir.Module.parse(print_root_via_include_module) 109 callee1 = ir.Module.parse(callee1_definition) 110 callee2 = ir.Module.parse(callee2_definition) 111 interp.copy_symbols_and_merge_into(main, callee1) 112 interp.copy_symbols_and_merge_into(main, callee2) 113 114 # CHECK: @print_root_via_include_module 115 # CHECK: transform.named_sequence @__transform_main 116 # CHECK: transform.include @callee2 117 # 118 # CHECK: transform.named_sequence @callee1 119 # CHECK: transform.print 120 # 121 # CHECK: transform.named_sequence @callee2 122 # CHECK: transform.include @callee1 123 interp.apply_named_sequence(main, main.body.operations[0], main) 124 125 126@test_in_context 127def partial_include(): 128 main = ir.Module.parse(print_root_via_include_module) 129 callee2 = ir.Module.parse(callee2_definition) 130 interp.copy_symbols_and_merge_into(main, callee2) 131 132 try: 133 interp.apply_named_sequence(main, main.body.operations[0], main) 134 except ValueError as e: 135 assert "Failed to apply" in str(e) 136 137 138@test_in_context 139def repeated_include(): 140 main = ir.Module.parse(print_root_via_include_module) 141 callee2 = ir.Module.parse(callee2_definition) 142 interp.copy_symbols_and_merge_into(main, callee2) 143 144 try: 145 interp.copy_symbols_and_merge_into(main, callee2) 146 except ValueError as e: 147 assert "doubly defined symbol @callee2" in str(e) 148