xref: /llvm-project/mlir/lib/Bindings/Python/DialectSparseTensor.cpp (revision 2a6b521b36fb538a49564323ecd457d7b08b1325)
1 //===- DialectSparseTensor.cpp - 'sparse_tensor' dialect submodule --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/AffineMap.h"
10 #include "mlir-c/Dialect/SparseTensor.h"
11 #include "mlir-c/IR.h"
12 #include "mlir/Bindings/Python/PybindAdaptors.h"
13 #include <optional>
14 #include <pybind11/cast.h>
15 #include <pybind11/detail/common.h>
16 #include <pybind11/pybind11.h>
17 #include <pybind11/pytypes.h>
18 #include <vector>
19 
20 namespace py = pybind11;
21 using namespace llvm;
22 using namespace mlir;
23 using namespace mlir::python::adaptors;
24 
25 static void populateDialectSparseTensorSubmodule(const py::module &m) {
26   py::enum_<MlirBaseSparseTensorLevelType>(m, "LevelType", py::module_local())
27       .value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE)
28       .value("n_out_of_m", MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M)
29       .value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED)
30       .value("compressed_nu", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU)
31       .value("compressed_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO)
32       .value("compressed_nu_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO)
33       .value("singleton", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON)
34       .value("singleton_nu", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU)
35       .value("singleton_no", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO)
36       .value("singleton_nu_no", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO)
37       .value("loose_compressed", MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED)
38       .value("loose_compressed_nu",
39              MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU)
40       .value("loose_compressed_no",
41              MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO)
42       .value("loose_compressed_nu_no",
43              MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO);
44 
45   mlir_attribute_subclass(m, "EncodingAttr",
46                           mlirAttributeIsASparseTensorEncodingAttr)
47       .def_classmethod(
48           "get",
49           [](py::object cls, std::vector<MlirSparseTensorLevelType> lvlTypes,
50              std::optional<MlirAffineMap> dimToLvl,
51              std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
52              MlirContext context) {
53             return cls(mlirSparseTensorEncodingAttrGet(
54                 context, lvlTypes.size(), lvlTypes.data(),
55                 dimToLvl ? *dimToLvl : MlirAffineMap{nullptr},
56                 lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth,
57                 crdWidth));
58           },
59           py::arg("cls"), py::arg("lvl_types"), py::arg("dim_to_lvl"),
60           py::arg("lvl_to_dim"), py::arg("pos_width"), py::arg("crd_width"),
61           py::arg("context") = py::none(),
62           "Gets a sparse_tensor.encoding from parameters.")
63       .def_classmethod(
64           "build_level_type",
65           [](py::object cls, MlirBaseSparseTensorLevelType lvlType, unsigned n,
66              unsigned m) {
67             return mlirSparseTensorEncodingAttrBuildLvlType(lvlType, n, m);
68           },
69           py::arg("cls"), py::arg("lvl_type"), py::arg("n") = 0,
70           py::arg("m") = 0,
71           "Builds a sparse_tensor.encoding.level_type from parameters.")
72       .def_property_readonly(
73           "lvl_types",
74           [](MlirAttribute self) {
75             const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
76             std::vector<MlirSparseTensorLevelType> ret;
77             ret.reserve(lvlRank);
78             for (int l = 0; l < lvlRank; ++l)
79               ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l));
80             return ret;
81           })
82       .def_property_readonly(
83           "dim_to_lvl",
84           [](MlirAttribute self) -> std::optional<MlirAffineMap> {
85             MlirAffineMap ret = mlirSparseTensorEncodingAttrGetDimToLvl(self);
86             if (mlirAffineMapIsNull(ret))
87               return {};
88             return ret;
89           })
90       .def_property_readonly(
91           "lvl_to_dim",
92           [](MlirAttribute self) -> std::optional<MlirAffineMap> {
93             MlirAffineMap ret = mlirSparseTensorEncodingAttrGetLvlToDim(self);
94             if (mlirAffineMapIsNull(ret))
95               return {};
96             return ret;
97           })
98       .def_property_readonly("pos_width",
99                              mlirSparseTensorEncodingAttrGetPosWidth)
100       .def_property_readonly("crd_width",
101                              mlirSparseTensorEncodingAttrGetCrdWidth)
102       .def_property_readonly(
103           "structured_n",
104           [](MlirAttribute self) -> unsigned {
105             const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
106             return mlirSparseTensorEncodingAttrGetStructuredN(
107                 mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
108           })
109       .def_property_readonly(
110           "structured_m",
111           [](MlirAttribute self) -> unsigned {
112             const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
113             return mlirSparseTensorEncodingAttrGetStructuredM(
114                 mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
115           })
116       .def_property_readonly("lvl_types_enum", [](MlirAttribute self) {
117         const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
118         std::vector<MlirBaseSparseTensorLevelType> ret;
119         ret.reserve(lvlRank);
120         for (int l = 0; l < lvlRank; l++) {
121           // Convert level type to 32 bits to ignore n and m for n_out_of_m
122           // format.
123           ret.push_back(
124               static_cast<MlirBaseSparseTensorLevelType>(static_cast<uint32_t>(
125                   mlirSparseTensorEncodingAttrGetLvlType(self, l))));
126         }
127         return ret;
128       });
129 }
130 
131 PYBIND11_MODULE(_mlirDialectsSparseTensor, m) {
132   m.doc() = "MLIR SparseTensor dialect.";
133   populateDialectSparseTensorSubmodule(m);
134 }
135