1 //===- Tensor.cpp - C API for SparseTensor dialect ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/Dialect/SparseTensor.h" 10 #include "mlir-c/IR.h" 11 #include "mlir/CAPI/AffineMap.h" 12 #include "mlir/CAPI/Registration.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 14 #include "mlir/Support/LLVM.h" 15 16 using namespace llvm; 17 using namespace mlir::sparse_tensor; 18 19 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor, 20 mlir::sparse_tensor::SparseTensorDialect) 21 22 // Ensure the C-API enums are int-castable to C++ equivalents. 23 static_assert( 24 static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) == 25 static_cast<int>(LevelFormat::Dense) && 26 static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) == 27 static_cast<int>(LevelFormat::Compressed) && 28 static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) == 29 static_cast<int>(LevelFormat::Singleton) && 30 static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED) == 31 static_cast<int>(LevelFormat::LooseCompressed) && 32 static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) == 33 static_cast<int>(LevelFormat::NOutOfM), 34 "MlirSparseTensorLevelFormat (C-API) and LevelFormat (C++) mismatch"); 35 36 static_assert(static_cast<int>(MLIR_SPARSE_PROPERTY_NON_ORDERED) == 37 static_cast<int>(LevelPropNonDefault::Nonordered) && 38 static_cast<int>(MLIR_SPARSE_PROPERTY_NON_UNIQUE) == 39 static_cast<int>(LevelPropNonDefault::Nonunique), 40 "MlirSparseTensorLevelProperty (C-API) and " 41 "LevelPropertyNondefault (C++) mismatch"); 42 43 bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { 44 return isa<SparseTensorEncodingAttr>(unwrap(attr)); 45 } 46 47 MlirAttribute mlirSparseTensorEncodingAttrGet( 48 MlirContext ctx, intptr_t lvlRank, 49 MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl, 50 MlirAffineMap lvlToDim, int posWidth, int crdWidth, 51 MlirAttribute explicitVal, MlirAttribute implicitVal) { 52 SmallVector<LevelType> cppLvlTypes; 53 54 cppLvlTypes.reserve(lvlRank); 55 for (intptr_t l = 0; l < lvlRank; ++l) 56 cppLvlTypes.push_back(static_cast<LevelType>(lvlTypes[l])); 57 58 return wrap(SparseTensorEncodingAttr::get( 59 unwrap(ctx), cppLvlTypes, unwrap(dimToLvl), unwrap(lvlToDim), posWidth, 60 crdWidth, unwrap(explicitVal), unwrap(implicitVal))); 61 } 62 63 MlirAffineMap mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr) { 64 return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getDimToLvl()); 65 } 66 67 MlirAffineMap mlirSparseTensorEncodingAttrGetLvlToDim(MlirAttribute attr) { 68 return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlToDim()); 69 } 70 71 intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) { 72 return cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlRank(); 73 } 74 75 MlirSparseTensorLevelType 76 mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) { 77 return static_cast<MlirSparseTensorLevelType>( 78 cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlType(lvl)); 79 } 80 81 enum MlirSparseTensorLevelFormat 82 mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl) { 83 LevelType lt = 84 static_cast<LevelType>(mlirSparseTensorEncodingAttrGetLvlType(attr, lvl)); 85 return static_cast<MlirSparseTensorLevelFormat>(lt.getLvlFmt()); 86 } 87 88 int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) { 89 return cast<SparseTensorEncodingAttr>(unwrap(attr)).getPosWidth(); 90 } 91 92 int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) { 93 return cast<SparseTensorEncodingAttr>(unwrap(attr)).getCrdWidth(); 94 } 95 96 MlirAttribute mlirSparseTensorEncodingAttrGetExplicitVal(MlirAttribute attr) { 97 return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getExplicitVal()); 98 } 99 100 MlirAttribute mlirSparseTensorEncodingAttrGetImplicitVal(MlirAttribute attr) { 101 return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getImplicitVal()); 102 } 103 104 MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType( 105 enum MlirSparseTensorLevelFormat lvlFmt, 106 const enum MlirSparseTensorLevelPropertyNondefault *properties, 107 unsigned size, unsigned n, unsigned m) { 108 109 std::vector<LevelPropNonDefault> props; 110 for (unsigned i = 0; i < size; i++) 111 props.push_back(static_cast<LevelPropNonDefault>(properties[i])); 112 113 return static_cast<MlirSparseTensorLevelType>( 114 *buildLevelType(static_cast<LevelFormat>(lvlFmt), props, n, m)); 115 } 116 117 unsigned 118 mlirSparseTensorEncodingAttrGetStructuredN(MlirSparseTensorLevelType lvlType) { 119 return getN(static_cast<LevelType>(lvlType)); 120 } 121 122 unsigned 123 mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType) { 124 return getM(static_cast<LevelType>(lvlType)); 125 } 126