10fca5c5fSwren romano //===- StorageBase.cpp - TACO-flavored sparse tensor representation -------===//
20fca5c5fSwren romano //
30fca5c5fSwren romano // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40fca5c5fSwren romano // See https://llvm.org/LICENSE.txt for license information.
50fca5c5fSwren romano // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60fca5c5fSwren romano //
70fca5c5fSwren romano //===----------------------------------------------------------------------===//
80fca5c5fSwren romano //
90fca5c5fSwren romano // This file contains method definitions for `SparseTensorStorageBase`.
100fca5c5fSwren romano // In particular we want to ensure that the default implementations of
110fca5c5fSwren romano // the "partial method specialization" trick aren't inline (since there's
12427f120fSAart Bik // no benefit).
130fca5c5fSwren romano //
140fca5c5fSwren romano //===----------------------------------------------------------------------===//
150fca5c5fSwren romano
160fca5c5fSwren romano #include "mlir/ExecutionEngine/SparseTensor/Storage.h"
170fca5c5fSwren romano
180fca5c5fSwren romano using namespace mlir::sparse_tensor;
190fca5c5fSwren romano
isAllDense(uint64_t lvlRank,const LevelType * lvlTypes)206fb7c2d7SAart Bik static inline bool isAllDense(uint64_t lvlRank, const LevelType *lvlTypes) {
216fb7c2d7SAart Bik for (uint64_t l = 0; l < lvlRank; l++)
226fb7c2d7SAart Bik if (!isDenseLT(lvlTypes[l]))
236fb7c2d7SAart Bik return false;
246fb7c2d7SAart Bik return true;
256fb7c2d7SAart Bik }
266fb7c2d7SAart Bik
SparseTensorStorageBase(uint64_t dimRank,const uint64_t * dimSizes,uint64_t lvlRank,const uint64_t * lvlSizes,const LevelType * lvlTypes,const uint64_t * dim2lvl,const uint64_t * lvl2dim)27c518745bSwren romano SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT
28c518745bSwren romano uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
291944c4f7SAart Bik const uint64_t *lvlSizes, const LevelType *lvlTypes,
30db1d40f3SAart Bik const uint64_t *dim2lvl, const uint64_t *lvl2dim)
31c518745bSwren romano : dimSizes(dimSizes, dimSizes + dimRank),
32c518745bSwren romano lvlSizes(lvlSizes, lvlSizes + lvlRank),
33c518745bSwren romano lvlTypes(lvlTypes, lvlTypes + lvlRank),
34306f4c30SAart Bik dim2lvlVec(dim2lvl, dim2lvl + lvlRank),
35306f4c30SAart Bik lvl2dimVec(lvl2dim, lvl2dim + dimRank),
366fb7c2d7SAart Bik map(dimRank, lvlRank, dim2lvlVec.data(), lvl2dimVec.data()),
376fb7c2d7SAart Bik allDense(isAllDense(lvlRank, lvlTypes)) {
38db1d40f3SAart Bik assert(dimSizes && lvlSizes && lvlTypes && dim2lvl && lvl2dim);
39c518745bSwren romano // Validate dim-indexed parameters.
40c518745bSwren romano assert(dimRank > 0 && "Trivial shape is unsupported");
416fb7c2d7SAart Bik for (uint64_t d = 0; d < dimRank; d++)
42c518745bSwren romano assert(dimSizes[d] > 0 && "Dimension size zero has trivial storage");
43427f120fSAart Bik // Validate lvl-indexed parameters.
44c518745bSwren romano assert(lvlRank > 0 && "Trivial shape is unsupported");
456fb7c2d7SAart Bik for (uint64_t l = 0; l < lvlRank; l++) {
46c518745bSwren romano assert(lvlSizes[l] > 0 && "Level size zero has trivial storage");
47e6005d5aSAart Bik assert(isDenseLvl(l) || isCompressedLvl(l) || isLooseCompressedLvl(l) ||
48e5924d64SYinying Li isSingletonLvl(l) || isNOutOfMLvl(l));
490fca5c5fSwren romano }
50427f120fSAart Bik }
510fca5c5fSwren romano
52233c3e6cSAart Bik // Helper macro for wrong "partial method specialization" errors.
530fca5c5fSwren romano #define FATAL_PIV(NAME) \
541c2456d6SAart Bik fprintf(stderr, "<P,I,V> type mismatch for: " #NAME); \
551c2456d6SAart Bik exit(1);
560fca5c5fSwren romano
5784cd51bbSwren romano #define IMPL_GETPOSITIONS(PNAME, P) \
5884cd51bbSwren romano void SparseTensorStorageBase::getPositions(std::vector<P> **, uint64_t) { \
5984cd51bbSwren romano FATAL_PIV("getPositions" #PNAME); \
600fca5c5fSwren romano }
6184cd51bbSwren romano MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETPOSITIONS)
6284cd51bbSwren romano #undef IMPL_GETPOSITIONS
630fca5c5fSwren romano
6484cd51bbSwren romano #define IMPL_GETCOORDINATES(CNAME, C) \
6584cd51bbSwren romano void SparseTensorStorageBase::getCoordinates(std::vector<C> **, uint64_t) { \
6684cd51bbSwren romano FATAL_PIV("getCoordinates" #CNAME); \
670fca5c5fSwren romano }
6884cd51bbSwren romano MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATES)
6984cd51bbSwren romano #undef IMPL_GETCOORDINATES
700fca5c5fSwren romano
71*dc4cfdbbSAart Bik #define IMPL_GETCOORDINATESBUFFER(CNAME, C) \
72*dc4cfdbbSAart Bik void SparseTensorStorageBase::getCoordinatesBuffer(std::vector<C> **, \
73*dc4cfdbbSAart Bik uint64_t) { \
74*dc4cfdbbSAart Bik FATAL_PIV("getCoordinatesBuffer" #CNAME); \
75*dc4cfdbbSAart Bik }
76*dc4cfdbbSAart Bik MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATESBUFFER)
77*dc4cfdbbSAart Bik #undef IMPL_GETCOORDINATESBUFFER
78*dc4cfdbbSAart Bik
790fca5c5fSwren romano #define IMPL_GETVALUES(VNAME, V) \
800fca5c5fSwren romano void SparseTensorStorageBase::getValues(std::vector<V> **) { \
810fca5c5fSwren romano FATAL_PIV("getValues" #VNAME); \
820fca5c5fSwren romano }
830011c0a1Swren romano MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETVALUES)
840fca5c5fSwren romano #undef IMPL_GETVALUES
850fca5c5fSwren romano
860fca5c5fSwren romano #define IMPL_LEXINSERT(VNAME, V) \
870fca5c5fSwren romano void SparseTensorStorageBase::lexInsert(const uint64_t *, V) { \
880fca5c5fSwren romano FATAL_PIV("lexInsert" #VNAME); \
890fca5c5fSwren romano }
900011c0a1Swren romano MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT)
910fca5c5fSwren romano #undef IMPL_LEXINSERT
920fca5c5fSwren romano
930fca5c5fSwren romano #define IMPL_EXPINSERT(VNAME, V) \
940fca5c5fSwren romano void SparseTensorStorageBase::expInsert(uint64_t *, V *, bool *, uint64_t *, \
95ab6334ddSAart Bik uint64_t, uint64_t) { \
960fca5c5fSwren romano FATAL_PIV("expInsert" #VNAME); \
970fca5c5fSwren romano }
980011c0a1Swren romano MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT)
990fca5c5fSwren romano #undef IMPL_EXPINSERT
1000fca5c5fSwren romano
1010fca5c5fSwren romano #undef FATAL_PIV
102