xref: /llvm-project/mlir/include/mlir/Dialect/Transform/Utils/RaggedArray.h (revision b33b91a21788d439f49d6db4e7224c20f740f1a7)
1 //===- RaggedArray.h - 2D array with different inner lengths ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Support/LLVM.h"
10 #include "llvm/ADT/STLExtras.h"
11 #include "llvm/ADT/SmallVector.h"
12 #include <iterator>
13 
14 namespace mlir {
15 /// A 2D array where each row may have different length. Elements of each row
16 /// are stored contiguously, but rows don't have a fixed order in the storage.
17 template <typename T>
18 class RaggedArray {
19 public:
20   /// Returns the number of rows in the 2D array.
size()21   size_t size() const { return slices.size(); }
22 
23   /// Returns true if the are no rows in the 2D array. Note that an array with a
24   /// non-zero number of empty rows is *NOT* empty.
empty()25   bool empty() const { return slices.empty(); }
26 
27   /// Accesses `pos`-th row.
28   ArrayRef<T> operator[](size_t pos) const { return at(pos); }
at(size_t pos)29   ArrayRef<T> at(size_t pos) const {
30     if (slices[pos].first == static_cast<size_t>(-1))
31       return ArrayRef<T>();
32     return ArrayRef<T>(storage).slice(slices[pos].first, slices[pos].second);
33   }
34   MutableArrayRef<T> operator[](size_t pos) { return at(pos); }
at(size_t pos)35   MutableArrayRef<T> at(size_t pos) {
36     if (slices[pos].first == static_cast<size_t>(-1))
37       return MutableArrayRef<T>();
38     return MutableArrayRef<T>(storage).slice(slices[pos].first,
39                                              slices[pos].second);
40   }
41 
42   /// Iterator over the rows.
43   class iterator
44       : public llvm::iterator_facade_base<
45             iterator, std::forward_iterator_tag, MutableArrayRef<T>,
46             std::ptrdiff_t, MutableArrayRef<T> *, MutableArrayRef<T>> {
47   public:
48     /// Creates the start iterator.
iterator(RaggedArray & ragged)49     explicit iterator(RaggedArray &ragged) : ragged(ragged), pos(0) {}
50 
51     /// Creates the end iterator.
iterator(RaggedArray & ragged,size_t pos)52     iterator(RaggedArray &ragged, size_t pos) : ragged(ragged), pos(pos) {}
53 
54     /// Dereferences the current iterator. Assumes in-bounds.
55     MutableArrayRef<T> operator*() const { return ragged[pos]; }
56 
57     /// Increments the iterator.
58     iterator &operator++() {
59       if (pos < ragged.slices.size())
60         ++pos;
61       return *this;
62     }
63 
64     /// Compares the two iterators. Iterators into different ragged arrays
65     /// compare not equal.
66     bool operator==(const iterator &other) const {
67       return &ragged == &other.ragged && pos == other.pos;
68     }
69 
70   private:
71     RaggedArray &ragged;
72     size_t pos;
73   };
74 
75   /// Constant iterator over the rows.
76   class const_iterator
77       : public llvm::iterator_facade_base<
78             const_iterator, std::forward_iterator_tag, ArrayRef<T>,
79             std::ptrdiff_t, ArrayRef<T> *, ArrayRef<T>> {
80   public:
81     /// Creates the start iterator.
const_iterator(const RaggedArray & ragged)82     explicit const_iterator(const RaggedArray &ragged)
83         : ragged(ragged), pos(0) {}
84 
85     /// Creates the end iterator.
const_iterator(const RaggedArray & ragged,size_t pos)86     const_iterator(const RaggedArray &ragged, size_t pos)
87         : ragged(ragged), pos(pos) {}
88 
89     /// Dereferences the current iterator. Assumes in-bounds.
90     ArrayRef<T> operator*() const { return ragged[pos]; }
91 
92     /// Increments the iterator.
93     const_iterator &operator++() {
94       if (pos < ragged.slices.size())
95         ++pos;
96       return *this;
97     }
98 
99     /// Compares the two iterators. Iterators into different ragged arrays
100     /// compare not equal.
101     bool operator==(const const_iterator &other) const {
102       return &ragged == &other.ragged && pos == other.pos;
103     }
104 
105   private:
106     const RaggedArray &ragged;
107     size_t pos;
108   };
109 
110   /// Iterator over rows.
begin()111   const_iterator begin() const { return const_iterator(*this); }
end()112   const_iterator end() const { return const_iterator(*this, slices.size()); }
begin()113   iterator begin() { return iterator(*this); }
end()114   iterator end() { return iterator(*this, slices.size()); }
115 
116   /// Reserve space to store `size` rows with `nestedSize` elements each.
117   void reserve(size_t size, size_t nestedSize = 0) {
118     slices.reserve(size);
119     storage.reserve(size * nestedSize);
120   }
121 
122   /// Appends the given range of elements as a new row to the 2D array. May
123   /// invalidate the end iterator.
124   template <typename Range>
push_back(Range && elements)125   void push_back(Range &&elements) {
126     slices.push_back(appendToStorage(std::forward<Range>(elements)));
127   }
128 
129   /// Replaces the `pos`-th row in the 2D array with the given range of
130   /// elements. Invalidates iterators and references to `pos`-th and all
131   /// succeeding rows.
132   template <typename Range>
replace(size_t pos,Range && elements)133   void replace(size_t pos, Range &&elements) {
134     if (slices[pos].first != static_cast<size_t>(-1)) {
135       auto from = std::next(storage.begin(), slices[pos].first);
136       auto to = std::next(from, slices[pos].second);
137       auto newFrom = storage.erase(from, to);
138       // Update the array refs after the underlying storage was shifted.
139       for (size_t i = pos + 1, e = size(); i < e; ++i) {
140         slices[i] = std::make_pair(std::distance(storage.begin(), newFrom),
141                                    slices[i].second);
142         std::advance(newFrom, slices[i].second);
143       }
144     }
145     slices[pos] = appendToStorage(std::forward<Range>(elements));
146   }
147 
148   /// Appends `num` empty rows to the array.
appendEmptyRows(size_t num)149   void appendEmptyRows(size_t num) {
150     slices.resize(slices.size() + num, std::pair<size_t, size_t>(-1, 0));
151   }
152 
153   /// Removes the first subarray in-place. Invalidates iterators to all rows.
removeFront()154   void removeFront() { slices.erase(slices.begin()); }
155 
156 private:
157   /// Appends the given elements to the storage and returns an ArrayRef
158   /// pointing to them in the storage.
159   template <typename Range>
appendToStorage(Range && elements)160   std::pair<size_t, size_t> appendToStorage(Range &&elements) {
161     size_t start = storage.size();
162     llvm::append_range(storage, std::forward<Range>(elements));
163     return std::make_pair(start, storage.size() - start);
164   }
165 
166   /// Outer elements of the ragged array. Each entry is an (offset, length)
167   /// pair identifying a contiguous segment in the `storage` list that
168   /// contains the actual elements. This allows for elements to be stored
169   /// contiguously without nested vectors and for different segments to be set
170   /// or replaced in any order.
171   SmallVector<std::pair<size_t, size_t>> slices;
172 
173   /// Dense storage for ragged array elements.
174   SmallVector<T> storage;
175 };
176 } // namespace mlir
177