xref: /llvm-project/mlir/test/python/ir/blocks.py (revision 55d2fffdae5531759569e4ea8985c3de2e96bcc1)
18e6c55c9SStella Laurenzo# RUN: %PYTHON %s | FileCheck %s
28e6c55c9SStella Laurenzo
38e6c55c9SStella Laurenzoimport gc
48e6c55c9SStella Laurenzoimport io
58e6c55c9SStella Laurenzoimport itertools
68e6c55c9SStella Laurenzofrom mlir.ir import *
78e6c55c9SStella Laurenzofrom mlir.dialects import builtin
8fe23a6fbSStella Laurenzofrom mlir.dialects import cf
923aa5a74SRiver Riddlefrom mlir.dialects import func
108e6c55c9SStella Laurenzo
118e6c55c9SStella Laurenzo
128e6c55c9SStella Laurenzodef run(f):
138e6c55c9SStella Laurenzo    print("\nTEST:", f.__name__)
148e6c55c9SStella Laurenzo    f()
158e6c55c9SStella Laurenzo    gc.collect()
168e6c55c9SStella Laurenzo    assert Context._get_live_count() == 0
178e6c55c9SStella Laurenzo    return f
188e6c55c9SStella Laurenzo
198e6c55c9SStella Laurenzo
208e6c55c9SStella Laurenzo# CHECK-LABEL: TEST: testBlockCreation
21514dddbeSRahul Kayaith# CHECK: func @test(%[[ARG0:.*]]: i32 loc("arg0"), %[[ARG1:.*]]: i16 loc("arg1"))
22ace01605SRiver Riddle# CHECK:   cf.br ^bb1(%[[ARG1]] : i16)
23514dddbeSRahul Kayaith# CHECK: ^bb1(%[[PHI0:.*]]: i16 loc("middle")):
24ace01605SRiver Riddle# CHECK:   cf.br ^bb2(%[[ARG0]] : i32)
25514dddbeSRahul Kayaith# CHECK: ^bb2(%[[PHI1:.*]]: i32 loc("successor")):
268e6c55c9SStella Laurenzo# CHECK:   return
278e6c55c9SStella Laurenzo@run
288e6c55c9SStella Laurenzodef testBlockCreation():
298e6c55c9SStella Laurenzo    with Context() as ctx, Location.unknown():
30514dddbeSRahul Kayaith        module = builtin.ModuleOp()
318e6c55c9SStella Laurenzo        with InsertionPoint(module.body):
328e6c55c9SStella Laurenzo            f_type = FunctionType.get(
33f9008e63STobias Hieta                [IntegerType.get_signless(32), IntegerType.get_signless(16)], []
34f9008e63STobias Hieta            )
3536550692SRiver Riddle            f_op = func.FuncOp("test", f_type)
36f9008e63STobias Hieta            entry_block = f_op.add_entry_block(
37f9008e63STobias Hieta                [Location.name("arg0"), Location.name("arg1")]
38f9008e63STobias Hieta            )
398e6c55c9SStella Laurenzo            i32_arg, i16_arg = entry_block.arguments
40f9008e63STobias Hieta            successor_block = entry_block.create_after(
41f9008e63STobias Hieta                i32_arg.type, arg_locs=[Location.name("successor")]
42f9008e63STobias Hieta            )
438e6c55c9SStella Laurenzo            with InsertionPoint(successor_block) as successor_ip:
448e6c55c9SStella Laurenzo                assert successor_ip.block == successor_block
4523aa5a74SRiver Riddle                func.ReturnOp([])
46f9008e63STobias Hieta            middle_block = successor_block.create_before(
47f9008e63STobias Hieta                i16_arg.type, arg_locs=[Location.name("middle")]
48f9008e63STobias Hieta            )
498e6c55c9SStella Laurenzo
508e6c55c9SStella Laurenzo            with InsertionPoint(entry_block) as entry_ip:
518e6c55c9SStella Laurenzo                assert entry_ip.block == entry_block
52fe23a6fbSStella Laurenzo                cf.BranchOp([i16_arg], dest=middle_block)
538e6c55c9SStella Laurenzo
548e6c55c9SStella Laurenzo            with InsertionPoint(middle_block) as middle_ip:
558e6c55c9SStella Laurenzo                assert middle_ip.block == middle_block
56fe23a6fbSStella Laurenzo                cf.BranchOp([i32_arg], dest=successor_block)
57514dddbeSRahul Kayaith        module.print(enable_debug_info=True)
588e6c55c9SStella Laurenzo        # Ensure region back references are coherent.
598e6c55c9SStella Laurenzo        assert entry_block.region == middle_block.region == successor_block.region
6078f2dae0SAlex Zinenko
6178f2dae0SAlex Zinenko
62514dddbeSRahul Kayaith# CHECK-LABEL: TEST: testBlockCreationArgLocs
63514dddbeSRahul Kayaith@run
64514dddbeSRahul Kayaithdef testBlockCreationArgLocs():
65514dddbeSRahul Kayaith    with Context() as ctx:
66514dddbeSRahul Kayaith        ctx.allow_unregistered_dialects = True
67514dddbeSRahul Kayaith        f32 = F32Type.get()
68514dddbeSRahul Kayaith        op = Operation.create("test", regions=1, loc=Location.unknown())
69514dddbeSRahul Kayaith        blocks = op.regions[0].blocks
70514dddbeSRahul Kayaith
71514dddbeSRahul Kayaith        with Location.name("default_loc"):
72514dddbeSRahul Kayaith            blocks.append(f32)
73514dddbeSRahul Kayaith        blocks.append()
74514dddbeSRahul Kayaith        # CHECK:      ^bb0(%{{.+}}: f32 loc("default_loc")):
75514dddbeSRahul Kayaith        # CHECK-NEXT: ^bb1:
76514dddbeSRahul Kayaith        op.print(enable_debug_info=True)
77514dddbeSRahul Kayaith
78514dddbeSRahul Kayaith        try:
79514dddbeSRahul Kayaith            blocks.append(f32)
80514dddbeSRahul Kayaith        except RuntimeError as err:
81514dddbeSRahul Kayaith            # CHECK: Missing loc: An MLIR function requires a Location but none was provided
82514dddbeSRahul Kayaith            print("Missing loc:", err)
83514dddbeSRahul Kayaith
84514dddbeSRahul Kayaith        try:
85514dddbeSRahul Kayaith            blocks.append(f32, f32, arg_locs=[Location.unknown()])
86514dddbeSRahul Kayaith        except ValueError as err:
87514dddbeSRahul Kayaith            # CHECK: Wrong loc count: Expected 2 locations, got: 1
88514dddbeSRahul Kayaith            print("Wrong loc count:", err)
89514dddbeSRahul Kayaith
90514dddbeSRahul Kayaith
9178f2dae0SAlex Zinenko# CHECK-LABEL: TEST: testFirstBlockCreation
92514dddbeSRahul Kayaith# CHECK: func @test(%{{.*}}: f32 loc("arg_loc"))
9378f2dae0SAlex Zinenko# CHECK:   return
9478f2dae0SAlex Zinenko@run
9578f2dae0SAlex Zinenkodef testFirstBlockCreation():
9678f2dae0SAlex Zinenko    with Context() as ctx, Location.unknown():
97514dddbeSRahul Kayaith        module = builtin.ModuleOp()
9878f2dae0SAlex Zinenko        f32 = F32Type.get()
9978f2dae0SAlex Zinenko        with InsertionPoint(module.body):
10036550692SRiver Riddle            f = func.FuncOp("test", ([f32], []))
101f9008e63STobias Hieta            entry_block = Block.create_at_start(
102f9008e63STobias Hieta                f.operation.regions[0], [f32], [Location.name("arg_loc")]
103f9008e63STobias Hieta            )
10478f2dae0SAlex Zinenko            with InsertionPoint(entry_block):
10523aa5a74SRiver Riddle                func.ReturnOp([])
10678f2dae0SAlex Zinenko
107514dddbeSRahul Kayaith        module.print(enable_debug_info=True)
108514dddbeSRahul Kayaith        assert module.verify()
10923aa5a74SRiver Riddle        assert f.body.blocks[0] == entry_block
1108d8738f6SJohn Demme
1118d8738f6SJohn Demme
1128d8738f6SJohn Demme# CHECK-LABEL: TEST: testBlockMove
1138d8738f6SJohn Demme# CHECK:  %0 = "realop"() ({
1148d8738f6SJohn Demme# CHECK:  ^bb0([[ARG0:%.+]]: f32):
1158d8738f6SJohn Demme# CHECK:    "ret"([[ARG0]]) : (f32) -> ()
1168d8738f6SJohn Demme# CHECK:  }) : () -> f32
1178d8738f6SJohn Demme@run
1188d8738f6SJohn Demmedef testBlockMove():
1198d8738f6SJohn Demme    with Context() as ctx, Location.unknown():
1208d8738f6SJohn Demme        ctx.allow_unregistered_dialects = True
1218d8738f6SJohn Demme        module = Module.create()
1228d8738f6SJohn Demme        f32 = F32Type.get()
1238d8738f6SJohn Demme        with InsertionPoint(module.body):
1248d8738f6SJohn Demme            dummy = Operation.create("dummy", regions=1)
1258d8738f6SJohn Demme            block = Block.create_at_start(dummy.operation.regions[0], [f32])
1268d8738f6SJohn Demme            with InsertionPoint(block):
1278d8738f6SJohn Demme                ret_op = Operation.create("ret", operands=[block.arguments[0]])
128f9008e63STobias Hieta            realop = Operation.create(
129f9008e63STobias Hieta                "realop", results=[r.type for r in ret_op.operands], regions=1
130f9008e63STobias Hieta            )
1318d8738f6SJohn Demme            block.append_to(realop.operation.regions[0])
1328d8738f6SJohn Demme            dummy.operation.erase()
1338d8738f6SJohn Demme        print(module)
134fa45b2fbSMike Urbach
135fa45b2fbSMike Urbach
136fa45b2fbSMike Urbach# CHECK-LABEL: TEST: testBlockHash
137fa45b2fbSMike Urbach@run
138fa45b2fbSMike Urbachdef testBlockHash():
139fa45b2fbSMike Urbach    with Context() as ctx, Location.unknown():
140fa45b2fbSMike Urbach        ctx.allow_unregistered_dialects = True
141fa45b2fbSMike Urbach        module = Module.create()
142fa45b2fbSMike Urbach        f32 = F32Type.get()
143fa45b2fbSMike Urbach        with InsertionPoint(module.body):
144fa45b2fbSMike Urbach            dummy = Operation.create("dummy", regions=1)
145fa45b2fbSMike Urbach            block1 = Block.create_at_start(dummy.operation.regions[0], [f32])
146fa45b2fbSMike Urbach            block2 = Block.create_at_start(dummy.operation.regions[0], [f32])
147fa45b2fbSMike Urbach            assert hash(block1) != hash(block2)
148*55d2fffdSSandeep Dasgupta
149*55d2fffdSSandeep Dasgupta
150*55d2fffdSSandeep Dasgupta# CHECK-LABEL: TEST: testBlockAddArgs
151*55d2fffdSSandeep Dasgupta@run
152*55d2fffdSSandeep Dasguptadef testBlockAddArgs():
153*55d2fffdSSandeep Dasgupta    with Context() as ctx, Location.unknown(ctx) as loc:
154*55d2fffdSSandeep Dasgupta        ctx.allow_unregistered_dialects = True
155*55d2fffdSSandeep Dasgupta        f32 = F32Type.get()
156*55d2fffdSSandeep Dasgupta        op = Operation.create("test", regions=1, loc=Location.unknown())
157*55d2fffdSSandeep Dasgupta        blocks = op.regions[0].blocks
158*55d2fffdSSandeep Dasgupta        blocks.append()
159*55d2fffdSSandeep Dasgupta        # CHECK: ^bb0:
160*55d2fffdSSandeep Dasgupta        op.print(enable_debug_info=True)
161*55d2fffdSSandeep Dasgupta        blocks[0].add_argument(f32, loc)
162*55d2fffdSSandeep Dasgupta        # CHECK: ^bb0(%{{.+}}: f32 loc(unknown)):
163*55d2fffdSSandeep Dasgupta        op.print(enable_debug_info=True)
164*55d2fffdSSandeep Dasgupta
165*55d2fffdSSandeep Dasgupta
166*55d2fffdSSandeep Dasgupta# CHECK-LABEL: TEST: testBlockEraseArgs
167*55d2fffdSSandeep Dasgupta@run
168*55d2fffdSSandeep Dasguptadef testBlockEraseArgs():
169*55d2fffdSSandeep Dasgupta    with Context() as ctx, Location.unknown(ctx) as loc:
170*55d2fffdSSandeep Dasgupta        ctx.allow_unregistered_dialects = True
171*55d2fffdSSandeep Dasgupta        f32 = F32Type.get()
172*55d2fffdSSandeep Dasgupta        op = Operation.create("test", regions=1, loc=Location.unknown())
173*55d2fffdSSandeep Dasgupta        blocks = op.regions[0].blocks
174*55d2fffdSSandeep Dasgupta        blocks.append(f32)
175*55d2fffdSSandeep Dasgupta        # CHECK: ^bb0(%{{.+}}: f32 loc(unknown)):
176*55d2fffdSSandeep Dasgupta        op.print(enable_debug_info=True)
177*55d2fffdSSandeep Dasgupta        blocks[0].erase_argument(0)
178*55d2fffdSSandeep Dasgupta        # CHECK: ^bb0:
179*55d2fffdSSandeep Dasgupta        op.print(enable_debug_info=True)
180