xref: /llvm-project/mlir/test/python/ir/operation.py (revision abad8455ab08d4ca25b893e6a4605c1abe4aac02)
1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4import io
5import itertools
6from mlir.ir import *
7from mlir.dialects.builtin import ModuleOp
8from mlir.dialects import arith
9from mlir.dialects._ods_common import _cext
10
11
12def run(f):
13    print("\nTEST:", f.__name__)
14    f()
15    gc.collect()
16    assert Context._get_live_count() == 0
17    return f
18
19
20def expect_index_error(callback):
21    try:
22        _ = callback()
23        raise RuntimeError("Expected IndexError")
24    except IndexError:
25        pass
26
27
28# Verify iterator based traversal of the op/region/block hierarchy.
29# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
30@run
31def testTraverseOpRegionBlockIterators():
32    ctx = Context()
33    ctx.allow_unregistered_dialects = True
34    module = Module.parse(
35        r"""
36    func.func @f1(%arg0: i32) -> i32 {
37      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
38      return %1 : i32
39    }
40  """,
41        ctx,
42    )
43    op = module.operation
44    assert op.context is ctx
45    # Get the block using iterators off of the named collections.
46    regions = list(op.regions)
47    blocks = list(regions[0].blocks)
48    # CHECK: MODULE REGIONS=1 BLOCKS=1
49    print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
50
51    # Should verify.
52    # CHECK: .verify = True
53    print(f".verify = {module.operation.verify()}")
54
55    # Get the blocks from the default collection.
56    default_blocks = list(regions[0])
57    # They should compare equal regardless of how obtained.
58    assert default_blocks == blocks
59
60    # Should be able to get the operations from either the named collection
61    # or the block.
62    operations = list(blocks[0].operations)
63    default_operations = list(blocks[0])
64    assert default_operations == operations
65
66    def walk_operations(indent, op):
67        for i, region in enumerate(op.regions):
68            print(f"{indent}REGION {i}:")
69            for j, block in enumerate(region):
70                print(f"{indent}  BLOCK {j}:")
71                for k, child_op in enumerate(block):
72                    print(f"{indent}    OP {k}: {child_op}")
73                    walk_operations(indent + "      ", child_op)
74
75    # CHECK: REGION 0:
76    # CHECK:   BLOCK 0:
77    # CHECK:     OP 0: func
78    # CHECK:       REGION 0:
79    # CHECK:         BLOCK 0:
80    # CHECK:           OP 0: %0 = "custom.addi"
81    # CHECK:           OP 1: func.return
82    walk_operations("", op)
83
84    # CHECK:    Region iter: <mlir.{{.+}}.RegionIterator
85    # CHECK:     Block iter: <mlir.{{.+}}.BlockIterator
86    # CHECK: Operation iter: <mlir.{{.+}}.OperationIterator
87    print("   Region iter:", iter(op.regions))
88    print("    Block iter:", iter(op.regions[0]))
89    print("Operation iter:", iter(op.regions[0].blocks[0]))
90
91
92# Verify index based traversal of the op/region/block hierarchy.
93# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
94@run
95def testTraverseOpRegionBlockIndices():
96    ctx = Context()
97    ctx.allow_unregistered_dialects = True
98    module = Module.parse(
99        r"""
100    func.func @f1(%arg0: i32) -> i32 {
101      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
102      return %1 : i32
103    }
104  """,
105        ctx,
106    )
107
108    def walk_operations(indent, op):
109        for i in range(len(op.regions)):
110            region = op.regions[i]
111            print(f"{indent}REGION {i}:")
112            for j in range(len(region.blocks)):
113                block = region.blocks[j]
114                print(f"{indent}  BLOCK {j}:")
115                for k in range(len(block.operations)):
116                    child_op = block.operations[k]
117                    print(f"{indent}    OP {k}: {child_op}")
118                    print(
119                        f"{indent}    OP {k}: parent {child_op.operation.parent.name}"
120                    )
121                    walk_operations(indent + "      ", child_op)
122
123    # CHECK: REGION 0:
124    # CHECK:   BLOCK 0:
125    # CHECK:     OP 0: func
126    # CHECK:     OP 0: parent builtin.module
127    # CHECK:       REGION 0:
128    # CHECK:         BLOCK 0:
129    # CHECK:           OP 0: %0 = "custom.addi"
130    # CHECK:           OP 0: parent func.func
131    # CHECK:           OP 1: func.return
132    # CHECK:           OP 1: parent func.func
133    walk_operations("", module.operation)
134
135
136# CHECK-LABEL: TEST: testBlockAndRegionOwners
137@run
138def testBlockAndRegionOwners():
139    ctx = Context()
140    ctx.allow_unregistered_dialects = True
141    module = Module.parse(
142        r"""
143    builtin.module {
144      func.func @f() {
145        func.return
146      }
147    }
148  """,
149        ctx,
150    )
151
152    assert module.operation.regions[0].owner == module.operation
153    assert module.operation.regions[0].blocks[0].owner == module.operation
154
155    func = module.body.operations[0]
156    assert func.operation.regions[0].owner == func
157    assert func.operation.regions[0].blocks[0].owner == func
158
159
160# CHECK-LABEL: TEST: testBlockArgumentList
161@run
162def testBlockArgumentList():
163    with Context() as ctx:
164        module = Module.parse(
165            r"""
166      func.func @f1(%arg0: i32, %arg1: f64, %arg2: index) {
167        return
168      }
169    """,
170            ctx,
171        )
172        func = module.body.operations[0]
173        entry_block = func.regions[0].blocks[0]
174        assert len(entry_block.arguments) == 3
175        # CHECK: Argument 0, type i32
176        # CHECK: Argument 1, type f64
177        # CHECK: Argument 2, type index
178        for arg in entry_block.arguments:
179            print(f"Argument {arg.arg_number}, type {arg.type}")
180            new_type = IntegerType.get_signless(8 * (arg.arg_number + 1))
181            arg.set_type(new_type)
182
183        # CHECK: Argument 0, type i8
184        # CHECK: Argument 1, type i16
185        # CHECK: Argument 2, type i24
186        for arg in entry_block.arguments:
187            print(f"Argument {arg.arg_number}, type {arg.type}")
188
189        # Check that slicing works for block argument lists.
190        # CHECK: Argument 1, type i16
191        # CHECK: Argument 2, type i24
192        for arg in entry_block.arguments[1:]:
193            print(f"Argument {arg.arg_number}, type {arg.type}")
194
195        # Check that we can concatenate slices of argument lists.
196        # CHECK: Length: 4
197        print("Length: ", len(entry_block.arguments[:2] + entry_block.arguments[1:]))
198
199        # CHECK: Type: i8
200        # CHECK: Type: i16
201        # CHECK: Type: i24
202        for t in entry_block.arguments.types:
203            print("Type: ", t)
204
205        # Check that slicing and type access compose.
206        # CHECK: Sliced type: i16
207        # CHECK: Sliced type: i24
208        for t in entry_block.arguments[1:].types:
209            print("Sliced type: ", t)
210
211        # Check that slice addition works as expected.
212        # CHECK: Argument 2, type i24
213        # CHECK: Argument 0, type i8
214        restructured = entry_block.arguments[-1:] + entry_block.arguments[:1]
215        for arg in restructured:
216            print(f"Argument {arg.arg_number}, type {arg.type}")
217
218
219# CHECK-LABEL: TEST: testOperationOperands
220@run
221def testOperationOperands():
222    with Context() as ctx:
223        ctx.allow_unregistered_dialects = True
224        module = Module.parse(
225            r"""
226      func.func @f1(%arg0: i32) {
227        %0 = "test.producer"() : () -> i64
228        "test.consumer"(%arg0, %0) : (i32, i64) -> ()
229        return
230      }"""
231        )
232        func = module.body.operations[0]
233        entry_block = func.regions[0].blocks[0]
234        consumer = entry_block.operations[1]
235        assert len(consumer.operands) == 2
236        # CHECK: Operand 0, type i32
237        # CHECK: Operand 1, type i64
238        for i, operand in enumerate(consumer.operands):
239            print(f"Operand {i}, type {operand.type}")
240
241
242# CHECK-LABEL: TEST: testOperationOperandsSlice
243@run
244def testOperationOperandsSlice():
245    with Context() as ctx:
246        ctx.allow_unregistered_dialects = True
247        module = Module.parse(
248            r"""
249      func.func @f1() {
250        %0 = "test.producer0"() : () -> i64
251        %1 = "test.producer1"() : () -> i64
252        %2 = "test.producer2"() : () -> i64
253        %3 = "test.producer3"() : () -> i64
254        %4 = "test.producer4"() : () -> i64
255        "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> ()
256        return
257      }"""
258        )
259        func = module.body.operations[0]
260        entry_block = func.regions[0].blocks[0]
261        consumer = entry_block.operations[5]
262        assert len(consumer.operands) == 5
263        for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]):
264            assert left == right
265
266        # CHECK: test.producer0
267        # CHECK: test.producer1
268        # CHECK: test.producer2
269        # CHECK: test.producer3
270        # CHECK: test.producer4
271        full_slice = consumer.operands[:]
272        for operand in full_slice:
273            print(operand)
274
275        # CHECK: test.producer0
276        # CHECK: test.producer1
277        first_two = consumer.operands[0:2]
278        for operand in first_two:
279            print(operand)
280
281        # CHECK: test.producer3
282        # CHECK: test.producer4
283        last_two = consumer.operands[3:]
284        for operand in last_two:
285            print(operand)
286
287        # CHECK: test.producer0
288        # CHECK: test.producer2
289        # CHECK: test.producer4
290        even = consumer.operands[::2]
291        for operand in even:
292            print(operand)
293
294        # CHECK: test.producer2
295        fourth = consumer.operands[::2][1::2]
296        for operand in fourth:
297            print(operand)
298
299
300# CHECK-LABEL: TEST: testOperationOperandsSet
301@run
302def testOperationOperandsSet():
303    with Context() as ctx, Location.unknown(ctx):
304        ctx.allow_unregistered_dialects = True
305        module = Module.parse(
306            r"""
307      func.func @f1() {
308        %0 = "test.producer0"() : () -> i64
309        %1 = "test.producer1"() : () -> i64
310        %2 = "test.producer2"() : () -> i64
311        "test.consumer"(%0) : (i64) -> ()
312        return
313      }"""
314        )
315        func = module.body.operations[0]
316        entry_block = func.regions[0].blocks[0]
317        producer1 = entry_block.operations[1]
318        producer2 = entry_block.operations[2]
319        consumer = entry_block.operations[3]
320        assert len(consumer.operands) == 1
321        type = consumer.operands[0].type
322
323        # CHECK: test.producer1
324        consumer.operands[0] = producer1.result
325        print(consumer.operands[0])
326
327        # CHECK: test.producer2
328        consumer.operands[-1] = producer2.result
329        print(consumer.operands[0])
330
331
332# CHECK-LABEL: TEST: testDetachedOperation
333@run
334def testDetachedOperation():
335    ctx = Context()
336    ctx.allow_unregistered_dialects = True
337    with Location.unknown(ctx):
338        i32 = IntegerType.get_signed(32)
339        op1 = Operation.create(
340            "custom.op1",
341            results=[i32, i32],
342            regions=1,
343            attributes={
344                "foo": StringAttr.get("foo_value"),
345                "bar": StringAttr.get("bar_value"),
346            },
347        )
348        # CHECK: %0:2 = "custom.op1"() ({
349        # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
350        print(op1)
351
352    # TODO: Check successors once enough infra exists to do it properly.
353
354
355# CHECK-LABEL: TEST: testOperationInsertionPoint
356@run
357def testOperationInsertionPoint():
358    ctx = Context()
359    ctx.allow_unregistered_dialects = True
360    module = Module.parse(
361        r"""
362    func.func @f1(%arg0: i32) -> i32 {
363      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
364      return %1 : i32
365    }
366  """,
367        ctx,
368    )
369
370    # Create test op.
371    with Location.unknown(ctx):
372        op1 = Operation.create("custom.op1")
373        op2 = Operation.create("custom.op2")
374
375        func = module.body.operations[0]
376        entry_block = func.regions[0].blocks[0]
377        ip = InsertionPoint.at_block_begin(entry_block)
378        ip.insert(op1)
379        ip.insert(op2)
380        # CHECK: func @f1
381        # CHECK: "custom.op1"()
382        # CHECK: "custom.op2"()
383        # CHECK: %0 = "custom.addi"
384        print(module)
385
386    # Trying to add a previously added op should raise.
387    try:
388        ip.insert(op1)
389    except ValueError:
390        pass
391    else:
392        assert False, "expected insert of attached op to raise"
393
394
395# CHECK-LABEL: TEST: testOperationWithRegion
396@run
397def testOperationWithRegion():
398    ctx = Context()
399    ctx.allow_unregistered_dialects = True
400    with Location.unknown(ctx):
401        i32 = IntegerType.get_signed(32)
402        op1 = Operation.create("custom.op1", regions=1)
403        block = op1.regions[0].blocks.append(i32, i32)
404        # CHECK: "custom.op1"() ({
405        # CHECK: ^bb0(%arg0: si32, %arg1: si32):
406        # CHECK:   "custom.terminator"() : () -> ()
407        # CHECK: }) : () -> ()
408        terminator = Operation.create("custom.terminator")
409        ip = InsertionPoint(block)
410        ip.insert(terminator)
411        print(op1)
412
413        # Now add the whole operation to another op.
414        # TODO: Verify lifetime hazard by nulling out the new owning module and
415        # accessing op1.
416        # TODO: Also verify accessing the terminator once both parents are nulled
417        # out.
418        module = Module.parse(
419            r"""
420      func.func @f1(%arg0: i32) -> i32 {
421        %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
422        return %1 : i32
423      }
424    """
425        )
426        func = module.body.operations[0]
427        entry_block = func.regions[0].blocks[0]
428        ip = InsertionPoint.at_block_begin(entry_block)
429        ip.insert(op1)
430        # CHECK: func @f1
431        # CHECK: "custom.op1"()
432        # CHECK:   "custom.terminator"
433        # CHECK: %0 = "custom.addi"
434        print(module)
435
436
437# CHECK-LABEL: TEST: testOperationResultList
438@run
439def testOperationResultList():
440    ctx = Context()
441    module = Module.parse(
442        r"""
443    func.func @f1() {
444      %0:3 = call @f2() : () -> (i32, f64, index)
445      call @f3() : () -> ()
446      return
447    }
448    func.func private @f2() -> (i32, f64, index)
449    func.func private @f3() -> ()
450  """,
451        ctx,
452    )
453    caller = module.body.operations[0]
454    call = caller.regions[0].blocks[0].operations[0]
455    assert len(call.results) == 3
456    # CHECK: Result 0, type i32
457    # CHECK: Result 1, type f64
458    # CHECK: Result 2, type index
459    for res in call.results:
460        print(f"Result {res.result_number}, type {res.type}")
461
462    # CHECK: Result type i32
463    # CHECK: Result type f64
464    # CHECK: Result type index
465    for t in call.results.types:
466        print(f"Result type {t}")
467
468    # Out of range
469    expect_index_error(lambda: call.results[3])
470    expect_index_error(lambda: call.results[-4])
471
472    no_results_call = caller.regions[0].blocks[0].operations[1]
473    assert len(no_results_call.results) == 0
474    assert no_results_call.results.owner == no_results_call
475
476
477# CHECK-LABEL: TEST: testOperationResultListSlice
478@run
479def testOperationResultListSlice():
480    with Context() as ctx:
481        ctx.allow_unregistered_dialects = True
482        module = Module.parse(
483            r"""
484      func.func @f1() {
485        "some.op"() : () -> (i1, i2, i3, i4, i5)
486        return
487      }
488    """
489        )
490        func = module.body.operations[0]
491        entry_block = func.regions[0].blocks[0]
492        producer = entry_block.operations[0]
493
494        assert len(producer.results) == 5
495        for left, right in zip(producer.results, producer.results[::-1][::-1]):
496            assert left == right
497            assert left.result_number == right.result_number
498
499        # CHECK: Result 0, type i1
500        # CHECK: Result 1, type i2
501        # CHECK: Result 2, type i3
502        # CHECK: Result 3, type i4
503        # CHECK: Result 4, type i5
504        full_slice = producer.results[:]
505        for res in full_slice:
506            print(f"Result {res.result_number}, type {res.type}")
507
508        # CHECK: Result 1, type i2
509        # CHECK: Result 2, type i3
510        # CHECK: Result 3, type i4
511        middle = producer.results[1:4]
512        for res in middle:
513            print(f"Result {res.result_number}, type {res.type}")
514
515        # CHECK: Result 1, type i2
516        # CHECK: Result 3, type i4
517        odd = producer.results[1::2]
518        for res in odd:
519            print(f"Result {res.result_number}, type {res.type}")
520
521        # CHECK: Result 3, type i4
522        # CHECK: Result 1, type i2
523        inverted_middle = producer.results[-2:0:-2]
524        for res in inverted_middle:
525            print(f"Result {res.result_number}, type {res.type}")
526
527
528# CHECK-LABEL: TEST: testOperationAttributes
529@run
530def testOperationAttributes():
531    ctx = Context()
532    ctx.allow_unregistered_dialects = True
533    module = Module.parse(
534        r"""
535    "some.op"() { some.attribute = 1 : i8,
536                  other.attribute = 3.0,
537                  dependent = "text" } : () -> ()
538  """,
539        ctx,
540    )
541    op = module.body.operations[0]
542    assert len(op.attributes) == 3
543    iattr = op.attributes["some.attribute"]
544    fattr = op.attributes["other.attribute"]
545    sattr = op.attributes["dependent"]
546    # CHECK: Attribute type i8, value 1
547    print(f"Attribute type {iattr.type}, value {iattr.value}")
548    # CHECK: Attribute type f64, value 3.0
549    print(f"Attribute type {fattr.type}, value {fattr.value}")
550    # CHECK: Attribute value text
551    print(f"Attribute value {sattr.value}")
552    # CHECK: Attribute value b'text'
553    print(f"Attribute value {sattr.value_bytes}")
554
555    # We don't know in which order the attributes are stored.
556    # CHECK-DAG: NamedAttribute(dependent="text")
557    # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
558    # CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
559    for attr in op.attributes:
560        print(str(attr))
561
562    # Check that exceptions are raised as expected.
563    try:
564        op.attributes["does_not_exist"]
565    except KeyError:
566        pass
567    else:
568        assert False, "expected KeyError on accessing a non-existent attribute"
569
570    try:
571        op.attributes[42]
572    except IndexError:
573        pass
574    else:
575        assert False, "expected IndexError on accessing an out-of-bounds attribute"
576
577
578# CHECK-LABEL: TEST: testOperationPrint
579@run
580def testOperationPrint():
581    ctx = Context()
582    module = Module.parse(
583        r"""
584    func.func @f1(%arg0: i32) -> i32 {
585      %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
586      return %arg0 : i32
587    }
588  """,
589        ctx,
590    )
591
592    # Test print to stdout.
593    # CHECK: return %arg0 : i32
594    module.operation.print()
595
596    # Test print to text file.
597    f = io.StringIO()
598    # CHECK: <class 'str'>
599    # CHECK: return %arg0 : i32
600    module.operation.print(file=f)
601    str_value = f.getvalue()
602    print(str_value.__class__)
603    print(f.getvalue())
604
605    # Test roundtrip to bytecode.
606    bytecode_stream = io.BytesIO()
607    module.operation.write_bytecode(bytecode_stream, desired_version=1)
608    bytecode = bytecode_stream.getvalue()
609    assert bytecode.startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
610    module_roundtrip = Module.parse(bytecode, ctx)
611    f = io.StringIO()
612    module_roundtrip.operation.print(file=f)
613    roundtrip_value = f.getvalue()
614    assert str_value == roundtrip_value, "Mismatch after roundtrip bytecode"
615
616    # Test print to binary file.
617    f = io.BytesIO()
618    # CHECK: <class 'bytes'>
619    # CHECK: return %arg0 : i32
620    module.operation.print(file=f, binary=True)
621    bytes_value = f.getvalue()
622    print(bytes_value.__class__)
623    print(bytes_value)
624
625    # Test print local_scope.
626    # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
627    module.operation.print(enable_debug_info=True, use_local_scope=True)
628
629    # Test printing using state.
630    state = AsmState(module.operation)
631    # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32>
632    module.operation.print(state)
633
634    # Test print with options.
635    # CHECK: value = dense_resource<__elided__> : tensor<4xi32>
636    # CHECK: "func.return"(%arg0) : (i32) -> () -:4:7
637    module.operation.print(
638        large_elements_limit=2,
639        enable_debug_info=True,
640        pretty_debug_info=True,
641        print_generic_op_form=True,
642        use_local_scope=True,
643    )
644
645    # Test print with skip_regions option
646    # CHECK: func.func @f1(%arg0: i32) -> i32
647    # CHECK-NOT: func.return
648    module.body.operations[0].print(
649        skip_regions=True,
650    )
651
652
653# CHECK-LABEL: TEST: testKnownOpView
654@run
655def testKnownOpView():
656    with Context(), Location.unknown():
657        Context.current.allow_unregistered_dialects = True
658        module = Module.parse(
659            r"""
660      %1 = "custom.f32"() : () -> f32
661      %2 = "custom.f32"() : () -> f32
662      %3 = arith.addf %1, %2 : f32
663      %4 = arith.constant 0 : i32
664    """
665        )
666        print(module)
667
668        # addf should map to a known OpView class in the arithmetic dialect.
669        # We know the OpView for it defines an 'lhs' attribute.
670        addf = module.body.operations[2]
671        # CHECK: <mlir.dialects._arith_ops_gen.AddFOp object
672        print(repr(addf))
673        # CHECK: "custom.f32"()
674        print(addf.lhs)
675
676        # One of the custom ops should resolve to the default OpView.
677        custom = module.body.operations[0]
678        # CHECK: OpView object
679        print(repr(custom))
680
681        # Check again to make sure negative caching works.
682        custom = module.body.operations[0]
683        # CHECK: OpView object
684        print(repr(custom))
685
686        # constant should map to an extension OpView class in the arithmetic dialect.
687        constant = module.body.operations[3]
688        # CHECK: <mlir.dialects.arith.ConstantOp object
689        print(repr(constant))
690        # Checks that the arith extension is being registered successfully
691        # (literal_value is a property on the extension class but not on the default OpView).
692        # CHECK: literal value 0
693        print("literal value", constant.literal_value)
694
695        # Checks that "late" registration/replacement (i.e., post all module loading/initialization)
696        # is working correctly.
697        @_cext.register_operation(arith._Dialect, replace=True)
698        class ConstantOp(arith.ConstantOp):
699            def __init__(self, result, value, *, loc=None, ip=None):
700                if isinstance(value, int):
701                    super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
702                elif isinstance(value, float):
703                    super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
704                else:
705                    super().__init__(value, loc=loc, ip=ip)
706
707        constant = module.body.operations[3]
708        # CHECK: <__main__.testKnownOpView.<locals>.ConstantOp object
709        print(repr(constant))
710
711
712# CHECK-LABEL: TEST: testSingleResultProperty
713@run
714def testSingleResultProperty():
715    with Context(), Location.unknown():
716        Context.current.allow_unregistered_dialects = True
717        module = Module.parse(
718            r"""
719      "custom.no_result"() : () -> ()
720      %0:2 = "custom.two_result"() : () -> (f32, f32)
721      %1 = "custom.one_result"() : () -> f32
722    """
723        )
724        print(module)
725
726    try:
727        module.body.operations[0].result
728    except ValueError as e:
729        # CHECK: Cannot call .result on operation custom.no_result which has 0 results
730        print(e)
731    else:
732        assert False, "Expected exception"
733
734    try:
735        module.body.operations[1].result
736    except ValueError as e:
737        # CHECK: Cannot call .result on operation custom.two_result which has 2 results
738        print(e)
739    else:
740        assert False, "Expected exception"
741
742    # CHECK: %1 = "custom.one_result"() : () -> f32
743    print(module.body.operations[2])
744
745
746def create_invalid_operation():
747    # This module has two region and is invalid verify that we fallback
748    # to the generic printer for safety.
749    op = Operation.create("builtin.module", regions=2)
750    op.regions[0].blocks.append()
751    return op
752
753
754# CHECK-LABEL: TEST: testInvalidOperationStrSoftFails
755@run
756def testInvalidOperationStrSoftFails():
757    ctx = Context()
758    with Location.unknown(ctx):
759        invalid_op = create_invalid_operation()
760        # Verify that we fallback to the generic printer for safety.
761        # CHECK: "builtin.module"() ({
762        # CHECK: }) : () -> ()
763        print(invalid_op)
764        try:
765            invalid_op.verify()
766        except MLIRError as e:
767            # CHECK: Exception: <
768            # CHECK:   Verification failed:
769            # CHECK:   error: unknown: 'builtin.module' op requires one region
770            # CHECK:    note: unknown: see current operation:
771            # CHECK:     "builtin.module"() ({
772            # CHECK:     ^bb0:
773            # CHECK:     }, {
774            # CHECK:     }) : () -> ()
775            # CHECK: >
776            print(f"Exception: <{e}>")
777
778
779# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails
780@run
781def testInvalidModuleStrSoftFails():
782    ctx = Context()
783    with Location.unknown(ctx):
784        module = Module.create()
785        with InsertionPoint(module.body):
786            invalid_op = create_invalid_operation()
787        # Verify that we fallback to the generic printer for safety.
788        # CHECK: "builtin.module"() ({
789        # CHECK: }) : () -> ()
790        print(module)
791
792
793# CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails
794@run
795def testInvalidOperationGetAsmBinarySoftFails():
796    ctx = Context()
797    with Location.unknown(ctx):
798        invalid_op = create_invalid_operation()
799        # Verify that we fallback to the generic printer for safety.
800        # CHECK: b'"builtin.module"() ({\n^bb0:\n}, {\n}) : () -> ()\n'
801        print(invalid_op.get_asm(binary=True))
802
803
804# CHECK-LABEL: TEST: testCreateWithInvalidAttributes
805@run
806def testCreateWithInvalidAttributes():
807    ctx = Context()
808    with Location.unknown(ctx):
809        try:
810            Operation.create(
811                "builtin.module", attributes={None: StringAttr.get("name")}
812            )
813        except Exception as e:
814            # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
815            print(e)
816        try:
817            Operation.create("builtin.module", attributes={42: StringAttr.get("name")})
818        except Exception as e:
819            # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
820            print(e)
821        try:
822            Operation.create("builtin.module", attributes={"some_key": ctx})
823        except Exception as e:
824            # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module"
825            print(e)
826        try:
827            Operation.create("builtin.module", attributes={"some_key": None})
828        except Exception as e:
829            # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module"
830            print(e)
831
832
833# CHECK-LABEL: TEST: testOperationName
834@run
835def testOperationName():
836    ctx = Context()
837    ctx.allow_unregistered_dialects = True
838    module = Module.parse(
839        r"""
840    %0 = "custom.op1"() : () -> f32
841    %1 = "custom.op2"() : () -> i32
842    %2 = "custom.op1"() : () -> f32
843  """,
844        ctx,
845    )
846
847    # CHECK: custom.op1
848    # CHECK: custom.op2
849    # CHECK: custom.op1
850    for op in module.body.operations:
851        print(op.operation.name)
852
853
854# CHECK-LABEL: TEST: testCapsuleConversions
855@run
856def testCapsuleConversions():
857    ctx = Context()
858    ctx.allow_unregistered_dialects = True
859    with Location.unknown(ctx):
860        m = Operation.create("custom.op1").operation
861        m_capsule = m._CAPIPtr
862        assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
863        m2 = Operation._CAPICreate(m_capsule)
864        assert m2 is m
865
866
867# CHECK-LABEL: TEST: testOperationErase
868@run
869def testOperationErase():
870    ctx = Context()
871    ctx.allow_unregistered_dialects = True
872    with Location.unknown(ctx):
873        m = Module.create()
874        with InsertionPoint(m.body):
875            op = Operation.create("custom.op1")
876
877            # CHECK: "custom.op1"
878            print(m)
879
880            op.operation.erase()
881
882            # CHECK-NOT: "custom.op1"
883            print(m)
884
885            # Ensure we can create another operation
886            Operation.create("custom.op2")
887
888
889# CHECK-LABEL: TEST: testOperationClone
890@run
891def testOperationClone():
892    ctx = Context()
893    ctx.allow_unregistered_dialects = True
894    with Location.unknown(ctx):
895        m = Module.create()
896        with InsertionPoint(m.body):
897            op = Operation.create("custom.op1")
898
899            # CHECK: "custom.op1"
900            print(m)
901
902            clone = op.operation.clone()
903            op.operation.erase()
904
905            # CHECK: "custom.op1"
906            print(m)
907
908
909# CHECK-LABEL: TEST: testOperationLoc
910@run
911def testOperationLoc():
912    ctx = Context()
913    ctx.allow_unregistered_dialects = True
914    with ctx:
915        loc = Location.name("loc")
916        op = Operation.create("custom.op", loc=loc)
917        assert op.location == loc
918        assert op.operation.location == loc
919
920
921# CHECK-LABEL: TEST: testModuleMerge
922@run
923def testModuleMerge():
924    with Context():
925        m1 = Module.parse("func.func private @foo()")
926        m2 = Module.parse(
927            """
928      func.func private @bar()
929      func.func private @qux()
930    """
931        )
932        foo = m1.body.operations[0]
933        bar = m2.body.operations[0]
934        qux = m2.body.operations[1]
935        bar.move_before(foo)
936        qux.move_after(foo)
937
938        # CHECK: module
939        # CHECK: func private @bar
940        # CHECK: func private @foo
941        # CHECK: func private @qux
942        print(m1)
943
944        # CHECK: module {
945        # CHECK-NEXT: }
946        print(m2)
947
948
949# CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock
950@run
951def testAppendMoveFromAnotherBlock():
952    with Context():
953        m1 = Module.parse("func.func private @foo()")
954        m2 = Module.parse("func.func private @bar()")
955        func = m1.body.operations[0]
956        m2.body.append(func)
957
958        # CHECK: module
959        # CHECK: func private @bar
960        # CHECK: func private @foo
961
962        print(m2)
963        # CHECK: module {
964        # CHECK-NEXT: }
965        print(m1)
966
967
968# CHECK-LABEL: TEST: testDetachFromParent
969@run
970def testDetachFromParent():
971    with Context():
972        m1 = Module.parse("func.func private @foo()")
973        func = m1.body.operations[0].detach_from_parent()
974
975        try:
976            func.detach_from_parent()
977        except ValueError as e:
978            if "has no parent" not in str(e):
979                raise
980        else:
981            assert False, "expected ValueError when detaching a detached operation"
982
983        print(m1)
984        # CHECK-NOT: func private @foo
985
986
987# CHECK-LABEL: TEST: testOperationHash
988@run
989def testOperationHash():
990    ctx = Context()
991    ctx.allow_unregistered_dialects = True
992    with ctx, Location.unknown():
993        op = Operation.create("custom.op1")
994        assert hash(op) == hash(op.operation)
995
996
997# CHECK-LABEL: TEST: testOperationParse
998@run
999def testOperationParse():
1000    with Context() as ctx:
1001        ctx.allow_unregistered_dialects = True
1002
1003        # Generic operation parsing.
1004        m = Operation.parse("module {}")
1005        o = Operation.parse('"test.foo"() : () -> ()')
1006        assert isinstance(m, ModuleOp)
1007        assert type(o) is OpView
1008
1009        # Parsing specific operation.
1010        m = ModuleOp.parse("module {}")
1011        assert isinstance(m, ModuleOp)
1012        try:
1013            ModuleOp.parse('"test.foo"() : () -> ()')
1014        except MLIRError as e:
1015            # CHECK: error: Expected a 'builtin.module' op, got: 'test.foo'
1016            print(f"error: {e}")
1017        else:
1018            assert False, "expected error"
1019
1020        o = Operation.parse('"test.foo"() : () -> ()', source_name="my-source-string")
1021        # CHECK: op_with_source_name: "test.foo"() : () -> () loc("my-source-string":1:1)
1022        print(
1023            f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
1024        )
1025
1026
1027# CHECK-LABEL: TEST: testOpWalk
1028@run
1029def testOpWalk():
1030    ctx = Context()
1031    ctx.allow_unregistered_dialects = True
1032    module = Module.parse(
1033        r"""
1034    builtin.module {
1035      func.func @f() {
1036        func.return
1037      }
1038    }
1039  """,
1040        ctx,
1041    )
1042
1043    def callback(op):
1044        print(op.name)
1045        return WalkResult.ADVANCE
1046
1047    # Test post-order walk (default).
1048    # CHECK-NEXT:  Post-order
1049    # CHECK-NEXT:  func.return
1050    # CHECK-NEXT:  func.func
1051    # CHECK-NEXT:  builtin.module
1052    print("Post-order")
1053    module.operation.walk(callback)
1054
1055    # Test pre-order walk.
1056    # CHECK-NEXT:  Pre-order
1057    # CHECK-NEXT:  builtin.module
1058    # CHECK-NEXT:  func.fun
1059    # CHECK-NEXT:  func.return
1060    print("Pre-order")
1061    module.operation.walk(callback, WalkOrder.PRE_ORDER)
1062
1063    # Test interrput.
1064    # CHECK-NEXT:  Interrupt post-order
1065    # CHECK-NEXT:  func.return
1066    print("Interrupt post-order")
1067
1068    def callback(op):
1069        print(op.name)
1070        return WalkResult.INTERRUPT
1071
1072    module.operation.walk(callback)
1073
1074    # Test skip.
1075    # CHECK-NEXT:  Skip pre-order
1076    # CHECK-NEXT:  builtin.module
1077    print("Skip pre-order")
1078
1079    def callback(op):
1080        print(op.name)
1081        return WalkResult.SKIP
1082
1083    module.operation.walk(callback, WalkOrder.PRE_ORDER)
1084
1085    # Test exception.
1086    # CHECK: Exception
1087    # CHECK-NEXT: func.return
1088    # CHECK-NEXT: Exception raised
1089    print("Exception")
1090
1091    def callback(op):
1092        print(op.name)
1093        raise ValueError
1094        return WalkResult.ADVANCE
1095
1096    try:
1097        module.operation.walk(callback)
1098    except RuntimeError:
1099        print("Exception raised")
1100