1# RUN: SUPPORT_LIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ 2# RUN: %PYTHON %s | FileCheck %s 3 4import ctypes 5import os 6import tempfile 7 8import mlir.all_passes_registration 9 10from mlir import execution_engine 11from mlir import ir 12from mlir import passmanager 13from mlir import runtime as rt 14 15from mlir.dialects import builtin 16from mlir.dialects import sparse_tensor as st 17 18 19# TODO: move more into actual IR building. 20def boilerplate(attr: st.EncodingAttr): 21 """Returns boilerplate main method.""" 22 return f""" 23func @main(%p : !llvm.ptr<i8>) -> () attributes {{ llvm.emit_c_interface }} {{ 24 %d = arith.constant sparse<[[0, 0], [1, 1], [0, 9], [9, 0], [4, 4]], 25 [1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<10x10xf64> 26 %a = sparse_tensor.convert %d : tensor<10x10xf64> to tensor<10x10xf64, {attr}> 27 sparse_tensor.out %a, %p : tensor<10x10xf64, {attr}>, !llvm.ptr<i8> 28 return 29}} 30""" 31 32 33def expected(): 34 """Returns expected contents of output. 35 36 Regardless of the dimension ordering, compression, and bitwidths that are 37 used in the sparse tensor, the output is always lexicographically sorted 38 by natural index order. 39 """ 40 return f"""; extended FROSTT format 412 5 4210 10 431 1 1 441 10 3 452 2 2 465 5 5 4710 1 4 48""" 49 50 51def build_compile_and_run_output(attr: st.EncodingAttr, support_lib: str, 52 compiler): 53 # Build and Compile. 54 module = ir.Module.parse(boilerplate(attr)) 55 compiler(module) 56 engine = execution_engine.ExecutionEngine( 57 module, opt_level=0, shared_libs=[support_lib]) 58 59 # Invoke the kernel and compare output. 60 with tempfile.TemporaryDirectory() as test_dir: 61 out = os.path.join(test_dir, 'out.tns') 62 buf = out.encode('utf-8') 63 mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf))) 64 engine.invoke('main', mem_a) 65 66 actual = open(out).read() 67 if actual != expected(): 68 quit('FAILURE') 69 70 71class SparseCompiler: 72 """Sparse compiler passes.""" 73 74 def __init__(self): 75 pipeline = ( 76 f'builtin.func(linalg-generalize-named-ops,linalg-fuse-elementwise-ops),' 77 f'sparse-compiler{{reassociate-fp-reductions=1 enable-index-optimizations=1}}') 78 self.pipeline = pipeline 79 80 def __call__(self, module: ir.Module): 81 passmanager.PassManager.parse(self.pipeline).run(module) 82 83 84def main(): 85 support_lib = os.getenv('SUPPORT_LIB') 86 assert support_lib is not None, 'SUPPORT_LIB is undefined' 87 if not os.path.exists(support_lib): 88 raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), 89 support_lib) 90 91 # CHECK-LABEL: TEST: test_output 92 print('\nTEST: test_output') 93 count = 0 94 with ir.Context() as ctx, ir.Location.unknown(): 95 # Loop over various sparse types: CSR, DCSR, CSC, DCSC. 96 levels = [[st.DimLevelType.dense, st.DimLevelType.compressed], 97 [st.DimLevelType.compressed, st.DimLevelType.compressed]] 98 orderings = [ 99 ir.AffineMap.get_permutation([0, 1]), 100 ir.AffineMap.get_permutation([1, 0]) 101 ] 102 bitwidths = [8, 16, 32, 64] 103 for level in levels: 104 for ordering in orderings: 105 for bwidth in bitwidths: 106 attr = st.EncodingAttr.get(level, ordering, bwidth, bwidth) 107 compiler = SparseCompiler() 108 build_compile_and_run_output(attr, support_lib, compiler) 109 count = count + 1 110 111 # CHECK: Passed 16 tests 112 print('Passed', count, 'tests') 113 114 115if __name__ == '__main__': 116 main() 117