xref: /llvm-project/mlir/lib/CAPI/Dialect/SparseTensor.cpp (revision 9921ef73c864c5aa7a2f1e539a09d5cbd487def9)
1 //===- Tensor.cpp - C API for SparseTensor dialect ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Dialect/SparseTensor.h"
10 #include "mlir-c/IR.h"
11 #include "mlir/CAPI/AffineMap.h"
12 #include "mlir/CAPI/Registration.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14 #include "mlir/Support/LLVM.h"
15 
16 using namespace llvm;
17 using namespace mlir::sparse_tensor;
18 
19 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor,
20                                       mlir::sparse_tensor::SparseTensorDialect)
21 
22 // Ensure the C-API enums are int-castable to C++ equivalents.
23 static_assert(
24     static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) ==
25             static_cast<int>(SparseTensorEncodingAttr::DimLevelType::Dense) &&
26         static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) ==
27             static_cast<int>(
28                 SparseTensorEncodingAttr::DimLevelType::Compressed),
29     "MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch");
30 
31 bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
32   return unwrap(attr).isa<SparseTensorEncodingAttr>();
33 }
34 
35 MlirAttribute mlirSparseTensorEncodingAttrGet(
36     MlirContext ctx, intptr_t numDimLevelTypes,
37     MlirSparseTensorDimLevelType const *dimLevelTypes,
38     MlirAffineMap dimOrdering, int pointerBitWidth, int indexBitWidth) {
39   SmallVector<SparseTensorEncodingAttr::DimLevelType> cppDimLevelTypes;
40   cppDimLevelTypes.resize(numDimLevelTypes);
41   for (intptr_t i = 0; i < numDimLevelTypes; ++i)
42     cppDimLevelTypes[i] =
43         static_cast<SparseTensorEncodingAttr::DimLevelType>(dimLevelTypes[i]);
44   return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppDimLevelTypes,
45                                             unwrap(dimOrdering),
46                                             pointerBitWidth, indexBitWidth));
47 }
48 
49 MlirAffineMap mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr) {
50   return wrap(unwrap(attr).cast<SparseTensorEncodingAttr>().getDimOrdering());
51 }
52 
53 intptr_t mlirSparseTensorEncodingGetNumDimLevelTypes(MlirAttribute attr) {
54   return unwrap(attr).cast<SparseTensorEncodingAttr>().getDimLevelType().size();
55 }
56 
57 MlirSparseTensorDimLevelType
58 mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t pos) {
59   return static_cast<MlirSparseTensorDimLevelType>(
60       unwrap(attr).cast<SparseTensorEncodingAttr>().getDimLevelType()[pos]);
61 }
62 
63 int mlirSparseTensorEncodingAttrGetPointerBitWidth(MlirAttribute attr) {
64   return unwrap(attr).cast<SparseTensorEncodingAttr>().getPointerBitWidth();
65 }
66 
67 int mlirSparseTensorEncodingAttrGetIndexBitWidth(MlirAttribute attr) {
68   return unwrap(attr).cast<SparseTensorEncodingAttr>().getIndexBitWidth();
69 }
70