1 //===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements rewriting rules that are specific to sparse tensors. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Utils/CodegenUtils.h" 14 #include "Utils/LoopEmitter.h" 15 16 #include "mlir/Dialect/Affine/IR/AffineOps.h" 17 #include "mlir/Dialect/Arith/IR/Arith.h" 18 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 19 #include "mlir/Dialect/Linalg/IR/Linalg.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/IR/SCF.h" 23 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 24 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" 25 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 26 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 27 #include "mlir/Dialect/Tensor/IR/Tensor.h" 28 #include "mlir/Dialect/Vector/IR/VectorOps.h" 29 #include "mlir/IR/AffineMap.h" 30 #include "mlir/IR/Matchers.h" 31 #include "mlir/Support/LLVM.h" 32 33 using namespace mlir; 34 using namespace mlir::bufferization; 35 using namespace mlir::linalg; 36 using namespace mlir::sparse_tensor; 37 38 //===---------------------------------------------------------------------===// 39 // Helper methods for the actual rewriting rules. 40 //===---------------------------------------------------------------------===// 41 42 // Helper method to match any typed zero. 43 static bool isZeroValue(Value val) { 44 return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()); 45 } 46 47 // Helper to detect a sparse tensor type operand. 48 static bool isSparseTensor(Value v) { 49 auto enc = getSparseTensorEncoding(v.getType()); 50 return enc && !llvm::all_of(enc.getLvlTypes(), 51 [](auto lt) { return lt == LevelFormat::Dense; }); 52 } 53 static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); } 54 55 // Helper method to find zero/uninitialized tensor materialization. 56 static bool isMaterializing(OpOperand *op, bool isZero) { 57 Value val = op->get(); 58 // Check allocation, with zero alloc when required. 59 if (auto alloc = val.getDefiningOp<AllocTensorOp>()) { 60 Value copy = alloc.getCopy(); 61 if (isZero) 62 return copy && isZeroValue(copy); 63 return !copy; 64 } 65 // Check for empty tensor materialization. 66 if (auto empty = val.getDefiningOp<tensor::EmptyOp>()) 67 return !isZero; 68 // Last resort for zero alloc: the whole value is zero. 69 return isZero && isZeroValue(val); 70 } 71 72 // Helper to detect sampling operation. 73 static bool isSampling(GenericOp op) { 74 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); 75 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { 76 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) { 77 // Both scalar input arguments used exactly once. 78 Value s1 = op.getBlock()->getArgument(0); 79 Value s2 = op.getBlock()->getArgument(1); 80 return (def->getOperand(0) == s1 && def->getOperand(1) == s2) || 81 (def->getOperand(1) == s1 && def->getOperand(0) == s2); 82 } 83 } 84 return false; 85 } 86 87 // Helper to detect chain of multiplications that do not involve x. 88 static bool isMulChain(Value val, Value x) { 89 if (auto arg = dyn_cast<BlockArgument>(val)) 90 return arg != x; 91 if (auto *def = val.getDefiningOp()) { 92 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) 93 return isMulChain(def->getOperand(0), x) && 94 isMulChain(def->getOperand(1), x); 95 } 96 return false; 97 } 98 99 // Helper to detect x = x + <multiplications>. 100 static bool isSumOfMul(GenericOp op) { 101 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); 102 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { 103 if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) { 104 Value x = op.getBlock()->getArguments().back(); 105 return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) || 106 (def->getOperand(1) == x && isMulChain(def->getOperand(0), x)); 107 } 108 } 109 return false; 110 } 111 112 // Helper to detect direct yield of a zero value. 113 static bool isZeroYield(GenericOp op) { 114 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); 115 if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) { 116 if (arg.getOwner()->getParentOp() == op) { 117 return isZeroValue(op->getOperand(arg.getArgNumber())); 118 } 119 } 120 return isZeroValue(yieldOp.getOperand(0)); 121 } 122 123 /// Populates given sizes array from type (for static sizes) and from 124 /// the tensor (for dynamic sizes). 125 static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes, 126 Location loc, ShapedType stp, Value tensor) { 127 for (const auto &d : enumerate(stp.getShape())) { 128 Value dim; 129 if (d.value() == ShapedType::kDynamic) 130 dim = builder.create<tensor::DimOp>(loc, tensor, d.index()); 131 else 132 dim = constantIndex(builder, loc, d.value()); 133 sizes.push_back(dim); 134 } 135 } 136 137 static RankedTensorType getBufferType(const SparseTensorType &stt, 138 bool needTmpCOO) { 139 return needTmpCOO ? stt.getCOOType(/*ordered=*/false) 140 : stt.getRankedTensorType(); 141 } 142 143 /// Collects the dynamic dimension sizes for `tp` with the assumption that 144 /// `sizes` are the dimension sizes for the type. Stores the dynamic dimension 145 /// sizes to dynSizes. 146 static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, 147 SmallVectorImpl<Value> &dynSizes) { 148 for (const auto &d : enumerate(tp.getShape())) { 149 if (d.value() == ShapedType::kDynamic) 150 dynSizes.push_back(sizes[d.index()]); 151 } 152 } 153 154 static LogicalResult genForeachOnSparseConstant(ForeachOp op, 155 RewriterBase &rewriter, 156 SparseElementsAttr attr) { 157 auto loc = op.getLoc(); 158 SmallVector<Value> reduc = op.getInitArgs(); 159 160 // Foreach on constant. 161 foreachInSparseConstant( 162 rewriter, loc, attr, op.getOrder().value_or(AffineMap()), 163 [&reduc, &rewriter, op](ArrayRef<Value> cvs, Value v) mutable { 164 SmallVector<Value> args; 165 args.append(cvs.begin(), cvs.end()); 166 args.push_back(v); 167 args.append(reduc); 168 // Clones the foreach op to get a copy of the loop body. 169 auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation())); 170 assert(args.size() == cloned.getBody()->getNumArguments()); 171 Operation *yield = cloned.getBody()->getTerminator(); 172 rewriter.inlineBlockBefore(cloned.getBody(), op, args); 173 // clean up 174 rewriter.eraseOp(cloned); 175 reduc = yield->getOperands(); 176 rewriter.eraseOp(yield); 177 }); 178 179 rewriter.replaceOp(op, reduc); 180 return success(); 181 } 182 183 /// Populates the given sizes array for concatenation from types (for static 184 /// sizes) and from the source tensors (for dynamic sizes). 185 static void concatSizesFromInputs(OpBuilder &builder, 186 SmallVectorImpl<Value> &sizes, Location loc, 187 ShapedType dstTp, ValueRange srcs, 188 unsigned dim) { 189 auto dstShape = dstTp.getShape(); 190 sizesFromSrc(builder, sizes, loc, srcs[0]); 191 192 // Sum up on the `dim` if the dimension is dynamic. 193 if (dstShape[dim] != ShapedType::kDynamic) { 194 // Faithfully take the static size. 195 sizes[dim] = constantIndex(builder, loc, dstShape[dim]); 196 } else { 197 // Else, compute the shape dynamically. 198 for (const auto &src : srcs.drop_front()) { 199 Value srcSz = linalg::createOrFoldDimOp(builder, loc, src, dim); 200 // Sum up all the sizes. 201 sizes[dim] = builder.create<arith::AddIOp>(loc, sizes[dim], srcSz); 202 } 203 } 204 } 205 206 //===---------------------------------------------------------------------===// 207 // The actual sparse tensor rewriting rules. 208 //===---------------------------------------------------------------------===// 209 210 namespace { 211 212 /// Rewriting rule that converts direct yield of zero with initial allocation. 213 struct FoldInvariantYield : public OpRewritePattern<GenericOp> { 214 public: 215 using OpRewritePattern<GenericOp>::OpRewritePattern; 216 217 LogicalResult matchAndRewrite(GenericOp op, 218 PatternRewriter &rewriter) const override { 219 if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 || 220 !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) || 221 !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse()) 222 return failure(); 223 auto outputType = getRankedTensorType(op.getResult(0)); 224 // Yielding zero on newly materialized sparse tensor can be 225 // optimized directly (regardless of dynamic or static size). 226 if (getSparseTensorEncoding(outputType)) { 227 rewriter.replaceOp(op, op.getDpsInitOperand(0)->get()); 228 return success(); 229 } 230 // Use static zero value directly instead of materialization. 231 if (!outputType.hasStaticShape()) 232 return failure(); 233 Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp(); 234 rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType)); 235 rewriter.eraseOp(def); 236 return success(); 237 } 238 }; 239 240 /// Rewriting rule that converts two kernels: 241 /// 242 /// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... ) 243 /// X(i,j) = S(i,j) * T(i,j) 244 /// 245 /// into a single kernel, using distributive law: 246 /// 247 /// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... ) 248 /// 249 /// This kind of fusion (merging two ops into one but using arithmetic 250 /// equalities that may not hold for floating-point computations) would 251 /// be undesirable in the dense case, since we distribute the multiplication 252 /// into the reduction loop. However, for sparse sampling tensor S, such 253 /// a fusion may actually reduce the asymptotic complexity of the kernel, 254 /// since intermediate results may be nullified. 255 struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> { 256 public: 257 using OpRewritePattern<GenericOp>::OpRewritePattern; 258 259 LogicalResult matchAndRewrite(GenericOp op, 260 PatternRewriter &rewriter) const override { 261 // Check consumer. 262 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 || 263 op.getNumResults() != 1 || 264 op.getNumParallelLoops() != op.getNumLoops() || 265 !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() || 266 !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() || 267 !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity()) 268 return failure(); 269 // Find consuming OP2(sparse, other) or OP2(other, sparse). The other 270 // operand can be sparse or dense, since the point of this rewriting rule 271 // is detecting a situation in which *more* sparsity is introduced into 272 // a computation, be it already sparse or still dense. 273 unsigned other = 0; 274 if (isSparseTensor(op.getDpsInputOperand(0))) 275 other = 1; 276 else if (!isSparseTensor(op.getDpsInputOperand(1))) 277 return failure(); 278 // Check producer. 279 auto prod = dyn_cast_or_null<GenericOp>( 280 op.getDpsInputOperand(other)->get().getDefiningOp()); 281 if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 || 282 !prod.getResult(0).hasOneUse()) 283 return failure(); 284 // Sampling consumer and sum of multiplication chain producer. 285 if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) || 286 !isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) || 287 !isSampling(op) || !isSumOfMul(prod)) 288 return failure(); 289 // Modify operand structure of producer and consumer. 290 Location loc = prod.getLoc(); 291 SmallVector<Value> inputOps = prod.getInputs(); 292 SmallVector<Value> outputOps = op.getOutputs(); 293 SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray(); 294 inputOps.push_back(op.getDpsInputOperand(1 - other)->get()); 295 fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other 296 // Fuse producer and consumer into a new generic op. 297 auto fusedOp = rewriter.create<GenericOp>( 298 loc, op.getResult(0).getType(), inputOps, outputOps, 299 rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(), 300 /*doc=*/nullptr, /*library_call=*/nullptr); 301 Block &prodBlock = prod.getRegion().front(); 302 Block &consBlock = op.getRegion().front(); 303 IRMapping mapper; 304 Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion()); 305 unsigned num = prodBlock.getNumArguments(); 306 for (unsigned i = 0; i < num - 1; i++) 307 addArg(mapper, fusedBlock, prodBlock.getArgument(i)); 308 addArg(mapper, fusedBlock, consBlock.getArgument(1 - other)); 309 addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1)); 310 // Clone bodies of the producer and consumer in new evaluation order. 311 auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp(); 312 auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp(); 313 Value last; 314 for (auto &op : prodBlock.without_terminator()) 315 if (&op != acc) { 316 last = op.getResult(0); 317 rewriter.clone(op, mapper); 318 } 319 mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0)); 320 mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0)); 321 last = rewriter.clone(*acc, mapper)->getResult(0); 322 rewriter.create<linalg::YieldOp>(loc, last); 323 // Force initial value on merged allocation for dense outputs. 324 // TODO: deal with non alloc tensor here one day 325 if (!getSparseTensorEncoding(op.getResult(0).getType())) { 326 Value init = prod.getDpsInitOperand(0) 327 ->get() 328 .getDefiningOp<AllocTensorOp>() 329 .getCopy(); 330 AllocTensorOp a = 331 op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>(); 332 rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); }); 333 } 334 // Replace consumer with fused operation. Old producer 335 // and consumer ops will be removed by DCE. 336 rewriter.replaceOp(op, fusedOp->getResults()); 337 return success(); 338 } 339 340 private: 341 // Helper to add argument and record the mapping. 342 static void addArg(IRMapping &mapper, Block *b, BlockArgument a) { 343 mapper.map(a, b->addArgument(a.getType(), a.getLoc())); 344 } 345 }; 346 347 // Fuse a tensor cast into producing operation. Note that a tensor.cast 348 // should really not be used to convert between sparse encodings. Since 349 // the pattern currently appears as a result of some prior rewriting 350 // we make an attempt to repair very obvious cases. 351 // TODO: audit the pure tensor dialect rewriting rules 352 struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> { 353 public: 354 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 355 356 LogicalResult matchAndRewrite(tensor::CastOp op, 357 PatternRewriter &rewriter) const override { 358 Type srcType = op.getSource().getType(); 359 Type dstType = op.getDest().getType(); 360 // A nop cast simply folds away. 361 if (srcType == dstType) { 362 rewriter.replaceOp(op, op->getResults()); 363 return success(); 364 } 365 // See if a sparsity changing cast can be fused into producer. 366 if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) { 367 if (Operation *def = op.getSource().getDefiningOp()) { 368 if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) { 369 rewriter.modifyOpInPlace(def, [&]() { 370 def->getResult(0).setType(op->getResultTypes()[0]); 371 }); 372 rewriter.replaceOp(op, def->getResult(0)); 373 return success(); 374 } 375 } 376 } 377 // Repair tensor casts with at least one sparse operand into the 378 // the properly supported sparse_tensor.convert. 379 if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) { 380 rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource()); 381 return success(); 382 } 383 // Fail otherwise. 384 return failure(); 385 } 386 }; 387 388 /// Rewrites a sequence of operations for sparse tensor selections in to 389 /// semi-ring operations such that they can be compiled correctly by the 390 /// sparsifier. E.g., transforming the following sequence 391 /// 392 /// %sel = arith.select %cond, %sp1, %sp2 393 /// 394 /// to 395 /// 396 /// %sel = binary %sp1, %sp2: 397 /// both (%l, %r) {yield select %cond, %l, %r} 398 /// left (%l) {yield select %cond, %l, 0} 399 /// right (%r) {yield select %cond, 0, %r} 400 /// 401 /// TODO: We require that the tensor used for extracting conditions to be dense 402 /// to sparsify the code. To support a sparse condition tensor, we need a 403 /// tri-nary operation. 404 struct GenSemiRingSelect : public OpRewritePattern<GenericOp> { 405 public: 406 using OpRewritePattern<GenericOp>::OpRewritePattern; 407 LogicalResult matchAndRewrite(GenericOp op, 408 PatternRewriter &rewriter) const override { 409 // Rejects non sparse kernels. 410 if (!op.hasPureTensorSemantics() || !hasAnySparseOperand(op)) 411 return failure(); 412 413 Location loc = op.getLoc(); 414 SmallVector<std::pair<Operation *, sparse_tensor::BinaryOp>> semiRings; 415 for (Operation &inst : *op.getBody()) { 416 // Matches pattern. 417 auto matched = isRewritablePattern(op, &inst); 418 if (!matched.has_value()) 419 continue; 420 421 rewriter.setInsertionPoint(&inst); 422 auto [c, t, f] = matched.value(); 423 assert(t.getType() == f.getType()); 424 auto selTp = t.getType(); 425 auto c0 = constantZero(rewriter, loc, selTp); 426 auto binOp = rewriter.create<sparse_tensor::BinaryOp>(loc, selTp, t, f); 427 // Initializes all the blocks. 428 rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp}, 429 {t.getLoc(), f.getLoc()}); 430 rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc()); 431 rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc()); 432 433 for (auto *r : binOp.getRegions()) { 434 Block *b = &r->front(); 435 rewriter.setInsertionPointToStart(b); 436 437 IRMapping irMap; 438 // Clones the cmp operations into the region to make the binary op 439 // admissible. 440 Value newC = c; 441 if (auto *def = c.getDefiningOp()) 442 newC = rewriter.clone(*def, irMap)->getResult(0); 443 444 irMap.map(c, newC); 445 if (r == &binOp.getLeftRegion()) { 446 irMap.map(t, b->getArgument(0)); 447 irMap.map(f, c0); 448 } else if (r == &binOp.getRightRegion()) { 449 irMap.map(t, c0); 450 irMap.map(f, b->getArgument(0)); 451 } else { 452 irMap.map(t, b->getArgument(0)); 453 irMap.map(f, b->getArgument(1)); 454 } 455 auto y = rewriter.clone(inst, irMap)->getResult(0); 456 rewriter.create<sparse_tensor::YieldOp>(loc, y); 457 } 458 459 // We successfully rewrited a operation. We can not do replacement here 460 // becuase it invalidate the iterator for the current loop to traverse 461 // the instructions. 462 semiRings.emplace_back(&inst, binOp); 463 } 464 465 // Finalizes the replacement. 466 for (auto [sel, semi] : semiRings) 467 rewriter.replaceOp(sel, semi->getResults()); 468 469 return success(!semiRings.empty()); 470 } 471 472 private: 473 static std::optional<std::tuple<Value, BlockArgument, BlockArgument>> 474 isRewritablePattern(GenericOp op, Operation *v) { 475 auto sel = dyn_cast<arith::SelectOp>(v); 476 if (!sel) 477 return std::nullopt; 478 479 auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue()); 480 auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue()); 481 // TODO: For simplicity, we only handle cases where both true/false value 482 // are directly loaded the input tensor. We can probably admit more cases 483 // in theory. 484 if (!tVal || !fVal) 485 return std::nullopt; 486 487 // Helper lambda to determine whether the value is loaded from a dense input 488 // or is a loop invariant. 489 auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool { 490 if (auto bArg = dyn_cast<BlockArgument>(v); 491 bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber()))) 492 return true; 493 // If the value is defined outside the loop, it is a loop invariant. 494 return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody(); 495 }; 496 497 // If the condition value is load directly from a dense tensor or 498 // loop-invariants, we can sparsify the kernel. 499 auto cond = sel.getCondition(); 500 if (isValFromDenseInputOrInvariant(cond)) 501 return std::make_tuple(cond, tVal, fVal); 502 503 Value cmpL, cmpR; 504 if (matchPattern(cond, m_Op<arith::CmpIOp>(matchers::m_Any(&cmpL), 505 matchers::m_Any(&cmpR))) || 506 matchPattern(cond, m_Op<arith::CmpFOp>(matchers::m_Any(&cmpL), 507 matchers::m_Any(&cmpR)))) { 508 // TODO: we can do it recursively to check whether all the leaf values are 509 // loaded from dense tensors or are loop invariants. 510 if (isValFromDenseInputOrInvariant(cmpL) || 511 isValFromDenseInputOrInvariant(cmpR)) 512 return std::make_tuple(cond, tVal, fVal); 513 } 514 515 return std::nullopt; 516 }; 517 }; 518 519 /// Rewrites a sparse reduction that would not sparsify directly since 520 /// doing so would only iterate over the stored elements, ignoring the 521 /// implicit zeros, into a semi-ring. Applies to all prod/and/min/max 522 /// (note that reductions like add/sub/or/xor can directly be sparsified 523 /// since the implicit zeros do not contribute to the final result). 524 /// Note that prod/and are still included since, even though they often 525 /// are nullified in sparse data, they may still occur for special 526 /// situations in which e.g. some rows in a sparse matrix are fully 527 /// dense. For min/max, including the implicit zeros is a much more 528 /// common situation. 529 /// 530 /// TODO: this essentially "densifies" the operation; we want to implement 531 /// this much more efficiently by performing the reduction over the 532 /// stored values, and feed in the zero once if there were *any* 533 /// implicit zeros as well; but for now, at least we provide 534 /// the functionality 535 /// 536 struct GenSemiRingReduction : public OpRewritePattern<GenericOp> { 537 public: 538 using OpRewritePattern<GenericOp>::OpRewritePattern; 539 540 LogicalResult matchAndRewrite(GenericOp op, 541 PatternRewriter &rewriter) const override { 542 // Reject non-reductions. 543 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 || 544 op.getNumReductionLoops() == 0 || op.getNumResults() != 1) 545 return failure(); 546 auto *inp = op.getDpsInputOperand(0); 547 auto *init = op.getDpsInitOperand(0); 548 if (!isSparseTensor(inp)) 549 return failure(); 550 // Look for direct x = x OP y for semi-ring ready reductions. 551 auto *red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()) 552 .getOperand(0) 553 .getDefiningOp(); 554 if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp, 555 arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp, 556 arith::MaxUIOp>(red)) 557 return failure(); 558 Value s0 = op.getBlock()->getArgument(0); 559 Value s1 = op.getBlock()->getArgument(1); 560 if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) && 561 (red->getOperand(0) != s1 || red->getOperand(1) != s0)) 562 return failure(); 563 // Identity. 564 Location loc = op.getLoc(); 565 Value identity = 566 rewriter.create<tensor::ExtractOp>(loc, init->get(), ValueRange()); 567 // Unary { 568 // present -> value 569 // absent -> zero. 570 // } 571 Type rtp = s0.getType(); 572 rewriter.setInsertionPointToStart(&op.getRegion().front()); 573 auto semiring = rewriter.create<sparse_tensor::UnaryOp>(loc, rtp, s0); 574 Block *present = 575 rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc); 576 rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front()); 577 rewriter.create<sparse_tensor::YieldOp>(loc, present->getArgument(0)); 578 rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {}); 579 rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front()); 580 auto zero = 581 rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(rtp)); 582 rewriter.create<sparse_tensor::YieldOp>(loc, zero); 583 rewriter.setInsertionPointAfter(semiring); 584 // CustomReduce { 585 // x = x REDUC y, identity 586 // } 587 auto custom = rewriter.create<sparse_tensor::ReduceOp>( 588 loc, rtp, semiring.getResult(), s1, identity); 589 Block *region = 590 rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc}); 591 rewriter.setInsertionPointToStart(&custom.getRegion().front()); 592 IRMapping irMap; 593 irMap.map(red->getOperand(0), region->getArgument(0)); 594 irMap.map(red->getOperand(1), region->getArgument(1)); 595 auto *cloned = rewriter.clone(*red, irMap); 596 rewriter.create<sparse_tensor::YieldOp>(loc, cloned->getResult(0)); 597 rewriter.setInsertionPointAfter(custom); 598 rewriter.replaceOp(red, custom.getResult()); 599 return success(); 600 } 601 }; 602 603 /// Sparse rewriting rule for the print operator. This operation is mainly used 604 /// for debugging and testing. As such, it lowers to the vector.print operation 605 /// which only require very light-weight runtime support. 606 struct PrintRewriter : public OpRewritePattern<PrintOp> { 607 public: 608 using OpRewritePattern::OpRewritePattern; 609 LogicalResult matchAndRewrite(PrintOp op, 610 PatternRewriter &rewriter) const override { 611 Location loc = op.getLoc(); 612 auto tensor = op.getTensor(); 613 auto stt = getSparseTensorType(tensor); 614 // Header with NSE. 615 auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor); 616 rewriter.create<vector::PrintOp>( 617 loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = ")); 618 rewriter.create<vector::PrintOp>(loc, nse); 619 // Print run-time contents for dim/lvl sizes. 620 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("dim = ")); 621 printSizes(rewriter, loc, tensor, stt.getDimRank(), /*isDim=*/true); 622 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("lvl = ")); 623 printSizes(rewriter, loc, tensor, stt.getLvlRank(), /*isDim=*/false); 624 // Use the "codegen" foreach loop construct to iterate over 625 // all typical sparse tensor components for printing. 626 foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor, 627 &stt](Type, FieldIndex, 628 SparseTensorFieldKind kind, 629 Level l, LevelType) { 630 switch (kind) { 631 case SparseTensorFieldKind::StorageSpec: { 632 break; 633 } 634 case SparseTensorFieldKind::PosMemRef: { 635 auto lvl = constantIndex(rewriter, loc, l); 636 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos[")); 637 rewriter.create<vector::PrintOp>( 638 loc, lvl, vector::PrintPunctuation::NoPunctuation); 639 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : ")); 640 auto pos = rewriter.create<ToPositionsOp>(loc, tensor, l); 641 printContents(rewriter, loc, pos); 642 break; 643 } 644 case SparseTensorFieldKind::CrdMemRef: { 645 auto lvl = constantIndex(rewriter, loc, l); 646 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("crd[")); 647 rewriter.create<vector::PrintOp>( 648 loc, lvl, vector::PrintPunctuation::NoPunctuation); 649 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : ")); 650 Value crd = nullptr; 651 // For COO AoS storage, we want to print a single, linear view of 652 // the full coordinate storage at this level. For any other storage, 653 // we show the coordinate storage for every indivual level. 654 if (stt.getAoSCOOStart() == l) 655 crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor); 656 else 657 crd = rewriter.create<ToCoordinatesOp>(loc, tensor, l); 658 printContents(rewriter, loc, crd); 659 break; 660 } 661 case SparseTensorFieldKind::ValMemRef: { 662 rewriter.create<vector::PrintOp>(loc, 663 rewriter.getStringAttr("values : ")); 664 auto val = rewriter.create<ToValuesOp>(loc, tensor); 665 printContents(rewriter, loc, val); 666 break; 667 } 668 } 669 return true; 670 }); 671 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("----\n")); 672 rewriter.eraseOp(op); 673 return success(); 674 } 675 676 private: 677 // Helper to print contents of a single memref. Note that for the "push_back" 678 // vectors, this prints the full capacity, not just the size. This is done 679 // on purpose, so that clients see how much storage has been allocated in 680 // total. Contents of the extra capacity in the buffer may be uninitialized 681 // (unless the flag enable-buffer-initialization is set to true). 682 // 683 // Generates code to print: 684 // ( a0, a1, ... ) 685 static void printContents(PatternRewriter &rewriter, Location loc, 686 Value vec) { 687 // Open bracket. 688 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open); 689 // For loop over elements. 690 auto zero = constantIndex(rewriter, loc, 0); 691 auto size = rewriter.create<memref::DimOp>(loc, vec, zero); 692 auto step = constantIndex(rewriter, loc, 1); 693 auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step); 694 rewriter.setInsertionPointToStart(forOp.getBody()); 695 auto idx = forOp.getInductionVar(); 696 auto val = rewriter.create<memref::LoadOp>(loc, vec, idx); 697 if (llvm::isa<ComplexType>(val.getType())) { 698 // Since the vector dialect does not support complex types in any op, 699 // we split those into (real, imag) pairs here. 700 Value real = rewriter.create<complex::ReOp>(loc, val); 701 Value imag = rewriter.create<complex::ImOp>(loc, val); 702 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open); 703 rewriter.create<vector::PrintOp>(loc, real, 704 vector::PrintPunctuation::Comma); 705 rewriter.create<vector::PrintOp>(loc, imag, 706 vector::PrintPunctuation::Close); 707 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma); 708 } else { 709 rewriter.create<vector::PrintOp>(loc, val, 710 vector::PrintPunctuation::Comma); 711 } 712 rewriter.setInsertionPointAfter(forOp); 713 // Close bracket and end of line. 714 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close); 715 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine); 716 } 717 718 // Helper method to print run-time lvl/dim sizes. 719 static void printSizes(PatternRewriter &rewriter, Location loc, Value tensor, 720 unsigned size, bool isDim) { 721 // Open bracket. 722 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open); 723 // Print unrolled contents (dimop requires constant value). 724 for (unsigned i = 0; i < size; i++) { 725 auto idx = constantIndex(rewriter, loc, i); 726 Value val; 727 if (isDim) 728 val = rewriter.create<tensor::DimOp>(loc, tensor, idx); 729 else 730 val = rewriter.create<LvlOp>(loc, tensor, idx); 731 rewriter.create<vector::PrintOp>( 732 loc, val, 733 i != size - 1 ? vector::PrintPunctuation::Comma 734 : vector::PrintPunctuation::NoPunctuation); 735 } 736 // Close bracket and end of line. 737 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close); 738 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine); 739 } 740 }; 741 742 /// Sparse rewriting rule for sparse-to-sparse reshape operator. 743 struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> { 744 public: 745 using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern; 746 747 LogicalResult matchAndRewrite(tensor::ReshapeOp op, 748 PatternRewriter &rewriter) const override { 749 Location loc = op.getLoc(); 750 Value srcTensor = op.getSource(); 751 const auto srcTp = getSparseTensorType(srcTensor); 752 const auto dstTp = getSparseTensorType(op.getResult()); 753 754 if (!srcTp.hasEncoding() || !dstTp.hasEncoding() || 755 !dstTp.hasStaticDimShape()) 756 return failure(); 757 758 SmallVector<Value> srcSizes; 759 sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor); 760 SmallVector<Value> dstSizes; 761 for (Dimension d : dstTp.getDimShape()) 762 dstSizes.push_back(constantIndex(rewriter, loc, d)); 763 764 Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor); 765 // Only need an unordered COO buffer if input and output are not sorted 766 // in the same way. 767 Type bufferTp = getBufferType( 768 dstTp.withoutDimToLvl(), 769 !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity()); 770 SmallVector<Value> dynSizes; 771 Value buffer = rewriter 772 .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(), 773 nnz, Attribute()) 774 .getResult(); 775 776 // Convert src coordinates to dst coordinates by first collapsing it to 1D 777 // and then expand it to the match the rank of the destination tensor. 778 // Implemented as follows: 779 // foreach srcCoords %srcTensor 780 // collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank]) 781 // expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank]) 782 // insert expandedCoords, %buffer 783 // 784 // followed by an optional 785 // %t = sparse_tensor.cast %tmp 786 // depending on whether the input/output are sorted in the same way. 787 const auto encSrc = srcTp.getEncoding(); 788 ForeachOp foreachOp = rewriter.create<ForeachOp>( 789 loc, srcTensor, buffer, 790 [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v, 791 ValueRange reduc) { 792 const Dimension srcRank = srcTp.getDimRank(); 793 SmallVector<Value> srcDcvs; 794 srcDcvs.reserve(srcRank); 795 for (Dimension d = 0; d < srcRank; d++) { 796 Level lvl = toLvl(encSrc, d); 797 srcDcvs.push_back(srcLcvs[lvl]); 798 } 799 800 Value collapseSize = constantIndex(builder, loc, 1); 801 for (Dimension d = 0; d < srcRank; d++) 802 collapseSize = 803 builder.create<arith::MulIOp>(loc, collapseSize, srcSizes[d]); 804 SmallVector<Value, 1> collapsedSizes = {collapseSize}; 805 806 ReassociationIndices collapseIdx; 807 for (Dimension i = 0; i < srcRank; i++) 808 collapseIdx.push_back(i); 809 SmallVector<ReassociationIndices, 1> collapseReass = {collapseIdx}; 810 SmallVector<Value, 1> collapsedDcvs; 811 reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs, 812 collapsedSizes, collapsedDcvs); 813 814 ReassociationIndices expandIdx; 815 for (Dimension i = 0; i < dstTp.getDimRank(); i++) 816 expandIdx.push_back(i); 817 SmallVector<ReassociationIndices, 1> expandReass = {expandIdx}; 818 SmallVector<Value> dstDcvs; 819 reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs, 820 dstSizes, dstDcvs); 821 822 auto t = 823 builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs); 824 builder.create<sparse_tensor::YieldOp>(loc, t); 825 }); 826 827 Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true); 828 if (bufferTp != dstTp) { 829 auto dstRTT = dstTp.getRankedTensorType(); 830 Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult(); 831 rewriter.create<DeallocTensorOp>(loc, t); 832 t = converted; 833 } 834 rewriter.replaceOp(op, t); 835 return success(); 836 } 837 }; 838 839 /// Sparse rewriting rule for sparse-to-sparse reshape operator. 840 template <typename ReshapeOp> 841 struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> { 842 public: 843 using OpRewritePattern<ReshapeOp>::OpRewritePattern; 844 845 LogicalResult matchAndRewrite(ReshapeOp op, 846 PatternRewriter &rewriter) const override { 847 Location loc = op.getLoc(); 848 Value srcTensor = op.getSrc(); 849 const auto srcTp = getSparseTensorType(srcTensor); 850 const auto dstTp = getSparseTensorType(op.getResult()); 851 if (!srcTp.hasEncoding() || !dstTp.hasEncoding()) 852 return failure(); 853 854 // Generate code to represent the static dimension constants or compute 855 // the dynamic dimension values. 856 SmallVector<Value> srcSizes; 857 sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor); 858 SmallVector<Value> dstSizes; 859 SmallVector<Value> dstDynSizes; 860 if (dstTp.hasStaticDimShape()) { 861 for (Dimension d : dstTp.getDimShape()) 862 dstSizes.push_back(constantIndex(rewriter, loc, d)); 863 } else { 864 ArrayRef<Size> dstShape = dstTp.getDimShape(); 865 genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape, 866 op.getReassociationIndices()); 867 for (auto [idx, shape] : llvm::enumerate(dstShape)) { 868 if (shape == ShapedType::kDynamic) 869 dstDynSizes.push_back(dstSizes[idx]); 870 } 871 } 872 Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor); 873 // Only need a unordered COO buffer if input and output are not sorted 874 // in the same way. 875 Type bufferTp = getBufferType( 876 dstTp.withoutDimToLvl(), 877 !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity()); 878 879 Value buffer = 880 rewriter 881 .create<AllocTensorOp>(loc, bufferTp, dstDynSizes, Value(), 882 /*sizeHint=*/nnz, Attribute()) 883 .getResult(); 884 885 // Implement the sparse2sparse reshape as follows: 886 // foreach srcCoords %srcTensor 887 // insert reshapeCvs(srcCoords), %buffer 888 // 889 // followed by an optional 890 // %t = sparse_tensor.cast %tmp 891 // depending on whether the input/output are sorted in the same way. 892 const auto encSrc = srcTp.getEncoding(); 893 ForeachOp foreachOp = rewriter.create<ForeachOp>( 894 loc, srcTensor, buffer, 895 [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v, 896 ValueRange reduc) { 897 const Dimension dimRank = srcTp.getDimRank(); 898 SmallVector<Value> srcDcvs; 899 srcDcvs.reserve(dimRank); 900 for (Dimension d = 0; d < dimRank; d++) { 901 Level lvl = toLvl(encSrc, d); 902 srcDcvs.push_back(srcLcvs[lvl]); 903 } 904 SmallVector<Value> dstDcvs; 905 reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes, 906 srcDcvs, dstSizes, dstDcvs); 907 auto t = 908 builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs); 909 builder.create<sparse_tensor::YieldOp>(loc, t); 910 }); 911 912 Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true); 913 if (bufferTp != dstTp) { 914 auto dstRTT = dstTp.getRankedTensorType(); 915 Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult(); 916 rewriter.create<DeallocTensorOp>(loc, t); 917 t = converted; 918 } 919 rewriter.replaceOp(op, t); 920 return success(); 921 } 922 }; 923 924 /// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape 925 /// operator. 926 template <typename ReshapeOp> 927 struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> { 928 public: 929 using OpRewritePattern<ReshapeOp>::OpRewritePattern; 930 931 LogicalResult matchAndRewrite(ReshapeOp op, 932 PatternRewriter &rewriter) const override { 933 Location loc = op->getLoc(); 934 auto encDst = getSparseTensorEncoding(op.getResult().getType()); 935 auto encSrc = getSparseTensorEncoding(op.getSrc().getType()); 936 // Since a pure dense expansion is very cheap (change of view), for 937 // a sparse2dense or dense2sparse, we can simply unfuse a sparse 938 // conversion from the reshape operation itself. 939 // All other cases are handled elsewhere. 940 if (encDst && encSrc) { 941 return failure(); 942 } 943 if (encSrc) { 944 auto rtp = getRankedTensorType(op.getSrc()); 945 auto denseTp = 946 RankedTensorType::get(rtp.getShape(), rtp.getElementType()); 947 auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc()); 948 rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); }); 949 return success(); 950 } 951 if (encDst) { 952 auto rtp = getRankedTensorType(op.getResult()); 953 auto denseTp = 954 RankedTensorType::get(rtp.getShape(), rtp.getElementType()); 955 ReshapeOp reshape; 956 if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) { 957 reshape = rewriter.create<ReshapeOp>( 958 loc, denseTp, op.getSrc(), op.getReassociation(), 959 op.getOutputShape(), op.getStaticOutputShape()); 960 } else { 961 reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(), 962 op.getReassociation()); 963 } 964 Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape); 965 rewriter.replaceOp(op, convert); 966 return success(); 967 } 968 return failure(); 969 } 970 }; 971 972 // A trivial wrapper to help generate different operations for dense/sparse 973 // tensors. 974 struct TensorLike { 975 TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt, 976 ValueRange sizes) { 977 SmallVector<Value> dynSzs; 978 getDynamicSizes(rtt, sizes, dynSzs); 979 980 val = builder.create<AllocTensorOp>(loc, rtt, dynSzs); 981 if (!isSparse()) { 982 Value c0 = constantZero(builder, loc, rtt.getElementType()); 983 val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0); 984 } 985 } 986 987 void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) { 988 val = builder.create<tensor::InsertOp>(loc, v, val, crds); 989 } 990 991 Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const { 992 if (isSparse()) 993 return builder.create<LoadOp>(loc, val, true); 994 return val; 995 } 996 997 bool isSparse() const { 998 return getSparseTensorEncoding(val.getType()) != nullptr; 999 } 1000 1001 Value val; 1002 }; 1003 1004 struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> { 1005 using OpRewritePattern::OpRewritePattern; 1006 LogicalResult matchAndRewrite(tensor::DimOp op, 1007 PatternRewriter &rewriter) const override { 1008 std::optional<int64_t> dim = op.getConstantIndex(); 1009 auto stt = getSparseTensorType(op.getSource()); 1010 if (!dim || !stt.hasEncoding()) 1011 return failure(); 1012 1013 if (stt.isPermutation()) { 1014 rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(), 1015 toLvl(stt.getEncoding(), *dim)); 1016 return success(); 1017 } 1018 1019 // Non-permutation dim2lvl/lvl2dim maps. 1020 // Compute as follows: 1021 // affine.apply #map (l0 - 1, l1 - 1, ...) + 1 1022 // Note that it is not the most efficient way (but a more general one) for 1023 // the lvl to dim translation, e.g., for BSR, the dimension size for can be 1024 // computed simply by lvl_size * block_size. 1025 Location loc = op.getLoc(); 1026 SmallVector<Value> maxLvlCrds; 1027 for (Level l = 0; l < stt.getLvlRank(); l++) { 1028 Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l); 1029 Value maxLvlCrd = rewriter.create<arith::SubIOp>( 1030 loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType())); 1031 maxLvlCrds.push_back(maxLvlCrd); 1032 } 1033 1034 AffineExpr lvl2DimExp = stt.getLvlToDim().getResult(*dim); 1035 Value maxDimCrd = rewriter.create<affine::AffineApplyOp>( 1036 op.getLoc(), AffineMap::get(stt.getLvlRank(), 0, lvl2DimExp), 1037 maxLvlCrds); 1038 1039 Value dimSz = rewriter.create<arith::AddIOp>( 1040 loc, maxDimCrd, constantOne(rewriter, loc, rewriter.getIndexType())); 1041 rewriter.replaceOp(op, dimSz); 1042 return success(); 1043 } 1044 }; 1045 1046 struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> { 1047 using OpRewritePattern::OpRewritePattern; 1048 LogicalResult matchAndRewrite(ConcatenateOp op, 1049 PatternRewriter &rewriter) const override { 1050 if (op.needsExtraSort()) 1051 op.emitError("ConcatenateOp not staged"); 1052 1053 const Location loc = op.getLoc(); 1054 const auto dstTp = getSparseTensorType(op); 1055 const Dimension conDim = op.getDimension(); 1056 SmallVector<Value> sizes; 1057 concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim); 1058 1059 // %t = concatenate %s1, %s2, %s3 {dim = 1} 1060 // ==> 1061 // if (isSparseDst) 1062 // if (allDense) 1063 // %tmp = bufferization.alloc_tensor dstTp 1064 // else 1065 // %tmp = bufferization.alloc_tensor : unordered COO 1066 // else 1067 // %tmp = memref.alloc : dense tensor 1068 // foreach in %s1 : insert d0, d1, %tmp 1069 // foreach in %s2 : insert d0, d1 + size(s1), %tmp 1070 // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp 1071 1072 TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes); 1073 Value offset = constantIndex(rewriter, loc, 0); 1074 Value iterArg = dstBuf.val; 1075 1076 ForeachOp foreachOp; 1077 for (Value input : op.getInputs()) { 1078 // Builds a for op for each input tensor to append new values into the 1079 // output tensor. 1080 foreachOp = rewriter.create<ForeachOp>( 1081 loc, input, iterArg, 1082 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v, 1083 ValueRange reduc) { 1084 SmallVector<Value> offDimCrd(dcvs); 1085 offDimCrd[conDim] = 1086 builder.create<arith::AddIOp>(loc, offDimCrd[conDim], offset); 1087 1088 // Enters foreach, updates the SSA chain. 1089 dstBuf.val = reduc.front(); 1090 if (!dstTp.isAllDense()) { 1091 Value cond = genIsNonzero(builder, loc, v); 1092 auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond, 1093 /*else*/ true); 1094 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 1095 builder.create<scf::YieldOp>(loc, dstBuf.val); 1096 1097 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1098 dstBuf.insert(builder, loc, v, offDimCrd); 1099 builder.create<scf::YieldOp>(loc, dstBuf.val); 1100 1101 // Exits the ifOp, update the sparse tensor SSA value. 1102 builder.setInsertionPointAfter(ifOp); 1103 dstBuf.val = ifOp.getResult(0); 1104 } else { 1105 dstBuf.insert(builder, loc, v, offDimCrd); 1106 } 1107 builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val); 1108 }); 1109 // Accumulates the offset. Note that only static-shaped inputs are allowed 1110 // by concatenate op verifier, which saves us from computing the offset 1111 // dynamically. 1112 const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim); 1113 assert(!ShapedType::isDynamic(sz)); 1114 offset = rewriter.create<arith::AddIOp>(loc, offset, 1115 constantIndex(rewriter, loc, sz)); 1116 iterArg = foreachOp.getResult(0); 1117 dstBuf.val = iterArg; 1118 } 1119 1120 dstBuf.val = iterArg; 1121 Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType()); 1122 rewriter.replaceOp(op, ret); 1123 return success(); 1124 } 1125 }; 1126 1127 struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> { 1128 using OpRewritePattern::OpRewritePattern; 1129 LogicalResult matchAndRewrite(ConvertOp op, 1130 PatternRewriter &rewriter) const override { 1131 if (op.needsExtraSort()) 1132 return op.emitError("ConvertOp not staged."); 1133 1134 // TODO: Maybe we want a different operation for this too. 1135 auto encDst = getSparseTensorEncoding(op.getType()); 1136 auto encSrc = getSparseTensorEncoding(op.getSource().getType()); 1137 if (encDst && encSrc && !encSrc.isSlice() && 1138 encSrc.withoutBitWidths() == encDst.withoutBitWidths()) { 1139 // Trivial tensor conversion and simple element type conversion is handled 1140 // in codegen. 1141 return failure(); 1142 } 1143 1144 Location loc = op.getLoc(); 1145 Value src = op.getSource(); 1146 1147 SparseTensorType srcStt = getSparseTensorType(op.getSource()); 1148 SparseTensorType dstStt = getSparseTensorType(op.getDest()); 1149 1150 bool fromSparseConst = false; 1151 if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>()) 1152 if (dyn_cast<SparseElementsAttr>(constOp.getValue())) 1153 fromSparseConst = true; 1154 1155 const AffineMapAttr foreachOrder = 1156 (!dstStt.isIdentity() && fromSparseConst) 1157 ? AffineMapAttr::get(dstStt.getExpandedDimToLvl()) 1158 : nullptr; 1159 1160 bool skipZeroCheck = srcStt.hasEncoding() || fromSparseConst; 1161 1162 SmallVector<Value> sizes; 1163 sizesFromSrc(rewriter, sizes, loc, src); 1164 ValueRange vs; 1165 TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes); 1166 1167 auto foreachOp = rewriter.create<ForeachOp>( 1168 loc, src, dstBuf.val, foreachOrder, 1169 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v, 1170 ValueRange reduc) { 1171 // Enters the loop, update the SSA value for insertion chain. 1172 dstBuf.val = reduc.front(); 1173 if (!skipZeroCheck) { 1174 Value cond = genIsNonzero(builder, loc, v); 1175 auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond, 1176 /*else*/ true); 1177 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 1178 builder.create<scf::YieldOp>(loc, dstBuf.val); 1179 1180 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1181 dstBuf.insert(builder, loc, v, dcvs); 1182 builder.create<scf::YieldOp>(loc, dstBuf.val); 1183 1184 // Exits the ifOp, update the sparse tensor SSA value. 1185 builder.setInsertionPointAfter(ifOp); 1186 dstBuf.val = ifOp.getResult(0); 1187 } else { 1188 dstBuf.insert(builder, loc, v, dcvs); 1189 } 1190 builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val); 1191 }); 1192 1193 rewriter.setInsertionPointAfter(foreachOp); 1194 1195 // Exits the for loop, links the SSA chain. 1196 dstBuf.val = foreachOp.getResult(0); 1197 1198 Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType()); 1199 rewriter.replaceOp(op, ret); 1200 return success(); 1201 } 1202 }; 1203 1204 struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> { 1205 using OpRewritePattern::OpRewritePattern; 1206 LogicalResult matchAndRewrite(CrdTranslateOp op, 1207 PatternRewriter &rewriter) const override { 1208 AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl 1209 ? op.getEncoder().getDimToLvl() 1210 : op.getEncoder().getLvlToDim(); 1211 1212 SmallVector<Value> outCrds; 1213 for (AffineExpr result : map.getResults()) { 1214 // TODO: we should probably expand the affine map to IR using our own 1215 // rules, since affine.apply assume signed value, while the cooridinates 1216 // we provided must always be signless. 1217 Value trans = rewriter.create<affine::AffineApplyOp>( 1218 op.getLoc(), AffineMap::get(map.getNumDims(), 0, result), 1219 op.getInCrds()); 1220 outCrds.push_back(trans); 1221 } 1222 rewriter.replaceOp(op, outCrds); 1223 return success(); 1224 } 1225 }; 1226 1227 /// Sparse rewriting rule for the foreach operator. 1228 struct ForeachRewriter : public OpRewritePattern<ForeachOp> { 1229 public: 1230 using OpRewritePattern::OpRewritePattern; 1231 1232 LogicalResult matchAndRewrite(ForeachOp op, 1233 PatternRewriter &rewriter) const override { 1234 1235 auto loc = op.getLoc(); 1236 Value input = op.getTensor(); 1237 SmallVector<Value> reduc = op.getInitArgs(); 1238 const auto stt = getSparseTensorType(input); 1239 const Level lvlRank = stt.getLvlRank(); 1240 1241 // Special-case: for each over a sparse constant uses its own rewriting 1242 // rule. 1243 if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) { 1244 if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) { 1245 return genForeachOnSparseConstant(op, rewriter, attr); 1246 } 1247 } 1248 1249 // Otherwise, use loop emitter to generate loops. 1250 const auto enc = stt.getEncoding(); 1251 1252 // 1. Generates loop for the sparse input. 1253 LoopEmitter loopEmitter( 1254 ValueRange{input}, 1255 StringAttr::get(getContext(), ForeachOp::getOperationName())); 1256 loopEmitter.initializeLoopEmit(rewriter, loc); 1257 for (Level l = 0; l < lvlRank; l++) { 1258 // TODO: provide utility function for loop sequences that only contains 1259 // one for loop? 1260 const SmallVector<TensorLevel, 1> tidLvls{ 1261 loopEmitter.makeTensorLevel(0, l)}; 1262 loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls); 1263 // Note that reduc will be taken care of by loop emitter and get updated 1264 // in place. 1265 loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1266 reduc); 1267 } 1268 1269 SmallVector<Value> lcvs = loopEmitter.getLoopIVs(); 1270 if (op.getOrder()) { 1271 // TODO: Support it so that we can do direct conversion from CSR->BSR. 1272 llvm_unreachable( 1273 "Level order not yet implemented on non-constant input tensors."); 1274 } 1275 1276 Value vals = loopEmitter.getValBuffer()[0]; 1277 SmallVector<Value> pos = loopEmitter.getValPosits(0); 1278 // Loads the value from sparse tensor using position-index; 1279 // loads the value from dense tensor using coords. 1280 Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pos) 1281 : rewriter.create<memref::LoadOp>(loc, vals, lcvs); 1282 1283 // 2. Inline the block in the foreach operator. 1284 Block *srcBlock = op.getBody(); 1285 1286 // Remap coordinates. 1287 SmallVector<Value> args = 1288 enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim); 1289 1290 // Remap value. 1291 args.push_back(val); 1292 // Remap reduction variables. 1293 args.append(reduc); 1294 1295 // Remove sparse_tensor.yield. 1296 SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands(); 1297 rewriter.eraseOp(srcBlock->getTerminator()); 1298 1299 Operation &last = rewriter.getBlock()->back(); 1300 if (llvm::isa<scf::YieldOp>(last)) { 1301 // Because `scf.for` inserts an implicit yield op when there is no 1302 // reduction variable upon creation, we reset the insertion point such 1303 // that the block is inlined before *before* the yield op. 1304 rewriter.setInsertionPoint(&last); 1305 } 1306 1307 rewriter.inlineBlockBefore(srcBlock, rewriter.getBlock(), 1308 rewriter.getInsertionPoint(), args); 1309 rewriter.setInsertionPointToEnd(rewriter.getBlock()); 1310 for (Level l = 0; l < lvlRank; l++) { 1311 // Link the reduction chain. Note that loop emitter update the reducValue 1312 // in place. 1313 loopEmitter.exitCurrentLoop(rewriter, loc, reducValue); 1314 loopEmitter.exitCurrentLoopSeq(rewriter, loc); 1315 } 1316 1317 // Replace the foreach operator with the value returned by the outtermost 1318 // for loop. 1319 rewriter.replaceOp(op, reducValue); 1320 return success(); 1321 } 1322 }; 1323 1324 /// Sparse rewriting rule for the new operator. 1325 struct NewRewriter : public OpRewritePattern<NewOp> { 1326 using OpRewritePattern::OpRewritePattern; 1327 LogicalResult matchAndRewrite(NewOp op, 1328 PatternRewriter &rewriter) const override { 1329 Location loc = op.getLoc(); 1330 auto stt = getSparseTensorType(op.getResult()); 1331 if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0) 1332 return failure(); 1333 1334 // Implement the NewOp as follows: 1335 // %orderedCoo = sparse_tensor.new %filename 1336 // %t = sparse_tensor.convert %orderedCoo 1337 // with enveloping reinterpreted_map ops for non-permutations. 1338 RankedTensorType dstTp = stt.getRankedTensorType(); 1339 RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true); 1340 Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource()); 1341 Value convert = cooTensor; 1342 auto enc = stt.getEncoding(); 1343 if (!stt.isPermutation()) { // demap coo, demap dstTp 1344 auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl(); 1345 convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert); 1346 dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl()); 1347 } 1348 convert = rewriter.create<ConvertOp>(loc, dstTp, convert); 1349 if (!stt.isPermutation()) // remap to original enc 1350 convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert); 1351 rewriter.replaceOp(op, convert); 1352 1353 // Release the temporary ordered COO tensor. 1354 rewriter.setInsertionPointAfterValue(convert); 1355 rewriter.create<DeallocTensorOp>(loc, cooTensor); 1356 1357 return success(); 1358 } 1359 }; 1360 1361 /// Sparse rewriting rule for the out operator. 1362 struct OutRewriter : public OpRewritePattern<OutOp> { 1363 using OpRewritePattern::OpRewritePattern; 1364 LogicalResult matchAndRewrite(OutOp op, 1365 PatternRewriter &rewriter) const override { 1366 Location loc = op.getLoc(); 1367 // Calculate NNZ. 1368 Value src = op.getTensor(); 1369 Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src); 1370 1371 // Allocate a temporary buffer for storing dimension-sizes/coordinates. 1372 const auto srcTp = getSparseTensorType(src); 1373 const Dimension dimRank = srcTp.getDimRank(); 1374 Type indexTp = rewriter.getIndexType(); 1375 Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp); 1376 1377 // Generate code to calculate dimension size values and store the values to 1378 // the buffer. 1379 SmallVector<Value> dims; 1380 sizesForTensor(rewriter, dims, loc, srcTp, src); 1381 for (Dimension d = 0; d < dimRank; d++) { 1382 rewriter.create<memref::StoreOp>(loc, dims[d], dimSizes, 1383 constantIndex(rewriter, loc, d)); 1384 } 1385 1386 // Create a sparse tensor writer and output meta data. 1387 Type opaqueTp = getOpaquePointerType(rewriter); 1388 Value writer = 1389 createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp}, 1390 {op.getDest()}, EmitCInterface::Off) 1391 .getResult(0); 1392 Value rankValue = constantIndex(rewriter, loc, dimRank); 1393 createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {}, 1394 {writer, rankValue, nnz, dimSizes}, EmitCInterface::On); 1395 1396 Value dimCoords = dimSizes; // Reuse the dimSizes buffer for dimCoords. 1397 Type eltTp = srcTp.getElementType(); 1398 SmallString<29> outNextFuncName{"outSparseTensorWriterNext", 1399 primaryTypeFunctionSuffix(eltTp)}; 1400 Value value = genAllocaScalar(rewriter, loc, eltTp); 1401 ModuleOp module = op->getParentOfType<ModuleOp>(); 1402 1403 // For each element in the source tensor, output the element. 1404 rewriter.create<ForeachOp>( 1405 loc, src, std::nullopt, 1406 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v, 1407 ValueRange reduc) { 1408 for (Dimension d = 0; d < dimRank; d++) { 1409 rewriter.create<memref::StoreOp>(loc, dcvs[d], dimCoords, 1410 constantIndex(builder, loc, d)); 1411 } 1412 rewriter.create<memref::StoreOp>(loc, v, value); 1413 SmallVector<Value> operands{writer, rankValue, dimCoords, value}; 1414 FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands, 1415 EmitCInterface::On); 1416 builder.create<func::CallOp>(loc, TypeRange(), fn, operands); 1417 builder.create<sparse_tensor::YieldOp>(loc); 1418 }); 1419 1420 // Release the writer. 1421 createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer}, 1422 EmitCInterface::Off); 1423 1424 rewriter.eraseOp(op); 1425 return success(); 1426 } 1427 }; 1428 1429 } // namespace 1430 1431 //===---------------------------------------------------------------------===// 1432 // Methods that add patterns described in this file to a pattern list. 1433 //===---------------------------------------------------------------------===// 1434 1435 void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) { 1436 patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast, 1437 GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>( 1438 patterns.getContext()); 1439 } 1440 1441 void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, 1442 bool enableRT, 1443 bool enableConvert) { 1444 patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>, 1445 ReshapeRewriter<tensor::CollapseShapeOp>, 1446 Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>, 1447 Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>, 1448 SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>( 1449 patterns.getContext()); 1450 1451 if (enableConvert) 1452 patterns.add<DirectConvertRewriter>(patterns.getContext()); 1453 if (!enableRT) 1454 patterns.add<NewRewriter>(patterns.getContext()); 1455 } 1456 1457 void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) { 1458 // Run CrdTranslateRewriter later in the pipeline so that operation can be 1459 // folded before lowering to affine.apply 1460 patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext()); 1461 } 1462