xref: /llvm-project/mlir/test/python/execution_engine.py (revision ba44d7ba1fb3e27f51d65ea1af280e00382e09e0)
1# RUN: env MLIR_RUNNER_UTILS=%mlir_runner_utils MLIR_C_RUNNER_UTILS=%mlir_c_runner_utils %PYTHON %s 2>&1 | FileCheck %s
2# REQUIRES: host-supports-jit
3import gc, sys, os, tempfile
4from mlir.ir import *
5from mlir.passmanager import *
6from mlir.execution_engine import *
7from mlir.runtime import *
8
9try:
10    from ml_dtypes import bfloat16, float8_e5m2
11
12    HAS_ML_DTYPES = True
13except ModuleNotFoundError:
14    HAS_ML_DTYPES = False
15
16
17MLIR_RUNNER_UTILS = os.getenv(
18    "MLIR_RUNNER_UTILS", "../../../../lib/libmlir_runner_utils.so"
19)
20MLIR_C_RUNNER_UTILS = os.getenv(
21    "MLIR_C_RUNNER_UTILS", "../../../../lib/libmlir_c_runner_utils.so"
22)
23
24# Log everything to stderr and flush so that we have a unified stream to match
25# errors/info emitted by MLIR to stderr.
26def log(*args):
27    print(*args, file=sys.stderr)
28    sys.stderr.flush()
29
30
31def run(f):
32    log("\nTEST:", f.__name__)
33    f()
34    gc.collect()
35    assert Context._get_live_count() == 0
36
37
38# Verify capsule interop.
39# CHECK-LABEL: TEST: testCapsule
40def testCapsule():
41    with Context():
42        module = Module.parse(
43            r"""
44llvm.func @none() {
45  llvm.return
46}
47    """
48        )
49        execution_engine = ExecutionEngine(module)
50        execution_engine_capsule = execution_engine._CAPIPtr
51        # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr
52        log(repr(execution_engine_capsule))
53        execution_engine._testing_release()
54        execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule)
55        # CHECK: _mlirExecutionEngine.ExecutionEngine
56        log(repr(execution_engine1))
57
58
59run(testCapsule)
60
61
62# Test invalid ExecutionEngine creation
63# CHECK-LABEL: TEST: testInvalidModule
64def testInvalidModule():
65    with Context():
66        # Builtin function
67        module = Module.parse(
68            r"""
69    func.func @foo() { return }
70    """
71        )
72        # CHECK: Got RuntimeError:  Failure while creating the ExecutionEngine.
73        try:
74            execution_engine = ExecutionEngine(module)
75        except RuntimeError as e:
76            log("Got RuntimeError: ", e)
77
78
79run(testInvalidModule)
80
81
82def lowerToLLVM(module):
83    pm = PassManager.parse(
84        "builtin.module(convert-complex-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,convert-arith-to-llvm,convert-cf-to-llvm,reconcile-unrealized-casts)"
85    )
86    pm.run(module.operation)
87    return module
88
89
90# Test simple ExecutionEngine execution
91# CHECK-LABEL: TEST: testInvokeVoid
92def testInvokeVoid():
93    with Context():
94        module = Module.parse(
95            r"""
96func.func @void() attributes { llvm.emit_c_interface } {
97  return
98}
99    """
100        )
101        execution_engine = ExecutionEngine(lowerToLLVM(module))
102        # Nothing to check other than no exception thrown here.
103        execution_engine.invoke("void")
104
105
106run(testInvokeVoid)
107
108
109# Test argument passing and result with a simple float addition.
110# CHECK-LABEL: TEST: testInvokeFloatAdd
111def testInvokeFloatAdd():
112    with Context():
113        module = Module.parse(
114            r"""
115func.func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } {
116  %add = arith.addf %arg0, %arg1 : f32
117  return %add : f32
118}
119    """
120        )
121        execution_engine = ExecutionEngine(lowerToLLVM(module))
122        # Prepare arguments: two input floats and one result.
123        # Arguments must be passed as pointers.
124        c_float_p = ctypes.c_float * 1
125        arg0 = c_float_p(42.0)
126        arg1 = c_float_p(2.0)
127        res = c_float_p(-1.0)
128        execution_engine.invoke("add", arg0, arg1, res)
129        # CHECK: 42.0 + 2.0 = 44.0
130        log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]))
131
132
133run(testInvokeFloatAdd)
134
135
136# Test callback
137# CHECK-LABEL: TEST: testBasicCallback
138def testBasicCallback():
139    # Define a callback function that takes a float and an integer and returns a float.
140    @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int)
141    def callback(a, b):
142        return a / 2 + b / 2
143
144    with Context():
145        # The module just forwards to a runtime function known as "some_callback_into_python".
146        module = Module.parse(
147            r"""
148func.func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } {
149  %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32)
150  return %resf : f32
151}
152func.func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface }
153    """
154        )
155        execution_engine = ExecutionEngine(lowerToLLVM(module))
156        execution_engine.register_runtime("some_callback_into_python", callback)
157
158        # Prepare arguments: two input floats and one result.
159        # Arguments must be passed as pointers.
160        c_float_p = ctypes.c_float * 1
161        c_int_p = ctypes.c_int * 1
162        arg0 = c_float_p(42.0)
163        arg1 = c_int_p(2)
164        res = c_float_p(-1.0)
165        execution_engine.invoke("add", arg0, arg1, res)
166        # CHECK: 42.0 + 2 = 44.0
167        log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0] * 2))
168
169
170run(testBasicCallback)
171
172
173# Test callback with an unranked memref
174# CHECK-LABEL: TEST: testUnrankedMemRefCallback
175def testUnrankedMemRefCallback():
176    # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
177    @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
178    def callback(a):
179        arr = unranked_memref_to_numpy(a, np.float32)
180        log("Inside callback: ")
181        log(arr)
182
183    with Context():
184        # The module just forwards to a runtime function known as "some_callback_into_python".
185        module = Module.parse(
186            r"""
187func.func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } {
188  call @some_callback_into_python(%arg0) : (memref<*xf32>) -> ()
189  return
190}
191func.func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface }
192"""
193        )
194        execution_engine = ExecutionEngine(lowerToLLVM(module))
195        execution_engine.register_runtime("some_callback_into_python", callback)
196        inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
197        # CHECK: Inside callback:
198        # CHECK{LITERAL}: [[1. 2.]
199        # CHECK{LITERAL}:  [3. 4.]]
200        execution_engine.invoke(
201            "callback_memref",
202            ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))),
203        )
204        inp_arr_1 = np.array([5, 6, 7], dtype=np.float32)
205        strided_arr = np.lib.stride_tricks.as_strided(
206            inp_arr_1, strides=(4, 0), shape=(3, 4)
207        )
208        # CHECK: Inside callback:
209        # CHECK{LITERAL}: [[5. 5. 5. 5.]
210        # CHECK{LITERAL}:  [6. 6. 6. 6.]
211        # CHECK{LITERAL}:  [7. 7. 7. 7.]]
212        execution_engine.invoke(
213            "callback_memref",
214            ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(strided_arr))),
215        )
216
217
218run(testUnrankedMemRefCallback)
219
220
221# Test callback with a ranked memref.
222# CHECK-LABEL: TEST: testRankedMemRefCallback
223def testRankedMemRefCallback():
224    # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
225    @ctypes.CFUNCTYPE(
226        None,
227        ctypes.POINTER(
228            make_nd_memref_descriptor(2, np.ctypeslib.as_ctypes_type(np.float32))
229        ),
230    )
231    def callback(a):
232        arr = ranked_memref_to_numpy(a)
233        log("Inside Callback: ")
234        log(arr)
235
236    with Context():
237        # The module just forwards to a runtime function known as "some_callback_into_python".
238        module = Module.parse(
239            r"""
240func.func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } {
241  call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> ()
242  return
243}
244func.func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface }
245"""
246        )
247        execution_engine = ExecutionEngine(lowerToLLVM(module))
248        execution_engine.register_runtime("some_callback_into_python", callback)
249        inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32)
250        # CHECK: Inside Callback:
251        # CHECK{LITERAL}: [[1. 5.]
252        # CHECK{LITERAL}:  [6. 7.]]
253        execution_engine.invoke(
254            "callback_memref",
255            ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))),
256        )
257
258
259run(testRankedMemRefCallback)
260
261
262# Test callback with a ranked memref with non-zero offset.
263# CHECK-LABEL: TEST: testRankedMemRefWithOffsetCallback
264def testRankedMemRefWithOffsetCallback():
265    # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
266    @ctypes.CFUNCTYPE(
267        None,
268        ctypes.POINTER(
269            make_nd_memref_descriptor(1, np.ctypeslib.as_ctypes_type(np.float32))
270        ),
271    )
272    def callback(a):
273        arr = ranked_memref_to_numpy(a)
274        log("Inside Callback: ")
275        log(arr)
276
277    with Context():
278        # The module takes a subview of the argument memref and calls the callback with it
279        module = Module.parse(
280            r"""
281func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} {
282  %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index
283  %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>>
284  %cast = memref.cast %reinterpret_cast : memref<2xf32, strided<[1], offset: 3>> to memref<?xf32, strided<[?], offset: ?>>
285  call @some_callback_into_python(%cast) : (memref<?xf32, strided<[?], offset: ?>>) -> ()
286  return
287}
288func.func private @some_callback_into_python(memref<?xf32, strided<[?], offset: ?>>) attributes {llvm.emit_c_interface}
289"""
290        )
291        execution_engine = ExecutionEngine(lowerToLLVM(module))
292        execution_engine.register_runtime("some_callback_into_python", callback)
293        inp_arr = np.array([0, 0, 0, 1, 2], np.float32)
294        # CHECK: Inside Callback:
295        # CHECK{LITERAL}: [1. 2.]
296        execution_engine.invoke(
297            "callback_memref",
298            ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))),
299        )
300
301
302run(testRankedMemRefWithOffsetCallback)
303
304
305# Test callback with an unranked memref with non-zero offset
306# CHECK-LABEL: TEST: testUnrankedMemRefWithOffsetCallback
307def testUnrankedMemRefWithOffsetCallback():
308    # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
309    @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
310    def callback(a):
311        arr = unranked_memref_to_numpy(a, np.float32)
312        log("Inside callback: ")
313        log(arr)
314
315    with Context():
316        # The module takes a subview of the argument memref, casts it to an unranked memref and
317        # calls the callback with it.
318        module = Module.parse(
319            r"""
320func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} {
321    %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index
322    %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>>
323    %cast = memref.cast %reinterpret_cast : memref<2xf32, strided<[1], offset: 3>> to memref<*xf32>
324    call @some_callback_into_python(%cast) : (memref<*xf32>) -> ()
325    return
326}
327func.func private @some_callback_into_python(memref<*xf32>) attributes {llvm.emit_c_interface}
328"""
329        )
330        execution_engine = ExecutionEngine(lowerToLLVM(module))
331        execution_engine.register_runtime("some_callback_into_python", callback)
332        inp_arr = np.array([1, 2, 3, 4, 5], np.float32)
333        # CHECK: Inside callback:
334        # CHECK{LITERAL}: [4. 5.]
335        execution_engine.invoke(
336            "callback_memref",
337            ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))),
338        )
339
340run(testUnrankedMemRefWithOffsetCallback)
341
342
343#  Test addition of two memrefs.
344# CHECK-LABEL: TEST: testMemrefAdd
345def testMemrefAdd():
346    with Context():
347        module = Module.parse(
348            """
349    module  {
350      func.func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } {
351        %0 = arith.constant 0 : index
352        %1 = memref.load %arg0[%0] : memref<1xf32>
353        %2 = memref.load %arg1[] : memref<f32>
354        %3 = arith.addf %1, %2 : f32
355        memref.store %3, %arg2[%0] : memref<1xf32>
356        return
357      }
358    } """
359        )
360        arg1 = np.array([32.5]).astype(np.float32)
361        arg2 = np.array(6).astype(np.float32)
362        res = np.array([0]).astype(np.float32)
363
364        arg1_memref_ptr = ctypes.pointer(
365            ctypes.pointer(get_ranked_memref_descriptor(arg1))
366        )
367        arg2_memref_ptr = ctypes.pointer(
368            ctypes.pointer(get_ranked_memref_descriptor(arg2))
369        )
370        res_memref_ptr = ctypes.pointer(
371            ctypes.pointer(get_ranked_memref_descriptor(res))
372        )
373
374        execution_engine = ExecutionEngine(lowerToLLVM(module))
375        execution_engine.invoke(
376            "main", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
377        )
378        # CHECK: [32.5] + 6.0 = [38.5]
379        log("{0} + {1} = {2}".format(arg1, arg2, res))
380
381
382run(testMemrefAdd)
383
384
385# Test addition of two f16 memrefs
386# CHECK-LABEL: TEST: testF16MemrefAdd
387def testF16MemrefAdd():
388    with Context():
389        module = Module.parse(
390            """
391    module  {
392      func.func @main(%arg0: memref<1xf16>,
393                      %arg1: memref<1xf16>,
394                      %arg2: memref<1xf16>) attributes { llvm.emit_c_interface } {
395        %0 = arith.constant 0 : index
396        %1 = memref.load %arg0[%0] : memref<1xf16>
397        %2 = memref.load %arg1[%0] : memref<1xf16>
398        %3 = arith.addf %1, %2 : f16
399        memref.store %3, %arg2[%0] : memref<1xf16>
400        return
401      }
402    } """
403        )
404
405        arg1 = np.array([11.0]).astype(np.float16)
406        arg2 = np.array([22.0]).astype(np.float16)
407        arg3 = np.array([0.0]).astype(np.float16)
408
409        arg1_memref_ptr = ctypes.pointer(
410            ctypes.pointer(get_ranked_memref_descriptor(arg1))
411        )
412        arg2_memref_ptr = ctypes.pointer(
413            ctypes.pointer(get_ranked_memref_descriptor(arg2))
414        )
415        arg3_memref_ptr = ctypes.pointer(
416            ctypes.pointer(get_ranked_memref_descriptor(arg3))
417        )
418
419        execution_engine = ExecutionEngine(lowerToLLVM(module))
420        execution_engine.invoke(
421            "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
422        )
423        # CHECK: [11.] + [22.] = [33.]
424        log("{0} + {1} = {2}".format(arg1, arg2, arg3))
425
426        # test to-numpy utility
427        # CHECK: [33.]
428        npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
429        log(npout)
430
431
432run(testF16MemrefAdd)
433
434
435# Test addition of two complex memrefs
436# CHECK-LABEL: TEST: testComplexMemrefAdd
437def testComplexMemrefAdd():
438    with Context():
439        module = Module.parse(
440            """
441    module  {
442      func.func @main(%arg0: memref<1xcomplex<f64>>,
443                      %arg1: memref<1xcomplex<f64>>,
444                      %arg2: memref<1xcomplex<f64>>) attributes { llvm.emit_c_interface } {
445        %0 = arith.constant 0 : index
446        %1 = memref.load %arg0[%0] : memref<1xcomplex<f64>>
447        %2 = memref.load %arg1[%0] : memref<1xcomplex<f64>>
448        %3 = complex.add %1, %2 : complex<f64>
449        memref.store %3, %arg2[%0] : memref<1xcomplex<f64>>
450        return
451      }
452    } """
453        )
454
455        arg1 = np.array([1.0 + 2.0j]).astype(np.complex128)
456        arg2 = np.array([3.0 + 4.0j]).astype(np.complex128)
457        arg3 = np.array([0.0 + 0.0j]).astype(np.complex128)
458
459        arg1_memref_ptr = ctypes.pointer(
460            ctypes.pointer(get_ranked_memref_descriptor(arg1))
461        )
462        arg2_memref_ptr = ctypes.pointer(
463            ctypes.pointer(get_ranked_memref_descriptor(arg2))
464        )
465        arg3_memref_ptr = ctypes.pointer(
466            ctypes.pointer(get_ranked_memref_descriptor(arg3))
467        )
468
469        execution_engine = ExecutionEngine(lowerToLLVM(module))
470        execution_engine.invoke(
471            "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
472        )
473        # CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j]
474        log("{0} + {1} = {2}".format(arg1, arg2, arg3))
475
476        # test to-numpy utility
477        # CHECK: [4.+6.j]
478        npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
479        log(npout)
480
481
482run(testComplexMemrefAdd)
483
484
485# Test addition of two complex unranked memrefs
486# CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd
487def testComplexUnrankedMemrefAdd():
488    with Context():
489        module = Module.parse(
490            """
491    module  {
492      func.func @main(%arg0: memref<*xcomplex<f32>>,
493                      %arg1: memref<*xcomplex<f32>>,
494                      %arg2: memref<*xcomplex<f32>>) attributes { llvm.emit_c_interface } {
495        %A = memref.cast %arg0 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
496        %B = memref.cast %arg1 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
497        %C = memref.cast %arg2 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
498        %0 = arith.constant 0 : index
499        %1 = memref.load %A[%0] : memref<1xcomplex<f32>>
500        %2 = memref.load %B[%0] : memref<1xcomplex<f32>>
501        %3 = complex.add %1, %2 : complex<f32>
502        memref.store %3, %C[%0] : memref<1xcomplex<f32>>
503        return
504      }
505    } """
506        )
507
508        arg1 = np.array([5.0 + 6.0j]).astype(np.complex64)
509        arg2 = np.array([7.0 + 8.0j]).astype(np.complex64)
510        arg3 = np.array([0.0 + 0.0j]).astype(np.complex64)
511
512        arg1_memref_ptr = ctypes.pointer(
513            ctypes.pointer(get_unranked_memref_descriptor(arg1))
514        )
515        arg2_memref_ptr = ctypes.pointer(
516            ctypes.pointer(get_unranked_memref_descriptor(arg2))
517        )
518        arg3_memref_ptr = ctypes.pointer(
519            ctypes.pointer(get_unranked_memref_descriptor(arg3))
520        )
521
522        execution_engine = ExecutionEngine(lowerToLLVM(module))
523        execution_engine.invoke(
524            "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
525        )
526        # CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j]
527        log("{0} + {1} = {2}".format(arg1, arg2, arg3))
528
529        # test to-numpy utility
530        # CHECK: [12.+14.j]
531        npout = unranked_memref_to_numpy(arg3_memref_ptr[0], np.dtype(np.complex64))
532        log(npout)
533
534
535run(testComplexUnrankedMemrefAdd)
536
537
538# Test bf16 memrefs
539# CHECK-LABEL: TEST: testBF16Memref
540def testBF16Memref():
541    with Context():
542        module = Module.parse(
543            """
544    module  {
545      func.func @main(%arg0: memref<1xbf16>,
546                      %arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } {
547        %0 = arith.constant 0 : index
548        %1 = memref.load %arg0[%0] : memref<1xbf16>
549        memref.store %1, %arg1[%0] : memref<1xbf16>
550        return
551      }
552    } """
553        )
554
555        arg1 = np.array([0.5]).astype(bfloat16)
556        arg2 = np.array([0.0]).astype(bfloat16)
557
558        arg1_memref_ptr = ctypes.pointer(
559            ctypes.pointer(get_ranked_memref_descriptor(arg1))
560        )
561        arg2_memref_ptr = ctypes.pointer(
562            ctypes.pointer(get_ranked_memref_descriptor(arg2))
563        )
564
565        execution_engine = ExecutionEngine(lowerToLLVM(module))
566        execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
567
568        # test to-numpy utility
569        x = ranked_memref_to_numpy(arg2_memref_ptr[0])
570        assert len(x) == 1
571        assert x[0] == 0.5
572
573
574if HAS_ML_DTYPES:
575    run(testBF16Memref)
576else:
577    log("TEST: testBF16Memref")
578
579
580# Test f8E5M2 memrefs
581# CHECK-LABEL: TEST: testF8E5M2Memref
582def testF8E5M2Memref():
583    with Context():
584        module = Module.parse(
585            """
586    module  {
587      func.func @main(%arg0: memref<1xf8E5M2>,
588                      %arg1: memref<1xf8E5M2>) attributes { llvm.emit_c_interface } {
589        %0 = arith.constant 0 : index
590        %1 = memref.load %arg0[%0] : memref<1xf8E5M2>
591        memref.store %1, %arg1[%0] : memref<1xf8E5M2>
592        return
593      }
594    } """
595        )
596
597        arg1 = np.array([0.5]).astype(float8_e5m2)
598        arg2 = np.array([0.0]).astype(float8_e5m2)
599
600        arg1_memref_ptr = ctypes.pointer(
601            ctypes.pointer(get_ranked_memref_descriptor(arg1))
602        )
603        arg2_memref_ptr = ctypes.pointer(
604            ctypes.pointer(get_ranked_memref_descriptor(arg2))
605        )
606
607        execution_engine = ExecutionEngine(lowerToLLVM(module))
608        execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
609
610        # test to-numpy utility
611        x = ranked_memref_to_numpy(arg2_memref_ptr[0])
612        assert len(x) == 1
613        assert x[0] == 0.5
614
615
616if HAS_ML_DTYPES:
617    run(testF8E5M2Memref)
618else:
619    log("TEST: testF8E5M2Memref")
620
621
622#  Test addition of two 2d_memref
623# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
624def testDynamicMemrefAdd2D():
625    with Context():
626        module = Module.parse(
627            """
628      module  {
629        func.func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} {
630          %c0 = arith.constant 0 : index
631          %c2 = arith.constant 2 : index
632          %c1 = arith.constant 1 : index
633          cf.br ^bb1(%c0 : index)
634        ^bb1(%0: index):  // 2 preds: ^bb0, ^bb5
635          %1 = arith.cmpi slt, %0, %c2 : index
636          cf.cond_br %1, ^bb2, ^bb6
637        ^bb2:  // pred: ^bb1
638          %c0_0 = arith.constant 0 : index
639          %c2_1 = arith.constant 2 : index
640          %c1_2 = arith.constant 1 : index
641          cf.br ^bb3(%c0_0 : index)
642        ^bb3(%2: index):  // 2 preds: ^bb2, ^bb4
643          %3 = arith.cmpi slt, %2, %c2_1 : index
644          cf.cond_br %3, ^bb4, ^bb5
645        ^bb4:  // pred: ^bb3
646          %4 = memref.load %arg0[%0, %2] : memref<2x2xf32>
647          %5 = memref.load %arg1[%0, %2] : memref<?x?xf32>
648          %6 = arith.addf %4, %5 : f32
649          memref.store %6, %arg2[%0, %2] : memref<2x2xf32>
650          %7 = arith.addi %2, %c1_2 : index
651          cf.br ^bb3(%7 : index)
652        ^bb5:  // pred: ^bb3
653          %8 = arith.addi %0, %c1 : index
654          cf.br ^bb1(%8 : index)
655        ^bb6:  // pred: ^bb1
656          return
657        }
658      }
659        """
660        )
661        arg1 = np.random.randn(2, 2).astype(np.float32)
662        arg2 = np.random.randn(2, 2).astype(np.float32)
663        res = np.random.randn(2, 2).astype(np.float32)
664
665        arg1_memref_ptr = ctypes.pointer(
666            ctypes.pointer(get_ranked_memref_descriptor(arg1))
667        )
668        arg2_memref_ptr = ctypes.pointer(
669            ctypes.pointer(get_ranked_memref_descriptor(arg2))
670        )
671        res_memref_ptr = ctypes.pointer(
672            ctypes.pointer(get_ranked_memref_descriptor(res))
673        )
674
675        execution_engine = ExecutionEngine(lowerToLLVM(module))
676        execution_engine.invoke(
677            "memref_add_2d", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
678        )
679        # CHECK: True
680        log(np.allclose(arg1 + arg2, res))
681
682
683run(testDynamicMemrefAdd2D)
684
685
686#  Test loading of shared libraries.
687# CHECK-LABEL: TEST: testSharedLibLoad
688def testSharedLibLoad():
689    with Context():
690        module = Module.parse(
691            """
692      module  {
693      func.func @main(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } {
694        %c0 = arith.constant 0 : index
695        %cst42 = arith.constant 42.0 : f32
696        memref.store %cst42, %arg0[%c0] : memref<1xf32>
697        %u_memref = memref.cast %arg0 : memref<1xf32> to memref<*xf32>
698        call @printMemrefF32(%u_memref) : (memref<*xf32>) -> ()
699        return
700      }
701      func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface }
702     } """
703        )
704        arg0 = np.array([0.0]).astype(np.float32)
705
706        arg0_memref_ptr = ctypes.pointer(
707            ctypes.pointer(get_ranked_memref_descriptor(arg0))
708        )
709
710        if sys.platform == "win32":
711            shared_libs = [
712                "../../../../bin/mlir_runner_utils.dll",
713                "../../../../bin/mlir_c_runner_utils.dll",
714            ]
715        elif sys.platform == "darwin":
716            shared_libs = [
717                "../../../../lib/libmlir_runner_utils.dylib",
718                "../../../../lib/libmlir_c_runner_utils.dylib",
719            ]
720        else:
721            shared_libs = [
722                MLIR_RUNNER_UTILS,
723                MLIR_C_RUNNER_UTILS,
724            ]
725
726        execution_engine = ExecutionEngine(
727            lowerToLLVM(module), opt_level=3, shared_libs=shared_libs
728        )
729        execution_engine.invoke("main", arg0_memref_ptr)
730        # CHECK: Unranked Memref
731        # CHECK-NEXT: [42]
732
733
734run(testSharedLibLoad)
735
736
737#  Test that nano time clock is available.
738# CHECK-LABEL: TEST: testNanoTime
739def testNanoTime():
740    with Context():
741        module = Module.parse(
742            """
743      module {
744      func.func @main() attributes { llvm.emit_c_interface } {
745        %now = call @nanoTime() : () -> i64
746        %memref = memref.alloca() : memref<1xi64>
747        %c0 = arith.constant 0 : index
748        memref.store %now, %memref[%c0] : memref<1xi64>
749        %u_memref = memref.cast %memref : memref<1xi64> to memref<*xi64>
750        call @printMemrefI64(%u_memref) : (memref<*xi64>) -> ()
751        return
752      }
753      func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface }
754      func.func private @printMemrefI64(memref<*xi64>) attributes { llvm.emit_c_interface }
755    }"""
756        )
757
758        if sys.platform == "win32":
759            shared_libs = [
760                "../../../../bin/mlir_runner_utils.dll",
761                "../../../../bin/mlir_c_runner_utils.dll",
762            ]
763        else:
764            shared_libs = [
765                MLIR_RUNNER_UTILS,
766                MLIR_C_RUNNER_UTILS,
767            ]
768
769        execution_engine = ExecutionEngine(
770            lowerToLLVM(module), opt_level=3, shared_libs=shared_libs
771        )
772        execution_engine.invoke("main")
773        # CHECK: Unranked Memref
774        # CHECK: [{{.*}}]
775
776
777run(testNanoTime)
778
779
780#  Test that nano time clock is available.
781# CHECK-LABEL: TEST: testDumpToObjectFile
782def testDumpToObjectFile():
783    fd, object_path = tempfile.mkstemp(suffix=".o")
784
785    try:
786        with Context():
787            module = Module.parse(
788                """
789        module {
790        func.func @main() attributes { llvm.emit_c_interface } {
791          return
792        }
793      }"""
794            )
795
796            execution_engine = ExecutionEngine(lowerToLLVM(module), opt_level=3)
797
798            # CHECK: Object file exists: True
799            print(f"Object file exists: {os.path.exists(object_path)}")
800            # CHECK: Object file is empty: True
801            print(f"Object file is empty: {os.path.getsize(object_path) == 0}")
802
803            execution_engine.dump_to_object_file(object_path)
804
805            # CHECK: Object file exists: True
806            print(f"Object file exists: {os.path.exists(object_path)}")
807            # CHECK: Object file is empty: False
808            print(f"Object file is empty: {os.path.getsize(object_path) == 0}")
809
810    finally:
811        os.close(fd)
812        os.remove(object_path)
813
814
815run(testDumpToObjectFile)
816