//===- Tensor.cpp - C API for SparseTensor dialect ------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir-c/Dialect/SparseTensor.h" #include "mlir-c/IR.h" #include "mlir/CAPI/AffineMap.h" #include "mlir/CAPI/Registration.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Support/LLVM.h" using namespace llvm; using namespace mlir::sparse_tensor; MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor, mlir::sparse_tensor::SparseTensorDialect) // Ensure the C-API enums are int-castable to C++ equivalents. static_assert( static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) == static_cast(SparseTensorEncodingAttr::DimLevelType::Dense) && static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) == static_cast( SparseTensorEncodingAttr::DimLevelType::Compressed) && static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) == static_cast(SparseTensorEncodingAttr::DimLevelType::Singleton), "MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch"); bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirSparseTensorEncodingAttrGet( MlirContext ctx, intptr_t numDimLevelTypes, MlirSparseTensorDimLevelType const *dimLevelTypes, MlirAffineMap dimOrdering, int pointerBitWidth, int indexBitWidth) { SmallVector cppDimLevelTypes; cppDimLevelTypes.resize(numDimLevelTypes); for (intptr_t i = 0; i < numDimLevelTypes; ++i) cppDimLevelTypes[i] = static_cast(dimLevelTypes[i]); return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppDimLevelTypes, unwrap(dimOrdering), pointerBitWidth, indexBitWidth)); } MlirAffineMap mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr) { return wrap(unwrap(attr).cast().getDimOrdering()); } intptr_t mlirSparseTensorEncodingGetNumDimLevelTypes(MlirAttribute attr) { return unwrap(attr).cast().getDimLevelType().size(); } MlirSparseTensorDimLevelType mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t pos) { return static_cast( unwrap(attr).cast().getDimLevelType()[pos]); } int mlirSparseTensorEncodingAttrGetPointerBitWidth(MlirAttribute attr) { return unwrap(attr).cast().getPointerBitWidth(); } int mlirSparseTensorEncodingAttrGetIndexBitWidth(MlirAttribute attr) { return unwrap(attr).cast().getIndexBitWidth(); }