xref: /llvm-project/mlir/lib/CAPI/Dialect/SparseTensor.cpp (revision c48e90877f936710491614b39147410f711c9931)
1 //===- Tensor.cpp - C API for SparseTensor dialect ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Dialect/SparseTensor.h"
10 #include "mlir-c/IR.h"
11 #include "mlir/CAPI/AffineMap.h"
12 #include "mlir/CAPI/Registration.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14 #include "mlir/Support/LLVM.h"
15 
16 using namespace llvm;
17 using namespace mlir::sparse_tensor;
18 
19 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor,
20                                       mlir::sparse_tensor::SparseTensorDialect)
21 
22 // Ensure the C-API enums are int-castable to C++ equivalents.
23 static_assert(
24     static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) ==
25             static_cast<int>(SparseTensorEncodingAttr::DimLevelType::Dense) &&
26         static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) ==
27             static_cast<int>(
28                 SparseTensorEncodingAttr::DimLevelType::Compressed) &&
29         static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU) ==
30             static_cast<int>(
31                 SparseTensorEncodingAttr::DimLevelType::CompressedNu) &&
32         static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO) ==
33             static_cast<int>(
34                 SparseTensorEncodingAttr::DimLevelType::CompressedNo) &&
35         static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO) ==
36             static_cast<int>(
37                 SparseTensorEncodingAttr::DimLevelType::CompressedNuNo) &&
38         static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) ==
39             static_cast<int>(
40                 SparseTensorEncodingAttr::DimLevelType::Singleton) &&
41         static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU) ==
42             static_cast<int>(
43                 SparseTensorEncodingAttr::DimLevelType::SingletonNu) &&
44         static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO) ==
45             static_cast<int>(
46                 SparseTensorEncodingAttr::DimLevelType::SingletonNo) &&
47         static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO) ==
48             static_cast<int>(
49                 SparseTensorEncodingAttr::DimLevelType::SingletonNuNo),
50     "MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch");
51 
52 bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
53   return unwrap(attr).isa<SparseTensorEncodingAttr>();
54 }
55 
56 MlirAttribute mlirSparseTensorEncodingAttrGet(
57     MlirContext ctx, intptr_t numDimLevelTypes,
58     MlirSparseTensorDimLevelType const *dimLevelTypes,
59     MlirAffineMap dimOrdering, MlirAffineMap higherOrdering,
60     int pointerBitWidth, int indexBitWidth) {
61   SmallVector<SparseTensorEncodingAttr::DimLevelType> cppDimLevelTypes;
62   cppDimLevelTypes.resize(numDimLevelTypes);
63   for (intptr_t i = 0; i < numDimLevelTypes; ++i)
64     cppDimLevelTypes[i] =
65         static_cast<SparseTensorEncodingAttr::DimLevelType>(dimLevelTypes[i]);
66   return wrap(SparseTensorEncodingAttr::get(
67       unwrap(ctx), cppDimLevelTypes, unwrap(dimOrdering),
68       unwrap(higherOrdering), pointerBitWidth, indexBitWidth));
69 }
70 
71 MlirAffineMap mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr) {
72   return wrap(unwrap(attr).cast<SparseTensorEncodingAttr>().getDimOrdering());
73 }
74 
75 MlirAffineMap
76 mlirSparseTensorEncodingAttrGetHigherOrdering(MlirAttribute attr) {
77   return wrap(
78       unwrap(attr).cast<SparseTensorEncodingAttr>().getHigherOrdering());
79 }
80 
81 intptr_t mlirSparseTensorEncodingGetNumDimLevelTypes(MlirAttribute attr) {
82   return unwrap(attr).cast<SparseTensorEncodingAttr>().getDimLevelType().size();
83 }
84 
85 MlirSparseTensorDimLevelType
86 mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t pos) {
87   return static_cast<MlirSparseTensorDimLevelType>(
88       unwrap(attr).cast<SparseTensorEncodingAttr>().getDimLevelType()[pos]);
89 }
90 
91 int mlirSparseTensorEncodingAttrGetPointerBitWidth(MlirAttribute attr) {
92   return unwrap(attr).cast<SparseTensorEncodingAttr>().getPointerBitWidth();
93 }
94 
95 int mlirSparseTensorEncodingAttrGetIndexBitWidth(MlirAttribute attr) {
96   return unwrap(attr).cast<SparseTensorEncodingAttr>().getIndexBitWidth();
97 }
98