xref: /llvm-project/mlir/test/python/ir/dialects.py (revision c703b4645c79e889fd6a0f3f64f01f957d981aa4)
19f3f6d7bSStella Laurenzo# RUN: %PYTHON %s | FileCheck %s
29f3f6d7bSStella Laurenzo
39f3f6d7bSStella Laurenzoimport gc
45192e299SMaksim Leventalimport sys
59f3f6d7bSStella Laurenzofrom mlir.ir import *
65192e299SMaksim Leventalfrom mlir.dialects._ods_common import _cext
79f3f6d7bSStella Laurenzo
89f3f6d7bSStella Laurenzo
99f3f6d7bSStella Laurenzodef run(f):
109f3f6d7bSStella Laurenzo    print("\nTEST:", f.__name__)
119f3f6d7bSStella Laurenzo    f()
129f3f6d7bSStella Laurenzo    gc.collect()
139f3f6d7bSStella Laurenzo    assert Context._get_live_count() == 0
149f3f6d7bSStella Laurenzo    return f
159f3f6d7bSStella Laurenzo
169f3f6d7bSStella Laurenzo
179f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testDialectDescriptor
189f3f6d7bSStella Laurenzo@run
199f3f6d7bSStella Laurenzodef testDialectDescriptor():
209f3f6d7bSStella Laurenzo    ctx = Context()
2123aa5a74SRiver Riddle    d = ctx.get_dialect_descriptor("func")
2223aa5a74SRiver Riddle    # CHECK: <DialectDescriptor func>
239f3f6d7bSStella Laurenzo    print(d)
2423aa5a74SRiver Riddle    # CHECK: func
259f3f6d7bSStella Laurenzo    print(d.namespace)
269f3f6d7bSStella Laurenzo    try:
279f3f6d7bSStella Laurenzo        _ = ctx.get_dialect_descriptor("not_existing")
289f3f6d7bSStella Laurenzo    except ValueError:
299f3f6d7bSStella Laurenzo        pass
309f3f6d7bSStella Laurenzo    else:
319f3f6d7bSStella Laurenzo        assert False, "Expected exception"
329f3f6d7bSStella Laurenzo
339f3f6d7bSStella Laurenzo
349f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testUserDialectClass
359f3f6d7bSStella Laurenzo@run
369f3f6d7bSStella Laurenzodef testUserDialectClass():
379f3f6d7bSStella Laurenzo    ctx = Context()
389f3f6d7bSStella Laurenzo    # Access using attribute.
3923aa5a74SRiver Riddle    d = ctx.dialects.func
4023aa5a74SRiver Riddle    # CHECK: <Dialect func (class mlir.dialects._func_ops_gen._Dialect)>
419f3f6d7bSStella Laurenzo    print(d)
429f3f6d7bSStella Laurenzo    try:
439f3f6d7bSStella Laurenzo        _ = ctx.dialects.not_existing
449f3f6d7bSStella Laurenzo    except AttributeError:
459f3f6d7bSStella Laurenzo        pass
469f3f6d7bSStella Laurenzo    else:
479f3f6d7bSStella Laurenzo        assert False, "Expected exception"
489f3f6d7bSStella Laurenzo
499f3f6d7bSStella Laurenzo    # Access using index.
5023aa5a74SRiver Riddle    d = ctx.dialects["func"]
5123aa5a74SRiver Riddle    # CHECK: <Dialect func (class mlir.dialects._func_ops_gen._Dialect)>
529f3f6d7bSStella Laurenzo    print(d)
539f3f6d7bSStella Laurenzo    try:
549f3f6d7bSStella Laurenzo        _ = ctx.dialects["not_existing"]
559f3f6d7bSStella Laurenzo    except IndexError:
569f3f6d7bSStella Laurenzo        pass
579f3f6d7bSStella Laurenzo    else:
589f3f6d7bSStella Laurenzo        assert False, "Expected exception"
599f3f6d7bSStella Laurenzo
609f3f6d7bSStella Laurenzo    # Using the 'd' alias.
6123aa5a74SRiver Riddle    d = ctx.d["func"]
6223aa5a74SRiver Riddle    # CHECK: <Dialect func (class mlir.dialects._func_ops_gen._Dialect)>
639f3f6d7bSStella Laurenzo    print(d)
649f3f6d7bSStella Laurenzo
659f3f6d7bSStella Laurenzo
669f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testCustomOpView
679f3f6d7bSStella Laurenzo# This test uses the standard dialect AddFOp as an example of a user op.
689f3f6d7bSStella Laurenzo# TODO: Op creation and access is still quite verbose: simplify this test as
699f3f6d7bSStella Laurenzo# additional capabilities come online.
709f3f6d7bSStella Laurenzo@run
719f3f6d7bSStella Laurenzodef testCustomOpView():
729f3f6d7bSStella Laurenzo    def createInput():
739f3f6d7bSStella Laurenzo        op = Operation.create("pytest_dummy.intinput", results=[f32])
749f3f6d7bSStella Laurenzo        # TODO: Auto result cast from operation
759f3f6d7bSStella Laurenzo        return op.results[0]
769f3f6d7bSStella Laurenzo
779f3f6d7bSStella Laurenzo    with Context() as ctx, Location.unknown():
789f3f6d7bSStella Laurenzo        ctx.allow_unregistered_dialects = True
799f3f6d7bSStella Laurenzo        m = Module.create()
809f3f6d7bSStella Laurenzo
819f3f6d7bSStella Laurenzo        with InsertionPoint(m.body):
829f3f6d7bSStella Laurenzo            f32 = F32Type.get()
839f3f6d7bSStella Laurenzo            # Create via dialects context collection.
849f3f6d7bSStella Laurenzo            input1 = createInput()
859f3f6d7bSStella Laurenzo            input2 = createInput()
862995d29bSAlex Zinenko            op1 = ctx.dialects.arith.AddFOp(input1, input2)
879f3f6d7bSStella Laurenzo
889f3f6d7bSStella Laurenzo            # Create via an import
89a54f4eaeSMogball            from mlir.dialects.arith import AddFOp
90f9008e63STobias Hieta
912995d29bSAlex Zinenko            AddFOp(input1, op1.result)
929f3f6d7bSStella Laurenzo
939f3f6d7bSStella Laurenzo    # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput"
949f3f6d7bSStella Laurenzo    # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput"
95a54f4eaeSMogball    # CHECK: %[[R0:.*]] = arith.addf %[[INPUT0]], %[[INPUT1]] : f32
96a54f4eaeSMogball    # CHECK: %[[R1:.*]] = arith.addf %[[INPUT0]], %[[R0]] : f32
979f3f6d7bSStella Laurenzo    m.operation.print()
989f3f6d7bSStella Laurenzo
999f3f6d7bSStella Laurenzo
1009f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testIsRegisteredOperation
1019f3f6d7bSStella Laurenzo@run
1029f3f6d7bSStella Laurenzodef testIsRegisteredOperation():
1039f3f6d7bSStella Laurenzo    ctx = Context()
1049f3f6d7bSStella Laurenzo
105ace01605SRiver Riddle    # CHECK: cf.cond_br: True
106ace01605SRiver Riddle    print(f"cf.cond_br: {ctx.is_registered_operation('cf.cond_br')}")
10723aa5a74SRiver Riddle    # CHECK: func.not_existing: False
10823aa5a74SRiver Riddle    print(f"func.not_existing: {ctx.is_registered_operation('func.not_existing')}")
1095192e299SMaksim Levental
1105192e299SMaksim Levental
1115192e299SMaksim Levental# CHECK-LABEL: TEST: testAppendPrefixSearchPath
1125192e299SMaksim Levental@run
1135192e299SMaksim Leventaldef testAppendPrefixSearchPath():
1145192e299SMaksim Levental    ctx = Context()
1155192e299SMaksim Levental    ctx.allow_unregistered_dialects = True
1165192e299SMaksim Levental    with Location.unknown(ctx):
1175192e299SMaksim Levental        assert not _cext.globals._check_dialect_module_loaded("custom")
1185192e299SMaksim Levental        Operation.create("custom.op")
1195192e299SMaksim Levental        assert not _cext.globals._check_dialect_module_loaded("custom")
1205192e299SMaksim Levental
1215192e299SMaksim Levental        sys.path.append(".")
1225192e299SMaksim Levental        _cext.globals.append_dialect_search_prefix("custom_dialect")
1235192e299SMaksim Levental        assert _cext.globals._check_dialect_module_loaded("custom")
124*c703b464SJacques Pienaar
125*c703b464SJacques Pienaar
126*c703b464SJacques Pienaar# CHECK-LABEL: TEST: testDialectLoadOnCreate
127*c703b464SJacques Pienaar@run
128*c703b464SJacques Pienaardef testDialectLoadOnCreate():
129*c703b464SJacques Pienaar    with Context(load_on_create_dialects=[]) as ctx:
130*c703b464SJacques Pienaar        ctx.emit_error_diagnostics = True
131*c703b464SJacques Pienaar        ctx.allow_unregistered_dialects = True
132*c703b464SJacques Pienaar
133*c703b464SJacques Pienaar        def callback(d):
134*c703b464SJacques Pienaar            # CHECK: DIAGNOSTIC
135*c703b464SJacques Pienaar            # CHECK-SAME: op created with unregistered dialect
136*c703b464SJacques Pienaar            print(f"DIAGNOSTIC={d.message}")
137*c703b464SJacques Pienaar            return True
138*c703b464SJacques Pienaar
139*c703b464SJacques Pienaar        handler = ctx.attach_diagnostic_handler(callback)
140*c703b464SJacques Pienaar        loc = Location.unknown(ctx)
141*c703b464SJacques Pienaar        try:
142*c703b464SJacques Pienaar            op = Operation.create("arith.addi", loc=loc)
143*c703b464SJacques Pienaar            ctx.allow_unregistered_dialects = False
144*c703b464SJacques Pienaar            op.verify()
145*c703b464SJacques Pienaar        except MLIRError as e:
146*c703b464SJacques Pienaar            pass
147*c703b464SJacques Pienaar
148*c703b464SJacques Pienaar    with Context(load_on_create_dialects=["func"]) as ctx:
149*c703b464SJacques Pienaar        loc = Location.unknown(ctx)
150*c703b464SJacques Pienaar        fn = Operation.create("func.func", loc=loc)
151*c703b464SJacques Pienaar
152*c703b464SJacques Pienaar    # TODO: This may require an update if a site wide policy is set.
153*c703b464SJacques Pienaar    # CHECK: Load on create: []
154*c703b464SJacques Pienaar    print(f"Load on create: {get_load_on_create_dialects()}")
155*c703b464SJacques Pienaar    append_load_on_create_dialect("func")
156*c703b464SJacques Pienaar    # CHECK: Load on create:
157*c703b464SJacques Pienaar    # CHECK-SAME: func
158*c703b464SJacques Pienaar    print(f"Load on create: {get_load_on_create_dialects()}")
159*c703b464SJacques Pienaar    print(get_load_on_create_dialects())
160