xref: /llvm-project/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py (revision e3686f1e44676fa28789c6732076b8998be23527)
19048ea28SMarkus Böck# RUN: env SUPPORT_LIB=%mlir_c_runner_utils \
292c1c63dSAart Bik# RUN:   %PYTHON %s | FileCheck %s
392c1c63dSAart Bik
492c1c63dSAart Bikimport ctypes
592c1c63dSAart Bikimport os
68b83b8f1SAart Bikimport sys
792c1c63dSAart Bikimport tempfile
892c1c63dSAart Bik
992c1c63dSAart Bikfrom mlir import ir
1092c1c63dSAart Bikfrom mlir import runtime as rt
1192c1c63dSAart Bikfrom mlir.dialects import builtin
1292c1c63dSAart Bikfrom mlir.dialects import sparse_tensor as st
1392c1c63dSAart Bik
148b83b8f1SAart Bik_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
158b83b8f1SAart Biksys.path.append(_SCRIPT_PATH)
16dce7a7cfSTim Harveyfrom tools import sparsifier
1792c1c63dSAart Bik
18ed2d0b0eSAart Bik
1992c1c63dSAart Bikdef boilerplate(attr: st.EncodingAttr):
2092c1c63dSAart Bik    """Returns boilerplate main method."""
2192c1c63dSAart Bik    return f"""
2252491c99SChristian Ulmannfunc.func @main(%p : !llvm.ptr) -> () attributes {{ llvm.emit_c_interface }} {{
2392c1c63dSAart Bik  %d = arith.constant sparse<[[0, 0], [1, 1], [0, 9], [9, 0], [4, 4]],
2492c1c63dSAart Bik                             [1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<10x10xf64>
2592c1c63dSAart Bik  %a = sparse_tensor.convert %d : tensor<10x10xf64> to tensor<10x10xf64, {attr}>
2652491c99SChristian Ulmann  sparse_tensor.out %a, %p : tensor<10x10xf64, {attr}>, !llvm.ptr
2792c1c63dSAart Bik  return
2892c1c63dSAart Bik}}
2992c1c63dSAart Bik"""
3092c1c63dSAart Bik
3192c1c63dSAart Bik
327d608ee2SPeiming Liudef expected(id_map):
3392c1c63dSAart Bik    """Returns expected contents of output.
3492c1c63dSAart Bik
35a89c15aaSAart Bik    +-----+-----+-----+-----+-----+
36a89c15aaSAart Bik    | 1 0 | . . | . . | . . | 0 3 |
37a89c15aaSAart Bik    | 0 2 | . . | . . | . . | 0 0 |
38a89c15aaSAart Bik    +-----+-----+-----+-----+-----+
39a89c15aaSAart Bik    | . . | . . | . . | . . | . . |
40a89c15aaSAart Bik    | . . | . . | . . | . . | . . |
41a89c15aaSAart Bik    +-----+-----+-----+-----+-----+
42a89c15aaSAart Bik    | . . | . . | 5 0 | . . | . . |
43a89c15aaSAart Bik    | . . | . . | 0 0 | . . | . . |
44a89c15aaSAart Bik    +-----+-----+-----+-----+-----+
45a89c15aaSAart Bik    | . . | . . | . . | . . | . . |
46a89c15aaSAart Bik    | . . | . . | . . | . . | . . |
47a89c15aaSAart Bik    +-----+-----+-----+-----+-----+
48a89c15aaSAart Bik    | 0 0 | . . | . . | . . | . . |
49a89c15aaSAart Bik    | 4 0 | . . | . . | . . | . . |
50a89c15aaSAart Bik    +-----+-----+-----+-----+-----+
51a89c15aaSAart Bik
527d608ee2SPeiming Liu    Output appears as dimension coordinates but lexicographically
53a89c15aaSAart Bik    sorted by level coordinates. For BSR, the blocks are filled.
5492c1c63dSAart Bik    """
55a89c15aaSAart Bik    if id_map is 0:
56a89c15aaSAart Bik        return f"""# extended FROSTT format
5792c1c63dSAart Bik2 5
5892c1c63dSAart Bik10 10
5992c1c63dSAart Bik1 1 1
6092c1c63dSAart Bik1 10 3
6192c1c63dSAart Bik2 2 2
6292c1c63dSAart Bik5 5 5
6392c1c63dSAart Bik10 1 4
6492c1c63dSAart Bik"""
65a89c15aaSAart Bik    if id_map is 1:
66a89c15aaSAart Bik        return f"""# extended FROSTT format
677d608ee2SPeiming Liu2 5
687d608ee2SPeiming Liu10 10
697d608ee2SPeiming Liu1 1 1
707d608ee2SPeiming Liu10 1 4
717d608ee2SPeiming Liu2 2 2
727d608ee2SPeiming Liu5 5 5
737d608ee2SPeiming Liu1 10 3
747d608ee2SPeiming Liu"""
75a89c15aaSAart Bik    if id_map is 2:
76a89c15aaSAart Bik        return f"""# extended FROSTT format
77a89c15aaSAart Bik2 16
78a89c15aaSAart Bik10 10
79a89c15aaSAart Bik1 1 1
80a89c15aaSAart Bik1 2 0
81a89c15aaSAart Bik2 1 0
82a89c15aaSAart Bik2 2 2
83a89c15aaSAart Bik1 9 0
84a89c15aaSAart Bik1 10 3
85a89c15aaSAart Bik2 9 0
86a89c15aaSAart Bik2 10 0
87a89c15aaSAart Bik5 5 5
88a89c15aaSAart Bik5 6 0
89a89c15aaSAart Bik6 5 0
90a89c15aaSAart Bik6 6 0
91a89c15aaSAart Bik9 1 0
92a89c15aaSAart Bik9 2 0
93a89c15aaSAart Bik10 1 4
94a89c15aaSAart Bik10 2 0
95a89c15aaSAart Bik"""
96a89c15aaSAart Bik    raise AssertionError("unexpected id_map")
9792c1c63dSAart Bik
9892c1c63dSAart Bik
99ed2d0b0eSAart Bikdef build_compile_and_run_output(attr: st.EncodingAttr, compiler, expected):
10092c1c63dSAart Bik    # Build and Compile.
10192c1c63dSAart Bik    module = ir.Module.parse(boilerplate(attr))
10228063a28SAart Bik    engine = compiler.compile_and_jit(module)
10392c1c63dSAart Bik    # Invoke the kernel and compare output.
10492c1c63dSAart Bik    with tempfile.TemporaryDirectory() as test_dir:
105f9008e63STobias Hieta        out = os.path.join(test_dir, "out.tns")
106f9008e63STobias Hieta        buf = out.encode("utf-8")
10792c1c63dSAart Bik        mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf)))
108f9008e63STobias Hieta        engine.invoke("main", mem_a)
10992c1c63dSAart Bik        actual = open(out).read()
110ed2d0b0eSAart Bik        if actual != expected:
111f9008e63STobias Hieta            quit("FAILURE")
11292c1c63dSAart Bik
11392c1c63dSAart Bik
11492c1c63dSAart Bikdef main():
115f9008e63STobias Hieta    support_lib = os.getenv("SUPPORT_LIB")
116f9008e63STobias Hieta    assert support_lib is not None, "SUPPORT_LIB is undefined"
11792c1c63dSAart Bik    if not os.path.exists(support_lib):
118f9008e63STobias Hieta        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
11992c1c63dSAart Bik
12092c1c63dSAart Bik    # CHECK-LABEL: TEST: test_output
121f9008e63STobias Hieta    print("\nTEST: test_output")
12292c1c63dSAart Bik    count = 0
12392c1c63dSAart Bik    with ir.Context() as ctx, ir.Location.unknown():
124ed2d0b0eSAart Bik        # Loop over various sparse types (COO, CSR, DCSR, CSC, DCSC) with
125ed2d0b0eSAart Bik        # regular and loose compression and various metadata bitwidths.
126ed2d0b0eSAart Bik        # For these simple orderings, dim2lvl and lvl2dim are the same.
127429919e3SPeiming Liu        builder = st.EncodingAttr.build_level_type
128429919e3SPeiming Liu        fmt = st.LevelFormat
129429919e3SPeiming Liu        prop = st.LevelProperty
130f9008e63STobias Hieta        levels = [
131429919e3SPeiming Liu            [builder(fmt.compressed, [prop.non_unique]), builder(fmt.singleton)],
132b50ce4c8SMateusz Sokół            [
133b50ce4c8SMateusz Sokół                builder(fmt.compressed, [prop.non_unique]),
134b50ce4c8SMateusz Sokół                builder(fmt.singleton, [prop.soa]),
135b50ce4c8SMateusz Sokół            ],
136429919e3SPeiming Liu            [builder(fmt.dense), builder(fmt.compressed)],
137429919e3SPeiming Liu            [builder(fmt.dense), builder(fmt.loose_compressed)],
138429919e3SPeiming Liu            [builder(fmt.compressed), builder(fmt.compressed)],
139f9008e63STobias Hieta        ]
14092c1c63dSAart Bik        orderings = [
141a89c15aaSAart Bik            (ir.AffineMap.get_permutation([0, 1]), 0),
142a89c15aaSAart Bik            (ir.AffineMap.get_permutation([1, 0]), 1),
14392c1c63dSAart Bik        ]
144a89c15aaSAart Bik        bitwidths = [8, 64]
145dce7a7cfSTim Harvey        compiler = sparsifier.Sparsifier(
1460e34dbb4SAart Bik            extras="", options="", opt_level=2, shared_libs=[support_lib]
147f9008e63STobias Hieta        )
14892c1c63dSAart Bik        for level in levels:
1497d608ee2SPeiming Liu            for ordering, id_map in orderings:
15092c1c63dSAart Bik                for bwidth in bitwidths:
151ed2d0b0eSAart Bik                    attr = st.EncodingAttr.get(
152ed2d0b0eSAart Bik                        level, ordering, ordering, bwidth, bwidth
153ed2d0b0eSAart Bik                    )
1547d608ee2SPeiming Liu                    build_compile_and_run_output(attr, compiler, expected(id_map))
15592c1c63dSAart Bik                    count = count + 1
15692c1c63dSAart Bik
157ed2d0b0eSAart Bik        # Now do the same for BSR.
158ed2d0b0eSAart Bik        level = [
159429919e3SPeiming Liu            builder(fmt.dense),
160429919e3SPeiming Liu            builder(fmt.compressed),
161429919e3SPeiming Liu            builder(fmt.dense),
162429919e3SPeiming Liu            builder(fmt.dense),
163ed2d0b0eSAart Bik        ]
164ed2d0b0eSAart Bik        d0 = ir.AffineDimExpr.get(0)
165ed2d0b0eSAart Bik        d1 = ir.AffineDimExpr.get(1)
166ed2d0b0eSAart Bik        c2 = ir.AffineConstantExpr.get(2)
167ed2d0b0eSAart Bik        dim2lvl = ir.AffineMap.get(
168ed2d0b0eSAart Bik            2,
169ed2d0b0eSAart Bik            0,
170ed2d0b0eSAart Bik            [
171ed2d0b0eSAart Bik                ir.AffineExpr.get_floor_div(d0, c2),
172ed2d0b0eSAart Bik                ir.AffineExpr.get_floor_div(d1, c2),
173ed2d0b0eSAart Bik                ir.AffineExpr.get_mod(d0, c2),
174ed2d0b0eSAart Bik                ir.AffineExpr.get_mod(d1, c2),
175ed2d0b0eSAart Bik            ],
176ed2d0b0eSAart Bik        )
177ed2d0b0eSAart Bik        l0 = ir.AffineDimExpr.get(0)
178ed2d0b0eSAart Bik        l1 = ir.AffineDimExpr.get(1)
179ed2d0b0eSAart Bik        l2 = ir.AffineDimExpr.get(2)
180ed2d0b0eSAart Bik        l3 = ir.AffineDimExpr.get(3)
181ed2d0b0eSAart Bik        lvl2dim = ir.AffineMap.get(4, 0, [2 * l0 + l2, 2 * l1 + l3])
182ed2d0b0eSAart Bik        attr = st.EncodingAttr.get(level, dim2lvl, lvl2dim, 0, 0)
183a89c15aaSAart Bik        build_compile_and_run_output(attr, compiler, expected(2))
184ed2d0b0eSAart Bik        count = count + 1
185ed2d0b0eSAart Bik
186*e3686f1eSMateusz Sokół    # CHECK: Passed 21 tests
187f9008e63STobias Hieta    print("Passed", count, "tests")
18892c1c63dSAart Bik
18992c1c63dSAart Bik
190f9008e63STobias Hietaif __name__ == "__main__":
19192c1c63dSAart Bik    main()
192