1# RUN: env SUPPORT_LIB=%mlir_c_runner_utils \ 2# RUN: %PYTHON %s | FileCheck %s 3 4import ctypes 5import os 6import sys 7import tempfile 8 9from mlir import ir 10from mlir import runtime as rt 11from mlir.dialects import builtin 12from mlir.dialects import sparse_tensor as st 13 14_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) 15sys.path.append(_SCRIPT_PATH) 16from tools import sparsifier 17 18 19def boilerplate(attr: st.EncodingAttr): 20 """Returns boilerplate main method.""" 21 return f""" 22func.func @main(%p : !llvm.ptr) -> () attributes {{ llvm.emit_c_interface }} {{ 23 %d = arith.constant sparse<[[0, 0], [1, 1], [0, 9], [9, 0], [4, 4]], 24 [1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<10x10xf64> 25 %a = sparse_tensor.convert %d : tensor<10x10xf64> to tensor<10x10xf64, {attr}> 26 sparse_tensor.out %a, %p : tensor<10x10xf64, {attr}>, !llvm.ptr 27 return 28}} 29""" 30 31 32def expected(id_map): 33 """Returns expected contents of output. 34 35 +-----+-----+-----+-----+-----+ 36 | 1 0 | . . | . . | . . | 0 3 | 37 | 0 2 | . . | . . | . . | 0 0 | 38 +-----+-----+-----+-----+-----+ 39 | . . | . . | . . | . . | . . | 40 | . . | . . | . . | . . | . . | 41 +-----+-----+-----+-----+-----+ 42 | . . | . . | 5 0 | . . | . . | 43 | . . | . . | 0 0 | . . | . . | 44 +-----+-----+-----+-----+-----+ 45 | . . | . . | . . | . . | . . | 46 | . . | . . | . . | . . | . . | 47 +-----+-----+-----+-----+-----+ 48 | 0 0 | . . | . . | . . | . . | 49 | 4 0 | . . | . . | . . | . . | 50 +-----+-----+-----+-----+-----+ 51 52 Output appears as dimension coordinates but lexicographically 53 sorted by level coordinates. For BSR, the blocks are filled. 54 """ 55 if id_map is 0: 56 return f"""# extended FROSTT format 572 5 5810 10 591 1 1 601 10 3 612 2 2 625 5 5 6310 1 4 64""" 65 if id_map is 1: 66 return f"""# extended FROSTT format 672 5 6810 10 691 1 1 7010 1 4 712 2 2 725 5 5 731 10 3 74""" 75 if id_map is 2: 76 return f"""# extended FROSTT format 772 16 7810 10 791 1 1 801 2 0 812 1 0 822 2 2 831 9 0 841 10 3 852 9 0 862 10 0 875 5 5 885 6 0 896 5 0 906 6 0 919 1 0 929 2 0 9310 1 4 9410 2 0 95""" 96 raise AssertionError("unexpected id_map") 97 98 99def build_compile_and_run_output(attr: st.EncodingAttr, compiler, expected): 100 # Build and Compile. 101 module = ir.Module.parse(boilerplate(attr)) 102 engine = compiler.compile_and_jit(module) 103 # Invoke the kernel and compare output. 104 with tempfile.TemporaryDirectory() as test_dir: 105 out = os.path.join(test_dir, "out.tns") 106 buf = out.encode("utf-8") 107 mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf))) 108 engine.invoke("main", mem_a) 109 actual = open(out).read() 110 if actual != expected: 111 quit("FAILURE") 112 113 114def main(): 115 support_lib = os.getenv("SUPPORT_LIB") 116 assert support_lib is not None, "SUPPORT_LIB is undefined" 117 if not os.path.exists(support_lib): 118 raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib) 119 120 # CHECK-LABEL: TEST: test_output 121 print("\nTEST: test_output") 122 count = 0 123 with ir.Context() as ctx, ir.Location.unknown(): 124 # Loop over various sparse types (COO, CSR, DCSR, CSC, DCSC) with 125 # regular and loose compression and various metadata bitwidths. 126 # For these simple orderings, dim2lvl and lvl2dim are the same. 127 builder = st.EncodingAttr.build_level_type 128 fmt = st.LevelFormat 129 prop = st.LevelProperty 130 levels = [ 131 [builder(fmt.compressed, [prop.non_unique]), builder(fmt.singleton)], 132 [ 133 builder(fmt.compressed, [prop.non_unique]), 134 builder(fmt.singleton, [prop.soa]), 135 ], 136 [builder(fmt.dense), builder(fmt.compressed)], 137 [builder(fmt.dense), builder(fmt.loose_compressed)], 138 [builder(fmt.compressed), builder(fmt.compressed)], 139 ] 140 orderings = [ 141 (ir.AffineMap.get_permutation([0, 1]), 0), 142 (ir.AffineMap.get_permutation([1, 0]), 1), 143 ] 144 bitwidths = [8, 64] 145 compiler = sparsifier.Sparsifier( 146 extras="", options="", opt_level=2, shared_libs=[support_lib] 147 ) 148 for level in levels: 149 for ordering, id_map in orderings: 150 for bwidth in bitwidths: 151 attr = st.EncodingAttr.get( 152 level, ordering, ordering, bwidth, bwidth 153 ) 154 build_compile_and_run_output(attr, compiler, expected(id_map)) 155 count = count + 1 156 157 # Now do the same for BSR. 158 level = [ 159 builder(fmt.dense), 160 builder(fmt.compressed), 161 builder(fmt.dense), 162 builder(fmt.dense), 163 ] 164 d0 = ir.AffineDimExpr.get(0) 165 d1 = ir.AffineDimExpr.get(1) 166 c2 = ir.AffineConstantExpr.get(2) 167 dim2lvl = ir.AffineMap.get( 168 2, 169 0, 170 [ 171 ir.AffineExpr.get_floor_div(d0, c2), 172 ir.AffineExpr.get_floor_div(d1, c2), 173 ir.AffineExpr.get_mod(d0, c2), 174 ir.AffineExpr.get_mod(d1, c2), 175 ], 176 ) 177 l0 = ir.AffineDimExpr.get(0) 178 l1 = ir.AffineDimExpr.get(1) 179 l2 = ir.AffineDimExpr.get(2) 180 l3 = ir.AffineDimExpr.get(3) 181 lvl2dim = ir.AffineMap.get(4, 0, [2 * l0 + l2, 2 * l1 + l3]) 182 attr = st.EncodingAttr.get(level, dim2lvl, lvl2dim, 0, 0) 183 build_compile_and_run_output(attr, compiler, expected(2)) 184 count = count + 1 185 186 # CHECK: Passed 21 tests 187 print("Passed", count, "tests") 188 189 190if __name__ == "__main__": 191 main() 192