xref: /llvm-project/mlir/test/python/pass_manager.py (revision 2e51e150e161bd5fb5b8adb8655744a672ced002)
1# RUN: %PYTHON %s 2>&1 | FileCheck %s
2
3import gc, os, sys, tempfile
4from mlir.ir import *
5from mlir.passmanager import *
6from mlir.dialects.func import FuncOp
7from mlir.dialects.builtin import ModuleOp
8
9
10# Log everything to stderr and flush so that we have a unified stream to match
11# errors/info emitted by MLIR to stderr.
12def log(*args):
13    print(*args, file=sys.stderr)
14    sys.stderr.flush()
15
16
17def run(f):
18    log("\nTEST:", f.__name__)
19    f()
20    gc.collect()
21    assert Context._get_live_count() == 0
22
23
24# Verify capsule interop.
25# CHECK-LABEL: TEST: testCapsule
26def testCapsule():
27    with Context():
28        pm = PassManager()
29        pm_capsule = pm._CAPIPtr
30        assert '"mlir.passmanager.PassManager._CAPIPtr"' in repr(pm_capsule)
31        pm._testing_release()
32        pm1 = PassManager._CAPICreate(pm_capsule)
33        assert pm1 is not None  # And does not crash.
34
35
36run(testCapsule)
37
38
39# CHECK-LABEL: TEST: testConstruct
40@run
41def testConstruct():
42    with Context():
43        # CHECK: pm1: 'any()'
44        # CHECK: pm2: 'builtin.module()'
45        pm1 = PassManager()
46        pm2 = PassManager("builtin.module")
47        log(f"pm1: '{pm1}'")
48        log(f"pm2: '{pm2}'")
49
50
51# Verify successful round-trip.
52# CHECK-LABEL: TEST: testParseSuccess
53def testParseSuccess():
54    with Context():
55        # An unregistered pass should not parse.
56        try:
57            pm = PassManager.parse(
58                "builtin.module(func.func(not-existing-pass{json=false}))"
59            )
60        except ValueError as e:
61            # CHECK: ValueError exception: {{.+}} 'not-existing-pass' does not refer to a registered pass
62            log("ValueError exception:", e)
63        else:
64            log("Exception not produced")
65
66        # A registered pass should parse successfully.
67        pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))")
68        # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false}))
69        log("Roundtrip: ", pm)
70
71
72run(testParseSuccess)
73
74
75# Verify successful round-trip.
76# CHECK-LABEL: TEST: testParseSpacedPipeline
77def testParseSpacedPipeline():
78    with Context():
79        # A registered pass should parse successfully even if has extras spaces for readability
80        pm = PassManager.parse(
81            """builtin.module(
82        func.func( print-op-stats{ json=false } )
83    )"""
84        )
85        # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false}))
86        log("Roundtrip: ", pm)
87
88
89run(testParseSpacedPipeline)
90
91
92# Verify failure on unregistered pass.
93# CHECK-LABEL: TEST: testParseFail
94def testParseFail():
95    with Context():
96        try:
97            pm = PassManager.parse("any(unknown-pass)")
98        except ValueError as e:
99            #      CHECK: ValueError exception: MLIR Textual PassPipeline Parser:1:1: error:
100            # CHECK-SAME: 'unknown-pass' does not refer to a registered pass or pass pipeline
101            #      CHECK: unknown-pass
102            #      CHECK: ^
103            log("ValueError exception:", e)
104        else:
105            log("Exception not produced")
106
107
108run(testParseFail)
109
110
111# Check that adding to a pass manager works
112# CHECK-LABEL: TEST: testAdd
113@run
114def testAdd():
115    pm = PassManager("any", Context())
116    # CHECK: pm: 'any()'
117    log(f"pm: '{pm}'")
118    # CHECK: pm: 'any(cse)'
119    pm.add("cse")
120    log(f"pm: '{pm}'")
121    # CHECK: pm: 'any(cse,cse)'
122    pm.add("cse")
123    log(f"pm: '{pm}'")
124
125
126# Verify failure on incorrect level of nesting.
127# CHECK-LABEL: TEST: testInvalidNesting
128def testInvalidNesting():
129    with Context():
130        try:
131            pm = PassManager.parse("func.func(normalize-memrefs)")
132        except ValueError as e:
133            # CHECK: ValueError exception: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest?
134            log("ValueError exception:", e)
135        else:
136            log("Exception not produced")
137
138
139run(testInvalidNesting)
140
141
142# Verify that a pass manager can execute on IR
143# CHECK-LABEL: TEST: testRunPipeline
144def testRunPipeline():
145    with Context():
146        pm = PassManager.parse("any(print-op-stats{json=false})")
147        func = FuncOp.parse(r"""func.func @successfulParse() { return }""")
148        pm.run(func)
149
150
151# CHECK: Operations encountered:
152# CHECK: func.func      , 1
153# CHECK: func.return        , 1
154run(testRunPipeline)
155
156
157# CHECK-LABEL: TEST: testRunPipelineError
158@run
159def testRunPipelineError():
160    with Context() as ctx:
161        ctx.allow_unregistered_dialects = True
162        op = Operation.parse('"test.op"() : () -> ()')
163        pm = PassManager.parse("any(cse)")
164        try:
165            pm.run(op)
166        except MLIRError as e:
167            # CHECK: Exception: <
168            # CHECK:   Failure while executing pass pipeline:
169            # CHECK:   error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation
170            # CHECK:    note: "-":1:1: see current operation: "test.op"() : () -> ()
171            # CHECK: >
172            log(f"Exception: <{e}>")
173
174
175# CHECK-LABEL: TEST: testPostPassOpInvalidation
176@run
177def testPostPassOpInvalidation():
178    with Context() as ctx:
179        log_op_count = lambda: log("live ops:", ctx._get_live_operation_count())
180
181        # CHECK: invalidate_ops=False
182        log("invalidate_ops=False")
183
184        # CHECK: live ops: 0
185        log_op_count()
186
187        module = ModuleOp.parse(
188            """
189          module {
190            arith.constant 10
191            func.func @foo() {
192              arith.constant 10
193              return
194            }
195          }
196        """
197        )
198
199        # CHECK: live ops: 1
200        log_op_count()
201
202        outer_const_op = module.body.operations[0]
203        # CHECK: %[[VAL0:.*]] = arith.constant 10 : i64
204        log(outer_const_op)
205
206        func_op = module.body.operations[1]
207        # CHECK: func.func @[[FOO:.*]]() {
208        # CHECK:   %[[VAL1:.*]] = arith.constant 10 : i64
209        # CHECK:   return
210        # CHECK: }
211        log(func_op)
212
213        inner_const_op = func_op.body.blocks[0].operations[0]
214        # CHECK: %[[VAL1]] = arith.constant 10 : i64
215        log(inner_const_op)
216
217        # CHECK: live ops: 4
218        log_op_count()
219
220        PassManager.parse("builtin.module(canonicalize)").run(
221            module, invalidate_ops=False
222        )
223        # CHECK: func.func @foo() {
224        # CHECK:   return
225        # CHECK: }
226        log(func_op)
227
228        # CHECK: func.func @foo() {
229        # CHECK:   return
230        # CHECK: }
231        log(module)
232
233        # CHECK: invalidate_ops=True
234        log("invalidate_ops=True")
235
236        # CHECK: live ops: 4
237        log_op_count()
238
239        module = ModuleOp.parse(
240            """
241          module {
242            arith.constant 10
243            func.func @foo() {
244              arith.constant 10
245              return
246            }
247          }
248        """
249        )
250        outer_const_op = module.body.operations[0]
251        func_op = module.body.operations[1]
252        inner_const_op = func_op.body.blocks[0].operations[0]
253
254        # CHECK: live ops: 4
255        log_op_count()
256
257        PassManager.parse("builtin.module(canonicalize)").run(module)
258
259        # CHECK: live ops: 1
260        log_op_count()
261
262        try:
263            log(func_op)
264        except RuntimeError as e:
265            # CHECK: the operation has been invalidated
266            log(e)
267
268        try:
269            log(outer_const_op)
270        except RuntimeError as e:
271            # CHECK: the operation has been invalidated
272            log(e)
273
274        try:
275            log(inner_const_op)
276        except RuntimeError as e:
277            # CHECK: the operation has been invalidated
278            log(e)
279
280        # CHECK: func.func @foo() {
281        # CHECK:   return
282        # CHECK: }
283        log(module)
284
285
286# CHECK-LABEL: TEST: testPrintIrAfterAll
287@run
288def testPrintIrAfterAll():
289    with Context() as ctx:
290        module = ModuleOp.parse(
291            """
292          module {
293            func.func @main() {
294              %0 = arith.constant 10
295              return
296            }
297          }
298        """
299        )
300        pm = PassManager.parse("builtin.module(canonicalize)")
301        ctx.enable_multithreading(False)
302        pm.enable_ir_printing()
303        # CHECK: // -----// IR Dump After Canonicalizer (canonicalize) //----- //
304        # CHECK: module {
305        # CHECK:   func.func @main() {
306        # CHECK:     return
307        # CHECK:   }
308        # CHECK: }
309        pm.run(module)
310
311
312# CHECK-LABEL: TEST: testPrintIrBeforeAndAfterAll
313@run
314def testPrintIrBeforeAndAfterAll():
315    with Context() as ctx:
316        module = ModuleOp.parse(
317            """
318          module {
319            func.func @main() {
320              %0 = arith.constant 10
321              return
322            }
323          }
324        """
325        )
326        pm = PassManager.parse("builtin.module(canonicalize)")
327        ctx.enable_multithreading(False)
328        pm.enable_ir_printing(print_before_all=True, print_after_all=True)
329        # CHECK: // -----// IR Dump Before Canonicalizer (canonicalize) //----- //
330        # CHECK: module {
331        # CHECK:   func.func @main() {
332        # CHECK:     %[[C10:.*]] = arith.constant 10 : i64
333        # CHECK:     return
334        # CHECK:   }
335        # CHECK: }
336        # CHECK: // -----// IR Dump After Canonicalizer (canonicalize) //----- //
337        # CHECK: module {
338        # CHECK:   func.func @main() {
339        # CHECK:     return
340        # CHECK:   }
341        # CHECK: }
342        pm.run(module)
343
344
345# CHECK-LABEL: TEST: testPrintIrLargeLimitElements
346@run
347def testPrintIrLargeLimitElements():
348    with Context() as ctx:
349        module = ModuleOp.parse(
350            """
351          module {
352            func.func @main() -> tensor<3xi64> {
353              %0 = arith.constant dense<[1, 2, 3]> : tensor<3xi64>
354              return %0 : tensor<3xi64>
355            }
356          }
357        """
358        )
359        pm = PassManager.parse("builtin.module(canonicalize)")
360        ctx.enable_multithreading(False)
361        pm.enable_ir_printing(large_elements_limit=2)
362        # CHECK:     %[[CST:.*]] = arith.constant dense_resource<__elided__> : tensor<3xi64>
363        pm.run(module)
364
365
366# CHECK-LABEL: TEST: testPrintIrTree
367@run
368def testPrintIrTree():
369    with Context() as ctx:
370        module = ModuleOp.parse(
371            """
372          module {
373            func.func @main() {
374              %0 = arith.constant 10
375              return
376            }
377          }
378        """
379        )
380        pm = PassManager.parse("builtin.module(canonicalize)")
381        ctx.enable_multithreading(False)
382        pm.enable_ir_printing()
383        # CHECK-LABEL: // Tree printing begin
384        # CHECK: \-- builtin_module_no-symbol-name
385        # CHECK:     \-- 0_canonicalize.mlir
386        # CHECK-LABEL: // Tree printing end
387        pm.run(module)
388        log("// Tree printing begin")
389        with tempfile.TemporaryDirectory() as temp_dir:
390            pm.enable_ir_printing(tree_printing_dir_path=temp_dir)
391            pm.run(module)
392
393            def print_file_tree(directory, prefix=""):
394                entries = sorted(os.listdir(directory))
395                for i, entry in enumerate(entries):
396                    path = os.path.join(directory, entry)
397                    connector = "\-- " if i == len(entries) - 1 else "|-- "
398                    log(f"{prefix}{connector}{entry}")
399                    if os.path.isdir(path):
400                        print_file_tree(
401                            path, prefix + ("    " if i == len(entries) - 1 else "│   ")
402                        )
403
404            print_file_tree(temp_dir)
405        log("// Tree printing end")
406