1 //===- Tensor.cpp - C API for SparseTensor dialect ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/Dialect/SparseTensor.h" 10 #include "mlir-c/IR.h" 11 #include "mlir/CAPI/AffineMap.h" 12 #include "mlir/CAPI/Registration.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 14 #include "mlir/Support/LLVM.h" 15 16 using namespace llvm; 17 using namespace mlir::sparse_tensor; 18 19 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor, 20 mlir::sparse_tensor::SparseTensorDialect) 21 22 // Ensure the C-API enums are int-castable to C++ equivalents. 23 static_assert( 24 static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) == 25 static_cast<int>(DimLevelType::Dense) && 26 static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) == 27 static_cast<int>(DimLevelType::Compressed) && 28 static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU) == 29 static_cast<int>(DimLevelType::CompressedNu) && 30 static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO) == 31 static_cast<int>(DimLevelType::CompressedNo) && 32 static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO) == 33 static_cast<int>(DimLevelType::CompressedNuNo) && 34 static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) == 35 static_cast<int>(DimLevelType::Singleton) && 36 static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU) == 37 static_cast<int>(DimLevelType::SingletonNu) && 38 static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO) == 39 static_cast<int>(DimLevelType::SingletonNo) && 40 static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO) == 41 static_cast<int>(DimLevelType::SingletonNuNo), 42 "MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch"); 43 44 bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { 45 return isa<SparseTensorEncodingAttr>(unwrap(attr)); 46 } 47 48 MlirAttribute mlirSparseTensorEncodingAttrGet( 49 MlirContext ctx, intptr_t lvlRank, 50 MlirSparseTensorDimLevelType const *dimLevelTypes, 51 MlirAffineMap dimOrdering, MlirAffineMap higherOrdering, int posWidth, 52 int crdWidth) { 53 SmallVector<DimLevelType> cppDimLevelTypes; 54 cppDimLevelTypes.reserve(lvlRank); 55 for (intptr_t l = 0; l < lvlRank; ++l) 56 cppDimLevelTypes.push_back(static_cast<DimLevelType>(dimLevelTypes[l])); 57 return wrap(SparseTensorEncodingAttr::get( 58 unwrap(ctx), cppDimLevelTypes, unwrap(dimOrdering), 59 unwrap(higherOrdering), posWidth, crdWidth)); 60 } 61 62 MlirAffineMap mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr) { 63 return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getDimOrdering()); 64 } 65 66 MlirAffineMap 67 mlirSparseTensorEncodingAttrGetHigherOrdering(MlirAttribute attr) { 68 return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getHigherOrdering()); 69 } 70 71 intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) { 72 return cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlRank(); 73 } 74 75 MlirSparseTensorDimLevelType 76 mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t lvl) { 77 return static_cast<MlirSparseTensorDimLevelType>( 78 cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlType(lvl)); 79 } 80 81 int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) { 82 return cast<SparseTensorEncodingAttr>(unwrap(attr)).getPosWidth(); 83 } 84 85 int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) { 86 return cast<SparseTensorEncodingAttr>(unwrap(attr)).getCrdWidth(); 87 } 88