19f3f6d7bSStella Laurenzo# RUN: %PYTHON %s | FileCheck %s 29f3f6d7bSStella Laurenzo 39f3f6d7bSStella Laurenzoimport gc 45192e299SMaksim Leventalimport sys 59f3f6d7bSStella Laurenzofrom mlir.ir import * 65192e299SMaksim Leventalfrom mlir.dialects._ods_common import _cext 79f3f6d7bSStella Laurenzo 89f3f6d7bSStella Laurenzo 99f3f6d7bSStella Laurenzodef run(f): 109f3f6d7bSStella Laurenzo print("\nTEST:", f.__name__) 119f3f6d7bSStella Laurenzo f() 129f3f6d7bSStella Laurenzo gc.collect() 139f3f6d7bSStella Laurenzo assert Context._get_live_count() == 0 149f3f6d7bSStella Laurenzo return f 159f3f6d7bSStella Laurenzo 169f3f6d7bSStella Laurenzo 179f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testDialectDescriptor 189f3f6d7bSStella Laurenzo@run 199f3f6d7bSStella Laurenzodef testDialectDescriptor(): 209f3f6d7bSStella Laurenzo ctx = Context() 2123aa5a74SRiver Riddle d = ctx.get_dialect_descriptor("func") 2223aa5a74SRiver Riddle # CHECK: <DialectDescriptor func> 239f3f6d7bSStella Laurenzo print(d) 2423aa5a74SRiver Riddle # CHECK: func 259f3f6d7bSStella Laurenzo print(d.namespace) 269f3f6d7bSStella Laurenzo try: 279f3f6d7bSStella Laurenzo _ = ctx.get_dialect_descriptor("not_existing") 289f3f6d7bSStella Laurenzo except ValueError: 299f3f6d7bSStella Laurenzo pass 309f3f6d7bSStella Laurenzo else: 319f3f6d7bSStella Laurenzo assert False, "Expected exception" 329f3f6d7bSStella Laurenzo 339f3f6d7bSStella Laurenzo 349f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testUserDialectClass 359f3f6d7bSStella Laurenzo@run 369f3f6d7bSStella Laurenzodef testUserDialectClass(): 379f3f6d7bSStella Laurenzo ctx = Context() 389f3f6d7bSStella Laurenzo # Access using attribute. 3923aa5a74SRiver Riddle d = ctx.dialects.func 4023aa5a74SRiver Riddle # CHECK: <Dialect func (class mlir.dialects._func_ops_gen._Dialect)> 419f3f6d7bSStella Laurenzo print(d) 429f3f6d7bSStella Laurenzo try: 439f3f6d7bSStella Laurenzo _ = ctx.dialects.not_existing 449f3f6d7bSStella Laurenzo except AttributeError: 459f3f6d7bSStella Laurenzo pass 469f3f6d7bSStella Laurenzo else: 479f3f6d7bSStella Laurenzo assert False, "Expected exception" 489f3f6d7bSStella Laurenzo 499f3f6d7bSStella Laurenzo # Access using index. 5023aa5a74SRiver Riddle d = ctx.dialects["func"] 5123aa5a74SRiver Riddle # CHECK: <Dialect func (class mlir.dialects._func_ops_gen._Dialect)> 529f3f6d7bSStella Laurenzo print(d) 539f3f6d7bSStella Laurenzo try: 549f3f6d7bSStella Laurenzo _ = ctx.dialects["not_existing"] 559f3f6d7bSStella Laurenzo except IndexError: 569f3f6d7bSStella Laurenzo pass 579f3f6d7bSStella Laurenzo else: 589f3f6d7bSStella Laurenzo assert False, "Expected exception" 599f3f6d7bSStella Laurenzo 609f3f6d7bSStella Laurenzo # Using the 'd' alias. 6123aa5a74SRiver Riddle d = ctx.d["func"] 6223aa5a74SRiver Riddle # CHECK: <Dialect func (class mlir.dialects._func_ops_gen._Dialect)> 639f3f6d7bSStella Laurenzo print(d) 649f3f6d7bSStella Laurenzo 659f3f6d7bSStella Laurenzo 669f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testCustomOpView 679f3f6d7bSStella Laurenzo# This test uses the standard dialect AddFOp as an example of a user op. 689f3f6d7bSStella Laurenzo# TODO: Op creation and access is still quite verbose: simplify this test as 699f3f6d7bSStella Laurenzo# additional capabilities come online. 709f3f6d7bSStella Laurenzo@run 719f3f6d7bSStella Laurenzodef testCustomOpView(): 729f3f6d7bSStella Laurenzo def createInput(): 739f3f6d7bSStella Laurenzo op = Operation.create("pytest_dummy.intinput", results=[f32]) 749f3f6d7bSStella Laurenzo # TODO: Auto result cast from operation 759f3f6d7bSStella Laurenzo return op.results[0] 769f3f6d7bSStella Laurenzo 779f3f6d7bSStella Laurenzo with Context() as ctx, Location.unknown(): 789f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 799f3f6d7bSStella Laurenzo m = Module.create() 809f3f6d7bSStella Laurenzo 819f3f6d7bSStella Laurenzo with InsertionPoint(m.body): 829f3f6d7bSStella Laurenzo f32 = F32Type.get() 839f3f6d7bSStella Laurenzo # Create via dialects context collection. 849f3f6d7bSStella Laurenzo input1 = createInput() 859f3f6d7bSStella Laurenzo input2 = createInput() 862995d29bSAlex Zinenko op1 = ctx.dialects.arith.AddFOp(input1, input2) 879f3f6d7bSStella Laurenzo 889f3f6d7bSStella Laurenzo # Create via an import 89a54f4eaeSMogball from mlir.dialects.arith import AddFOp 90f9008e63STobias Hieta 912995d29bSAlex Zinenko AddFOp(input1, op1.result) 929f3f6d7bSStella Laurenzo 939f3f6d7bSStella Laurenzo # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput" 949f3f6d7bSStella Laurenzo # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput" 95a54f4eaeSMogball # CHECK: %[[R0:.*]] = arith.addf %[[INPUT0]], %[[INPUT1]] : f32 96a54f4eaeSMogball # CHECK: %[[R1:.*]] = arith.addf %[[INPUT0]], %[[R0]] : f32 979f3f6d7bSStella Laurenzo m.operation.print() 989f3f6d7bSStella Laurenzo 999f3f6d7bSStella Laurenzo 1009f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testIsRegisteredOperation 1019f3f6d7bSStella Laurenzo@run 1029f3f6d7bSStella Laurenzodef testIsRegisteredOperation(): 1039f3f6d7bSStella Laurenzo ctx = Context() 1049f3f6d7bSStella Laurenzo 105ace01605SRiver Riddle # CHECK: cf.cond_br: True 106ace01605SRiver Riddle print(f"cf.cond_br: {ctx.is_registered_operation('cf.cond_br')}") 10723aa5a74SRiver Riddle # CHECK: func.not_existing: False 10823aa5a74SRiver Riddle print(f"func.not_existing: {ctx.is_registered_operation('func.not_existing')}") 1095192e299SMaksim Levental 1105192e299SMaksim Levental 1115192e299SMaksim Levental# CHECK-LABEL: TEST: testAppendPrefixSearchPath 1125192e299SMaksim Levental@run 1135192e299SMaksim Leventaldef testAppendPrefixSearchPath(): 1145192e299SMaksim Levental ctx = Context() 1155192e299SMaksim Levental ctx.allow_unregistered_dialects = True 1165192e299SMaksim Levental with Location.unknown(ctx): 1175192e299SMaksim Levental assert not _cext.globals._check_dialect_module_loaded("custom") 1185192e299SMaksim Levental Operation.create("custom.op") 1195192e299SMaksim Levental assert not _cext.globals._check_dialect_module_loaded("custom") 1205192e299SMaksim Levental 1215192e299SMaksim Levental sys.path.append(".") 1225192e299SMaksim Levental _cext.globals.append_dialect_search_prefix("custom_dialect") 1235192e299SMaksim Levental assert _cext.globals._check_dialect_module_loaded("custom") 124*c703b464SJacques Pienaar 125*c703b464SJacques Pienaar 126*c703b464SJacques Pienaar# CHECK-LABEL: TEST: testDialectLoadOnCreate 127*c703b464SJacques Pienaar@run 128*c703b464SJacques Pienaardef testDialectLoadOnCreate(): 129*c703b464SJacques Pienaar with Context(load_on_create_dialects=[]) as ctx: 130*c703b464SJacques Pienaar ctx.emit_error_diagnostics = True 131*c703b464SJacques Pienaar ctx.allow_unregistered_dialects = True 132*c703b464SJacques Pienaar 133*c703b464SJacques Pienaar def callback(d): 134*c703b464SJacques Pienaar # CHECK: DIAGNOSTIC 135*c703b464SJacques Pienaar # CHECK-SAME: op created with unregistered dialect 136*c703b464SJacques Pienaar print(f"DIAGNOSTIC={d.message}") 137*c703b464SJacques Pienaar return True 138*c703b464SJacques Pienaar 139*c703b464SJacques Pienaar handler = ctx.attach_diagnostic_handler(callback) 140*c703b464SJacques Pienaar loc = Location.unknown(ctx) 141*c703b464SJacques Pienaar try: 142*c703b464SJacques Pienaar op = Operation.create("arith.addi", loc=loc) 143*c703b464SJacques Pienaar ctx.allow_unregistered_dialects = False 144*c703b464SJacques Pienaar op.verify() 145*c703b464SJacques Pienaar except MLIRError as e: 146*c703b464SJacques Pienaar pass 147*c703b464SJacques Pienaar 148*c703b464SJacques Pienaar with Context(load_on_create_dialects=["func"]) as ctx: 149*c703b464SJacques Pienaar loc = Location.unknown(ctx) 150*c703b464SJacques Pienaar fn = Operation.create("func.func", loc=loc) 151*c703b464SJacques Pienaar 152*c703b464SJacques Pienaar # TODO: This may require an update if a site wide policy is set. 153*c703b464SJacques Pienaar # CHECK: Load on create: [] 154*c703b464SJacques Pienaar print(f"Load on create: {get_load_on_create_dialects()}") 155*c703b464SJacques Pienaar append_load_on_create_dialect("func") 156*c703b464SJacques Pienaar # CHECK: Load on create: 157*c703b464SJacques Pienaar # CHECK-SAME: func 158*c703b464SJacques Pienaar print(f"Load on create: {get_load_on_create_dialects()}") 159*c703b464SJacques Pienaar print(get_load_on_create_dialects()) 160