xref: /llvm-project/mlir/test/python/ir/diagnostic_handler.py (revision 8934b10642664c0824f45f115b2a0afcb56a5e5f)
1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6
7def run(f):
8    print("\nTEST:", f.__name__)
9    f()
10    gc.collect()
11    assert Context._get_live_count() == 0
12    return f
13
14
15@run
16def testLifecycleContextDestroy():
17    ctx = Context()
18
19    def callback(foo):
20        ...
21
22    handler = ctx.attach_diagnostic_handler(callback)
23    assert handler.attached
24    # If context is destroyed before the handler, it should auto-detach.
25    ctx = None
26    gc.collect()
27    assert not handler.attached
28
29    # And finally collecting the handler should be fine.
30    handler = None
31    gc.collect()
32
33
34@run
35def testLifecycleExplicitDetach():
36    ctx = Context()
37
38    def callback(foo):
39        ...
40
41    handler = ctx.attach_diagnostic_handler(callback)
42    assert handler.attached
43    handler.detach()
44    assert not handler.attached
45
46
47@run
48def testLifecycleWith():
49    ctx = Context()
50
51    def callback(foo):
52        ...
53
54    with ctx.attach_diagnostic_handler(callback) as handler:
55        assert handler.attached
56    assert not handler.attached
57
58
59@run
60def testLifecycleWithAndExplicitDetach():
61    ctx = Context()
62
63    def callback(foo):
64        ...
65
66    with ctx.attach_diagnostic_handler(callback) as handler:
67        assert handler.attached
68        handler.detach()
69    assert not handler.attached
70
71
72# CHECK-LABEL: TEST: testDiagnosticCallback
73@run
74def testDiagnosticCallback():
75    ctx = Context()
76
77    def callback(d):
78        # CHECK: DIAGNOSTIC: message='foobar', severity=DiagnosticSeverity.ERROR, loc=loc(unknown)
79        print(
80            f"DIAGNOSTIC: message='{d.message}', severity={d.severity}, loc={d.location}"
81        )
82        return True
83
84    handler = ctx.attach_diagnostic_handler(callback)
85    loc = Location.unknown(ctx)
86    loc.emit_error("foobar")
87    assert not handler.had_error
88
89
90# CHECK-LABEL: TEST: testDiagnosticEmptyNotes
91# TODO: Come up with a way to inject a diagnostic with notes from this API.
92@run
93def testDiagnosticEmptyNotes():
94    ctx = Context()
95
96    def callback(d):
97        # CHECK: DIAGNOSTIC: notes=()
98        print(f"DIAGNOSTIC: notes={d.notes}")
99        return True
100
101    handler = ctx.attach_diagnostic_handler(callback)
102    loc = Location.unknown(ctx)
103    loc.emit_error("foobar")
104    assert not handler.had_error
105
106
107# CHECK-LABEL: TEST: testDiagnosticNonEmptyNotes
108@run
109def testDiagnosticNonEmptyNotes():
110    ctx = Context()
111    ctx.emit_error_diagnostics = True
112
113    def callback(d):
114        # CHECK: DIAGNOSTIC:
115        # CHECK:   message='arith.addi' op requires one result
116        # CHECK:   notes=['see current operation: "arith.addi"() {{.*}} : () -> ()']
117        print(f"DIAGNOSTIC:")
118        print(f"  message={d.message}")
119        print(f"  notes={list(map(str, d.notes))}")
120        return True
121
122    handler = ctx.attach_diagnostic_handler(callback)
123    loc = Location.unknown(ctx)
124    try:
125        Operation.create("arith.addi", loc=loc).verify()
126    except MLIRError:
127        pass
128    assert not handler.had_error
129
130
131# CHECK-LABEL: TEST: testDiagnosticCallbackException
132@run
133def testDiagnosticCallbackException():
134    ctx = Context()
135
136    def callback(d):
137        raise ValueError("Error in handler")
138
139    handler = ctx.attach_diagnostic_handler(callback)
140    loc = Location.unknown(ctx)
141    loc.emit_error("foobar")
142    assert handler.had_error
143
144
145# CHECK-LABEL: TEST: testEscapingDiagnostic
146@run
147def testEscapingDiagnostic():
148    ctx = Context()
149    diags = []
150
151    def callback(d):
152        diags.append(d)
153        return True
154
155    handler = ctx.attach_diagnostic_handler(callback)
156    loc = Location.unknown(ctx)
157    loc.emit_error("foobar")
158    assert not handler.had_error
159
160    # CHECK: DIAGNOSTIC: <Invalid Diagnostic>
161    print(f"DIAGNOSTIC: {str(diags[0])}")
162    try:
163        diags[0].severity
164        raise RuntimeError("expected exception")
165    except ValueError:
166        pass
167    try:
168        diags[0].location
169        raise RuntimeError("expected exception")
170    except ValueError:
171        pass
172    try:
173        diags[0].message
174        raise RuntimeError("expected exception")
175    except ValueError:
176        pass
177    try:
178        diags[0].notes
179        raise RuntimeError("expected exception")
180    except ValueError:
181        pass
182
183
184# CHECK-LABEL: TEST: testDiagnosticReturnTrueHandles
185@run
186def testDiagnosticReturnTrueHandles():
187    ctx = Context()
188
189    def callback1(d):
190        print(f"CALLBACK1: {d}")
191        return True
192
193    def callback2(d):
194        print(f"CALLBACK2: {d}")
195        return True
196
197    ctx.attach_diagnostic_handler(callback1)
198    ctx.attach_diagnostic_handler(callback2)
199    loc = Location.unknown(ctx)
200    # CHECK-NOT: CALLBACK1
201    # CHECK: CALLBACK2: foobar
202    # CHECK-NOT: CALLBACK1
203    loc.emit_error("foobar")
204
205
206# CHECK-LABEL: TEST: testDiagnosticReturnFalseDoesNotHandle
207@run
208def testDiagnosticReturnFalseDoesNotHandle():
209    ctx = Context()
210
211    def callback1(d):
212        print(f"CALLBACK1: {d}")
213        return True
214
215    def callback2(d):
216        print(f"CALLBACK2: {d}")
217        return False
218
219    ctx.attach_diagnostic_handler(callback1)
220    ctx.attach_diagnostic_handler(callback2)
221    loc = Location.unknown(ctx)
222    # CHECK: CALLBACK2: foobar
223    # CHECK: CALLBACK1: foobar
224    loc.emit_error("foobar")
225