1# RUN: env MLIR_RUNNER_UTILS=%mlir_runner_utils MLIR_C_RUNNER_UTILS=%mlir_c_runner_utils %PYTHON %s 2>&1 | FileCheck %s 2# REQUIRES: host-supports-jit 3import gc, sys, os, tempfile 4from mlir.ir import * 5from mlir.passmanager import * 6from mlir.execution_engine import * 7from mlir.runtime import * 8 9try: 10 from ml_dtypes import bfloat16, float8_e5m2 11 12 HAS_ML_DTYPES = True 13except ModuleNotFoundError: 14 HAS_ML_DTYPES = False 15 16 17MLIR_RUNNER_UTILS = os.getenv( 18 "MLIR_RUNNER_UTILS", "../../../../lib/libmlir_runner_utils.so" 19) 20MLIR_C_RUNNER_UTILS = os.getenv( 21 "MLIR_C_RUNNER_UTILS", "../../../../lib/libmlir_c_runner_utils.so" 22) 23 24# Log everything to stderr and flush so that we have a unified stream to match 25# errors/info emitted by MLIR to stderr. 26def log(*args): 27 print(*args, file=sys.stderr) 28 sys.stderr.flush() 29 30 31def run(f): 32 log("\nTEST:", f.__name__) 33 f() 34 gc.collect() 35 assert Context._get_live_count() == 0 36 37 38# Verify capsule interop. 39# CHECK-LABEL: TEST: testCapsule 40def testCapsule(): 41 with Context(): 42 module = Module.parse( 43 r""" 44llvm.func @none() { 45 llvm.return 46} 47 """ 48 ) 49 execution_engine = ExecutionEngine(module) 50 execution_engine_capsule = execution_engine._CAPIPtr 51 # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr 52 log(repr(execution_engine_capsule)) 53 execution_engine._testing_release() 54 execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule) 55 # CHECK: _mlirExecutionEngine.ExecutionEngine 56 log(repr(execution_engine1)) 57 58 59run(testCapsule) 60 61 62# Test invalid ExecutionEngine creation 63# CHECK-LABEL: TEST: testInvalidModule 64def testInvalidModule(): 65 with Context(): 66 # Builtin function 67 module = Module.parse( 68 r""" 69 func.func @foo() { return } 70 """ 71 ) 72 # CHECK: Got RuntimeError: Failure while creating the ExecutionEngine. 73 try: 74 execution_engine = ExecutionEngine(module) 75 except RuntimeError as e: 76 log("Got RuntimeError: ", e) 77 78 79run(testInvalidModule) 80 81 82def lowerToLLVM(module): 83 pm = PassManager.parse( 84 "builtin.module(convert-complex-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,convert-arith-to-llvm,convert-cf-to-llvm,reconcile-unrealized-casts)" 85 ) 86 pm.run(module.operation) 87 return module 88 89 90# Test simple ExecutionEngine execution 91# CHECK-LABEL: TEST: testInvokeVoid 92def testInvokeVoid(): 93 with Context(): 94 module = Module.parse( 95 r""" 96func.func @void() attributes { llvm.emit_c_interface } { 97 return 98} 99 """ 100 ) 101 execution_engine = ExecutionEngine(lowerToLLVM(module)) 102 # Nothing to check other than no exception thrown here. 103 execution_engine.invoke("void") 104 105 106run(testInvokeVoid) 107 108 109# Test argument passing and result with a simple float addition. 110# CHECK-LABEL: TEST: testInvokeFloatAdd 111def testInvokeFloatAdd(): 112 with Context(): 113 module = Module.parse( 114 r""" 115func.func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } { 116 %add = arith.addf %arg0, %arg1 : f32 117 return %add : f32 118} 119 """ 120 ) 121 execution_engine = ExecutionEngine(lowerToLLVM(module)) 122 # Prepare arguments: two input floats and one result. 123 # Arguments must be passed as pointers. 124 c_float_p = ctypes.c_float * 1 125 arg0 = c_float_p(42.0) 126 arg1 = c_float_p(2.0) 127 res = c_float_p(-1.0) 128 execution_engine.invoke("add", arg0, arg1, res) 129 # CHECK: 42.0 + 2.0 = 44.0 130 log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0])) 131 132 133run(testInvokeFloatAdd) 134 135 136# Test callback 137# CHECK-LABEL: TEST: testBasicCallback 138def testBasicCallback(): 139 # Define a callback function that takes a float and an integer and returns a float. 140 @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int) 141 def callback(a, b): 142 return a / 2 + b / 2 143 144 with Context(): 145 # The module just forwards to a runtime function known as "some_callback_into_python". 146 module = Module.parse( 147 r""" 148func.func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } { 149 %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32) 150 return %resf : f32 151} 152func.func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface } 153 """ 154 ) 155 execution_engine = ExecutionEngine(lowerToLLVM(module)) 156 execution_engine.register_runtime("some_callback_into_python", callback) 157 158 # Prepare arguments: two input floats and one result. 159 # Arguments must be passed as pointers. 160 c_float_p = ctypes.c_float * 1 161 c_int_p = ctypes.c_int * 1 162 arg0 = c_float_p(42.0) 163 arg1 = c_int_p(2) 164 res = c_float_p(-1.0) 165 execution_engine.invoke("add", arg0, arg1, res) 166 # CHECK: 42.0 + 2 = 44.0 167 log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0] * 2)) 168 169 170run(testBasicCallback) 171 172 173# Test callback with an unranked memref 174# CHECK-LABEL: TEST: testUnrankedMemRefCallback 175def testUnrankedMemRefCallback(): 176 # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it. 177 @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) 178 def callback(a): 179 arr = unranked_memref_to_numpy(a, np.float32) 180 log("Inside callback: ") 181 log(arr) 182 183 with Context(): 184 # The module just forwards to a runtime function known as "some_callback_into_python". 185 module = Module.parse( 186 r""" 187func.func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } { 188 call @some_callback_into_python(%arg0) : (memref<*xf32>) -> () 189 return 190} 191func.func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface } 192""" 193 ) 194 execution_engine = ExecutionEngine(lowerToLLVM(module)) 195 execution_engine.register_runtime("some_callback_into_python", callback) 196 inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32) 197 # CHECK: Inside callback: 198 # CHECK{LITERAL}: [[1. 2.] 199 # CHECK{LITERAL}: [3. 4.]] 200 execution_engine.invoke( 201 "callback_memref", 202 ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))), 203 ) 204 inp_arr_1 = np.array([5, 6, 7], dtype=np.float32) 205 strided_arr = np.lib.stride_tricks.as_strided( 206 inp_arr_1, strides=(4, 0), shape=(3, 4) 207 ) 208 # CHECK: Inside callback: 209 # CHECK{LITERAL}: [[5. 5. 5. 5.] 210 # CHECK{LITERAL}: [6. 6. 6. 6.] 211 # CHECK{LITERAL}: [7. 7. 7. 7.]] 212 execution_engine.invoke( 213 "callback_memref", 214 ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(strided_arr))), 215 ) 216 217 218run(testUnrankedMemRefCallback) 219 220 221# Test callback with a ranked memref. 222# CHECK-LABEL: TEST: testRankedMemRefCallback 223def testRankedMemRefCallback(): 224 # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it. 225 @ctypes.CFUNCTYPE( 226 None, 227 ctypes.POINTER( 228 make_nd_memref_descriptor(2, np.ctypeslib.as_ctypes_type(np.float32)) 229 ), 230 ) 231 def callback(a): 232 arr = ranked_memref_to_numpy(a) 233 log("Inside Callback: ") 234 log(arr) 235 236 with Context(): 237 # The module just forwards to a runtime function known as "some_callback_into_python". 238 module = Module.parse( 239 r""" 240func.func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } { 241 call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> () 242 return 243} 244func.func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface } 245""" 246 ) 247 execution_engine = ExecutionEngine(lowerToLLVM(module)) 248 execution_engine.register_runtime("some_callback_into_python", callback) 249 inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32) 250 # CHECK: Inside Callback: 251 # CHECK{LITERAL}: [[1. 5.] 252 # CHECK{LITERAL}: [6. 7.]] 253 execution_engine.invoke( 254 "callback_memref", 255 ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))), 256 ) 257 258 259run(testRankedMemRefCallback) 260 261 262# Test callback with a ranked memref with non-zero offset. 263# CHECK-LABEL: TEST: testRankedMemRefWithOffsetCallback 264def testRankedMemRefWithOffsetCallback(): 265 # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it. 266 @ctypes.CFUNCTYPE( 267 None, 268 ctypes.POINTER( 269 make_nd_memref_descriptor(1, np.ctypeslib.as_ctypes_type(np.float32)) 270 ), 271 ) 272 def callback(a): 273 arr = ranked_memref_to_numpy(a) 274 log("Inside Callback: ") 275 log(arr) 276 277 with Context(): 278 # The module takes a subview of the argument memref and calls the callback with it 279 module = Module.parse( 280 r""" 281func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} { 282 %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index 283 %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>> 284 %cast = memref.cast %reinterpret_cast : memref<2xf32, strided<[1], offset: 3>> to memref<?xf32, strided<[?], offset: ?>> 285 call @some_callback_into_python(%cast) : (memref<?xf32, strided<[?], offset: ?>>) -> () 286 return 287} 288func.func private @some_callback_into_python(memref<?xf32, strided<[?], offset: ?>>) attributes {llvm.emit_c_interface} 289""" 290 ) 291 execution_engine = ExecutionEngine(lowerToLLVM(module)) 292 execution_engine.register_runtime("some_callback_into_python", callback) 293 inp_arr = np.array([0, 0, 0, 1, 2], np.float32) 294 # CHECK: Inside Callback: 295 # CHECK{LITERAL}: [1. 2.] 296 execution_engine.invoke( 297 "callback_memref", 298 ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))), 299 ) 300 301 302run(testRankedMemRefWithOffsetCallback) 303 304 305# Test callback with an unranked memref with non-zero offset 306# CHECK-LABEL: TEST: testUnrankedMemRefWithOffsetCallback 307def testUnrankedMemRefWithOffsetCallback(): 308 # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it. 309 @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) 310 def callback(a): 311 arr = unranked_memref_to_numpy(a, np.float32) 312 log("Inside callback: ") 313 log(arr) 314 315 with Context(): 316 # The module takes a subview of the argument memref, casts it to an unranked memref and 317 # calls the callback with it. 318 module = Module.parse( 319 r""" 320func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} { 321 %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index 322 %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>> 323 %cast = memref.cast %reinterpret_cast : memref<2xf32, strided<[1], offset: 3>> to memref<*xf32> 324 call @some_callback_into_python(%cast) : (memref<*xf32>) -> () 325 return 326} 327func.func private @some_callback_into_python(memref<*xf32>) attributes {llvm.emit_c_interface} 328""" 329 ) 330 execution_engine = ExecutionEngine(lowerToLLVM(module)) 331 execution_engine.register_runtime("some_callback_into_python", callback) 332 inp_arr = np.array([1, 2, 3, 4, 5], np.float32) 333 # CHECK: Inside callback: 334 # CHECK{LITERAL}: [4. 5.] 335 execution_engine.invoke( 336 "callback_memref", 337 ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))), 338 ) 339 340run(testUnrankedMemRefWithOffsetCallback) 341 342 343# Test addition of two memrefs. 344# CHECK-LABEL: TEST: testMemrefAdd 345def testMemrefAdd(): 346 with Context(): 347 module = Module.parse( 348 """ 349 module { 350 func.func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } { 351 %0 = arith.constant 0 : index 352 %1 = memref.load %arg0[%0] : memref<1xf32> 353 %2 = memref.load %arg1[] : memref<f32> 354 %3 = arith.addf %1, %2 : f32 355 memref.store %3, %arg2[%0] : memref<1xf32> 356 return 357 } 358 } """ 359 ) 360 arg1 = np.array([32.5]).astype(np.float32) 361 arg2 = np.array(6).astype(np.float32) 362 res = np.array([0]).astype(np.float32) 363 364 arg1_memref_ptr = ctypes.pointer( 365 ctypes.pointer(get_ranked_memref_descriptor(arg1)) 366 ) 367 arg2_memref_ptr = ctypes.pointer( 368 ctypes.pointer(get_ranked_memref_descriptor(arg2)) 369 ) 370 res_memref_ptr = ctypes.pointer( 371 ctypes.pointer(get_ranked_memref_descriptor(res)) 372 ) 373 374 execution_engine = ExecutionEngine(lowerToLLVM(module)) 375 execution_engine.invoke( 376 "main", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr 377 ) 378 # CHECK: [32.5] + 6.0 = [38.5] 379 log("{0} + {1} = {2}".format(arg1, arg2, res)) 380 381 382run(testMemrefAdd) 383 384 385# Test addition of two f16 memrefs 386# CHECK-LABEL: TEST: testF16MemrefAdd 387def testF16MemrefAdd(): 388 with Context(): 389 module = Module.parse( 390 """ 391 module { 392 func.func @main(%arg0: memref<1xf16>, 393 %arg1: memref<1xf16>, 394 %arg2: memref<1xf16>) attributes { llvm.emit_c_interface } { 395 %0 = arith.constant 0 : index 396 %1 = memref.load %arg0[%0] : memref<1xf16> 397 %2 = memref.load %arg1[%0] : memref<1xf16> 398 %3 = arith.addf %1, %2 : f16 399 memref.store %3, %arg2[%0] : memref<1xf16> 400 return 401 } 402 } """ 403 ) 404 405 arg1 = np.array([11.0]).astype(np.float16) 406 arg2 = np.array([22.0]).astype(np.float16) 407 arg3 = np.array([0.0]).astype(np.float16) 408 409 arg1_memref_ptr = ctypes.pointer( 410 ctypes.pointer(get_ranked_memref_descriptor(arg1)) 411 ) 412 arg2_memref_ptr = ctypes.pointer( 413 ctypes.pointer(get_ranked_memref_descriptor(arg2)) 414 ) 415 arg3_memref_ptr = ctypes.pointer( 416 ctypes.pointer(get_ranked_memref_descriptor(arg3)) 417 ) 418 419 execution_engine = ExecutionEngine(lowerToLLVM(module)) 420 execution_engine.invoke( 421 "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr 422 ) 423 # CHECK: [11.] + [22.] = [33.] 424 log("{0} + {1} = {2}".format(arg1, arg2, arg3)) 425 426 # test to-numpy utility 427 # CHECK: [33.] 428 npout = ranked_memref_to_numpy(arg3_memref_ptr[0]) 429 log(npout) 430 431 432run(testF16MemrefAdd) 433 434 435# Test addition of two complex memrefs 436# CHECK-LABEL: TEST: testComplexMemrefAdd 437def testComplexMemrefAdd(): 438 with Context(): 439 module = Module.parse( 440 """ 441 module { 442 func.func @main(%arg0: memref<1xcomplex<f64>>, 443 %arg1: memref<1xcomplex<f64>>, 444 %arg2: memref<1xcomplex<f64>>) attributes { llvm.emit_c_interface } { 445 %0 = arith.constant 0 : index 446 %1 = memref.load %arg0[%0] : memref<1xcomplex<f64>> 447 %2 = memref.load %arg1[%0] : memref<1xcomplex<f64>> 448 %3 = complex.add %1, %2 : complex<f64> 449 memref.store %3, %arg2[%0] : memref<1xcomplex<f64>> 450 return 451 } 452 } """ 453 ) 454 455 arg1 = np.array([1.0 + 2.0j]).astype(np.complex128) 456 arg2 = np.array([3.0 + 4.0j]).astype(np.complex128) 457 arg3 = np.array([0.0 + 0.0j]).astype(np.complex128) 458 459 arg1_memref_ptr = ctypes.pointer( 460 ctypes.pointer(get_ranked_memref_descriptor(arg1)) 461 ) 462 arg2_memref_ptr = ctypes.pointer( 463 ctypes.pointer(get_ranked_memref_descriptor(arg2)) 464 ) 465 arg3_memref_ptr = ctypes.pointer( 466 ctypes.pointer(get_ranked_memref_descriptor(arg3)) 467 ) 468 469 execution_engine = ExecutionEngine(lowerToLLVM(module)) 470 execution_engine.invoke( 471 "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr 472 ) 473 # CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j] 474 log("{0} + {1} = {2}".format(arg1, arg2, arg3)) 475 476 # test to-numpy utility 477 # CHECK: [4.+6.j] 478 npout = ranked_memref_to_numpy(arg3_memref_ptr[0]) 479 log(npout) 480 481 482run(testComplexMemrefAdd) 483 484 485# Test addition of two complex unranked memrefs 486# CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd 487def testComplexUnrankedMemrefAdd(): 488 with Context(): 489 module = Module.parse( 490 """ 491 module { 492 func.func @main(%arg0: memref<*xcomplex<f32>>, 493 %arg1: memref<*xcomplex<f32>>, 494 %arg2: memref<*xcomplex<f32>>) attributes { llvm.emit_c_interface } { 495 %A = memref.cast %arg0 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>> 496 %B = memref.cast %arg1 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>> 497 %C = memref.cast %arg2 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>> 498 %0 = arith.constant 0 : index 499 %1 = memref.load %A[%0] : memref<1xcomplex<f32>> 500 %2 = memref.load %B[%0] : memref<1xcomplex<f32>> 501 %3 = complex.add %1, %2 : complex<f32> 502 memref.store %3, %C[%0] : memref<1xcomplex<f32>> 503 return 504 } 505 } """ 506 ) 507 508 arg1 = np.array([5.0 + 6.0j]).astype(np.complex64) 509 arg2 = np.array([7.0 + 8.0j]).astype(np.complex64) 510 arg3 = np.array([0.0 + 0.0j]).astype(np.complex64) 511 512 arg1_memref_ptr = ctypes.pointer( 513 ctypes.pointer(get_unranked_memref_descriptor(arg1)) 514 ) 515 arg2_memref_ptr = ctypes.pointer( 516 ctypes.pointer(get_unranked_memref_descriptor(arg2)) 517 ) 518 arg3_memref_ptr = ctypes.pointer( 519 ctypes.pointer(get_unranked_memref_descriptor(arg3)) 520 ) 521 522 execution_engine = ExecutionEngine(lowerToLLVM(module)) 523 execution_engine.invoke( 524 "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr 525 ) 526 # CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j] 527 log("{0} + {1} = {2}".format(arg1, arg2, arg3)) 528 529 # test to-numpy utility 530 # CHECK: [12.+14.j] 531 npout = unranked_memref_to_numpy(arg3_memref_ptr[0], np.dtype(np.complex64)) 532 log(npout) 533 534 535run(testComplexUnrankedMemrefAdd) 536 537 538# Test bf16 memrefs 539# CHECK-LABEL: TEST: testBF16Memref 540def testBF16Memref(): 541 with Context(): 542 module = Module.parse( 543 """ 544 module { 545 func.func @main(%arg0: memref<1xbf16>, 546 %arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } { 547 %0 = arith.constant 0 : index 548 %1 = memref.load %arg0[%0] : memref<1xbf16> 549 memref.store %1, %arg1[%0] : memref<1xbf16> 550 return 551 } 552 } """ 553 ) 554 555 arg1 = np.array([0.5]).astype(bfloat16) 556 arg2 = np.array([0.0]).astype(bfloat16) 557 558 arg1_memref_ptr = ctypes.pointer( 559 ctypes.pointer(get_ranked_memref_descriptor(arg1)) 560 ) 561 arg2_memref_ptr = ctypes.pointer( 562 ctypes.pointer(get_ranked_memref_descriptor(arg2)) 563 ) 564 565 execution_engine = ExecutionEngine(lowerToLLVM(module)) 566 execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr) 567 568 # test to-numpy utility 569 x = ranked_memref_to_numpy(arg2_memref_ptr[0]) 570 assert len(x) == 1 571 assert x[0] == 0.5 572 573 574if HAS_ML_DTYPES: 575 run(testBF16Memref) 576else: 577 log("TEST: testBF16Memref") 578 579 580# Test f8E5M2 memrefs 581# CHECK-LABEL: TEST: testF8E5M2Memref 582def testF8E5M2Memref(): 583 with Context(): 584 module = Module.parse( 585 """ 586 module { 587 func.func @main(%arg0: memref<1xf8E5M2>, 588 %arg1: memref<1xf8E5M2>) attributes { llvm.emit_c_interface } { 589 %0 = arith.constant 0 : index 590 %1 = memref.load %arg0[%0] : memref<1xf8E5M2> 591 memref.store %1, %arg1[%0] : memref<1xf8E5M2> 592 return 593 } 594 } """ 595 ) 596 597 arg1 = np.array([0.5]).astype(float8_e5m2) 598 arg2 = np.array([0.0]).astype(float8_e5m2) 599 600 arg1_memref_ptr = ctypes.pointer( 601 ctypes.pointer(get_ranked_memref_descriptor(arg1)) 602 ) 603 arg2_memref_ptr = ctypes.pointer( 604 ctypes.pointer(get_ranked_memref_descriptor(arg2)) 605 ) 606 607 execution_engine = ExecutionEngine(lowerToLLVM(module)) 608 execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr) 609 610 # test to-numpy utility 611 x = ranked_memref_to_numpy(arg2_memref_ptr[0]) 612 assert len(x) == 1 613 assert x[0] == 0.5 614 615 616if HAS_ML_DTYPES: 617 run(testF8E5M2Memref) 618else: 619 log("TEST: testF8E5M2Memref") 620 621 622# Test addition of two 2d_memref 623# CHECK-LABEL: TEST: testDynamicMemrefAdd2D 624def testDynamicMemrefAdd2D(): 625 with Context(): 626 module = Module.parse( 627 """ 628 module { 629 func.func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} { 630 %c0 = arith.constant 0 : index 631 %c2 = arith.constant 2 : index 632 %c1 = arith.constant 1 : index 633 cf.br ^bb1(%c0 : index) 634 ^bb1(%0: index): // 2 preds: ^bb0, ^bb5 635 %1 = arith.cmpi slt, %0, %c2 : index 636 cf.cond_br %1, ^bb2, ^bb6 637 ^bb2: // pred: ^bb1 638 %c0_0 = arith.constant 0 : index 639 %c2_1 = arith.constant 2 : index 640 %c1_2 = arith.constant 1 : index 641 cf.br ^bb3(%c0_0 : index) 642 ^bb3(%2: index): // 2 preds: ^bb2, ^bb4 643 %3 = arith.cmpi slt, %2, %c2_1 : index 644 cf.cond_br %3, ^bb4, ^bb5 645 ^bb4: // pred: ^bb3 646 %4 = memref.load %arg0[%0, %2] : memref<2x2xf32> 647 %5 = memref.load %arg1[%0, %2] : memref<?x?xf32> 648 %6 = arith.addf %4, %5 : f32 649 memref.store %6, %arg2[%0, %2] : memref<2x2xf32> 650 %7 = arith.addi %2, %c1_2 : index 651 cf.br ^bb3(%7 : index) 652 ^bb5: // pred: ^bb3 653 %8 = arith.addi %0, %c1 : index 654 cf.br ^bb1(%8 : index) 655 ^bb6: // pred: ^bb1 656 return 657 } 658 } 659 """ 660 ) 661 arg1 = np.random.randn(2, 2).astype(np.float32) 662 arg2 = np.random.randn(2, 2).astype(np.float32) 663 res = np.random.randn(2, 2).astype(np.float32) 664 665 arg1_memref_ptr = ctypes.pointer( 666 ctypes.pointer(get_ranked_memref_descriptor(arg1)) 667 ) 668 arg2_memref_ptr = ctypes.pointer( 669 ctypes.pointer(get_ranked_memref_descriptor(arg2)) 670 ) 671 res_memref_ptr = ctypes.pointer( 672 ctypes.pointer(get_ranked_memref_descriptor(res)) 673 ) 674 675 execution_engine = ExecutionEngine(lowerToLLVM(module)) 676 execution_engine.invoke( 677 "memref_add_2d", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr 678 ) 679 # CHECK: True 680 log(np.allclose(arg1 + arg2, res)) 681 682 683run(testDynamicMemrefAdd2D) 684 685 686# Test loading of shared libraries. 687# CHECK-LABEL: TEST: testSharedLibLoad 688def testSharedLibLoad(): 689 with Context(): 690 module = Module.parse( 691 """ 692 module { 693 func.func @main(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } { 694 %c0 = arith.constant 0 : index 695 %cst42 = arith.constant 42.0 : f32 696 memref.store %cst42, %arg0[%c0] : memref<1xf32> 697 %u_memref = memref.cast %arg0 : memref<1xf32> to memref<*xf32> 698 call @printMemrefF32(%u_memref) : (memref<*xf32>) -> () 699 return 700 } 701 func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface } 702 } """ 703 ) 704 arg0 = np.array([0.0]).astype(np.float32) 705 706 arg0_memref_ptr = ctypes.pointer( 707 ctypes.pointer(get_ranked_memref_descriptor(arg0)) 708 ) 709 710 if sys.platform == "win32": 711 shared_libs = [ 712 "../../../../bin/mlir_runner_utils.dll", 713 "../../../../bin/mlir_c_runner_utils.dll", 714 ] 715 elif sys.platform == "darwin": 716 shared_libs = [ 717 "../../../../lib/libmlir_runner_utils.dylib", 718 "../../../../lib/libmlir_c_runner_utils.dylib", 719 ] 720 else: 721 shared_libs = [ 722 MLIR_RUNNER_UTILS, 723 MLIR_C_RUNNER_UTILS, 724 ] 725 726 execution_engine = ExecutionEngine( 727 lowerToLLVM(module), opt_level=3, shared_libs=shared_libs 728 ) 729 execution_engine.invoke("main", arg0_memref_ptr) 730 # CHECK: Unranked Memref 731 # CHECK-NEXT: [42] 732 733 734run(testSharedLibLoad) 735 736 737# Test that nano time clock is available. 738# CHECK-LABEL: TEST: testNanoTime 739def testNanoTime(): 740 with Context(): 741 module = Module.parse( 742 """ 743 module { 744 func.func @main() attributes { llvm.emit_c_interface } { 745 %now = call @nanoTime() : () -> i64 746 %memref = memref.alloca() : memref<1xi64> 747 %c0 = arith.constant 0 : index 748 memref.store %now, %memref[%c0] : memref<1xi64> 749 %u_memref = memref.cast %memref : memref<1xi64> to memref<*xi64> 750 call @printMemrefI64(%u_memref) : (memref<*xi64>) -> () 751 return 752 } 753 func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface } 754 func.func private @printMemrefI64(memref<*xi64>) attributes { llvm.emit_c_interface } 755 }""" 756 ) 757 758 if sys.platform == "win32": 759 shared_libs = [ 760 "../../../../bin/mlir_runner_utils.dll", 761 "../../../../bin/mlir_c_runner_utils.dll", 762 ] 763 else: 764 shared_libs = [ 765 MLIR_RUNNER_UTILS, 766 MLIR_C_RUNNER_UTILS, 767 ] 768 769 execution_engine = ExecutionEngine( 770 lowerToLLVM(module), opt_level=3, shared_libs=shared_libs 771 ) 772 execution_engine.invoke("main") 773 # CHECK: Unranked Memref 774 # CHECK: [{{.*}}] 775 776 777run(testNanoTime) 778 779 780# Test that nano time clock is available. 781# CHECK-LABEL: TEST: testDumpToObjectFile 782def testDumpToObjectFile(): 783 fd, object_path = tempfile.mkstemp(suffix=".o") 784 785 try: 786 with Context(): 787 module = Module.parse( 788 """ 789 module { 790 func.func @main() attributes { llvm.emit_c_interface } { 791 return 792 } 793 }""" 794 ) 795 796 execution_engine = ExecutionEngine(lowerToLLVM(module), opt_level=3) 797 798 # CHECK: Object file exists: True 799 print(f"Object file exists: {os.path.exists(object_path)}") 800 # CHECK: Object file is empty: True 801 print(f"Object file is empty: {os.path.getsize(object_path) == 0}") 802 803 execution_engine.dump_to_object_file(object_path) 804 805 # CHECK: Object file exists: True 806 print(f"Object file exists: {os.path.exists(object_path)}") 807 # CHECK: Object file is empty: False 808 print(f"Object file is empty: {os.path.getsize(object_path) == 0}") 809 810 finally: 811 os.close(fd) 812 os.remove(object_path) 813 814 815run(testDumpToObjectFile) 816