xref: /llvm-project/mlir/test/python/dialects/func.py (revision d898ff650ae09e3ef942592aee2e87627f45d7c6)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import arith
5from mlir.dialects import builtin
6from mlir.dialects import func
7
8
9def constructAndPrintInModule(f):
10    print("\nTEST:", f.__name__)
11    with Context(), Location.unknown():
12        module = Module.create()
13        with InsertionPoint(module.body):
14            f()
15        print(module)
16    return f
17
18
19# CHECK-LABEL: TEST: testConstantOp
20
21
22@constructAndPrintInModule
23def testConstantOp():
24    c1 = arith.ConstantOp(IntegerType.get_signless(32), 42)
25    c2 = arith.ConstantOp(IntegerType.get_signless(64), 100)
26    c3 = arith.ConstantOp(F32Type.get(), 3.14)
27    c4 = arith.ConstantOp(F64Type.get(), 1.23)
28    # CHECK: 42
29    print(c1.literal_value)
30
31    # CHECK: 100
32    print(c2.literal_value)
33
34    # CHECK: 3.140000104904175
35    print(c3.literal_value)
36
37    # CHECK: 1.23
38    print(c4.literal_value)
39
40
41# CHECK: = arith.constant 42 : i32
42# CHECK: = arith.constant 100 : i64
43# CHECK: = arith.constant 3.140000e+00 : f32
44# CHECK: = arith.constant 1.230000e+00 : f64
45
46
47# CHECK-LABEL: TEST: testVectorConstantOp
48@constructAndPrintInModule
49def testVectorConstantOp():
50    int_type = IntegerType.get_signless(32)
51    vec_type = VectorType.get([2, 2], int_type)
52    c1 = arith.ConstantOp(
53        vec_type, DenseElementsAttr.get_splat(vec_type, IntegerAttr.get(int_type, 42))
54    )
55    try:
56        print(c1.literal_value)
57    except ValueError as e:
58        assert "only integer and float constants have literal values" in str(e)
59    else:
60        assert False
61
62
63# CHECK: = arith.constant dense<42> : vector<2x2xi32>
64
65
66# CHECK-LABEL: TEST: testConstantIndexOp
67@constructAndPrintInModule
68def testConstantIndexOp():
69    c1 = arith.ConstantOp.create_index(10)
70    # CHECK: 10
71    print(c1.literal_value)
72
73
74# CHECK: = arith.constant 10 : index
75
76
77# CHECK-LABEL: TEST: testFunctionCalls
78@constructAndPrintInModule
79def testFunctionCalls():
80    foo = func.FuncOp("foo", ([], []))
81    foo.sym_visibility = StringAttr.get("private")
82    bar = func.FuncOp("bar", ([], [IndexType.get()]))
83    bar.sym_visibility = StringAttr.get("private")
84    qux = func.FuncOp("qux", ([], [F32Type.get()]))
85    qux.sym_visibility = StringAttr.get("private")
86
87    con = func.ConstantOp(qux.type, qux.sym_name.value)
88    assert con.type == qux.type
89
90    with InsertionPoint(func.FuncOp("caller", ([], [])).add_entry_block()):
91        func.CallOp(foo, [])
92        func.CallOp([IndexType.get()], "bar", [])
93        func.CallOp([F32Type.get()], FlatSymbolRefAttr.get("qux"), [])
94        func.ReturnOp([])
95
96
97# CHECK: func private @foo()
98# CHECK: func private @bar() -> index
99# CHECK: func private @qux() -> f32
100# CHECK: %f = func.constant @qux : () -> f32
101# CHECK: func @caller() {
102# CHECK:   call @foo() : () -> ()
103# CHECK:   %0 = call @bar() : () -> index
104# CHECK:   %1 = call @qux() : () -> f32
105# CHECK:   return
106# CHECK: }
107
108
109# CHECK-LABEL: TEST: testFunctionArgAttrs
110@constructAndPrintInModule
111def testFunctionArgAttrs():
112    foo = func.FuncOp("foo", ([F32Type.get()], []))
113    foo.sym_visibility = StringAttr.get("private")
114    foo2 = func.FuncOp("foo2", ([F32Type.get(), F32Type.get()], []))
115    foo2.sym_visibility = StringAttr.get("private")
116
117    empty_attr = DictAttr.get({})
118    test_attr = DictAttr.get({"test.foo": StringAttr.get("bar")})
119    test_attr2 = DictAttr.get({"test.baz": StringAttr.get("qux")})
120
121    assert len(foo.arg_attrs) == 1
122    assert foo.arg_attrs[0] == empty_attr
123
124    foo.arg_attrs = [test_attr]
125    assert foo.arg_attrs[0]["test.foo"] == StringAttr.get("bar")
126
127    assert len(foo2.arg_attrs) == 2
128    assert foo2.arg_attrs == ArrayAttr.get([empty_attr, empty_attr])
129
130    foo2.arg_attrs = [empty_attr, test_attr2]
131    assert foo2.arg_attrs == ArrayAttr.get([empty_attr, test_attr2])
132
133
134# CHECK: func private @foo(f32 {test.foo = "bar"})
135# CHECK: func private @foo2(f32, f32  {test.baz = "qux"})
136