1 //===- Tensor.cpp - C API for SparseTensor dialect ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/Dialect/SparseTensor.h" 10 #include "mlir-c/IR.h" 11 #include "mlir/CAPI/AffineMap.h" 12 #include "mlir/CAPI/Registration.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 14 #include "mlir/Support/LLVM.h" 15 16 using namespace llvm; 17 using namespace mlir::sparse_tensor; 18 19 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor, 20 mlir::sparse_tensor::SparseTensorDialect) 21 22 // Ensure the C-API enums are int-castable to C++ equivalents. 23 static_assert( 24 static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) == 25 static_cast<int>(LevelType::Dense) && 26 static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) == 27 static_cast<int>(LevelType::Compressed) && 28 static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) == 29 static_cast<int>(LevelType::CompressedNu) && 30 static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) == 31 static_cast<int>(LevelType::CompressedNo) && 32 static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) == 33 static_cast<int>(LevelType::CompressedNuNo) && 34 static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) == 35 static_cast<int>(LevelType::Singleton) && 36 static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) == 37 static_cast<int>(LevelType::SingletonNu) && 38 static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) == 39 static_cast<int>(LevelType::SingletonNo) && 40 static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) == 41 static_cast<int>(LevelType::SingletonNuNo) && 42 static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED) == 43 static_cast<int>(LevelType::LooseCompressed) && 44 static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU) == 45 static_cast<int>(LevelType::LooseCompressedNu) && 46 static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO) == 47 static_cast<int>(LevelType::LooseCompressedNo) && 48 static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO) == 49 static_cast<int>(LevelType::LooseCompressedNuNo) && 50 static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) == 51 static_cast<int>(LevelType::NOutOfM), 52 "MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch"); 53 54 bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { 55 return isa<SparseTensorEncodingAttr>(unwrap(attr)); 56 } 57 58 MlirAttribute 59 mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank, 60 MlirSparseTensorLevelType const *lvlTypes, 61 MlirAffineMap dimToLvl, MlirAffineMap lvlToDim, 62 int posWidth, int crdWidth) { 63 SmallVector<LevelType> cppLvlTypes; 64 cppLvlTypes.reserve(lvlRank); 65 for (intptr_t l = 0; l < lvlRank; ++l) 66 cppLvlTypes.push_back(static_cast<LevelType>(lvlTypes[l])); 67 return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes, 68 unwrap(dimToLvl), unwrap(lvlToDim), 69 posWidth, crdWidth)); 70 } 71 72 MlirAffineMap mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr) { 73 return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getDimToLvl()); 74 } 75 76 MlirAffineMap mlirSparseTensorEncodingAttrGetLvlToDim(MlirAttribute attr) { 77 return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlToDim()); 78 } 79 80 intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) { 81 return cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlRank(); 82 } 83 84 MlirSparseTensorLevelType 85 mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) { 86 return static_cast<MlirSparseTensorLevelType>( 87 cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlType(lvl)); 88 } 89 90 int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) { 91 return cast<SparseTensorEncodingAttr>(unwrap(attr)).getPosWidth(); 92 } 93 94 int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) { 95 return cast<SparseTensorEncodingAttr>(unwrap(attr)).getCrdWidth(); 96 } 97 98 MlirSparseTensorLevelType 99 mlirSparseTensorEncodingAttrBuildLvlType(MlirBaseSparseTensorLevelType lvlType, 100 unsigned n, unsigned m) { 101 LevelType lt = static_cast<LevelType>(lvlType); 102 return static_cast<MlirSparseTensorLevelType>(*buildLevelType( 103 *getLevelFormat(lt), isOrderedLT(lt), isUniqueLT(lt), n, m)); 104 } 105 106 unsigned 107 mlirSparseTensorEncodingAttrGetStructuredN(MlirSparseTensorLevelType lvlType) { 108 return getN(static_cast<LevelType>(lvlType)); 109 } 110 111 unsigned 112 mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType) { 113 return getM(static_cast<LevelType>(lvlType)); 114 } 115