xref: /llvm-project/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h (revision b3fbb67379a4e67d54d7693e88c05697d01a9a5f)
1 //===- COO.h - Coordinate-scheme sparse tensor representation ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a coordinate-scheme representation of sparse tensors.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_EXECUTIONENGINE_SPARSETENSOR_COO_H
14 #define MLIR_EXECUTIONENGINE_SPARSETENSOR_COO_H
15 
16 #include <algorithm>
17 #include <cassert>
18 #include <cinttypes>
19 #include <functional>
20 #include <vector>
21 
22 namespace mlir {
23 namespace sparse_tensor {
24 
25 /// An element of a sparse tensor in coordinate-scheme representation
26 /// (i.e., a pair of coordinates and value). For example, a rank-1
27 /// vector element would look like
28 ///   ({i}, a[i])
29 /// and a rank-5 tensor element would look like
30 ///   ({i,j,k,l,m}, a[i,j,k,l,m])
31 ///
32 /// The coordinates are represented as a (non-owning) pointer into a
33 /// shared pool of coordinates, rather than being stored directly in this
34 /// object. This significantly improves performance because it reduces the
35 /// per-element memory footprint and centralizes the memory management for
36 /// coordinates. The only downside is that the coordinates themselves cannot
37 /// be retrieved without knowing the rank of the tensor to which this element
38 /// belongs (and that rank is not stored in this object).
39 template <typename V>
40 struct Element final {
Elementfinal41   Element(const uint64_t *coords, V val) : coords(coords), value(val){};
42   const uint64_t *coords; // pointer into shared coordinates pool
43   V value;
44 };
45 
46 /// Closure object for `operator<` on `Element` with a given rank.
47 template <typename V>
48 struct ElementLT final {
ElementLTfinal49   ElementLT(uint64_t rank) : rank(rank) {}
operatorfinal50   bool operator()(const Element<V> &e1, const Element<V> &e2) const {
51     for (uint64_t d = 0; d < rank; ++d) {
52       if (e1.coords[d] == e2.coords[d])
53         continue;
54       return e1.coords[d] < e2.coords[d];
55     }
56     return false;
57   }
58   const uint64_t rank;
59 };
60 
61 /// A memory-resident sparse tensor in coordinate-scheme representation
62 /// (a collection of `Element`s). This data structure is used as an
63 /// intermediate representation, e.g., for reading sparse tensors from
64 /// external formats into memory.
65 template <typename V>
66 class SparseTensorCOO final {
67 public:
68   /// Constructs a new coordinate-scheme sparse tensor with the given
69   /// sizes and an optional initial storage capacity.
70   explicit SparseTensorCOO(const std::vector<uint64_t> &dimSizes,
71                            uint64_t capacity = 0)
72       : SparseTensorCOO(dimSizes.size(), dimSizes.data(), capacity) {}
73 
74   /// Constructs a new coordinate-scheme sparse tensor with the given
75   /// sizes and an optional initial storage capacity. The size of the
76   /// dimSizes array is determined by dimRank.
77   explicit SparseTensorCOO(uint64_t dimRank, const uint64_t *dimSizes,
78                            uint64_t capacity = 0)
79       : dimSizes(dimSizes, dimSizes + dimRank), isSorted(true) {
80     assert(dimRank > 0 && "Trivial shape is not supported");
81     for (uint64_t d = 0; d < dimRank; ++d)
82       assert(dimSizes[d] > 0 && "Dimension size zero has trivial storage");
83     if (capacity) {
84       elements.reserve(capacity);
85       coordinates.reserve(capacity * dimRank);
86     }
87   }
88 
89   /// Gets the dimension-rank of the tensor.
getRank()90   uint64_t getRank() const { return dimSizes.size(); }
91 
92   /// Gets the dimension-sizes array.
getDimSizes()93   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
94 
95   /// Gets the elements array.
getElements()96   const std::vector<Element<V>> &getElements() const { return elements; }
97 
98   /// Returns the `operator<` closure object for the COO's element type.
getElementLT()99   ElementLT<V> getElementLT() const { return ElementLT<V>(getRank()); }
100 
101   /// Adds an element to the tensor.
add(const std::vector<uint64_t> & dimCoords,V val)102   void add(const std::vector<uint64_t> &dimCoords, V val) {
103     const uint64_t *base = coordinates.data();
104     const uint64_t size = coordinates.size();
105     const uint64_t dimRank = getRank();
106     assert(dimCoords.size() == dimRank && "Element rank mismatch");
107     for (uint64_t d = 0; d < dimRank; ++d) {
108       assert(dimCoords[d] < dimSizes[d] &&
109              "Coordinate is too large for the dimension");
110       coordinates.push_back(dimCoords[d]);
111     }
112     // This base only changes if `coordinates` was reallocated. In which
113     // case, we need to correct all previous pointers into the vector.
114     // Note that this only happens if we did not set the initial capacity
115     // right, and then only for every internal vector reallocation (which
116     // with the doubling rule should only incur an amortized linear overhead).
117     const uint64_t *const newBase = coordinates.data();
118     if (newBase != base) {
119       for (uint64_t i = 0, n = elements.size(); i < n; ++i)
120         elements[i].coords = newBase + (elements[i].coords - base);
121       base = newBase;
122     }
123     // Add the new element and update the sorted bit.
124     const Element<V> addedElem(base + size, val);
125     if (!elements.empty() && isSorted)
126       isSorted = getElementLT()(elements.back(), addedElem);
127     elements.push_back(addedElem);
128   }
129 
130   /// Sorts elements lexicographically by coordinates. If a coordinate
131   /// is mapped to multiple values, then the relative order of those
132   /// values is unspecified.
sort()133   void sort() {
134     if (isSorted)
135       return;
136     std::sort(elements.begin(), elements.end(), getElementLT());
137     isSorted = true;
138   }
139 
140 private:
141   const std::vector<uint64_t> dimSizes; // per-dimension sizes
142   std::vector<Element<V>> elements;     // all COO elements
143   std::vector<uint64_t> coordinates;    // shared coordinate pool
144   bool isSorted;
145 };
146 
147 } // namespace sparse_tensor
148 } // namespace mlir
149 
150 #endif // MLIR_EXECUTIONENGINE_SPARSETENSOR_COO_H
151