xref: /llvm-project/mlir/test/python/dialects/sparse_tensor/dialect.py (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
1f13893f6SStella Laurenzo# RUN: %PYTHON %s | FileCheck %s
2f13893f6SStella Laurenzo
3f13893f6SStella Laurenzofrom mlir.ir import *
4a9746675SMateusz Sokółfrom mlir.dialects import sparse_tensor as st, tensor
5a10d67f9SYinying Liimport textwrap
6f13893f6SStella Laurenzo
7f9008e63STobias Hieta
8f13893f6SStella Laurenzodef run(f):
9f13893f6SStella Laurenzo    print("\nTEST:", f.__name__)
10f13893f6SStella Laurenzo    f()
11f13893f6SStella Laurenzo    return f
12f13893f6SStella Laurenzo
13f13893f6SStella Laurenzo
14f13893f6SStella Laurenzo# CHECK-LABEL: TEST: testEncodingAttr1D
15f13893f6SStella Laurenzo@run
16f13893f6SStella Laurenzodef testEncodingAttr1D():
17f13893f6SStella Laurenzo    with Context() as ctx:
18f9008e63STobias Hieta        parsed = Attribute.parse(
19a10d67f9SYinying Li            textwrap.dedent(
20a10d67f9SYinying Li                """\
21a10d67f9SYinying Li                #sparse_tensor.encoding<{
22a10d67f9SYinying Li                    map = (d0) -> (d0 : compressed),
23a10d67f9SYinying Li                    posWidth = 16,
24a10d67f9SYinying Li                    crdWidth = 32,
25a10d67f9SYinying Li                    explicitVal = 1.0 : f64
26a10d67f9SYinying Li                }>\
27a10d67f9SYinying Li            """
28f9008e63STobias Hieta            )
29a10d67f9SYinying Li        )
30a10d67f9SYinying Li        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32, explicitVal = 1.000000e+00 : f64 }>
31f13893f6SStella Laurenzo        print(parsed)
32f13893f6SStella Laurenzo
33f13893f6SStella Laurenzo        casted = st.EncodingAttr(parsed)
34f13893f6SStella Laurenzo        # CHECK: equal: True
35f13893f6SStella Laurenzo        print(f"equal: {casted == parsed}")
36f13893f6SStella Laurenzo
3756d58295SPeiming Liu        # CHECK: lvl_types: [262144]
38a0615d02Swren romano        print(f"lvl_types: {casted.lvl_types}")
39b165650aSYinying Li        # CHECK: dim_to_lvl: (d0) -> (d0)
4076647fceSwren romano        print(f"dim_to_lvl: {casted.dim_to_lvl}")
41b165650aSYinying Li        # CHECK: lvl_to_dim: (d0) -> (d0)
42d4088e7dSYinying Li        print(f"lvl_to_dim: {casted.lvl_to_dim}")
4384cd51bbSwren romano        # CHECK: pos_width: 16
4484cd51bbSwren romano        print(f"pos_width: {casted.pos_width}")
4584cd51bbSwren romano        # CHECK: crd_width: 32
4684cd51bbSwren romano        print(f"crd_width: {casted.crd_width}")
47a10d67f9SYinying Li        # CHECK: explicit_val: 1.000000e+00
48a10d67f9SYinying Li        print(f"explicit_val: {casted.explicit_val}")
49a10d67f9SYinying Li        # CHECK: implicit_val: None
50a10d67f9SYinying Li        print(f"implicit_val: {casted.implicit_val}")
51f13893f6SStella Laurenzo
52a10d67f9SYinying Li        new_explicit_val = FloatAttr.get_f64(1.0)
53a10d67f9SYinying Li        created = st.EncodingAttr.get(
54a10d67f9SYinying Li            casted.lvl_types, None, None, 0, 0, new_explicit_val
55a10d67f9SYinying Li        )
56a10d67f9SYinying Li        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), explicitVal = 1.000000e+00 : f64 }>
57f13893f6SStella Laurenzo        print(created)
580baec207SAart Bik        # CHECK: created_equal: False
59f13893f6SStella Laurenzo        print(f"created_equal: {created == casted}")
60f13893f6SStella Laurenzo
61f13893f6SStella Laurenzo        # Verify that the factory creates an instance of the proper type.
62f13893f6SStella Laurenzo        # CHECK: is_proper_instance: True
63f13893f6SStella Laurenzo        print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
6484cd51bbSwren romano        # CHECK: created_pos_width: 0
6584cd51bbSwren romano        print(f"created_pos_width: {created.pos_width}")
66f13893f6SStella Laurenzo
67f13893f6SStella Laurenzo
682a6b521bSYinying Li# CHECK-LABEL: TEST: testEncodingAttrStructure
692a6b521bSYinying Li@run
702a6b521bSYinying Lidef testEncodingAttrStructure():
712a6b521bSYinying Li    with Context() as ctx:
722a6b521bSYinying Li        parsed = Attribute.parse(
73a10d67f9SYinying Li            textwrap.dedent(
74a10d67f9SYinying Li                """\
75a10d67f9SYinying Li                #sparse_tensor.encoding<{
76a10d67f9SYinying Li                    map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense,
77a10d67f9SYinying Li                    d1 mod 4 : structured[2, 4]),
78a10d67f9SYinying Li                    posWidth = 16,
79a10d67f9SYinying Li                    crdWidth = 32,
80a10d67f9SYinying Li                }>\
81a10d67f9SYinying Li            """
82a10d67f9SYinying Li            )
832a6b521bSYinying Li        )
842a6b521bSYinying Li        # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]), posWidth = 16, crdWidth = 32 }>
852a6b521bSYinying Li        print(parsed)
862a6b521bSYinying Li
872a6b521bSYinying Li        casted = st.EncodingAttr(parsed)
882a6b521bSYinying Li        # CHECK: equal: True
892a6b521bSYinying Li        print(f"equal: {casted == parsed}")
902a6b521bSYinying Li
9156d58295SPeiming Liu        # CHECK: lvl_types: [65536, 65536, 4406638542848]
922a6b521bSYinying Li        print(f"lvl_types: {casted.lvl_types}")
93*5cd42747SPeter Hawkins        # CHECK: lvl_formats_enum: [{{65536|LevelFormat.dense}}, {{65536|LevelFormat.dense}}, {{2097152|LevelFormat.n_out_of_m}}]
94429919e3SPeiming Liu        print(f"lvl_formats_enum: {casted.lvl_formats_enum}")
952a6b521bSYinying Li        # CHECK: structured_n: 2
962a6b521bSYinying Li        print(f"structured_n: {casted.structured_n}")
972a6b521bSYinying Li        # CHECK: structured_m: 4
982a6b521bSYinying Li        print(f"structured_m: {casted.structured_m}")
992a6b521bSYinying Li        # CHECK: dim_to_lvl: (d0, d1) -> (d0, d1 floordiv 4, d1 mod 4)
1002a6b521bSYinying Li        print(f"dim_to_lvl: {casted.dim_to_lvl}")
1012a6b521bSYinying Li        # CHECK: lvl_to_dim: (d0, d1, d2) -> (d0, d1 * 4 + d2)
1022a6b521bSYinying Li        print(f"lvl_to_dim: {casted.lvl_to_dim}")
1032a6b521bSYinying Li        # CHECK: pos_width: 16
1042a6b521bSYinying Li        print(f"pos_width: {casted.pos_width}")
1052a6b521bSYinying Li        # CHECK: crd_width: 32
1062a6b521bSYinying Li        print(f"crd_width: {casted.crd_width}")
1072a6b521bSYinying Li
1082a6b521bSYinying Li        created = st.EncodingAttr.get(
1092a6b521bSYinying Li            casted.lvl_types, casted.dim_to_lvl, casted.lvl_to_dim, 0, 0
1102a6b521bSYinying Li        )
1112a6b521bSYinying Li        # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]) }>
1122a6b521bSYinying Li        print(created)
1132a6b521bSYinying Li        # CHECK: created_equal: False
1142a6b521bSYinying Li        print(f"created_equal: {created == casted}")
1152a6b521bSYinying Li
116429919e3SPeiming Liu        built_2_4 = st.EncodingAttr.build_level_type(
117429919e3SPeiming Liu            st.LevelFormat.n_out_of_m, [], 2, 4
118429919e3SPeiming Liu        )
119429919e3SPeiming Liu        built_dense = st.EncodingAttr.build_level_type(st.LevelFormat.dense)
1202a6b521bSYinying Li        dim_to_lvl = AffineMap.get(
1212a6b521bSYinying Li            2,
1222a6b521bSYinying Li            0,
1232a6b521bSYinying Li            [
1242a6b521bSYinying Li                AffineExpr.get_dim(0),
1252a6b521bSYinying Li                AffineExpr.get_floor_div(AffineExpr.get_dim(1), 4),
1262a6b521bSYinying Li                AffineExpr.get_mod(AffineExpr.get_dim(1), 4),
1272a6b521bSYinying Li            ],
1282a6b521bSYinying Li        )
1292a6b521bSYinying Li        lvl_to_dim = AffineMap.get(
1302a6b521bSYinying Li            3,
1312a6b521bSYinying Li            0,
1322a6b521bSYinying Li            [
1332a6b521bSYinying Li                AffineExpr.get_dim(0),
1342a6b521bSYinying Li                AffineExpr.get_add(
1352a6b521bSYinying Li                    AffineExpr.get_mul(AffineExpr.get_dim(1), 4),
1362a6b521bSYinying Li                    AffineExpr.get_dim(2),
1372a6b521bSYinying Li                ),
1382a6b521bSYinying Li            ],
1392a6b521bSYinying Li        )
1402a6b521bSYinying Li        built = st.EncodingAttr.get(
141429919e3SPeiming Liu            [built_dense, built_dense, built_2_4],
1422a6b521bSYinying Li            dim_to_lvl,
1432a6b521bSYinying Li            lvl_to_dim,
1442a6b521bSYinying Li            0,
1452a6b521bSYinying Li            0,
1462a6b521bSYinying Li        )
1472a6b521bSYinying Li        # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]) }>
1482a6b521bSYinying Li        print(built)
1492a6b521bSYinying Li        # CHECK: built_equal: True
1502a6b521bSYinying Li        print(f"built_equal: {built == created}")
1512a6b521bSYinying Li
1522a6b521bSYinying Li        # Verify that the factory creates an instance of the proper type.
1532a6b521bSYinying Li        # CHECK: is_proper_instance: True
1542a6b521bSYinying Li        print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
1552a6b521bSYinying Li        # CHECK: created_pos_width: 0
1562a6b521bSYinying Li        print(f"created_pos_width: {created.pos_width}")
1572a6b521bSYinying Li
1582a6b521bSYinying Li
159f13893f6SStella Laurenzo# CHECK-LABEL: TEST: testEncodingAttr2D
160f13893f6SStella Laurenzo@run
161f13893f6SStella Laurenzodef testEncodingAttr2D():
162f13893f6SStella Laurenzo    with Context() as ctx:
163f9008e63STobias Hieta        parsed = Attribute.parse(
164a10d67f9SYinying Li            textwrap.dedent(
165a10d67f9SYinying Li                """\
166a10d67f9SYinying Li                #sparse_tensor.encoding<{
167a10d67f9SYinying Li                    map = (d0, d1) -> (d1 : dense, d0 : compressed),
168a10d67f9SYinying Li                    posWidth = 8,
169a10d67f9SYinying Li                    crdWidth = 32,
170a10d67f9SYinying Li                }>\
171a10d67f9SYinying Li            """
172a10d67f9SYinying Li            )
173f9008e63STobias Hieta        )
1746280e231SYinying Li        # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
175f13893f6SStella Laurenzo        print(parsed)
176f13893f6SStella Laurenzo
177f13893f6SStella Laurenzo        casted = st.EncodingAttr(parsed)
178f13893f6SStella Laurenzo        # CHECK: equal: True
179f13893f6SStella Laurenzo        print(f"equal: {casted == parsed}")
180f13893f6SStella Laurenzo
18156d58295SPeiming Liu        # CHECK: lvl_types: [65536, 262144]
182a0615d02Swren romano        print(f"lvl_types: {casted.lvl_types}")
18376647fceSwren romano        # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
18476647fceSwren romano        print(f"dim_to_lvl: {casted.dim_to_lvl}")
185d4088e7dSYinying Li        # CHECK: lvl_to_dim: (d0, d1) -> (d1, d0)
186d4088e7dSYinying Li        print(f"lvl_to_dim: {casted.lvl_to_dim}")
18784cd51bbSwren romano        # CHECK: pos_width: 8
18884cd51bbSwren romano        print(f"pos_width: {casted.pos_width}")
18984cd51bbSwren romano        # CHECK: crd_width: 32
19084cd51bbSwren romano        print(f"crd_width: {casted.crd_width}")
191f13893f6SStella Laurenzo
192d4088e7dSYinying Li        created = st.EncodingAttr.get(
193d4088e7dSYinying Li            casted.lvl_types,
194d4088e7dSYinying Li            casted.dim_to_lvl,
195d4088e7dSYinying Li            casted.lvl_to_dim,
196d4088e7dSYinying Li            8,
197d4088e7dSYinying Li            32,
198d4088e7dSYinying Li        )
1996280e231SYinying Li        # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
200f13893f6SStella Laurenzo        print(created)
201f13893f6SStella Laurenzo        # CHECK: created_equal: True
202f13893f6SStella Laurenzo        print(f"created_equal: {created == casted}")
203a2c8aebdSStella Laurenzo
204a2c8aebdSStella Laurenzo
2050baec207SAart Bik# CHECK-LABEL: TEST: testEncodingAttrOnTensorType
206a2c8aebdSStella Laurenzo@run
2070baec207SAart Bikdef testEncodingAttrOnTensorType():
208a2c8aebdSStella Laurenzo    with Context() as ctx, Location.unknown():
2090baec207SAart Bik        encoding = st.EncodingAttr(
210f9008e63STobias Hieta            Attribute.parse(
211a10d67f9SYinying Li                textwrap.dedent(
212a10d67f9SYinying Li                    """\
213a10d67f9SYinying Li                    #sparse_tensor.encoding<{
214a10d67f9SYinying Li                        map = (d0) -> (d0 : compressed),
215a10d67f9SYinying Li                        posWidth = 64,
216a10d67f9SYinying Li                        crdWidth = 32,
217a10d67f9SYinying Li                    }>\
218a10d67f9SYinying Li                """
219a10d67f9SYinying Li                )
220f9008e63STobias Hieta            )
221f9008e63STobias Hieta        )
222a2c8aebdSStella Laurenzo        tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)
2236280e231SYinying Li        # CHECK: tensor<1024xf32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }>>
224a2c8aebdSStella Laurenzo        print(tt)
2256280e231SYinying Li        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }>
226a2c8aebdSStella Laurenzo        print(tt.encoding)
227a2c8aebdSStella Laurenzo        assert tt.encoding == encoding
228a9746675SMateusz Sokół
229a9746675SMateusz Sokół
230a9746675SMateusz Sokół# CHECK-LABEL: TEST: testEncodingEmptyTensor
231a9746675SMateusz Sokół@run
232a9746675SMateusz Sokółdef testEncodingEmptyTensor():
233a9746675SMateusz Sokół    with Context(), Location.unknown():
234a9746675SMateusz Sokół        module = Module.create()
235a9746675SMateusz Sokół        with InsertionPoint(module.body):
236a9746675SMateusz Sokół            levels = [st.LevelFormat.compressed]
237a9746675SMateusz Sokół            ordering = AffineMap.get_permutation([0])
238a9746675SMateusz Sokół            encoding = st.EncodingAttr.get(levels, ordering, ordering, 32, 32)
239a9746675SMateusz Sokół            tensor.empty((1024,), F32Type.get(), encoding=encoding)
240a9746675SMateusz Sokół
241a9746675SMateusz Sokół        # CHECK: #sparse = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 32, crdWidth = 32 }>
242a9746675SMateusz Sokół        # CHECK: module {
243a9746675SMateusz Sokół        # CHECK:   %[[VAL_0:.*]] = tensor.empty() : tensor<1024xf32, #sparse>
244a9746675SMateusz Sokół        # CHECK: }
245a9746675SMateusz Sokół        print(module)
246