xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h (revision f607102a0d6be0e2aebc1bfaed2ed0a6ae020145)
1 //===- SparseTensorIterator.h ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
10 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
11 
12 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
13 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
14 
15 namespace mlir {
16 namespace sparse_tensor {
17 
18 // Forward declaration.
19 class SparseIterator;
20 
21 /// The base class for all types of sparse tensor levels. It provides interfaces
22 /// to query the loop range (see `peekRangeAt`) and look up the coordinates (see
23 /// `peekCrdAt`).
24 class SparseTensorLevel {
25   SparseTensorLevel(SparseTensorLevel &&) = delete;
26   SparseTensorLevel(const SparseTensorLevel &) = delete;
27   SparseTensorLevel &operator=(SparseTensorLevel &&) = delete;
28   SparseTensorLevel &operator=(const SparseTensorLevel &) = delete;
29 
30 public:
31   virtual ~SparseTensorLevel() = default;
32 
33   std::string toString() const {
34     return std::string(toMLIRString(lt)) + "[" + std::to_string(tid) + "," +
35            std::to_string(lvl) + "]";
36   }
37 
38   virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
39                           Value iv) const = 0;
40 
41   /// Peeks the lower and upper bound to *fully* traverse the level with
42   /// the given position `parentPos`, see SparseTensorIterator::getCurPostion(),
43   /// that the immediate parent level is current at. Returns a pair of values
44   /// for *posLo* and *loopHi* respectively.
45   ///
46   /// For a dense level, the *posLo* is the linearized position at beginning,
47   /// while *loopHi* is the largest *coordinate*, it also implies that the
48   /// smallest *coordinate* to start the loop is 0.
49   ///
50   /// For a sparse level, [posLo, loopHi) specifies the range of index pointer
51   /// to load coordinate from the coordinate buffer.
52   virtual std::pair<Value, Value>
53   peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
54               ValueRange parentPos, Value inPadZone = nullptr) const = 0;
55 
56   virtual std::pair<Value, Value>
57   collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix,
58                        std::pair<Value, Value> parentRange) const {
59     llvm_unreachable("Not Implemented");
60   };
61 
62   Level getLevel() const { return lvl; }
63   LevelType getLT() const { return lt; }
64   Value getSize() const { return lvlSize; }
65   virtual ValueRange getLvlBuffers() const = 0;
66 
67   //
68   // Level properties
69   //
70   bool isUnique() const { return isUniqueLT(lt); }
71 
72 protected:
73   SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize)
74       : tid(tid), lvl(lvl), lt(lt), lvlSize(lvlSize) {};
75 
76 public:
77   const unsigned tid, lvl;
78   const LevelType lt;
79   const Value lvlSize;
80 };
81 
82 enum class IterKind : uint8_t {
83   kTrivial,
84   kDedup,
85   kSubSect,
86   kNonEmptySubSect,
87   kFilter,
88   kPad,
89 };
90 
91 /// A `SparseIterationSpace` represents a sparse set of coordinates defined by
92 /// (possibly multiple) levels of a specific sparse tensor.
93 /// TODO: remove `SparseTensorLevel` and switch to SparseIterationSpace when
94 /// feature complete.
95 class SparseIterationSpace {
96 public:
97   SparseIterationSpace() = default;
98   SparseIterationSpace(SparseIterationSpace &) = delete;
99   SparseIterationSpace(SparseIterationSpace &&) = default;
100 
101   // Constructs a N-D iteration space.
102   SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid,
103                        std::pair<Level, Level> lvlRange, ValueRange parentPos);
104 
105   // Constructs a 1-D iteration space.
106   SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid,
107                        Level lvl, ValueRange parentPos)
108       : SparseIterationSpace(loc, b, t, tid, {lvl, lvl + 1}, parentPos) {};
109 
110   bool isUnique() const { return lvls.back()->isUnique(); }
111 
112   unsigned getSpaceDim() const { return lvls.size(); }
113 
114   // Reconstructs a iteration space directly from the provided ValueRange.
115   static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values,
116                                          unsigned tid);
117 
118   // The inverse operation of `fromValues`.
119   SmallVector<Value> toValues() const {
120     SmallVector<Value> vals;
121     for (auto &stl : lvls) {
122       llvm::append_range(vals, stl->getLvlBuffers());
123       vals.push_back(stl->getSize());
124     }
125     vals.append({bound.first, bound.second});
126     return vals;
127   }
128 
129   const SparseTensorLevel &getLastLvl() const { return *lvls.back(); }
130   ArrayRef<std::unique_ptr<SparseTensorLevel>> getLvlRef() const {
131     return lvls;
132   }
133 
134   Value getBoundLo() const { return bound.first; }
135   Value getBoundHi() const { return bound.second; }
136 
137   // Extract an iterator to iterate over the sparse iteration space.
138   std::unique_ptr<SparseIterator> extractIterator(OpBuilder &b,
139                                                   Location l) const;
140 
141 private:
142   SmallVector<std::unique_ptr<SparseTensorLevel>> lvls;
143   std::pair<Value, Value> bound;
144 };
145 
146 /// Helper class that generates loop conditions, etc, to traverse a
147 /// sparse tensor level.
148 class SparseIterator {
149   SparseIterator(SparseIterator &&) = delete;
150   SparseIterator(const SparseIterator &) = delete;
151   SparseIterator &operator=(SparseIterator &&) = delete;
152   SparseIterator &operator=(const SparseIterator &) = delete;
153 
154 protected:
155   SparseIterator(IterKind kind, unsigned tid, unsigned lvl,
156                  unsigned cursorValsCnt,
157                  SmallVectorImpl<Value> &cursorValStorage)
158       : batchCrds(0), kind(kind), tid(tid), lvl(lvl), crd(nullptr),
159         cursorValsCnt(cursorValsCnt), cursorValsStorageRef(cursorValStorage) {};
160 
161   SparseIterator(IterKind kind, unsigned cursorValsCnt,
162                  SmallVectorImpl<Value> &cursorValStorage,
163                  const SparseIterator &delegate)
164       : SparseIterator(kind, delegate.tid, delegate.lvl, cursorValsCnt,
165                        cursorValStorage) {};
166 
167   SparseIterator(IterKind kind, const SparseIterator &wrap,
168                  unsigned extraCursorCnt = 0)
169       : SparseIterator(kind, wrap.tid, wrap.lvl,
170                        extraCursorCnt + wrap.cursorValsCnt,
171                        wrap.cursorValsStorageRef) {
172     assert(wrap.cursorValsCnt == wrap.cursorValsStorageRef.size());
173     cursorValsStorageRef.append(extraCursorCnt, nullptr);
174     assert(cursorValsStorageRef.size() == wrap.cursorValsCnt + extraCursorCnt);
175   };
176 
177 public:
178   virtual ~SparseIterator() = default;
179 
180   void setSparseEmitStrategy(SparseEmitStrategy strategy) {
181     emitStrategy = strategy;
182   }
183 
184   virtual std::string getDebugInterfacePrefix() const = 0;
185   virtual SmallVector<Type> getCursorValTypes(OpBuilder &b) const = 0;
186 
187   Value getCrd() const { return crd; }
188   ValueRange getBatchCrds() const { return batchCrds; }
189   ValueRange getCursor() const {
190     return ValueRange(cursorValsStorageRef).take_front(cursorValsCnt);
191   };
192 
193   // Sets the iterate to the specified position.
194   void seek(ValueRange vals) {
195     assert(vals.size() == cursorValsCnt);
196     std::copy(vals.begin(), vals.end(), cursorValsStorageRef.begin());
197     // Now that the iterator is re-positioned, the coordinate becomes invalid.
198     crd = nullptr;
199   }
200 
201   // Reconstructs a iteration space directly from the provided ValueRange.
202   static std::unique_ptr<SparseIterator>
203   fromValues(IteratorType dstTp, ValueRange values, unsigned tid);
204 
205   // The inverse operation of `fromValues`.
206   SmallVector<Value> toValues() const { llvm_unreachable("Not implemented"); }
207 
208   //
209   // Iterator properties.
210   //
211 
212   // Whether the iterator is a iterator over a batch level.
213   virtual bool isBatchIterator() const = 0;
214 
215   // Whether the iterator support random access (i.e., support look up by
216   // *coordinate*). A random access iterator must also traverses a dense space.
217   virtual bool randomAccessible() const = 0;
218 
219   // Whether the iterator can simply traversed by a for loop.
220   virtual bool iteratableByFor() const { return false; };
221 
222   // Get the upper bound of the sparse space that the iterator might visited. A
223   // sparse space is a subset of a dense space [0, bound), this function returns
224   // *bound*.
225   virtual Value upperBound(OpBuilder &b, Location l) const = 0;
226 
227   // Serializes and deserializes the current status to/from a set of values. The
228   // ValueRange should contain values that are sufficient to recover the current
229   // iterating postion (i.e., itVals) as well as loop bound.
230   //
231   // Not every type of iterator supports the operations, e.g., non-empty
232   // subsection iterator does not because the the number of non-empty
233   // subsections can not be determined easily.
234   //
235   // NOTE: All the values should have index type.
236   virtual SmallVector<Value> serialize() const {
237     llvm_unreachable("unsupported");
238   };
239   virtual void deserialize(ValueRange vs) { llvm_unreachable("unsupported"); };
240 
241   //
242   // Core functions.
243   //
244 
245   // Initializes the iterator according to the parent iterator's state.
246   void genInit(OpBuilder &b, Location l, const SparseIterator *p);
247 
248   // Forwards the iterator to the next element.
249   ValueRange forward(OpBuilder &b, Location l);
250 
251   // Locate the iterator to the position specified by *crd*, this can only
252   // be done on an iterator that supports randm access.
253   void locate(OpBuilder &b, Location l, Value crd);
254 
255   // Returns a boolean value that equals `!it.end()`
256   Value genNotEnd(OpBuilder &b, Location l);
257 
258   // Dereferences the iterator, loads the coordinate at the current position.
259   //
260   // The method assumes that the iterator is not currently exhausted (i.e.,
261   // it != it.end()).
262   Value deref(OpBuilder &b, Location l);
263 
264   // Actual Implementation provided by derived class.
265   virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *) = 0;
266   virtual ValueRange forwardImpl(OpBuilder &b, Location l) = 0;
267   virtual void locateImpl(OpBuilder &b, Location l, Value crd) {
268     llvm_unreachable("Unsupported");
269   }
270   virtual Value genNotEndImpl(OpBuilder &b, Location l) = 0;
271   virtual Value derefImpl(OpBuilder &b, Location l) = 0;
272   // Gets the ValueRange that together specifies the current position of the
273   // iterator. For a unique level, the position can be a single index points to
274   // the current coordinate being visited. For a non-unique level, an extra
275   // index for the `segment high` is needed to to specifies the range of
276   // duplicated coordinates. The ValueRange should be able to uniquely identify
277   // the sparse range for the next level. See SparseTensorLevel::peekRangeAt();
278   //
279   // Not every type of iterator supports the operation, e.g., non-empty
280   // subsection iterator does not because it represent a range of coordinates
281   // instead of just one.
282   virtual ValueRange getCurPosition() const { return getCursor(); };
283 
284   // Returns a pair of values for *upper*, *lower* bound respectively.
285   virtual std::pair<Value, Value> genForCond(OpBuilder &b, Location l) {
286     assert(randomAccessible());
287     // Random-access iterator is traversed by coordinate, i.e., [curCrd, UB).
288     return {getCrd(), upperBound(b, l)};
289   }
290 
291   // Generates a bool value for scf::ConditionOp.
292   std::pair<Value, ValueRange> genWhileCond(OpBuilder &b, Location l,
293                                             ValueRange vs) {
294     ValueRange rem = linkNewScope(vs);
295     return std::make_pair(genNotEnd(b, l), rem);
296   }
297 
298   // Generate a conditional it.next() in the following form
299   //
300   // if (cond)
301   //    yield it.next
302   // else
303   //    yield it
304   //
305   // The function is virtual to allow alternative implementation. For example,
306   // if it.next() is trivial to compute, we can use a select operation instead.
307   // E.g.,
308   //
309   //  it = select cond ? it+1 : it
310   virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond);
311 
312   // Update the SSA value for the iterator after entering a new scope.
313   ValueRange linkNewScope(ValueRange pos) {
314     assert(!randomAccessible() && "random accessible iterators are traversed "
315                                   "by coordinate, call locate() instead.");
316     seek(pos.take_front(cursorValsCnt));
317     return pos.drop_front(cursorValsCnt);
318   };
319 
320 protected:
321   void updateCrd(Value crd) { this->crd = crd; }
322 
323   MutableArrayRef<Value> getMutCursorVals() {
324     MutableArrayRef<Value> ref = cursorValsStorageRef;
325     return ref.take_front(cursorValsCnt);
326   }
327 
328   void inherentBatch(const SparseIterator &parent) {
329     batchCrds = parent.batchCrds;
330   }
331 
332   SparseEmitStrategy emitStrategy;
333   SmallVector<Value> batchCrds;
334 
335 public:
336   const IterKind kind;     // For LLVM-style RTTI.
337   const unsigned tid, lvl; // tensor level identifier.
338 
339 private:
340   Value crd; // The sparse coordinate used to coiterate;
341 
342   // A range of value that together defines the current state of the
343   // iterator. Only loop variants should be included.
344   //
345   // For trivial iterators, it is the position; for dedup iterators, it consists
346   // of the positon and the segment high, for non-empty subsection iterator, it
347   // is the metadata that specifies the subsection.
348   // Note that the wrapped iterator shares the same storage to maintain itVals
349   // with it wrapper, which means the wrapped iterator might only own a subset
350   // of all the values stored in itValStorage.
351   const unsigned cursorValsCnt;
352   SmallVectorImpl<Value> &cursorValsStorageRef;
353 };
354 
355 /// Helper function to create a TensorLevel object from given `tensor`.
356 std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b,
357                                                          Location l, Value t,
358                                                          unsigned tid,
359                                                          Level lvl);
360 
361 /// Helper function to create a TensorLevel object from given ValueRange.
362 std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
363                                                          ValueRange buffers,
364                                                          unsigned tid, Level l);
365 
366 /// Helper function to create a simple SparseIterator object that iterate
367 /// over the entire iteration space.
368 std::unique_ptr<SparseIterator>
369 makeSimpleIterator(OpBuilder &b, Location l,
370                    const SparseIterationSpace &iterSpace);
371 
372 /// Helper function to create a simple SparseIterator object that iterate
373 /// over the sparse tensor level.
374 /// TODO: switch to `SparseIterationSpace` (which support N-D iterator) when
375 /// feature complete.
376 std::unique_ptr<SparseIterator> makeSimpleIterator(
377     const SparseTensorLevel &stl,
378     SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional);
379 
380 /// Helper function to create a synthetic SparseIterator object that iterates
381 /// over a dense space specified by [0,`sz`).
382 std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
383 makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
384                         SparseEmitStrategy strategy);
385 
386 /// Helper function to create a SparseIterator object that iterates over a
387 /// sliced space, the orignal space (before slicing) is traversed by `sit`.
388 std::unique_ptr<SparseIterator>
389 makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, Value offset,
390                         Value stride, Value size, SparseEmitStrategy strategy);
391 
392 /// Helper function to create a SparseIterator object that iterates over a
393 /// padded sparse level (the padded value must be zero).
394 std::unique_ptr<SparseIterator>
395 makePaddedIterator(std::unique_ptr<SparseIterator> &&sit, Value padLow,
396                    Value padHigh, SparseEmitStrategy strategy);
397 
398 /// Helper function to create a SparseIterator object that iterate over the
399 /// non-empty subsections set.
400 std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator(
401     OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
402     std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride,
403     SparseEmitStrategy strategy);
404 
405 /// Helper function to create a SparseIterator object that iterates over a
406 /// non-empty subsection created by NonEmptySubSectIterator.
407 std::unique_ptr<SparseIterator> makeTraverseSubSectIterator(
408     OpBuilder &b, Location l, const SparseIterator &subsectIter,
409     const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
410     Value loopBound, unsigned stride, SparseEmitStrategy strategy);
411 
412 } // namespace sparse_tensor
413 } // namespace mlir
414 
415 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
416