19048ea28SMarkus Böck# RUN: env SUPPORT_LIB=%mlir_c_runner_utils \ 292c1c63dSAart Bik# RUN: %PYTHON %s | FileCheck %s 392c1c63dSAart Bik 492c1c63dSAart Bikimport ctypes 592c1c63dSAart Bikimport os 68b83b8f1SAart Bikimport sys 792c1c63dSAart Bikimport tempfile 892c1c63dSAart Bik 992c1c63dSAart Bikfrom mlir import ir 1092c1c63dSAart Bikfrom mlir import runtime as rt 1192c1c63dSAart Bikfrom mlir.dialects import builtin 1292c1c63dSAart Bikfrom mlir.dialects import sparse_tensor as st 1392c1c63dSAart Bik 148b83b8f1SAart Bik_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) 158b83b8f1SAart Biksys.path.append(_SCRIPT_PATH) 16dce7a7cfSTim Harveyfrom tools import sparsifier 1792c1c63dSAart Bik 18ed2d0b0eSAart Bik 1992c1c63dSAart Bikdef boilerplate(attr: st.EncodingAttr): 2092c1c63dSAart Bik """Returns boilerplate main method.""" 2192c1c63dSAart Bik return f""" 2252491c99SChristian Ulmannfunc.func @main(%p : !llvm.ptr) -> () attributes {{ llvm.emit_c_interface }} {{ 2392c1c63dSAart Bik %d = arith.constant sparse<[[0, 0], [1, 1], [0, 9], [9, 0], [4, 4]], 2492c1c63dSAart Bik [1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<10x10xf64> 2592c1c63dSAart Bik %a = sparse_tensor.convert %d : tensor<10x10xf64> to tensor<10x10xf64, {attr}> 2652491c99SChristian Ulmann sparse_tensor.out %a, %p : tensor<10x10xf64, {attr}>, !llvm.ptr 2792c1c63dSAart Bik return 2892c1c63dSAart Bik}} 2992c1c63dSAart Bik""" 3092c1c63dSAart Bik 3192c1c63dSAart Bik 327d608ee2SPeiming Liudef expected(id_map): 3392c1c63dSAart Bik """Returns expected contents of output. 3492c1c63dSAart Bik 35a89c15aaSAart Bik +-----+-----+-----+-----+-----+ 36a89c15aaSAart Bik | 1 0 | . . | . . | . . | 0 3 | 37a89c15aaSAart Bik | 0 2 | . . | . . | . . | 0 0 | 38a89c15aaSAart Bik +-----+-----+-----+-----+-----+ 39a89c15aaSAart Bik | . . | . . | . . | . . | . . | 40a89c15aaSAart Bik | . . | . . | . . | . . | . . | 41a89c15aaSAart Bik +-----+-----+-----+-----+-----+ 42a89c15aaSAart Bik | . . | . . | 5 0 | . . | . . | 43a89c15aaSAart Bik | . . | . . | 0 0 | . . | . . | 44a89c15aaSAart Bik +-----+-----+-----+-----+-----+ 45a89c15aaSAart Bik | . . | . . | . . | . . | . . | 46a89c15aaSAart Bik | . . | . . | . . | . . | . . | 47a89c15aaSAart Bik +-----+-----+-----+-----+-----+ 48a89c15aaSAart Bik | 0 0 | . . | . . | . . | . . | 49a89c15aaSAart Bik | 4 0 | . . | . . | . . | . . | 50a89c15aaSAart Bik +-----+-----+-----+-----+-----+ 51a89c15aaSAart Bik 527d608ee2SPeiming Liu Output appears as dimension coordinates but lexicographically 53a89c15aaSAart Bik sorted by level coordinates. For BSR, the blocks are filled. 5492c1c63dSAart Bik """ 55a89c15aaSAart Bik if id_map is 0: 56a89c15aaSAart Bik return f"""# extended FROSTT format 5792c1c63dSAart Bik2 5 5892c1c63dSAart Bik10 10 5992c1c63dSAart Bik1 1 1 6092c1c63dSAart Bik1 10 3 6192c1c63dSAart Bik2 2 2 6292c1c63dSAart Bik5 5 5 6392c1c63dSAart Bik10 1 4 6492c1c63dSAart Bik""" 65a89c15aaSAart Bik if id_map is 1: 66a89c15aaSAart Bik return f"""# extended FROSTT format 677d608ee2SPeiming Liu2 5 687d608ee2SPeiming Liu10 10 697d608ee2SPeiming Liu1 1 1 707d608ee2SPeiming Liu10 1 4 717d608ee2SPeiming Liu2 2 2 727d608ee2SPeiming Liu5 5 5 737d608ee2SPeiming Liu1 10 3 747d608ee2SPeiming Liu""" 75a89c15aaSAart Bik if id_map is 2: 76a89c15aaSAart Bik return f"""# extended FROSTT format 77a89c15aaSAart Bik2 16 78a89c15aaSAart Bik10 10 79a89c15aaSAart Bik1 1 1 80a89c15aaSAart Bik1 2 0 81a89c15aaSAart Bik2 1 0 82a89c15aaSAart Bik2 2 2 83a89c15aaSAart Bik1 9 0 84a89c15aaSAart Bik1 10 3 85a89c15aaSAart Bik2 9 0 86a89c15aaSAart Bik2 10 0 87a89c15aaSAart Bik5 5 5 88a89c15aaSAart Bik5 6 0 89a89c15aaSAart Bik6 5 0 90a89c15aaSAart Bik6 6 0 91a89c15aaSAart Bik9 1 0 92a89c15aaSAart Bik9 2 0 93a89c15aaSAart Bik10 1 4 94a89c15aaSAart Bik10 2 0 95a89c15aaSAart Bik""" 96a89c15aaSAart Bik raise AssertionError("unexpected id_map") 9792c1c63dSAart Bik 9892c1c63dSAart Bik 99ed2d0b0eSAart Bikdef build_compile_and_run_output(attr: st.EncodingAttr, compiler, expected): 10092c1c63dSAart Bik # Build and Compile. 10192c1c63dSAart Bik module = ir.Module.parse(boilerplate(attr)) 10228063a28SAart Bik engine = compiler.compile_and_jit(module) 10392c1c63dSAart Bik # Invoke the kernel and compare output. 10492c1c63dSAart Bik with tempfile.TemporaryDirectory() as test_dir: 105f9008e63STobias Hieta out = os.path.join(test_dir, "out.tns") 106f9008e63STobias Hieta buf = out.encode("utf-8") 10792c1c63dSAart Bik mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf))) 108f9008e63STobias Hieta engine.invoke("main", mem_a) 10992c1c63dSAart Bik actual = open(out).read() 110ed2d0b0eSAart Bik if actual != expected: 111f9008e63STobias Hieta quit("FAILURE") 11292c1c63dSAart Bik 11392c1c63dSAart Bik 11492c1c63dSAart Bikdef main(): 115f9008e63STobias Hieta support_lib = os.getenv("SUPPORT_LIB") 116f9008e63STobias Hieta assert support_lib is not None, "SUPPORT_LIB is undefined" 11792c1c63dSAart Bik if not os.path.exists(support_lib): 118f9008e63STobias Hieta raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib) 11992c1c63dSAart Bik 12092c1c63dSAart Bik # CHECK-LABEL: TEST: test_output 121f9008e63STobias Hieta print("\nTEST: test_output") 12292c1c63dSAart Bik count = 0 12392c1c63dSAart Bik with ir.Context() as ctx, ir.Location.unknown(): 124ed2d0b0eSAart Bik # Loop over various sparse types (COO, CSR, DCSR, CSC, DCSC) with 125ed2d0b0eSAart Bik # regular and loose compression and various metadata bitwidths. 126ed2d0b0eSAart Bik # For these simple orderings, dim2lvl and lvl2dim are the same. 127429919e3SPeiming Liu builder = st.EncodingAttr.build_level_type 128429919e3SPeiming Liu fmt = st.LevelFormat 129429919e3SPeiming Liu prop = st.LevelProperty 130f9008e63STobias Hieta levels = [ 131429919e3SPeiming Liu [builder(fmt.compressed, [prop.non_unique]), builder(fmt.singleton)], 132b50ce4c8SMateusz Sokół [ 133b50ce4c8SMateusz Sokół builder(fmt.compressed, [prop.non_unique]), 134b50ce4c8SMateusz Sokół builder(fmt.singleton, [prop.soa]), 135b50ce4c8SMateusz Sokół ], 136429919e3SPeiming Liu [builder(fmt.dense), builder(fmt.compressed)], 137429919e3SPeiming Liu [builder(fmt.dense), builder(fmt.loose_compressed)], 138429919e3SPeiming Liu [builder(fmt.compressed), builder(fmt.compressed)], 139f9008e63STobias Hieta ] 14092c1c63dSAart Bik orderings = [ 141a89c15aaSAart Bik (ir.AffineMap.get_permutation([0, 1]), 0), 142a89c15aaSAart Bik (ir.AffineMap.get_permutation([1, 0]), 1), 14392c1c63dSAart Bik ] 144a89c15aaSAart Bik bitwidths = [8, 64] 145dce7a7cfSTim Harvey compiler = sparsifier.Sparsifier( 1460e34dbb4SAart Bik extras="", options="", opt_level=2, shared_libs=[support_lib] 147f9008e63STobias Hieta ) 14892c1c63dSAart Bik for level in levels: 1497d608ee2SPeiming Liu for ordering, id_map in orderings: 15092c1c63dSAart Bik for bwidth in bitwidths: 151ed2d0b0eSAart Bik attr = st.EncodingAttr.get( 152ed2d0b0eSAart Bik level, ordering, ordering, bwidth, bwidth 153ed2d0b0eSAart Bik ) 1547d608ee2SPeiming Liu build_compile_and_run_output(attr, compiler, expected(id_map)) 15592c1c63dSAart Bik count = count + 1 15692c1c63dSAart Bik 157ed2d0b0eSAart Bik # Now do the same for BSR. 158ed2d0b0eSAart Bik level = [ 159429919e3SPeiming Liu builder(fmt.dense), 160429919e3SPeiming Liu builder(fmt.compressed), 161429919e3SPeiming Liu builder(fmt.dense), 162429919e3SPeiming Liu builder(fmt.dense), 163ed2d0b0eSAart Bik ] 164ed2d0b0eSAart Bik d0 = ir.AffineDimExpr.get(0) 165ed2d0b0eSAart Bik d1 = ir.AffineDimExpr.get(1) 166ed2d0b0eSAart Bik c2 = ir.AffineConstantExpr.get(2) 167ed2d0b0eSAart Bik dim2lvl = ir.AffineMap.get( 168ed2d0b0eSAart Bik 2, 169ed2d0b0eSAart Bik 0, 170ed2d0b0eSAart Bik [ 171ed2d0b0eSAart Bik ir.AffineExpr.get_floor_div(d0, c2), 172ed2d0b0eSAart Bik ir.AffineExpr.get_floor_div(d1, c2), 173ed2d0b0eSAart Bik ir.AffineExpr.get_mod(d0, c2), 174ed2d0b0eSAart Bik ir.AffineExpr.get_mod(d1, c2), 175ed2d0b0eSAart Bik ], 176ed2d0b0eSAart Bik ) 177ed2d0b0eSAart Bik l0 = ir.AffineDimExpr.get(0) 178ed2d0b0eSAart Bik l1 = ir.AffineDimExpr.get(1) 179ed2d0b0eSAart Bik l2 = ir.AffineDimExpr.get(2) 180ed2d0b0eSAart Bik l3 = ir.AffineDimExpr.get(3) 181ed2d0b0eSAart Bik lvl2dim = ir.AffineMap.get(4, 0, [2 * l0 + l2, 2 * l1 + l3]) 182ed2d0b0eSAart Bik attr = st.EncodingAttr.get(level, dim2lvl, lvl2dim, 0, 0) 183a89c15aaSAart Bik build_compile_and_run_output(attr, compiler, expected(2)) 184ed2d0b0eSAart Bik count = count + 1 185ed2d0b0eSAart Bik 186*e3686f1eSMateusz Sokół # CHECK: Passed 21 tests 187f9008e63STobias Hieta print("Passed", count, "tests") 18892c1c63dSAart Bik 18992c1c63dSAart Bik 190f9008e63STobias Hietaif __name__ == "__main__": 19192c1c63dSAart Bik main() 192