1 //===- SparseTensorIterator.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "SparseTensorIterator.h"
10 #include "CodegenUtils.h"
11
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/SCF/IR/SCF.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15
16 using namespace mlir;
17 using namespace mlir::sparse_tensor;
18 using ValuePair = std::pair<Value, Value>;
19 using ValueTuple = std::tuple<Value, Value, Value>;
20
21 //===----------------------------------------------------------------------===//
22 // File local helper functions/macros.
23 //===----------------------------------------------------------------------===//
24 #define CMPI(p, lhs, rhs) \
25 (b.create<arith::CmpIOp>(l, arith::CmpIPredicate::p, (lhs), (rhs)) \
26 .getResult())
27
28 #define C_FALSE (constantI1(b, l, false))
29 #define C_TRUE (constantI1(b, l, true))
30 #define C_IDX(v) (constantIndex(b, l, (v)))
31 #define YIELD(vs) (b.create<scf::YieldOp>(l, (vs)))
32 #define ADDI(lhs, rhs) (b.create<arith::AddIOp>(l, (lhs), (rhs)).getResult())
33 #define ORI(lhs, rhs) (b.create<arith::OrIOp>(l, (lhs), (rhs)).getResult())
34 #define ANDI(lhs, rhs) (b.create<arith::AndIOp>(l, (lhs), (rhs)).getResult())
35 #define SUBI(lhs, rhs) (b.create<arith::SubIOp>(l, (lhs), (rhs)).getResult())
36 #define MULI(lhs, rhs) (b.create<arith::MulIOp>(l, (lhs), (rhs)).getResult())
37 #define MINUI(lhs, rhs) (b.create<arith::MinUIOp>(l, (lhs), (rhs)).getResult())
38 #define REMUI(lhs, rhs) (b.create<arith::RemUIOp>(l, (lhs), (rhs)).getResult())
39 #define DIVUI(lhs, rhs) (b.create<arith::DivUIOp>(l, (lhs), (rhs)).getResult())
40 #define SELECT(c, lhs, rhs) \
41 (b.create<arith::SelectOp>(l, (c), (lhs), (rhs)).getResult())
42
43 //===----------------------------------------------------------------------===//
44 // SparseTensorLevel derived classes.
45 //===----------------------------------------------------------------------===//
46
47 namespace {
48
49 template <bool hasPosBuffer>
50 class SparseLevel : public SparseTensorLevel {
51 // It is either an array of size 2 or size 1 depending on whether the sparse
52 // level requires a position array.
53 using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
54 std::array<Value, 1>>;
55
56 public:
SparseLevel(unsigned tid,Level lvl,LevelType lt,Value lvlSize,BufferT buffers)57 SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
58 BufferT buffers)
59 : SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {}
60
getLvlBuffers() const61 ValueRange getLvlBuffers() const override { return buffers; }
62
peekCrdAt(OpBuilder & b,Location l,ValueRange batchPrefix,Value iv) const63 Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
64 Value iv) const override {
65 SmallVector<Value> memCrd(batchPrefix);
66 memCrd.push_back(iv);
67 return genIndexLoad(b, l, getCrdBuf(), memCrd);
68 }
69
70 protected:
71 template <typename T = void, typename = std::enable_if_t<hasPosBuffer, T>>
getPosBuf() const72 Value getPosBuf() const {
73 return buffers[0];
74 }
75
getCrdBuf() const76 Value getCrdBuf() const {
77 if constexpr (hasPosBuffer)
78 return buffers[1];
79 else
80 return buffers[0];
81 }
82
83 const BufferT buffers;
84 };
85
86 class DenseLevel : public SparseTensorLevel {
87 public:
DenseLevel(unsigned tid,Level lvl,Value lvlSize)88 DenseLevel(unsigned tid, Level lvl, Value lvlSize)
89 : SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize) {}
90
peekCrdAt(OpBuilder &,Location,ValueRange,Value) const91 Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
92 llvm_unreachable("locate random-accessible level instead");
93 }
94
getLvlBuffers() const95 ValueRange getLvlBuffers() const override { return {}; }
96
peekRangeAt(OpBuilder & b,Location l,ValueRange batchPrefix,ValueRange parentPos,Value inPadZone) const97 ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
98 ValueRange parentPos, Value inPadZone) const override {
99 assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
100 assert(!inPadZone && "Not implemented");
101 Value p = parentPos.front();
102 Value posLo = MULI(p, lvlSize);
103 return {posLo, lvlSize};
104 }
105 };
106
107 class BatchLevel : public SparseTensorLevel {
108 public:
BatchLevel(unsigned tid,Level lvl,Value lvlSize)109 BatchLevel(unsigned tid, Level lvl, Value lvlSize)
110 : SparseTensorLevel(tid, lvl, LevelFormat::Batch, lvlSize) {}
111
peekCrdAt(OpBuilder &,Location,ValueRange,Value) const112 Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
113 llvm_unreachable("locate random-accessible level instead");
114 }
115
getLvlBuffers() const116 ValueRange getLvlBuffers() const override { return {}; }
117
peekRangeAt(OpBuilder & b,Location l,ValueRange,ValueRange parentPos,Value inPadZone) const118 ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
119 ValueRange parentPos, Value inPadZone) const override {
120 assert(!inPadZone && "Not implemented");
121 assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
122 // No need to linearize the position for non-annotated tensors.
123 return {C_IDX(0), lvlSize};
124 }
125 };
126
127 class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
128 public:
CompressedLevel(unsigned tid,Level lvl,LevelType lt,Value lvlSize,Value posBuffer,Value crdBuffer)129 CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
130 Value posBuffer, Value crdBuffer)
131 : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
132
peekRangeAt(OpBuilder & b,Location l,ValueRange batchPrefix,ValueRange parentPos,Value inPadZone) const133 ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
134 ValueRange parentPos, Value inPadZone) const override {
135
136 assert(parentPos.size() == 1 &&
137 "compressed level must be the first non-unique level.");
138
139 auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair {
140 Value p = parentPos.front();
141 SmallVector<Value> memCrd(batchPrefix);
142 memCrd.push_back(p);
143 Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
144 memCrd.back() = ADDI(p, C_IDX(1));
145 Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
146 return {pLo, pHi};
147 };
148
149 if (inPadZone == nullptr)
150 return loadRange();
151
152 SmallVector<Type, 2> types{b.getIndexType(), b.getIndexType()};
153 scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true);
154 // True branch, returns a "fake" empty range [0, 0) if parent
155 // iterator is in pad zone.
156 b.setInsertionPointToStart(posRangeIf.thenBlock());
157
158 SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)};
159 b.create<scf::YieldOp>(l, emptyRange);
160
161 // False branch, returns the actual range.
162 b.setInsertionPointToStart(posRangeIf.elseBlock());
163 auto [pLo, pHi] = loadRange();
164 SmallVector<Value, 2> loadedRange{pLo, pHi};
165 b.create<scf::YieldOp>(l, loadedRange);
166
167 b.setInsertionPointAfter(posRangeIf);
168 ValueRange posRange = posRangeIf.getResults();
169 return {posRange.front(), posRange.back()};
170 }
171 }; // namespace
172
173 class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
174 public:
LooseCompressedLevel(unsigned tid,Level lvl,LevelType lt,Value lvlSize,Value posBuffer,Value crdBuffer)175 LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
176 Value posBuffer, Value crdBuffer)
177 : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
178
peekRangeAt(OpBuilder & b,Location l,ValueRange batchPrefix,ValueRange parentPos,Value inPadZone) const179 ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
180 ValueRange parentPos, Value inPadZone) const override {
181 assert(parentPos.size() == 1 &&
182 "loose-compressed level must be the first non-unique level.");
183 assert(!inPadZone && "Not implemented");
184 SmallVector<Value> memCrd(batchPrefix);
185 Value p = parentPos.front();
186 p = MULI(p, C_IDX(2));
187 memCrd.push_back(p);
188 Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
189 memCrd.back() = ADDI(p, C_IDX(1));
190 Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
191 return {pLo, pHi};
192 }
193 }; // namespace
194
195 class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
196 public:
SingletonLevel(unsigned tid,Level lvl,LevelType lt,Value lvlSize,Value crdBuffer)197 SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
198 Value crdBuffer)
199 : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
200
peekRangeAt(OpBuilder & b,Location l,ValueRange batchPrefix,ValueRange parentPos,Value inPadZone) const201 ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
202 ValueRange parentPos, Value inPadZone) const override {
203 assert(parentPos.size() == 1 || parentPos.size() == 2);
204 assert(!inPadZone && "Not implemented");
205 Value p = parentPos.front();
206 Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
207
208 if (segHi == nullptr)
209 return {p, ADDI(p, C_IDX(1))};
210 // Use the segHi as the loop upper bound.
211 return {p, segHi};
212 }
213
214 ValuePair
collapseRangeBetween(OpBuilder & b,Location l,ValueRange batchPrefix,std::pair<Value,Value> parentRange) const215 collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix,
216 std::pair<Value, Value> parentRange) const override {
217 // Singleton level keeps the same range after collapsing.
218 return parentRange;
219 };
220 };
221
222 class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
223 public:
NOutOfMLevel(unsigned tid,Level lvl,LevelType lt,Value lvlSize,Value crdBuffer)224 NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
225 Value crdBuffer)
226 : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
227
peekRangeAt(OpBuilder & b,Location l,ValueRange batchPrefix,ValueRange parentPos,Value inPadZone) const228 ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
229 ValueRange parentPos, Value inPadZone) const override {
230 assert(parentPos.size() == 1 && isUnique() &&
231 "n:m level can not be non-unique.");
232 assert(!inPadZone && "Not implemented");
233 // Each n:m blk has exactly n specified elements.
234 auto n = getN(lt);
235 Value posLo = MULI(parentPos.front(), C_IDX(n));
236 return {posLo, ADDI(posLo, C_IDX(n))};
237 }
238 };
239
240 } // namespace
241
242 //===----------------------------------------------------------------------===//
243 // File local helpers
244 //===----------------------------------------------------------------------===//
245
genWhenInBound(OpBuilder & b,Location l,SparseIterator & it,ValueRange elseRet,llvm::function_ref<scf::ValueVector (OpBuilder &,Location,Value)> builder)246 static scf::ValueVector genWhenInBound(
247 OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet,
248 llvm::function_ref<scf::ValueVector(OpBuilder &, Location, Value)>
249 builder) {
250 TypeRange ifRetTypes = elseRet.getTypes();
251 auto ifOp = b.create<scf::IfOp>(l, ifRetTypes, it.genNotEnd(b, l), true);
252
253 b.setInsertionPointToStart(ifOp.thenBlock());
254 Value crd = it.deref(b, l);
255 scf::ValueVector ret = builder(b, l, crd);
256 YIELD(ret);
257
258 b.setInsertionPointToStart(ifOp.elseBlock());
259 YIELD(elseRet);
260
261 b.setInsertionPointAfter(ifOp);
262 return ifOp.getResults();
263 }
264
265 /// Generates code to compute the *absolute* offset of the slice based on the
266 /// provide minimum coordinates in the slice.
267 /// E.g., when reducing d0 + d1 + d2, we need two slices to fully reduced the
268 /// expression, i,e, s1 = slice(T, d0), s2 = slice(s1, d1). The *absolute*
269 /// offset is the offset computed relative to the initial tensors T.
270 ///
271 /// When isNonEmpty == true, the computed offset is meaningless and should not
272 /// be used during runtime, the method generates code to return 0 currently in
273 /// that case.
274 ///
275 /// offset = minCrd >= size ? minCrd - size + 1 : 0;
offsetFromMinCrd(OpBuilder & b,Location l,Value minCrd,Value size)276 static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd,
277 Value size) {
278 Value geSize = CMPI(uge, minCrd, size);
279 // Compute minCrd - size + 1.
280 Value mms = SUBI(ADDI(minCrd, C_IDX(1)), size);
281 // This is the absolute offset related to the actual tensor.
282 return SELECT(geSize, mms, C_IDX(0));
283 }
284
285 //===----------------------------------------------------------------------===//
286 // SparseIterator derived classes.
287 //===----------------------------------------------------------------------===//
288
289 namespace {
290
291 // The iterator that traverses a concrete sparse tensor levels. High-level
292 // abstract iterators wrap it to achieve more complex goals (such as collapsing
293 // several levels). It also holds the common storage to hold the mlir::Values
294 // for itself as well as for wrappers.
295 class ConcreteIterator : public SparseIterator {
296 protected:
ConcreteIterator(const SparseTensorLevel & stl,IterKind kind,unsigned cursorValCnt)297 ConcreteIterator(const SparseTensorLevel &stl, IterKind kind,
298 unsigned cursorValCnt)
299 : SparseIterator(kind, stl.tid, stl.lvl, cursorValCnt, cursorValsStorage),
300 stl(stl), cursorValsStorage(cursorValCnt, nullptr) {
301 assert(getCursor().size() == cursorValCnt);
302 };
303
304 public:
305 // For LLVM-style RTTI.
classof(const SparseIterator * from)306 static bool classof(const SparseIterator *from) {
307 return from->kind == IterKind::kTrivial;
308 }
309
isBatchIterator() const310 bool isBatchIterator() const override {
311 return stl.getLT().isa<LevelFormat::Batch>();
312 }
randomAccessible() const313 bool randomAccessible() const override {
314 return stl.getLT().hasDenseSemantic();
315 };
iteratableByFor() const316 bool iteratableByFor() const override { return kind != IterKind::kDedup; };
upperBound(OpBuilder & b,Location l) const317 Value upperBound(OpBuilder &b, Location l) const override {
318 return stl.getSize();
319 };
320
321 protected:
322 const SparseTensorLevel &stl;
323 // Owner of the storage, all wrappers build on top of a concrete iterator
324 // share the same storage such that the iterator values are always
325 // synchronized.
326 SmallVector<Value> cursorValsStorage;
327 };
328
329 class TrivialIterator : public ConcreteIterator {
330 public:
TrivialIterator(const SparseTensorLevel & stl)331 TrivialIterator(const SparseTensorLevel &stl)
332 : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
333
TrivialIterator(OpBuilder & b,Location l,const SparseTensorLevel & stl,Value posLo,Value posHi)334 TrivialIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
335 Value posLo, Value posHi)
336 : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1), posLo(posLo),
337 posHi(posHi) {
338 seek(posLo);
339 }
340
getDebugInterfacePrefix() const341 std::string getDebugInterfacePrefix() const override {
342 return std::string("trivial<") + stl.toString() + ">";
343 }
getCursorValTypes(OpBuilder & b) const344 SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
345 return {b.getIndexType()};
346 }
347
serialize() const348 SmallVector<Value> serialize() const override {
349 SmallVector<Value> ret;
350 ret.push_back(getItPos());
351 if (randomAccessible()) {
352 // Loop high is implicit (defined by `upperBound()`) for random-access
353 // iterator, but we need to memorize posLo for linearization.
354 ret.push_back(posLo);
355 } else {
356 ret.push_back(posHi);
357 }
358 return ret;
359 };
360
deserialize(ValueRange vs)361 void deserialize(ValueRange vs) override {
362 assert(vs.size() == 2);
363 seek(vs.front());
364 if (randomAccessible())
365 posLo = vs.back();
366 else
367 posHi = vs.back();
368 };
369
370 void genInitImpl(OpBuilder &b, Location l,
371 const SparseIterator *parent) override;
372
genForCond(OpBuilder & b,Location l)373 ValuePair genForCond(OpBuilder &b, Location l) override {
374 if (randomAccessible())
375 return {deref(b, l), upperBound(b, l)};
376 return std::make_pair(getItPos(), posHi);
377 }
378
genNotEndImpl(OpBuilder & b,Location l)379 Value genNotEndImpl(OpBuilder &b, Location l) override {
380 // We used the first level bound as the bound the collapsed set of levels.
381 return CMPI(ult, getItPos(), posHi);
382 }
383
derefImpl(OpBuilder & b,Location l)384 Value derefImpl(OpBuilder &b, Location l) override {
385 if (randomAccessible()) {
386 updateCrd(SUBI(getItPos(), posLo));
387 } else {
388 updateCrd(stl.peekCrdAt(b, l, getBatchCrds(), getItPos()));
389 }
390 return getCrd();
391 };
392
forwardImpl(OpBuilder & b,Location l)393 ValueRange forwardImpl(OpBuilder &b, Location l) override {
394 seek(ADDI(getItPos(), C_IDX(1)));
395 return getCursor();
396 }
397
forwardIf(OpBuilder & b,Location l,Value cond)398 ValueRange forwardIf(OpBuilder &b, Location l, Value cond) override {
399 Value curPos = getCursor().front();
400 Value nxPos = forward(b, l).front();
401 seek(SELECT(cond, nxPos, curPos));
402 return getCursor();
403 }
404
locateImpl(OpBuilder & b,Location l,Value crd)405 void locateImpl(OpBuilder &b, Location l, Value crd) override {
406 assert(randomAccessible());
407 // Seek to the linearized position.
408 seek(ADDI(crd, posLo));
409 updateCrd(crd);
410 if (isBatchIterator()) {
411 // If this is a batch iterator, also update the batch coordinate.
412 assert(batchCrds.size() > lvl);
413 batchCrds[lvl] = crd;
414 }
415 }
416
getItPos() const417 Value getItPos() const { return getCursor().front(); }
418 Value posLo, posHi;
419 };
420
421 class DedupIterator : public ConcreteIterator {
422 private:
423 Value genSegmentHigh(OpBuilder &b, Location l, Value pos);
424
425 public:
DedupIterator(const SparseTensorLevel & stl)426 DedupIterator(const SparseTensorLevel &stl)
427 : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) {
428 assert(!stl.isUnique());
429 }
430
DedupIterator(OpBuilder & b,Location l,const SparseTensorLevel & stl,Value posLo,Value posHi)431 DedupIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
432 Value posLo, Value posHi)
433 : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2), posHi(posHi) {
434 assert(!stl.isUnique());
435 seek({posLo, genSegmentHigh(b, l, posLo)});
436 }
437
438 // For LLVM-style RTTI.
classof(const SparseIterator * from)439 static bool classof(const SparseIterator *from) {
440 return from->kind == IterKind::kDedup;
441 }
442
getDebugInterfacePrefix() const443 std::string getDebugInterfacePrefix() const override {
444 return std::string("dedup<") + stl.toString() + ">";
445 }
getCursorValTypes(OpBuilder & b) const446 SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
447 return {b.getIndexType(), b.getIndexType()};
448 }
449
genInitImpl(OpBuilder & b,Location l,const SparseIterator * parent)450 void genInitImpl(OpBuilder &b, Location l,
451 const SparseIterator *parent) override {
452 Value c0 = C_IDX(0);
453 ValueRange pPos = c0;
454
455 // If the parent iterator is a batch iterator, we also start from 0 (but
456 // on a different batch).
457 if (parent && !parent->isBatchIterator())
458 pPos = parent->getCurPosition();
459
460 Value posLo;
461 ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
462 std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
463
464 seek({posLo, genSegmentHigh(b, l, posLo)});
465 }
466
serialize() const467 SmallVector<Value> serialize() const override {
468 SmallVector<Value> ret;
469 ret.append(getCursor().begin(), getCursor().end());
470 ret.push_back(posHi);
471 return ret;
472 };
deserialize(ValueRange vs)473 void deserialize(ValueRange vs) override {
474 assert(vs.size() == 3);
475 seek(vs.take_front(getCursor().size()));
476 posHi = vs.back();
477 };
478
genNotEndImpl(OpBuilder & b,Location l)479 Value genNotEndImpl(OpBuilder &b, Location l) override {
480 return CMPI(ult, getPos(), posHi);
481 }
482
derefImpl(OpBuilder & b,Location l)483 Value derefImpl(OpBuilder &b, Location l) override {
484 updateCrd(stl.peekCrdAt(b, l, getBatchCrds(), getPos()));
485 return getCrd();
486 };
487
forwardImpl(OpBuilder & b,Location l)488 ValueRange forwardImpl(OpBuilder &b, Location l) override {
489 Value nxPos = getSegHi(); // forward the position to the next segment.
490 seek({nxPos, genSegmentHigh(b, l, nxPos)});
491 return getCursor();
492 }
493
getPos() const494 Value getPos() const { return getCursor()[0]; }
getSegHi() const495 Value getSegHi() const { return getCursor()[1]; }
496
497 Value posHi;
498 };
499
500 // A util base-iterator that delegates all methods to the wrapped iterator.
501 class SimpleWrapIterator : public SparseIterator {
502 public:
SimpleWrapIterator(std::unique_ptr<SparseIterator> && wrap,IterKind kind,unsigned extraCursorVal=0)503 SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
504 unsigned extraCursorVal = 0)
505 : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
506
getCursorValTypes(OpBuilder & b) const507 SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
508 return wrap->getCursorValTypes(b);
509 }
isBatchIterator() const510 bool isBatchIterator() const override { return wrap->isBatchIterator(); }
randomAccessible() const511 bool randomAccessible() const override { return wrap->randomAccessible(); };
iteratableByFor() const512 bool iteratableByFor() const override { return wrap->iteratableByFor(); };
513
serialize() const514 SmallVector<Value> serialize() const override { return wrap->serialize(); };
deserialize(ValueRange vs)515 void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
getCurPosition() const516 ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
genInitImpl(OpBuilder & b,Location l,const SparseIterator * parent)517 void genInitImpl(OpBuilder &b, Location l,
518 const SparseIterator *parent) override {
519 wrap->genInit(b, l, parent);
520 }
genNotEndImpl(OpBuilder & b,Location l)521 Value genNotEndImpl(OpBuilder &b, Location l) override {
522 return wrap->genNotEndImpl(b, l);
523 }
forwardImpl(OpBuilder & b,Location l)524 ValueRange forwardImpl(OpBuilder &b, Location l) override {
525 return wrap->forward(b, l);
526 };
upperBound(OpBuilder & b,Location l) const527 Value upperBound(OpBuilder &b, Location l) const override {
528 return wrap->upperBound(b, l);
529 };
530
derefImpl(OpBuilder & b,Location l)531 Value derefImpl(OpBuilder &b, Location l) override {
532 return wrap->derefImpl(b, l);
533 }
534
locateImpl(OpBuilder & b,Location l,Value crd)535 void locateImpl(OpBuilder &b, Location l, Value crd) override {
536 return wrap->locate(b, l, crd);
537 }
538
getWrappedIterator() const539 SparseIterator &getWrappedIterator() const { return *wrap; }
540
541 protected:
542 std::unique_ptr<SparseIterator> wrap;
543 };
544
545 //
546 // A filter iterator wrapped from another iterator. The filter iterator update
547 // the wrapped iterator *in-place*.
548 //
549 class FilterIterator : public SimpleWrapIterator {
550 // Coorindate translation between crd loaded from the wrap iterator and the
551 // filter iterator.
fromWrapCrd(OpBuilder & b,Location l,Value wrapCrd) const552 Value fromWrapCrd(OpBuilder &b, Location l, Value wrapCrd) const {
553 // crd = (wrapCrd - offset) / stride
554 return DIVUI(SUBI(wrapCrd, offset), stride);
555 }
toWrapCrd(OpBuilder & b,Location l,Value crd) const556 Value toWrapCrd(OpBuilder &b, Location l, Value crd) const {
557 // wrapCrd = crd * stride + offset
558 return ADDI(MULI(crd, stride), offset);
559 }
560
561 Value genCrdNotLegitPredicate(OpBuilder &b, Location l, Value wrapCrd);
562
563 Value genShouldFilter(OpBuilder &b, Location l);
564
565 public:
566 // TODO: avoid unnessary check when offset == 0 and/or when stride == 1 and/or
567 // when crd always < size.
FilterIterator(std::unique_ptr<SparseIterator> && wrap,Value offset,Value stride,Value size)568 FilterIterator(std::unique_ptr<SparseIterator> &&wrap, Value offset,
569 Value stride, Value size)
570 : SimpleWrapIterator(std::move(wrap), IterKind::kFilter), offset(offset),
571 stride(stride), size(size) {}
572
573 // For LLVM-style RTTI.
classof(const SparseIterator * from)574 static bool classof(const SparseIterator *from) {
575 return from->kind == IterKind::kFilter;
576 }
577
getDebugInterfacePrefix() const578 std::string getDebugInterfacePrefix() const override {
579 return std::string("filter<") + wrap->getDebugInterfacePrefix() + ">";
580 }
581
iteratableByFor() const582 bool iteratableByFor() const override { return randomAccessible(); };
upperBound(OpBuilder & b,Location l) const583 Value upperBound(OpBuilder &b, Location l) const override { return size; };
584
genInitImpl(OpBuilder & b,Location l,const SparseIterator * parent)585 void genInitImpl(OpBuilder &b, Location l,
586 const SparseIterator *parent) override {
587 wrap->genInit(b, l, parent);
588 if (!randomAccessible()) {
589 // TODO: we can skip this when stride == 1 and offset == 0, we can also
590 // use binary search here.
591 forwardIf(b, l, genShouldFilter(b, l));
592 } else {
593 // Else, locate to the slice.offset, which is the first coordinate
594 // included by the slice.
595 wrap->locate(b, l, offset);
596 }
597 }
598
599 Value genNotEndImpl(OpBuilder &b, Location l) override;
600
derefImpl(OpBuilder & b,Location l)601 Value derefImpl(OpBuilder &b, Location l) override {
602 updateCrd(fromWrapCrd(b, l, wrap->deref(b, l)));
603 return getCrd();
604 }
605
locateImpl(OpBuilder & b,Location l,Value crd)606 void locateImpl(OpBuilder &b, Location l, Value crd) override {
607 assert(randomAccessible());
608 wrap->locate(b, l, toWrapCrd(b, l, crd));
609 updateCrd(crd);
610 }
611
612 ValueRange forwardImpl(OpBuilder &b, Location l) override;
613
614 Value offset, stride, size;
615 };
616
617 //
618 // A pad iterator wrapped from another iterator. The pad iterator updates
619 // the wrapped iterator *in-place*.
620 //
621 class PadIterator : public SimpleWrapIterator {
622
623 public:
PadIterator(std::unique_ptr<SparseIterator> && wrap,Value padLow,Value padHigh)624 PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
625 Value padHigh)
626 : SimpleWrapIterator(std::move(wrap), IterKind::kPad,
627 wrap->randomAccessible() ? 1 : 0),
628 padLow(padLow), padHigh(padHigh) {}
629
630 // For LLVM-style RTTI.
classof(const SparseIterator * from)631 static bool classof(const SparseIterator *from) {
632 return from->kind == IterKind::kPad;
633 }
634
getDebugInterfacePrefix() const635 std::string getDebugInterfacePrefix() const override {
636 return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
637 }
638
639 // Returns a pair of values for *upper*, *lower* bound respectively.
genForCond(OpBuilder & b,Location l)640 ValuePair genForCond(OpBuilder &b, Location l) override {
641 if (randomAccessible())
642 return {getCrd(), upperBound(b, l)};
643 return wrap->genForCond(b, l);
644 }
645
646 // For padded dense iterator, we append a `inPadZone: bool` in addition to
647 // values used by the wrapped iterator.
getCurPosition() const648 ValueRange getCurPosition() const override { return getCursor(); }
649
getCursorValTypes(OpBuilder & b) const650 SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
651 SmallVector<Type> ret = wrap->getCursorValTypes(b);
652 // Need an extra boolean value `inPadZone` for padded dense iterator.
653 if (randomAccessible())
654 ret.push_back(b.getI1Type());
655
656 return ret;
657 }
658
659 // The upper bound after padding becomes `size + padLow + padHigh`.
upperBound(OpBuilder & b,Location l) const660 Value upperBound(OpBuilder &b, Location l) const override {
661 return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
662 };
663
664 // The pad_coord = coord + pad_lo
derefImpl(OpBuilder & b,Location l)665 Value derefImpl(OpBuilder &b, Location l) override {
666 updateCrd(ADDI(wrap->deref(b, l), padLow));
667 return getCrd();
668 }
669
locateImpl(OpBuilder & b,Location l,Value crd)670 void locateImpl(OpBuilder &b, Location l, Value crd) override {
671 assert(randomAccessible());
672 wrap->locate(b, l, SUBI(crd, padLow));
673
674 // inPadZone = crd < padLow || crd >= size + padLow.
675 Value inPadLow = CMPI(ult, crd, padLow);
676 Value inPadHigh = CMPI(uge, crd, ADDI(wrap->upperBound(b, l), padLow));
677 getMutCursorVals().back() = ORI(inPadLow, inPadHigh);
678
679 updateCrd(crd);
680 }
681
682 Value padLow, padHigh;
683 };
684
685 class NonEmptySubSectIterator : public SparseIterator {
686 public:
687 using TraverseBuilder = llvm::function_ref<scf::ValueVector(
688 OpBuilder &, Location, const SparseIterator *, ValueRange)>;
689
NonEmptySubSectIterator(OpBuilder & b,Location l,const SparseIterator * parent,std::unique_ptr<SparseIterator> && delegate,Value subSectSz)690 NonEmptySubSectIterator(OpBuilder &b, Location l,
691 const SparseIterator *parent,
692 std::unique_ptr<SparseIterator> &&delegate,
693 Value subSectSz)
694 : SparseIterator(IterKind::kNonEmptySubSect, 3, subSectMeta, *delegate),
695 parent(parent), delegate(std::move(delegate)),
696 tupleSz(this->delegate->serialize().size()), subSectSz(subSectSz) {
697 auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
698 if (p == nullptr) {
699 // Extract subsections along the root level.
700 maxTupleCnt = C_IDX(1);
701 } else if (p->lvl == lvl) {
702 // Extract subsections along the same level.
703 maxTupleCnt = p->maxTupleCnt;
704 assert(false && "Not implemented.");
705 } else {
706 // Extract subsections along the previous level.
707 assert(p->lvl + 1 == lvl);
708 maxTupleCnt = MULI(p->maxTupleCnt, p->subSectSz);
709 }
710 // We don't need an extra buffer to find subsections on random-accessible
711 // levels.
712 if (randomAccessible())
713 return;
714 subSectPosBuf = allocSubSectPosBuf(b, l);
715 }
716
717 // For LLVM-style RTTI.
classof(const SparseIterator * from)718 static bool classof(const SparseIterator *from) {
719 return from->kind == IterKind::kNonEmptySubSect;
720 }
721
getDebugInterfacePrefix() const722 std::string getDebugInterfacePrefix() const override {
723 return std::string("ne_sub<") + delegate->getDebugInterfacePrefix() + ">";
724 }
getCursorValTypes(OpBuilder & b) const725 SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
726 // minCrd, absolute offset, notEnd
727 return {b.getIndexType(), b.getIndexType(), b.getI1Type()};
728 }
729
730 // The sliced pointer buffer is organized as:
731 // [[itVal0, itVal1, ..., pNx0],
732 // [itVal0, itVal1, ..., pNx0],
733 // ...]
allocSubSectPosBuf(OpBuilder & b,Location l)734 Value allocSubSectPosBuf(OpBuilder &b, Location l) {
735 return b.create<memref::AllocaOp>(
736 l,
737 MemRefType::get({ShapedType::kDynamic, tupleSz + 1}, b.getIndexType()),
738 maxTupleCnt);
739 }
740
storeNxLvlStart(OpBuilder & b,Location l,Value tupleId,Value start) const741 void storeNxLvlStart(OpBuilder &b, Location l, Value tupleId,
742 Value start) const {
743 b.create<memref::StoreOp>(l, start, subSectPosBuf,
744 ValueRange{tupleId, C_IDX(tupleSz)});
745 }
746
loadNxLvlStart(OpBuilder & b,Location l,Value tupleId) const747 Value loadNxLvlStart(OpBuilder &b, Location l, Value tupleId) const {
748 return b.create<memref::LoadOp>(l, subSectPosBuf,
749 ValueRange{tupleId, C_IDX(tupleSz)});
750 }
751
storeCursorVals(OpBuilder & b,Location l,Value tupleId,ValueRange itVals) const752 void storeCursorVals(OpBuilder &b, Location l, Value tupleId,
753 ValueRange itVals) const {
754 assert(itVals.size() == tupleSz);
755 for (unsigned i = 0; i < tupleSz; i++) {
756 b.create<memref::StoreOp>(l, itVals[i], subSectPosBuf,
757 ValueRange{tupleId, C_IDX(i)});
758 }
759 }
760
loadCursorVals(OpBuilder & b,Location l,Value tupleId) const761 SmallVector<Value> loadCursorVals(OpBuilder &b, Location l,
762 Value tupleId) const {
763 SmallVector<Value> ret;
764 for (unsigned i = 0; i < tupleSz; i++) {
765 Value v = b.create<memref::LoadOp>(l, subSectPosBuf,
766 ValueRange{tupleId, C_IDX(i)});
767 ret.push_back(v);
768 }
769 return ret;
770 }
771
isSubSectRoot() const772 bool isSubSectRoot() const {
773 return !parent || !llvm::isa<NonEmptySubSectIterator>(parent);
774 }
775
776 // Generate code that inflate the current subsection tree till the current
777 // level such that every leaf node is visited.
778 ValueRange inflateSubSectTree(OpBuilder &b, Location l, ValueRange reduc,
779 TraverseBuilder builder) const;
780
isBatchIterator() const781 bool isBatchIterator() const override { return delegate->isBatchIterator(); }
randomAccessible() const782 bool randomAccessible() const override {
783 return delegate->randomAccessible();
784 };
iteratableByFor() const785 bool iteratableByFor() const override { return randomAccessible(); };
upperBound(OpBuilder & b,Location l) const786 Value upperBound(OpBuilder &b, Location l) const override {
787 auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
788 Value parentUB =
789 p && p->lvl == lvl ? p->upperBound(b, l) : delegate->upperBound(b, l);
790 return ADDI(SUBI(parentUB, subSectSz), C_IDX(1));
791 };
792
793 void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override;
794
locateImpl(OpBuilder & b,Location l,Value crd)795 void locateImpl(OpBuilder &b, Location l, Value crd) override {
796 Value absOff = crd;
797
798 if (isSubSectRoot())
799 delegate->locate(b, l, absOff);
800 else
801 assert(parent->lvl + 1 == lvl);
802
803 seek(ValueRange{absOff, absOff, C_TRUE});
804 updateCrd(crd);
805 }
806
toSubSectCrd(OpBuilder & b,Location l,Value wrapCrd) const807 Value toSubSectCrd(OpBuilder &b, Location l, Value wrapCrd) const {
808 return SUBI(wrapCrd, getAbsOff());
809 }
810
genNotEndImpl(OpBuilder & b,Location l)811 Value genNotEndImpl(OpBuilder &b, Location l) override {
812 return getNotEnd();
813 };
814
derefImpl(OpBuilder & b,Location l)815 Value derefImpl(OpBuilder &b, Location l) override {
816 // Use the relative offset to coiterate.
817 Value crd;
818 auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
819 if (p && p->lvl == lvl)
820 crd = SUBI(getAbsOff(), p->getAbsOff());
821 crd = getAbsOff();
822
823 updateCrd(crd);
824 return crd;
825 };
826
827 ValueRange forwardImpl(OpBuilder &b, Location l) override;
828
getMinCrd() const829 Value getMinCrd() const { return subSectMeta[0]; }
getAbsOff() const830 Value getAbsOff() const { return subSectMeta[1]; }
getNotEnd() const831 Value getNotEnd() const { return subSectMeta[2]; }
832
833 const SparseIterator *parent;
834 std::unique_ptr<SparseIterator> delegate;
835
836 // Number of values required to serialize the wrapped iterator.
837 const unsigned tupleSz;
838 // Max number of tuples, and the actual number of tuple.
839 Value maxTupleCnt, tupleCnt;
840 // The memory used to cache the tuple serialized from the wrapped iterator.
841 Value subSectPosBuf;
842
843 const Value subSectSz;
844
845 // minCrd, absolute offset, notEnd
846 SmallVector<Value, 3> subSectMeta{nullptr, nullptr, nullptr};
847 };
848
849 class SubSectIterator;
850
851 // A wrapper that helps generating code to traverse a subsection, used
852 // by both `NonEmptySubSectIterator`and `SubSectIterator`.
853 struct SubSectIterHelper {
854 explicit SubSectIterHelper(const SubSectIterator &iter);
855 explicit SubSectIterHelper(const NonEmptySubSectIterator &subSect);
856
857 // Delegate methods.
858 void deserializeFromTupleId(OpBuilder &b, Location l, Value tupleId);
859 void locate(OpBuilder &b, Location l, Value crd);
860 Value genNotEnd(OpBuilder &b, Location l);
861 Value deref(OpBuilder &b, Location l);
862 ValueRange forward(OpBuilder &b, Location l);
863
864 const NonEmptySubSectIterator &subSect;
865 SparseIterator &wrap;
866 };
867
868 class SubSectIterator : public SparseIterator {
869 public:
SubSectIterator(const NonEmptySubSectIterator & subSect,const SparseIterator & parent,std::unique_ptr<SparseIterator> && wrap)870 SubSectIterator(const NonEmptySubSectIterator &subSect,
871 const SparseIterator &parent,
872 std::unique_ptr<SparseIterator> &&wrap)
873 : SparseIterator(IterKind::kSubSect, *wrap,
874 /*extraCursorCnt=*/wrap->randomAccessible() ? 0 : 1),
875 subSect(subSect), wrap(std::move(wrap)), parent(parent), helper(*this) {
876 assert(subSect.tid == tid && subSect.lvl == lvl);
877 assert(parent.kind != IterKind::kSubSect || parent.lvl + 1 == lvl);
878 };
879
880 // For LLVM-style RTTI.
classof(const SparseIterator * from)881 static bool classof(const SparseIterator *from) {
882 return from->kind == IterKind::kSubSect;
883 }
884
getDebugInterfacePrefix() const885 std::string getDebugInterfacePrefix() const override {
886 return std::string("subsect<") + wrap->getDebugInterfacePrefix() + ">";
887 }
getCursorValTypes(OpBuilder & b) const888 SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
889 SmallVector<Type> ret = wrap->getCursorValTypes(b);
890 if (!randomAccessible())
891 ret.push_back(b.getIndexType()); // The extra counter.
892 return ret;
893 }
894
isBatchIterator() const895 bool isBatchIterator() const override { return wrap->isBatchIterator(); }
randomAccessible() const896 bool randomAccessible() const override { return wrap->randomAccessible(); };
iteratableByFor() const897 bool iteratableByFor() const override { return randomAccessible(); };
upperBound(OpBuilder & b,Location l) const898 Value upperBound(OpBuilder &b, Location l) const override {
899 return subSect.subSectSz;
900 }
901
getCurPosition() const902 ValueRange getCurPosition() const override { return wrap->getCurPosition(); };
903
getNxLvlTupleId(OpBuilder & b,Location l) const904 Value getNxLvlTupleId(OpBuilder &b, Location l) const {
905 if (randomAccessible()) {
906 return ADDI(getCrd(), nxLvlTupleStart);
907 };
908 return ADDI(getCursor().back(), nxLvlTupleStart);
909 }
910
genInitImpl(OpBuilder & b,Location l,const SparseIterator *)911 void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override {
912 if (randomAccessible()) {
913 if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
914 assert(p->lvl + 1 == lvl);
915 wrap->genInit(b, l, p);
916 // Linearize the dense subsection index.
917 nxLvlTupleStart = MULI(subSect.subSectSz, p->getNxLvlTupleId(b, l));
918 } else {
919 assert(subSect.lvl == lvl && subSect.isSubSectRoot());
920 wrap->deserialize(subSect.delegate->serialize());
921 nxLvlTupleStart = C_IDX(0);
922 }
923 return;
924 }
925 assert(!randomAccessible());
926 assert(getCursor().size() == wrap->getCursor().size() + 1);
927 // Extra counter that counts the number of actually visited coordinates in
928 // the sparse subsection.
929 getMutCursorVals().back() = C_IDX(0);
930 Value tupleId;
931 if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
932 assert(p->lvl + 1 == lvl);
933 tupleId = p->getNxLvlTupleId(b, l);
934 } else {
935 assert(subSect.lvl == lvl && subSect.isSubSectRoot());
936 tupleId = C_IDX(0);
937 }
938 nxLvlTupleStart = subSect.loadNxLvlStart(b, l, tupleId);
939 helper.deserializeFromTupleId(b, l, tupleId);
940 }
941
locateImpl(OpBuilder & b,Location l,Value crd)942 void locateImpl(OpBuilder &b, Location l, Value crd) override {
943 helper.locate(b, l, crd);
944 updateCrd(crd);
945 }
946
genNotEndImpl(OpBuilder & b,Location l)947 Value genNotEndImpl(OpBuilder &b, Location l) override {
948 return helper.genNotEnd(b, l);
949 }
950
derefImpl(OpBuilder & b,Location l)951 Value derefImpl(OpBuilder &b, Location l) override {
952 Value crd = helper.deref(b, l);
953 updateCrd(crd);
954 return crd;
955 };
956
forwardImpl(OpBuilder & b,Location l)957 ValueRange forwardImpl(OpBuilder &b, Location l) override {
958 helper.forward(b, l);
959 assert(!randomAccessible());
960 assert(getCursor().size() == wrap->getCursor().size() + 1);
961 getMutCursorVals().back() = ADDI(getCursor().back(), C_IDX(1));
962 return getCursor();
963 };
964
965 Value nxLvlTupleStart;
966
967 const NonEmptySubSectIterator &subSect;
968 std::unique_ptr<SparseIterator> wrap;
969 const SparseIterator &parent;
970
971 SubSectIterHelper helper;
972 };
973
974 } // namespace
975
976 //===----------------------------------------------------------------------===//
977 // SparseIterator derived classes implementation.
978 //===----------------------------------------------------------------------===//
979
genInit(OpBuilder & b,Location l,const SparseIterator * p)980 void SparseIterator::genInit(OpBuilder &b, Location l,
981 const SparseIterator *p) {
982 if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
983 std::string prefix = getDebugInterfacePrefix();
984 Operation *begin = b.create(l, b.getStringAttr(prefix + ".begin"), {},
985 getCursorValTypes(b));
986 seek(begin->getResults());
987 return;
988 }
989 // Inherent batch coordinates from parents.
990 if (p)
991 inherentBatch(*p);
992 // TODO: support lowering to function call.
993 return genInitImpl(b, l, p);
994 }
995
genNotEnd(OpBuilder & b,Location l)996 Value SparseIterator::genNotEnd(OpBuilder &b, Location l) {
997 if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
998 std::string prefix = getDebugInterfacePrefix();
999 Operation *notEnd = b.create(l, b.getStringAttr(prefix + ".not_end"),
1000 getCursor(), b.getI1Type());
1001 return notEnd->getResult(0);
1002 }
1003 // TODO: support lowering to function call.
1004 return genNotEndImpl(b, l);
1005 }
1006
locate(OpBuilder & b,Location l,Value crd)1007 void SparseIterator::locate(OpBuilder &b, Location l, Value crd) {
1008 if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
1009 std::string prefix = getDebugInterfacePrefix();
1010 SmallVector<Value> args = getCursor();
1011 args.push_back(crd);
1012 Operation *locate = b.create(l, b.getStringAttr(prefix + ".locate"), args,
1013 getCursorValTypes(b));
1014 seek(locate->getResults());
1015 updateCrd(crd);
1016 return;
1017 }
1018 return locateImpl(b, l, crd);
1019 }
1020
deref(OpBuilder & b,Location l)1021 Value SparseIterator::deref(OpBuilder &b, Location l) {
1022 if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
1023 std::string prefix = getDebugInterfacePrefix();
1024 SmallVector<Value> args = getCursor();
1025 Operation *deref = b.create(l, b.getStringAttr(prefix + ".deref"),
1026 getCursor(), b.getIndexType());
1027 updateCrd(deref->getResult(0));
1028 return getCrd();
1029 }
1030 return derefImpl(b, l);
1031 }
1032
forward(OpBuilder & b,Location l)1033 ValueRange SparseIterator::forward(OpBuilder &b, Location l) {
1034 assert(!randomAccessible());
1035 if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
1036 std::string prefix = getDebugInterfacePrefix();
1037 Operation *next = b.create(l, b.getStringAttr(prefix + ".next"),
1038 getCursor(), getCursorValTypes(b));
1039 seek(next->getResults());
1040 return getCursor();
1041 }
1042 return forwardImpl(b, l);
1043 }
1044
forwardIf(OpBuilder & b,Location l,Value cond)1045 ValueRange SparseIterator::forwardIf(OpBuilder &b, Location l, Value cond) {
1046 auto ifOp = b.create<scf::IfOp>(l, getCursor().getTypes(), cond, true);
1047 // Generate else branch first, otherwise iterator values will be updated by
1048 // `forward()`.
1049 b.setInsertionPointToStart(ifOp.elseBlock());
1050 YIELD(getCursor());
1051
1052 b.setInsertionPointToStart(ifOp.thenBlock());
1053 YIELD(forward(b, l));
1054
1055 b.setInsertionPointAfter(ifOp);
1056 seek(ifOp.getResults());
1057 return getCursor();
1058 }
1059
genSegmentHigh(OpBuilder & b,Location l,Value pos)1060 Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) {
1061 auto whileOp = b.create<scf::WhileOp>(
1062 l, pos.getType(), pos,
1063 /*beforeBuilder=*/
1064 [this, pos](OpBuilder &b, Location l, ValueRange ivs) {
1065 Value inBound = CMPI(ult, ivs.front(), posHi);
1066 auto ifInBound = b.create<scf::IfOp>(l, b.getI1Type(), inBound, true);
1067 {
1068 OpBuilder::InsertionGuard guard(b);
1069 // If in bound, load the next coordinates and check duplication.
1070 b.setInsertionPointToStart(ifInBound.thenBlock());
1071 Value headCrd = stl.peekCrdAt(b, l, getBatchCrds(), pos);
1072 Value tailCrd = stl.peekCrdAt(b, l, getBatchCrds(), ivs.front());
1073 Value isDup = CMPI(eq, headCrd, tailCrd);
1074 YIELD(isDup);
1075 // Else, the position is out of bound, yield false.
1076 b.setInsertionPointToStart(ifInBound.elseBlock());
1077 YIELD(constantI1(b, l, false));
1078 }
1079 b.create<scf::ConditionOp>(l, ifInBound.getResults()[0], ivs);
1080 },
1081 /*afterBuilder=*/
1082 [](OpBuilder &b, Location l, ValueRange ivs) {
1083 Value nxPos = ADDI(ivs[0], C_IDX(1));
1084 YIELD(nxPos);
1085 });
1086 // Return the segment high.
1087 return whileOp.getResult(0);
1088 }
1089
genCrdNotLegitPredicate(OpBuilder & b,Location l,Value wrapCrd)1090 Value FilterIterator::genCrdNotLegitPredicate(OpBuilder &b, Location l,
1091 Value wrapCrd) {
1092 Value crd = fromWrapCrd(b, l, wrapCrd);
1093 // Test whether the coordinate is on stride.
1094 Value notlegit = CMPI(ne, toWrapCrd(b, l, crd), wrapCrd);
1095 // Test wrapCrd < offset
1096 notlegit = ORI(CMPI(ult, wrapCrd, offset), notlegit);
1097 // Test crd >= length
1098 notlegit = ORI(CMPI(uge, crd, size), notlegit);
1099 return notlegit;
1100 }
1101
genShouldFilter(OpBuilder & b,Location l)1102 Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) {
1103 auto r = genWhenInBound(
1104 b, l, *wrap, C_FALSE,
1105 [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1106 Value notLegit = genCrdNotLegitPredicate(b, l, wrapCrd);
1107 return {notLegit};
1108 });
1109
1110 assert(r.size() == 1);
1111 return r.front();
1112 }
1113
genNotEndImpl(OpBuilder & b,Location l)1114 Value FilterIterator::genNotEndImpl(OpBuilder &b, Location l) {
1115 assert(!wrap->randomAccessible());
1116 auto r = genWhenInBound(
1117 b, l, *wrap, C_FALSE,
1118 [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1119 Value crd = fromWrapCrd(b, l, wrapCrd);
1120 // crd < size
1121 return {CMPI(ult, crd, size)};
1122 });
1123 assert(r.size() == 1);
1124 return r.front();
1125 }
1126
forwardImpl(OpBuilder & b,Location l)1127 ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
1128 assert(!randomAccessible());
1129 // Generates
1130 //
1131 // bool isFirst = true;
1132 // while !it.end() && (!legit(*it) || isFirst)
1133 // wrap ++;
1134 // isFirst = false;
1135 //
1136 // We do not hoist the first `wrap++` outside the loop but use a `isFirst`
1137 // flag here because `wrap++` might have a complex implementation (e.g., to
1138 // forward a subsection).
1139 Value isFirst = constantI1(b, l, true);
1140
1141 SmallVector<Value> whileArgs(getCursor().begin(), getCursor().end());
1142 whileArgs.push_back(isFirst);
1143 auto whileOp = b.create<scf::WhileOp>(
1144 l, ValueRange(whileArgs).getTypes(), whileArgs,
1145 /*beforeBuilder=*/
1146 [this](OpBuilder &b, Location l, ValueRange ivs) {
1147 ValueRange isFirst = linkNewScope(ivs);
1148 assert(isFirst.size() == 1);
1149 scf::ValueVector cont =
1150 genWhenInBound(b, l, *wrap, C_FALSE,
1151 [this, isFirst](OpBuilder &b, Location l,
1152 Value wrapCrd) -> scf::ValueVector {
1153 // crd < size && !legit();
1154 Value notLegit =
1155 genCrdNotLegitPredicate(b, l, wrapCrd);
1156 Value crd = fromWrapCrd(b, l, wrapCrd);
1157 Value ret = ANDI(CMPI(ult, crd, size), notLegit);
1158 ret = ORI(ret, isFirst.front());
1159 return {ret};
1160 });
1161 b.create<scf::ConditionOp>(l, cont.front(), ivs);
1162 },
1163 /*afterBuilder=*/
1164 [this](OpBuilder &b, Location l, ValueRange ivs) {
1165 linkNewScope(ivs);
1166 wrap->forward(b, l);
1167 SmallVector<Value> yieldVals(getCursor().begin(), getCursor().end());
1168 yieldVals.push_back(constantI1(b, l, false));
1169 YIELD(yieldVals);
1170 });
1171
1172 b.setInsertionPointAfter(whileOp);
1173 linkNewScope(whileOp.getResults());
1174 return getCursor();
1175 }
1176
SubSectIterHelper(const NonEmptySubSectIterator & subSect)1177 SubSectIterHelper::SubSectIterHelper(const NonEmptySubSectIterator &subSect)
1178 : subSect(subSect), wrap(*subSect.delegate) {}
1179
SubSectIterHelper(const SubSectIterator & iter)1180 SubSectIterHelper::SubSectIterHelper(const SubSectIterator &iter)
1181 : subSect(iter.subSect), wrap(*iter.wrap) {}
1182
deserializeFromTupleId(OpBuilder & b,Location l,Value tupleId)1183 void SubSectIterHelper::deserializeFromTupleId(OpBuilder &b, Location l,
1184 Value tupleId) {
1185 assert(!subSect.randomAccessible());
1186 wrap.deserialize(subSect.loadCursorVals(b, l, tupleId));
1187 }
1188
locate(OpBuilder & b,Location l,Value crd)1189 void SubSectIterHelper::locate(OpBuilder &b, Location l, Value crd) {
1190 Value absCrd = ADDI(crd, subSect.getAbsOff());
1191 wrap.locate(b, l, absCrd);
1192 }
1193
genNotEnd(OpBuilder & b,Location l)1194 Value SubSectIterHelper::genNotEnd(OpBuilder &b, Location l) {
1195 assert(!wrap.randomAccessible());
1196 auto r = genWhenInBound(
1197 b, l, wrap, C_FALSE,
1198 [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1199 Value crd = SUBI(wrapCrd, subSect.getAbsOff());
1200 // crd < size
1201 return {CMPI(ult, crd, subSect.subSectSz)};
1202 });
1203 assert(r.size() == 1);
1204 return r.front();
1205 }
1206
deref(OpBuilder & b,Location l)1207 Value SubSectIterHelper::deref(OpBuilder &b, Location l) {
1208 Value wrapCrd = wrap.deref(b, l);
1209 Value crd = subSect.toSubSectCrd(b, l, wrapCrd);
1210 return crd;
1211 }
1212
forward(OpBuilder & b,Location l)1213 ValueRange SubSectIterHelper::forward(OpBuilder &b, Location l) {
1214 return wrap.forward(b, l);
1215 }
1216
inflateSubSectTree(OpBuilder & b,Location l,ValueRange reduc,TraverseBuilder builder) const1217 ValueRange NonEmptySubSectIterator::inflateSubSectTree(
1218 OpBuilder &b, Location l, ValueRange reduc, TraverseBuilder builder) const {
1219 // Set up the helper to help traverse a sparse subsection.
1220 SubSectIterHelper helper(*this);
1221 if (!randomAccessible()) {
1222 // The subsection tree have been expanded till the level and cached,
1223 // traverse all the leaves and expanded to the next level.
1224 SmallVector<Value> iterArgs;
1225 iterArgs.push_back(C_IDX(0));
1226 iterArgs.append(reduc.begin(), reduc.end());
1227 auto forEachLeaf = b.create<scf::ForOp>(
1228 l, /*lb=*/C_IDX(0), /*ub=*/tupleCnt, /*step=*/C_IDX(1), iterArgs,
1229 [&helper, &builder](OpBuilder &b, Location l, Value tupleId,
1230 ValueRange iterArgs) {
1231 // Deserialize the iterator at the cached position (tupleId).
1232 helper.deserializeFromTupleId(b, l, tupleId);
1233
1234 Value cnt = iterArgs.front();
1235 // Record the number of leaf nodes included in the subsection.
1236 // The number indicates the starting tupleId for the next level that
1237 // is corresponding to the current node.
1238 helper.subSect.storeNxLvlStart(b, l, tupleId, cnt);
1239
1240 SmallVector<Value> whileArgs(helper.wrap.getCursor());
1241 whileArgs.append(iterArgs.begin(), iterArgs.end());
1242
1243 auto whileOp = b.create<scf::WhileOp>(
1244 l, ValueRange(whileArgs).getTypes(), whileArgs,
1245 /*beforeBuilder=*/
1246 [&helper](OpBuilder &b, Location l, ValueRange ivs) {
1247 helper.wrap.linkNewScope(ivs);
1248 b.create<scf::ConditionOp>(l, helper.genNotEnd(b, l), ivs);
1249 },
1250 /*afterBuilder=*/
1251 [&helper, &builder](OpBuilder &b, Location l, ValueRange ivs) {
1252 ValueRange remIter = helper.wrap.linkNewScope(ivs);
1253 Value cnt = remIter.front();
1254 ValueRange userIter = remIter.drop_front();
1255 scf::ValueVector userNx = builder(b, l, &helper.wrap, userIter);
1256
1257 SmallVector<Value> nxIter = helper.forward(b, l);
1258 nxIter.push_back(ADDI(cnt, C_IDX(1)));
1259 nxIter.append(userNx.begin(), userNx.end());
1260 YIELD(nxIter);
1261 });
1262 ValueRange res = helper.wrap.linkNewScope(whileOp.getResults());
1263 YIELD(res);
1264 });
1265 return forEachLeaf.getResults().drop_front();
1266 }
1267
1268 assert(randomAccessible());
1269 // Helper lambda that traverse the current dense subsection range.
1270 auto visitDenseSubSect = [&, this](OpBuilder &b, Location l,
1271 const SparseIterator *parent,
1272 ValueRange reduc) {
1273 assert(!parent || parent->lvl + 1 == lvl);
1274 delegate->genInit(b, l, parent);
1275 auto forOp = b.create<scf::ForOp>(
1276 l, /*lb=*/C_IDX(0), /*ub=*/subSectSz, /*step=*/C_IDX(1), reduc,
1277 [&](OpBuilder &b, Location l, Value crd, ValueRange iterArgs) {
1278 helper.locate(b, l, crd);
1279 scf::ValueVector nx = builder(b, l, &helper.wrap, iterArgs);
1280 YIELD(nx);
1281 });
1282 return forOp.getResults();
1283 };
1284
1285 if (isSubSectRoot()) {
1286 return visitDenseSubSect(b, l, parent, reduc);
1287 }
1288 // Else, this is not the root, recurse until root.
1289 auto *p = llvm::cast<NonEmptySubSectIterator>(parent);
1290 assert(p->lvl + 1 == lvl);
1291 return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
1292 }
1293
genInitImpl(OpBuilder & b,Location l,const SparseIterator * parent)1294 void TrivialIterator::genInitImpl(OpBuilder &b, Location l,
1295 const SparseIterator *parent) {
1296
1297 if (isBatchIterator() && batchCrds.size() <= stl.lvl)
1298 batchCrds.resize(stl.lvl + 1, nullptr);
1299
1300 Value c0 = C_IDX(0);
1301 ValueRange pPos = c0;
1302 Value inPadZone = nullptr;
1303 // If the parent iterator is a batch iterator, we also start from 0 (but
1304 // on a different batch).
1305 if (parent && !parent->isBatchIterator()) {
1306 pPos = parent->getCurPosition();
1307 if (llvm::isa<PadIterator>(parent) && parent->randomAccessible()) {
1308 // A padded dense iterator create "sparse" padded zone, which need to be
1309 // handled specially.
1310 inPadZone = pPos.back();
1311 pPos = pPos.drop_back();
1312 }
1313 }
1314
1315 ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
1316 std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos, inPadZone);
1317 // Seek to the lowest position.
1318 seek(posLo);
1319 }
1320
genInitImpl(OpBuilder & b,Location l,const SparseIterator *)1321 void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
1322 const SparseIterator *) {
1323 Value c0 = C_IDX(0);
1324 if (!isSubSectRoot()) {
1325 assert(parent->lvl + 1 == lvl);
1326 if (randomAccessible()) {
1327 // We can not call wrap->genInit() here to initialize the wrapped
1328 // iterator, because the parent of the curent iterator is still
1329 // unresolved.
1330 seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
1331 return;
1332 }
1333
1334 auto *p = cast<NonEmptySubSectIterator>(parent);
1335 SmallVector<Value, 3> reduc = {
1336 C_IDX(-1), // minCrd (max signless integer)
1337 c0, // tupleId
1338 };
1339
1340 // Expand the subsection tree from the parent level to the current level.
1341 ValueRange result = p->inflateSubSectTree(
1342 b, l, reduc,
1343 [this](OpBuilder &b, Location l, const SparseIterator *parent,
1344 ValueRange reduc) -> scf::ValueVector {
1345 assert(parent->lvl + 1 == lvl && reduc.size() == 2);
1346 Value minCrd = reduc.front();
1347 Value tupleId = reduc.back();
1348
1349 // Initialize the subsection range.
1350 SubSectIterHelper helper(*this);
1351 helper.wrap.genInit(b, l, parent);
1352
1353 // Update minCrd.
1354 minCrd = genWhenInBound(b, l, helper.wrap, minCrd,
1355 [minCrd](OpBuilder &b, Location l,
1356 Value crd) -> scf::ValueVector {
1357 Value min = MINUI(crd, minCrd);
1358 return {min};
1359 })
1360 .front();
1361
1362 // Cache the sparse range.
1363 storeCursorVals(b, l, tupleId, helper.wrap.serialize());
1364 tupleId = ADDI(tupleId, C_IDX(1));
1365 return {minCrd, tupleId};
1366 });
1367 assert(result.size() == 2);
1368 tupleCnt = result.back();
1369
1370 Value minCrd = result.front();
1371 Value absOff = offsetFromMinCrd(b, l, minCrd, subSectSz);
1372 Value notEnd = CMPI(ne, minCrd, C_IDX(-1));
1373 seek({minCrd, absOff, notEnd});
1374 return;
1375 }
1376
1377 // This is the root level of the subsection, which means that it is resolved
1378 // to one node.
1379 assert(isSubSectRoot());
1380
1381 // Initialize the position, the position marks the *lower bound* of the
1382 // subRange. The higher bound is determined by the size of the subsection.
1383 delegate->genInit(b, l, parent);
1384 if (randomAccessible()) {
1385 seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
1386 return;
1387 }
1388
1389 // Only have one root node.
1390 tupleCnt = C_IDX(1);
1391 // Cache the sparse range.
1392 storeCursorVals(b, l, c0, delegate->serialize());
1393 SmallVector<Value> elseRet{c0, c0, /*notEnd=*/C_FALSE};
1394 auto meta = genWhenInBound(
1395 b, l, *delegate, elseRet,
1396 [this](OpBuilder &b, Location l, Value crd) -> scf::ValueVector {
1397 Value offset = offsetFromMinCrd(b, l, crd, subSectSz);
1398 return {crd, offset, C_TRUE};
1399 });
1400
1401 seek(meta);
1402 }
1403
forwardImpl(OpBuilder & b,Location l)1404 ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
1405 assert(!randomAccessible());
1406 Value c0 = C_IDX(0), c1 = C_IDX(1);
1407 // Forward to the next non empty slice by generating
1408 //
1409 // if (minCrd > offset) {
1410 // offset += 1
1411 // } else {
1412 // minCrd = nextMinInSlice();
1413 // offset = minCrd - size + 1;
1414 // }
1415 //
1416 // if (offset + size > parents.size)
1417 // isNonEmpty = false;
1418 Value fastPathP = CMPI(ugt, getMinCrd(), getAbsOff());
1419 auto ifOp = b.create<scf::IfOp>(l, getCursor().getTypes(), fastPathP, true);
1420 {
1421 OpBuilder::InsertionGuard guard(b);
1422 // Take the fast path
1423 // if (minCrd > offset)
1424 // offset += 1
1425 b.setInsertionPointToStart(&ifOp.getThenRegion().front());
1426 Value nxOffset = ADDI(getAbsOff(), c1);
1427 YIELD((ValueRange{getMinCrd(), nxOffset, getNotEnd()}));
1428
1429 // else /*minCrd == offset*/ {
1430 // for (i = 0; i < tupleCnt; i++) {
1431 // wrap->deserialize(pos[i]);
1432 // minCrd=min(minCrd, *wrap);
1433 // }
1434 // offset = minCrd - size + 1;
1435 // }
1436 b.setInsertionPointToStart(&ifOp.getElseRegion().front());
1437 SmallVector<Value, 2> loopArgs{C_IDX(-1), // nextMinCrd
1438 C_FALSE}; // isNotEnd
1439 auto loopNest = scf::buildLoopNest(
1440 b, l, c0, tupleCnt, c1, loopArgs,
1441 [this](OpBuilder &b, Location l, ValueRange ivs,
1442 ValueRange iterArgs) -> scf::ValueVector {
1443 Value tupleId = ivs.front();
1444 SubSectIterHelper helper(*this);
1445 helper.deserializeFromTupleId(b, l, tupleId);
1446
1447 return genWhenInBound(
1448 b, l, *delegate, /*elseRet=*/iterArgs,
1449 [this, iterArgs, tupleId](OpBuilder &b, Location l,
1450 Value crd) -> scf::ValueVector {
1451 // if coord == minCrd
1452 // wrap->forward();
1453 Value isMin = CMPI(eq, crd, getMinCrd());
1454 delegate->forwardIf(b, l, isMin);
1455 // Update the forwarded iterator values if needed.
1456 auto ifIsMin = b.create<scf::IfOp>(l, isMin, false);
1457 b.setInsertionPointToStart(&ifIsMin.getThenRegion().front());
1458 storeCursorVals(b, l, tupleId, delegate->serialize());
1459 b.setInsertionPointAfter(ifIsMin);
1460 // if (!wrap.end())
1461 // yield(min(nxMinCrd, *wrap), true)
1462 Value nxMin = iterArgs[0];
1463 return genWhenInBound(b, l, *delegate, /*elseRet=*/iterArgs,
1464 [nxMin](OpBuilder &b, Location l,
1465 Value crd) -> scf::ValueVector {
1466 Value nx = b.create<arith::MinUIOp>(
1467 l, crd, nxMin);
1468 return {nx, C_TRUE};
1469 });
1470 });
1471 });
1472
1473 scf::ForOp forOp = loopNest.loops.front();
1474 b.setInsertionPointAfter(forOp);
1475
1476 Value nxMinCrd = forOp.getResult(0);
1477 Value nxNotEnd = forOp.getResult(1);
1478 Value nxAbsOff = offsetFromMinCrd(b, l, nxMinCrd, subSectSz);
1479 YIELD((ValueRange{nxMinCrd, nxAbsOff, nxNotEnd}));
1480 }
1481
1482 Value nxMinCrd = ifOp.getResult(0);
1483 Value nxAbsOff = ifOp.getResult(1);
1484 Value nxNotEnd = ifOp.getResult(2);
1485
1486 // We should at least forward the offset by one.
1487 Value minAbsOff = ADDI(getAbsOff(), c1);
1488 nxAbsOff = b.create<arith::MaxUIOp>(l, minAbsOff, nxAbsOff);
1489
1490 seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1491 // The coordinate should not exceeds the space upper bound.
1492 Value crd = deref(b, l);
1493 nxNotEnd = ANDI(nxNotEnd, CMPI(ult, crd, upperBound(b, l)));
1494
1495 seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1496 return getCursor();
1497 }
1498
1499 //===----------------------------------------------------------------------===//
1500 // SparseIterationSpace Implementation
1501 //===----------------------------------------------------------------------===//
1502
SparseIterationSpace(Location l,OpBuilder & b,Value t,unsigned tid,std::pair<Level,Level> lvlRange,ValueRange parentPos)1503 mlir::sparse_tensor::SparseIterationSpace::SparseIterationSpace(
1504 Location l, OpBuilder &b, Value t, unsigned tid,
1505 std::pair<Level, Level> lvlRange, ValueRange parentPos)
1506 : lvls() {
1507 auto [lvlLo, lvlHi] = lvlRange;
1508
1509 Value c0 = C_IDX(0);
1510 if (parentPos.empty())
1511 parentPos = c0;
1512
1513 for (Level lvl = lvlLo; lvl < lvlHi; lvl++)
1514 lvls.emplace_back(makeSparseTensorLevel(b, l, t, tid, lvl));
1515
1516 bound = lvls.front()->peekRangeAt(b, l, /*batchPrefix=*/{}, parentPos);
1517 for (auto &lvl : getLvlRef().drop_front())
1518 bound = lvl->collapseRangeBetween(b, l, /*batchPrefix=*/{}, bound);
1519 }
1520
fromValues(IterSpaceType dstTp,ValueRange values,unsigned int tid)1521 SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues(
1522 IterSpaceType dstTp, ValueRange values, unsigned int tid) {
1523 // Reconstruct every sparse tensor level.
1524 SparseIterationSpace space;
1525 for (auto [i, lt] : llvm::enumerate(dstTp.getLvlTypes())) {
1526 unsigned bufferCnt = 0;
1527 if (lt.isWithPosLT())
1528 bufferCnt++;
1529 if (lt.isWithCrdLT())
1530 bufferCnt++;
1531 // Sparse tensor buffers.
1532 ValueRange buffers = values.take_front(bufferCnt);
1533 values = values.drop_front(bufferCnt);
1534
1535 // Level size.
1536 Value sz = values.front();
1537 values = values.drop_front();
1538 space.lvls.push_back(
1539 makeSparseTensorLevel(lt, sz, buffers, tid, i + dstTp.getLoLvl()));
1540 }
1541 // Two bounds.
1542 space.bound = std::make_pair(values[0], values[1]);
1543 values = values.drop_front(2);
1544
1545 // Must have consumed all values.
1546 assert(values.empty());
1547 return space;
1548 }
1549
1550 std::unique_ptr<SparseIterator>
extractIterator(OpBuilder & b,Location l) const1551 SparseIterationSpace::extractIterator(OpBuilder &b, Location l) const {
1552 return makeSimpleIterator(b, l, *this);
1553 }
1554
1555 //===----------------------------------------------------------------------===//
1556 // SparseIterator factory functions.
1557 //===----------------------------------------------------------------------===//
1558
1559 /// Helper function to create a TensorLevel object from given `tensor`.
1560 std::unique_ptr<SparseTensorLevel>
makeSparseTensorLevel(LevelType lt,Value sz,ValueRange b,unsigned t,Level l)1561 sparse_tensor::makeSparseTensorLevel(LevelType lt, Value sz, ValueRange b,
1562 unsigned t, Level l) {
1563 assert(lt.getNumBuffer() == b.size());
1564 switch (lt.getLvlFmt()) {
1565 case LevelFormat::Dense:
1566 return std::make_unique<DenseLevel>(t, l, sz);
1567 case LevelFormat::Batch:
1568 return std::make_unique<BatchLevel>(t, l, sz);
1569 case LevelFormat::Compressed:
1570 return std::make_unique<CompressedLevel>(t, l, lt, sz, b[0], b[1]);
1571 case LevelFormat::LooseCompressed:
1572 return std::make_unique<LooseCompressedLevel>(t, l, lt, sz, b[0], b[1]);
1573 case LevelFormat::Singleton:
1574 return std::make_unique<SingletonLevel>(t, l, lt, sz, b[0]);
1575 case LevelFormat::NOutOfM:
1576 return std::make_unique<NOutOfMLevel>(t, l, lt, sz, b[0]);
1577 case LevelFormat::Undef:
1578 llvm_unreachable("undefined level format");
1579 }
1580 llvm_unreachable("unrecognizable level format");
1581 }
1582
1583 std::unique_ptr<SparseTensorLevel>
makeSparseTensorLevel(OpBuilder & b,Location l,Value t,unsigned tid,Level lvl)1584 sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
1585 unsigned tid, Level lvl) {
1586 auto stt = getSparseTensorType(t);
1587
1588 LevelType lt = stt.getLvlType(lvl);
1589 Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
1590 : b.create<tensor::DimOp>(l, t, lvl).getResult();
1591
1592 SmallVector<Value, 2> buffers;
1593 if (lt.isWithPosLT()) {
1594 Value pos = b.create<ToPositionsOp>(l, t, lvl);
1595 buffers.push_back(pos);
1596 }
1597 if (lt.isWithCrdLT()) {
1598 Value pos = b.create<ToCoordinatesOp>(l, t, lvl);
1599 buffers.push_back(pos);
1600 }
1601 return makeSparseTensorLevel(lt, sz, buffers, tid, lvl);
1602 }
1603
1604 std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
makeSynLevelAndIterator(Value sz,unsigned tid,unsigned lvl,SparseEmitStrategy strategy)1605 sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
1606 SparseEmitStrategy strategy) {
1607 auto stl = std::make_unique<BatchLevel>(tid, lvl, sz);
1608 auto it = std::make_unique<TrivialIterator>(*stl);
1609 it->setSparseEmitStrategy(strategy);
1610 return std::make_pair(std::move(stl), std::move(it));
1611 }
1612
1613 std::unique_ptr<SparseIterator>
makeSimpleIterator(OpBuilder & b,Location l,const SparseIterationSpace & iterSpace)1614 sparse_tensor::makeSimpleIterator(OpBuilder &b, Location l,
1615 const SparseIterationSpace &iterSpace) {
1616 // assert(iterSpace.getSpaceDim() == 1);
1617 std::unique_ptr<SparseIterator> ret;
1618 if (!iterSpace.isUnique()) {
1619 // We always dedupliate the non-unique level, but we should optimize it away
1620 // if possible.
1621 ret = std::make_unique<DedupIterator>(b, l, iterSpace.getLastLvl(),
1622 iterSpace.getBoundLo(),
1623 iterSpace.getBoundHi());
1624 } else {
1625 ret = std::make_unique<TrivialIterator>(b, l, iterSpace.getLastLvl(),
1626 iterSpace.getBoundLo(),
1627 iterSpace.getBoundHi());
1628 }
1629 ret->setSparseEmitStrategy(SparseEmitStrategy::kFunctional);
1630 return ret;
1631 }
1632
1633 std::unique_ptr<SparseIterator>
makeSimpleIterator(const SparseTensorLevel & stl,SparseEmitStrategy strategy)1634 sparse_tensor::makeSimpleIterator(const SparseTensorLevel &stl,
1635 SparseEmitStrategy strategy) {
1636 std::unique_ptr<SparseIterator> ret;
1637 if (!isUniqueLT(stl.getLT())) {
1638 // We always dedupliate the non-unique level, but we should optimize it away
1639 // if possible.
1640 ret = std::make_unique<DedupIterator>(stl);
1641 } else {
1642 ret = std::make_unique<TrivialIterator>(stl);
1643 }
1644 ret->setSparseEmitStrategy(strategy);
1645 return ret;
1646 }
1647
1648 std::unique_ptr<SparseIterator>
makeSlicedLevelIterator(std::unique_ptr<SparseIterator> && sit,Value offset,Value stride,Value size,SparseEmitStrategy strategy)1649 sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit,
1650 Value offset, Value stride, Value size,
1651 SparseEmitStrategy strategy) {
1652
1653 auto ret =
1654 std::make_unique<FilterIterator>(std::move(sit), offset, stride, size);
1655 ret->setSparseEmitStrategy(strategy);
1656 return ret;
1657 }
1658
1659 std::unique_ptr<SparseIterator>
makePaddedIterator(std::unique_ptr<SparseIterator> && sit,Value padLow,Value padHigh,SparseEmitStrategy strategy)1660 sparse_tensor::makePaddedIterator(std::unique_ptr<SparseIterator> &&sit,
1661 Value padLow, Value padHigh,
1662 SparseEmitStrategy strategy) {
1663 auto ret = std::make_unique<PadIterator>(std::move(sit), padLow, padHigh);
1664 ret->setSparseEmitStrategy(strategy);
1665 return ret;
1666 }
1667
tryUnwrapFilter(const SparseIterator * it)1668 static const SparseIterator *tryUnwrapFilter(const SparseIterator *it) {
1669 auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
1670 if (filter)
1671 return &filter->getWrappedIterator();
1672 return it;
1673 }
1674
makeNonEmptySubSectIterator(OpBuilder & b,Location l,const SparseIterator * parent,Value loopBound,std::unique_ptr<SparseIterator> && delegate,Value size,unsigned stride,SparseEmitStrategy strategy)1675 std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
1676 OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
1677 std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride,
1678 SparseEmitStrategy strategy) {
1679
1680 // Try unwrap the NonEmptySubSectIterator from a filter parent.
1681 parent = tryUnwrapFilter(parent);
1682 std::unique_ptr<SparseIterator> it =
1683 std::make_unique<NonEmptySubSectIterator>(b, l, parent,
1684 std::move(delegate), size);
1685
1686 if (stride != 1) {
1687 // TODO: We can safely skip bound checking on sparse levels, but for dense
1688 // iteration space, we need the bound to infer the dense loop range.
1689 it = std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1690 C_IDX(stride), /*size=*/loopBound);
1691 }
1692 it->setSparseEmitStrategy(strategy);
1693 return it;
1694 }
1695
makeTraverseSubSectIterator(OpBuilder & b,Location l,const SparseIterator & subSectIter,const SparseIterator & parent,std::unique_ptr<SparseIterator> && wrap,Value loopBound,unsigned stride,SparseEmitStrategy strategy)1696 std::unique_ptr<SparseIterator> sparse_tensor::makeTraverseSubSectIterator(
1697 OpBuilder &b, Location l, const SparseIterator &subSectIter,
1698 const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
1699 Value loopBound, unsigned stride, SparseEmitStrategy strategy) {
1700
1701 // This must be a subsection iterator or a filtered subsection iterator.
1702 auto &subSect =
1703 llvm::cast<NonEmptySubSectIterator>(*tryUnwrapFilter(&subSectIter));
1704
1705 std::unique_ptr<SparseIterator> it = std::make_unique<SubSectIterator>(
1706 subSect, *tryUnwrapFilter(&parent), std::move(wrap));
1707
1708 if (stride != 1) {
1709 it = std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1710 C_IDX(stride), /*size=*/loopBound);
1711 }
1712 it->setSparseEmitStrategy(strategy);
1713 return it;
1714 }
1715
1716 #undef CMPI
1717 #undef C_IDX
1718 #undef YIELD
1719 #undef ADDI
1720 #undef ANDI
1721 #undef SUBI
1722 #undef MULI
1723 #undef SELECT
1724