1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4from mlir.ir import * 5 6 7def run(f): 8 print("\nTEST:", f.__name__) 9 f() 10 gc.collect() 11 assert Context._get_live_count() == 0 12 return f 13 14 15# Verify successful parse. 16# CHECK-LABEL: TEST: testParseSuccess 17# CHECK: module @successfulParse 18@run 19def testParseSuccess(): 20 ctx = Context() 21 module = Module.parse(r"""module @successfulParse {}""", ctx) 22 assert module.context is ctx 23 print("CLEAR CONTEXT") 24 ctx = None # Ensure that module captures the context. 25 gc.collect() 26 module.dump() # Just outputs to stderr. Verifies that it functions. 27 print(str(module)) 28 29 30# Verify parse error. 31# CHECK-LABEL: TEST: testParseError 32# CHECK: testParseError: < 33# CHECK: Unable to parse module assembly: 34# CHECK: error: "-":1:1: expected operation name in quotes 35# CHECK: > 36@run 37def testParseError(): 38 ctx = Context() 39 try: 40 module = Module.parse(r"""}SYNTAX ERROR{""", ctx) 41 except MLIRError as e: 42 print(f"testParseError: <{e}>") 43 else: 44 print("Exception not produced") 45 46 47# Verify successful parse. 48# CHECK-LABEL: TEST: testCreateEmpty 49# CHECK: module { 50@run 51def testCreateEmpty(): 52 ctx = Context() 53 loc = Location.unknown(ctx) 54 module = Module.create(loc) 55 print("CLEAR CONTEXT") 56 ctx = None # Ensure that module captures the context. 57 gc.collect() 58 print(str(module)) 59 60 61# Verify round-trip of ASM that contains unicode. 62# Note that this does not test that the print path converts unicode properly 63# because MLIR asm always normalizes it to the hex encoding. 64# CHECK-LABEL: TEST: testRoundtripUnicode 65# CHECK: func private @roundtripUnicode() 66# CHECK: foo = "\F0\9F\98\8A" 67@run 68def testRoundtripUnicode(): 69 ctx = Context() 70 module = Module.parse( 71 r""" 72 func.func private @roundtripUnicode() attributes { foo = "" } 73 """, 74 ctx, 75 ) 76 print(str(module)) 77 78 79# Verify round-trip of ASM that contains unicode. 80# Note that this does not test that the print path converts unicode properly 81# because MLIR asm always normalizes it to the hex encoding. 82# CHECK-LABEL: TEST: testRoundtripBinary 83# CHECK: func private @roundtripUnicode() 84# CHECK: foo = "\F0\9F\98\8A" 85@run 86def testRoundtripBinary(): 87 with Context(): 88 module = Module.parse( 89 r""" 90 func.func private @roundtripUnicode() attributes { foo = "" } 91 """ 92 ) 93 binary_asm = module.operation.get_asm(binary=True) 94 assert isinstance(binary_asm, bytes) 95 module = Module.parse(binary_asm) 96 print(module) 97 98 99# Tests that module.operation works and correctly interns instances. 100# CHECK-LABEL: TEST: testModuleOperation 101@run 102def testModuleOperation(): 103 ctx = Context() 104 module = Module.parse(r"""module @successfulParse {}""", ctx) 105 assert ctx._get_live_module_count() == 1 106 op1 = module.operation 107 assert ctx._get_live_operation_count() == 1 108 live_ops = ctx._get_live_operation_objects() 109 assert len(live_ops) == 1 110 assert live_ops[0] is op1 111 live_ops = None 112 # CHECK: module @successfulParse 113 print(op1) 114 115 # Ensure that operations are the same on multiple calls. 116 op2 = module.operation 117 assert ctx._get_live_operation_count() == 1 118 assert op1 is op2 119 120 # Test live operation clearing. 121 op1 = module.operation 122 assert ctx._get_live_operation_count() == 1 123 num_invalidated = ctx._clear_live_operations() 124 assert num_invalidated == 1 125 assert ctx._get_live_operation_count() == 0 126 op1 = None 127 gc.collect() 128 op1 = module.operation 129 130 # Ensure that if module is de-referenced, the operations are still valid. 131 module = None 132 gc.collect() 133 print(op1) 134 135 # Collect and verify lifetime. 136 op1 = None 137 op2 = None 138 gc.collect() 139 print("LIVE OPERATIONS:", ctx._get_live_operation_count()) 140 assert ctx._get_live_operation_count() == 0 141 assert ctx._get_live_module_count() == 0 142 143 144# CHECK-LABEL: TEST: testModuleCapsule 145@run 146def testModuleCapsule(): 147 ctx = Context() 148 module = Module.parse(r"""module @successfulParse {}""", ctx) 149 assert ctx._get_live_module_count() == 1 150 # CHECK: "mlir.ir.Module._CAPIPtr" 151 module_capsule = module._CAPIPtr 152 print(module_capsule) 153 module_dup = Module._CAPICreate(module_capsule) 154 assert module is module_dup 155 assert module_dup.context is ctx 156 # Gc and verify destructed. 157 module = None 158 module_capsule = None 159 module_dup = None 160 gc.collect() 161 assert ctx._get_live_module_count() == 0 162