14eee0cfcSTulio Magno Quites Machado Filho# RUN: env MLIR_RUNNER_UTILS=%mlir_runner_utils MLIR_C_RUNNER_UTILS=%mlir_c_runner_utils %PYTHON %s 2>&1 | FileCheck %s 2ca98e0ddSRainer Orth# REQUIRES: host-supports-jit 395c083f5SDenys Shabalinimport gc, sys, os, tempfile 49f3f6d7bSStella Laurenzofrom mlir.ir import * 59f3f6d7bSStella Laurenzofrom mlir.passmanager import * 69f3f6d7bSStella Laurenzofrom mlir.execution_engine import * 79f3f6d7bSStella Laurenzofrom mlir.runtime import * 834d50721SKonrad Kleine 934d50721SKonrad Kleinetry: 10c8cac33aSPhrygianGates from ml_dtypes import bfloat16, float8_e5m2 119f3f6d7bSStella Laurenzo 1234d50721SKonrad Kleine HAS_ML_DTYPES = True 1334d50721SKonrad Kleineexcept ModuleNotFoundError: 1434d50721SKonrad Kleine HAS_ML_DTYPES = False 1534d50721SKonrad Kleine 1634d50721SKonrad Kleine 174eee0cfcSTulio Magno Quites Machado FilhoMLIR_RUNNER_UTILS = os.getenv( 184eee0cfcSTulio Magno Quites Machado Filho "MLIR_RUNNER_UTILS", "../../../../lib/libmlir_runner_utils.so" 194eee0cfcSTulio Magno Quites Machado Filho) 204eee0cfcSTulio Magno Quites Machado FilhoMLIR_C_RUNNER_UTILS = os.getenv( 214eee0cfcSTulio Magno Quites Machado Filho "MLIR_C_RUNNER_UTILS", "../../../../lib/libmlir_c_runner_utils.so" 224eee0cfcSTulio Magno Quites Machado Filho) 23a54f4eaeSMogball 249f3f6d7bSStella Laurenzo# Log everything to stderr and flush so that we have a unified stream to match 259f3f6d7bSStella Laurenzo# errors/info emitted by MLIR to stderr. 269f3f6d7bSStella Laurenzodef log(*args): 279f3f6d7bSStella Laurenzo print(*args, file=sys.stderr) 289f3f6d7bSStella Laurenzo sys.stderr.flush() 299f3f6d7bSStella Laurenzo 30a54f4eaeSMogball 319f3f6d7bSStella Laurenzodef run(f): 329f3f6d7bSStella Laurenzo log("\nTEST:", f.__name__) 339f3f6d7bSStella Laurenzo f() 349f3f6d7bSStella Laurenzo gc.collect() 359f3f6d7bSStella Laurenzo assert Context._get_live_count() == 0 369f3f6d7bSStella Laurenzo 37a54f4eaeSMogball 389f3f6d7bSStella Laurenzo# Verify capsule interop. 399f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testCapsule 409f3f6d7bSStella Laurenzodef testCapsule(): 419f3f6d7bSStella Laurenzo with Context(): 42f9008e63STobias Hieta module = Module.parse( 43f9008e63STobias Hieta r""" 449f3f6d7bSStella Laurenzollvm.func @none() { 459f3f6d7bSStella Laurenzo llvm.return 469f3f6d7bSStella Laurenzo} 47f9008e63STobias Hieta """ 48f9008e63STobias Hieta ) 499f3f6d7bSStella Laurenzo execution_engine = ExecutionEngine(module) 509f3f6d7bSStella Laurenzo execution_engine_capsule = execution_engine._CAPIPtr 519f3f6d7bSStella Laurenzo # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr 529f3f6d7bSStella Laurenzo log(repr(execution_engine_capsule)) 539f3f6d7bSStella Laurenzo execution_engine._testing_release() 549f3f6d7bSStella Laurenzo execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule) 550cdf4915SStella Laurenzo # CHECK: _mlirExecutionEngine.ExecutionEngine 569f3f6d7bSStella Laurenzo log(repr(execution_engine1)) 579f3f6d7bSStella Laurenzo 58a54f4eaeSMogball 599f3f6d7bSStella Laurenzorun(testCapsule) 609f3f6d7bSStella Laurenzo 61a54f4eaeSMogball 629f3f6d7bSStella Laurenzo# Test invalid ExecutionEngine creation 639f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testInvalidModule 649f3f6d7bSStella Laurenzodef testInvalidModule(): 659f3f6d7bSStella Laurenzo with Context(): 669f3f6d7bSStella Laurenzo # Builtin function 67f9008e63STobias Hieta module = Module.parse( 68f9008e63STobias Hieta r""" 692310ced8SRiver Riddle func.func @foo() { return } 70f9008e63STobias Hieta """ 71f9008e63STobias Hieta ) 729f3f6d7bSStella Laurenzo # CHECK: Got RuntimeError: Failure while creating the ExecutionEngine. 739f3f6d7bSStella Laurenzo try: 749f3f6d7bSStella Laurenzo execution_engine = ExecutionEngine(module) 759f3f6d7bSStella Laurenzo except RuntimeError as e: 769f3f6d7bSStella Laurenzo log("Got RuntimeError: ", e) 779f3f6d7bSStella Laurenzo 78a54f4eaeSMogball 799f3f6d7bSStella Laurenzorun(testInvalidModule) 809f3f6d7bSStella Laurenzo 81a54f4eaeSMogball 829f3f6d7bSStella Laurenzodef lowerToLLVM(module): 83a54f4eaeSMogball pm = PassManager.parse( 84eb6c4197SMatthias Springer "builtin.module(convert-complex-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,convert-arith-to-llvm,convert-cf-to-llvm,reconcile-unrealized-casts)" 85f9008e63STobias Hieta ) 86c00f81ccSrkayaith pm.run(module.operation) 879f3f6d7bSStella Laurenzo return module 889f3f6d7bSStella Laurenzo 89a54f4eaeSMogball 909f3f6d7bSStella Laurenzo# Test simple ExecutionEngine execution 919f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testInvokeVoid 929f3f6d7bSStella Laurenzodef testInvokeVoid(): 939f3f6d7bSStella Laurenzo with Context(): 94f9008e63STobias Hieta module = Module.parse( 95f9008e63STobias Hieta r""" 962310ced8SRiver Riddlefunc.func @void() attributes { llvm.emit_c_interface } { 979f3f6d7bSStella Laurenzo return 989f3f6d7bSStella Laurenzo} 99f9008e63STobias Hieta """ 100f9008e63STobias Hieta ) 1019f3f6d7bSStella Laurenzo execution_engine = ExecutionEngine(lowerToLLVM(module)) 1029f3f6d7bSStella Laurenzo # Nothing to check other than no exception thrown here. 1039f3f6d7bSStella Laurenzo execution_engine.invoke("void") 1049f3f6d7bSStella Laurenzo 105a54f4eaeSMogball 1069f3f6d7bSStella Laurenzorun(testInvokeVoid) 1079f3f6d7bSStella Laurenzo 1089f3f6d7bSStella Laurenzo 1099f3f6d7bSStella Laurenzo# Test argument passing and result with a simple float addition. 1109f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testInvokeFloatAdd 1119f3f6d7bSStella Laurenzodef testInvokeFloatAdd(): 1129f3f6d7bSStella Laurenzo with Context(): 113f9008e63STobias Hieta module = Module.parse( 114f9008e63STobias Hieta r""" 1152310ced8SRiver Riddlefunc.func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } { 116a54f4eaeSMogball %add = arith.addf %arg0, %arg1 : f32 1179f3f6d7bSStella Laurenzo return %add : f32 1189f3f6d7bSStella Laurenzo} 119f9008e63STobias Hieta """ 120f9008e63STobias Hieta ) 1219f3f6d7bSStella Laurenzo execution_engine = ExecutionEngine(lowerToLLVM(module)) 1229f3f6d7bSStella Laurenzo # Prepare arguments: two input floats and one result. 1239f3f6d7bSStella Laurenzo # Arguments must be passed as pointers. 1249f3f6d7bSStella Laurenzo c_float_p = ctypes.c_float * 1 125f9008e63STobias Hieta arg0 = c_float_p(42.0) 126f9008e63STobias Hieta arg1 = c_float_p(2.0) 127f9008e63STobias Hieta res = c_float_p(-1.0) 1289f3f6d7bSStella Laurenzo execution_engine.invoke("add", arg0, arg1, res) 1299f3f6d7bSStella Laurenzo # CHECK: 42.0 + 2.0 = 44.0 1309f3f6d7bSStella Laurenzo log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0])) 1319f3f6d7bSStella Laurenzo 132a54f4eaeSMogball 1339f3f6d7bSStella Laurenzorun(testInvokeFloatAdd) 1349f3f6d7bSStella Laurenzo 1359f3f6d7bSStella Laurenzo 1369f3f6d7bSStella Laurenzo# Test callback 1379f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testBasicCallback 1389f3f6d7bSStella Laurenzodef testBasicCallback(): 1399f3f6d7bSStella Laurenzo # Define a callback function that takes a float and an integer and returns a float. 1409f3f6d7bSStella Laurenzo @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int) 1419f3f6d7bSStella Laurenzo def callback(a, b): 1429f3f6d7bSStella Laurenzo return a / 2 + b / 2 1439f3f6d7bSStella Laurenzo 1449f3f6d7bSStella Laurenzo with Context(): 1459f3f6d7bSStella Laurenzo # The module just forwards to a runtime function known as "some_callback_into_python". 146f9008e63STobias Hieta module = Module.parse( 147f9008e63STobias Hieta r""" 1482310ced8SRiver Riddlefunc.func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } { 1499f3f6d7bSStella Laurenzo %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32) 1509f3f6d7bSStella Laurenzo return %resf : f32 1519f3f6d7bSStella Laurenzo} 1522310ced8SRiver Riddlefunc.func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface } 153f9008e63STobias Hieta """ 154f9008e63STobias Hieta ) 1559f3f6d7bSStella Laurenzo execution_engine = ExecutionEngine(lowerToLLVM(module)) 1569f3f6d7bSStella Laurenzo execution_engine.register_runtime("some_callback_into_python", callback) 1579f3f6d7bSStella Laurenzo 1589f3f6d7bSStella Laurenzo # Prepare arguments: two input floats and one result. 1599f3f6d7bSStella Laurenzo # Arguments must be passed as pointers. 1609f3f6d7bSStella Laurenzo c_float_p = ctypes.c_float * 1 1619f3f6d7bSStella Laurenzo c_int_p = ctypes.c_int * 1 162f9008e63STobias Hieta arg0 = c_float_p(42.0) 1639f3f6d7bSStella Laurenzo arg1 = c_int_p(2) 164f9008e63STobias Hieta res = c_float_p(-1.0) 1659f3f6d7bSStella Laurenzo execution_engine.invoke("add", arg0, arg1, res) 1669f3f6d7bSStella Laurenzo # CHECK: 42.0 + 2 = 44.0 1679f3f6d7bSStella Laurenzo log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0] * 2)) 1689f3f6d7bSStella Laurenzo 169a54f4eaeSMogball 1709f3f6d7bSStella Laurenzorun(testBasicCallback) 1719f3f6d7bSStella Laurenzo 172a54f4eaeSMogball 1739f3f6d7bSStella Laurenzo# Test callback with an unranked memref 1749f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testUnrankedMemRefCallback 1759f3f6d7bSStella Laurenzodef testUnrankedMemRefCallback(): 1769f3f6d7bSStella Laurenzo # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it. 1779f3f6d7bSStella Laurenzo @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) 1789f3f6d7bSStella Laurenzo def callback(a): 1799f3f6d7bSStella Laurenzo arr = unranked_memref_to_numpy(a, np.float32) 1809f3f6d7bSStella Laurenzo log("Inside callback: ") 1819f3f6d7bSStella Laurenzo log(arr) 1829f3f6d7bSStella Laurenzo 1839f3f6d7bSStella Laurenzo with Context(): 1849f3f6d7bSStella Laurenzo # The module just forwards to a runtime function known as "some_callback_into_python". 185f9008e63STobias Hieta module = Module.parse( 186f9008e63STobias Hieta r""" 1872310ced8SRiver Riddlefunc.func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } { 1889f3f6d7bSStella Laurenzo call @some_callback_into_python(%arg0) : (memref<*xf32>) -> () 1899f3f6d7bSStella Laurenzo return 1909f3f6d7bSStella Laurenzo} 1912310ced8SRiver Riddlefunc.func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface } 192f9008e63STobias Hieta""" 193f9008e63STobias Hieta ) 1949f3f6d7bSStella Laurenzo execution_engine = ExecutionEngine(lowerToLLVM(module)) 1959f3f6d7bSStella Laurenzo execution_engine.register_runtime("some_callback_into_python", callback) 1969f3f6d7bSStella Laurenzo inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32) 1979f3f6d7bSStella Laurenzo # CHECK: Inside callback: 1989f3f6d7bSStella Laurenzo # CHECK{LITERAL}: [[1. 2.] 1999f3f6d7bSStella Laurenzo # CHECK{LITERAL}: [3. 4.]] 2009f3f6d7bSStella Laurenzo execution_engine.invoke( 2019f3f6d7bSStella Laurenzo "callback_memref", 2029f3f6d7bSStella Laurenzo ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))), 2039f3f6d7bSStella Laurenzo ) 2049f3f6d7bSStella Laurenzo inp_arr_1 = np.array([5, 6, 7], dtype=np.float32) 2059f3f6d7bSStella Laurenzo strided_arr = np.lib.stride_tricks.as_strided( 206f9008e63STobias Hieta inp_arr_1, strides=(4, 0), shape=(3, 4) 207f9008e63STobias Hieta ) 2089f3f6d7bSStella Laurenzo # CHECK: Inside callback: 2099f3f6d7bSStella Laurenzo # CHECK{LITERAL}: [[5. 5. 5. 5.] 2109f3f6d7bSStella Laurenzo # CHECK{LITERAL}: [6. 6. 6. 6.] 2119f3f6d7bSStella Laurenzo # CHECK{LITERAL}: [7. 7. 7. 7.]] 2129f3f6d7bSStella Laurenzo execution_engine.invoke( 2139f3f6d7bSStella Laurenzo "callback_memref", 214f9008e63STobias Hieta ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(strided_arr))), 2159f3f6d7bSStella Laurenzo ) 2169f3f6d7bSStella Laurenzo 217a54f4eaeSMogball 2189f3f6d7bSStella Laurenzorun(testUnrankedMemRefCallback) 2199f3f6d7bSStella Laurenzo 220a54f4eaeSMogball 2219f3f6d7bSStella Laurenzo# Test callback with a ranked memref. 2229f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testRankedMemRefCallback 2239f3f6d7bSStella Laurenzodef testRankedMemRefCallback(): 2249f3f6d7bSStella Laurenzo # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it. 2259f3f6d7bSStella Laurenzo @ctypes.CFUNCTYPE( 2269f3f6d7bSStella Laurenzo None, 2279f3f6d7bSStella Laurenzo ctypes.POINTER( 228f9008e63STobias Hieta make_nd_memref_descriptor(2, np.ctypeslib.as_ctypes_type(np.float32)) 229f9008e63STobias Hieta ), 2309f3f6d7bSStella Laurenzo ) 2319f3f6d7bSStella Laurenzo def callback(a): 2329f3f6d7bSStella Laurenzo arr = ranked_memref_to_numpy(a) 2339f3f6d7bSStella Laurenzo log("Inside Callback: ") 2349f3f6d7bSStella Laurenzo log(arr) 2359f3f6d7bSStella Laurenzo 2369f3f6d7bSStella Laurenzo with Context(): 2379f3f6d7bSStella Laurenzo # The module just forwards to a runtime function known as "some_callback_into_python". 238f9008e63STobias Hieta module = Module.parse( 239f9008e63STobias Hieta r""" 2402310ced8SRiver Riddlefunc.func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } { 2419f3f6d7bSStella Laurenzo call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> () 2429f3f6d7bSStella Laurenzo return 2439f3f6d7bSStella Laurenzo} 2442310ced8SRiver Riddlefunc.func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface } 245f9008e63STobias Hieta""" 246f9008e63STobias Hieta ) 2479f3f6d7bSStella Laurenzo execution_engine = ExecutionEngine(lowerToLLVM(module)) 2489f3f6d7bSStella Laurenzo execution_engine.register_runtime("some_callback_into_python", callback) 2499f3f6d7bSStella Laurenzo inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32) 2509f3f6d7bSStella Laurenzo # CHECK: Inside Callback: 2519f3f6d7bSStella Laurenzo # CHECK{LITERAL}: [[1. 5.] 2529f3f6d7bSStella Laurenzo # CHECK{LITERAL}: [6. 7.]] 2539f3f6d7bSStella Laurenzo execution_engine.invoke( 254a54f4eaeSMogball "callback_memref", 255f9008e63STobias Hieta ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))), 256f9008e63STobias Hieta ) 257a54f4eaeSMogball 2589f3f6d7bSStella Laurenzo 2599f3f6d7bSStella Laurenzorun(testRankedMemRefCallback) 2609f3f6d7bSStella Laurenzo 261a54f4eaeSMogball 2622a603deeSFelix Schneider# Test callback with a ranked memref with non-zero offset. 2632a603deeSFelix Schneider# CHECK-LABEL: TEST: testRankedMemRefWithOffsetCallback 2642a603deeSFelix Schneiderdef testRankedMemRefWithOffsetCallback(): 2652a603deeSFelix Schneider # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it. 2662a603deeSFelix Schneider @ctypes.CFUNCTYPE( 2672a603deeSFelix Schneider None, 2682a603deeSFelix Schneider ctypes.POINTER( 2692a603deeSFelix Schneider make_nd_memref_descriptor(1, np.ctypeslib.as_ctypes_type(np.float32)) 2702a603deeSFelix Schneider ), 2712a603deeSFelix Schneider ) 2722a603deeSFelix Schneider def callback(a): 2732a603deeSFelix Schneider arr = ranked_memref_to_numpy(a) 2742a603deeSFelix Schneider log("Inside Callback: ") 2752a603deeSFelix Schneider log(arr) 2762a603deeSFelix Schneider 2772a603deeSFelix Schneider with Context(): 2782a603deeSFelix Schneider # The module takes a subview of the argument memref and calls the callback with it 2792a603deeSFelix Schneider module = Module.parse( 2802a603deeSFelix Schneider r""" 2812a603deeSFelix Schneiderfunc.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} { 2822a603deeSFelix Schneider %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index 2832a603deeSFelix Schneider %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>> 2842a603deeSFelix Schneider %cast = memref.cast %reinterpret_cast : memref<2xf32, strided<[1], offset: 3>> to memref<?xf32, strided<[?], offset: ?>> 2852a603deeSFelix Schneider call @some_callback_into_python(%cast) : (memref<?xf32, strided<[?], offset: ?>>) -> () 2862a603deeSFelix Schneider return 2872a603deeSFelix Schneider} 2882a603deeSFelix Schneiderfunc.func private @some_callback_into_python(memref<?xf32, strided<[?], offset: ?>>) attributes {llvm.emit_c_interface} 2892a603deeSFelix Schneider""" 2902a603deeSFelix Schneider ) 2912a603deeSFelix Schneider execution_engine = ExecutionEngine(lowerToLLVM(module)) 2922a603deeSFelix Schneider execution_engine.register_runtime("some_callback_into_python", callback) 2932a603deeSFelix Schneider inp_arr = np.array([0, 0, 0, 1, 2], np.float32) 2942a603deeSFelix Schneider # CHECK: Inside Callback: 2952a603deeSFelix Schneider # CHECK{LITERAL}: [1. 2.] 2962a603deeSFelix Schneider execution_engine.invoke( 2972a603deeSFelix Schneider "callback_memref", 2982a603deeSFelix Schneider ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))), 2992a603deeSFelix Schneider ) 3002a603deeSFelix Schneider 3012a603deeSFelix Schneider 3022a603deeSFelix Schneiderrun(testRankedMemRefWithOffsetCallback) 3032a603deeSFelix Schneider 3042a603deeSFelix Schneider 3052a603deeSFelix Schneider# Test callback with an unranked memref with non-zero offset 3062a603deeSFelix Schneider# CHECK-LABEL: TEST: testUnrankedMemRefWithOffsetCallback 3072a603deeSFelix Schneiderdef testUnrankedMemRefWithOffsetCallback(): 3082a603deeSFelix Schneider # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it. 3092a603deeSFelix Schneider @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) 3102a603deeSFelix Schneider def callback(a): 3112a603deeSFelix Schneider arr = unranked_memref_to_numpy(a, np.float32) 3122a603deeSFelix Schneider log("Inside callback: ") 3132a603deeSFelix Schneider log(arr) 3142a603deeSFelix Schneider 3152a603deeSFelix Schneider with Context(): 3162a603deeSFelix Schneider # The module takes a subview of the argument memref, casts it to an unranked memref and 3172a603deeSFelix Schneider # calls the callback with it. 3182a603deeSFelix Schneider module = Module.parse( 3192a603deeSFelix Schneider r""" 3202a603deeSFelix Schneiderfunc.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} { 3212a603deeSFelix Schneider %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index 3222a603deeSFelix Schneider %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>> 3232a603deeSFelix Schneider %cast = memref.cast %reinterpret_cast : memref<2xf32, strided<[1], offset: 3>> to memref<*xf32> 3242a603deeSFelix Schneider call @some_callback_into_python(%cast) : (memref<*xf32>) -> () 3252a603deeSFelix Schneider return 3262a603deeSFelix Schneider} 3272a603deeSFelix Schneiderfunc.func private @some_callback_into_python(memref<*xf32>) attributes {llvm.emit_c_interface} 3282a603deeSFelix Schneider""" 3292a603deeSFelix Schneider ) 3302a603deeSFelix Schneider execution_engine = ExecutionEngine(lowerToLLVM(module)) 3312a603deeSFelix Schneider execution_engine.register_runtime("some_callback_into_python", callback) 3322a603deeSFelix Schneider inp_arr = np.array([1, 2, 3, 4, 5], np.float32) 3332a603deeSFelix Schneider # CHECK: Inside callback: 3342a603deeSFelix Schneider # CHECK{LITERAL}: [4. 5.] 3352a603deeSFelix Schneider execution_engine.invoke( 3362a603deeSFelix Schneider "callback_memref", 3372a603deeSFelix Schneider ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))), 3382a603deeSFelix Schneider ) 3392a603deeSFelix Schneider 3402a603deeSFelix Schneiderrun(testUnrankedMemRefWithOffsetCallback) 3412a603deeSFelix Schneider 3422a603deeSFelix Schneider 343c8b8e8e0SUday Bondhugula# Test addition of two memrefs. 3449f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testMemrefAdd 3459f3f6d7bSStella Laurenzodef testMemrefAdd(): 3469f3f6d7bSStella Laurenzo with Context(): 347f9008e63STobias Hieta module = Module.parse( 348f9008e63STobias Hieta """ 3499f3f6d7bSStella Laurenzo module { 3502310ced8SRiver Riddle func.func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } { 351a54f4eaeSMogball %0 = arith.constant 0 : index 3529f3f6d7bSStella Laurenzo %1 = memref.load %arg0[%0] : memref<1xf32> 3539f3f6d7bSStella Laurenzo %2 = memref.load %arg1[] : memref<f32> 354a54f4eaeSMogball %3 = arith.addf %1, %2 : f32 3559f3f6d7bSStella Laurenzo memref.store %3, %arg2[%0] : memref<1xf32> 3569f3f6d7bSStella Laurenzo return 3579f3f6d7bSStella Laurenzo } 358f9008e63STobias Hieta } """ 359f9008e63STobias Hieta ) 3609f3f6d7bSStella Laurenzo arg1 = np.array([32.5]).astype(np.float32) 3619f3f6d7bSStella Laurenzo arg2 = np.array(6).astype(np.float32) 3629f3f6d7bSStella Laurenzo res = np.array([0]).astype(np.float32) 3639f3f6d7bSStella Laurenzo 364a54f4eaeSMogball arg1_memref_ptr = ctypes.pointer( 365f9008e63STobias Hieta ctypes.pointer(get_ranked_memref_descriptor(arg1)) 366f9008e63STobias Hieta ) 367a54f4eaeSMogball arg2_memref_ptr = ctypes.pointer( 368f9008e63STobias Hieta ctypes.pointer(get_ranked_memref_descriptor(arg2)) 369f9008e63STobias Hieta ) 370a54f4eaeSMogball res_memref_ptr = ctypes.pointer( 371f9008e63STobias Hieta ctypes.pointer(get_ranked_memref_descriptor(res)) 372f9008e63STobias Hieta ) 3739f3f6d7bSStella Laurenzo 3749f3f6d7bSStella Laurenzo execution_engine = ExecutionEngine(lowerToLLVM(module)) 375f9008e63STobias Hieta execution_engine.invoke( 376f9008e63STobias Hieta "main", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr 377f9008e63STobias Hieta ) 3789f3f6d7bSStella Laurenzo # CHECK: [32.5] + 6.0 = [38.5] 3799f3f6d7bSStella Laurenzo log("{0} + {1} = {2}".format(arg1, arg2, res)) 3809f3f6d7bSStella Laurenzo 381a54f4eaeSMogball 3829f3f6d7bSStella Laurenzorun(testMemrefAdd) 3839f3f6d7bSStella Laurenzo 384a54f4eaeSMogball 385f8b692ddSAart Bik# Test addition of two f16 memrefs 386f8b692ddSAart Bik# CHECK-LABEL: TEST: testF16MemrefAdd 387f8b692ddSAart Bikdef testF16MemrefAdd(): 388f8b692ddSAart Bik with Context(): 389f9008e63STobias Hieta module = Module.parse( 390f9008e63STobias Hieta """ 391f8b692ddSAart Bik module { 392f8b692ddSAart Bik func.func @main(%arg0: memref<1xf16>, 393f8b692ddSAart Bik %arg1: memref<1xf16>, 394f8b692ddSAart Bik %arg2: memref<1xf16>) attributes { llvm.emit_c_interface } { 395f8b692ddSAart Bik %0 = arith.constant 0 : index 396f8b692ddSAart Bik %1 = memref.load %arg0[%0] : memref<1xf16> 397f8b692ddSAart Bik %2 = memref.load %arg1[%0] : memref<1xf16> 398f8b692ddSAart Bik %3 = arith.addf %1, %2 : f16 399f8b692ddSAart Bik memref.store %3, %arg2[%0] : memref<1xf16> 400f8b692ddSAart Bik return 401f8b692ddSAart Bik } 402f9008e63STobias Hieta } """ 403f9008e63STobias Hieta ) 404f8b692ddSAart Bik 405f9008e63STobias Hieta arg1 = np.array([11.0]).astype(np.float16) 406f9008e63STobias Hieta arg2 = np.array([22.0]).astype(np.float16) 407f9008e63STobias Hieta arg3 = np.array([0.0]).astype(np.float16) 408f8b692ddSAart Bik 409f8b692ddSAart Bik arg1_memref_ptr = ctypes.pointer( 410f9008e63STobias Hieta ctypes.pointer(get_ranked_memref_descriptor(arg1)) 411f9008e63STobias Hieta ) 412f8b692ddSAart Bik arg2_memref_ptr = ctypes.pointer( 413f9008e63STobias Hieta ctypes.pointer(get_ranked_memref_descriptor(arg2)) 414f9008e63STobias Hieta ) 415f8b692ddSAart Bik arg3_memref_ptr = ctypes.pointer( 416f9008e63STobias Hieta ctypes.pointer(get_ranked_memref_descriptor(arg3)) 417f9008e63STobias Hieta ) 418f8b692ddSAart Bik 419f8b692ddSAart Bik execution_engine = ExecutionEngine(lowerToLLVM(module)) 420f9008e63STobias Hieta execution_engine.invoke( 421f9008e63STobias Hieta "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr 422f9008e63STobias Hieta ) 423f8b692ddSAart Bik # CHECK: [11.] + [22.] = [33.] 424f8b692ddSAart Bik log("{0} + {1} = {2}".format(arg1, arg2, arg3)) 425f8b692ddSAart Bik 426f8b692ddSAart Bik # test to-numpy utility 427f8b692ddSAart Bik # CHECK: [33.] 428f8b692ddSAart Bik npout = ranked_memref_to_numpy(arg3_memref_ptr[0]) 429f8b692ddSAart Bik log(npout) 430f8b692ddSAart Bik 431f8b692ddSAart Bik 432f8b692ddSAart Bikrun(testF16MemrefAdd) 433f8b692ddSAart Bik 434f8b692ddSAart Bik 435d6682189SAart Bik# Test addition of two complex memrefs 436d6682189SAart Bik# CHECK-LABEL: TEST: testComplexMemrefAdd 437d6682189SAart Bikdef testComplexMemrefAdd(): 438d6682189SAart Bik with Context(): 439f9008e63STobias Hieta module = Module.parse( 440f9008e63STobias Hieta """ 441d6682189SAart Bik module { 442d6682189SAart Bik func.func @main(%arg0: memref<1xcomplex<f64>>, 443d6682189SAart Bik %arg1: memref<1xcomplex<f64>>, 444d6682189SAart Bik %arg2: memref<1xcomplex<f64>>) attributes { llvm.emit_c_interface } { 445d6682189SAart Bik %0 = arith.constant 0 : index 446d6682189SAart Bik %1 = memref.load %arg0[%0] : memref<1xcomplex<f64>> 447d6682189SAart Bik %2 = memref.load %arg1[%0] : memref<1xcomplex<f64>> 448d6682189SAart Bik %3 = complex.add %1, %2 : complex<f64> 449d6682189SAart Bik memref.store %3, %arg2[%0] : memref<1xcomplex<f64>> 450d6682189SAart Bik return 451d6682189SAart Bik } 452f9008e63STobias Hieta } """ 453f9008e63STobias Hieta ) 454d6682189SAart Bik 455f9008e63STobias Hieta arg1 = np.array([1.0 + 2.0j]).astype(np.complex128) 456f9008e63STobias Hieta arg2 = np.array([3.0 + 4.0j]).astype(np.complex128) 457f9008e63STobias Hieta arg3 = np.array([0.0 + 0.0j]).astype(np.complex128) 458d6682189SAart Bik 459d6682189SAart Bik arg1_memref_ptr = ctypes.pointer( 460f9008e63STobias Hieta ctypes.pointer(get_ranked_memref_descriptor(arg1)) 461f9008e63STobias Hieta ) 462d6682189SAart Bik arg2_memref_ptr = ctypes.pointer( 463f9008e63STobias Hieta ctypes.pointer(get_ranked_memref_descriptor(arg2)) 464f9008e63STobias Hieta ) 465d6682189SAart Bik arg3_memref_ptr = ctypes.pointer( 466f9008e63STobias Hieta ctypes.pointer(get_ranked_memref_descriptor(arg3)) 467f9008e63STobias Hieta ) 468d6682189SAart Bik 469d6682189SAart Bik execution_engine = ExecutionEngine(lowerToLLVM(module)) 470f9008e63STobias Hieta execution_engine.invoke( 471f9008e63STobias Hieta "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr 472f9008e63STobias Hieta ) 473d6682189SAart Bik # CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j] 474d6682189SAart Bik log("{0} + {1} = {2}".format(arg1, arg2, arg3)) 475d6682189SAart Bik 476d6682189SAart Bik # test to-numpy utility 477d6682189SAart Bik # CHECK: [4.+6.j] 478d6682189SAart Bik npout = ranked_memref_to_numpy(arg3_memref_ptr[0]) 479d6682189SAart Bik log(npout) 480d6682189SAart Bik 481d6682189SAart Bik 482d6682189SAart Bikrun(testComplexMemrefAdd) 483d6682189SAart Bik 484d6682189SAart Bik 485d6682189SAart Bik# Test addition of two complex unranked memrefs 486d6682189SAart Bik# CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd 487d6682189SAart Bikdef testComplexUnrankedMemrefAdd(): 488d6682189SAart Bik with Context(): 489f9008e63STobias Hieta module = Module.parse( 490f9008e63STobias Hieta """ 491d6682189SAart Bik module { 492d6682189SAart Bik func.func @main(%arg0: memref<*xcomplex<f32>>, 493d6682189SAart Bik %arg1: memref<*xcomplex<f32>>, 494d6682189SAart Bik %arg2: memref<*xcomplex<f32>>) attributes { llvm.emit_c_interface } { 495d6682189SAart Bik %A = memref.cast %arg0 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>> 496d6682189SAart Bik %B = memref.cast %arg1 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>> 497d6682189SAart Bik %C = memref.cast %arg2 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>> 498d6682189SAart Bik %0 = arith.constant 0 : index 499d6682189SAart Bik %1 = memref.load %A[%0] : memref<1xcomplex<f32>> 500d6682189SAart Bik %2 = memref.load %B[%0] : memref<1xcomplex<f32>> 501d6682189SAart Bik %3 = complex.add %1, %2 : complex<f32> 502d6682189SAart Bik memref.store %3, %C[%0] : memref<1xcomplex<f32>> 503d6682189SAart Bik return 504d6682189SAart Bik } 505f9008e63STobias Hieta } """ 506f9008e63STobias Hieta ) 507d6682189SAart Bik 508f9008e63STobias Hieta arg1 = np.array([5.0 + 6.0j]).astype(np.complex64) 509f9008e63STobias Hieta arg2 = np.array([7.0 + 8.0j]).astype(np.complex64) 510f9008e63STobias Hieta arg3 = np.array([0.0 + 0.0j]).astype(np.complex64) 511d6682189SAart Bik 512d6682189SAart Bik arg1_memref_ptr = ctypes.pointer( 513f9008e63STobias Hieta ctypes.pointer(get_unranked_memref_descriptor(arg1)) 514f9008e63STobias Hieta ) 515d6682189SAart Bik arg2_memref_ptr = ctypes.pointer( 516f9008e63STobias Hieta ctypes.pointer(get_unranked_memref_descriptor(arg2)) 517f9008e63STobias Hieta ) 518d6682189SAart Bik arg3_memref_ptr = ctypes.pointer( 519f9008e63STobias Hieta ctypes.pointer(get_unranked_memref_descriptor(arg3)) 520f9008e63STobias Hieta ) 521d6682189SAart Bik 522d6682189SAart Bik execution_engine = ExecutionEngine(lowerToLLVM(module)) 523f9008e63STobias Hieta execution_engine.invoke( 524f9008e63STobias Hieta "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr 525f9008e63STobias Hieta ) 526d6682189SAart Bik # CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j] 527d6682189SAart Bik log("{0} + {1} = {2}".format(arg1, arg2, arg3)) 528d6682189SAart Bik 529d6682189SAart Bik # test to-numpy utility 530d6682189SAart Bik # CHECK: [12.+14.j] 531f9008e63STobias Hieta npout = unranked_memref_to_numpy(arg3_memref_ptr[0], np.dtype(np.complex64)) 532d6682189SAart Bik log(npout) 533d6682189SAart Bik 534d6682189SAart Bik 535d6682189SAart Bikrun(testComplexUnrankedMemrefAdd) 536d6682189SAart Bik 537d6682189SAart Bik 5385ef087b7SBimo# Test bf16 memrefs 5395ef087b7SBimo# CHECK-LABEL: TEST: testBF16Memref 5405ef087b7SBimodef testBF16Memref(): 5415ef087b7SBimo with Context(): 5425ef087b7SBimo module = Module.parse( 5435ef087b7SBimo """ 5445ef087b7SBimo module { 5455ef087b7SBimo func.func @main(%arg0: memref<1xbf16>, 5465ef087b7SBimo %arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } { 5475ef087b7SBimo %0 = arith.constant 0 : index 5485ef087b7SBimo %1 = memref.load %arg0[%0] : memref<1xbf16> 5495ef087b7SBimo memref.store %1, %arg1[%0] : memref<1xbf16> 5505ef087b7SBimo return 5515ef087b7SBimo } 5525ef087b7SBimo } """ 5535ef087b7SBimo ) 5545ef087b7SBimo 5555ef087b7SBimo arg1 = np.array([0.5]).astype(bfloat16) 5565ef087b7SBimo arg2 = np.array([0.0]).astype(bfloat16) 5575ef087b7SBimo 5585ef087b7SBimo arg1_memref_ptr = ctypes.pointer( 5595ef087b7SBimo ctypes.pointer(get_ranked_memref_descriptor(arg1)) 5605ef087b7SBimo ) 5615ef087b7SBimo arg2_memref_ptr = ctypes.pointer( 5625ef087b7SBimo ctypes.pointer(get_ranked_memref_descriptor(arg2)) 5635ef087b7SBimo ) 5645ef087b7SBimo 5655ef087b7SBimo execution_engine = ExecutionEngine(lowerToLLVM(module)) 5665ef087b7SBimo execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr) 5675ef087b7SBimo 5685ef087b7SBimo # test to-numpy utility 569*ba44d7baSKonrad Kleine x = ranked_memref_to_numpy(arg2_memref_ptr[0]) 570*ba44d7baSKonrad Kleine assert len(x) == 1 571*ba44d7baSKonrad Kleine assert x[0] == 0.5 5725ef087b7SBimo 5735ef087b7SBimo 57434d50721SKonrad Kleineif HAS_ML_DTYPES: 5755ef087b7SBimo run(testBF16Memref) 576*ba44d7baSKonrad Kleineelse: 577*ba44d7baSKonrad Kleine log("TEST: testBF16Memref") 5785ef087b7SBimo 5795ef087b7SBimo 580c8cac33aSPhrygianGates# Test f8E5M2 memrefs 581c8cac33aSPhrygianGates# CHECK-LABEL: TEST: testF8E5M2Memref 582c8cac33aSPhrygianGatesdef testF8E5M2Memref(): 583c8cac33aSPhrygianGates with Context(): 584c8cac33aSPhrygianGates module = Module.parse( 585c8cac33aSPhrygianGates """ 586c8cac33aSPhrygianGates module { 587c8cac33aSPhrygianGates func.func @main(%arg0: memref<1xf8E5M2>, 588c8cac33aSPhrygianGates %arg1: memref<1xf8E5M2>) attributes { llvm.emit_c_interface } { 589c8cac33aSPhrygianGates %0 = arith.constant 0 : index 590c8cac33aSPhrygianGates %1 = memref.load %arg0[%0] : memref<1xf8E5M2> 591c8cac33aSPhrygianGates memref.store %1, %arg1[%0] : memref<1xf8E5M2> 592c8cac33aSPhrygianGates return 593c8cac33aSPhrygianGates } 594c8cac33aSPhrygianGates } """ 595c8cac33aSPhrygianGates ) 596c8cac33aSPhrygianGates 597c8cac33aSPhrygianGates arg1 = np.array([0.5]).astype(float8_e5m2) 598c8cac33aSPhrygianGates arg2 = np.array([0.0]).astype(float8_e5m2) 599c8cac33aSPhrygianGates 600c8cac33aSPhrygianGates arg1_memref_ptr = ctypes.pointer( 601c8cac33aSPhrygianGates ctypes.pointer(get_ranked_memref_descriptor(arg1)) 602c8cac33aSPhrygianGates ) 603c8cac33aSPhrygianGates arg2_memref_ptr = ctypes.pointer( 604c8cac33aSPhrygianGates ctypes.pointer(get_ranked_memref_descriptor(arg2)) 605c8cac33aSPhrygianGates ) 606c8cac33aSPhrygianGates 607c8cac33aSPhrygianGates execution_engine = ExecutionEngine(lowerToLLVM(module)) 608c8cac33aSPhrygianGates execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr) 609c8cac33aSPhrygianGates 610c8cac33aSPhrygianGates # test to-numpy utility 611*ba44d7baSKonrad Kleine x = ranked_memref_to_numpy(arg2_memref_ptr[0]) 612*ba44d7baSKonrad Kleine assert len(x) == 1 613*ba44d7baSKonrad Kleine assert x[0] == 0.5 614c8cac33aSPhrygianGates 615c8cac33aSPhrygianGates 61634d50721SKonrad Kleineif HAS_ML_DTYPES: 617c8cac33aSPhrygianGates run(testF8E5M2Memref) 618*ba44d7baSKonrad Kleineelse: 619*ba44d7baSKonrad Kleine log("TEST: testF8E5M2Memref") 620c8cac33aSPhrygianGates 621c8cac33aSPhrygianGates 6229f3f6d7bSStella Laurenzo# Test addition of two 2d_memref 6239f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testDynamicMemrefAdd2D 6249f3f6d7bSStella Laurenzodef testDynamicMemrefAdd2D(): 6259f3f6d7bSStella Laurenzo with Context(): 626f9008e63STobias Hieta module = Module.parse( 627f9008e63STobias Hieta """ 6289f3f6d7bSStella Laurenzo module { 6292310ced8SRiver Riddle func.func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} { 630a54f4eaeSMogball %c0 = arith.constant 0 : index 631a54f4eaeSMogball %c2 = arith.constant 2 : index 632a54f4eaeSMogball %c1 = arith.constant 1 : index 633ace01605SRiver Riddle cf.br ^bb1(%c0 : index) 6349f3f6d7bSStella Laurenzo ^bb1(%0: index): // 2 preds: ^bb0, ^bb5 635a54f4eaeSMogball %1 = arith.cmpi slt, %0, %c2 : index 636ace01605SRiver Riddle cf.cond_br %1, ^bb2, ^bb6 6379f3f6d7bSStella Laurenzo ^bb2: // pred: ^bb1 638a54f4eaeSMogball %c0_0 = arith.constant 0 : index 639a54f4eaeSMogball %c2_1 = arith.constant 2 : index 640a54f4eaeSMogball %c1_2 = arith.constant 1 : index 641ace01605SRiver Riddle cf.br ^bb3(%c0_0 : index) 6429f3f6d7bSStella Laurenzo ^bb3(%2: index): // 2 preds: ^bb2, ^bb4 643a54f4eaeSMogball %3 = arith.cmpi slt, %2, %c2_1 : index 644ace01605SRiver Riddle cf.cond_br %3, ^bb4, ^bb5 6459f3f6d7bSStella Laurenzo ^bb4: // pred: ^bb3 6469f3f6d7bSStella Laurenzo %4 = memref.load %arg0[%0, %2] : memref<2x2xf32> 6479f3f6d7bSStella Laurenzo %5 = memref.load %arg1[%0, %2] : memref<?x?xf32> 648a54f4eaeSMogball %6 = arith.addf %4, %5 : f32 6499f3f6d7bSStella Laurenzo memref.store %6, %arg2[%0, %2] : memref<2x2xf32> 650a54f4eaeSMogball %7 = arith.addi %2, %c1_2 : index 651ace01605SRiver Riddle cf.br ^bb3(%7 : index) 6529f3f6d7bSStella Laurenzo ^bb5: // pred: ^bb3 653a54f4eaeSMogball %8 = arith.addi %0, %c1 : index 654ace01605SRiver Riddle cf.br ^bb1(%8 : index) 6559f3f6d7bSStella Laurenzo ^bb6: // pred: ^bb1 6569f3f6d7bSStella Laurenzo return 6579f3f6d7bSStella Laurenzo } 6589f3f6d7bSStella Laurenzo } 659f9008e63STobias Hieta """ 660f9008e63STobias Hieta ) 6619f3f6d7bSStella Laurenzo arg1 = np.random.randn(2, 2).astype(np.float32) 6629f3f6d7bSStella Laurenzo arg2 = np.random.randn(2, 2).astype(np.float32) 6639f3f6d7bSStella Laurenzo res = np.random.randn(2, 2).astype(np.float32) 6649f3f6d7bSStella Laurenzo 665a54f4eaeSMogball arg1_memref_ptr = ctypes.pointer( 666f9008e63STobias Hieta ctypes.pointer(get_ranked_memref_descriptor(arg1)) 667f9008e63STobias Hieta ) 668a54f4eaeSMogball arg2_memref_ptr = ctypes.pointer( 669f9008e63STobias Hieta ctypes.pointer(get_ranked_memref_descriptor(arg2)) 670f9008e63STobias Hieta ) 671a54f4eaeSMogball res_memref_ptr = ctypes.pointer( 672f9008e63STobias Hieta ctypes.pointer(get_ranked_memref_descriptor(res)) 673f9008e63STobias Hieta ) 6749f3f6d7bSStella Laurenzo 6759f3f6d7bSStella Laurenzo execution_engine = ExecutionEngine(lowerToLLVM(module)) 676f9008e63STobias Hieta execution_engine.invoke( 677f9008e63STobias Hieta "memref_add_2d", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr 678f9008e63STobias Hieta ) 6799f3f6d7bSStella Laurenzo # CHECK: True 6809f3f6d7bSStella Laurenzo log(np.allclose(arg1 + arg2, res)) 6819f3f6d7bSStella Laurenzo 682a54f4eaeSMogball 6839f3f6d7bSStella Laurenzorun(testDynamicMemrefAdd2D) 684c8b8e8e0SUday Bondhugula 685a54f4eaeSMogball 686c8b8e8e0SUday Bondhugula# Test loading of shared libraries. 687c8b8e8e0SUday Bondhugula# CHECK-LABEL: TEST: testSharedLibLoad 688c8b8e8e0SUday Bondhuguladef testSharedLibLoad(): 689c8b8e8e0SUday Bondhugula with Context(): 690f9008e63STobias Hieta module = Module.parse( 691f9008e63STobias Hieta """ 692c8b8e8e0SUday Bondhugula module { 6932310ced8SRiver Riddle func.func @main(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } { 694a54f4eaeSMogball %c0 = arith.constant 0 : index 695a54f4eaeSMogball %cst42 = arith.constant 42.0 : f32 696c8b8e8e0SUday Bondhugula memref.store %cst42, %arg0[%c0] : memref<1xf32> 697c8b8e8e0SUday Bondhugula %u_memref = memref.cast %arg0 : memref<1xf32> to memref<*xf32> 698d4555698SStella Stamenova call @printMemrefF32(%u_memref) : (memref<*xf32>) -> () 699c8b8e8e0SUday Bondhugula return 700c8b8e8e0SUday Bondhugula } 701d4555698SStella Stamenova func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface } 702f9008e63STobias Hieta } """ 703f9008e63STobias Hieta ) 704c8b8e8e0SUday Bondhugula arg0 = np.array([0.0]).astype(np.float32) 705c8b8e8e0SUday Bondhugula 706a54f4eaeSMogball arg0_memref_ptr = ctypes.pointer( 707f9008e63STobias Hieta ctypes.pointer(get_ranked_memref_descriptor(arg0)) 708f9008e63STobias Hieta ) 709c8b8e8e0SUday Bondhugula 710f9008e63STobias Hieta if sys.platform == "win32": 711057863a9SStella Stamenova shared_libs = [ 712057863a9SStella Stamenova "../../../../bin/mlir_runner_utils.dll", 713f9008e63STobias Hieta "../../../../bin/mlir_c_runner_utils.dll", 714057863a9SStella Stamenova ] 715f9008e63STobias Hieta elif sys.platform == "darwin": 716f9676d2dSAnush Elangovan shared_libs = [ 717f9676d2dSAnush Elangovan "../../../../lib/libmlir_runner_utils.dylib", 718f9008e63STobias Hieta "../../../../lib/libmlir_c_runner_utils.dylib", 719f9676d2dSAnush Elangovan ] 720057863a9SStella Stamenova else: 721a54f4eaeSMogball shared_libs = [ 7224eee0cfcSTulio Magno Quites Machado Filho MLIR_RUNNER_UTILS, 7234eee0cfcSTulio Magno Quites Machado Filho MLIR_C_RUNNER_UTILS, 724057863a9SStella Stamenova ] 725057863a9SStella Stamenova 726057863a9SStella Stamenova execution_engine = ExecutionEngine( 727f9008e63STobias Hieta lowerToLLVM(module), opt_level=3, shared_libs=shared_libs 728f9008e63STobias Hieta ) 729c8b8e8e0SUday Bondhugula execution_engine.invoke("main", arg0_memref_ptr) 730c8b8e8e0SUday Bondhugula # CHECK: Unranked Memref 731c8b8e8e0SUday Bondhugula # CHECK-NEXT: [42] 732c8b8e8e0SUday Bondhugula 733a54f4eaeSMogball 734c8b8e8e0SUday Bondhugularun(testSharedLibLoad) 735aaea92e1SDenys Shabalin 736aaea92e1SDenys Shabalin 737aaea92e1SDenys Shabalin# Test that nano time clock is available. 738aaea92e1SDenys Shabalin# CHECK-LABEL: TEST: testNanoTime 739aaea92e1SDenys Shabalindef testNanoTime(): 740aaea92e1SDenys Shabalin with Context(): 741f9008e63STobias Hieta module = Module.parse( 742f9008e63STobias Hieta """ 743aaea92e1SDenys Shabalin module { 7442310ced8SRiver Riddle func.func @main() attributes { llvm.emit_c_interface } { 745d4555698SStella Stamenova %now = call @nanoTime() : () -> i64 746aaea92e1SDenys Shabalin %memref = memref.alloca() : memref<1xi64> 747aaea92e1SDenys Shabalin %c0 = arith.constant 0 : index 748aaea92e1SDenys Shabalin memref.store %now, %memref[%c0] : memref<1xi64> 749aaea92e1SDenys Shabalin %u_memref = memref.cast %memref : memref<1xi64> to memref<*xi64> 750d4555698SStella Stamenova call @printMemrefI64(%u_memref) : (memref<*xi64>) -> () 751aaea92e1SDenys Shabalin return 752aaea92e1SDenys Shabalin } 753d4555698SStella Stamenova func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface } 754d4555698SStella Stamenova func.func private @printMemrefI64(memref<*xi64>) attributes { llvm.emit_c_interface } 755f9008e63STobias Hieta }""" 756f9008e63STobias Hieta ) 757aaea92e1SDenys Shabalin 758f9008e63STobias Hieta if sys.platform == "win32": 759057863a9SStella Stamenova shared_libs = [ 760057863a9SStella Stamenova "../../../../bin/mlir_runner_utils.dll", 761f9008e63STobias Hieta "../../../../bin/mlir_c_runner_utils.dll", 762057863a9SStella Stamenova ] 763057863a9SStella Stamenova else: 764aaea92e1SDenys Shabalin shared_libs = [ 7654eee0cfcSTulio Magno Quites Machado Filho MLIR_RUNNER_UTILS, 7664eee0cfcSTulio Magno Quites Machado Filho MLIR_C_RUNNER_UTILS, 767057863a9SStella Stamenova ] 768057863a9SStella Stamenova 769057863a9SStella Stamenova execution_engine = ExecutionEngine( 770f9008e63STobias Hieta lowerToLLVM(module), opt_level=3, shared_libs=shared_libs 771f9008e63STobias Hieta ) 772aaea92e1SDenys Shabalin execution_engine.invoke("main") 773aaea92e1SDenys Shabalin # CHECK: Unranked Memref 774aaea92e1SDenys Shabalin # CHECK: [{{.*}}] 775aaea92e1SDenys Shabalin 776aaea92e1SDenys Shabalin 777aaea92e1SDenys Shabalinrun(testNanoTime) 77895c083f5SDenys Shabalin 77995c083f5SDenys Shabalin 78095c083f5SDenys Shabalin# Test that nano time clock is available. 78195c083f5SDenys Shabalin# CHECK-LABEL: TEST: testDumpToObjectFile 78295c083f5SDenys Shabalindef testDumpToObjectFile(): 78362eae837SDenys Shabalin fd, object_path = tempfile.mkstemp(suffix=".o") 78495c083f5SDenys Shabalin 78595c083f5SDenys Shabalin try: 78695c083f5SDenys Shabalin with Context(): 787f9008e63STobias Hieta module = Module.parse( 788f9008e63STobias Hieta """ 78995c083f5SDenys Shabalin module { 79095c083f5SDenys Shabalin func.func @main() attributes { llvm.emit_c_interface } { 79195c083f5SDenys Shabalin return 79295c083f5SDenys Shabalin } 793f9008e63STobias Hieta }""" 794f9008e63STobias Hieta ) 79595c083f5SDenys Shabalin 796f9008e63STobias Hieta execution_engine = ExecutionEngine(lowerToLLVM(module), opt_level=3) 79795c083f5SDenys Shabalin 79895c083f5SDenys Shabalin # CHECK: Object file exists: True 79995c083f5SDenys Shabalin print(f"Object file exists: {os.path.exists(object_path)}") 80095c083f5SDenys Shabalin # CHECK: Object file is empty: True 80195c083f5SDenys Shabalin print(f"Object file is empty: {os.path.getsize(object_path) == 0}") 80295c083f5SDenys Shabalin 80395c083f5SDenys Shabalin execution_engine.dump_to_object_file(object_path) 80495c083f5SDenys Shabalin 80595c083f5SDenys Shabalin # CHECK: Object file exists: True 80695c083f5SDenys Shabalin print(f"Object file exists: {os.path.exists(object_path)}") 80795c083f5SDenys Shabalin # CHECK: Object file is empty: False 80895c083f5SDenys Shabalin print(f"Object file is empty: {os.path.getsize(object_path) == 0}") 80995c083f5SDenys Shabalin 81095c083f5SDenys Shabalin finally: 81162eae837SDenys Shabalin os.close(fd) 81295c083f5SDenys Shabalin os.remove(object_path) 81395c083f5SDenys Shabalin 81495c083f5SDenys Shabalin 81595c083f5SDenys Shabalinrun(testDumpToObjectFile) 816