xref: /llvm-project/mlir/test/python/ir/module.py (revision d1fdb416299c0efa5979ed989f7c1f39973dcb73)
1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6
7def run(f):
8    print("\nTEST:", f.__name__)
9    f()
10    gc.collect()
11    assert Context._get_live_count() == 0
12    return f
13
14
15# Verify successful parse.
16# CHECK-LABEL: TEST: testParseSuccess
17# CHECK: module @successfulParse
18@run
19def testParseSuccess():
20    ctx = Context()
21    module = Module.parse(r"""module @successfulParse {}""", ctx)
22    assert module.context is ctx
23    print("CLEAR CONTEXT")
24    ctx = None  # Ensure that module captures the context.
25    gc.collect()
26    module.dump()  # Just outputs to stderr. Verifies that it functions.
27    print(str(module))
28
29
30# Verify parse error.
31# CHECK-LABEL: TEST: testParseError
32# CHECK: testParseError: <
33# CHECK:   Unable to parse module assembly:
34# CHECK:   error: "-":1:1: expected operation name in quotes
35# CHECK: >
36@run
37def testParseError():
38    ctx = Context()
39    try:
40        module = Module.parse(r"""}SYNTAX ERROR{""", ctx)
41    except MLIRError as e:
42        print(f"testParseError: <{e}>")
43    else:
44        print("Exception not produced")
45
46
47# Verify successful parse.
48# CHECK-LABEL: TEST: testCreateEmpty
49# CHECK: module {
50@run
51def testCreateEmpty():
52    ctx = Context()
53    loc = Location.unknown(ctx)
54    module = Module.create(loc)
55    print("CLEAR CONTEXT")
56    ctx = None  # Ensure that module captures the context.
57    gc.collect()
58    print(str(module))
59
60
61# Verify round-trip of ASM that contains unicode.
62# Note that this does not test that the print path converts unicode properly
63# because MLIR asm always normalizes it to the hex encoding.
64# CHECK-LABEL: TEST: testRoundtripUnicode
65# CHECK: func private @roundtripUnicode()
66# CHECK: foo = "\F0\9F\98\8A"
67@run
68def testRoundtripUnicode():
69    ctx = Context()
70    module = Module.parse(
71        r"""
72    func.func private @roundtripUnicode() attributes { foo = "��" }
73  """,
74        ctx,
75    )
76    print(str(module))
77
78
79# Verify round-trip of ASM that contains unicode.
80# Note that this does not test that the print path converts unicode properly
81# because MLIR asm always normalizes it to the hex encoding.
82# CHECK-LABEL: TEST: testRoundtripBinary
83# CHECK: func private @roundtripUnicode()
84# CHECK: foo = "\F0\9F\98\8A"
85@run
86def testRoundtripBinary():
87    with Context():
88        module = Module.parse(
89            r"""
90      func.func private @roundtripUnicode() attributes { foo = "��" }
91    """
92        )
93        binary_asm = module.operation.get_asm(binary=True)
94        assert isinstance(binary_asm, bytes)
95        module = Module.parse(binary_asm)
96        print(module)
97
98
99# Tests that module.operation works and correctly interns instances.
100# CHECK-LABEL: TEST: testModuleOperation
101@run
102def testModuleOperation():
103    ctx = Context()
104    module = Module.parse(r"""module @successfulParse {}""", ctx)
105    assert ctx._get_live_module_count() == 1
106    op1 = module.operation
107    assert ctx._get_live_operation_count() == 1
108    live_ops = ctx._get_live_operation_objects()
109    assert len(live_ops) == 1
110    assert live_ops[0] is op1
111    live_ops = None
112    # CHECK: module @successfulParse
113    print(op1)
114
115    # Ensure that operations are the same on multiple calls.
116    op2 = module.operation
117    assert ctx._get_live_operation_count() == 1
118    assert op1 is op2
119
120    # Test live operation clearing.
121    op1 = module.operation
122    assert ctx._get_live_operation_count() == 1
123    num_invalidated = ctx._clear_live_operations()
124    assert num_invalidated == 1
125    assert ctx._get_live_operation_count() == 0
126    op1 = None
127    gc.collect()
128    op1 = module.operation
129
130    # Ensure that if module is de-referenced, the operations are still valid.
131    module = None
132    gc.collect()
133    print(op1)
134
135    # Collect and verify lifetime.
136    op1 = None
137    op2 = None
138    gc.collect()
139    print("LIVE OPERATIONS:", ctx._get_live_operation_count())
140    assert ctx._get_live_operation_count() == 0
141    assert ctx._get_live_module_count() == 0
142
143
144# CHECK-LABEL: TEST: testModuleCapsule
145@run
146def testModuleCapsule():
147    ctx = Context()
148    module = Module.parse(r"""module @successfulParse {}""", ctx)
149    assert ctx._get_live_module_count() == 1
150    # CHECK: "mlir.ir.Module._CAPIPtr"
151    module_capsule = module._CAPIPtr
152    print(module_capsule)
153    module_dup = Module._CAPICreate(module_capsule)
154    assert module is module_dup
155    assert module_dup.context is ctx
156    # Gc and verify destructed.
157    module = None
158    module_capsule = None
159    module_dup = None
160    gc.collect()
161    assert ctx._get_live_module_count() == 0
162