1 //===- Tensor.cpp - C API for SparseTensor dialect ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/Dialect/SparseTensor.h" 10 #include "mlir-c/IR.h" 11 #include "mlir/CAPI/AffineMap.h" 12 #include "mlir/CAPI/Registration.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 14 #include "mlir/Support/LLVM.h" 15 16 using namespace llvm; 17 using namespace mlir::sparse_tensor; 18 19 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor, 20 mlir::sparse_tensor::SparseTensorDialect) 21 22 // Ensure the C-API enums are int-castable to C++ equivalents. 23 static_assert( 24 static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) == 25 static_cast<int>(SparseTensorEncodingAttr::DimLevelType::Dense) && 26 static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) == 27 static_cast<int>( 28 SparseTensorEncodingAttr::DimLevelType::Compressed) && 29 static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU) == 30 static_cast<int>( 31 SparseTensorEncodingAttr::DimLevelType::CompressedNu) && 32 static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO) == 33 static_cast<int>( 34 SparseTensorEncodingAttr::DimLevelType::CompressedNo) && 35 static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO) == 36 static_cast<int>( 37 SparseTensorEncodingAttr::DimLevelType::CompressedNuNo) && 38 static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) == 39 static_cast<int>( 40 SparseTensorEncodingAttr::DimLevelType::Singleton) && 41 static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU) == 42 static_cast<int>( 43 SparseTensorEncodingAttr::DimLevelType::SingletonNu) && 44 static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO) == 45 static_cast<int>( 46 SparseTensorEncodingAttr::DimLevelType::SingletonNo) && 47 static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO) == 48 static_cast<int>( 49 SparseTensorEncodingAttr::DimLevelType::SingletonNuNo), 50 "MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch"); 51 52 bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { 53 return unwrap(attr).isa<SparseTensorEncodingAttr>(); 54 } 55 56 MlirAttribute mlirSparseTensorEncodingAttrGet( 57 MlirContext ctx, intptr_t numDimLevelTypes, 58 MlirSparseTensorDimLevelType const *dimLevelTypes, 59 MlirAffineMap dimOrdering, MlirAffineMap higherOrdering, 60 int pointerBitWidth, int indexBitWidth) { 61 SmallVector<SparseTensorEncodingAttr::DimLevelType> cppDimLevelTypes; 62 cppDimLevelTypes.resize(numDimLevelTypes); 63 for (intptr_t i = 0; i < numDimLevelTypes; ++i) 64 cppDimLevelTypes[i] = 65 static_cast<SparseTensorEncodingAttr::DimLevelType>(dimLevelTypes[i]); 66 return wrap(SparseTensorEncodingAttr::get( 67 unwrap(ctx), cppDimLevelTypes, unwrap(dimOrdering), 68 unwrap(higherOrdering), pointerBitWidth, indexBitWidth)); 69 } 70 71 MlirAffineMap mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr) { 72 return wrap(unwrap(attr).cast<SparseTensorEncodingAttr>().getDimOrdering()); 73 } 74 75 MlirAffineMap 76 mlirSparseTensorEncodingAttrGetHigherOrdering(MlirAttribute attr) { 77 return wrap( 78 unwrap(attr).cast<SparseTensorEncodingAttr>().getHigherOrdering()); 79 } 80 81 intptr_t mlirSparseTensorEncodingGetNumDimLevelTypes(MlirAttribute attr) { 82 return unwrap(attr).cast<SparseTensorEncodingAttr>().getDimLevelType().size(); 83 } 84 85 MlirSparseTensorDimLevelType 86 mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t pos) { 87 return static_cast<MlirSparseTensorDimLevelType>( 88 unwrap(attr).cast<SparseTensorEncodingAttr>().getDimLevelType()[pos]); 89 } 90 91 int mlirSparseTensorEncodingAttrGetPointerBitWidth(MlirAttribute attr) { 92 return unwrap(attr).cast<SparseTensorEncodingAttr>().getPointerBitWidth(); 93 } 94 95 int mlirSparseTensorEncodingAttrGetIndexBitWidth(MlirAttribute attr) { 96 return unwrap(attr).cast<SparseTensorEncodingAttr>().getIndexBitWidth(); 97 } 98