1 //===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements rewriting rules that are specific to sparse tensors. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Utils/CodegenUtils.h" 14 #include "Utils/LoopEmitter.h" 15 16 #include "mlir/Dialect/Affine/IR/AffineOps.h" 17 #include "mlir/Dialect/Arith/IR/Arith.h" 18 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 19 #include "mlir/Dialect/Linalg/IR/Linalg.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/IR/SCF.h" 23 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 24 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" 25 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 26 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 27 #include "mlir/Dialect/Tensor/IR/Tensor.h" 28 #include "mlir/Dialect/Vector/IR/VectorOps.h" 29 #include "mlir/IR/AffineMap.h" 30 #include "mlir/IR/Matchers.h" 31 #include "mlir/Support/LLVM.h" 32 33 using namespace mlir; 34 using namespace mlir::bufferization; 35 using namespace mlir::linalg; 36 using namespace mlir::sparse_tensor; 37 38 //===---------------------------------------------------------------------===// 39 // Helper methods for the actual rewriting rules. 40 //===---------------------------------------------------------------------===// 41 42 // Helper method to match any typed zero. 43 static bool isZeroValue(Value val) { 44 return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()); 45 } 46 47 // Helper to detect a sparse tensor type operand. 48 static bool isSparseTensor(Value v) { 49 auto enc = getSparseTensorEncoding(v.getType()); 50 return enc && !llvm::all_of(enc.getLvlTypes(), 51 [](auto lt) { return lt == LevelFormat::Dense; }); 52 } 53 static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); } 54 55 // Helper method to find zero/uninitialized tensor materialization. 56 static bool isMaterializing(OpOperand *op, bool isZero) { 57 Value val = op->get(); 58 // Check allocation, with zero alloc when required. 59 if (auto alloc = val.getDefiningOp<AllocTensorOp>()) { 60 Value copy = alloc.getCopy(); 61 if (isZero) 62 return copy && isZeroValue(copy); 63 return !copy; 64 } 65 // Check for empty tensor materialization. 66 if (auto empty = val.getDefiningOp<tensor::EmptyOp>()) 67 return !isZero; 68 // Last resort for zero alloc: the whole value is zero. 69 return isZero && isZeroValue(val); 70 } 71 72 // Helper to detect sampling operation. 73 static bool isSampling(GenericOp op) { 74 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); 75 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { 76 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) { 77 // Both scalar input arguments used exactly once. 78 Value s1 = op.getBlock()->getArgument(0); 79 Value s2 = op.getBlock()->getArgument(1); 80 return (def->getOperand(0) == s1 && def->getOperand(1) == s2) || 81 (def->getOperand(1) == s1 && def->getOperand(0) == s2); 82 } 83 } 84 return false; 85 } 86 87 // Helper to detect chain of multiplications that do not involve x. 88 static bool isMulChain(Value val, Value x) { 89 if (auto arg = dyn_cast<BlockArgument>(val)) 90 return arg != x; 91 if (auto *def = val.getDefiningOp()) { 92 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) 93 return isMulChain(def->getOperand(0), x) && 94 isMulChain(def->getOperand(1), x); 95 } 96 return false; 97 } 98 99 // Helper to detect x = x + <multiplications>. 100 static bool isSumOfMul(GenericOp op) { 101 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); 102 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { 103 if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) { 104 Value x = op.getBlock()->getArguments().back(); 105 return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) || 106 (def->getOperand(1) == x && isMulChain(def->getOperand(0), x)); 107 } 108 } 109 return false; 110 } 111 112 // Helper to detect direct yield of a zero value. 113 static bool isZeroYield(GenericOp op) { 114 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); 115 if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) { 116 if (arg.getOwner()->getParentOp() == op) { 117 return isZeroValue(op->getOperand(arg.getArgNumber())); 118 } 119 } 120 return isZeroValue(yieldOp.getOperand(0)); 121 } 122 123 /// Populates given sizes array from type (for static sizes) and from 124 /// the tensor (for dynamic sizes). 125 static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes, 126 Location loc, ShapedType stp, Value tensor) { 127 for (const auto &d : enumerate(stp.getShape())) { 128 Value dim; 129 if (d.value() == ShapedType::kDynamic) 130 dim = builder.create<tensor::DimOp>(loc, tensor, d.index()); 131 else 132 dim = constantIndex(builder, loc, d.value()); 133 sizes.push_back(dim); 134 } 135 } 136 137 static RankedTensorType getBufferType(const SparseTensorType &stt, 138 bool needTmpCOO) { 139 return needTmpCOO ? stt.getCOOType(/*ordered=*/false) 140 : stt.getRankedTensorType(); 141 } 142 143 /// Collects the dynamic dimension sizes for `tp` with the assumption that 144 /// `sizes` are the dimension sizes for the type. Stores the dynamic dimension 145 /// sizes to dynSizes. 146 static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, 147 SmallVectorImpl<Value> &dynSizes) { 148 for (const auto &d : enumerate(tp.getShape())) { 149 if (d.value() == ShapedType::kDynamic) 150 dynSizes.push_back(sizes[d.index()]); 151 } 152 } 153 154 static LogicalResult genForeachOnSparseConstant(ForeachOp op, 155 RewriterBase &rewriter, 156 SparseElementsAttr attr) { 157 auto loc = op.getLoc(); 158 SmallVector<Value> reduc = op.getInitArgs(); 159 160 // Foreach on constant. 161 foreachInSparseConstant( 162 rewriter, loc, attr, op.getOrder().value_or(AffineMap()), 163 [&reduc, &rewriter, op](ArrayRef<Value> cvs, Value v) mutable { 164 SmallVector<Value> args; 165 args.append(cvs.begin(), cvs.end()); 166 args.push_back(v); 167 args.append(reduc); 168 // Clones the foreach op to get a copy of the loop body. 169 auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation())); 170 assert(args.size() == cloned.getBody()->getNumArguments()); 171 Operation *yield = cloned.getBody()->getTerminator(); 172 rewriter.inlineBlockBefore(cloned.getBody(), op, args); 173 // clean up 174 rewriter.eraseOp(cloned); 175 reduc = yield->getOperands(); 176 rewriter.eraseOp(yield); 177 }); 178 179 rewriter.replaceOp(op, reduc); 180 return success(); 181 } 182 183 /// Populates the given sizes array for concatenation from types (for static 184 /// sizes) and from the source tensors (for dynamic sizes). 185 static void concatSizesFromInputs(OpBuilder &builder, 186 SmallVectorImpl<Value> &sizes, Location loc, 187 ShapedType dstTp, ValueRange srcs, 188 unsigned dim) { 189 auto dstShape = dstTp.getShape(); 190 sizesFromSrc(builder, sizes, loc, srcs[0]); 191 192 // Sum up on the `dim` if the dimension is dynamic. 193 if (dstShape[dim] != ShapedType::kDynamic) { 194 // Faithfully take the static size. 195 sizes[dim] = constantIndex(builder, loc, dstShape[dim]); 196 } else { 197 // Else, compute the shape dynamically. 198 for (const auto &src : srcs.drop_front()) { 199 Value srcSz = linalg::createOrFoldDimOp(builder, loc, src, dim); 200 // Sum up all the sizes. 201 sizes[dim] = builder.create<arith::AddIOp>(loc, sizes[dim], srcSz); 202 } 203 } 204 } 205 206 //===---------------------------------------------------------------------===// 207 // The actual sparse tensor rewriting rules. 208 //===---------------------------------------------------------------------===// 209 210 namespace { 211 212 /// TODO: move it to tensor dialect instead. 213 /// 214 /// Fold `tensor.concat` and `tensor.extract_slice` 215 /// 216 /// %concat = tensor.concat dim(2) %t0, %t1 217 /// : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32> 218 /// %extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1] 219 /// : tensor<1x64x2xf32> to tensor<1x64x1xf32> 220 /// %extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1] 221 /// : tensor<1x64x2xf32> to tensor<1x64x1xf32> 222 /// 223 /// Becomes 224 /// 225 /// %extract0, %extract1 = %t0, %t1 226 struct FuseExtractSliceWithConcat 227 : public OpRewritePattern<tensor::ExtractSliceOp> { 228 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; 229 230 LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp, 231 PatternRewriter &rewriter) const override { 232 auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>(); 233 if (!concatOp) 234 return failure(); 235 236 Location loc = extractOp.getLoc(); 237 int64_t dim = concatOp.getDim(); 238 int64_t rank = extractOp.getResultType().getRank(); 239 240 SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1)); 241 SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0)); 242 243 // Compute the partial sums for the slice offsets. 244 AffineExpr sum = rewriter.getAffineDimExpr(0); 245 SmallVector<AffineExpr> partialSums = {sum}; 246 SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)}; 247 for (auto [idx, input] : 248 llvm::enumerate(concatOp.getInputs().drop_back())) { 249 sum = sum + rewriter.getAffineDimExpr(idx + 1); 250 partialSums.push_back(sum); 251 offsetStrides.push_back( 252 rewriter.createOrFold<tensor::DimOp>(loc, input, dim)); 253 } 254 auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0, 255 partialSums, rewriter.getContext()); 256 SmallVector<OpFoldResult> dimOffsets = 257 affine::makeComposedFoldedMultiResultAffineApply( 258 rewriter, loc, partialSumMap, offsetStrides); 259 260 auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) { 261 for (auto [l, r] : llvm::zip(lhs, rhs)) { 262 std::optional<int64_t> staticVal = getConstantIntValue(l); 263 if (!staticVal.has_value() || staticVal != getConstantIntValue(r)) 264 return false; 265 } 266 return lhs.size() == rhs.size(); 267 }; 268 269 for (auto [i, input, offset] : 270 llvm::enumerate(concatOp.getInputs(), dimOffsets)) { 271 SmallVector<OpFoldResult> srcSizes = 272 tensor::getMixedSizes(rewriter, loc, input); 273 srcOffsets[dim] = offset; 274 275 SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes(); 276 SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets(); 277 SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides(); 278 279 if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) && 280 allEqual(srcStrides, dstStrides)) { 281 Value operand = concatOp.getOperand(i); 282 if (operand.getType() == extractOp.getResultType()) 283 rewriter.replaceOp(extractOp, operand); 284 break; 285 } 286 } 287 288 return success(); 289 } 290 }; 291 292 /// Rewriting rule that fuses sparse_tensor.convert into producer. 293 struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> { 294 public: 295 using OpRewritePattern::OpRewritePattern; 296 297 LogicalResult matchAndRewrite(ConvertOp op, 298 PatternRewriter &rewriter) const override { 299 auto producer = op.getSource().getDefiningOp<GenericOp>(); 300 if (!producer || producer.getDpsInits().size() != 1 || 301 !isMaterializing(producer.getDpsInitOperand(0), false) || 302 !producer.getResult(0).hasOneUse()) { 303 return failure(); 304 } 305 // Clone the materialization operation, but update the result to sparse. 306 rewriter.setInsertionPoint(producer); 307 Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp(); 308 Operation *cloned = rewriter.clone(*init); 309 cloned->getResult(0).setType(op.getResult().getType()); 310 311 rewriter.modifyOpInPlace(producer, [&]() { 312 producer.getDpsInitsMutable().assign(cloned->getResults()); 313 producer.getResult(0).setType(op.getResult().getType()); 314 }); 315 316 rewriter.replaceAllOpUsesWith(op, producer); 317 op->erase(); 318 319 return success(); 320 } 321 }; 322 323 /// Rewriting rule that converts direct yield of zero with initial allocation. 324 struct FoldInvariantYield : public OpRewritePattern<GenericOp> { 325 public: 326 using OpRewritePattern<GenericOp>::OpRewritePattern; 327 328 LogicalResult matchAndRewrite(GenericOp op, 329 PatternRewriter &rewriter) const override { 330 if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 || 331 !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) || 332 !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse()) 333 return failure(); 334 auto outputType = getRankedTensorType(op.getResult(0)); 335 // Yielding zero on newly materialized sparse tensor can be 336 // optimized directly (regardless of dynamic or static size). 337 if (getSparseTensorEncoding(outputType)) { 338 rewriter.replaceOp(op, op.getDpsInitOperand(0)->get()); 339 return success(); 340 } 341 // Use static zero value directly instead of materialization. 342 if (!outputType.hasStaticShape()) 343 return failure(); 344 Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp(); 345 rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType)); 346 rewriter.eraseOp(def); 347 return success(); 348 } 349 }; 350 351 /// Rewriting rule that converts two kernels: 352 /// 353 /// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... ) 354 /// X(i,j) = S(i,j) * T(i,j) 355 /// 356 /// into a single kernel, using distributive law: 357 /// 358 /// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... ) 359 /// 360 /// This kind of fusion (merging two ops into one but using arithmetic 361 /// equalities that may not hold for floating-point computations) would 362 /// be undesirable in the dense case, since we distribute the multiplication 363 /// into the reduction loop. However, for sparse sampling tensor S, such 364 /// a fusion may actually reduce the asymptotic complexity of the kernel, 365 /// since intermediate results may be nullified. 366 struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> { 367 public: 368 using OpRewritePattern<GenericOp>::OpRewritePattern; 369 370 LogicalResult matchAndRewrite(GenericOp op, 371 PatternRewriter &rewriter) const override { 372 // Check consumer. 373 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 || 374 op.getNumResults() != 1 || 375 op.getNumParallelLoops() != op.getNumLoops() || 376 !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() || 377 !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() || 378 !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity()) 379 return failure(); 380 // Find consuming OP2(sparse, other) or OP2(other, sparse). The other 381 // operand can be sparse or dense, since the point of this rewriting rule 382 // is detecting a situation in which *more* sparsity is introduced into 383 // a computation, be it already sparse or still dense. 384 unsigned other = 0; 385 if (isSparseTensor(op.getDpsInputOperand(0))) 386 other = 1; 387 else if (!isSparseTensor(op.getDpsInputOperand(1))) 388 return failure(); 389 // Check producer. 390 auto prod = dyn_cast_or_null<GenericOp>( 391 op.getDpsInputOperand(other)->get().getDefiningOp()); 392 if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 || 393 !prod.getResult(0).hasOneUse()) 394 return failure(); 395 // Sampling consumer and sum of multiplication chain producer. 396 if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) || 397 !isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) || 398 !isSampling(op) || !isSumOfMul(prod)) 399 return failure(); 400 // Modify operand structure of producer and consumer. 401 Location loc = prod.getLoc(); 402 SmallVector<Value> inputOps = prod.getInputs(); 403 SmallVector<Value> outputOps = op.getOutputs(); 404 SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray(); 405 inputOps.push_back(op.getDpsInputOperand(1 - other)->get()); 406 fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other 407 // Fuse producer and consumer into a new generic op. 408 auto fusedOp = rewriter.create<GenericOp>( 409 loc, op.getResult(0).getType(), inputOps, outputOps, 410 rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(), 411 /*doc=*/nullptr, /*library_call=*/nullptr); 412 Block &prodBlock = prod.getRegion().front(); 413 Block &consBlock = op.getRegion().front(); 414 IRMapping mapper; 415 Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion()); 416 unsigned num = prodBlock.getNumArguments(); 417 for (unsigned i = 0; i < num - 1; i++) 418 addArg(mapper, fusedBlock, prodBlock.getArgument(i)); 419 addArg(mapper, fusedBlock, consBlock.getArgument(1 - other)); 420 addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1)); 421 // Clone bodies of the producer and consumer in new evaluation order. 422 auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp(); 423 auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp(); 424 Value last; 425 for (auto &op : prodBlock.without_terminator()) 426 if (&op != acc) { 427 last = op.getResult(0); 428 rewriter.clone(op, mapper); 429 } 430 mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0)); 431 mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0)); 432 last = rewriter.clone(*acc, mapper)->getResult(0); 433 rewriter.create<linalg::YieldOp>(loc, last); 434 // Force initial value on merged allocation for dense outputs. 435 // TODO: deal with non alloc tensor here one day 436 if (!getSparseTensorEncoding(op.getResult(0).getType())) { 437 Value init = prod.getDpsInitOperand(0) 438 ->get() 439 .getDefiningOp<AllocTensorOp>() 440 .getCopy(); 441 AllocTensorOp a = 442 op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>(); 443 rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); }); 444 } 445 // Replace consumer with fused operation. Old producer 446 // and consumer ops will be removed by DCE. 447 rewriter.replaceOp(op, fusedOp->getResults()); 448 return success(); 449 } 450 451 private: 452 // Helper to add argument and record the mapping. 453 static void addArg(IRMapping &mapper, Block *b, BlockArgument a) { 454 mapper.map(a, b->addArgument(a.getType(), a.getLoc())); 455 } 456 }; 457 458 // Fuse a tensor cast into producing operation. Note that a tensor.cast 459 // should really not be used to convert between sparse encodings. Since 460 // the pattern currently appears as a result of some prior rewriting 461 // we make an attempt to repair very obvious cases. 462 // TODO: audit the pure tensor dialect rewriting rules 463 struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> { 464 public: 465 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 466 467 LogicalResult matchAndRewrite(tensor::CastOp op, 468 PatternRewriter &rewriter) const override { 469 Type srcType = op.getSource().getType(); 470 Type dstType = op.getDest().getType(); 471 // A nop cast simply folds away. 472 if (srcType == dstType) { 473 rewriter.replaceOp(op, op->getResults()); 474 return success(); 475 } 476 // See if a sparsity changing cast can be fused into producer. 477 if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) { 478 if (Operation *def = op.getSource().getDefiningOp()) { 479 if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) { 480 rewriter.modifyOpInPlace(def, [&]() { 481 def->getResult(0).setType(op->getResultTypes()[0]); 482 }); 483 rewriter.replaceOp(op, def->getResult(0)); 484 return success(); 485 } 486 } 487 } 488 // Repair tensor casts with at least one sparse operand into the 489 // the properly supported sparse_tensor.convert. 490 if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) { 491 rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource()); 492 return success(); 493 } 494 // Fail otherwise. 495 return failure(); 496 } 497 }; 498 499 /// Rewrites a sequence of operations for sparse tensor selections in to 500 /// semi-ring operations such that they can be compiled correctly by the 501 /// sparsifier. E.g., transforming the following sequence 502 /// 503 /// %sel = arith.select %cond, %sp1, %sp2 504 /// 505 /// to 506 /// 507 /// %sel = binary %sp1, %sp2: 508 /// both (%l, %r) {yield select %cond, %l, %r} 509 /// left (%l) {yield select %cond, %l, 0} 510 /// right (%r) {yield select %cond, 0, %r} 511 /// 512 /// TODO: We require that the tensor used for extracting conditions to be dense 513 /// to sparsify the code. To support a sparse condition tensor, we need a 514 /// tri-nary operation. 515 struct GenSemiRingSelect : public OpRewritePattern<GenericOp> { 516 public: 517 using OpRewritePattern<GenericOp>::OpRewritePattern; 518 LogicalResult matchAndRewrite(GenericOp op, 519 PatternRewriter &rewriter) const override { 520 // Rejects non sparse kernels. 521 if (!op.hasPureTensorSemantics() || !hasAnySparseOperand(op)) 522 return failure(); 523 524 Location loc = op.getLoc(); 525 SmallVector<std::pair<Operation *, sparse_tensor::BinaryOp>> semiRings; 526 for (Operation &inst : *op.getBody()) { 527 // Matches pattern. 528 auto matched = isRewritablePattern(op, &inst); 529 if (!matched.has_value()) 530 continue; 531 532 rewriter.setInsertionPoint(&inst); 533 auto [c, t, f] = matched.value(); 534 assert(t.getType() == f.getType()); 535 auto selTp = t.getType(); 536 auto c0 = constantZero(rewriter, loc, selTp); 537 auto binOp = rewriter.create<sparse_tensor::BinaryOp>(loc, selTp, t, f); 538 // Initializes all the blocks. 539 rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp}, 540 {t.getLoc(), f.getLoc()}); 541 rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc()); 542 rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc()); 543 544 for (auto *r : binOp.getRegions()) { 545 Block *b = &r->front(); 546 rewriter.setInsertionPointToStart(b); 547 548 IRMapping irMap; 549 // Clones the cmp operations into the region to make the binary op 550 // admissible. 551 Value newC = c; 552 if (auto *def = c.getDefiningOp()) 553 newC = rewriter.clone(*def, irMap)->getResult(0); 554 555 irMap.map(c, newC); 556 if (r == &binOp.getLeftRegion()) { 557 irMap.map(t, b->getArgument(0)); 558 irMap.map(f, c0); 559 } else if (r == &binOp.getRightRegion()) { 560 irMap.map(t, c0); 561 irMap.map(f, b->getArgument(0)); 562 } else { 563 irMap.map(t, b->getArgument(0)); 564 irMap.map(f, b->getArgument(1)); 565 } 566 auto y = rewriter.clone(inst, irMap)->getResult(0); 567 rewriter.create<sparse_tensor::YieldOp>(loc, y); 568 } 569 570 // We successfully rewrited a operation. We can not do replacement here 571 // becuase it invalidate the iterator for the current loop to traverse 572 // the instructions. 573 semiRings.emplace_back(&inst, binOp); 574 } 575 576 // Finalizes the replacement. 577 for (auto [sel, semi] : semiRings) 578 rewriter.replaceOp(sel, semi->getResults()); 579 580 return success(!semiRings.empty()); 581 } 582 583 private: 584 static std::optional<std::tuple<Value, BlockArgument, BlockArgument>> 585 isRewritablePattern(GenericOp op, Operation *v) { 586 auto sel = dyn_cast<arith::SelectOp>(v); 587 if (!sel) 588 return std::nullopt; 589 590 auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue()); 591 auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue()); 592 // TODO: For simplicity, we only handle cases where both true/false value 593 // are directly loaded the input tensor. We can probably admit more cases 594 // in theory. 595 if (!tVal || !fVal) 596 return std::nullopt; 597 598 // Helper lambda to determine whether the value is loaded from a dense input 599 // or is a loop invariant. 600 auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool { 601 if (auto bArg = dyn_cast<BlockArgument>(v); 602 bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber()))) 603 return true; 604 // If the value is defined outside the loop, it is a loop invariant. 605 return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody(); 606 }; 607 608 // If the condition value is load directly from a dense tensor or 609 // loop-invariants, we can sparsify the kernel. 610 auto cond = sel.getCondition(); 611 if (isValFromDenseInputOrInvariant(cond)) 612 return std::make_tuple(cond, tVal, fVal); 613 614 Value cmpL, cmpR; 615 if (matchPattern(cond, m_Op<arith::CmpIOp>(matchers::m_Any(&cmpL), 616 matchers::m_Any(&cmpR))) || 617 matchPattern(cond, m_Op<arith::CmpFOp>(matchers::m_Any(&cmpL), 618 matchers::m_Any(&cmpR)))) { 619 // TODO: we can do it recursively to check whether all the leaf values are 620 // loaded from dense tensors or are loop invariants. 621 if (isValFromDenseInputOrInvariant(cmpL) || 622 isValFromDenseInputOrInvariant(cmpR)) 623 return std::make_tuple(cond, tVal, fVal); 624 } 625 626 return std::nullopt; 627 }; 628 }; 629 630 /// Rewrites a sparse reduction that would not sparsify directly since 631 /// doing so would only iterate over the stored elements, ignoring the 632 /// implicit zeros, into a semi-ring. Applies to all prod/and/min/max 633 /// (note that reductions like add/sub/or/xor can directly be sparsified 634 /// since the implicit zeros do not contribute to the final result). 635 /// Note that prod/and are still included since, even though they often 636 /// are nullified in sparse data, they may still occur for special 637 /// situations in which e.g. some rows in a sparse matrix are fully 638 /// dense. For min/max, including the implicit zeros is a much more 639 /// common situation. 640 /// 641 /// TODO: this essentially "densifies" the operation; we want to implement 642 /// this much more efficiently by performing the reduction over the 643 /// stored values, and feed in the zero once if there were *any* 644 /// implicit zeros as well; but for now, at least we provide 645 /// the functionality 646 /// 647 struct GenSemiRingReduction : public OpRewritePattern<GenericOp> { 648 public: 649 using OpRewritePattern<GenericOp>::OpRewritePattern; 650 651 LogicalResult matchAndRewrite(GenericOp op, 652 PatternRewriter &rewriter) const override { 653 // Reject non-reductions. 654 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 || 655 op.getNumReductionLoops() == 0 || op.getNumResults() != 1) 656 return failure(); 657 auto *inp = op.getDpsInputOperand(0); 658 auto *init = op.getDpsInitOperand(0); 659 if (!isSparseTensor(inp)) 660 return failure(); 661 // Look for direct x = x OP y for semi-ring ready reductions. 662 auto *red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()) 663 .getOperand(0) 664 .getDefiningOp(); 665 if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp, 666 arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp, 667 arith::MaxUIOp>(red)) 668 return failure(); 669 Value s0 = op.getBlock()->getArgument(0); 670 Value s1 = op.getBlock()->getArgument(1); 671 if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) && 672 (red->getOperand(0) != s1 || red->getOperand(1) != s0)) 673 return failure(); 674 // Identity. 675 Location loc = op.getLoc(); 676 Value identity = 677 rewriter.create<tensor::ExtractOp>(loc, init->get(), ValueRange()); 678 // Unary { 679 // present -> value 680 // absent -> zero. 681 // } 682 Type rtp = s0.getType(); 683 rewriter.setInsertionPointToStart(&op.getRegion().front()); 684 auto semiring = rewriter.create<sparse_tensor::UnaryOp>(loc, rtp, s0); 685 Block *present = 686 rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc); 687 rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front()); 688 rewriter.create<sparse_tensor::YieldOp>(loc, present->getArgument(0)); 689 rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {}); 690 rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front()); 691 auto zero = 692 rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(rtp)); 693 rewriter.create<sparse_tensor::YieldOp>(loc, zero); 694 rewriter.setInsertionPointAfter(semiring); 695 // CustomReduce { 696 // x = x REDUC y, identity 697 // } 698 auto custom = rewriter.create<sparse_tensor::ReduceOp>( 699 loc, rtp, semiring.getResult(), s1, identity); 700 Block *region = 701 rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc}); 702 rewriter.setInsertionPointToStart(&custom.getRegion().front()); 703 IRMapping irMap; 704 irMap.map(red->getOperand(0), region->getArgument(0)); 705 irMap.map(red->getOperand(1), region->getArgument(1)); 706 auto *cloned = rewriter.clone(*red, irMap); 707 rewriter.create<sparse_tensor::YieldOp>(loc, cloned->getResult(0)); 708 rewriter.setInsertionPointAfter(custom); 709 rewriter.replaceOp(red, custom.getResult()); 710 return success(); 711 } 712 }; 713 714 /// Sparse rewriting rule for the print operator. This operation is mainly used 715 /// for debugging and testing. As such, it lowers to the vector.print operation 716 /// which only require very light-weight runtime support. 717 struct PrintRewriter : public OpRewritePattern<PrintOp> { 718 public: 719 using OpRewritePattern::OpRewritePattern; 720 LogicalResult matchAndRewrite(PrintOp op, 721 PatternRewriter &rewriter) const override { 722 Location loc = op.getLoc(); 723 auto tensor = op.getTensor(); 724 auto stt = getSparseTensorType(tensor); 725 // Header with NSE. 726 auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor); 727 rewriter.create<vector::PrintOp>( 728 loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = ")); 729 rewriter.create<vector::PrintOp>(loc, nse); 730 // Print run-time contents for dim/lvl sizes. 731 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("dim = ")); 732 printSizes(rewriter, loc, tensor, stt.getDimRank(), /*isDim=*/true); 733 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("lvl = ")); 734 printSizes(rewriter, loc, tensor, stt.getLvlRank(), /*isDim=*/false); 735 // Use the "codegen" foreach loop construct to iterate over 736 // all typical sparse tensor components for printing. 737 foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor, 738 &stt](Type, FieldIndex, 739 SparseTensorFieldKind kind, 740 Level l, LevelType) { 741 switch (kind) { 742 case SparseTensorFieldKind::StorageSpec: { 743 break; 744 } 745 case SparseTensorFieldKind::PosMemRef: { 746 auto lvl = constantIndex(rewriter, loc, l); 747 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos[")); 748 rewriter.create<vector::PrintOp>( 749 loc, lvl, vector::PrintPunctuation::NoPunctuation); 750 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : ")); 751 auto pos = rewriter.create<ToPositionsOp>(loc, tensor, l); 752 printContents(rewriter, loc, pos); 753 break; 754 } 755 case SparseTensorFieldKind::CrdMemRef: { 756 auto lvl = constantIndex(rewriter, loc, l); 757 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("crd[")); 758 rewriter.create<vector::PrintOp>( 759 loc, lvl, vector::PrintPunctuation::NoPunctuation); 760 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : ")); 761 Value crd = nullptr; 762 // For COO AoS storage, we want to print a single, linear view of 763 // the full coordinate storage at this level. For any other storage, 764 // we show the coordinate storage for every indivual level. 765 if (stt.getAoSCOOStart() == l) 766 crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor); 767 else 768 crd = rewriter.create<ToCoordinatesOp>(loc, tensor, l); 769 printContents(rewriter, loc, crd); 770 break; 771 } 772 case SparseTensorFieldKind::ValMemRef: { 773 rewriter.create<vector::PrintOp>(loc, 774 rewriter.getStringAttr("values : ")); 775 auto val = rewriter.create<ToValuesOp>(loc, tensor); 776 printContents(rewriter, loc, val); 777 break; 778 } 779 } 780 return true; 781 }); 782 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("----\n")); 783 rewriter.eraseOp(op); 784 return success(); 785 } 786 787 private: 788 // Helper to print contents of a single memref. For "push_back" vectors, 789 // we assume that the previous getters for pos/crd/val have added a 790 // slice-to-size view to make sure we just print the size and not the 791 // full capacity. 792 // 793 // Generates code to print (1-dim or higher): 794 // ( a0, a1, ... ) 795 static void printContents(PatternRewriter &rewriter, Location loc, 796 Value vec) { 797 auto shape = cast<ShapedType>(vec.getType()).getShape(); 798 SmallVector<Value> idxs; 799 printContentsLevel(rewriter, loc, vec, 0, shape, idxs); 800 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine); 801 } 802 803 // Helper to the helper. 804 static void printContentsLevel(PatternRewriter &rewriter, Location loc, 805 Value vec, unsigned i, ArrayRef<int64_t> shape, 806 SmallVectorImpl<Value> &idxs) { 807 // Open bracket. 808 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open); 809 // Generate for loop. 810 auto zero = constantIndex(rewriter, loc, 0); 811 auto index = constantIndex(rewriter, loc, i); 812 auto size = rewriter.create<memref::DimOp>(loc, vec, index); 813 auto step = constantIndex(rewriter, loc, 1); 814 auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step); 815 idxs.push_back(forOp.getInductionVar()); 816 rewriter.setInsertionPointToStart(forOp.getBody()); 817 if (i < shape.size() - 1) { 818 // Enter deeper loop nest. 819 printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs); 820 } else { 821 // Actual contents printing. 822 auto val = rewriter.create<memref::LoadOp>(loc, vec, idxs); 823 if (llvm::isa<ComplexType>(val.getType())) { 824 // Since the vector dialect does not support complex types in any op, 825 // we split those into (real, imag) pairs here. 826 Value real = rewriter.create<complex::ReOp>(loc, val); 827 Value imag = rewriter.create<complex::ImOp>(loc, val); 828 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open); 829 rewriter.create<vector::PrintOp>(loc, real, 830 vector::PrintPunctuation::Comma); 831 rewriter.create<vector::PrintOp>(loc, imag, 832 vector::PrintPunctuation::Close); 833 } else { 834 rewriter.create<vector::PrintOp>( 835 loc, val, vector::PrintPunctuation::NoPunctuation); 836 } 837 // Terminating comma (except at end). 838 auto bound = rewriter.create<arith::AddIOp>(loc, idxs.back(), step); 839 Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, 840 bound, size); 841 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false); 842 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 843 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma); 844 } 845 idxs.pop_back(); 846 rewriter.setInsertionPointAfter(forOp); 847 // Close bracket. 848 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close); 849 } 850 851 // Helper method to print run-time lvl/dim sizes. 852 static void printSizes(PatternRewriter &rewriter, Location loc, Value tensor, 853 unsigned size, bool isDim) { 854 // Open bracket. 855 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open); 856 // Print unrolled contents (dimop requires constant value). 857 for (unsigned i = 0; i < size; i++) { 858 auto idx = constantIndex(rewriter, loc, i); 859 Value val; 860 if (isDim) 861 val = rewriter.create<tensor::DimOp>(loc, tensor, idx); 862 else 863 val = rewriter.create<LvlOp>(loc, tensor, idx); 864 rewriter.create<vector::PrintOp>( 865 loc, val, 866 i != size - 1 ? vector::PrintPunctuation::Comma 867 : vector::PrintPunctuation::NoPunctuation); 868 } 869 // Close bracket and end of line. 870 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close); 871 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine); 872 } 873 }; 874 875 /// Sparse rewriting rule for sparse-to-sparse reshape operator. 876 struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> { 877 public: 878 using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern; 879 880 LogicalResult matchAndRewrite(tensor::ReshapeOp op, 881 PatternRewriter &rewriter) const override { 882 Location loc = op.getLoc(); 883 Value srcTensor = op.getSource(); 884 const auto srcTp = tryGetSparseTensorType(srcTensor); 885 const auto dstTp = tryGetSparseTensorType(op.getResult()); 886 if (!srcTp || !dstTp) 887 return failure(); 888 889 if (!srcTp->hasEncoding() || !dstTp->hasEncoding() || 890 !dstTp->hasStaticDimShape()) 891 return failure(); 892 893 SmallVector<Value> srcSizes; 894 sizesForTensor(rewriter, srcSizes, loc, *srcTp, srcTensor); 895 SmallVector<Value> dstSizes; 896 for (Dimension d : dstTp->getDimShape()) 897 dstSizes.push_back(constantIndex(rewriter, loc, d)); 898 899 Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor); 900 // Only need an unordered COO buffer if input and output are not sorted 901 // in the same way. 902 Type bufferTp = getBufferType( 903 dstTp->withoutDimToLvl(), 904 !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity()); 905 SmallVector<Value> dynSizes; 906 Value buffer = rewriter 907 .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(), 908 nnz, Attribute()) 909 .getResult(); 910 911 // Convert src coordinates to dst coordinates by first collapsing it to 1D 912 // and then expand it to the match the rank of the destination tensor. 913 // Implemented as follows: 914 // foreach srcCoords %srcTensor 915 // collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank]) 916 // expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank]) 917 // insert expandedCoords, %buffer 918 // 919 // followed by an optional 920 // %t = sparse_tensor.cast %tmp 921 // depending on whether the input/output are sorted in the same way. 922 const auto encSrc = srcTp->getEncoding(); 923 ForeachOp foreachOp = rewriter.create<ForeachOp>( 924 loc, srcTensor, buffer, 925 [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v, 926 ValueRange reduc) { 927 const Dimension srcRank = srcTp->getDimRank(); 928 SmallVector<Value> srcDcvs; 929 srcDcvs.reserve(srcRank); 930 for (Dimension d = 0; d < srcRank; d++) { 931 Level lvl = toLvl(encSrc, d); 932 srcDcvs.push_back(srcLcvs[lvl]); 933 } 934 935 Value collapseSize = constantIndex(builder, loc, 1); 936 for (Dimension d = 0; d < srcRank; d++) 937 collapseSize = 938 builder.create<arith::MulIOp>(loc, collapseSize, srcSizes[d]); 939 SmallVector<Value, 1> collapsedSizes = {collapseSize}; 940 941 ReassociationIndices collapseIdx; 942 for (Dimension i = 0; i < srcRank; i++) 943 collapseIdx.push_back(i); 944 SmallVector<ReassociationIndices, 1> collapseReass = {collapseIdx}; 945 SmallVector<Value, 1> collapsedDcvs; 946 reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs, 947 collapsedSizes, collapsedDcvs); 948 949 ReassociationIndices expandIdx; 950 for (Dimension i = 0; i < dstTp->getDimRank(); i++) 951 expandIdx.push_back(i); 952 SmallVector<ReassociationIndices, 1> expandReass = {expandIdx}; 953 SmallVector<Value> dstDcvs; 954 reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs, 955 dstSizes, dstDcvs); 956 957 auto t = 958 builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs); 959 builder.create<sparse_tensor::YieldOp>(loc, t); 960 }); 961 962 Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true); 963 if (bufferTp != *dstTp) { 964 auto dstRTT = dstTp->getRankedTensorType(); 965 Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult(); 966 rewriter.create<DeallocTensorOp>(loc, t); 967 t = converted; 968 } 969 rewriter.replaceOp(op, t); 970 return success(); 971 } 972 }; 973 974 /// Sparse rewriting rule for sparse-to-sparse reshape operator. 975 template <typename ReshapeOp> 976 struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> { 977 public: 978 using OpRewritePattern<ReshapeOp>::OpRewritePattern; 979 980 LogicalResult matchAndRewrite(ReshapeOp op, 981 PatternRewriter &rewriter) const override { 982 Location loc = op.getLoc(); 983 Value srcTensor = op.getSrc(); 984 const auto srcTp = getSparseTensorType(srcTensor); 985 const auto dstTp = getSparseTensorType(op.getResult()); 986 if (!srcTp.hasEncoding() || !dstTp.hasEncoding()) 987 return failure(); 988 989 // Generate code to represent the static dimension constants or compute 990 // the dynamic dimension values. 991 SmallVector<Value> srcSizes; 992 sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor); 993 SmallVector<Value> dstSizes; 994 SmallVector<Value> dstDynSizes; 995 if (dstTp.hasStaticDimShape()) { 996 for (Dimension d : dstTp.getDimShape()) 997 dstSizes.push_back(constantIndex(rewriter, loc, d)); 998 } else { 999 ArrayRef<Size> dstShape = dstTp.getDimShape(); 1000 genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape, 1001 op.getReassociationIndices()); 1002 for (auto [idx, shape] : llvm::enumerate(dstShape)) { 1003 if (shape == ShapedType::kDynamic) 1004 dstDynSizes.push_back(dstSizes[idx]); 1005 } 1006 } 1007 Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor); 1008 // Only need a unordered COO buffer if input and output are not sorted 1009 // in the same way. 1010 Type bufferTp = getBufferType( 1011 dstTp.withoutDimToLvl(), 1012 !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity()); 1013 1014 Value buffer = 1015 rewriter 1016 .create<AllocTensorOp>(loc, bufferTp, dstDynSizes, Value(), 1017 /*sizeHint=*/nnz, Attribute()) 1018 .getResult(); 1019 1020 // Implement the sparse2sparse reshape as follows: 1021 // foreach srcCoords %srcTensor 1022 // insert reshapeCvs(srcCoords), %buffer 1023 // 1024 // followed by an optional 1025 // %t = sparse_tensor.cast %tmp 1026 // depending on whether the input/output are sorted in the same way. 1027 const auto encSrc = srcTp.getEncoding(); 1028 ForeachOp foreachOp = rewriter.create<ForeachOp>( 1029 loc, srcTensor, buffer, 1030 [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v, 1031 ValueRange reduc) { 1032 const Dimension dimRank = srcTp.getDimRank(); 1033 SmallVector<Value> srcDcvs; 1034 srcDcvs.reserve(dimRank); 1035 for (Dimension d = 0; d < dimRank; d++) { 1036 Level lvl = toLvl(encSrc, d); 1037 srcDcvs.push_back(srcLcvs[lvl]); 1038 } 1039 SmallVector<Value> dstDcvs; 1040 reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes, 1041 srcDcvs, dstSizes, dstDcvs); 1042 auto t = 1043 builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs); 1044 builder.create<sparse_tensor::YieldOp>(loc, t); 1045 }); 1046 1047 Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true); 1048 if (bufferTp != dstTp) { 1049 auto dstRTT = dstTp.getRankedTensorType(); 1050 Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult(); 1051 rewriter.create<DeallocTensorOp>(loc, t); 1052 t = converted; 1053 } 1054 rewriter.replaceOp(op, t); 1055 return success(); 1056 } 1057 }; 1058 1059 /// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape 1060 /// operator. 1061 template <typename ReshapeOp> 1062 struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> { 1063 public: 1064 using OpRewritePattern<ReshapeOp>::OpRewritePattern; 1065 1066 LogicalResult matchAndRewrite(ReshapeOp op, 1067 PatternRewriter &rewriter) const override { 1068 Location loc = op->getLoc(); 1069 auto encDst = getSparseTensorEncoding(op.getResult().getType()); 1070 auto encSrc = getSparseTensorEncoding(op.getSrc().getType()); 1071 // Since a pure dense expansion is very cheap (change of view), for 1072 // a sparse2dense or dense2sparse, we can simply unfuse a sparse 1073 // conversion from the reshape operation itself. 1074 // All other cases are handled elsewhere. 1075 if (encDst && encSrc) { 1076 return failure(); 1077 } 1078 if (encSrc) { 1079 auto rtp = getRankedTensorType(op.getSrc()); 1080 auto denseTp = 1081 RankedTensorType::get(rtp.getShape(), rtp.getElementType()); 1082 auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc()); 1083 rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); }); 1084 return success(); 1085 } 1086 if (encDst) { 1087 auto rtp = getRankedTensorType(op.getResult()); 1088 auto denseTp = 1089 RankedTensorType::get(rtp.getShape(), rtp.getElementType()); 1090 ReshapeOp reshape; 1091 if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) { 1092 reshape = rewriter.create<ReshapeOp>( 1093 loc, denseTp, op.getSrc(), op.getReassociation(), 1094 op.getOutputShape(), op.getStaticOutputShape()); 1095 } else { 1096 reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(), 1097 op.getReassociation()); 1098 } 1099 Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape); 1100 rewriter.replaceOp(op, convert); 1101 return success(); 1102 } 1103 return failure(); 1104 } 1105 }; 1106 1107 // A trivial wrapper to help generate different operations for dense/sparse 1108 // tensors. 1109 struct TensorLike { 1110 TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt, 1111 ValueRange sizes) { 1112 SmallVector<Value> dynSzs; 1113 getDynamicSizes(rtt, sizes, dynSzs); 1114 1115 val = builder.create<AllocTensorOp>(loc, rtt, dynSzs); 1116 if (!isSparse()) { 1117 Value c0 = constantZero(builder, loc, rtt.getElementType()); 1118 val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0); 1119 } 1120 } 1121 1122 void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) { 1123 val = builder.create<tensor::InsertOp>(loc, v, val, crds); 1124 } 1125 1126 Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const { 1127 if (isSparse()) 1128 return builder.create<LoadOp>(loc, val, true); 1129 return val; 1130 } 1131 1132 bool isSparse() const { 1133 return getSparseTensorEncoding(val.getType()) != nullptr; 1134 } 1135 1136 Value val; 1137 }; 1138 1139 struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> { 1140 using OpRewritePattern::OpRewritePattern; 1141 LogicalResult matchAndRewrite(tensor::DimOp op, 1142 PatternRewriter &rewriter) const override { 1143 std::optional<int64_t> dim = op.getConstantIndex(); 1144 auto stt = tryGetSparseTensorType(op.getSource()); 1145 if (!dim || !stt || !stt->hasEncoding()) 1146 return failure(); 1147 1148 if (stt->isPermutation()) { 1149 rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(), 1150 toLvl(stt->getEncoding(), *dim)); 1151 return success(); 1152 } 1153 1154 // Non-permutation dim2lvl/lvl2dim maps. 1155 // Compute as follows: 1156 // affine.apply #map (l0 - 1, l1 - 1, ...) + 1 1157 // Note that it is not the most efficient way (but a more general one) for 1158 // the lvl to dim translation, e.g., for BSR, the dimension size for can be 1159 // computed simply by lvl_size * block_size. 1160 Location loc = op.getLoc(); 1161 SmallVector<Value> maxLvlCrds; 1162 for (Level l = 0; l < stt->getLvlRank(); l++) { 1163 Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l); 1164 Value maxLvlCrd = rewriter.create<arith::SubIOp>( 1165 loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType())); 1166 maxLvlCrds.push_back(maxLvlCrd); 1167 } 1168 1169 AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim); 1170 Value maxDimCrd = rewriter.create<affine::AffineApplyOp>( 1171 op.getLoc(), AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp), 1172 maxLvlCrds); 1173 1174 Value dimSz = rewriter.create<arith::AddIOp>( 1175 loc, maxDimCrd, constantOne(rewriter, loc, rewriter.getIndexType())); 1176 rewriter.replaceOp(op, dimSz); 1177 return success(); 1178 } 1179 }; 1180 1181 struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> { 1182 using OpRewritePattern::OpRewritePattern; 1183 LogicalResult matchAndRewrite(ConcatenateOp op, 1184 PatternRewriter &rewriter) const override { 1185 if (op.needsExtraSort()) 1186 op.emitError("ConcatenateOp not staged"); 1187 1188 const Location loc = op.getLoc(); 1189 const auto dstTp = getSparseTensorType(op); 1190 const Dimension conDim = op.getDimension(); 1191 SmallVector<Value> sizes; 1192 concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim); 1193 1194 // %t = concatenate %s1, %s2, %s3 {dim = 1} 1195 // ==> 1196 // if (isSparseDst) 1197 // if (allDense) 1198 // %tmp = bufferization.alloc_tensor dstTp 1199 // else 1200 // %tmp = bufferization.alloc_tensor : unordered COO 1201 // else 1202 // %tmp = memref.alloc : dense tensor 1203 // foreach in %s1 : insert d0, d1, %tmp 1204 // foreach in %s2 : insert d0, d1 + size(s1), %tmp 1205 // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp 1206 1207 TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes); 1208 Value offset = constantIndex(rewriter, loc, 0); 1209 Value iterArg = dstBuf.val; 1210 1211 ForeachOp foreachOp; 1212 for (Value input : op.getInputs()) { 1213 // Builds a for op for each input tensor to append new values into the 1214 // output tensor. 1215 foreachOp = rewriter.create<ForeachOp>( 1216 loc, input, iterArg, 1217 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v, 1218 ValueRange reduc) { 1219 SmallVector<Value> offDimCrd(dcvs); 1220 offDimCrd[conDim] = 1221 builder.create<arith::AddIOp>(loc, offDimCrd[conDim], offset); 1222 1223 // Enters foreach, updates the SSA chain. 1224 dstBuf.val = reduc.front(); 1225 if (!dstTp.isAllDense()) { 1226 Value cond = genIsNonzero(builder, loc, v); 1227 auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond, 1228 /*else*/ true); 1229 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 1230 builder.create<scf::YieldOp>(loc, dstBuf.val); 1231 1232 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1233 dstBuf.insert(builder, loc, v, offDimCrd); 1234 builder.create<scf::YieldOp>(loc, dstBuf.val); 1235 1236 // Exits the ifOp, update the sparse tensor SSA value. 1237 builder.setInsertionPointAfter(ifOp); 1238 dstBuf.val = ifOp.getResult(0); 1239 } else { 1240 dstBuf.insert(builder, loc, v, offDimCrd); 1241 } 1242 builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val); 1243 }); 1244 // Accumulates the offset. Note that only static-shaped inputs are allowed 1245 // by concatenate op verifier, which saves us from computing the offset 1246 // dynamically. 1247 const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim); 1248 assert(!ShapedType::isDynamic(sz)); 1249 offset = rewriter.create<arith::AddIOp>(loc, offset, 1250 constantIndex(rewriter, loc, sz)); 1251 iterArg = foreachOp.getResult(0); 1252 dstBuf.val = iterArg; 1253 } 1254 1255 dstBuf.val = iterArg; 1256 Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType()); 1257 rewriter.replaceOp(op, ret); 1258 return success(); 1259 } 1260 }; 1261 1262 struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> { 1263 using OpRewritePattern::OpRewritePattern; 1264 LogicalResult matchAndRewrite(ConvertOp op, 1265 PatternRewriter &rewriter) const override { 1266 if (op.needsExtraSort()) 1267 return op.emitError("ConvertOp not staged."); 1268 1269 // TODO: Maybe we want a different operation for this too. 1270 auto encDst = getSparseTensorEncoding(op.getType()); 1271 auto encSrc = getSparseTensorEncoding(op.getSource().getType()); 1272 if (encDst && encSrc && !encSrc.isSlice() && 1273 encSrc.withoutBitWidths() == encDst.withoutBitWidths()) { 1274 // Trivial tensor conversion and simple element type conversion is handled 1275 // in codegen. 1276 return failure(); 1277 } 1278 1279 Location loc = op.getLoc(); 1280 Value src = op.getSource(); 1281 1282 SparseTensorType srcStt = getSparseTensorType(op.getSource()); 1283 SparseTensorType dstStt = getSparseTensorType(op.getDest()); 1284 1285 bool fromSparseConst = false; 1286 if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>()) 1287 if (dyn_cast<SparseElementsAttr>(constOp.getValue())) 1288 fromSparseConst = true; 1289 1290 const AffineMapAttr foreachOrder = 1291 (!dstStt.isIdentity() && fromSparseConst) 1292 ? AffineMapAttr::get(dstStt.getExpandedDimToLvl()) 1293 : nullptr; 1294 1295 bool skipZeroCheck = srcStt.hasEncoding() || fromSparseConst; 1296 1297 SmallVector<Value> sizes; 1298 sizesFromSrc(rewriter, sizes, loc, src); 1299 ValueRange vs; 1300 TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes); 1301 1302 auto foreachOp = rewriter.create<ForeachOp>( 1303 loc, src, dstBuf.val, foreachOrder, 1304 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v, 1305 ValueRange reduc) { 1306 // Enters the loop, update the SSA value for insertion chain. 1307 dstBuf.val = reduc.front(); 1308 if (!skipZeroCheck) { 1309 Value cond = genIsNonzero(builder, loc, v); 1310 auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond, 1311 /*else*/ true); 1312 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 1313 builder.create<scf::YieldOp>(loc, dstBuf.val); 1314 1315 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1316 dstBuf.insert(builder, loc, v, dcvs); 1317 builder.create<scf::YieldOp>(loc, dstBuf.val); 1318 1319 // Exits the ifOp, update the sparse tensor SSA value. 1320 builder.setInsertionPointAfter(ifOp); 1321 dstBuf.val = ifOp.getResult(0); 1322 } else { 1323 dstBuf.insert(builder, loc, v, dcvs); 1324 } 1325 builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val); 1326 }); 1327 1328 rewriter.setInsertionPointAfter(foreachOp); 1329 1330 // Exits the for loop, links the SSA chain. 1331 dstBuf.val = foreachOp.getResult(0); 1332 1333 Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType()); 1334 rewriter.replaceOp(op, ret); 1335 return success(); 1336 } 1337 }; 1338 1339 struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> { 1340 using OpRewritePattern::OpRewritePattern; 1341 LogicalResult matchAndRewrite(CrdTranslateOp op, 1342 PatternRewriter &rewriter) const override { 1343 AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl 1344 ? op.getEncoder().getDimToLvl() 1345 : op.getEncoder().getLvlToDim(); 1346 1347 SmallVector<Value> outCrds; 1348 for (AffineExpr result : map.getResults()) { 1349 // TODO: we should probably expand the affine map to IR using our own 1350 // rules, since affine.apply assume signed value, while the cooridinates 1351 // we provided must always be signless. 1352 Value trans = rewriter.create<affine::AffineApplyOp>( 1353 op.getLoc(), AffineMap::get(map.getNumDims(), 0, result), 1354 op.getInCrds()); 1355 outCrds.push_back(trans); 1356 } 1357 rewriter.replaceOp(op, outCrds); 1358 return success(); 1359 } 1360 }; 1361 1362 /// Sparse rewriting rule for the foreach operator. 1363 struct ForeachRewriter : public OpRewritePattern<ForeachOp> { 1364 public: 1365 using OpRewritePattern::OpRewritePattern; 1366 1367 LogicalResult matchAndRewrite(ForeachOp op, 1368 PatternRewriter &rewriter) const override { 1369 1370 auto loc = op.getLoc(); 1371 Value input = op.getTensor(); 1372 SmallVector<Value> reduc = op.getInitArgs(); 1373 const auto stt = getSparseTensorType(input); 1374 const Level lvlRank = stt.getLvlRank(); 1375 1376 // Special-case: for each over a sparse constant uses its own rewriting 1377 // rule. 1378 if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) { 1379 if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) { 1380 return genForeachOnSparseConstant(op, rewriter, attr); 1381 } 1382 } 1383 1384 // Otherwise, use loop emitter to generate loops. 1385 const auto enc = stt.getEncoding(); 1386 1387 // 1. Generates loop for the sparse input. 1388 LoopEmitter loopEmitter( 1389 ValueRange{input}, 1390 StringAttr::get(getContext(), ForeachOp::getOperationName())); 1391 loopEmitter.initializeLoopEmit(rewriter, loc); 1392 for (Level l = 0; l < lvlRank; l++) { 1393 // TODO: provide utility function for loop sequences that only contains 1394 // one for loop? 1395 const SmallVector<TensorLevel, 1> tidLvls{ 1396 loopEmitter.makeTensorLevel(0, l)}; 1397 loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls); 1398 // Note that reduc will be taken care of by loop emitter and get updated 1399 // in place. 1400 loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1, 1401 reduc); 1402 } 1403 1404 SmallVector<Value> lcvs = loopEmitter.getLoopIVs(); 1405 if (op.getOrder()) { 1406 // TODO: Support it so that we can do direct conversion from CSR->BSR. 1407 llvm_unreachable( 1408 "Level order not yet implemented on non-constant input tensors."); 1409 } 1410 1411 Value vals = loopEmitter.getValBuffer()[0]; 1412 SmallVector<Value> pos = loopEmitter.getValPosits(0); 1413 // Loads the value from sparse tensor using position-index; 1414 // loads the value from dense tensor using coords. 1415 Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pos) 1416 : rewriter.create<memref::LoadOp>(loc, vals, lcvs); 1417 1418 // 2. Inline the block in the foreach operator. 1419 Block *srcBlock = op.getBody(); 1420 1421 // Remap coordinates. 1422 SmallVector<Value> args = 1423 enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim); 1424 1425 // Remap value. 1426 args.push_back(val); 1427 // Remap reduction variables. 1428 args.append(reduc); 1429 1430 // Remove sparse_tensor.yield. 1431 SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands(); 1432 rewriter.eraseOp(srcBlock->getTerminator()); 1433 1434 Operation &last = rewriter.getBlock()->back(); 1435 if (llvm::isa<scf::YieldOp>(last)) { 1436 // Because `scf.for` inserts an implicit yield op when there is no 1437 // reduction variable upon creation, we reset the insertion point such 1438 // that the block is inlined before *before* the yield op. 1439 rewriter.setInsertionPoint(&last); 1440 } 1441 1442 rewriter.inlineBlockBefore(srcBlock, rewriter.getBlock(), 1443 rewriter.getInsertionPoint(), args); 1444 rewriter.setInsertionPointToEnd(rewriter.getBlock()); 1445 for (Level l = 0; l < lvlRank; l++) { 1446 // Link the reduction chain. Note that loop emitter update the reducValue 1447 // in place. 1448 loopEmitter.exitCurrentLoop(rewriter, loc, reducValue); 1449 loopEmitter.exitCurrentLoopSeq(rewriter, loc); 1450 } 1451 1452 // Replace the foreach operator with the value returned by the outtermost 1453 // for loop. 1454 rewriter.replaceOp(op, reducValue); 1455 return success(); 1456 } 1457 }; 1458 1459 /// Sparse rewriting rule for the new operator. 1460 struct NewRewriter : public OpRewritePattern<NewOp> { 1461 using OpRewritePattern::OpRewritePattern; 1462 LogicalResult matchAndRewrite(NewOp op, 1463 PatternRewriter &rewriter) const override { 1464 Location loc = op.getLoc(); 1465 auto stt = getSparseTensorType(op.getResult()); 1466 if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0) 1467 return failure(); 1468 1469 // Implement the NewOp as follows: 1470 // %orderedCoo = sparse_tensor.new %filename 1471 // %t = sparse_tensor.convert %orderedCoo 1472 // with enveloping reinterpreted_map ops for non-permutations. 1473 RankedTensorType dstTp = stt.getRankedTensorType(); 1474 RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true); 1475 Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource()); 1476 Value convert = cooTensor; 1477 auto enc = stt.getEncoding(); 1478 if (!stt.isPermutation()) { // demap coo, demap dstTp 1479 auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl(); 1480 convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert); 1481 dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl()); 1482 } 1483 convert = rewriter.create<ConvertOp>(loc, dstTp, convert); 1484 if (!stt.isPermutation()) // remap to original enc 1485 convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert); 1486 rewriter.replaceOp(op, convert); 1487 1488 // Release the temporary ordered COO tensor. 1489 rewriter.setInsertionPointAfterValue(convert); 1490 rewriter.create<DeallocTensorOp>(loc, cooTensor); 1491 1492 return success(); 1493 } 1494 }; 1495 1496 /// Sparse rewriting rule for the out operator. 1497 struct OutRewriter : public OpRewritePattern<OutOp> { 1498 using OpRewritePattern::OpRewritePattern; 1499 LogicalResult matchAndRewrite(OutOp op, 1500 PatternRewriter &rewriter) const override { 1501 Location loc = op.getLoc(); 1502 // Calculate NNZ. 1503 Value src = op.getTensor(); 1504 Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src); 1505 1506 // Allocate a temporary buffer for storing dimension-sizes/coordinates. 1507 const auto srcTp = getSparseTensorType(src); 1508 const Dimension dimRank = srcTp.getDimRank(); 1509 Type indexTp = rewriter.getIndexType(); 1510 Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp); 1511 1512 // Generate code to calculate dimension size values and store the values to 1513 // the buffer. 1514 SmallVector<Value> dims; 1515 sizesForTensor(rewriter, dims, loc, srcTp, src); 1516 for (Dimension d = 0; d < dimRank; d++) { 1517 rewriter.create<memref::StoreOp>(loc, dims[d], dimSizes, 1518 constantIndex(rewriter, loc, d)); 1519 } 1520 1521 // Create a sparse tensor writer and output meta data. 1522 Type opaqueTp = getOpaquePointerType(rewriter); 1523 Value writer = 1524 createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp}, 1525 {op.getDest()}, EmitCInterface::Off) 1526 .getResult(0); 1527 Value rankValue = constantIndex(rewriter, loc, dimRank); 1528 createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {}, 1529 {writer, rankValue, nnz, dimSizes}, EmitCInterface::On); 1530 1531 Value dimCoords = dimSizes; // Reuse the dimSizes buffer for dimCoords. 1532 Type eltTp = srcTp.getElementType(); 1533 SmallString<29> outNextFuncName{"outSparseTensorWriterNext", 1534 primaryTypeFunctionSuffix(eltTp)}; 1535 Value value = genAllocaScalar(rewriter, loc, eltTp); 1536 ModuleOp module = op->getParentOfType<ModuleOp>(); 1537 1538 // For each element in the source tensor, output the element. 1539 rewriter.create<ForeachOp>( 1540 loc, src, std::nullopt, 1541 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v, 1542 ValueRange reduc) { 1543 for (Dimension d = 0; d < dimRank; d++) { 1544 rewriter.create<memref::StoreOp>(loc, dcvs[d], dimCoords, 1545 constantIndex(builder, loc, d)); 1546 } 1547 rewriter.create<memref::StoreOp>(loc, v, value); 1548 SmallVector<Value> operands{writer, rankValue, dimCoords, value}; 1549 FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands, 1550 EmitCInterface::On); 1551 builder.create<func::CallOp>(loc, TypeRange(), fn, operands); 1552 builder.create<sparse_tensor::YieldOp>(loc); 1553 }); 1554 1555 // Release the writer. 1556 createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer}, 1557 EmitCInterface::Off); 1558 1559 rewriter.eraseOp(op); 1560 return success(); 1561 } 1562 }; 1563 1564 } // namespace 1565 1566 //===---------------------------------------------------------------------===// 1567 // Methods that add patterns described in this file to a pattern list. 1568 //===---------------------------------------------------------------------===// 1569 1570 void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) { 1571 patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer, 1572 FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast, 1573 GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>( 1574 patterns.getContext()); 1575 } 1576 1577 void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, 1578 bool enableRT, 1579 bool enableConvert) { 1580 patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>, 1581 ReshapeRewriter<tensor::CollapseShapeOp>, 1582 Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>, 1583 Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>, 1584 SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>( 1585 patterns.getContext()); 1586 1587 if (enableConvert) 1588 patterns.add<DirectConvertRewriter>(patterns.getContext()); 1589 if (!enableRT) 1590 patterns.add<NewRewriter>(patterns.getContext()); 1591 } 1592 1593 void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) { 1594 // Run CrdTranslateRewriter later in the pipeline so that operation can be 1595 // folded before lowering to affine.apply 1596 patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext()); 1597 } 1598