xref: /llvm-project/mlir/lib/Bindings/Python/DialectSparseTensor.cpp (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
1bc1df1faSAlex Zinenko //===- DialectSparseTensor.cpp - 'sparse_tensor' dialect submodule --------===//
2f13893f6SStella Laurenzo //
3f13893f6SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f13893f6SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5f13893f6SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f13893f6SStella Laurenzo //
7f13893f6SStella Laurenzo //===----------------------------------------------------------------------===//
8f13893f6SStella Laurenzo 
9*5cd42747SPeter Hawkins #include <optional>
10*5cd42747SPeter Hawkins #include <vector>
11*5cd42747SPeter Hawkins 
12285a229fSMehdi Amini #include "mlir-c/AffineMap.h"
13f13893f6SStella Laurenzo #include "mlir-c/Dialect/SparseTensor.h"
14f13893f6SStella Laurenzo #include "mlir-c/IR.h"
15*5cd42747SPeter Hawkins #include "mlir/Bindings/Python/NanobindAdaptors.h"
16*5cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h"
17f13893f6SStella Laurenzo 
18*5cd42747SPeter Hawkins namespace nb = nanobind;
19f13893f6SStella Laurenzo using namespace llvm;
20f13893f6SStella Laurenzo using namespace mlir;
21*5cd42747SPeter Hawkins using namespace mlir::python::nanobind_adaptors;
22f13893f6SStella Laurenzo 
23*5cd42747SPeter Hawkins static void populateDialectSparseTensorSubmodule(const nb::module_ &m) {
24*5cd42747SPeter Hawkins   nb::enum_<MlirSparseTensorLevelFormat>(m, "LevelFormat", nb::is_arithmetic(),
25*5cd42747SPeter Hawkins                                          nb::is_flag())
261944c4f7SAart Bik       .value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE)
27e5924d64SYinying Li       .value("n_out_of_m", MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M)
281944c4f7SAart Bik       .value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED)
291944c4f7SAart Bik       .value("singleton", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON)
30429919e3SPeiming Liu       .value("loose_compressed", MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED);
31429919e3SPeiming Liu 
32*5cd42747SPeter Hawkins   nb::enum_<MlirSparseTensorLevelPropertyNondefault>(m, "LevelProperty")
33429919e3SPeiming Liu       .value("non_ordered", MLIR_SPARSE_PROPERTY_NON_ORDERED)
34b50ce4c8SMateusz Sokół       .value("non_unique", MLIR_SPARSE_PROPERTY_NON_UNIQUE)
35b50ce4c8SMateusz Sokół       .value("soa", MLIR_SPARSE_PROPERTY_SOA);
36f13893f6SStella Laurenzo 
37f13893f6SStella Laurenzo   mlir_attribute_subclass(m, "EncodingAttr",
3895ddbed9SAlex Zinenko                           mlirAttributeIsASparseTensorEncodingAttr)
39f13893f6SStella Laurenzo       .def_classmethod(
40f13893f6SStella Laurenzo           "get",
41*5cd42747SPeter Hawkins           [](nb::object cls, std::vector<MlirSparseTensorLevelType> lvlTypes,
42d4088e7dSYinying Li              std::optional<MlirAffineMap> dimToLvl,
43d4088e7dSYinying Li              std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
44a10d67f9SYinying Li              std::optional<MlirAttribute> explicitVal,
45a10d67f9SYinying Li              std::optional<MlirAttribute> implicitVal, MlirContext context) {
46f13893f6SStella Laurenzo             return cls(mlirSparseTensorEncodingAttrGet(
47a0615d02Swren romano                 context, lvlTypes.size(), lvlTypes.data(),
48d4088e7dSYinying Li                 dimToLvl ? *dimToLvl : MlirAffineMap{nullptr},
49d4088e7dSYinying Li                 lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth,
50a10d67f9SYinying Li                 crdWidth, explicitVal ? *explicitVal : MlirAttribute{nullptr},
51a10d67f9SYinying Li                 implicitVal ? *implicitVal : MlirAttribute{nullptr}));
52f13893f6SStella Laurenzo           },
53*5cd42747SPeter Hawkins           nb::arg("cls"), nb::arg("lvl_types"), nb::arg("dim_to_lvl").none(),
54*5cd42747SPeter Hawkins           nb::arg("lvl_to_dim").none(), nb::arg("pos_width"),
55*5cd42747SPeter Hawkins           nb::arg("crd_width"), nb::arg("explicit_val").none() = nb::none(),
56*5cd42747SPeter Hawkins           nb::arg("implicit_val").none() = nb::none(),
57*5cd42747SPeter Hawkins           nb::arg("context").none() = nb::none(),
58f13893f6SStella Laurenzo           "Gets a sparse_tensor.encoding from parameters.")
592a6b521bSYinying Li       .def_classmethod(
602a6b521bSYinying Li           "build_level_type",
61*5cd42747SPeter Hawkins           [](nb::object cls, MlirSparseTensorLevelFormat lvlFmt,
62429919e3SPeiming Liu              const std::vector<MlirSparseTensorLevelPropertyNondefault>
63429919e3SPeiming Liu                  &properties,
64429919e3SPeiming Liu              unsigned n, unsigned m) {
65429919e3SPeiming Liu             return mlirSparseTensorEncodingAttrBuildLvlType(
66429919e3SPeiming Liu                 lvlFmt, properties.data(), properties.size(), n, m);
672a6b521bSYinying Li           },
68*5cd42747SPeter Hawkins           nb::arg("cls"), nb::arg("lvl_fmt"),
69*5cd42747SPeter Hawkins           nb::arg("properties") =
70429919e3SPeiming Liu               std::vector<MlirSparseTensorLevelPropertyNondefault>(),
71*5cd42747SPeter Hawkins           nb::arg("n") = 0, nb::arg("m") = 0,
722a6b521bSYinying Li           "Builds a sparse_tensor.encoding.level_type from parameters.")
73f13893f6SStella Laurenzo       .def_property_readonly(
74a0615d02Swren romano           "lvl_types",
75f13893f6SStella Laurenzo           [](MlirAttribute self) {
7684cd51bbSwren romano             const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
771944c4f7SAart Bik             std::vector<MlirSparseTensorLevelType> ret;
7884cd51bbSwren romano             ret.reserve(lvlRank);
7984cd51bbSwren romano             for (int l = 0; l < lvlRank; ++l)
80a0615d02Swren romano               ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l));
81f13893f6SStella Laurenzo             return ret;
82f13893f6SStella Laurenzo           })
83f13893f6SStella Laurenzo       .def_property_readonly(
8476647fceSwren romano           "dim_to_lvl",
850a81ace0SKazu Hirata           [](MlirAttribute self) -> std::optional<MlirAffineMap> {
8676647fceSwren romano             MlirAffineMap ret = mlirSparseTensorEncodingAttrGetDimToLvl(self);
87c48e9087SAart Bik             if (mlirAffineMapIsNull(ret))
88c48e9087SAart Bik               return {};
89c48e9087SAart Bik             return ret;
90c48e9087SAart Bik           })
91d4088e7dSYinying Li       .def_property_readonly(
92d4088e7dSYinying Li           "lvl_to_dim",
93d4088e7dSYinying Li           [](MlirAttribute self) -> std::optional<MlirAffineMap> {
94d4088e7dSYinying Li             MlirAffineMap ret = mlirSparseTensorEncodingAttrGetLvlToDim(self);
95d4088e7dSYinying Li             if (mlirAffineMapIsNull(ret))
96d4088e7dSYinying Li               return {};
97d4088e7dSYinying Li             return ret;
98d4088e7dSYinying Li           })
9984cd51bbSwren romano       .def_property_readonly("pos_width",
10084cd51bbSwren romano                              mlirSparseTensorEncodingAttrGetPosWidth)
10184cd51bbSwren romano       .def_property_readonly("crd_width",
1022a6b521bSYinying Li                              mlirSparseTensorEncodingAttrGetCrdWidth)
1032a6b521bSYinying Li       .def_property_readonly(
104a10d67f9SYinying Li           "explicit_val",
105a10d67f9SYinying Li           [](MlirAttribute self) -> std::optional<MlirAttribute> {
106a10d67f9SYinying Li             MlirAttribute ret =
107a10d67f9SYinying Li                 mlirSparseTensorEncodingAttrGetExplicitVal(self);
108a10d67f9SYinying Li             if (mlirAttributeIsNull(ret))
109a10d67f9SYinying Li               return {};
110a10d67f9SYinying Li             return ret;
111a10d67f9SYinying Li           })
112a10d67f9SYinying Li       .def_property_readonly(
113a10d67f9SYinying Li           "implicit_val",
114a10d67f9SYinying Li           [](MlirAttribute self) -> std::optional<MlirAttribute> {
115a10d67f9SYinying Li             MlirAttribute ret =
116a10d67f9SYinying Li                 mlirSparseTensorEncodingAttrGetImplicitVal(self);
117a10d67f9SYinying Li             if (mlirAttributeIsNull(ret))
118a10d67f9SYinying Li               return {};
119a10d67f9SYinying Li             return ret;
120a10d67f9SYinying Li           })
121a10d67f9SYinying Li       .def_property_readonly(
1222a6b521bSYinying Li           "structured_n",
1232a6b521bSYinying Li           [](MlirAttribute self) -> unsigned {
1242a6b521bSYinying Li             const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
1252a6b521bSYinying Li             return mlirSparseTensorEncodingAttrGetStructuredN(
1262a6b521bSYinying Li                 mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
1272a6b521bSYinying Li           })
1282a6b521bSYinying Li       .def_property_readonly(
1292a6b521bSYinying Li           "structured_m",
1302a6b521bSYinying Li           [](MlirAttribute self) -> unsigned {
1312a6b521bSYinying Li             const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
1322a6b521bSYinying Li             return mlirSparseTensorEncodingAttrGetStructuredM(
1332a6b521bSYinying Li                 mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
1342a6b521bSYinying Li           })
135429919e3SPeiming Liu       .def_property_readonly("lvl_formats_enum", [](MlirAttribute self) {
1362a6b521bSYinying Li         const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
137429919e3SPeiming Liu         std::vector<MlirSparseTensorLevelFormat> ret;
1382a6b521bSYinying Li         ret.reserve(lvlRank);
139429919e3SPeiming Liu         for (int l = 0; l < lvlRank; l++)
140429919e3SPeiming Liu           ret.push_back(mlirSparseTensorEncodingAttrGetLvlFmt(self, l));
1412a6b521bSYinying Li         return ret;
1422a6b521bSYinying Li       });
143f13893f6SStella Laurenzo }
14495ddbed9SAlex Zinenko 
145*5cd42747SPeter Hawkins NB_MODULE(_mlirDialectsSparseTensor, m) {
14695ddbed9SAlex Zinenko   m.doc() = "MLIR SparseTensor dialect.";
14795ddbed9SAlex Zinenko   populateDialectSparseTensorSubmodule(m);
14895ddbed9SAlex Zinenko }
149