1 //===- DialectSparseTensor.cpp - 'sparse_tensor' dialect submodule --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/Dialect/SparseTensor.h" 10 #include "mlir-c/IR.h" 11 #include "mlir/Bindings/Python/PybindAdaptors.h" 12 #include <optional> 13 14 namespace py = pybind11; 15 using namespace llvm; 16 using namespace mlir; 17 using namespace mlir::python::adaptors; 18 19 static void populateDialectSparseTensorSubmodule(const py::module &m) { 20 py::enum_<MlirSparseTensorDimLevelType>(m, "DimLevelType", py::module_local()) 21 .value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) 22 .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) 23 .value("compressed-nu", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU) 24 .value("compressed-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO) 25 .value("compressed-nu-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO) 26 .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) 27 .value("singleton-nu", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU) 28 .value("singleton-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO) 29 .value("singleton-nu-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO); 30 31 mlir_attribute_subclass(m, "EncodingAttr", 32 mlirAttributeIsASparseTensorEncodingAttr) 33 .def_classmethod( 34 "get", 35 [](py::object cls, 36 std::vector<MlirSparseTensorDimLevelType> dimLevelTypes, 37 std::optional<MlirAffineMap> dimOrdering, 38 std::optional<MlirAffineMap> higherOrdering, int posWidth, 39 int crdWidth, MlirContext context) { 40 return cls(mlirSparseTensorEncodingAttrGet( 41 context, dimLevelTypes.size(), dimLevelTypes.data(), 42 dimOrdering ? *dimOrdering : MlirAffineMap{nullptr}, 43 higherOrdering ? *higherOrdering : MlirAffineMap{nullptr}, 44 posWidth, crdWidth)); 45 }, 46 py::arg("cls"), py::arg("dim_level_types"), py::arg("dim_ordering"), 47 py::arg("higher_ordering"), py::arg("pos_width"), 48 py::arg("crd_width"), py::arg("context") = py::none(), 49 "Gets a sparse_tensor.encoding from parameters.") 50 .def_property_readonly( 51 "dim_level_types", 52 [](MlirAttribute self) { 53 const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self); 54 std::vector<MlirSparseTensorDimLevelType> ret; 55 ret.reserve(lvlRank); 56 for (int l = 0; l < lvlRank; ++l) 57 ret.push_back( 58 mlirSparseTensorEncodingAttrGetDimLevelType(self, l)); 59 return ret; 60 }) 61 .def_property_readonly( 62 "dim_ordering", 63 [](MlirAttribute self) -> std::optional<MlirAffineMap> { 64 MlirAffineMap ret = 65 mlirSparseTensorEncodingAttrGetDimOrdering(self); 66 if (mlirAffineMapIsNull(ret)) 67 return {}; 68 return ret; 69 }) 70 .def_property_readonly( 71 "higher_ordering", 72 [](MlirAttribute self) -> std::optional<MlirAffineMap> { 73 MlirAffineMap ret = 74 mlirSparseTensorEncodingAttrGetHigherOrdering(self); 75 if (mlirAffineMapIsNull(ret)) 76 return {}; 77 return ret; 78 }) 79 .def_property_readonly("pos_width", 80 mlirSparseTensorEncodingAttrGetPosWidth) 81 .def_property_readonly("crd_width", 82 mlirSparseTensorEncodingAttrGetCrdWidth); 83 } 84 85 PYBIND11_MODULE(_mlirDialectsSparseTensor, m) { 86 m.doc() = "MLIR SparseTensor dialect."; 87 populateDialectSparseTensorSubmodule(m); 88 } 89