xref: /llvm-project/mlir/lib/CAPI/Dialect/SparseTensor.cpp (revision bcfa7baec8bbf45b98bcde60305efa23df7399e6)
1 //===- Tensor.cpp - C API for SparseTensor dialect ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Dialect/SparseTensor.h"
10 #include "mlir-c/IR.h"
11 #include "mlir/CAPI/AffineMap.h"
12 #include "mlir/CAPI/Registration.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14 #include "mlir/Support/LLVM.h"
15 
16 using namespace llvm;
17 using namespace mlir::sparse_tensor;
18 
19 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor,
20                                       mlir::sparse_tensor::SparseTensorDialect)
21 
22 // Ensure the C-API enums are int-castable to C++ equivalents.
23 static_assert(
24     static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) ==
25             static_cast<int>(SparseTensorEncodingAttr::DimLevelType::Dense) &&
26         static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) ==
27             static_cast<int>(
28                 SparseTensorEncodingAttr::DimLevelType::Compressed) &&
29         static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) ==
30             static_cast<int>(SparseTensorEncodingAttr::DimLevelType::Singleton),
31     "MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch");
32 
33 bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
34   return unwrap(attr).isa<SparseTensorEncodingAttr>();
35 }
36 
37 MlirAttribute mlirSparseTensorEncodingAttrGet(
38     MlirContext ctx, intptr_t numDimLevelTypes,
39     MlirSparseTensorDimLevelType const *dimLevelTypes,
40     MlirAffineMap dimOrdering, int pointerBitWidth, int indexBitWidth) {
41   SmallVector<SparseTensorEncodingAttr::DimLevelType> cppDimLevelTypes;
42   cppDimLevelTypes.resize(numDimLevelTypes);
43   for (intptr_t i = 0; i < numDimLevelTypes; ++i)
44     cppDimLevelTypes[i] =
45         static_cast<SparseTensorEncodingAttr::DimLevelType>(dimLevelTypes[i]);
46   return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppDimLevelTypes,
47                                             unwrap(dimOrdering),
48                                             pointerBitWidth, indexBitWidth));
49 }
50 
51 MlirAffineMap mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr) {
52   return wrap(unwrap(attr).cast<SparseTensorEncodingAttr>().getDimOrdering());
53 }
54 
55 intptr_t mlirSparseTensorEncodingGetNumDimLevelTypes(MlirAttribute attr) {
56   return unwrap(attr).cast<SparseTensorEncodingAttr>().getDimLevelType().size();
57 }
58 
59 MlirSparseTensorDimLevelType
60 mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t pos) {
61   return static_cast<MlirSparseTensorDimLevelType>(
62       unwrap(attr).cast<SparseTensorEncodingAttr>().getDimLevelType()[pos]);
63 }
64 
65 int mlirSparseTensorEncodingAttrGetPointerBitWidth(MlirAttribute attr) {
66   return unwrap(attr).cast<SparseTensorEncodingAttr>().getPointerBitWidth();
67 }
68 
69 int mlirSparseTensorEncodingAttrGetIndexBitWidth(MlirAttribute attr) {
70   return unwrap(attr).cast<SparseTensorEncodingAttr>().getIndexBitWidth();
71 }
72