xref: /llvm-project/mlir/test/python/dialects/transform_interpreter.py (revision 73140daebbf522dbb14dc4b2f3c67dc0aa1a62dd)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir import ir
4from mlir.dialects.transform import interpreter as interp
5
6
7def test_in_context(f):
8    with ir.Context(), ir.Location.unknown():
9        f()
10    return f
11
12
13print_root_module = """
14module attributes {transform.with_named_sequence} {
15  transform.named_sequence @__transform_main(%root: !transform.any_op) {
16    transform.print %root { name = \"from interpreter\" }: !transform.any_op
17    transform.yield
18  }
19}"""
20
21
22@test_in_context
23def print_self():
24    m = ir.Module.parse(print_root_module.replace("from interpreter", "print_self"))
25    interp.apply_named_sequence(m, m.body.operations[0], m)
26
27
28# CHECK-LABEL: print_self
29# CHECK: transform.named_sequence @__transform_main
30# CHECK: transform.print
31# CHECK: transform.yield
32
33
34@test_in_context
35def print_other():
36    transform = ir.Module.parse(
37        print_root_module.replace("from interpreter", "print_other")
38    )
39    payload = ir.Module.parse("module attributes { this.is.payload } {}")
40    interp.apply_named_sequence(payload, transform.body.operations[0], transform)
41
42
43# CHECK-LABEL: print_other
44# CHECK-NOT: transform
45# CHECK: this.is.payload
46
47
48@test_in_context
49def failed():
50    payload = ir.Module.parse("module attributes { this.is.payload } {}")
51    try:
52        interp.apply_named_sequence(payload, payload, payload)
53    except ValueError as e:
54        assert (
55            "must implement TransformOpInterface to be used as transform root" in str(e)
56        )
57
58
59print_root_via_include_module = """
60module @print_root_via_include_module attributes {transform.with_named_sequence} {
61  transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
62  transform.named_sequence private @callee2(%root: !transform.any_op {transform.readonly})
63  transform.named_sequence @__transform_main(%root: !transform.any_op) {
64    transform.include @callee2 failures(propagate)
65        (%root) : (!transform.any_op) -> ()
66    transform.yield
67  }
68}"""
69
70callee2_definition = """
71module attributes {transform.with_named_sequence} {
72  transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
73  transform.named_sequence @callee2(%root: !transform.any_op {transform.readonly}) {
74    transform.include @callee1 failures(propagate)
75        (%root) : (!transform.any_op) -> ()
76    transform.yield
77  }
78}
79"""
80
81callee1_definition = """
82module attributes {transform.with_named_sequence} {
83  transform.named_sequence @callee1(%root: !transform.any_op {transform.readonly}) {
84    transform.print %root { name = \"from interpreter\" }: !transform.any_op
85    transform.yield
86  }
87}
88"""
89
90
91@test_in_context
92def include():
93    main = ir.Module.parse(print_root_via_include_module)
94    callee1 = ir.Module.parse(callee1_definition)
95    callee2 = ir.Module.parse(callee2_definition)
96    interp.copy_symbols_and_merge_into(main, callee1)
97    interp.copy_symbols_and_merge_into(main, callee2)
98
99    # CHECK: @print_root_via_include_module
100    # CHECK: transform.named_sequence @__transform_main
101    # CHECK: transform.include @callee2
102    #
103    # CHECK: transform.named_sequence @callee1
104    # CHECK: transform.print
105    #
106    # CHECK: transform.named_sequence @callee2
107    # CHECK: transform.include @callee1
108    interp.apply_named_sequence(main, main.body.operations[0], main)
109
110
111@test_in_context
112def partial_include():
113    main = ir.Module.parse(print_root_via_include_module)
114    callee2 = ir.Module.parse(callee2_definition)
115    interp.copy_symbols_and_merge_into(main, callee2)
116
117    try:
118        interp.apply_named_sequence(main, main.body.operations[0], main)
119    except ValueError as e:
120        assert "Failed to apply" in str(e)
121
122
123@test_in_context
124def repeated_include():
125    main = ir.Module.parse(print_root_via_include_module)
126    callee2 = ir.Module.parse(callee2_definition)
127    interp.copy_symbols_and_merge_into(main, callee2)
128
129    try:
130        interp.copy_symbols_and_merge_into(main, callee2)
131    except ValueError as e:
132        assert "doubly defined symbol @callee2" in str(e)
133