xref: /llvm-project/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp (revision dc4cfdbb8f9f665c1699e6289b6edfbc8d1bb443)
1 //===- StorageBase.cpp - TACO-flavored sparse tensor representation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains method definitions for `SparseTensorStorageBase`.
10 // In particular we want to ensure that the default implementations of
11 // the "partial method specialization" trick aren't inline (since there's
12 // no benefit).
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "mlir/ExecutionEngine/SparseTensor/Storage.h"
17 
18 using namespace mlir::sparse_tensor;
19 
isAllDense(uint64_t lvlRank,const LevelType * lvlTypes)20 static inline bool isAllDense(uint64_t lvlRank, const LevelType *lvlTypes) {
21   for (uint64_t l = 0; l < lvlRank; l++)
22     if (!isDenseLT(lvlTypes[l]))
23       return false;
24   return true;
25 }
26 
SparseTensorStorageBase(uint64_t dimRank,const uint64_t * dimSizes,uint64_t lvlRank,const uint64_t * lvlSizes,const LevelType * lvlTypes,const uint64_t * dim2lvl,const uint64_t * lvl2dim)27 SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT
28     uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
29     const uint64_t *lvlSizes, const LevelType *lvlTypes,
30     const uint64_t *dim2lvl, const uint64_t *lvl2dim)
31     : dimSizes(dimSizes, dimSizes + dimRank),
32       lvlSizes(lvlSizes, lvlSizes + lvlRank),
33       lvlTypes(lvlTypes, lvlTypes + lvlRank),
34       dim2lvlVec(dim2lvl, dim2lvl + lvlRank),
35       lvl2dimVec(lvl2dim, lvl2dim + dimRank),
36       map(dimRank, lvlRank, dim2lvlVec.data(), lvl2dimVec.data()),
37       allDense(isAllDense(lvlRank, lvlTypes)) {
38   assert(dimSizes && lvlSizes && lvlTypes && dim2lvl && lvl2dim);
39   // Validate dim-indexed parameters.
40   assert(dimRank > 0 && "Trivial shape is unsupported");
41   for (uint64_t d = 0; d < dimRank; d++)
42     assert(dimSizes[d] > 0 && "Dimension size zero has trivial storage");
43   // Validate lvl-indexed parameters.
44   assert(lvlRank > 0 && "Trivial shape is unsupported");
45   for (uint64_t l = 0; l < lvlRank; l++) {
46     assert(lvlSizes[l] > 0 && "Level size zero has trivial storage");
47     assert(isDenseLvl(l) || isCompressedLvl(l) || isLooseCompressedLvl(l) ||
48            isSingletonLvl(l) || isNOutOfMLvl(l));
49   }
50 }
51 
52 // Helper macro for wrong "partial method specialization" errors.
53 #define FATAL_PIV(NAME)                                                        \
54   fprintf(stderr, "<P,I,V> type mismatch for: " #NAME);                        \
55   exit(1);
56 
57 #define IMPL_GETPOSITIONS(PNAME, P)                                            \
58   void SparseTensorStorageBase::getPositions(std::vector<P> **, uint64_t) {    \
59     FATAL_PIV("getPositions" #PNAME);                                          \
60   }
61 MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETPOSITIONS)
62 #undef IMPL_GETPOSITIONS
63 
64 #define IMPL_GETCOORDINATES(CNAME, C)                                          \
65   void SparseTensorStorageBase::getCoordinates(std::vector<C> **, uint64_t) {  \
66     FATAL_PIV("getCoordinates" #CNAME);                                        \
67   }
68 MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATES)
69 #undef IMPL_GETCOORDINATES
70 
71 #define IMPL_GETCOORDINATESBUFFER(CNAME, C)                                    \
72   void SparseTensorStorageBase::getCoordinatesBuffer(std::vector<C> **,        \
73                                                      uint64_t) {               \
74     FATAL_PIV("getCoordinatesBuffer" #CNAME);                                  \
75   }
76 MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATESBUFFER)
77 #undef IMPL_GETCOORDINATESBUFFER
78 
79 #define IMPL_GETVALUES(VNAME, V)                                               \
80   void SparseTensorStorageBase::getValues(std::vector<V> **) {                 \
81     FATAL_PIV("getValues" #VNAME);                                             \
82   }
83 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETVALUES)
84 #undef IMPL_GETVALUES
85 
86 #define IMPL_LEXINSERT(VNAME, V)                                               \
87   void SparseTensorStorageBase::lexInsert(const uint64_t *, V) {               \
88     FATAL_PIV("lexInsert" #VNAME);                                             \
89   }
90 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT)
91 #undef IMPL_LEXINSERT
92 
93 #define IMPL_EXPINSERT(VNAME, V)                                               \
94   void SparseTensorStorageBase::expInsert(uint64_t *, V *, bool *, uint64_t *, \
95                                           uint64_t, uint64_t) {                \
96     FATAL_PIV("expInsert" #VNAME);                                             \
97   }
98 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT)
99 #undef IMPL_EXPINSERT
100 
101 #undef FATAL_PIV
102