xref: /llvm-project/mlir/test/python/ir/value.py (revision 21df32511b558b2c1e24fe23f677fffaad4da333)
1# RUN: %PYTHON %s | FileCheck %s --enable-var-scope=false
2
3import gc
4from mlir.ir import *
5from mlir.dialects import func
6
7
8def run(f):
9    print("\nTEST:", f.__name__)
10    f()
11    gc.collect()
12    assert Context._get_live_count() == 0
13    return f
14
15
16# CHECK-LABEL: TEST: testCapsuleConversions
17@run
18def testCapsuleConversions():
19    ctx = Context()
20    ctx.allow_unregistered_dialects = True
21    with Location.unknown(ctx):
22        i32 = IntegerType.get_signless(32)
23        value = Operation.create("custom.op1", results=[i32]).result
24        value_capsule = value._CAPIPtr
25        assert '"mlir.ir.Value._CAPIPtr"' in repr(value_capsule)
26        value2 = Value._CAPICreate(value_capsule)
27        assert value2 == value
28
29
30# CHECK-LABEL: TEST: testOpResultOwner
31@run
32def testOpResultOwner():
33    ctx = Context()
34    ctx.allow_unregistered_dialects = True
35    with Location.unknown(ctx):
36        i32 = IntegerType.get_signless(32)
37        op = Operation.create("custom.op1", results=[i32])
38        assert op.result.owner == op
39
40
41# CHECK-LABEL: TEST: testBlockArgOwner
42@run
43def testBlockArgOwner():
44    ctx = Context()
45    ctx.allow_unregistered_dialects = True
46    module = Module.parse(
47        r"""
48    func.func @foo(%arg0: f32) {
49      return
50    }""",
51        ctx,
52    )
53    func = module.body.operations[0]
54    block = func.regions[0].blocks[0]
55    assert block.arguments[0].owner == block
56
57
58# CHECK-LABEL: TEST: testValueIsInstance
59@run
60def testValueIsInstance():
61    ctx = Context()
62    ctx.allow_unregistered_dialects = True
63    module = Module.parse(
64        r"""
65    func.func @foo(%arg0: f32) {
66      %0 = "some_dialect.some_op"() : () -> f64
67      return
68    }""",
69        ctx,
70    )
71    func = module.body.operations[0]
72    assert BlockArgument.isinstance(func.regions[0].blocks[0].arguments[0])
73    assert not OpResult.isinstance(func.regions[0].blocks[0].arguments[0])
74
75    op = func.regions[0].blocks[0].operations[0]
76    assert not BlockArgument.isinstance(op.results[0])
77    assert OpResult.isinstance(op.results[0])
78
79
80# CHECK-LABEL: TEST: testValueHash
81@run
82def testValueHash():
83    ctx = Context()
84    ctx.allow_unregistered_dialects = True
85    module = Module.parse(
86        r"""
87    func.func @foo(%arg0: f32) -> f32 {
88      %0 = "some_dialect.some_op"(%arg0) : (f32) -> f32
89      return %0 : f32
90    }""",
91        ctx,
92    )
93
94    [func] = module.body.operations
95    block = func.entry_block
96    op, ret = block.operations
97    assert hash(block.arguments[0]) == hash(op.operands[0])
98    assert hash(op.result) == hash(ret.operands[0])
99
100
101# CHECK-LABEL: TEST: testValueUses
102@run
103def testValueUses():
104    ctx = Context()
105    ctx.allow_unregistered_dialects = True
106    with Location.unknown(ctx):
107        i32 = IntegerType.get_signless(32)
108        module = Module.create()
109        with InsertionPoint(module.body):
110            value = Operation.create("custom.op1", results=[i32]).results[0]
111            op1 = Operation.create("custom.op2", operands=[value])
112            op2 = Operation.create("custom.op2", operands=[value])
113
114    # CHECK: Use owner: "custom.op2"
115    # CHECK: Use operand_number: 0
116    # CHECK: Use owner: "custom.op2"
117    # CHECK: Use operand_number: 0
118    for use in value.uses:
119        assert use.owner in [op1, op2]
120        print(f"Use owner: {use.owner}")
121        print(f"Use operand_number: {use.operand_number}")
122
123
124# CHECK-LABEL: TEST: testValueReplaceAllUsesWith
125@run
126def testValueReplaceAllUsesWith():
127    ctx = Context()
128    ctx.allow_unregistered_dialects = True
129    with Location.unknown(ctx):
130        i32 = IntegerType.get_signless(32)
131        module = Module.create()
132        with InsertionPoint(module.body):
133            value = Operation.create("custom.op1", results=[i32]).results[0]
134            op1 = Operation.create("custom.op2", operands=[value])
135            op2 = Operation.create("custom.op2", operands=[value])
136            value2 = Operation.create("custom.op3", results=[i32]).results[0]
137            value.replace_all_uses_with(value2)
138
139    assert len(list(value.uses)) == 0
140
141    # CHECK: Use owner: "custom.op2"
142    # CHECK: Use operand_number: 0
143    # CHECK: Use owner: "custom.op2"
144    # CHECK: Use operand_number: 0
145    for use in value2.uses:
146        assert use.owner in [op1, op2]
147        print(f"Use owner: {use.owner}")
148        print(f"Use operand_number: {use.operand_number}")
149
150
151# CHECK-LABEL: TEST: testValueReplaceAllUsesWithExcept
152@run
153def testValueReplaceAllUsesWithExcept():
154    ctx = Context()
155    ctx.allow_unregistered_dialects = True
156    with Location.unknown(ctx):
157        i32 = IntegerType.get_signless(32)
158        module = Module.create()
159        with InsertionPoint(module.body):
160            value = Operation.create("custom.op1", results=[i32]).results[0]
161            op1 = Operation.create("custom.op1", operands=[value])
162            op2 = Operation.create("custom.op2", operands=[value])
163            value2 = Operation.create("custom.op3", results=[i32]).results[0]
164            value.replace_all_uses_except(value2, op1)
165
166    assert len(list(value.uses)) == 1
167
168    # CHECK: Use owner: "custom.op2"
169    # CHECK: Use operand_number: 0
170    for use in value2.uses:
171        assert use.owner in [op2]
172        print(f"Use owner: {use.owner}")
173        print(f"Use operand_number: {use.operand_number}")
174
175    # CHECK: Use owner: "custom.op1"
176    # CHECK: Use operand_number: 0
177    for use in value.uses:
178        assert use.owner in [op1]
179        print(f"Use owner: {use.owner}")
180        print(f"Use operand_number: {use.operand_number}")
181
182
183# CHECK-LABEL: TEST: testValueReplaceAllUsesWithMultipleExceptions
184@run
185def testValueReplaceAllUsesWithMultipleExceptions():
186    ctx = Context()
187    ctx.allow_unregistered_dialects = True
188    with Location.unknown(ctx):
189        i32 = IntegerType.get_signless(32)
190        module = Module.create()
191        with InsertionPoint(module.body):
192            value = Operation.create("custom.op1", results=[i32]).results[0]
193            op1 = Operation.create("custom.op1", operands=[value])
194            op2 = Operation.create("custom.op2", operands=[value])
195            op3 = Operation.create("custom.op3", operands=[value])
196            value2 = Operation.create("custom.op4", results=[i32]).results[0]
197
198            # Replace all uses of `value` with `value2`, except for `op1` and `op2`.
199            value.replace_all_uses_except(value2, [op1, op2])
200
201    # After replacement, only `op3` should use `value2`, while `op1` and `op2` should still use `value`.
202    assert len(list(value.uses)) == 2
203    assert len(list(value2.uses)) == 1
204
205    # CHECK: Use owner: "custom.op3"
206    # CHECK: Use operand_number: 0
207    for use in value2.uses:
208        assert use.owner in [op3]
209        print(f"Use owner: {use.owner}")
210        print(f"Use operand_number: {use.operand_number}")
211
212    # CHECK: Use owner: "custom.op2"
213    # CHECK: Use operand_number: 0
214    # CHECK: Use owner: "custom.op1"
215    # CHECK: Use operand_number: 0
216    for use in value.uses:
217        assert use.owner in [op1, op2]
218        print(f"Use owner: {use.owner}")
219        print(f"Use operand_number: {use.operand_number}")
220
221
222# CHECK-LABEL: TEST: testValuePrintAsOperand
223@run
224def testValuePrintAsOperand():
225    ctx = Context()
226    ctx.allow_unregistered_dialects = True
227    with Location.unknown(ctx):
228        i32 = IntegerType.get_signless(32)
229        module = Module.create()
230        with InsertionPoint(module.body):
231            value = Operation.create("custom.op1", results=[i32]).results[0]
232            # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
233            print(value)
234
235            value2 = Operation.create("custom.op2", results=[i32]).results[0]
236            # CHECK: Value(%[[VAL2:.*]] = "custom.op2"() : () -> i32)
237            print(value2)
238
239            topFn = func.FuncOp("test", ([i32, i32], []))
240            entry_block = Block.create_at_start(topFn.operation.regions[0], [i32, i32])
241
242            with InsertionPoint(entry_block):
243                value3 = Operation.create("custom.op3", results=[i32]).results[0]
244                # CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32)
245                print(value3)
246                value4 = Operation.create("custom.op4", results=[i32]).results[0]
247                # CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32)
248                print(value4)
249                func.ReturnOp([])
250
251        # CHECK: %[[VAL1]]
252        print(value.get_name())
253        # CHECK: %[[VAL2]]
254        print(value2.get_name())
255        # CHECK: %[[VAL3]]
256        print(value3.get_name())
257        # CHECK: %[[VAL4]]
258        print(value4.get_name())
259
260        print("With AsmState")
261        # CHECK-LABEL: With AsmState
262        state = AsmState(topFn.operation, use_local_scope=True)
263        # CHECK: %0
264        print(value3.get_name(state=state))
265        # CHECK: %1
266        print(value4.get_name(state=state))
267
268        print("With use_local_scope")
269        # CHECK-LABEL: With use_local_scope
270        # CHECK: %0
271        print(value3.get_name(use_local_scope=True))
272        # CHECK: %1
273        print(value4.get_name(use_local_scope=True))
274
275        # CHECK: %[[ARG0:.*]]
276        print(entry_block.arguments[0].get_name())
277        # CHECK: %[[ARG1:.*]]
278        print(entry_block.arguments[1].get_name())
279
280        # CHECK: module {
281        # CHECK:   %[[VAL1]] = "custom.op1"() : () -> i32
282        # CHECK:   %[[VAL2]] = "custom.op2"() : () -> i32
283        # CHECK:   func.func @test(%[[ARG0]]: i32, %[[ARG1]]: i32) {
284        # CHECK:     %[[VAL3]] = "custom.op3"() : () -> i32
285        # CHECK:     %[[VAL4]] = "custom.op4"() : () -> i32
286        # CHECK:     return
287        # CHECK:   }
288        # CHECK: }
289        print(module)
290
291        value2.owner.detach_from_parent()
292        # CHECK: %0
293        print(value2.get_name())
294
295
296# CHECK-LABEL: TEST: testValueSetType
297@run
298def testValueSetType():
299    ctx = Context()
300    ctx.allow_unregistered_dialects = True
301    with Location.unknown(ctx):
302        i32 = IntegerType.get_signless(32)
303        i64 = IntegerType.get_signless(64)
304        module = Module.create()
305        with InsertionPoint(module.body):
306            value = Operation.create("custom.op1", results=[i32]).results[0]
307            # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
308            print(value)
309
310            value.set_type(i64)
311            # CHECK: Value(%[[VAL1]] = "custom.op1"() : () -> i64)
312            print(value)
313
314            # CHECK: %[[VAL1]] = "custom.op1"() : () -> i64
315            print(value.owner)
316
317
318# CHECK-LABEL: TEST: testValueCasters
319@run
320def testValueCasters():
321    class NOPResult(OpResult):
322        def __init__(self, v):
323            super().__init__(v)
324
325        def __str__(self):
326            return super().__str__().replace(Value.__name__, NOPResult.__name__)
327
328    class NOPValue(Value):
329        def __init__(self, v):
330            super().__init__(v)
331
332        def __str__(self):
333            return super().__str__().replace(Value.__name__, NOPValue.__name__)
334
335    class NOPBlockArg(BlockArgument):
336        def __init__(self, v):
337            super().__init__(v)
338
339        def __str__(self):
340            return super().__str__().replace(Value.__name__, NOPBlockArg.__name__)
341
342    @register_value_caster(IntegerType.static_typeid)
343    def cast_int(v) -> Value:
344        print("in caster", v.__class__.__name__)
345        if isinstance(v, OpResult):
346            return NOPResult(v)
347        if isinstance(v, BlockArgument):
348            return NOPBlockArg(v)
349        elif isinstance(v, Value):
350            return NOPValue(v)
351
352    ctx = Context()
353    ctx.allow_unregistered_dialects = True
354    with Location.unknown(ctx):
355        i32 = IntegerType.get_signless(32)
356        module = Module.create()
357        with InsertionPoint(module.body):
358            values = Operation.create("custom.op1", results=[i32, i32]).results
359            # CHECK: in caster OpResult
360            # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
361            print("result", values[0].result_number, values[0])
362            # CHECK: in caster OpResult
363            # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
364            print("result", values[1].result_number, values[1])
365
366            # CHECK: results slice 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
367            print("results slice", values[:1][0].result_number, values[:1][0])
368
369            value0, value1 = values
370            # CHECK: in caster OpResult
371            # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
372            print("result", value0.result_number, values[0])
373            # CHECK: in caster OpResult
374            # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
375            print("result", value1.result_number, values[1])
376
377            op1 = Operation.create("custom.op2", operands=[value0, value1])
378            # CHECK: "custom.op2"(%0#0, %0#1) : (i32, i32) -> ()
379            print(op1)
380
381            # CHECK: in caster Value
382            # CHECK: operand 0 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
383            print("operand 0", op1.operands[0])
384            # CHECK: in caster Value
385            # CHECK: operand 1 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
386            print("operand 1", op1.operands[1])
387
388            # CHECK: in caster BlockArgument
389            # CHECK: in caster BlockArgument
390            @func.FuncOp.from_py_func(i32, i32)
391            def reduction(arg0, arg1):
392                # CHECK: as func arg 0 NOPBlockArg
393                print("as func arg", arg0.arg_number, arg0.__class__.__name__)
394                # CHECK: as func arg 1 NOPBlockArg
395                print("as func arg", arg1.arg_number, arg1.__class__.__name__)
396
397            # CHECK: args slice 0 NOPBlockArg(<block argument> of type 'i32' at index: 0)
398            print(
399                "args slice",
400                reduction.func_op.arguments[:1][0].arg_number,
401                reduction.func_op.arguments[:1][0],
402            )
403
404    try:
405
406        @register_value_caster(IntegerType.static_typeid)
407        def dont_cast_int_shouldnt_register(v):
408            ...
409
410    except RuntimeError as e:
411        # CHECK: Value caster is already registered: {{.*}}cast_int
412        print(e)
413
414    @register_value_caster(IntegerType.static_typeid, replace=True)
415    def dont_cast_int(v) -> OpResult:
416        assert isinstance(v, OpResult)
417        print("don't cast", v.result_number, v)
418        return v
419
420    with Location.unknown(ctx):
421        i32 = IntegerType.get_signless(32)
422        module = Module.create()
423        with InsertionPoint(module.body):
424            # CHECK: don't cast 0 Value(%0 = "custom.op1"() : () -> i32)
425            new_value = Operation.create("custom.op1", results=[i32]).result
426            # CHECK: result 0 Value(%0 = "custom.op1"() : () -> i32)
427            print("result", new_value.result_number, new_value)
428
429            # CHECK: don't cast 0 Value(%1 = "custom.op2"() : () -> i32)
430            new_value = Operation.create("custom.op2", results=[i32]).results[0]
431            # CHECK: result 0 Value(%1 = "custom.op2"() : () -> i32)
432            print("result", new_value.result_number, new_value)
433