xref: /llvm-project/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py (revision 52491c99fa8b30a558749da231fed7544159edca)
1# RUN: env SUPPORT_LIB=%mlir_c_runner_utils \
2# RUN:   %PYTHON %s | FileCheck %s
3
4import ctypes
5import os
6import sys
7import tempfile
8
9from mlir import ir
10from mlir import runtime as rt
11from mlir.dialects import builtin
12from mlir.dialects import sparse_tensor as st
13
14_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
15sys.path.append(_SCRIPT_PATH)
16from tools import sparse_compiler
17
18
19def boilerplate(attr: st.EncodingAttr):
20    """Returns boilerplate main method."""
21    return f"""
22func.func @main(%p : !llvm.ptr) -> () attributes {{ llvm.emit_c_interface }} {{
23  %d = arith.constant sparse<[[0, 0], [1, 1], [0, 9], [9, 0], [4, 4]],
24                             [1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<10x10xf64>
25  %a = sparse_tensor.convert %d : tensor<10x10xf64> to tensor<10x10xf64, {attr}>
26  sparse_tensor.out %a, %p : tensor<10x10xf64, {attr}>, !llvm.ptr
27  return
28}}
29"""
30
31
32def expected(id_map):
33    """Returns expected contents of output.
34
35    Output appears as dimension coordinates but lexicographically
36    sorted by level coordinates.
37    """
38    return (
39        f"""# extended FROSTT format
402 5
4110 10
421 1 1
431 10 3
442 2 2
455 5 5
4610 1 4
47"""
48        if id_map
49        else f"""# extended FROSTT format
502 5
5110 10
521 1 1
5310 1 4
542 2 2
555 5 5
561 10 3
57"""
58    )
59
60
61def build_compile_and_run_output(attr: st.EncodingAttr, compiler, expected):
62    # Build and Compile.
63    module = ir.Module.parse(boilerplate(attr))
64    engine = compiler.compile_and_jit(module)
65    # Invoke the kernel and compare output.
66    with tempfile.TemporaryDirectory() as test_dir:
67        out = os.path.join(test_dir, "out.tns")
68        buf = out.encode("utf-8")
69        mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf)))
70        engine.invoke("main", mem_a)
71        actual = open(out).read()
72        if actual != expected:
73            quit("FAILURE")
74
75
76def main():
77    support_lib = os.getenv("SUPPORT_LIB")
78    assert support_lib is not None, "SUPPORT_LIB is undefined"
79    if not os.path.exists(support_lib):
80        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
81
82    # CHECK-LABEL: TEST: test_output
83    print("\nTEST: test_output")
84    count = 0
85    with ir.Context() as ctx, ir.Location.unknown():
86        # Loop over various sparse types (COO, CSR, DCSR, CSC, DCSC) with
87        # regular and loose compression and various metadata bitwidths.
88        # For these simple orderings, dim2lvl and lvl2dim are the same.
89        levels = [
90            [st.DimLevelType.compressed_nu, st.DimLevelType.singleton],
91            [st.DimLevelType.dense, st.DimLevelType.compressed],
92            [st.DimLevelType.dense, st.DimLevelType.loose_compressed],
93            [st.DimLevelType.compressed, st.DimLevelType.compressed],
94        ]
95        orderings = [
96            (ir.AffineMap.get_permutation([0, 1]), True),
97            (ir.AffineMap.get_permutation([1, 0]), False),
98        ]
99        bitwidths = [8, 16, 32, 64]
100        compiler = sparse_compiler.SparseCompiler(
101            options="", opt_level=2, shared_libs=[support_lib]
102        )
103        for level in levels:
104            for ordering, id_map in orderings:
105                for bwidth in bitwidths:
106                    attr = st.EncodingAttr.get(
107                        level, ordering, ordering, bwidth, bwidth
108                    )
109                    build_compile_and_run_output(attr, compiler, expected(id_map))
110                    count = count + 1
111
112        # Now do the same for BSR.
113        level = [
114            st.DimLevelType.dense,
115            st.DimLevelType.compressed,
116            st.DimLevelType.dense,
117            st.DimLevelType.dense,
118        ]
119        d0 = ir.AffineDimExpr.get(0)
120        d1 = ir.AffineDimExpr.get(1)
121        c2 = ir.AffineConstantExpr.get(2)
122        dim2lvl = ir.AffineMap.get(
123            2,
124            0,
125            [
126                ir.AffineExpr.get_floor_div(d0, c2),
127                ir.AffineExpr.get_floor_div(d1, c2),
128                ir.AffineExpr.get_mod(d0, c2),
129                ir.AffineExpr.get_mod(d1, c2),
130            ],
131        )
132        l0 = ir.AffineDimExpr.get(0)
133        l1 = ir.AffineDimExpr.get(1)
134        l2 = ir.AffineDimExpr.get(2)
135        l3 = ir.AffineDimExpr.get(3)
136        lvl2dim = ir.AffineMap.get(4, 0, [2 * l0 + l2, 2 * l1 + l3])
137        attr = st.EncodingAttr.get(level, dim2lvl, lvl2dim, 0, 0)
138        # TODO: enable this one CONVERSION on BSR is working
139        # build_compile_and_run_output(attr, compiler, block_expected())
140        count = count + 1
141
142    # CHECK: Passed 33 tests
143    print("Passed", count, "tests")
144
145
146if __name__ == "__main__":
147    main()
148