xref: /llvm-project/mlir/test/python/ir/symbol_table.py (revision b56d1ec6cb8b5cb3ff46cba39a1049ecf3831afb)
1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4import io
5import itertools
6from mlir.ir import *
7
8
9def run(f):
10    print("\nTEST:", f.__name__)
11    f()
12    gc.collect()
13    assert Context._get_live_count() == 0
14    return f
15
16
17# CHECK-LABEL: TEST: testSymbolTableInsert
18@run
19def testSymbolTableInsert():
20    with Context() as ctx:
21        ctx.allow_unregistered_dialects = True
22        m1 = Module.parse(
23            """
24      func.func private @foo()
25      func.func private @bar()"""
26        )
27        m2 = Module.parse(
28            """
29      func.func private @qux()
30      func.func private @foo()
31      "foo.bar"() : () -> ()"""
32        )
33
34        symbol_table = SymbolTable(m1.operation)
35
36        # CHECK: func private @foo
37        # CHECK: func private @bar
38        assert "foo" in symbol_table
39        print(symbol_table["foo"])
40        assert "bar" in symbol_table
41        bar = symbol_table["bar"]
42        print(symbol_table["bar"])
43
44        assert "qux" not in symbol_table
45
46        del symbol_table["bar"]
47        try:
48            symbol_table.erase(symbol_table["bar"])
49        except KeyError:
50            pass
51        else:
52            assert False, "expected KeyError"
53
54        # CHECK: module
55        # CHECK:   func private @foo()
56        print(m1)
57        assert "bar" not in symbol_table
58
59        try:
60            print(bar)
61        except RuntimeError as e:
62            if "the operation has been invalidated" not in str(e):
63                raise
64        else:
65            assert False, "expected RuntimeError due to invalidated operation"
66
67        qux = m2.body.operations[0]
68        m1.body.append(qux)
69        symbol_table.insert(qux)
70        assert "qux" in symbol_table
71
72        # Check that insertion actually renames this symbol in the symbol table.
73        foo2 = m2.body.operations[0]
74        m1.body.append(foo2)
75        updated_name = symbol_table.insert(foo2)
76        assert foo2.name.value != "foo"
77        assert foo2.name == updated_name
78        assert isinstance(updated_name, StringAttr)
79
80        # CHECK: module
81        # CHECK:   func private @foo()
82        # CHECK:   func private @qux()
83        # CHECK:   func private @foo{{.*}}
84        print(m1)
85
86        try:
87            symbol_table.insert(m2.body.operations[0])
88        except ValueError as e:
89            if "Expected operation to have a symbol name" not in str(e):
90                raise
91        else:
92            assert False, "exepcted ValueError when adding a non-symbol"
93
94
95# CHECK-LABEL: testSymbolTableRAUW
96@run
97def testSymbolTableRAUW():
98    with Context() as ctx:
99        m = Module.parse(
100            """
101      func.func private @foo() {
102        call @bar() : () -> ()
103        return
104      }
105      func.func private @bar()
106      """
107        )
108        foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2]
109
110        # Do renaming just within `foo`.
111        SymbolTable.set_symbol_name(bar, "bam")
112        SymbolTable.replace_all_symbol_uses("bar", "bam", foo)
113        # CHECK: call @bam()
114        # CHECK: func private @bam
115        print(m)
116        # CHECK: Foo symbol: StringAttr("foo")
117        # CHECK: Bar symbol: StringAttr("bam")
118        print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
119        print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")
120
121        # Do renaming within the module.
122        SymbolTable.set_symbol_name(bar, "baz")
123        SymbolTable.replace_all_symbol_uses("bam", "baz", m.operation)
124        # CHECK: call @baz()
125        # CHECK: func private @baz
126        print(m)
127        # CHECK: Foo symbol: StringAttr("foo")
128        # CHECK: Bar symbol: StringAttr("baz")
129        print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
130        print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")
131
132
133# CHECK-LABEL: testSymbolTableVisibility
134@run
135def testSymbolTableVisibility():
136    with Context() as ctx:
137        m = Module.parse(
138            """
139      func.func private @foo() {
140        return
141      }
142      """
143        )
144        foo = m.operation.regions[0].blocks[0].operations[0]
145        # CHECK: Existing visibility: StringAttr("private")
146        print(f"Existing visibility: {repr(SymbolTable.get_visibility(foo))}")
147        SymbolTable.set_visibility(foo, "public")
148        # CHECK: func public @foo
149        print(m)
150
151
152# CHECK: testWalkSymbolTables
153@run
154def testWalkSymbolTables():
155    with Context() as ctx:
156        m = Module.parse(
157            """
158      module @outer {
159        module @inner{
160        }
161      }
162      """
163        )
164
165        def callback(symbol_table_op, uses_visible):
166            print(f"SYMBOL TABLE: {uses_visible}: {symbol_table_op}")
167
168        # CHECK: SYMBOL TABLE: True: module @inner
169        # CHECK: SYMBOL TABLE: True: module @outer
170        SymbolTable.walk_symbol_tables(m.operation, True, callback)
171
172        # Make sure exceptions in the callback are handled.
173        def error_callback(symbol_table_op, uses_visible):
174            assert False, "Raised from python"
175
176        try:
177            SymbolTable.walk_symbol_tables(m.operation, True, error_callback)
178        except RuntimeError as e:
179            # CHECK: GOT EXCEPTION: Exception raised in callback:
180            # CHECK: AssertionError: Raised from python
181            print(f"GOT EXCEPTION: {e}")
182