1 //===- RaggedArray.h - 2D array with different inner lengths ----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Support/LLVM.h" 10 #include "llvm/ADT/STLExtras.h" 11 #include "llvm/ADT/SmallVector.h" 12 #include <iterator> 13 14 namespace mlir { 15 /// A 2D array where each row may have different length. Elements of each row 16 /// are stored contiguously, but rows don't have a fixed order in the storage. 17 template <typename T> 18 class RaggedArray { 19 public: 20 /// Returns the number of rows in the 2D array. size()21 size_t size() const { return slices.size(); } 22 23 /// Returns true if the are no rows in the 2D array. Note that an array with a 24 /// non-zero number of empty rows is *NOT* empty. empty()25 bool empty() const { return slices.empty(); } 26 27 /// Accesses `pos`-th row. 28 ArrayRef<T> operator[](size_t pos) const { return at(pos); } at(size_t pos)29 ArrayRef<T> at(size_t pos) const { 30 if (slices[pos].first == static_cast<size_t>(-1)) 31 return ArrayRef<T>(); 32 return ArrayRef<T>(storage).slice(slices[pos].first, slices[pos].second); 33 } 34 MutableArrayRef<T> operator[](size_t pos) { return at(pos); } at(size_t pos)35 MutableArrayRef<T> at(size_t pos) { 36 if (slices[pos].first == static_cast<size_t>(-1)) 37 return MutableArrayRef<T>(); 38 return MutableArrayRef<T>(storage).slice(slices[pos].first, 39 slices[pos].second); 40 } 41 42 /// Iterator over the rows. 43 class iterator 44 : public llvm::iterator_facade_base< 45 iterator, std::forward_iterator_tag, MutableArrayRef<T>, 46 std::ptrdiff_t, MutableArrayRef<T> *, MutableArrayRef<T>> { 47 public: 48 /// Creates the start iterator. iterator(RaggedArray & ragged)49 explicit iterator(RaggedArray &ragged) : ragged(ragged), pos(0) {} 50 51 /// Creates the end iterator. iterator(RaggedArray & ragged,size_t pos)52 iterator(RaggedArray &ragged, size_t pos) : ragged(ragged), pos(pos) {} 53 54 /// Dereferences the current iterator. Assumes in-bounds. 55 MutableArrayRef<T> operator*() const { return ragged[pos]; } 56 57 /// Increments the iterator. 58 iterator &operator++() { 59 if (pos < ragged.slices.size()) 60 ++pos; 61 return *this; 62 } 63 64 /// Compares the two iterators. Iterators into different ragged arrays 65 /// compare not equal. 66 bool operator==(const iterator &other) const { 67 return &ragged == &other.ragged && pos == other.pos; 68 } 69 70 private: 71 RaggedArray &ragged; 72 size_t pos; 73 }; 74 75 /// Constant iterator over the rows. 76 class const_iterator 77 : public llvm::iterator_facade_base< 78 const_iterator, std::forward_iterator_tag, ArrayRef<T>, 79 std::ptrdiff_t, ArrayRef<T> *, ArrayRef<T>> { 80 public: 81 /// Creates the start iterator. const_iterator(const RaggedArray & ragged)82 explicit const_iterator(const RaggedArray &ragged) 83 : ragged(ragged), pos(0) {} 84 85 /// Creates the end iterator. const_iterator(const RaggedArray & ragged,size_t pos)86 const_iterator(const RaggedArray &ragged, size_t pos) 87 : ragged(ragged), pos(pos) {} 88 89 /// Dereferences the current iterator. Assumes in-bounds. 90 ArrayRef<T> operator*() const { return ragged[pos]; } 91 92 /// Increments the iterator. 93 const_iterator &operator++() { 94 if (pos < ragged.slices.size()) 95 ++pos; 96 return *this; 97 } 98 99 /// Compares the two iterators. Iterators into different ragged arrays 100 /// compare not equal. 101 bool operator==(const const_iterator &other) const { 102 return &ragged == &other.ragged && pos == other.pos; 103 } 104 105 private: 106 const RaggedArray &ragged; 107 size_t pos; 108 }; 109 110 /// Iterator over rows. begin()111 const_iterator begin() const { return const_iterator(*this); } end()112 const_iterator end() const { return const_iterator(*this, slices.size()); } begin()113 iterator begin() { return iterator(*this); } end()114 iterator end() { return iterator(*this, slices.size()); } 115 116 /// Reserve space to store `size` rows with `nestedSize` elements each. 117 void reserve(size_t size, size_t nestedSize = 0) { 118 slices.reserve(size); 119 storage.reserve(size * nestedSize); 120 } 121 122 /// Appends the given range of elements as a new row to the 2D array. May 123 /// invalidate the end iterator. 124 template <typename Range> push_back(Range && elements)125 void push_back(Range &&elements) { 126 slices.push_back(appendToStorage(std::forward<Range>(elements))); 127 } 128 129 /// Replaces the `pos`-th row in the 2D array with the given range of 130 /// elements. Invalidates iterators and references to `pos`-th and all 131 /// succeeding rows. 132 template <typename Range> replace(size_t pos,Range && elements)133 void replace(size_t pos, Range &&elements) { 134 if (slices[pos].first != static_cast<size_t>(-1)) { 135 auto from = std::next(storage.begin(), slices[pos].first); 136 auto to = std::next(from, slices[pos].second); 137 auto newFrom = storage.erase(from, to); 138 // Update the array refs after the underlying storage was shifted. 139 for (size_t i = pos + 1, e = size(); i < e; ++i) { 140 slices[i] = std::make_pair(std::distance(storage.begin(), newFrom), 141 slices[i].second); 142 std::advance(newFrom, slices[i].second); 143 } 144 } 145 slices[pos] = appendToStorage(std::forward<Range>(elements)); 146 } 147 148 /// Appends `num` empty rows to the array. appendEmptyRows(size_t num)149 void appendEmptyRows(size_t num) { 150 slices.resize(slices.size() + num, std::pair<size_t, size_t>(-1, 0)); 151 } 152 153 /// Removes the first subarray in-place. Invalidates iterators to all rows. removeFront()154 void removeFront() { slices.erase(slices.begin()); } 155 156 private: 157 /// Appends the given elements to the storage and returns an ArrayRef 158 /// pointing to them in the storage. 159 template <typename Range> appendToStorage(Range && elements)160 std::pair<size_t, size_t> appendToStorage(Range &&elements) { 161 size_t start = storage.size(); 162 llvm::append_range(storage, std::forward<Range>(elements)); 163 return std::make_pair(start, storage.size() - start); 164 } 165 166 /// Outer elements of the ragged array. Each entry is an (offset, length) 167 /// pair identifying a contiguous segment in the `storage` list that 168 /// contains the actual elements. This allows for elements to be stored 169 /// contiguously without nested vectors and for different segments to be set 170 /// or replaced in any order. 171 SmallVector<std::pair<size_t, size_t>> slices; 172 173 /// Dense storage for ragged array elements. 174 SmallVector<T> storage; 175 }; 176 } // namespace mlir 177