xref: /llvm-project/mlir/lib/Bindings/Python/DialectSparseTensor.cpp (revision cd481fa827b76953cd12dae9319face96670c0b3)
1 //===- DialectSparseTensor.cpp - 'sparse_tensor' dialect submodule --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/AffineMap.h"
10 #include "mlir-c/Dialect/SparseTensor.h"
11 #include "mlir-c/IR.h"
12 #include "mlir/Bindings/Python/PybindAdaptors.h"
13 #include <optional>
14 #include <pybind11/cast.h>
15 #include <pybind11/detail/common.h>
16 #include <pybind11/pybind11.h>
17 #include <pybind11/pytypes.h>
18 #include <vector>
19 
20 namespace py = pybind11;
21 using namespace llvm;
22 using namespace mlir;
23 using namespace mlir::python::adaptors;
24 
25 static void populateDialectSparseTensorSubmodule(const py::module &m) {
26   py::enum_<MlirBaseSparseTensorLevelType>(m, "LevelType", py::module_local())
27       .value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE)
28       .value("compressed24", MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR)
29       .value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED)
30       .value("compressed_nu", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU)
31       .value("compressed_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO)
32       .value("compressed_nu_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO)
33       .value("singleton", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON)
34       .value("singleton_nu", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU)
35       .value("singleton_no", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO)
36       .value("singleton_nu_no", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO)
37       .value("loose_compressed", MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED)
38       .value("loose_compressed_nu",
39              MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU)
40       .value("loose_compressed_no",
41              MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO)
42       .value("loose_compressed_nu_no",
43              MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO);
44 
45   mlir_attribute_subclass(m, "EncodingAttr",
46                           mlirAttributeIsASparseTensorEncodingAttr)
47       .def_classmethod(
48           "get",
49           [](py::object cls, std::vector<MlirSparseTensorLevelType> lvlTypes,
50              std::optional<MlirAffineMap> dimToLvl,
51              std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
52              MlirContext context) {
53             return cls(mlirSparseTensorEncodingAttrGet(
54                 context, lvlTypes.size(), lvlTypes.data(),
55                 dimToLvl ? *dimToLvl : MlirAffineMap{nullptr},
56                 lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth,
57                 crdWidth));
58           },
59           py::arg("cls"), py::arg("lvl_types"), py::arg("dim_to_lvl"),
60           py::arg("lvl_to_dim"), py::arg("pos_width"), py::arg("crd_width"),
61           py::arg("context") = py::none(),
62           "Gets a sparse_tensor.encoding from parameters.")
63       .def_property_readonly(
64           "lvl_types",
65           [](MlirAttribute self) {
66             const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
67             std::vector<MlirSparseTensorLevelType> ret;
68             ret.reserve(lvlRank);
69             for (int l = 0; l < lvlRank; ++l)
70               ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l));
71             return ret;
72           })
73       .def_property_readonly(
74           "dim_to_lvl",
75           [](MlirAttribute self) -> std::optional<MlirAffineMap> {
76             MlirAffineMap ret = mlirSparseTensorEncodingAttrGetDimToLvl(self);
77             if (mlirAffineMapIsNull(ret))
78               return {};
79             return ret;
80           })
81       .def_property_readonly(
82           "lvl_to_dim",
83           [](MlirAttribute self) -> std::optional<MlirAffineMap> {
84             MlirAffineMap ret = mlirSparseTensorEncodingAttrGetLvlToDim(self);
85             if (mlirAffineMapIsNull(ret))
86               return {};
87             return ret;
88           })
89       .def_property_readonly("pos_width",
90                              mlirSparseTensorEncodingAttrGetPosWidth)
91       .def_property_readonly("crd_width",
92                              mlirSparseTensorEncodingAttrGetCrdWidth);
93 }
94 
95 PYBIND11_MODULE(_mlirDialectsSparseTensor, m) {
96   m.doc() = "MLIR SparseTensor dialect.";
97   populateDialectSparseTensorSubmodule(m);
98 }
99