xref: /llvm-project/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py (revision e3686f1e44676fa28789c6732076b8998be23527)
1# RUN: env SUPPORT_LIB=%mlir_c_runner_utils \
2# RUN:   %PYTHON %s | FileCheck %s
3
4import ctypes
5import os
6import sys
7import tempfile
8
9from mlir import ir
10from mlir import runtime as rt
11from mlir.dialects import builtin
12from mlir.dialects import sparse_tensor as st
13
14_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
15sys.path.append(_SCRIPT_PATH)
16from tools import sparsifier
17
18
19def boilerplate(attr: st.EncodingAttr):
20    """Returns boilerplate main method."""
21    return f"""
22func.func @main(%p : !llvm.ptr) -> () attributes {{ llvm.emit_c_interface }} {{
23  %d = arith.constant sparse<[[0, 0], [1, 1], [0, 9], [9, 0], [4, 4]],
24                             [1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<10x10xf64>
25  %a = sparse_tensor.convert %d : tensor<10x10xf64> to tensor<10x10xf64, {attr}>
26  sparse_tensor.out %a, %p : tensor<10x10xf64, {attr}>, !llvm.ptr
27  return
28}}
29"""
30
31
32def expected(id_map):
33    """Returns expected contents of output.
34
35    +-----+-----+-----+-----+-----+
36    | 1 0 | . . | . . | . . | 0 3 |
37    | 0 2 | . . | . . | . . | 0 0 |
38    +-----+-----+-----+-----+-----+
39    | . . | . . | . . | . . | . . |
40    | . . | . . | . . | . . | . . |
41    +-----+-----+-----+-----+-----+
42    | . . | . . | 5 0 | . . | . . |
43    | . . | . . | 0 0 | . . | . . |
44    +-----+-----+-----+-----+-----+
45    | . . | . . | . . | . . | . . |
46    | . . | . . | . . | . . | . . |
47    +-----+-----+-----+-----+-----+
48    | 0 0 | . . | . . | . . | . . |
49    | 4 0 | . . | . . | . . | . . |
50    +-----+-----+-----+-----+-----+
51
52    Output appears as dimension coordinates but lexicographically
53    sorted by level coordinates. For BSR, the blocks are filled.
54    """
55    if id_map is 0:
56        return f"""# extended FROSTT format
572 5
5810 10
591 1 1
601 10 3
612 2 2
625 5 5
6310 1 4
64"""
65    if id_map is 1:
66        return f"""# extended FROSTT format
672 5
6810 10
691 1 1
7010 1 4
712 2 2
725 5 5
731 10 3
74"""
75    if id_map is 2:
76        return f"""# extended FROSTT format
772 16
7810 10
791 1 1
801 2 0
812 1 0
822 2 2
831 9 0
841 10 3
852 9 0
862 10 0
875 5 5
885 6 0
896 5 0
906 6 0
919 1 0
929 2 0
9310 1 4
9410 2 0
95"""
96    raise AssertionError("unexpected id_map")
97
98
99def build_compile_and_run_output(attr: st.EncodingAttr, compiler, expected):
100    # Build and Compile.
101    module = ir.Module.parse(boilerplate(attr))
102    engine = compiler.compile_and_jit(module)
103    # Invoke the kernel and compare output.
104    with tempfile.TemporaryDirectory() as test_dir:
105        out = os.path.join(test_dir, "out.tns")
106        buf = out.encode("utf-8")
107        mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf)))
108        engine.invoke("main", mem_a)
109        actual = open(out).read()
110        if actual != expected:
111            quit("FAILURE")
112
113
114def main():
115    support_lib = os.getenv("SUPPORT_LIB")
116    assert support_lib is not None, "SUPPORT_LIB is undefined"
117    if not os.path.exists(support_lib):
118        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
119
120    # CHECK-LABEL: TEST: test_output
121    print("\nTEST: test_output")
122    count = 0
123    with ir.Context() as ctx, ir.Location.unknown():
124        # Loop over various sparse types (COO, CSR, DCSR, CSC, DCSC) with
125        # regular and loose compression and various metadata bitwidths.
126        # For these simple orderings, dim2lvl and lvl2dim are the same.
127        builder = st.EncodingAttr.build_level_type
128        fmt = st.LevelFormat
129        prop = st.LevelProperty
130        levels = [
131            [builder(fmt.compressed, [prop.non_unique]), builder(fmt.singleton)],
132            [
133                builder(fmt.compressed, [prop.non_unique]),
134                builder(fmt.singleton, [prop.soa]),
135            ],
136            [builder(fmt.dense), builder(fmt.compressed)],
137            [builder(fmt.dense), builder(fmt.loose_compressed)],
138            [builder(fmt.compressed), builder(fmt.compressed)],
139        ]
140        orderings = [
141            (ir.AffineMap.get_permutation([0, 1]), 0),
142            (ir.AffineMap.get_permutation([1, 0]), 1),
143        ]
144        bitwidths = [8, 64]
145        compiler = sparsifier.Sparsifier(
146            extras="", options="", opt_level=2, shared_libs=[support_lib]
147        )
148        for level in levels:
149            for ordering, id_map in orderings:
150                for bwidth in bitwidths:
151                    attr = st.EncodingAttr.get(
152                        level, ordering, ordering, bwidth, bwidth
153                    )
154                    build_compile_and_run_output(attr, compiler, expected(id_map))
155                    count = count + 1
156
157        # Now do the same for BSR.
158        level = [
159            builder(fmt.dense),
160            builder(fmt.compressed),
161            builder(fmt.dense),
162            builder(fmt.dense),
163        ]
164        d0 = ir.AffineDimExpr.get(0)
165        d1 = ir.AffineDimExpr.get(1)
166        c2 = ir.AffineConstantExpr.get(2)
167        dim2lvl = ir.AffineMap.get(
168            2,
169            0,
170            [
171                ir.AffineExpr.get_floor_div(d0, c2),
172                ir.AffineExpr.get_floor_div(d1, c2),
173                ir.AffineExpr.get_mod(d0, c2),
174                ir.AffineExpr.get_mod(d1, c2),
175            ],
176        )
177        l0 = ir.AffineDimExpr.get(0)
178        l1 = ir.AffineDimExpr.get(1)
179        l2 = ir.AffineDimExpr.get(2)
180        l3 = ir.AffineDimExpr.get(3)
181        lvl2dim = ir.AffineMap.get(4, 0, [2 * l0 + l2, 2 * l1 + l3])
182        attr = st.EncodingAttr.get(level, dim2lvl, lvl2dim, 0, 0)
183        build_compile_and_run_output(attr, compiler, expected(2))
184        count = count + 1
185
186    # CHECK: Passed 21 tests
187    print("Passed", count, "tests")
188
189
190if __name__ == "__main__":
191    main()
192