196a23911SAart Bik //===- Sparsification.cpp - Implementation of sparsification --------------===// 2a2c9d4bbSAart Bik // 3a2c9d4bbSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a2c9d4bbSAart Bik // See https://llvm.org/LICENSE.txt for license information. 5a2c9d4bbSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a2c9d4bbSAart Bik // 7a2c9d4bbSAart Bik //===----------------------------------------------------------------------===// 8a2c9d4bbSAart Bik // 9160399c7SAart Bik // This file implements converting sparse tensor types to actual sparse code. 10a2c9d4bbSAart Bik // 11a2c9d4bbSAart Bik //===----------------------------------------------------------------------===// 12a2c9d4bbSAart Bik 13365777ecSAart Bik #include "Utils/CodegenEnv.h" 14365777ecSAart Bik #include "Utils/CodegenUtils.h" 15365777ecSAart Bik #include "Utils/LoopEmitter.h" 1653cc3a06SAart Bik 1776a18618SMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h" 18abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 197a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 2057470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 2136550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 2263015742SJavier Setoain #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 23b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h" 24a2c9d4bbSAart Bik #include "mlir/Dialect/Linalg/Utils/Utils.h" 2566f878ceSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 268b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 278b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h" 28a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 29f708a549Swren romano #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 30a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 31744146f6SGus Smith #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 326d8e2f1eSAart Bik #include "mlir/Dialect/Tensor/IR/Tensor.h" 33fb287335SPeiming Liu #include "mlir/IR/AffineExprVisitor.h" 34a2c9d4bbSAart Bik #include "mlir/IR/Matchers.h" 3596a23911SAart Bik #include "mlir/IR/TensorEncoding.h" 36a2c9d4bbSAart Bik #include "llvm/ADT/SmallBitVector.h" 37067bebb5SAart Bik 38a1fe1f5fSKazu Hirata #include <optional> 39a2c9d4bbSAart Bik 40a2c9d4bbSAart Bik using namespace mlir; 4196a23911SAart Bik using namespace mlir::sparse_tensor; 42a2c9d4bbSAart Bik 435da21338SAart Bik //===----------------------------------------------------------------------===// 44c43e6274STim Harvey // Sparsifier analysis methods. 455da21338SAart Bik //===----------------------------------------------------------------------===// 465da21338SAart Bik 4798ce2debSAart Bik /// Returns true iff affine expression is invariant. Sets the 48c5a1732cSAart Bik /// parameter `isCurrentLoop` when expression just became invariant. 49c5a1732cSAart Bik static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop) { 5027aabca0SPeiming Liu switch (a.getKind()) { 5127aabca0SPeiming Liu case AffineExprKind::DimId: { 521609f1c2Slong.chen const LoopId i = cast<AffineDimExpr>(a).getPosition(); 53c5a1732cSAart Bik if (i + 1 == curr) { 54c5a1732cSAart Bik isCurrentLoop = true; 5598ce2debSAart Bik return true; // becomes invariant at current loop 5627aabca0SPeiming Liu } 57c5a1732cSAart Bik return i < curr; // invariant when already generated 5827aabca0SPeiming Liu } 5927aabca0SPeiming Liu case AffineExprKind::Add: 6027aabca0SPeiming Liu case AffineExprKind::Mul: { 611609f1c2Slong.chen auto binOp = cast<AffineBinaryOpExpr>(a); 62c5a1732cSAart Bik return isInvariantAffine(binOp.getLHS(), curr, isCurrentLoop) && 63c5a1732cSAart Bik isInvariantAffine(binOp.getRHS(), curr, isCurrentLoop); 6427aabca0SPeiming Liu } 6527aabca0SPeiming Liu default: { 661609f1c2Slong.chen assert(isa<AffineConstantExpr>(a)); 6727aabca0SPeiming Liu return true; 6827aabca0SPeiming Liu } 6927aabca0SPeiming Liu } 7027aabca0SPeiming Liu } 7127aabca0SPeiming Liu 72b1d44e59SAart Bik /// Helper method to inspect affine expressions. Rejects cases where the 73b22397feSAart Bik /// same index is used more than once. Also rejects compound affine 74b22397feSAart Bik /// expressions in sparse dimensions. 75b8cf7af9Swren romano static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, 761944c4f7SAart Bik LevelType lt, bool setLvlFormat = true) { 77b1d44e59SAart Bik switch (a.getKind()) { 78b1d44e59SAart Bik case AffineExprKind::DimId: { 791609f1c2Slong.chen const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition()); 801dd387e1SAart Bik if (!isUndefLT(merger.getLvlType(tid, idx))) 81b1d44e59SAart Bik return false; // used more than once 82b0f8057eSPeiming Liu if (setLvlFormat) 831dd387e1SAart Bik merger.setLevelAndType(tid, idx, lvl, lt); 84b1d44e59SAart Bik return true; 85b1d44e59SAart Bik } 86b1d44e59SAart Bik case AffineExprKind::Add: 87372e7939SPeiming Liu case AffineExprKind::Mul: 88372e7939SPeiming Liu case AffineExprKind::Constant: { 8952b69aa3SPeiming Liu assert(lt.hasDenseSemantic()); 901609f1c2Slong.chen if (auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) { 912cb99df6SYinying Li // We do not set dim level format for affine expression like d0 + d1 on 92c5a1732cSAart Bik // either loop index at d0 or d1. We continue the recursion merely to 93c5a1732cSAart Bik // check whether current affine is admissible or not. 941dd387e1SAart Bik return findAffine(merger, tid, lvl, binOp.getLHS(), lt, false) && 951dd387e1SAart Bik findAffine(merger, tid, lvl, binOp.getRHS(), lt, false); 96372e7939SPeiming Liu } 97372e7939SPeiming Liu // Falls through when it is a constant Affine 98372e7939SPeiming Liu return true; 99372e7939SPeiming Liu } 100b1d44e59SAart Bik default: 101b1d44e59SAart Bik return false; 102b1d44e59SAart Bik } 103b1d44e59SAart Bik } 104b1d44e59SAart Bik 105d03805f2SPeiming Liu /// Helper method to inspect affine expressions for index variable reduction 106d03805f2SPeiming Liu /// based codegen. It finds the dependent index set for all tensor levels in the 107d03805f2SPeiming Liu /// current expression we are generating. 108d03805f2SPeiming Liu /// 109d03805f2SPeiming Liu /// For example, when handling A[i+j][j+k], we build the two way mapping in 110d03805f2SPeiming Liu /// merger between (tensor, level) pairs and their dependent index variable set: 111d03805f2SPeiming Liu /// A_0 <=> [i, j] and A_1 <=> [j, k] 112d03805f2SPeiming Liu /// 113d03805f2SPeiming Liu /// It rejects cases (returns false) 114d03805f2SPeiming Liu /// 1st, when the same index is used more than once, e.g., A[i+j][i] 115d03805f2SPeiming Liu /// 2nd, when multiplication is used in the non-trivial index expression. 116d03805f2SPeiming Liu /// 3rd, when a constant operand is used in the non-trivial index expression. 117d03805f2SPeiming Liu /// 118d03805f2SPeiming Liu /// TODO: constant should be easy to handle. 119d03805f2SPeiming Liu static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, 1201944c4f7SAart Bik AffineExpr a, LevelType lt, bool isSubExp = false, 121e015d385SPeiming Liu int64_t coefficient = 1) { 122d03805f2SPeiming Liu switch (a.getKind()) { 123d03805f2SPeiming Liu case AffineExprKind::DimId: { 124e015d385SPeiming Liu // Only allow positive coefficients on AffineDimExpr. 125e015d385SPeiming Liu if (coefficient <= 0) 126e015d385SPeiming Liu return false; 127e015d385SPeiming Liu 12898ce2debSAart Bik const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition()); 12998ce2debSAart Bik if (!isUndefLT(merger.getLvlType(tensor, idx))) 130d03805f2SPeiming Liu return false; // used more than once, e.g., A[i][i] 131d03805f2SPeiming Liu 132d03805f2SPeiming Liu // TODO: Generalizes the following two cases. A[i] (with trivial index 133d03805f2SPeiming Liu // expression) can be treated as a special affine index expression. We do 134d03805f2SPeiming Liu // not necessarily need to differentiate them. 135e015d385SPeiming Liu if (!isSubExp) { 136e015d385SPeiming Liu assert(coefficient == 1); 13798ce2debSAart Bik merger.setLevelAndType(tensor, idx, lvl, lt); 138e015d385SPeiming Liu } 139d03805f2SPeiming Liu 140d03805f2SPeiming Liu if (isSubExp) { 141d03805f2SPeiming Liu // The current loops appears in more than one affine expressions on the 142d03805f2SPeiming Liu // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is 143d03805f2SPeiming Liu // used twice. 14498ce2debSAart Bik if (merger.hasDependentLvl(idx, tensor)) { 145d03805f2SPeiming Liu // TODO: This can be supported by coiterate slices if the loop idx is 146d03805f2SPeiming Liu // appeared on affine index for different tensor, or take slice on 1472cb99df6SYinying Li // multiple dimensions when it is on the same tensor. 148d03805f2SPeiming Liu // E.g., 149d03805f2SPeiming Liu // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0] 150d03805f2SPeiming Liu // d0_1 = getNextSliceOffset t0 along lvl0 151d03805f2SPeiming Liu // d0_2 = getNextSliceOffset t1 along lvl0 152d03805f2SPeiming Liu // if d0_1 == d0_2 then d0 = d0_1 = d0_1 153d03805f2SPeiming Liu // else increase min(d0_1, d0_2). 154d03805f2SPeiming Liu return false; 155d03805f2SPeiming Liu } 15698ce2debSAart Bik merger.setLoopDependentTensorLevel(idx, tensor, lvl, lt, coefficient); 157d03805f2SPeiming Liu } 158d03805f2SPeiming Liu return true; 159d03805f2SPeiming Liu } 160d03805f2SPeiming Liu case AffineExprKind::Constant: 161e015d385SPeiming Liu case AffineExprKind::Mul: { 162e015d385SPeiming Liu // TODO: Support index expression like `2 * d0`, we now only support more 163e015d385SPeiming Liu // complicated cases like `2 * d0 + d1`. 164e015d385SPeiming Liu if (!isSubExp) 165d03805f2SPeiming Liu return false; 16626968554SPeiming Liu 16726968554SPeiming Liu // TODO: Support Constant AffineExp for slice-based codegen 1681609f1c2Slong.chen if (isa<AffineConstantExpr>(a)) 16926968554SPeiming Liu llvm_unreachable("Not yet implemented"); 17026968554SPeiming Liu 1711609f1c2Slong.chen auto binOp = cast<AffineBinaryOpExpr>(a); 172e015d385SPeiming Liu auto lhs = binOp.getLHS(), rhs = binOp.getRHS(); 1731609f1c2Slong.chen if (isa<AffineConstantExpr>(rhs)) 174e015d385SPeiming Liu std::swap(lhs, rhs); 175e015d385SPeiming Liu // Must be in form of `constant * d`. 1761609f1c2Slong.chen assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs)); 1771609f1c2Slong.chen int64_t coefficient = cast<AffineConstantExpr>(lhs).getValue(); 1781dd387e1SAart Bik return findDepIdxSet(merger, tensor, lvl, rhs, lt, isSubExp, coefficient); 179e015d385SPeiming Liu } 180d03805f2SPeiming Liu case AffineExprKind::Add: { 1811609f1c2Slong.chen auto binOp = cast<AffineBinaryOpExpr>(a); 1821dd387e1SAart Bik return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), lt, true) && 1831dd387e1SAart Bik findDepIdxSet(merger, tensor, lvl, binOp.getRHS(), lt, true); 184d03805f2SPeiming Liu } 185d03805f2SPeiming Liu default: 186d03805f2SPeiming Liu return false; 187d03805f2SPeiming Liu } 188d03805f2SPeiming Liu } 189d03805f2SPeiming Liu 190067bebb5SAart Bik /// Gets the total number of compound affine expressions in the 191b8cf7af9Swren romano /// `getMatchingIndexingMap` for the given tensor. For the following inputs: 192372e7939SPeiming Liu /// 1932a07f0fdSYinying Li /// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed) 194372e7939SPeiming Liu /// 195372e7939SPeiming Liu /// Returns 1 (because the first level is compressed and its corresponding 196b8cf7af9Swren romano /// indexing-expression is `d0 + d1`) 197d03805f2SPeiming Liu static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, 198d03805f2SPeiming Liu Value tensor) { 1992cb99df6SYinying Li // The `tensor` is not guaranteed to have `RankedTensorType`, therefore 200b8cf7af9Swren romano // we can't use `getRankedTensorType`/`getSparseTensorType` here. 201b8cf7af9Swren romano // However, we don't need to handle `StorageSpecifierType`, so we 202b8cf7af9Swren romano // can use `SparseTensorType` once we guard against non-tensors. 2035550c821STres Popp const auto rtp = dyn_cast<RankedTensorType>(tensor.getType()); 204b8cf7af9Swren romano if (!rtp) 205b8cf7af9Swren romano return 0; 206b8cf7af9Swren romano const SparseTensorType stt(rtp); 207b8cf7af9Swren romano 208b8cf7af9Swren romano const Level lvlRank = stt.getLvlRank(); 209b8cf7af9Swren romano const auto exprs = map.getResults(); 210ccd923e3SPeiming Liu assert(static_cast<Dimension>(exprs.size()) == lvlRank && 211b8cf7af9Swren romano "AffineMap does not have dimension-rank many results"); 212372e7939SPeiming Liu unsigned num = 0; 213f708a549Swren romano for (Level l = 0; l < lvlRank; l++) { 21452b69aa3SPeiming Liu if (!isa<AffineDimExpr>(exprs[l]) && !stt.getLvlType(l).hasDenseSemantic()) 215372e7939SPeiming Liu num++; 216372e7939SPeiming Liu } 217372e7939SPeiming Liu return num; 218372e7939SPeiming Liu } 219372e7939SPeiming Liu 220067bebb5SAart Bik /// Gets the total number of sparse levels with compound affine 221b8cf7af9Swren romano /// expressions, summed over all operands of the `GenericOp`. 222d03805f2SPeiming Liu static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) { 223372e7939SPeiming Liu unsigned num = 0; 224372e7939SPeiming Liu for (OpOperand &t : op->getOpOperands()) 225d03805f2SPeiming Liu num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t), 226372e7939SPeiming Liu t.get()); 227372e7939SPeiming Liu return num; 228372e7939SPeiming Liu } 229372e7939SPeiming Liu 230067bebb5SAart Bik // Returns true iff output has nontrivial affine indices. 231d03805f2SPeiming Liu static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) { 232a7bf2e55SPeiming Liu OpOperand *out = op.getDpsInitOperand(0); 233f708a549Swren romano if (getSparseTensorType(out->get()).isAllDense()) 234a7bf2e55SPeiming Liu return false; 235d03805f2SPeiming Liu return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out), 236a7bf2e55SPeiming Liu out->get()); 237a7bf2e55SPeiming Liu } 238a7bf2e55SPeiming Liu 23996a23911SAart Bik /// Helper method to inspect sparse encodings in the tensor types. 240a2c9d4bbSAart Bik /// Fills the per-dimension sparsity information for all tensors. 241b1d44e59SAart Bik /// Returns true if the sparse annotations and affine subscript 242a3610359SAart Bik /// expressions of all tensors are admissible. Returns false if 243a3610359SAart Bik /// no annotations are found or inadmissible constructs occur. 244d03805f2SPeiming Liu /// We currently support two different ways to handle non-trivial index 245d03805f2SPeiming Liu /// expression on sparse tensors, and they accept different affine expressions. 246d03805f2SPeiming Liu /// When using dependent index reducton-based approach, it currently only 247d03805f2SPeiming Liu /// supports affine addition index expression. 248d03805f2SPeiming Liu static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) { 249bf9ef3efSAart Bik bool annotated = false; 250384049a7SAart Bik for (OpOperand &t : env.op()->getOpOperands()) { 25146a384dfSwren romano const TensorId tid = env.makeTensorId(t.getOperandNumber()); 252f708a549Swren romano const auto map = env.op().getMatchingIndexingMap(&t); 253f708a549Swren romano const auto enc = getSparseTensorEncoding(t.get().getType()); 254727a63e0SAart Bik if (enc) 255bf9ef3efSAart Bik annotated = true; 256f708a549Swren romano const Level lvlRank = map.getNumResults(); 257f708a549Swren romano assert(!enc || lvlRank == enc.getLvlRank()); 25871251e8dSJie Fu assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank); 259c5a1732cSAart Bik // We only need to do index reduction if there is at least one 260c5a1732cSAart Bik // non-trivial index expression on sparse levels. If all non-trivial 261c5a1732cSAart Bik // index expression is on dense levels, we can efficiently rely on 262c5a1732cSAart Bik // the random access to locate the element. 263d03805f2SPeiming Liu bool needIdxReduc = 264d03805f2SPeiming Liu enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0; 265d03805f2SPeiming Liu // If then current tensor being inspected requires affine index, it need 266d03805f2SPeiming Liu // to be sliced. 267f708a549Swren romano for (Level l = 0; l < lvlRank; l++) { 268ccd923e3SPeiming Liu const AffineExpr a = map.getResult(l); 2691944c4f7SAart Bik const LevelType lt = enc.getLvlType(l); 270d03805f2SPeiming Liu if (idxReducBased && needIdxReduc) { 2711dd387e1SAart Bik if (!findDepIdxSet(env.merger(), tid, l, a, lt)) 272a3610359SAart Bik return false; // inadmissible affine expression 273d03805f2SPeiming Liu } else { 2741dd387e1SAart Bik if (!findAffine(env.merger(), tid, l, a, lt)) 275d03805f2SPeiming Liu return false; // inadmissible affine expression 276d03805f2SPeiming Liu } 277a2c9d4bbSAart Bik } 278a2c9d4bbSAart Bik } 279bf9ef3efSAart Bik return annotated; 280a2c9d4bbSAart Bik } 281a2c9d4bbSAart Bik 2825da21338SAart Bik //===----------------------------------------------------------------------===// 283c43e6274STim Harvey // Sparsifier synthesis methods (statements and expressions). 2847373cabcSAart Bik //===----------------------------------------------------------------------===// 2857373cabcSAart Bik 286a2c9d4bbSAart Bik /// Local bufferization of all dense and sparse data structures. 287fbe61130SAart Bik static void genBuffers(CodegenEnv &env, OpBuilder &builder) { 288384049a7SAart Bik linalg::GenericOp op = env.op(); 289a2c9d4bbSAart Bik Location loc = op.getLoc(); 290b4db15a9SAlexander Belyaev assert(op.getNumOperands() == op.getNumDpsInputs() + 1); 291b0f8057eSPeiming Liu 29283b7f018SPeiming Liu SmallVector<Range, 4> loopRange = 29383b7f018SPeiming Liu llvm::cast<linalg::LinalgOp>(op.getOperation()) 29483b7f018SPeiming Liu .createLoopRanges(builder, loc); 29583b7f018SPeiming Liu 29636fd2875SAart Bik env.emitter().initializeLoopEmit( 297b0f8057eSPeiming Liu builder, loc, 298b0f8057eSPeiming Liu /// Generates buffer for the output tensor. 299b0f8057eSPeiming Liu /// Note that all sparse kernels assume that when all elements are written 300b0f8057eSPeiming Liu /// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized 301b0f8057eSPeiming Liu /// to all zeroes and only nonzeroes values are computed and written out. 302b0f8057eSPeiming Liu /// For updates (viz. x(i) += y(i) * z(i)), only nonzeroes values are used 303b0f8057eSPeiming Liu /// for the updates and no assumption on the original contents of the 304b0f8057eSPeiming Liu /// output buffer is necessary. 305b0f8057eSPeiming Liu [&op](OpBuilder &builder, Location loc, Value memref, 306b0f8057eSPeiming Liu Value tensor) -> Value { 307b0f8057eSPeiming Liu // Must not be a sparse tensor. 308b0f8057eSPeiming Liu assert(!getSparseTensorEncoding(tensor.getType())); 30998f93e3bSAart Bik // Two output tensor references should point to the same object. 310b4db15a9SAlexander Belyaev OpOperand *lhs = op.getDpsInitOperand(0); 311b0f8057eSPeiming Liu assert(lhs->get() == tensor); 312b0f8057eSPeiming Liu // An output tensor can simply materialize from the buffer of the tensor 313b0f8057eSPeiming Liu // that appears in the outs() clause. For updates, this has the 314b0f8057eSPeiming Liu // advantage that only the nonzero value are involved in the 315b0f8057eSPeiming Liu // computation, keeping the operation O(nnz). In all other cases, we are 316b0f8057eSPeiming Liu // forced to zero out the buffer to enforce the assumption above, which 317b0f8057eSPeiming Liu // may negatively impact running complexity (viz. O(n^2 + nnz) vs. 318b0f8057eSPeiming Liu // O(nnz) for matrices). 319b0f8057eSPeiming Liu // TODO: use better analysis to avoid zeroing out the buffer? 32098f93e3bSAart Bik bool isInit = op.isInitTensor(lhs); 321b0f8057eSPeiming Liu Value init = memref; 322b0f8057eSPeiming Liu if (!isInit) { 323b0f8057eSPeiming Liu Value zero = constantZero(builder, loc, 324b0f8057eSPeiming Liu getElementTypeOrSelf(tensor.getType())); 325b0f8057eSPeiming Liu builder.create<linalg::FillOp>(loc, ValueRange{zero}, 326b0f8057eSPeiming Liu ValueRange{init}); 327a2c9d4bbSAart Bik } 328b0f8057eSPeiming Liu return init; 32983b7f018SPeiming Liu }, 330ccd923e3SPeiming Liu [&loopRange](OpBuilder &b, Location loc, Level l) { 331ccd923e3SPeiming Liu assert(l < loopRange.size()); 332ff8815e5SPeiming Liu return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size); 333b0f8057eSPeiming Liu }); 334a2c9d4bbSAart Bik } 335a2c9d4bbSAart Bik 3364f2ec7f9SAart Bik /// Generates index for load/store on sparse tensor. 337fbe61130SAart Bik static Value genIndex(CodegenEnv &env, OpOperand *t) { 33846a384dfSwren romano const auto map = env.op().getMatchingIndexingMap(t); 339f708a549Swren romano const auto stt = getSparseTensorType(t->get()); 340f708a549Swren romano const Level lvlRank = stt.getLvlRank(); 341f708a549Swren romano assert(static_cast<Level>(map.getNumResults()) == lvlRank); 342ccd923e3SPeiming Liu const AffineExpr a = map.getResult(lvlRank - 1); 3434f2ec7f9SAart Bik assert(a.getKind() == AffineExprKind::DimId); 3441609f1c2Slong.chen const LoopId idx = env.makeLoopId(cast<AffineDimExpr>(a).getPosition()); 345b8cf7af9Swren romano return env.getLoopVar(idx); 3464f2ec7f9SAart Bik } 3474f2ec7f9SAart Bik 348b1d44e59SAart Bik /// Generates subscript for load/store on a dense or sparse tensor. 349fbe61130SAart Bik static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, 3500e1708ffSAart Bik SmallVectorImpl<Value> &args) { 351b8cf7af9Swren romano const Location loc = env.op().getLoc(); 35246a384dfSwren romano const TensorId tid = env.makeTensorId(t->getOperandNumber()); 353b8cf7af9Swren romano const auto map = env.op().getMatchingIndexingMap(t); 354f708a549Swren romano const auto stt = getSparseTensorType(t->get()); 355f708a549Swren romano if (stt.hasEncoding()) { 356b8cf7af9Swren romano // For sparse tensors we only push the last-level's position onto `args`. 357298412b5SPeiming Liu const auto pos = env.emitter().getValPosits(tid); 35852b69aa3SPeiming Liu assert(!pos.empty()); 35952b69aa3SPeiming Liu args.append(pos); 360951a3630SPeiming Liu // Simply returns the tensor to extract value using iterators. 361951a3630SPeiming Liu if (env.options().sparseEmitStrategy == SparseEmitStrategy::kSparseIterator) 362951a3630SPeiming Liu return t->get(); 363b1d44e59SAart Bik } else { 364b8cf7af9Swren romano // For dense tensors we push all level's coordinates onto `args`. 365f708a549Swren romano const Level lvlRank = stt.getLvlRank(); 366f708a549Swren romano assert(static_cast<Level>(map.getNumResults()) == lvlRank); 367f708a549Swren romano for (Level l = 0; l < lvlRank; l++) { 368b8cf7af9Swren romano const auto lvlExpr = map.getResult(l); 369b8cf7af9Swren romano const auto lvlCrd = env.emitter().genAffine(builder, loc, lvlExpr); 370b8cf7af9Swren romano args.push_back(lvlCrd); 371b1d44e59SAart Bik } 372b1d44e59SAart Bik } 373b8cf7af9Swren romano return env.emitter().getValBuffer()[tid]; 374b1d44e59SAart Bik } 375b1d44e59SAart Bik 3764f2ec7f9SAart Bik /// Generates insertion code to implement dynamic tensor load. 377fbe61130SAart Bik static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder, 37898f93e3bSAart Bik OpOperand *t) { 379384049a7SAart Bik linalg::GenericOp op = env.op(); 3804f2ec7f9SAart Bik Location loc = op.getLoc(); 38184cd51bbSwren romano // Direct lexicographic coordinate order, tensor loads as zero. 382384049a7SAart Bik if (!env.isExpand()) { 3834f2ec7f9SAart Bik Type tp = getElementTypeOrSelf(t->get().getType()); 384e9fa5590SMatthias Springer return constantZero(builder, loc, tp); 3854f2ec7f9SAart Bik } 3864f2ec7f9SAart Bik // Load from expanded access pattern. 38798f93e3bSAart Bik Value index = genIndex(env, t); 388384049a7SAart Bik return builder.create<memref::LoadOp>(loc, env.getExpandValues(), index); 3894f2ec7f9SAart Bik } 3904f2ec7f9SAart Bik 391c8bb2354SJim Kitchen /// Generates insertion code to implement dynamic tensor load for reduction. 392fbe61130SAart Bik static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, 393c8bb2354SJim Kitchen OpOperand *t) { 394384049a7SAart Bik linalg::GenericOp op = env.op(); 395c8bb2354SJim Kitchen Location loc = op.getLoc(); 396fbe61130SAart Bik Value identity = env.getCustomRedId(); 39784cd51bbSwren romano // Direct lexicographic coordinate order, tensor loads as identity. 398384049a7SAart Bik if (!env.isExpand()) 399c8bb2354SJim Kitchen return identity; 400c8bb2354SJim Kitchen // Load from expanded access pattern if filled, identity otherwise. 401384049a7SAart Bik Value values = env.getExpandValues(); 402384049a7SAart Bik Value filled = env.getExpandFilled(); 40398f93e3bSAart Bik Value index = genIndex(env, t); 404384049a7SAart Bik Value isFilled = builder.create<memref::LoadOp>(loc, filled, index); 405384049a7SAart Bik Value valAtIndex = builder.create<memref::LoadOp>(loc, values, index); 406c8bb2354SJim Kitchen return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity); 407c8bb2354SJim Kitchen } 408c8bb2354SJim Kitchen 4093aeb28b9SPeiming Liu static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond, 4103aeb28b9SPeiming Liu Value sparseOut, ValueRange ivs, Value v) { 4113aeb28b9SPeiming Liu scf::IfOp condInsert = 4123aeb28b9SPeiming Liu builder.create<scf::IfOp>(loc, sparseOut.getType(), cond, true); 4133aeb28b9SPeiming Liu // True branch. 4143aeb28b9SPeiming Liu builder.setInsertionPointToStart(condInsert.thenBlock()); 4153aeb28b9SPeiming Liu Value res = builder.create<tensor::InsertOp>(loc, v, sparseOut, ivs); 4163aeb28b9SPeiming Liu builder.create<scf::YieldOp>(loc, res); 4173aeb28b9SPeiming Liu // False branch. 4183aeb28b9SPeiming Liu builder.setInsertionPointToStart(condInsert.elseBlock()); 4193aeb28b9SPeiming Liu builder.create<scf::YieldOp>(loc, sparseOut); 4203aeb28b9SPeiming Liu // Value assignment. 4213aeb28b9SPeiming Liu builder.setInsertionPointAfter(condInsert); 4223aeb28b9SPeiming Liu return condInsert.getResult(0); 4233aeb28b9SPeiming Liu } 4243aeb28b9SPeiming Liu 4254f2ec7f9SAart Bik /// Generates insertion code to implement dynamic tensor store. 426fbe61130SAart Bik static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, 42798f93e3bSAart Bik Value rhs) { 428384049a7SAart Bik linalg::GenericOp op = env.op(); 4294f2ec7f9SAart Bik Location loc = op.getLoc(); 43084cd51bbSwren romano // Direct insertion in lexicographic coordinate order. 431384049a7SAart Bik if (!env.isExpand()) { 4325b0db27aSAart Bik const LoopId numLoops = op.getRank(t); 433372d88b0SPeiming Liu // Retrieves the first `numLoop` induction variables. 434c5a1732cSAart Bik SmallVector<Value> ivs = llvm::to_vector(llvm::drop_end( 435c5a1732cSAart Bik env.emitter().getLoopIVsRange(), env.getCurrentDepth() - numLoops)); 436384049a7SAart Bik Value chain = env.getInsertionChain(); 437047399c2SAart Bik if (env.isValidLexInsert()) { 43881d0d2b2SJim Kitchen // Generates runtime check for a valid lex during reduction, 43981d0d2b2SJim Kitchen // to avoid inserting the identity value for empty reductions. 44081d0d2b2SJim Kitchen // if (validLexInsert) then 44181d0d2b2SJim Kitchen // insert(rhs) into chain 44281d0d2b2SJim Kitchen // return updated chain 44381d0d2b2SJim Kitchen // else 44481d0d2b2SJim Kitchen // return unmodified chain 4453aeb28b9SPeiming Liu Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(), 4463aeb28b9SPeiming Liu chain, ivs, rhs); 4473aeb28b9SPeiming Liu env.updateInsertionChain(out); 448047399c2SAart Bik } else { 4493aeb28b9SPeiming Liu Value sparseOut; 4503aeb28b9SPeiming Liu if (!hasAnySparseType(env.op().getInputs().getTypes())) { 4513aeb28b9SPeiming Liu // This is an all-dense -> sparse kernel, test rhs != 0 before 4523aeb28b9SPeiming Liu // insertion. 4533aeb28b9SPeiming Liu Value nz = genIsNonzero(builder, loc, rhs); 4543aeb28b9SPeiming Liu sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs); 4553aeb28b9SPeiming Liu } else { 4563aeb28b9SPeiming Liu sparseOut = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs); 4573aeb28b9SPeiming Liu } 458047399c2SAart Bik // Generates regular insertion chain. 4593aeb28b9SPeiming Liu env.updateInsertionChain(sparseOut); 46081d0d2b2SJim Kitchen } 4614f2ec7f9SAart Bik return; 4624f2ec7f9SAart Bik } 4634f2ec7f9SAart Bik // Generates insertion code along expanded access pattern. 4644f2ec7f9SAart Bik // if (!expFilled[i]) then 4654f2ec7f9SAart Bik // expFilled[i] = true 4664f2ec7f9SAart Bik // expAdded[inserts++] = i 4674f2ec7f9SAart Bik // endif 4684f2ec7f9SAart Bik // values[i] = rhs 469384049a7SAart Bik Value values = env.getExpandValues(); 470384049a7SAart Bik Value filled = env.getExpandFilled(); 471384049a7SAart Bik Value added = env.getExpandAdded(); 472384049a7SAart Bik Value count = env.getExpandCount(); 47398f93e3bSAart Bik Value index = genIndex(env, t); 474e9fa5590SMatthias Springer Value fval = constantI1(builder, loc, false); 475e9fa5590SMatthias Springer Value tval = constantI1(builder, loc, true); 4764f2ec7f9SAart Bik // If statement. 477384049a7SAart Bik Value isFilled = builder.create<memref::LoadOp>(loc, filled, index); 478e9fa5590SMatthias Springer Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 479384049a7SAart Bik isFilled, fval); 480e9fa5590SMatthias Springer scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond, 481e9fa5590SMatthias Springer /*else=*/true); 4824f2ec7f9SAart Bik // True branch. 483e9fa5590SMatthias Springer builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 484384049a7SAart Bik builder.create<memref::StoreOp>(loc, tval, filled, index); 485384049a7SAart Bik builder.create<memref::StoreOp>(loc, index, added, count); 486e9fa5590SMatthias Springer Value one = constantIndex(builder, loc, 1); 487384049a7SAart Bik Value add = builder.create<arith::AddIOp>(loc, count, one); 488e9fa5590SMatthias Springer builder.create<scf::YieldOp>(loc, add); 4894f2ec7f9SAart Bik // False branch. 490e9fa5590SMatthias Springer builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 491384049a7SAart Bik builder.create<scf::YieldOp>(loc, count); 492e9fa5590SMatthias Springer builder.setInsertionPointAfter(ifOp); 4934f2ec7f9SAart Bik // Value assignment. 494384049a7SAart Bik env.updateExpandCount(ifOp.getResult(0)); 495384049a7SAart Bik builder.create<memref::StoreOp>(loc, rhs, values, index); 4964f2ec7f9SAart Bik } 4974f2ec7f9SAart Bik 498a2c9d4bbSAart Bik /// Generates a load on a dense or sparse tensor. 499b8cf7af9Swren romano static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) { 500a2c9d4bbSAart Bik // Test if the load was hoisted to a higher loop nest. 50198f93e3bSAart Bik Value val = env.exp(exp).val; 50226eb2c6bSPeiming Liu if (val) 503a2c9d4bbSAart Bik return val; 50465ee8f10SAart Bik // Get tensor operand. 505384049a7SAart Bik linalg::GenericOp op = env.op(); 50665ee8f10SAart Bik Location loc = op.getLoc(); 507384049a7SAart Bik OpOperand *t = &op->getOpOperand(env.exp(exp).tensor); 50865ee8f10SAart Bik // Fold binary-valued tensor into explicit value. 50965ee8f10SAart Bik const auto stt = getSparseTensorType(t->get()); 51065ee8f10SAart Bik if (auto explVal = stt.getExplicitVal()) 51165ee8f10SAart Bik return genValFromAttr(builder, loc, explVal); 51265ee8f10SAart Bik // Load during insertion. 513384049a7SAart Bik if (env.isSparseOutput(t)) { 514fbe61130SAart Bik if (env.isCustomReduc()) 515384049a7SAart Bik return genInsertionLoadReduce(env, builder, t); 516384049a7SAart Bik return genInsertionLoad(env, builder, t); 517c8bb2354SJim Kitchen } 518951a3630SPeiming Liu 519a2c9d4bbSAart Bik // Actual load. 5200e1708ffSAart Bik SmallVector<Value> args; 521384049a7SAart Bik Value ptr = genSubscript(env, builder, t, args); 522951a3630SPeiming Liu if (llvm::isa<TensorType>(ptr.getType())) { 523951a3630SPeiming Liu assert(env.options().sparseEmitStrategy == 524951a3630SPeiming Liu SparseEmitStrategy::kSparseIterator && 525951a3630SPeiming Liu args.size() == 1); 526951a3630SPeiming Liu return builder.create<ExtractValOp>(loc, ptr, args.front()); 527951a3630SPeiming Liu } 52865ee8f10SAart Bik return builder.create<memref::LoadOp>(loc, ptr, args); 529a2c9d4bbSAart Bik } 530a2c9d4bbSAart Bik 531727a63e0SAart Bik /// Generates a store on a dense or sparse tensor. 532b8cf7af9Swren romano static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, 53398f93e3bSAart Bik Value rhs) { 534e2167d89SAart Bik // Only unary and binary are allowed to return an uninitialized rhs 535e2167d89SAart Bik // to indicate missing output. Or otherwise a custom reduction that 536e2167d89SAart Bik // received no value to accumulate. 5376a38c772SAart Bik if (!rhs) { 5386a38c772SAart Bik assert(env.exp(exp).kind == TensorExp::Kind::kUnary || 539e2167d89SAart Bik env.exp(exp).kind == TensorExp::Kind::kBinary || 540e2167d89SAart Bik env.exp(exp).kind == TensorExp::Kind::kReduce); 5416a38c772SAart Bik return; 5426a38c772SAart Bik } 543a2c9d4bbSAart Bik // Test if this is a scalarized reduction. 544fbe61130SAart Bik if (env.isReduc()) { 545fbe61130SAart Bik env.updateReduc(rhs); 546a2c9d4bbSAart Bik return; 547a2c9d4bbSAart Bik } 5486a38c772SAart Bik // Regular store. 5496a38c772SAart Bik linalg::GenericOp op = env.op(); 5506a38c772SAart Bik Location loc = op.getLoc(); 551b4db15a9SAlexander Belyaev OpOperand *t = op.getDpsInitOperand(0); 5526a38c772SAart Bik if (!env.isSparseOutput(t)) { 5536a38c772SAart Bik SmallVector<Value> args; 5546a38c772SAart Bik Value ptr = genSubscript(env, builder, t, args); 5556a38c772SAart Bik builder.create<memref::StoreOp>(loc, rhs, ptr, args); 5566a38c772SAart Bik return; 5576a38c772SAart Bik } 5586a38c772SAart Bik // Store during sparse insertion. 5596a38c772SAart Bik if (env.exp(exp).kind != TensorExp::Kind::kSelect) { 5606a38c772SAart Bik genInsertionStore(env, builder, t, rhs); 5616a38c772SAart Bik return; 5626a38c772SAart Bik } 5635661647eSAart Bik // Select operation insertion. 564384049a7SAart Bik Value chain = env.getInsertionChain(); 565384049a7SAart Bik scf::IfOp ifOp = 566384049a7SAart Bik builder.create<scf::IfOp>(loc, chain.getType(), rhs, /*else=*/true); 56779193503SJim Kitchen builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 56879193503SJim Kitchen // Existing value was preserved to be used here. 56998f93e3bSAart Bik assert(env.exp(exp).val); 57098f93e3bSAart Bik Value v0 = env.exp(exp).val; 57198f93e3bSAart Bik genInsertionStore(env, builder, t, v0); 5727c7c10a0Swren romano env.merger().clearExprValue(exp); 5735661647eSAart Bik // Yield modified insertion chain along true branch. 574384049a7SAart Bik Value mchain = env.getInsertionChain(); 575384049a7SAart Bik builder.create<scf::YieldOp>(op.getLoc(), mchain); 5765661647eSAart Bik // Yield original insertion chain along false branch. 5775661647eSAart Bik builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 578384049a7SAart Bik builder.create<scf::YieldOp>(loc, chain); 5795661647eSAart Bik // Done with if statement. 580384049a7SAart Bik env.updateInsertionChain(ifOp->getResult(0)); 58179193503SJim Kitchen builder.setInsertionPointAfter(ifOp); 582a2c9d4bbSAart Bik } 583a2c9d4bbSAart Bik 584a2c9d4bbSAart Bik /// Generates an invariant value. 585b8cf7af9Swren romano inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) { 58698f93e3bSAart Bik return env.exp(exp).val; 587a2c9d4bbSAart Bik } 588a2c9d4bbSAart Bik 589c43e6274STim Harvey /// Semi-ring branches are simply inlined by the sparsifier. Prior 5902a288616SAart Bik /// analysis has verified that all computations are "local" to the inlined 5912a288616SAart Bik /// branch or otherwise invariantly defined outside the loop nest, with the 5922a288616SAart Bik /// exception of index computations, which need to be relinked to actual 5932a288616SAart Bik /// inlined cloned code. 594fbe61130SAart Bik static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, 595067bebb5SAart Bik Value e) { 59665bfd5cbSAart Bik if (auto arg = dyn_cast<BlockArgument>(e)) { 59765bfd5cbSAart Bik // Direct arguments of the original linalg op must be converted 59865bfd5cbSAart Bik // into dense tensor loads. Note that we should not encounter 59965bfd5cbSAart Bik // anything else. This needs to be verified by semi-ring ops. 60065bfd5cbSAart Bik linalg::GenericOp op = env.op(); 60165bfd5cbSAart Bik if (arg.getOwner()->getParentOp() == op) { 60265bfd5cbSAart Bik const TensorId tid = env.makeTensorId(arg.getArgNumber()); 60365bfd5cbSAart Bik OpOperand *t = &op->getOpOperand(tid); 60465bfd5cbSAart Bik assert(!getSparseTensorType(t->get()).hasEncoding()); // dense! 60565bfd5cbSAart Bik SmallVector<Value> args; 60665bfd5cbSAart Bik Value ptr = genSubscript(env, rewriter, t, args); 60765bfd5cbSAart Bik return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); 60865bfd5cbSAart Bik } 60965bfd5cbSAart Bik } else if (Operation *def = e.getDefiningOp()) { 61065bfd5cbSAart Bik // Handle index computation. 6112a288616SAart Bik if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 61246a384dfSwren romano return env.getLoopVar(env.makeLoopId(indexOp.getDim())); 61365bfd5cbSAart Bik // When still defined in new body, recurse into operands. 6142a288616SAart Bik if (def->getBlock() == block) { 61565bfd5cbSAart Bik rewriter.setInsertionPoint(def); 616ae9e1d1dSMatthias Springer for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) { 6175fcf907bSMatthias Springer rewriter.modifyOpInPlace(def, [&]() { 6182a288616SAart Bik def->setOperand( 619067bebb5SAart Bik i, relinkBranch(env, rewriter, block, def->getOperand(i))); 620ae9e1d1dSMatthias Springer }); 621ae9e1d1dSMatthias Springer } 6222a288616SAart Bik } 6232a288616SAart Bik } 6242a288616SAart Bik return e; 6252a288616SAart Bik } 6262a288616SAart Bik 627a2c9d4bbSAart Bik /// Recursively generates tensor expression. 628067bebb5SAart Bik static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) { 62974c54206Swren romano if (e == ::mlir::sparse_tensor::detail::kInvalidId) 630123e8dfcSAart Bik return Value(); 631faf7cd97SPeiming Liu 632faf7cd97SPeiming Liu linalg::GenericOp op = env.op(); 633faf7cd97SPeiming Liu Location loc = op.getLoc(); 634b8cf7af9Swren romano const TensorExp &exp = env.exp(e); 635b8cf7af9Swren romano const auto kind = exp.kind; 6361f58ae80Swren romano if (kind == TensorExp::Kind::kTensor) 637b8cf7af9Swren romano return genTensorLoad(env, rewriter, e); 6381f58ae80Swren romano if (kind == TensorExp::Kind::kInvariant) 639b8cf7af9Swren romano return genInvariantValue(env, e); 6401f58ae80Swren romano if (kind == TensorExp::Kind::kLoopVar) 641b8cf7af9Swren romano return env.getLoopVar(exp.loop); 642b22397feSAart Bik 6431f58ae80Swren romano if (kind == TensorExp::Kind::kReduce) 644b8cf7af9Swren romano env.startCustomReduc(e); // enter custom 645b22397feSAart Bik 646faf7cd97SPeiming Liu // If either lhs/rhs is a synthetic zero, we infer the type for the zero value 647faf7cd97SPeiming Liu // based on the type of the other operand. 64898ce2debSAart Bik Value v0, v1; 649faf7cd97SPeiming Liu if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId && 650faf7cd97SPeiming Liu env.exp(exp.children.e0).kind == TensorExp::Kind::kSynZero) { 651067bebb5SAart Bik v1 = genExp(env, rewriter, exp.children.e1); 652faf7cd97SPeiming Liu v0 = constantZero(rewriter, loc, v1.getType()); 653faf7cd97SPeiming Liu } else if (exp.children.e1 != ::mlir::sparse_tensor::detail::kInvalidId && 654faf7cd97SPeiming Liu env.exp(exp.children.e1).kind == TensorExp::Kind::kSynZero) { 655067bebb5SAart Bik v0 = genExp(env, rewriter, exp.children.e0); 656faf7cd97SPeiming Liu v1 = constantZero(rewriter, loc, v0.getType()); 657faf7cd97SPeiming Liu } else { 658067bebb5SAart Bik v0 = genExp(env, rewriter, exp.children.e0); 659067bebb5SAart Bik v1 = genExp(env, rewriter, exp.children.e1); 660faf7cd97SPeiming Liu } 661faf7cd97SPeiming Liu 662e2167d89SAart Bik Value ee; 663e2167d89SAart Bik if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) { 664e2167d89SAart Bik // custom reduce did not receive a value 665e2167d89SAart Bik } else { 666e2167d89SAart Bik ee = env.merger().buildExp(rewriter, loc, e, v0, v1); 6671f58ae80Swren romano if (ee && 6681f58ae80Swren romano (kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary || 6691f58ae80Swren romano kind == TensorExp::Kind::kBinaryBranch || 670e7df8281SPeiming Liu kind == TensorExp::Kind::kReduce || 671e7df8281SPeiming Liu kind == TensorExp::Kind::kSelect)) { 67265bfd5cbSAart Bik OpBuilder::InsertionGuard guard(rewriter); 673067bebb5SAart Bik ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee); 674e2167d89SAart Bik } 67565bfd5cbSAart Bik } 676b22397feSAart Bik 6771f58ae80Swren romano if (kind == TensorExp::Kind::kReduce) 678fbe61130SAart Bik env.endCustomReduc(); // exit custom 679fbe61130SAart Bik 6807c7c10a0Swren romano if (kind == TensorExp::Kind::kSelect) 6817c7c10a0Swren romano env.merger().setExprValue(e, v0); // Preserve value for later use. 682b22397feSAart Bik 6832a288616SAart Bik return ee; 684a2c9d4bbSAart Bik } 685a2c9d4bbSAart Bik 686a2c9d4bbSAart Bik /// Hoists loop invariant tensor loads for which indices have been exhausted. 687b8cf7af9Swren romano static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, 688c5a1732cSAart Bik LoopId curr, bool isStart) { 68974c54206Swren romano if (exp == ::mlir::sparse_tensor::detail::kInvalidId) 690123e8dfcSAart Bik return; 6911f58ae80Swren romano if (env.exp(exp).kind == TensorExp::Kind::kTensor) { 692a2c9d4bbSAart Bik // Inspect tensor indices. 693384049a7SAart Bik linalg::GenericOp op = env.op(); 69498f93e3bSAart Bik OpOperand &t = op->getOpOperand(env.exp(exp).tensor); 69546a384dfSwren romano const auto map = op.getMatchingIndexingMap(&t); 696f708a549Swren romano const auto stt = getSparseTensorType(t.get()); 697f708a549Swren romano const Level lvlRank = stt.getLvlRank(); 698f708a549Swren romano assert(static_cast<Level>(map.getNumResults()) == lvlRank); 699c5a1732cSAart Bik bool isCurrentLoop = curr == 0; // for scalar tensors 700f708a549Swren romano for (Level l = 0; l < lvlRank; l++) { 701ccd923e3SPeiming Liu const AffineExpr a = map.getResult(l); 702c5a1732cSAart Bik if (!isInvariantAffine(a, curr, /*out*/ isCurrentLoop)) 703a2c9d4bbSAart Bik return; // still in play 704a2c9d4bbSAart Bik } 705c5a1732cSAart Bik // All exhausted at current level. 706c5a1732cSAart Bik if (!isCurrentLoop) 7077373cabcSAart Bik return; 708d96f46ddSAart Bik // Generate code for a scalarized reduction or invariant. Note that 709d96f46ddSAart Bik // because custom reduction lhs may occur several times in the IR, 710d96f46ddSAart Bik // we have a built-in safety for only initializing and wrapping-up 711d96f46ddSAart Bik // the scalarized reduction once. 712b4db15a9SAlexander Belyaev OpOperand *lhs = op.getDpsInitOperand(0); 713a7cccb9cSAlexander Belyaev if (lhs == &t) { 71481d0d2b2SJim Kitchen // Start or end a scalarized reduction. 715c5a1732cSAart Bik if (isStart) { 716d96f46ddSAart Bik if (env.isCustomReduc()) { 717d96f46ddSAart Bik if (!env.isReduc()) 718d96f46ddSAart Bik env.startReduc(exp, env.getCustomRedId()); 719d96f46ddSAart Bik } else { 720d96f46ddSAart Bik env.startReduc(exp, genTensorLoad(env, builder, exp)); 721d96f46ddSAart Bik } 72281d0d2b2SJim Kitchen if (env.hasSparseOutput()) 723047399c2SAart Bik env.startValidLexInsert( 724047399c2SAart Bik constantI1(builder, env.op().getLoc(), false)); 7257373cabcSAart Bik } else { 726d96f46ddSAart Bik if (!env.isCustomReduc() || env.isReduc()) 727fbe61130SAart Bik genTensorStore(env, builder, exp, env.endReduc()); 728d96f46ddSAart Bik if (env.hasSparseOutput()) 729047399c2SAart Bik env.endValidLexInsert(); 7307373cabcSAart Bik } 7317373cabcSAart Bik } else { 7327373cabcSAart Bik // Start or end loop invariant hoisting of a tensor load. 733d96f46ddSAart Bik if (isStart) { 7347c7c10a0Swren romano env.merger().setExprValue(exp, genTensorLoad(env, builder, exp)); 735d96f46ddSAart Bik } else { 7367c7c10a0Swren romano env.merger().clearExprValue(exp); 737a2c9d4bbSAart Bik } 738d96f46ddSAart Bik } 7391f58ae80Swren romano } else if (env.exp(exp).kind != TensorExp::Kind::kInvariant && 740faf7cd97SPeiming Liu env.exp(exp).kind != TensorExp::Kind::kLoopVar && 741faf7cd97SPeiming Liu env.exp(exp).kind != TensorExp::Kind::kSynZero) { 742a2c9d4bbSAart Bik // Traverse into the binary operations. Note that we only hoist 743a2c9d4bbSAart Bik // tensor loads, since subsequent MLIR/LLVM passes know how to 744a2c9d4bbSAart Bik // deal with all other kinds of derived loop invariants. 7451f58ae80Swren romano if (env.exp(exp).kind == TensorExp::Kind::kReduce) 746fbe61130SAart Bik env.startCustomReduc(exp); // enter custom 747b8cf7af9Swren romano const ExprId e0 = env.exp(exp).children.e0; 748b8cf7af9Swren romano const ExprId e1 = env.exp(exp).children.e1; 749c5a1732cSAart Bik genInvariants(env, builder, e0, curr, isStart); 750c5a1732cSAart Bik genInvariants(env, builder, e1, curr, isStart); 7511f58ae80Swren romano if (env.exp(exp).kind == TensorExp::Kind::kReduce) 752fbe61130SAart Bik env.endCustomReduc(); // exit custom 753a2c9d4bbSAart Bik } 754a2c9d4bbSAart Bik } 755a2c9d4bbSAart Bik 7564f2ec7f9SAart Bik /// Generates an expanded access pattern in innermost dimension. 757c5a1732cSAart Bik static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr, 758c5a1732cSAart Bik bool isStart) { 759384049a7SAart Bik linalg::GenericOp op = env.op(); 760384049a7SAart Bik OpOperand *lhs = op.getDpsInitOperand(0); 761c5a1732cSAart Bik if (!env.atExpandLevel(lhs, op.getRank(lhs), curr)) 762c5a1732cSAart Bik return; // not needed at current level 763fbe61130SAart Bik assert(!env.isReduc()); 7645661647eSAart Bik // Generate start or end of an expanded access pattern. Note that because 7652cb99df6SYinying Li // an expansion does not rely on the ongoing contents of the sparse storage 7665661647eSAart Bik // scheme, we can use the original tensor as incoming SSA value (which 7675661647eSAart Bik // simplifies codegen a bit). If expansion on the actual contents is ever 7685661647eSAart Bik // needed, we will need to use the SSA value in the insertion chain instead. 7694f2ec7f9SAart Bik Value tensor = lhs->get(); 7704f2ec7f9SAart Bik Location loc = op.getLoc(); 771c5a1732cSAart Bik if (isStart) { 772399638f9SAliia Khasanova auto dynShape = {ShapedType::kDynamic}; 7735550c821STres Popp Type etp = cast<ShapedType>(tensor.getType()).getElementType(); 7744f2ec7f9SAart Bik Type t1 = MemRefType::get(dynShape, etp); 775e9fa5590SMatthias Springer Type t2 = MemRefType::get(dynShape, builder.getI1Type()); 776e9fa5590SMatthias Springer Type t3 = MemRefType::get(dynShape, builder.getIndexType()); 777e9fa5590SMatthias Springer Type t4 = builder.getIndexType(); 778384049a7SAart Bik auto r = builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor); 779384049a7SAart Bik assert(r.getNumResults() == 4); 780384049a7SAart Bik env.startExpand(r.getResult(0), r.getResult(1), r.getResult(2), 781384049a7SAart Bik r.getResult(3)); 7824f2ec7f9SAart Bik } else { 7830e1708ffSAart Bik SmallVector<Value> indices; 784c5a1732cSAart Bik for (LoopId i = 0; i < curr; i++) 78536fd2875SAart Bik indices.push_back(env.emitter().getLoopIV(i)); 786384049a7SAart Bik Value values = env.getExpandValues(); 787384049a7SAart Bik Value filled = env.getExpandFilled(); 788384049a7SAart Bik Value added = env.getExpandAdded(); 789384049a7SAart Bik Value count = env.getExpandCount(); 790384049a7SAart Bik Value chain = env.getInsertionChain(); 791384049a7SAart Bik Value compress = builder.create<CompressOp>(loc, values, filled, added, 792384049a7SAart Bik count, chain, indices); 793384049a7SAart Bik env.updateInsertionChain(compress); 794384049a7SAart Bik env.endExpand(); 7954f2ec7f9SAart Bik } 7964f2ec7f9SAart Bik } 7974f2ec7f9SAart Bik 798b0f8057eSPeiming Liu /// Returns parallelization strategy. Any implicit loop in the Linalg 799b0f8057eSPeiming Liu /// operation that is marked "parallel" is a candidate. Whether it is actually 800b0f8057eSPeiming Liu /// converted to a parallel operation depends on the requested strategy. 801fbe61130SAart Bik static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) { 80253cc3a06SAart Bik // Reject parallelization of sparse output. 803384049a7SAart Bik if (env.hasSparseOutput()) 80453cc3a06SAart Bik return false; 80575ac294bSPeiming Liu // Parallel loops on tensor expansion can cause data races. 806384049a7SAart Bik if (env.isExpand()) 80775ac294bSPeiming Liu return false; 80853cc3a06SAart Bik // Inspect strategy. 809384049a7SAart Bik switch (env.options().parallelizationStrategy) { 810a2c9d4bbSAart Bik case SparseParallelizationStrategy::kNone: 811a2c9d4bbSAart Bik return false; 812a2c9d4bbSAart Bik case SparseParallelizationStrategy::kDenseOuterLoop: 81375ac294bSPeiming Liu return isOuter && !isSparse; 814a2c9d4bbSAart Bik case SparseParallelizationStrategy::kAnyStorageOuterLoop: 81575ac294bSPeiming Liu return isOuter; 816a2c9d4bbSAart Bik case SparseParallelizationStrategy::kDenseAnyLoop: 81775ac294bSPeiming Liu return !isSparse; 818a2c9d4bbSAart Bik case SparseParallelizationStrategy::kAnyStorageAnyLoop: 81975ac294bSPeiming Liu return true; 820a2c9d4bbSAart Bik } 821a2c9d4bbSAart Bik llvm_unreachable("unexpected parallelization strategy"); 822a2c9d4bbSAart Bik } 823a2c9d4bbSAart Bik 824fd68d361SPeiming Liu /// Whether or not the current loop being generated should be parallized (if 825fd68d361SPeiming Liu /// possible) according to the configuration. 826c5a1732cSAart Bik static bool shouldTryParallize(CodegenEnv &env, LoopId curr, 82736c95ee7SPeiming Liu ArrayRef<TensorLevel> tidLvls) { 828384049a7SAart Bik linalg::GenericOp op = env.op(); 82998f93e3bSAart Bik auto iteratorTypes = op.getIteratorTypesArray(); 830c5a1732cSAart Bik bool isSparse = llvm::any_of(tidLvls, [curr, &env](TensorLevel tidLvl) { 83198ce2debSAart Bik // Queries the LT based on the tensor and loop id, as requested by 83298ce2debSAart Bik // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv 8331dd387e1SAart Bik // should be consistent with the LT indexed by <TensorId, Level>. 834c5a1732cSAart Bik const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, curr); 83552b69aa3SPeiming Liu return lt.hasSparseSemantic(); 83649be68b8SPeiming Liu }); 837c5a1732cSAart Bik return isParallelFor(env, /*isOuter=*/curr == 0, isSparse); 838fd68d361SPeiming Liu } 839a2c9d4bbSAart Bik 840fd68d361SPeiming Liu /// Emit a loop to coiterate over the list of tensor levels. The generated loop 841fd68d361SPeiming Liu /// can either be a for loop or while loop depending on whether there is at most 842fd68d361SPeiming Liu /// one sparse level in the list. 843fd68d361SPeiming Liu static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder, 84498ce2debSAart Bik ArrayRef<TensorLevel> tidLvls, 845*c4420257SPeiming Liu unsigned numCases, bool tryParallel, 846*c4420257SPeiming Liu bool needsUniv) { 8478109d5e9SAart Bik Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) { 84898ce2debSAart Bik // Construct while-loop with a parameter for each index. 849b8cf7af9Swren romano return env.emitter().enterCoIterationOverTensorsAtLvls( 850*c4420257SPeiming Liu builder, env.op().getLoc(), tidLvls, numCases, reduc, tryParallel, 851*c4420257SPeiming Liu needsUniv); 85276b11003SFangrui Song }); 85375ac294bSPeiming Liu assert(loop); 854b0f8057eSPeiming Liu return loop; 855a2c9d4bbSAart Bik } 856a2c9d4bbSAart Bik 857a2c9d4bbSAart Bik /// Generates a for-loop or a while-loop, depending on whether it implements 858a2c9d4bbSAart Bik /// singleton iteration or co-iteration over the given conjunction. 859c5a1732cSAart Bik static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, 860*c4420257SPeiming Liu unsigned numCases, bool needsUniv, 861*c4420257SPeiming Liu ArrayRef<TensorLevel> tidLvls) { 862c5a1732cSAart Bik bool tryParallel = shouldTryParallize(env, curr, tidLvls); 863*c4420257SPeiming Liu return genCoIteration(env, builder, tidLvls, numCases, tryParallel, 864*c4420257SPeiming Liu needsUniv); 865a2c9d4bbSAart Bik } 866a2c9d4bbSAart Bik 867a2c9d4bbSAart Bik /// Generates the induction structure for a while-loop. 86898ce2debSAart Bik static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, 869372d88b0SPeiming Liu bool needsUniv) { 870384049a7SAart Bik Location loc = env.op().getLoc(); 8717373cabcSAart Bik // Finalize each else branch of all if statements. 872384049a7SAart Bik if (env.isReduc() || env.isExpand() || env.getInsertionChain()) { 8737373cabcSAart Bik while (auto ifOp = dyn_cast_or_null<scf::IfOp>( 874e9fa5590SMatthias Springer builder.getInsertionBlock()->getParentOp())) { 875e2e83f4cSPeiming Liu // Break on IfOp for slicing filtering. 876e2e83f4cSPeiming Liu if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) == 877e2e83f4cSPeiming Liu StringAttr::get(ifOp->getContext(), "slice")) 878e2e83f4cSPeiming Liu break; 879e2e83f4cSPeiming Liu 8804f2ec7f9SAart Bik unsigned y = 0; 8810e1708ffSAart Bik SmallVector<Value> yields; 882fbe61130SAart Bik if (env.isReduc()) { 883fbe61130SAart Bik yields.push_back(env.getReduc()); 884fbe61130SAart Bik env.updateReduc(ifOp.getResult(y++)); 885047399c2SAart Bik if (env.isValidLexInsert()) { 88681d0d2b2SJim Kitchen yields.push_back(env.getValidLexInsert()); 887047399c2SAart Bik env.updateValidLexInsert(ifOp.getResult(y++)); 88881d0d2b2SJim Kitchen } 8894f2ec7f9SAart Bik } 890384049a7SAart Bik if (env.isExpand()) { 891384049a7SAart Bik yields.push_back(env.getExpandCount()); 892384049a7SAart Bik env.updateExpandCount(ifOp->getResult(y++)); 8934f2ec7f9SAart Bik } 894384049a7SAart Bik if (env.getInsertionChain()) { 895384049a7SAart Bik yields.push_back(env.getInsertionChain()); 896384049a7SAart Bik env.updateInsertionChain(ifOp->getResult(y++)); 8975661647eSAart Bik } 8984f2ec7f9SAart Bik assert(y == yields.size()); 899e9fa5590SMatthias Springer builder.create<scf::YieldOp>(loc, yields); 900e9fa5590SMatthias Springer builder.setInsertionPointAfter(ifOp); 9017373cabcSAart Bik } 9027373cabcSAart Bik } 903372d88b0SPeiming Liu // No need to set the insertion point here as LoopEmitter keeps track of the 904372d88b0SPeiming Liu // basic block where scf::Yield should be inserted. 905a2c9d4bbSAart Bik } 906a2c9d4bbSAart Bik 907*c4420257SPeiming Liu /// Generates a case region in the coiterate operation. 908*c4420257SPeiming Liu static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder, 909*c4420257SPeiming Liu unsigned caseIdx, LatPointId allCase, 910*c4420257SPeiming Liu LatPointId curCase, 911*c4420257SPeiming Liu MutableArrayRef<Value> reduc) { 912*c4420257SPeiming Liu assert(allCase == curCase || env.merger().latGT(allCase, curCase)); 913*c4420257SPeiming Liu const BitVector &allCaseBits = env.merger().lat(allCase).simple; 914*c4420257SPeiming Liu const BitVector &curCaseBits = env.merger().lat(curCase).simple; 915*c4420257SPeiming Liu 916*c4420257SPeiming Liu /// Computes the subset of iterators that are valid in the current case being 917*c4420257SPeiming Liu /// generated. 918*c4420257SPeiming Liu I64BitSet caseBit(0); 919*c4420257SPeiming Liu for (auto [idx, set] : llvm::enumerate(allCaseBits.set_bits())) 920*c4420257SPeiming Liu if (curCaseBits.test(set)) 921*c4420257SPeiming Liu caseBit.set(idx); 922*c4420257SPeiming Liu 923*c4420257SPeiming Liu env.emitter().enterCurrentCoIterationCase(builder, env.op().getLoc(), caseBit, 924*c4420257SPeiming Liu caseIdx, reduc); 925*c4420257SPeiming Liu } 926*c4420257SPeiming Liu 927a2c9d4bbSAart Bik /// Generates a single if-statement within a while-loop. 928c5a1732cSAart Bik static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr, 92946a384dfSwren romano LatPointId p) { 930384049a7SAart Bik Location loc = env.op().getLoc(); 9310e1708ffSAart Bik SmallVector<Type> types; 932a2c9d4bbSAart Bik Value cond; 93346a384dfSwren romano env.merger().foreachTensorLoopId( 93446a384dfSwren romano p, /*simple=*/true, 9351944c4f7SAart Bik [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt, 9361944c4f7SAart Bik bool isIdxRed) { 937fc5d8fceSPeiming Liu if (isIdxRed) { 938fc5d8fceSPeiming Liu // Since there is no 1:1 mapping from loop to level (multiple loops 939fc5d8fceSPeiming Liu // are required to resolve one level with non-trivial index 940fc5d8fceSPeiming Liu // expression), we need to reconstruct the tensor level types if this 941fc5d8fceSPeiming Liu // loop requires index reduction condition. 9421dd387e1SAart Bik assert(lvl.has_value() && isUndefLT(lt)); 943fc5d8fceSPeiming Liu auto stt = getSparseTensorType(env.op().getInputs()[tid]); 9441dd387e1SAart Bik lt = stt.getLvlType(*lvl); 945fc5d8fceSPeiming Liu } 946c5a1732cSAart Bik assert(curr == env.merger().loop(b)); 947a2c9d4bbSAart Bik Value clause; 94852b69aa3SPeiming Liu if (lt.hasSparseSemantic()) { 94946a384dfSwren romano assert(lvl.has_value()); 950298412b5SPeiming Liu const Value crd = env.emitter().getCoord(tid, *lvl); 951c5a1732cSAart Bik const Value lvar = env.getLoopVar(curr); 95246a384dfSwren romano clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 95346a384dfSwren romano crd, lvar); 954a2c9d4bbSAart Bik } else { 95552b69aa3SPeiming Liu assert(lt.hasDenseSemantic() || isUndefLT(lt)); 956e9fa5590SMatthias Springer clause = constantI1(builder, loc, true); 957a2c9d4bbSAart Bik } 958e9fa5590SMatthias Springer cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause; 95946a384dfSwren romano }); 96081d0d2b2SJim Kitchen if (env.isReduc()) { 961fbe61130SAart Bik types.push_back(env.getReduc().getType()); 962047399c2SAart Bik if (env.isValidLexInsert()) 96381d0d2b2SJim Kitchen types.push_back(env.getValidLexInsert().getType()); 96481d0d2b2SJim Kitchen } 965384049a7SAart Bik if (env.isExpand()) 966e9fa5590SMatthias Springer types.push_back(builder.getIndexType()); 967384049a7SAart Bik if (env.getInsertionChain()) 968384049a7SAart Bik types.push_back(env.getInsertionChain().getType()); 969e9fa5590SMatthias Springer scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); 970e9fa5590SMatthias Springer builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 971a2c9d4bbSAart Bik return ifOp; 972a2c9d4bbSAart Bik } 973a2c9d4bbSAart Bik 9747373cabcSAart Bik /// Generates end of true branch of if-statement within a while-loop. 975fbe61130SAart Bik static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, 976372d88b0SPeiming Liu Value redInput, Value cntInput, Value insInput, 977372d88b0SPeiming Liu Value validIns) { 9780e1708ffSAart Bik SmallVector<Value> operands; 979fbe61130SAart Bik if (env.isReduc()) { 980fbe61130SAart Bik operands.push_back(env.getReduc()); 981fbe61130SAart Bik env.updateReduc(redInput); 982047399c2SAart Bik if (env.isValidLexInsert()) { 98381d0d2b2SJim Kitchen // Any overlapping indices during a reduction creates a valid lex insert. 98481d0d2b2SJim Kitchen operands.push_back(constantI1(builder, env.op().getLoc(), true)); 985047399c2SAart Bik env.updateValidLexInsert(validIns); 986fc5d8fceSPeiming Liu } 9877373cabcSAart Bik } 988384049a7SAart Bik if (env.isExpand()) { 989384049a7SAart Bik operands.push_back(env.getExpandCount()); 990384049a7SAart Bik env.updateExpandCount(cntInput); 9914f2ec7f9SAart Bik } 992384049a7SAart Bik if (env.getInsertionChain()) { 993384049a7SAart Bik operands.push_back(env.getInsertionChain()); 994384049a7SAart Bik env.updateInsertionChain(insInput); 9955661647eSAart Bik } 9964f2ec7f9SAart Bik if (!operands.empty()) 997384049a7SAart Bik builder.create<scf::YieldOp>(env.op().getLoc(), operands); 998e9fa5590SMatthias Springer builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 9997373cabcSAart Bik } 10007373cabcSAart Bik 1001c8d5dcb0SAart Bik //===----------------------------------------------------------------------===// 1002c43e6274STim Harvey // Sparsifier synthesis methods (loop sequence). 1003c8d5dcb0SAart Bik //===----------------------------------------------------------------------===// 1004c8d5dcb0SAart Bik 1005d933b88bSPeiming Liu static bool getAllTidLvlsInLatPoints( 1006d933b88bSPeiming Liu CodegenEnv &env, LatPointId li, LoopId curr, 1007d933b88bSPeiming Liu llvm::function_ref<void(TensorLevel, AffineExpr)> callback) { 1008d933b88bSPeiming Liu const BitVector &simple = env.lat(li).simple; 1009d933b88bSPeiming Liu const TensorId outTid = env.merger().getOutTensorID(); 1010d933b88bSPeiming Liu const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr); 1011d933b88bSPeiming Liu 1012d933b88bSPeiming Liu unsigned numloopCond = 0; 1013d933b88bSPeiming Liu bool hasNonUnique = false; 1014d933b88bSPeiming Liu env.merger().foreachTensorLoopId( 1015d933b88bSPeiming Liu li, [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl, 1016d933b88bSPeiming Liu LevelType lt, bool isIdxReduc) { 1017d933b88bSPeiming Liu if (simple[b]) { 1018d933b88bSPeiming Liu if (isIdxReduc) { 1019d933b88bSPeiming Liu callback(env.makeTensorLevel(tid, *lvl), nullptr); 1020d933b88bSPeiming Liu numloopCond++; 1021d933b88bSPeiming Liu return; 1022d933b88bSPeiming Liu } 1023d933b88bSPeiming Liu if (isUndefLT(lt)) { 1024d933b88bSPeiming Liu // An undefined lt in the lattices, we probably mean to 1025d933b88bSPeiming Liu // generate a dense loop according to the synthetic tensor (for 1026d933b88bSPeiming Liu // invariants and sparse output tensor). 1027d933b88bSPeiming Liu if (env.merger().getSynTensorID() == tid) { 1028d933b88bSPeiming Liu // Coiterating with an invariant 1029d933b88bSPeiming Liu // e.g., out = prod(in[i][j] op invariant); 1030d933b88bSPeiming Liu // or a broadcast 1031d933b88bSPeiming Liu // e.g., out[i][j] = in[i] (j is undef for input) 1032d933b88bSPeiming Liu // 1033d933b88bSPeiming Liu // The level of the synthetic tensor is the current loop depth; 1034d933b88bSPeiming Liu // the rank of the synthetic tensor equals to number of loops. 1035d933b88bSPeiming Liu assert(curr == env.getCurrentDepth()); 1036d933b88bSPeiming Liu lvl = curr; 1037d933b88bSPeiming Liu } else if (!lvl) { 1038d933b88bSPeiming Liu // Skips invalid lvl (e.g., when this is a zero ranked tensor). 1039d933b88bSPeiming Liu return; 1040d933b88bSPeiming Liu } 1041d933b88bSPeiming Liu } 1042d933b88bSPeiming Liu hasNonUnique = !isUniqueLT(lt) || hasNonUnique; 1043d933b88bSPeiming Liu callback(env.makeTensorLevel(tid, *lvl), nullptr); 1044d933b88bSPeiming Liu numloopCond++; 104552b69aa3SPeiming Liu } else if (lt.hasDenseSemantic() || isIdxReduc) { 1046d933b88bSPeiming Liu callback(env.makeTensorLevel(tid, *lvl), nullptr); 1047d933b88bSPeiming Liu } else { 1048d933b88bSPeiming Liu assert(isUndefLT(lt)); 1049d933b88bSPeiming Liu linalg::GenericOp op = env.op(); 1050d933b88bSPeiming Liu if (tid >= op.getNumDpsInputs()) 1051d933b88bSPeiming Liu // We only handle affine expression on input tensors (for now). 1052d933b88bSPeiming Liu return; 1053d933b88bSPeiming Liu OpOperand *operand = &op->getOpOperand(tid); 1054d933b88bSPeiming Liu const auto stt = getSparseTensorType(operand->get()); 1055d933b88bSPeiming Liu // Non-annotated dense tensors requires no special handling. 1056d933b88bSPeiming Liu if (!stt.hasEncoding()) 1057d933b88bSPeiming Liu return; 1058d933b88bSPeiming Liu 1059d933b88bSPeiming Liu ArrayRef<AffineExpr> affines = 1060d933b88bSPeiming Liu op.getMatchingIndexingMap(operand).getResults(); 1061d933b88bSPeiming Liu const Level lvlRank = stt.getLvlRank(); 1062d933b88bSPeiming Liu assert(affines.size() == static_cast<size_t>(lvlRank)); 1063d933b88bSPeiming Liu for (Level l = 0; l < lvlRank; l++) { 1064d933b88bSPeiming Liu AffineExpr exp = affines[l]; 1065d933b88bSPeiming Liu // Skip simple affine expression and non-dense levels (which 1066d933b88bSPeiming Liu // have their own filter loop). 106752b69aa3SPeiming Liu LevelType lt = stt.getLvlType(l); 106852b69aa3SPeiming Liu if (isa<AffineDimExpr>(exp) || !lt.hasDenseSemantic()) 1069d933b88bSPeiming Liu continue; 1070d933b88bSPeiming Liu 1071d933b88bSPeiming Liu // Constant affine expression are handled in genLoop. 1072d933b88bSPeiming Liu if (!isa<AffineConstantExpr>(exp)) { 1073d933b88bSPeiming Liu bool isCurrentLoop = false; 1074d933b88bSPeiming Liu assert(curr == env.getCurrentDepth()); 1075d933b88bSPeiming Liu if (isInvariantAffine(exp, curr + 1, /*out*/ isCurrentLoop) && 1076d933b88bSPeiming Liu isCurrentLoop) { 1077d933b88bSPeiming Liu // If the compound affine is invariant and we are right at the 1078d933b88bSPeiming Liu // level. We need to generate the address according to the 1079d933b88bSPeiming Liu // affine expression. This is also the best place we can do it 1080d933b88bSPeiming Liu // to avoid putting it inside inner loops. 1081d933b88bSPeiming Liu callback(env.makeTensorLevel(tid, l), exp); 1082d933b88bSPeiming Liu } 1083d933b88bSPeiming Liu } 1084d933b88bSPeiming Liu } 1085d933b88bSPeiming Liu } 1086d933b88bSPeiming Liu }); 1087d933b88bSPeiming Liu 1088d933b88bSPeiming Liu if (isDenseLT(env.lt(outTid, curr))) { 1089298412b5SPeiming Liu auto stt = getSparseTensorType(env.op().getOutputs().front()); 1090298412b5SPeiming Liu // Note that we generate dense indices of the output tensor unconditionally, 1091298412b5SPeiming Liu // since they may not appear in the lattice, but may be needed for 1092298412b5SPeiming Liu // linearized env. 1093298412b5SPeiming Liu // TODO: we should avoid introducing corner cases for all-dense sparse 1094298412b5SPeiming Liu // tensors. 1095298412b5SPeiming Liu if (stt.hasEncoding() && stt.isAllDense()) 1096d933b88bSPeiming Liu callback(env.makeTensorLevel(outTid, *outLvl), nullptr); 1097d933b88bSPeiming Liu } 1098d933b88bSPeiming Liu 1099d933b88bSPeiming Liu if (numloopCond == 0) { 1100d933b88bSPeiming Liu // Corner cases where the loop bound is defined by a *unused* operand, in 1101d933b88bSPeiming Liu // this case, we just generate a dense "fake" loop by iterating over the 1102d933b88bSPeiming Liu // synthetic tensor. 1103d933b88bSPeiming Liu callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr); 1104d933b88bSPeiming Liu numloopCond++; 1105d933b88bSPeiming Liu } 1106d933b88bSPeiming Liu // If we just need to one loop conditions and the conditions is not imposed on 1107d933b88bSPeiming Liu // non-unique level, the loop can be generated by a for loop. 1108a02010b3SPeiming Liu // Or, if we are generating sparse-iterator-based loops, we always generate 1109a02010b3SPeiming Liu // `sparse_tensor.iterate` regardless whether the level is unique or not. 1110a02010b3SPeiming Liu return numloopCond == 1 && 1111a02010b3SPeiming Liu (!hasNonUnique || env.options().sparseEmitStrategy == 1112a02010b3SPeiming Liu SparseEmitStrategy::kSparseIterator); 1113d933b88bSPeiming Liu } 1114d933b88bSPeiming Liu 1115c8d5dcb0SAart Bik /// Starts a loop sequence at given level. Returns true if 1116c8d5dcb0SAart Bik /// the universal loop index must be maintained at this level. 1117b8cf7af9Swren romano static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, 1118c5a1732cSAart Bik LoopId curr, LatSetId lts) { 1119c5a1732cSAart Bik assert(!env.getLoopVar(curr)); 1120c8d5dcb0SAart Bik // Emit invariants at this loop sequence level. 1121c5a1732cSAart Bik genInvariants(env, builder, exp, curr, /*isStart=*/true); 11224f2ec7f9SAart Bik // Emit access pattern expansion for sparse tensor output. 1123c5a1732cSAart Bik genExpand(env, builder, curr, /*isStart=*/true); 1124d933b88bSPeiming Liu // Emit further initialization at this loop sequence level. 1125b8cf7af9Swren romano const LatPointId l0 = env.set(lts)[0]; 1126b0f8057eSPeiming Liu 112736c95ee7SPeiming Liu SmallVector<TensorLevel> tidLvls; 1128d933b88bSPeiming Liu getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) { 1129298412b5SPeiming Liu // TODO: remove this! The same tensor level might be added for multiple 1130298412b5SPeiming Liu // times due to the special handling for all-dense "sparse" output tensor 1131298412b5SPeiming Liu // (see L1038). 1132298412b5SPeiming Liu if (llvm::find(tidLvls, tl) != tidLvls.end()) 1133298412b5SPeiming Liu return; 1134d933b88bSPeiming Liu tidLvls.emplace_back(tl); 113532c512e4SPeiming Liu }); 1136b0f8057eSPeiming Liu 113736c95ee7SPeiming Liu env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls); 1138b0f8057eSPeiming Liu 1139c8d5dcb0SAart Bik // Maintain the universal index only if it is actually 1140c8d5dcb0SAart Bik // consumed by a subsequent lattice point. 11417c7c10a0Swren romano for (const LatPointId li : env.set(lts).drop_front()) 11425fd9d801SPeiming Liu if (!env.merger().hasAnySparse(env.lat(li).simple)) 1143c8d5dcb0SAart Bik return true; 1144d933b88bSPeiming Liu 1145c8d5dcb0SAart Bik return false; 1146c8d5dcb0SAart Bik } 1147c8d5dcb0SAart Bik 114898ce2debSAart Bik // Generates dense affine address for encoding. 1149fbe61130SAart Bik static void genConstantDenseAddressFromLevel(CodegenEnv &env, 1150b8cf7af9Swren romano OpBuilder &builder, TensorId tid, 1151b8cf7af9Swren romano Level startLvl) { 11524fd3a120SPeiming Liu // TODO: Handle affine expression on output tensor. 1153384049a7SAart Bik linalg::GenericOp op = env.op(); 11544fd3a120SPeiming Liu assert(tid < op.getNumDpsInputs()); 11554fd3a120SPeiming Liu OpOperand *input = op.getDpsInputOperands()[tid]; 1156b8cf7af9Swren romano const auto lvlExprs = op.getMatchingIndexingMap(input).getResults(); 1157f708a549Swren romano const auto enc = getSparseTensorEncoding(input->get().getType()); 11584fd3a120SPeiming Liu if (enc) { 1159b8cf7af9Swren romano const Location loc = op.getLoc(); 116046a384dfSwren romano const TensorId tid = env.makeTensorId(input->getOperandNumber()); 1161f708a549Swren romano const Level lvlRank = enc.getLvlRank(); 1162b8cf7af9Swren romano assert(lvlExprs.size() == static_cast<size_t>(lvlRank)); 1163b8cf7af9Swren romano for (Level l = startLvl; l < lvlRank; l++) { 1164ccd923e3SPeiming Liu AffineExpr lvlExpr = lvlExprs[l]; 116552b69aa3SPeiming Liu if (enc.getLvlType(l).hasDenseSemantic() && 116652b69aa3SPeiming Liu isa<AffineConstantExpr>(lvlExpr)) 1167298412b5SPeiming Liu env.emitter().locateLvlAtAffineAddress( 116836c95ee7SPeiming Liu builder, loc, env.makeTensorLevel(tid, l), lvlExpr); 116998f93e3bSAart Bik else 117098f93e3bSAart Bik return; // break on first non-dense non-constant level 11714fd3a120SPeiming Liu } 11724fd3a120SPeiming Liu } 11734fd3a120SPeiming Liu } 11744fd3a120SPeiming Liu 117598f93e3bSAart Bik // We can generate address for constant affine expression before any loops 11764fd3a120SPeiming Liu // starting from the first level as they do not depend on anything. 11774fd3a120SPeiming Liu // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two 11784fd3a120SPeiming Liu // levels can be determined before loops. 117998ce2debSAart Bik static void genInitConstantDenseAddress(CodegenEnv &env, 118098ce2debSAart Bik RewriterBase &rewriter) { 1181b8cf7af9Swren romano for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++) 118298f93e3bSAart Bik genConstantDenseAddressFromLevel(env, rewriter, tid, 0); 11834fd3a120SPeiming Liu } 11844fd3a120SPeiming Liu 1185d933b88bSPeiming Liu /// Returns true if the lattice bit can be iterated by a for loop. 1186b8cf7af9Swren romano static bool translateBitsToTidLvlPairs( 1187c5a1732cSAart Bik CodegenEnv &env, LatPointId li, LoopId curr, 118836c95ee7SPeiming Liu SmallVectorImpl<TensorLevel> &tidLvls, 118936c95ee7SPeiming Liu SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) { 1190d933b88bSPeiming Liu return getAllTidLvlsInLatPoints(env, li, curr, 1191d933b88bSPeiming Liu [&](TensorLevel tl, AffineExpr exp) { 1192d933b88bSPeiming Liu if (exp) 1193d933b88bSPeiming Liu affineTidLvls.emplace_back(tl, exp); 1194d933b88bSPeiming Liu else 1195d933b88bSPeiming Liu tidLvls.emplace_back(tl); 119632c512e4SPeiming Liu }); 1197b0f8057eSPeiming Liu } 1198b0f8057eSPeiming Liu 1199c8d5dcb0SAart Bik /// Starts a single loop in current sequence. 12001328bb6eSPeiming Liu static std::pair<Operation *, bool> startLoop(CodegenEnv &env, 1201c5a1732cSAart Bik OpBuilder &builder, LoopId curr, 1202*c4420257SPeiming Liu LatPointId li, unsigned numCases, 1203*c4420257SPeiming Liu bool needsUniv) { 1204*c4420257SPeiming Liu // TODO: numCases only used when generating iterator-based loops. Cleanup 1205*c4420257SPeiming Liu // after fully migration. 1206b8cf7af9Swren romano // The set of tensors + lvls to generate loops on 120736c95ee7SPeiming Liu SmallVector<TensorLevel> tidLvls; 120898ce2debSAart Bik 1209fb287335SPeiming Liu // The set of dense tensors with non-trivial affine expression that just 121098ce2debSAart Bik // becomes invariant and the address are generated at the current level. 121136c95ee7SPeiming Liu SmallVector<std::pair<TensorLevel, AffineExpr>> affineTidLvls; 1212ff8815e5SPeiming Liu bool isSingleCond = 1213c5a1732cSAart Bik translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls); 1214b0f8057eSPeiming Liu 1215c8d5dcb0SAart Bik // Emit the for/while-loop control. 1216*c4420257SPeiming Liu Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls); 1217b8cf7af9Swren romano Location loc = env.op().getLoc(); 121836c95ee7SPeiming Liu for (auto [tidLvl, exp] : affineTidLvls) { 1219298412b5SPeiming Liu env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp); 1220fb287335SPeiming Liu } 12214fd3a120SPeiming Liu 1222b8cf7af9Swren romano // Until now, we have entered every <tid, lvl> pair in {cond, extra, 1223b8cf7af9Swren romano // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent 12244fd3a120SPeiming Liu // on constant affines expression may now be determined. 122536c95ee7SPeiming Liu auto allTidLvls = 122636c95ee7SPeiming Liu llvm::concat<TensorLevel>(tidLvls, llvm::make_first_range(affineTidLvls)); 122736c95ee7SPeiming Liu for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) { 1228e7b4c93fSPeiming Liu if (tid != env.merger().getOutTensorID() && 1229e7b4c93fSPeiming Liu tid != env.merger().getSynTensorID()) 1230b8cf7af9Swren romano genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1); 12314fd3a120SPeiming Liu } 12324fd3a120SPeiming Liu 12331328bb6eSPeiming Liu return std::make_pair(loop, isSingleCond); 1234c8d5dcb0SAart Bik } 1235c8d5dcb0SAart Bik 1236c8d5dcb0SAart Bik /// Ends a single loop in current sequence. Returns new values for needsUniv. 1237fbe61130SAart Bik static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, 123898ce2debSAart Bik LatPointId li, bool needsUniv, bool isSingleCond) { 12395fd9d801SPeiming Liu // Either a for-loop or a while-loop that iterates over a slice. 124098ce2debSAart Bik if (isSingleCond) { 12415fd9d801SPeiming Liu // Any iteration creates a valid lex insert. 1242047399c2SAart Bik if (env.isReduc() && env.isValidLexInsert()) 1243047399c2SAart Bik env.updateValidLexInsert(constantI1(rewriter, env.op().getLoc(), true)); 12445fd9d801SPeiming Liu } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { 12455fd9d801SPeiming Liu // End a while-loop. 124698ce2debSAart Bik finalizeWhileOp(env, rewriter, needsUniv); 1247b0f8057eSPeiming Liu } else { 1248b0f8057eSPeiming Liu needsUniv = false; 1249c8d5dcb0SAart Bik } 12508109d5e9SAart Bik env.genLoopBoundary([&](MutableArrayRef<Value> reduc) { 125136fd2875SAart Bik env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc); 12521a36588eSKazu Hirata return std::nullopt; 125375ac294bSPeiming Liu }); 1254b0f8057eSPeiming Liu return needsUniv; 1255c8d5dcb0SAart Bik } 1256c8d5dcb0SAart Bik 1257c8d5dcb0SAart Bik /// Ends a loop sequence at given level. 12585fd9d801SPeiming Liu static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, 125998ce2debSAart Bik unsigned at) { 126098ce2debSAart Bik assert(!env.getLoopVar(at)); 12615fd9d801SPeiming Liu env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc()); 12627373cabcSAart Bik // Unmark bookkeeping of invariants and loop index. 1263c5a1732cSAart Bik genInvariants(env, builder, exp, at, /*isStart=*/false); 12644f2ec7f9SAart Bik // Finalize access pattern expansion for sparse tensor output. 1265c5a1732cSAart Bik genExpand(env, builder, at, /*isStart=*/false); 1266c8d5dcb0SAart Bik } 1267c8d5dcb0SAart Bik 1268a2c9d4bbSAart Bik /// Recursively generates code while computing iteration lattices in order 1269a2c9d4bbSAart Bik /// to manage the complexity of implementing co-iteration over unions 1270a2c9d4bbSAart Bik /// and intersections of sparse iterations spaces. 1271b8cf7af9Swren romano static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, 1272c5a1732cSAart Bik LoopId curr) { 1273c5a1732cSAart Bik assert(curr == env.getCurrentDepth()); 127498ce2debSAart Bik 1275a2c9d4bbSAart Bik // At each leaf, assign remaining tensor (sub)expression to output tensor. 1276c5a1732cSAart Bik if (curr == env.getLoopNum()) { 1277067bebb5SAart Bik Value rhs = genExp(env, rewriter, exp); 127898f93e3bSAart Bik genTensorStore(env, rewriter, exp, rhs); 1279a2c9d4bbSAart Bik return; 1280a2c9d4bbSAart Bik } 1281a2c9d4bbSAart Bik 128298ce2debSAart Bik // Construct iteration lattices for current loop index. 1283b8cf7af9Swren romano const LatSetId lts = 1284c5a1732cSAart Bik env.merger().optimizeSet(env.merger().buildLattices(exp, curr)); 1285a2c9d4bbSAart Bik 1286c8d5dcb0SAart Bik // Start a loop sequence. 1287c5a1732cSAart Bik bool needsUniv = startLoopSeq(env, rewriter, exp, curr, lts); 1288c8d5dcb0SAart Bik 1289*c4420257SPeiming Liu // When using sparse-iterator-based loops, we only need one loops, as 1290*c4420257SPeiming Liu // opposed to a loop sequence, to cover all the iterator spaces. 1291*c4420257SPeiming Liu const unsigned lsize = env.set(lts).size(); 1292*c4420257SPeiming Liu if (env.generatingSparseIterator()) { 1293*c4420257SPeiming Liu // Get the largest lattice point and start a loop. 1294*c4420257SPeiming Liu const LatPointId li = env.set(lts)[0]; 1295*c4420257SPeiming Liu auto [loop, isSingleCond] = 1296*c4420257SPeiming Liu startLoop(env, rewriter, curr, li, lsize, needsUniv); 1297*c4420257SPeiming Liu assert(isSingleCond == llvm::isa<IterateOp>(loop)); 1298067bebb5SAart Bik // We cannot change this to `for (const LatPointId li : env.set(lts))` 12997c7c10a0Swren romano // because the loop body causes data-movement which invalidates 13007c7c10a0Swren romano // the iterator. 1301*c4420257SPeiming Liu for (unsigned j = 0; j < lsize; j++) { 1302*c4420257SPeiming Liu const LatPointId lj = env.set(lts)[j]; 1303*c4420257SPeiming Liu const ExprId ej = env.lat(lj).exp; 1304*c4420257SPeiming Liu // Recurse into body of each branch. 1305*c4420257SPeiming Liu if (!isSingleCond) { 1306*c4420257SPeiming Liu env.genLoopBoundary([&, curr, j, li, lj](MutableArrayRef<Value> reduc) { 1307*c4420257SPeiming Liu genCoIterationCase(env, rewriter, /*caseIdx*/ j, li, lj, reduc); 1308*c4420257SPeiming Liu genStmt(env, rewriter, ej, curr + 1); 1309*c4420257SPeiming Liu // TODO: handle yield values. 1310*c4420257SPeiming Liu assert(reduc.empty() && "Not Implemented"); 1311*c4420257SPeiming Liu rewriter.create<sparse_tensor::YieldOp>(env.op().getLoc()); 1312*c4420257SPeiming Liu return std::nullopt; 1313*c4420257SPeiming Liu }); 1314*c4420257SPeiming Liu // endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns); 1315*c4420257SPeiming Liu } else { 1316*c4420257SPeiming Liu genStmt(env, rewriter, ej, curr + 1); 1317*c4420257SPeiming Liu } 1318*c4420257SPeiming Liu } 1319*c4420257SPeiming Liu // End a loop. 1320*c4420257SPeiming Liu needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond); 1321*c4420257SPeiming Liu } else { 1322*c4420257SPeiming Liu // Emit a loop for every lattice point L0 >= Li in this loop sequence. 1323a2c9d4bbSAart Bik for (unsigned i = 0; i < lsize; i++) { 1324b8cf7af9Swren romano const LatPointId li = env.set(lts)[i]; 13257c7c10a0Swren romano // Start a loop. 1326*c4420257SPeiming Liu auto [loop, isSingleCond] = 1327*c4420257SPeiming Liu startLoop(env, rewriter, curr, li, lsize, needsUniv); 1328a2c9d4bbSAart Bik 1329a2c9d4bbSAart Bik // Visit all lattices points with Li >= Lj to generate the 1330a2c9d4bbSAart Bik // loop-body, possibly with if statements for coiteration. 1331fbe61130SAart Bik Value redInput = env.getReduc(); 1332384049a7SAart Bik Value cntInput = env.getExpandCount(); 1333384049a7SAart Bik Value insInput = env.getInsertionChain(); 1334fc5d8fceSPeiming Liu Value validIns = env.getValidLexInsert(); 1335067bebb5SAart Bik // We cannot change this to `for (const LatPointId lj : env.set(lts))` 13367c7c10a0Swren romano // because the loop body causes data-movement which invalidates the 13377c7c10a0Swren romano // iterator. 1338a2c9d4bbSAart Bik for (unsigned j = 0; j < lsize; j++) { 1339b8cf7af9Swren romano const LatPointId lj = env.set(lts)[j]; 1340b8cf7af9Swren romano const ExprId ej = env.lat(lj).exp; 1341384049a7SAart Bik if (li == lj || env.merger().latGT(li, lj)) { 1342a2c9d4bbSAart Bik // Recurse into body of each branch. 13431328bb6eSPeiming Liu if (!isSingleCond) { 1344c5a1732cSAart Bik scf::IfOp ifOp = genIf(env, rewriter, curr, lj); 1345c5a1732cSAart Bik genStmt(env, rewriter, ej, curr + 1); 1346372d88b0SPeiming Liu endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns); 1347a2c9d4bbSAart Bik } else { 1348c5a1732cSAart Bik genStmt(env, rewriter, ej, curr + 1); 1349a2c9d4bbSAart Bik } 1350a2c9d4bbSAart Bik } 1351a2c9d4bbSAart Bik } 1352a2c9d4bbSAart Bik 1353c8d5dcb0SAart Bik // End a loop. 1354c5a1732cSAart Bik needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond); 1355a2c9d4bbSAart Bik } 1356*c4420257SPeiming Liu } 1357a2c9d4bbSAart Bik 1358c8d5dcb0SAart Bik // End a loop sequence. 1359c5a1732cSAart Bik endLoopSeq(env, rewriter, exp, curr); 1360c5a1732cSAart Bik assert(curr == env.getCurrentDepth()); 1361a2c9d4bbSAart Bik } 1362a2c9d4bbSAart Bik 1363727a63e0SAart Bik /// Converts the result computed by the sparse kernel into the required form. 1364fbe61130SAart Bik static void genResult(CodegenEnv &env, RewriterBase &rewriter) { 1365384049a7SAart Bik linalg::GenericOp op = env.op(); 1366b4db15a9SAlexander Belyaev OpOperand *lhs = op.getDpsInitOperand(0); 13675661647eSAart Bik Value tensor = lhs->get(); 13685661647eSAart Bik Type resType = tensor.getType(); 1369f66e5769SAart Bik if (getSparseTensorEncoding(resType)) { 1370f66e5769SAart Bik // The sparse tensor rematerializes from the original sparse tensor's 13715661647eSAart Bik // underlying sparse storage format. For an insertion chain, the 13725661647eSAart Bik // tensor materializes from the chain with 'hasInserts' enabled. 1373384049a7SAart Bik bool hasInserts = false; 1374384049a7SAart Bik if (Value chain = env.getInsertionChain()) { 1375384049a7SAart Bik hasInserts = true; 1376384049a7SAart Bik tensor = chain; 1377384049a7SAart Bik } 13785661647eSAart Bik rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts); 137936b66ab9SAart Bik } else { 1380f66e5769SAart Bik // To rematerialize an non-annotated tensor, simply load it 138136b66ab9SAart Bik // from the bufferized value. 1382e7b4c93fSPeiming Liu Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()]; 138357470abcSAlexander Belyaev rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val); 138436b66ab9SAart Bik } 1385727a63e0SAart Bik } 1386727a63e0SAart Bik 13875da21338SAart Bik //===----------------------------------------------------------------------===// 1388c43e6274STim Harvey // Sparsifier rewriting methods. 13895da21338SAart Bik //===----------------------------------------------------------------------===// 13905da21338SAart Bik 1391a2c9d4bbSAart Bik namespace { 139298f93e3bSAart Bik 1393a2c9d4bbSAart Bik /// Sparse rewriting rule for generic Lingalg operation. 1394a2c9d4bbSAart Bik struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1395a2c9d4bbSAart Bik public: 1396a2c9d4bbSAart Bik GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1397a2c9d4bbSAart Bik : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1398a2c9d4bbSAart Bik 1399a2c9d4bbSAart Bik LogicalResult matchAndRewrite(linalg::GenericOp op, 1400a2c9d4bbSAart Bik PatternRewriter &rewriter) const override { 1401740582faSAart Bik // Only accept single output operations with pure tensor semantics. 14020a8e3dd4SMatthias Springer if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics()) 1403740582faSAart Bik return failure(); 1404740582faSAart Bik 1405740582faSAart Bik // Only accept trivial affine indices. 1406740582faSAart Bik if (hasNonTrivialAffineOnSparseOut(op)) 140711d75076SBenjamin Kramer return failure(); 140898f93e3bSAart Bik 1409067bebb5SAart Bik // Only accept scheduled loops. 141006a65ce5SPeiming Liu if (!op->hasAttr("sorted")) { 141106a65ce5SPeiming Liu return rewriter.notifyMatchFailure( 141206a65ce5SPeiming Liu op, "Loops not yet scheduled, try run --sparse-reinterpret-map " 141306a65ce5SPeiming Liu "before sparsification."); 141406a65ce5SPeiming Liu } 141598ce2debSAart Bik 1416ccd923e3SPeiming Liu // Must have been demapped as well if the generic op is sorted. 1417ccd923e3SPeiming Liu assert(!hasAnyNonIdentityOperandsOrResults(op)); 141806a65ce5SPeiming Liu 141998f93e3bSAart Bik // Sets up a code generation environment. 1420b8cf7af9Swren romano const unsigned numTensors = op->getNumOperands(); 1421b8cf7af9Swren romano const unsigned numLoops = op.getNumLoops(); 1422ccd923e3SPeiming Liu bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0; 14232b21327fSPeiming Liu // If we have indexing map like (d0) -> (0, d0), there might be more 14242b21327fSPeiming Liu // levels then loops because of the constant index, that means we can not 14252b21327fSPeiming Liu // use numLoops as the upper bound for ranks of all tensors. 14262b21327fSPeiming Liu // TODO: Constant indices are currently not support on sparse tensor, but 14272b21327fSPeiming Liu // are allowed in non-annotated dense tensor. Support it, it would be 14282b21327fSPeiming Liu // required for sparse tensor slice rank reducing too. 14292b21327fSPeiming Liu Level maxLvlRank = 0; 14302b21327fSPeiming Liu for (auto operand : op.getOperands()) { 14315550c821STres Popp if (auto rtp = dyn_cast<RankedTensorType>(operand.getType())) { 14322b21327fSPeiming Liu maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank()); 14332b21327fSPeiming Liu } 14342b21327fSPeiming Liu } 143506a65ce5SPeiming Liu 1436b8cf7af9Swren romano // Detects sparse annotations and translates the per-level sparsity 143798f93e3bSAart Bik // information for all tensors to loop indices in the kernel. 1438067bebb5SAart Bik CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank); 1439ccd923e3SPeiming Liu if (!findSparseAnnotations(env, needIdxRed)) 1440bf9ef3efSAart Bik return failure(); 1441a2c9d4bbSAart Bik 1442378f1885SAart Bik // Only standard reduction operations (add, sub, or, xor) that can be 1443378f1885SAart Bik // sparsified by merely reducing the stored values are admissible. More 1444378f1885SAart Bik // elaborate reduction operations (such as mul, and, min, max) would need 1445378f1885SAart Bik // to know whether implicit zeros occur as well. They can still be 1446378f1885SAart Bik // implemented with a custom reduction operation, accepted here as well. 1447378f1885SAart Bik if (op.getNumReductionLoops() > 0) { 1448378f1885SAart Bik Operation *yield = op.getRegion().front().getTerminator(); 1449378f1885SAart Bik assert(isa<linalg::YieldOp>(yield)); 1450378f1885SAart Bik Operation *redop = yield->getOperand(0).getDefiningOp(); 1451378f1885SAart Bik if (!isa<arith::AddFOp>(redop) && !isa<complex::AddOp>(redop) && 1452378f1885SAart Bik !isa<arith::AddIOp>(redop) && !isa<arith::SubFOp>(redop) && 1453378f1885SAart Bik !isa<complex::SubOp>(redop) && !isa<arith::SubIOp>(redop) && 1454378f1885SAart Bik !isa<arith::OrIOp>(redop) && !isa<arith::XOrIOp>(redop) && 1455378f1885SAart Bik !isa<ReduceOp>(redop)) { 1456378f1885SAart Bik return failure(); 1457378f1885SAart Bik } 1458378f1885SAart Bik } 1459378f1885SAart Bik 1460a7bf2e55SPeiming Liu // Constructs the tensor expressions tree from `op`, returns failure if the 1461a7bf2e55SPeiming Liu // tree can not be built or the tensor expression is inadmissible. 1462a7bf2e55SPeiming Liu if (failed(env.initTensorExp())) 146355a1d50fSPeiming Liu return failure(); 146498f93e3bSAart Bik 1465b0f8057eSPeiming Liu // Recursively generates code if admissible. 14664a653b4dSPeiming Liu env.startEmit(options.sparseEmitStrategy); 146798f93e3bSAart Bik genBuffers(env, rewriter); 14681328bb6eSPeiming Liu // TODO: Constant affine expression should be handled differently when using 14692cb99df6SYinying Li // slice-based codegen, it does not matter now because we already reject the 1470c5a1732cSAart Bik // constant expression at an earlier stage. 147198f93e3bSAart Bik genInitConstantDenseAddress(env, rewriter); 1472b8cf7af9Swren romano genStmt(env, rewriter, env.getExprId(), 0); 147398f93e3bSAart Bik genResult(env, rewriter); 1474a2c9d4bbSAart Bik return success(); 1475a2c9d4bbSAart Bik } 1476a2c9d4bbSAart Bik 1477a2c9d4bbSAart Bik private: 1478a2c9d4bbSAart Bik /// Options to control sparse code generation. 1479a2c9d4bbSAart Bik SparsificationOptions options; 1480a2c9d4bbSAart Bik }; 1481a2c9d4bbSAart Bik 1482a2c9d4bbSAart Bik } // namespace 1483a2c9d4bbSAart Bik 1484a2c9d4bbSAart Bik /// Populates the given patterns list with rewriting rules required for 1485a2c9d4bbSAart Bik /// the sparsification of linear algebra operations. 1486a2c9d4bbSAart Bik void mlir::populateSparsificationPatterns( 1487a2c9d4bbSAart Bik RewritePatternSet &patterns, const SparsificationOptions &options) { 1488a2c9d4bbSAart Bik patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1489a2c9d4bbSAart Bik } 1490