1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements converting sparse tensor types to actual sparse code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Utils/CodegenEnv.h" 14 #include "Utils/CodegenUtils.h" 15 #include "Utils/LoopEmitter.h" 16 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Arith/IR/Arith.h" 19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 20 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 21 #include "mlir/Dialect/Func/IR/FuncOps.h" 22 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 23 #include "mlir/Dialect/Linalg/IR/Linalg.h" 24 #include "mlir/Dialect/Linalg/Utils/Utils.h" 25 #include "mlir/Dialect/MemRef/IR/MemRef.h" 26 #include "mlir/Dialect/SCF/IR/SCF.h" 27 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 28 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 29 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 30 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 31 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 32 #include "mlir/Dialect/Tensor/IR/Tensor.h" 33 #include "mlir/IR/AffineExprVisitor.h" 34 #include "mlir/IR/Matchers.h" 35 #include "mlir/IR/TensorEncoding.h" 36 #include "llvm/ADT/SmallBitVector.h" 37 38 #include <optional> 39 40 using namespace mlir; 41 using namespace mlir::sparse_tensor; 42 43 //===----------------------------------------------------------------------===// 44 // Sparsifier analysis methods. 45 //===----------------------------------------------------------------------===// 46 47 /// Returns true iff affine expression is invariant. Sets the 48 /// parameter `isCurrentLoop` when expression just became invariant. 49 static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop) { 50 switch (a.getKind()) { 51 case AffineExprKind::DimId: { 52 const LoopId i = cast<AffineDimExpr>(a).getPosition(); 53 if (i + 1 == curr) { 54 isCurrentLoop = true; 55 return true; // becomes invariant at current loop 56 } 57 return i < curr; // invariant when already generated 58 } 59 case AffineExprKind::Add: 60 case AffineExprKind::Mul: { 61 auto binOp = cast<AffineBinaryOpExpr>(a); 62 return isInvariantAffine(binOp.getLHS(), curr, isCurrentLoop) && 63 isInvariantAffine(binOp.getRHS(), curr, isCurrentLoop); 64 } 65 default: { 66 assert(isa<AffineConstantExpr>(a)); 67 return true; 68 } 69 } 70 } 71 72 /// Helper method to inspect affine expressions. Rejects cases where the 73 /// same index is used more than once. Also rejects compound affine 74 /// expressions in sparse dimensions. 75 static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, 76 LevelType lt, bool setLvlFormat = true) { 77 switch (a.getKind()) { 78 case AffineExprKind::DimId: { 79 const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition()); 80 if (!isUndefLT(merger.getLvlType(tid, idx))) 81 return false; // used more than once 82 if (setLvlFormat) 83 merger.setLevelAndType(tid, idx, lvl, lt); 84 return true; 85 } 86 case AffineExprKind::Add: 87 case AffineExprKind::Mul: 88 case AffineExprKind::Constant: { 89 assert(lt.hasDenseSemantic()); 90 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) { 91 // We do not set dim level format for affine expression like d0 + d1 on 92 // either loop index at d0 or d1. We continue the recursion merely to 93 // check whether current affine is admissible or not. 94 return findAffine(merger, tid, lvl, binOp.getLHS(), lt, false) && 95 findAffine(merger, tid, lvl, binOp.getRHS(), lt, false); 96 } 97 // Falls through when it is a constant Affine 98 return true; 99 } 100 default: 101 return false; 102 } 103 } 104 105 /// Helper method to inspect affine expressions for index variable reduction 106 /// based codegen. It finds the dependent index set for all tensor levels in the 107 /// current expression we are generating. 108 /// 109 /// For example, when handling A[i+j][j+k], we build the two way mapping in 110 /// merger between (tensor, level) pairs and their dependent index variable set: 111 /// A_0 <=> [i, j] and A_1 <=> [j, k] 112 /// 113 /// It rejects cases (returns false) 114 /// 1st, when the same index is used more than once, e.g., A[i+j][i] 115 /// 2nd, when multiplication is used in the non-trivial index expression. 116 /// 3rd, when a constant operand is used in the non-trivial index expression. 117 /// 118 /// TODO: constant should be easy to handle. 119 static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, 120 AffineExpr a, LevelType lt, bool isSubExp = false, 121 int64_t coefficient = 1) { 122 switch (a.getKind()) { 123 case AffineExprKind::DimId: { 124 // Only allow positive coefficients on AffineDimExpr. 125 if (coefficient <= 0) 126 return false; 127 128 const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition()); 129 if (!isUndefLT(merger.getLvlType(tensor, idx))) 130 return false; // used more than once, e.g., A[i][i] 131 132 // TODO: Generalizes the following two cases. A[i] (with trivial index 133 // expression) can be treated as a special affine index expression. We do 134 // not necessarily need to differentiate them. 135 if (!isSubExp) { 136 assert(coefficient == 1); 137 merger.setLevelAndType(tensor, idx, lvl, lt); 138 } 139 140 if (isSubExp) { 141 // The current loops appears in more than one affine expressions on the 142 // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is 143 // used twice. 144 if (merger.hasDependentLvl(idx, tensor)) { 145 // TODO: This can be supported by coiterate slices if the loop idx is 146 // appeared on affine index for different tensor, or take slice on 147 // multiple dimensions when it is on the same tensor. 148 // E.g., 149 // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0] 150 // d0_1 = getNextSliceOffset t0 along lvl0 151 // d0_2 = getNextSliceOffset t1 along lvl0 152 // if d0_1 == d0_2 then d0 = d0_1 = d0_1 153 // else increase min(d0_1, d0_2). 154 return false; 155 } 156 merger.setLoopDependentTensorLevel(idx, tensor, lvl, lt, coefficient); 157 } 158 return true; 159 } 160 case AffineExprKind::Constant: 161 case AffineExprKind::Mul: { 162 // TODO: Support index expression like `2 * d0`, we now only support more 163 // complicated cases like `2 * d0 + d1`. 164 if (!isSubExp) 165 return false; 166 167 // TODO: Support Constant AffineExp for slice-based codegen 168 if (isa<AffineConstantExpr>(a)) 169 llvm_unreachable("Not yet implemented"); 170 171 auto binOp = cast<AffineBinaryOpExpr>(a); 172 auto lhs = binOp.getLHS(), rhs = binOp.getRHS(); 173 if (isa<AffineConstantExpr>(rhs)) 174 std::swap(lhs, rhs); 175 // Must be in form of `constant * d`. 176 assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs)); 177 int64_t coefficient = cast<AffineConstantExpr>(lhs).getValue(); 178 return findDepIdxSet(merger, tensor, lvl, rhs, lt, isSubExp, coefficient); 179 } 180 case AffineExprKind::Add: { 181 auto binOp = cast<AffineBinaryOpExpr>(a); 182 return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), lt, true) && 183 findDepIdxSet(merger, tensor, lvl, binOp.getRHS(), lt, true); 184 } 185 default: 186 return false; 187 } 188 } 189 190 /// Gets the total number of compound affine expressions in the 191 /// `getMatchingIndexingMap` for the given tensor. For the following inputs: 192 /// 193 /// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed) 194 /// 195 /// Returns 1 (because the first level is compressed and its corresponding 196 /// indexing-expression is `d0 + d1`) 197 static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, 198 Value tensor) { 199 // The `tensor` is not guaranteed to have `RankedTensorType`, therefore 200 // we can't use `getRankedTensorType`/`getSparseTensorType` here. 201 // However, we don't need to handle `StorageSpecifierType`, so we 202 // can use `SparseTensorType` once we guard against non-tensors. 203 const auto rtp = dyn_cast<RankedTensorType>(tensor.getType()); 204 if (!rtp) 205 return 0; 206 const SparseTensorType stt(rtp); 207 208 const Level lvlRank = stt.getLvlRank(); 209 const auto exprs = map.getResults(); 210 assert(static_cast<Dimension>(exprs.size()) == lvlRank && 211 "AffineMap does not have dimension-rank many results"); 212 unsigned num = 0; 213 for (Level l = 0; l < lvlRank; l++) { 214 if (!isa<AffineDimExpr>(exprs[l]) && !stt.getLvlType(l).hasDenseSemantic()) 215 num++; 216 } 217 return num; 218 } 219 220 /// Gets the total number of sparse levels with compound affine 221 /// expressions, summed over all operands of the `GenericOp`. 222 static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) { 223 unsigned num = 0; 224 for (OpOperand &t : op->getOpOperands()) 225 num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t), 226 t.get()); 227 return num; 228 } 229 230 // Returns true iff output has nontrivial affine indices. 231 static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) { 232 OpOperand *out = op.getDpsInitOperand(0); 233 if (getSparseTensorType(out->get()).isAllDense()) 234 return false; 235 return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out), 236 out->get()); 237 } 238 239 /// Helper method to inspect sparse encodings in the tensor types. 240 /// Fills the per-dimension sparsity information for all tensors. 241 /// Returns true if the sparse annotations and affine subscript 242 /// expressions of all tensors are admissible. Returns false if 243 /// no annotations are found or inadmissible constructs occur. 244 /// We currently support two different ways to handle non-trivial index 245 /// expression on sparse tensors, and they accept different affine expressions. 246 /// When using dependent index reducton-based approach, it currently only 247 /// supports affine addition index expression. 248 static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) { 249 bool annotated = false; 250 for (OpOperand &t : env.op()->getOpOperands()) { 251 const TensorId tid = env.makeTensorId(t.getOperandNumber()); 252 const auto map = env.op().getMatchingIndexingMap(&t); 253 const auto enc = getSparseTensorEncoding(t.get().getType()); 254 if (enc) 255 annotated = true; 256 const Level lvlRank = map.getNumResults(); 257 assert(!enc || lvlRank == enc.getLvlRank()); 258 assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank); 259 // We only need to do index reduction if there is at least one 260 // non-trivial index expression on sparse levels. If all non-trivial 261 // index expression is on dense levels, we can efficiently rely on 262 // the random access to locate the element. 263 bool needIdxReduc = 264 enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0; 265 // If then current tensor being inspected requires affine index, it need 266 // to be sliced. 267 for (Level l = 0; l < lvlRank; l++) { 268 const AffineExpr a = map.getResult(l); 269 const LevelType lt = enc.getLvlType(l); 270 if (idxReducBased && needIdxReduc) { 271 if (!findDepIdxSet(env.merger(), tid, l, a, lt)) 272 return false; // inadmissible affine expression 273 } else { 274 if (!findAffine(env.merger(), tid, l, a, lt)) 275 return false; // inadmissible affine expression 276 } 277 } 278 } 279 return annotated; 280 } 281 282 //===----------------------------------------------------------------------===// 283 // Sparsifier synthesis methods (statements and expressions). 284 //===----------------------------------------------------------------------===// 285 286 /// Local bufferization of all dense and sparse data structures. 287 static void genBuffers(CodegenEnv &env, OpBuilder &builder) { 288 linalg::GenericOp op = env.op(); 289 Location loc = op.getLoc(); 290 assert(op.getNumOperands() == op.getNumDpsInputs() + 1); 291 292 SmallVector<Range, 4> loopRange = 293 llvm::cast<linalg::LinalgOp>(op.getOperation()) 294 .createLoopRanges(builder, loc); 295 296 env.emitter().initializeLoopEmit( 297 builder, loc, 298 /// Generates buffer for the output tensor. 299 /// Note that all sparse kernels assume that when all elements are written 300 /// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized 301 /// to all zeroes and only nonzeroes values are computed and written out. 302 /// For updates (viz. x(i) += y(i) * z(i)), only nonzeroes values are used 303 /// for the updates and no assumption on the original contents of the 304 /// output buffer is necessary. 305 [&op](OpBuilder &builder, Location loc, Value memref, 306 Value tensor) -> Value { 307 // Must not be a sparse tensor. 308 assert(!getSparseTensorEncoding(tensor.getType())); 309 // Two output tensor references should point to the same object. 310 OpOperand *lhs = op.getDpsInitOperand(0); 311 assert(lhs->get() == tensor); 312 // An output tensor can simply materialize from the buffer of the tensor 313 // that appears in the outs() clause. For updates, this has the 314 // advantage that only the nonzero value are involved in the 315 // computation, keeping the operation O(nnz). In all other cases, we are 316 // forced to zero out the buffer to enforce the assumption above, which 317 // may negatively impact running complexity (viz. O(n^2 + nnz) vs. 318 // O(nnz) for matrices). 319 // TODO: use better analysis to avoid zeroing out the buffer? 320 bool isInit = op.isInitTensor(lhs); 321 Value init = memref; 322 if (!isInit) { 323 Value zero = constantZero(builder, loc, 324 getElementTypeOrSelf(tensor.getType())); 325 builder.create<linalg::FillOp>(loc, ValueRange{zero}, 326 ValueRange{init}); 327 } 328 return init; 329 }, 330 [&loopRange](OpBuilder &b, Location loc, Level l) { 331 assert(l < loopRange.size()); 332 return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size); 333 }); 334 } 335 336 /// Generates index for load/store on sparse tensor. 337 static Value genIndex(CodegenEnv &env, OpOperand *t) { 338 const auto map = env.op().getMatchingIndexingMap(t); 339 const auto stt = getSparseTensorType(t->get()); 340 const Level lvlRank = stt.getLvlRank(); 341 assert(static_cast<Level>(map.getNumResults()) == lvlRank); 342 const AffineExpr a = map.getResult(lvlRank - 1); 343 assert(a.getKind() == AffineExprKind::DimId); 344 const LoopId idx = env.makeLoopId(cast<AffineDimExpr>(a).getPosition()); 345 return env.getLoopVar(idx); 346 } 347 348 /// Generates subscript for load/store on a dense or sparse tensor. 349 static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, 350 SmallVectorImpl<Value> &args) { 351 const Location loc = env.op().getLoc(); 352 const TensorId tid = env.makeTensorId(t->getOperandNumber()); 353 const auto map = env.op().getMatchingIndexingMap(t); 354 const auto stt = getSparseTensorType(t->get()); 355 if (stt.hasEncoding()) { 356 // For sparse tensors we only push the last-level's position onto `args`. 357 const auto pos = env.emitter().getValPosits(tid); 358 assert(!pos.empty()); 359 args.append(pos); 360 // Simply returns the tensor to extract value using iterators. 361 if (env.options().sparseEmitStrategy == SparseEmitStrategy::kSparseIterator) 362 return t->get(); 363 } else { 364 // For dense tensors we push all level's coordinates onto `args`. 365 const Level lvlRank = stt.getLvlRank(); 366 assert(static_cast<Level>(map.getNumResults()) == lvlRank); 367 for (Level l = 0; l < lvlRank; l++) { 368 const auto lvlExpr = map.getResult(l); 369 const auto lvlCrd = env.emitter().genAffine(builder, loc, lvlExpr); 370 args.push_back(lvlCrd); 371 } 372 } 373 return env.emitter().getValBuffer()[tid]; 374 } 375 376 /// Generates insertion code to implement dynamic tensor load. 377 static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder, 378 OpOperand *t) { 379 linalg::GenericOp op = env.op(); 380 Location loc = op.getLoc(); 381 // Direct lexicographic coordinate order, tensor loads as zero. 382 if (!env.isExpand()) { 383 Type tp = getElementTypeOrSelf(t->get().getType()); 384 return constantZero(builder, loc, tp); 385 } 386 // Load from expanded access pattern. 387 Value index = genIndex(env, t); 388 return builder.create<memref::LoadOp>(loc, env.getExpandValues(), index); 389 } 390 391 /// Generates insertion code to implement dynamic tensor load for reduction. 392 static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, 393 OpOperand *t) { 394 linalg::GenericOp op = env.op(); 395 Location loc = op.getLoc(); 396 Value identity = env.getCustomRedId(); 397 // Direct lexicographic coordinate order, tensor loads as identity. 398 if (!env.isExpand()) 399 return identity; 400 // Load from expanded access pattern if filled, identity otherwise. 401 Value values = env.getExpandValues(); 402 Value filled = env.getExpandFilled(); 403 Value index = genIndex(env, t); 404 Value isFilled = builder.create<memref::LoadOp>(loc, filled, index); 405 Value valAtIndex = builder.create<memref::LoadOp>(loc, values, index); 406 return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity); 407 } 408 409 static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond, 410 Value sparseOut, ValueRange ivs, Value v) { 411 scf::IfOp condInsert = 412 builder.create<scf::IfOp>(loc, sparseOut.getType(), cond, true); 413 // True branch. 414 builder.setInsertionPointToStart(condInsert.thenBlock()); 415 Value res = builder.create<tensor::InsertOp>(loc, v, sparseOut, ivs); 416 builder.create<scf::YieldOp>(loc, res); 417 // False branch. 418 builder.setInsertionPointToStart(condInsert.elseBlock()); 419 builder.create<scf::YieldOp>(loc, sparseOut); 420 // Value assignment. 421 builder.setInsertionPointAfter(condInsert); 422 return condInsert.getResult(0); 423 } 424 425 /// Generates insertion code to implement dynamic tensor store. 426 static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, 427 Value rhs) { 428 linalg::GenericOp op = env.op(); 429 Location loc = op.getLoc(); 430 // Direct insertion in lexicographic coordinate order. 431 if (!env.isExpand()) { 432 const LoopId numLoops = op.getRank(t); 433 // Retrieves the first `numLoop` induction variables. 434 SmallVector<Value> ivs = llvm::to_vector(llvm::drop_end( 435 env.emitter().getLoopIVsRange(), env.getCurrentDepth() - numLoops)); 436 Value chain = env.getInsertionChain(); 437 if (env.isValidLexInsert()) { 438 // Generates runtime check for a valid lex during reduction, 439 // to avoid inserting the identity value for empty reductions. 440 // if (validLexInsert) then 441 // insert(rhs) into chain 442 // return updated chain 443 // else 444 // return unmodified chain 445 Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(), 446 chain, ivs, rhs); 447 env.updateInsertionChain(out); 448 } else { 449 Value sparseOut; 450 if (!hasAnySparseType(env.op().getInputs().getTypes())) { 451 // This is an all-dense -> sparse kernel, test rhs != 0 before 452 // insertion. 453 Value nz = genIsNonzero(builder, loc, rhs); 454 sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs); 455 } else { 456 sparseOut = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs); 457 } 458 // Generates regular insertion chain. 459 env.updateInsertionChain(sparseOut); 460 } 461 return; 462 } 463 // Generates insertion code along expanded access pattern. 464 // if (!expFilled[i]) then 465 // expFilled[i] = true 466 // expAdded[inserts++] = i 467 // endif 468 // values[i] = rhs 469 Value values = env.getExpandValues(); 470 Value filled = env.getExpandFilled(); 471 Value added = env.getExpandAdded(); 472 Value count = env.getExpandCount(); 473 Value index = genIndex(env, t); 474 Value fval = constantI1(builder, loc, false); 475 Value tval = constantI1(builder, loc, true); 476 // If statement. 477 Value isFilled = builder.create<memref::LoadOp>(loc, filled, index); 478 Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 479 isFilled, fval); 480 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond, 481 /*else=*/true); 482 // True branch. 483 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 484 builder.create<memref::StoreOp>(loc, tval, filled, index); 485 builder.create<memref::StoreOp>(loc, index, added, count); 486 Value one = constantIndex(builder, loc, 1); 487 Value add = builder.create<arith::AddIOp>(loc, count, one); 488 builder.create<scf::YieldOp>(loc, add); 489 // False branch. 490 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 491 builder.create<scf::YieldOp>(loc, count); 492 builder.setInsertionPointAfter(ifOp); 493 // Value assignment. 494 env.updateExpandCount(ifOp.getResult(0)); 495 builder.create<memref::StoreOp>(loc, rhs, values, index); 496 } 497 498 /// Generates a load on a dense or sparse tensor. 499 static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) { 500 // Test if the load was hoisted to a higher loop nest. 501 Value val = env.exp(exp).val; 502 if (val) 503 return val; 504 // Get tensor operand. 505 linalg::GenericOp op = env.op(); 506 Location loc = op.getLoc(); 507 OpOperand *t = &op->getOpOperand(env.exp(exp).tensor); 508 // Fold binary-valued tensor into explicit value. 509 const auto stt = getSparseTensorType(t->get()); 510 if (auto explVal = stt.getExplicitVal()) 511 return genValFromAttr(builder, loc, explVal); 512 // Load during insertion. 513 if (env.isSparseOutput(t)) { 514 if (env.isCustomReduc()) 515 return genInsertionLoadReduce(env, builder, t); 516 return genInsertionLoad(env, builder, t); 517 } 518 519 // Actual load. 520 SmallVector<Value> args; 521 Value ptr = genSubscript(env, builder, t, args); 522 if (llvm::isa<TensorType>(ptr.getType())) { 523 assert(env.options().sparseEmitStrategy == 524 SparseEmitStrategy::kSparseIterator && 525 args.size() == 1); 526 return builder.create<ExtractValOp>(loc, ptr, args.front()); 527 } 528 return builder.create<memref::LoadOp>(loc, ptr, args); 529 } 530 531 /// Generates a store on a dense or sparse tensor. 532 static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, 533 Value rhs) { 534 // Only unary and binary are allowed to return an uninitialized rhs 535 // to indicate missing output. Or otherwise a custom reduction that 536 // received no value to accumulate. 537 if (!rhs) { 538 assert(env.exp(exp).kind == TensorExp::Kind::kUnary || 539 env.exp(exp).kind == TensorExp::Kind::kBinary || 540 env.exp(exp).kind == TensorExp::Kind::kReduce); 541 return; 542 } 543 // Test if this is a scalarized reduction. 544 if (env.isReduc()) { 545 env.updateReduc(rhs); 546 return; 547 } 548 // Regular store. 549 linalg::GenericOp op = env.op(); 550 Location loc = op.getLoc(); 551 OpOperand *t = op.getDpsInitOperand(0); 552 if (!env.isSparseOutput(t)) { 553 SmallVector<Value> args; 554 Value ptr = genSubscript(env, builder, t, args); 555 builder.create<memref::StoreOp>(loc, rhs, ptr, args); 556 return; 557 } 558 // Store during sparse insertion. 559 if (env.exp(exp).kind != TensorExp::Kind::kSelect) { 560 genInsertionStore(env, builder, t, rhs); 561 return; 562 } 563 // Select operation insertion. 564 Value chain = env.getInsertionChain(); 565 scf::IfOp ifOp = 566 builder.create<scf::IfOp>(loc, chain.getType(), rhs, /*else=*/true); 567 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 568 // Existing value was preserved to be used here. 569 assert(env.exp(exp).val); 570 Value v0 = env.exp(exp).val; 571 genInsertionStore(env, builder, t, v0); 572 env.merger().clearExprValue(exp); 573 // Yield modified insertion chain along true branch. 574 Value mchain = env.getInsertionChain(); 575 builder.create<scf::YieldOp>(op.getLoc(), mchain); 576 // Yield original insertion chain along false branch. 577 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 578 builder.create<scf::YieldOp>(loc, chain); 579 // Done with if statement. 580 env.updateInsertionChain(ifOp->getResult(0)); 581 builder.setInsertionPointAfter(ifOp); 582 } 583 584 /// Generates an invariant value. 585 inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) { 586 return env.exp(exp).val; 587 } 588 589 /// Semi-ring branches are simply inlined by the sparsifier. Prior 590 /// analysis has verified that all computations are "local" to the inlined 591 /// branch or otherwise invariantly defined outside the loop nest, with the 592 /// exception of index computations, which need to be relinked to actual 593 /// inlined cloned code. 594 static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, 595 Value e) { 596 if (auto arg = dyn_cast<BlockArgument>(e)) { 597 // Direct arguments of the original linalg op must be converted 598 // into dense tensor loads. Note that we should not encounter 599 // anything else. This needs to be verified by semi-ring ops. 600 linalg::GenericOp op = env.op(); 601 if (arg.getOwner()->getParentOp() == op) { 602 const TensorId tid = env.makeTensorId(arg.getArgNumber()); 603 OpOperand *t = &op->getOpOperand(tid); 604 assert(!getSparseTensorType(t->get()).hasEncoding()); // dense! 605 SmallVector<Value> args; 606 Value ptr = genSubscript(env, rewriter, t, args); 607 return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); 608 } 609 } else if (Operation *def = e.getDefiningOp()) { 610 // Handle index computation. 611 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 612 return env.getLoopVar(env.makeLoopId(indexOp.getDim())); 613 // When still defined in new body, recurse into operands. 614 if (def->getBlock() == block) { 615 rewriter.setInsertionPoint(def); 616 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) { 617 rewriter.modifyOpInPlace(def, [&]() { 618 def->setOperand( 619 i, relinkBranch(env, rewriter, block, def->getOperand(i))); 620 }); 621 } 622 } 623 } 624 return e; 625 } 626 627 /// Recursively generates tensor expression. 628 static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) { 629 if (e == ::mlir::sparse_tensor::detail::kInvalidId) 630 return Value(); 631 632 linalg::GenericOp op = env.op(); 633 Location loc = op.getLoc(); 634 const TensorExp &exp = env.exp(e); 635 const auto kind = exp.kind; 636 if (kind == TensorExp::Kind::kTensor) 637 return genTensorLoad(env, rewriter, e); 638 if (kind == TensorExp::Kind::kInvariant) 639 return genInvariantValue(env, e); 640 if (kind == TensorExp::Kind::kLoopVar) 641 return env.getLoopVar(exp.loop); 642 643 if (kind == TensorExp::Kind::kReduce) 644 env.startCustomReduc(e); // enter custom 645 646 // If either lhs/rhs is a synthetic zero, we infer the type for the zero value 647 // based on the type of the other operand. 648 Value v0, v1; 649 if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId && 650 env.exp(exp.children.e0).kind == TensorExp::Kind::kSynZero) { 651 v1 = genExp(env, rewriter, exp.children.e1); 652 v0 = constantZero(rewriter, loc, v1.getType()); 653 } else if (exp.children.e1 != ::mlir::sparse_tensor::detail::kInvalidId && 654 env.exp(exp.children.e1).kind == TensorExp::Kind::kSynZero) { 655 v0 = genExp(env, rewriter, exp.children.e0); 656 v1 = constantZero(rewriter, loc, v0.getType()); 657 } else { 658 v0 = genExp(env, rewriter, exp.children.e0); 659 v1 = genExp(env, rewriter, exp.children.e1); 660 } 661 662 Value ee; 663 if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) { 664 // custom reduce did not receive a value 665 } else { 666 ee = env.merger().buildExp(rewriter, loc, e, v0, v1); 667 if (ee && 668 (kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary || 669 kind == TensorExp::Kind::kBinaryBranch || 670 kind == TensorExp::Kind::kReduce || 671 kind == TensorExp::Kind::kSelect)) { 672 OpBuilder::InsertionGuard guard(rewriter); 673 ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee); 674 } 675 } 676 677 if (kind == TensorExp::Kind::kReduce) 678 env.endCustomReduc(); // exit custom 679 680 if (kind == TensorExp::Kind::kSelect) 681 env.merger().setExprValue(e, v0); // Preserve value for later use. 682 683 return ee; 684 } 685 686 /// Hoists loop invariant tensor loads for which indices have been exhausted. 687 static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, 688 LoopId curr, bool isStart) { 689 if (exp == ::mlir::sparse_tensor::detail::kInvalidId) 690 return; 691 if (env.exp(exp).kind == TensorExp::Kind::kTensor) { 692 // Inspect tensor indices. 693 linalg::GenericOp op = env.op(); 694 OpOperand &t = op->getOpOperand(env.exp(exp).tensor); 695 const auto map = op.getMatchingIndexingMap(&t); 696 const auto stt = getSparseTensorType(t.get()); 697 const Level lvlRank = stt.getLvlRank(); 698 assert(static_cast<Level>(map.getNumResults()) == lvlRank); 699 bool isCurrentLoop = curr == 0; // for scalar tensors 700 for (Level l = 0; l < lvlRank; l++) { 701 const AffineExpr a = map.getResult(l); 702 if (!isInvariantAffine(a, curr, /*out*/ isCurrentLoop)) 703 return; // still in play 704 } 705 // All exhausted at current level. 706 if (!isCurrentLoop) 707 return; 708 // Generate code for a scalarized reduction or invariant. Note that 709 // because custom reduction lhs may occur several times in the IR, 710 // we have a built-in safety for only initializing and wrapping-up 711 // the scalarized reduction once. 712 OpOperand *lhs = op.getDpsInitOperand(0); 713 if (lhs == &t) { 714 // Start or end a scalarized reduction. 715 if (isStart) { 716 if (env.isCustomReduc()) { 717 if (!env.isReduc()) 718 env.startReduc(exp, env.getCustomRedId()); 719 } else { 720 env.startReduc(exp, genTensorLoad(env, builder, exp)); 721 } 722 if (env.hasSparseOutput()) 723 env.startValidLexInsert( 724 constantI1(builder, env.op().getLoc(), false)); 725 } else { 726 if (!env.isCustomReduc() || env.isReduc()) 727 genTensorStore(env, builder, exp, env.endReduc()); 728 if (env.hasSparseOutput()) 729 env.endValidLexInsert(); 730 } 731 } else { 732 // Start or end loop invariant hoisting of a tensor load. 733 if (isStart) { 734 env.merger().setExprValue(exp, genTensorLoad(env, builder, exp)); 735 } else { 736 env.merger().clearExprValue(exp); 737 } 738 } 739 } else if (env.exp(exp).kind != TensorExp::Kind::kInvariant && 740 env.exp(exp).kind != TensorExp::Kind::kLoopVar && 741 env.exp(exp).kind != TensorExp::Kind::kSynZero) { 742 // Traverse into the binary operations. Note that we only hoist 743 // tensor loads, since subsequent MLIR/LLVM passes know how to 744 // deal with all other kinds of derived loop invariants. 745 if (env.exp(exp).kind == TensorExp::Kind::kReduce) 746 env.startCustomReduc(exp); // enter custom 747 const ExprId e0 = env.exp(exp).children.e0; 748 const ExprId e1 = env.exp(exp).children.e1; 749 genInvariants(env, builder, e0, curr, isStart); 750 genInvariants(env, builder, e1, curr, isStart); 751 if (env.exp(exp).kind == TensorExp::Kind::kReduce) 752 env.endCustomReduc(); // exit custom 753 } 754 } 755 756 /// Generates an expanded access pattern in innermost dimension. 757 static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr, 758 bool isStart) { 759 linalg::GenericOp op = env.op(); 760 OpOperand *lhs = op.getDpsInitOperand(0); 761 if (!env.atExpandLevel(lhs, op.getRank(lhs), curr)) 762 return; // not needed at current level 763 assert(!env.isReduc()); 764 // Generate start or end of an expanded access pattern. Note that because 765 // an expansion does not rely on the ongoing contents of the sparse storage 766 // scheme, we can use the original tensor as incoming SSA value (which 767 // simplifies codegen a bit). If expansion on the actual contents is ever 768 // needed, we will need to use the SSA value in the insertion chain instead. 769 Value tensor = lhs->get(); 770 Location loc = op.getLoc(); 771 if (isStart) { 772 auto dynShape = {ShapedType::kDynamic}; 773 Type etp = cast<ShapedType>(tensor.getType()).getElementType(); 774 Type t1 = MemRefType::get(dynShape, etp); 775 Type t2 = MemRefType::get(dynShape, builder.getI1Type()); 776 Type t3 = MemRefType::get(dynShape, builder.getIndexType()); 777 Type t4 = builder.getIndexType(); 778 auto r = builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor); 779 assert(r.getNumResults() == 4); 780 env.startExpand(r.getResult(0), r.getResult(1), r.getResult(2), 781 r.getResult(3)); 782 } else { 783 SmallVector<Value> indices; 784 for (LoopId i = 0; i < curr; i++) 785 indices.push_back(env.emitter().getLoopIV(i)); 786 Value values = env.getExpandValues(); 787 Value filled = env.getExpandFilled(); 788 Value added = env.getExpandAdded(); 789 Value count = env.getExpandCount(); 790 Value chain = env.getInsertionChain(); 791 Value compress = builder.create<CompressOp>(loc, values, filled, added, 792 count, chain, indices); 793 env.updateInsertionChain(compress); 794 env.endExpand(); 795 } 796 } 797 798 /// Returns parallelization strategy. Any implicit loop in the Linalg 799 /// operation that is marked "parallel" is a candidate. Whether it is actually 800 /// converted to a parallel operation depends on the requested strategy. 801 static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) { 802 // Reject parallelization of sparse output. 803 if (env.hasSparseOutput()) 804 return false; 805 // Parallel loops on tensor expansion can cause data races. 806 if (env.isExpand()) 807 return false; 808 // Inspect strategy. 809 switch (env.options().parallelizationStrategy) { 810 case SparseParallelizationStrategy::kNone: 811 return false; 812 case SparseParallelizationStrategy::kDenseOuterLoop: 813 return isOuter && !isSparse; 814 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 815 return isOuter; 816 case SparseParallelizationStrategy::kDenseAnyLoop: 817 return !isSparse; 818 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 819 return true; 820 } 821 llvm_unreachable("unexpected parallelization strategy"); 822 } 823 824 /// Whether or not the current loop being generated should be parallized (if 825 /// possible) according to the configuration. 826 static bool shouldTryParallize(CodegenEnv &env, LoopId curr, 827 ArrayRef<TensorLevel> tidLvls) { 828 linalg::GenericOp op = env.op(); 829 auto iteratorTypes = op.getIteratorTypesArray(); 830 bool isSparse = llvm::any_of(tidLvls, [curr, &env](TensorLevel tidLvl) { 831 // Queries the LT based on the tensor and loop id, as requested by 832 // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv 833 // should be consistent with the LT indexed by <TensorId, Level>. 834 const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, curr); 835 return lt.hasSparseSemantic(); 836 }); 837 return isParallelFor(env, /*isOuter=*/curr == 0, isSparse); 838 } 839 840 /// Emit a loop to coiterate over the list of tensor levels. The generated loop 841 /// can either be a for loop or while loop depending on whether there is at most 842 /// one sparse level in the list. 843 static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder, 844 ArrayRef<TensorLevel> tidLvls, 845 unsigned numCases, bool tryParallel, 846 bool needsUniv) { 847 Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) { 848 // Construct while-loop with a parameter for each index. 849 return env.emitter().enterCoIterationOverTensorsAtLvls( 850 builder, env.op().getLoc(), tidLvls, numCases, reduc, tryParallel, 851 needsUniv); 852 }); 853 assert(loop); 854 return loop; 855 } 856 857 /// Generates a for-loop or a while-loop, depending on whether it implements 858 /// singleton iteration or co-iteration over the given conjunction. 859 static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, 860 unsigned numCases, bool needsUniv, 861 ArrayRef<TensorLevel> tidLvls) { 862 bool tryParallel = shouldTryParallize(env, curr, tidLvls); 863 return genCoIteration(env, builder, tidLvls, numCases, tryParallel, 864 needsUniv); 865 } 866 867 /// Generates the induction structure for a while-loop. 868 static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, 869 bool needsUniv) { 870 Location loc = env.op().getLoc(); 871 // Finalize each else branch of all if statements. 872 if (env.isReduc() || env.isExpand() || env.getInsertionChain()) { 873 while (auto ifOp = dyn_cast_or_null<scf::IfOp>( 874 builder.getInsertionBlock()->getParentOp())) { 875 // Break on IfOp for slicing filtering. 876 if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) == 877 StringAttr::get(ifOp->getContext(), "slice")) 878 break; 879 880 unsigned y = 0; 881 SmallVector<Value> yields; 882 if (env.isReduc()) { 883 yields.push_back(env.getReduc()); 884 env.updateReduc(ifOp.getResult(y++)); 885 if (env.isValidLexInsert()) { 886 yields.push_back(env.getValidLexInsert()); 887 env.updateValidLexInsert(ifOp.getResult(y++)); 888 } 889 } 890 if (env.isExpand()) { 891 yields.push_back(env.getExpandCount()); 892 env.updateExpandCount(ifOp->getResult(y++)); 893 } 894 if (env.getInsertionChain()) { 895 yields.push_back(env.getInsertionChain()); 896 env.updateInsertionChain(ifOp->getResult(y++)); 897 } 898 assert(y == yields.size()); 899 builder.create<scf::YieldOp>(loc, yields); 900 builder.setInsertionPointAfter(ifOp); 901 } 902 } 903 // No need to set the insertion point here as LoopEmitter keeps track of the 904 // basic block where scf::Yield should be inserted. 905 } 906 907 /// Generates a case region in the coiterate operation. 908 static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder, 909 unsigned caseIdx, LatPointId allCase, 910 LatPointId curCase, 911 MutableArrayRef<Value> reduc) { 912 assert(allCase == curCase || env.merger().latGT(allCase, curCase)); 913 const BitVector &allCaseBits = env.merger().lat(allCase).simple; 914 const BitVector &curCaseBits = env.merger().lat(curCase).simple; 915 916 /// Computes the subset of iterators that are valid in the current case being 917 /// generated. 918 I64BitSet caseBit(0); 919 for (auto [idx, set] : llvm::enumerate(allCaseBits.set_bits())) 920 if (curCaseBits.test(set)) 921 caseBit.set(idx); 922 923 env.emitter().enterCurrentCoIterationCase(builder, env.op().getLoc(), caseBit, 924 caseIdx, reduc); 925 } 926 927 /// Generates a single if-statement within a while-loop. 928 static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr, 929 LatPointId p) { 930 Location loc = env.op().getLoc(); 931 SmallVector<Type> types; 932 Value cond; 933 env.merger().foreachTensorLoopId( 934 p, /*simple=*/true, 935 [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt, 936 bool isIdxRed) { 937 if (isIdxRed) { 938 // Since there is no 1:1 mapping from loop to level (multiple loops 939 // are required to resolve one level with non-trivial index 940 // expression), we need to reconstruct the tensor level types if this 941 // loop requires index reduction condition. 942 assert(lvl.has_value() && isUndefLT(lt)); 943 auto stt = getSparseTensorType(env.op().getInputs()[tid]); 944 lt = stt.getLvlType(*lvl); 945 } 946 assert(curr == env.merger().loop(b)); 947 Value clause; 948 if (lt.hasSparseSemantic()) { 949 assert(lvl.has_value()); 950 const Value crd = env.emitter().getCoord(tid, *lvl); 951 const Value lvar = env.getLoopVar(curr); 952 clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 953 crd, lvar); 954 } else { 955 assert(lt.hasDenseSemantic() || isUndefLT(lt)); 956 clause = constantI1(builder, loc, true); 957 } 958 cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause; 959 }); 960 if (env.isReduc()) { 961 types.push_back(env.getReduc().getType()); 962 if (env.isValidLexInsert()) 963 types.push_back(env.getValidLexInsert().getType()); 964 } 965 if (env.isExpand()) 966 types.push_back(builder.getIndexType()); 967 if (env.getInsertionChain()) 968 types.push_back(env.getInsertionChain().getType()); 969 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); 970 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 971 return ifOp; 972 } 973 974 /// Generates end of true branch of if-statement within a while-loop. 975 static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, 976 Value redInput, Value cntInput, Value insInput, 977 Value validIns) { 978 SmallVector<Value> operands; 979 if (env.isReduc()) { 980 operands.push_back(env.getReduc()); 981 env.updateReduc(redInput); 982 if (env.isValidLexInsert()) { 983 // Any overlapping indices during a reduction creates a valid lex insert. 984 operands.push_back(constantI1(builder, env.op().getLoc(), true)); 985 env.updateValidLexInsert(validIns); 986 } 987 } 988 if (env.isExpand()) { 989 operands.push_back(env.getExpandCount()); 990 env.updateExpandCount(cntInput); 991 } 992 if (env.getInsertionChain()) { 993 operands.push_back(env.getInsertionChain()); 994 env.updateInsertionChain(insInput); 995 } 996 if (!operands.empty()) 997 builder.create<scf::YieldOp>(env.op().getLoc(), operands); 998 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 999 } 1000 1001 //===----------------------------------------------------------------------===// 1002 // Sparsifier synthesis methods (loop sequence). 1003 //===----------------------------------------------------------------------===// 1004 1005 static bool getAllTidLvlsInLatPoints( 1006 CodegenEnv &env, LatPointId li, LoopId curr, 1007 llvm::function_ref<void(TensorLevel, AffineExpr)> callback) { 1008 const BitVector &simple = env.lat(li).simple; 1009 const TensorId outTid = env.merger().getOutTensorID(); 1010 const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr); 1011 1012 unsigned numloopCond = 0; 1013 bool hasNonUnique = false; 1014 env.merger().foreachTensorLoopId( 1015 li, [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl, 1016 LevelType lt, bool isIdxReduc) { 1017 if (simple[b]) { 1018 if (isIdxReduc) { 1019 callback(env.makeTensorLevel(tid, *lvl), nullptr); 1020 numloopCond++; 1021 return; 1022 } 1023 if (isUndefLT(lt)) { 1024 // An undefined lt in the lattices, we probably mean to 1025 // generate a dense loop according to the synthetic tensor (for 1026 // invariants and sparse output tensor). 1027 if (env.merger().getSynTensorID() == tid) { 1028 // Coiterating with an invariant 1029 // e.g., out = prod(in[i][j] op invariant); 1030 // or a broadcast 1031 // e.g., out[i][j] = in[i] (j is undef for input) 1032 // 1033 // The level of the synthetic tensor is the current loop depth; 1034 // the rank of the synthetic tensor equals to number of loops. 1035 assert(curr == env.getCurrentDepth()); 1036 lvl = curr; 1037 } else if (!lvl) { 1038 // Skips invalid lvl (e.g., when this is a zero ranked tensor). 1039 return; 1040 } 1041 } 1042 hasNonUnique = !isUniqueLT(lt) || hasNonUnique; 1043 callback(env.makeTensorLevel(tid, *lvl), nullptr); 1044 numloopCond++; 1045 } else if (lt.hasDenseSemantic() || isIdxReduc) { 1046 callback(env.makeTensorLevel(tid, *lvl), nullptr); 1047 } else { 1048 assert(isUndefLT(lt)); 1049 linalg::GenericOp op = env.op(); 1050 if (tid >= op.getNumDpsInputs()) 1051 // We only handle affine expression on input tensors (for now). 1052 return; 1053 OpOperand *operand = &op->getOpOperand(tid); 1054 const auto stt = getSparseTensorType(operand->get()); 1055 // Non-annotated dense tensors requires no special handling. 1056 if (!stt.hasEncoding()) 1057 return; 1058 1059 ArrayRef<AffineExpr> affines = 1060 op.getMatchingIndexingMap(operand).getResults(); 1061 const Level lvlRank = stt.getLvlRank(); 1062 assert(affines.size() == static_cast<size_t>(lvlRank)); 1063 for (Level l = 0; l < lvlRank; l++) { 1064 AffineExpr exp = affines[l]; 1065 // Skip simple affine expression and non-dense levels (which 1066 // have their own filter loop). 1067 LevelType lt = stt.getLvlType(l); 1068 if (isa<AffineDimExpr>(exp) || !lt.hasDenseSemantic()) 1069 continue; 1070 1071 // Constant affine expression are handled in genLoop. 1072 if (!isa<AffineConstantExpr>(exp)) { 1073 bool isCurrentLoop = false; 1074 assert(curr == env.getCurrentDepth()); 1075 if (isInvariantAffine(exp, curr + 1, /*out*/ isCurrentLoop) && 1076 isCurrentLoop) { 1077 // If the compound affine is invariant and we are right at the 1078 // level. We need to generate the address according to the 1079 // affine expression. This is also the best place we can do it 1080 // to avoid putting it inside inner loops. 1081 callback(env.makeTensorLevel(tid, l), exp); 1082 } 1083 } 1084 } 1085 } 1086 }); 1087 1088 if (isDenseLT(env.lt(outTid, curr))) { 1089 auto stt = getSparseTensorType(env.op().getOutputs().front()); 1090 // Note that we generate dense indices of the output tensor unconditionally, 1091 // since they may not appear in the lattice, but may be needed for 1092 // linearized env. 1093 // TODO: we should avoid introducing corner cases for all-dense sparse 1094 // tensors. 1095 if (stt.hasEncoding() && stt.isAllDense()) 1096 callback(env.makeTensorLevel(outTid, *outLvl), nullptr); 1097 } 1098 1099 if (numloopCond == 0) { 1100 // Corner cases where the loop bound is defined by a *unused* operand, in 1101 // this case, we just generate a dense "fake" loop by iterating over the 1102 // synthetic tensor. 1103 callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr); 1104 numloopCond++; 1105 } 1106 // If we just need to one loop conditions and the conditions is not imposed on 1107 // non-unique level, the loop can be generated by a for loop. 1108 // Or, if we are generating sparse-iterator-based loops, we always generate 1109 // `sparse_tensor.iterate` regardless whether the level is unique or not. 1110 return numloopCond == 1 && 1111 (!hasNonUnique || env.options().sparseEmitStrategy == 1112 SparseEmitStrategy::kSparseIterator); 1113 } 1114 1115 /// Starts a loop sequence at given level. Returns true if 1116 /// the universal loop index must be maintained at this level. 1117 static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, 1118 LoopId curr, LatSetId lts) { 1119 assert(!env.getLoopVar(curr)); 1120 // Emit invariants at this loop sequence level. 1121 genInvariants(env, builder, exp, curr, /*isStart=*/true); 1122 // Emit access pattern expansion for sparse tensor output. 1123 genExpand(env, builder, curr, /*isStart=*/true); 1124 // Emit further initialization at this loop sequence level. 1125 const LatPointId l0 = env.set(lts)[0]; 1126 1127 SmallVector<TensorLevel> tidLvls; 1128 getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) { 1129 // TODO: remove this! The same tensor level might be added for multiple 1130 // times due to the special handling for all-dense "sparse" output tensor 1131 // (see L1038). 1132 if (llvm::find(tidLvls, tl) != tidLvls.end()) 1133 return; 1134 tidLvls.emplace_back(tl); 1135 }); 1136 1137 env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls); 1138 1139 // Maintain the universal index only if it is actually 1140 // consumed by a subsequent lattice point. 1141 for (const LatPointId li : env.set(lts).drop_front()) 1142 if (!env.merger().hasAnySparse(env.lat(li).simple)) 1143 return true; 1144 1145 return false; 1146 } 1147 1148 // Generates dense affine address for encoding. 1149 static void genConstantDenseAddressFromLevel(CodegenEnv &env, 1150 OpBuilder &builder, TensorId tid, 1151 Level startLvl) { 1152 // TODO: Handle affine expression on output tensor. 1153 linalg::GenericOp op = env.op(); 1154 assert(tid < op.getNumDpsInputs()); 1155 OpOperand *input = op.getDpsInputOperands()[tid]; 1156 const auto lvlExprs = op.getMatchingIndexingMap(input).getResults(); 1157 const auto enc = getSparseTensorEncoding(input->get().getType()); 1158 if (enc) { 1159 const Location loc = op.getLoc(); 1160 const TensorId tid = env.makeTensorId(input->getOperandNumber()); 1161 const Level lvlRank = enc.getLvlRank(); 1162 assert(lvlExprs.size() == static_cast<size_t>(lvlRank)); 1163 for (Level l = startLvl; l < lvlRank; l++) { 1164 AffineExpr lvlExpr = lvlExprs[l]; 1165 if (enc.getLvlType(l).hasDenseSemantic() && 1166 isa<AffineConstantExpr>(lvlExpr)) 1167 env.emitter().locateLvlAtAffineAddress( 1168 builder, loc, env.makeTensorLevel(tid, l), lvlExpr); 1169 else 1170 return; // break on first non-dense non-constant level 1171 } 1172 } 1173 } 1174 1175 // We can generate address for constant affine expression before any loops 1176 // starting from the first level as they do not depend on anything. 1177 // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two 1178 // levels can be determined before loops. 1179 static void genInitConstantDenseAddress(CodegenEnv &env, 1180 RewriterBase &rewriter) { 1181 for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++) 1182 genConstantDenseAddressFromLevel(env, rewriter, tid, 0); 1183 } 1184 1185 /// Returns true if the lattice bit can be iterated by a for loop. 1186 static bool translateBitsToTidLvlPairs( 1187 CodegenEnv &env, LatPointId li, LoopId curr, 1188 SmallVectorImpl<TensorLevel> &tidLvls, 1189 SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) { 1190 return getAllTidLvlsInLatPoints(env, li, curr, 1191 [&](TensorLevel tl, AffineExpr exp) { 1192 if (exp) 1193 affineTidLvls.emplace_back(tl, exp); 1194 else 1195 tidLvls.emplace_back(tl); 1196 }); 1197 } 1198 1199 /// Starts a single loop in current sequence. 1200 static std::pair<Operation *, bool> startLoop(CodegenEnv &env, 1201 OpBuilder &builder, LoopId curr, 1202 LatPointId li, unsigned numCases, 1203 bool needsUniv) { 1204 // TODO: numCases only used when generating iterator-based loops. Cleanup 1205 // after fully migration. 1206 // The set of tensors + lvls to generate loops on 1207 SmallVector<TensorLevel> tidLvls; 1208 1209 // The set of dense tensors with non-trivial affine expression that just 1210 // becomes invariant and the address are generated at the current level. 1211 SmallVector<std::pair<TensorLevel, AffineExpr>> affineTidLvls; 1212 bool isSingleCond = 1213 translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls); 1214 1215 // Emit the for/while-loop control. 1216 Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls); 1217 Location loc = env.op().getLoc(); 1218 for (auto [tidLvl, exp] : affineTidLvls) { 1219 env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp); 1220 } 1221 1222 // Until now, we have entered every <tid, lvl> pair in {cond, extra, 1223 // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent 1224 // on constant affines expression may now be determined. 1225 auto allTidLvls = 1226 llvm::concat<TensorLevel>(tidLvls, llvm::make_first_range(affineTidLvls)); 1227 for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) { 1228 if (tid != env.merger().getOutTensorID() && 1229 tid != env.merger().getSynTensorID()) 1230 genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1); 1231 } 1232 1233 return std::make_pair(loop, isSingleCond); 1234 } 1235 1236 /// Ends a single loop in current sequence. Returns new values for needsUniv. 1237 static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, 1238 LatPointId li, bool needsUniv, bool isSingleCond) { 1239 // Either a for-loop or a while-loop that iterates over a slice. 1240 if (isSingleCond) { 1241 // Any iteration creates a valid lex insert. 1242 if (env.isReduc() && env.isValidLexInsert()) 1243 env.updateValidLexInsert(constantI1(rewriter, env.op().getLoc(), true)); 1244 } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { 1245 // End a while-loop. 1246 finalizeWhileOp(env, rewriter, needsUniv); 1247 } else { 1248 needsUniv = false; 1249 } 1250 env.genLoopBoundary([&](MutableArrayRef<Value> reduc) { 1251 env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc); 1252 return std::nullopt; 1253 }); 1254 return needsUniv; 1255 } 1256 1257 /// Ends a loop sequence at given level. 1258 static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, 1259 unsigned at) { 1260 assert(!env.getLoopVar(at)); 1261 env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc()); 1262 // Unmark bookkeeping of invariants and loop index. 1263 genInvariants(env, builder, exp, at, /*isStart=*/false); 1264 // Finalize access pattern expansion for sparse tensor output. 1265 genExpand(env, builder, at, /*isStart=*/false); 1266 } 1267 1268 /// Recursively generates code while computing iteration lattices in order 1269 /// to manage the complexity of implementing co-iteration over unions 1270 /// and intersections of sparse iterations spaces. 1271 static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, 1272 LoopId curr) { 1273 assert(curr == env.getCurrentDepth()); 1274 1275 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1276 if (curr == env.getLoopNum()) { 1277 Value rhs = genExp(env, rewriter, exp); 1278 genTensorStore(env, rewriter, exp, rhs); 1279 return; 1280 } 1281 1282 // Construct iteration lattices for current loop index. 1283 const LatSetId lts = 1284 env.merger().optimizeSet(env.merger().buildLattices(exp, curr)); 1285 1286 // Start a loop sequence. 1287 bool needsUniv = startLoopSeq(env, rewriter, exp, curr, lts); 1288 1289 // When using sparse-iterator-based loops, we only need one loops, as 1290 // opposed to a loop sequence, to cover all the iterator spaces. 1291 const unsigned lsize = env.set(lts).size(); 1292 if (env.generatingSparseIterator()) { 1293 // Get the largest lattice point and start a loop. 1294 const LatPointId li = env.set(lts)[0]; 1295 auto [loop, isSingleCond] = 1296 startLoop(env, rewriter, curr, li, lsize, needsUniv); 1297 assert(isSingleCond == llvm::isa<IterateOp>(loop)); 1298 // We cannot change this to `for (const LatPointId li : env.set(lts))` 1299 // because the loop body causes data-movement which invalidates 1300 // the iterator. 1301 for (unsigned j = 0; j < lsize; j++) { 1302 const LatPointId lj = env.set(lts)[j]; 1303 const ExprId ej = env.lat(lj).exp; 1304 // Recurse into body of each branch. 1305 if (!isSingleCond) { 1306 env.genLoopBoundary([&, curr, j, li, lj](MutableArrayRef<Value> reduc) { 1307 genCoIterationCase(env, rewriter, /*caseIdx*/ j, li, lj, reduc); 1308 genStmt(env, rewriter, ej, curr + 1); 1309 // TODO: handle yield values. 1310 assert(reduc.empty() && "Not Implemented"); 1311 rewriter.create<sparse_tensor::YieldOp>(env.op().getLoc()); 1312 return std::nullopt; 1313 }); 1314 // endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns); 1315 } else { 1316 genStmt(env, rewriter, ej, curr + 1); 1317 } 1318 } 1319 // End a loop. 1320 needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond); 1321 } else { 1322 // Emit a loop for every lattice point L0 >= Li in this loop sequence. 1323 for (unsigned i = 0; i < lsize; i++) { 1324 const LatPointId li = env.set(lts)[i]; 1325 // Start a loop. 1326 auto [loop, isSingleCond] = 1327 startLoop(env, rewriter, curr, li, lsize, needsUniv); 1328 1329 // Visit all lattices points with Li >= Lj to generate the 1330 // loop-body, possibly with if statements for coiteration. 1331 Value redInput = env.getReduc(); 1332 Value cntInput = env.getExpandCount(); 1333 Value insInput = env.getInsertionChain(); 1334 Value validIns = env.getValidLexInsert(); 1335 // We cannot change this to `for (const LatPointId lj : env.set(lts))` 1336 // because the loop body causes data-movement which invalidates the 1337 // iterator. 1338 for (unsigned j = 0; j < lsize; j++) { 1339 const LatPointId lj = env.set(lts)[j]; 1340 const ExprId ej = env.lat(lj).exp; 1341 if (li == lj || env.merger().latGT(li, lj)) { 1342 // Recurse into body of each branch. 1343 if (!isSingleCond) { 1344 scf::IfOp ifOp = genIf(env, rewriter, curr, lj); 1345 genStmt(env, rewriter, ej, curr + 1); 1346 endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns); 1347 } else { 1348 genStmt(env, rewriter, ej, curr + 1); 1349 } 1350 } 1351 } 1352 1353 // End a loop. 1354 needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond); 1355 } 1356 } 1357 1358 // End a loop sequence. 1359 endLoopSeq(env, rewriter, exp, curr); 1360 assert(curr == env.getCurrentDepth()); 1361 } 1362 1363 /// Converts the result computed by the sparse kernel into the required form. 1364 static void genResult(CodegenEnv &env, RewriterBase &rewriter) { 1365 linalg::GenericOp op = env.op(); 1366 OpOperand *lhs = op.getDpsInitOperand(0); 1367 Value tensor = lhs->get(); 1368 Type resType = tensor.getType(); 1369 if (getSparseTensorEncoding(resType)) { 1370 // The sparse tensor rematerializes from the original sparse tensor's 1371 // underlying sparse storage format. For an insertion chain, the 1372 // tensor materializes from the chain with 'hasInserts' enabled. 1373 bool hasInserts = false; 1374 if (Value chain = env.getInsertionChain()) { 1375 hasInserts = true; 1376 tensor = chain; 1377 } 1378 rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts); 1379 } else { 1380 // To rematerialize an non-annotated tensor, simply load it 1381 // from the bufferized value. 1382 Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()]; 1383 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val); 1384 } 1385 } 1386 1387 //===----------------------------------------------------------------------===// 1388 // Sparsifier rewriting methods. 1389 //===----------------------------------------------------------------------===// 1390 1391 namespace { 1392 1393 /// Sparse rewriting rule for generic Lingalg operation. 1394 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1395 public: 1396 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1397 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1398 1399 LogicalResult matchAndRewrite(linalg::GenericOp op, 1400 PatternRewriter &rewriter) const override { 1401 // Only accept single output operations with pure tensor semantics. 1402 if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics()) 1403 return failure(); 1404 1405 // Only accept trivial affine indices. 1406 if (hasNonTrivialAffineOnSparseOut(op)) 1407 return failure(); 1408 1409 // Only accept scheduled loops. 1410 if (!op->hasAttr("sorted")) { 1411 return rewriter.notifyMatchFailure( 1412 op, "Loops not yet scheduled, try run --sparse-reinterpret-map " 1413 "before sparsification."); 1414 } 1415 1416 // Must have been demapped as well if the generic op is sorted. 1417 assert(!hasAnyNonIdentityOperandsOrResults(op)); 1418 1419 // Sets up a code generation environment. 1420 const unsigned numTensors = op->getNumOperands(); 1421 const unsigned numLoops = op.getNumLoops(); 1422 bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0; 1423 // If we have indexing map like (d0) -> (0, d0), there might be more 1424 // levels then loops because of the constant index, that means we can not 1425 // use numLoops as the upper bound for ranks of all tensors. 1426 // TODO: Constant indices are currently not support on sparse tensor, but 1427 // are allowed in non-annotated dense tensor. Support it, it would be 1428 // required for sparse tensor slice rank reducing too. 1429 Level maxLvlRank = 0; 1430 for (auto operand : op.getOperands()) { 1431 if (auto rtp = dyn_cast<RankedTensorType>(operand.getType())) { 1432 maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank()); 1433 } 1434 } 1435 1436 // Detects sparse annotations and translates the per-level sparsity 1437 // information for all tensors to loop indices in the kernel. 1438 CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank); 1439 if (!findSparseAnnotations(env, needIdxRed)) 1440 return failure(); 1441 1442 // Only standard reduction operations (add, sub, or, xor) that can be 1443 // sparsified by merely reducing the stored values are admissible. More 1444 // elaborate reduction operations (such as mul, and, min, max) would need 1445 // to know whether implicit zeros occur as well. They can still be 1446 // implemented with a custom reduction operation, accepted here as well. 1447 if (op.getNumReductionLoops() > 0) { 1448 Operation *yield = op.getRegion().front().getTerminator(); 1449 assert(isa<linalg::YieldOp>(yield)); 1450 Operation *redop = yield->getOperand(0).getDefiningOp(); 1451 if (!isa<arith::AddFOp>(redop) && !isa<complex::AddOp>(redop) && 1452 !isa<arith::AddIOp>(redop) && !isa<arith::SubFOp>(redop) && 1453 !isa<complex::SubOp>(redop) && !isa<arith::SubIOp>(redop) && 1454 !isa<arith::OrIOp>(redop) && !isa<arith::XOrIOp>(redop) && 1455 !isa<ReduceOp>(redop)) { 1456 return failure(); 1457 } 1458 } 1459 1460 // Constructs the tensor expressions tree from `op`, returns failure if the 1461 // tree can not be built or the tensor expression is inadmissible. 1462 if (failed(env.initTensorExp())) 1463 return failure(); 1464 1465 // Recursively generates code if admissible. 1466 env.startEmit(options.sparseEmitStrategy); 1467 genBuffers(env, rewriter); 1468 // TODO: Constant affine expression should be handled differently when using 1469 // slice-based codegen, it does not matter now because we already reject the 1470 // constant expression at an earlier stage. 1471 genInitConstantDenseAddress(env, rewriter); 1472 genStmt(env, rewriter, env.getExprId(), 0); 1473 genResult(env, rewriter); 1474 return success(); 1475 } 1476 1477 private: 1478 /// Options to control sparse code generation. 1479 SparsificationOptions options; 1480 }; 1481 1482 } // namespace 1483 1484 /// Populates the given patterns list with rewriting rules required for 1485 /// the sparsification of linear algebra operations. 1486 void mlir::populateSparsificationPatterns( 1487 RewritePatternSet &patterns, const SparsificationOptions &options) { 1488 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1489 } 1490