1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4import mlir.dialects.builtin as builtin 5import mlir.dialects.func as func 6import numpy as np 7 8 9def run(f): 10 print("\nTEST:", f.__name__) 11 f() 12 return f 13 14 15# CHECK-LABEL: TEST: testFromPyFunc 16@run 17def testFromPyFunc(): 18 with Context() as ctx, Location.unknown() as loc: 19 ctx.allow_unregistered_dialects = True 20 m = builtin.ModuleOp() 21 f32 = F32Type.get() 22 f64 = F64Type.get() 23 with InsertionPoint(m.body): 24 # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64 25 # CHECK: return %arg0 : f64 26 @func.FuncOp.from_py_func(f64) 27 def unary_return(a): 28 return a 29 30 # CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64) 31 # CHECK: return %arg0, %arg1 : f32, f64 32 @func.FuncOp.from_py_func(f32, f64) 33 def binary_return(a, b): 34 return a, b 35 36 # CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64) 37 # CHECK: return 38 @func.FuncOp.from_py_func(f32, f64) 39 def none_return(a, b): 40 pass 41 42 # CHECK-LABEL: func @call_unary 43 # CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64 44 # CHECK: return %0 : f64 45 @func.FuncOp.from_py_func(f64) 46 def call_unary(a): 47 return unary_return(a) 48 49 # CHECK-LABEL: func @call_binary 50 # CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64) 51 # CHECK: return %0#0, %0#1 : f32, f64 52 @func.FuncOp.from_py_func(f32, f64) 53 def call_binary(a, b): 54 return binary_return(a, b) 55 56 # We expect coercion of a single result operation to a returned value. 57 # CHECK-LABEL: func @single_result_op 58 # CHECK: %0 = "custom.op1"() : () -> f32 59 # CHECK: return %0 : f32 60 @func.FuncOp.from_py_func() 61 def single_result_op(): 62 return Operation.create("custom.op1", results=[f32]) 63 64 # CHECK-LABEL: func @call_none 65 # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> () 66 # CHECK: return 67 @func.FuncOp.from_py_func(f32, f64) 68 def call_none(a, b): 69 return none_return(a, b) 70 71 ## Variants and optional feature tests. 72 # CHECK-LABEL: func @from_name_arg 73 @func.FuncOp.from_py_func(f32, f64, name="from_name_arg") 74 def explicit_name(a, b): 75 return b 76 77 @func.FuncOp.from_py_func(f32, f64) 78 def positional_func_op(a, b, func_op): 79 assert isinstance(func_op, func.FuncOp) 80 return b 81 82 @func.FuncOp.from_py_func(f32, f64) 83 def kw_func_op(a, b=None, func_op=None): 84 assert isinstance(func_op, func.FuncOp) 85 return b 86 87 @func.FuncOp.from_py_func(f32, f64) 88 def kwargs_func_op(a, b=None, **kwargs): 89 assert isinstance(kwargs["func_op"], func.FuncOp) 90 return b 91 92 # CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64 93 # CHECK: return %arg1 : f64 94 @func.FuncOp.from_py_func(f32, f64, results=[f64]) 95 def explicit_results(a, b): 96 func.ReturnOp([b]) 97 98 print(m) 99 100 101# CHECK-LABEL: TEST: testFromPyFuncErrors 102@run 103def testFromPyFuncErrors(): 104 with Context() as ctx, Location.unknown() as loc: 105 m = builtin.ModuleOp() 106 f32 = F32Type.get() 107 f64 = F64Type.get() 108 with InsertionPoint(m.body): 109 try: 110 111 @func.FuncOp.from_py_func(f64, results=[f64]) 112 def unary_return(a): 113 return a 114 115 except AssertionError as e: 116 # CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None. 117 print(e) 118 119 120# CHECK-LABEL: TEST: testBuildFuncOp 121@run 122def testBuildFuncOp(): 123 ctx = Context() 124 with Location.unknown(ctx) as loc: 125 m = builtin.ModuleOp() 126 127 f32 = F32Type.get() 128 tensor_type = RankedTensorType.get((2, 3, 4), f32) 129 with InsertionPoint.at_block_begin(m.body): 130 f = func.FuncOp( 131 name="some_func", 132 type=FunctionType.get( 133 inputs=[tensor_type, tensor_type], results=[tensor_type] 134 ), 135 visibility="nested", 136 ) 137 # CHECK: Name is: "some_func" 138 print("Name is: ", f.name) 139 140 # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> 141 print("Type is: ", f.type) 142 143 # CHECK: Visibility is: "nested" 144 print("Visibility is: ", f.visibility) 145 146 try: 147 entry_block = f.entry_block 148 except IndexError as e: 149 # CHECK: External function does not have a body 150 print(e) 151 152 with InsertionPoint(f.add_entry_block()): 153 func.ReturnOp([f.entry_block.arguments[0]]) 154 pass 155 156 try: 157 f.add_entry_block() 158 except IndexError as e: 159 # CHECK: The function already has an entry block! 160 print(e) 161 162 # Try the callback builder and passing type as tuple. 163 f = func.FuncOp( 164 name="some_other_func", 165 type=([tensor_type, tensor_type], [tensor_type]), 166 visibility="nested", 167 body_builder=lambda f: func.ReturnOp([f.entry_block.arguments[0]]), 168 ) 169 170 # CHECK: module { 171 # CHECK: func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { 172 # CHECK: return %arg0 : tensor<2x3x4xf32> 173 # CHECK: } 174 # CHECK: func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { 175 # CHECK: return %arg0 : tensor<2x3x4xf32> 176 # CHECK: } 177 print(m) 178 179 180# CHECK-LABEL: TEST: testFuncArgumentAccess 181@run 182def testFuncArgumentAccess(): 183 with Context() as ctx, Location.unknown(): 184 ctx.allow_unregistered_dialects = True 185 module = Module.create() 186 f32 = F32Type.get() 187 f64 = F64Type.get() 188 with InsertionPoint(module.body): 189 f = func.FuncOp("some_func", ([f32, f32], [f32, f32])) 190 with InsertionPoint(f.add_entry_block()): 191 func.ReturnOp(f.arguments) 192 f.arg_attrs = ArrayAttr.get( 193 [ 194 DictAttr.get( 195 { 196 "custom_dialect.foo": StringAttr.get("bar"), 197 "custom_dialect.baz": UnitAttr.get(), 198 } 199 ), 200 DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])}), 201 ] 202 ) 203 f.result_attrs = ArrayAttr.get( 204 [ 205 DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}), 206 DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)}), 207 ] 208 ) 209 210 other = func.FuncOp("other_func", ([f32, f32], [])) 211 with InsertionPoint(other.add_entry_block()): 212 func.ReturnOp([]) 213 other.arg_attrs = [ 214 DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}), 215 DictAttr.get(), 216 ] 217 218 # CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}] 219 print(f.arg_attrs) 220 221 # CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}] 222 print(f.result_attrs) 223 224 # CHECK: func @some_func( 225 # CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"}, 226 # CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) -> 227 # CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32}, 228 # CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64}) 229 # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32 230 # 231 # CHECK: func @other_func( 232 # CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"}, 233 # CHECK: %{{.*}}: f32) 234 print(module) 235 236 237# CHECK-LABEL: testDenseElementsAttr 238@run 239def testDenseElementsAttr(): 240 with Context(), Location.unknown(): 241 values = np.arange(4, dtype=np.int32) 242 i32 = IntegerType.get_signless(32) 243 print(DenseElementsAttr.get(values, type=i32)) 244 # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : tensor<4xi32> 245 print(DenseElementsAttr.get(values, type=i32, shape=(2, 2))) 246 # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> 247 print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32))) 248 # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32> 249 idx_values = np.arange(4, dtype=np.int64) 250 idx_type = IndexType.get() 251 print(DenseElementsAttr.get(idx_values, type=VectorType.get([4], idx_type))) 252 # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : vector<4xindex> 253