xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (revision c44202574ff9a8c0632aba30c2765b134557435f)
196a23911SAart Bik //===- Sparsification.cpp - Implementation of sparsification --------------===//
2a2c9d4bbSAart Bik //
3a2c9d4bbSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a2c9d4bbSAart Bik // See https://llvm.org/LICENSE.txt for license information.
5a2c9d4bbSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a2c9d4bbSAart Bik //
7a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
8a2c9d4bbSAart Bik //
9160399c7SAart Bik // This file implements converting sparse tensor types to actual sparse code.
10a2c9d4bbSAart Bik //
11a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
12a2c9d4bbSAart Bik 
13365777ecSAart Bik #include "Utils/CodegenEnv.h"
14365777ecSAart Bik #include "Utils/CodegenUtils.h"
15365777ecSAart Bik #include "Utils/LoopEmitter.h"
1653cc3a06SAart Bik 
1776a18618SMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
18abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
197a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
2057470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
2136550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
2263015742SJavier Setoain #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
24a2c9d4bbSAart Bik #include "mlir/Dialect/Linalg/Utils/Utils.h"
2566f878ceSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
268b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
278b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
28a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
29f708a549Swren romano #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
30a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
31744146f6SGus Smith #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
326d8e2f1eSAart Bik #include "mlir/Dialect/Tensor/IR/Tensor.h"
33fb287335SPeiming Liu #include "mlir/IR/AffineExprVisitor.h"
34a2c9d4bbSAart Bik #include "mlir/IR/Matchers.h"
3596a23911SAart Bik #include "mlir/IR/TensorEncoding.h"
36a2c9d4bbSAart Bik #include "llvm/ADT/SmallBitVector.h"
37067bebb5SAart Bik 
38a1fe1f5fSKazu Hirata #include <optional>
39a2c9d4bbSAart Bik 
40a2c9d4bbSAart Bik using namespace mlir;
4196a23911SAart Bik using namespace mlir::sparse_tensor;
42a2c9d4bbSAart Bik 
435da21338SAart Bik //===----------------------------------------------------------------------===//
44c43e6274STim Harvey // Sparsifier analysis methods.
455da21338SAart Bik //===----------------------------------------------------------------------===//
465da21338SAart Bik 
4798ce2debSAart Bik /// Returns true iff affine expression is invariant. Sets the
48c5a1732cSAart Bik /// parameter `isCurrentLoop` when expression just became invariant.
49c5a1732cSAart Bik static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop) {
5027aabca0SPeiming Liu   switch (a.getKind()) {
5127aabca0SPeiming Liu   case AffineExprKind::DimId: {
521609f1c2Slong.chen     const LoopId i = cast<AffineDimExpr>(a).getPosition();
53c5a1732cSAart Bik     if (i + 1 == curr) {
54c5a1732cSAart Bik       isCurrentLoop = true;
5598ce2debSAart Bik       return true; // becomes invariant at current loop
5627aabca0SPeiming Liu     }
57c5a1732cSAart Bik     return i < curr; // invariant when already generated
5827aabca0SPeiming Liu   }
5927aabca0SPeiming Liu   case AffineExprKind::Add:
6027aabca0SPeiming Liu   case AffineExprKind::Mul: {
611609f1c2Slong.chen     auto binOp = cast<AffineBinaryOpExpr>(a);
62c5a1732cSAart Bik     return isInvariantAffine(binOp.getLHS(), curr, isCurrentLoop) &&
63c5a1732cSAart Bik            isInvariantAffine(binOp.getRHS(), curr, isCurrentLoop);
6427aabca0SPeiming Liu   }
6527aabca0SPeiming Liu   default: {
661609f1c2Slong.chen     assert(isa<AffineConstantExpr>(a));
6727aabca0SPeiming Liu     return true;
6827aabca0SPeiming Liu   }
6927aabca0SPeiming Liu   }
7027aabca0SPeiming Liu }
7127aabca0SPeiming Liu 
72b1d44e59SAart Bik /// Helper method to inspect affine expressions. Rejects cases where the
73b22397feSAart Bik /// same index is used more than once. Also rejects compound affine
74b22397feSAart Bik /// expressions in sparse dimensions.
75b8cf7af9Swren romano static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
761944c4f7SAart Bik                        LevelType lt, bool setLvlFormat = true) {
77b1d44e59SAart Bik   switch (a.getKind()) {
78b1d44e59SAart Bik   case AffineExprKind::DimId: {
791609f1c2Slong.chen     const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
801dd387e1SAart Bik     if (!isUndefLT(merger.getLvlType(tid, idx)))
81b1d44e59SAart Bik       return false; // used more than once
82b0f8057eSPeiming Liu     if (setLvlFormat)
831dd387e1SAart Bik       merger.setLevelAndType(tid, idx, lvl, lt);
84b1d44e59SAart Bik     return true;
85b1d44e59SAart Bik   }
86b1d44e59SAart Bik   case AffineExprKind::Add:
87372e7939SPeiming Liu   case AffineExprKind::Mul:
88372e7939SPeiming Liu   case AffineExprKind::Constant: {
8952b69aa3SPeiming Liu     assert(lt.hasDenseSemantic());
901609f1c2Slong.chen     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) {
912cb99df6SYinying Li       // We do not set dim level format for affine expression like d0 + d1 on
92c5a1732cSAart Bik       // either loop index at d0 or d1. We continue the recursion merely to
93c5a1732cSAart Bik       // check whether current affine is admissible or not.
941dd387e1SAart Bik       return findAffine(merger, tid, lvl, binOp.getLHS(), lt, false) &&
951dd387e1SAart Bik              findAffine(merger, tid, lvl, binOp.getRHS(), lt, false);
96372e7939SPeiming Liu     }
97372e7939SPeiming Liu     // Falls through when it is a constant Affine
98372e7939SPeiming Liu     return true;
99372e7939SPeiming Liu   }
100b1d44e59SAart Bik   default:
101b1d44e59SAart Bik     return false;
102b1d44e59SAart Bik   }
103b1d44e59SAart Bik }
104b1d44e59SAart Bik 
105d03805f2SPeiming Liu /// Helper method to inspect affine expressions for index variable reduction
106d03805f2SPeiming Liu /// based codegen. It finds the dependent index set for all tensor levels in the
107d03805f2SPeiming Liu /// current expression we are generating.
108d03805f2SPeiming Liu ///
109d03805f2SPeiming Liu /// For example, when handling A[i+j][j+k], we build the two way mapping in
110d03805f2SPeiming Liu /// merger between (tensor, level) pairs and their dependent index variable set:
111d03805f2SPeiming Liu /// A_0 <=> [i, j] and A_1 <=> [j, k]
112d03805f2SPeiming Liu ///
113d03805f2SPeiming Liu /// It rejects cases (returns false)
114d03805f2SPeiming Liu /// 1st, when the same index is used more than once, e.g., A[i+j][i]
115d03805f2SPeiming Liu /// 2nd, when multiplication is used in the non-trivial index expression.
116d03805f2SPeiming Liu /// 3rd, when a constant operand is used in the non-trivial index expression.
117d03805f2SPeiming Liu ///
118d03805f2SPeiming Liu /// TODO: constant should be easy to handle.
119d03805f2SPeiming Liu static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
1201944c4f7SAart Bik                           AffineExpr a, LevelType lt, bool isSubExp = false,
121e015d385SPeiming Liu                           int64_t coefficient = 1) {
122d03805f2SPeiming Liu   switch (a.getKind()) {
123d03805f2SPeiming Liu   case AffineExprKind::DimId: {
124e015d385SPeiming Liu     // Only allow positive coefficients on AffineDimExpr.
125e015d385SPeiming Liu     if (coefficient <= 0)
126e015d385SPeiming Liu       return false;
127e015d385SPeiming Liu 
12898ce2debSAart Bik     const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
12998ce2debSAart Bik     if (!isUndefLT(merger.getLvlType(tensor, idx)))
130d03805f2SPeiming Liu       return false; // used more than once, e.g., A[i][i]
131d03805f2SPeiming Liu 
132d03805f2SPeiming Liu     // TODO: Generalizes the following two cases. A[i] (with trivial index
133d03805f2SPeiming Liu     // expression) can be treated as a special affine index expression. We do
134d03805f2SPeiming Liu     // not necessarily need to differentiate them.
135e015d385SPeiming Liu     if (!isSubExp) {
136e015d385SPeiming Liu       assert(coefficient == 1);
13798ce2debSAart Bik       merger.setLevelAndType(tensor, idx, lvl, lt);
138e015d385SPeiming Liu     }
139d03805f2SPeiming Liu 
140d03805f2SPeiming Liu     if (isSubExp) {
141d03805f2SPeiming Liu       // The current loops appears in more than one affine expressions on the
142d03805f2SPeiming Liu       // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is
143d03805f2SPeiming Liu       // used twice.
14498ce2debSAart Bik       if (merger.hasDependentLvl(idx, tensor)) {
145d03805f2SPeiming Liu         // TODO: This can be supported by coiterate slices if the loop idx is
146d03805f2SPeiming Liu         // appeared on affine index for different tensor, or take slice on
1472cb99df6SYinying Li         // multiple dimensions when it is on the same tensor.
148d03805f2SPeiming Liu         // E.g.,
149d03805f2SPeiming Liu         // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0]
150d03805f2SPeiming Liu         // d0_1 = getNextSliceOffset t0 along lvl0
151d03805f2SPeiming Liu         // d0_2 = getNextSliceOffset t1 along lvl0
152d03805f2SPeiming Liu         // if d0_1 == d0_2 then d0 = d0_1 = d0_1
153d03805f2SPeiming Liu         // else increase min(d0_1, d0_2).
154d03805f2SPeiming Liu         return false;
155d03805f2SPeiming Liu       }
15698ce2debSAart Bik       merger.setLoopDependentTensorLevel(idx, tensor, lvl, lt, coefficient);
157d03805f2SPeiming Liu     }
158d03805f2SPeiming Liu     return true;
159d03805f2SPeiming Liu   }
160d03805f2SPeiming Liu   case AffineExprKind::Constant:
161e015d385SPeiming Liu   case AffineExprKind::Mul: {
162e015d385SPeiming Liu     // TODO: Support index expression like `2 * d0`, we now only support more
163e015d385SPeiming Liu     // complicated cases like `2 * d0 + d1`.
164e015d385SPeiming Liu     if (!isSubExp)
165d03805f2SPeiming Liu       return false;
16626968554SPeiming Liu 
16726968554SPeiming Liu     // TODO: Support Constant AffineExp for slice-based codegen
1681609f1c2Slong.chen     if (isa<AffineConstantExpr>(a))
16926968554SPeiming Liu       llvm_unreachable("Not yet implemented");
17026968554SPeiming Liu 
1711609f1c2Slong.chen     auto binOp = cast<AffineBinaryOpExpr>(a);
172e015d385SPeiming Liu     auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
1731609f1c2Slong.chen     if (isa<AffineConstantExpr>(rhs))
174e015d385SPeiming Liu       std::swap(lhs, rhs);
175e015d385SPeiming Liu     // Must be in form of `constant * d`.
1761609f1c2Slong.chen     assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs));
1771609f1c2Slong.chen     int64_t coefficient = cast<AffineConstantExpr>(lhs).getValue();
1781dd387e1SAart Bik     return findDepIdxSet(merger, tensor, lvl, rhs, lt, isSubExp, coefficient);
179e015d385SPeiming Liu   }
180d03805f2SPeiming Liu   case AffineExprKind::Add: {
1811609f1c2Slong.chen     auto binOp = cast<AffineBinaryOpExpr>(a);
1821dd387e1SAart Bik     return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), lt, true) &&
1831dd387e1SAart Bik            findDepIdxSet(merger, tensor, lvl, binOp.getRHS(), lt, true);
184d03805f2SPeiming Liu   }
185d03805f2SPeiming Liu   default:
186d03805f2SPeiming Liu     return false;
187d03805f2SPeiming Liu   }
188d03805f2SPeiming Liu }
189d03805f2SPeiming Liu 
190067bebb5SAart Bik /// Gets the total number of compound affine expressions in the
191b8cf7af9Swren romano /// `getMatchingIndexingMap` for the given tensor.  For the following inputs:
192372e7939SPeiming Liu ///
1932a07f0fdSYinying Li /// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed)
194372e7939SPeiming Liu ///
195372e7939SPeiming Liu /// Returns 1 (because the first level is compressed and its corresponding
196b8cf7af9Swren romano /// indexing-expression is `d0 + d1`)
197d03805f2SPeiming Liu static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
198d03805f2SPeiming Liu                                                    Value tensor) {
1992cb99df6SYinying Li   // The `tensor` is not guaranteed to have `RankedTensorType`, therefore
200b8cf7af9Swren romano   // we can't use `getRankedTensorType`/`getSparseTensorType` here.
201b8cf7af9Swren romano   // However, we don't need to handle `StorageSpecifierType`, so we
202b8cf7af9Swren romano   // can use `SparseTensorType` once we guard against non-tensors.
2035550c821STres Popp   const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
204b8cf7af9Swren romano   if (!rtp)
205b8cf7af9Swren romano     return 0;
206b8cf7af9Swren romano   const SparseTensorType stt(rtp);
207b8cf7af9Swren romano 
208b8cf7af9Swren romano   const Level lvlRank = stt.getLvlRank();
209b8cf7af9Swren romano   const auto exprs = map.getResults();
210ccd923e3SPeiming Liu   assert(static_cast<Dimension>(exprs.size()) == lvlRank &&
211b8cf7af9Swren romano          "AffineMap does not have dimension-rank many results");
212372e7939SPeiming Liu   unsigned num = 0;
213f708a549Swren romano   for (Level l = 0; l < lvlRank; l++) {
21452b69aa3SPeiming Liu     if (!isa<AffineDimExpr>(exprs[l]) && !stt.getLvlType(l).hasDenseSemantic())
215372e7939SPeiming Liu       num++;
216372e7939SPeiming Liu   }
217372e7939SPeiming Liu   return num;
218372e7939SPeiming Liu }
219372e7939SPeiming Liu 
220067bebb5SAart Bik /// Gets the total number of sparse levels with compound affine
221b8cf7af9Swren romano /// expressions, summed over all operands of the `GenericOp`.
222d03805f2SPeiming Liu static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
223372e7939SPeiming Liu   unsigned num = 0;
224372e7939SPeiming Liu   for (OpOperand &t : op->getOpOperands())
225d03805f2SPeiming Liu     num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t),
226372e7939SPeiming Liu                                               t.get());
227372e7939SPeiming Liu   return num;
228372e7939SPeiming Liu }
229372e7939SPeiming Liu 
230067bebb5SAart Bik // Returns true iff output has nontrivial affine indices.
231d03805f2SPeiming Liu static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
232a7bf2e55SPeiming Liu   OpOperand *out = op.getDpsInitOperand(0);
233f708a549Swren romano   if (getSparseTensorType(out->get()).isAllDense())
234a7bf2e55SPeiming Liu     return false;
235d03805f2SPeiming Liu   return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out),
236a7bf2e55SPeiming Liu                                             out->get());
237a7bf2e55SPeiming Liu }
238a7bf2e55SPeiming Liu 
23996a23911SAart Bik /// Helper method to inspect sparse encodings in the tensor types.
240a2c9d4bbSAart Bik /// Fills the per-dimension sparsity information for all tensors.
241b1d44e59SAart Bik /// Returns true if the sparse annotations and affine subscript
242a3610359SAart Bik /// expressions of all tensors are admissible. Returns false if
243a3610359SAart Bik /// no annotations are found or inadmissible constructs occur.
244d03805f2SPeiming Liu /// We currently support two different ways to handle non-trivial index
245d03805f2SPeiming Liu /// expression on sparse tensors, and they accept different affine expressions.
246d03805f2SPeiming Liu /// When using dependent index reducton-based approach, it currently only
247d03805f2SPeiming Liu /// supports affine addition index expression.
248d03805f2SPeiming Liu static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
249bf9ef3efSAart Bik   bool annotated = false;
250384049a7SAart Bik   for (OpOperand &t : env.op()->getOpOperands()) {
25146a384dfSwren romano     const TensorId tid = env.makeTensorId(t.getOperandNumber());
252f708a549Swren romano     const auto map = env.op().getMatchingIndexingMap(&t);
253f708a549Swren romano     const auto enc = getSparseTensorEncoding(t.get().getType());
254727a63e0SAart Bik     if (enc)
255bf9ef3efSAart Bik       annotated = true;
256f708a549Swren romano     const Level lvlRank = map.getNumResults();
257f708a549Swren romano     assert(!enc || lvlRank == enc.getLvlRank());
25871251e8dSJie Fu     assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
259c5a1732cSAart Bik     // We only need to do index reduction if there is at least one
260c5a1732cSAart Bik     // non-trivial index expression on sparse levels. If all non-trivial
261c5a1732cSAart Bik     // index expression is on dense levels, we can efficiently rely on
262c5a1732cSAart Bik     // the random access to locate the element.
263d03805f2SPeiming Liu     bool needIdxReduc =
264d03805f2SPeiming Liu         enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0;
265d03805f2SPeiming Liu     // If then current tensor being inspected requires affine index, it need
266d03805f2SPeiming Liu     // to be sliced.
267f708a549Swren romano     for (Level l = 0; l < lvlRank; l++) {
268ccd923e3SPeiming Liu       const AffineExpr a = map.getResult(l);
2691944c4f7SAart Bik       const LevelType lt = enc.getLvlType(l);
270d03805f2SPeiming Liu       if (idxReducBased && needIdxReduc) {
2711dd387e1SAart Bik         if (!findDepIdxSet(env.merger(), tid, l, a, lt))
272a3610359SAart Bik           return false; // inadmissible affine expression
273d03805f2SPeiming Liu       } else {
2741dd387e1SAart Bik         if (!findAffine(env.merger(), tid, l, a, lt))
275d03805f2SPeiming Liu           return false; // inadmissible affine expression
276d03805f2SPeiming Liu       }
277a2c9d4bbSAart Bik     }
278a2c9d4bbSAart Bik   }
279bf9ef3efSAart Bik   return annotated;
280a2c9d4bbSAart Bik }
281a2c9d4bbSAart Bik 
2825da21338SAart Bik //===----------------------------------------------------------------------===//
283c43e6274STim Harvey // Sparsifier synthesis methods (statements and expressions).
2847373cabcSAart Bik //===----------------------------------------------------------------------===//
2857373cabcSAart Bik 
286a2c9d4bbSAart Bik /// Local bufferization of all dense and sparse data structures.
287fbe61130SAart Bik static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
288384049a7SAart Bik   linalg::GenericOp op = env.op();
289a2c9d4bbSAart Bik   Location loc = op.getLoc();
290b4db15a9SAlexander Belyaev   assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
291b0f8057eSPeiming Liu 
29283b7f018SPeiming Liu   SmallVector<Range, 4> loopRange =
29383b7f018SPeiming Liu       llvm::cast<linalg::LinalgOp>(op.getOperation())
29483b7f018SPeiming Liu           .createLoopRanges(builder, loc);
29583b7f018SPeiming Liu 
29636fd2875SAart Bik   env.emitter().initializeLoopEmit(
297b0f8057eSPeiming Liu       builder, loc,
298b0f8057eSPeiming Liu       /// Generates buffer for the output tensor.
299b0f8057eSPeiming Liu       /// Note that all sparse kernels assume that when all elements are written
300b0f8057eSPeiming Liu       /// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized
301b0f8057eSPeiming Liu       /// to all zeroes and only nonzeroes values are computed and written out.
302b0f8057eSPeiming Liu       /// For updates (viz. x(i) += y(i) * z(i)), only nonzeroes values are used
303b0f8057eSPeiming Liu       /// for the updates and no assumption on the original contents of the
304b0f8057eSPeiming Liu       /// output buffer is necessary.
305b0f8057eSPeiming Liu       [&op](OpBuilder &builder, Location loc, Value memref,
306b0f8057eSPeiming Liu             Value tensor) -> Value {
307b0f8057eSPeiming Liu         // Must not be a sparse tensor.
308b0f8057eSPeiming Liu         assert(!getSparseTensorEncoding(tensor.getType()));
30998f93e3bSAart Bik         // Two output tensor references should point to the same object.
310b4db15a9SAlexander Belyaev         OpOperand *lhs = op.getDpsInitOperand(0);
311b0f8057eSPeiming Liu         assert(lhs->get() == tensor);
312b0f8057eSPeiming Liu         // An output tensor can simply materialize from the buffer of the tensor
313b0f8057eSPeiming Liu         // that appears in the outs() clause. For updates, this has the
314b0f8057eSPeiming Liu         // advantage that only the nonzero value are involved in the
315b0f8057eSPeiming Liu         // computation, keeping the operation O(nnz). In all other cases, we are
316b0f8057eSPeiming Liu         // forced to zero out the buffer to enforce the assumption above, which
317b0f8057eSPeiming Liu         // may negatively impact running complexity (viz. O(n^2 + nnz) vs.
318b0f8057eSPeiming Liu         // O(nnz) for matrices).
319b0f8057eSPeiming Liu         // TODO: use better analysis to avoid zeroing out the buffer?
32098f93e3bSAart Bik         bool isInit = op.isInitTensor(lhs);
321b0f8057eSPeiming Liu         Value init = memref;
322b0f8057eSPeiming Liu         if (!isInit) {
323b0f8057eSPeiming Liu           Value zero = constantZero(builder, loc,
324b0f8057eSPeiming Liu                                     getElementTypeOrSelf(tensor.getType()));
325b0f8057eSPeiming Liu           builder.create<linalg::FillOp>(loc, ValueRange{zero},
326b0f8057eSPeiming Liu                                          ValueRange{init});
327a2c9d4bbSAart Bik         }
328b0f8057eSPeiming Liu         return init;
32983b7f018SPeiming Liu       },
330ccd923e3SPeiming Liu       [&loopRange](OpBuilder &b, Location loc, Level l) {
331ccd923e3SPeiming Liu         assert(l < loopRange.size());
332ff8815e5SPeiming Liu         return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size);
333b0f8057eSPeiming Liu       });
334a2c9d4bbSAart Bik }
335a2c9d4bbSAart Bik 
3364f2ec7f9SAart Bik /// Generates index for load/store on sparse tensor.
337fbe61130SAart Bik static Value genIndex(CodegenEnv &env, OpOperand *t) {
33846a384dfSwren romano   const auto map = env.op().getMatchingIndexingMap(t);
339f708a549Swren romano   const auto stt = getSparseTensorType(t->get());
340f708a549Swren romano   const Level lvlRank = stt.getLvlRank();
341f708a549Swren romano   assert(static_cast<Level>(map.getNumResults()) == lvlRank);
342ccd923e3SPeiming Liu   const AffineExpr a = map.getResult(lvlRank - 1);
3434f2ec7f9SAart Bik   assert(a.getKind() == AffineExprKind::DimId);
3441609f1c2Slong.chen   const LoopId idx = env.makeLoopId(cast<AffineDimExpr>(a).getPosition());
345b8cf7af9Swren romano   return env.getLoopVar(idx);
3464f2ec7f9SAart Bik }
3474f2ec7f9SAart Bik 
348b1d44e59SAart Bik /// Generates subscript for load/store on a dense or sparse tensor.
349fbe61130SAart Bik static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
3500e1708ffSAart Bik                           SmallVectorImpl<Value> &args) {
351b8cf7af9Swren romano   const Location loc = env.op().getLoc();
35246a384dfSwren romano   const TensorId tid = env.makeTensorId(t->getOperandNumber());
353b8cf7af9Swren romano   const auto map = env.op().getMatchingIndexingMap(t);
354f708a549Swren romano   const auto stt = getSparseTensorType(t->get());
355f708a549Swren romano   if (stt.hasEncoding()) {
356b8cf7af9Swren romano     // For sparse tensors we only push the last-level's position onto `args`.
357298412b5SPeiming Liu     const auto pos = env.emitter().getValPosits(tid);
35852b69aa3SPeiming Liu     assert(!pos.empty());
35952b69aa3SPeiming Liu     args.append(pos);
360951a3630SPeiming Liu     // Simply returns the tensor to extract value using iterators.
361951a3630SPeiming Liu     if (env.options().sparseEmitStrategy == SparseEmitStrategy::kSparseIterator)
362951a3630SPeiming Liu       return t->get();
363b1d44e59SAart Bik   } else {
364b8cf7af9Swren romano     // For dense tensors we push all level's coordinates onto `args`.
365f708a549Swren romano     const Level lvlRank = stt.getLvlRank();
366f708a549Swren romano     assert(static_cast<Level>(map.getNumResults()) == lvlRank);
367f708a549Swren romano     for (Level l = 0; l < lvlRank; l++) {
368b8cf7af9Swren romano       const auto lvlExpr = map.getResult(l);
369b8cf7af9Swren romano       const auto lvlCrd = env.emitter().genAffine(builder, loc, lvlExpr);
370b8cf7af9Swren romano       args.push_back(lvlCrd);
371b1d44e59SAart Bik     }
372b1d44e59SAart Bik   }
373b8cf7af9Swren romano   return env.emitter().getValBuffer()[tid];
374b1d44e59SAart Bik }
375b1d44e59SAart Bik 
3764f2ec7f9SAart Bik /// Generates insertion code to implement dynamic tensor load.
377fbe61130SAart Bik static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder,
37898f93e3bSAart Bik                               OpOperand *t) {
379384049a7SAart Bik   linalg::GenericOp op = env.op();
3804f2ec7f9SAart Bik   Location loc = op.getLoc();
38184cd51bbSwren romano   // Direct lexicographic coordinate order, tensor loads as zero.
382384049a7SAart Bik   if (!env.isExpand()) {
3834f2ec7f9SAart Bik     Type tp = getElementTypeOrSelf(t->get().getType());
384e9fa5590SMatthias Springer     return constantZero(builder, loc, tp);
3854f2ec7f9SAart Bik   }
3864f2ec7f9SAart Bik   // Load from expanded access pattern.
38798f93e3bSAart Bik   Value index = genIndex(env, t);
388384049a7SAart Bik   return builder.create<memref::LoadOp>(loc, env.getExpandValues(), index);
3894f2ec7f9SAart Bik }
3904f2ec7f9SAart Bik 
391c8bb2354SJim Kitchen /// Generates insertion code to implement dynamic tensor load for reduction.
392fbe61130SAart Bik static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder,
393c8bb2354SJim Kitchen                                     OpOperand *t) {
394384049a7SAart Bik   linalg::GenericOp op = env.op();
395c8bb2354SJim Kitchen   Location loc = op.getLoc();
396fbe61130SAart Bik   Value identity = env.getCustomRedId();
39784cd51bbSwren romano   // Direct lexicographic coordinate order, tensor loads as identity.
398384049a7SAart Bik   if (!env.isExpand())
399c8bb2354SJim Kitchen     return identity;
400c8bb2354SJim Kitchen   // Load from expanded access pattern if filled, identity otherwise.
401384049a7SAart Bik   Value values = env.getExpandValues();
402384049a7SAart Bik   Value filled = env.getExpandFilled();
40398f93e3bSAart Bik   Value index = genIndex(env, t);
404384049a7SAart Bik   Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
405384049a7SAart Bik   Value valAtIndex = builder.create<memref::LoadOp>(loc, values, index);
406c8bb2354SJim Kitchen   return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
407c8bb2354SJim Kitchen }
408c8bb2354SJim Kitchen 
4093aeb28b9SPeiming Liu static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond,
4103aeb28b9SPeiming Liu                                   Value sparseOut, ValueRange ivs, Value v) {
4113aeb28b9SPeiming Liu   scf::IfOp condInsert =
4123aeb28b9SPeiming Liu       builder.create<scf::IfOp>(loc, sparseOut.getType(), cond, true);
4133aeb28b9SPeiming Liu   // True branch.
4143aeb28b9SPeiming Liu   builder.setInsertionPointToStart(condInsert.thenBlock());
4153aeb28b9SPeiming Liu   Value res = builder.create<tensor::InsertOp>(loc, v, sparseOut, ivs);
4163aeb28b9SPeiming Liu   builder.create<scf::YieldOp>(loc, res);
4173aeb28b9SPeiming Liu   // False branch.
4183aeb28b9SPeiming Liu   builder.setInsertionPointToStart(condInsert.elseBlock());
4193aeb28b9SPeiming Liu   builder.create<scf::YieldOp>(loc, sparseOut);
4203aeb28b9SPeiming Liu   // Value assignment.
4213aeb28b9SPeiming Liu   builder.setInsertionPointAfter(condInsert);
4223aeb28b9SPeiming Liu   return condInsert.getResult(0);
4233aeb28b9SPeiming Liu }
4243aeb28b9SPeiming Liu 
4254f2ec7f9SAart Bik /// Generates insertion code to implement dynamic tensor store.
426fbe61130SAart Bik static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
42798f93e3bSAart Bik                               Value rhs) {
428384049a7SAart Bik   linalg::GenericOp op = env.op();
4294f2ec7f9SAart Bik   Location loc = op.getLoc();
43084cd51bbSwren romano   // Direct insertion in lexicographic coordinate order.
431384049a7SAart Bik   if (!env.isExpand()) {
4325b0db27aSAart Bik     const LoopId numLoops = op.getRank(t);
433372d88b0SPeiming Liu     // Retrieves the first `numLoop` induction variables.
434c5a1732cSAart Bik     SmallVector<Value> ivs = llvm::to_vector(llvm::drop_end(
435c5a1732cSAart Bik         env.emitter().getLoopIVsRange(), env.getCurrentDepth() - numLoops));
436384049a7SAart Bik     Value chain = env.getInsertionChain();
437047399c2SAart Bik     if (env.isValidLexInsert()) {
43881d0d2b2SJim Kitchen       // Generates runtime check for a valid lex during reduction,
43981d0d2b2SJim Kitchen       // to avoid inserting the identity value for empty reductions.
44081d0d2b2SJim Kitchen       //   if (validLexInsert) then
44181d0d2b2SJim Kitchen       //     insert(rhs) into chain
44281d0d2b2SJim Kitchen       //     return updated chain
44381d0d2b2SJim Kitchen       //   else
44481d0d2b2SJim Kitchen       //     return unmodified chain
4453aeb28b9SPeiming Liu       Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(),
4463aeb28b9SPeiming Liu                                        chain, ivs, rhs);
4473aeb28b9SPeiming Liu       env.updateInsertionChain(out);
448047399c2SAart Bik     } else {
4493aeb28b9SPeiming Liu       Value sparseOut;
4503aeb28b9SPeiming Liu       if (!hasAnySparseType(env.op().getInputs().getTypes())) {
4513aeb28b9SPeiming Liu         // This is an all-dense -> sparse kernel, test rhs != 0 before
4523aeb28b9SPeiming Liu         // insertion.
4533aeb28b9SPeiming Liu         Value nz = genIsNonzero(builder, loc, rhs);
4543aeb28b9SPeiming Liu         sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs);
4553aeb28b9SPeiming Liu       } else {
4563aeb28b9SPeiming Liu         sparseOut = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
4573aeb28b9SPeiming Liu       }
458047399c2SAart Bik       // Generates regular insertion chain.
4593aeb28b9SPeiming Liu       env.updateInsertionChain(sparseOut);
46081d0d2b2SJim Kitchen     }
4614f2ec7f9SAart Bik     return;
4624f2ec7f9SAart Bik   }
4634f2ec7f9SAart Bik   // Generates insertion code along expanded access pattern.
4644f2ec7f9SAart Bik   //   if (!expFilled[i]) then
4654f2ec7f9SAart Bik   //     expFilled[i] = true
4664f2ec7f9SAart Bik   //     expAdded[inserts++] = i
4674f2ec7f9SAart Bik   //   endif
4684f2ec7f9SAart Bik   //   values[i] = rhs
469384049a7SAart Bik   Value values = env.getExpandValues();
470384049a7SAart Bik   Value filled = env.getExpandFilled();
471384049a7SAart Bik   Value added = env.getExpandAdded();
472384049a7SAart Bik   Value count = env.getExpandCount();
47398f93e3bSAart Bik   Value index = genIndex(env, t);
474e9fa5590SMatthias Springer   Value fval = constantI1(builder, loc, false);
475e9fa5590SMatthias Springer   Value tval = constantI1(builder, loc, true);
4764f2ec7f9SAart Bik   // If statement.
477384049a7SAart Bik   Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
478e9fa5590SMatthias Springer   Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
479384049a7SAart Bik                                              isFilled, fval);
480e9fa5590SMatthias Springer   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond,
481e9fa5590SMatthias Springer                                              /*else=*/true);
4824f2ec7f9SAart Bik   // True branch.
483e9fa5590SMatthias Springer   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
484384049a7SAart Bik   builder.create<memref::StoreOp>(loc, tval, filled, index);
485384049a7SAart Bik   builder.create<memref::StoreOp>(loc, index, added, count);
486e9fa5590SMatthias Springer   Value one = constantIndex(builder, loc, 1);
487384049a7SAart Bik   Value add = builder.create<arith::AddIOp>(loc, count, one);
488e9fa5590SMatthias Springer   builder.create<scf::YieldOp>(loc, add);
4894f2ec7f9SAart Bik   // False branch.
490e9fa5590SMatthias Springer   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
491384049a7SAart Bik   builder.create<scf::YieldOp>(loc, count);
492e9fa5590SMatthias Springer   builder.setInsertionPointAfter(ifOp);
4934f2ec7f9SAart Bik   // Value assignment.
494384049a7SAart Bik   env.updateExpandCount(ifOp.getResult(0));
495384049a7SAart Bik   builder.create<memref::StoreOp>(loc, rhs, values, index);
4964f2ec7f9SAart Bik }
4974f2ec7f9SAart Bik 
498a2c9d4bbSAart Bik /// Generates a load on a dense or sparse tensor.
499b8cf7af9Swren romano static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
500a2c9d4bbSAart Bik   // Test if the load was hoisted to a higher loop nest.
50198f93e3bSAart Bik   Value val = env.exp(exp).val;
50226eb2c6bSPeiming Liu   if (val)
503a2c9d4bbSAart Bik     return val;
50465ee8f10SAart Bik   // Get tensor operand.
505384049a7SAart Bik   linalg::GenericOp op = env.op();
50665ee8f10SAart Bik   Location loc = op.getLoc();
507384049a7SAart Bik   OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
50865ee8f10SAart Bik   // Fold binary-valued tensor into explicit value.
50965ee8f10SAart Bik   const auto stt = getSparseTensorType(t->get());
51065ee8f10SAart Bik   if (auto explVal = stt.getExplicitVal())
51165ee8f10SAart Bik     return genValFromAttr(builder, loc, explVal);
51265ee8f10SAart Bik   // Load during insertion.
513384049a7SAart Bik   if (env.isSparseOutput(t)) {
514fbe61130SAart Bik     if (env.isCustomReduc())
515384049a7SAart Bik       return genInsertionLoadReduce(env, builder, t);
516384049a7SAart Bik     return genInsertionLoad(env, builder, t);
517c8bb2354SJim Kitchen   }
518951a3630SPeiming Liu 
519a2c9d4bbSAart Bik   // Actual load.
5200e1708ffSAart Bik   SmallVector<Value> args;
521384049a7SAart Bik   Value ptr = genSubscript(env, builder, t, args);
522951a3630SPeiming Liu   if (llvm::isa<TensorType>(ptr.getType())) {
523951a3630SPeiming Liu     assert(env.options().sparseEmitStrategy ==
524951a3630SPeiming Liu                SparseEmitStrategy::kSparseIterator &&
525951a3630SPeiming Liu            args.size() == 1);
526951a3630SPeiming Liu     return builder.create<ExtractValOp>(loc, ptr, args.front());
527951a3630SPeiming Liu   }
52865ee8f10SAart Bik   return builder.create<memref::LoadOp>(loc, ptr, args);
529a2c9d4bbSAart Bik }
530a2c9d4bbSAart Bik 
531727a63e0SAart Bik /// Generates a store on a dense or sparse tensor.
532b8cf7af9Swren romano static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
53398f93e3bSAart Bik                            Value rhs) {
534e2167d89SAart Bik   // Only unary and binary are allowed to return an uninitialized rhs
535e2167d89SAart Bik   // to indicate missing output. Or otherwise a custom reduction that
536e2167d89SAart Bik   // received no value to accumulate.
5376a38c772SAart Bik   if (!rhs) {
5386a38c772SAart Bik     assert(env.exp(exp).kind == TensorExp::Kind::kUnary ||
539e2167d89SAart Bik            env.exp(exp).kind == TensorExp::Kind::kBinary ||
540e2167d89SAart Bik            env.exp(exp).kind == TensorExp::Kind::kReduce);
5416a38c772SAart Bik     return;
5426a38c772SAart Bik   }
543a2c9d4bbSAart Bik   // Test if this is a scalarized reduction.
544fbe61130SAart Bik   if (env.isReduc()) {
545fbe61130SAart Bik     env.updateReduc(rhs);
546a2c9d4bbSAart Bik     return;
547a2c9d4bbSAart Bik   }
5486a38c772SAart Bik   // Regular store.
5496a38c772SAart Bik   linalg::GenericOp op = env.op();
5506a38c772SAart Bik   Location loc = op.getLoc();
551b4db15a9SAlexander Belyaev   OpOperand *t = op.getDpsInitOperand(0);
5526a38c772SAart Bik   if (!env.isSparseOutput(t)) {
5536a38c772SAart Bik     SmallVector<Value> args;
5546a38c772SAart Bik     Value ptr = genSubscript(env, builder, t, args);
5556a38c772SAart Bik     builder.create<memref::StoreOp>(loc, rhs, ptr, args);
5566a38c772SAart Bik     return;
5576a38c772SAart Bik   }
5586a38c772SAart Bik   // Store during sparse insertion.
5596a38c772SAart Bik   if (env.exp(exp).kind != TensorExp::Kind::kSelect) {
5606a38c772SAart Bik     genInsertionStore(env, builder, t, rhs);
5616a38c772SAart Bik     return;
5626a38c772SAart Bik   }
5635661647eSAart Bik   // Select operation insertion.
564384049a7SAart Bik   Value chain = env.getInsertionChain();
565384049a7SAart Bik   scf::IfOp ifOp =
566384049a7SAart Bik       builder.create<scf::IfOp>(loc, chain.getType(), rhs, /*else=*/true);
56779193503SJim Kitchen   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
56879193503SJim Kitchen   // Existing value was preserved to be used here.
56998f93e3bSAart Bik   assert(env.exp(exp).val);
57098f93e3bSAart Bik   Value v0 = env.exp(exp).val;
57198f93e3bSAart Bik   genInsertionStore(env, builder, t, v0);
5727c7c10a0Swren romano   env.merger().clearExprValue(exp);
5735661647eSAart Bik   // Yield modified insertion chain along true branch.
574384049a7SAart Bik   Value mchain = env.getInsertionChain();
575384049a7SAart Bik   builder.create<scf::YieldOp>(op.getLoc(), mchain);
5765661647eSAart Bik   // Yield original insertion chain along false branch.
5775661647eSAart Bik   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
578384049a7SAart Bik   builder.create<scf::YieldOp>(loc, chain);
5795661647eSAart Bik   // Done with if statement.
580384049a7SAart Bik   env.updateInsertionChain(ifOp->getResult(0));
58179193503SJim Kitchen   builder.setInsertionPointAfter(ifOp);
582a2c9d4bbSAart Bik }
583a2c9d4bbSAart Bik 
584a2c9d4bbSAart Bik /// Generates an invariant value.
585b8cf7af9Swren romano inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
58698f93e3bSAart Bik   return env.exp(exp).val;
587a2c9d4bbSAart Bik }
588a2c9d4bbSAart Bik 
589c43e6274STim Harvey /// Semi-ring branches are simply inlined by the sparsifier. Prior
5902a288616SAart Bik /// analysis has verified that all computations are "local" to the inlined
5912a288616SAart Bik /// branch or otherwise invariantly defined outside the loop nest, with the
5922a288616SAart Bik /// exception of index computations, which need to be relinked to actual
5932a288616SAart Bik /// inlined cloned code.
594fbe61130SAart Bik static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
595067bebb5SAart Bik                           Value e) {
59665bfd5cbSAart Bik   if (auto arg = dyn_cast<BlockArgument>(e)) {
59765bfd5cbSAart Bik     // Direct arguments of the original linalg op must be converted
59865bfd5cbSAart Bik     // into dense tensor loads. Note that we should not encounter
59965bfd5cbSAart Bik     // anything else. This needs to be verified by semi-ring ops.
60065bfd5cbSAart Bik     linalg::GenericOp op = env.op();
60165bfd5cbSAart Bik     if (arg.getOwner()->getParentOp() == op) {
60265bfd5cbSAart Bik       const TensorId tid = env.makeTensorId(arg.getArgNumber());
60365bfd5cbSAart Bik       OpOperand *t = &op->getOpOperand(tid);
60465bfd5cbSAart Bik       assert(!getSparseTensorType(t->get()).hasEncoding()); // dense!
60565bfd5cbSAart Bik       SmallVector<Value> args;
60665bfd5cbSAart Bik       Value ptr = genSubscript(env, rewriter, t, args);
60765bfd5cbSAart Bik       return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args);
60865bfd5cbSAart Bik     }
60965bfd5cbSAart Bik   } else if (Operation *def = e.getDefiningOp()) {
61065bfd5cbSAart Bik     // Handle index computation.
6112a288616SAart Bik     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
61246a384dfSwren romano       return env.getLoopVar(env.makeLoopId(indexOp.getDim()));
61365bfd5cbSAart Bik     // When still defined in new body, recurse into operands.
6142a288616SAart Bik     if (def->getBlock() == block) {
61565bfd5cbSAart Bik       rewriter.setInsertionPoint(def);
616ae9e1d1dSMatthias Springer       for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
6175fcf907bSMatthias Springer         rewriter.modifyOpInPlace(def, [&]() {
6182a288616SAart Bik           def->setOperand(
619067bebb5SAart Bik               i, relinkBranch(env, rewriter, block, def->getOperand(i)));
620ae9e1d1dSMatthias Springer         });
621ae9e1d1dSMatthias Springer       }
6222a288616SAart Bik     }
6232a288616SAart Bik   }
6242a288616SAart Bik   return e;
6252a288616SAart Bik }
6262a288616SAart Bik 
627a2c9d4bbSAart Bik /// Recursively generates tensor expression.
628067bebb5SAart Bik static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
62974c54206Swren romano   if (e == ::mlir::sparse_tensor::detail::kInvalidId)
630123e8dfcSAart Bik     return Value();
631faf7cd97SPeiming Liu 
632faf7cd97SPeiming Liu   linalg::GenericOp op = env.op();
633faf7cd97SPeiming Liu   Location loc = op.getLoc();
634b8cf7af9Swren romano   const TensorExp &exp = env.exp(e);
635b8cf7af9Swren romano   const auto kind = exp.kind;
6361f58ae80Swren romano   if (kind == TensorExp::Kind::kTensor)
637b8cf7af9Swren romano     return genTensorLoad(env, rewriter, e);
6381f58ae80Swren romano   if (kind == TensorExp::Kind::kInvariant)
639b8cf7af9Swren romano     return genInvariantValue(env, e);
6401f58ae80Swren romano   if (kind == TensorExp::Kind::kLoopVar)
641b8cf7af9Swren romano     return env.getLoopVar(exp.loop);
642b22397feSAart Bik 
6431f58ae80Swren romano   if (kind == TensorExp::Kind::kReduce)
644b8cf7af9Swren romano     env.startCustomReduc(e); // enter custom
645b22397feSAart Bik 
646faf7cd97SPeiming Liu   // If either lhs/rhs is a synthetic zero, we infer the type for the zero value
647faf7cd97SPeiming Liu   // based on the type of the other operand.
64898ce2debSAart Bik   Value v0, v1;
649faf7cd97SPeiming Liu   if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId &&
650faf7cd97SPeiming Liu       env.exp(exp.children.e0).kind == TensorExp::Kind::kSynZero) {
651067bebb5SAart Bik     v1 = genExp(env, rewriter, exp.children.e1);
652faf7cd97SPeiming Liu     v0 = constantZero(rewriter, loc, v1.getType());
653faf7cd97SPeiming Liu   } else if (exp.children.e1 != ::mlir::sparse_tensor::detail::kInvalidId &&
654faf7cd97SPeiming Liu              env.exp(exp.children.e1).kind == TensorExp::Kind::kSynZero) {
655067bebb5SAart Bik     v0 = genExp(env, rewriter, exp.children.e0);
656faf7cd97SPeiming Liu     v1 = constantZero(rewriter, loc, v0.getType());
657faf7cd97SPeiming Liu   } else {
658067bebb5SAart Bik     v0 = genExp(env, rewriter, exp.children.e0);
659067bebb5SAart Bik     v1 = genExp(env, rewriter, exp.children.e1);
660faf7cd97SPeiming Liu   }
661faf7cd97SPeiming Liu 
662e2167d89SAart Bik   Value ee;
663e2167d89SAart Bik   if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) {
664e2167d89SAart Bik     // custom reduce did not receive a value
665e2167d89SAart Bik   } else {
666e2167d89SAart Bik     ee = env.merger().buildExp(rewriter, loc, e, v0, v1);
6671f58ae80Swren romano     if (ee &&
6681f58ae80Swren romano         (kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary ||
6691f58ae80Swren romano          kind == TensorExp::Kind::kBinaryBranch ||
670e7df8281SPeiming Liu          kind == TensorExp::Kind::kReduce ||
671e7df8281SPeiming Liu          kind == TensorExp::Kind::kSelect)) {
67265bfd5cbSAart Bik       OpBuilder::InsertionGuard guard(rewriter);
673067bebb5SAart Bik       ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee);
674e2167d89SAart Bik     }
67565bfd5cbSAart Bik   }
676b22397feSAart Bik 
6771f58ae80Swren romano   if (kind == TensorExp::Kind::kReduce)
678fbe61130SAart Bik     env.endCustomReduc(); // exit custom
679fbe61130SAart Bik 
6807c7c10a0Swren romano   if (kind == TensorExp::Kind::kSelect)
6817c7c10a0Swren romano     env.merger().setExprValue(e, v0); // Preserve value for later use.
682b22397feSAart Bik 
6832a288616SAart Bik   return ee;
684a2c9d4bbSAart Bik }
685a2c9d4bbSAart Bik 
686a2c9d4bbSAart Bik /// Hoists loop invariant tensor loads for which indices have been exhausted.
687b8cf7af9Swren romano static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
688c5a1732cSAart Bik                           LoopId curr, bool isStart) {
68974c54206Swren romano   if (exp == ::mlir::sparse_tensor::detail::kInvalidId)
690123e8dfcSAart Bik     return;
6911f58ae80Swren romano   if (env.exp(exp).kind == TensorExp::Kind::kTensor) {
692a2c9d4bbSAart Bik     // Inspect tensor indices.
693384049a7SAart Bik     linalg::GenericOp op = env.op();
69498f93e3bSAart Bik     OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
69546a384dfSwren romano     const auto map = op.getMatchingIndexingMap(&t);
696f708a549Swren romano     const auto stt = getSparseTensorType(t.get());
697f708a549Swren romano     const Level lvlRank = stt.getLvlRank();
698f708a549Swren romano     assert(static_cast<Level>(map.getNumResults()) == lvlRank);
699c5a1732cSAart Bik     bool isCurrentLoop = curr == 0; // for scalar tensors
700f708a549Swren romano     for (Level l = 0; l < lvlRank; l++) {
701ccd923e3SPeiming Liu       const AffineExpr a = map.getResult(l);
702c5a1732cSAart Bik       if (!isInvariantAffine(a, curr, /*out*/ isCurrentLoop))
703a2c9d4bbSAart Bik         return; // still in play
704a2c9d4bbSAart Bik     }
705c5a1732cSAart Bik     // All exhausted at current level.
706c5a1732cSAart Bik     if (!isCurrentLoop)
7077373cabcSAart Bik       return;
708d96f46ddSAart Bik     // Generate code for a scalarized reduction or invariant. Note that
709d96f46ddSAart Bik     // because custom reduction lhs may occur several times in the IR,
710d96f46ddSAart Bik     // we have a built-in safety for only initializing and wrapping-up
711d96f46ddSAart Bik     // the scalarized reduction once.
712b4db15a9SAlexander Belyaev     OpOperand *lhs = op.getDpsInitOperand(0);
713a7cccb9cSAlexander Belyaev     if (lhs == &t) {
71481d0d2b2SJim Kitchen       // Start or end a scalarized reduction.
715c5a1732cSAart Bik       if (isStart) {
716d96f46ddSAart Bik         if (env.isCustomReduc()) {
717d96f46ddSAart Bik           if (!env.isReduc())
718d96f46ddSAart Bik             env.startReduc(exp, env.getCustomRedId());
719d96f46ddSAart Bik         } else {
720d96f46ddSAart Bik           env.startReduc(exp, genTensorLoad(env, builder, exp));
721d96f46ddSAart Bik         }
72281d0d2b2SJim Kitchen         if (env.hasSparseOutput())
723047399c2SAart Bik           env.startValidLexInsert(
724047399c2SAart Bik               constantI1(builder, env.op().getLoc(), false));
7257373cabcSAart Bik       } else {
726d96f46ddSAart Bik         if (!env.isCustomReduc() || env.isReduc())
727fbe61130SAart Bik           genTensorStore(env, builder, exp, env.endReduc());
728d96f46ddSAart Bik         if (env.hasSparseOutput())
729047399c2SAart Bik           env.endValidLexInsert();
7307373cabcSAart Bik       }
7317373cabcSAart Bik     } else {
7327373cabcSAart Bik       // Start or end loop invariant hoisting of a tensor load.
733d96f46ddSAart Bik       if (isStart) {
7347c7c10a0Swren romano         env.merger().setExprValue(exp, genTensorLoad(env, builder, exp));
735d96f46ddSAart Bik       } else {
7367c7c10a0Swren romano         env.merger().clearExprValue(exp);
737a2c9d4bbSAart Bik       }
738d96f46ddSAart Bik     }
7391f58ae80Swren romano   } else if (env.exp(exp).kind != TensorExp::Kind::kInvariant &&
740faf7cd97SPeiming Liu              env.exp(exp).kind != TensorExp::Kind::kLoopVar &&
741faf7cd97SPeiming Liu              env.exp(exp).kind != TensorExp::Kind::kSynZero) {
742a2c9d4bbSAart Bik     // Traverse into the binary operations. Note that we only hoist
743a2c9d4bbSAart Bik     // tensor loads, since subsequent MLIR/LLVM passes know how to
744a2c9d4bbSAart Bik     // deal with all other kinds of derived loop invariants.
7451f58ae80Swren romano     if (env.exp(exp).kind == TensorExp::Kind::kReduce)
746fbe61130SAart Bik       env.startCustomReduc(exp); // enter custom
747b8cf7af9Swren romano     const ExprId e0 = env.exp(exp).children.e0;
748b8cf7af9Swren romano     const ExprId e1 = env.exp(exp).children.e1;
749c5a1732cSAart Bik     genInvariants(env, builder, e0, curr, isStart);
750c5a1732cSAart Bik     genInvariants(env, builder, e1, curr, isStart);
7511f58ae80Swren romano     if (env.exp(exp).kind == TensorExp::Kind::kReduce)
752fbe61130SAart Bik       env.endCustomReduc(); // exit custom
753a2c9d4bbSAart Bik   }
754a2c9d4bbSAart Bik }
755a2c9d4bbSAart Bik 
7564f2ec7f9SAart Bik /// Generates an expanded access pattern in innermost dimension.
757c5a1732cSAart Bik static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr,
758c5a1732cSAart Bik                       bool isStart) {
759384049a7SAart Bik   linalg::GenericOp op = env.op();
760384049a7SAart Bik   OpOperand *lhs = op.getDpsInitOperand(0);
761c5a1732cSAart Bik   if (!env.atExpandLevel(lhs, op.getRank(lhs), curr))
762c5a1732cSAart Bik     return; // not needed at current level
763fbe61130SAart Bik   assert(!env.isReduc());
7645661647eSAart Bik   // Generate start or end of an expanded access pattern. Note that because
7652cb99df6SYinying Li   // an expansion does not rely on the ongoing contents of the sparse storage
7665661647eSAart Bik   // scheme, we can use the original tensor as incoming SSA value (which
7675661647eSAart Bik   // simplifies codegen a bit). If expansion on the actual contents is ever
7685661647eSAart Bik   // needed, we will need to use the SSA value in the insertion chain instead.
7694f2ec7f9SAart Bik   Value tensor = lhs->get();
7704f2ec7f9SAart Bik   Location loc = op.getLoc();
771c5a1732cSAart Bik   if (isStart) {
772399638f9SAliia Khasanova     auto dynShape = {ShapedType::kDynamic};
7735550c821STres Popp     Type etp = cast<ShapedType>(tensor.getType()).getElementType();
7744f2ec7f9SAart Bik     Type t1 = MemRefType::get(dynShape, etp);
775e9fa5590SMatthias Springer     Type t2 = MemRefType::get(dynShape, builder.getI1Type());
776e9fa5590SMatthias Springer     Type t3 = MemRefType::get(dynShape, builder.getIndexType());
777e9fa5590SMatthias Springer     Type t4 = builder.getIndexType();
778384049a7SAart Bik     auto r = builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor);
779384049a7SAart Bik     assert(r.getNumResults() == 4);
780384049a7SAart Bik     env.startExpand(r.getResult(0), r.getResult(1), r.getResult(2),
781384049a7SAart Bik                     r.getResult(3));
7824f2ec7f9SAart Bik   } else {
7830e1708ffSAart Bik     SmallVector<Value> indices;
784c5a1732cSAart Bik     for (LoopId i = 0; i < curr; i++)
78536fd2875SAart Bik       indices.push_back(env.emitter().getLoopIV(i));
786384049a7SAart Bik     Value values = env.getExpandValues();
787384049a7SAart Bik     Value filled = env.getExpandFilled();
788384049a7SAart Bik     Value added = env.getExpandAdded();
789384049a7SAart Bik     Value count = env.getExpandCount();
790384049a7SAart Bik     Value chain = env.getInsertionChain();
791384049a7SAart Bik     Value compress = builder.create<CompressOp>(loc, values, filled, added,
792384049a7SAart Bik                                                 count, chain, indices);
793384049a7SAart Bik     env.updateInsertionChain(compress);
794384049a7SAart Bik     env.endExpand();
7954f2ec7f9SAart Bik   }
7964f2ec7f9SAart Bik }
7974f2ec7f9SAart Bik 
798b0f8057eSPeiming Liu /// Returns parallelization strategy. Any implicit loop in the Linalg
799b0f8057eSPeiming Liu /// operation that is marked "parallel" is a candidate. Whether it is actually
800b0f8057eSPeiming Liu /// converted to a parallel operation depends on the requested strategy.
801fbe61130SAart Bik static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
80253cc3a06SAart Bik   // Reject parallelization of sparse output.
803384049a7SAart Bik   if (env.hasSparseOutput())
80453cc3a06SAart Bik     return false;
80575ac294bSPeiming Liu   // Parallel loops on tensor expansion can cause data races.
806384049a7SAart Bik   if (env.isExpand())
80775ac294bSPeiming Liu     return false;
80853cc3a06SAart Bik   // Inspect strategy.
809384049a7SAart Bik   switch (env.options().parallelizationStrategy) {
810a2c9d4bbSAart Bik   case SparseParallelizationStrategy::kNone:
811a2c9d4bbSAart Bik     return false;
812a2c9d4bbSAart Bik   case SparseParallelizationStrategy::kDenseOuterLoop:
81375ac294bSPeiming Liu     return isOuter && !isSparse;
814a2c9d4bbSAart Bik   case SparseParallelizationStrategy::kAnyStorageOuterLoop:
81575ac294bSPeiming Liu     return isOuter;
816a2c9d4bbSAart Bik   case SparseParallelizationStrategy::kDenseAnyLoop:
81775ac294bSPeiming Liu     return !isSparse;
818a2c9d4bbSAart Bik   case SparseParallelizationStrategy::kAnyStorageAnyLoop:
81975ac294bSPeiming Liu     return true;
820a2c9d4bbSAart Bik   }
821a2c9d4bbSAart Bik   llvm_unreachable("unexpected parallelization strategy");
822a2c9d4bbSAart Bik }
823a2c9d4bbSAart Bik 
824fd68d361SPeiming Liu /// Whether or not the current loop being generated should be parallized (if
825fd68d361SPeiming Liu /// possible) according to the configuration.
826c5a1732cSAart Bik static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
82736c95ee7SPeiming Liu                                ArrayRef<TensorLevel> tidLvls) {
828384049a7SAart Bik   linalg::GenericOp op = env.op();
82998f93e3bSAart Bik   auto iteratorTypes = op.getIteratorTypesArray();
830c5a1732cSAart Bik   bool isSparse = llvm::any_of(tidLvls, [curr, &env](TensorLevel tidLvl) {
83198ce2debSAart Bik     // Queries the LT based on the tensor and loop id, as requested by
83298ce2debSAart Bik     // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv
8331dd387e1SAart Bik     // should be consistent with the LT indexed by <TensorId, Level>.
834c5a1732cSAart Bik     const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, curr);
83552b69aa3SPeiming Liu     return lt.hasSparseSemantic();
83649be68b8SPeiming Liu   });
837c5a1732cSAart Bik   return isParallelFor(env, /*isOuter=*/curr == 0, isSparse);
838fd68d361SPeiming Liu }
839a2c9d4bbSAart Bik 
840fd68d361SPeiming Liu /// Emit a loop to coiterate over the list of tensor levels. The generated loop
841fd68d361SPeiming Liu /// can either be a for loop or while loop depending on whether there is at most
842fd68d361SPeiming Liu /// one sparse level in the list.
843fd68d361SPeiming Liu static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
84498ce2debSAart Bik                                  ArrayRef<TensorLevel> tidLvls,
845*c4420257SPeiming Liu                                  unsigned numCases, bool tryParallel,
846*c4420257SPeiming Liu                                  bool needsUniv) {
8478109d5e9SAart Bik   Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
84898ce2debSAart Bik     // Construct while-loop with a parameter for each index.
849b8cf7af9Swren romano     return env.emitter().enterCoIterationOverTensorsAtLvls(
850*c4420257SPeiming Liu         builder, env.op().getLoc(), tidLvls, numCases, reduc, tryParallel,
851*c4420257SPeiming Liu         needsUniv);
85276b11003SFangrui Song   });
85375ac294bSPeiming Liu   assert(loop);
854b0f8057eSPeiming Liu   return loop;
855a2c9d4bbSAart Bik }
856a2c9d4bbSAart Bik 
857a2c9d4bbSAart Bik /// Generates a for-loop or a while-loop, depending on whether it implements
858a2c9d4bbSAart Bik /// singleton iteration or co-iteration over the given conjunction.
859c5a1732cSAart Bik static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr,
860*c4420257SPeiming Liu                           unsigned numCases, bool needsUniv,
861*c4420257SPeiming Liu                           ArrayRef<TensorLevel> tidLvls) {
862c5a1732cSAart Bik   bool tryParallel = shouldTryParallize(env, curr, tidLvls);
863*c4420257SPeiming Liu   return genCoIteration(env, builder, tidLvls, numCases, tryParallel,
864*c4420257SPeiming Liu                         needsUniv);
865a2c9d4bbSAart Bik }
866a2c9d4bbSAart Bik 
867a2c9d4bbSAart Bik /// Generates the induction structure for a while-loop.
86898ce2debSAart Bik static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
869372d88b0SPeiming Liu                             bool needsUniv) {
870384049a7SAart Bik   Location loc = env.op().getLoc();
8717373cabcSAart Bik   // Finalize each else branch of all if statements.
872384049a7SAart Bik   if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
8737373cabcSAart Bik     while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
874e9fa5590SMatthias Springer                builder.getInsertionBlock()->getParentOp())) {
875e2e83f4cSPeiming Liu       // Break on IfOp for slicing filtering.
876e2e83f4cSPeiming Liu       if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ==
877e2e83f4cSPeiming Liu           StringAttr::get(ifOp->getContext(), "slice"))
878e2e83f4cSPeiming Liu         break;
879e2e83f4cSPeiming Liu 
8804f2ec7f9SAart Bik       unsigned y = 0;
8810e1708ffSAart Bik       SmallVector<Value> yields;
882fbe61130SAart Bik       if (env.isReduc()) {
883fbe61130SAart Bik         yields.push_back(env.getReduc());
884fbe61130SAart Bik         env.updateReduc(ifOp.getResult(y++));
885047399c2SAart Bik         if (env.isValidLexInsert()) {
88681d0d2b2SJim Kitchen           yields.push_back(env.getValidLexInsert());
887047399c2SAart Bik           env.updateValidLexInsert(ifOp.getResult(y++));
88881d0d2b2SJim Kitchen         }
8894f2ec7f9SAart Bik       }
890384049a7SAart Bik       if (env.isExpand()) {
891384049a7SAart Bik         yields.push_back(env.getExpandCount());
892384049a7SAart Bik         env.updateExpandCount(ifOp->getResult(y++));
8934f2ec7f9SAart Bik       }
894384049a7SAart Bik       if (env.getInsertionChain()) {
895384049a7SAart Bik         yields.push_back(env.getInsertionChain());
896384049a7SAart Bik         env.updateInsertionChain(ifOp->getResult(y++));
8975661647eSAart Bik       }
8984f2ec7f9SAart Bik       assert(y == yields.size());
899e9fa5590SMatthias Springer       builder.create<scf::YieldOp>(loc, yields);
900e9fa5590SMatthias Springer       builder.setInsertionPointAfter(ifOp);
9017373cabcSAart Bik     }
9027373cabcSAart Bik   }
903372d88b0SPeiming Liu   // No need to set the insertion point here as LoopEmitter keeps track of the
904372d88b0SPeiming Liu   // basic block where scf::Yield should be inserted.
905a2c9d4bbSAart Bik }
906a2c9d4bbSAart Bik 
907*c4420257SPeiming Liu /// Generates a case region in the coiterate operation.
908*c4420257SPeiming Liu static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder,
909*c4420257SPeiming Liu                                unsigned caseIdx, LatPointId allCase,
910*c4420257SPeiming Liu                                LatPointId curCase,
911*c4420257SPeiming Liu                                MutableArrayRef<Value> reduc) {
912*c4420257SPeiming Liu   assert(allCase == curCase || env.merger().latGT(allCase, curCase));
913*c4420257SPeiming Liu   const BitVector &allCaseBits = env.merger().lat(allCase).simple;
914*c4420257SPeiming Liu   const BitVector &curCaseBits = env.merger().lat(curCase).simple;
915*c4420257SPeiming Liu 
916*c4420257SPeiming Liu   /// Computes the subset of iterators that are valid in the current case being
917*c4420257SPeiming Liu   /// generated.
918*c4420257SPeiming Liu   I64BitSet caseBit(0);
919*c4420257SPeiming Liu   for (auto [idx, set] : llvm::enumerate(allCaseBits.set_bits()))
920*c4420257SPeiming Liu     if (curCaseBits.test(set))
921*c4420257SPeiming Liu       caseBit.set(idx);
922*c4420257SPeiming Liu 
923*c4420257SPeiming Liu   env.emitter().enterCurrentCoIterationCase(builder, env.op().getLoc(), caseBit,
924*c4420257SPeiming Liu                                             caseIdx, reduc);
925*c4420257SPeiming Liu }
926*c4420257SPeiming Liu 
927a2c9d4bbSAart Bik /// Generates a single if-statement within a while-loop.
928c5a1732cSAart Bik static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
92946a384dfSwren romano                        LatPointId p) {
930384049a7SAart Bik   Location loc = env.op().getLoc();
9310e1708ffSAart Bik   SmallVector<Type> types;
932a2c9d4bbSAart Bik   Value cond;
93346a384dfSwren romano   env.merger().foreachTensorLoopId(
93446a384dfSwren romano       p, /*simple=*/true,
9351944c4f7SAart Bik       [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt,
9361944c4f7SAart Bik           bool isIdxRed) {
937fc5d8fceSPeiming Liu         if (isIdxRed) {
938fc5d8fceSPeiming Liu           // Since there is no 1:1 mapping from loop to level (multiple loops
939fc5d8fceSPeiming Liu           // are required to resolve one level with non-trivial index
940fc5d8fceSPeiming Liu           // expression), we need to reconstruct the tensor level types if this
941fc5d8fceSPeiming Liu           // loop requires index reduction condition.
9421dd387e1SAart Bik           assert(lvl.has_value() && isUndefLT(lt));
943fc5d8fceSPeiming Liu           auto stt = getSparseTensorType(env.op().getInputs()[tid]);
9441dd387e1SAart Bik           lt = stt.getLvlType(*lvl);
945fc5d8fceSPeiming Liu         }
946c5a1732cSAart Bik         assert(curr == env.merger().loop(b));
947a2c9d4bbSAart Bik         Value clause;
94852b69aa3SPeiming Liu         if (lt.hasSparseSemantic()) {
94946a384dfSwren romano           assert(lvl.has_value());
950298412b5SPeiming Liu           const Value crd = env.emitter().getCoord(tid, *lvl);
951c5a1732cSAart Bik           const Value lvar = env.getLoopVar(curr);
95246a384dfSwren romano           clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
95346a384dfSwren romano                                                  crd, lvar);
954a2c9d4bbSAart Bik         } else {
95552b69aa3SPeiming Liu           assert(lt.hasDenseSemantic() || isUndefLT(lt));
956e9fa5590SMatthias Springer           clause = constantI1(builder, loc, true);
957a2c9d4bbSAart Bik         }
958e9fa5590SMatthias Springer         cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
95946a384dfSwren romano       });
96081d0d2b2SJim Kitchen   if (env.isReduc()) {
961fbe61130SAart Bik     types.push_back(env.getReduc().getType());
962047399c2SAart Bik     if (env.isValidLexInsert())
96381d0d2b2SJim Kitchen       types.push_back(env.getValidLexInsert().getType());
96481d0d2b2SJim Kitchen   }
965384049a7SAart Bik   if (env.isExpand())
966e9fa5590SMatthias Springer     types.push_back(builder.getIndexType());
967384049a7SAart Bik   if (env.getInsertionChain())
968384049a7SAart Bik     types.push_back(env.getInsertionChain().getType());
969e9fa5590SMatthias Springer   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
970e9fa5590SMatthias Springer   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
971a2c9d4bbSAart Bik   return ifOp;
972a2c9d4bbSAart Bik }
973a2c9d4bbSAart Bik 
9747373cabcSAart Bik /// Generates end of true branch of if-statement within a while-loop.
975fbe61130SAart Bik static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
976372d88b0SPeiming Liu                   Value redInput, Value cntInput, Value insInput,
977372d88b0SPeiming Liu                   Value validIns) {
9780e1708ffSAart Bik   SmallVector<Value> operands;
979fbe61130SAart Bik   if (env.isReduc()) {
980fbe61130SAart Bik     operands.push_back(env.getReduc());
981fbe61130SAart Bik     env.updateReduc(redInput);
982047399c2SAart Bik     if (env.isValidLexInsert()) {
98381d0d2b2SJim Kitchen       // Any overlapping indices during a reduction creates a valid lex insert.
98481d0d2b2SJim Kitchen       operands.push_back(constantI1(builder, env.op().getLoc(), true));
985047399c2SAart Bik       env.updateValidLexInsert(validIns);
986fc5d8fceSPeiming Liu     }
9877373cabcSAart Bik   }
988384049a7SAart Bik   if (env.isExpand()) {
989384049a7SAart Bik     operands.push_back(env.getExpandCount());
990384049a7SAart Bik     env.updateExpandCount(cntInput);
9914f2ec7f9SAart Bik   }
992384049a7SAart Bik   if (env.getInsertionChain()) {
993384049a7SAart Bik     operands.push_back(env.getInsertionChain());
994384049a7SAart Bik     env.updateInsertionChain(insInput);
9955661647eSAart Bik   }
9964f2ec7f9SAart Bik   if (!operands.empty())
997384049a7SAart Bik     builder.create<scf::YieldOp>(env.op().getLoc(), operands);
998e9fa5590SMatthias Springer   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
9997373cabcSAart Bik }
10007373cabcSAart Bik 
1001c8d5dcb0SAart Bik //===----------------------------------------------------------------------===//
1002c43e6274STim Harvey // Sparsifier synthesis methods (loop sequence).
1003c8d5dcb0SAart Bik //===----------------------------------------------------------------------===//
1004c8d5dcb0SAart Bik 
1005d933b88bSPeiming Liu static bool getAllTidLvlsInLatPoints(
1006d933b88bSPeiming Liu     CodegenEnv &env, LatPointId li, LoopId curr,
1007d933b88bSPeiming Liu     llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {
1008d933b88bSPeiming Liu   const BitVector &simple = env.lat(li).simple;
1009d933b88bSPeiming Liu   const TensorId outTid = env.merger().getOutTensorID();
1010d933b88bSPeiming Liu   const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr);
1011d933b88bSPeiming Liu 
1012d933b88bSPeiming Liu   unsigned numloopCond = 0;
1013d933b88bSPeiming Liu   bool hasNonUnique = false;
1014d933b88bSPeiming Liu   env.merger().foreachTensorLoopId(
1015d933b88bSPeiming Liu       li, [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
1016d933b88bSPeiming Liu                     LevelType lt, bool isIdxReduc) {
1017d933b88bSPeiming Liu         if (simple[b]) {
1018d933b88bSPeiming Liu           if (isIdxReduc) {
1019d933b88bSPeiming Liu             callback(env.makeTensorLevel(tid, *lvl), nullptr);
1020d933b88bSPeiming Liu             numloopCond++;
1021d933b88bSPeiming Liu             return;
1022d933b88bSPeiming Liu           }
1023d933b88bSPeiming Liu           if (isUndefLT(lt)) {
1024d933b88bSPeiming Liu             // An undefined lt in the lattices, we probably mean to
1025d933b88bSPeiming Liu             // generate a dense loop according to the synthetic tensor (for
1026d933b88bSPeiming Liu             // invariants and sparse output tensor).
1027d933b88bSPeiming Liu             if (env.merger().getSynTensorID() == tid) {
1028d933b88bSPeiming Liu               // Coiterating with an invariant
1029d933b88bSPeiming Liu               // e.g., out = prod(in[i][j] op invariant);
1030d933b88bSPeiming Liu               // or a broadcast
1031d933b88bSPeiming Liu               // e.g., out[i][j] = in[i] (j is undef for input)
1032d933b88bSPeiming Liu               //
1033d933b88bSPeiming Liu               // The level of the synthetic tensor is the current loop depth;
1034d933b88bSPeiming Liu               // the rank of the synthetic tensor equals to number of loops.
1035d933b88bSPeiming Liu               assert(curr == env.getCurrentDepth());
1036d933b88bSPeiming Liu               lvl = curr;
1037d933b88bSPeiming Liu             } else if (!lvl) {
1038d933b88bSPeiming Liu               // Skips invalid lvl (e.g., when this is a zero ranked tensor).
1039d933b88bSPeiming Liu               return;
1040d933b88bSPeiming Liu             }
1041d933b88bSPeiming Liu           }
1042d933b88bSPeiming Liu           hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
1043d933b88bSPeiming Liu           callback(env.makeTensorLevel(tid, *lvl), nullptr);
1044d933b88bSPeiming Liu           numloopCond++;
104552b69aa3SPeiming Liu         } else if (lt.hasDenseSemantic() || isIdxReduc) {
1046d933b88bSPeiming Liu           callback(env.makeTensorLevel(tid, *lvl), nullptr);
1047d933b88bSPeiming Liu         } else {
1048d933b88bSPeiming Liu           assert(isUndefLT(lt));
1049d933b88bSPeiming Liu           linalg::GenericOp op = env.op();
1050d933b88bSPeiming Liu           if (tid >= op.getNumDpsInputs())
1051d933b88bSPeiming Liu             // We only handle affine expression on input tensors (for now).
1052d933b88bSPeiming Liu             return;
1053d933b88bSPeiming Liu           OpOperand *operand = &op->getOpOperand(tid);
1054d933b88bSPeiming Liu           const auto stt = getSparseTensorType(operand->get());
1055d933b88bSPeiming Liu           // Non-annotated dense tensors requires no special handling.
1056d933b88bSPeiming Liu           if (!stt.hasEncoding())
1057d933b88bSPeiming Liu             return;
1058d933b88bSPeiming Liu 
1059d933b88bSPeiming Liu           ArrayRef<AffineExpr> affines =
1060d933b88bSPeiming Liu               op.getMatchingIndexingMap(operand).getResults();
1061d933b88bSPeiming Liu           const Level lvlRank = stt.getLvlRank();
1062d933b88bSPeiming Liu           assert(affines.size() == static_cast<size_t>(lvlRank));
1063d933b88bSPeiming Liu           for (Level l = 0; l < lvlRank; l++) {
1064d933b88bSPeiming Liu             AffineExpr exp = affines[l];
1065d933b88bSPeiming Liu             // Skip simple affine expression and non-dense levels (which
1066d933b88bSPeiming Liu             // have their own filter loop).
106752b69aa3SPeiming Liu             LevelType lt = stt.getLvlType(l);
106852b69aa3SPeiming Liu             if (isa<AffineDimExpr>(exp) || !lt.hasDenseSemantic())
1069d933b88bSPeiming Liu               continue;
1070d933b88bSPeiming Liu 
1071d933b88bSPeiming Liu             // Constant affine expression are handled in genLoop.
1072d933b88bSPeiming Liu             if (!isa<AffineConstantExpr>(exp)) {
1073d933b88bSPeiming Liu               bool isCurrentLoop = false;
1074d933b88bSPeiming Liu               assert(curr == env.getCurrentDepth());
1075d933b88bSPeiming Liu               if (isInvariantAffine(exp, curr + 1, /*out*/ isCurrentLoop) &&
1076d933b88bSPeiming Liu                   isCurrentLoop) {
1077d933b88bSPeiming Liu                 // If the compound affine is invariant and we are right at the
1078d933b88bSPeiming Liu                 // level. We need to generate the address according to the
1079d933b88bSPeiming Liu                 // affine expression. This is also the best place we can do it
1080d933b88bSPeiming Liu                 // to avoid putting it inside inner loops.
1081d933b88bSPeiming Liu                 callback(env.makeTensorLevel(tid, l), exp);
1082d933b88bSPeiming Liu               }
1083d933b88bSPeiming Liu             }
1084d933b88bSPeiming Liu           }
1085d933b88bSPeiming Liu         }
1086d933b88bSPeiming Liu       });
1087d933b88bSPeiming Liu 
1088d933b88bSPeiming Liu   if (isDenseLT(env.lt(outTid, curr))) {
1089298412b5SPeiming Liu     auto stt = getSparseTensorType(env.op().getOutputs().front());
1090298412b5SPeiming Liu     // Note that we generate dense indices of the output tensor unconditionally,
1091298412b5SPeiming Liu     // since they may not appear in the lattice, but may be needed for
1092298412b5SPeiming Liu     // linearized env.
1093298412b5SPeiming Liu     // TODO: we should avoid introducing corner cases for all-dense sparse
1094298412b5SPeiming Liu     // tensors.
1095298412b5SPeiming Liu     if (stt.hasEncoding() && stt.isAllDense())
1096d933b88bSPeiming Liu       callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
1097d933b88bSPeiming Liu   }
1098d933b88bSPeiming Liu 
1099d933b88bSPeiming Liu   if (numloopCond == 0) {
1100d933b88bSPeiming Liu     // Corner cases where the loop bound is defined by a *unused* operand, in
1101d933b88bSPeiming Liu     // this case, we just generate a dense "fake" loop by iterating over the
1102d933b88bSPeiming Liu     // synthetic tensor.
1103d933b88bSPeiming Liu     callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr);
1104d933b88bSPeiming Liu     numloopCond++;
1105d933b88bSPeiming Liu   }
1106d933b88bSPeiming Liu   // If we just need to one loop conditions and the conditions is not imposed on
1107d933b88bSPeiming Liu   // non-unique level, the loop can be generated by a for loop.
1108a02010b3SPeiming Liu   // Or, if we are generating sparse-iterator-based loops, we always generate
1109a02010b3SPeiming Liu   // `sparse_tensor.iterate` regardless whether the level is unique or not.
1110a02010b3SPeiming Liu   return numloopCond == 1 &&
1111a02010b3SPeiming Liu          (!hasNonUnique || env.options().sparseEmitStrategy ==
1112a02010b3SPeiming Liu                                SparseEmitStrategy::kSparseIterator);
1113d933b88bSPeiming Liu }
1114d933b88bSPeiming Liu 
1115c8d5dcb0SAart Bik /// Starts a loop sequence at given level. Returns true if
1116c8d5dcb0SAart Bik /// the universal loop index must be maintained at this level.
1117b8cf7af9Swren romano static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
1118c5a1732cSAart Bik                          LoopId curr, LatSetId lts) {
1119c5a1732cSAart Bik   assert(!env.getLoopVar(curr));
1120c8d5dcb0SAart Bik   // Emit invariants at this loop sequence level.
1121c5a1732cSAart Bik   genInvariants(env, builder, exp, curr, /*isStart=*/true);
11224f2ec7f9SAart Bik   // Emit access pattern expansion for sparse tensor output.
1123c5a1732cSAart Bik   genExpand(env, builder, curr, /*isStart=*/true);
1124d933b88bSPeiming Liu   // Emit further initialization at this loop sequence level.
1125b8cf7af9Swren romano   const LatPointId l0 = env.set(lts)[0];
1126b0f8057eSPeiming Liu 
112736c95ee7SPeiming Liu   SmallVector<TensorLevel> tidLvls;
1128d933b88bSPeiming Liu   getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
1129298412b5SPeiming Liu     // TODO: remove this! The same tensor level might be added for multiple
1130298412b5SPeiming Liu     // times due to the special handling for all-dense "sparse" output tensor
1131298412b5SPeiming Liu     // (see L1038).
1132298412b5SPeiming Liu     if (llvm::find(tidLvls, tl) != tidLvls.end())
1133298412b5SPeiming Liu       return;
1134d933b88bSPeiming Liu     tidLvls.emplace_back(tl);
113532c512e4SPeiming Liu   });
1136b0f8057eSPeiming Liu 
113736c95ee7SPeiming Liu   env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
1138b0f8057eSPeiming Liu 
1139c8d5dcb0SAart Bik   // Maintain the universal index only if it is actually
1140c8d5dcb0SAart Bik   // consumed by a subsequent lattice point.
11417c7c10a0Swren romano   for (const LatPointId li : env.set(lts).drop_front())
11425fd9d801SPeiming Liu     if (!env.merger().hasAnySparse(env.lat(li).simple))
1143c8d5dcb0SAart Bik       return true;
1144d933b88bSPeiming Liu 
1145c8d5dcb0SAart Bik   return false;
1146c8d5dcb0SAart Bik }
1147c8d5dcb0SAart Bik 
114898ce2debSAart Bik // Generates dense affine address for encoding.
1149fbe61130SAart Bik static void genConstantDenseAddressFromLevel(CodegenEnv &env,
1150b8cf7af9Swren romano                                              OpBuilder &builder, TensorId tid,
1151b8cf7af9Swren romano                                              Level startLvl) {
11524fd3a120SPeiming Liu   // TODO: Handle affine expression on output tensor.
1153384049a7SAart Bik   linalg::GenericOp op = env.op();
11544fd3a120SPeiming Liu   assert(tid < op.getNumDpsInputs());
11554fd3a120SPeiming Liu   OpOperand *input = op.getDpsInputOperands()[tid];
1156b8cf7af9Swren romano   const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
1157f708a549Swren romano   const auto enc = getSparseTensorEncoding(input->get().getType());
11584fd3a120SPeiming Liu   if (enc) {
1159b8cf7af9Swren romano     const Location loc = op.getLoc();
116046a384dfSwren romano     const TensorId tid = env.makeTensorId(input->getOperandNumber());
1161f708a549Swren romano     const Level lvlRank = enc.getLvlRank();
1162b8cf7af9Swren romano     assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
1163b8cf7af9Swren romano     for (Level l = startLvl; l < lvlRank; l++) {
1164ccd923e3SPeiming Liu       AffineExpr lvlExpr = lvlExprs[l];
116552b69aa3SPeiming Liu       if (enc.getLvlType(l).hasDenseSemantic() &&
116652b69aa3SPeiming Liu           isa<AffineConstantExpr>(lvlExpr))
1167298412b5SPeiming Liu         env.emitter().locateLvlAtAffineAddress(
116836c95ee7SPeiming Liu             builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
116998f93e3bSAart Bik       else
117098f93e3bSAart Bik         return; // break on first non-dense non-constant level
11714fd3a120SPeiming Liu     }
11724fd3a120SPeiming Liu   }
11734fd3a120SPeiming Liu }
11744fd3a120SPeiming Liu 
117598f93e3bSAart Bik // We can generate address for constant affine expression before any loops
11764fd3a120SPeiming Liu // starting from the first level as they do not depend on anything.
11774fd3a120SPeiming Liu // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
11784fd3a120SPeiming Liu // levels can be determined before loops.
117998ce2debSAart Bik static void genInitConstantDenseAddress(CodegenEnv &env,
118098ce2debSAart Bik                                         RewriterBase &rewriter) {
1181b8cf7af9Swren romano   for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
118298f93e3bSAart Bik     genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
11834fd3a120SPeiming Liu }
11844fd3a120SPeiming Liu 
1185d933b88bSPeiming Liu /// Returns true if the lattice bit can be iterated by a for loop.
1186b8cf7af9Swren romano static bool translateBitsToTidLvlPairs(
1187c5a1732cSAart Bik     CodegenEnv &env, LatPointId li, LoopId curr,
118836c95ee7SPeiming Liu     SmallVectorImpl<TensorLevel> &tidLvls,
118936c95ee7SPeiming Liu     SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
1190d933b88bSPeiming Liu   return getAllTidLvlsInLatPoints(env, li, curr,
1191d933b88bSPeiming Liu                                   [&](TensorLevel tl, AffineExpr exp) {
1192d933b88bSPeiming Liu                                     if (exp)
1193d933b88bSPeiming Liu                                       affineTidLvls.emplace_back(tl, exp);
1194d933b88bSPeiming Liu                                     else
1195d933b88bSPeiming Liu                                       tidLvls.emplace_back(tl);
119632c512e4SPeiming Liu                                   });
1197b0f8057eSPeiming Liu }
1198b0f8057eSPeiming Liu 
1199c8d5dcb0SAart Bik /// Starts a single loop in current sequence.
12001328bb6eSPeiming Liu static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
1201c5a1732cSAart Bik                                               OpBuilder &builder, LoopId curr,
1202*c4420257SPeiming Liu                                               LatPointId li, unsigned numCases,
1203*c4420257SPeiming Liu                                               bool needsUniv) {
1204*c4420257SPeiming Liu   // TODO: numCases only used when generating iterator-based loops. Cleanup
1205*c4420257SPeiming Liu   // after fully migration.
1206b8cf7af9Swren romano   // The set of tensors + lvls to generate loops on
120736c95ee7SPeiming Liu   SmallVector<TensorLevel> tidLvls;
120898ce2debSAart Bik 
1209fb287335SPeiming Liu   // The set of dense tensors with non-trivial affine expression that just
121098ce2debSAart Bik   // becomes invariant and the address are generated at the current level.
121136c95ee7SPeiming Liu   SmallVector<std::pair<TensorLevel, AffineExpr>> affineTidLvls;
1212ff8815e5SPeiming Liu   bool isSingleCond =
1213c5a1732cSAart Bik       translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls);
1214b0f8057eSPeiming Liu 
1215c8d5dcb0SAart Bik   // Emit the for/while-loop control.
1216*c4420257SPeiming Liu   Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls);
1217b8cf7af9Swren romano   Location loc = env.op().getLoc();
121836c95ee7SPeiming Liu   for (auto [tidLvl, exp] : affineTidLvls) {
1219298412b5SPeiming Liu     env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp);
1220fb287335SPeiming Liu   }
12214fd3a120SPeiming Liu 
1222b8cf7af9Swren romano   // Until now, we have entered every <tid, lvl> pair in {cond, extra,
1223b8cf7af9Swren romano   // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent
12244fd3a120SPeiming Liu   // on constant affines expression may now be determined.
122536c95ee7SPeiming Liu   auto allTidLvls =
122636c95ee7SPeiming Liu       llvm::concat<TensorLevel>(tidLvls, llvm::make_first_range(affineTidLvls));
122736c95ee7SPeiming Liu   for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) {
1228e7b4c93fSPeiming Liu     if (tid != env.merger().getOutTensorID() &&
1229e7b4c93fSPeiming Liu         tid != env.merger().getSynTensorID())
1230b8cf7af9Swren romano       genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1);
12314fd3a120SPeiming Liu   }
12324fd3a120SPeiming Liu 
12331328bb6eSPeiming Liu   return std::make_pair(loop, isSingleCond);
1234c8d5dcb0SAart Bik }
1235c8d5dcb0SAart Bik 
1236c8d5dcb0SAart Bik /// Ends a single loop in current sequence. Returns new values for needsUniv.
1237fbe61130SAart Bik static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
123898ce2debSAart Bik                     LatPointId li, bool needsUniv, bool isSingleCond) {
12395fd9d801SPeiming Liu   // Either a for-loop or a while-loop that iterates over a slice.
124098ce2debSAart Bik   if (isSingleCond) {
12415fd9d801SPeiming Liu     // Any iteration creates a valid lex insert.
1242047399c2SAart Bik     if (env.isReduc() && env.isValidLexInsert())
1243047399c2SAart Bik       env.updateValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
12445fd9d801SPeiming Liu   } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
12455fd9d801SPeiming Liu     // End a while-loop.
124698ce2debSAart Bik     finalizeWhileOp(env, rewriter, needsUniv);
1247b0f8057eSPeiming Liu   } else {
1248b0f8057eSPeiming Liu     needsUniv = false;
1249c8d5dcb0SAart Bik   }
12508109d5e9SAart Bik   env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
125136fd2875SAart Bik     env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc);
12521a36588eSKazu Hirata     return std::nullopt;
125375ac294bSPeiming Liu   });
1254b0f8057eSPeiming Liu   return needsUniv;
1255c8d5dcb0SAart Bik }
1256c8d5dcb0SAart Bik 
1257c8d5dcb0SAart Bik /// Ends a loop sequence at given level.
12585fd9d801SPeiming Liu static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
125998ce2debSAart Bik                        unsigned at) {
126098ce2debSAart Bik   assert(!env.getLoopVar(at));
12615fd9d801SPeiming Liu   env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc());
12627373cabcSAart Bik   // Unmark bookkeeping of invariants and loop index.
1263c5a1732cSAart Bik   genInvariants(env, builder, exp, at, /*isStart=*/false);
12644f2ec7f9SAart Bik   // Finalize access pattern expansion for sparse tensor output.
1265c5a1732cSAart Bik   genExpand(env, builder, at, /*isStart=*/false);
1266c8d5dcb0SAart Bik }
1267c8d5dcb0SAart Bik 
1268a2c9d4bbSAart Bik /// Recursively generates code while computing iteration lattices in order
1269a2c9d4bbSAart Bik /// to manage the complexity of implementing co-iteration over unions
1270a2c9d4bbSAart Bik /// and intersections of sparse iterations spaces.
1271b8cf7af9Swren romano static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
1272c5a1732cSAart Bik                     LoopId curr) {
1273c5a1732cSAart Bik   assert(curr == env.getCurrentDepth());
127498ce2debSAart Bik 
1275a2c9d4bbSAart Bik   // At each leaf, assign remaining tensor (sub)expression to output tensor.
1276c5a1732cSAart Bik   if (curr == env.getLoopNum()) {
1277067bebb5SAart Bik     Value rhs = genExp(env, rewriter, exp);
127898f93e3bSAart Bik     genTensorStore(env, rewriter, exp, rhs);
1279a2c9d4bbSAart Bik     return;
1280a2c9d4bbSAart Bik   }
1281a2c9d4bbSAart Bik 
128298ce2debSAart Bik   // Construct iteration lattices for current loop index.
1283b8cf7af9Swren romano   const LatSetId lts =
1284c5a1732cSAart Bik       env.merger().optimizeSet(env.merger().buildLattices(exp, curr));
1285a2c9d4bbSAart Bik 
1286c8d5dcb0SAart Bik   // Start a loop sequence.
1287c5a1732cSAart Bik   bool needsUniv = startLoopSeq(env, rewriter, exp, curr, lts);
1288c8d5dcb0SAart Bik 
1289*c4420257SPeiming Liu   // When using sparse-iterator-based loops, we only need one loops, as
1290*c4420257SPeiming Liu   // opposed to a loop sequence, to cover all the iterator spaces.
1291*c4420257SPeiming Liu   const unsigned lsize = env.set(lts).size();
1292*c4420257SPeiming Liu   if (env.generatingSparseIterator()) {
1293*c4420257SPeiming Liu     // Get the largest lattice point and start a loop.
1294*c4420257SPeiming Liu     const LatPointId li = env.set(lts)[0];
1295*c4420257SPeiming Liu     auto [loop, isSingleCond] =
1296*c4420257SPeiming Liu         startLoop(env, rewriter, curr, li, lsize, needsUniv);
1297*c4420257SPeiming Liu     assert(isSingleCond == llvm::isa<IterateOp>(loop));
1298067bebb5SAart Bik     // We cannot change this to `for (const LatPointId li : env.set(lts))`
12997c7c10a0Swren romano     // because the loop body causes data-movement which invalidates
13007c7c10a0Swren romano     // the iterator.
1301*c4420257SPeiming Liu     for (unsigned j = 0; j < lsize; j++) {
1302*c4420257SPeiming Liu       const LatPointId lj = env.set(lts)[j];
1303*c4420257SPeiming Liu       const ExprId ej = env.lat(lj).exp;
1304*c4420257SPeiming Liu       // Recurse into body of each branch.
1305*c4420257SPeiming Liu       if (!isSingleCond) {
1306*c4420257SPeiming Liu         env.genLoopBoundary([&, curr, j, li, lj](MutableArrayRef<Value> reduc) {
1307*c4420257SPeiming Liu           genCoIterationCase(env, rewriter, /*caseIdx*/ j, li, lj, reduc);
1308*c4420257SPeiming Liu           genStmt(env, rewriter, ej, curr + 1);
1309*c4420257SPeiming Liu           // TODO: handle yield values.
1310*c4420257SPeiming Liu           assert(reduc.empty() && "Not Implemented");
1311*c4420257SPeiming Liu           rewriter.create<sparse_tensor::YieldOp>(env.op().getLoc());
1312*c4420257SPeiming Liu           return std::nullopt;
1313*c4420257SPeiming Liu         });
1314*c4420257SPeiming Liu         // endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1315*c4420257SPeiming Liu       } else {
1316*c4420257SPeiming Liu         genStmt(env, rewriter, ej, curr + 1);
1317*c4420257SPeiming Liu       }
1318*c4420257SPeiming Liu     }
1319*c4420257SPeiming Liu     // End a loop.
1320*c4420257SPeiming Liu     needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1321*c4420257SPeiming Liu   } else {
1322*c4420257SPeiming Liu     // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1323a2c9d4bbSAart Bik     for (unsigned i = 0; i < lsize; i++) {
1324b8cf7af9Swren romano       const LatPointId li = env.set(lts)[i];
13257c7c10a0Swren romano       // Start a loop.
1326*c4420257SPeiming Liu       auto [loop, isSingleCond] =
1327*c4420257SPeiming Liu           startLoop(env, rewriter, curr, li, lsize, needsUniv);
1328a2c9d4bbSAart Bik 
1329a2c9d4bbSAart Bik       // Visit all lattices points with Li >= Lj to generate the
1330a2c9d4bbSAart Bik       // loop-body, possibly with if statements for coiteration.
1331fbe61130SAart Bik       Value redInput = env.getReduc();
1332384049a7SAart Bik       Value cntInput = env.getExpandCount();
1333384049a7SAart Bik       Value insInput = env.getInsertionChain();
1334fc5d8fceSPeiming Liu       Value validIns = env.getValidLexInsert();
1335067bebb5SAart Bik       // We cannot change this to `for (const LatPointId lj : env.set(lts))`
13367c7c10a0Swren romano       // because the loop body causes data-movement which invalidates the
13377c7c10a0Swren romano       // iterator.
1338a2c9d4bbSAart Bik       for (unsigned j = 0; j < lsize; j++) {
1339b8cf7af9Swren romano         const LatPointId lj = env.set(lts)[j];
1340b8cf7af9Swren romano         const ExprId ej = env.lat(lj).exp;
1341384049a7SAart Bik         if (li == lj || env.merger().latGT(li, lj)) {
1342a2c9d4bbSAart Bik           // Recurse into body of each branch.
13431328bb6eSPeiming Liu           if (!isSingleCond) {
1344c5a1732cSAart Bik             scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
1345c5a1732cSAart Bik             genStmt(env, rewriter, ej, curr + 1);
1346372d88b0SPeiming Liu             endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1347a2c9d4bbSAart Bik           } else {
1348c5a1732cSAart Bik             genStmt(env, rewriter, ej, curr + 1);
1349a2c9d4bbSAart Bik           }
1350a2c9d4bbSAart Bik         }
1351a2c9d4bbSAart Bik       }
1352a2c9d4bbSAart Bik 
1353c8d5dcb0SAart Bik       // End a loop.
1354c5a1732cSAart Bik       needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1355a2c9d4bbSAart Bik     }
1356*c4420257SPeiming Liu   }
1357a2c9d4bbSAart Bik 
1358c8d5dcb0SAart Bik   // End a loop sequence.
1359c5a1732cSAart Bik   endLoopSeq(env, rewriter, exp, curr);
1360c5a1732cSAart Bik   assert(curr == env.getCurrentDepth());
1361a2c9d4bbSAart Bik }
1362a2c9d4bbSAart Bik 
1363727a63e0SAart Bik /// Converts the result computed by the sparse kernel into the required form.
1364fbe61130SAart Bik static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
1365384049a7SAart Bik   linalg::GenericOp op = env.op();
1366b4db15a9SAlexander Belyaev   OpOperand *lhs = op.getDpsInitOperand(0);
13675661647eSAart Bik   Value tensor = lhs->get();
13685661647eSAart Bik   Type resType = tensor.getType();
1369f66e5769SAart Bik   if (getSparseTensorEncoding(resType)) {
1370f66e5769SAart Bik     // The sparse tensor rematerializes from the original sparse tensor's
13715661647eSAart Bik     // underlying sparse storage format. For an insertion chain, the
13725661647eSAart Bik     // tensor materializes from the chain with 'hasInserts' enabled.
1373384049a7SAart Bik     bool hasInserts = false;
1374384049a7SAart Bik     if (Value chain = env.getInsertionChain()) {
1375384049a7SAart Bik       hasInserts = true;
1376384049a7SAart Bik       tensor = chain;
1377384049a7SAart Bik     }
13785661647eSAart Bik     rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts);
137936b66ab9SAart Bik   } else {
1380f66e5769SAart Bik     // To rematerialize an non-annotated tensor, simply load it
138136b66ab9SAart Bik     // from the bufferized value.
1382e7b4c93fSPeiming Liu     Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()];
138357470abcSAlexander Belyaev     rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
138436b66ab9SAart Bik   }
1385727a63e0SAart Bik }
1386727a63e0SAart Bik 
13875da21338SAart Bik //===----------------------------------------------------------------------===//
1388c43e6274STim Harvey // Sparsifier rewriting methods.
13895da21338SAart Bik //===----------------------------------------------------------------------===//
13905da21338SAart Bik 
1391a2c9d4bbSAart Bik namespace {
139298f93e3bSAart Bik 
1393a2c9d4bbSAart Bik /// Sparse rewriting rule for generic Lingalg operation.
1394a2c9d4bbSAart Bik struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1395a2c9d4bbSAart Bik public:
1396a2c9d4bbSAart Bik   GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1397a2c9d4bbSAart Bik       : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1398a2c9d4bbSAart Bik 
1399a2c9d4bbSAart Bik   LogicalResult matchAndRewrite(linalg::GenericOp op,
1400a2c9d4bbSAart Bik                                 PatternRewriter &rewriter) const override {
1401740582faSAart Bik     // Only accept single output operations with pure tensor semantics.
14020a8e3dd4SMatthias Springer     if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics())
1403740582faSAart Bik       return failure();
1404740582faSAart Bik 
1405740582faSAart Bik     // Only accept trivial affine indices.
1406740582faSAart Bik     if (hasNonTrivialAffineOnSparseOut(op))
140711d75076SBenjamin Kramer       return failure();
140898f93e3bSAart Bik 
1409067bebb5SAart Bik     // Only accept scheduled loops.
141006a65ce5SPeiming Liu     if (!op->hasAttr("sorted")) {
141106a65ce5SPeiming Liu       return rewriter.notifyMatchFailure(
141206a65ce5SPeiming Liu           op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
141306a65ce5SPeiming Liu               "before sparsification.");
141406a65ce5SPeiming Liu     }
141598ce2debSAart Bik 
1416ccd923e3SPeiming Liu     // Must have been demapped as well if the generic op is sorted.
1417ccd923e3SPeiming Liu     assert(!hasAnyNonIdentityOperandsOrResults(op));
141806a65ce5SPeiming Liu 
141998f93e3bSAart Bik     // Sets up a code generation environment.
1420b8cf7af9Swren romano     const unsigned numTensors = op->getNumOperands();
1421b8cf7af9Swren romano     const unsigned numLoops = op.getNumLoops();
1422ccd923e3SPeiming Liu     bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0;
14232b21327fSPeiming Liu     // If we have indexing map like (d0) -> (0, d0), there might be more
14242b21327fSPeiming Liu     // levels then loops because of the constant index, that means we can not
14252b21327fSPeiming Liu     // use numLoops as the upper bound for ranks of all tensors.
14262b21327fSPeiming Liu     // TODO: Constant indices are currently not support on sparse tensor, but
14272b21327fSPeiming Liu     // are allowed in non-annotated dense tensor. Support it, it would be
14282b21327fSPeiming Liu     // required for sparse tensor slice rank reducing too.
14292b21327fSPeiming Liu     Level maxLvlRank = 0;
14302b21327fSPeiming Liu     for (auto operand : op.getOperands()) {
14315550c821STres Popp       if (auto rtp = dyn_cast<RankedTensorType>(operand.getType())) {
14322b21327fSPeiming Liu         maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank());
14332b21327fSPeiming Liu       }
14342b21327fSPeiming Liu     }
143506a65ce5SPeiming Liu 
1436b8cf7af9Swren romano     // Detects sparse annotations and translates the per-level sparsity
143798f93e3bSAart Bik     // information for all tensors to loop indices in the kernel.
1438067bebb5SAart Bik     CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
1439ccd923e3SPeiming Liu     if (!findSparseAnnotations(env, needIdxRed))
1440bf9ef3efSAart Bik       return failure();
1441a2c9d4bbSAart Bik 
1442378f1885SAart Bik     // Only standard reduction operations (add, sub, or, xor) that can be
1443378f1885SAart Bik     // sparsified by merely reducing the stored values are admissible. More
1444378f1885SAart Bik     // elaborate reduction operations (such as mul, and, min, max) would need
1445378f1885SAart Bik     // to know whether implicit zeros occur as well. They can still be
1446378f1885SAart Bik     // implemented with a custom reduction operation, accepted here as well.
1447378f1885SAart Bik     if (op.getNumReductionLoops() > 0) {
1448378f1885SAart Bik       Operation *yield = op.getRegion().front().getTerminator();
1449378f1885SAart Bik       assert(isa<linalg::YieldOp>(yield));
1450378f1885SAart Bik       Operation *redop = yield->getOperand(0).getDefiningOp();
1451378f1885SAart Bik       if (!isa<arith::AddFOp>(redop) && !isa<complex::AddOp>(redop) &&
1452378f1885SAart Bik           !isa<arith::AddIOp>(redop) && !isa<arith::SubFOp>(redop) &&
1453378f1885SAart Bik           !isa<complex::SubOp>(redop) && !isa<arith::SubIOp>(redop) &&
1454378f1885SAart Bik           !isa<arith::OrIOp>(redop) && !isa<arith::XOrIOp>(redop) &&
1455378f1885SAart Bik           !isa<ReduceOp>(redop)) {
1456378f1885SAart Bik         return failure();
1457378f1885SAart Bik       }
1458378f1885SAart Bik     }
1459378f1885SAart Bik 
1460a7bf2e55SPeiming Liu     // Constructs the tensor expressions tree from `op`, returns failure if the
1461a7bf2e55SPeiming Liu     // tree can not be built or the tensor expression is inadmissible.
1462a7bf2e55SPeiming Liu     if (failed(env.initTensorExp()))
146355a1d50fSPeiming Liu       return failure();
146498f93e3bSAart Bik 
1465b0f8057eSPeiming Liu     // Recursively generates code if admissible.
14664a653b4dSPeiming Liu     env.startEmit(options.sparseEmitStrategy);
146798f93e3bSAart Bik     genBuffers(env, rewriter);
14681328bb6eSPeiming Liu     // TODO: Constant affine expression should be handled differently when using
14692cb99df6SYinying Li     // slice-based codegen, it does not matter now because we already reject the
1470c5a1732cSAart Bik     // constant expression at an earlier stage.
147198f93e3bSAart Bik     genInitConstantDenseAddress(env, rewriter);
1472b8cf7af9Swren romano     genStmt(env, rewriter, env.getExprId(), 0);
147398f93e3bSAart Bik     genResult(env, rewriter);
1474a2c9d4bbSAart Bik     return success();
1475a2c9d4bbSAart Bik   }
1476a2c9d4bbSAart Bik 
1477a2c9d4bbSAart Bik private:
1478a2c9d4bbSAart Bik   /// Options to control sparse code generation.
1479a2c9d4bbSAart Bik   SparsificationOptions options;
1480a2c9d4bbSAart Bik };
1481a2c9d4bbSAart Bik 
1482a2c9d4bbSAart Bik } // namespace
1483a2c9d4bbSAart Bik 
1484a2c9d4bbSAart Bik /// Populates the given patterns list with rewriting rules required for
1485a2c9d4bbSAart Bik /// the sparsification of linear algebra operations.
1486a2c9d4bbSAart Bik void mlir::populateSparsificationPatterns(
1487a2c9d4bbSAart Bik     RewritePatternSet &patterns, const SparsificationOptions &options) {
1488a2c9d4bbSAart Bik   patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1489a2c9d4bbSAart Bik }
1490