1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4import io 5import itertools 6from mlir.ir import * 7 8 9def run(f): 10 print("\nTEST:", f.__name__) 11 f() 12 gc.collect() 13 assert Context._get_live_count() == 0 14 return f 15 16 17# CHECK-LABEL: TEST: testSymbolTableInsert 18@run 19def testSymbolTableInsert(): 20 with Context() as ctx: 21 ctx.allow_unregistered_dialects = True 22 m1 = Module.parse( 23 """ 24 func.func private @foo() 25 func.func private @bar()""" 26 ) 27 m2 = Module.parse( 28 """ 29 func.func private @qux() 30 func.func private @foo() 31 "foo.bar"() : () -> ()""" 32 ) 33 34 symbol_table = SymbolTable(m1.operation) 35 36 # CHECK: func private @foo 37 # CHECK: func private @bar 38 assert "foo" in symbol_table 39 print(symbol_table["foo"]) 40 assert "bar" in symbol_table 41 bar = symbol_table["bar"] 42 print(symbol_table["bar"]) 43 44 assert "qux" not in symbol_table 45 46 del symbol_table["bar"] 47 try: 48 symbol_table.erase(symbol_table["bar"]) 49 except KeyError: 50 pass 51 else: 52 assert False, "expected KeyError" 53 54 # CHECK: module 55 # CHECK: func private @foo() 56 print(m1) 57 assert "bar" not in symbol_table 58 59 try: 60 print(bar) 61 except RuntimeError as e: 62 if "the operation has been invalidated" not in str(e): 63 raise 64 else: 65 assert False, "expected RuntimeError due to invalidated operation" 66 67 qux = m2.body.operations[0] 68 m1.body.append(qux) 69 symbol_table.insert(qux) 70 assert "qux" in symbol_table 71 72 # Check that insertion actually renames this symbol in the symbol table. 73 foo2 = m2.body.operations[0] 74 m1.body.append(foo2) 75 updated_name = symbol_table.insert(foo2) 76 assert foo2.name.value != "foo" 77 assert foo2.name == updated_name 78 assert isinstance(updated_name, StringAttr) 79 80 # CHECK: module 81 # CHECK: func private @foo() 82 # CHECK: func private @qux() 83 # CHECK: func private @foo{{.*}} 84 print(m1) 85 86 try: 87 symbol_table.insert(m2.body.operations[0]) 88 except ValueError as e: 89 if "Expected operation to have a symbol name" not in str(e): 90 raise 91 else: 92 assert False, "exepcted ValueError when adding a non-symbol" 93 94 95# CHECK-LABEL: testSymbolTableRAUW 96@run 97def testSymbolTableRAUW(): 98 with Context() as ctx: 99 m = Module.parse( 100 """ 101 func.func private @foo() { 102 call @bar() : () -> () 103 return 104 } 105 func.func private @bar() 106 """ 107 ) 108 foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2] 109 110 # Do renaming just within `foo`. 111 SymbolTable.set_symbol_name(bar, "bam") 112 SymbolTable.replace_all_symbol_uses("bar", "bam", foo) 113 # CHECK: call @bam() 114 # CHECK: func private @bam 115 print(m) 116 # CHECK: Foo symbol: StringAttr("foo") 117 # CHECK: Bar symbol: StringAttr("bam") 118 print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}") 119 print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}") 120 121 # Do renaming within the module. 122 SymbolTable.set_symbol_name(bar, "baz") 123 SymbolTable.replace_all_symbol_uses("bam", "baz", m.operation) 124 # CHECK: call @baz() 125 # CHECK: func private @baz 126 print(m) 127 # CHECK: Foo symbol: StringAttr("foo") 128 # CHECK: Bar symbol: StringAttr("baz") 129 print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}") 130 print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}") 131 132 133# CHECK-LABEL: testSymbolTableVisibility 134@run 135def testSymbolTableVisibility(): 136 with Context() as ctx: 137 m = Module.parse( 138 """ 139 func.func private @foo() { 140 return 141 } 142 """ 143 ) 144 foo = m.operation.regions[0].blocks[0].operations[0] 145 # CHECK: Existing visibility: StringAttr("private") 146 print(f"Existing visibility: {repr(SymbolTable.get_visibility(foo))}") 147 SymbolTable.set_visibility(foo, "public") 148 # CHECK: func public @foo 149 print(m) 150 151 152# CHECK: testWalkSymbolTables 153@run 154def testWalkSymbolTables(): 155 with Context() as ctx: 156 m = Module.parse( 157 """ 158 module @outer { 159 module @inner{ 160 } 161 } 162 """ 163 ) 164 165 def callback(symbol_table_op, uses_visible): 166 print(f"SYMBOL TABLE: {uses_visible}: {symbol_table_op}") 167 168 # CHECK: SYMBOL TABLE: True: module @inner 169 # CHECK: SYMBOL TABLE: True: module @outer 170 SymbolTable.walk_symbol_tables(m.operation, True, callback) 171 172 # Make sure exceptions in the callback are handled. 173 def error_callback(symbol_table_op, uses_visible): 174 assert False, "Raised from python" 175 176 try: 177 SymbolTable.walk_symbol_tables(m.operation, True, error_callback) 178 except RuntimeError as e: 179 # CHECK: GOT EXCEPTION: Exception raised in callback: 180 # CHECK: AssertionError: Raised from python 181 print(f"GOT EXCEPTION: {e}") 182