xref: /llvm-project/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp (revision dc4cfdbb8f9f665c1699e6289b6edfbc8d1bb443)
10fca5c5fSwren romano //===- StorageBase.cpp - TACO-flavored sparse tensor representation -------===//
20fca5c5fSwren romano //
30fca5c5fSwren romano // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40fca5c5fSwren romano // See https://llvm.org/LICENSE.txt for license information.
50fca5c5fSwren romano // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60fca5c5fSwren romano //
70fca5c5fSwren romano //===----------------------------------------------------------------------===//
80fca5c5fSwren romano //
90fca5c5fSwren romano // This file contains method definitions for `SparseTensorStorageBase`.
100fca5c5fSwren romano // In particular we want to ensure that the default implementations of
110fca5c5fSwren romano // the "partial method specialization" trick aren't inline (since there's
12427f120fSAart Bik // no benefit).
130fca5c5fSwren romano //
140fca5c5fSwren romano //===----------------------------------------------------------------------===//
150fca5c5fSwren romano 
160fca5c5fSwren romano #include "mlir/ExecutionEngine/SparseTensor/Storage.h"
170fca5c5fSwren romano 
180fca5c5fSwren romano using namespace mlir::sparse_tensor;
190fca5c5fSwren romano 
isAllDense(uint64_t lvlRank,const LevelType * lvlTypes)206fb7c2d7SAart Bik static inline bool isAllDense(uint64_t lvlRank, const LevelType *lvlTypes) {
216fb7c2d7SAart Bik   for (uint64_t l = 0; l < lvlRank; l++)
226fb7c2d7SAart Bik     if (!isDenseLT(lvlTypes[l]))
236fb7c2d7SAart Bik       return false;
246fb7c2d7SAart Bik   return true;
256fb7c2d7SAart Bik }
266fb7c2d7SAart Bik 
SparseTensorStorageBase(uint64_t dimRank,const uint64_t * dimSizes,uint64_t lvlRank,const uint64_t * lvlSizes,const LevelType * lvlTypes,const uint64_t * dim2lvl,const uint64_t * lvl2dim)27c518745bSwren romano SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT
28c518745bSwren romano     uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
291944c4f7SAart Bik     const uint64_t *lvlSizes, const LevelType *lvlTypes,
30db1d40f3SAart Bik     const uint64_t *dim2lvl, const uint64_t *lvl2dim)
31c518745bSwren romano     : dimSizes(dimSizes, dimSizes + dimRank),
32c518745bSwren romano       lvlSizes(lvlSizes, lvlSizes + lvlRank),
33c518745bSwren romano       lvlTypes(lvlTypes, lvlTypes + lvlRank),
34306f4c30SAart Bik       dim2lvlVec(dim2lvl, dim2lvl + lvlRank),
35306f4c30SAart Bik       lvl2dimVec(lvl2dim, lvl2dim + dimRank),
366fb7c2d7SAart Bik       map(dimRank, lvlRank, dim2lvlVec.data(), lvl2dimVec.data()),
376fb7c2d7SAart Bik       allDense(isAllDense(lvlRank, lvlTypes)) {
38db1d40f3SAart Bik   assert(dimSizes && lvlSizes && lvlTypes && dim2lvl && lvl2dim);
39c518745bSwren romano   // Validate dim-indexed parameters.
40c518745bSwren romano   assert(dimRank > 0 && "Trivial shape is unsupported");
416fb7c2d7SAart Bik   for (uint64_t d = 0; d < dimRank; d++)
42c518745bSwren romano     assert(dimSizes[d] > 0 && "Dimension size zero has trivial storage");
43427f120fSAart Bik   // Validate lvl-indexed parameters.
44c518745bSwren romano   assert(lvlRank > 0 && "Trivial shape is unsupported");
456fb7c2d7SAart Bik   for (uint64_t l = 0; l < lvlRank; l++) {
46c518745bSwren romano     assert(lvlSizes[l] > 0 && "Level size zero has trivial storage");
47e6005d5aSAart Bik     assert(isDenseLvl(l) || isCompressedLvl(l) || isLooseCompressedLvl(l) ||
48e5924d64SYinying Li            isSingletonLvl(l) || isNOutOfMLvl(l));
490fca5c5fSwren romano   }
50427f120fSAart Bik }
510fca5c5fSwren romano 
52233c3e6cSAart Bik // Helper macro for wrong "partial method specialization" errors.
530fca5c5fSwren romano #define FATAL_PIV(NAME)                                                        \
541c2456d6SAart Bik   fprintf(stderr, "<P,I,V> type mismatch for: " #NAME);                        \
551c2456d6SAart Bik   exit(1);
560fca5c5fSwren romano 
5784cd51bbSwren romano #define IMPL_GETPOSITIONS(PNAME, P)                                            \
5884cd51bbSwren romano   void SparseTensorStorageBase::getPositions(std::vector<P> **, uint64_t) {    \
5984cd51bbSwren romano     FATAL_PIV("getPositions" #PNAME);                                          \
600fca5c5fSwren romano   }
6184cd51bbSwren romano MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETPOSITIONS)
6284cd51bbSwren romano #undef IMPL_GETPOSITIONS
630fca5c5fSwren romano 
6484cd51bbSwren romano #define IMPL_GETCOORDINATES(CNAME, C)                                          \
6584cd51bbSwren romano   void SparseTensorStorageBase::getCoordinates(std::vector<C> **, uint64_t) {  \
6684cd51bbSwren romano     FATAL_PIV("getCoordinates" #CNAME);                                        \
670fca5c5fSwren romano   }
6884cd51bbSwren romano MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATES)
6984cd51bbSwren romano #undef IMPL_GETCOORDINATES
700fca5c5fSwren romano 
71*dc4cfdbbSAart Bik #define IMPL_GETCOORDINATESBUFFER(CNAME, C)                                    \
72*dc4cfdbbSAart Bik   void SparseTensorStorageBase::getCoordinatesBuffer(std::vector<C> **,        \
73*dc4cfdbbSAart Bik                                                      uint64_t) {               \
74*dc4cfdbbSAart Bik     FATAL_PIV("getCoordinatesBuffer" #CNAME);                                  \
75*dc4cfdbbSAart Bik   }
76*dc4cfdbbSAart Bik MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATESBUFFER)
77*dc4cfdbbSAart Bik #undef IMPL_GETCOORDINATESBUFFER
78*dc4cfdbbSAart Bik 
790fca5c5fSwren romano #define IMPL_GETVALUES(VNAME, V)                                               \
800fca5c5fSwren romano   void SparseTensorStorageBase::getValues(std::vector<V> **) {                 \
810fca5c5fSwren romano     FATAL_PIV("getValues" #VNAME);                                             \
820fca5c5fSwren romano   }
830011c0a1Swren romano MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETVALUES)
840fca5c5fSwren romano #undef IMPL_GETVALUES
850fca5c5fSwren romano 
860fca5c5fSwren romano #define IMPL_LEXINSERT(VNAME, V)                                               \
870fca5c5fSwren romano   void SparseTensorStorageBase::lexInsert(const uint64_t *, V) {               \
880fca5c5fSwren romano     FATAL_PIV("lexInsert" #VNAME);                                             \
890fca5c5fSwren romano   }
900011c0a1Swren romano MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT)
910fca5c5fSwren romano #undef IMPL_LEXINSERT
920fca5c5fSwren romano 
930fca5c5fSwren romano #define IMPL_EXPINSERT(VNAME, V)                                               \
940fca5c5fSwren romano   void SparseTensorStorageBase::expInsert(uint64_t *, V *, bool *, uint64_t *, \
95ab6334ddSAart Bik                                           uint64_t, uint64_t) {                \
960fca5c5fSwren romano     FATAL_PIV("expInsert" #VNAME);                                             \
970fca5c5fSwren romano   }
980011c0a1Swren romano MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT)
990fca5c5fSwren romano #undef IMPL_EXPINSERT
1000fca5c5fSwren romano 
1010fca5c5fSwren romano #undef FATAL_PIV
102