xref: /llvm-project/mlir/test/python/dialects/sparse_tensor/dialect.py (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import sparse_tensor as st, tensor
5import textwrap
6
7
8def run(f):
9    print("\nTEST:", f.__name__)
10    f()
11    return f
12
13
14# CHECK-LABEL: TEST: testEncodingAttr1D
15@run
16def testEncodingAttr1D():
17    with Context() as ctx:
18        parsed = Attribute.parse(
19            textwrap.dedent(
20                """\
21                #sparse_tensor.encoding<{
22                    map = (d0) -> (d0 : compressed),
23                    posWidth = 16,
24                    crdWidth = 32,
25                    explicitVal = 1.0 : f64
26                }>\
27            """
28            )
29        )
30        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32, explicitVal = 1.000000e+00 : f64 }>
31        print(parsed)
32
33        casted = st.EncodingAttr(parsed)
34        # CHECK: equal: True
35        print(f"equal: {casted == parsed}")
36
37        # CHECK: lvl_types: [262144]
38        print(f"lvl_types: {casted.lvl_types}")
39        # CHECK: dim_to_lvl: (d0) -> (d0)
40        print(f"dim_to_lvl: {casted.dim_to_lvl}")
41        # CHECK: lvl_to_dim: (d0) -> (d0)
42        print(f"lvl_to_dim: {casted.lvl_to_dim}")
43        # CHECK: pos_width: 16
44        print(f"pos_width: {casted.pos_width}")
45        # CHECK: crd_width: 32
46        print(f"crd_width: {casted.crd_width}")
47        # CHECK: explicit_val: 1.000000e+00
48        print(f"explicit_val: {casted.explicit_val}")
49        # CHECK: implicit_val: None
50        print(f"implicit_val: {casted.implicit_val}")
51
52        new_explicit_val = FloatAttr.get_f64(1.0)
53        created = st.EncodingAttr.get(
54            casted.lvl_types, None, None, 0, 0, new_explicit_val
55        )
56        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), explicitVal = 1.000000e+00 : f64 }>
57        print(created)
58        # CHECK: created_equal: False
59        print(f"created_equal: {created == casted}")
60
61        # Verify that the factory creates an instance of the proper type.
62        # CHECK: is_proper_instance: True
63        print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
64        # CHECK: created_pos_width: 0
65        print(f"created_pos_width: {created.pos_width}")
66
67
68# CHECK-LABEL: TEST: testEncodingAttrStructure
69@run
70def testEncodingAttrStructure():
71    with Context() as ctx:
72        parsed = Attribute.parse(
73            textwrap.dedent(
74                """\
75                #sparse_tensor.encoding<{
76                    map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense,
77                    d1 mod 4 : structured[2, 4]),
78                    posWidth = 16,
79                    crdWidth = 32,
80                }>\
81            """
82            )
83        )
84        # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]), posWidth = 16, crdWidth = 32 }>
85        print(parsed)
86
87        casted = st.EncodingAttr(parsed)
88        # CHECK: equal: True
89        print(f"equal: {casted == parsed}")
90
91        # CHECK: lvl_types: [65536, 65536, 4406638542848]
92        print(f"lvl_types: {casted.lvl_types}")
93        # CHECK: lvl_formats_enum: [{{65536|LevelFormat.dense}}, {{65536|LevelFormat.dense}}, {{2097152|LevelFormat.n_out_of_m}}]
94        print(f"lvl_formats_enum: {casted.lvl_formats_enum}")
95        # CHECK: structured_n: 2
96        print(f"structured_n: {casted.structured_n}")
97        # CHECK: structured_m: 4
98        print(f"structured_m: {casted.structured_m}")
99        # CHECK: dim_to_lvl: (d0, d1) -> (d0, d1 floordiv 4, d1 mod 4)
100        print(f"dim_to_lvl: {casted.dim_to_lvl}")
101        # CHECK: lvl_to_dim: (d0, d1, d2) -> (d0, d1 * 4 + d2)
102        print(f"lvl_to_dim: {casted.lvl_to_dim}")
103        # CHECK: pos_width: 16
104        print(f"pos_width: {casted.pos_width}")
105        # CHECK: crd_width: 32
106        print(f"crd_width: {casted.crd_width}")
107
108        created = st.EncodingAttr.get(
109            casted.lvl_types, casted.dim_to_lvl, casted.lvl_to_dim, 0, 0
110        )
111        # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]) }>
112        print(created)
113        # CHECK: created_equal: False
114        print(f"created_equal: {created == casted}")
115
116        built_2_4 = st.EncodingAttr.build_level_type(
117            st.LevelFormat.n_out_of_m, [], 2, 4
118        )
119        built_dense = st.EncodingAttr.build_level_type(st.LevelFormat.dense)
120        dim_to_lvl = AffineMap.get(
121            2,
122            0,
123            [
124                AffineExpr.get_dim(0),
125                AffineExpr.get_floor_div(AffineExpr.get_dim(1), 4),
126                AffineExpr.get_mod(AffineExpr.get_dim(1), 4),
127            ],
128        )
129        lvl_to_dim = AffineMap.get(
130            3,
131            0,
132            [
133                AffineExpr.get_dim(0),
134                AffineExpr.get_add(
135                    AffineExpr.get_mul(AffineExpr.get_dim(1), 4),
136                    AffineExpr.get_dim(2),
137                ),
138            ],
139        )
140        built = st.EncodingAttr.get(
141            [built_dense, built_dense, built_2_4],
142            dim_to_lvl,
143            lvl_to_dim,
144            0,
145            0,
146        )
147        # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]) }>
148        print(built)
149        # CHECK: built_equal: True
150        print(f"built_equal: {built == created}")
151
152        # Verify that the factory creates an instance of the proper type.
153        # CHECK: is_proper_instance: True
154        print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
155        # CHECK: created_pos_width: 0
156        print(f"created_pos_width: {created.pos_width}")
157
158
159# CHECK-LABEL: TEST: testEncodingAttr2D
160@run
161def testEncodingAttr2D():
162    with Context() as ctx:
163        parsed = Attribute.parse(
164            textwrap.dedent(
165                """\
166                #sparse_tensor.encoding<{
167                    map = (d0, d1) -> (d1 : dense, d0 : compressed),
168                    posWidth = 8,
169                    crdWidth = 32,
170                }>\
171            """
172            )
173        )
174        # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
175        print(parsed)
176
177        casted = st.EncodingAttr(parsed)
178        # CHECK: equal: True
179        print(f"equal: {casted == parsed}")
180
181        # CHECK: lvl_types: [65536, 262144]
182        print(f"lvl_types: {casted.lvl_types}")
183        # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
184        print(f"dim_to_lvl: {casted.dim_to_lvl}")
185        # CHECK: lvl_to_dim: (d0, d1) -> (d1, d0)
186        print(f"lvl_to_dim: {casted.lvl_to_dim}")
187        # CHECK: pos_width: 8
188        print(f"pos_width: {casted.pos_width}")
189        # CHECK: crd_width: 32
190        print(f"crd_width: {casted.crd_width}")
191
192        created = st.EncodingAttr.get(
193            casted.lvl_types,
194            casted.dim_to_lvl,
195            casted.lvl_to_dim,
196            8,
197            32,
198        )
199        # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
200        print(created)
201        # CHECK: created_equal: True
202        print(f"created_equal: {created == casted}")
203
204
205# CHECK-LABEL: TEST: testEncodingAttrOnTensorType
206@run
207def testEncodingAttrOnTensorType():
208    with Context() as ctx, Location.unknown():
209        encoding = st.EncodingAttr(
210            Attribute.parse(
211                textwrap.dedent(
212                    """\
213                    #sparse_tensor.encoding<{
214                        map = (d0) -> (d0 : compressed),
215                        posWidth = 64,
216                        crdWidth = 32,
217                    }>\
218                """
219                )
220            )
221        )
222        tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)
223        # CHECK: tensor<1024xf32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }>>
224        print(tt)
225        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }>
226        print(tt.encoding)
227        assert tt.encoding == encoding
228
229
230# CHECK-LABEL: TEST: testEncodingEmptyTensor
231@run
232def testEncodingEmptyTensor():
233    with Context(), Location.unknown():
234        module = Module.create()
235        with InsertionPoint(module.body):
236            levels = [st.LevelFormat.compressed]
237            ordering = AffineMap.get_permutation([0])
238            encoding = st.EncodingAttr.get(levels, ordering, ordering, 32, 32)
239            tensor.empty((1024,), F32Type.get(), encoding=encoding)
240
241        # CHECK: #sparse = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 32, crdWidth = 32 }>
242        # CHECK: module {
243        # CHECK:   %[[VAL_0:.*]] = tensor.empty() : tensor<1024xf32, #sparse>
244        # CHECK: }
245        print(module)
246