xref: /llvm-project/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h (revision dc4cfdbb8f9f665c1699e6289b6edfbc8d1bb443)
1 //===- Storage.h - TACO-flavored sparse tensor representation ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains definitions for the following classes:
10 //
11 // * `SparseTensorStorageBase`
12 // * `SparseTensorStorage<P, C, V>`
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef MLIR_EXECUTIONENGINE_SPARSETENSOR_STORAGE_H
17 #define MLIR_EXECUTIONENGINE_SPARSETENSOR_STORAGE_H
18 
19 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
20 #include "mlir/ExecutionEngine/Float16bits.h"
21 #include "mlir/ExecutionEngine/SparseTensor/ArithmeticUtils.h"
22 #include "mlir/ExecutionEngine/SparseTensor/COO.h"
23 #include "mlir/ExecutionEngine/SparseTensor/MapRef.h"
24 
25 namespace mlir {
26 namespace sparse_tensor {
27 
28 //===----------------------------------------------------------------------===//
29 //
30 //  SparseTensorStorage Classes
31 //
32 //===----------------------------------------------------------------------===//
33 
34 /// Abstract base class for `SparseTensorStorage<P,C,V>`. This class
35 /// takes responsibility for all the `<P,C,V>`-independent aspects
36 /// of the tensor (e.g., sizes, sparsity, mapping). In addition,
37 /// we use function overloading to implement "partial" method
38 /// specialization, which the C-API relies on to catch type errors
39 /// arising from our use of opaque pointers.
40 ///
41 /// Because this class forms a bridge between the denotational semantics
42 /// of "tensors" and the operational semantics of how we store and
43 /// compute with them, it also distinguishes between two different
44 /// coordinate spaces (and their associated rank, sizes, etc).
45 /// Denotationally, we have the *dimensions* of the tensor represented
46 /// by this object.  Operationally, we have the *levels* of the storage
47 /// representation itself.
48 ///
49 /// The *size* of an axis is the cardinality of possible coordinate
50 /// values along that axis (regardless of which coordinates have stored
51 /// element values). As such, each size must be non-zero since if any
52 /// axis has size-zero then the whole tensor would have trivial storage
53 /// (since there are no possible coordinates). Thus we use the plural
54 /// term *sizes* for a collection of non-zero cardinalities, and use
55 /// this term whenever referring to run-time cardinalities. Whereas we
56 /// use the term *shape* for a collection of compile-time cardinalities,
57 /// where zero is used to indicate cardinalities which are dynamic (i.e.,
58 /// unknown/unspecified at compile-time). At run-time, these dynamic
59 /// cardinalities will be inferred from or checked against sizes otherwise
60 /// specified. Thus, dynamic cardinalities always have an "immutable but
61 /// unknown" value; so the term "dynamic" should not be taken to indicate
62 /// run-time mutability.
63 class SparseTensorStorageBase {
64 protected:
65   SparseTensorStorageBase(const SparseTensorStorageBase &) = default;
66   SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
67 
68 public:
69   /// Constructs a new sparse-tensor storage object with the given encoding.
70   SparseTensorStorageBase(uint64_t dimRank, const uint64_t *dimSizes,
71                           uint64_t lvlRank, const uint64_t *lvlSizes,
72                           const LevelType *lvlTypes, const uint64_t *dim2lvl,
73                           const uint64_t *lvl2dim);
74   virtual ~SparseTensorStorageBase() = default;
75 
76   /// Gets the number of tensor-dimensions.
getDimRank()77   uint64_t getDimRank() const { return dimSizes.size(); }
78 
79   /// Gets the number of storage-levels.
getLvlRank()80   uint64_t getLvlRank() const { return lvlSizes.size(); }
81 
82   /// Gets the tensor-dimension sizes array.
getDimSizes()83   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
84 
85   /// Safely looks up the size of the given tensor-dimension.
getDimSize(uint64_t d)86   uint64_t getDimSize(uint64_t d) const {
87     assert(d < getDimRank());
88     return dimSizes[d];
89   }
90 
91   /// Gets the storage-level sizes array.
getLvlSizes()92   const std::vector<uint64_t> &getLvlSizes() const { return lvlSizes; }
93 
94   /// Safely looks up the size of the given storage-level.
getLvlSize(uint64_t l)95   uint64_t getLvlSize(uint64_t l) const {
96     assert(l < getLvlRank());
97     return lvlSizes[l];
98   }
99 
100   /// Gets the level-types array.
getLvlTypes()101   const std::vector<LevelType> &getLvlTypes() const { return lvlTypes; }
102 
103   /// Safely looks up the type of the given level.
getLvlType(uint64_t l)104   LevelType getLvlType(uint64_t l) const {
105     assert(l < getLvlRank());
106     return lvlTypes[l];
107   }
108 
109   /// Safely checks if the level uses dense storage.
isDenseLvl(uint64_t l)110   bool isDenseLvl(uint64_t l) const { return isDenseLT(getLvlType(l)); }
111 
112   /// Safely checks if the level uses compressed storage.
isCompressedLvl(uint64_t l)113   bool isCompressedLvl(uint64_t l) const {
114     return isCompressedLT(getLvlType(l));
115   }
116 
117   /// Safely checks if the level uses loose compressed storage.
isLooseCompressedLvl(uint64_t l)118   bool isLooseCompressedLvl(uint64_t l) const {
119     return isLooseCompressedLT(getLvlType(l));
120   }
121 
122   /// Safely checks if the level uses singleton storage.
isSingletonLvl(uint64_t l)123   bool isSingletonLvl(uint64_t l) const { return isSingletonLT(getLvlType(l)); }
124 
125   /// Safely checks if the level uses n out of m storage.
isNOutOfMLvl(uint64_t l)126   bool isNOutOfMLvl(uint64_t l) const { return isNOutOfMLT(getLvlType(l)); }
127 
128   /// Safely checks if the level is ordered.
isOrderedLvl(uint64_t l)129   bool isOrderedLvl(uint64_t l) const { return isOrderedLT(getLvlType(l)); }
130 
131   /// Safely checks if the level is unique.
isUniqueLvl(uint64_t l)132   bool isUniqueLvl(uint64_t l) const { return isUniqueLT(getLvlType(l)); }
133 
134   /// Gets positions-overhead storage for the given level.
135 #define DECL_GETPOSITIONS(PNAME, P)                                            \
136   virtual void getPositions(std::vector<P> **, uint64_t);
137   MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DECL_GETPOSITIONS)
138 #undef DECL_GETPOSITIONS
139 
140   /// Gets coordinates-overhead storage for the given level.
141 #define DECL_GETCOORDINATES(INAME, C)                                          \
142   virtual void getCoordinates(std::vector<C> **, uint64_t);
143   MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DECL_GETCOORDINATES)
144 #undef DECL_GETCOORDINATES
145 
146   /// Gets coordinates-overhead storage buffer for the given level.
147 #define DECL_GETCOORDINATESBUFFER(INAME, C)                                    \
148   virtual void getCoordinatesBuffer(std::vector<C> **, uint64_t);
149   MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DECL_GETCOORDINATESBUFFER)
150 #undef DECL_GETCOORDINATESBUFFER
151 
152   /// Gets primary storage.
153 #define DECL_GETVALUES(VNAME, V) virtual void getValues(std::vector<V> **);
154   MLIR_SPARSETENSOR_FOREVERY_V(DECL_GETVALUES)
155 #undef DECL_GETVALUES
156 
157   /// Element-wise insertion in lexicographic coordinate order. The first
158   /// argument is the level-coordinates for the value being inserted.
159 #define DECL_LEXINSERT(VNAME, V) virtual void lexInsert(const uint64_t *, V);
160   MLIR_SPARSETENSOR_FOREVERY_V(DECL_LEXINSERT)
161 #undef DECL_LEXINSERT
162 
163   /// Expanded insertion.  Note that this method resets the
164   /// values/filled-switch array back to all-zero/false while only
165   /// iterating over the nonzero elements.
166 #define DECL_EXPINSERT(VNAME, V)                                               \
167   virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t,        \
168                          uint64_t);
169   MLIR_SPARSETENSOR_FOREVERY_V(DECL_EXPINSERT)
170 #undef DECL_EXPINSERT
171 
172   /// Finalizes lexicographic insertions.
173   virtual void endLexInsert() = 0;
174 
175 private:
176   const std::vector<uint64_t> dimSizes;
177   const std::vector<uint64_t> lvlSizes;
178   const std::vector<LevelType> lvlTypes;
179   const std::vector<uint64_t> dim2lvlVec;
180   const std::vector<uint64_t> lvl2dimVec;
181 
182 protected:
183   const MapRef map; // non-owning pointers into dim2lvl/lvl2dim vectors
184   const bool allDense;
185 };
186 
187 /// A memory-resident sparse tensor using a storage scheme based on
188 /// per-level sparse/dense annotations. This data structure provides
189 /// a bufferized form of a sparse tensor type. In contrast to generating
190 /// setup methods for each differently annotated sparse tensor, this
191 /// method provides a convenient "one-size-fits-all" solution that simply
192 /// takes an input tensor and annotations to implement all required setup
193 /// in a general manner.
194 template <typename P, typename C, typename V>
195 class SparseTensorStorage final : public SparseTensorStorageBase {
196   /// Private constructor to share code between the other constructors.
197   /// Beware that the object is not necessarily guaranteed to be in a
198   /// valid state after this constructor alone; e.g., `isCompressedLvl(l)`
199   /// doesn't entail `!(positions[l].empty())`.
SparseTensorStorage(uint64_t dimRank,const uint64_t * dimSizes,uint64_t lvlRank,const uint64_t * lvlSizes,const LevelType * lvlTypes,const uint64_t * dim2lvl,const uint64_t * lvl2dim)200   SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes,
201                       uint64_t lvlRank, const uint64_t *lvlSizes,
202                       const LevelType *lvlTypes, const uint64_t *dim2lvl,
203                       const uint64_t *lvl2dim)
204       : SparseTensorStorageBase(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
205                                 dim2lvl, lvl2dim),
206         positions(lvlRank), coordinates(lvlRank), lvlCursor(lvlRank) {}
207 
208 public:
209   /// Constructs a sparse tensor with the given encoding, and allocates
210   /// overhead storage according to some simple heuristics. When lvlCOO
211   /// is set, the sparse tensor initializes with the contents from that
212   /// data structure. Otherwise, an empty sparse tensor results.
213   SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes,
214                       uint64_t lvlRank, const uint64_t *lvlSizes,
215                       const LevelType *lvlTypes, const uint64_t *dim2lvl,
216                       const uint64_t *lvl2dim, SparseTensorCOO<V> *lvlCOO);
217 
218   /// Constructs a sparse tensor with the given encoding, and initializes
219   /// the contents from the level buffers. The constructor assumes that the
220   /// data provided by `lvlBufs` can be directly used to interpret the result
221   /// sparse tensor and performs no integrity test on the input data.
222   SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes,
223                       uint64_t lvlRank, const uint64_t *lvlSizes,
224                       const LevelType *lvlTypes, const uint64_t *dim2lvl,
225                       const uint64_t *lvl2dim, const intptr_t *lvlBufs);
226 
227   /// Allocates a new empty sparse tensor.
228   static SparseTensorStorage<P, C, V> *
229   newEmpty(uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
230            const uint64_t *lvlSizes, const LevelType *lvlTypes,
231            const uint64_t *dim2lvl, const uint64_t *lvl2dim);
232 
233   /// Allocates a new sparse tensor and initializes it from the given COO.
234   static SparseTensorStorage<P, C, V> *
235   newFromCOO(uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
236              const uint64_t *lvlSizes, const LevelType *lvlTypes,
237              const uint64_t *dim2lvl, const uint64_t *lvl2dim,
238              SparseTensorCOO<V> *lvlCOO);
239 
240   /// Allocates a new sparse tensor and initialize it from the given buffers.
241   static SparseTensorStorage<P, C, V> *
242   newFromBuffers(uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
243                  const uint64_t *lvlSizes, const LevelType *lvlTypes,
244                  const uint64_t *dim2lvl, const uint64_t *lvl2dim,
245                  uint64_t srcRank, const intptr_t *buffers);
246 
247   ~SparseTensorStorage() final = default;
248 
249   /// Partially specialize these getter methods based on template types.
getPositions(std::vector<P> ** out,uint64_t lvl)250   void getPositions(std::vector<P> **out, uint64_t lvl) final {
251     assert(out && "Received nullptr for out parameter");
252     assert(lvl < getLvlRank());
253     *out = &positions[lvl];
254   }
getCoordinates(std::vector<C> ** out,uint64_t lvl)255   void getCoordinates(std::vector<C> **out, uint64_t lvl) final {
256     assert(out && "Received nullptr for out parameter");
257     assert(lvl < getLvlRank());
258     *out = &coordinates[lvl];
259   }
getCoordinatesBuffer(std::vector<C> ** out,uint64_t lvl)260   void getCoordinatesBuffer(std::vector<C> **out, uint64_t lvl) final {
261     assert(out && "Received nullptr for out parameter");
262     assert(lvl < getLvlRank());
263     // Note that the sparse tensor support library always stores COO in SoA
264     // format, even when AoS is requested. This is never an issue, since all
265     // actual code/library generation requests "views" into the coordinate
266     // storage for the individual levels, which is trivially provided for
267     // both AoS and SoA (as well as all the other storage formats). The only
268     // exception is when the buffer version of coordinate storage is requested
269     // (currently only for printing). In that case, we do the following
270     // potentially expensive transformation to provide that view. If this
271     // operation becomes more common beyond debugging, we should consider
272     // implementing proper AoS in the support library as well.
273     uint64_t lvlRank = getLvlRank();
274     uint64_t nnz = values.size();
275     crdBuffer.clear();
276     crdBuffer.reserve(nnz * (lvlRank - lvl));
277     for (uint64_t i = 0; i < nnz; i++) {
278       for (uint64_t l = lvl; l < lvlRank; l++) {
279         assert(i < coordinates[l].size());
280         crdBuffer.push_back(coordinates[l][i]);
281       }
282     }
283     *out = &crdBuffer;
284   }
getValues(std::vector<V> ** out)285   void getValues(std::vector<V> **out) final {
286     assert(out && "Received nullptr for out parameter");
287     *out = &values;
288   }
289 
290   /// Partially specialize lexicographical insertions based on template types.
lexInsert(const uint64_t * lvlCoords,V val)291   void lexInsert(const uint64_t *lvlCoords, V val) final {
292     assert(lvlCoords);
293     if (allDense) {
294       uint64_t lvlRank = getLvlRank();
295       uint64_t valIdx = 0;
296       // Linearize the address.
297       for (uint64_t l = 0; l < lvlRank; l++)
298         valIdx = valIdx * getLvlSize(l) + lvlCoords[l];
299       values[valIdx] = val;
300       return;
301     }
302     // First, wrap up pending insertion path.
303     uint64_t diffLvl = 0;
304     uint64_t full = 0;
305     if (!values.empty()) {
306       diffLvl = lexDiff(lvlCoords);
307       endPath(diffLvl + 1);
308       full = lvlCursor[diffLvl] + 1;
309     }
310     // Then continue with insertion path.
311     insPath(lvlCoords, diffLvl, full, val);
312   }
313 
314   /// Partially specialize expanded insertions based on template types.
expInsert(uint64_t * lvlCoords,V * values,bool * filled,uint64_t * added,uint64_t count,uint64_t expsz)315   void expInsert(uint64_t *lvlCoords, V *values, bool *filled, uint64_t *added,
316                  uint64_t count, uint64_t expsz) final {
317     assert((lvlCoords && values && filled && added) && "Received nullptr");
318     if (count == 0)
319       return;
320     // Sort.
321     std::sort(added, added + count);
322     // Restore insertion path for first insert.
323     const uint64_t lastLvl = getLvlRank() - 1;
324     uint64_t c = added[0];
325     assert(c <= expsz);
326     assert(filled[c] && "added coordinate is not filled");
327     lvlCoords[lastLvl] = c;
328     lexInsert(lvlCoords, values[c]);
329     values[c] = 0;
330     filled[c] = false;
331     // Subsequent insertions are quick.
332     for (uint64_t i = 1; i < count; i++) {
333       assert(c < added[i] && "non-lexicographic insertion");
334       c = added[i];
335       assert(c <= expsz);
336       assert(filled[c] && "added coordinate is not filled");
337       lvlCoords[lastLvl] = c;
338       insPath(lvlCoords, lastLvl, added[i - 1] + 1, values[c]);
339       values[c] = 0;
340       filled[c] = false;
341     }
342   }
343 
344   /// Finalizes lexicographic insertions.
endLexInsert()345   void endLexInsert() final {
346     if (!allDense) {
347       if (values.empty())
348         finalizeSegment(0);
349       else
350         endPath(0);
351     }
352   }
353 
354   /// Sort the unordered tensor in place, the method assumes that it is
355   /// an unordered COO tensor.
sortInPlace()356   void sortInPlace() {
357     uint64_t nnz = values.size();
358 #ifndef NDEBUG
359     for (uint64_t l = 0; l < getLvlRank(); l++)
360       assert(nnz == coordinates[l].size());
361 #endif
362 
363     // In-place permutation.
364     auto applyPerm = [this](std::vector<uint64_t> &perm) {
365       uint64_t length = perm.size();
366       uint64_t lvlRank = getLvlRank();
367       // Cache for the current level coordinates.
368       std::vector<P> lvlCrds(lvlRank);
369       for (uint64_t i = 0; i < length; i++) {
370         uint64_t current = i;
371         if (i != perm[current]) {
372           for (uint64_t l = 0; l < lvlRank; l++)
373             lvlCrds[l] = coordinates[l][i];
374           V val = values[i];
375           // Deals with a permutation cycle.
376           while (i != perm[current]) {
377             uint64_t next = perm[current];
378             // Swaps the level coordinates and value.
379             for (uint64_t l = 0; l < lvlRank; l++)
380               coordinates[l][current] = coordinates[l][next];
381             values[current] = values[next];
382             perm[current] = current;
383             current = next;
384           }
385           for (uint64_t l = 0; l < lvlRank; l++)
386             coordinates[l][current] = lvlCrds[l];
387           values[current] = val;
388           perm[current] = current;
389         }
390       }
391     };
392 
393     std::vector<uint64_t> sortedIdx(nnz, 0);
394     for (uint64_t i = 0; i < nnz; i++)
395       sortedIdx[i] = i;
396 
397     std::sort(sortedIdx.begin(), sortedIdx.end(),
398               [this](uint64_t lhs, uint64_t rhs) {
399                 for (uint64_t l = 0; l < getLvlRank(); l++) {
400                   if (coordinates[l][lhs] == coordinates[l][rhs])
401                     continue;
402                   return coordinates[l][lhs] < coordinates[l][rhs];
403                 }
404                 assert(lhs == rhs && "duplicate coordinates");
405                 return false;
406               });
407 
408     applyPerm(sortedIdx);
409   }
410 
411 private:
412   /// Appends coordinate `crd` to level `lvl`, in the semantically
413   /// general sense.  For non-dense levels, that means appending to the
414   /// `coordinates[lvl]` array, checking that `crd` is representable in
415   /// the `C` type; however, we do not verify other semantic requirements
416   /// (e.g., that `crd` is in bounds for `lvlSizes[lvl]`, and not previously
417   /// occurring in the same segment).  For dense levels, this method instead
418   /// appends the appropriate number of zeros to the `values` array, where
419   /// `full` is the number of "entries" already written to `values` for this
420   /// segment (aka one after the highest coordinate previously appended).
appendCrd(uint64_t lvl,uint64_t full,uint64_t crd)421   void appendCrd(uint64_t lvl, uint64_t full, uint64_t crd) {
422     if (!isDenseLvl(lvl)) {
423       assert(isCompressedLvl(lvl) || isLooseCompressedLvl(lvl) ||
424              isSingletonLvl(lvl) || isNOutOfMLvl(lvl));
425       coordinates[lvl].push_back(detail::checkOverflowCast<C>(crd));
426     } else { // Dense level.
427       assert(crd >= full && "Coordinate was already filled");
428       if (crd == full)
429         return; // Short-circuit, since it'll be a nop.
430       if (lvl + 1 == getLvlRank())
431         values.insert(values.end(), crd - full, 0);
432       else
433         finalizeSegment(lvl + 1, 0, crd - full);
434     }
435   }
436 
437   /// Computes the assembled-size associated with the `l`-th level,
438   /// given the assembled-size associated with the `(l-1)`-th level.
assembledSize(uint64_t parentSz,uint64_t l)439   uint64_t assembledSize(uint64_t parentSz, uint64_t l) const {
440     if (isCompressedLvl(l))
441       return positions[l][parentSz];
442     if (isLooseCompressedLvl(l))
443       return positions[l][2 * parentSz - 1];
444     if (isSingletonLvl(l) || isNOutOfMLvl(l))
445       return parentSz; // new size same as the parent
446     assert(isDenseLvl(l));
447     return parentSz * getLvlSize(l);
448   }
449 
450   /// Initializes sparse tensor storage scheme from a memory-resident sparse
451   /// tensor in coordinate scheme. This method prepares the positions and
452   /// coordinates arrays under the given per-level dense/sparse annotations.
fromCOO(const std::vector<Element<V>> & lvlElements,uint64_t lo,uint64_t hi,uint64_t l)453   void fromCOO(const std::vector<Element<V>> &lvlElements, uint64_t lo,
454                uint64_t hi, uint64_t l) {
455     const uint64_t lvlRank = getLvlRank();
456     assert(l <= lvlRank && hi <= lvlElements.size());
457     // Once levels are exhausted, insert the numerical values.
458     if (l == lvlRank) {
459       assert(lo < hi);
460       values.push_back(lvlElements[lo].value);
461       return;
462     }
463     // Visit all elements in this interval.
464     uint64_t full = 0;
465     while (lo < hi) { // If `hi` is unchanged, then `lo < lvlElements.size()`.
466       // Find segment in interval with same coordinate at this level.
467       const uint64_t c = lvlElements[lo].coords[l];
468       uint64_t seg = lo + 1;
469       if (isUniqueLvl(l))
470         while (seg < hi && lvlElements[seg].coords[l] == c)
471           seg++;
472       // Handle segment in interval for sparse or dense level.
473       appendCrd(l, full, c);
474       full = c + 1;
475       fromCOO(lvlElements, lo, seg, l + 1);
476       // And move on to next segment in interval.
477       lo = seg;
478     }
479     // Finalize the sparse position structure at this level.
480     finalizeSegment(l, full);
481   }
482 
483   /// Finalizes the sparse position structure at this level.
484   void finalizeSegment(uint64_t l, uint64_t full = 0, uint64_t count = 1) {
485     if (count == 0)
486       return; // Short-circuit, since it'll be a nop.
487     if (isCompressedLvl(l)) {
488       uint64_t pos = coordinates[l].size();
489       positions[l].insert(positions[l].end(), count,
490                           detail::checkOverflowCast<P>(pos));
491     } else if (isLooseCompressedLvl(l)) {
492       // Finish this level, and push pairs for the empty ones, and one
493       // more for next level. Note that this always leaves one extra
494       // unused element at the end.
495       uint64_t pos = coordinates[l].size();
496       positions[l].insert(positions[l].end(), 2 * count,
497                           detail::checkOverflowCast<P>(pos));
498     } else if (isSingletonLvl(l) || isNOutOfMLvl(l)) {
499       return; // Nothing to finalize.
500     } else {  // Dense dimension.
501       assert(isDenseLvl(l));
502       const uint64_t sz = getLvlSizes()[l];
503       assert(sz >= full && "Segment is overfull");
504       count = detail::checkedMul(count, sz - full);
505       // For dense storage we must enumerate all the remaining coordinates
506       // in this level (i.e., coordinates after the last non-zero
507       // element), and either fill in their zero values or else recurse
508       // to finalize some deeper level.
509       if (l + 1 == getLvlRank())
510         values.insert(values.end(), count, 0);
511       else
512         finalizeSegment(l + 1, 0, count);
513     }
514   }
515 
516   /// Wraps up a single insertion path, inner to outer.
endPath(uint64_t diffLvl)517   void endPath(uint64_t diffLvl) {
518     const uint64_t lvlRank = getLvlRank();
519     const uint64_t lastLvl = lvlRank - 1;
520     assert(diffLvl <= lvlRank);
521     const uint64_t stop = lvlRank - diffLvl;
522     for (uint64_t i = 0; i < stop; i++) {
523       const uint64_t l = lastLvl - i;
524       finalizeSegment(l, lvlCursor[l] + 1);
525     }
526   }
527 
528   /// Continues a single insertion path, outer to inner. The first
529   /// argument is the level-coordinates for the value being inserted.
insPath(const uint64_t * lvlCoords,uint64_t diffLvl,uint64_t full,V val)530   void insPath(const uint64_t *lvlCoords, uint64_t diffLvl, uint64_t full,
531                V val) {
532     const uint64_t lvlRank = getLvlRank();
533     assert(diffLvl <= lvlRank);
534     for (uint64_t l = diffLvl; l < lvlRank; l++) {
535       const uint64_t c = lvlCoords[l];
536       appendCrd(l, full, c);
537       full = 0;
538       lvlCursor[l] = c;
539     }
540     values.push_back(val);
541   }
542 
543   /// Finds the lexicographically first level where the level-coordinates
544   /// in the argument differ from those in the current cursor.
lexDiff(const uint64_t * lvlCoords)545   uint64_t lexDiff(const uint64_t *lvlCoords) const {
546     const uint64_t lvlRank = getLvlRank();
547     for (uint64_t l = 0; l < lvlRank; l++) {
548       const auto crd = lvlCoords[l];
549       const auto cur = lvlCursor[l];
550       if (crd > cur || (crd == cur && !isUniqueLvl(l)) ||
551           (crd < cur && !isOrderedLvl(l))) {
552         return l;
553       }
554       if (crd < cur) {
555         assert(false && "non-lexicographic insertion");
556         return -1u;
557       }
558     }
559     assert(false && "duplicate insertion");
560     return -1u;
561   }
562 
563   // Sparse tensor storage components.
564   std::vector<std::vector<P>> positions;
565   std::vector<std::vector<C>> coordinates;
566   std::vector<V> values;
567 
568   // Auxiliary data structures.
569   std::vector<uint64_t> lvlCursor;
570   std::vector<C> crdBuffer; // just for AoS view
571 };
572 
573 //===----------------------------------------------------------------------===//
574 //
575 //  SparseTensorStorage Factories
576 //
577 //===----------------------------------------------------------------------===//
578 
579 template <typename P, typename C, typename V>
newEmpty(uint64_t dimRank,const uint64_t * dimSizes,uint64_t lvlRank,const uint64_t * lvlSizes,const LevelType * lvlTypes,const uint64_t * dim2lvl,const uint64_t * lvl2dim)580 SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newEmpty(
581     uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
582     const uint64_t *lvlSizes, const LevelType *lvlTypes,
583     const uint64_t *dim2lvl, const uint64_t *lvl2dim) {
584   SparseTensorCOO<V> *noLvlCOO = nullptr;
585   return new SparseTensorStorage<P, C, V>(dimRank, dimSizes, lvlRank, lvlSizes,
586                                           lvlTypes, dim2lvl, lvl2dim, noLvlCOO);
587 }
588 
589 template <typename P, typename C, typename V>
newFromCOO(uint64_t dimRank,const uint64_t * dimSizes,uint64_t lvlRank,const uint64_t * lvlSizes,const LevelType * lvlTypes,const uint64_t * dim2lvl,const uint64_t * lvl2dim,SparseTensorCOO<V> * lvlCOO)590 SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newFromCOO(
591     uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
592     const uint64_t *lvlSizes, const LevelType *lvlTypes,
593     const uint64_t *dim2lvl, const uint64_t *lvl2dim,
594     SparseTensorCOO<V> *lvlCOO) {
595   assert(lvlCOO);
596   return new SparseTensorStorage<P, C, V>(dimRank, dimSizes, lvlRank, lvlSizes,
597                                           lvlTypes, dim2lvl, lvl2dim, lvlCOO);
598 }
599 
600 template <typename P, typename C, typename V>
newFromBuffers(uint64_t dimRank,const uint64_t * dimSizes,uint64_t lvlRank,const uint64_t * lvlSizes,const LevelType * lvlTypes,const uint64_t * dim2lvl,const uint64_t * lvl2dim,uint64_t srcRank,const intptr_t * buffers)601 SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newFromBuffers(
602     uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
603     const uint64_t *lvlSizes, const LevelType *lvlTypes,
604     const uint64_t *dim2lvl, const uint64_t *lvl2dim, uint64_t srcRank,
605     const intptr_t *buffers) {
606   return new SparseTensorStorage<P, C, V>(dimRank, dimSizes, lvlRank, lvlSizes,
607                                           lvlTypes, dim2lvl, lvl2dim, buffers);
608 }
609 
610 //===----------------------------------------------------------------------===//
611 //
612 //  SparseTensorStorage Constructors
613 //
614 //===----------------------------------------------------------------------===//
615 
616 template <typename P, typename C, typename V>
SparseTensorStorage(uint64_t dimRank,const uint64_t * dimSizes,uint64_t lvlRank,const uint64_t * lvlSizes,const LevelType * lvlTypes,const uint64_t * dim2lvl,const uint64_t * lvl2dim,SparseTensorCOO<V> * lvlCOO)617 SparseTensorStorage<P, C, V>::SparseTensorStorage(
618     uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
619     const uint64_t *lvlSizes, const LevelType *lvlTypes,
620     const uint64_t *dim2lvl, const uint64_t *lvl2dim,
621     SparseTensorCOO<V> *lvlCOO)
622     : SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
623                           dim2lvl, lvl2dim) {
624   // Provide hints on capacity of positions and coordinates.
625   // TODO: needs much fine-tuning based on actual sparsity; currently
626   // we reserve position/coordinate space based on all previous dense
627   // levels, which works well up to first sparse level; but we should
628   // really use nnz and dense/sparse distribution.
629   uint64_t sz = 1;
630   for (uint64_t l = 0; l < lvlRank; l++) {
631     if (isCompressedLvl(l)) {
632       positions[l].reserve(sz + 1);
633       positions[l].push_back(0);
634       coordinates[l].reserve(sz);
635       sz = 1;
636     } else if (isLooseCompressedLvl(l)) {
637       positions[l].reserve(2 * sz + 1); // last one unused
638       positions[l].push_back(0);
639       coordinates[l].reserve(sz);
640       sz = 1;
641     } else if (isSingletonLvl(l)) {
642       coordinates[l].reserve(sz);
643       sz = 1;
644     } else if (isNOutOfMLvl(l)) {
645       assert(l == lvlRank - 1 && "unexpected n:m usage");
646       sz = detail::checkedMul(sz, lvlSizes[l]) / 2;
647       coordinates[l].reserve(sz);
648       values.reserve(sz);
649     } else { // Dense level.
650       assert(isDenseLvl(l));
651       sz = detail::checkedMul(sz, lvlSizes[l]);
652     }
653   }
654   if (lvlCOO) {
655     /* New from COO: ensure it is sorted. */
656     assert(lvlCOO->getRank() == lvlRank);
657     lvlCOO->sort();
658     // Now actually insert the `elements`.
659     const auto &elements = lvlCOO->getElements();
660     const uint64_t nse = elements.size();
661     assert(values.size() == 0);
662     values.reserve(nse);
663     fromCOO(elements, 0, nse, 0);
664   } else if (allDense) {
665     /* New empty (all dense) */
666     values.resize(sz, 0);
667   }
668 }
669 
670 template <typename P, typename C, typename V>
SparseTensorStorage(uint64_t dimRank,const uint64_t * dimSizes,uint64_t lvlRank,const uint64_t * lvlSizes,const LevelType * lvlTypes,const uint64_t * dim2lvl,const uint64_t * lvl2dim,const intptr_t * lvlBufs)671 SparseTensorStorage<P, C, V>::SparseTensorStorage(
672     uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
673     const uint64_t *lvlSizes, const LevelType *lvlTypes,
674     const uint64_t *dim2lvl, const uint64_t *lvl2dim, const intptr_t *lvlBufs)
675     : SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
676                           dim2lvl, lvl2dim) {
677   // Note that none of the buffers can be reused because ownership
678   // of the memory passed from clients is not necessarily transferred.
679   // Therefore, all data is copied over into a new SparseTensorStorage.
680   uint64_t trailCOOLen = 0, parentSz = 1, bufIdx = 0;
681   for (uint64_t l = 0; l < lvlRank; l++) {
682     if (!isUniqueLvl(l) && (isCompressedLvl(l) || isLooseCompressedLvl(l))) {
683       // A `(loose)compressed_nu` level marks the start of trailing COO
684       // start level. Since the coordinate buffer used for trailing COO
685       // is passed in as AoS scheme and SparseTensorStorage uses a SoA
686       // scheme, we cannot simply copy the value from the provided buffers.
687       trailCOOLen = lvlRank - l;
688       break;
689     }
690     if (isCompressedLvl(l) || isLooseCompressedLvl(l)) {
691       P *posPtr = reinterpret_cast<P *>(lvlBufs[bufIdx++]);
692       C *crdPtr = reinterpret_cast<C *>(lvlBufs[bufIdx++]);
693       if (isLooseCompressedLvl(l)) {
694         positions[l].assign(posPtr, posPtr + 2 * parentSz);
695         coordinates[l].assign(crdPtr, crdPtr + positions[l][2 * parentSz - 1]);
696       } else {
697         positions[l].assign(posPtr, posPtr + parentSz + 1);
698         coordinates[l].assign(crdPtr, crdPtr + positions[l][parentSz]);
699       }
700     } else if (isSingletonLvl(l)) {
701       assert(0 && "general singleton not supported yet");
702     } else if (isNOutOfMLvl(l)) {
703       assert(0 && "n ouf of m not supported yet");
704     } else {
705       assert(isDenseLvl(l));
706     }
707     parentSz = assembledSize(parentSz, l);
708   }
709 
710   // Handle Aos vs. SoA mismatch for COO.
711   if (trailCOOLen != 0) {
712     uint64_t cooStartLvl = lvlRank - trailCOOLen;
713     assert(!isUniqueLvl(cooStartLvl) &&
714            (isCompressedLvl(cooStartLvl) || isLooseCompressedLvl(cooStartLvl)));
715     P *posPtr = reinterpret_cast<P *>(lvlBufs[bufIdx++]);
716     C *aosCrdPtr = reinterpret_cast<C *>(lvlBufs[bufIdx++]);
717     P crdLen;
718     if (isLooseCompressedLvl(cooStartLvl)) {
719       positions[cooStartLvl].assign(posPtr, posPtr + 2 * parentSz);
720       crdLen = positions[cooStartLvl][2 * parentSz - 1];
721     } else {
722       positions[cooStartLvl].assign(posPtr, posPtr + parentSz + 1);
723       crdLen = positions[cooStartLvl][parentSz];
724     }
725     for (uint64_t l = cooStartLvl; l < lvlRank; l++) {
726       coordinates[l].resize(crdLen);
727       for (uint64_t n = 0; n < crdLen; n++) {
728         coordinates[l][n] = *(aosCrdPtr + (l - cooStartLvl) + n * trailCOOLen);
729       }
730     }
731     parentSz = assembledSize(parentSz, cooStartLvl);
732   }
733 
734   // Copy the values buffer.
735   V *valPtr = reinterpret_cast<V *>(lvlBufs[bufIdx]);
736   values.assign(valPtr, valPtr + parentSz);
737 }
738 
739 } // namespace sparse_tensor
740 } // namespace mlir
741 
742 #endif // MLIR_EXECUTIONENGINE_SPARSETENSOR_STORAGE_H
743