1 //===- DialectSparseTensor.cpp - 'sparse_tensor' dialect submodule --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/Dialect/SparseTensor.h" 10 #include "mlir-c/IR.h" 11 #include "mlir/Bindings/Python/PybindAdaptors.h" 12 #include <optional> 13 14 namespace py = pybind11; 15 using namespace llvm; 16 using namespace mlir; 17 using namespace mlir::python::adaptors; 18 19 static void populateDialectSparseTensorSubmodule(const py::module &m) { 20 py::enum_<MlirSparseTensorDimLevelType>(m, "DimLevelType", py::module_local()) 21 .value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) 22 .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) 23 .value("compressed-nu", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU) 24 .value("compressed-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO) 25 .value("compressed-nu-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO) 26 .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) 27 .value("singleton-nu", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU) 28 .value("singleton-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO) 29 .value("singleton-nu-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO); 30 31 mlir_attribute_subclass(m, "EncodingAttr", 32 mlirAttributeIsASparseTensorEncodingAttr) 33 .def_classmethod( 34 "get", 35 [](py::object cls, 36 std::vector<MlirSparseTensorDimLevelType> dimLevelTypes, 37 llvm::Optional<MlirAffineMap> dimOrdering, 38 llvm::Optional<MlirAffineMap> higherOrdering, int pointerBitWidth, 39 int indexBitWidth, MlirContext context) { 40 return cls(mlirSparseTensorEncodingAttrGet( 41 context, dimLevelTypes.size(), dimLevelTypes.data(), 42 dimOrdering ? *dimOrdering : MlirAffineMap{nullptr}, 43 higherOrdering ? *higherOrdering : MlirAffineMap{nullptr}, 44 pointerBitWidth, indexBitWidth)); 45 }, 46 py::arg("cls"), py::arg("dim_level_types"), py::arg("dim_ordering"), 47 py::arg("higher_ordering"), py::arg("pointer_bit_width"), 48 py::arg("index_bit_width"), py::arg("context") = py::none(), 49 "Gets a sparse_tensor.encoding from parameters.") 50 .def_property_readonly( 51 "dim_level_types", 52 [](MlirAttribute self) { 53 std::vector<MlirSparseTensorDimLevelType> ret; 54 for (int i = 0, 55 e = mlirSparseTensorEncodingGetNumDimLevelTypes(self); 56 i < e; ++i) 57 ret.push_back( 58 mlirSparseTensorEncodingAttrGetDimLevelType(self, i)); 59 return ret; 60 }) 61 .def_property_readonly( 62 "dim_ordering", 63 [](MlirAttribute self) -> llvm::Optional<MlirAffineMap> { 64 MlirAffineMap ret = 65 mlirSparseTensorEncodingAttrGetDimOrdering(self); 66 if (mlirAffineMapIsNull(ret)) 67 return {}; 68 return ret; 69 }) 70 .def_property_readonly( 71 "higher_ordering", 72 [](MlirAttribute self) -> llvm::Optional<MlirAffineMap> { 73 MlirAffineMap ret = 74 mlirSparseTensorEncodingAttrGetHigherOrdering(self); 75 if (mlirAffineMapIsNull(ret)) 76 return {}; 77 return ret; 78 }) 79 .def_property_readonly( 80 "pointer_bit_width", 81 [](MlirAttribute self) { 82 return mlirSparseTensorEncodingAttrGetPointerBitWidth(self); 83 }) 84 .def_property_readonly("index_bit_width", [](MlirAttribute self) { 85 return mlirSparseTensorEncodingAttrGetIndexBitWidth(self); 86 }); 87 } 88 89 PYBIND11_MODULE(_mlirDialectsSparseTensor, m) { 90 m.doc() = "MLIR SparseTensor dialect."; 91 populateDialectSparseTensorSubmodule(m); 92 } 93