xref: /llvm-project/mlir/test/python/dialects/builtin.py (revision 5d3ae5161210c068d01ffba36c8e0761e9971179)
19f3f6d7bSStella Laurenzo# RUN: %PYTHON %s | FileCheck %s
29f3f6d7bSStella Laurenzo
39f3f6d7bSStella Laurenzofrom mlir.ir import *
49f3f6d7bSStella Laurenzoimport mlir.dialects.builtin as builtin
523aa5a74SRiver Riddleimport mlir.dialects.func as func
699dee31eSAdam Paszkeimport numpy as np
79f3f6d7bSStella Laurenzo
89f3f6d7bSStella Laurenzo
99f3f6d7bSStella Laurenzodef run(f):
109f3f6d7bSStella Laurenzo    print("\nTEST:", f.__name__)
119f3f6d7bSStella Laurenzo    f()
129f3f6d7bSStella Laurenzo    return f
139f3f6d7bSStella Laurenzo
149f3f6d7bSStella Laurenzo
159f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testFromPyFunc
169f3f6d7bSStella Laurenzo@run
179f3f6d7bSStella Laurenzodef testFromPyFunc():
189f3f6d7bSStella Laurenzo    with Context() as ctx, Location.unknown() as loc:
19f431d387SMehdi Amini        ctx.allow_unregistered_dialects = True
209f3f6d7bSStella Laurenzo        m = builtin.ModuleOp()
219f3f6d7bSStella Laurenzo        f32 = F32Type.get()
229f3f6d7bSStella Laurenzo        f64 = F64Type.get()
239f3f6d7bSStella Laurenzo        with InsertionPoint(m.body):
249f3f6d7bSStella Laurenzo            # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
259f3f6d7bSStella Laurenzo            # CHECK: return %arg0 : f64
2636550692SRiver Riddle            @func.FuncOp.from_py_func(f64)
279f3f6d7bSStella Laurenzo            def unary_return(a):
289f3f6d7bSStella Laurenzo                return a
299f3f6d7bSStella Laurenzo
309f3f6d7bSStella Laurenzo            # CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64)
319f3f6d7bSStella Laurenzo            # CHECK: return %arg0, %arg1 : f32, f64
3236550692SRiver Riddle            @func.FuncOp.from_py_func(f32, f64)
339f3f6d7bSStella Laurenzo            def binary_return(a, b):
349f3f6d7bSStella Laurenzo                return a, b
359f3f6d7bSStella Laurenzo
369f3f6d7bSStella Laurenzo            # CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64)
379f3f6d7bSStella Laurenzo            # CHECK: return
3836550692SRiver Riddle            @func.FuncOp.from_py_func(f32, f64)
399f3f6d7bSStella Laurenzo            def none_return(a, b):
409f3f6d7bSStella Laurenzo                pass
419f3f6d7bSStella Laurenzo
429f3f6d7bSStella Laurenzo            # CHECK-LABEL: func @call_unary
439f3f6d7bSStella Laurenzo            # CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64
449f3f6d7bSStella Laurenzo            # CHECK: return %0 : f64
4536550692SRiver Riddle            @func.FuncOp.from_py_func(f64)
469f3f6d7bSStella Laurenzo            def call_unary(a):
479f3f6d7bSStella Laurenzo                return unary_return(a)
489f3f6d7bSStella Laurenzo
499f3f6d7bSStella Laurenzo            # CHECK-LABEL: func @call_binary
509f3f6d7bSStella Laurenzo            # CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64)
519f3f6d7bSStella Laurenzo            # CHECK: return %0#0, %0#1 : f32, f64
5236550692SRiver Riddle            @func.FuncOp.from_py_func(f32, f64)
539f3f6d7bSStella Laurenzo            def call_binary(a, b):
549f3f6d7bSStella Laurenzo                return binary_return(a, b)
559f3f6d7bSStella Laurenzo
56f431d387SMehdi Amini            # We expect coercion of a single result operation to a returned value.
57f431d387SMehdi Amini            # CHECK-LABEL: func @single_result_op
58f431d387SMehdi Amini            # CHECK: %0 = "custom.op1"() : () -> f32
59f431d387SMehdi Amini            # CHECK: return %0 : f32
6036550692SRiver Riddle            @func.FuncOp.from_py_func()
61f431d387SMehdi Amini            def single_result_op():
62f431d387SMehdi Amini                return Operation.create("custom.op1", results=[f32])
63f431d387SMehdi Amini
649f3f6d7bSStella Laurenzo            # CHECK-LABEL: func @call_none
659f3f6d7bSStella Laurenzo            # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
669f3f6d7bSStella Laurenzo            # CHECK: return
6736550692SRiver Riddle            @func.FuncOp.from_py_func(f32, f64)
689f3f6d7bSStella Laurenzo            def call_none(a, b):
699f3f6d7bSStella Laurenzo                return none_return(a, b)
709f3f6d7bSStella Laurenzo
719f3f6d7bSStella Laurenzo            ## Variants and optional feature tests.
729f3f6d7bSStella Laurenzo            # CHECK-LABEL: func @from_name_arg
7336550692SRiver Riddle            @func.FuncOp.from_py_func(f32, f64, name="from_name_arg")
749f3f6d7bSStella Laurenzo            def explicit_name(a, b):
759f3f6d7bSStella Laurenzo                return b
769f3f6d7bSStella Laurenzo
7736550692SRiver Riddle            @func.FuncOp.from_py_func(f32, f64)
789f3f6d7bSStella Laurenzo            def positional_func_op(a, b, func_op):
7936550692SRiver Riddle                assert isinstance(func_op, func.FuncOp)
809f3f6d7bSStella Laurenzo                return b
819f3f6d7bSStella Laurenzo
8236550692SRiver Riddle            @func.FuncOp.from_py_func(f32, f64)
839f3f6d7bSStella Laurenzo            def kw_func_op(a, b=None, func_op=None):
8436550692SRiver Riddle                assert isinstance(func_op, func.FuncOp)
859f3f6d7bSStella Laurenzo                return b
869f3f6d7bSStella Laurenzo
8736550692SRiver Riddle            @func.FuncOp.from_py_func(f32, f64)
889f3f6d7bSStella Laurenzo            def kwargs_func_op(a, b=None, **kwargs):
8936550692SRiver Riddle                assert isinstance(kwargs["func_op"], func.FuncOp)
909f3f6d7bSStella Laurenzo                return b
919f3f6d7bSStella Laurenzo
929f3f6d7bSStella Laurenzo            # CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64
939f3f6d7bSStella Laurenzo            # CHECK: return %arg1 : f64
9436550692SRiver Riddle            @func.FuncOp.from_py_func(f32, f64, results=[f64])
959f3f6d7bSStella Laurenzo            def explicit_results(a, b):
9623aa5a74SRiver Riddle                func.ReturnOp([b])
979f3f6d7bSStella Laurenzo
989f3f6d7bSStella Laurenzo    print(m)
999f3f6d7bSStella Laurenzo
1009f3f6d7bSStella Laurenzo
1019f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testFromPyFuncErrors
1029f3f6d7bSStella Laurenzo@run
1039f3f6d7bSStella Laurenzodef testFromPyFuncErrors():
1049f3f6d7bSStella Laurenzo    with Context() as ctx, Location.unknown() as loc:
1059f3f6d7bSStella Laurenzo        m = builtin.ModuleOp()
1069f3f6d7bSStella Laurenzo        f32 = F32Type.get()
1079f3f6d7bSStella Laurenzo        f64 = F64Type.get()
1089f3f6d7bSStella Laurenzo        with InsertionPoint(m.body):
1099f3f6d7bSStella Laurenzo            try:
1109f3f6d7bSStella Laurenzo
11136550692SRiver Riddle                @func.FuncOp.from_py_func(f64, results=[f64])
1129f3f6d7bSStella Laurenzo                def unary_return(a):
1139f3f6d7bSStella Laurenzo                    return a
114f9008e63STobias Hieta
1159f3f6d7bSStella Laurenzo            except AssertionError as e:
1169f3f6d7bSStella Laurenzo                # CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None.
1179f3f6d7bSStella Laurenzo                print(e)
1189f3f6d7bSStella Laurenzo
1199f3f6d7bSStella Laurenzo
1209f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testBuildFuncOp
1219f3f6d7bSStella Laurenzo@run
1229f3f6d7bSStella Laurenzodef testBuildFuncOp():
1239f3f6d7bSStella Laurenzo    ctx = Context()
1249f3f6d7bSStella Laurenzo    with Location.unknown(ctx) as loc:
1259f3f6d7bSStella Laurenzo        m = builtin.ModuleOp()
1269f3f6d7bSStella Laurenzo
1279f3f6d7bSStella Laurenzo        f32 = F32Type.get()
1289f3f6d7bSStella Laurenzo        tensor_type = RankedTensorType.get((2, 3, 4), f32)
1299f3f6d7bSStella Laurenzo        with InsertionPoint.at_block_begin(m.body):
130f9008e63STobias Hieta            f = func.FuncOp(
131f9008e63STobias Hieta                name="some_func",
1329f3f6d7bSStella Laurenzo                type=FunctionType.get(
133f9008e63STobias Hieta                    inputs=[tensor_type, tensor_type], results=[tensor_type]
134f9008e63STobias Hieta                ),
135f9008e63STobias Hieta                visibility="nested",
136f9008e63STobias Hieta            )
1379f533548SIngo Müller            # CHECK: Name is: "some_func"
13823aa5a74SRiver Riddle            print("Name is: ", f.name)
1399f3f6d7bSStella Laurenzo
1409f3f6d7bSStella Laurenzo            # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
14123aa5a74SRiver Riddle            print("Type is: ", f.type)
1429f3f6d7bSStella Laurenzo
1439f533548SIngo Müller            # CHECK: Visibility is: "nested"
14423aa5a74SRiver Riddle            print("Visibility is: ", f.visibility)
1459f3f6d7bSStella Laurenzo
1469f3f6d7bSStella Laurenzo            try:
14723aa5a74SRiver Riddle                entry_block = f.entry_block
1489f3f6d7bSStella Laurenzo            except IndexError as e:
1499f3f6d7bSStella Laurenzo                # CHECK: External function does not have a body
1509f3f6d7bSStella Laurenzo                print(e)
1519f3f6d7bSStella Laurenzo
15223aa5a74SRiver Riddle            with InsertionPoint(f.add_entry_block()):
15323aa5a74SRiver Riddle                func.ReturnOp([f.entry_block.arguments[0]])
1549f3f6d7bSStella Laurenzo                pass
1559f3f6d7bSStella Laurenzo
1569f3f6d7bSStella Laurenzo            try:
15723aa5a74SRiver Riddle                f.add_entry_block()
1589f3f6d7bSStella Laurenzo            except IndexError as e:
1599f3f6d7bSStella Laurenzo                # CHECK: The function already has an entry block!
1609f3f6d7bSStella Laurenzo                print(e)
1619f3f6d7bSStella Laurenzo
1629f3f6d7bSStella Laurenzo            # Try the callback builder and passing type as tuple.
163f9008e63STobias Hieta            f = func.FuncOp(
164f9008e63STobias Hieta                name="some_other_func",
1659f3f6d7bSStella Laurenzo                type=([tensor_type, tensor_type], [tensor_type]),
1669f3f6d7bSStella Laurenzo                visibility="nested",
167f9008e63STobias Hieta                body_builder=lambda f: func.ReturnOp([f.entry_block.arguments[0]]),
168f9008e63STobias Hieta            )
1699f3f6d7bSStella Laurenzo
1709f3f6d7bSStella Laurenzo    # CHECK: module  {
1719f3f6d7bSStella Laurenzo    # CHECK:  func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
1729f3f6d7bSStella Laurenzo    # CHECK:   return %arg0 : tensor<2x3x4xf32>
1739f3f6d7bSStella Laurenzo    # CHECK:  }
1749f3f6d7bSStella Laurenzo    # CHECK:  func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
1759f3f6d7bSStella Laurenzo    # CHECK:   return %arg0 : tensor<2x3x4xf32>
1769f3f6d7bSStella Laurenzo    # CHECK:  }
1779f3f6d7bSStella Laurenzo    print(m)
178afeda4b9SAlex Zinenko
179afeda4b9SAlex Zinenko
180afeda4b9SAlex Zinenko# CHECK-LABEL: TEST: testFuncArgumentAccess
181afeda4b9SAlex Zinenko@run
182afeda4b9SAlex Zinenkodef testFuncArgumentAccess():
183ace1d0adSStella Laurenzo    with Context() as ctx, Location.unknown():
184ace1d0adSStella Laurenzo        ctx.allow_unregistered_dialects = True
185afeda4b9SAlex Zinenko        module = Module.create()
186afeda4b9SAlex Zinenko        f32 = F32Type.get()
187afeda4b9SAlex Zinenko        f64 = F64Type.get()
188afeda4b9SAlex Zinenko        with InsertionPoint(module.body):
18936550692SRiver Riddle            f = func.FuncOp("some_func", ([f32, f32], [f32, f32]))
19023aa5a74SRiver Riddle            with InsertionPoint(f.add_entry_block()):
19123aa5a74SRiver Riddle                func.ReturnOp(f.arguments)
192f9008e63STobias Hieta            f.arg_attrs = ArrayAttr.get(
193f9008e63STobias Hieta                [
194f9008e63STobias Hieta                    DictAttr.get(
195f9008e63STobias Hieta                        {
196ace1d0adSStella Laurenzo                            "custom_dialect.foo": StringAttr.get("bar"),
197f9008e63STobias Hieta                            "custom_dialect.baz": UnitAttr.get(),
198f9008e63STobias Hieta                        }
199f9008e63STobias Hieta                    ),
200f9008e63STobias Hieta                    DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])}),
201f9008e63STobias Hieta                ]
202f9008e63STobias Hieta            )
203f9008e63STobias Hieta            f.result_attrs = ArrayAttr.get(
204f9008e63STobias Hieta                [
205ace1d0adSStella Laurenzo                    DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}),
206f9008e63STobias Hieta                    DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)}),
207f9008e63STobias Hieta                ]
208f9008e63STobias Hieta            )
209afeda4b9SAlex Zinenko
21036550692SRiver Riddle            other = func.FuncOp("other_func", ([f32, f32], []))
211255a6909SAlex Zinenko            with InsertionPoint(other.add_entry_block()):
21223aa5a74SRiver Riddle                func.ReturnOp([])
213255a6909SAlex Zinenko            other.arg_attrs = [
214ace1d0adSStella Laurenzo                DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}),
215f9008e63STobias Hieta                DictAttr.get(),
216255a6909SAlex Zinenko            ]
217255a6909SAlex Zinenko
218ace1d0adSStella Laurenzo    # CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}]
21923aa5a74SRiver Riddle    print(f.arg_attrs)
220afeda4b9SAlex Zinenko
221ace1d0adSStella Laurenzo    # CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}]
22223aa5a74SRiver Riddle    print(f.result_attrs)
223afeda4b9SAlex Zinenko
224afeda4b9SAlex Zinenko    # CHECK: func @some_func(
225ace1d0adSStella Laurenzo    # CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"},
226ace1d0adSStella Laurenzo    # CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) ->
227ace1d0adSStella Laurenzo    # CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32},
228ace1d0adSStella Laurenzo    # CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64})
229afeda4b9SAlex Zinenko    # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
230255a6909SAlex Zinenko    #
231255a6909SAlex Zinenko    # CHECK: func @other_func(
232ace1d0adSStella Laurenzo    # CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"},
233255a6909SAlex Zinenko    # CHECK: %{{.*}}: f32)
234afeda4b9SAlex Zinenko    print(module)
23599dee31eSAdam Paszke
23699dee31eSAdam Paszke
23799dee31eSAdam Paszke# CHECK-LABEL: testDenseElementsAttr
23899dee31eSAdam Paszke@run
23999dee31eSAdam Paszkedef testDenseElementsAttr():
24099dee31eSAdam Paszke    with Context(), Location.unknown():
24199dee31eSAdam Paszke        values = np.arange(4, dtype=np.int32)
24299dee31eSAdam Paszke        i32 = IntegerType.get_signless(32)
24399dee31eSAdam Paszke        print(DenseElementsAttr.get(values, type=i32))
24499dee31eSAdam Paszke        # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : tensor<4xi32>
24599dee31eSAdam Paszke        print(DenseElementsAttr.get(values, type=i32, shape=(2, 2)))
24699dee31eSAdam Paszke        # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
24799dee31eSAdam Paszke        print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32)))
24899dee31eSAdam Paszke        # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
249*5d3ae516SMatthias Gehre        idx_values = np.arange(4, dtype=np.int64)
250*5d3ae516SMatthias Gehre        idx_type = IndexType.get()
251*5d3ae516SMatthias Gehre        print(DenseElementsAttr.get(idx_values, type=VectorType.get([4], idx_type)))
252*5d3ae516SMatthias Gehre        # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : vector<4xindex>
253