19f3f6d7bSStella Laurenzo# RUN: %PYTHON %s | FileCheck %s 29f3f6d7bSStella Laurenzo 39f3f6d7bSStella Laurenzofrom mlir.ir import * 49f3f6d7bSStella Laurenzoimport mlir.dialects.builtin as builtin 523aa5a74SRiver Riddleimport mlir.dialects.func as func 699dee31eSAdam Paszkeimport numpy as np 79f3f6d7bSStella Laurenzo 89f3f6d7bSStella Laurenzo 99f3f6d7bSStella Laurenzodef run(f): 109f3f6d7bSStella Laurenzo print("\nTEST:", f.__name__) 119f3f6d7bSStella Laurenzo f() 129f3f6d7bSStella Laurenzo return f 139f3f6d7bSStella Laurenzo 149f3f6d7bSStella Laurenzo 159f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testFromPyFunc 169f3f6d7bSStella Laurenzo@run 179f3f6d7bSStella Laurenzodef testFromPyFunc(): 189f3f6d7bSStella Laurenzo with Context() as ctx, Location.unknown() as loc: 19f431d387SMehdi Amini ctx.allow_unregistered_dialects = True 209f3f6d7bSStella Laurenzo m = builtin.ModuleOp() 219f3f6d7bSStella Laurenzo f32 = F32Type.get() 229f3f6d7bSStella Laurenzo f64 = F64Type.get() 239f3f6d7bSStella Laurenzo with InsertionPoint(m.body): 249f3f6d7bSStella Laurenzo # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64 259f3f6d7bSStella Laurenzo # CHECK: return %arg0 : f64 2636550692SRiver Riddle @func.FuncOp.from_py_func(f64) 279f3f6d7bSStella Laurenzo def unary_return(a): 289f3f6d7bSStella Laurenzo return a 299f3f6d7bSStella Laurenzo 309f3f6d7bSStella Laurenzo # CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64) 319f3f6d7bSStella Laurenzo # CHECK: return %arg0, %arg1 : f32, f64 3236550692SRiver Riddle @func.FuncOp.from_py_func(f32, f64) 339f3f6d7bSStella Laurenzo def binary_return(a, b): 349f3f6d7bSStella Laurenzo return a, b 359f3f6d7bSStella Laurenzo 369f3f6d7bSStella Laurenzo # CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64) 379f3f6d7bSStella Laurenzo # CHECK: return 3836550692SRiver Riddle @func.FuncOp.from_py_func(f32, f64) 399f3f6d7bSStella Laurenzo def none_return(a, b): 409f3f6d7bSStella Laurenzo pass 419f3f6d7bSStella Laurenzo 429f3f6d7bSStella Laurenzo # CHECK-LABEL: func @call_unary 439f3f6d7bSStella Laurenzo # CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64 449f3f6d7bSStella Laurenzo # CHECK: return %0 : f64 4536550692SRiver Riddle @func.FuncOp.from_py_func(f64) 469f3f6d7bSStella Laurenzo def call_unary(a): 479f3f6d7bSStella Laurenzo return unary_return(a) 489f3f6d7bSStella Laurenzo 499f3f6d7bSStella Laurenzo # CHECK-LABEL: func @call_binary 509f3f6d7bSStella Laurenzo # CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64) 519f3f6d7bSStella Laurenzo # CHECK: return %0#0, %0#1 : f32, f64 5236550692SRiver Riddle @func.FuncOp.from_py_func(f32, f64) 539f3f6d7bSStella Laurenzo def call_binary(a, b): 549f3f6d7bSStella Laurenzo return binary_return(a, b) 559f3f6d7bSStella Laurenzo 56f431d387SMehdi Amini # We expect coercion of a single result operation to a returned value. 57f431d387SMehdi Amini # CHECK-LABEL: func @single_result_op 58f431d387SMehdi Amini # CHECK: %0 = "custom.op1"() : () -> f32 59f431d387SMehdi Amini # CHECK: return %0 : f32 6036550692SRiver Riddle @func.FuncOp.from_py_func() 61f431d387SMehdi Amini def single_result_op(): 62f431d387SMehdi Amini return Operation.create("custom.op1", results=[f32]) 63f431d387SMehdi Amini 649f3f6d7bSStella Laurenzo # CHECK-LABEL: func @call_none 659f3f6d7bSStella Laurenzo # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> () 669f3f6d7bSStella Laurenzo # CHECK: return 6736550692SRiver Riddle @func.FuncOp.from_py_func(f32, f64) 689f3f6d7bSStella Laurenzo def call_none(a, b): 699f3f6d7bSStella Laurenzo return none_return(a, b) 709f3f6d7bSStella Laurenzo 719f3f6d7bSStella Laurenzo ## Variants and optional feature tests. 729f3f6d7bSStella Laurenzo # CHECK-LABEL: func @from_name_arg 7336550692SRiver Riddle @func.FuncOp.from_py_func(f32, f64, name="from_name_arg") 749f3f6d7bSStella Laurenzo def explicit_name(a, b): 759f3f6d7bSStella Laurenzo return b 769f3f6d7bSStella Laurenzo 7736550692SRiver Riddle @func.FuncOp.from_py_func(f32, f64) 789f3f6d7bSStella Laurenzo def positional_func_op(a, b, func_op): 7936550692SRiver Riddle assert isinstance(func_op, func.FuncOp) 809f3f6d7bSStella Laurenzo return b 819f3f6d7bSStella Laurenzo 8236550692SRiver Riddle @func.FuncOp.from_py_func(f32, f64) 839f3f6d7bSStella Laurenzo def kw_func_op(a, b=None, func_op=None): 8436550692SRiver Riddle assert isinstance(func_op, func.FuncOp) 859f3f6d7bSStella Laurenzo return b 869f3f6d7bSStella Laurenzo 8736550692SRiver Riddle @func.FuncOp.from_py_func(f32, f64) 889f3f6d7bSStella Laurenzo def kwargs_func_op(a, b=None, **kwargs): 8936550692SRiver Riddle assert isinstance(kwargs["func_op"], func.FuncOp) 909f3f6d7bSStella Laurenzo return b 919f3f6d7bSStella Laurenzo 929f3f6d7bSStella Laurenzo # CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64 939f3f6d7bSStella Laurenzo # CHECK: return %arg1 : f64 9436550692SRiver Riddle @func.FuncOp.from_py_func(f32, f64, results=[f64]) 959f3f6d7bSStella Laurenzo def explicit_results(a, b): 9623aa5a74SRiver Riddle func.ReturnOp([b]) 979f3f6d7bSStella Laurenzo 989f3f6d7bSStella Laurenzo print(m) 999f3f6d7bSStella Laurenzo 1009f3f6d7bSStella Laurenzo 1019f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testFromPyFuncErrors 1029f3f6d7bSStella Laurenzo@run 1039f3f6d7bSStella Laurenzodef testFromPyFuncErrors(): 1049f3f6d7bSStella Laurenzo with Context() as ctx, Location.unknown() as loc: 1059f3f6d7bSStella Laurenzo m = builtin.ModuleOp() 1069f3f6d7bSStella Laurenzo f32 = F32Type.get() 1079f3f6d7bSStella Laurenzo f64 = F64Type.get() 1089f3f6d7bSStella Laurenzo with InsertionPoint(m.body): 1099f3f6d7bSStella Laurenzo try: 1109f3f6d7bSStella Laurenzo 11136550692SRiver Riddle @func.FuncOp.from_py_func(f64, results=[f64]) 1129f3f6d7bSStella Laurenzo def unary_return(a): 1139f3f6d7bSStella Laurenzo return a 114f9008e63STobias Hieta 1159f3f6d7bSStella Laurenzo except AssertionError as e: 1169f3f6d7bSStella Laurenzo # CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None. 1179f3f6d7bSStella Laurenzo print(e) 1189f3f6d7bSStella Laurenzo 1199f3f6d7bSStella Laurenzo 1209f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testBuildFuncOp 1219f3f6d7bSStella Laurenzo@run 1229f3f6d7bSStella Laurenzodef testBuildFuncOp(): 1239f3f6d7bSStella Laurenzo ctx = Context() 1249f3f6d7bSStella Laurenzo with Location.unknown(ctx) as loc: 1259f3f6d7bSStella Laurenzo m = builtin.ModuleOp() 1269f3f6d7bSStella Laurenzo 1279f3f6d7bSStella Laurenzo f32 = F32Type.get() 1289f3f6d7bSStella Laurenzo tensor_type = RankedTensorType.get((2, 3, 4), f32) 1299f3f6d7bSStella Laurenzo with InsertionPoint.at_block_begin(m.body): 130f9008e63STobias Hieta f = func.FuncOp( 131f9008e63STobias Hieta name="some_func", 1329f3f6d7bSStella Laurenzo type=FunctionType.get( 133f9008e63STobias Hieta inputs=[tensor_type, tensor_type], results=[tensor_type] 134f9008e63STobias Hieta ), 135f9008e63STobias Hieta visibility="nested", 136f9008e63STobias Hieta ) 1379f533548SIngo Müller # CHECK: Name is: "some_func" 13823aa5a74SRiver Riddle print("Name is: ", f.name) 1399f3f6d7bSStella Laurenzo 1409f3f6d7bSStella Laurenzo # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> 14123aa5a74SRiver Riddle print("Type is: ", f.type) 1429f3f6d7bSStella Laurenzo 1439f533548SIngo Müller # CHECK: Visibility is: "nested" 14423aa5a74SRiver Riddle print("Visibility is: ", f.visibility) 1459f3f6d7bSStella Laurenzo 1469f3f6d7bSStella Laurenzo try: 14723aa5a74SRiver Riddle entry_block = f.entry_block 1489f3f6d7bSStella Laurenzo except IndexError as e: 1499f3f6d7bSStella Laurenzo # CHECK: External function does not have a body 1509f3f6d7bSStella Laurenzo print(e) 1519f3f6d7bSStella Laurenzo 15223aa5a74SRiver Riddle with InsertionPoint(f.add_entry_block()): 15323aa5a74SRiver Riddle func.ReturnOp([f.entry_block.arguments[0]]) 1549f3f6d7bSStella Laurenzo pass 1559f3f6d7bSStella Laurenzo 1569f3f6d7bSStella Laurenzo try: 15723aa5a74SRiver Riddle f.add_entry_block() 1589f3f6d7bSStella Laurenzo except IndexError as e: 1599f3f6d7bSStella Laurenzo # CHECK: The function already has an entry block! 1609f3f6d7bSStella Laurenzo print(e) 1619f3f6d7bSStella Laurenzo 1629f3f6d7bSStella Laurenzo # Try the callback builder and passing type as tuple. 163f9008e63STobias Hieta f = func.FuncOp( 164f9008e63STobias Hieta name="some_other_func", 1659f3f6d7bSStella Laurenzo type=([tensor_type, tensor_type], [tensor_type]), 1669f3f6d7bSStella Laurenzo visibility="nested", 167f9008e63STobias Hieta body_builder=lambda f: func.ReturnOp([f.entry_block.arguments[0]]), 168f9008e63STobias Hieta ) 1699f3f6d7bSStella Laurenzo 1709f3f6d7bSStella Laurenzo # CHECK: module { 1719f3f6d7bSStella Laurenzo # CHECK: func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { 1729f3f6d7bSStella Laurenzo # CHECK: return %arg0 : tensor<2x3x4xf32> 1739f3f6d7bSStella Laurenzo # CHECK: } 1749f3f6d7bSStella Laurenzo # CHECK: func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { 1759f3f6d7bSStella Laurenzo # CHECK: return %arg0 : tensor<2x3x4xf32> 1769f3f6d7bSStella Laurenzo # CHECK: } 1779f3f6d7bSStella Laurenzo print(m) 178afeda4b9SAlex Zinenko 179afeda4b9SAlex Zinenko 180afeda4b9SAlex Zinenko# CHECK-LABEL: TEST: testFuncArgumentAccess 181afeda4b9SAlex Zinenko@run 182afeda4b9SAlex Zinenkodef testFuncArgumentAccess(): 183ace1d0adSStella Laurenzo with Context() as ctx, Location.unknown(): 184ace1d0adSStella Laurenzo ctx.allow_unregistered_dialects = True 185afeda4b9SAlex Zinenko module = Module.create() 186afeda4b9SAlex Zinenko f32 = F32Type.get() 187afeda4b9SAlex Zinenko f64 = F64Type.get() 188afeda4b9SAlex Zinenko with InsertionPoint(module.body): 18936550692SRiver Riddle f = func.FuncOp("some_func", ([f32, f32], [f32, f32])) 19023aa5a74SRiver Riddle with InsertionPoint(f.add_entry_block()): 19123aa5a74SRiver Riddle func.ReturnOp(f.arguments) 192f9008e63STobias Hieta f.arg_attrs = ArrayAttr.get( 193f9008e63STobias Hieta [ 194f9008e63STobias Hieta DictAttr.get( 195f9008e63STobias Hieta { 196ace1d0adSStella Laurenzo "custom_dialect.foo": StringAttr.get("bar"), 197f9008e63STobias Hieta "custom_dialect.baz": UnitAttr.get(), 198f9008e63STobias Hieta } 199f9008e63STobias Hieta ), 200f9008e63STobias Hieta DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])}), 201f9008e63STobias Hieta ] 202f9008e63STobias Hieta ) 203f9008e63STobias Hieta f.result_attrs = ArrayAttr.get( 204f9008e63STobias Hieta [ 205ace1d0adSStella Laurenzo DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}), 206f9008e63STobias Hieta DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)}), 207f9008e63STobias Hieta ] 208f9008e63STobias Hieta ) 209afeda4b9SAlex Zinenko 21036550692SRiver Riddle other = func.FuncOp("other_func", ([f32, f32], [])) 211255a6909SAlex Zinenko with InsertionPoint(other.add_entry_block()): 21223aa5a74SRiver Riddle func.ReturnOp([]) 213255a6909SAlex Zinenko other.arg_attrs = [ 214ace1d0adSStella Laurenzo DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}), 215f9008e63STobias Hieta DictAttr.get(), 216255a6909SAlex Zinenko ] 217255a6909SAlex Zinenko 218ace1d0adSStella Laurenzo # CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}] 21923aa5a74SRiver Riddle print(f.arg_attrs) 220afeda4b9SAlex Zinenko 221ace1d0adSStella Laurenzo # CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}] 22223aa5a74SRiver Riddle print(f.result_attrs) 223afeda4b9SAlex Zinenko 224afeda4b9SAlex Zinenko # CHECK: func @some_func( 225ace1d0adSStella Laurenzo # CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"}, 226ace1d0adSStella Laurenzo # CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) -> 227ace1d0adSStella Laurenzo # CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32}, 228ace1d0adSStella Laurenzo # CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64}) 229afeda4b9SAlex Zinenko # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32 230255a6909SAlex Zinenko # 231255a6909SAlex Zinenko # CHECK: func @other_func( 232ace1d0adSStella Laurenzo # CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"}, 233255a6909SAlex Zinenko # CHECK: %{{.*}}: f32) 234afeda4b9SAlex Zinenko print(module) 23599dee31eSAdam Paszke 23699dee31eSAdam Paszke 23799dee31eSAdam Paszke# CHECK-LABEL: testDenseElementsAttr 23899dee31eSAdam Paszke@run 23999dee31eSAdam Paszkedef testDenseElementsAttr(): 24099dee31eSAdam Paszke with Context(), Location.unknown(): 24199dee31eSAdam Paszke values = np.arange(4, dtype=np.int32) 24299dee31eSAdam Paszke i32 = IntegerType.get_signless(32) 24399dee31eSAdam Paszke print(DenseElementsAttr.get(values, type=i32)) 24499dee31eSAdam Paszke # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : tensor<4xi32> 24599dee31eSAdam Paszke print(DenseElementsAttr.get(values, type=i32, shape=(2, 2))) 24699dee31eSAdam Paszke # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> 24799dee31eSAdam Paszke print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32))) 24899dee31eSAdam Paszke # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32> 249*5d3ae516SMatthias Gehre idx_values = np.arange(4, dtype=np.int64) 250*5d3ae516SMatthias Gehre idx_type = IndexType.get() 251*5d3ae516SMatthias Gehre print(DenseElementsAttr.get(idx_values, type=VectorType.get([4], idx_type))) 252*5d3ae516SMatthias Gehre # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : vector<4xindex> 253