1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4import sys 5from mlir.ir import * 6from mlir.dialects._ods_common import _cext 7 8 9def run(f): 10 print("\nTEST:", f.__name__) 11 f() 12 gc.collect() 13 assert Context._get_live_count() == 0 14 return f 15 16 17# CHECK-LABEL: TEST: testDialectDescriptor 18@run 19def testDialectDescriptor(): 20 ctx = Context() 21 d = ctx.get_dialect_descriptor("func") 22 # CHECK: <DialectDescriptor func> 23 print(d) 24 # CHECK: func 25 print(d.namespace) 26 try: 27 _ = ctx.get_dialect_descriptor("not_existing") 28 except ValueError: 29 pass 30 else: 31 assert False, "Expected exception" 32 33 34# CHECK-LABEL: TEST: testUserDialectClass 35@run 36def testUserDialectClass(): 37 ctx = Context() 38 # Access using attribute. 39 d = ctx.dialects.func 40 # CHECK: <Dialect func (class mlir.dialects._func_ops_gen._Dialect)> 41 print(d) 42 try: 43 _ = ctx.dialects.not_existing 44 except AttributeError: 45 pass 46 else: 47 assert False, "Expected exception" 48 49 # Access using index. 50 d = ctx.dialects["func"] 51 # CHECK: <Dialect func (class mlir.dialects._func_ops_gen._Dialect)> 52 print(d) 53 try: 54 _ = ctx.dialects["not_existing"] 55 except IndexError: 56 pass 57 else: 58 assert False, "Expected exception" 59 60 # Using the 'd' alias. 61 d = ctx.d["func"] 62 # CHECK: <Dialect func (class mlir.dialects._func_ops_gen._Dialect)> 63 print(d) 64 65 66# CHECK-LABEL: TEST: testCustomOpView 67# This test uses the standard dialect AddFOp as an example of a user op. 68# TODO: Op creation and access is still quite verbose: simplify this test as 69# additional capabilities come online. 70@run 71def testCustomOpView(): 72 def createInput(): 73 op = Operation.create("pytest_dummy.intinput", results=[f32]) 74 # TODO: Auto result cast from operation 75 return op.results[0] 76 77 with Context() as ctx, Location.unknown(): 78 ctx.allow_unregistered_dialects = True 79 m = Module.create() 80 81 with InsertionPoint(m.body): 82 f32 = F32Type.get() 83 # Create via dialects context collection. 84 input1 = createInput() 85 input2 = createInput() 86 op1 = ctx.dialects.arith.AddFOp(input1, input2) 87 88 # Create via an import 89 from mlir.dialects.arith import AddFOp 90 91 AddFOp(input1, op1.result) 92 93 # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput" 94 # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput" 95 # CHECK: %[[R0:.*]] = arith.addf %[[INPUT0]], %[[INPUT1]] : f32 96 # CHECK: %[[R1:.*]] = arith.addf %[[INPUT0]], %[[R0]] : f32 97 m.operation.print() 98 99 100# CHECK-LABEL: TEST: testIsRegisteredOperation 101@run 102def testIsRegisteredOperation(): 103 ctx = Context() 104 105 # CHECK: cf.cond_br: True 106 print(f"cf.cond_br: {ctx.is_registered_operation('cf.cond_br')}") 107 # CHECK: func.not_existing: False 108 print(f"func.not_existing: {ctx.is_registered_operation('func.not_existing')}") 109 110 111# CHECK-LABEL: TEST: testAppendPrefixSearchPath 112@run 113def testAppendPrefixSearchPath(): 114 ctx = Context() 115 ctx.allow_unregistered_dialects = True 116 with Location.unknown(ctx): 117 assert not _cext.globals._check_dialect_module_loaded("custom") 118 Operation.create("custom.op") 119 assert not _cext.globals._check_dialect_module_loaded("custom") 120 121 sys.path.append(".") 122 _cext.globals.append_dialect_search_prefix("custom_dialect") 123 assert _cext.globals._check_dialect_module_loaded("custom") 124 125 126# CHECK-LABEL: TEST: testDialectLoadOnCreate 127@run 128def testDialectLoadOnCreate(): 129 with Context(load_on_create_dialects=[]) as ctx: 130 ctx.emit_error_diagnostics = True 131 ctx.allow_unregistered_dialects = True 132 133 def callback(d): 134 # CHECK: DIAGNOSTIC 135 # CHECK-SAME: op created with unregistered dialect 136 print(f"DIAGNOSTIC={d.message}") 137 return True 138 139 handler = ctx.attach_diagnostic_handler(callback) 140 loc = Location.unknown(ctx) 141 try: 142 op = Operation.create("arith.addi", loc=loc) 143 ctx.allow_unregistered_dialects = False 144 op.verify() 145 except MLIRError as e: 146 pass 147 148 with Context(load_on_create_dialects=["func"]) as ctx: 149 loc = Location.unknown(ctx) 150 fn = Operation.create("func.func", loc=loc) 151 152 # TODO: This may require an update if a site wide policy is set. 153 # CHECK: Load on create: [] 154 print(f"Load on create: {get_load_on_create_dialects()}") 155 append_load_on_create_dialect("func") 156 # CHECK: Load on create: 157 # CHECK-SAME: func 158 print(f"Load on create: {get_load_on_create_dialects()}") 159 print(get_load_on_create_dialects()) 160