1# RUN: env SUPPORT_LIB=%mlir_c_runner_utils \ 2# RUN: %PYTHON %s | FileCheck %s 3 4import ctypes 5import os 6import sys 7import tempfile 8 9from mlir import ir 10from mlir import runtime as rt 11from mlir.dialects import builtin 12from mlir.dialects import sparse_tensor as st 13 14_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) 15sys.path.append(_SCRIPT_PATH) 16from tools import sparse_compiler 17 18 19def boilerplate(attr: st.EncodingAttr): 20 """Returns boilerplate main method.""" 21 return f""" 22func.func @main(%p : !llvm.ptr) -> () attributes {{ llvm.emit_c_interface }} {{ 23 %d = arith.constant sparse<[[0, 0], [1, 1], [0, 9], [9, 0], [4, 4]], 24 [1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<10x10xf64> 25 %a = sparse_tensor.convert %d : tensor<10x10xf64> to tensor<10x10xf64, {attr}> 26 sparse_tensor.out %a, %p : tensor<10x10xf64, {attr}>, !llvm.ptr 27 return 28}} 29""" 30 31 32def expected(id_map): 33 """Returns expected contents of output. 34 35 Output appears as dimension coordinates but lexicographically 36 sorted by level coordinates. 37 """ 38 return ( 39 f"""# extended FROSTT format 402 5 4110 10 421 1 1 431 10 3 442 2 2 455 5 5 4610 1 4 47""" 48 if id_map 49 else f"""# extended FROSTT format 502 5 5110 10 521 1 1 5310 1 4 542 2 2 555 5 5 561 10 3 57""" 58 ) 59 60 61def build_compile_and_run_output(attr: st.EncodingAttr, compiler, expected): 62 # Build and Compile. 63 module = ir.Module.parse(boilerplate(attr)) 64 engine = compiler.compile_and_jit(module) 65 # Invoke the kernel and compare output. 66 with tempfile.TemporaryDirectory() as test_dir: 67 out = os.path.join(test_dir, "out.tns") 68 buf = out.encode("utf-8") 69 mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf))) 70 engine.invoke("main", mem_a) 71 actual = open(out).read() 72 if actual != expected: 73 quit("FAILURE") 74 75 76def main(): 77 support_lib = os.getenv("SUPPORT_LIB") 78 assert support_lib is not None, "SUPPORT_LIB is undefined" 79 if not os.path.exists(support_lib): 80 raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib) 81 82 # CHECK-LABEL: TEST: test_output 83 print("\nTEST: test_output") 84 count = 0 85 with ir.Context() as ctx, ir.Location.unknown(): 86 # Loop over various sparse types (COO, CSR, DCSR, CSC, DCSC) with 87 # regular and loose compression and various metadata bitwidths. 88 # For these simple orderings, dim2lvl and lvl2dim are the same. 89 levels = [ 90 [st.DimLevelType.compressed_nu, st.DimLevelType.singleton], 91 [st.DimLevelType.dense, st.DimLevelType.compressed], 92 [st.DimLevelType.dense, st.DimLevelType.loose_compressed], 93 [st.DimLevelType.compressed, st.DimLevelType.compressed], 94 ] 95 orderings = [ 96 (ir.AffineMap.get_permutation([0, 1]), True), 97 (ir.AffineMap.get_permutation([1, 0]), False), 98 ] 99 bitwidths = [8, 16, 32, 64] 100 compiler = sparse_compiler.SparseCompiler( 101 options="", opt_level=2, shared_libs=[support_lib] 102 ) 103 for level in levels: 104 for ordering, id_map in orderings: 105 for bwidth in bitwidths: 106 attr = st.EncodingAttr.get( 107 level, ordering, ordering, bwidth, bwidth 108 ) 109 build_compile_and_run_output(attr, compiler, expected(id_map)) 110 count = count + 1 111 112 # Now do the same for BSR. 113 level = [ 114 st.DimLevelType.dense, 115 st.DimLevelType.compressed, 116 st.DimLevelType.dense, 117 st.DimLevelType.dense, 118 ] 119 d0 = ir.AffineDimExpr.get(0) 120 d1 = ir.AffineDimExpr.get(1) 121 c2 = ir.AffineConstantExpr.get(2) 122 dim2lvl = ir.AffineMap.get( 123 2, 124 0, 125 [ 126 ir.AffineExpr.get_floor_div(d0, c2), 127 ir.AffineExpr.get_floor_div(d1, c2), 128 ir.AffineExpr.get_mod(d0, c2), 129 ir.AffineExpr.get_mod(d1, c2), 130 ], 131 ) 132 l0 = ir.AffineDimExpr.get(0) 133 l1 = ir.AffineDimExpr.get(1) 134 l2 = ir.AffineDimExpr.get(2) 135 l3 = ir.AffineDimExpr.get(3) 136 lvl2dim = ir.AffineMap.get(4, 0, [2 * l0 + l2, 2 * l1 + l3]) 137 attr = st.EncodingAttr.get(level, dim2lvl, lvl2dim, 0, 0) 138 # TODO: enable this one CONVERSION on BSR is working 139 # build_compile_and_run_output(attr, compiler, block_expected()) 140 count = count + 1 141 142 # CHECK: Passed 33 tests 143 print("Passed", count, "tests") 144 145 146if __name__ == "__main__": 147 main() 148