16944f7daSTobias Gysi# RUN: %PYTHON %s 2>&1 | FileCheck %s 26944f7daSTobias Gysi 37e2174c2SStella Laurenzoimport ctypes 46944f7daSTobias Gysiimport sys 56944f7daSTobias Gysifrom mlir.ir import * 66944f7daSTobias Gysifrom mlir.dialects import builtin 723aa5a74SRiver Riddlefrom mlir.dialects import func 86944f7daSTobias Gysifrom mlir.dialects import linalg 96944f7daSTobias Gysifrom mlir.passmanager import * 106944f7daSTobias Gysifrom mlir.execution_engine import * 116944f7daSTobias Gysi 1251fdd802Sgysitfrom mlir.dialects.linalg.opdsl.lang import * 1351fdd802Sgysit 146944f7daSTobias Gysi 156944f7daSTobias Gysi# Log everything to stderr and flush so that we have a unified stream to match 166944f7daSTobias Gysi# errors/info emitted by MLIR to stderr. 176944f7daSTobias Gysidef log(*args): 186944f7daSTobias Gysi print(*args, file=sys.stderr) 196944f7daSTobias Gysi sys.stderr.flush() 206944f7daSTobias Gysi 216944f7daSTobias Gysi 2224357fecSgysitelemwise_boiler = """ 232310ced8SRiver Riddlefunc.func @main() -> f32 attributes {llvm.emit_c_interface} { 2424357fecSgysit %v0 = arith.constant 0.0 : f32 2524357fecSgysit %v1 = arith.constant 1.0 : f32 2624357fecSgysit %v2 = arith.constant 2.0 : f32 2724357fecSgysit 28f345f7e3Sgysit %lhs = memref.alloc() : memref<f32> 2924357fecSgysit %rhs = memref.alloc() : memref<4x8xf32> 3024357fecSgysit %O0 = memref.alloc() : memref<4x8xf32> 3124357fecSgysit %O1 = memref.alloc() : memref<4x8xf32> 327294be2bSgysit linalg.fill ins(%v1 : f32) outs(%lhs : memref<f32>) 337294be2bSgysit linalg.fill ins(%v2 : f32) outs(%rhs : memref<4x8xf32>) 347294be2bSgysit linalg.fill ins(%v0 : f32) outs(%O0 : memref<4x8xf32>) 357294be2bSgysit linalg.fill ins(%v0 : f32) outs(%O1 : memref<4x8xf32>) 3624357fecSgysit 3724357fecSgysit call @elemwise_exp_add_on_buffers(%lhs, %rhs, %O0) : 38f345f7e3Sgysit (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> () 3924357fecSgysit call @elemwise_log_mul_on_buffers(%lhs, %rhs, %O1) : 40f345f7e3Sgysit (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> () 4124357fecSgysit 4224357fecSgysit %c0 = arith.constant 0 : index 4324357fecSgysit %res0 = memref.load %O0[%c0, %c0] : memref<4x8xf32> 4424357fecSgysit %res1 = memref.load %O1[%c0, %c0] : memref<4x8xf32> 4524357fecSgysit 4624357fecSgysit %0 = arith.addf %res0, %res1 : f32 4724357fecSgysit 4824357fecSgysit // TODO: FFI-based solution to allow testing and printing with python code. 4924357fecSgysit return %0 : f32 5024357fecSgysit} 5124357fecSgysit""" 5224357fecSgysit 536944f7daSTobias Gysifill_boiler = """ 542310ced8SRiver Riddlefunc.func @main() -> i32 attributes {llvm.emit_c_interface} { 55a3655de2Sgysit %O0 = memref.alloc() : memref<i32> 56a3655de2Sgysit %O1 = memref.alloc() : memref<16xi32> 57a3655de2Sgysit %O2 = memref.alloc() : memref<4x16xi32> 58a3655de2Sgysit 59a3655de2Sgysit %val0 = arith.constant 1.0 : f32 60a3655de2Sgysit %val1 = arith.constant 2.0 : f32 61a3655de2Sgysit %val2 = arith.constant 3.0 : f32 62a3655de2Sgysit 63a3655de2Sgysit call @fill_0d_on_buffers(%val0, %O0) : (f32, memref<i32>) -> () 64a3655de2Sgysit call @fill_1d_on_buffers(%val1, %O1) : (f32, memref<16xi32>) -> () 65a3655de2Sgysit call @fill_2d_on_buffers(%val2, %O2) : (f32, memref<4x16xi32>) -> () 66a3655de2Sgysit 67a3655de2Sgysit %c0 = arith.constant 0 : index 68a3655de2Sgysit %res0 = memref.load %O0[] : memref<i32> 69a3655de2Sgysit %c8 = arith.constant 8 : index 70a3655de2Sgysit %res1 = memref.load %O1[%c8] : memref<16xi32> 71a3655de2Sgysit %c2 = arith.constant 2 : index 72a3655de2Sgysit %res2 = memref.load %O2[%c2, %c8] : memref<4x16xi32> 73a3655de2Sgysit 74a3655de2Sgysit %0 = arith.addi %res0, %res1 : i32 75a3655de2Sgysit %1 = arith.addi %0, %res2 : i32 76a3655de2Sgysit 77a3655de2Sgysit // TODO: FFI-based solution to allow testing and printing with python code. 78a3655de2Sgysit return %1 : i32 79a3655de2Sgysit} 80a3655de2Sgysit""" 81a3655de2Sgysit 82a3655de2Sgysitfill_rng_boiler = """ 832310ced8SRiver Riddlefunc.func @main() -> i32 attributes {llvm.emit_c_interface} { 846944f7daSTobias Gysi %O = memref.alloc() : memref<4x16xi32> 85a54f4eaeSMogball %min = arith.constant -1000.0 : f64 86a54f4eaeSMogball %max = arith.constant 1000.0 : f64 87a54f4eaeSMogball %seed = arith.constant 42 : i32 886944f7daSTobias Gysi 89a3655de2Sgysit call @fill_rng_on_buffers(%min, %max, %seed, %O) : 906944f7daSTobias Gysi (f64, f64, i32, memref<4x16xi32>) -> () 916944f7daSTobias Gysi 92a54f4eaeSMogball %c0 = arith.constant 0 : index 936944f7daSTobias Gysi %0 = memref.load %O[%c0, %c0] : memref<4x16xi32> 946944f7daSTobias Gysi 956944f7daSTobias Gysi // TODO: FFI-based solution to allow testing and printing with python code. 966944f7daSTobias Gysi return %0 : i32 976944f7daSTobias Gysi} 986944f7daSTobias Gysi""" 996944f7daSTobias Gysi 1006944f7daSTobias Gysiconv_boiler = """ 1012310ced8SRiver Riddlefunc.func @main() -> i32 attributes {llvm.emit_c_interface} { 102a54f4eaeSMogball %v0 = arith.constant 0 : i32 103a54f4eaeSMogball %v1 = arith.constant 1.0 : f64 104a54f4eaeSMogball %v2 = arith.constant 2.0 : f64 1056944f7daSTobias Gysi 1066944f7daSTobias Gysi %input = memref.alloc() : memref<1x4x16x1xf64> 1076944f7daSTobias Gysi %filter = memref.alloc() : memref<2x2x1xf64> 1086944f7daSTobias Gysi %output = memref.alloc() : memref<1x2x4x1xi32> 1097294be2bSgysit linalg.fill ins(%v1 : f64) outs(%input : memref<1x4x16x1xf64>) 1107294be2bSgysit linalg.fill ins(%v2 : f64) outs(%filter : memref<2x2x1xf64>) 1117294be2bSgysit linalg.fill ins(%v0 : i32) outs(%output : memref<1x2x4x1xi32>) 1126944f7daSTobias Gysi 1136944f7daSTobias Gysi call @conv_on_buffers(%input, %filter, %output) : 1146944f7daSTobias Gysi (memref<1x4x16x1xf64>, memref<2x2x1xf64>, memref<1x2x4x1xi32>) -> () 1156944f7daSTobias Gysi 116a54f4eaeSMogball %c0 = arith.constant 0 : index 1176944f7daSTobias Gysi %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32> 1186944f7daSTobias Gysi 1196944f7daSTobias Gysi // TODO: FFI-based solution to allow testing and printing with python code. 1206944f7daSTobias Gysi return %0 : i32 1216944f7daSTobias Gysi} 1226944f7daSTobias Gysi""" 1236944f7daSTobias Gysi 1246944f7daSTobias Gysipooling_boiler = """ 1252310ced8SRiver Riddlefunc.func @main() -> i32 attributes {llvm.emit_c_interface} { 126a54f4eaeSMogball %v0 = arith.constant 0 : i32 127a54f4eaeSMogball %v42 = arith.constant 42.0 : f64 128a54f4eaeSMogball %v77 = arith.constant 77.0 : f64 129a54f4eaeSMogball %v-13 = arith.constant -13.0 : f64 130a54f4eaeSMogball %v1 = arith.constant 1.0 : f64 1316944f7daSTobias Gysi 1326944f7daSTobias Gysi %input = memref.alloc() : memref<1x4x16x1xf64> 1336944f7daSTobias Gysi %shape = memref.alloc() : memref<2x2xf64> 1346944f7daSTobias Gysi %output = memref.alloc() : memref<1x2x4x1xi32> 1357294be2bSgysit linalg.fill ins(%v1 : f64) outs(%input : memref<1x4x16x1xf64>) 1367294be2bSgysit linalg.fill ins(%v1 : f64) outs(%shape : memref<2x2xf64>) 1377294be2bSgysit linalg.fill ins(%v0 : i32) outs(%output : memref<1x2x4x1xi32>) 1386944f7daSTobias Gysi 139a54f4eaeSMogball %c0 = arith.constant 0 : index 140a54f4eaeSMogball %c1 = arith.constant 1 : index 141a54f4eaeSMogball %c2 = arith.constant 2 : index 1426944f7daSTobias Gysi memref.store %v42, %input[%c0, %c0, %c0, %c0] : memref<1x4x16x1xf64> 143f239026fSTobias Gysi memref.store %v77, %input[%c0, %c0, %c1, %c0] : memref<1x4x16x1xf64> 144d50571abSgysit memref.store %v-13, %input[%c0, %c1, %c0, %c0] : memref<1x4x16x1xf64> 1456944f7daSTobias Gysi 1466944f7daSTobias Gysi call @pooling_on_buffers(%input, %shape, %output) : 1476944f7daSTobias Gysi (memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> () 1486944f7daSTobias Gysi 1496944f7daSTobias Gysi %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32> 1506944f7daSTobias Gysi 1516944f7daSTobias Gysi // TODO: FFI-based solution to allow testing and printing with python code. 1526944f7daSTobias Gysi return %0 : i32 1536944f7daSTobias Gysi} 1546944f7daSTobias Gysi""" 1556944f7daSTobias Gysi 1566944f7daSTobias Gysi 1576944f7daSTobias Gysidef transform(module, boilerplate): 1586944f7daSTobias Gysi # TODO: Allow cloning functions from one module to another. 1596944f7daSTobias Gysi # Atm we have to resort to string concatenation. 160a3655de2Sgysit ops = module.operation.regions[0].blocks[0].operations 161a3655de2Sgysit mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate) 162a3655de2Sgysit 163f9008e63STobias Hieta pm = PassManager("builtin.module") 164dd1b1d44Srkayaith pm.add("func.func(convert-linalg-to-loops)") 165dd1b1d44Srkayaith pm.add("func.func(lower-affine)") 166dd1b1d44Srkayaith pm.add("func.func(convert-math-to-llvm)") 167dd1b1d44Srkayaith pm.add("func.func(convert-scf-to-cf)") 168dd1b1d44Srkayaith pm.add("func.func(arith-expand)") 169dd1b1d44Srkayaith pm.add("func.func(memref-expand)") 170dd1b1d44Srkayaith pm.add("convert-vector-to-llvm") 171cb4ccd38SQuentin Colombet pm.add("finalize-memref-to-llvm") 172dd1b1d44Srkayaith pm.add("convert-func-to-llvm") 173b03a09e7SMatthias Springer pm.add("convert-arith-to-llvm") 174*eb6c4197SMatthias Springer pm.add("convert-cf-to-llvm") 175dd1b1d44Srkayaith pm.add("reconcile-unrealized-casts") 176c00f81ccSrkayaith pm.run(mod.operation) 1776944f7daSTobias Gysi return mod 1786944f7daSTobias Gysi 1796944f7daSTobias Gysi 18024357fecSgysitdef test_elemwise_builtin(): 18124357fecSgysit with Context() as ctx, Location.unknown(): 18224357fecSgysit module = Module.create() 18324357fecSgysit f32 = F32Type.get() 18424357fecSgysit i8 = IntegerType.get_signless(8) 18524357fecSgysit with InsertionPoint(module.body): 18624357fecSgysit 18736550692SRiver Riddle @func.FuncOp.from_py_func( 188f9008e63STobias Hieta MemRefType.get((), f32), 189f9008e63STobias Hieta MemRefType.get((4, 8), f32), 190f9008e63STobias Hieta MemRefType.get((4, 8), f32), 191f9008e63STobias Hieta ) 19224357fecSgysit def elemwise_exp_add_on_buffers(lhs, rhs, out): 19324357fecSgysit linalg.elemwise_unary(lhs, outs=[out]) 19424357fecSgysit linalg.elemwise_binary(out, rhs, outs=[out]) 19524357fecSgysit 19636550692SRiver Riddle @func.FuncOp.from_py_func( 197f9008e63STobias Hieta MemRefType.get((), f32), 198f9008e63STobias Hieta MemRefType.get((4, 8), f32), 199f9008e63STobias Hieta MemRefType.get((4, 8), f32), 200f9008e63STobias Hieta ) 20124357fecSgysit def elemwise_log_mul_on_buffers(lhs, rhs, out): 20224357fecSgysit linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log) 20324357fecSgysit linalg.elemwise_binary(out, rhs, outs=[out], fun=BinaryFn.mul) 20424357fecSgysit 20524357fecSgysit execution_engine = ExecutionEngine(transform(module, elemwise_boiler)) 20624357fecSgysit 20724357fecSgysit # TODO: FFI-based solution to allow testing and printing with python code. 20824357fecSgysit # Prepare arguments: one result f32. 20924357fecSgysit # Arguments must be passed as pointers. 21024357fecSgysit c_float_p = ctypes.c_float * 1 211f9008e63STobias Hieta res = c_float_p(-1.0) 21224357fecSgysit execution_engine.invoke("main", res) 21324357fecSgysit 21424357fecSgysit log("RESULT: ", res[0]) 21524357fecSgysit # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846 21624357fecSgysit # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0 21724357fecSgysit # CHECK: RESULT: 4.71828 21824357fecSgysit 21924357fecSgysit 22024357fecSgysittest_elemwise_builtin() 22124357fecSgysit 22224357fecSgysit 22324357fecSgysitdef test_elemwise_generic(): 22424357fecSgysit with Context() as ctx, Location.unknown(): 22524357fecSgysit module = Module.create() 22624357fecSgysit f32 = F32Type.get() 22724357fecSgysit i8 = IntegerType.get_signless(8) 22824357fecSgysit with InsertionPoint(module.body): 22924357fecSgysit 23036550692SRiver Riddle @func.FuncOp.from_py_func( 231f9008e63STobias Hieta MemRefType.get((), f32), 232f9008e63STobias Hieta MemRefType.get((4, 8), f32), 233f9008e63STobias Hieta MemRefType.get((4, 8), f32), 234f9008e63STobias Hieta ) 23524357fecSgysit def elemwise_exp_add_on_buffers(lhs, rhs, out): 23624357fecSgysit linalg.elemwise_unary(lhs, outs=[out], emit_generic=True) 23724357fecSgysit linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True) 23824357fecSgysit 23936550692SRiver Riddle @func.FuncOp.from_py_func( 240f9008e63STobias Hieta MemRefType.get((), f32), 241f9008e63STobias Hieta MemRefType.get((4, 8), f32), 242f9008e63STobias Hieta MemRefType.get((4, 8), f32), 243f9008e63STobias Hieta ) 24424357fecSgysit def elemwise_log_mul_on_buffers(lhs, rhs, out): 24524357fecSgysit linalg.elemwise_unary( 246f9008e63STobias Hieta lhs, outs=[out], fun=UnaryFn.log, emit_generic=True 247f9008e63STobias Hieta ) 24824357fecSgysit linalg.elemwise_binary( 249f9008e63STobias Hieta out, rhs, outs=[out], fun=BinaryFn.mul, emit_generic=True 250f9008e63STobias Hieta ) 25124357fecSgysit 25224357fecSgysit execution_engine = ExecutionEngine(transform(module, elemwise_boiler)) 25324357fecSgysit 25424357fecSgysit # TODO: FFI-based solution to allow testing and printing with python code. 25524357fecSgysit # Prepare arguments: one result f32. 25624357fecSgysit # Arguments must be passed as pointers. 25724357fecSgysit c_float_p = ctypes.c_float * 1 258f9008e63STobias Hieta res = c_float_p(-1.0) 25924357fecSgysit execution_engine.invoke("main", res) 26024357fecSgysit 26124357fecSgysit log("RESULT: ", res[0]) 26224357fecSgysit # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846 26324357fecSgysit # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0 26424357fecSgysit # CHECK: RESULT: 4.71828 26524357fecSgysit 26624357fecSgysit 26724357fecSgysittest_elemwise_generic() 26824357fecSgysit 26924357fecSgysit 2706944f7daSTobias Gysidef test_fill_builtin(): 2716944f7daSTobias Gysi with Context() as ctx, Location.unknown(): 2726944f7daSTobias Gysi module = Module.create() 273a3655de2Sgysit f32 = F32Type.get() 2746944f7daSTobias Gysi i32 = IntegerType.get_signless(32) 2756944f7daSTobias Gysi with InsertionPoint(module.body): 2766944f7daSTobias Gysi 27736550692SRiver Riddle @func.FuncOp.from_py_func(f32, MemRefType.get([], i32)) 278a3655de2Sgysit def fill_0d_on_buffers(value, out): 2797294be2bSgysit linalg.fill(value, outs=[out]) 280a3655de2Sgysit 28136550692SRiver Riddle @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) 282a3655de2Sgysit def fill_1d_on_buffers(value, out): 2837294be2bSgysit linalg.fill(value, outs=[out]) 284a3655de2Sgysit 28536550692SRiver Riddle @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) 286a3655de2Sgysit def fill_2d_on_buffers(value, out): 2877294be2bSgysit linalg.fill(value, outs=[out]) 2886944f7daSTobias Gysi 2896944f7daSTobias Gysi execution_engine = ExecutionEngine(transform(module, fill_boiler)) 2906944f7daSTobias Gysi 2916944f7daSTobias Gysi # TODO: FFI-based solution to allow testing and printing with python code. 2926944f7daSTobias Gysi # Prepare arguments: one result i32. 2936944f7daSTobias Gysi # Arguments must be passed as pointers. 2946944f7daSTobias Gysi c_int_p = ctypes.c_int * 1 2956944f7daSTobias Gysi res = c_int_p(-1) 2966944f7daSTobias Gysi execution_engine.invoke("main", res) 2976944f7daSTobias Gysi 2986944f7daSTobias Gysi log("RESULT: ", res[0]) 299a3655de2Sgysit # CHECK: RESULT: 6 3006944f7daSTobias Gysi 3016944f7daSTobias Gysi 3026944f7daSTobias Gysitest_fill_builtin() 3036944f7daSTobias Gysi 3046944f7daSTobias Gysi 3056944f7daSTobias Gysidef test_fill_generic(): 3066944f7daSTobias Gysi with Context() as ctx, Location.unknown(): 3076944f7daSTobias Gysi module = Module.create() 308a3655de2Sgysit f32 = F32Type.get() 309a3655de2Sgysit i32 = IntegerType.get_signless(32) 310a3655de2Sgysit with InsertionPoint(module.body): 311a3655de2Sgysit 31236550692SRiver Riddle @func.FuncOp.from_py_func(f32, MemRefType.get([], i32)) 313a3655de2Sgysit def fill_0d_on_buffers(value, out): 3147294be2bSgysit linalg.fill(value, outs=[out], emit_generic=True) 315a3655de2Sgysit 31636550692SRiver Riddle @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) 317a3655de2Sgysit def fill_1d_on_buffers(value, out): 3187294be2bSgysit linalg.fill(value, outs=[out], emit_generic=True) 319a3655de2Sgysit 32036550692SRiver Riddle @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) 321a3655de2Sgysit def fill_2d_on_buffers(value, out): 3227294be2bSgysit linalg.fill(value, outs=[out], emit_generic=True) 323a3655de2Sgysit 324a3655de2Sgysit execution_engine = ExecutionEngine(transform(module, fill_boiler)) 325a3655de2Sgysit 326a3655de2Sgysit # TODO: FFI-based solution to allow testing and printing with python code. 327a3655de2Sgysit # Prepare arguments: one result i32. 328a3655de2Sgysit # Arguments must be passed as pointers. 329a3655de2Sgysit c_int_p = ctypes.c_int * 1 330a3655de2Sgysit res = c_int_p(-1) 331a3655de2Sgysit execution_engine.invoke("main", res) 332a3655de2Sgysit 333a3655de2Sgysit log("RESULT: ", res[0]) 334a3655de2Sgysit # CHECK: RESULT: 6 335a3655de2Sgysit 336a3655de2Sgysit 337a3655de2Sgysittest_fill_generic() 338a3655de2Sgysit 339a3655de2Sgysit 340a3655de2Sgysitdef test_fill_rng_builtin(): 341a3655de2Sgysit with Context() as ctx, Location.unknown(): 342a3655de2Sgysit module = Module.create() 3436944f7daSTobias Gysi f64 = F64Type.get() 3446944f7daSTobias Gysi i32 = IntegerType.get_signless(32) 3456944f7daSTobias Gysi with InsertionPoint(module.body): 3466944f7daSTobias Gysi 34736550692SRiver Riddle @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) 348a3655de2Sgysit def fill_rng_on_buffers(min, max, seed, out): 349a3655de2Sgysit linalg.fill_rng_2d(min, max, seed, outs=[out]) 3506944f7daSTobias Gysi 351a3655de2Sgysit execution_engine = ExecutionEngine(transform(module, fill_rng_boiler)) 3526944f7daSTobias Gysi 3536944f7daSTobias Gysi # TODO: FFI-based solution to allow testing and printing with python code. 3546944f7daSTobias Gysi # Prepare arguments: one result i32. 3556944f7daSTobias Gysi # Arguments must be passed as pointers. 3566944f7daSTobias Gysi c_int_p = ctypes.c_int * 1 3576944f7daSTobias Gysi res = c_int_p(-1) 3586944f7daSTobias Gysi execution_engine.invoke("main", res) 3596944f7daSTobias Gysi 3606944f7daSTobias Gysi log("RESULT: ", res[0]) 3616944f7daSTobias Gysi # CHECK: RESULT: -480 3626944f7daSTobias Gysi 3636944f7daSTobias Gysi 364a3655de2Sgysittest_fill_rng_builtin() 365a3655de2Sgysit 366a3655de2Sgysit 367a3655de2Sgysitdef test_fill_rng_generic(): 368a3655de2Sgysit with Context() as ctx, Location.unknown(): 369a3655de2Sgysit module = Module.create() 370a3655de2Sgysit f64 = F64Type.get() 371a3655de2Sgysit i32 = IntegerType.get_signless(32) 372a3655de2Sgysit with InsertionPoint(module.body): 373a3655de2Sgysit 37436550692SRiver Riddle @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) 375a3655de2Sgysit def fill_rng_on_buffers(min, max, seed, out): 376a3655de2Sgysit linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True) 377a3655de2Sgysit 378a3655de2Sgysit execution_engine = ExecutionEngine(transform(module, fill_rng_boiler)) 379a3655de2Sgysit 380a3655de2Sgysit # TODO: FFI-based solution to allow testing and printing with python code. 381a3655de2Sgysit # Prepare arguments: one result i32. 382a3655de2Sgysit # Arguments must be passed as pointers. 383a3655de2Sgysit c_int_p = ctypes.c_int * 1 384a3655de2Sgysit res = c_int_p(-1) 385a3655de2Sgysit execution_engine.invoke("main", res) 386a3655de2Sgysit 387a3655de2Sgysit log("RESULT: ", res[0]) 388a3655de2Sgysit # CHECK: RESULT: -480 389a3655de2Sgysit 390a3655de2Sgysit 391a3655de2Sgysittest_fill_rng_generic() 3926944f7daSTobias Gysi 3936944f7daSTobias Gysi 394f239026fSTobias Gysidef test_max_pooling_builtin(): 3956944f7daSTobias Gysi with Context() as ctx, Location.unknown(): 3966944f7daSTobias Gysi module = Module.create() 3976944f7daSTobias Gysi f64 = F64Type.get() 3986944f7daSTobias Gysi i32 = IntegerType.get_signless(32) 3996944f7daSTobias Gysi with InsertionPoint(module.body): 4006944f7daSTobias Gysi 40136550692SRiver Riddle @func.FuncOp.from_py_func( 402f9008e63STobias Hieta MemRefType.get((1, 4, 16, 1), f64), 403f9008e63STobias Hieta MemRefType.get((2, 2), f64), 404f9008e63STobias Hieta MemRefType.get((1, 2, 4, 1), i32), 405f9008e63STobias Hieta ) 4066944f7daSTobias Gysi def pooling_on_buffers(input, shape, output): 4079c491953SHanhan Wang linalg.pooling_nhwc_max( 408f9008e63STobias Hieta input, shape, outs=[output], strides=[2, 4], dilations=[1, 2] 409f9008e63STobias Hieta ) 4106944f7daSTobias Gysi 4116944f7daSTobias Gysi execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 4126944f7daSTobias Gysi 4136944f7daSTobias Gysi # TODO: FFI-based solution to allow testing and printing with python code. 4146944f7daSTobias Gysi # Prepare arguments: one result i32. 4156944f7daSTobias Gysi # Arguments must be passed as pointers. 4166944f7daSTobias Gysi c_int_p = ctypes.c_int * 1 4176944f7daSTobias Gysi res = c_int_p(-1) 4186944f7daSTobias Gysi execution_engine.invoke("main", res) 4196944f7daSTobias Gysi 4206944f7daSTobias Gysi log("RESULT: ", res[0]) 421f239026fSTobias Gysi # 77 is not selected due to the dilation 2 in the second dimension. 4226944f7daSTobias Gysi # CHECK: RESULT: 42 4236944f7daSTobias Gysi 4246944f7daSTobias Gysi 425f239026fSTobias Gysitest_max_pooling_builtin() 4266944f7daSTobias Gysi 4276944f7daSTobias Gysi 428f239026fSTobias Gysidef test_max_pooling_generic(): 4296944f7daSTobias Gysi with Context() as ctx, Location.unknown(): 4306944f7daSTobias Gysi module = Module.create() 4316944f7daSTobias Gysi f64 = F64Type.get() 4326944f7daSTobias Gysi i32 = IntegerType.get_signless(32) 4336944f7daSTobias Gysi with InsertionPoint(module.body): 4346944f7daSTobias Gysi 43536550692SRiver Riddle @func.FuncOp.from_py_func( 436f9008e63STobias Hieta MemRefType.get((1, 4, 16, 1), f64), 437f9008e63STobias Hieta MemRefType.get((2, 2), f64), 438f9008e63STobias Hieta MemRefType.get((1, 2, 4, 1), i32), 439f9008e63STobias Hieta ) 4406944f7daSTobias Gysi def pooling_on_buffers(input, shape, output): 4419c491953SHanhan Wang linalg.pooling_nhwc_max( 4426944f7daSTobias Gysi input, 4436944f7daSTobias Gysi shape, 4446944f7daSTobias Gysi outs=[output], 4456944f7daSTobias Gysi strides=[2, 4], 4466944f7daSTobias Gysi dilations=[1, 2], 447f9008e63STobias Hieta emit_generic=True, 448f9008e63STobias Hieta ) 4496944f7daSTobias Gysi 4506944f7daSTobias Gysi execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 4516944f7daSTobias Gysi 4526944f7daSTobias Gysi # TODO: FFI-based solution to allow testing and printing with python code. 4536944f7daSTobias Gysi # Prepare arguments: one result i32. 4546944f7daSTobias Gysi # Arguments must be passed as pointers. 4556944f7daSTobias Gysi c_int_p = ctypes.c_int * 1 4566944f7daSTobias Gysi res = c_int_p(-1) 4576944f7daSTobias Gysi execution_engine.invoke("main", res) 4586944f7daSTobias Gysi 4596944f7daSTobias Gysi log("RESULT: ", res[0]) 460f239026fSTobias Gysi # 77 is not selected due to the dilation 2 in the second dimension. 4616944f7daSTobias Gysi # CHECK: RESULT: 42 4626944f7daSTobias Gysi 4636944f7daSTobias Gysi 464f239026fSTobias Gysitest_max_pooling_generic() 465f239026fSTobias Gysi 466f239026fSTobias Gysi 467f239026fSTobias Gysidef test_min_pooling_builtin(): 468f239026fSTobias Gysi with Context() as ctx, Location.unknown(): 469f239026fSTobias Gysi module = Module.create() 470f239026fSTobias Gysi f64 = F64Type.get() 471f239026fSTobias Gysi i32 = IntegerType.get_signless(32) 472f239026fSTobias Gysi with InsertionPoint(module.body): 473f239026fSTobias Gysi 47436550692SRiver Riddle @func.FuncOp.from_py_func( 475f9008e63STobias Hieta MemRefType.get((1, 4, 16, 1), f64), 476f9008e63STobias Hieta MemRefType.get((2, 2), f64), 477f9008e63STobias Hieta MemRefType.get((1, 2, 4, 1), i32), 478f9008e63STobias Hieta ) 479d50571abSgysit # Set the strides and use the default dilations. 480f239026fSTobias Gysi def pooling_on_buffers(input, shape, output): 48151fdd802Sgysit linalg.pooling_nhwc_min(input, shape, outs=[output], strides=[2, 4]) 482f239026fSTobias Gysi 483f239026fSTobias Gysi execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 484f239026fSTobias Gysi 485f239026fSTobias Gysi # TODO: FFI-based solution to allow testing and printing with python code. 486f239026fSTobias Gysi # Prepare arguments: one result i32. 487f239026fSTobias Gysi # Arguments must be passed as pointers. 488f239026fSTobias Gysi c_int_p = ctypes.c_int * 1 489f239026fSTobias Gysi res = c_int_p(-1) 490f239026fSTobias Gysi execution_engine.invoke("main", res) 491f239026fSTobias Gysi 492f239026fSTobias Gysi log("RESULT: ", res[0]) 493f239026fSTobias Gysi # CHECK: RESULT: -13 494f239026fSTobias Gysi 495f239026fSTobias Gysi 496f239026fSTobias Gysitest_min_pooling_builtin() 497f239026fSTobias Gysi 498f239026fSTobias Gysi 499f239026fSTobias Gysidef test_min_pooling_generic(): 500f239026fSTobias Gysi with Context() as ctx, Location.unknown(): 501f239026fSTobias Gysi module = Module.create() 502f239026fSTobias Gysi f64 = F64Type.get() 503f239026fSTobias Gysi i32 = IntegerType.get_signless(32) 504f239026fSTobias Gysi with InsertionPoint(module.body): 505f239026fSTobias Gysi 50636550692SRiver Riddle @func.FuncOp.from_py_func( 507f9008e63STobias Hieta MemRefType.get((1, 4, 16, 1), f64), 508f9008e63STobias Hieta MemRefType.get((2, 2), f64), 509f9008e63STobias Hieta MemRefType.get((1, 2, 4, 1), i32), 510f9008e63STobias Hieta ) 511d50571abSgysit # Set the strides and use the default dilations. 512f239026fSTobias Gysi def pooling_on_buffers(input, shape, output): 5139c491953SHanhan Wang linalg.pooling_nhwc_min( 514f9008e63STobias Hieta input, shape, outs=[output], strides=[2, 4], emit_generic=True 515f9008e63STobias Hieta ) 516f239026fSTobias Gysi 517f239026fSTobias Gysi execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 518f239026fSTobias Gysi 519f239026fSTobias Gysi # TODO: FFI-based solution to allow testing and printing with python code. 520f239026fSTobias Gysi # Prepare arguments: one result i32. 521f239026fSTobias Gysi # Arguments must be passed as pointers. 522f239026fSTobias Gysi c_int_p = ctypes.c_int * 1 523f239026fSTobias Gysi res = c_int_p(-1) 524f239026fSTobias Gysi execution_engine.invoke("main", res) 525f239026fSTobias Gysi 526f239026fSTobias Gysi log("RESULT: ", res[0]) 527f239026fSTobias Gysi # CHECK: RESULT: -13 528f239026fSTobias Gysi 529f239026fSTobias Gysi 530f239026fSTobias Gysitest_min_pooling_generic() 531