xref: /llvm-project/mlir/lib/CAPI/Dialect/SparseTensor.cpp (revision 1b434652c56704be90d01039f4329ea9320bc742)
1 //===- Tensor.cpp - C API for SparseTensor dialect ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Dialect/SparseTensor.h"
10 #include "mlir-c/IR.h"
11 #include "mlir/CAPI/AffineMap.h"
12 #include "mlir/CAPI/Registration.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14 #include "mlir/Support/LLVM.h"
15 
16 using namespace llvm;
17 using namespace mlir::sparse_tensor;
18 
19 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor,
20                                       mlir::sparse_tensor::SparseTensorDialect)
21 
22 // Ensure the C-API enums are int-castable to C++ equivalents.
23 static_assert(
24     static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) ==
25             static_cast<int>(SparseTensorEncodingAttr::DimLevelType::Dense) &&
26         static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) ==
27             static_cast<int>(
28                 SparseTensorEncodingAttr::DimLevelType::Compressed) &&
29         static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU) ==
30             static_cast<int>(
31                 SparseTensorEncodingAttr::DimLevelType::CompressedNu) &&
32         static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO) ==
33             static_cast<int>(
34                 SparseTensorEncodingAttr::DimLevelType::CompressedNo) &&
35         static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO) ==
36             static_cast<int>(
37                 SparseTensorEncodingAttr::DimLevelType::CompressedNuNo) &&
38         static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) ==
39             static_cast<int>(
40                 SparseTensorEncodingAttr::DimLevelType::Singleton) &&
41         static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU) ==
42             static_cast<int>(
43                 SparseTensorEncodingAttr::DimLevelType::SingletonNu) &&
44         static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO) ==
45             static_cast<int>(
46                 SparseTensorEncodingAttr::DimLevelType::SingletonNo) &&
47         static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO) ==
48             static_cast<int>(
49                 SparseTensorEncodingAttr::DimLevelType::SingletonNuNo),
50     "MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch");
51 
52 bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
53   return unwrap(attr).isa<SparseTensorEncodingAttr>();
54 }
55 
56 MlirAttribute mlirSparseTensorEncodingAttrGet(
57     MlirContext ctx, intptr_t numDimLevelTypes,
58     MlirSparseTensorDimLevelType const *dimLevelTypes,
59     MlirAffineMap dimOrdering, int pointerBitWidth, int indexBitWidth) {
60   SmallVector<SparseTensorEncodingAttr::DimLevelType> cppDimLevelTypes;
61   cppDimLevelTypes.resize(numDimLevelTypes);
62   for (intptr_t i = 0; i < numDimLevelTypes; ++i)
63     cppDimLevelTypes[i] =
64         static_cast<SparseTensorEncodingAttr::DimLevelType>(dimLevelTypes[i]);
65   return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppDimLevelTypes,
66                                             unwrap(dimOrdering),
67                                             pointerBitWidth, indexBitWidth));
68 }
69 
70 MlirAffineMap mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr) {
71   return wrap(unwrap(attr).cast<SparseTensorEncodingAttr>().getDimOrdering());
72 }
73 
74 intptr_t mlirSparseTensorEncodingGetNumDimLevelTypes(MlirAttribute attr) {
75   return unwrap(attr).cast<SparseTensorEncodingAttr>().getDimLevelType().size();
76 }
77 
78 MlirSparseTensorDimLevelType
79 mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t pos) {
80   return static_cast<MlirSparseTensorDimLevelType>(
81       unwrap(attr).cast<SparseTensorEncodingAttr>().getDimLevelType()[pos]);
82 }
83 
84 int mlirSparseTensorEncodingAttrGetPointerBitWidth(MlirAttribute attr) {
85   return unwrap(attr).cast<SparseTensorEncodingAttr>().getPointerBitWidth();
86 }
87 
88 int mlirSparseTensorEncodingAttrGetIndexBitWidth(MlirAttribute attr) {
89   return unwrap(attr).cast<SparseTensorEncodingAttr>().getIndexBitWidth();
90 }
91