xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (revision c44202574ff9a8c0632aba30c2765b134557435f)
1 //===- Sparsification.cpp - Implementation of sparsification --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements converting sparse tensor types to actual sparse code.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Utils/CodegenEnv.h"
14 #include "Utils/CodegenUtils.h"
15 #include "Utils/LoopEmitter.h"
16 
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23 #include "mlir/Dialect/Linalg/IR/Linalg.h"
24 #include "mlir/Dialect/Linalg/Utils/Utils.h"
25 #include "mlir/Dialect/MemRef/IR/MemRef.h"
26 #include "mlir/Dialect/SCF/IR/SCF.h"
27 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
28 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
29 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
30 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
31 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
32 #include "mlir/Dialect/Tensor/IR/Tensor.h"
33 #include "mlir/IR/AffineExprVisitor.h"
34 #include "mlir/IR/Matchers.h"
35 #include "mlir/IR/TensorEncoding.h"
36 #include "llvm/ADT/SmallBitVector.h"
37 
38 #include <optional>
39 
40 using namespace mlir;
41 using namespace mlir::sparse_tensor;
42 
43 //===----------------------------------------------------------------------===//
44 // Sparsifier analysis methods.
45 //===----------------------------------------------------------------------===//
46 
47 /// Returns true iff affine expression is invariant. Sets the
48 /// parameter `isCurrentLoop` when expression just became invariant.
49 static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop) {
50   switch (a.getKind()) {
51   case AffineExprKind::DimId: {
52     const LoopId i = cast<AffineDimExpr>(a).getPosition();
53     if (i + 1 == curr) {
54       isCurrentLoop = true;
55       return true; // becomes invariant at current loop
56     }
57     return i < curr; // invariant when already generated
58   }
59   case AffineExprKind::Add:
60   case AffineExprKind::Mul: {
61     auto binOp = cast<AffineBinaryOpExpr>(a);
62     return isInvariantAffine(binOp.getLHS(), curr, isCurrentLoop) &&
63            isInvariantAffine(binOp.getRHS(), curr, isCurrentLoop);
64   }
65   default: {
66     assert(isa<AffineConstantExpr>(a));
67     return true;
68   }
69   }
70 }
71 
72 /// Helper method to inspect affine expressions. Rejects cases where the
73 /// same index is used more than once. Also rejects compound affine
74 /// expressions in sparse dimensions.
75 static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
76                        LevelType lt, bool setLvlFormat = true) {
77   switch (a.getKind()) {
78   case AffineExprKind::DimId: {
79     const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
80     if (!isUndefLT(merger.getLvlType(tid, idx)))
81       return false; // used more than once
82     if (setLvlFormat)
83       merger.setLevelAndType(tid, idx, lvl, lt);
84     return true;
85   }
86   case AffineExprKind::Add:
87   case AffineExprKind::Mul:
88   case AffineExprKind::Constant: {
89     assert(lt.hasDenseSemantic());
90     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) {
91       // We do not set dim level format for affine expression like d0 + d1 on
92       // either loop index at d0 or d1. We continue the recursion merely to
93       // check whether current affine is admissible or not.
94       return findAffine(merger, tid, lvl, binOp.getLHS(), lt, false) &&
95              findAffine(merger, tid, lvl, binOp.getRHS(), lt, false);
96     }
97     // Falls through when it is a constant Affine
98     return true;
99   }
100   default:
101     return false;
102   }
103 }
104 
105 /// Helper method to inspect affine expressions for index variable reduction
106 /// based codegen. It finds the dependent index set for all tensor levels in the
107 /// current expression we are generating.
108 ///
109 /// For example, when handling A[i+j][j+k], we build the two way mapping in
110 /// merger between (tensor, level) pairs and their dependent index variable set:
111 /// A_0 <=> [i, j] and A_1 <=> [j, k]
112 ///
113 /// It rejects cases (returns false)
114 /// 1st, when the same index is used more than once, e.g., A[i+j][i]
115 /// 2nd, when multiplication is used in the non-trivial index expression.
116 /// 3rd, when a constant operand is used in the non-trivial index expression.
117 ///
118 /// TODO: constant should be easy to handle.
119 static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
120                           AffineExpr a, LevelType lt, bool isSubExp = false,
121                           int64_t coefficient = 1) {
122   switch (a.getKind()) {
123   case AffineExprKind::DimId: {
124     // Only allow positive coefficients on AffineDimExpr.
125     if (coefficient <= 0)
126       return false;
127 
128     const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
129     if (!isUndefLT(merger.getLvlType(tensor, idx)))
130       return false; // used more than once, e.g., A[i][i]
131 
132     // TODO: Generalizes the following two cases. A[i] (with trivial index
133     // expression) can be treated as a special affine index expression. We do
134     // not necessarily need to differentiate them.
135     if (!isSubExp) {
136       assert(coefficient == 1);
137       merger.setLevelAndType(tensor, idx, lvl, lt);
138     }
139 
140     if (isSubExp) {
141       // The current loops appears in more than one affine expressions on the
142       // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is
143       // used twice.
144       if (merger.hasDependentLvl(idx, tensor)) {
145         // TODO: This can be supported by coiterate slices if the loop idx is
146         // appeared on affine index for different tensor, or take slice on
147         // multiple dimensions when it is on the same tensor.
148         // E.g.,
149         // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0]
150         // d0_1 = getNextSliceOffset t0 along lvl0
151         // d0_2 = getNextSliceOffset t1 along lvl0
152         // if d0_1 == d0_2 then d0 = d0_1 = d0_1
153         // else increase min(d0_1, d0_2).
154         return false;
155       }
156       merger.setLoopDependentTensorLevel(idx, tensor, lvl, lt, coefficient);
157     }
158     return true;
159   }
160   case AffineExprKind::Constant:
161   case AffineExprKind::Mul: {
162     // TODO: Support index expression like `2 * d0`, we now only support more
163     // complicated cases like `2 * d0 + d1`.
164     if (!isSubExp)
165       return false;
166 
167     // TODO: Support Constant AffineExp for slice-based codegen
168     if (isa<AffineConstantExpr>(a))
169       llvm_unreachable("Not yet implemented");
170 
171     auto binOp = cast<AffineBinaryOpExpr>(a);
172     auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
173     if (isa<AffineConstantExpr>(rhs))
174       std::swap(lhs, rhs);
175     // Must be in form of `constant * d`.
176     assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs));
177     int64_t coefficient = cast<AffineConstantExpr>(lhs).getValue();
178     return findDepIdxSet(merger, tensor, lvl, rhs, lt, isSubExp, coefficient);
179   }
180   case AffineExprKind::Add: {
181     auto binOp = cast<AffineBinaryOpExpr>(a);
182     return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), lt, true) &&
183            findDepIdxSet(merger, tensor, lvl, binOp.getRHS(), lt, true);
184   }
185   default:
186     return false;
187   }
188 }
189 
190 /// Gets the total number of compound affine expressions in the
191 /// `getMatchingIndexingMap` for the given tensor.  For the following inputs:
192 ///
193 /// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed)
194 ///
195 /// Returns 1 (because the first level is compressed and its corresponding
196 /// indexing-expression is `d0 + d1`)
197 static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
198                                                    Value tensor) {
199   // The `tensor` is not guaranteed to have `RankedTensorType`, therefore
200   // we can't use `getRankedTensorType`/`getSparseTensorType` here.
201   // However, we don't need to handle `StorageSpecifierType`, so we
202   // can use `SparseTensorType` once we guard against non-tensors.
203   const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
204   if (!rtp)
205     return 0;
206   const SparseTensorType stt(rtp);
207 
208   const Level lvlRank = stt.getLvlRank();
209   const auto exprs = map.getResults();
210   assert(static_cast<Dimension>(exprs.size()) == lvlRank &&
211          "AffineMap does not have dimension-rank many results");
212   unsigned num = 0;
213   for (Level l = 0; l < lvlRank; l++) {
214     if (!isa<AffineDimExpr>(exprs[l]) && !stt.getLvlType(l).hasDenseSemantic())
215       num++;
216   }
217   return num;
218 }
219 
220 /// Gets the total number of sparse levels with compound affine
221 /// expressions, summed over all operands of the `GenericOp`.
222 static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
223   unsigned num = 0;
224   for (OpOperand &t : op->getOpOperands())
225     num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t),
226                                               t.get());
227   return num;
228 }
229 
230 // Returns true iff output has nontrivial affine indices.
231 static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
232   OpOperand *out = op.getDpsInitOperand(0);
233   if (getSparseTensorType(out->get()).isAllDense())
234     return false;
235   return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out),
236                                             out->get());
237 }
238 
239 /// Helper method to inspect sparse encodings in the tensor types.
240 /// Fills the per-dimension sparsity information for all tensors.
241 /// Returns true if the sparse annotations and affine subscript
242 /// expressions of all tensors are admissible. Returns false if
243 /// no annotations are found or inadmissible constructs occur.
244 /// We currently support two different ways to handle non-trivial index
245 /// expression on sparse tensors, and they accept different affine expressions.
246 /// When using dependent index reducton-based approach, it currently only
247 /// supports affine addition index expression.
248 static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
249   bool annotated = false;
250   for (OpOperand &t : env.op()->getOpOperands()) {
251     const TensorId tid = env.makeTensorId(t.getOperandNumber());
252     const auto map = env.op().getMatchingIndexingMap(&t);
253     const auto enc = getSparseTensorEncoding(t.get().getType());
254     if (enc)
255       annotated = true;
256     const Level lvlRank = map.getNumResults();
257     assert(!enc || lvlRank == enc.getLvlRank());
258     assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
259     // We only need to do index reduction if there is at least one
260     // non-trivial index expression on sparse levels. If all non-trivial
261     // index expression is on dense levels, we can efficiently rely on
262     // the random access to locate the element.
263     bool needIdxReduc =
264         enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0;
265     // If then current tensor being inspected requires affine index, it need
266     // to be sliced.
267     for (Level l = 0; l < lvlRank; l++) {
268       const AffineExpr a = map.getResult(l);
269       const LevelType lt = enc.getLvlType(l);
270       if (idxReducBased && needIdxReduc) {
271         if (!findDepIdxSet(env.merger(), tid, l, a, lt))
272           return false; // inadmissible affine expression
273       } else {
274         if (!findAffine(env.merger(), tid, l, a, lt))
275           return false; // inadmissible affine expression
276       }
277     }
278   }
279   return annotated;
280 }
281 
282 //===----------------------------------------------------------------------===//
283 // Sparsifier synthesis methods (statements and expressions).
284 //===----------------------------------------------------------------------===//
285 
286 /// Local bufferization of all dense and sparse data structures.
287 static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
288   linalg::GenericOp op = env.op();
289   Location loc = op.getLoc();
290   assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
291 
292   SmallVector<Range, 4> loopRange =
293       llvm::cast<linalg::LinalgOp>(op.getOperation())
294           .createLoopRanges(builder, loc);
295 
296   env.emitter().initializeLoopEmit(
297       builder, loc,
298       /// Generates buffer for the output tensor.
299       /// Note that all sparse kernels assume that when all elements are written
300       /// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized
301       /// to all zeroes and only nonzeroes values are computed and written out.
302       /// For updates (viz. x(i) += y(i) * z(i)), only nonzeroes values are used
303       /// for the updates and no assumption on the original contents of the
304       /// output buffer is necessary.
305       [&op](OpBuilder &builder, Location loc, Value memref,
306             Value tensor) -> Value {
307         // Must not be a sparse tensor.
308         assert(!getSparseTensorEncoding(tensor.getType()));
309         // Two output tensor references should point to the same object.
310         OpOperand *lhs = op.getDpsInitOperand(0);
311         assert(lhs->get() == tensor);
312         // An output tensor can simply materialize from the buffer of the tensor
313         // that appears in the outs() clause. For updates, this has the
314         // advantage that only the nonzero value are involved in the
315         // computation, keeping the operation O(nnz). In all other cases, we are
316         // forced to zero out the buffer to enforce the assumption above, which
317         // may negatively impact running complexity (viz. O(n^2 + nnz) vs.
318         // O(nnz) for matrices).
319         // TODO: use better analysis to avoid zeroing out the buffer?
320         bool isInit = op.isInitTensor(lhs);
321         Value init = memref;
322         if (!isInit) {
323           Value zero = constantZero(builder, loc,
324                                     getElementTypeOrSelf(tensor.getType()));
325           builder.create<linalg::FillOp>(loc, ValueRange{zero},
326                                          ValueRange{init});
327         }
328         return init;
329       },
330       [&loopRange](OpBuilder &b, Location loc, Level l) {
331         assert(l < loopRange.size());
332         return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size);
333       });
334 }
335 
336 /// Generates index for load/store on sparse tensor.
337 static Value genIndex(CodegenEnv &env, OpOperand *t) {
338   const auto map = env.op().getMatchingIndexingMap(t);
339   const auto stt = getSparseTensorType(t->get());
340   const Level lvlRank = stt.getLvlRank();
341   assert(static_cast<Level>(map.getNumResults()) == lvlRank);
342   const AffineExpr a = map.getResult(lvlRank - 1);
343   assert(a.getKind() == AffineExprKind::DimId);
344   const LoopId idx = env.makeLoopId(cast<AffineDimExpr>(a).getPosition());
345   return env.getLoopVar(idx);
346 }
347 
348 /// Generates subscript for load/store on a dense or sparse tensor.
349 static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
350                           SmallVectorImpl<Value> &args) {
351   const Location loc = env.op().getLoc();
352   const TensorId tid = env.makeTensorId(t->getOperandNumber());
353   const auto map = env.op().getMatchingIndexingMap(t);
354   const auto stt = getSparseTensorType(t->get());
355   if (stt.hasEncoding()) {
356     // For sparse tensors we only push the last-level's position onto `args`.
357     const auto pos = env.emitter().getValPosits(tid);
358     assert(!pos.empty());
359     args.append(pos);
360     // Simply returns the tensor to extract value using iterators.
361     if (env.options().sparseEmitStrategy == SparseEmitStrategy::kSparseIterator)
362       return t->get();
363   } else {
364     // For dense tensors we push all level's coordinates onto `args`.
365     const Level lvlRank = stt.getLvlRank();
366     assert(static_cast<Level>(map.getNumResults()) == lvlRank);
367     for (Level l = 0; l < lvlRank; l++) {
368       const auto lvlExpr = map.getResult(l);
369       const auto lvlCrd = env.emitter().genAffine(builder, loc, lvlExpr);
370       args.push_back(lvlCrd);
371     }
372   }
373   return env.emitter().getValBuffer()[tid];
374 }
375 
376 /// Generates insertion code to implement dynamic tensor load.
377 static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder,
378                               OpOperand *t) {
379   linalg::GenericOp op = env.op();
380   Location loc = op.getLoc();
381   // Direct lexicographic coordinate order, tensor loads as zero.
382   if (!env.isExpand()) {
383     Type tp = getElementTypeOrSelf(t->get().getType());
384     return constantZero(builder, loc, tp);
385   }
386   // Load from expanded access pattern.
387   Value index = genIndex(env, t);
388   return builder.create<memref::LoadOp>(loc, env.getExpandValues(), index);
389 }
390 
391 /// Generates insertion code to implement dynamic tensor load for reduction.
392 static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder,
393                                     OpOperand *t) {
394   linalg::GenericOp op = env.op();
395   Location loc = op.getLoc();
396   Value identity = env.getCustomRedId();
397   // Direct lexicographic coordinate order, tensor loads as identity.
398   if (!env.isExpand())
399     return identity;
400   // Load from expanded access pattern if filled, identity otherwise.
401   Value values = env.getExpandValues();
402   Value filled = env.getExpandFilled();
403   Value index = genIndex(env, t);
404   Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
405   Value valAtIndex = builder.create<memref::LoadOp>(loc, values, index);
406   return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
407 }
408 
409 static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond,
410                                   Value sparseOut, ValueRange ivs, Value v) {
411   scf::IfOp condInsert =
412       builder.create<scf::IfOp>(loc, sparseOut.getType(), cond, true);
413   // True branch.
414   builder.setInsertionPointToStart(condInsert.thenBlock());
415   Value res = builder.create<tensor::InsertOp>(loc, v, sparseOut, ivs);
416   builder.create<scf::YieldOp>(loc, res);
417   // False branch.
418   builder.setInsertionPointToStart(condInsert.elseBlock());
419   builder.create<scf::YieldOp>(loc, sparseOut);
420   // Value assignment.
421   builder.setInsertionPointAfter(condInsert);
422   return condInsert.getResult(0);
423 }
424 
425 /// Generates insertion code to implement dynamic tensor store.
426 static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
427                               Value rhs) {
428   linalg::GenericOp op = env.op();
429   Location loc = op.getLoc();
430   // Direct insertion in lexicographic coordinate order.
431   if (!env.isExpand()) {
432     const LoopId numLoops = op.getRank(t);
433     // Retrieves the first `numLoop` induction variables.
434     SmallVector<Value> ivs = llvm::to_vector(llvm::drop_end(
435         env.emitter().getLoopIVsRange(), env.getCurrentDepth() - numLoops));
436     Value chain = env.getInsertionChain();
437     if (env.isValidLexInsert()) {
438       // Generates runtime check for a valid lex during reduction,
439       // to avoid inserting the identity value for empty reductions.
440       //   if (validLexInsert) then
441       //     insert(rhs) into chain
442       //     return updated chain
443       //   else
444       //     return unmodified chain
445       Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(),
446                                        chain, ivs, rhs);
447       env.updateInsertionChain(out);
448     } else {
449       Value sparseOut;
450       if (!hasAnySparseType(env.op().getInputs().getTypes())) {
451         // This is an all-dense -> sparse kernel, test rhs != 0 before
452         // insertion.
453         Value nz = genIsNonzero(builder, loc, rhs);
454         sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs);
455       } else {
456         sparseOut = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
457       }
458       // Generates regular insertion chain.
459       env.updateInsertionChain(sparseOut);
460     }
461     return;
462   }
463   // Generates insertion code along expanded access pattern.
464   //   if (!expFilled[i]) then
465   //     expFilled[i] = true
466   //     expAdded[inserts++] = i
467   //   endif
468   //   values[i] = rhs
469   Value values = env.getExpandValues();
470   Value filled = env.getExpandFilled();
471   Value added = env.getExpandAdded();
472   Value count = env.getExpandCount();
473   Value index = genIndex(env, t);
474   Value fval = constantI1(builder, loc, false);
475   Value tval = constantI1(builder, loc, true);
476   // If statement.
477   Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
478   Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
479                                              isFilled, fval);
480   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond,
481                                              /*else=*/true);
482   // True branch.
483   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
484   builder.create<memref::StoreOp>(loc, tval, filled, index);
485   builder.create<memref::StoreOp>(loc, index, added, count);
486   Value one = constantIndex(builder, loc, 1);
487   Value add = builder.create<arith::AddIOp>(loc, count, one);
488   builder.create<scf::YieldOp>(loc, add);
489   // False branch.
490   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
491   builder.create<scf::YieldOp>(loc, count);
492   builder.setInsertionPointAfter(ifOp);
493   // Value assignment.
494   env.updateExpandCount(ifOp.getResult(0));
495   builder.create<memref::StoreOp>(loc, rhs, values, index);
496 }
497 
498 /// Generates a load on a dense or sparse tensor.
499 static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
500   // Test if the load was hoisted to a higher loop nest.
501   Value val = env.exp(exp).val;
502   if (val)
503     return val;
504   // Get tensor operand.
505   linalg::GenericOp op = env.op();
506   Location loc = op.getLoc();
507   OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
508   // Fold binary-valued tensor into explicit value.
509   const auto stt = getSparseTensorType(t->get());
510   if (auto explVal = stt.getExplicitVal())
511     return genValFromAttr(builder, loc, explVal);
512   // Load during insertion.
513   if (env.isSparseOutput(t)) {
514     if (env.isCustomReduc())
515       return genInsertionLoadReduce(env, builder, t);
516     return genInsertionLoad(env, builder, t);
517   }
518 
519   // Actual load.
520   SmallVector<Value> args;
521   Value ptr = genSubscript(env, builder, t, args);
522   if (llvm::isa<TensorType>(ptr.getType())) {
523     assert(env.options().sparseEmitStrategy ==
524                SparseEmitStrategy::kSparseIterator &&
525            args.size() == 1);
526     return builder.create<ExtractValOp>(loc, ptr, args.front());
527   }
528   return builder.create<memref::LoadOp>(loc, ptr, args);
529 }
530 
531 /// Generates a store on a dense or sparse tensor.
532 static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
533                            Value rhs) {
534   // Only unary and binary are allowed to return an uninitialized rhs
535   // to indicate missing output. Or otherwise a custom reduction that
536   // received no value to accumulate.
537   if (!rhs) {
538     assert(env.exp(exp).kind == TensorExp::Kind::kUnary ||
539            env.exp(exp).kind == TensorExp::Kind::kBinary ||
540            env.exp(exp).kind == TensorExp::Kind::kReduce);
541     return;
542   }
543   // Test if this is a scalarized reduction.
544   if (env.isReduc()) {
545     env.updateReduc(rhs);
546     return;
547   }
548   // Regular store.
549   linalg::GenericOp op = env.op();
550   Location loc = op.getLoc();
551   OpOperand *t = op.getDpsInitOperand(0);
552   if (!env.isSparseOutput(t)) {
553     SmallVector<Value> args;
554     Value ptr = genSubscript(env, builder, t, args);
555     builder.create<memref::StoreOp>(loc, rhs, ptr, args);
556     return;
557   }
558   // Store during sparse insertion.
559   if (env.exp(exp).kind != TensorExp::Kind::kSelect) {
560     genInsertionStore(env, builder, t, rhs);
561     return;
562   }
563   // Select operation insertion.
564   Value chain = env.getInsertionChain();
565   scf::IfOp ifOp =
566       builder.create<scf::IfOp>(loc, chain.getType(), rhs, /*else=*/true);
567   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
568   // Existing value was preserved to be used here.
569   assert(env.exp(exp).val);
570   Value v0 = env.exp(exp).val;
571   genInsertionStore(env, builder, t, v0);
572   env.merger().clearExprValue(exp);
573   // Yield modified insertion chain along true branch.
574   Value mchain = env.getInsertionChain();
575   builder.create<scf::YieldOp>(op.getLoc(), mchain);
576   // Yield original insertion chain along false branch.
577   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
578   builder.create<scf::YieldOp>(loc, chain);
579   // Done with if statement.
580   env.updateInsertionChain(ifOp->getResult(0));
581   builder.setInsertionPointAfter(ifOp);
582 }
583 
584 /// Generates an invariant value.
585 inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
586   return env.exp(exp).val;
587 }
588 
589 /// Semi-ring branches are simply inlined by the sparsifier. Prior
590 /// analysis has verified that all computations are "local" to the inlined
591 /// branch or otherwise invariantly defined outside the loop nest, with the
592 /// exception of index computations, which need to be relinked to actual
593 /// inlined cloned code.
594 static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
595                           Value e) {
596   if (auto arg = dyn_cast<BlockArgument>(e)) {
597     // Direct arguments of the original linalg op must be converted
598     // into dense tensor loads. Note that we should not encounter
599     // anything else. This needs to be verified by semi-ring ops.
600     linalg::GenericOp op = env.op();
601     if (arg.getOwner()->getParentOp() == op) {
602       const TensorId tid = env.makeTensorId(arg.getArgNumber());
603       OpOperand *t = &op->getOpOperand(tid);
604       assert(!getSparseTensorType(t->get()).hasEncoding()); // dense!
605       SmallVector<Value> args;
606       Value ptr = genSubscript(env, rewriter, t, args);
607       return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args);
608     }
609   } else if (Operation *def = e.getDefiningOp()) {
610     // Handle index computation.
611     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
612       return env.getLoopVar(env.makeLoopId(indexOp.getDim()));
613     // When still defined in new body, recurse into operands.
614     if (def->getBlock() == block) {
615       rewriter.setInsertionPoint(def);
616       for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
617         rewriter.modifyOpInPlace(def, [&]() {
618           def->setOperand(
619               i, relinkBranch(env, rewriter, block, def->getOperand(i)));
620         });
621       }
622     }
623   }
624   return e;
625 }
626 
627 /// Recursively generates tensor expression.
628 static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
629   if (e == ::mlir::sparse_tensor::detail::kInvalidId)
630     return Value();
631 
632   linalg::GenericOp op = env.op();
633   Location loc = op.getLoc();
634   const TensorExp &exp = env.exp(e);
635   const auto kind = exp.kind;
636   if (kind == TensorExp::Kind::kTensor)
637     return genTensorLoad(env, rewriter, e);
638   if (kind == TensorExp::Kind::kInvariant)
639     return genInvariantValue(env, e);
640   if (kind == TensorExp::Kind::kLoopVar)
641     return env.getLoopVar(exp.loop);
642 
643   if (kind == TensorExp::Kind::kReduce)
644     env.startCustomReduc(e); // enter custom
645 
646   // If either lhs/rhs is a synthetic zero, we infer the type for the zero value
647   // based on the type of the other operand.
648   Value v0, v1;
649   if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId &&
650       env.exp(exp.children.e0).kind == TensorExp::Kind::kSynZero) {
651     v1 = genExp(env, rewriter, exp.children.e1);
652     v0 = constantZero(rewriter, loc, v1.getType());
653   } else if (exp.children.e1 != ::mlir::sparse_tensor::detail::kInvalidId &&
654              env.exp(exp.children.e1).kind == TensorExp::Kind::kSynZero) {
655     v0 = genExp(env, rewriter, exp.children.e0);
656     v1 = constantZero(rewriter, loc, v0.getType());
657   } else {
658     v0 = genExp(env, rewriter, exp.children.e0);
659     v1 = genExp(env, rewriter, exp.children.e1);
660   }
661 
662   Value ee;
663   if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) {
664     // custom reduce did not receive a value
665   } else {
666     ee = env.merger().buildExp(rewriter, loc, e, v0, v1);
667     if (ee &&
668         (kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary ||
669          kind == TensorExp::Kind::kBinaryBranch ||
670          kind == TensorExp::Kind::kReduce ||
671          kind == TensorExp::Kind::kSelect)) {
672       OpBuilder::InsertionGuard guard(rewriter);
673       ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee);
674     }
675   }
676 
677   if (kind == TensorExp::Kind::kReduce)
678     env.endCustomReduc(); // exit custom
679 
680   if (kind == TensorExp::Kind::kSelect)
681     env.merger().setExprValue(e, v0); // Preserve value for later use.
682 
683   return ee;
684 }
685 
686 /// Hoists loop invariant tensor loads for which indices have been exhausted.
687 static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
688                           LoopId curr, bool isStart) {
689   if (exp == ::mlir::sparse_tensor::detail::kInvalidId)
690     return;
691   if (env.exp(exp).kind == TensorExp::Kind::kTensor) {
692     // Inspect tensor indices.
693     linalg::GenericOp op = env.op();
694     OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
695     const auto map = op.getMatchingIndexingMap(&t);
696     const auto stt = getSparseTensorType(t.get());
697     const Level lvlRank = stt.getLvlRank();
698     assert(static_cast<Level>(map.getNumResults()) == lvlRank);
699     bool isCurrentLoop = curr == 0; // for scalar tensors
700     for (Level l = 0; l < lvlRank; l++) {
701       const AffineExpr a = map.getResult(l);
702       if (!isInvariantAffine(a, curr, /*out*/ isCurrentLoop))
703         return; // still in play
704     }
705     // All exhausted at current level.
706     if (!isCurrentLoop)
707       return;
708     // Generate code for a scalarized reduction or invariant. Note that
709     // because custom reduction lhs may occur several times in the IR,
710     // we have a built-in safety for only initializing and wrapping-up
711     // the scalarized reduction once.
712     OpOperand *lhs = op.getDpsInitOperand(0);
713     if (lhs == &t) {
714       // Start or end a scalarized reduction.
715       if (isStart) {
716         if (env.isCustomReduc()) {
717           if (!env.isReduc())
718             env.startReduc(exp, env.getCustomRedId());
719         } else {
720           env.startReduc(exp, genTensorLoad(env, builder, exp));
721         }
722         if (env.hasSparseOutput())
723           env.startValidLexInsert(
724               constantI1(builder, env.op().getLoc(), false));
725       } else {
726         if (!env.isCustomReduc() || env.isReduc())
727           genTensorStore(env, builder, exp, env.endReduc());
728         if (env.hasSparseOutput())
729           env.endValidLexInsert();
730       }
731     } else {
732       // Start or end loop invariant hoisting of a tensor load.
733       if (isStart) {
734         env.merger().setExprValue(exp, genTensorLoad(env, builder, exp));
735       } else {
736         env.merger().clearExprValue(exp);
737       }
738     }
739   } else if (env.exp(exp).kind != TensorExp::Kind::kInvariant &&
740              env.exp(exp).kind != TensorExp::Kind::kLoopVar &&
741              env.exp(exp).kind != TensorExp::Kind::kSynZero) {
742     // Traverse into the binary operations. Note that we only hoist
743     // tensor loads, since subsequent MLIR/LLVM passes know how to
744     // deal with all other kinds of derived loop invariants.
745     if (env.exp(exp).kind == TensorExp::Kind::kReduce)
746       env.startCustomReduc(exp); // enter custom
747     const ExprId e0 = env.exp(exp).children.e0;
748     const ExprId e1 = env.exp(exp).children.e1;
749     genInvariants(env, builder, e0, curr, isStart);
750     genInvariants(env, builder, e1, curr, isStart);
751     if (env.exp(exp).kind == TensorExp::Kind::kReduce)
752       env.endCustomReduc(); // exit custom
753   }
754 }
755 
756 /// Generates an expanded access pattern in innermost dimension.
757 static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr,
758                       bool isStart) {
759   linalg::GenericOp op = env.op();
760   OpOperand *lhs = op.getDpsInitOperand(0);
761   if (!env.atExpandLevel(lhs, op.getRank(lhs), curr))
762     return; // not needed at current level
763   assert(!env.isReduc());
764   // Generate start or end of an expanded access pattern. Note that because
765   // an expansion does not rely on the ongoing contents of the sparse storage
766   // scheme, we can use the original tensor as incoming SSA value (which
767   // simplifies codegen a bit). If expansion on the actual contents is ever
768   // needed, we will need to use the SSA value in the insertion chain instead.
769   Value tensor = lhs->get();
770   Location loc = op.getLoc();
771   if (isStart) {
772     auto dynShape = {ShapedType::kDynamic};
773     Type etp = cast<ShapedType>(tensor.getType()).getElementType();
774     Type t1 = MemRefType::get(dynShape, etp);
775     Type t2 = MemRefType::get(dynShape, builder.getI1Type());
776     Type t3 = MemRefType::get(dynShape, builder.getIndexType());
777     Type t4 = builder.getIndexType();
778     auto r = builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor);
779     assert(r.getNumResults() == 4);
780     env.startExpand(r.getResult(0), r.getResult(1), r.getResult(2),
781                     r.getResult(3));
782   } else {
783     SmallVector<Value> indices;
784     for (LoopId i = 0; i < curr; i++)
785       indices.push_back(env.emitter().getLoopIV(i));
786     Value values = env.getExpandValues();
787     Value filled = env.getExpandFilled();
788     Value added = env.getExpandAdded();
789     Value count = env.getExpandCount();
790     Value chain = env.getInsertionChain();
791     Value compress = builder.create<CompressOp>(loc, values, filled, added,
792                                                 count, chain, indices);
793     env.updateInsertionChain(compress);
794     env.endExpand();
795   }
796 }
797 
798 /// Returns parallelization strategy. Any implicit loop in the Linalg
799 /// operation that is marked "parallel" is a candidate. Whether it is actually
800 /// converted to a parallel operation depends on the requested strategy.
801 static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
802   // Reject parallelization of sparse output.
803   if (env.hasSparseOutput())
804     return false;
805   // Parallel loops on tensor expansion can cause data races.
806   if (env.isExpand())
807     return false;
808   // Inspect strategy.
809   switch (env.options().parallelizationStrategy) {
810   case SparseParallelizationStrategy::kNone:
811     return false;
812   case SparseParallelizationStrategy::kDenseOuterLoop:
813     return isOuter && !isSparse;
814   case SparseParallelizationStrategy::kAnyStorageOuterLoop:
815     return isOuter;
816   case SparseParallelizationStrategy::kDenseAnyLoop:
817     return !isSparse;
818   case SparseParallelizationStrategy::kAnyStorageAnyLoop:
819     return true;
820   }
821   llvm_unreachable("unexpected parallelization strategy");
822 }
823 
824 /// Whether or not the current loop being generated should be parallized (if
825 /// possible) according to the configuration.
826 static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
827                                ArrayRef<TensorLevel> tidLvls) {
828   linalg::GenericOp op = env.op();
829   auto iteratorTypes = op.getIteratorTypesArray();
830   bool isSparse = llvm::any_of(tidLvls, [curr, &env](TensorLevel tidLvl) {
831     // Queries the LT based on the tensor and loop id, as requested by
832     // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv
833     // should be consistent with the LT indexed by <TensorId, Level>.
834     const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, curr);
835     return lt.hasSparseSemantic();
836   });
837   return isParallelFor(env, /*isOuter=*/curr == 0, isSparse);
838 }
839 
840 /// Emit a loop to coiterate over the list of tensor levels. The generated loop
841 /// can either be a for loop or while loop depending on whether there is at most
842 /// one sparse level in the list.
843 static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
844                                  ArrayRef<TensorLevel> tidLvls,
845                                  unsigned numCases, bool tryParallel,
846                                  bool needsUniv) {
847   Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
848     // Construct while-loop with a parameter for each index.
849     return env.emitter().enterCoIterationOverTensorsAtLvls(
850         builder, env.op().getLoc(), tidLvls, numCases, reduc, tryParallel,
851         needsUniv);
852   });
853   assert(loop);
854   return loop;
855 }
856 
857 /// Generates a for-loop or a while-loop, depending on whether it implements
858 /// singleton iteration or co-iteration over the given conjunction.
859 static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr,
860                           unsigned numCases, bool needsUniv,
861                           ArrayRef<TensorLevel> tidLvls) {
862   bool tryParallel = shouldTryParallize(env, curr, tidLvls);
863   return genCoIteration(env, builder, tidLvls, numCases, tryParallel,
864                         needsUniv);
865 }
866 
867 /// Generates the induction structure for a while-loop.
868 static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
869                             bool needsUniv) {
870   Location loc = env.op().getLoc();
871   // Finalize each else branch of all if statements.
872   if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
873     while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
874                builder.getInsertionBlock()->getParentOp())) {
875       // Break on IfOp for slicing filtering.
876       if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ==
877           StringAttr::get(ifOp->getContext(), "slice"))
878         break;
879 
880       unsigned y = 0;
881       SmallVector<Value> yields;
882       if (env.isReduc()) {
883         yields.push_back(env.getReduc());
884         env.updateReduc(ifOp.getResult(y++));
885         if (env.isValidLexInsert()) {
886           yields.push_back(env.getValidLexInsert());
887           env.updateValidLexInsert(ifOp.getResult(y++));
888         }
889       }
890       if (env.isExpand()) {
891         yields.push_back(env.getExpandCount());
892         env.updateExpandCount(ifOp->getResult(y++));
893       }
894       if (env.getInsertionChain()) {
895         yields.push_back(env.getInsertionChain());
896         env.updateInsertionChain(ifOp->getResult(y++));
897       }
898       assert(y == yields.size());
899       builder.create<scf::YieldOp>(loc, yields);
900       builder.setInsertionPointAfter(ifOp);
901     }
902   }
903   // No need to set the insertion point here as LoopEmitter keeps track of the
904   // basic block where scf::Yield should be inserted.
905 }
906 
907 /// Generates a case region in the coiterate operation.
908 static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder,
909                                unsigned caseIdx, LatPointId allCase,
910                                LatPointId curCase,
911                                MutableArrayRef<Value> reduc) {
912   assert(allCase == curCase || env.merger().latGT(allCase, curCase));
913   const BitVector &allCaseBits = env.merger().lat(allCase).simple;
914   const BitVector &curCaseBits = env.merger().lat(curCase).simple;
915 
916   /// Computes the subset of iterators that are valid in the current case being
917   /// generated.
918   I64BitSet caseBit(0);
919   for (auto [idx, set] : llvm::enumerate(allCaseBits.set_bits()))
920     if (curCaseBits.test(set))
921       caseBit.set(idx);
922 
923   env.emitter().enterCurrentCoIterationCase(builder, env.op().getLoc(), caseBit,
924                                             caseIdx, reduc);
925 }
926 
927 /// Generates a single if-statement within a while-loop.
928 static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
929                        LatPointId p) {
930   Location loc = env.op().getLoc();
931   SmallVector<Type> types;
932   Value cond;
933   env.merger().foreachTensorLoopId(
934       p, /*simple=*/true,
935       [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt,
936           bool isIdxRed) {
937         if (isIdxRed) {
938           // Since there is no 1:1 mapping from loop to level (multiple loops
939           // are required to resolve one level with non-trivial index
940           // expression), we need to reconstruct the tensor level types if this
941           // loop requires index reduction condition.
942           assert(lvl.has_value() && isUndefLT(lt));
943           auto stt = getSparseTensorType(env.op().getInputs()[tid]);
944           lt = stt.getLvlType(*lvl);
945         }
946         assert(curr == env.merger().loop(b));
947         Value clause;
948         if (lt.hasSparseSemantic()) {
949           assert(lvl.has_value());
950           const Value crd = env.emitter().getCoord(tid, *lvl);
951           const Value lvar = env.getLoopVar(curr);
952           clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
953                                                  crd, lvar);
954         } else {
955           assert(lt.hasDenseSemantic() || isUndefLT(lt));
956           clause = constantI1(builder, loc, true);
957         }
958         cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
959       });
960   if (env.isReduc()) {
961     types.push_back(env.getReduc().getType());
962     if (env.isValidLexInsert())
963       types.push_back(env.getValidLexInsert().getType());
964   }
965   if (env.isExpand())
966     types.push_back(builder.getIndexType());
967   if (env.getInsertionChain())
968     types.push_back(env.getInsertionChain().getType());
969   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
970   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
971   return ifOp;
972 }
973 
974 /// Generates end of true branch of if-statement within a while-loop.
975 static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
976                   Value redInput, Value cntInput, Value insInput,
977                   Value validIns) {
978   SmallVector<Value> operands;
979   if (env.isReduc()) {
980     operands.push_back(env.getReduc());
981     env.updateReduc(redInput);
982     if (env.isValidLexInsert()) {
983       // Any overlapping indices during a reduction creates a valid lex insert.
984       operands.push_back(constantI1(builder, env.op().getLoc(), true));
985       env.updateValidLexInsert(validIns);
986     }
987   }
988   if (env.isExpand()) {
989     operands.push_back(env.getExpandCount());
990     env.updateExpandCount(cntInput);
991   }
992   if (env.getInsertionChain()) {
993     operands.push_back(env.getInsertionChain());
994     env.updateInsertionChain(insInput);
995   }
996   if (!operands.empty())
997     builder.create<scf::YieldOp>(env.op().getLoc(), operands);
998   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
999 }
1000 
1001 //===----------------------------------------------------------------------===//
1002 // Sparsifier synthesis methods (loop sequence).
1003 //===----------------------------------------------------------------------===//
1004 
1005 static bool getAllTidLvlsInLatPoints(
1006     CodegenEnv &env, LatPointId li, LoopId curr,
1007     llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {
1008   const BitVector &simple = env.lat(li).simple;
1009   const TensorId outTid = env.merger().getOutTensorID();
1010   const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr);
1011 
1012   unsigned numloopCond = 0;
1013   bool hasNonUnique = false;
1014   env.merger().foreachTensorLoopId(
1015       li, [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
1016                     LevelType lt, bool isIdxReduc) {
1017         if (simple[b]) {
1018           if (isIdxReduc) {
1019             callback(env.makeTensorLevel(tid, *lvl), nullptr);
1020             numloopCond++;
1021             return;
1022           }
1023           if (isUndefLT(lt)) {
1024             // An undefined lt in the lattices, we probably mean to
1025             // generate a dense loop according to the synthetic tensor (for
1026             // invariants and sparse output tensor).
1027             if (env.merger().getSynTensorID() == tid) {
1028               // Coiterating with an invariant
1029               // e.g., out = prod(in[i][j] op invariant);
1030               // or a broadcast
1031               // e.g., out[i][j] = in[i] (j is undef for input)
1032               //
1033               // The level of the synthetic tensor is the current loop depth;
1034               // the rank of the synthetic tensor equals to number of loops.
1035               assert(curr == env.getCurrentDepth());
1036               lvl = curr;
1037             } else if (!lvl) {
1038               // Skips invalid lvl (e.g., when this is a zero ranked tensor).
1039               return;
1040             }
1041           }
1042           hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
1043           callback(env.makeTensorLevel(tid, *lvl), nullptr);
1044           numloopCond++;
1045         } else if (lt.hasDenseSemantic() || isIdxReduc) {
1046           callback(env.makeTensorLevel(tid, *lvl), nullptr);
1047         } else {
1048           assert(isUndefLT(lt));
1049           linalg::GenericOp op = env.op();
1050           if (tid >= op.getNumDpsInputs())
1051             // We only handle affine expression on input tensors (for now).
1052             return;
1053           OpOperand *operand = &op->getOpOperand(tid);
1054           const auto stt = getSparseTensorType(operand->get());
1055           // Non-annotated dense tensors requires no special handling.
1056           if (!stt.hasEncoding())
1057             return;
1058 
1059           ArrayRef<AffineExpr> affines =
1060               op.getMatchingIndexingMap(operand).getResults();
1061           const Level lvlRank = stt.getLvlRank();
1062           assert(affines.size() == static_cast<size_t>(lvlRank));
1063           for (Level l = 0; l < lvlRank; l++) {
1064             AffineExpr exp = affines[l];
1065             // Skip simple affine expression and non-dense levels (which
1066             // have their own filter loop).
1067             LevelType lt = stt.getLvlType(l);
1068             if (isa<AffineDimExpr>(exp) || !lt.hasDenseSemantic())
1069               continue;
1070 
1071             // Constant affine expression are handled in genLoop.
1072             if (!isa<AffineConstantExpr>(exp)) {
1073               bool isCurrentLoop = false;
1074               assert(curr == env.getCurrentDepth());
1075               if (isInvariantAffine(exp, curr + 1, /*out*/ isCurrentLoop) &&
1076                   isCurrentLoop) {
1077                 // If the compound affine is invariant and we are right at the
1078                 // level. We need to generate the address according to the
1079                 // affine expression. This is also the best place we can do it
1080                 // to avoid putting it inside inner loops.
1081                 callback(env.makeTensorLevel(tid, l), exp);
1082               }
1083             }
1084           }
1085         }
1086       });
1087 
1088   if (isDenseLT(env.lt(outTid, curr))) {
1089     auto stt = getSparseTensorType(env.op().getOutputs().front());
1090     // Note that we generate dense indices of the output tensor unconditionally,
1091     // since they may not appear in the lattice, but may be needed for
1092     // linearized env.
1093     // TODO: we should avoid introducing corner cases for all-dense sparse
1094     // tensors.
1095     if (stt.hasEncoding() && stt.isAllDense())
1096       callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
1097   }
1098 
1099   if (numloopCond == 0) {
1100     // Corner cases where the loop bound is defined by a *unused* operand, in
1101     // this case, we just generate a dense "fake" loop by iterating over the
1102     // synthetic tensor.
1103     callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr);
1104     numloopCond++;
1105   }
1106   // If we just need to one loop conditions and the conditions is not imposed on
1107   // non-unique level, the loop can be generated by a for loop.
1108   // Or, if we are generating sparse-iterator-based loops, we always generate
1109   // `sparse_tensor.iterate` regardless whether the level is unique or not.
1110   return numloopCond == 1 &&
1111          (!hasNonUnique || env.options().sparseEmitStrategy ==
1112                                SparseEmitStrategy::kSparseIterator);
1113 }
1114 
1115 /// Starts a loop sequence at given level. Returns true if
1116 /// the universal loop index must be maintained at this level.
1117 static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
1118                          LoopId curr, LatSetId lts) {
1119   assert(!env.getLoopVar(curr));
1120   // Emit invariants at this loop sequence level.
1121   genInvariants(env, builder, exp, curr, /*isStart=*/true);
1122   // Emit access pattern expansion for sparse tensor output.
1123   genExpand(env, builder, curr, /*isStart=*/true);
1124   // Emit further initialization at this loop sequence level.
1125   const LatPointId l0 = env.set(lts)[0];
1126 
1127   SmallVector<TensorLevel> tidLvls;
1128   getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
1129     // TODO: remove this! The same tensor level might be added for multiple
1130     // times due to the special handling for all-dense "sparse" output tensor
1131     // (see L1038).
1132     if (llvm::find(tidLvls, tl) != tidLvls.end())
1133       return;
1134     tidLvls.emplace_back(tl);
1135   });
1136 
1137   env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
1138 
1139   // Maintain the universal index only if it is actually
1140   // consumed by a subsequent lattice point.
1141   for (const LatPointId li : env.set(lts).drop_front())
1142     if (!env.merger().hasAnySparse(env.lat(li).simple))
1143       return true;
1144 
1145   return false;
1146 }
1147 
1148 // Generates dense affine address for encoding.
1149 static void genConstantDenseAddressFromLevel(CodegenEnv &env,
1150                                              OpBuilder &builder, TensorId tid,
1151                                              Level startLvl) {
1152   // TODO: Handle affine expression on output tensor.
1153   linalg::GenericOp op = env.op();
1154   assert(tid < op.getNumDpsInputs());
1155   OpOperand *input = op.getDpsInputOperands()[tid];
1156   const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
1157   const auto enc = getSparseTensorEncoding(input->get().getType());
1158   if (enc) {
1159     const Location loc = op.getLoc();
1160     const TensorId tid = env.makeTensorId(input->getOperandNumber());
1161     const Level lvlRank = enc.getLvlRank();
1162     assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
1163     for (Level l = startLvl; l < lvlRank; l++) {
1164       AffineExpr lvlExpr = lvlExprs[l];
1165       if (enc.getLvlType(l).hasDenseSemantic() &&
1166           isa<AffineConstantExpr>(lvlExpr))
1167         env.emitter().locateLvlAtAffineAddress(
1168             builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
1169       else
1170         return; // break on first non-dense non-constant level
1171     }
1172   }
1173 }
1174 
1175 // We can generate address for constant affine expression before any loops
1176 // starting from the first level as they do not depend on anything.
1177 // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
1178 // levels can be determined before loops.
1179 static void genInitConstantDenseAddress(CodegenEnv &env,
1180                                         RewriterBase &rewriter) {
1181   for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
1182     genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
1183 }
1184 
1185 /// Returns true if the lattice bit can be iterated by a for loop.
1186 static bool translateBitsToTidLvlPairs(
1187     CodegenEnv &env, LatPointId li, LoopId curr,
1188     SmallVectorImpl<TensorLevel> &tidLvls,
1189     SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
1190   return getAllTidLvlsInLatPoints(env, li, curr,
1191                                   [&](TensorLevel tl, AffineExpr exp) {
1192                                     if (exp)
1193                                       affineTidLvls.emplace_back(tl, exp);
1194                                     else
1195                                       tidLvls.emplace_back(tl);
1196                                   });
1197 }
1198 
1199 /// Starts a single loop in current sequence.
1200 static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
1201                                               OpBuilder &builder, LoopId curr,
1202                                               LatPointId li, unsigned numCases,
1203                                               bool needsUniv) {
1204   // TODO: numCases only used when generating iterator-based loops. Cleanup
1205   // after fully migration.
1206   // The set of tensors + lvls to generate loops on
1207   SmallVector<TensorLevel> tidLvls;
1208 
1209   // The set of dense tensors with non-trivial affine expression that just
1210   // becomes invariant and the address are generated at the current level.
1211   SmallVector<std::pair<TensorLevel, AffineExpr>> affineTidLvls;
1212   bool isSingleCond =
1213       translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls);
1214 
1215   // Emit the for/while-loop control.
1216   Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls);
1217   Location loc = env.op().getLoc();
1218   for (auto [tidLvl, exp] : affineTidLvls) {
1219     env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp);
1220   }
1221 
1222   // Until now, we have entered every <tid, lvl> pair in {cond, extra,
1223   // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent
1224   // on constant affines expression may now be determined.
1225   auto allTidLvls =
1226       llvm::concat<TensorLevel>(tidLvls, llvm::make_first_range(affineTidLvls));
1227   for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) {
1228     if (tid != env.merger().getOutTensorID() &&
1229         tid != env.merger().getSynTensorID())
1230       genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1);
1231   }
1232 
1233   return std::make_pair(loop, isSingleCond);
1234 }
1235 
1236 /// Ends a single loop in current sequence. Returns new values for needsUniv.
1237 static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
1238                     LatPointId li, bool needsUniv, bool isSingleCond) {
1239   // Either a for-loop or a while-loop that iterates over a slice.
1240   if (isSingleCond) {
1241     // Any iteration creates a valid lex insert.
1242     if (env.isReduc() && env.isValidLexInsert())
1243       env.updateValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
1244   } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
1245     // End a while-loop.
1246     finalizeWhileOp(env, rewriter, needsUniv);
1247   } else {
1248     needsUniv = false;
1249   }
1250   env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
1251     env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc);
1252     return std::nullopt;
1253   });
1254   return needsUniv;
1255 }
1256 
1257 /// Ends a loop sequence at given level.
1258 static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
1259                        unsigned at) {
1260   assert(!env.getLoopVar(at));
1261   env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc());
1262   // Unmark bookkeeping of invariants and loop index.
1263   genInvariants(env, builder, exp, at, /*isStart=*/false);
1264   // Finalize access pattern expansion for sparse tensor output.
1265   genExpand(env, builder, at, /*isStart=*/false);
1266 }
1267 
1268 /// Recursively generates code while computing iteration lattices in order
1269 /// to manage the complexity of implementing co-iteration over unions
1270 /// and intersections of sparse iterations spaces.
1271 static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
1272                     LoopId curr) {
1273   assert(curr == env.getCurrentDepth());
1274 
1275   // At each leaf, assign remaining tensor (sub)expression to output tensor.
1276   if (curr == env.getLoopNum()) {
1277     Value rhs = genExp(env, rewriter, exp);
1278     genTensorStore(env, rewriter, exp, rhs);
1279     return;
1280   }
1281 
1282   // Construct iteration lattices for current loop index.
1283   const LatSetId lts =
1284       env.merger().optimizeSet(env.merger().buildLattices(exp, curr));
1285 
1286   // Start a loop sequence.
1287   bool needsUniv = startLoopSeq(env, rewriter, exp, curr, lts);
1288 
1289   // When using sparse-iterator-based loops, we only need one loops, as
1290   // opposed to a loop sequence, to cover all the iterator spaces.
1291   const unsigned lsize = env.set(lts).size();
1292   if (env.generatingSparseIterator()) {
1293     // Get the largest lattice point and start a loop.
1294     const LatPointId li = env.set(lts)[0];
1295     auto [loop, isSingleCond] =
1296         startLoop(env, rewriter, curr, li, lsize, needsUniv);
1297     assert(isSingleCond == llvm::isa<IterateOp>(loop));
1298     // We cannot change this to `for (const LatPointId li : env.set(lts))`
1299     // because the loop body causes data-movement which invalidates
1300     // the iterator.
1301     for (unsigned j = 0; j < lsize; j++) {
1302       const LatPointId lj = env.set(lts)[j];
1303       const ExprId ej = env.lat(lj).exp;
1304       // Recurse into body of each branch.
1305       if (!isSingleCond) {
1306         env.genLoopBoundary([&, curr, j, li, lj](MutableArrayRef<Value> reduc) {
1307           genCoIterationCase(env, rewriter, /*caseIdx*/ j, li, lj, reduc);
1308           genStmt(env, rewriter, ej, curr + 1);
1309           // TODO: handle yield values.
1310           assert(reduc.empty() && "Not Implemented");
1311           rewriter.create<sparse_tensor::YieldOp>(env.op().getLoc());
1312           return std::nullopt;
1313         });
1314         // endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1315       } else {
1316         genStmt(env, rewriter, ej, curr + 1);
1317       }
1318     }
1319     // End a loop.
1320     needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1321   } else {
1322     // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1323     for (unsigned i = 0; i < lsize; i++) {
1324       const LatPointId li = env.set(lts)[i];
1325       // Start a loop.
1326       auto [loop, isSingleCond] =
1327           startLoop(env, rewriter, curr, li, lsize, needsUniv);
1328 
1329       // Visit all lattices points with Li >= Lj to generate the
1330       // loop-body, possibly with if statements for coiteration.
1331       Value redInput = env.getReduc();
1332       Value cntInput = env.getExpandCount();
1333       Value insInput = env.getInsertionChain();
1334       Value validIns = env.getValidLexInsert();
1335       // We cannot change this to `for (const LatPointId lj : env.set(lts))`
1336       // because the loop body causes data-movement which invalidates the
1337       // iterator.
1338       for (unsigned j = 0; j < lsize; j++) {
1339         const LatPointId lj = env.set(lts)[j];
1340         const ExprId ej = env.lat(lj).exp;
1341         if (li == lj || env.merger().latGT(li, lj)) {
1342           // Recurse into body of each branch.
1343           if (!isSingleCond) {
1344             scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
1345             genStmt(env, rewriter, ej, curr + 1);
1346             endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1347           } else {
1348             genStmt(env, rewriter, ej, curr + 1);
1349           }
1350         }
1351       }
1352 
1353       // End a loop.
1354       needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1355     }
1356   }
1357 
1358   // End a loop sequence.
1359   endLoopSeq(env, rewriter, exp, curr);
1360   assert(curr == env.getCurrentDepth());
1361 }
1362 
1363 /// Converts the result computed by the sparse kernel into the required form.
1364 static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
1365   linalg::GenericOp op = env.op();
1366   OpOperand *lhs = op.getDpsInitOperand(0);
1367   Value tensor = lhs->get();
1368   Type resType = tensor.getType();
1369   if (getSparseTensorEncoding(resType)) {
1370     // The sparse tensor rematerializes from the original sparse tensor's
1371     // underlying sparse storage format. For an insertion chain, the
1372     // tensor materializes from the chain with 'hasInserts' enabled.
1373     bool hasInserts = false;
1374     if (Value chain = env.getInsertionChain()) {
1375       hasInserts = true;
1376       tensor = chain;
1377     }
1378     rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts);
1379   } else {
1380     // To rematerialize an non-annotated tensor, simply load it
1381     // from the bufferized value.
1382     Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()];
1383     rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
1384   }
1385 }
1386 
1387 //===----------------------------------------------------------------------===//
1388 // Sparsifier rewriting methods.
1389 //===----------------------------------------------------------------------===//
1390 
1391 namespace {
1392 
1393 /// Sparse rewriting rule for generic Lingalg operation.
1394 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1395 public:
1396   GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1397       : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1398 
1399   LogicalResult matchAndRewrite(linalg::GenericOp op,
1400                                 PatternRewriter &rewriter) const override {
1401     // Only accept single output operations with pure tensor semantics.
1402     if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics())
1403       return failure();
1404 
1405     // Only accept trivial affine indices.
1406     if (hasNonTrivialAffineOnSparseOut(op))
1407       return failure();
1408 
1409     // Only accept scheduled loops.
1410     if (!op->hasAttr("sorted")) {
1411       return rewriter.notifyMatchFailure(
1412           op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
1413               "before sparsification.");
1414     }
1415 
1416     // Must have been demapped as well if the generic op is sorted.
1417     assert(!hasAnyNonIdentityOperandsOrResults(op));
1418 
1419     // Sets up a code generation environment.
1420     const unsigned numTensors = op->getNumOperands();
1421     const unsigned numLoops = op.getNumLoops();
1422     bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0;
1423     // If we have indexing map like (d0) -> (0, d0), there might be more
1424     // levels then loops because of the constant index, that means we can not
1425     // use numLoops as the upper bound for ranks of all tensors.
1426     // TODO: Constant indices are currently not support on sparse tensor, but
1427     // are allowed in non-annotated dense tensor. Support it, it would be
1428     // required for sparse tensor slice rank reducing too.
1429     Level maxLvlRank = 0;
1430     for (auto operand : op.getOperands()) {
1431       if (auto rtp = dyn_cast<RankedTensorType>(operand.getType())) {
1432         maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank());
1433       }
1434     }
1435 
1436     // Detects sparse annotations and translates the per-level sparsity
1437     // information for all tensors to loop indices in the kernel.
1438     CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
1439     if (!findSparseAnnotations(env, needIdxRed))
1440       return failure();
1441 
1442     // Only standard reduction operations (add, sub, or, xor) that can be
1443     // sparsified by merely reducing the stored values are admissible. More
1444     // elaborate reduction operations (such as mul, and, min, max) would need
1445     // to know whether implicit zeros occur as well. They can still be
1446     // implemented with a custom reduction operation, accepted here as well.
1447     if (op.getNumReductionLoops() > 0) {
1448       Operation *yield = op.getRegion().front().getTerminator();
1449       assert(isa<linalg::YieldOp>(yield));
1450       Operation *redop = yield->getOperand(0).getDefiningOp();
1451       if (!isa<arith::AddFOp>(redop) && !isa<complex::AddOp>(redop) &&
1452           !isa<arith::AddIOp>(redop) && !isa<arith::SubFOp>(redop) &&
1453           !isa<complex::SubOp>(redop) && !isa<arith::SubIOp>(redop) &&
1454           !isa<arith::OrIOp>(redop) && !isa<arith::XOrIOp>(redop) &&
1455           !isa<ReduceOp>(redop)) {
1456         return failure();
1457       }
1458     }
1459 
1460     // Constructs the tensor expressions tree from `op`, returns failure if the
1461     // tree can not be built or the tensor expression is inadmissible.
1462     if (failed(env.initTensorExp()))
1463       return failure();
1464 
1465     // Recursively generates code if admissible.
1466     env.startEmit(options.sparseEmitStrategy);
1467     genBuffers(env, rewriter);
1468     // TODO: Constant affine expression should be handled differently when using
1469     // slice-based codegen, it does not matter now because we already reject the
1470     // constant expression at an earlier stage.
1471     genInitConstantDenseAddress(env, rewriter);
1472     genStmt(env, rewriter, env.getExprId(), 0);
1473     genResult(env, rewriter);
1474     return success();
1475   }
1476 
1477 private:
1478   /// Options to control sparse code generation.
1479   SparsificationOptions options;
1480 };
1481 
1482 } // namespace
1483 
1484 /// Populates the given patterns list with rewriting rules required for
1485 /// the sparsification of linear algebra operations.
1486 void mlir::populateSparsificationPatterns(
1487     RewritePatternSet &patterns, const SparsificationOptions &options) {
1488   patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1489 }
1490