xref: /llvm-project/mlir/test/python/dialects/transform_interpreter.py (revision ff57f40673f0db2c1a867e5697d5407bc9f39a5e)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir import ir
4from mlir.dialects.transform import interpreter as interp
5
6
7def test_in_context(f):
8    with ir.Context(), ir.Location.unknown():
9        f()
10    return f
11
12
13print_root_module = """
14module attributes {transform.with_named_sequence} {
15  transform.named_sequence @__transform_main(%root: !transform.any_op) {
16    transform.print %root { name = \"from interpreter\" }: !transform.any_op
17    transform.yield
18  }
19}"""
20
21
22@test_in_context
23def print_self():
24    m = ir.Module.parse(print_root_module.replace("from interpreter", "print_self"))
25    interp.apply_named_sequence(m, m.body.operations[0], m)
26
27
28# CHECK-LABEL: print_self
29# CHECK: transform.named_sequence @__transform_main
30# CHECK: transform.print
31# CHECK: transform.yield
32
33
34@test_in_context
35def print_other():
36    transform = ir.Module.parse(
37        print_root_module.replace("from interpreter", "print_other")
38    )
39    payload = ir.Module.parse("module attributes { this.is.payload } {}")
40    interp.apply_named_sequence(payload, transform.body.operations[0], transform)
41
42
43# CHECK-LABEL: print_other
44# CHECK-NOT: transform
45# CHECK: this.is.payload
46
47
48@test_in_context
49def transform_options():
50    options = interp.TransformOptions()
51    options.expensive_checks = False
52    options.enforce_single_top_level_transform_op = True
53    m = ir.Module.parse(
54        print_root_module.replace("from interpreter", "transform_options")
55    )
56    payload = ir.Module.parse("module attributes { this.is.payload } {}")
57    interp.apply_named_sequence(payload, m.body.operations[0], m, options)
58
59
60# CHECK-LABEL: transform_options
61
62
63@test_in_context
64def failed():
65    payload = ir.Module.parse("module attributes { this.is.payload } {}")
66    try:
67        interp.apply_named_sequence(payload, payload, payload)
68    except ValueError as e:
69        assert (
70            "must implement TransformOpInterface to be used as transform root" in str(e)
71        )
72
73
74print_root_via_include_module = """
75module @print_root_via_include_module attributes {transform.with_named_sequence} {
76  transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
77  transform.named_sequence private @callee2(%root: !transform.any_op {transform.readonly})
78  transform.named_sequence @__transform_main(%root: !transform.any_op) {
79    transform.include @callee2 failures(propagate)
80        (%root) : (!transform.any_op) -> ()
81    transform.yield
82  }
83}"""
84
85callee2_definition = """
86module attributes {transform.with_named_sequence} {
87  transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
88  transform.named_sequence @callee2(%root: !transform.any_op {transform.readonly}) {
89    transform.include @callee1 failures(propagate)
90        (%root) : (!transform.any_op) -> ()
91    transform.yield
92  }
93}
94"""
95
96callee1_definition = """
97module attributes {transform.with_named_sequence} {
98  transform.named_sequence @callee1(%root: !transform.any_op {transform.readonly}) {
99    transform.print %root { name = \"from interpreter\" }: !transform.any_op
100    transform.yield
101  }
102}
103"""
104
105
106@test_in_context
107def include():
108    main = ir.Module.parse(print_root_via_include_module)
109    callee1 = ir.Module.parse(callee1_definition)
110    callee2 = ir.Module.parse(callee2_definition)
111    interp.copy_symbols_and_merge_into(main, callee1)
112    interp.copy_symbols_and_merge_into(main, callee2)
113
114    # CHECK: @print_root_via_include_module
115    # CHECK: transform.named_sequence @__transform_main
116    # CHECK: transform.include @callee2
117    #
118    # CHECK: transform.named_sequence @callee1
119    # CHECK: transform.print
120    #
121    # CHECK: transform.named_sequence @callee2
122    # CHECK: transform.include @callee1
123    interp.apply_named_sequence(main, main.body.operations[0], main)
124
125
126@test_in_context
127def partial_include():
128    main = ir.Module.parse(print_root_via_include_module)
129    callee2 = ir.Module.parse(callee2_definition)
130    interp.copy_symbols_and_merge_into(main, callee2)
131
132    try:
133        interp.apply_named_sequence(main, main.body.operations[0], main)
134    except ValueError as e:
135        assert "Failed to apply" in str(e)
136
137
138@test_in_context
139def repeated_include():
140    main = ir.Module.parse(print_root_via_include_module)
141    callee2 = ir.Module.parse(callee2_definition)
142    interp.copy_symbols_and_merge_into(main, callee2)
143
144    try:
145        interp.copy_symbols_and_merge_into(main, callee2)
146    except ValueError as e:
147        assert "doubly defined symbol @callee2" in str(e)
148