xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp (revision d6cc35f7f67575f2d3534ea385c2f36f48f49aea)
1 //===- SparseTensorIterator.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "SparseTensorIterator.h"
10 #include "CodegenUtils.h"
11 
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/SCF/IR/SCF.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 
16 using namespace mlir;
17 using namespace mlir::sparse_tensor;
18 using ValuePair = std::pair<Value, Value>;
19 using ValueTuple = std::tuple<Value, Value, Value>;
20 
21 //===----------------------------------------------------------------------===//
22 // File local helper functions/macros.
23 //===----------------------------------------------------------------------===//
24 #define CMPI(p, lhs, rhs)                                                      \
25   (b.create<arith::CmpIOp>(l, arith::CmpIPredicate::p, (lhs), (rhs))           \
26        .getResult())
27 
28 #define C_FALSE (constantI1(b, l, false))
29 #define C_TRUE (constantI1(b, l, true))
30 #define C_IDX(v) (constantIndex(b, l, (v)))
31 #define YIELD(vs) (b.create<scf::YieldOp>(l, (vs)))
32 #define ADDI(lhs, rhs) (b.create<arith::AddIOp>(l, (lhs), (rhs)).getResult())
33 #define ORI(lhs, rhs) (b.create<arith::OrIOp>(l, (lhs), (rhs)).getResult())
34 #define ANDI(lhs, rhs) (b.create<arith::AndIOp>(l, (lhs), (rhs)).getResult())
35 #define SUBI(lhs, rhs) (b.create<arith::SubIOp>(l, (lhs), (rhs)).getResult())
36 #define MULI(lhs, rhs) (b.create<arith::MulIOp>(l, (lhs), (rhs)).getResult())
37 #define MINUI(lhs, rhs) (b.create<arith::MinUIOp>(l, (lhs), (rhs)).getResult())
38 #define REMUI(lhs, rhs) (b.create<arith::RemUIOp>(l, (lhs), (rhs)).getResult())
39 #define DIVUI(lhs, rhs) (b.create<arith::DivUIOp>(l, (lhs), (rhs)).getResult())
40 #define SELECT(c, lhs, rhs)                                                    \
41   (b.create<arith::SelectOp>(l, (c), (lhs), (rhs)).getResult())
42 
43 //===----------------------------------------------------------------------===//
44 // SparseTensorLevel derived classes.
45 //===----------------------------------------------------------------------===//
46 
47 namespace {
48 
49 template <bool hasPosBuffer>
50 class SparseLevel : public SparseTensorLevel {
51   // It is either an array of size 2 or size 1 depending on whether the sparse
52   // level requires a position array.
53   using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
54                                      std::array<Value, 1>>;
55 
56 public:
SparseLevel(unsigned tid,Level lvl,LevelType lt,Value lvlSize,BufferT buffers)57   SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
58               BufferT buffers)
59       : SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {}
60 
getLvlBuffers() const61   ValueRange getLvlBuffers() const override { return buffers; }
62 
peekCrdAt(OpBuilder & b,Location l,ValueRange batchPrefix,Value iv) const63   Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
64                   Value iv) const override {
65     SmallVector<Value> memCrd(batchPrefix);
66     memCrd.push_back(iv);
67     return genIndexLoad(b, l, getCrdBuf(), memCrd);
68   }
69 
70 protected:
71   template <typename T = void, typename = std::enable_if_t<hasPosBuffer, T>>
getPosBuf() const72   Value getPosBuf() const {
73     return buffers[0];
74   }
75 
getCrdBuf() const76   Value getCrdBuf() const {
77     if constexpr (hasPosBuffer)
78       return buffers[1];
79     else
80       return buffers[0];
81   }
82 
83   const BufferT buffers;
84 };
85 
86 class DenseLevel : public SparseTensorLevel {
87 public:
DenseLevel(unsigned tid,Level lvl,Value lvlSize)88   DenseLevel(unsigned tid, Level lvl, Value lvlSize)
89       : SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize) {}
90 
peekCrdAt(OpBuilder &,Location,ValueRange,Value) const91   Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
92     llvm_unreachable("locate random-accessible level instead");
93   }
94 
getLvlBuffers() const95   ValueRange getLvlBuffers() const override { return {}; }
96 
peekRangeAt(OpBuilder & b,Location l,ValueRange batchPrefix,ValueRange parentPos,Value inPadZone) const97   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
98                         ValueRange parentPos, Value inPadZone) const override {
99     assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
100     assert(!inPadZone && "Not implemented");
101     Value p = parentPos.front();
102     Value posLo = MULI(p, lvlSize);
103     return {posLo, lvlSize};
104   }
105 };
106 
107 class BatchLevel : public SparseTensorLevel {
108 public:
BatchLevel(unsigned tid,Level lvl,Value lvlSize)109   BatchLevel(unsigned tid, Level lvl, Value lvlSize)
110       : SparseTensorLevel(tid, lvl, LevelFormat::Batch, lvlSize) {}
111 
peekCrdAt(OpBuilder &,Location,ValueRange,Value) const112   Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
113     llvm_unreachable("locate random-accessible level instead");
114   }
115 
getLvlBuffers() const116   ValueRange getLvlBuffers() const override { return {}; }
117 
peekRangeAt(OpBuilder & b,Location l,ValueRange,ValueRange parentPos,Value inPadZone) const118   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
119                         ValueRange parentPos, Value inPadZone) const override {
120     assert(!inPadZone && "Not implemented");
121     assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
122     // No need to linearize the position for non-annotated tensors.
123     return {C_IDX(0), lvlSize};
124   }
125 };
126 
127 class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
128 public:
CompressedLevel(unsigned tid,Level lvl,LevelType lt,Value lvlSize,Value posBuffer,Value crdBuffer)129   CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
130                   Value posBuffer, Value crdBuffer)
131       : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
132 
peekRangeAt(OpBuilder & b,Location l,ValueRange batchPrefix,ValueRange parentPos,Value inPadZone) const133   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
134                         ValueRange parentPos, Value inPadZone) const override {
135 
136     assert(parentPos.size() == 1 &&
137            "compressed level must be the first non-unique level.");
138 
139     auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair {
140       Value p = parentPos.front();
141       SmallVector<Value> memCrd(batchPrefix);
142       memCrd.push_back(p);
143       Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
144       memCrd.back() = ADDI(p, C_IDX(1));
145       Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
146       return {pLo, pHi};
147     };
148 
149     if (inPadZone == nullptr)
150       return loadRange();
151 
152     SmallVector<Type, 2> types{b.getIndexType(), b.getIndexType()};
153     scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true);
154     // True branch, returns a "fake" empty range [0, 0) if parent
155     // iterator is in pad zone.
156     b.setInsertionPointToStart(posRangeIf.thenBlock());
157 
158     SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)};
159     b.create<scf::YieldOp>(l, emptyRange);
160 
161     // False branch, returns the actual range.
162     b.setInsertionPointToStart(posRangeIf.elseBlock());
163     auto [pLo, pHi] = loadRange();
164     SmallVector<Value, 2> loadedRange{pLo, pHi};
165     b.create<scf::YieldOp>(l, loadedRange);
166 
167     b.setInsertionPointAfter(posRangeIf);
168     ValueRange posRange = posRangeIf.getResults();
169     return {posRange.front(), posRange.back()};
170   }
171 }; // namespace
172 
173 class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
174 public:
LooseCompressedLevel(unsigned tid,Level lvl,LevelType lt,Value lvlSize,Value posBuffer,Value crdBuffer)175   LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
176                        Value posBuffer, Value crdBuffer)
177       : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
178 
peekRangeAt(OpBuilder & b,Location l,ValueRange batchPrefix,ValueRange parentPos,Value inPadZone) const179   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
180                         ValueRange parentPos, Value inPadZone) const override {
181     assert(parentPos.size() == 1 &&
182            "loose-compressed level must be the first non-unique level.");
183     assert(!inPadZone && "Not implemented");
184     SmallVector<Value> memCrd(batchPrefix);
185     Value p = parentPos.front();
186     p = MULI(p, C_IDX(2));
187     memCrd.push_back(p);
188     Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
189     memCrd.back() = ADDI(p, C_IDX(1));
190     Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
191     return {pLo, pHi};
192   }
193 }; // namespace
194 
195 class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
196 public:
SingletonLevel(unsigned tid,Level lvl,LevelType lt,Value lvlSize,Value crdBuffer)197   SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
198                  Value crdBuffer)
199       : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
200 
peekRangeAt(OpBuilder & b,Location l,ValueRange batchPrefix,ValueRange parentPos,Value inPadZone) const201   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
202                         ValueRange parentPos, Value inPadZone) const override {
203     assert(parentPos.size() == 1 || parentPos.size() == 2);
204     assert(!inPadZone && "Not implemented");
205     Value p = parentPos.front();
206     Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
207 
208     if (segHi == nullptr)
209       return {p, ADDI(p, C_IDX(1))};
210     // Use the segHi as the loop upper bound.
211     return {p, segHi};
212   }
213 
214   ValuePair
collapseRangeBetween(OpBuilder & b,Location l,ValueRange batchPrefix,std::pair<Value,Value> parentRange) const215   collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix,
216                        std::pair<Value, Value> parentRange) const override {
217     // Singleton level keeps the same range after collapsing.
218     return parentRange;
219   };
220 };
221 
222 class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
223 public:
NOutOfMLevel(unsigned tid,Level lvl,LevelType lt,Value lvlSize,Value crdBuffer)224   NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
225                Value crdBuffer)
226       : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
227 
peekRangeAt(OpBuilder & b,Location l,ValueRange batchPrefix,ValueRange parentPos,Value inPadZone) const228   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
229                         ValueRange parentPos, Value inPadZone) const override {
230     assert(parentPos.size() == 1 && isUnique() &&
231            "n:m level can not be non-unique.");
232     assert(!inPadZone && "Not implemented");
233     // Each n:m blk has exactly n specified elements.
234     auto n = getN(lt);
235     Value posLo = MULI(parentPos.front(), C_IDX(n));
236     return {posLo, ADDI(posLo, C_IDX(n))};
237   }
238 };
239 
240 } // namespace
241 
242 //===----------------------------------------------------------------------===//
243 // File local helpers
244 //===----------------------------------------------------------------------===//
245 
genWhenInBound(OpBuilder & b,Location l,SparseIterator & it,ValueRange elseRet,llvm::function_ref<scf::ValueVector (OpBuilder &,Location,Value)> builder)246 static scf::ValueVector genWhenInBound(
247     OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet,
248     llvm::function_ref<scf::ValueVector(OpBuilder &, Location, Value)>
249         builder) {
250   TypeRange ifRetTypes = elseRet.getTypes();
251   auto ifOp = b.create<scf::IfOp>(l, ifRetTypes, it.genNotEnd(b, l), true);
252 
253   b.setInsertionPointToStart(ifOp.thenBlock());
254   Value crd = it.deref(b, l);
255   scf::ValueVector ret = builder(b, l, crd);
256   YIELD(ret);
257 
258   b.setInsertionPointToStart(ifOp.elseBlock());
259   YIELD(elseRet);
260 
261   b.setInsertionPointAfter(ifOp);
262   return ifOp.getResults();
263 }
264 
265 /// Generates code to compute the *absolute* offset of the slice based on the
266 /// provide minimum coordinates in the slice.
267 /// E.g., when reducing d0 + d1 + d2, we need two slices to fully reduced the
268 /// expression, i,e, s1 = slice(T, d0), s2 = slice(s1, d1). The *absolute*
269 /// offset is the offset computed relative to the initial tensors T.
270 ///
271 /// When isNonEmpty == true, the computed offset is meaningless and should not
272 /// be used during runtime, the method generates code to return 0 currently in
273 /// that case.
274 ///
275 /// offset = minCrd >= size ? minCrd - size + 1 : 0;
offsetFromMinCrd(OpBuilder & b,Location l,Value minCrd,Value size)276 static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd,
277                               Value size) {
278   Value geSize = CMPI(uge, minCrd, size);
279   // Compute minCrd - size + 1.
280   Value mms = SUBI(ADDI(minCrd, C_IDX(1)), size);
281   // This is the absolute offset related to the actual tensor.
282   return SELECT(geSize, mms, C_IDX(0));
283 }
284 
285 //===----------------------------------------------------------------------===//
286 // SparseIterator derived classes.
287 //===----------------------------------------------------------------------===//
288 
289 namespace {
290 
291 // The iterator that traverses a concrete sparse tensor levels. High-level
292 // abstract iterators wrap it to achieve more complex goals (such as collapsing
293 // several levels). It also holds the common storage to hold the mlir::Values
294 // for itself as well as for wrappers.
295 class ConcreteIterator : public SparseIterator {
296 protected:
ConcreteIterator(const SparseTensorLevel & stl,IterKind kind,unsigned cursorValCnt)297   ConcreteIterator(const SparseTensorLevel &stl, IterKind kind,
298                    unsigned cursorValCnt)
299       : SparseIterator(kind, stl.tid, stl.lvl, cursorValCnt, cursorValsStorage),
300         stl(stl), cursorValsStorage(cursorValCnt, nullptr) {
301     assert(getCursor().size() == cursorValCnt);
302   };
303 
304 public:
305   // For LLVM-style RTTI.
classof(const SparseIterator * from)306   static bool classof(const SparseIterator *from) {
307     return from->kind == IterKind::kTrivial;
308   }
309 
isBatchIterator() const310   bool isBatchIterator() const override {
311     return stl.getLT().isa<LevelFormat::Batch>();
312   }
randomAccessible() const313   bool randomAccessible() const override {
314     return stl.getLT().hasDenseSemantic();
315   };
iteratableByFor() const316   bool iteratableByFor() const override { return kind != IterKind::kDedup; };
upperBound(OpBuilder & b,Location l) const317   Value upperBound(OpBuilder &b, Location l) const override {
318     return stl.getSize();
319   };
320 
321 protected:
322   const SparseTensorLevel &stl;
323   // Owner of the storage, all wrappers build on top of a concrete iterator
324   // share the same storage such that the iterator values are always
325   // synchronized.
326   SmallVector<Value> cursorValsStorage;
327 };
328 
329 class TrivialIterator : public ConcreteIterator {
330 public:
TrivialIterator(const SparseTensorLevel & stl)331   TrivialIterator(const SparseTensorLevel &stl)
332       : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
333 
TrivialIterator(OpBuilder & b,Location l,const SparseTensorLevel & stl,Value posLo,Value posHi)334   TrivialIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
335                   Value posLo, Value posHi)
336       : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1), posLo(posLo),
337         posHi(posHi) {
338     seek(posLo);
339   }
340 
getDebugInterfacePrefix() const341   std::string getDebugInterfacePrefix() const override {
342     return std::string("trivial<") + stl.toString() + ">";
343   }
getCursorValTypes(OpBuilder & b) const344   SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
345     return {b.getIndexType()};
346   }
347 
serialize() const348   SmallVector<Value> serialize() const override {
349     SmallVector<Value> ret;
350     ret.push_back(getItPos());
351     if (randomAccessible()) {
352       // Loop high is implicit (defined by `upperBound()`) for random-access
353       // iterator, but we need to memorize posLo for linearization.
354       ret.push_back(posLo);
355     } else {
356       ret.push_back(posHi);
357     }
358     return ret;
359   };
360 
deserialize(ValueRange vs)361   void deserialize(ValueRange vs) override {
362     assert(vs.size() == 2);
363     seek(vs.front());
364     if (randomAccessible())
365       posLo = vs.back();
366     else
367       posHi = vs.back();
368   };
369 
370   void genInitImpl(OpBuilder &b, Location l,
371                    const SparseIterator *parent) override;
372 
genForCond(OpBuilder & b,Location l)373   ValuePair genForCond(OpBuilder &b, Location l) override {
374     if (randomAccessible())
375       return {deref(b, l), upperBound(b, l)};
376     return std::make_pair(getItPos(), posHi);
377   }
378 
genNotEndImpl(OpBuilder & b,Location l)379   Value genNotEndImpl(OpBuilder &b, Location l) override {
380     // We used the first level bound as the bound the collapsed set of levels.
381     return CMPI(ult, getItPos(), posHi);
382   }
383 
derefImpl(OpBuilder & b,Location l)384   Value derefImpl(OpBuilder &b, Location l) override {
385     if (randomAccessible()) {
386       updateCrd(SUBI(getItPos(), posLo));
387     } else {
388       updateCrd(stl.peekCrdAt(b, l, getBatchCrds(), getItPos()));
389     }
390     return getCrd();
391   };
392 
forwardImpl(OpBuilder & b,Location l)393   ValueRange forwardImpl(OpBuilder &b, Location l) override {
394     seek(ADDI(getItPos(), C_IDX(1)));
395     return getCursor();
396   }
397 
forwardIf(OpBuilder & b,Location l,Value cond)398   ValueRange forwardIf(OpBuilder &b, Location l, Value cond) override {
399     Value curPos = getCursor().front();
400     Value nxPos = forward(b, l).front();
401     seek(SELECT(cond, nxPos, curPos));
402     return getCursor();
403   }
404 
locateImpl(OpBuilder & b,Location l,Value crd)405   void locateImpl(OpBuilder &b, Location l, Value crd) override {
406     assert(randomAccessible());
407     // Seek to the linearized position.
408     seek(ADDI(crd, posLo));
409     updateCrd(crd);
410     if (isBatchIterator()) {
411       // If this is a batch iterator, also update the batch coordinate.
412       assert(batchCrds.size() > lvl);
413       batchCrds[lvl] = crd;
414     }
415   }
416 
getItPos() const417   Value getItPos() const { return getCursor().front(); }
418   Value posLo, posHi;
419 };
420 
421 class DedupIterator : public ConcreteIterator {
422 private:
423   Value genSegmentHigh(OpBuilder &b, Location l, Value pos);
424 
425 public:
DedupIterator(const SparseTensorLevel & stl)426   DedupIterator(const SparseTensorLevel &stl)
427       : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) {
428     assert(!stl.isUnique());
429   }
430 
DedupIterator(OpBuilder & b,Location l,const SparseTensorLevel & stl,Value posLo,Value posHi)431   DedupIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
432                 Value posLo, Value posHi)
433       : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2), posHi(posHi) {
434     assert(!stl.isUnique());
435     seek({posLo, genSegmentHigh(b, l, posLo)});
436   }
437 
438   // For LLVM-style RTTI.
classof(const SparseIterator * from)439   static bool classof(const SparseIterator *from) {
440     return from->kind == IterKind::kDedup;
441   }
442 
getDebugInterfacePrefix() const443   std::string getDebugInterfacePrefix() const override {
444     return std::string("dedup<") + stl.toString() + ">";
445   }
getCursorValTypes(OpBuilder & b) const446   SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
447     return {b.getIndexType(), b.getIndexType()};
448   }
449 
genInitImpl(OpBuilder & b,Location l,const SparseIterator * parent)450   void genInitImpl(OpBuilder &b, Location l,
451                    const SparseIterator *parent) override {
452     Value c0 = C_IDX(0);
453     ValueRange pPos = c0;
454 
455     // If the parent iterator is a batch iterator, we also start from 0 (but
456     // on a different batch).
457     if (parent && !parent->isBatchIterator())
458       pPos = parent->getCurPosition();
459 
460     Value posLo;
461     ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
462     std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
463 
464     seek({posLo, genSegmentHigh(b, l, posLo)});
465   }
466 
serialize() const467   SmallVector<Value> serialize() const override {
468     SmallVector<Value> ret;
469     ret.append(getCursor().begin(), getCursor().end());
470     ret.push_back(posHi);
471     return ret;
472   };
deserialize(ValueRange vs)473   void deserialize(ValueRange vs) override {
474     assert(vs.size() == 3);
475     seek(vs.take_front(getCursor().size()));
476     posHi = vs.back();
477   };
478 
genNotEndImpl(OpBuilder & b,Location l)479   Value genNotEndImpl(OpBuilder &b, Location l) override {
480     return CMPI(ult, getPos(), posHi);
481   }
482 
derefImpl(OpBuilder & b,Location l)483   Value derefImpl(OpBuilder &b, Location l) override {
484     updateCrd(stl.peekCrdAt(b, l, getBatchCrds(), getPos()));
485     return getCrd();
486   };
487 
forwardImpl(OpBuilder & b,Location l)488   ValueRange forwardImpl(OpBuilder &b, Location l) override {
489     Value nxPos = getSegHi(); // forward the position to the next segment.
490     seek({nxPos, genSegmentHigh(b, l, nxPos)});
491     return getCursor();
492   }
493 
getPos() const494   Value getPos() const { return getCursor()[0]; }
getSegHi() const495   Value getSegHi() const { return getCursor()[1]; }
496 
497   Value posHi;
498 };
499 
500 // A util base-iterator that delegates all methods to the wrapped iterator.
501 class SimpleWrapIterator : public SparseIterator {
502 public:
SimpleWrapIterator(std::unique_ptr<SparseIterator> && wrap,IterKind kind,unsigned extraCursorVal=0)503   SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
504                      unsigned extraCursorVal = 0)
505       : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
506 
getCursorValTypes(OpBuilder & b) const507   SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
508     return wrap->getCursorValTypes(b);
509   }
isBatchIterator() const510   bool isBatchIterator() const override { return wrap->isBatchIterator(); }
randomAccessible() const511   bool randomAccessible() const override { return wrap->randomAccessible(); };
iteratableByFor() const512   bool iteratableByFor() const override { return wrap->iteratableByFor(); };
513 
serialize() const514   SmallVector<Value> serialize() const override { return wrap->serialize(); };
deserialize(ValueRange vs)515   void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
getCurPosition() const516   ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
genInitImpl(OpBuilder & b,Location l,const SparseIterator * parent)517   void genInitImpl(OpBuilder &b, Location l,
518                    const SparseIterator *parent) override {
519     wrap->genInit(b, l, parent);
520   }
genNotEndImpl(OpBuilder & b,Location l)521   Value genNotEndImpl(OpBuilder &b, Location l) override {
522     return wrap->genNotEndImpl(b, l);
523   }
forwardImpl(OpBuilder & b,Location l)524   ValueRange forwardImpl(OpBuilder &b, Location l) override {
525     return wrap->forward(b, l);
526   };
upperBound(OpBuilder & b,Location l) const527   Value upperBound(OpBuilder &b, Location l) const override {
528     return wrap->upperBound(b, l);
529   };
530 
derefImpl(OpBuilder & b,Location l)531   Value derefImpl(OpBuilder &b, Location l) override {
532     return wrap->derefImpl(b, l);
533   }
534 
locateImpl(OpBuilder & b,Location l,Value crd)535   void locateImpl(OpBuilder &b, Location l, Value crd) override {
536     return wrap->locate(b, l, crd);
537   }
538 
getWrappedIterator() const539   SparseIterator &getWrappedIterator() const { return *wrap; }
540 
541 protected:
542   std::unique_ptr<SparseIterator> wrap;
543 };
544 
545 //
546 // A filter iterator wrapped from another iterator. The filter iterator update
547 // the wrapped iterator *in-place*.
548 //
549 class FilterIterator : public SimpleWrapIterator {
550   // Coorindate translation between crd loaded from the wrap iterator and the
551   // filter iterator.
fromWrapCrd(OpBuilder & b,Location l,Value wrapCrd) const552   Value fromWrapCrd(OpBuilder &b, Location l, Value wrapCrd) const {
553     // crd = (wrapCrd - offset) / stride
554     return DIVUI(SUBI(wrapCrd, offset), stride);
555   }
toWrapCrd(OpBuilder & b,Location l,Value crd) const556   Value toWrapCrd(OpBuilder &b, Location l, Value crd) const {
557     // wrapCrd = crd * stride + offset
558     return ADDI(MULI(crd, stride), offset);
559   }
560 
561   Value genCrdNotLegitPredicate(OpBuilder &b, Location l, Value wrapCrd);
562 
563   Value genShouldFilter(OpBuilder &b, Location l);
564 
565 public:
566   // TODO: avoid unnessary check when offset == 0 and/or when stride == 1 and/or
567   // when crd always < size.
FilterIterator(std::unique_ptr<SparseIterator> && wrap,Value offset,Value stride,Value size)568   FilterIterator(std::unique_ptr<SparseIterator> &&wrap, Value offset,
569                  Value stride, Value size)
570       : SimpleWrapIterator(std::move(wrap), IterKind::kFilter), offset(offset),
571         stride(stride), size(size) {}
572 
573   // For LLVM-style RTTI.
classof(const SparseIterator * from)574   static bool classof(const SparseIterator *from) {
575     return from->kind == IterKind::kFilter;
576   }
577 
getDebugInterfacePrefix() const578   std::string getDebugInterfacePrefix() const override {
579     return std::string("filter<") + wrap->getDebugInterfacePrefix() + ">";
580   }
581 
iteratableByFor() const582   bool iteratableByFor() const override { return randomAccessible(); };
upperBound(OpBuilder & b,Location l) const583   Value upperBound(OpBuilder &b, Location l) const override { return size; };
584 
genInitImpl(OpBuilder & b,Location l,const SparseIterator * parent)585   void genInitImpl(OpBuilder &b, Location l,
586                    const SparseIterator *parent) override {
587     wrap->genInit(b, l, parent);
588     if (!randomAccessible()) {
589       // TODO: we can skip this when stride == 1 and offset == 0, we can also
590       // use binary search here.
591       forwardIf(b, l, genShouldFilter(b, l));
592     } else {
593       // Else, locate to the slice.offset, which is the first coordinate
594       // included by the slice.
595       wrap->locate(b, l, offset);
596     }
597   }
598 
599   Value genNotEndImpl(OpBuilder &b, Location l) override;
600 
derefImpl(OpBuilder & b,Location l)601   Value derefImpl(OpBuilder &b, Location l) override {
602     updateCrd(fromWrapCrd(b, l, wrap->deref(b, l)));
603     return getCrd();
604   }
605 
locateImpl(OpBuilder & b,Location l,Value crd)606   void locateImpl(OpBuilder &b, Location l, Value crd) override {
607     assert(randomAccessible());
608     wrap->locate(b, l, toWrapCrd(b, l, crd));
609     updateCrd(crd);
610   }
611 
612   ValueRange forwardImpl(OpBuilder &b, Location l) override;
613 
614   Value offset, stride, size;
615 };
616 
617 //
618 // A pad iterator wrapped from another iterator. The pad iterator updates
619 // the wrapped iterator *in-place*.
620 //
621 class PadIterator : public SimpleWrapIterator {
622 
623 public:
PadIterator(std::unique_ptr<SparseIterator> && wrap,Value padLow,Value padHigh)624   PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
625               Value padHigh)
626       : SimpleWrapIterator(std::move(wrap), IterKind::kPad,
627                            wrap->randomAccessible() ? 1 : 0),
628         padLow(padLow), padHigh(padHigh) {}
629 
630   // For LLVM-style RTTI.
classof(const SparseIterator * from)631   static bool classof(const SparseIterator *from) {
632     return from->kind == IterKind::kPad;
633   }
634 
getDebugInterfacePrefix() const635   std::string getDebugInterfacePrefix() const override {
636     return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
637   }
638 
639   // Returns a pair of values for *upper*, *lower* bound respectively.
genForCond(OpBuilder & b,Location l)640   ValuePair genForCond(OpBuilder &b, Location l) override {
641     if (randomAccessible())
642       return {getCrd(), upperBound(b, l)};
643     return wrap->genForCond(b, l);
644   }
645 
646   // For padded dense iterator, we append a `inPadZone: bool` in addition to
647   // values used by the wrapped iterator.
getCurPosition() const648   ValueRange getCurPosition() const override { return getCursor(); }
649 
getCursorValTypes(OpBuilder & b) const650   SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
651     SmallVector<Type> ret = wrap->getCursorValTypes(b);
652     // Need an extra boolean value `inPadZone` for padded dense iterator.
653     if (randomAccessible())
654       ret.push_back(b.getI1Type());
655 
656     return ret;
657   }
658 
659   // The upper bound after padding becomes `size + padLow + padHigh`.
upperBound(OpBuilder & b,Location l) const660   Value upperBound(OpBuilder &b, Location l) const override {
661     return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
662   };
663 
664   // The pad_coord = coord + pad_lo
derefImpl(OpBuilder & b,Location l)665   Value derefImpl(OpBuilder &b, Location l) override {
666     updateCrd(ADDI(wrap->deref(b, l), padLow));
667     return getCrd();
668   }
669 
locateImpl(OpBuilder & b,Location l,Value crd)670   void locateImpl(OpBuilder &b, Location l, Value crd) override {
671     assert(randomAccessible());
672     wrap->locate(b, l, SUBI(crd, padLow));
673 
674     // inPadZone = crd < padLow || crd >= size + padLow.
675     Value inPadLow = CMPI(ult, crd, padLow);
676     Value inPadHigh = CMPI(uge, crd, ADDI(wrap->upperBound(b, l), padLow));
677     getMutCursorVals().back() = ORI(inPadLow, inPadHigh);
678 
679     updateCrd(crd);
680   }
681 
682   Value padLow, padHigh;
683 };
684 
685 class NonEmptySubSectIterator : public SparseIterator {
686 public:
687   using TraverseBuilder = llvm::function_ref<scf::ValueVector(
688       OpBuilder &, Location, const SparseIterator *, ValueRange)>;
689 
NonEmptySubSectIterator(OpBuilder & b,Location l,const SparseIterator * parent,std::unique_ptr<SparseIterator> && delegate,Value subSectSz)690   NonEmptySubSectIterator(OpBuilder &b, Location l,
691                           const SparseIterator *parent,
692                           std::unique_ptr<SparseIterator> &&delegate,
693                           Value subSectSz)
694       : SparseIterator(IterKind::kNonEmptySubSect, 3, subSectMeta, *delegate),
695         parent(parent), delegate(std::move(delegate)),
696         tupleSz(this->delegate->serialize().size()), subSectSz(subSectSz) {
697     auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
698     if (p == nullptr) {
699       // Extract subsections along the root level.
700       maxTupleCnt = C_IDX(1);
701     } else if (p->lvl == lvl) {
702       // Extract subsections along the same level.
703       maxTupleCnt = p->maxTupleCnt;
704       assert(false && "Not implemented.");
705     } else {
706       // Extract subsections along the previous level.
707       assert(p->lvl + 1 == lvl);
708       maxTupleCnt = MULI(p->maxTupleCnt, p->subSectSz);
709     }
710     // We don't need an extra buffer to find subsections on random-accessible
711     // levels.
712     if (randomAccessible())
713       return;
714     subSectPosBuf = allocSubSectPosBuf(b, l);
715   }
716 
717   // For LLVM-style RTTI.
classof(const SparseIterator * from)718   static bool classof(const SparseIterator *from) {
719     return from->kind == IterKind::kNonEmptySubSect;
720   }
721 
getDebugInterfacePrefix() const722   std::string getDebugInterfacePrefix() const override {
723     return std::string("ne_sub<") + delegate->getDebugInterfacePrefix() + ">";
724   }
getCursorValTypes(OpBuilder & b) const725   SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
726     // minCrd, absolute offset, notEnd
727     return {b.getIndexType(), b.getIndexType(), b.getI1Type()};
728   }
729 
730   // The sliced pointer buffer is organized as:
731   //     [[itVal0, itVal1, ..., pNx0],
732   //      [itVal0, itVal1, ..., pNx0],
733   //      ...]
allocSubSectPosBuf(OpBuilder & b,Location l)734   Value allocSubSectPosBuf(OpBuilder &b, Location l) {
735     return b.create<memref::AllocaOp>(
736         l,
737         MemRefType::get({ShapedType::kDynamic, tupleSz + 1}, b.getIndexType()),
738         maxTupleCnt);
739   }
740 
storeNxLvlStart(OpBuilder & b,Location l,Value tupleId,Value start) const741   void storeNxLvlStart(OpBuilder &b, Location l, Value tupleId,
742                        Value start) const {
743     b.create<memref::StoreOp>(l, start, subSectPosBuf,
744                               ValueRange{tupleId, C_IDX(tupleSz)});
745   }
746 
loadNxLvlStart(OpBuilder & b,Location l,Value tupleId) const747   Value loadNxLvlStart(OpBuilder &b, Location l, Value tupleId) const {
748     return b.create<memref::LoadOp>(l, subSectPosBuf,
749                                     ValueRange{tupleId, C_IDX(tupleSz)});
750   }
751 
storeCursorVals(OpBuilder & b,Location l,Value tupleId,ValueRange itVals) const752   void storeCursorVals(OpBuilder &b, Location l, Value tupleId,
753                        ValueRange itVals) const {
754     assert(itVals.size() == tupleSz);
755     for (unsigned i = 0; i < tupleSz; i++) {
756       b.create<memref::StoreOp>(l, itVals[i], subSectPosBuf,
757                                 ValueRange{tupleId, C_IDX(i)});
758     }
759   }
760 
loadCursorVals(OpBuilder & b,Location l,Value tupleId) const761   SmallVector<Value> loadCursorVals(OpBuilder &b, Location l,
762                                     Value tupleId) const {
763     SmallVector<Value> ret;
764     for (unsigned i = 0; i < tupleSz; i++) {
765       Value v = b.create<memref::LoadOp>(l, subSectPosBuf,
766                                          ValueRange{tupleId, C_IDX(i)});
767       ret.push_back(v);
768     }
769     return ret;
770   }
771 
isSubSectRoot() const772   bool isSubSectRoot() const {
773     return !parent || !llvm::isa<NonEmptySubSectIterator>(parent);
774   }
775 
776   // Generate code that inflate the current subsection tree till the current
777   // level such that every leaf node is visited.
778   ValueRange inflateSubSectTree(OpBuilder &b, Location l, ValueRange reduc,
779                                 TraverseBuilder builder) const;
780 
isBatchIterator() const781   bool isBatchIterator() const override { return delegate->isBatchIterator(); }
randomAccessible() const782   bool randomAccessible() const override {
783     return delegate->randomAccessible();
784   };
iteratableByFor() const785   bool iteratableByFor() const override { return randomAccessible(); };
upperBound(OpBuilder & b,Location l) const786   Value upperBound(OpBuilder &b, Location l) const override {
787     auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
788     Value parentUB =
789         p && p->lvl == lvl ? p->upperBound(b, l) : delegate->upperBound(b, l);
790     return ADDI(SUBI(parentUB, subSectSz), C_IDX(1));
791   };
792 
793   void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override;
794 
locateImpl(OpBuilder & b,Location l,Value crd)795   void locateImpl(OpBuilder &b, Location l, Value crd) override {
796     Value absOff = crd;
797 
798     if (isSubSectRoot())
799       delegate->locate(b, l, absOff);
800     else
801       assert(parent->lvl + 1 == lvl);
802 
803     seek(ValueRange{absOff, absOff, C_TRUE});
804     updateCrd(crd);
805   }
806 
toSubSectCrd(OpBuilder & b,Location l,Value wrapCrd) const807   Value toSubSectCrd(OpBuilder &b, Location l, Value wrapCrd) const {
808     return SUBI(wrapCrd, getAbsOff());
809   }
810 
genNotEndImpl(OpBuilder & b,Location l)811   Value genNotEndImpl(OpBuilder &b, Location l) override {
812     return getNotEnd();
813   };
814 
derefImpl(OpBuilder & b,Location l)815   Value derefImpl(OpBuilder &b, Location l) override {
816     // Use the relative offset to coiterate.
817     Value crd;
818     auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
819     if (p && p->lvl == lvl)
820       crd = SUBI(getAbsOff(), p->getAbsOff());
821     crd = getAbsOff();
822 
823     updateCrd(crd);
824     return crd;
825   };
826 
827   ValueRange forwardImpl(OpBuilder &b, Location l) override;
828 
getMinCrd() const829   Value getMinCrd() const { return subSectMeta[0]; }
getAbsOff() const830   Value getAbsOff() const { return subSectMeta[1]; }
getNotEnd() const831   Value getNotEnd() const { return subSectMeta[2]; }
832 
833   const SparseIterator *parent;
834   std::unique_ptr<SparseIterator> delegate;
835 
836   // Number of values required to serialize the wrapped iterator.
837   const unsigned tupleSz;
838   // Max number of tuples, and the actual number of tuple.
839   Value maxTupleCnt, tupleCnt;
840   // The memory used to cache the tuple serialized from the wrapped iterator.
841   Value subSectPosBuf;
842 
843   const Value subSectSz;
844 
845   // minCrd, absolute offset, notEnd
846   SmallVector<Value, 3> subSectMeta{nullptr, nullptr, nullptr};
847 };
848 
849 class SubSectIterator;
850 
851 // A wrapper that helps generating code to traverse a subsection, used
852 // by both `NonEmptySubSectIterator`and `SubSectIterator`.
853 struct SubSectIterHelper {
854   explicit SubSectIterHelper(const SubSectIterator &iter);
855   explicit SubSectIterHelper(const NonEmptySubSectIterator &subSect);
856 
857   // Delegate methods.
858   void deserializeFromTupleId(OpBuilder &b, Location l, Value tupleId);
859   void locate(OpBuilder &b, Location l, Value crd);
860   Value genNotEnd(OpBuilder &b, Location l);
861   Value deref(OpBuilder &b, Location l);
862   ValueRange forward(OpBuilder &b, Location l);
863 
864   const NonEmptySubSectIterator &subSect;
865   SparseIterator &wrap;
866 };
867 
868 class SubSectIterator : public SparseIterator {
869 public:
SubSectIterator(const NonEmptySubSectIterator & subSect,const SparseIterator & parent,std::unique_ptr<SparseIterator> && wrap)870   SubSectIterator(const NonEmptySubSectIterator &subSect,
871                   const SparseIterator &parent,
872                   std::unique_ptr<SparseIterator> &&wrap)
873       : SparseIterator(IterKind::kSubSect, *wrap,
874                        /*extraCursorCnt=*/wrap->randomAccessible() ? 0 : 1),
875         subSect(subSect), wrap(std::move(wrap)), parent(parent), helper(*this) {
876     assert(subSect.tid == tid && subSect.lvl == lvl);
877     assert(parent.kind != IterKind::kSubSect || parent.lvl + 1 == lvl);
878   };
879 
880   // For LLVM-style RTTI.
classof(const SparseIterator * from)881   static bool classof(const SparseIterator *from) {
882     return from->kind == IterKind::kSubSect;
883   }
884 
getDebugInterfacePrefix() const885   std::string getDebugInterfacePrefix() const override {
886     return std::string("subsect<") + wrap->getDebugInterfacePrefix() + ">";
887   }
getCursorValTypes(OpBuilder & b) const888   SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
889     SmallVector<Type> ret = wrap->getCursorValTypes(b);
890     if (!randomAccessible())
891       ret.push_back(b.getIndexType()); // The extra counter.
892     return ret;
893   }
894 
isBatchIterator() const895   bool isBatchIterator() const override { return wrap->isBatchIterator(); }
randomAccessible() const896   bool randomAccessible() const override { return wrap->randomAccessible(); };
iteratableByFor() const897   bool iteratableByFor() const override { return randomAccessible(); };
upperBound(OpBuilder & b,Location l) const898   Value upperBound(OpBuilder &b, Location l) const override {
899     return subSect.subSectSz;
900   }
901 
getCurPosition() const902   ValueRange getCurPosition() const override { return wrap->getCurPosition(); };
903 
getNxLvlTupleId(OpBuilder & b,Location l) const904   Value getNxLvlTupleId(OpBuilder &b, Location l) const {
905     if (randomAccessible()) {
906       return ADDI(getCrd(), nxLvlTupleStart);
907     };
908     return ADDI(getCursor().back(), nxLvlTupleStart);
909   }
910 
genInitImpl(OpBuilder & b,Location l,const SparseIterator *)911   void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override {
912     if (randomAccessible()) {
913       if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
914         assert(p->lvl + 1 == lvl);
915         wrap->genInit(b, l, p);
916         // Linearize the dense subsection index.
917         nxLvlTupleStart = MULI(subSect.subSectSz, p->getNxLvlTupleId(b, l));
918       } else {
919         assert(subSect.lvl == lvl && subSect.isSubSectRoot());
920         wrap->deserialize(subSect.delegate->serialize());
921         nxLvlTupleStart = C_IDX(0);
922       }
923       return;
924     }
925     assert(!randomAccessible());
926     assert(getCursor().size() == wrap->getCursor().size() + 1);
927     // Extra counter that counts the number of actually visited coordinates in
928     // the sparse subsection.
929     getMutCursorVals().back() = C_IDX(0);
930     Value tupleId;
931     if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
932       assert(p->lvl + 1 == lvl);
933       tupleId = p->getNxLvlTupleId(b, l);
934     } else {
935       assert(subSect.lvl == lvl && subSect.isSubSectRoot());
936       tupleId = C_IDX(0);
937     }
938     nxLvlTupleStart = subSect.loadNxLvlStart(b, l, tupleId);
939     helper.deserializeFromTupleId(b, l, tupleId);
940   }
941 
locateImpl(OpBuilder & b,Location l,Value crd)942   void locateImpl(OpBuilder &b, Location l, Value crd) override {
943     helper.locate(b, l, crd);
944     updateCrd(crd);
945   }
946 
genNotEndImpl(OpBuilder & b,Location l)947   Value genNotEndImpl(OpBuilder &b, Location l) override {
948     return helper.genNotEnd(b, l);
949   }
950 
derefImpl(OpBuilder & b,Location l)951   Value derefImpl(OpBuilder &b, Location l) override {
952     Value crd = helper.deref(b, l);
953     updateCrd(crd);
954     return crd;
955   };
956 
forwardImpl(OpBuilder & b,Location l)957   ValueRange forwardImpl(OpBuilder &b, Location l) override {
958     helper.forward(b, l);
959     assert(!randomAccessible());
960     assert(getCursor().size() == wrap->getCursor().size() + 1);
961     getMutCursorVals().back() = ADDI(getCursor().back(), C_IDX(1));
962     return getCursor();
963   };
964 
965   Value nxLvlTupleStart;
966 
967   const NonEmptySubSectIterator &subSect;
968   std::unique_ptr<SparseIterator> wrap;
969   const SparseIterator &parent;
970 
971   SubSectIterHelper helper;
972 };
973 
974 } // namespace
975 
976 //===----------------------------------------------------------------------===//
977 // SparseIterator derived classes implementation.
978 //===----------------------------------------------------------------------===//
979 
genInit(OpBuilder & b,Location l,const SparseIterator * p)980 void SparseIterator::genInit(OpBuilder &b, Location l,
981                              const SparseIterator *p) {
982   if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
983     std::string prefix = getDebugInterfacePrefix();
984     Operation *begin = b.create(l, b.getStringAttr(prefix + ".begin"), {},
985                                 getCursorValTypes(b));
986     seek(begin->getResults());
987     return;
988   }
989   // Inherent batch coordinates from parents.
990   if (p)
991     inherentBatch(*p);
992   // TODO: support lowering to function call.
993   return genInitImpl(b, l, p);
994 }
995 
genNotEnd(OpBuilder & b,Location l)996 Value SparseIterator::genNotEnd(OpBuilder &b, Location l) {
997   if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
998     std::string prefix = getDebugInterfacePrefix();
999     Operation *notEnd = b.create(l, b.getStringAttr(prefix + ".not_end"),
1000                                  getCursor(), b.getI1Type());
1001     return notEnd->getResult(0);
1002   }
1003   // TODO: support lowering to function call.
1004   return genNotEndImpl(b, l);
1005 }
1006 
locate(OpBuilder & b,Location l,Value crd)1007 void SparseIterator::locate(OpBuilder &b, Location l, Value crd) {
1008   if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
1009     std::string prefix = getDebugInterfacePrefix();
1010     SmallVector<Value> args = getCursor();
1011     args.push_back(crd);
1012     Operation *locate = b.create(l, b.getStringAttr(prefix + ".locate"), args,
1013                                  getCursorValTypes(b));
1014     seek(locate->getResults());
1015     updateCrd(crd);
1016     return;
1017   }
1018   return locateImpl(b, l, crd);
1019 }
1020 
deref(OpBuilder & b,Location l)1021 Value SparseIterator::deref(OpBuilder &b, Location l) {
1022   if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
1023     std::string prefix = getDebugInterfacePrefix();
1024     SmallVector<Value> args = getCursor();
1025     Operation *deref = b.create(l, b.getStringAttr(prefix + ".deref"),
1026                                 getCursor(), b.getIndexType());
1027     updateCrd(deref->getResult(0));
1028     return getCrd();
1029   }
1030   return derefImpl(b, l);
1031 }
1032 
forward(OpBuilder & b,Location l)1033 ValueRange SparseIterator::forward(OpBuilder &b, Location l) {
1034   assert(!randomAccessible());
1035   if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
1036     std::string prefix = getDebugInterfacePrefix();
1037     Operation *next = b.create(l, b.getStringAttr(prefix + ".next"),
1038                                getCursor(), getCursorValTypes(b));
1039     seek(next->getResults());
1040     return getCursor();
1041   }
1042   return forwardImpl(b, l);
1043 }
1044 
forwardIf(OpBuilder & b,Location l,Value cond)1045 ValueRange SparseIterator::forwardIf(OpBuilder &b, Location l, Value cond) {
1046   auto ifOp = b.create<scf::IfOp>(l, getCursor().getTypes(), cond, true);
1047   // Generate else branch first, otherwise iterator values will be updated by
1048   // `forward()`.
1049   b.setInsertionPointToStart(ifOp.elseBlock());
1050   YIELD(getCursor());
1051 
1052   b.setInsertionPointToStart(ifOp.thenBlock());
1053   YIELD(forward(b, l));
1054 
1055   b.setInsertionPointAfter(ifOp);
1056   seek(ifOp.getResults());
1057   return getCursor();
1058 }
1059 
genSegmentHigh(OpBuilder & b,Location l,Value pos)1060 Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) {
1061   auto whileOp = b.create<scf::WhileOp>(
1062       l, pos.getType(), pos,
1063       /*beforeBuilder=*/
1064       [this, pos](OpBuilder &b, Location l, ValueRange ivs) {
1065         Value inBound = CMPI(ult, ivs.front(), posHi);
1066         auto ifInBound = b.create<scf::IfOp>(l, b.getI1Type(), inBound, true);
1067         {
1068           OpBuilder::InsertionGuard guard(b);
1069           // If in bound, load the next coordinates and check duplication.
1070           b.setInsertionPointToStart(ifInBound.thenBlock());
1071           Value headCrd = stl.peekCrdAt(b, l, getBatchCrds(), pos);
1072           Value tailCrd = stl.peekCrdAt(b, l, getBatchCrds(), ivs.front());
1073           Value isDup = CMPI(eq, headCrd, tailCrd);
1074           YIELD(isDup);
1075           // Else, the position is out of bound, yield false.
1076           b.setInsertionPointToStart(ifInBound.elseBlock());
1077           YIELD(constantI1(b, l, false));
1078         }
1079         b.create<scf::ConditionOp>(l, ifInBound.getResults()[0], ivs);
1080       },
1081       /*afterBuilder=*/
1082       [](OpBuilder &b, Location l, ValueRange ivs) {
1083         Value nxPos = ADDI(ivs[0], C_IDX(1));
1084         YIELD(nxPos);
1085       });
1086   // Return the segment high.
1087   return whileOp.getResult(0);
1088 }
1089 
genCrdNotLegitPredicate(OpBuilder & b,Location l,Value wrapCrd)1090 Value FilterIterator::genCrdNotLegitPredicate(OpBuilder &b, Location l,
1091                                               Value wrapCrd) {
1092   Value crd = fromWrapCrd(b, l, wrapCrd);
1093   // Test whether the coordinate is on stride.
1094   Value notlegit = CMPI(ne, toWrapCrd(b, l, crd), wrapCrd);
1095   // Test wrapCrd < offset
1096   notlegit = ORI(CMPI(ult, wrapCrd, offset), notlegit);
1097   // Test crd >= length
1098   notlegit = ORI(CMPI(uge, crd, size), notlegit);
1099   return notlegit;
1100 }
1101 
genShouldFilter(OpBuilder & b,Location l)1102 Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) {
1103   auto r = genWhenInBound(
1104       b, l, *wrap, C_FALSE,
1105       [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1106         Value notLegit = genCrdNotLegitPredicate(b, l, wrapCrd);
1107         return {notLegit};
1108       });
1109 
1110   assert(r.size() == 1);
1111   return r.front();
1112 }
1113 
genNotEndImpl(OpBuilder & b,Location l)1114 Value FilterIterator::genNotEndImpl(OpBuilder &b, Location l) {
1115   assert(!wrap->randomAccessible());
1116   auto r = genWhenInBound(
1117       b, l, *wrap, C_FALSE,
1118       [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1119         Value crd = fromWrapCrd(b, l, wrapCrd);
1120         // crd < size
1121         return {CMPI(ult, crd, size)};
1122       });
1123   assert(r.size() == 1);
1124   return r.front();
1125 }
1126 
forwardImpl(OpBuilder & b,Location l)1127 ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
1128   assert(!randomAccessible());
1129   // Generates
1130   //
1131   // bool isFirst = true;
1132   // while !it.end() && (!legit(*it) || isFirst)
1133   //   wrap ++;
1134   //   isFirst = false;
1135   //
1136   // We do not hoist the first `wrap++` outside the loop but use a `isFirst`
1137   // flag here because `wrap++` might have a complex implementation (e.g., to
1138   // forward a subsection).
1139   Value isFirst = constantI1(b, l, true);
1140 
1141   SmallVector<Value> whileArgs(getCursor().begin(), getCursor().end());
1142   whileArgs.push_back(isFirst);
1143   auto whileOp = b.create<scf::WhileOp>(
1144       l, ValueRange(whileArgs).getTypes(), whileArgs,
1145       /*beforeBuilder=*/
1146       [this](OpBuilder &b, Location l, ValueRange ivs) {
1147         ValueRange isFirst = linkNewScope(ivs);
1148         assert(isFirst.size() == 1);
1149         scf::ValueVector cont =
1150             genWhenInBound(b, l, *wrap, C_FALSE,
1151                            [this, isFirst](OpBuilder &b, Location l,
1152                                            Value wrapCrd) -> scf::ValueVector {
1153                              // crd < size && !legit();
1154                              Value notLegit =
1155                                  genCrdNotLegitPredicate(b, l, wrapCrd);
1156                              Value crd = fromWrapCrd(b, l, wrapCrd);
1157                              Value ret = ANDI(CMPI(ult, crd, size), notLegit);
1158                              ret = ORI(ret, isFirst.front());
1159                              return {ret};
1160                            });
1161         b.create<scf::ConditionOp>(l, cont.front(), ivs);
1162       },
1163       /*afterBuilder=*/
1164       [this](OpBuilder &b, Location l, ValueRange ivs) {
1165         linkNewScope(ivs);
1166         wrap->forward(b, l);
1167         SmallVector<Value> yieldVals(getCursor().begin(), getCursor().end());
1168         yieldVals.push_back(constantI1(b, l, false));
1169         YIELD(yieldVals);
1170       });
1171 
1172   b.setInsertionPointAfter(whileOp);
1173   linkNewScope(whileOp.getResults());
1174   return getCursor();
1175 }
1176 
SubSectIterHelper(const NonEmptySubSectIterator & subSect)1177 SubSectIterHelper::SubSectIterHelper(const NonEmptySubSectIterator &subSect)
1178     : subSect(subSect), wrap(*subSect.delegate) {}
1179 
SubSectIterHelper(const SubSectIterator & iter)1180 SubSectIterHelper::SubSectIterHelper(const SubSectIterator &iter)
1181     : subSect(iter.subSect), wrap(*iter.wrap) {}
1182 
deserializeFromTupleId(OpBuilder & b,Location l,Value tupleId)1183 void SubSectIterHelper::deserializeFromTupleId(OpBuilder &b, Location l,
1184                                                Value tupleId) {
1185   assert(!subSect.randomAccessible());
1186   wrap.deserialize(subSect.loadCursorVals(b, l, tupleId));
1187 }
1188 
locate(OpBuilder & b,Location l,Value crd)1189 void SubSectIterHelper::locate(OpBuilder &b, Location l, Value crd) {
1190   Value absCrd = ADDI(crd, subSect.getAbsOff());
1191   wrap.locate(b, l, absCrd);
1192 }
1193 
genNotEnd(OpBuilder & b,Location l)1194 Value SubSectIterHelper::genNotEnd(OpBuilder &b, Location l) {
1195   assert(!wrap.randomAccessible());
1196   auto r = genWhenInBound(
1197       b, l, wrap, C_FALSE,
1198       [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1199         Value crd = SUBI(wrapCrd, subSect.getAbsOff());
1200         // crd < size
1201         return {CMPI(ult, crd, subSect.subSectSz)};
1202       });
1203   assert(r.size() == 1);
1204   return r.front();
1205 }
1206 
deref(OpBuilder & b,Location l)1207 Value SubSectIterHelper::deref(OpBuilder &b, Location l) {
1208   Value wrapCrd = wrap.deref(b, l);
1209   Value crd = subSect.toSubSectCrd(b, l, wrapCrd);
1210   return crd;
1211 }
1212 
forward(OpBuilder & b,Location l)1213 ValueRange SubSectIterHelper::forward(OpBuilder &b, Location l) {
1214   return wrap.forward(b, l);
1215 }
1216 
inflateSubSectTree(OpBuilder & b,Location l,ValueRange reduc,TraverseBuilder builder) const1217 ValueRange NonEmptySubSectIterator::inflateSubSectTree(
1218     OpBuilder &b, Location l, ValueRange reduc, TraverseBuilder builder) const {
1219   // Set up the helper to help traverse a sparse subsection.
1220   SubSectIterHelper helper(*this);
1221   if (!randomAccessible()) {
1222     // The subsection tree have been expanded till the level and cached,
1223     // traverse all the leaves and expanded to the next level.
1224     SmallVector<Value> iterArgs;
1225     iterArgs.push_back(C_IDX(0));
1226     iterArgs.append(reduc.begin(), reduc.end());
1227     auto forEachLeaf = b.create<scf::ForOp>(
1228         l, /*lb=*/C_IDX(0), /*ub=*/tupleCnt, /*step=*/C_IDX(1), iterArgs,
1229         [&helper, &builder](OpBuilder &b, Location l, Value tupleId,
1230                             ValueRange iterArgs) {
1231           // Deserialize the iterator at the cached position (tupleId).
1232           helper.deserializeFromTupleId(b, l, tupleId);
1233 
1234           Value cnt = iterArgs.front();
1235           // Record the number of leaf nodes included in the subsection.
1236           // The number indicates the starting tupleId for the next level that
1237           // is corresponding to the current node.
1238           helper.subSect.storeNxLvlStart(b, l, tupleId, cnt);
1239 
1240           SmallVector<Value> whileArgs(helper.wrap.getCursor());
1241           whileArgs.append(iterArgs.begin(), iterArgs.end());
1242 
1243           auto whileOp = b.create<scf::WhileOp>(
1244               l, ValueRange(whileArgs).getTypes(), whileArgs,
1245               /*beforeBuilder=*/
1246               [&helper](OpBuilder &b, Location l, ValueRange ivs) {
1247                 helper.wrap.linkNewScope(ivs);
1248                 b.create<scf::ConditionOp>(l, helper.genNotEnd(b, l), ivs);
1249               },
1250               /*afterBuilder=*/
1251               [&helper, &builder](OpBuilder &b, Location l, ValueRange ivs) {
1252                 ValueRange remIter = helper.wrap.linkNewScope(ivs);
1253                 Value cnt = remIter.front();
1254                 ValueRange userIter = remIter.drop_front();
1255                 scf::ValueVector userNx = builder(b, l, &helper.wrap, userIter);
1256 
1257                 SmallVector<Value> nxIter = helper.forward(b, l);
1258                 nxIter.push_back(ADDI(cnt, C_IDX(1)));
1259                 nxIter.append(userNx.begin(), userNx.end());
1260                 YIELD(nxIter);
1261               });
1262           ValueRange res = helper.wrap.linkNewScope(whileOp.getResults());
1263           YIELD(res);
1264         });
1265     return forEachLeaf.getResults().drop_front();
1266   }
1267 
1268   assert(randomAccessible());
1269   // Helper lambda that traverse the current dense subsection range.
1270   auto visitDenseSubSect = [&, this](OpBuilder &b, Location l,
1271                                      const SparseIterator *parent,
1272                                      ValueRange reduc) {
1273     assert(!parent || parent->lvl + 1 == lvl);
1274     delegate->genInit(b, l, parent);
1275     auto forOp = b.create<scf::ForOp>(
1276         l, /*lb=*/C_IDX(0), /*ub=*/subSectSz, /*step=*/C_IDX(1), reduc,
1277         [&](OpBuilder &b, Location l, Value crd, ValueRange iterArgs) {
1278           helper.locate(b, l, crd);
1279           scf::ValueVector nx = builder(b, l, &helper.wrap, iterArgs);
1280           YIELD(nx);
1281         });
1282     return forOp.getResults();
1283   };
1284 
1285   if (isSubSectRoot()) {
1286     return visitDenseSubSect(b, l, parent, reduc);
1287   }
1288   // Else, this is not the root, recurse until root.
1289   auto *p = llvm::cast<NonEmptySubSectIterator>(parent);
1290   assert(p->lvl + 1 == lvl);
1291   return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
1292 }
1293 
genInitImpl(OpBuilder & b,Location l,const SparseIterator * parent)1294 void TrivialIterator::genInitImpl(OpBuilder &b, Location l,
1295                                   const SparseIterator *parent) {
1296 
1297   if (isBatchIterator() && batchCrds.size() <= stl.lvl)
1298     batchCrds.resize(stl.lvl + 1, nullptr);
1299 
1300   Value c0 = C_IDX(0);
1301   ValueRange pPos = c0;
1302   Value inPadZone = nullptr;
1303   // If the parent iterator is a batch iterator, we also start from 0 (but
1304   // on a different batch).
1305   if (parent && !parent->isBatchIterator()) {
1306     pPos = parent->getCurPosition();
1307     if (llvm::isa<PadIterator>(parent) && parent->randomAccessible()) {
1308       // A padded dense iterator create "sparse" padded zone, which need to be
1309       // handled specially.
1310       inPadZone = pPos.back();
1311       pPos = pPos.drop_back();
1312     }
1313   }
1314 
1315   ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
1316   std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos, inPadZone);
1317   // Seek to the lowest position.
1318   seek(posLo);
1319 }
1320 
genInitImpl(OpBuilder & b,Location l,const SparseIterator *)1321 void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
1322                                           const SparseIterator *) {
1323   Value c0 = C_IDX(0);
1324   if (!isSubSectRoot()) {
1325     assert(parent->lvl + 1 == lvl);
1326     if (randomAccessible()) {
1327       // We can not call wrap->genInit() here to initialize the wrapped
1328       // iterator, because the parent of the curent iterator is still
1329       // unresolved.
1330       seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
1331       return;
1332     }
1333 
1334     auto *p = cast<NonEmptySubSectIterator>(parent);
1335     SmallVector<Value, 3> reduc = {
1336         C_IDX(-1), // minCrd (max signless integer)
1337         c0,        // tupleId
1338     };
1339 
1340     // Expand the subsection tree from the parent level to the current level.
1341     ValueRange result = p->inflateSubSectTree(
1342         b, l, reduc,
1343         [this](OpBuilder &b, Location l, const SparseIterator *parent,
1344                ValueRange reduc) -> scf::ValueVector {
1345           assert(parent->lvl + 1 == lvl && reduc.size() == 2);
1346           Value minCrd = reduc.front();
1347           Value tupleId = reduc.back();
1348 
1349           // Initialize the subsection range.
1350           SubSectIterHelper helper(*this);
1351           helper.wrap.genInit(b, l, parent);
1352 
1353           // Update minCrd.
1354           minCrd = genWhenInBound(b, l, helper.wrap, minCrd,
1355                                   [minCrd](OpBuilder &b, Location l,
1356                                            Value crd) -> scf::ValueVector {
1357                                     Value min = MINUI(crd, minCrd);
1358                                     return {min};
1359                                   })
1360                        .front();
1361 
1362           // Cache the sparse range.
1363           storeCursorVals(b, l, tupleId, helper.wrap.serialize());
1364           tupleId = ADDI(tupleId, C_IDX(1));
1365           return {minCrd, tupleId};
1366         });
1367     assert(result.size() == 2);
1368     tupleCnt = result.back();
1369 
1370     Value minCrd = result.front();
1371     Value absOff = offsetFromMinCrd(b, l, minCrd, subSectSz);
1372     Value notEnd = CMPI(ne, minCrd, C_IDX(-1));
1373     seek({minCrd, absOff, notEnd});
1374     return;
1375   }
1376 
1377   // This is the root level of the subsection, which means that it is resolved
1378   // to one node.
1379   assert(isSubSectRoot());
1380 
1381   // Initialize the position, the position marks the *lower bound* of the
1382   // subRange. The higher bound is determined by the size of the subsection.
1383   delegate->genInit(b, l, parent);
1384   if (randomAccessible()) {
1385     seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
1386     return;
1387   }
1388 
1389   // Only have one root node.
1390   tupleCnt = C_IDX(1);
1391   // Cache the sparse range.
1392   storeCursorVals(b, l, c0, delegate->serialize());
1393   SmallVector<Value> elseRet{c0, c0, /*notEnd=*/C_FALSE};
1394   auto meta = genWhenInBound(
1395       b, l, *delegate, elseRet,
1396       [this](OpBuilder &b, Location l, Value crd) -> scf::ValueVector {
1397         Value offset = offsetFromMinCrd(b, l, crd, subSectSz);
1398         return {crd, offset, C_TRUE};
1399       });
1400 
1401   seek(meta);
1402 }
1403 
forwardImpl(OpBuilder & b,Location l)1404 ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
1405   assert(!randomAccessible());
1406   Value c0 = C_IDX(0), c1 = C_IDX(1);
1407   // Forward to the next non empty slice by generating
1408   //
1409   // if (minCrd > offset) {
1410   //   offset += 1
1411   // } else {
1412   //    minCrd = nextMinInSlice();
1413   //    offset = minCrd - size + 1;
1414   // }
1415   //
1416   // if (offset + size > parents.size)
1417   //   isNonEmpty = false;
1418   Value fastPathP = CMPI(ugt, getMinCrd(), getAbsOff());
1419   auto ifOp = b.create<scf::IfOp>(l, getCursor().getTypes(), fastPathP, true);
1420   {
1421     OpBuilder::InsertionGuard guard(b);
1422     // Take the fast path
1423     // if (minCrd > offset)
1424     //   offset += 1
1425     b.setInsertionPointToStart(&ifOp.getThenRegion().front());
1426     Value nxOffset = ADDI(getAbsOff(), c1);
1427     YIELD((ValueRange{getMinCrd(), nxOffset, getNotEnd()}));
1428 
1429     // else /*minCrd == offset*/ {
1430     //    for (i = 0; i < tupleCnt; i++) {
1431     //       wrap->deserialize(pos[i]);
1432     //       minCrd=min(minCrd, *wrap);
1433     //    }
1434     //    offset = minCrd - size + 1;
1435     // }
1436     b.setInsertionPointToStart(&ifOp.getElseRegion().front());
1437     SmallVector<Value, 2> loopArgs{C_IDX(-1), // nextMinCrd
1438                                    C_FALSE};  // isNotEnd
1439     auto loopNest = scf::buildLoopNest(
1440         b, l, c0, tupleCnt, c1, loopArgs,
1441         [this](OpBuilder &b, Location l, ValueRange ivs,
1442                ValueRange iterArgs) -> scf::ValueVector {
1443           Value tupleId = ivs.front();
1444           SubSectIterHelper helper(*this);
1445           helper.deserializeFromTupleId(b, l, tupleId);
1446 
1447           return genWhenInBound(
1448               b, l, *delegate, /*elseRet=*/iterArgs,
1449               [this, iterArgs, tupleId](OpBuilder &b, Location l,
1450                                         Value crd) -> scf::ValueVector {
1451                 // if coord == minCrd
1452                 //   wrap->forward();
1453                 Value isMin = CMPI(eq, crd, getMinCrd());
1454                 delegate->forwardIf(b, l, isMin);
1455                 // Update the forwarded iterator values if needed.
1456                 auto ifIsMin = b.create<scf::IfOp>(l, isMin, false);
1457                 b.setInsertionPointToStart(&ifIsMin.getThenRegion().front());
1458                 storeCursorVals(b, l, tupleId, delegate->serialize());
1459                 b.setInsertionPointAfter(ifIsMin);
1460                 // if (!wrap.end())
1461                 //  yield(min(nxMinCrd, *wrap), true)
1462                 Value nxMin = iterArgs[0];
1463                 return genWhenInBound(b, l, *delegate, /*elseRet=*/iterArgs,
1464                                       [nxMin](OpBuilder &b, Location l,
1465                                               Value crd) -> scf::ValueVector {
1466                                         Value nx = b.create<arith::MinUIOp>(
1467                                             l, crd, nxMin);
1468                                         return {nx, C_TRUE};
1469                                       });
1470               });
1471         });
1472 
1473     scf::ForOp forOp = loopNest.loops.front();
1474     b.setInsertionPointAfter(forOp);
1475 
1476     Value nxMinCrd = forOp.getResult(0);
1477     Value nxNotEnd = forOp.getResult(1);
1478     Value nxAbsOff = offsetFromMinCrd(b, l, nxMinCrd, subSectSz);
1479     YIELD((ValueRange{nxMinCrd, nxAbsOff, nxNotEnd}));
1480   }
1481 
1482   Value nxMinCrd = ifOp.getResult(0);
1483   Value nxAbsOff = ifOp.getResult(1);
1484   Value nxNotEnd = ifOp.getResult(2);
1485 
1486   // We should at least forward the offset by one.
1487   Value minAbsOff = ADDI(getAbsOff(), c1);
1488   nxAbsOff = b.create<arith::MaxUIOp>(l, minAbsOff, nxAbsOff);
1489 
1490   seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1491   // The coordinate should not exceeds the space upper bound.
1492   Value crd = deref(b, l);
1493   nxNotEnd = ANDI(nxNotEnd, CMPI(ult, crd, upperBound(b, l)));
1494 
1495   seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1496   return getCursor();
1497 }
1498 
1499 //===----------------------------------------------------------------------===//
1500 // SparseIterationSpace Implementation
1501 //===----------------------------------------------------------------------===//
1502 
SparseIterationSpace(Location l,OpBuilder & b,Value t,unsigned tid,std::pair<Level,Level> lvlRange,ValueRange parentPos)1503 mlir::sparse_tensor::SparseIterationSpace::SparseIterationSpace(
1504     Location l, OpBuilder &b, Value t, unsigned tid,
1505     std::pair<Level, Level> lvlRange, ValueRange parentPos)
1506     : lvls() {
1507   auto [lvlLo, lvlHi] = lvlRange;
1508 
1509   Value c0 = C_IDX(0);
1510   if (parentPos.empty())
1511     parentPos = c0;
1512 
1513   for (Level lvl = lvlLo; lvl < lvlHi; lvl++)
1514     lvls.emplace_back(makeSparseTensorLevel(b, l, t, tid, lvl));
1515 
1516   bound = lvls.front()->peekRangeAt(b, l, /*batchPrefix=*/{}, parentPos);
1517   for (auto &lvl : getLvlRef().drop_front())
1518     bound = lvl->collapseRangeBetween(b, l, /*batchPrefix=*/{}, bound);
1519 }
1520 
fromValues(IterSpaceType dstTp,ValueRange values,unsigned int tid)1521 SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues(
1522     IterSpaceType dstTp, ValueRange values, unsigned int tid) {
1523   // Reconstruct every sparse tensor level.
1524   SparseIterationSpace space;
1525   for (auto [i, lt] : llvm::enumerate(dstTp.getLvlTypes())) {
1526     unsigned bufferCnt = 0;
1527     if (lt.isWithPosLT())
1528       bufferCnt++;
1529     if (lt.isWithCrdLT())
1530       bufferCnt++;
1531     // Sparse tensor buffers.
1532     ValueRange buffers = values.take_front(bufferCnt);
1533     values = values.drop_front(bufferCnt);
1534 
1535     // Level size.
1536     Value sz = values.front();
1537     values = values.drop_front();
1538     space.lvls.push_back(
1539         makeSparseTensorLevel(lt, sz, buffers, tid, i + dstTp.getLoLvl()));
1540   }
1541   // Two bounds.
1542   space.bound = std::make_pair(values[0], values[1]);
1543   values = values.drop_front(2);
1544 
1545   // Must have consumed all values.
1546   assert(values.empty());
1547   return space;
1548 }
1549 
1550 std::unique_ptr<SparseIterator>
extractIterator(OpBuilder & b,Location l) const1551 SparseIterationSpace::extractIterator(OpBuilder &b, Location l) const {
1552   return makeSimpleIterator(b, l, *this);
1553 }
1554 
1555 //===----------------------------------------------------------------------===//
1556 // SparseIterator factory functions.
1557 //===----------------------------------------------------------------------===//
1558 
1559 /// Helper function to create a TensorLevel object from given `tensor`.
1560 std::unique_ptr<SparseTensorLevel>
makeSparseTensorLevel(LevelType lt,Value sz,ValueRange b,unsigned t,Level l)1561 sparse_tensor::makeSparseTensorLevel(LevelType lt, Value sz, ValueRange b,
1562                                      unsigned t, Level l) {
1563   assert(lt.getNumBuffer() == b.size());
1564   switch (lt.getLvlFmt()) {
1565   case LevelFormat::Dense:
1566     return std::make_unique<DenseLevel>(t, l, sz);
1567   case LevelFormat::Batch:
1568     return std::make_unique<BatchLevel>(t, l, sz);
1569   case LevelFormat::Compressed:
1570     return std::make_unique<CompressedLevel>(t, l, lt, sz, b[0], b[1]);
1571   case LevelFormat::LooseCompressed:
1572     return std::make_unique<LooseCompressedLevel>(t, l, lt, sz, b[0], b[1]);
1573   case LevelFormat::Singleton:
1574     return std::make_unique<SingletonLevel>(t, l, lt, sz, b[0]);
1575   case LevelFormat::NOutOfM:
1576     return std::make_unique<NOutOfMLevel>(t, l, lt, sz, b[0]);
1577   case LevelFormat::Undef:
1578     llvm_unreachable("undefined level format");
1579   }
1580   llvm_unreachable("unrecognizable level format");
1581 }
1582 
1583 std::unique_ptr<SparseTensorLevel>
makeSparseTensorLevel(OpBuilder & b,Location l,Value t,unsigned tid,Level lvl)1584 sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
1585                                      unsigned tid, Level lvl) {
1586   auto stt = getSparseTensorType(t);
1587 
1588   LevelType lt = stt.getLvlType(lvl);
1589   Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
1590                                : b.create<tensor::DimOp>(l, t, lvl).getResult();
1591 
1592   SmallVector<Value, 2> buffers;
1593   if (lt.isWithPosLT()) {
1594     Value pos = b.create<ToPositionsOp>(l, t, lvl);
1595     buffers.push_back(pos);
1596   }
1597   if (lt.isWithCrdLT()) {
1598     Value pos = b.create<ToCoordinatesOp>(l, t, lvl);
1599     buffers.push_back(pos);
1600   }
1601   return makeSparseTensorLevel(lt, sz, buffers, tid, lvl);
1602 }
1603 
1604 std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
makeSynLevelAndIterator(Value sz,unsigned tid,unsigned lvl,SparseEmitStrategy strategy)1605 sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
1606                                        SparseEmitStrategy strategy) {
1607   auto stl = std::make_unique<BatchLevel>(tid, lvl, sz);
1608   auto it = std::make_unique<TrivialIterator>(*stl);
1609   it->setSparseEmitStrategy(strategy);
1610   return std::make_pair(std::move(stl), std::move(it));
1611 }
1612 
1613 std::unique_ptr<SparseIterator>
makeSimpleIterator(OpBuilder & b,Location l,const SparseIterationSpace & iterSpace)1614 sparse_tensor::makeSimpleIterator(OpBuilder &b, Location l,
1615                                   const SparseIterationSpace &iterSpace) {
1616   // assert(iterSpace.getSpaceDim() == 1);
1617   std::unique_ptr<SparseIterator> ret;
1618   if (!iterSpace.isUnique()) {
1619     // We always dedupliate the non-unique level, but we should optimize it away
1620     // if possible.
1621     ret = std::make_unique<DedupIterator>(b, l, iterSpace.getLastLvl(),
1622                                           iterSpace.getBoundLo(),
1623                                           iterSpace.getBoundHi());
1624   } else {
1625     ret = std::make_unique<TrivialIterator>(b, l, iterSpace.getLastLvl(),
1626                                             iterSpace.getBoundLo(),
1627                                             iterSpace.getBoundHi());
1628   }
1629   ret->setSparseEmitStrategy(SparseEmitStrategy::kFunctional);
1630   return ret;
1631 }
1632 
1633 std::unique_ptr<SparseIterator>
makeSimpleIterator(const SparseTensorLevel & stl,SparseEmitStrategy strategy)1634 sparse_tensor::makeSimpleIterator(const SparseTensorLevel &stl,
1635                                   SparseEmitStrategy strategy) {
1636   std::unique_ptr<SparseIterator> ret;
1637   if (!isUniqueLT(stl.getLT())) {
1638     // We always dedupliate the non-unique level, but we should optimize it away
1639     // if possible.
1640     ret = std::make_unique<DedupIterator>(stl);
1641   } else {
1642     ret = std::make_unique<TrivialIterator>(stl);
1643   }
1644   ret->setSparseEmitStrategy(strategy);
1645   return ret;
1646 }
1647 
1648 std::unique_ptr<SparseIterator>
makeSlicedLevelIterator(std::unique_ptr<SparseIterator> && sit,Value offset,Value stride,Value size,SparseEmitStrategy strategy)1649 sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit,
1650                                        Value offset, Value stride, Value size,
1651                                        SparseEmitStrategy strategy) {
1652 
1653   auto ret =
1654       std::make_unique<FilterIterator>(std::move(sit), offset, stride, size);
1655   ret->setSparseEmitStrategy(strategy);
1656   return ret;
1657 }
1658 
1659 std::unique_ptr<SparseIterator>
makePaddedIterator(std::unique_ptr<SparseIterator> && sit,Value padLow,Value padHigh,SparseEmitStrategy strategy)1660 sparse_tensor::makePaddedIterator(std::unique_ptr<SparseIterator> &&sit,
1661                                   Value padLow, Value padHigh,
1662                                   SparseEmitStrategy strategy) {
1663   auto ret = std::make_unique<PadIterator>(std::move(sit), padLow, padHigh);
1664   ret->setSparseEmitStrategy(strategy);
1665   return ret;
1666 }
1667 
tryUnwrapFilter(const SparseIterator * it)1668 static const SparseIterator *tryUnwrapFilter(const SparseIterator *it) {
1669   auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
1670   if (filter)
1671     return &filter->getWrappedIterator();
1672   return it;
1673 }
1674 
makeNonEmptySubSectIterator(OpBuilder & b,Location l,const SparseIterator * parent,Value loopBound,std::unique_ptr<SparseIterator> && delegate,Value size,unsigned stride,SparseEmitStrategy strategy)1675 std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
1676     OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
1677     std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride,
1678     SparseEmitStrategy strategy) {
1679 
1680   // Try unwrap the NonEmptySubSectIterator from a filter parent.
1681   parent = tryUnwrapFilter(parent);
1682   std::unique_ptr<SparseIterator> it =
1683       std::make_unique<NonEmptySubSectIterator>(b, l, parent,
1684                                                 std::move(delegate), size);
1685 
1686   if (stride != 1) {
1687     // TODO: We can safely skip bound checking on sparse levels, but for dense
1688     // iteration space, we need the bound to infer the dense loop range.
1689     it = std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1690                                           C_IDX(stride), /*size=*/loopBound);
1691   }
1692   it->setSparseEmitStrategy(strategy);
1693   return it;
1694 }
1695 
makeTraverseSubSectIterator(OpBuilder & b,Location l,const SparseIterator & subSectIter,const SparseIterator & parent,std::unique_ptr<SparseIterator> && wrap,Value loopBound,unsigned stride,SparseEmitStrategy strategy)1696 std::unique_ptr<SparseIterator> sparse_tensor::makeTraverseSubSectIterator(
1697     OpBuilder &b, Location l, const SparseIterator &subSectIter,
1698     const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
1699     Value loopBound, unsigned stride, SparseEmitStrategy strategy) {
1700 
1701   // This must be a subsection iterator or a filtered subsection iterator.
1702   auto &subSect =
1703       llvm::cast<NonEmptySubSectIterator>(*tryUnwrapFilter(&subSectIter));
1704 
1705   std::unique_ptr<SparseIterator> it = std::make_unique<SubSectIterator>(
1706       subSect, *tryUnwrapFilter(&parent), std::move(wrap));
1707 
1708   if (stride != 1) {
1709     it = std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1710                                           C_IDX(stride), /*size=*/loopBound);
1711   }
1712   it->setSparseEmitStrategy(strategy);
1713   return it;
1714 }
1715 
1716 #undef CMPI
1717 #undef C_IDX
1718 #undef YIELD
1719 #undef ADDI
1720 #undef ANDI
1721 #undef SUBI
1722 #undef MULI
1723 #undef SELECT
1724