1# RUN: SUPPORT_LIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ 2# RUN: %PYTHON %s | FileCheck %s 3 4import ctypes 5import os 6import sys 7import tempfile 8 9from mlir import ir 10from mlir import runtime as rt 11 12from mlir.dialects import builtin 13from mlir.dialects import sparse_tensor as st 14 15_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) 16sys.path.append(_SCRIPT_PATH) 17from tools import sparse_compiler 18 19# TODO: move more into actual IR building. 20def boilerplate(attr: st.EncodingAttr): 21 """Returns boilerplate main method.""" 22 return f""" 23func @main(%p : !llvm.ptr<i8>) -> () attributes {{ llvm.emit_c_interface }} {{ 24 %d = arith.constant sparse<[[0, 0], [1, 1], [0, 9], [9, 0], [4, 4]], 25 [1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<10x10xf64> 26 %a = sparse_tensor.convert %d : tensor<10x10xf64> to tensor<10x10xf64, {attr}> 27 sparse_tensor.out %a, %p : tensor<10x10xf64, {attr}>, !llvm.ptr<i8> 28 return 29}} 30""" 31 32 33def expected(): 34 """Returns expected contents of output. 35 36 Regardless of the dimension ordering, compression, and bitwidths that are 37 used in the sparse tensor, the output is always lexicographically sorted 38 by natural index order. 39 """ 40 return f"""; extended FROSTT format 412 5 4210 10 431 1 1 441 10 3 452 2 2 465 5 5 4710 1 4 48""" 49 50 51def build_compile_and_run_output(attr: st.EncodingAttr, compiler): 52 # Build and Compile. 53 module = ir.Module.parse(boilerplate(attr)) 54 engine = compiler.compile_and_jit(module) 55 56 # Invoke the kernel and compare output. 57 with tempfile.TemporaryDirectory() as test_dir: 58 out = os.path.join(test_dir, 'out.tns') 59 buf = out.encode('utf-8') 60 mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf))) 61 engine.invoke('main', mem_a) 62 63 actual = open(out).read() 64 if actual != expected(): 65 quit('FAILURE') 66 67 68def main(): 69 support_lib = os.getenv('SUPPORT_LIB') 70 assert support_lib is not None, 'SUPPORT_LIB is undefined' 71 if not os.path.exists(support_lib): 72 raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), 73 support_lib) 74 75 # CHECK-LABEL: TEST: test_output 76 print('\nTEST: test_output') 77 count = 0 78 with ir.Context() as ctx, ir.Location.unknown(): 79 # Loop over various sparse types: CSR, DCSR, CSC, DCSC. 80 levels = [[st.DimLevelType.dense, st.DimLevelType.compressed], 81 [st.DimLevelType.compressed, st.DimLevelType.compressed]] 82 orderings = [ 83 ir.AffineMap.get_permutation([0, 1]), 84 ir.AffineMap.get_permutation([1, 0]) 85 ] 86 bitwidths = [8, 16, 32, 64] 87 compiler = sparse_compiler.SparseCompiler( 88 options='', opt_level=2, shared_libs=[support_lib]) 89 for level in levels: 90 for ordering in orderings: 91 for bwidth in bitwidths: 92 attr = st.EncodingAttr.get(level, ordering, bwidth, bwidth) 93 build_compile_and_run_output(attr, compiler) 94 count = count + 1 95 96 # CHECK: Passed 16 tests 97 print('Passed', count, 'tests') 98 99 100if __name__ == '__main__': 101 main() 102