xref: /llvm-project/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py (revision 1944c4f76b47c0b86c91845987baca24fd4775f8)
1# RUN: env SUPPORT_LIB=%mlir_c_runner_utils \
2# RUN:   %PYTHON %s | FileCheck %s
3
4import ctypes
5import os
6import sys
7import tempfile
8
9from mlir import ir
10from mlir import runtime as rt
11from mlir.dialects import builtin
12from mlir.dialects import sparse_tensor as st
13
14_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
15sys.path.append(_SCRIPT_PATH)
16from tools import sparsifier
17
18
19def boilerplate(attr: st.EncodingAttr):
20    """Returns boilerplate main method."""
21    return f"""
22func.func @main(%p : !llvm.ptr) -> () attributes {{ llvm.emit_c_interface }} {{
23  %d = arith.constant sparse<[[0, 0], [1, 1], [0, 9], [9, 0], [4, 4]],
24                             [1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<10x10xf64>
25  %a = sparse_tensor.convert %d : tensor<10x10xf64> to tensor<10x10xf64, {attr}>
26  sparse_tensor.out %a, %p : tensor<10x10xf64, {attr}>, !llvm.ptr
27  return
28}}
29"""
30
31
32def expected(id_map):
33    """Returns expected contents of output.
34
35    +-----+-----+-----+-----+-----+
36    | 1 0 | . . | . . | . . | 0 3 |
37    | 0 2 | . . | . . | . . | 0 0 |
38    +-----+-----+-----+-----+-----+
39    | . . | . . | . . | . . | . . |
40    | . . | . . | . . | . . | . . |
41    +-----+-----+-----+-----+-----+
42    | . . | . . | 5 0 | . . | . . |
43    | . . | . . | 0 0 | . . | . . |
44    +-----+-----+-----+-----+-----+
45    | . . | . . | . . | . . | . . |
46    | . . | . . | . . | . . | . . |
47    +-----+-----+-----+-----+-----+
48    | 0 0 | . . | . . | . . | . . |
49    | 4 0 | . . | . . | . . | . . |
50    +-----+-----+-----+-----+-----+
51
52    Output appears as dimension coordinates but lexicographically
53    sorted by level coordinates. For BSR, the blocks are filled.
54    """
55    if id_map is 0:
56        return f"""# extended FROSTT format
572 5
5810 10
591 1 1
601 10 3
612 2 2
625 5 5
6310 1 4
64"""
65    if id_map is 1:
66        return f"""# extended FROSTT format
672 5
6810 10
691 1 1
7010 1 4
712 2 2
725 5 5
731 10 3
74"""
75    if id_map is 2:
76        return f"""# extended FROSTT format
772 16
7810 10
791 1 1
801 2 0
812 1 0
822 2 2
831 9 0
841 10 3
852 9 0
862 10 0
875 5 5
885 6 0
896 5 0
906 6 0
919 1 0
929 2 0
9310 1 4
9410 2 0
95"""
96    raise AssertionError("unexpected id_map")
97
98
99def build_compile_and_run_output(attr: st.EncodingAttr, compiler, expected):
100    # Build and Compile.
101    module = ir.Module.parse(boilerplate(attr))
102    engine = compiler.compile_and_jit(module)
103    # Invoke the kernel and compare output.
104    with tempfile.TemporaryDirectory() as test_dir:
105        out = os.path.join(test_dir, "out.tns")
106        buf = out.encode("utf-8")
107        mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf)))
108        engine.invoke("main", mem_a)
109        actual = open(out).read()
110        if actual != expected:
111            quit("FAILURE")
112
113
114def main():
115    support_lib = os.getenv("SUPPORT_LIB")
116    assert support_lib is not None, "SUPPORT_LIB is undefined"
117    if not os.path.exists(support_lib):
118        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
119
120    # CHECK-LABEL: TEST: test_output
121    print("\nTEST: test_output")
122    count = 0
123    with ir.Context() as ctx, ir.Location.unknown():
124        # Loop over various sparse types (COO, CSR, DCSR, CSC, DCSC) with
125        # regular and loose compression and various metadata bitwidths.
126        # For these simple orderings, dim2lvl and lvl2dim are the same.
127        levels = [
128            [st.LevelType.compressed_nu, st.LevelType.singleton],
129            [st.LevelType.dense, st.LevelType.compressed],
130            [st.LevelType.dense, st.LevelType.loose_compressed],
131            [st.LevelType.compressed, st.LevelType.compressed],
132        ]
133        orderings = [
134            (ir.AffineMap.get_permutation([0, 1]), 0),
135            (ir.AffineMap.get_permutation([1, 0]), 1),
136        ]
137        bitwidths = [8, 64]
138        compiler = sparsifier.Sparsifier(
139            options="", opt_level=2, shared_libs=[support_lib]
140        )
141        for level in levels:
142            for ordering, id_map in orderings:
143                for bwidth in bitwidths:
144                    attr = st.EncodingAttr.get(
145                        level, ordering, ordering, bwidth, bwidth
146                    )
147                    build_compile_and_run_output(attr, compiler, expected(id_map))
148                    count = count + 1
149
150        # Now do the same for BSR.
151        level = [
152            st.LevelType.dense,
153            st.LevelType.compressed,
154            st.LevelType.dense,
155            st.LevelType.dense,
156        ]
157        d0 = ir.AffineDimExpr.get(0)
158        d1 = ir.AffineDimExpr.get(1)
159        c2 = ir.AffineConstantExpr.get(2)
160        dim2lvl = ir.AffineMap.get(
161            2,
162            0,
163            [
164                ir.AffineExpr.get_floor_div(d0, c2),
165                ir.AffineExpr.get_floor_div(d1, c2),
166                ir.AffineExpr.get_mod(d0, c2),
167                ir.AffineExpr.get_mod(d1, c2),
168            ],
169        )
170        l0 = ir.AffineDimExpr.get(0)
171        l1 = ir.AffineDimExpr.get(1)
172        l2 = ir.AffineDimExpr.get(2)
173        l3 = ir.AffineDimExpr.get(3)
174        lvl2dim = ir.AffineMap.get(4, 0, [2 * l0 + l2, 2 * l1 + l3])
175        attr = st.EncodingAttr.get(level, dim2lvl, lvl2dim, 0, 0)
176        build_compile_and_run_output(attr, compiler, expected(2))
177        count = count + 1
178
179    # CHECK: Passed 17 tests
180    print("Passed", count, "tests")
181
182
183if __name__ == "__main__":
184    main()
185