xref: /llvm-project/mlir/lib/CAPI/Dialect/SparseTensor.cpp (revision b50ce4c81e71855bc01b9564d3bd239437847184)
1bcfa7baeSStella Laurenzo //===- Tensor.cpp - C API for SparseTensor dialect ------------------------===//
2bcfa7baeSStella Laurenzo //
3bcfa7baeSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4bcfa7baeSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5bcfa7baeSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6bcfa7baeSStella Laurenzo //
7bcfa7baeSStella Laurenzo //===----------------------------------------------------------------------===//
8bcfa7baeSStella Laurenzo 
9bcfa7baeSStella Laurenzo #include "mlir-c/Dialect/SparseTensor.h"
10bcfa7baeSStella Laurenzo #include "mlir-c/IR.h"
11bcfa7baeSStella Laurenzo #include "mlir/CAPI/AffineMap.h"
12bcfa7baeSStella Laurenzo #include "mlir/CAPI/Registration.h"
13bcfa7baeSStella Laurenzo #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14bcfa7baeSStella Laurenzo #include "mlir/Support/LLVM.h"
15bcfa7baeSStella Laurenzo 
16bcfa7baeSStella Laurenzo using namespace llvm;
17bcfa7baeSStella Laurenzo using namespace mlir::sparse_tensor;
18bcfa7baeSStella Laurenzo 
19bcfa7baeSStella Laurenzo MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor,
20bcfa7baeSStella Laurenzo                                       mlir::sparse_tensor::SparseTensorDialect)
21bcfa7baeSStella Laurenzo 
22bcfa7baeSStella Laurenzo // Ensure the C-API enums are int-castable to C++ equivalents.
23e5924d64SYinying Li static_assert(
24e5924d64SYinying Li     static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) ==
25429919e3SPeiming Liu             static_cast<int>(LevelFormat::Dense) &&
261944c4f7SAart Bik         static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) ==
27429919e3SPeiming Liu             static_cast<int>(LevelFormat::Compressed) &&
281944c4f7SAart Bik         static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) ==
29429919e3SPeiming Liu             static_cast<int>(LevelFormat::Singleton) &&
30e5924d64SYinying Li         static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED) ==
31429919e3SPeiming Liu             static_cast<int>(LevelFormat::LooseCompressed) &&
32e5924d64SYinying Li         static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) ==
33429919e3SPeiming Liu             static_cast<int>(LevelFormat::NOutOfM),
34429919e3SPeiming Liu     "MlirSparseTensorLevelFormat (C-API) and LevelFormat (C++) mismatch");
35429919e3SPeiming Liu 
36429919e3SPeiming Liu static_assert(static_cast<int>(MLIR_SPARSE_PROPERTY_NON_ORDERED) ==
37aaf91645SPeiming Liu                       static_cast<int>(LevelPropNonDefault::Nonordered) &&
38429919e3SPeiming Liu                   static_cast<int>(MLIR_SPARSE_PROPERTY_NON_UNIQUE) ==
39*b50ce4c8SMateusz Sokół                       static_cast<int>(LevelPropNonDefault::Nonunique) &&
40*b50ce4c8SMateusz Sokół                   static_cast<int>(MLIR_SPARSE_PROPERTY_SOA) ==
41*b50ce4c8SMateusz Sokół                       static_cast<int>(LevelPropNonDefault::SoA),
42429919e3SPeiming Liu               "MlirSparseTensorLevelProperty (C-API) and "
43429919e3SPeiming Liu               "LevelPropertyNondefault (C++) mismatch");
44bcfa7baeSStella Laurenzo 
45bcfa7baeSStella Laurenzo bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
465550c821STres Popp   return isa<SparseTensorEncodingAttr>(unwrap(attr));
47bcfa7baeSStella Laurenzo }
48bcfa7baeSStella Laurenzo 
49a10d67f9SYinying Li MlirAttribute mlirSparseTensorEncodingAttrGet(
50a10d67f9SYinying Li     MlirContext ctx, intptr_t lvlRank,
51a10d67f9SYinying Li     MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl,
52a10d67f9SYinying Li     MlirAffineMap lvlToDim, int posWidth, int crdWidth,
53a10d67f9SYinying Li     MlirAttribute explicitVal, MlirAttribute implicitVal) {
541944c4f7SAart Bik   SmallVector<LevelType> cppLvlTypes;
55a10d67f9SYinying Li 
56a0615d02Swren romano   cppLvlTypes.reserve(lvlRank);
5784cd51bbSwren romano   for (intptr_t l = 0; l < lvlRank; ++l)
581944c4f7SAart Bik     cppLvlTypes.push_back(static_cast<LevelType>(lvlTypes[l]));
59a10d67f9SYinying Li 
60a10d67f9SYinying Li   return wrap(SparseTensorEncodingAttr::get(
61a10d67f9SYinying Li       unwrap(ctx), cppLvlTypes, unwrap(dimToLvl), unwrap(lvlToDim), posWidth,
62a10d67f9SYinying Li       crdWidth, unwrap(explicitVal), unwrap(implicitVal)));
63bcfa7baeSStella Laurenzo }
64bcfa7baeSStella Laurenzo 
6576647fceSwren romano MlirAffineMap mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr) {
6676647fceSwren romano   return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getDimToLvl());
67c48e9087SAart Bik }
68c48e9087SAart Bik 
69836411b9SAart Bik MlirAffineMap mlirSparseTensorEncodingAttrGetLvlToDim(MlirAttribute attr) {
70836411b9SAart Bik   return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlToDim());
71836411b9SAart Bik }
72836411b9SAart Bik 
7384cd51bbSwren romano intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) {
745550c821STres Popp   return cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlRank();
75bcfa7baeSStella Laurenzo }
76bcfa7baeSStella Laurenzo 
771944c4f7SAart Bik MlirSparseTensorLevelType
78a0615d02Swren romano mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) {
791944c4f7SAart Bik   return static_cast<MlirSparseTensorLevelType>(
805550c821STres Popp       cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlType(lvl));
81bcfa7baeSStella Laurenzo }
82bcfa7baeSStella Laurenzo 
83429919e3SPeiming Liu enum MlirSparseTensorLevelFormat
84429919e3SPeiming Liu mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl) {
85429919e3SPeiming Liu   LevelType lt =
86429919e3SPeiming Liu       static_cast<LevelType>(mlirSparseTensorEncodingAttrGetLvlType(attr, lvl));
87aaf91645SPeiming Liu   return static_cast<MlirSparseTensorLevelFormat>(lt.getLvlFmt());
88429919e3SPeiming Liu }
89429919e3SPeiming Liu 
9084cd51bbSwren romano int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) {
915550c821STres Popp   return cast<SparseTensorEncodingAttr>(unwrap(attr)).getPosWidth();
92bcfa7baeSStella Laurenzo }
93bcfa7baeSStella Laurenzo 
9484cd51bbSwren romano int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) {
955550c821STres Popp   return cast<SparseTensorEncodingAttr>(unwrap(attr)).getCrdWidth();
96bcfa7baeSStella Laurenzo }
972a6b521bSYinying Li 
98a10d67f9SYinying Li MlirAttribute mlirSparseTensorEncodingAttrGetExplicitVal(MlirAttribute attr) {
99a10d67f9SYinying Li   return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getExplicitVal());
100a10d67f9SYinying Li }
101a10d67f9SYinying Li 
102a10d67f9SYinying Li MlirAttribute mlirSparseTensorEncodingAttrGetImplicitVal(MlirAttribute attr) {
103a10d67f9SYinying Li   return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getImplicitVal());
104a10d67f9SYinying Li }
105a10d67f9SYinying Li 
106429919e3SPeiming Liu MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType(
107429919e3SPeiming Liu     enum MlirSparseTensorLevelFormat lvlFmt,
108429919e3SPeiming Liu     const enum MlirSparseTensorLevelPropertyNondefault *properties,
109429919e3SPeiming Liu     unsigned size, unsigned n, unsigned m) {
110429919e3SPeiming Liu 
111aaf91645SPeiming Liu   std::vector<LevelPropNonDefault> props;
11236c57532SAdrian Kuegel   props.reserve(size);
113429919e3SPeiming Liu   for (unsigned i = 0; i < size; i++)
114aaf91645SPeiming Liu     props.push_back(static_cast<LevelPropNonDefault>(properties[i]));
115429919e3SPeiming Liu 
116429919e3SPeiming Liu   return static_cast<MlirSparseTensorLevelType>(
117429919e3SPeiming Liu       *buildLevelType(static_cast<LevelFormat>(lvlFmt), props, n, m));
1182a6b521bSYinying Li }
1192a6b521bSYinying Li 
1202a6b521bSYinying Li unsigned
1212a6b521bSYinying Li mlirSparseTensorEncodingAttrGetStructuredN(MlirSparseTensorLevelType lvlType) {
1222a6b521bSYinying Li   return getN(static_cast<LevelType>(lvlType));
1232a6b521bSYinying Li }
1242a6b521bSYinying Li 
1252a6b521bSYinying Li unsigned
1262a6b521bSYinying Li mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType) {
1272a6b521bSYinying Li   return getM(static_cast<LevelType>(lvlType));
1282a6b521bSYinying Li }
129