xref: /llvm-project/mlir/test/python/execution_engine.py (revision ba44d7ba1fb3e27f51d65ea1af280e00382e09e0)
14eee0cfcSTulio Magno Quites Machado Filho# RUN: env MLIR_RUNNER_UTILS=%mlir_runner_utils MLIR_C_RUNNER_UTILS=%mlir_c_runner_utils %PYTHON %s 2>&1 | FileCheck %s
2ca98e0ddSRainer Orth# REQUIRES: host-supports-jit
395c083f5SDenys Shabalinimport gc, sys, os, tempfile
49f3f6d7bSStella Laurenzofrom mlir.ir import *
59f3f6d7bSStella Laurenzofrom mlir.passmanager import *
69f3f6d7bSStella Laurenzofrom mlir.execution_engine import *
79f3f6d7bSStella Laurenzofrom mlir.runtime import *
834d50721SKonrad Kleine
934d50721SKonrad Kleinetry:
10c8cac33aSPhrygianGates    from ml_dtypes import bfloat16, float8_e5m2
119f3f6d7bSStella Laurenzo
1234d50721SKonrad Kleine    HAS_ML_DTYPES = True
1334d50721SKonrad Kleineexcept ModuleNotFoundError:
1434d50721SKonrad Kleine    HAS_ML_DTYPES = False
1534d50721SKonrad Kleine
1634d50721SKonrad Kleine
174eee0cfcSTulio Magno Quites Machado FilhoMLIR_RUNNER_UTILS = os.getenv(
184eee0cfcSTulio Magno Quites Machado Filho    "MLIR_RUNNER_UTILS", "../../../../lib/libmlir_runner_utils.so"
194eee0cfcSTulio Magno Quites Machado Filho)
204eee0cfcSTulio Magno Quites Machado FilhoMLIR_C_RUNNER_UTILS = os.getenv(
214eee0cfcSTulio Magno Quites Machado Filho    "MLIR_C_RUNNER_UTILS", "../../../../lib/libmlir_c_runner_utils.so"
224eee0cfcSTulio Magno Quites Machado Filho)
23a54f4eaeSMogball
249f3f6d7bSStella Laurenzo# Log everything to stderr and flush so that we have a unified stream to match
259f3f6d7bSStella Laurenzo# errors/info emitted by MLIR to stderr.
269f3f6d7bSStella Laurenzodef log(*args):
279f3f6d7bSStella Laurenzo    print(*args, file=sys.stderr)
289f3f6d7bSStella Laurenzo    sys.stderr.flush()
299f3f6d7bSStella Laurenzo
30a54f4eaeSMogball
319f3f6d7bSStella Laurenzodef run(f):
329f3f6d7bSStella Laurenzo    log("\nTEST:", f.__name__)
339f3f6d7bSStella Laurenzo    f()
349f3f6d7bSStella Laurenzo    gc.collect()
359f3f6d7bSStella Laurenzo    assert Context._get_live_count() == 0
369f3f6d7bSStella Laurenzo
37a54f4eaeSMogball
389f3f6d7bSStella Laurenzo# Verify capsule interop.
399f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testCapsule
409f3f6d7bSStella Laurenzodef testCapsule():
419f3f6d7bSStella Laurenzo    with Context():
42f9008e63STobias Hieta        module = Module.parse(
43f9008e63STobias Hieta            r"""
449f3f6d7bSStella Laurenzollvm.func @none() {
459f3f6d7bSStella Laurenzo  llvm.return
469f3f6d7bSStella Laurenzo}
47f9008e63STobias Hieta    """
48f9008e63STobias Hieta        )
499f3f6d7bSStella Laurenzo        execution_engine = ExecutionEngine(module)
509f3f6d7bSStella Laurenzo        execution_engine_capsule = execution_engine._CAPIPtr
519f3f6d7bSStella Laurenzo        # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr
529f3f6d7bSStella Laurenzo        log(repr(execution_engine_capsule))
539f3f6d7bSStella Laurenzo        execution_engine._testing_release()
549f3f6d7bSStella Laurenzo        execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule)
550cdf4915SStella Laurenzo        # CHECK: _mlirExecutionEngine.ExecutionEngine
569f3f6d7bSStella Laurenzo        log(repr(execution_engine1))
579f3f6d7bSStella Laurenzo
58a54f4eaeSMogball
599f3f6d7bSStella Laurenzorun(testCapsule)
609f3f6d7bSStella Laurenzo
61a54f4eaeSMogball
629f3f6d7bSStella Laurenzo# Test invalid ExecutionEngine creation
639f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testInvalidModule
649f3f6d7bSStella Laurenzodef testInvalidModule():
659f3f6d7bSStella Laurenzo    with Context():
669f3f6d7bSStella Laurenzo        # Builtin function
67f9008e63STobias Hieta        module = Module.parse(
68f9008e63STobias Hieta            r"""
692310ced8SRiver Riddle    func.func @foo() { return }
70f9008e63STobias Hieta    """
71f9008e63STobias Hieta        )
729f3f6d7bSStella Laurenzo        # CHECK: Got RuntimeError:  Failure while creating the ExecutionEngine.
739f3f6d7bSStella Laurenzo        try:
749f3f6d7bSStella Laurenzo            execution_engine = ExecutionEngine(module)
759f3f6d7bSStella Laurenzo        except RuntimeError as e:
769f3f6d7bSStella Laurenzo            log("Got RuntimeError: ", e)
779f3f6d7bSStella Laurenzo
78a54f4eaeSMogball
799f3f6d7bSStella Laurenzorun(testInvalidModule)
809f3f6d7bSStella Laurenzo
81a54f4eaeSMogball
829f3f6d7bSStella Laurenzodef lowerToLLVM(module):
83a54f4eaeSMogball    pm = PassManager.parse(
84eb6c4197SMatthias Springer        "builtin.module(convert-complex-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,convert-arith-to-llvm,convert-cf-to-llvm,reconcile-unrealized-casts)"
85f9008e63STobias Hieta    )
86c00f81ccSrkayaith    pm.run(module.operation)
879f3f6d7bSStella Laurenzo    return module
889f3f6d7bSStella Laurenzo
89a54f4eaeSMogball
909f3f6d7bSStella Laurenzo# Test simple ExecutionEngine execution
919f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testInvokeVoid
929f3f6d7bSStella Laurenzodef testInvokeVoid():
939f3f6d7bSStella Laurenzo    with Context():
94f9008e63STobias Hieta        module = Module.parse(
95f9008e63STobias Hieta            r"""
962310ced8SRiver Riddlefunc.func @void() attributes { llvm.emit_c_interface } {
979f3f6d7bSStella Laurenzo  return
989f3f6d7bSStella Laurenzo}
99f9008e63STobias Hieta    """
100f9008e63STobias Hieta        )
1019f3f6d7bSStella Laurenzo        execution_engine = ExecutionEngine(lowerToLLVM(module))
1029f3f6d7bSStella Laurenzo        # Nothing to check other than no exception thrown here.
1039f3f6d7bSStella Laurenzo        execution_engine.invoke("void")
1049f3f6d7bSStella Laurenzo
105a54f4eaeSMogball
1069f3f6d7bSStella Laurenzorun(testInvokeVoid)
1079f3f6d7bSStella Laurenzo
1089f3f6d7bSStella Laurenzo
1099f3f6d7bSStella Laurenzo# Test argument passing and result with a simple float addition.
1109f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testInvokeFloatAdd
1119f3f6d7bSStella Laurenzodef testInvokeFloatAdd():
1129f3f6d7bSStella Laurenzo    with Context():
113f9008e63STobias Hieta        module = Module.parse(
114f9008e63STobias Hieta            r"""
1152310ced8SRiver Riddlefunc.func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } {
116a54f4eaeSMogball  %add = arith.addf %arg0, %arg1 : f32
1179f3f6d7bSStella Laurenzo  return %add : f32
1189f3f6d7bSStella Laurenzo}
119f9008e63STobias Hieta    """
120f9008e63STobias Hieta        )
1219f3f6d7bSStella Laurenzo        execution_engine = ExecutionEngine(lowerToLLVM(module))
1229f3f6d7bSStella Laurenzo        # Prepare arguments: two input floats and one result.
1239f3f6d7bSStella Laurenzo        # Arguments must be passed as pointers.
1249f3f6d7bSStella Laurenzo        c_float_p = ctypes.c_float * 1
125f9008e63STobias Hieta        arg0 = c_float_p(42.0)
126f9008e63STobias Hieta        arg1 = c_float_p(2.0)
127f9008e63STobias Hieta        res = c_float_p(-1.0)
1289f3f6d7bSStella Laurenzo        execution_engine.invoke("add", arg0, arg1, res)
1299f3f6d7bSStella Laurenzo        # CHECK: 42.0 + 2.0 = 44.0
1309f3f6d7bSStella Laurenzo        log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]))
1319f3f6d7bSStella Laurenzo
132a54f4eaeSMogball
1339f3f6d7bSStella Laurenzorun(testInvokeFloatAdd)
1349f3f6d7bSStella Laurenzo
1359f3f6d7bSStella Laurenzo
1369f3f6d7bSStella Laurenzo# Test callback
1379f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testBasicCallback
1389f3f6d7bSStella Laurenzodef testBasicCallback():
1399f3f6d7bSStella Laurenzo    # Define a callback function that takes a float and an integer and returns a float.
1409f3f6d7bSStella Laurenzo    @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int)
1419f3f6d7bSStella Laurenzo    def callback(a, b):
1429f3f6d7bSStella Laurenzo        return a / 2 + b / 2
1439f3f6d7bSStella Laurenzo
1449f3f6d7bSStella Laurenzo    with Context():
1459f3f6d7bSStella Laurenzo        # The module just forwards to a runtime function known as "some_callback_into_python".
146f9008e63STobias Hieta        module = Module.parse(
147f9008e63STobias Hieta            r"""
1482310ced8SRiver Riddlefunc.func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } {
1499f3f6d7bSStella Laurenzo  %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32)
1509f3f6d7bSStella Laurenzo  return %resf : f32
1519f3f6d7bSStella Laurenzo}
1522310ced8SRiver Riddlefunc.func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface }
153f9008e63STobias Hieta    """
154f9008e63STobias Hieta        )
1559f3f6d7bSStella Laurenzo        execution_engine = ExecutionEngine(lowerToLLVM(module))
1569f3f6d7bSStella Laurenzo        execution_engine.register_runtime("some_callback_into_python", callback)
1579f3f6d7bSStella Laurenzo
1589f3f6d7bSStella Laurenzo        # Prepare arguments: two input floats and one result.
1599f3f6d7bSStella Laurenzo        # Arguments must be passed as pointers.
1609f3f6d7bSStella Laurenzo        c_float_p = ctypes.c_float * 1
1619f3f6d7bSStella Laurenzo        c_int_p = ctypes.c_int * 1
162f9008e63STobias Hieta        arg0 = c_float_p(42.0)
1639f3f6d7bSStella Laurenzo        arg1 = c_int_p(2)
164f9008e63STobias Hieta        res = c_float_p(-1.0)
1659f3f6d7bSStella Laurenzo        execution_engine.invoke("add", arg0, arg1, res)
1669f3f6d7bSStella Laurenzo        # CHECK: 42.0 + 2 = 44.0
1679f3f6d7bSStella Laurenzo        log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0] * 2))
1689f3f6d7bSStella Laurenzo
169a54f4eaeSMogball
1709f3f6d7bSStella Laurenzorun(testBasicCallback)
1719f3f6d7bSStella Laurenzo
172a54f4eaeSMogball
1739f3f6d7bSStella Laurenzo# Test callback with an unranked memref
1749f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testUnrankedMemRefCallback
1759f3f6d7bSStella Laurenzodef testUnrankedMemRefCallback():
1769f3f6d7bSStella Laurenzo    # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
1779f3f6d7bSStella Laurenzo    @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
1789f3f6d7bSStella Laurenzo    def callback(a):
1799f3f6d7bSStella Laurenzo        arr = unranked_memref_to_numpy(a, np.float32)
1809f3f6d7bSStella Laurenzo        log("Inside callback: ")
1819f3f6d7bSStella Laurenzo        log(arr)
1829f3f6d7bSStella Laurenzo
1839f3f6d7bSStella Laurenzo    with Context():
1849f3f6d7bSStella Laurenzo        # The module just forwards to a runtime function known as "some_callback_into_python".
185f9008e63STobias Hieta        module = Module.parse(
186f9008e63STobias Hieta            r"""
1872310ced8SRiver Riddlefunc.func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } {
1889f3f6d7bSStella Laurenzo  call @some_callback_into_python(%arg0) : (memref<*xf32>) -> ()
1899f3f6d7bSStella Laurenzo  return
1909f3f6d7bSStella Laurenzo}
1912310ced8SRiver Riddlefunc.func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface }
192f9008e63STobias Hieta"""
193f9008e63STobias Hieta        )
1949f3f6d7bSStella Laurenzo        execution_engine = ExecutionEngine(lowerToLLVM(module))
1959f3f6d7bSStella Laurenzo        execution_engine.register_runtime("some_callback_into_python", callback)
1969f3f6d7bSStella Laurenzo        inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
1979f3f6d7bSStella Laurenzo        # CHECK: Inside callback:
1989f3f6d7bSStella Laurenzo        # CHECK{LITERAL}: [[1. 2.]
1999f3f6d7bSStella Laurenzo        # CHECK{LITERAL}:  [3. 4.]]
2009f3f6d7bSStella Laurenzo        execution_engine.invoke(
2019f3f6d7bSStella Laurenzo            "callback_memref",
2029f3f6d7bSStella Laurenzo            ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))),
2039f3f6d7bSStella Laurenzo        )
2049f3f6d7bSStella Laurenzo        inp_arr_1 = np.array([5, 6, 7], dtype=np.float32)
2059f3f6d7bSStella Laurenzo        strided_arr = np.lib.stride_tricks.as_strided(
206f9008e63STobias Hieta            inp_arr_1, strides=(4, 0), shape=(3, 4)
207f9008e63STobias Hieta        )
2089f3f6d7bSStella Laurenzo        # CHECK: Inside callback:
2099f3f6d7bSStella Laurenzo        # CHECK{LITERAL}: [[5. 5. 5. 5.]
2109f3f6d7bSStella Laurenzo        # CHECK{LITERAL}:  [6. 6. 6. 6.]
2119f3f6d7bSStella Laurenzo        # CHECK{LITERAL}:  [7. 7. 7. 7.]]
2129f3f6d7bSStella Laurenzo        execution_engine.invoke(
2139f3f6d7bSStella Laurenzo            "callback_memref",
214f9008e63STobias Hieta            ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(strided_arr))),
2159f3f6d7bSStella Laurenzo        )
2169f3f6d7bSStella Laurenzo
217a54f4eaeSMogball
2189f3f6d7bSStella Laurenzorun(testUnrankedMemRefCallback)
2199f3f6d7bSStella Laurenzo
220a54f4eaeSMogball
2219f3f6d7bSStella Laurenzo# Test callback with a ranked memref.
2229f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testRankedMemRefCallback
2239f3f6d7bSStella Laurenzodef testRankedMemRefCallback():
2249f3f6d7bSStella Laurenzo    # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
2259f3f6d7bSStella Laurenzo    @ctypes.CFUNCTYPE(
2269f3f6d7bSStella Laurenzo        None,
2279f3f6d7bSStella Laurenzo        ctypes.POINTER(
228f9008e63STobias Hieta            make_nd_memref_descriptor(2, np.ctypeslib.as_ctypes_type(np.float32))
229f9008e63STobias Hieta        ),
2309f3f6d7bSStella Laurenzo    )
2319f3f6d7bSStella Laurenzo    def callback(a):
2329f3f6d7bSStella Laurenzo        arr = ranked_memref_to_numpy(a)
2339f3f6d7bSStella Laurenzo        log("Inside Callback: ")
2349f3f6d7bSStella Laurenzo        log(arr)
2359f3f6d7bSStella Laurenzo
2369f3f6d7bSStella Laurenzo    with Context():
2379f3f6d7bSStella Laurenzo        # The module just forwards to a runtime function known as "some_callback_into_python".
238f9008e63STobias Hieta        module = Module.parse(
239f9008e63STobias Hieta            r"""
2402310ced8SRiver Riddlefunc.func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } {
2419f3f6d7bSStella Laurenzo  call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> ()
2429f3f6d7bSStella Laurenzo  return
2439f3f6d7bSStella Laurenzo}
2442310ced8SRiver Riddlefunc.func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface }
245f9008e63STobias Hieta"""
246f9008e63STobias Hieta        )
2479f3f6d7bSStella Laurenzo        execution_engine = ExecutionEngine(lowerToLLVM(module))
2489f3f6d7bSStella Laurenzo        execution_engine.register_runtime("some_callback_into_python", callback)
2499f3f6d7bSStella Laurenzo        inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32)
2509f3f6d7bSStella Laurenzo        # CHECK: Inside Callback:
2519f3f6d7bSStella Laurenzo        # CHECK{LITERAL}: [[1. 5.]
2529f3f6d7bSStella Laurenzo        # CHECK{LITERAL}:  [6. 7.]]
2539f3f6d7bSStella Laurenzo        execution_engine.invoke(
254a54f4eaeSMogball            "callback_memref",
255f9008e63STobias Hieta            ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))),
256f9008e63STobias Hieta        )
257a54f4eaeSMogball
2589f3f6d7bSStella Laurenzo
2599f3f6d7bSStella Laurenzorun(testRankedMemRefCallback)
2609f3f6d7bSStella Laurenzo
261a54f4eaeSMogball
2622a603deeSFelix Schneider# Test callback with a ranked memref with non-zero offset.
2632a603deeSFelix Schneider# CHECK-LABEL: TEST: testRankedMemRefWithOffsetCallback
2642a603deeSFelix Schneiderdef testRankedMemRefWithOffsetCallback():
2652a603deeSFelix Schneider    # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
2662a603deeSFelix Schneider    @ctypes.CFUNCTYPE(
2672a603deeSFelix Schneider        None,
2682a603deeSFelix Schneider        ctypes.POINTER(
2692a603deeSFelix Schneider            make_nd_memref_descriptor(1, np.ctypeslib.as_ctypes_type(np.float32))
2702a603deeSFelix Schneider        ),
2712a603deeSFelix Schneider    )
2722a603deeSFelix Schneider    def callback(a):
2732a603deeSFelix Schneider        arr = ranked_memref_to_numpy(a)
2742a603deeSFelix Schneider        log("Inside Callback: ")
2752a603deeSFelix Schneider        log(arr)
2762a603deeSFelix Schneider
2772a603deeSFelix Schneider    with Context():
2782a603deeSFelix Schneider        # The module takes a subview of the argument memref and calls the callback with it
2792a603deeSFelix Schneider        module = Module.parse(
2802a603deeSFelix Schneider            r"""
2812a603deeSFelix Schneiderfunc.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} {
2822a603deeSFelix Schneider  %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index
2832a603deeSFelix Schneider  %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>>
2842a603deeSFelix Schneider  %cast = memref.cast %reinterpret_cast : memref<2xf32, strided<[1], offset: 3>> to memref<?xf32, strided<[?], offset: ?>>
2852a603deeSFelix Schneider  call @some_callback_into_python(%cast) : (memref<?xf32, strided<[?], offset: ?>>) -> ()
2862a603deeSFelix Schneider  return
2872a603deeSFelix Schneider}
2882a603deeSFelix Schneiderfunc.func private @some_callback_into_python(memref<?xf32, strided<[?], offset: ?>>) attributes {llvm.emit_c_interface}
2892a603deeSFelix Schneider"""
2902a603deeSFelix Schneider        )
2912a603deeSFelix Schneider        execution_engine = ExecutionEngine(lowerToLLVM(module))
2922a603deeSFelix Schneider        execution_engine.register_runtime("some_callback_into_python", callback)
2932a603deeSFelix Schneider        inp_arr = np.array([0, 0, 0, 1, 2], np.float32)
2942a603deeSFelix Schneider        # CHECK: Inside Callback:
2952a603deeSFelix Schneider        # CHECK{LITERAL}: [1. 2.]
2962a603deeSFelix Schneider        execution_engine.invoke(
2972a603deeSFelix Schneider            "callback_memref",
2982a603deeSFelix Schneider            ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))),
2992a603deeSFelix Schneider        )
3002a603deeSFelix Schneider
3012a603deeSFelix Schneider
3022a603deeSFelix Schneiderrun(testRankedMemRefWithOffsetCallback)
3032a603deeSFelix Schneider
3042a603deeSFelix Schneider
3052a603deeSFelix Schneider# Test callback with an unranked memref with non-zero offset
3062a603deeSFelix Schneider# CHECK-LABEL: TEST: testUnrankedMemRefWithOffsetCallback
3072a603deeSFelix Schneiderdef testUnrankedMemRefWithOffsetCallback():
3082a603deeSFelix Schneider    # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
3092a603deeSFelix Schneider    @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
3102a603deeSFelix Schneider    def callback(a):
3112a603deeSFelix Schneider        arr = unranked_memref_to_numpy(a, np.float32)
3122a603deeSFelix Schneider        log("Inside callback: ")
3132a603deeSFelix Schneider        log(arr)
3142a603deeSFelix Schneider
3152a603deeSFelix Schneider    with Context():
3162a603deeSFelix Schneider        # The module takes a subview of the argument memref, casts it to an unranked memref and
3172a603deeSFelix Schneider        # calls the callback with it.
3182a603deeSFelix Schneider        module = Module.parse(
3192a603deeSFelix Schneider            r"""
3202a603deeSFelix Schneiderfunc.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} {
3212a603deeSFelix Schneider    %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index
3222a603deeSFelix Schneider    %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>>
3232a603deeSFelix Schneider    %cast = memref.cast %reinterpret_cast : memref<2xf32, strided<[1], offset: 3>> to memref<*xf32>
3242a603deeSFelix Schneider    call @some_callback_into_python(%cast) : (memref<*xf32>) -> ()
3252a603deeSFelix Schneider    return
3262a603deeSFelix Schneider}
3272a603deeSFelix Schneiderfunc.func private @some_callback_into_python(memref<*xf32>) attributes {llvm.emit_c_interface}
3282a603deeSFelix Schneider"""
3292a603deeSFelix Schneider        )
3302a603deeSFelix Schneider        execution_engine = ExecutionEngine(lowerToLLVM(module))
3312a603deeSFelix Schneider        execution_engine.register_runtime("some_callback_into_python", callback)
3322a603deeSFelix Schneider        inp_arr = np.array([1, 2, 3, 4, 5], np.float32)
3332a603deeSFelix Schneider        # CHECK: Inside callback:
3342a603deeSFelix Schneider        # CHECK{LITERAL}: [4. 5.]
3352a603deeSFelix Schneider        execution_engine.invoke(
3362a603deeSFelix Schneider            "callback_memref",
3372a603deeSFelix Schneider            ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))),
3382a603deeSFelix Schneider        )
3392a603deeSFelix Schneider
3402a603deeSFelix Schneiderrun(testUnrankedMemRefWithOffsetCallback)
3412a603deeSFelix Schneider
3422a603deeSFelix Schneider
343c8b8e8e0SUday Bondhugula#  Test addition of two memrefs.
3449f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testMemrefAdd
3459f3f6d7bSStella Laurenzodef testMemrefAdd():
3469f3f6d7bSStella Laurenzo    with Context():
347f9008e63STobias Hieta        module = Module.parse(
348f9008e63STobias Hieta            """
3499f3f6d7bSStella Laurenzo    module  {
3502310ced8SRiver Riddle      func.func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } {
351a54f4eaeSMogball        %0 = arith.constant 0 : index
3529f3f6d7bSStella Laurenzo        %1 = memref.load %arg0[%0] : memref<1xf32>
3539f3f6d7bSStella Laurenzo        %2 = memref.load %arg1[] : memref<f32>
354a54f4eaeSMogball        %3 = arith.addf %1, %2 : f32
3559f3f6d7bSStella Laurenzo        memref.store %3, %arg2[%0] : memref<1xf32>
3569f3f6d7bSStella Laurenzo        return
3579f3f6d7bSStella Laurenzo      }
358f9008e63STobias Hieta    } """
359f9008e63STobias Hieta        )
3609f3f6d7bSStella Laurenzo        arg1 = np.array([32.5]).astype(np.float32)
3619f3f6d7bSStella Laurenzo        arg2 = np.array(6).astype(np.float32)
3629f3f6d7bSStella Laurenzo        res = np.array([0]).astype(np.float32)
3639f3f6d7bSStella Laurenzo
364a54f4eaeSMogball        arg1_memref_ptr = ctypes.pointer(
365f9008e63STobias Hieta            ctypes.pointer(get_ranked_memref_descriptor(arg1))
366f9008e63STobias Hieta        )
367a54f4eaeSMogball        arg2_memref_ptr = ctypes.pointer(
368f9008e63STobias Hieta            ctypes.pointer(get_ranked_memref_descriptor(arg2))
369f9008e63STobias Hieta        )
370a54f4eaeSMogball        res_memref_ptr = ctypes.pointer(
371f9008e63STobias Hieta            ctypes.pointer(get_ranked_memref_descriptor(res))
372f9008e63STobias Hieta        )
3739f3f6d7bSStella Laurenzo
3749f3f6d7bSStella Laurenzo        execution_engine = ExecutionEngine(lowerToLLVM(module))
375f9008e63STobias Hieta        execution_engine.invoke(
376f9008e63STobias Hieta            "main", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
377f9008e63STobias Hieta        )
3789f3f6d7bSStella Laurenzo        # CHECK: [32.5] + 6.0 = [38.5]
3799f3f6d7bSStella Laurenzo        log("{0} + {1} = {2}".format(arg1, arg2, res))
3809f3f6d7bSStella Laurenzo
381a54f4eaeSMogball
3829f3f6d7bSStella Laurenzorun(testMemrefAdd)
3839f3f6d7bSStella Laurenzo
384a54f4eaeSMogball
385f8b692ddSAart Bik# Test addition of two f16 memrefs
386f8b692ddSAart Bik# CHECK-LABEL: TEST: testF16MemrefAdd
387f8b692ddSAart Bikdef testF16MemrefAdd():
388f8b692ddSAart Bik    with Context():
389f9008e63STobias Hieta        module = Module.parse(
390f9008e63STobias Hieta            """
391f8b692ddSAart Bik    module  {
392f8b692ddSAart Bik      func.func @main(%arg0: memref<1xf16>,
393f8b692ddSAart Bik                      %arg1: memref<1xf16>,
394f8b692ddSAart Bik                      %arg2: memref<1xf16>) attributes { llvm.emit_c_interface } {
395f8b692ddSAart Bik        %0 = arith.constant 0 : index
396f8b692ddSAart Bik        %1 = memref.load %arg0[%0] : memref<1xf16>
397f8b692ddSAart Bik        %2 = memref.load %arg1[%0] : memref<1xf16>
398f8b692ddSAart Bik        %3 = arith.addf %1, %2 : f16
399f8b692ddSAart Bik        memref.store %3, %arg2[%0] : memref<1xf16>
400f8b692ddSAart Bik        return
401f8b692ddSAart Bik      }
402f9008e63STobias Hieta    } """
403f9008e63STobias Hieta        )
404f8b692ddSAart Bik
405f9008e63STobias Hieta        arg1 = np.array([11.0]).astype(np.float16)
406f9008e63STobias Hieta        arg2 = np.array([22.0]).astype(np.float16)
407f9008e63STobias Hieta        arg3 = np.array([0.0]).astype(np.float16)
408f8b692ddSAart Bik
409f8b692ddSAart Bik        arg1_memref_ptr = ctypes.pointer(
410f9008e63STobias Hieta            ctypes.pointer(get_ranked_memref_descriptor(arg1))
411f9008e63STobias Hieta        )
412f8b692ddSAart Bik        arg2_memref_ptr = ctypes.pointer(
413f9008e63STobias Hieta            ctypes.pointer(get_ranked_memref_descriptor(arg2))
414f9008e63STobias Hieta        )
415f8b692ddSAart Bik        arg3_memref_ptr = ctypes.pointer(
416f9008e63STobias Hieta            ctypes.pointer(get_ranked_memref_descriptor(arg3))
417f9008e63STobias Hieta        )
418f8b692ddSAart Bik
419f8b692ddSAart Bik        execution_engine = ExecutionEngine(lowerToLLVM(module))
420f9008e63STobias Hieta        execution_engine.invoke(
421f9008e63STobias Hieta            "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
422f9008e63STobias Hieta        )
423f8b692ddSAart Bik        # CHECK: [11.] + [22.] = [33.]
424f8b692ddSAart Bik        log("{0} + {1} = {2}".format(arg1, arg2, arg3))
425f8b692ddSAart Bik
426f8b692ddSAart Bik        # test to-numpy utility
427f8b692ddSAart Bik        # CHECK: [33.]
428f8b692ddSAart Bik        npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
429f8b692ddSAart Bik        log(npout)
430f8b692ddSAart Bik
431f8b692ddSAart Bik
432f8b692ddSAart Bikrun(testF16MemrefAdd)
433f8b692ddSAart Bik
434f8b692ddSAart Bik
435d6682189SAart Bik# Test addition of two complex memrefs
436d6682189SAart Bik# CHECK-LABEL: TEST: testComplexMemrefAdd
437d6682189SAart Bikdef testComplexMemrefAdd():
438d6682189SAart Bik    with Context():
439f9008e63STobias Hieta        module = Module.parse(
440f9008e63STobias Hieta            """
441d6682189SAart Bik    module  {
442d6682189SAart Bik      func.func @main(%arg0: memref<1xcomplex<f64>>,
443d6682189SAart Bik                      %arg1: memref<1xcomplex<f64>>,
444d6682189SAart Bik                      %arg2: memref<1xcomplex<f64>>) attributes { llvm.emit_c_interface } {
445d6682189SAart Bik        %0 = arith.constant 0 : index
446d6682189SAart Bik        %1 = memref.load %arg0[%0] : memref<1xcomplex<f64>>
447d6682189SAart Bik        %2 = memref.load %arg1[%0] : memref<1xcomplex<f64>>
448d6682189SAart Bik        %3 = complex.add %1, %2 : complex<f64>
449d6682189SAart Bik        memref.store %3, %arg2[%0] : memref<1xcomplex<f64>>
450d6682189SAart Bik        return
451d6682189SAart Bik      }
452f9008e63STobias Hieta    } """
453f9008e63STobias Hieta        )
454d6682189SAart Bik
455f9008e63STobias Hieta        arg1 = np.array([1.0 + 2.0j]).astype(np.complex128)
456f9008e63STobias Hieta        arg2 = np.array([3.0 + 4.0j]).astype(np.complex128)
457f9008e63STobias Hieta        arg3 = np.array([0.0 + 0.0j]).astype(np.complex128)
458d6682189SAart Bik
459d6682189SAart Bik        arg1_memref_ptr = ctypes.pointer(
460f9008e63STobias Hieta            ctypes.pointer(get_ranked_memref_descriptor(arg1))
461f9008e63STobias Hieta        )
462d6682189SAart Bik        arg2_memref_ptr = ctypes.pointer(
463f9008e63STobias Hieta            ctypes.pointer(get_ranked_memref_descriptor(arg2))
464f9008e63STobias Hieta        )
465d6682189SAart Bik        arg3_memref_ptr = ctypes.pointer(
466f9008e63STobias Hieta            ctypes.pointer(get_ranked_memref_descriptor(arg3))
467f9008e63STobias Hieta        )
468d6682189SAart Bik
469d6682189SAart Bik        execution_engine = ExecutionEngine(lowerToLLVM(module))
470f9008e63STobias Hieta        execution_engine.invoke(
471f9008e63STobias Hieta            "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
472f9008e63STobias Hieta        )
473d6682189SAart Bik        # CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j]
474d6682189SAart Bik        log("{0} + {1} = {2}".format(arg1, arg2, arg3))
475d6682189SAart Bik
476d6682189SAart Bik        # test to-numpy utility
477d6682189SAart Bik        # CHECK: [4.+6.j]
478d6682189SAart Bik        npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
479d6682189SAart Bik        log(npout)
480d6682189SAart Bik
481d6682189SAart Bik
482d6682189SAart Bikrun(testComplexMemrefAdd)
483d6682189SAart Bik
484d6682189SAart Bik
485d6682189SAart Bik# Test addition of two complex unranked memrefs
486d6682189SAart Bik# CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd
487d6682189SAart Bikdef testComplexUnrankedMemrefAdd():
488d6682189SAart Bik    with Context():
489f9008e63STobias Hieta        module = Module.parse(
490f9008e63STobias Hieta            """
491d6682189SAart Bik    module  {
492d6682189SAart Bik      func.func @main(%arg0: memref<*xcomplex<f32>>,
493d6682189SAart Bik                      %arg1: memref<*xcomplex<f32>>,
494d6682189SAart Bik                      %arg2: memref<*xcomplex<f32>>) attributes { llvm.emit_c_interface } {
495d6682189SAart Bik        %A = memref.cast %arg0 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
496d6682189SAart Bik        %B = memref.cast %arg1 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
497d6682189SAart Bik        %C = memref.cast %arg2 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
498d6682189SAart Bik        %0 = arith.constant 0 : index
499d6682189SAart Bik        %1 = memref.load %A[%0] : memref<1xcomplex<f32>>
500d6682189SAart Bik        %2 = memref.load %B[%0] : memref<1xcomplex<f32>>
501d6682189SAart Bik        %3 = complex.add %1, %2 : complex<f32>
502d6682189SAart Bik        memref.store %3, %C[%0] : memref<1xcomplex<f32>>
503d6682189SAart Bik        return
504d6682189SAart Bik      }
505f9008e63STobias Hieta    } """
506f9008e63STobias Hieta        )
507d6682189SAart Bik
508f9008e63STobias Hieta        arg1 = np.array([5.0 + 6.0j]).astype(np.complex64)
509f9008e63STobias Hieta        arg2 = np.array([7.0 + 8.0j]).astype(np.complex64)
510f9008e63STobias Hieta        arg3 = np.array([0.0 + 0.0j]).astype(np.complex64)
511d6682189SAart Bik
512d6682189SAart Bik        arg1_memref_ptr = ctypes.pointer(
513f9008e63STobias Hieta            ctypes.pointer(get_unranked_memref_descriptor(arg1))
514f9008e63STobias Hieta        )
515d6682189SAart Bik        arg2_memref_ptr = ctypes.pointer(
516f9008e63STobias Hieta            ctypes.pointer(get_unranked_memref_descriptor(arg2))
517f9008e63STobias Hieta        )
518d6682189SAart Bik        arg3_memref_ptr = ctypes.pointer(
519f9008e63STobias Hieta            ctypes.pointer(get_unranked_memref_descriptor(arg3))
520f9008e63STobias Hieta        )
521d6682189SAart Bik
522d6682189SAart Bik        execution_engine = ExecutionEngine(lowerToLLVM(module))
523f9008e63STobias Hieta        execution_engine.invoke(
524f9008e63STobias Hieta            "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
525f9008e63STobias Hieta        )
526d6682189SAart Bik        # CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j]
527d6682189SAart Bik        log("{0} + {1} = {2}".format(arg1, arg2, arg3))
528d6682189SAart Bik
529d6682189SAart Bik        # test to-numpy utility
530d6682189SAart Bik        # CHECK: [12.+14.j]
531f9008e63STobias Hieta        npout = unranked_memref_to_numpy(arg3_memref_ptr[0], np.dtype(np.complex64))
532d6682189SAart Bik        log(npout)
533d6682189SAart Bik
534d6682189SAart Bik
535d6682189SAart Bikrun(testComplexUnrankedMemrefAdd)
536d6682189SAart Bik
537d6682189SAart Bik
5385ef087b7SBimo# Test bf16 memrefs
5395ef087b7SBimo# CHECK-LABEL: TEST: testBF16Memref
5405ef087b7SBimodef testBF16Memref():
5415ef087b7SBimo    with Context():
5425ef087b7SBimo        module = Module.parse(
5435ef087b7SBimo            """
5445ef087b7SBimo    module  {
5455ef087b7SBimo      func.func @main(%arg0: memref<1xbf16>,
5465ef087b7SBimo                      %arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } {
5475ef087b7SBimo        %0 = arith.constant 0 : index
5485ef087b7SBimo        %1 = memref.load %arg0[%0] : memref<1xbf16>
5495ef087b7SBimo        memref.store %1, %arg1[%0] : memref<1xbf16>
5505ef087b7SBimo        return
5515ef087b7SBimo      }
5525ef087b7SBimo    } """
5535ef087b7SBimo        )
5545ef087b7SBimo
5555ef087b7SBimo        arg1 = np.array([0.5]).astype(bfloat16)
5565ef087b7SBimo        arg2 = np.array([0.0]).astype(bfloat16)
5575ef087b7SBimo
5585ef087b7SBimo        arg1_memref_ptr = ctypes.pointer(
5595ef087b7SBimo            ctypes.pointer(get_ranked_memref_descriptor(arg1))
5605ef087b7SBimo        )
5615ef087b7SBimo        arg2_memref_ptr = ctypes.pointer(
5625ef087b7SBimo            ctypes.pointer(get_ranked_memref_descriptor(arg2))
5635ef087b7SBimo        )
5645ef087b7SBimo
5655ef087b7SBimo        execution_engine = ExecutionEngine(lowerToLLVM(module))
5665ef087b7SBimo        execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
5675ef087b7SBimo
5685ef087b7SBimo        # test to-numpy utility
569*ba44d7baSKonrad Kleine        x = ranked_memref_to_numpy(arg2_memref_ptr[0])
570*ba44d7baSKonrad Kleine        assert len(x) == 1
571*ba44d7baSKonrad Kleine        assert x[0] == 0.5
5725ef087b7SBimo
5735ef087b7SBimo
57434d50721SKonrad Kleineif HAS_ML_DTYPES:
5755ef087b7SBimo    run(testBF16Memref)
576*ba44d7baSKonrad Kleineelse:
577*ba44d7baSKonrad Kleine    log("TEST: testBF16Memref")
5785ef087b7SBimo
5795ef087b7SBimo
580c8cac33aSPhrygianGates# Test f8E5M2 memrefs
581c8cac33aSPhrygianGates# CHECK-LABEL: TEST: testF8E5M2Memref
582c8cac33aSPhrygianGatesdef testF8E5M2Memref():
583c8cac33aSPhrygianGates    with Context():
584c8cac33aSPhrygianGates        module = Module.parse(
585c8cac33aSPhrygianGates            """
586c8cac33aSPhrygianGates    module  {
587c8cac33aSPhrygianGates      func.func @main(%arg0: memref<1xf8E5M2>,
588c8cac33aSPhrygianGates                      %arg1: memref<1xf8E5M2>) attributes { llvm.emit_c_interface } {
589c8cac33aSPhrygianGates        %0 = arith.constant 0 : index
590c8cac33aSPhrygianGates        %1 = memref.load %arg0[%0] : memref<1xf8E5M2>
591c8cac33aSPhrygianGates        memref.store %1, %arg1[%0] : memref<1xf8E5M2>
592c8cac33aSPhrygianGates        return
593c8cac33aSPhrygianGates      }
594c8cac33aSPhrygianGates    } """
595c8cac33aSPhrygianGates        )
596c8cac33aSPhrygianGates
597c8cac33aSPhrygianGates        arg1 = np.array([0.5]).astype(float8_e5m2)
598c8cac33aSPhrygianGates        arg2 = np.array([0.0]).astype(float8_e5m2)
599c8cac33aSPhrygianGates
600c8cac33aSPhrygianGates        arg1_memref_ptr = ctypes.pointer(
601c8cac33aSPhrygianGates            ctypes.pointer(get_ranked_memref_descriptor(arg1))
602c8cac33aSPhrygianGates        )
603c8cac33aSPhrygianGates        arg2_memref_ptr = ctypes.pointer(
604c8cac33aSPhrygianGates            ctypes.pointer(get_ranked_memref_descriptor(arg2))
605c8cac33aSPhrygianGates        )
606c8cac33aSPhrygianGates
607c8cac33aSPhrygianGates        execution_engine = ExecutionEngine(lowerToLLVM(module))
608c8cac33aSPhrygianGates        execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
609c8cac33aSPhrygianGates
610c8cac33aSPhrygianGates        # test to-numpy utility
611*ba44d7baSKonrad Kleine        x = ranked_memref_to_numpy(arg2_memref_ptr[0])
612*ba44d7baSKonrad Kleine        assert len(x) == 1
613*ba44d7baSKonrad Kleine        assert x[0] == 0.5
614c8cac33aSPhrygianGates
615c8cac33aSPhrygianGates
61634d50721SKonrad Kleineif HAS_ML_DTYPES:
617c8cac33aSPhrygianGates    run(testF8E5M2Memref)
618*ba44d7baSKonrad Kleineelse:
619*ba44d7baSKonrad Kleine    log("TEST: testF8E5M2Memref")
620c8cac33aSPhrygianGates
621c8cac33aSPhrygianGates
6229f3f6d7bSStella Laurenzo#  Test addition of two 2d_memref
6239f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
6249f3f6d7bSStella Laurenzodef testDynamicMemrefAdd2D():
6259f3f6d7bSStella Laurenzo    with Context():
626f9008e63STobias Hieta        module = Module.parse(
627f9008e63STobias Hieta            """
6289f3f6d7bSStella Laurenzo      module  {
6292310ced8SRiver Riddle        func.func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} {
630a54f4eaeSMogball          %c0 = arith.constant 0 : index
631a54f4eaeSMogball          %c2 = arith.constant 2 : index
632a54f4eaeSMogball          %c1 = arith.constant 1 : index
633ace01605SRiver Riddle          cf.br ^bb1(%c0 : index)
6349f3f6d7bSStella Laurenzo        ^bb1(%0: index):  // 2 preds: ^bb0, ^bb5
635a54f4eaeSMogball          %1 = arith.cmpi slt, %0, %c2 : index
636ace01605SRiver Riddle          cf.cond_br %1, ^bb2, ^bb6
6379f3f6d7bSStella Laurenzo        ^bb2:  // pred: ^bb1
638a54f4eaeSMogball          %c0_0 = arith.constant 0 : index
639a54f4eaeSMogball          %c2_1 = arith.constant 2 : index
640a54f4eaeSMogball          %c1_2 = arith.constant 1 : index
641ace01605SRiver Riddle          cf.br ^bb3(%c0_0 : index)
6429f3f6d7bSStella Laurenzo        ^bb3(%2: index):  // 2 preds: ^bb2, ^bb4
643a54f4eaeSMogball          %3 = arith.cmpi slt, %2, %c2_1 : index
644ace01605SRiver Riddle          cf.cond_br %3, ^bb4, ^bb5
6459f3f6d7bSStella Laurenzo        ^bb4:  // pred: ^bb3
6469f3f6d7bSStella Laurenzo          %4 = memref.load %arg0[%0, %2] : memref<2x2xf32>
6479f3f6d7bSStella Laurenzo          %5 = memref.load %arg1[%0, %2] : memref<?x?xf32>
648a54f4eaeSMogball          %6 = arith.addf %4, %5 : f32
6499f3f6d7bSStella Laurenzo          memref.store %6, %arg2[%0, %2] : memref<2x2xf32>
650a54f4eaeSMogball          %7 = arith.addi %2, %c1_2 : index
651ace01605SRiver Riddle          cf.br ^bb3(%7 : index)
6529f3f6d7bSStella Laurenzo        ^bb5:  // pred: ^bb3
653a54f4eaeSMogball          %8 = arith.addi %0, %c1 : index
654ace01605SRiver Riddle          cf.br ^bb1(%8 : index)
6559f3f6d7bSStella Laurenzo        ^bb6:  // pred: ^bb1
6569f3f6d7bSStella Laurenzo          return
6579f3f6d7bSStella Laurenzo        }
6589f3f6d7bSStella Laurenzo      }
659f9008e63STobias Hieta        """
660f9008e63STobias Hieta        )
6619f3f6d7bSStella Laurenzo        arg1 = np.random.randn(2, 2).astype(np.float32)
6629f3f6d7bSStella Laurenzo        arg2 = np.random.randn(2, 2).astype(np.float32)
6639f3f6d7bSStella Laurenzo        res = np.random.randn(2, 2).astype(np.float32)
6649f3f6d7bSStella Laurenzo
665a54f4eaeSMogball        arg1_memref_ptr = ctypes.pointer(
666f9008e63STobias Hieta            ctypes.pointer(get_ranked_memref_descriptor(arg1))
667f9008e63STobias Hieta        )
668a54f4eaeSMogball        arg2_memref_ptr = ctypes.pointer(
669f9008e63STobias Hieta            ctypes.pointer(get_ranked_memref_descriptor(arg2))
670f9008e63STobias Hieta        )
671a54f4eaeSMogball        res_memref_ptr = ctypes.pointer(
672f9008e63STobias Hieta            ctypes.pointer(get_ranked_memref_descriptor(res))
673f9008e63STobias Hieta        )
6749f3f6d7bSStella Laurenzo
6759f3f6d7bSStella Laurenzo        execution_engine = ExecutionEngine(lowerToLLVM(module))
676f9008e63STobias Hieta        execution_engine.invoke(
677f9008e63STobias Hieta            "memref_add_2d", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
678f9008e63STobias Hieta        )
6799f3f6d7bSStella Laurenzo        # CHECK: True
6809f3f6d7bSStella Laurenzo        log(np.allclose(arg1 + arg2, res))
6819f3f6d7bSStella Laurenzo
682a54f4eaeSMogball
6839f3f6d7bSStella Laurenzorun(testDynamicMemrefAdd2D)
684c8b8e8e0SUday Bondhugula
685a54f4eaeSMogball
686c8b8e8e0SUday Bondhugula#  Test loading of shared libraries.
687c8b8e8e0SUday Bondhugula# CHECK-LABEL: TEST: testSharedLibLoad
688c8b8e8e0SUday Bondhuguladef testSharedLibLoad():
689c8b8e8e0SUday Bondhugula    with Context():
690f9008e63STobias Hieta        module = Module.parse(
691f9008e63STobias Hieta            """
692c8b8e8e0SUday Bondhugula      module  {
6932310ced8SRiver Riddle      func.func @main(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } {
694a54f4eaeSMogball        %c0 = arith.constant 0 : index
695a54f4eaeSMogball        %cst42 = arith.constant 42.0 : f32
696c8b8e8e0SUday Bondhugula        memref.store %cst42, %arg0[%c0] : memref<1xf32>
697c8b8e8e0SUday Bondhugula        %u_memref = memref.cast %arg0 : memref<1xf32> to memref<*xf32>
698d4555698SStella Stamenova        call @printMemrefF32(%u_memref) : (memref<*xf32>) -> ()
699c8b8e8e0SUday Bondhugula        return
700c8b8e8e0SUday Bondhugula      }
701d4555698SStella Stamenova      func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface }
702f9008e63STobias Hieta     } """
703f9008e63STobias Hieta        )
704c8b8e8e0SUday Bondhugula        arg0 = np.array([0.0]).astype(np.float32)
705c8b8e8e0SUday Bondhugula
706a54f4eaeSMogball        arg0_memref_ptr = ctypes.pointer(
707f9008e63STobias Hieta            ctypes.pointer(get_ranked_memref_descriptor(arg0))
708f9008e63STobias Hieta        )
709c8b8e8e0SUday Bondhugula
710f9008e63STobias Hieta        if sys.platform == "win32":
711057863a9SStella Stamenova            shared_libs = [
712057863a9SStella Stamenova                "../../../../bin/mlir_runner_utils.dll",
713f9008e63STobias Hieta                "../../../../bin/mlir_c_runner_utils.dll",
714057863a9SStella Stamenova            ]
715f9008e63STobias Hieta        elif sys.platform == "darwin":
716f9676d2dSAnush Elangovan            shared_libs = [
717f9676d2dSAnush Elangovan                "../../../../lib/libmlir_runner_utils.dylib",
718f9008e63STobias Hieta                "../../../../lib/libmlir_c_runner_utils.dylib",
719f9676d2dSAnush Elangovan            ]
720057863a9SStella Stamenova        else:
721a54f4eaeSMogball            shared_libs = [
7224eee0cfcSTulio Magno Quites Machado Filho                MLIR_RUNNER_UTILS,
7234eee0cfcSTulio Magno Quites Machado Filho                MLIR_C_RUNNER_UTILS,
724057863a9SStella Stamenova            ]
725057863a9SStella Stamenova
726057863a9SStella Stamenova        execution_engine = ExecutionEngine(
727f9008e63STobias Hieta            lowerToLLVM(module), opt_level=3, shared_libs=shared_libs
728f9008e63STobias Hieta        )
729c8b8e8e0SUday Bondhugula        execution_engine.invoke("main", arg0_memref_ptr)
730c8b8e8e0SUday Bondhugula        # CHECK: Unranked Memref
731c8b8e8e0SUday Bondhugula        # CHECK-NEXT: [42]
732c8b8e8e0SUday Bondhugula
733a54f4eaeSMogball
734c8b8e8e0SUday Bondhugularun(testSharedLibLoad)
735aaea92e1SDenys Shabalin
736aaea92e1SDenys Shabalin
737aaea92e1SDenys Shabalin#  Test that nano time clock is available.
738aaea92e1SDenys Shabalin# CHECK-LABEL: TEST: testNanoTime
739aaea92e1SDenys Shabalindef testNanoTime():
740aaea92e1SDenys Shabalin    with Context():
741f9008e63STobias Hieta        module = Module.parse(
742f9008e63STobias Hieta            """
743aaea92e1SDenys Shabalin      module {
7442310ced8SRiver Riddle      func.func @main() attributes { llvm.emit_c_interface } {
745d4555698SStella Stamenova        %now = call @nanoTime() : () -> i64
746aaea92e1SDenys Shabalin        %memref = memref.alloca() : memref<1xi64>
747aaea92e1SDenys Shabalin        %c0 = arith.constant 0 : index
748aaea92e1SDenys Shabalin        memref.store %now, %memref[%c0] : memref<1xi64>
749aaea92e1SDenys Shabalin        %u_memref = memref.cast %memref : memref<1xi64> to memref<*xi64>
750d4555698SStella Stamenova        call @printMemrefI64(%u_memref) : (memref<*xi64>) -> ()
751aaea92e1SDenys Shabalin        return
752aaea92e1SDenys Shabalin      }
753d4555698SStella Stamenova      func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface }
754d4555698SStella Stamenova      func.func private @printMemrefI64(memref<*xi64>) attributes { llvm.emit_c_interface }
755f9008e63STobias Hieta    }"""
756f9008e63STobias Hieta        )
757aaea92e1SDenys Shabalin
758f9008e63STobias Hieta        if sys.platform == "win32":
759057863a9SStella Stamenova            shared_libs = [
760057863a9SStella Stamenova                "../../../../bin/mlir_runner_utils.dll",
761f9008e63STobias Hieta                "../../../../bin/mlir_c_runner_utils.dll",
762057863a9SStella Stamenova            ]
763057863a9SStella Stamenova        else:
764aaea92e1SDenys Shabalin            shared_libs = [
7654eee0cfcSTulio Magno Quites Machado Filho                MLIR_RUNNER_UTILS,
7664eee0cfcSTulio Magno Quites Machado Filho                MLIR_C_RUNNER_UTILS,
767057863a9SStella Stamenova            ]
768057863a9SStella Stamenova
769057863a9SStella Stamenova        execution_engine = ExecutionEngine(
770f9008e63STobias Hieta            lowerToLLVM(module), opt_level=3, shared_libs=shared_libs
771f9008e63STobias Hieta        )
772aaea92e1SDenys Shabalin        execution_engine.invoke("main")
773aaea92e1SDenys Shabalin        # CHECK: Unranked Memref
774aaea92e1SDenys Shabalin        # CHECK: [{{.*}}]
775aaea92e1SDenys Shabalin
776aaea92e1SDenys Shabalin
777aaea92e1SDenys Shabalinrun(testNanoTime)
77895c083f5SDenys Shabalin
77995c083f5SDenys Shabalin
78095c083f5SDenys Shabalin#  Test that nano time clock is available.
78195c083f5SDenys Shabalin# CHECK-LABEL: TEST: testDumpToObjectFile
78295c083f5SDenys Shabalindef testDumpToObjectFile():
78362eae837SDenys Shabalin    fd, object_path = tempfile.mkstemp(suffix=".o")
78495c083f5SDenys Shabalin
78595c083f5SDenys Shabalin    try:
78695c083f5SDenys Shabalin        with Context():
787f9008e63STobias Hieta            module = Module.parse(
788f9008e63STobias Hieta                """
78995c083f5SDenys Shabalin        module {
79095c083f5SDenys Shabalin        func.func @main() attributes { llvm.emit_c_interface } {
79195c083f5SDenys Shabalin          return
79295c083f5SDenys Shabalin        }
793f9008e63STobias Hieta      }"""
794f9008e63STobias Hieta            )
79595c083f5SDenys Shabalin
796f9008e63STobias Hieta            execution_engine = ExecutionEngine(lowerToLLVM(module), opt_level=3)
79795c083f5SDenys Shabalin
79895c083f5SDenys Shabalin            # CHECK: Object file exists: True
79995c083f5SDenys Shabalin            print(f"Object file exists: {os.path.exists(object_path)}")
80095c083f5SDenys Shabalin            # CHECK: Object file is empty: True
80195c083f5SDenys Shabalin            print(f"Object file is empty: {os.path.getsize(object_path) == 0}")
80295c083f5SDenys Shabalin
80395c083f5SDenys Shabalin            execution_engine.dump_to_object_file(object_path)
80495c083f5SDenys Shabalin
80595c083f5SDenys Shabalin            # CHECK: Object file exists: True
80695c083f5SDenys Shabalin            print(f"Object file exists: {os.path.exists(object_path)}")
80795c083f5SDenys Shabalin            # CHECK: Object file is empty: False
80895c083f5SDenys Shabalin            print(f"Object file is empty: {os.path.getsize(object_path) == 0}")
80995c083f5SDenys Shabalin
81095c083f5SDenys Shabalin    finally:
81162eae837SDenys Shabalin        os.close(fd)
81295c083f5SDenys Shabalin        os.remove(object_path)
81395c083f5SDenys Shabalin
81495c083f5SDenys Shabalin
81595c083f5SDenys Shabalinrun(testDumpToObjectFile)
816