1 //===- SparseTensorIterator.h ---------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_ 10 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_ 11 12 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 13 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 14 15 namespace mlir { 16 namespace sparse_tensor { 17 18 // Forward declaration. 19 class SparseIterator; 20 21 /// The base class for all types of sparse tensor levels. It provides interfaces 22 /// to query the loop range (see `peekRangeAt`) and look up the coordinates (see 23 /// `peekCrdAt`). 24 class SparseTensorLevel { 25 SparseTensorLevel(SparseTensorLevel &&) = delete; 26 SparseTensorLevel(const SparseTensorLevel &) = delete; 27 SparseTensorLevel &operator=(SparseTensorLevel &&) = delete; 28 SparseTensorLevel &operator=(const SparseTensorLevel &) = delete; 29 30 public: 31 virtual ~SparseTensorLevel() = default; 32 33 std::string toString() const { 34 return std::string(toMLIRString(lt)) + "[" + std::to_string(tid) + "," + 35 std::to_string(lvl) + "]"; 36 } 37 38 virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix, 39 Value iv) const = 0; 40 41 /// Peeks the lower and upper bound to *fully* traverse the level with 42 /// the given position `parentPos`, see SparseTensorIterator::getCurPostion(), 43 /// that the immediate parent level is current at. Returns a pair of values 44 /// for *posLo* and *loopHi* respectively. 45 /// 46 /// For a dense level, the *posLo* is the linearized position at beginning, 47 /// while *loopHi* is the largest *coordinate*, it also implies that the 48 /// smallest *coordinate* to start the loop is 0. 49 /// 50 /// For a sparse level, [posLo, loopHi) specifies the range of index pointer 51 /// to load coordinate from the coordinate buffer. 52 virtual std::pair<Value, Value> 53 peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix, 54 ValueRange parentPos, Value inPadZone = nullptr) const = 0; 55 56 virtual std::pair<Value, Value> 57 collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix, 58 std::pair<Value, Value> parentRange) const { 59 llvm_unreachable("Not Implemented"); 60 }; 61 62 Level getLevel() const { return lvl; } 63 LevelType getLT() const { return lt; } 64 Value getSize() const { return lvlSize; } 65 virtual ValueRange getLvlBuffers() const = 0; 66 67 // 68 // Level properties 69 // 70 bool isUnique() const { return isUniqueLT(lt); } 71 72 protected: 73 SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize) 74 : tid(tid), lvl(lvl), lt(lt), lvlSize(lvlSize) {}; 75 76 public: 77 const unsigned tid, lvl; 78 const LevelType lt; 79 const Value lvlSize; 80 }; 81 82 enum class IterKind : uint8_t { 83 kTrivial, 84 kDedup, 85 kSubSect, 86 kNonEmptySubSect, 87 kFilter, 88 kPad, 89 }; 90 91 /// A `SparseIterationSpace` represents a sparse set of coordinates defined by 92 /// (possibly multiple) levels of a specific sparse tensor. 93 /// TODO: remove `SparseTensorLevel` and switch to SparseIterationSpace when 94 /// feature complete. 95 class SparseIterationSpace { 96 public: 97 SparseIterationSpace() = default; 98 SparseIterationSpace(SparseIterationSpace &) = delete; 99 SparseIterationSpace(SparseIterationSpace &&) = default; 100 101 // Constructs a N-D iteration space. 102 SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid, 103 std::pair<Level, Level> lvlRange, ValueRange parentPos); 104 105 // Constructs a 1-D iteration space. 106 SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid, 107 Level lvl, ValueRange parentPos) 108 : SparseIterationSpace(loc, b, t, tid, {lvl, lvl + 1}, parentPos) {}; 109 110 bool isUnique() const { return lvls.back()->isUnique(); } 111 112 unsigned getSpaceDim() const { return lvls.size(); } 113 114 // Reconstructs a iteration space directly from the provided ValueRange. 115 static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values, 116 unsigned tid); 117 118 // The inverse operation of `fromValues`. 119 SmallVector<Value> toValues() const { 120 SmallVector<Value> vals; 121 for (auto &stl : lvls) { 122 llvm::append_range(vals, stl->getLvlBuffers()); 123 vals.push_back(stl->getSize()); 124 } 125 vals.append({bound.first, bound.second}); 126 return vals; 127 } 128 129 const SparseTensorLevel &getLastLvl() const { return *lvls.back(); } 130 ArrayRef<std::unique_ptr<SparseTensorLevel>> getLvlRef() const { 131 return lvls; 132 } 133 134 Value getBoundLo() const { return bound.first; } 135 Value getBoundHi() const { return bound.second; } 136 137 // Extract an iterator to iterate over the sparse iteration space. 138 std::unique_ptr<SparseIterator> extractIterator(OpBuilder &b, 139 Location l) const; 140 141 private: 142 SmallVector<std::unique_ptr<SparseTensorLevel>> lvls; 143 std::pair<Value, Value> bound; 144 }; 145 146 /// Helper class that generates loop conditions, etc, to traverse a 147 /// sparse tensor level. 148 class SparseIterator { 149 SparseIterator(SparseIterator &&) = delete; 150 SparseIterator(const SparseIterator &) = delete; 151 SparseIterator &operator=(SparseIterator &&) = delete; 152 SparseIterator &operator=(const SparseIterator &) = delete; 153 154 protected: 155 SparseIterator(IterKind kind, unsigned tid, unsigned lvl, 156 unsigned cursorValsCnt, 157 SmallVectorImpl<Value> &cursorValStorage) 158 : batchCrds(0), kind(kind), tid(tid), lvl(lvl), crd(nullptr), 159 cursorValsCnt(cursorValsCnt), cursorValsStorageRef(cursorValStorage) {}; 160 161 SparseIterator(IterKind kind, unsigned cursorValsCnt, 162 SmallVectorImpl<Value> &cursorValStorage, 163 const SparseIterator &delegate) 164 : SparseIterator(kind, delegate.tid, delegate.lvl, cursorValsCnt, 165 cursorValStorage) {}; 166 167 SparseIterator(IterKind kind, const SparseIterator &wrap, 168 unsigned extraCursorCnt = 0) 169 : SparseIterator(kind, wrap.tid, wrap.lvl, 170 extraCursorCnt + wrap.cursorValsCnt, 171 wrap.cursorValsStorageRef) { 172 assert(wrap.cursorValsCnt == wrap.cursorValsStorageRef.size()); 173 cursorValsStorageRef.append(extraCursorCnt, nullptr); 174 assert(cursorValsStorageRef.size() == wrap.cursorValsCnt + extraCursorCnt); 175 }; 176 177 public: 178 virtual ~SparseIterator() = default; 179 180 void setSparseEmitStrategy(SparseEmitStrategy strategy) { 181 emitStrategy = strategy; 182 } 183 184 virtual std::string getDebugInterfacePrefix() const = 0; 185 virtual SmallVector<Type> getCursorValTypes(OpBuilder &b) const = 0; 186 187 Value getCrd() const { return crd; } 188 ValueRange getBatchCrds() const { return batchCrds; } 189 ValueRange getCursor() const { 190 return ValueRange(cursorValsStorageRef).take_front(cursorValsCnt); 191 }; 192 193 // Sets the iterate to the specified position. 194 void seek(ValueRange vals) { 195 assert(vals.size() == cursorValsCnt); 196 std::copy(vals.begin(), vals.end(), cursorValsStorageRef.begin()); 197 // Now that the iterator is re-positioned, the coordinate becomes invalid. 198 crd = nullptr; 199 } 200 201 // Reconstructs a iteration space directly from the provided ValueRange. 202 static std::unique_ptr<SparseIterator> 203 fromValues(IteratorType dstTp, ValueRange values, unsigned tid); 204 205 // The inverse operation of `fromValues`. 206 SmallVector<Value> toValues() const { llvm_unreachable("Not implemented"); } 207 208 // 209 // Iterator properties. 210 // 211 212 // Whether the iterator is a iterator over a batch level. 213 virtual bool isBatchIterator() const = 0; 214 215 // Whether the iterator support random access (i.e., support look up by 216 // *coordinate*). A random access iterator must also traverses a dense space. 217 virtual bool randomAccessible() const = 0; 218 219 // Whether the iterator can simply traversed by a for loop. 220 virtual bool iteratableByFor() const { return false; }; 221 222 // Get the upper bound of the sparse space that the iterator might visited. A 223 // sparse space is a subset of a dense space [0, bound), this function returns 224 // *bound*. 225 virtual Value upperBound(OpBuilder &b, Location l) const = 0; 226 227 // Serializes and deserializes the current status to/from a set of values. The 228 // ValueRange should contain values that are sufficient to recover the current 229 // iterating postion (i.e., itVals) as well as loop bound. 230 // 231 // Not every type of iterator supports the operations, e.g., non-empty 232 // subsection iterator does not because the the number of non-empty 233 // subsections can not be determined easily. 234 // 235 // NOTE: All the values should have index type. 236 virtual SmallVector<Value> serialize() const { 237 llvm_unreachable("unsupported"); 238 }; 239 virtual void deserialize(ValueRange vs) { llvm_unreachable("unsupported"); }; 240 241 // 242 // Core functions. 243 // 244 245 // Initializes the iterator according to the parent iterator's state. 246 void genInit(OpBuilder &b, Location l, const SparseIterator *p); 247 248 // Forwards the iterator to the next element. 249 ValueRange forward(OpBuilder &b, Location l); 250 251 // Locate the iterator to the position specified by *crd*, this can only 252 // be done on an iterator that supports randm access. 253 void locate(OpBuilder &b, Location l, Value crd); 254 255 // Returns a boolean value that equals `!it.end()` 256 Value genNotEnd(OpBuilder &b, Location l); 257 258 // Dereferences the iterator, loads the coordinate at the current position. 259 // 260 // The method assumes that the iterator is not currently exhausted (i.e., 261 // it != it.end()). 262 Value deref(OpBuilder &b, Location l); 263 264 // Actual Implementation provided by derived class. 265 virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *) = 0; 266 virtual ValueRange forwardImpl(OpBuilder &b, Location l) = 0; 267 virtual void locateImpl(OpBuilder &b, Location l, Value crd) { 268 llvm_unreachable("Unsupported"); 269 } 270 virtual Value genNotEndImpl(OpBuilder &b, Location l) = 0; 271 virtual Value derefImpl(OpBuilder &b, Location l) = 0; 272 // Gets the ValueRange that together specifies the current position of the 273 // iterator. For a unique level, the position can be a single index points to 274 // the current coordinate being visited. For a non-unique level, an extra 275 // index for the `segment high` is needed to to specifies the range of 276 // duplicated coordinates. The ValueRange should be able to uniquely identify 277 // the sparse range for the next level. See SparseTensorLevel::peekRangeAt(); 278 // 279 // Not every type of iterator supports the operation, e.g., non-empty 280 // subsection iterator does not because it represent a range of coordinates 281 // instead of just one. 282 virtual ValueRange getCurPosition() const { return getCursor(); }; 283 284 // Returns a pair of values for *upper*, *lower* bound respectively. 285 virtual std::pair<Value, Value> genForCond(OpBuilder &b, Location l) { 286 assert(randomAccessible()); 287 // Random-access iterator is traversed by coordinate, i.e., [curCrd, UB). 288 return {getCrd(), upperBound(b, l)}; 289 } 290 291 // Generates a bool value for scf::ConditionOp. 292 std::pair<Value, ValueRange> genWhileCond(OpBuilder &b, Location l, 293 ValueRange vs) { 294 ValueRange rem = linkNewScope(vs); 295 return std::make_pair(genNotEnd(b, l), rem); 296 } 297 298 // Generate a conditional it.next() in the following form 299 // 300 // if (cond) 301 // yield it.next 302 // else 303 // yield it 304 // 305 // The function is virtual to allow alternative implementation. For example, 306 // if it.next() is trivial to compute, we can use a select operation instead. 307 // E.g., 308 // 309 // it = select cond ? it+1 : it 310 virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond); 311 312 // Update the SSA value for the iterator after entering a new scope. 313 ValueRange linkNewScope(ValueRange pos) { 314 assert(!randomAccessible() && "random accessible iterators are traversed " 315 "by coordinate, call locate() instead."); 316 seek(pos.take_front(cursorValsCnt)); 317 return pos.drop_front(cursorValsCnt); 318 }; 319 320 protected: 321 void updateCrd(Value crd) { this->crd = crd; } 322 323 MutableArrayRef<Value> getMutCursorVals() { 324 MutableArrayRef<Value> ref = cursorValsStorageRef; 325 return ref.take_front(cursorValsCnt); 326 } 327 328 void inherentBatch(const SparseIterator &parent) { 329 batchCrds = parent.batchCrds; 330 } 331 332 SparseEmitStrategy emitStrategy; 333 SmallVector<Value> batchCrds; 334 335 public: 336 const IterKind kind; // For LLVM-style RTTI. 337 const unsigned tid, lvl; // tensor level identifier. 338 339 private: 340 Value crd; // The sparse coordinate used to coiterate; 341 342 // A range of value that together defines the current state of the 343 // iterator. Only loop variants should be included. 344 // 345 // For trivial iterators, it is the position; for dedup iterators, it consists 346 // of the positon and the segment high, for non-empty subsection iterator, it 347 // is the metadata that specifies the subsection. 348 // Note that the wrapped iterator shares the same storage to maintain itVals 349 // with it wrapper, which means the wrapped iterator might only own a subset 350 // of all the values stored in itValStorage. 351 const unsigned cursorValsCnt; 352 SmallVectorImpl<Value> &cursorValsStorageRef; 353 }; 354 355 /// Helper function to create a TensorLevel object from given `tensor`. 356 std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b, 357 Location l, Value t, 358 unsigned tid, 359 Level lvl); 360 361 /// Helper function to create a TensorLevel object from given ValueRange. 362 std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz, 363 ValueRange buffers, 364 unsigned tid, Level l); 365 366 /// Helper function to create a simple SparseIterator object that iterate 367 /// over the entire iteration space. 368 std::unique_ptr<SparseIterator> 369 makeSimpleIterator(OpBuilder &b, Location l, 370 const SparseIterationSpace &iterSpace); 371 372 /// Helper function to create a simple SparseIterator object that iterate 373 /// over the sparse tensor level. 374 /// TODO: switch to `SparseIterationSpace` (which support N-D iterator) when 375 /// feature complete. 376 std::unique_ptr<SparseIterator> makeSimpleIterator( 377 const SparseTensorLevel &stl, 378 SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional); 379 380 /// Helper function to create a synthetic SparseIterator object that iterates 381 /// over a dense space specified by [0,`sz`). 382 std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>> 383 makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, 384 SparseEmitStrategy strategy); 385 386 /// Helper function to create a SparseIterator object that iterates over a 387 /// sliced space, the orignal space (before slicing) is traversed by `sit`. 388 std::unique_ptr<SparseIterator> 389 makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, Value offset, 390 Value stride, Value size, SparseEmitStrategy strategy); 391 392 /// Helper function to create a SparseIterator object that iterates over a 393 /// padded sparse level (the padded value must be zero). 394 std::unique_ptr<SparseIterator> 395 makePaddedIterator(std::unique_ptr<SparseIterator> &&sit, Value padLow, 396 Value padHigh, SparseEmitStrategy strategy); 397 398 /// Helper function to create a SparseIterator object that iterate over the 399 /// non-empty subsections set. 400 std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator( 401 OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, 402 std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride, 403 SparseEmitStrategy strategy); 404 405 /// Helper function to create a SparseIterator object that iterates over a 406 /// non-empty subsection created by NonEmptySubSectIterator. 407 std::unique_ptr<SparseIterator> makeTraverseSubSectIterator( 408 OpBuilder &b, Location l, const SparseIterator &subsectIter, 409 const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap, 410 Value loopBound, unsigned stride, SparseEmitStrategy strategy); 411 412 } // namespace sparse_tensor 413 } // namespace mlir 414 415 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_ 416