xref: /llvm-project/mlir/test/python/integration/dialects/linalg/opsrun.py (revision eb6c4197d5263ed2e086925b2b2f032a19442d2b)
16944f7daSTobias Gysi# RUN: %PYTHON %s 2>&1 | FileCheck %s
26944f7daSTobias Gysi
37e2174c2SStella Laurenzoimport ctypes
46944f7daSTobias Gysiimport sys
56944f7daSTobias Gysifrom mlir.ir import *
66944f7daSTobias Gysifrom mlir.dialects import builtin
723aa5a74SRiver Riddlefrom mlir.dialects import func
86944f7daSTobias Gysifrom mlir.dialects import linalg
96944f7daSTobias Gysifrom mlir.passmanager import *
106944f7daSTobias Gysifrom mlir.execution_engine import *
116944f7daSTobias Gysi
1251fdd802Sgysitfrom mlir.dialects.linalg.opdsl.lang import *
1351fdd802Sgysit
146944f7daSTobias Gysi
156944f7daSTobias Gysi# Log everything to stderr and flush so that we have a unified stream to match
166944f7daSTobias Gysi# errors/info emitted by MLIR to stderr.
176944f7daSTobias Gysidef log(*args):
186944f7daSTobias Gysi    print(*args, file=sys.stderr)
196944f7daSTobias Gysi    sys.stderr.flush()
206944f7daSTobias Gysi
216944f7daSTobias Gysi
2224357fecSgysitelemwise_boiler = """
232310ced8SRiver Riddlefunc.func @main() -> f32 attributes {llvm.emit_c_interface} {
2424357fecSgysit  %v0 = arith.constant 0.0 : f32
2524357fecSgysit  %v1 = arith.constant 1.0 : f32
2624357fecSgysit  %v2 = arith.constant 2.0 : f32
2724357fecSgysit
28f345f7e3Sgysit  %lhs = memref.alloc() : memref<f32>
2924357fecSgysit  %rhs = memref.alloc() : memref<4x8xf32>
3024357fecSgysit  %O0 = memref.alloc() : memref<4x8xf32>
3124357fecSgysit  %O1 = memref.alloc() : memref<4x8xf32>
327294be2bSgysit  linalg.fill ins(%v1 : f32) outs(%lhs : memref<f32>)
337294be2bSgysit  linalg.fill ins(%v2 : f32) outs(%rhs : memref<4x8xf32>)
347294be2bSgysit  linalg.fill ins(%v0 : f32) outs(%O0 : memref<4x8xf32>)
357294be2bSgysit  linalg.fill ins(%v0 : f32) outs(%O1 : memref<4x8xf32>)
3624357fecSgysit
3724357fecSgysit  call @elemwise_exp_add_on_buffers(%lhs, %rhs, %O0) :
38f345f7e3Sgysit    (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
3924357fecSgysit  call @elemwise_log_mul_on_buffers(%lhs, %rhs, %O1) :
40f345f7e3Sgysit    (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
4124357fecSgysit
4224357fecSgysit  %c0 = arith.constant 0 : index
4324357fecSgysit  %res0 = memref.load %O0[%c0, %c0] : memref<4x8xf32>
4424357fecSgysit  %res1 = memref.load %O1[%c0, %c0] : memref<4x8xf32>
4524357fecSgysit
4624357fecSgysit  %0 = arith.addf %res0, %res1 : f32
4724357fecSgysit
4824357fecSgysit  // TODO: FFI-based solution to allow testing and printing with python code.
4924357fecSgysit  return %0 : f32
5024357fecSgysit}
5124357fecSgysit"""
5224357fecSgysit
536944f7daSTobias Gysifill_boiler = """
542310ced8SRiver Riddlefunc.func @main() -> i32 attributes {llvm.emit_c_interface} {
55a3655de2Sgysit  %O0 = memref.alloc() : memref<i32>
56a3655de2Sgysit  %O1 = memref.alloc() : memref<16xi32>
57a3655de2Sgysit  %O2 = memref.alloc() : memref<4x16xi32>
58a3655de2Sgysit
59a3655de2Sgysit  %val0 = arith.constant 1.0 : f32
60a3655de2Sgysit  %val1 = arith.constant 2.0 : f32
61a3655de2Sgysit  %val2 = arith.constant 3.0 : f32
62a3655de2Sgysit
63a3655de2Sgysit  call @fill_0d_on_buffers(%val0, %O0) : (f32, memref<i32>) -> ()
64a3655de2Sgysit  call @fill_1d_on_buffers(%val1, %O1) : (f32, memref<16xi32>) -> ()
65a3655de2Sgysit  call @fill_2d_on_buffers(%val2, %O2) : (f32, memref<4x16xi32>) -> ()
66a3655de2Sgysit
67a3655de2Sgysit  %c0 = arith.constant 0 : index
68a3655de2Sgysit  %res0 = memref.load %O0[] : memref<i32>
69a3655de2Sgysit  %c8 = arith.constant 8 : index
70a3655de2Sgysit  %res1 = memref.load %O1[%c8] : memref<16xi32>
71a3655de2Sgysit  %c2 = arith.constant 2 : index
72a3655de2Sgysit  %res2 = memref.load %O2[%c2, %c8] : memref<4x16xi32>
73a3655de2Sgysit
74a3655de2Sgysit  %0 = arith.addi %res0, %res1 : i32
75a3655de2Sgysit  %1 = arith.addi %0, %res2 : i32
76a3655de2Sgysit
77a3655de2Sgysit  // TODO: FFI-based solution to allow testing and printing with python code.
78a3655de2Sgysit  return %1 : i32
79a3655de2Sgysit}
80a3655de2Sgysit"""
81a3655de2Sgysit
82a3655de2Sgysitfill_rng_boiler = """
832310ced8SRiver Riddlefunc.func @main() -> i32 attributes {llvm.emit_c_interface} {
846944f7daSTobias Gysi  %O = memref.alloc() : memref<4x16xi32>
85a54f4eaeSMogball  %min = arith.constant -1000.0 : f64
86a54f4eaeSMogball  %max = arith.constant 1000.0 : f64
87a54f4eaeSMogball  %seed = arith.constant 42 : i32
886944f7daSTobias Gysi
89a3655de2Sgysit  call @fill_rng_on_buffers(%min, %max, %seed, %O) :
906944f7daSTobias Gysi    (f64, f64, i32, memref<4x16xi32>) -> ()
916944f7daSTobias Gysi
92a54f4eaeSMogball  %c0 = arith.constant 0 : index
936944f7daSTobias Gysi  %0 = memref.load %O[%c0, %c0] : memref<4x16xi32>
946944f7daSTobias Gysi
956944f7daSTobias Gysi  // TODO: FFI-based solution to allow testing and printing with python code.
966944f7daSTobias Gysi  return %0 : i32
976944f7daSTobias Gysi}
986944f7daSTobias Gysi"""
996944f7daSTobias Gysi
1006944f7daSTobias Gysiconv_boiler = """
1012310ced8SRiver Riddlefunc.func @main() -> i32 attributes {llvm.emit_c_interface} {
102a54f4eaeSMogball  %v0 = arith.constant 0 : i32
103a54f4eaeSMogball  %v1 = arith.constant 1.0 : f64
104a54f4eaeSMogball  %v2 = arith.constant 2.0 : f64
1056944f7daSTobias Gysi
1066944f7daSTobias Gysi  %input = memref.alloc() : memref<1x4x16x1xf64>
1076944f7daSTobias Gysi  %filter = memref.alloc() : memref<2x2x1xf64>
1086944f7daSTobias Gysi  %output = memref.alloc() : memref<1x2x4x1xi32>
1097294be2bSgysit  linalg.fill ins(%v1 : f64) outs(%input : memref<1x4x16x1xf64>)
1107294be2bSgysit  linalg.fill ins(%v2 : f64) outs(%filter : memref<2x2x1xf64>)
1117294be2bSgysit  linalg.fill ins(%v0 : i32) outs(%output : memref<1x2x4x1xi32>)
1126944f7daSTobias Gysi
1136944f7daSTobias Gysi  call @conv_on_buffers(%input, %filter, %output) :
1146944f7daSTobias Gysi    (memref<1x4x16x1xf64>, memref<2x2x1xf64>, memref<1x2x4x1xi32>) -> ()
1156944f7daSTobias Gysi
116a54f4eaeSMogball  %c0 = arith.constant 0 : index
1176944f7daSTobias Gysi  %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32>
1186944f7daSTobias Gysi
1196944f7daSTobias Gysi  // TODO: FFI-based solution to allow testing and printing with python code.
1206944f7daSTobias Gysi  return %0 : i32
1216944f7daSTobias Gysi}
1226944f7daSTobias Gysi"""
1236944f7daSTobias Gysi
1246944f7daSTobias Gysipooling_boiler = """
1252310ced8SRiver Riddlefunc.func @main() -> i32 attributes {llvm.emit_c_interface} {
126a54f4eaeSMogball  %v0 = arith.constant 0 : i32
127a54f4eaeSMogball  %v42 = arith.constant 42.0 : f64
128a54f4eaeSMogball  %v77 = arith.constant 77.0 : f64
129a54f4eaeSMogball  %v-13 = arith.constant -13.0 : f64
130a54f4eaeSMogball  %v1 = arith.constant 1.0 : f64
1316944f7daSTobias Gysi
1326944f7daSTobias Gysi  %input = memref.alloc() : memref<1x4x16x1xf64>
1336944f7daSTobias Gysi  %shape = memref.alloc() : memref<2x2xf64>
1346944f7daSTobias Gysi  %output = memref.alloc() : memref<1x2x4x1xi32>
1357294be2bSgysit  linalg.fill ins(%v1 : f64) outs(%input : memref<1x4x16x1xf64>)
1367294be2bSgysit  linalg.fill ins(%v1 : f64) outs(%shape : memref<2x2xf64>)
1377294be2bSgysit  linalg.fill ins(%v0 : i32) outs(%output : memref<1x2x4x1xi32>)
1386944f7daSTobias Gysi
139a54f4eaeSMogball  %c0 = arith.constant 0 : index
140a54f4eaeSMogball  %c1 = arith.constant 1 : index
141a54f4eaeSMogball  %c2 = arith.constant 2 : index
1426944f7daSTobias Gysi  memref.store %v42, %input[%c0, %c0, %c0, %c0] : memref<1x4x16x1xf64>
143f239026fSTobias Gysi  memref.store %v77, %input[%c0, %c0, %c1, %c0] : memref<1x4x16x1xf64>
144d50571abSgysit  memref.store %v-13, %input[%c0, %c1, %c0, %c0] : memref<1x4x16x1xf64>
1456944f7daSTobias Gysi
1466944f7daSTobias Gysi  call @pooling_on_buffers(%input, %shape, %output) :
1476944f7daSTobias Gysi    (memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> ()
1486944f7daSTobias Gysi
1496944f7daSTobias Gysi  %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32>
1506944f7daSTobias Gysi
1516944f7daSTobias Gysi  // TODO: FFI-based solution to allow testing and printing with python code.
1526944f7daSTobias Gysi  return %0 : i32
1536944f7daSTobias Gysi}
1546944f7daSTobias Gysi"""
1556944f7daSTobias Gysi
1566944f7daSTobias Gysi
1576944f7daSTobias Gysidef transform(module, boilerplate):
1586944f7daSTobias Gysi    # TODO: Allow cloning functions from one module to another.
1596944f7daSTobias Gysi    # Atm we have to resort to string concatenation.
160a3655de2Sgysit    ops = module.operation.regions[0].blocks[0].operations
161a3655de2Sgysit    mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate)
162a3655de2Sgysit
163f9008e63STobias Hieta    pm = PassManager("builtin.module")
164dd1b1d44Srkayaith    pm.add("func.func(convert-linalg-to-loops)")
165dd1b1d44Srkayaith    pm.add("func.func(lower-affine)")
166dd1b1d44Srkayaith    pm.add("func.func(convert-math-to-llvm)")
167dd1b1d44Srkayaith    pm.add("func.func(convert-scf-to-cf)")
168dd1b1d44Srkayaith    pm.add("func.func(arith-expand)")
169dd1b1d44Srkayaith    pm.add("func.func(memref-expand)")
170dd1b1d44Srkayaith    pm.add("convert-vector-to-llvm")
171cb4ccd38SQuentin Colombet    pm.add("finalize-memref-to-llvm")
172dd1b1d44Srkayaith    pm.add("convert-func-to-llvm")
173b03a09e7SMatthias Springer    pm.add("convert-arith-to-llvm")
174*eb6c4197SMatthias Springer    pm.add("convert-cf-to-llvm")
175dd1b1d44Srkayaith    pm.add("reconcile-unrealized-casts")
176c00f81ccSrkayaith    pm.run(mod.operation)
1776944f7daSTobias Gysi    return mod
1786944f7daSTobias Gysi
1796944f7daSTobias Gysi
18024357fecSgysitdef test_elemwise_builtin():
18124357fecSgysit    with Context() as ctx, Location.unknown():
18224357fecSgysit        module = Module.create()
18324357fecSgysit        f32 = F32Type.get()
18424357fecSgysit        i8 = IntegerType.get_signless(8)
18524357fecSgysit        with InsertionPoint(module.body):
18624357fecSgysit
18736550692SRiver Riddle            @func.FuncOp.from_py_func(
188f9008e63STobias Hieta                MemRefType.get((), f32),
189f9008e63STobias Hieta                MemRefType.get((4, 8), f32),
190f9008e63STobias Hieta                MemRefType.get((4, 8), f32),
191f9008e63STobias Hieta            )
19224357fecSgysit            def elemwise_exp_add_on_buffers(lhs, rhs, out):
19324357fecSgysit                linalg.elemwise_unary(lhs, outs=[out])
19424357fecSgysit                linalg.elemwise_binary(out, rhs, outs=[out])
19524357fecSgysit
19636550692SRiver Riddle            @func.FuncOp.from_py_func(
197f9008e63STobias Hieta                MemRefType.get((), f32),
198f9008e63STobias Hieta                MemRefType.get((4, 8), f32),
199f9008e63STobias Hieta                MemRefType.get((4, 8), f32),
200f9008e63STobias Hieta            )
20124357fecSgysit            def elemwise_log_mul_on_buffers(lhs, rhs, out):
20224357fecSgysit                linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log)
20324357fecSgysit                linalg.elemwise_binary(out, rhs, outs=[out], fun=BinaryFn.mul)
20424357fecSgysit
20524357fecSgysit        execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
20624357fecSgysit
20724357fecSgysit        # TODO: FFI-based solution to allow testing and printing with python code.
20824357fecSgysit        # Prepare arguments: one result f32.
20924357fecSgysit        # Arguments must be passed as pointers.
21024357fecSgysit        c_float_p = ctypes.c_float * 1
211f9008e63STobias Hieta        res = c_float_p(-1.0)
21224357fecSgysit        execution_engine.invoke("main", res)
21324357fecSgysit
21424357fecSgysit        log("RESULT: ", res[0])
21524357fecSgysit        # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
21624357fecSgysit        # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
21724357fecSgysit        # CHECK: RESULT: 4.71828
21824357fecSgysit
21924357fecSgysit
22024357fecSgysittest_elemwise_builtin()
22124357fecSgysit
22224357fecSgysit
22324357fecSgysitdef test_elemwise_generic():
22424357fecSgysit    with Context() as ctx, Location.unknown():
22524357fecSgysit        module = Module.create()
22624357fecSgysit        f32 = F32Type.get()
22724357fecSgysit        i8 = IntegerType.get_signless(8)
22824357fecSgysit        with InsertionPoint(module.body):
22924357fecSgysit
23036550692SRiver Riddle            @func.FuncOp.from_py_func(
231f9008e63STobias Hieta                MemRefType.get((), f32),
232f9008e63STobias Hieta                MemRefType.get((4, 8), f32),
233f9008e63STobias Hieta                MemRefType.get((4, 8), f32),
234f9008e63STobias Hieta            )
23524357fecSgysit            def elemwise_exp_add_on_buffers(lhs, rhs, out):
23624357fecSgysit                linalg.elemwise_unary(lhs, outs=[out], emit_generic=True)
23724357fecSgysit                linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True)
23824357fecSgysit
23936550692SRiver Riddle            @func.FuncOp.from_py_func(
240f9008e63STobias Hieta                MemRefType.get((), f32),
241f9008e63STobias Hieta                MemRefType.get((4, 8), f32),
242f9008e63STobias Hieta                MemRefType.get((4, 8), f32),
243f9008e63STobias Hieta            )
24424357fecSgysit            def elemwise_log_mul_on_buffers(lhs, rhs, out):
24524357fecSgysit                linalg.elemwise_unary(
246f9008e63STobias Hieta                    lhs, outs=[out], fun=UnaryFn.log, emit_generic=True
247f9008e63STobias Hieta                )
24824357fecSgysit                linalg.elemwise_binary(
249f9008e63STobias Hieta                    out, rhs, outs=[out], fun=BinaryFn.mul, emit_generic=True
250f9008e63STobias Hieta                )
25124357fecSgysit
25224357fecSgysit        execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
25324357fecSgysit
25424357fecSgysit        # TODO: FFI-based solution to allow testing and printing with python code.
25524357fecSgysit        # Prepare arguments: one result f32.
25624357fecSgysit        # Arguments must be passed as pointers.
25724357fecSgysit        c_float_p = ctypes.c_float * 1
258f9008e63STobias Hieta        res = c_float_p(-1.0)
25924357fecSgysit        execution_engine.invoke("main", res)
26024357fecSgysit
26124357fecSgysit        log("RESULT: ", res[0])
26224357fecSgysit        # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
26324357fecSgysit        # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
26424357fecSgysit        # CHECK: RESULT: 4.71828
26524357fecSgysit
26624357fecSgysit
26724357fecSgysittest_elemwise_generic()
26824357fecSgysit
26924357fecSgysit
2706944f7daSTobias Gysidef test_fill_builtin():
2716944f7daSTobias Gysi    with Context() as ctx, Location.unknown():
2726944f7daSTobias Gysi        module = Module.create()
273a3655de2Sgysit        f32 = F32Type.get()
2746944f7daSTobias Gysi        i32 = IntegerType.get_signless(32)
2756944f7daSTobias Gysi        with InsertionPoint(module.body):
2766944f7daSTobias Gysi
27736550692SRiver Riddle            @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
278a3655de2Sgysit            def fill_0d_on_buffers(value, out):
2797294be2bSgysit                linalg.fill(value, outs=[out])
280a3655de2Sgysit
28136550692SRiver Riddle            @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
282a3655de2Sgysit            def fill_1d_on_buffers(value, out):
2837294be2bSgysit                linalg.fill(value, outs=[out])
284a3655de2Sgysit
28536550692SRiver Riddle            @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
286a3655de2Sgysit            def fill_2d_on_buffers(value, out):
2877294be2bSgysit                linalg.fill(value, outs=[out])
2886944f7daSTobias Gysi
2896944f7daSTobias Gysi        execution_engine = ExecutionEngine(transform(module, fill_boiler))
2906944f7daSTobias Gysi
2916944f7daSTobias Gysi        # TODO: FFI-based solution to allow testing and printing with python code.
2926944f7daSTobias Gysi        # Prepare arguments: one result i32.
2936944f7daSTobias Gysi        # Arguments must be passed as pointers.
2946944f7daSTobias Gysi        c_int_p = ctypes.c_int * 1
2956944f7daSTobias Gysi        res = c_int_p(-1)
2966944f7daSTobias Gysi        execution_engine.invoke("main", res)
2976944f7daSTobias Gysi
2986944f7daSTobias Gysi        log("RESULT: ", res[0])
299a3655de2Sgysit        # CHECK: RESULT: 6
3006944f7daSTobias Gysi
3016944f7daSTobias Gysi
3026944f7daSTobias Gysitest_fill_builtin()
3036944f7daSTobias Gysi
3046944f7daSTobias Gysi
3056944f7daSTobias Gysidef test_fill_generic():
3066944f7daSTobias Gysi    with Context() as ctx, Location.unknown():
3076944f7daSTobias Gysi        module = Module.create()
308a3655de2Sgysit        f32 = F32Type.get()
309a3655de2Sgysit        i32 = IntegerType.get_signless(32)
310a3655de2Sgysit        with InsertionPoint(module.body):
311a3655de2Sgysit
31236550692SRiver Riddle            @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
313a3655de2Sgysit            def fill_0d_on_buffers(value, out):
3147294be2bSgysit                linalg.fill(value, outs=[out], emit_generic=True)
315a3655de2Sgysit
31636550692SRiver Riddle            @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
317a3655de2Sgysit            def fill_1d_on_buffers(value, out):
3187294be2bSgysit                linalg.fill(value, outs=[out], emit_generic=True)
319a3655de2Sgysit
32036550692SRiver Riddle            @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
321a3655de2Sgysit            def fill_2d_on_buffers(value, out):
3227294be2bSgysit                linalg.fill(value, outs=[out], emit_generic=True)
323a3655de2Sgysit
324a3655de2Sgysit        execution_engine = ExecutionEngine(transform(module, fill_boiler))
325a3655de2Sgysit
326a3655de2Sgysit        # TODO: FFI-based solution to allow testing and printing with python code.
327a3655de2Sgysit        # Prepare arguments: one result i32.
328a3655de2Sgysit        # Arguments must be passed as pointers.
329a3655de2Sgysit        c_int_p = ctypes.c_int * 1
330a3655de2Sgysit        res = c_int_p(-1)
331a3655de2Sgysit        execution_engine.invoke("main", res)
332a3655de2Sgysit
333a3655de2Sgysit        log("RESULT: ", res[0])
334a3655de2Sgysit        # CHECK: RESULT: 6
335a3655de2Sgysit
336a3655de2Sgysit
337a3655de2Sgysittest_fill_generic()
338a3655de2Sgysit
339a3655de2Sgysit
340a3655de2Sgysitdef test_fill_rng_builtin():
341a3655de2Sgysit    with Context() as ctx, Location.unknown():
342a3655de2Sgysit        module = Module.create()
3436944f7daSTobias Gysi        f64 = F64Type.get()
3446944f7daSTobias Gysi        i32 = IntegerType.get_signless(32)
3456944f7daSTobias Gysi        with InsertionPoint(module.body):
3466944f7daSTobias Gysi
34736550692SRiver Riddle            @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
348a3655de2Sgysit            def fill_rng_on_buffers(min, max, seed, out):
349a3655de2Sgysit                linalg.fill_rng_2d(min, max, seed, outs=[out])
3506944f7daSTobias Gysi
351a3655de2Sgysit        execution_engine = ExecutionEngine(transform(module, fill_rng_boiler))
3526944f7daSTobias Gysi
3536944f7daSTobias Gysi        # TODO: FFI-based solution to allow testing and printing with python code.
3546944f7daSTobias Gysi        # Prepare arguments: one result i32.
3556944f7daSTobias Gysi        # Arguments must be passed as pointers.
3566944f7daSTobias Gysi        c_int_p = ctypes.c_int * 1
3576944f7daSTobias Gysi        res = c_int_p(-1)
3586944f7daSTobias Gysi        execution_engine.invoke("main", res)
3596944f7daSTobias Gysi
3606944f7daSTobias Gysi        log("RESULT: ", res[0])
3616944f7daSTobias Gysi        # CHECK: RESULT: -480
3626944f7daSTobias Gysi
3636944f7daSTobias Gysi
364a3655de2Sgysittest_fill_rng_builtin()
365a3655de2Sgysit
366a3655de2Sgysit
367a3655de2Sgysitdef test_fill_rng_generic():
368a3655de2Sgysit    with Context() as ctx, Location.unknown():
369a3655de2Sgysit        module = Module.create()
370a3655de2Sgysit        f64 = F64Type.get()
371a3655de2Sgysit        i32 = IntegerType.get_signless(32)
372a3655de2Sgysit        with InsertionPoint(module.body):
373a3655de2Sgysit
37436550692SRiver Riddle            @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
375a3655de2Sgysit            def fill_rng_on_buffers(min, max, seed, out):
376a3655de2Sgysit                linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True)
377a3655de2Sgysit
378a3655de2Sgysit        execution_engine = ExecutionEngine(transform(module, fill_rng_boiler))
379a3655de2Sgysit
380a3655de2Sgysit        # TODO: FFI-based solution to allow testing and printing with python code.
381a3655de2Sgysit        # Prepare arguments: one result i32.
382a3655de2Sgysit        # Arguments must be passed as pointers.
383a3655de2Sgysit        c_int_p = ctypes.c_int * 1
384a3655de2Sgysit        res = c_int_p(-1)
385a3655de2Sgysit        execution_engine.invoke("main", res)
386a3655de2Sgysit
387a3655de2Sgysit        log("RESULT: ", res[0])
388a3655de2Sgysit        # CHECK: RESULT: -480
389a3655de2Sgysit
390a3655de2Sgysit
391a3655de2Sgysittest_fill_rng_generic()
3926944f7daSTobias Gysi
3936944f7daSTobias Gysi
394f239026fSTobias Gysidef test_max_pooling_builtin():
3956944f7daSTobias Gysi    with Context() as ctx, Location.unknown():
3966944f7daSTobias Gysi        module = Module.create()
3976944f7daSTobias Gysi        f64 = F64Type.get()
3986944f7daSTobias Gysi        i32 = IntegerType.get_signless(32)
3996944f7daSTobias Gysi        with InsertionPoint(module.body):
4006944f7daSTobias Gysi
40136550692SRiver Riddle            @func.FuncOp.from_py_func(
402f9008e63STobias Hieta                MemRefType.get((1, 4, 16, 1), f64),
403f9008e63STobias Hieta                MemRefType.get((2, 2), f64),
404f9008e63STobias Hieta                MemRefType.get((1, 2, 4, 1), i32),
405f9008e63STobias Hieta            )
4066944f7daSTobias Gysi            def pooling_on_buffers(input, shape, output):
4079c491953SHanhan Wang                linalg.pooling_nhwc_max(
408f9008e63STobias Hieta                    input, shape, outs=[output], strides=[2, 4], dilations=[1, 2]
409f9008e63STobias Hieta                )
4106944f7daSTobias Gysi
4116944f7daSTobias Gysi        execution_engine = ExecutionEngine(transform(module, pooling_boiler))
4126944f7daSTobias Gysi
4136944f7daSTobias Gysi        # TODO: FFI-based solution to allow testing and printing with python code.
4146944f7daSTobias Gysi        # Prepare arguments: one result i32.
4156944f7daSTobias Gysi        # Arguments must be passed as pointers.
4166944f7daSTobias Gysi        c_int_p = ctypes.c_int * 1
4176944f7daSTobias Gysi        res = c_int_p(-1)
4186944f7daSTobias Gysi        execution_engine.invoke("main", res)
4196944f7daSTobias Gysi
4206944f7daSTobias Gysi        log("RESULT: ", res[0])
421f239026fSTobias Gysi        # 77 is not selected due to the dilation 2 in the second dimension.
4226944f7daSTobias Gysi        # CHECK: RESULT: 42
4236944f7daSTobias Gysi
4246944f7daSTobias Gysi
425f239026fSTobias Gysitest_max_pooling_builtin()
4266944f7daSTobias Gysi
4276944f7daSTobias Gysi
428f239026fSTobias Gysidef test_max_pooling_generic():
4296944f7daSTobias Gysi    with Context() as ctx, Location.unknown():
4306944f7daSTobias Gysi        module = Module.create()
4316944f7daSTobias Gysi        f64 = F64Type.get()
4326944f7daSTobias Gysi        i32 = IntegerType.get_signless(32)
4336944f7daSTobias Gysi        with InsertionPoint(module.body):
4346944f7daSTobias Gysi
43536550692SRiver Riddle            @func.FuncOp.from_py_func(
436f9008e63STobias Hieta                MemRefType.get((1, 4, 16, 1), f64),
437f9008e63STobias Hieta                MemRefType.get((2, 2), f64),
438f9008e63STobias Hieta                MemRefType.get((1, 2, 4, 1), i32),
439f9008e63STobias Hieta            )
4406944f7daSTobias Gysi            def pooling_on_buffers(input, shape, output):
4419c491953SHanhan Wang                linalg.pooling_nhwc_max(
4426944f7daSTobias Gysi                    input,
4436944f7daSTobias Gysi                    shape,
4446944f7daSTobias Gysi                    outs=[output],
4456944f7daSTobias Gysi                    strides=[2, 4],
4466944f7daSTobias Gysi                    dilations=[1, 2],
447f9008e63STobias Hieta                    emit_generic=True,
448f9008e63STobias Hieta                )
4496944f7daSTobias Gysi
4506944f7daSTobias Gysi        execution_engine = ExecutionEngine(transform(module, pooling_boiler))
4516944f7daSTobias Gysi
4526944f7daSTobias Gysi        # TODO: FFI-based solution to allow testing and printing with python code.
4536944f7daSTobias Gysi        # Prepare arguments: one result i32.
4546944f7daSTobias Gysi        # Arguments must be passed as pointers.
4556944f7daSTobias Gysi        c_int_p = ctypes.c_int * 1
4566944f7daSTobias Gysi        res = c_int_p(-1)
4576944f7daSTobias Gysi        execution_engine.invoke("main", res)
4586944f7daSTobias Gysi
4596944f7daSTobias Gysi        log("RESULT: ", res[0])
460f239026fSTobias Gysi        # 77 is not selected due to the dilation 2 in the second dimension.
4616944f7daSTobias Gysi        # CHECK: RESULT: 42
4626944f7daSTobias Gysi
4636944f7daSTobias Gysi
464f239026fSTobias Gysitest_max_pooling_generic()
465f239026fSTobias Gysi
466f239026fSTobias Gysi
467f239026fSTobias Gysidef test_min_pooling_builtin():
468f239026fSTobias Gysi    with Context() as ctx, Location.unknown():
469f239026fSTobias Gysi        module = Module.create()
470f239026fSTobias Gysi        f64 = F64Type.get()
471f239026fSTobias Gysi        i32 = IntegerType.get_signless(32)
472f239026fSTobias Gysi        with InsertionPoint(module.body):
473f239026fSTobias Gysi
47436550692SRiver Riddle            @func.FuncOp.from_py_func(
475f9008e63STobias Hieta                MemRefType.get((1, 4, 16, 1), f64),
476f9008e63STobias Hieta                MemRefType.get((2, 2), f64),
477f9008e63STobias Hieta                MemRefType.get((1, 2, 4, 1), i32),
478f9008e63STobias Hieta            )
479d50571abSgysit            # Set the strides and use the default dilations.
480f239026fSTobias Gysi            def pooling_on_buffers(input, shape, output):
48151fdd802Sgysit                linalg.pooling_nhwc_min(input, shape, outs=[output], strides=[2, 4])
482f239026fSTobias Gysi
483f239026fSTobias Gysi        execution_engine = ExecutionEngine(transform(module, pooling_boiler))
484f239026fSTobias Gysi
485f239026fSTobias Gysi        # TODO: FFI-based solution to allow testing and printing with python code.
486f239026fSTobias Gysi        # Prepare arguments: one result i32.
487f239026fSTobias Gysi        # Arguments must be passed as pointers.
488f239026fSTobias Gysi        c_int_p = ctypes.c_int * 1
489f239026fSTobias Gysi        res = c_int_p(-1)
490f239026fSTobias Gysi        execution_engine.invoke("main", res)
491f239026fSTobias Gysi
492f239026fSTobias Gysi        log("RESULT: ", res[0])
493f239026fSTobias Gysi        # CHECK: RESULT: -13
494f239026fSTobias Gysi
495f239026fSTobias Gysi
496f239026fSTobias Gysitest_min_pooling_builtin()
497f239026fSTobias Gysi
498f239026fSTobias Gysi
499f239026fSTobias Gysidef test_min_pooling_generic():
500f239026fSTobias Gysi    with Context() as ctx, Location.unknown():
501f239026fSTobias Gysi        module = Module.create()
502f239026fSTobias Gysi        f64 = F64Type.get()
503f239026fSTobias Gysi        i32 = IntegerType.get_signless(32)
504f239026fSTobias Gysi        with InsertionPoint(module.body):
505f239026fSTobias Gysi
50636550692SRiver Riddle            @func.FuncOp.from_py_func(
507f9008e63STobias Hieta                MemRefType.get((1, 4, 16, 1), f64),
508f9008e63STobias Hieta                MemRefType.get((2, 2), f64),
509f9008e63STobias Hieta                MemRefType.get((1, 2, 4, 1), i32),
510f9008e63STobias Hieta            )
511d50571abSgysit            # Set the strides and use the default dilations.
512f239026fSTobias Gysi            def pooling_on_buffers(input, shape, output):
5139c491953SHanhan Wang                linalg.pooling_nhwc_min(
514f9008e63STobias Hieta                    input, shape, outs=[output], strides=[2, 4], emit_generic=True
515f9008e63STobias Hieta                )
516f239026fSTobias Gysi
517f239026fSTobias Gysi        execution_engine = ExecutionEngine(transform(module, pooling_boiler))
518f239026fSTobias Gysi
519f239026fSTobias Gysi        # TODO: FFI-based solution to allow testing and printing with python code.
520f239026fSTobias Gysi        # Prepare arguments: one result i32.
521f239026fSTobias Gysi        # Arguments must be passed as pointers.
522f239026fSTobias Gysi        c_int_p = ctypes.c_int * 1
523f239026fSTobias Gysi        res = c_int_p(-1)
524f239026fSTobias Gysi        execution_engine.invoke("main", res)
525f239026fSTobias Gysi
526f239026fSTobias Gysi        log("RESULT: ", res[0])
527f239026fSTobias Gysi        # CHECK: RESULT: -13
528f239026fSTobias Gysi
529f239026fSTobias Gysi
530f239026fSTobias Gysitest_min_pooling_generic()
531