xref: /llvm-project/mlir/test/python/dialects/func.py (revision d898ff650ae09e3ef942592aee2e87627f45d7c6)
123aa5a74SRiver Riddle# RUN: %PYTHON %s | FileCheck %s
223aa5a74SRiver Riddle
323aa5a74SRiver Riddlefrom mlir.ir import *
423aa5a74SRiver Riddlefrom mlir.dialects import arith
523aa5a74SRiver Riddlefrom mlir.dialects import builtin
623aa5a74SRiver Riddlefrom mlir.dialects import func
723aa5a74SRiver Riddle
823aa5a74SRiver Riddle
923aa5a74SRiver Riddledef constructAndPrintInModule(f):
1023aa5a74SRiver Riddle    print("\nTEST:", f.__name__)
1123aa5a74SRiver Riddle    with Context(), Location.unknown():
1223aa5a74SRiver Riddle        module = Module.create()
1323aa5a74SRiver Riddle        with InsertionPoint(module.body):
1423aa5a74SRiver Riddle            f()
1523aa5a74SRiver Riddle        print(module)
1623aa5a74SRiver Riddle    return f
1723aa5a74SRiver Riddle
1823aa5a74SRiver Riddle
1923aa5a74SRiver Riddle# CHECK-LABEL: TEST: testConstantOp
2023aa5a74SRiver Riddle
2123aa5a74SRiver Riddle
2223aa5a74SRiver Riddle@constructAndPrintInModule
2323aa5a74SRiver Riddledef testConstantOp():
2423aa5a74SRiver Riddle    c1 = arith.ConstantOp(IntegerType.get_signless(32), 42)
2523aa5a74SRiver Riddle    c2 = arith.ConstantOp(IntegerType.get_signless(64), 100)
2623aa5a74SRiver Riddle    c3 = arith.ConstantOp(F32Type.get(), 3.14)
2723aa5a74SRiver Riddle    c4 = arith.ConstantOp(F64Type.get(), 1.23)
2823aa5a74SRiver Riddle    # CHECK: 42
2923aa5a74SRiver Riddle    print(c1.literal_value)
3023aa5a74SRiver Riddle
3123aa5a74SRiver Riddle    # CHECK: 100
3223aa5a74SRiver Riddle    print(c2.literal_value)
3323aa5a74SRiver Riddle
3423aa5a74SRiver Riddle    # CHECK: 3.140000104904175
3523aa5a74SRiver Riddle    print(c3.literal_value)
3623aa5a74SRiver Riddle
3723aa5a74SRiver Riddle    # CHECK: 1.23
3823aa5a74SRiver Riddle    print(c4.literal_value)
3923aa5a74SRiver Riddle
4023aa5a74SRiver Riddle
4123aa5a74SRiver Riddle# CHECK: = arith.constant 42 : i32
4223aa5a74SRiver Riddle# CHECK: = arith.constant 100 : i64
4323aa5a74SRiver Riddle# CHECK: = arith.constant 3.140000e+00 : f32
4423aa5a74SRiver Riddle# CHECK: = arith.constant 1.230000e+00 : f64
4523aa5a74SRiver Riddle
4623aa5a74SRiver Riddle
4723aa5a74SRiver Riddle# CHECK-LABEL: TEST: testVectorConstantOp
4823aa5a74SRiver Riddle@constructAndPrintInModule
4923aa5a74SRiver Riddledef testVectorConstantOp():
5023aa5a74SRiver Riddle    int_type = IntegerType.get_signless(32)
5123aa5a74SRiver Riddle    vec_type = VectorType.get([2, 2], int_type)
5223aa5a74SRiver Riddle    c1 = arith.ConstantOp(
53f9008e63STobias Hieta        vec_type, DenseElementsAttr.get_splat(vec_type, IntegerAttr.get(int_type, 42))
54f9008e63STobias Hieta    )
5523aa5a74SRiver Riddle    try:
5623aa5a74SRiver Riddle        print(c1.literal_value)
5723aa5a74SRiver Riddle    except ValueError as e:
5823aa5a74SRiver Riddle        assert "only integer and float constants have literal values" in str(e)
5923aa5a74SRiver Riddle    else:
6023aa5a74SRiver Riddle        assert False
6123aa5a74SRiver Riddle
6223aa5a74SRiver Riddle
6323aa5a74SRiver Riddle# CHECK: = arith.constant dense<42> : vector<2x2xi32>
6423aa5a74SRiver Riddle
6523aa5a74SRiver Riddle
6623aa5a74SRiver Riddle# CHECK-LABEL: TEST: testConstantIndexOp
6723aa5a74SRiver Riddle@constructAndPrintInModule
6823aa5a74SRiver Riddledef testConstantIndexOp():
6923aa5a74SRiver Riddle    c1 = arith.ConstantOp.create_index(10)
7023aa5a74SRiver Riddle    # CHECK: 10
7123aa5a74SRiver Riddle    print(c1.literal_value)
7223aa5a74SRiver Riddle
7323aa5a74SRiver Riddle
7423aa5a74SRiver Riddle# CHECK: = arith.constant 10 : index
7523aa5a74SRiver Riddle
7623aa5a74SRiver Riddle
7723aa5a74SRiver Riddle# CHECK-LABEL: TEST: testFunctionCalls
7823aa5a74SRiver Riddle@constructAndPrintInModule
7923aa5a74SRiver Riddledef testFunctionCalls():
8036550692SRiver Riddle    foo = func.FuncOp("foo", ([], []))
8123aa5a74SRiver Riddle    foo.sym_visibility = StringAttr.get("private")
8236550692SRiver Riddle    bar = func.FuncOp("bar", ([], [IndexType.get()]))
8323aa5a74SRiver Riddle    bar.sym_visibility = StringAttr.get("private")
8436550692SRiver Riddle    qux = func.FuncOp("qux", ([], [F32Type.get()]))
8523aa5a74SRiver Riddle    qux.sym_visibility = StringAttr.get("private")
8623aa5a74SRiver Riddle
87dd473f1dSMaksim Levental    con = func.ConstantOp(qux.type, qux.sym_name.value)
88dd473f1dSMaksim Levental    assert con.type == qux.type
89dd473f1dSMaksim Levental
9036550692SRiver Riddle    with InsertionPoint(func.FuncOp("caller", ([], [])).add_entry_block()):
9123aa5a74SRiver Riddle        func.CallOp(foo, [])
9223aa5a74SRiver Riddle        func.CallOp([IndexType.get()], "bar", [])
9323aa5a74SRiver Riddle        func.CallOp([F32Type.get()], FlatSymbolRefAttr.get("qux"), [])
9423aa5a74SRiver Riddle        func.ReturnOp([])
9523aa5a74SRiver Riddle
9623aa5a74SRiver Riddle
9723aa5a74SRiver Riddle# CHECK: func private @foo()
9823aa5a74SRiver Riddle# CHECK: func private @bar() -> index
9923aa5a74SRiver Riddle# CHECK: func private @qux() -> f32
100dd473f1dSMaksim Levental# CHECK: %f = func.constant @qux : () -> f32
10123aa5a74SRiver Riddle# CHECK: func @caller() {
10223aa5a74SRiver Riddle# CHECK:   call @foo() : () -> ()
10323aa5a74SRiver Riddle# CHECK:   %0 = call @bar() : () -> index
10423aa5a74SRiver Riddle# CHECK:   %1 = call @qux() : () -> f32
10523aa5a74SRiver Riddle# CHECK:   return
10623aa5a74SRiver Riddle# CHECK: }
107*d898ff65SPerry Gibson
108*d898ff65SPerry Gibson
109*d898ff65SPerry Gibson# CHECK-LABEL: TEST: testFunctionArgAttrs
110*d898ff65SPerry Gibson@constructAndPrintInModule
111*d898ff65SPerry Gibsondef testFunctionArgAttrs():
112*d898ff65SPerry Gibson    foo = func.FuncOp("foo", ([F32Type.get()], []))
113*d898ff65SPerry Gibson    foo.sym_visibility = StringAttr.get("private")
114*d898ff65SPerry Gibson    foo2 = func.FuncOp("foo2", ([F32Type.get(), F32Type.get()], []))
115*d898ff65SPerry Gibson    foo2.sym_visibility = StringAttr.get("private")
116*d898ff65SPerry Gibson
117*d898ff65SPerry Gibson    empty_attr = DictAttr.get({})
118*d898ff65SPerry Gibson    test_attr = DictAttr.get({"test.foo": StringAttr.get("bar")})
119*d898ff65SPerry Gibson    test_attr2 = DictAttr.get({"test.baz": StringAttr.get("qux")})
120*d898ff65SPerry Gibson
121*d898ff65SPerry Gibson    assert len(foo.arg_attrs) == 1
122*d898ff65SPerry Gibson    assert foo.arg_attrs[0] == empty_attr
123*d898ff65SPerry Gibson
124*d898ff65SPerry Gibson    foo.arg_attrs = [test_attr]
125*d898ff65SPerry Gibson    assert foo.arg_attrs[0]["test.foo"] == StringAttr.get("bar")
126*d898ff65SPerry Gibson
127*d898ff65SPerry Gibson    assert len(foo2.arg_attrs) == 2
128*d898ff65SPerry Gibson    assert foo2.arg_attrs == ArrayAttr.get([empty_attr, empty_attr])
129*d898ff65SPerry Gibson
130*d898ff65SPerry Gibson    foo2.arg_attrs = [empty_attr, test_attr2]
131*d898ff65SPerry Gibson    assert foo2.arg_attrs == ArrayAttr.get([empty_attr, test_attr2])
132*d898ff65SPerry Gibson
133*d898ff65SPerry Gibson
134*d898ff65SPerry Gibson# CHECK: func private @foo(f32 {test.foo = "bar"})
135*d898ff65SPerry Gibson# CHECK: func private @foo2(f32, f32  {test.baz = "qux"})
136