1bcfa7baeSStella Laurenzo //===- Tensor.cpp - C API for SparseTensor dialect ------------------------===// 2bcfa7baeSStella Laurenzo // 3bcfa7baeSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4bcfa7baeSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5bcfa7baeSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6bcfa7baeSStella Laurenzo // 7bcfa7baeSStella Laurenzo //===----------------------------------------------------------------------===// 8bcfa7baeSStella Laurenzo 9bcfa7baeSStella Laurenzo #include "mlir-c/Dialect/SparseTensor.h" 10bcfa7baeSStella Laurenzo #include "mlir-c/IR.h" 11bcfa7baeSStella Laurenzo #include "mlir/CAPI/AffineMap.h" 12bcfa7baeSStella Laurenzo #include "mlir/CAPI/Registration.h" 13bcfa7baeSStella Laurenzo #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 14bcfa7baeSStella Laurenzo #include "mlir/Support/LLVM.h" 15bcfa7baeSStella Laurenzo 16bcfa7baeSStella Laurenzo using namespace llvm; 17bcfa7baeSStella Laurenzo using namespace mlir::sparse_tensor; 18bcfa7baeSStella Laurenzo 19bcfa7baeSStella Laurenzo MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor, 20bcfa7baeSStella Laurenzo mlir::sparse_tensor::SparseTensorDialect) 21bcfa7baeSStella Laurenzo 22bcfa7baeSStella Laurenzo // Ensure the C-API enums are int-castable to C++ equivalents. 23e5924d64SYinying Li static_assert( 24e5924d64SYinying Li static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) == 25429919e3SPeiming Liu static_cast<int>(LevelFormat::Dense) && 261944c4f7SAart Bik static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) == 27429919e3SPeiming Liu static_cast<int>(LevelFormat::Compressed) && 281944c4f7SAart Bik static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) == 29429919e3SPeiming Liu static_cast<int>(LevelFormat::Singleton) && 30e5924d64SYinying Li static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED) == 31429919e3SPeiming Liu static_cast<int>(LevelFormat::LooseCompressed) && 32e5924d64SYinying Li static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) == 33429919e3SPeiming Liu static_cast<int>(LevelFormat::NOutOfM), 34429919e3SPeiming Liu "MlirSparseTensorLevelFormat (C-API) and LevelFormat (C++) mismatch"); 35429919e3SPeiming Liu 36429919e3SPeiming Liu static_assert(static_cast<int>(MLIR_SPARSE_PROPERTY_NON_ORDERED) == 37aaf91645SPeiming Liu static_cast<int>(LevelPropNonDefault::Nonordered) && 38429919e3SPeiming Liu static_cast<int>(MLIR_SPARSE_PROPERTY_NON_UNIQUE) == 39*b50ce4c8SMateusz Sokół static_cast<int>(LevelPropNonDefault::Nonunique) && 40*b50ce4c8SMateusz Sokół static_cast<int>(MLIR_SPARSE_PROPERTY_SOA) == 41*b50ce4c8SMateusz Sokół static_cast<int>(LevelPropNonDefault::SoA), 42429919e3SPeiming Liu "MlirSparseTensorLevelProperty (C-API) and " 43429919e3SPeiming Liu "LevelPropertyNondefault (C++) mismatch"); 44bcfa7baeSStella Laurenzo 45bcfa7baeSStella Laurenzo bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { 465550c821STres Popp return isa<SparseTensorEncodingAttr>(unwrap(attr)); 47bcfa7baeSStella Laurenzo } 48bcfa7baeSStella Laurenzo 49a10d67f9SYinying Li MlirAttribute mlirSparseTensorEncodingAttrGet( 50a10d67f9SYinying Li MlirContext ctx, intptr_t lvlRank, 51a10d67f9SYinying Li MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl, 52a10d67f9SYinying Li MlirAffineMap lvlToDim, int posWidth, int crdWidth, 53a10d67f9SYinying Li MlirAttribute explicitVal, MlirAttribute implicitVal) { 541944c4f7SAart Bik SmallVector<LevelType> cppLvlTypes; 55a10d67f9SYinying Li 56a0615d02Swren romano cppLvlTypes.reserve(lvlRank); 5784cd51bbSwren romano for (intptr_t l = 0; l < lvlRank; ++l) 581944c4f7SAart Bik cppLvlTypes.push_back(static_cast<LevelType>(lvlTypes[l])); 59a10d67f9SYinying Li 60a10d67f9SYinying Li return wrap(SparseTensorEncodingAttr::get( 61a10d67f9SYinying Li unwrap(ctx), cppLvlTypes, unwrap(dimToLvl), unwrap(lvlToDim), posWidth, 62a10d67f9SYinying Li crdWidth, unwrap(explicitVal), unwrap(implicitVal))); 63bcfa7baeSStella Laurenzo } 64bcfa7baeSStella Laurenzo 6576647fceSwren romano MlirAffineMap mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr) { 6676647fceSwren romano return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getDimToLvl()); 67c48e9087SAart Bik } 68c48e9087SAart Bik 69836411b9SAart Bik MlirAffineMap mlirSparseTensorEncodingAttrGetLvlToDim(MlirAttribute attr) { 70836411b9SAart Bik return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlToDim()); 71836411b9SAart Bik } 72836411b9SAart Bik 7384cd51bbSwren romano intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) { 745550c821STres Popp return cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlRank(); 75bcfa7baeSStella Laurenzo } 76bcfa7baeSStella Laurenzo 771944c4f7SAart Bik MlirSparseTensorLevelType 78a0615d02Swren romano mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) { 791944c4f7SAart Bik return static_cast<MlirSparseTensorLevelType>( 805550c821STres Popp cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlType(lvl)); 81bcfa7baeSStella Laurenzo } 82bcfa7baeSStella Laurenzo 83429919e3SPeiming Liu enum MlirSparseTensorLevelFormat 84429919e3SPeiming Liu mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl) { 85429919e3SPeiming Liu LevelType lt = 86429919e3SPeiming Liu static_cast<LevelType>(mlirSparseTensorEncodingAttrGetLvlType(attr, lvl)); 87aaf91645SPeiming Liu return static_cast<MlirSparseTensorLevelFormat>(lt.getLvlFmt()); 88429919e3SPeiming Liu } 89429919e3SPeiming Liu 9084cd51bbSwren romano int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) { 915550c821STres Popp return cast<SparseTensorEncodingAttr>(unwrap(attr)).getPosWidth(); 92bcfa7baeSStella Laurenzo } 93bcfa7baeSStella Laurenzo 9484cd51bbSwren romano int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) { 955550c821STres Popp return cast<SparseTensorEncodingAttr>(unwrap(attr)).getCrdWidth(); 96bcfa7baeSStella Laurenzo } 972a6b521bSYinying Li 98a10d67f9SYinying Li MlirAttribute mlirSparseTensorEncodingAttrGetExplicitVal(MlirAttribute attr) { 99a10d67f9SYinying Li return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getExplicitVal()); 100a10d67f9SYinying Li } 101a10d67f9SYinying Li 102a10d67f9SYinying Li MlirAttribute mlirSparseTensorEncodingAttrGetImplicitVal(MlirAttribute attr) { 103a10d67f9SYinying Li return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getImplicitVal()); 104a10d67f9SYinying Li } 105a10d67f9SYinying Li 106429919e3SPeiming Liu MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType( 107429919e3SPeiming Liu enum MlirSparseTensorLevelFormat lvlFmt, 108429919e3SPeiming Liu const enum MlirSparseTensorLevelPropertyNondefault *properties, 109429919e3SPeiming Liu unsigned size, unsigned n, unsigned m) { 110429919e3SPeiming Liu 111aaf91645SPeiming Liu std::vector<LevelPropNonDefault> props; 11236c57532SAdrian Kuegel props.reserve(size); 113429919e3SPeiming Liu for (unsigned i = 0; i < size; i++) 114aaf91645SPeiming Liu props.push_back(static_cast<LevelPropNonDefault>(properties[i])); 115429919e3SPeiming Liu 116429919e3SPeiming Liu return static_cast<MlirSparseTensorLevelType>( 117429919e3SPeiming Liu *buildLevelType(static_cast<LevelFormat>(lvlFmt), props, n, m)); 1182a6b521bSYinying Li } 1192a6b521bSYinying Li 1202a6b521bSYinying Li unsigned 1212a6b521bSYinying Li mlirSparseTensorEncodingAttrGetStructuredN(MlirSparseTensorLevelType lvlType) { 1222a6b521bSYinying Li return getN(static_cast<LevelType>(lvlType)); 1232a6b521bSYinying Li } 1242a6b521bSYinying Li 1252a6b521bSYinying Li unsigned 1262a6b521bSYinying Li mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType) { 1272a6b521bSYinying Li return getM(static_cast<LevelType>(lvlType)); 1282a6b521bSYinying Li } 129