xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp (revision 129f1001c3b1b5200de43917d53c0efbdf08f11f)
1 //===- LoopEmitter.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "LoopEmitter.h"
10 #include "CodegenUtils.h"
11 
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/Utils/Utils.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 
22 using namespace mlir;
23 using namespace mlir::sparse_tensor;
24 
25 //===----------------------------------------------------------------------===//
26 // File local shorthand macros
27 //===----------------------------------------------------------------------===//
28 
29 #define CMPI(p, l, r)                                                          \
30   (builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::p, (l), (r))       \
31        .getResult())
32 
33 #define C_IDX(v) (constantIndex(builder, loc, (v)))
34 #define YIELD(vs) (builder.create<scf::YieldOp>(loc, (vs)))
35 #define ADDI(lhs, rhs) (builder.create<arith::AddIOp>(loc, (lhs), (rhs)))
36 #define ANDI(lhs, rhs) (builder.create<arith::AndIOp>(loc, (lhs), (rhs)))
37 #define SUBI(lhs, rhs) (builder.create<arith::SubIOp>(loc, (lhs), (rhs)))
38 #define MULI(lhs, rhs) (builder.create<arith::MulIOp>(loc, (lhs), (rhs)))
39 #define REMUI(lhs, rhs) (builder.create<arith::RemUIOp>(loc, (lhs), (rhs)))
40 #define DIVUI(lhs, rhs) (builder.create<arith::DivUIOp>(loc, (lhs), (rhs)))
41 #define SELECT(c, l, r) (builder.create<arith::SelectOp>(loc, (c), (l), (r)))
42 
43 //===----------------------------------------------------------------------===//
44 // Debugging utils
45 //===----------------------------------------------------------------------===//
46 
47 #ifndef NDEBUG
48 LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder,
49                                                   Location loc, Value memref) {
50   memref = builder.create<memref::CastOp>(
51       loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref);
52   createFuncCall(builder, loc, "printMemrefInd", TypeRange{},
53                  ValueRange{memref}, EmitCInterface::On);
54 }
55 #endif
56 
57 //===----------------------------------------------------------------------===//
58 // File local helper functions.
59 //===----------------------------------------------------------------------===//
60 
61 // For index reduction loops, since the tensor are sliced into non-continuous
62 // fragments, we need a triple [pLo, pHi, pPtr], in which the pair (pLo, pHi)
63 // specifies the range of the fragment, and pPtr specifies the index of the
64 // corresponding fragment in the child level (i.e., a pointer to the sliced
65 // position array).
66 static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
67                             Level lvl) {
68   auto enc = getSparseTensorEncoding(tensor.getType());
69   return createOrFoldSliceOffsetOp(builder, loc, tensor, toDim(enc, lvl));
70 }
71 
72 static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
73                             Level lvl) {
74   auto enc = getSparseTensorEncoding(tensor.getType());
75   return createOrFoldSliceStrideOp(builder, loc, tensor, toDim(enc, lvl));
76 }
77 
78 static bool isIntOrFPZero(Attribute attr) {
79   if (auto f = llvm::dyn_cast<FloatAttr>(attr); f && f.getValue().isZero())
80     return true;
81   if (auto i = llvm::dyn_cast<IntegerAttr>(attr); i && i.getValue().isZero())
82     return true;
83   return false;
84 }
85 
86 static Value unFoldOpIntResult(OpBuilder &builder, Location loc,
87                                OpFoldResult ofr) {
88   if (std::optional<int64_t> i = getConstantIntValue(ofr); i.has_value())
89     return constantIndex(builder, loc, *i);
90   return cast<Value>(ofr);
91 }
92 
93 static Value tryFoldTensors(Value t) {
94   // TODO: this should be done through a folding pass after switching to
95   // `sparse_tensor.iterate`-based sparsification.
96   auto stt = tryGetSparseTensorType(t);
97   auto padOp = t.getDefiningOp<tensor::PadOp>();
98   if (padOp && stt.has_value() && stt->hasEncoding() &&
99       padOp.getSourceType().getEncoding() == stt->getEncoding() &&
100       stt->getEncoding().isIdentity()) {
101     // Try fusing padOp with zeros.
102     Attribute padCst;
103     if (matchPattern(padOp.getBody()->getTerminator(),
104                      m_Op<tensor::YieldOp>(m_Constant(&padCst))) &&
105         isIntOrFPZero(padCst)) {
106       return padOp.getSource();
107     }
108   }
109   return t;
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // Sparse tensor loop emitter class implementations
114 //===----------------------------------------------------------------------===//
115 
116 LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
117                          bool isSparseOut, unsigned numLoops,
118                          DependentLvlGetter dimGetter,
119                          SparseEmitStrategy emitStrategy) {
120   initialize(tensors, loopTag, hasOutput, isSparseOut, numLoops, dimGetter);
121 }
122 
123 void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
124                              bool isSparseOut, unsigned numLoops,
125                              DependentLvlGetter dimGetter,
126                              SparseEmitStrategy emitStrategy) {
127   // First initialize the top-level type of the fields.
128   this->loopTag = loopTag;
129   this->hasOutput = hasOutput;
130   this->isSparseOut = isSparseOut;
131   this->emitStrategy = emitStrategy;
132 
133   const unsigned numManifestTensors = ts.size();
134   const unsigned synTensorId = numManifestTensors;
135   const unsigned numTensors = numManifestTensors + 1;
136   // tensors array (len == numManifestTensor).
137   this->tensors.assign(ts.begin(), ts.end());
138   // Arrays with len == numTensor.
139   this->valBuffer.assign(numTensors, nullptr);
140   this->lvls.resize(numTensors);
141   this->iters.resize(numTensors);
142   this->spIterVals.resize(numTensors);
143 
144   // These zeros will be overwritten below, but we need to initialize
145   // them to something since we'll need random-access assignment.
146   this->loopStack.reserve(numLoops);
147   this->loopSeqStack.reserve(numLoops);
148 
149   // Index-reduction related fields.
150   this->dependentLvlMap.assign(
151       numTensors, std::vector<std::vector<std::pair<TensorLevel, unsigned>>>());
152   this->sliceMeta.assign(
153       numTensors, std::vector<std::vector<std::pair<Value, unsigned>>>());
154   this->levelReducedDep.assign(numTensors, std::vector<unsigned>());
155 
156   // Initialize nested types of `TensorId`-indexed fields.
157   for (TensorId tid = 0; tid < numTensors; tid++) {
158     Level lvlRank;
159     if (tid == synTensorId) {
160       // Synthetic tensor (conceptually) is an all-dense tensor with rank equal
161       // to the total number of loops (each level can potentially be mapped to
162       // one of the loop being generated).
163       lvlRank = numLoops;
164     } else {
165       const Value t = tensors[tid];
166       // a scalar or 0-dimension tensors
167       if (isZeroRankedTensorOrScalar(t.getType()))
168         continue;
169 
170       auto rtp = getRankedTensorType(t);
171       const SparseTensorType stt(rtp);
172       lvlRank = stt.getLvlRank();
173     }
174 
175     lvls[tid].resize(lvlRank);
176     iters[tid].resize(lvlRank);
177     spIterVals[tid].resize(lvlRank);
178     loopHighs.assign(numLoops, nullptr);
179 
180     // Slice-driven loops related initialization.
181     levelReducedDep[tid].assign(lvlRank, 0);
182     dependentLvlMap[tid].assign(
183         lvlRank, std::vector<std::pair<TensorLevel, unsigned>>());
184     sliceMeta[tid].assign(lvlRank, std::vector<std::pair<Value, unsigned>>());
185     if (dimGetter && !isSynTensor(tid)) {
186       for (Level l = 0; l < lvlRank; l++) {
187         std::vector<std::pair<LoopId, unsigned>> deps = dimGetter(tid, l);
188         // Sort the loop by order.
189         llvm::sort(deps, llvm::less_first());
190 
191         dependentLvlMap[tid][l] = std::move(deps);
192         unsigned depends = dependentLvlMap[tid][l].size();
193         if (depends == 0)
194           continue;
195         sliceMeta[tid][l].reserve(depends);
196       }
197     }
198   }
199 }
200 
201 std::unique_ptr<SparseIterator>
202 LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
203                                Level l) {
204   Value tensor = tensors[t];
205   auto stt = getSparseTensorType(tensor);
206   auto it = makeSimpleIterator(*lvls[t][l], emitStrategy);
207 
208   Value folded = tryFoldTensors(tensor);
209   if (folded != tensor) {
210     auto padOp = tensor.getDefiningOp<tensor::PadOp>();
211     assert(padOp);
212     if (padOp.getPaddedDims().test(l)) {
213       Value low = unFoldOpIntResult(builder, loc, padOp.getMixedLowPad()[l]);
214       Value high = unFoldOpIntResult(builder, loc, padOp.getMixedHighPad()[l]);
215       auto padIt = makePaddedIterator(std::move(it), low, high, emitStrategy);
216       return padIt;
217     }
218   }
219 
220   if (stt.hasEncoding() && stt.getEncoding().isSlice()) {
221     Value offset = genSliceOffset(builder, loc, tensor, l);
222     Value stride = genSliceStride(builder, loc, tensor, l);
223     auto slicedIt = makeSlicedLevelIterator(
224         std::move(it), offset, stride, lvls[t][l]->getSize(), emitStrategy);
225     return slicedIt;
226   }
227 
228   return it;
229 }
230 
231 void LoopEmitter::initializeLoopEmit(
232     OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater,
233     LoopEmitter::SynTensorBoundSetter synSetter) {
234 
235   // For every manifest tensor, set up the values buffer.
236   for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors;
237        t++) {
238     // TODO: this should be done through a folding pass after switching to
239     // `sparse_tensor.iterate`-based sparsification.
240     const Value tensor = tryFoldTensors(tensors[t]);
241     const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
242     // Skips only scalar, zero ranked tensor still need to be bufferized and
243     // (probably) filled with zeros by users.
244     if (!rtp)
245       continue;
246 
247     auto stt = getSparseTensorType(tensor);
248     const auto shape = rtp.getShape();
249 
250     // Perform the required bufferization. Dense inputs materialize from the
251     // input tensors. Sparse inputs use sparse primitives to obtain the values.
252     // Delegates extra output initialization to clients.
253     bool isOutput = isOutputTensor(t);
254     Type elementType = stt.getElementType();
255     if (!stt.hasEncoding()) {
256       // Non-annotated dense tensors.
257       BaseMemRefType denseTp = MemRefType::get(shape, elementType);
258 
259       // TODO: if we unconditionally use fully dynamic layout here, it breaks
260       // some vectorization passes which requires static stride = 1.
261       // Is it possible to call vectorization pass after bufferization?
262       if (llvm::isa_and_nonnull<tensor::ExtractSliceOp>(tensor.getDefiningOp()))
263         denseTp = bufferization::getMemRefTypeWithFullyDynamicLayout(rtp);
264 
265       Value denseVal =
266           builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
267       // Dense outputs need special handling.
268       if (isOutput && updater)
269         denseVal = updater(builder, loc, denseVal, tensor);
270 
271       valBuffer[t] = denseVal;
272     } else {
273       // Annotated sparse tensors.
274       // We also need the value buffer for all-dense annotated "sparse"
275       // tensors.
276       valBuffer[t] = builder.create<ToValuesOp>(loc, tensor);
277     }
278   }
279 
280   // The sparse iterator values will only be available after the loop is
281   // constructed.
282   if (emitStrategy == SparseEmitStrategy::kSparseIterator)
283     return;
284 
285   // For every synthetic tensor, set the high bound by calling the callback.
286   if (synSetter) {
287     TensorId synId = getSynTensorId();
288     for (unsigned i = 0, e = loopHighs.size(); i < e; i++) {
289       Value sz = loopHighs[i] = synSetter(builder, loc, i);
290       auto [stl, it] = makeSynLevelAndIterator(sz, synId, i, emitStrategy);
291       lvls[synId][i] = std::move(stl);
292       iters[synId][i].emplace_back(std::move(it));
293     }
294   }
295 
296   // For every manifest tensor:
297   // * For every level:
298   //   * get the positions and coordinates buffers
299   //   * get/compute the level-size, which is also used as the upper-bound
300   //     on positions.
301   for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors;
302        t++) {
303     // TODO: this should be done through a folding pass after switching to
304     // `sparse_tensor.iterate`-based sparsification.
305     const Value tensor = tryFoldTensors(tensors[t]);
306     const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
307     if (!rtp)
308       // Skips only scalar, zero ranked tensor still need to be bufferized and
309       // (probably) filled with zeros by users.
310       continue;
311 
312     auto stt = getSparseTensorType(tensor);
313     const Level lvlRank = stt.getLvlRank();
314 
315     // Scan all levels of current tensor.
316     for (Level l = 0; l < lvlRank; l++) {
317       // Find upper bound in current dimension.
318       lvls[t][l] = makeSparseTensorLevel(builder, loc, tensor, t, l);
319       if (!dependentLvlMap[t][l].empty())
320         continue;
321 
322       auto it = makeLevelIterator(builder, loc, t, l);
323       iters[t][l].emplace_back(std::move(it));
324     }
325     // NOTE: we can also prepare for 0 lvl here in advance, this will hoist
326     // some loop preparation from tensor iteration, but will also (undesirably)
327     // hoist the code ouside if-conditions.
328   }
329   // TODO: avoid treating subsection iterator as a special case.
330   initSubSectIterator(builder, loc);
331 }
332 
333 void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) {
334   Value c0 = C_IDX(0);
335   for (TensorId t = 0, e = tensors.size(); t < e; t++) {
336     auto rtp = dyn_cast<RankedTensorType>(tensors[t].getType());
337     if (!rtp)
338       continue;
339 
340     Level lvlRank = SparseTensorType(rtp).getLvlRank();
341 
342     // Compute the dependency reduction order.
343     auto remDepStack = dependentLvlMap;
344     std::vector<std::tuple<LoopId, TensorId, Level>> depRedOrder;
345     for (Level lvl = 0; lvl < lvlRank; lvl++) {
346       // Reverse queue into a stack.
347       std::reverse(remDepStack[t][lvl].begin(), remDepStack[t][lvl].end());
348       for (auto [loop, coeff] : dependentLvlMap[t][lvl])
349         depRedOrder.emplace_back(std::make_tuple(loop, t, lvl));
350     }
351 
352     if (depRedOrder.empty())
353       continue;
354 
355     std::sort(depRedOrder.begin(), depRedOrder.end(),
356               [](auto &l, auto &r) { return std::get<0>(l) < std::get<0>(r); });
357 
358     SmallVector<SparseIterator *> lastIter(tensors.size(), nullptr);
359     for (auto [loop, t, lvl] : depRedOrder) {
360       std::pair<LoopId, unsigned> curDep = remDepStack[t][lvl].back();
361       assert(curDep.first == loop);
362       remDepStack[t][lvl].pop_back();
363 
364       auto lvlIt = makeLevelIterator(builder, loc, t, lvl);
365       const SparseIterator *parent = lastIter[t];
366       if (!parent && lvl > 0) {
367         if (dependentLvlMap[t][lvl - 1].empty()) {
368           parent = iters[t][lvl - 1].back().get();
369         }
370       }
371 
372       std::unique_ptr<SparseIterator> it;
373       if (!remDepStack[t][lvl].empty()) {
374         // Compute the subsection size.
375         Value size = c0;
376         for (auto [loop, stride] : remDepStack[t][lvl]) {
377           Value idxMax = SUBI(loopHighs[loop], C_IDX(1));
378           size = ADDI(size, ADDI(MULI(idxMax, C_IDX(stride)), C_IDX(1)));
379         }
380         it = makeNonEmptySubSectIterator(builder, loc, parent, loopHighs[loop],
381                                          std::move(lvlIt), size, curDep.second,
382                                          emitStrategy);
383       } else {
384         const SparseIterator &subSectIter = *iters[t][lvl].back();
385         it = makeTraverseSubSectIterator(builder, loc, subSectIter, *parent,
386                                          std::move(lvlIt), loopHighs[loop],
387                                          curDep.second, emitStrategy);
388       }
389       lastIter[t] = it.get();
390       iters[t][lvl].emplace_back(std::move(it));
391     }
392   }
393 }
394 
395 void LoopEmitter::categorizeIterators(
396     ArrayRef<TensorLevel> tidLvls, SmallVectorImpl<SparseIterator *> &raIters,
397     SmallVectorImpl<SparseIterator *> &spIters) {
398   // Finds out the tensor level that we should use to generate loops. Amongs all
399   // the tensor levels, there is at most one sparse tensor level.
400   for (auto [t, l] : unpackTensorLevelRange(tidLvls)) {
401     SparseIterator *it = &getCurIterator(t, l);
402     if (it->randomAccessible())
403       raIters.push_back(it);
404     else
405       spIters.push_back(it);
406   }
407 
408   std::stable_sort(spIters.begin(), spIters.end(), [](auto lhs, auto rhs) {
409     // AffineUnRed > Affine > Slice > Trivial
410     return static_cast<uint8_t>(lhs->kind) > static_cast<uint8_t>(rhs->kind);
411   });
412 }
413 
414 void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
415                                   ArrayRef<TensorLevel> tidLvls) {
416   // TODO: sort
417   assert(loopSeqStack.size() == loopStack.size());
418 
419   if (emitStrategy != SparseEmitStrategy::kSparseIterator) {
420     // Prepares for all the tensors used in the current loop sequence.
421     for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
422       levelReducedDep[tid][lvl]++;
423       prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
424     }
425   }
426 
427   // Universal Index starts from 0.
428   loopSeqStack.emplace_back(C_IDX(0), tidLvls.vec());
429 }
430 
431 void LoopEmitter::exitCurrentLoopSeq(OpBuilder &builder, Location loc) {
432   assert(loopSeqStack.size() == loopStack.size() + 1);
433 
434   // Depending on whether the slice is resolved or not at current loop sequence,
435   // end them in different ways.
436   for (auto [tid, lvl] : unpackTensorLevelRange(loopSeqStack.back().second))
437     levelReducedDep[tid][lvl]--;
438 
439   loopSeqStack.pop_back();
440 }
441 
442 Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) {
443   switch (a.getKind()) {
444   case AffineExprKind::DimId: {
445     // FIXME: since the one callsite in Sparsification passes in a
446     // level-expression, the `getPosition` must in fact be a `Dimension`.
447     // However, elsewhere we have been lead to expect that `loopIdToOrd`
448     // should be indexed by `LoopId`...
449     const auto loopId = cast<AffineDimExpr>(a).getPosition();
450     return loopStack[loopId].iv;
451   }
452   case AffineExprKind::Add: {
453     auto binOp = cast<AffineBinaryOpExpr>(a);
454     return ADDI(genAffine(builder, loc, binOp.getLHS()),
455                 genAffine(builder, loc, binOp.getRHS()));
456   }
457   case AffineExprKind::Mul: {
458     auto binOp = cast<AffineBinaryOpExpr>(a);
459     return MULI(genAffine(builder, loc, binOp.getLHS()),
460                 genAffine(builder, loc, binOp.getRHS()));
461   }
462   case AffineExprKind::Constant: {
463     int64_t c = cast<AffineConstantExpr>(a).getValue();
464     return C_IDX(c);
465   }
466   default:
467     llvm_unreachable("unexpected affine subscript");
468   }
469 }
470 
471 std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
472     OpBuilder &builder, Location loc, SparseIterator &iter,
473     MutableArrayRef<Value> reduc, bool isParallel) {
474 
475   // TODO: support dynamic slices.
476   // Uses the first dimension here to build the loop bound (which is also the
477   // biggest range).
478 
479   Value step = C_IDX(1);
480   auto [lo, hi] = iter.genForCond(builder, loc);
481   Operation *loop = nullptr;
482   Value iv;
483   if (isParallel) {
484     scf::ParallelOp parOp =
485         builder.create<scf::ParallelOp>(loc, lo, hi, step, reduc);
486     builder.setInsertionPointToStart(parOp.getBody());
487     assert(parOp.getNumReductions() == reduc.size());
488     iv = parOp.getInductionVars()[0];
489 
490     // In-place update on the reduction variable vector.
491     // Note that the init vals is not the actual reduction variables but instead
492     // used as a "special handle" to (temporarily) represent them. The
493     // expression on init vals will be moved into scf.reduce and replaced with
494     // the block arguments when exiting the loop (see exitForLoop). This is
495     // needed as we can not build the actual reduction block and get the actual
496     // reduction variable before users fill parallel loop body.
497     for (int i = 0, e = reduc.size(); i < e; i++)
498       reduc[i] = parOp.getInitVals()[i];
499     loop = parOp;
500   } else {
501     scf::ForOp forOp = builder.create<scf::ForOp>(loc, lo, hi, step, reduc);
502     builder.setInsertionPointToStart(forOp.getBody());
503     iv = forOp.getInductionVar();
504 
505     // In-place update on the reduction variable vector.
506     assert(forOp.getNumRegionIterArgs() == reduc.size());
507     for (int i = 0, e = reduc.size(); i < e; i++)
508       reduc[i] = forOp.getRegionIterArg(i);
509     loop = forOp;
510   }
511   assert(loop && iv);
512 
513   Value crd = iv;
514   if (!iter.randomAccessible()) {
515     iter.linkNewScope(iv);
516     crd = iter.deref(builder, loc);
517   } else {
518     iter.locate(builder, loc, iv);
519   }
520 
521   return {loop, crd};
522 }
523 
524 std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
525     OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters,
526     MutableArrayRef<Value> reduc, bool needsUniv) {
527   return genCoIteration(builder, loc, spIters, reduc,
528                         needsUniv ? loopSeqStack.back().first : nullptr);
529 }
530 
531 bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters) {
532   // If we need to co-iterate over two sparse tensors, we need a while loop
533   if (spIters.size() > 1)
534     return false;
535 
536   if (spIters.size() == 1)
537     return spIters.front()->iteratableByFor();
538 
539   return true;
540 }
541 
542 Region *LoopEmitter::enterCurrentCoIterationCase(OpBuilder &builder,
543                                                  Location loc,
544                                                  I64BitSet caseBit,
545                                                  unsigned caseIdx,
546                                                  MutableArrayRef<Value> reduc) {
547   auto coIterOp = cast<CoIterateOp>(loopStack.back().loop);
548   SmallVector<Attribute> cases(coIterOp.getCases().getAsRange<Attribute>());
549   cases[caseIdx] = builder.getI64IntegerAttr(caseBit);
550 
551   coIterOp.setCasesAttr(builder.getArrayAttr(cases));
552   Region &caseRegion = coIterOp.getRegion(caseIdx);
553   assert(caseRegion.getBlocks().empty() &&
554          "re-initialize the same coiteration case region.");
555 
556   // Each block starts with by a list of user-provided iteration arguments.
557   TypeRange iterArgsTps = coIterOp.getInitArgs().getTypes();
558   // Followed by a list of used coordinates of index type.
559   SmallVector<Type> blockArgTps(coIterOp.getCrdUsedLvls().count(),
560                                 builder.getIndexType());
561 
562   blockArgTps.append(iterArgsTps.begin(), iterArgsTps.end());
563   // Ends with a set of iterators that defines the actually iteration space.
564   for (auto i : caseBit.bits()) {
565     blockArgTps.push_back(
566         cast<IterSpaceType>(coIterOp.getIterSpaces()[i].getType())
567             .getIteratorType());
568   }
569   SmallVector<Location> locs(blockArgTps.size(), loc);
570   caseRegion.emplaceBlock().addArguments(blockArgTps, locs);
571 
572   // Entering the new region scope, updating the SSA chain.
573   builder.setInsertionPointToStart(&caseRegion.front());
574   // Update the coordinates.
575   loopStack.back().iv = coIterOp.getCrds(caseIdx).front();
576   // Updates loop iteration arguments.
577   ValueRange iterArgs = coIterOp.getRegionIterArgs(caseIdx);
578   llvm::copy(iterArgs, reduc.begin());
579   // Updates sparse iterator values.
580   ValueRange iters = coIterOp.getRegionIterators(caseIdx);
581   ArrayRef<TensorLevel> tidLvls = loopStack.back().tidLvls;
582   for (auto [i, tl] : llvm::enumerate(unpackTensorLevelRange(tidLvls))) {
583     if (caseBit[i]) {
584       spIterVals[tl.first][tl.second] = iters.front();
585       iters = iters.drop_front();
586     } else {
587       spIterVals[tl.first][tl.second] = nullptr;
588     }
589   }
590   // Must have consumed all iterator SSA values.
591   assert(iters.empty());
592   return &caseRegion;
593 }
594 
595 Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
596     OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
597     unsigned numCases, MutableArrayRef<Value> reduc, bool tryParallel,
598     bool needsUniv) {
599   // TODO: Argument `numCases` only used when generating iterator-based sparse
600   // loops. Simplify the code upon feature complete.
601   // TODO: handle coiteration with sparse iterator.
602   if (emitStrategy == SparseEmitStrategy::kSparseIterator) {
603     if (tidLvls.size() == 1) {
604       auto [tid, lvl] = unpackTensorLevel(tidLvls.front());
605       Value t = tensors[tid];
606 
607       // Extract and iterate over the iteration space.
608       ExtractIterSpaceOp extractSpaceOp =
609           lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t)
610                    : builder.create<ExtractIterSpaceOp>(
611                          loc, t, spIterVals[tid][lvl - 1], lvl);
612 
613       IterateOp iterOp = builder.create<IterateOp>(
614           loc, extractSpaceOp.getExtractedSpace(), reduc);
615       spIterVals[tid][lvl] = iterOp.getIterator();
616 
617       // Update the reduction varaibles.
618       llvm::copy(iterOp.getRegionIterArgs(), reduc.begin());
619       // Set the insertion point to loop body.
620       builder.setInsertionPointToStart(iterOp.getBody());
621       loopStack.emplace_back(tidLvls, iterOp, builder.getInsertionBlock(),
622                              iterOp.getCrds().front(), loopTag);
623       return iterOp;
624     }
625 
626     // CoIteration Loops.
627     SmallVector<Value> spaces;
628     for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
629       Value t = tensors[tid];
630       ExtractIterSpaceOp extractSpaceOp =
631           lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t)
632                    : builder.create<ExtractIterSpaceOp>(
633                          loc, t, spIterVals[tid][lvl - 1], lvl);
634       spaces.push_back(extractSpaceOp.getExtractedSpace());
635     }
636     auto coIterOp = builder.create<CoIterateOp>(loc, spaces, reduc, numCases);
637     // The CoIterationOp does not have insertion block nor induction variable.
638     // TODO: the `struct LoopInfo` should be simplied after full migration.
639     loopStack.emplace_back(tidLvls, coIterOp, /*insertion block*/ nullptr,
640                            /*induction variable*/ nullptr, loopTag);
641     return coIterOp;
642   }
643 
644   // TODO: support multiple return on parallel for?
645   tryParallel = tryParallel && reduc.size() <= 1;
646 
647   SmallVector<SparseIterator *> raIters;
648   SmallVector<SparseIterator *> spIters;
649   categorizeIterators(tidLvls, raIters, spIters);
650 
651   // Only when there is at least one sparse conditions, do we really need the
652   // universal index.
653   // TODO: Maybe we should instead requires merger to pass in a valid value at
654   // the first place instead of adjusting it in LoopEmitter?
655   needsUniv = !spIters.empty() && needsUniv;
656   // The TensorLevel used for loop conditions.
657   // If there is any sparse level, we need to use the sparse condition.
658   // If all levels are dense, we can pick arbitrary one (dense slice-driven loop
659   // can be generated using a simple ForOp as well).
660   Operation *l = nullptr;
661   Value iv = nullptr;
662   SmallVector<TensorLevel> tls;
663 
664   // Generates loops differently depending on whether we need a slice-driven
665   // loop or a simple level traversal loop.
666   if (shouldIteratedByForLoop(spIters) && !needsUniv) {
667     assert(spIters.size() <= 1);
668     SparseIterator &it = spIters.empty() ? *raIters.front() : *spIters.front();
669     std::tie(l, iv) =
670         emitForLoopOverTensorAtLvl(builder, loc, it, reduc, tryParallel);
671     tls.push_back(makeTensorLevel(it.tid, it.lvl));
672   } else {
673     for (auto *it : spIters) {
674       tls.push_back(makeTensorLevel(it->tid, it->lvl));
675     }
676 
677     if (needsUniv)
678       for (auto *it : raIters)
679         tls.push_back(makeTensorLevel(it->tid, it->lvl));
680 
681     std::tie(l, iv) =
682         emitWhileLoopOverTensorsAtLvls(builder, loc, spIters, reduc, needsUniv);
683   }
684 
685   // Enter dense tensor levels.
686   for (SparseIterator *it : raIters)
687     it->locate(builder, loc, iv);
688 
689   // NOTE: we can also prepare for next dim here in advance
690   // Pushes the loop into stack.
691   loopStack.emplace_back(tls, l, builder.getInsertionBlock(), iv, loopTag);
692   return l;
693 }
694 
695 void LoopEmitter::locateLvlAtAffineAddress(OpBuilder &builder, Location loc,
696                                            TensorLevel tidLvl,
697                                            AffineExpr lvlExpr) {
698   auto [tid, lvl] = unpackTensorLevel(tidLvl);
699 
700   const SparseIterator *parent =
701       lvl == 0 ? nullptr : iters[tid][lvl - 1].back().get();
702   auto &it = getCurIterator(tid, lvl);
703   it.genInit(builder, loc, parent);
704 
705   assert(it.kind == IterKind::kTrivial && it.randomAccessible());
706   Value lvlCrd = genAffine(builder, loc, lvlExpr);
707   it.locate(builder, loc, lvlCrd);
708 }
709 
710 void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
711                                              TensorId tid, Level lvl) {
712   // if this is the first level, there is no parent iterator for the current
713   // iterator.
714   // If the current iterator is a subsection-based iterator, the parent iterator
715   // is memorized by the iterator.
716   bool hasParent = lvl == 0 || !dependentLvlMap[tid][lvl].empty();
717 
718   const SparseIterator *parent =
719       hasParent ? nullptr : iters[tid][lvl - 1].back().get();
720   auto &it = getCurIterator(tid, lvl);
721   it.genInit(builder, loc, parent);
722 
723   // Locates the randon accessible iterator to 0.
724   if (it.randomAccessible())
725     it.locate(builder, loc, C_IDX(0));
726 }
727 
728 void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
729                               MutableArrayRef<Value> reduc) {
730   const LoopInfo &loopInfo = loopStack.back();
731   if (emitStrategy == SparseEmitStrategy::kSparseIterator) {
732     auto iterateOp = llvm::cast<IterateOp>(loopInfo.loop);
733     assert(reduc.size() == iterateOp.getNumResults());
734     rewriter.create<sparse_tensor::YieldOp>(loc, reduc);
735     // Exit the loop.
736     rewriter.setInsertionPointAfter(iterateOp);
737     // In-place update reduction variables.
738     llvm::copy(iterateOp.getResults(), reduc.begin());
739     return;
740   }
741   if (auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop)) {
742     if (!reduc.empty()) {
743       assert(reduc.size() == forOp.getNumResults());
744       rewriter.create<scf::YieldOp>(loc, reduc);
745     }
746     // Exit the loop.
747     rewriter.setInsertionPointAfter(forOp);
748     // In-place update reduction variables.
749     llvm::copy(forOp.getResults(), reduc.begin());
750   } else {
751     auto parOp = llvm::cast<scf::ParallelOp>(loopInfo.loop);
752     if (!reduc.empty()) {
753       assert(reduc.size() == parOp.getInitVals().size() && reduc.size() == 1);
754       Operation *redExp = reduc.front().getDefiningOp();
755       // Reduction expression should have no use.
756       assert(redExp->getUses().empty());
757       // This must be a binary operation.
758       // NOTE: This is users' responsibility to ensure the operation are
759       // commutative.
760       assert(redExp->getNumOperands() == 2 && redExp->getNumResults() == 1);
761 
762       Value redVal = parOp.getInitVals().front();
763       Value curVal;
764       if (redExp->getOperand(0) == redVal)
765         curVal = redExp->getOperand(1);
766       else if (redExp->getOperand(1) == redVal)
767         curVal = redExp->getOperand(0);
768       // One of the operands must be the init value (which is also the
769       // previous reduction value).
770       assert(curVal);
771 #ifndef NDEBUG
772       // The reduction expression should be the only user of the reduction val
773       // inside the parallel for.
774       unsigned numUsers = 0;
775       for (Operation *op : redVal.getUsers()) {
776         if (op->getParentOp() == parOp)
777           numUsers++;
778       }
779       assert(numUsers == 1);
780 #endif // NDEBUG
781 
782       rewriter.setInsertionPointAfter(redExp);
783       auto redOp = rewriter.create<scf::ReduceOp>(loc, curVal);
784       // Attach to the reduction op.
785       Block *redBlock = &redOp.getReductions().front().front();
786       rewriter.setInsertionPointToEnd(redBlock);
787       Operation *newRed = rewriter.clone(*redExp);
788       // Replaces arguments of the reduction expression by using the block
789       // arguments from scf.reduce.
790       rewriter.modifyOpInPlace(
791           newRed, [&]() { newRed->setOperands(redBlock->getArguments()); });
792       // Erases the out-dated reduction expression.
793       rewriter.eraseOp(redExp);
794       rewriter.setInsertionPointToEnd(redBlock);
795       rewriter.create<scf::ReduceReturnOp>(loc, newRed->getResult(0));
796     }
797     rewriter.setInsertionPointAfter(parOp);
798     // In-place update reduction variables.
799     for (unsigned i = 0, e = parOp.getResults().size(); i < e; i++)
800       reduc[i] = parOp.getResult(i);
801   }
802 }
803 
804 void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
805                                 MutableArrayRef<Value> reduc) {
806   const LoopInfo &loopInfo = loopStack.back();
807   auto whileOp = llvm::cast<scf::WhileOp>(loopInfo.loop);
808   Value iv = loopInfo.iv;
809   Value one = C_IDX(1);
810 
811   // Finalize the induction. Note that the induction could be performed
812   // in the individual if-branches to avoid re-evaluating the conditions.
813   // However, that would result in a rather elaborate forest of yield
814   // instructions during code generation. Moreover, performing the induction
815   // after the if-statements more closely resembles code generated by TACO.
816   SmallVector<Value> operands;
817   ValueRange whileRes = whileOp.getResults();
818 
819   for (auto [tid, lvl] : unpackTensorLevelRange(loopInfo.tidLvls)) {
820     SparseIterator &it = getCurIterator(tid, lvl);
821     if (!it.randomAccessible()) {
822       // Forward the sparse iterator.
823       Value cmp = CMPI(eq, it.getCrd(), iv);
824       it.forwardIf(builder, loc, cmp);
825       operands.append(it.getCursor().begin(), it.getCursor().end());
826       // const Value newPos = whileOp->getResult(o++);
827       // Following loops continue iteration from the break point of the
828       // current while loop.
829       whileRes = it.linkNewScope(whileRes);
830     } else {
831       // Make sure randomly accessible (dense) iterator is set to the right
832       // position according to the universal index.
833       Value uniIdx = whileOp.getResults().back();
834       it.locate(builder, loc, uniIdx);
835     }
836   }
837 
838   // Reduction value from users.
839   for (auto &i : reduc) {
840     operands.push_back(i);
841     // Update user reduction variables.
842     i = whileRes.front();
843     whileRes = whileRes.drop_front();
844   }
845 
846   // An (optional) universal index.
847   if (operands.size() < whileOp.getNumResults()) {
848     assert(operands.size() + 1 == whileOp.getNumResults());
849     // The last one is the universial index.
850     operands.push_back(ADDI(iv, one));
851     // update the loop starting point of current loop sequence
852     loopSeqStack.back().first = whileOp->getResults().back();
853   }
854 
855   if (!operands.empty())
856     YIELD(operands);
857 
858   builder.setInsertionPointAfter(whileOp);
859 }
860 
861 void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc,
862                                   MutableArrayRef<Value> reduc) {
863   // Clean up the values, it would help use to discover potential bug at a
864   // earlier stage (instead of silently using a wrong value).
865   const LoopInfo &loopInfo = loopStack.back();
866   if (emitStrategy == SparseEmitStrategy::kSparseIterator) {
867     Operation *p = loopInfo.loop;
868     if (isa<IterateOp>(p))
869       rewriter.create<sparse_tensor::YieldOp>(loc, reduc);
870 
871     // Exit the loop.
872     rewriter.setInsertionPointAfter(p);
873     // In-place update reduction variables.
874     llvm::copy(p->getResults(), reduc.begin());
875     loopStack.pop_back();
876     return;
877   }
878 
879   // Sets the insertion point to the right position.
880   rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock);
881   if (!loopInfo.userCodeBlock->empty() &&
882       llvm::isa<scf::YieldOp>(&loopInfo.userCodeBlock->back())) {
883     // scf::While/For inserts an implicit yield op when there is no loop
884     // iter args. In this case, we need to insert the code before the yield.
885     assert(loopInfo.userCodeBlock->back().getNumResults() == 0);
886     rewriter.setInsertionPoint(&loopInfo.userCodeBlock->back());
887   }
888 
889   if (llvm::isa<scf::WhileOp>(loopInfo.loop)) {
890     exitWhileLoop(rewriter, loc, reduc);
891   } else {
892     exitForLoop(rewriter, loc, reduc);
893   }
894 
895   assert(loopStack.size() == loopSeqStack.size());
896   loopStack.pop_back();
897 }
898 
899 //===----------------------------------------------------------------------===//
900 // Loop generation utils
901 //===----------------------------------------------------------------------===//
902 
903 std::pair<Operation *, Value> sparse_tensor::genCoIteration(
904     OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters,
905     MutableArrayRef<Value> reduc, Value uniIdx, bool userReducFirst) {
906   // NOTE: the slice driven tensor-related reduction variable must
907   // appear before normal tensors.
908 
909   // The set of induction variables for the while loop.
910   SmallVector<Value> ivs;
911 
912   // TODO: remove the flag after full migration. Currently
913   // `sparse_tensor.coiterate` operation (must) put user provided reduction
914   // values at the front of the block list, while direct sparsification to scf
915   // loops put them at the end.
916   if (userReducFirst)
917     ivs.append(reduc.begin(), reduc.end());
918 
919   // Construct the while-loop with a parameter for each coordinate.
920   for (SparseIterator *it : spIters) {
921     ValueRange itVals = it->getCursor();
922     ivs.append(itVals.begin(), itVals.end());
923   }
924 
925   if (!userReducFirst)
926     ivs.append(reduc.begin(), reduc.end());
927 
928   // Update universal index.
929   if (uniIdx)
930     ivs.push_back(uniIdx);
931 
932   // Ensures all operands are valid.
933   assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
934   TypeRange types = ValueRange(ivs).getTypes();
935   auto whileOp = builder.create<scf::WhileOp>(loc, types, ivs);
936 
937   SmallVector<Location> locs(types.size(), loc);
938   Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
939   Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
940 
941   // Generates loop conditions.
942   builder.setInsertionPointToStart(before);
943   ValueRange bArgs = before->getArguments();
944   Value whileCond = nullptr; // bool values for loop condition.
945 
946   for (SparseIterator *it : spIters) {
947     auto [cond, remArgs] = it->genWhileCond(builder, loc, bArgs);
948     whileCond = !whileCond ? cond : ANDI(whileCond, cond);
949     bArgs = remArgs;
950   }
951   // The remaining block arguments are user-provided reduction values and an
952   // optional universal index. Make sure their sizes match.
953   assert(bArgs.size() == reduc.size() + (uniIdx ? 1 : 0));
954   builder.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
955 
956   // Generates loop body.
957   builder.setInsertionPointToStart(after);
958   ValueRange aArgs = after->getArguments();
959   // Since some LoopCondKind might need extra checks to filter out invalid
960   // iterations, we maintains another array to hold the iteration arguments to
961   // yield if the checks fails.
962   SmallVector<Value> nextArgs(aArgs.begin(), aArgs.end());
963 
964   for (SparseIterator *it : spIters) {
965     aArgs = it->linkNewScope(aArgs);
966     // Dereference the iterator to cache the coordinate.
967     it->deref(builder, loc);
968   }
969 
970   // In-place update on reduction variable.
971   for (unsigned i = 0, e = reduc.size(); i < e; i++)
972     reduc[i] = aArgs[i];
973 
974   Value min;
975   // Finds the minimum coordinate
976   if (!uniIdx) {
977     for (SparseIterator *it : spIters) {
978       if (min) {
979         Value cmp = CMPI(ult, it->getCrd(), min);
980         min = SELECT(cmp, it->getCrd(), min);
981       } else {
982         min = it->getCrd();
983       }
984     }
985   } else {
986     // Otherwise, universal index is the minimal pos.
987     min = whileOp.getAfterArguments().back();
988   }
989 
990   return {whileOp, min};
991 }
992 
993 #undef CMPI
994 #undef C_IDX
995 #undef YIELD
996 #undef ADDI
997 #undef ANDI
998 #undef SUBI
999 #undef MULI
1000 #undef SELECT
1001