xref: /llvm-project/mlir/test/python/ir/module.py (revision d1fdb416299c0efa5979ed989f7c1f39973dcb73)
19f3f6d7bSStella Laurenzo# RUN: %PYTHON %s | FileCheck %s
29f3f6d7bSStella Laurenzo
39f3f6d7bSStella Laurenzoimport gc
49f3f6d7bSStella Laurenzofrom mlir.ir import *
59f3f6d7bSStella Laurenzo
6f9008e63STobias Hieta
79f3f6d7bSStella Laurenzodef run(f):
89f3f6d7bSStella Laurenzo    print("\nTEST:", f.__name__)
99f3f6d7bSStella Laurenzo    f()
109f3f6d7bSStella Laurenzo    gc.collect()
119f3f6d7bSStella Laurenzo    assert Context._get_live_count() == 0
12ace1d0adSStella Laurenzo    return f
139f3f6d7bSStella Laurenzo
149f3f6d7bSStella Laurenzo
159f3f6d7bSStella Laurenzo# Verify successful parse.
169f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testParseSuccess
179f3f6d7bSStella Laurenzo# CHECK: module @successfulParse
18ace1d0adSStella Laurenzo@run
199f3f6d7bSStella Laurenzodef testParseSuccess():
209f3f6d7bSStella Laurenzo    ctx = Context()
219f3f6d7bSStella Laurenzo    module = Module.parse(r"""module @successfulParse {}""", ctx)
229f3f6d7bSStella Laurenzo    assert module.context is ctx
239f3f6d7bSStella Laurenzo    print("CLEAR CONTEXT")
249f3f6d7bSStella Laurenzo    ctx = None  # Ensure that module captures the context.
259f3f6d7bSStella Laurenzo    gc.collect()
269f3f6d7bSStella Laurenzo    module.dump()  # Just outputs to stderr. Verifies that it functions.
279f3f6d7bSStella Laurenzo    print(str(module))
289f3f6d7bSStella Laurenzo
299f3f6d7bSStella Laurenzo
309f3f6d7bSStella Laurenzo# Verify parse error.
319f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testParseError
323ea4c501SRahul Kayaith# CHECK: testParseError: <
333ea4c501SRahul Kayaith# CHECK:   Unable to parse module assembly:
343ea4c501SRahul Kayaith# CHECK:   error: "-":1:1: expected operation name in quotes
353ea4c501SRahul Kayaith# CHECK: >
36ace1d0adSStella Laurenzo@run
379f3f6d7bSStella Laurenzodef testParseError():
389f3f6d7bSStella Laurenzo    ctx = Context()
399f3f6d7bSStella Laurenzo    try:
409f3f6d7bSStella Laurenzo        module = Module.parse(r"""}SYNTAX ERROR{""", ctx)
413ea4c501SRahul Kayaith    except MLIRError as e:
423ea4c501SRahul Kayaith        print(f"testParseError: <{e}>")
439f3f6d7bSStella Laurenzo    else:
449f3f6d7bSStella Laurenzo        print("Exception not produced")
459f3f6d7bSStella Laurenzo
469f3f6d7bSStella Laurenzo
479f3f6d7bSStella Laurenzo# Verify successful parse.
489f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testCreateEmpty
499f3f6d7bSStella Laurenzo# CHECK: module {
50ace1d0adSStella Laurenzo@run
519f3f6d7bSStella Laurenzodef testCreateEmpty():
529f3f6d7bSStella Laurenzo    ctx = Context()
539f3f6d7bSStella Laurenzo    loc = Location.unknown(ctx)
549f3f6d7bSStella Laurenzo    module = Module.create(loc)
559f3f6d7bSStella Laurenzo    print("CLEAR CONTEXT")
569f3f6d7bSStella Laurenzo    ctx = None  # Ensure that module captures the context.
579f3f6d7bSStella Laurenzo    gc.collect()
589f3f6d7bSStella Laurenzo    print(str(module))
599f3f6d7bSStella Laurenzo
609f3f6d7bSStella Laurenzo
619f3f6d7bSStella Laurenzo# Verify round-trip of ASM that contains unicode.
629f3f6d7bSStella Laurenzo# Note that this does not test that the print path converts unicode properly
639f3f6d7bSStella Laurenzo# because MLIR asm always normalizes it to the hex encoding.
649f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testRoundtripUnicode
659f3f6d7bSStella Laurenzo# CHECK: func private @roundtripUnicode()
669f3f6d7bSStella Laurenzo# CHECK: foo = "\F0\9F\98\8A"
67ace1d0adSStella Laurenzo@run
689f3f6d7bSStella Laurenzodef testRoundtripUnicode():
699f3f6d7bSStella Laurenzo    ctx = Context()
70f9008e63STobias Hieta    module = Module.parse(
71f9008e63STobias Hieta        r"""
722310ced8SRiver Riddle    func.func private @roundtripUnicode() attributes { foo = "��" }
73f9008e63STobias Hieta  """,
74f9008e63STobias Hieta        ctx,
75f9008e63STobias Hieta    )
769f3f6d7bSStella Laurenzo    print(str(module))
779f3f6d7bSStella Laurenzo
78ace1d0adSStella Laurenzo
79ace1d0adSStella Laurenzo# Verify round-trip of ASM that contains unicode.
80ace1d0adSStella Laurenzo# Note that this does not test that the print path converts unicode properly
81ace1d0adSStella Laurenzo# because MLIR asm always normalizes it to the hex encoding.
82ace1d0adSStella Laurenzo# CHECK-LABEL: TEST: testRoundtripBinary
83ace1d0adSStella Laurenzo# CHECK: func private @roundtripUnicode()
84ace1d0adSStella Laurenzo# CHECK: foo = "\F0\9F\98\8A"
85ace1d0adSStella Laurenzo@run
86ace1d0adSStella Laurenzodef testRoundtripBinary():
87ace1d0adSStella Laurenzo    with Context():
88f9008e63STobias Hieta        module = Module.parse(
89f9008e63STobias Hieta            r"""
902310ced8SRiver Riddle      func.func private @roundtripUnicode() attributes { foo = "��" }
91f9008e63STobias Hieta    """
92f9008e63STobias Hieta        )
93ace1d0adSStella Laurenzo        binary_asm = module.operation.get_asm(binary=True)
94ace1d0adSStella Laurenzo        assert isinstance(binary_asm, bytes)
95ace1d0adSStella Laurenzo        module = Module.parse(binary_asm)
96ace1d0adSStella Laurenzo        print(module)
979f3f6d7bSStella Laurenzo
989f3f6d7bSStella Laurenzo
999f3f6d7bSStella Laurenzo# Tests that module.operation works and correctly interns instances.
1009f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testModuleOperation
101ace1d0adSStella Laurenzo@run
1029f3f6d7bSStella Laurenzodef testModuleOperation():
1039f3f6d7bSStella Laurenzo    ctx = Context()
1049f3f6d7bSStella Laurenzo    module = Module.parse(r"""module @successfulParse {}""", ctx)
1059f3f6d7bSStella Laurenzo    assert ctx._get_live_module_count() == 1
1069f3f6d7bSStella Laurenzo    op1 = module.operation
1079f3f6d7bSStella Laurenzo    assert ctx._get_live_operation_count() == 1
108*d1fdb416SJohn Demme    live_ops = ctx._get_live_operation_objects()
109*d1fdb416SJohn Demme    assert len(live_ops) == 1
110*d1fdb416SJohn Demme    assert live_ops[0] is op1
111*d1fdb416SJohn Demme    live_ops = None
1129f3f6d7bSStella Laurenzo    # CHECK: module @successfulParse
1139f3f6d7bSStella Laurenzo    print(op1)
1149f3f6d7bSStella Laurenzo
1159f3f6d7bSStella Laurenzo    # Ensure that operations are the same on multiple calls.
1169f3f6d7bSStella Laurenzo    op2 = module.operation
1179f3f6d7bSStella Laurenzo    assert ctx._get_live_operation_count() == 1
1189f3f6d7bSStella Laurenzo    assert op1 is op2
1199f3f6d7bSStella Laurenzo
1206b0bed7eSJohn Demme    # Test live operation clearing.
1216b0bed7eSJohn Demme    op1 = module.operation
1226b0bed7eSJohn Demme    assert ctx._get_live_operation_count() == 1
1236b0bed7eSJohn Demme    num_invalidated = ctx._clear_live_operations()
1246b0bed7eSJohn Demme    assert num_invalidated == 1
1256b0bed7eSJohn Demme    assert ctx._get_live_operation_count() == 0
1266b0bed7eSJohn Demme    op1 = None
1276b0bed7eSJohn Demme    gc.collect()
1286b0bed7eSJohn Demme    op1 = module.operation
1296b0bed7eSJohn Demme
1309f3f6d7bSStella Laurenzo    # Ensure that if module is de-referenced, the operations are still valid.
1319f3f6d7bSStella Laurenzo    module = None
1329f3f6d7bSStella Laurenzo    gc.collect()
1339f3f6d7bSStella Laurenzo    print(op1)
1349f3f6d7bSStella Laurenzo
1359f3f6d7bSStella Laurenzo    # Collect and verify lifetime.
1369f3f6d7bSStella Laurenzo    op1 = None
1379f3f6d7bSStella Laurenzo    op2 = None
1389f3f6d7bSStella Laurenzo    gc.collect()
1399f3f6d7bSStella Laurenzo    print("LIVE OPERATIONS:", ctx._get_live_operation_count())
1409f3f6d7bSStella Laurenzo    assert ctx._get_live_operation_count() == 0
1419f3f6d7bSStella Laurenzo    assert ctx._get_live_module_count() == 0
1429f3f6d7bSStella Laurenzo
1439f3f6d7bSStella Laurenzo
1449f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testModuleCapsule
145ace1d0adSStella Laurenzo@run
1469f3f6d7bSStella Laurenzodef testModuleCapsule():
1479f3f6d7bSStella Laurenzo    ctx = Context()
1489f3f6d7bSStella Laurenzo    module = Module.parse(r"""module @successfulParse {}""", ctx)
1499f3f6d7bSStella Laurenzo    assert ctx._get_live_module_count() == 1
1509f3f6d7bSStella Laurenzo    # CHECK: "mlir.ir.Module._CAPIPtr"
1519f3f6d7bSStella Laurenzo    module_capsule = module._CAPIPtr
1529f3f6d7bSStella Laurenzo    print(module_capsule)
1539f3f6d7bSStella Laurenzo    module_dup = Module._CAPICreate(module_capsule)
1549f3f6d7bSStella Laurenzo    assert module is module_dup
1559f3f6d7bSStella Laurenzo    assert module_dup.context is ctx
1569f3f6d7bSStella Laurenzo    # Gc and verify destructed.
1579f3f6d7bSStella Laurenzo    module = None
1589f3f6d7bSStella Laurenzo    module_capsule = None
1599f3f6d7bSStella Laurenzo    module_dup = None
1609f3f6d7bSStella Laurenzo    gc.collect()
1619f3f6d7bSStella Laurenzo    assert ctx._get_live_module_count() == 0
162