xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (revision 129ade21bdad5f09206b773cd0591a9616ad0ca4)
1 //===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting rules that are specific to sparse tensors.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Utils/CodegenUtils.h"
14 #include "Utils/LoopEmitter.h"
15 
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
24 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
25 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
26 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
27 #include "mlir/Dialect/Tensor/IR/Tensor.h"
28 #include "mlir/Dialect/Vector/IR/VectorOps.h"
29 #include "mlir/IR/AffineMap.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/Support/LLVM.h"
32 
33 using namespace mlir;
34 using namespace mlir::bufferization;
35 using namespace mlir::linalg;
36 using namespace mlir::sparse_tensor;
37 
38 //===---------------------------------------------------------------------===//
39 // Helper methods for the actual rewriting rules.
40 //===---------------------------------------------------------------------===//
41 
42 // Helper method to match any typed zero.
43 static bool isZeroValue(Value val) {
44   return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat());
45 }
46 
47 // Helper to detect a sparse tensor type operand.
48 static bool isSparseTensor(Value v) {
49   auto enc = getSparseTensorEncoding(v.getType());
50   return enc && !llvm::all_of(enc.getLvlTypes(),
51                               [](auto lt) { return lt == LevelFormat::Dense; });
52 }
53 static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
54 
55 // Helper method to find zero/uninitialized tensor materialization.
56 static bool isMaterializing(OpOperand *op, bool isZero) {
57   Value val = op->get();
58   // Check allocation, with zero alloc when required.
59   if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
60     Value copy = alloc.getCopy();
61     if (isZero)
62       return copy && isZeroValue(copy);
63     return !copy;
64   }
65   // Check for empty tensor materialization.
66   if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
67     return !isZero;
68   // Last resort for zero alloc: the whole value is zero.
69   return isZero && isZeroValue(val);
70 }
71 
72 // Helper to detect sampling operation.
73 static bool isSampling(GenericOp op) {
74   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
75   if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
76     if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
77       // Both scalar input arguments used exactly once.
78       Value s1 = op.getBlock()->getArgument(0);
79       Value s2 = op.getBlock()->getArgument(1);
80       return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
81              (def->getOperand(1) == s1 && def->getOperand(0) == s2);
82     }
83   }
84   return false;
85 }
86 
87 // Helper to detect chain of multiplications that do not involve x.
88 static bool isMulChain(Value val, Value x) {
89   if (auto arg = dyn_cast<BlockArgument>(val))
90     return arg != x;
91   if (auto *def = val.getDefiningOp()) {
92     if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
93       return isMulChain(def->getOperand(0), x) &&
94              isMulChain(def->getOperand(1), x);
95   }
96   return false;
97 }
98 
99 // Helper to detect x = x + <multiplications>.
100 static bool isSumOfMul(GenericOp op) {
101   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
102   if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
103     if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
104       Value x = op.getBlock()->getArguments().back();
105       return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
106              (def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
107     }
108   }
109   return false;
110 }
111 
112 // Helper to detect direct yield of a zero value.
113 static bool isZeroYield(GenericOp op) {
114   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
115   if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
116     if (arg.getOwner()->getParentOp() == op) {
117       return isZeroValue(op->getOperand(arg.getArgNumber()));
118     }
119   }
120   return isZeroValue(yieldOp.getOperand(0));
121 }
122 
123 /// Populates given sizes array from type (for static sizes) and from
124 /// the tensor (for dynamic sizes).
125 static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
126                            Location loc, ShapedType stp, Value tensor) {
127   for (const auto &d : enumerate(stp.getShape())) {
128     Value dim;
129     if (d.value() == ShapedType::kDynamic)
130       dim = builder.create<tensor::DimOp>(loc, tensor, d.index());
131     else
132       dim = constantIndex(builder, loc, d.value());
133     sizes.push_back(dim);
134   }
135 }
136 
137 static RankedTensorType getBufferType(const SparseTensorType &stt,
138                                       bool needTmpCOO) {
139   return needTmpCOO ? stt.getCOOType(/*ordered=*/false)
140                     : stt.getRankedTensorType();
141 }
142 
143 /// Collects the dynamic dimension sizes for `tp` with the assumption that
144 /// `sizes` are the dimension sizes for the type. Stores the dynamic dimension
145 /// sizes to dynSizes.
146 static void getDynamicSizes(RankedTensorType tp, ValueRange sizes,
147                             SmallVectorImpl<Value> &dynSizes) {
148   for (const auto &d : enumerate(tp.getShape())) {
149     if (d.value() == ShapedType::kDynamic)
150       dynSizes.push_back(sizes[d.index()]);
151   }
152 }
153 
154 static LogicalResult genForeachOnSparseConstant(ForeachOp op,
155                                                 RewriterBase &rewriter,
156                                                 SparseElementsAttr attr) {
157   auto loc = op.getLoc();
158   SmallVector<Value> reduc = op.getInitArgs();
159 
160   // Foreach on constant.
161   foreachInSparseConstant(
162       rewriter, loc, attr, op.getOrder().value_or(AffineMap()),
163       [&reduc, &rewriter, op](ArrayRef<Value> cvs, Value v) mutable {
164         SmallVector<Value> args;
165         args.append(cvs.begin(), cvs.end());
166         args.push_back(v);
167         args.append(reduc);
168         // Clones the foreach op to get a copy of the loop body.
169         auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
170         assert(args.size() == cloned.getBody()->getNumArguments());
171         Operation *yield = cloned.getBody()->getTerminator();
172         rewriter.inlineBlockBefore(cloned.getBody(), op, args);
173         // clean up
174         rewriter.eraseOp(cloned);
175         reduc = yield->getOperands();
176         rewriter.eraseOp(yield);
177       });
178 
179   rewriter.replaceOp(op, reduc);
180   return success();
181 }
182 
183 /// Populates the given sizes array for concatenation from types (for static
184 /// sizes) and from the source tensors (for dynamic sizes).
185 static void concatSizesFromInputs(OpBuilder &builder,
186                                   SmallVectorImpl<Value> &sizes, Location loc,
187                                   ShapedType dstTp, ValueRange srcs,
188                                   unsigned dim) {
189   auto dstShape = dstTp.getShape();
190   sizesFromSrc(builder, sizes, loc, srcs[0]);
191 
192   // Sum up on the `dim` if the dimension is dynamic.
193   if (dstShape[dim] != ShapedType::kDynamic) {
194     // Faithfully take the static size.
195     sizes[dim] = constantIndex(builder, loc, dstShape[dim]);
196   } else {
197     // Else, compute the shape dynamically.
198     for (const auto &src : srcs.drop_front()) {
199       Value srcSz = linalg::createOrFoldDimOp(builder, loc, src, dim);
200       // Sum up all the sizes.
201       sizes[dim] = builder.create<arith::AddIOp>(loc, sizes[dim], srcSz);
202     }
203   }
204 }
205 
206 //===---------------------------------------------------------------------===//
207 // The actual sparse tensor rewriting rules.
208 //===---------------------------------------------------------------------===//
209 
210 namespace {
211 
212 /// TODO: move it to tensor dialect instead.
213 ///
214 /// Fold `tensor.concat` and `tensor.extract_slice`
215 ///
216 /// %concat = tensor.concat dim(2) %t0, %t1
217 ///   : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
218 /// %extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
219 ///   : tensor<1x64x2xf32> to tensor<1x64x1xf32>
220 /// %extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
221 ///   : tensor<1x64x2xf32> to tensor<1x64x1xf32>
222 ///
223 /// Becomes
224 ///
225 /// %extract0, %extract1 = %t0, %t1
226 struct FuseExtractSliceWithConcat
227     : public OpRewritePattern<tensor::ExtractSliceOp> {
228   using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
229 
230   LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
231                                 PatternRewriter &rewriter) const override {
232     auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
233     if (!concatOp)
234       return failure();
235 
236     Location loc = extractOp.getLoc();
237     int64_t dim = concatOp.getDim();
238     int64_t rank = extractOp.getResultType().getRank();
239 
240     SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1));
241     SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0));
242 
243     // Compute the partial sums for the slice offsets.
244     AffineExpr sum = rewriter.getAffineDimExpr(0);
245     SmallVector<AffineExpr> partialSums = {sum};
246     SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
247     for (auto [idx, input] :
248          llvm::enumerate(concatOp.getInputs().drop_back())) {
249       sum = sum + rewriter.getAffineDimExpr(idx + 1);
250       partialSums.push_back(sum);
251       offsetStrides.push_back(
252           rewriter.createOrFold<tensor::DimOp>(loc, input, dim));
253     }
254     auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
255                                         partialSums, rewriter.getContext());
256     SmallVector<OpFoldResult> dimOffsets =
257         affine::makeComposedFoldedMultiResultAffineApply(
258             rewriter, loc, partialSumMap, offsetStrides);
259 
260     auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {
261       for (auto [l, r] : llvm::zip(lhs, rhs)) {
262         std::optional<int64_t> staticVal = getConstantIntValue(l);
263         if (!staticVal.has_value() || staticVal != getConstantIntValue(r))
264           return false;
265       }
266       return lhs.size() == rhs.size();
267     };
268 
269     for (auto [i, input, offset] :
270          llvm::enumerate(concatOp.getInputs(), dimOffsets)) {
271       SmallVector<OpFoldResult> srcSizes =
272           tensor::getMixedSizes(rewriter, loc, input);
273       srcOffsets[dim] = offset;
274 
275       SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes();
276       SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets();
277       SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides();
278 
279       if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
280           allEqual(srcStrides, dstStrides)) {
281         Value operand = concatOp.getOperand(i);
282         if (operand.getType() == extractOp.getResultType())
283           rewriter.replaceOp(extractOp, operand);
284         break;
285       }
286     }
287 
288     return success();
289   }
290 };
291 
292 /// Rewriting rule that fuses sparse_tensor.convert into producer.
293 struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
294 public:
295   using OpRewritePattern::OpRewritePattern;
296 
297   LogicalResult matchAndRewrite(ConvertOp op,
298                                 PatternRewriter &rewriter) const override {
299     auto producer = op.getSource().getDefiningOp<GenericOp>();
300     if (!producer || producer.getDpsInits().size() != 1 ||
301         !isMaterializing(producer.getDpsInitOperand(0), false) ||
302         !producer.getResult(0).hasOneUse()) {
303       return failure();
304     }
305     // Clone the materialization operation, but update the result to sparse.
306     rewriter.setInsertionPoint(producer);
307     Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp();
308     Operation *cloned = rewriter.clone(*init);
309     cloned->getResult(0).setType(op.getResult().getType());
310 
311     rewriter.modifyOpInPlace(producer, [&]() {
312       producer.getDpsInitsMutable().assign(cloned->getResults());
313       producer.getResult(0).setType(op.getResult().getType());
314     });
315 
316     rewriter.replaceAllOpUsesWith(op, producer);
317     op->erase();
318 
319     return success();
320   }
321 };
322 
323 /// Rewriting rule that converts direct yield of zero with initial allocation.
324 struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
325 public:
326   using OpRewritePattern<GenericOp>::OpRewritePattern;
327 
328   LogicalResult matchAndRewrite(GenericOp op,
329                                 PatternRewriter &rewriter) const override {
330     if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 ||
331         !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
332         !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
333       return failure();
334     auto outputType = getRankedTensorType(op.getResult(0));
335     // Yielding zero on newly materialized sparse tensor can be
336     // optimized directly (regardless of dynamic or static size).
337     if (getSparseTensorEncoding(outputType)) {
338       rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
339       return success();
340     }
341     // Use static zero value directly instead of materialization.
342     if (!outputType.hasStaticShape())
343       return failure();
344     Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
345     rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType));
346     rewriter.eraseOp(def);
347     return success();
348   }
349 };
350 
351 /// Rewriting rule that converts two kernels:
352 ///
353 ///      T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
354 ///      X(i,j) = S(i,j) * T(i,j)
355 ///
356 /// into a single kernel, using distributive law:
357 ///
358 ///      X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
359 ///
360 /// This kind of fusion (merging two ops into one but using arithmetic
361 /// equalities that may not hold for floating-point computations) would
362 /// be undesirable in the dense case, since we distribute the multiplication
363 /// into the reduction loop. However, for sparse sampling tensor S, such
364 /// a fusion may actually reduce the asymptotic complexity of the kernel,
365 /// since intermediate results may be nullified.
366 struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
367 public:
368   using OpRewritePattern<GenericOp>::OpRewritePattern;
369 
370   LogicalResult matchAndRewrite(GenericOp op,
371                                 PatternRewriter &rewriter) const override {
372     // Check consumer.
373     if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 ||
374         op.getNumResults() != 1 ||
375         op.getNumParallelLoops() != op.getNumLoops() ||
376         !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
377         !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
378         !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
379       return failure();
380     // Find consuming OP2(sparse, other) or OP2(other, sparse). The other
381     // operand can be sparse or dense, since the point of this rewriting rule
382     // is detecting a situation in which *more* sparsity is introduced into
383     // a computation, be it already sparse or still dense.
384     unsigned other = 0;
385     if (isSparseTensor(op.getDpsInputOperand(0)))
386       other = 1;
387     else if (!isSparseTensor(op.getDpsInputOperand(1)))
388       return failure();
389     // Check producer.
390     auto prod = dyn_cast_or_null<GenericOp>(
391         op.getDpsInputOperand(other)->get().getDefiningOp());
392     if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 ||
393         !prod.getResult(0).hasOneUse())
394       return failure();
395     // Sampling consumer and sum of multiplication chain producer.
396     if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
397         !isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) ||
398         !isSampling(op) || !isSumOfMul(prod))
399       return failure();
400     // Modify operand structure of producer and consumer.
401     Location loc = prod.getLoc();
402     SmallVector<Value> inputOps = prod.getInputs();
403     SmallVector<Value> outputOps = op.getOutputs();
404     SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
405     inputOps.push_back(op.getDpsInputOperand(1 - other)->get());
406     fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
407     // Fuse producer and consumer into a new generic op.
408     auto fusedOp = rewriter.create<GenericOp>(
409         loc, op.getResult(0).getType(), inputOps, outputOps,
410         rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(),
411         /*doc=*/nullptr, /*library_call=*/nullptr);
412     Block &prodBlock = prod.getRegion().front();
413     Block &consBlock = op.getRegion().front();
414     IRMapping mapper;
415     Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
416     unsigned num = prodBlock.getNumArguments();
417     for (unsigned i = 0; i < num - 1; i++)
418       addArg(mapper, fusedBlock, prodBlock.getArgument(i));
419     addArg(mapper, fusedBlock, consBlock.getArgument(1 - other));
420     addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
421     // Clone bodies of the producer and consumer in new evaluation order.
422     auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
423     auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp();
424     Value last;
425     for (auto &op : prodBlock.without_terminator())
426       if (&op != acc) {
427         last = op.getResult(0);
428         rewriter.clone(op, mapper);
429       }
430     mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0));
431     mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
432     last = rewriter.clone(*acc, mapper)->getResult(0);
433     rewriter.create<linalg::YieldOp>(loc, last);
434     // Force initial value on merged allocation for dense outputs.
435     // TODO: deal with non alloc tensor here one day
436     if (!getSparseTensorEncoding(op.getResult(0).getType())) {
437       Value init = prod.getDpsInitOperand(0)
438                        ->get()
439                        .getDefiningOp<AllocTensorOp>()
440                        .getCopy();
441       AllocTensorOp a =
442           op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
443       rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });
444     }
445     // Replace consumer with fused operation. Old producer
446     // and consumer ops will be removed by DCE.
447     rewriter.replaceOp(op, fusedOp->getResults());
448     return success();
449   }
450 
451 private:
452   // Helper to add argument and record the mapping.
453   static void addArg(IRMapping &mapper, Block *b, BlockArgument a) {
454     mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
455   }
456 };
457 
458 // Fuse a tensor cast into producing operation. Note that a tensor.cast
459 // should really not be used to convert between sparse encodings. Since
460 // the pattern currently appears as a result of some prior rewriting
461 // we make an attempt to repair very obvious cases.
462 // TODO: audit the pure tensor dialect rewriting rules
463 struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
464 public:
465   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
466 
467   LogicalResult matchAndRewrite(tensor::CastOp op,
468                                 PatternRewriter &rewriter) const override {
469     Type srcType = op.getSource().getType();
470     Type dstType = op.getDest().getType();
471     // A nop cast simply folds away.
472     if (srcType == dstType) {
473       rewriter.replaceOp(op, op->getResults());
474       return success();
475     }
476     // See if a sparsity changing cast can be fused into producer.
477     if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) {
478       if (Operation *def = op.getSource().getDefiningOp()) {
479         if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
480           rewriter.modifyOpInPlace(def, [&]() {
481             def->getResult(0).setType(op->getResultTypes()[0]);
482           });
483           rewriter.replaceOp(op, def->getResult(0));
484           return success();
485         }
486       }
487     }
488     // Repair tensor casts with at least one sparse operand into the
489     // the properly supported sparse_tensor.convert.
490     if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) {
491       rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource());
492       return success();
493     }
494     // Fail otherwise.
495     return failure();
496   }
497 };
498 
499 /// Rewrites a sequence of operations for sparse tensor selections in to
500 /// semi-ring operations such that they can be compiled correctly by the
501 /// sparsifier. E.g., transforming the following sequence
502 ///
503 /// %sel = arith.select %cond, %sp1, %sp2
504 ///
505 /// to
506 ///
507 /// %sel = binary %sp1, %sp2:
508 ///         both  (%l, %r) {yield select %cond, %l, %r}
509 ///         left  (%l)     {yield select %cond, %l,  0}
510 ///         right (%r)     {yield select %cond,  0, %r}
511 ///
512 /// TODO: We require that the tensor used for extracting conditions to be dense
513 /// to sparsify the code. To support a sparse condition tensor, we need a
514 /// tri-nary operation.
515 struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
516 public:
517   using OpRewritePattern<GenericOp>::OpRewritePattern;
518   LogicalResult matchAndRewrite(GenericOp op,
519                                 PatternRewriter &rewriter) const override {
520     // Rejects non sparse kernels.
521     if (!op.hasPureTensorSemantics() || !hasAnySparseOperand(op))
522       return failure();
523 
524     Location loc = op.getLoc();
525     SmallVector<std::pair<Operation *, sparse_tensor::BinaryOp>> semiRings;
526     for (Operation &inst : *op.getBody()) {
527       // Matches pattern.
528       auto matched = isRewritablePattern(op, &inst);
529       if (!matched.has_value())
530         continue;
531 
532       rewriter.setInsertionPoint(&inst);
533       auto [c, t, f] = matched.value();
534       assert(t.getType() == f.getType());
535       auto selTp = t.getType();
536       auto c0 = constantZero(rewriter, loc, selTp);
537       auto binOp = rewriter.create<sparse_tensor::BinaryOp>(loc, selTp, t, f);
538       // Initializes all the blocks.
539       rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
540                            {t.getLoc(), f.getLoc()});
541       rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
542       rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
543 
544       for (auto *r : binOp.getRegions()) {
545         Block *b = &r->front();
546         rewriter.setInsertionPointToStart(b);
547 
548         IRMapping irMap;
549         // Clones the cmp operations into the region to make the binary op
550         // admissible.
551         Value newC = c;
552         if (auto *def = c.getDefiningOp())
553           newC = rewriter.clone(*def, irMap)->getResult(0);
554 
555         irMap.map(c, newC);
556         if (r == &binOp.getLeftRegion()) {
557           irMap.map(t, b->getArgument(0));
558           irMap.map(f, c0);
559         } else if (r == &binOp.getRightRegion()) {
560           irMap.map(t, c0);
561           irMap.map(f, b->getArgument(0));
562         } else {
563           irMap.map(t, b->getArgument(0));
564           irMap.map(f, b->getArgument(1));
565         }
566         auto y = rewriter.clone(inst, irMap)->getResult(0);
567         rewriter.create<sparse_tensor::YieldOp>(loc, y);
568       }
569 
570       // We successfully rewrited a operation. We can not do replacement here
571       // becuase it invalidate the iterator for the current loop to traverse
572       // the instructions.
573       semiRings.emplace_back(&inst, binOp);
574     }
575 
576     // Finalizes the replacement.
577     for (auto [sel, semi] : semiRings)
578       rewriter.replaceOp(sel, semi->getResults());
579 
580     return success(!semiRings.empty());
581   }
582 
583 private:
584   static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
585   isRewritablePattern(GenericOp op, Operation *v) {
586     auto sel = dyn_cast<arith::SelectOp>(v);
587     if (!sel)
588       return std::nullopt;
589 
590     auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
591     auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
592     // TODO: For simplicity, we only handle cases where both true/false value
593     // are directly loaded the input tensor. We can probably admit more cases
594     // in theory.
595     if (!tVal || !fVal)
596       return std::nullopt;
597 
598     // Helper lambda to determine whether the value is loaded from a dense input
599     // or is a loop invariant.
600     auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
601       if (auto bArg = dyn_cast<BlockArgument>(v);
602           bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
603         return true;
604       // If the value is defined outside the loop, it is a loop invariant.
605       return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody();
606     };
607 
608     // If the condition value is load directly from a dense tensor or
609     // loop-invariants, we can sparsify the kernel.
610     auto cond = sel.getCondition();
611     if (isValFromDenseInputOrInvariant(cond))
612       return std::make_tuple(cond, tVal, fVal);
613 
614     Value cmpL, cmpR;
615     if (matchPattern(cond, m_Op<arith::CmpIOp>(matchers::m_Any(&cmpL),
616                                                matchers::m_Any(&cmpR))) ||
617         matchPattern(cond, m_Op<arith::CmpFOp>(matchers::m_Any(&cmpL),
618                                                matchers::m_Any(&cmpR)))) {
619       // TODO: we can do it recursively to check whether all the leaf values are
620       // loaded from dense tensors or are loop invariants.
621       if (isValFromDenseInputOrInvariant(cmpL) ||
622           isValFromDenseInputOrInvariant(cmpR))
623         return std::make_tuple(cond, tVal, fVal);
624     }
625 
626     return std::nullopt;
627   };
628 };
629 
630 /// Rewrites a sparse reduction that would not sparsify directly since
631 /// doing so would only iterate over the stored elements, ignoring the
632 /// implicit zeros, into a semi-ring. Applies to all prod/and/min/max
633 /// (note that reductions like add/sub/or/xor can directly be sparsified
634 /// since the implicit zeros do not contribute to the final result).
635 /// Note that prod/and are still included since, even though they often
636 /// are nullified in sparse data, they may still occur for special
637 /// situations in which e.g. some rows in a sparse matrix are fully
638 /// dense. For min/max, including the implicit zeros is a much more
639 /// common situation.
640 ///
641 /// TODO: this essentially "densifies" the operation; we want to implement
642 ///       this much more efficiently by performing the reduction over the
643 ///       stored values, and feed in the zero once if there were *any*
644 ///       implicit zeros as well; but for now, at least we provide
645 ///       the functionality
646 ///
647 struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
648 public:
649   using OpRewritePattern<GenericOp>::OpRewritePattern;
650 
651   LogicalResult matchAndRewrite(GenericOp op,
652                                 PatternRewriter &rewriter) const override {
653     // Reject non-reductions.
654     if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 ||
655         op.getNumReductionLoops() == 0 || op.getNumResults() != 1)
656       return failure();
657     auto *inp = op.getDpsInputOperand(0);
658     auto *init = op.getDpsInitOperand(0);
659     if (!isSparseTensor(inp))
660       return failure();
661     // Look for direct x = x OP y for semi-ring ready reductions.
662     auto *red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator())
663                     .getOperand(0)
664                     .getDefiningOp();
665     if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp,
666              arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp,
667              arith::MaxUIOp>(red))
668       return failure();
669     Value s0 = op.getBlock()->getArgument(0);
670     Value s1 = op.getBlock()->getArgument(1);
671     if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&
672         (red->getOperand(0) != s1 || red->getOperand(1) != s0))
673       return failure();
674     // Identity.
675     Location loc = op.getLoc();
676     Value identity =
677         rewriter.create<tensor::ExtractOp>(loc, init->get(), ValueRange());
678     // Unary {
679     //    present -> value
680     //    absent  -> zero.
681     // }
682     Type rtp = s0.getType();
683     rewriter.setInsertionPointToStart(&op.getRegion().front());
684     auto semiring = rewriter.create<sparse_tensor::UnaryOp>(loc, rtp, s0);
685     Block *present =
686         rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc);
687     rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front());
688     rewriter.create<sparse_tensor::YieldOp>(loc, present->getArgument(0));
689     rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {});
690     rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front());
691     auto zero =
692         rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(rtp));
693     rewriter.create<sparse_tensor::YieldOp>(loc, zero);
694     rewriter.setInsertionPointAfter(semiring);
695     // CustomReduce {
696     //    x = x REDUC y, identity
697     // }
698     auto custom = rewriter.create<sparse_tensor::ReduceOp>(
699         loc, rtp, semiring.getResult(), s1, identity);
700     Block *region =
701         rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});
702     rewriter.setInsertionPointToStart(&custom.getRegion().front());
703     IRMapping irMap;
704     irMap.map(red->getOperand(0), region->getArgument(0));
705     irMap.map(red->getOperand(1), region->getArgument(1));
706     auto *cloned = rewriter.clone(*red, irMap);
707     rewriter.create<sparse_tensor::YieldOp>(loc, cloned->getResult(0));
708     rewriter.setInsertionPointAfter(custom);
709     rewriter.replaceOp(red, custom.getResult());
710     return success();
711   }
712 };
713 
714 /// Sparse rewriting rule for the print operator. This operation is mainly used
715 /// for debugging and testing. As such, it lowers to the vector.print operation
716 /// which only require very light-weight runtime support.
717 struct PrintRewriter : public OpRewritePattern<PrintOp> {
718 public:
719   using OpRewritePattern::OpRewritePattern;
720   LogicalResult matchAndRewrite(PrintOp op,
721                                 PatternRewriter &rewriter) const override {
722     Location loc = op.getLoc();
723     auto tensor = op.getTensor();
724     auto stt = getSparseTensorType(tensor);
725     // Header with NSE.
726     auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor);
727     rewriter.create<vector::PrintOp>(
728         loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
729     rewriter.create<vector::PrintOp>(loc, nse);
730     // Print run-time contents for dim/lvl sizes.
731     rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("dim = "));
732     printSizes(rewriter, loc, tensor, stt.getDimRank(), /*isDim=*/true);
733     rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("lvl = "));
734     printSizes(rewriter, loc, tensor, stt.getLvlRank(), /*isDim=*/false);
735     // Use the "codegen" foreach loop construct to iterate over
736     // all typical sparse tensor components for printing.
737     foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor,
738                                             &stt](Type, FieldIndex,
739                                                   SparseTensorFieldKind kind,
740                                                   Level l, LevelType) {
741       switch (kind) {
742       case SparseTensorFieldKind::StorageSpec: {
743         break;
744       }
745       case SparseTensorFieldKind::PosMemRef: {
746         auto lvl = constantIndex(rewriter, loc, l);
747         rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos["));
748         rewriter.create<vector::PrintOp>(
749             loc, lvl, vector::PrintPunctuation::NoPunctuation);
750         rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
751         auto pos = rewriter.create<ToPositionsOp>(loc, tensor, l);
752         printContents(rewriter, loc, pos);
753         break;
754       }
755       case SparseTensorFieldKind::CrdMemRef: {
756         auto lvl = constantIndex(rewriter, loc, l);
757         rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("crd["));
758         rewriter.create<vector::PrintOp>(
759             loc, lvl, vector::PrintPunctuation::NoPunctuation);
760         rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
761         Value crd = nullptr;
762         // For COO AoS storage, we want to print a single, linear view of
763         // the full coordinate storage at this level. For any other storage,
764         // we show the coordinate storage for every indivual level.
765         if (stt.getAoSCOOStart() == l)
766           crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor);
767         else
768           crd = rewriter.create<ToCoordinatesOp>(loc, tensor, l);
769         printContents(rewriter, loc, crd);
770         break;
771       }
772       case SparseTensorFieldKind::ValMemRef: {
773         rewriter.create<vector::PrintOp>(loc,
774                                          rewriter.getStringAttr("values : "));
775         auto val = rewriter.create<ToValuesOp>(loc, tensor);
776         printContents(rewriter, loc, val);
777         break;
778       }
779       }
780       return true;
781     });
782     rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("----\n"));
783     rewriter.eraseOp(op);
784     return success();
785   }
786 
787 private:
788   // Helper to print contents of a single memref. For "push_back" vectors,
789   // we assume that the previous getters for pos/crd/val have added a
790   // slice-to-size view to make sure we just print the size and not the
791   // full capacity.
792   //
793   // Generates code to print (1-dim or higher):
794   //    ( a0, a1, ... )
795   static void printContents(PatternRewriter &rewriter, Location loc,
796                             Value vec) {
797     auto shape = cast<ShapedType>(vec.getType()).getShape();
798     SmallVector<Value> idxs;
799     printContentsLevel(rewriter, loc, vec, 0, shape, idxs);
800     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
801   }
802 
803   // Helper to the helper.
804   static void printContentsLevel(PatternRewriter &rewriter, Location loc,
805                                  Value vec, unsigned i, ArrayRef<int64_t> shape,
806                                  SmallVectorImpl<Value> &idxs) {
807     // Open bracket.
808     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
809     // Generate for loop.
810     auto zero = constantIndex(rewriter, loc, 0);
811     auto index = constantIndex(rewriter, loc, i);
812     auto size = rewriter.create<memref::DimOp>(loc, vec, index);
813     auto step = constantIndex(rewriter, loc, 1);
814     auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
815     idxs.push_back(forOp.getInductionVar());
816     rewriter.setInsertionPointToStart(forOp.getBody());
817     if (i < shape.size() - 1) {
818       // Enter deeper loop nest.
819       printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);
820     } else {
821       // Actual contents printing.
822       auto val = rewriter.create<memref::LoadOp>(loc, vec, idxs);
823       if (llvm::isa<ComplexType>(val.getType())) {
824         // Since the vector dialect does not support complex types in any op,
825         // we split those into (real, imag) pairs here.
826         Value real = rewriter.create<complex::ReOp>(loc, val);
827         Value imag = rewriter.create<complex::ImOp>(loc, val);
828         rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
829         rewriter.create<vector::PrintOp>(loc, real,
830                                          vector::PrintPunctuation::Comma);
831         rewriter.create<vector::PrintOp>(loc, imag,
832                                          vector::PrintPunctuation::Close);
833       } else {
834         rewriter.create<vector::PrintOp>(
835             loc, val, vector::PrintPunctuation::NoPunctuation);
836       }
837       // Terminating comma (except at end).
838       auto bound = rewriter.create<arith::AddIOp>(loc, idxs.back(), step);
839       Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
840                                                   bound, size);
841       scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
842       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
843       rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
844     }
845     idxs.pop_back();
846     rewriter.setInsertionPointAfter(forOp);
847     // Close bracket.
848     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
849   }
850 
851   // Helper method to print run-time lvl/dim sizes.
852   static void printSizes(PatternRewriter &rewriter, Location loc, Value tensor,
853                          unsigned size, bool isDim) {
854     // Open bracket.
855     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
856     // Print unrolled contents (dimop requires constant value).
857     for (unsigned i = 0; i < size; i++) {
858       auto idx = constantIndex(rewriter, loc, i);
859       Value val;
860       if (isDim)
861         val = rewriter.create<tensor::DimOp>(loc, tensor, idx);
862       else
863         val = rewriter.create<LvlOp>(loc, tensor, idx);
864       rewriter.create<vector::PrintOp>(
865           loc, val,
866           i != size - 1 ? vector::PrintPunctuation::Comma
867                         : vector::PrintPunctuation::NoPunctuation);
868     }
869     // Close bracket and end of line.
870     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
871     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
872   }
873 };
874 
875 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
876 struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
877 public:
878   using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern;
879 
880   LogicalResult matchAndRewrite(tensor::ReshapeOp op,
881                                 PatternRewriter &rewriter) const override {
882     Location loc = op.getLoc();
883     Value srcTensor = op.getSource();
884     const auto srcTp = tryGetSparseTensorType(srcTensor);
885     const auto dstTp = tryGetSparseTensorType(op.getResult());
886     if (!srcTp || !dstTp)
887       return failure();
888 
889     if (!srcTp->hasEncoding() || !dstTp->hasEncoding() ||
890         !dstTp->hasStaticDimShape())
891       return failure();
892 
893     SmallVector<Value> srcSizes;
894     sizesForTensor(rewriter, srcSizes, loc, *srcTp, srcTensor);
895     SmallVector<Value> dstSizes;
896     for (Dimension d : dstTp->getDimShape())
897       dstSizes.push_back(constantIndex(rewriter, loc, d));
898 
899     Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
900     // Only need an unordered COO buffer if input and output are not sorted
901     // in the same way.
902     Type bufferTp = getBufferType(
903         dstTp->withoutDimToLvl(),
904         !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity());
905     SmallVector<Value> dynSizes;
906     Value buffer = rewriter
907                        .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
908                                               nnz, Attribute())
909                        .getResult();
910 
911     // Convert src coordinates to dst coordinates by first collapsing it to 1D
912     // and then expand it to the match the rank of the destination tensor.
913     // Implemented as follows:
914     //   foreach srcCoords %srcTensor
915     //     collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank])
916     //     expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank])
917     //     insert expandedCoords, %buffer
918     //
919     // followed by an optional
920     //   %t = sparse_tensor.cast %tmp
921     // depending on whether the input/output are sorted in the same way.
922     const auto encSrc = srcTp->getEncoding();
923     ForeachOp foreachOp = rewriter.create<ForeachOp>(
924         loc, srcTensor, buffer,
925         [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
926             ValueRange reduc) {
927           const Dimension srcRank = srcTp->getDimRank();
928           SmallVector<Value> srcDcvs;
929           srcDcvs.reserve(srcRank);
930           for (Dimension d = 0; d < srcRank; d++) {
931             Level lvl = toLvl(encSrc, d);
932             srcDcvs.push_back(srcLcvs[lvl]);
933           }
934 
935           Value collapseSize = constantIndex(builder, loc, 1);
936           for (Dimension d = 0; d < srcRank; d++)
937             collapseSize =
938                 builder.create<arith::MulIOp>(loc, collapseSize, srcSizes[d]);
939           SmallVector<Value, 1> collapsedSizes = {collapseSize};
940 
941           ReassociationIndices collapseIdx;
942           for (Dimension i = 0; i < srcRank; i++)
943             collapseIdx.push_back(i);
944           SmallVector<ReassociationIndices, 1> collapseReass = {collapseIdx};
945           SmallVector<Value, 1> collapsedDcvs;
946           reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs,
947                      collapsedSizes, collapsedDcvs);
948 
949           ReassociationIndices expandIdx;
950           for (Dimension i = 0; i < dstTp->getDimRank(); i++)
951             expandIdx.push_back(i);
952           SmallVector<ReassociationIndices, 1> expandReass = {expandIdx};
953           SmallVector<Value> dstDcvs;
954           reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
955                      dstSizes, dstDcvs);
956 
957           auto t =
958               builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
959           builder.create<sparse_tensor::YieldOp>(loc, t);
960         });
961 
962     Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
963     if (bufferTp != *dstTp) {
964       auto dstRTT = dstTp->getRankedTensorType();
965       Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
966       rewriter.create<DeallocTensorOp>(loc, t);
967       t = converted;
968     }
969     rewriter.replaceOp(op, t);
970     return success();
971   }
972 };
973 
974 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
975 template <typename ReshapeOp>
976 struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
977 public:
978   using OpRewritePattern<ReshapeOp>::OpRewritePattern;
979 
980   LogicalResult matchAndRewrite(ReshapeOp op,
981                                 PatternRewriter &rewriter) const override {
982     Location loc = op.getLoc();
983     Value srcTensor = op.getSrc();
984     const auto srcTp = getSparseTensorType(srcTensor);
985     const auto dstTp = getSparseTensorType(op.getResult());
986     if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
987       return failure();
988 
989     // Generate code to represent the static dimension constants or compute
990     // the dynamic dimension values.
991     SmallVector<Value> srcSizes;
992     sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
993     SmallVector<Value> dstSizes;
994     SmallVector<Value> dstDynSizes;
995     if (dstTp.hasStaticDimShape()) {
996       for (Dimension d : dstTp.getDimShape())
997         dstSizes.push_back(constantIndex(rewriter, loc, d));
998     } else {
999       ArrayRef<Size> dstShape = dstTp.getDimShape();
1000       genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape,
1001                          op.getReassociationIndices());
1002       for (auto [idx, shape] : llvm::enumerate(dstShape)) {
1003         if (shape == ShapedType::kDynamic)
1004           dstDynSizes.push_back(dstSizes[idx]);
1005       }
1006     }
1007     Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
1008     // Only need a unordered COO buffer if input and output are not sorted
1009     // in the same way.
1010     Type bufferTp = getBufferType(
1011         dstTp.withoutDimToLvl(),
1012         !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
1013 
1014     Value buffer =
1015         rewriter
1016             .create<AllocTensorOp>(loc, bufferTp, dstDynSizes, Value(),
1017                                    /*sizeHint=*/nnz, Attribute())
1018             .getResult();
1019 
1020     // Implement the sparse2sparse reshape as follows:
1021     //   foreach srcCoords %srcTensor
1022     //     insert reshapeCvs(srcCoords), %buffer
1023     //
1024     // followed by an optional
1025     //   %t = sparse_tensor.cast %tmp
1026     // depending on whether the input/output are sorted in the same way.
1027     const auto encSrc = srcTp.getEncoding();
1028     ForeachOp foreachOp = rewriter.create<ForeachOp>(
1029         loc, srcTensor, buffer,
1030         [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
1031             ValueRange reduc) {
1032           const Dimension dimRank = srcTp.getDimRank();
1033           SmallVector<Value> srcDcvs;
1034           srcDcvs.reserve(dimRank);
1035           for (Dimension d = 0; d < dimRank; d++) {
1036             Level lvl = toLvl(encSrc, d);
1037             srcDcvs.push_back(srcLcvs[lvl]);
1038           }
1039           SmallVector<Value> dstDcvs;
1040           reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
1041                      srcDcvs, dstSizes, dstDcvs);
1042           auto t =
1043               builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
1044           builder.create<sparse_tensor::YieldOp>(loc, t);
1045         });
1046 
1047     Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
1048     if (bufferTp != dstTp) {
1049       auto dstRTT = dstTp.getRankedTensorType();
1050       Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
1051       rewriter.create<DeallocTensorOp>(loc, t);
1052       t = converted;
1053     }
1054     rewriter.replaceOp(op, t);
1055     return success();
1056   }
1057 };
1058 
1059 /// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape
1060 /// operator.
1061 template <typename ReshapeOp>
1062 struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
1063 public:
1064   using OpRewritePattern<ReshapeOp>::OpRewritePattern;
1065 
1066   LogicalResult matchAndRewrite(ReshapeOp op,
1067                                 PatternRewriter &rewriter) const override {
1068     Location loc = op->getLoc();
1069     auto encDst = getSparseTensorEncoding(op.getResult().getType());
1070     auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
1071     // Since a pure dense expansion is very cheap (change of view), for
1072     // a sparse2dense or dense2sparse, we can simply unfuse a sparse
1073     // conversion from the reshape operation itself.
1074     // All other cases are handled elsewhere.
1075     if (encDst && encSrc) {
1076       return failure();
1077     }
1078     if (encSrc) {
1079       auto rtp = getRankedTensorType(op.getSrc());
1080       auto denseTp =
1081           RankedTensorType::get(rtp.getShape(), rtp.getElementType());
1082       auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
1083       rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); });
1084       return success();
1085     }
1086     if (encDst) {
1087       auto rtp = getRankedTensorType(op.getResult());
1088       auto denseTp =
1089           RankedTensorType::get(rtp.getShape(), rtp.getElementType());
1090       ReshapeOp reshape;
1091       if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) {
1092         reshape = rewriter.create<ReshapeOp>(
1093             loc, denseTp, op.getSrc(), op.getReassociation(),
1094             op.getOutputShape(), op.getStaticOutputShape());
1095       } else {
1096         reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
1097                                              op.getReassociation());
1098       }
1099       Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
1100       rewriter.replaceOp(op, convert);
1101       return success();
1102     }
1103     return failure();
1104   }
1105 };
1106 
1107 // A trivial wrapper to help generate different operations for dense/sparse
1108 // tensors.
1109 struct TensorLike {
1110   TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
1111              ValueRange sizes) {
1112     SmallVector<Value> dynSzs;
1113     getDynamicSizes(rtt, sizes, dynSzs);
1114 
1115     val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
1116     if (!isSparse()) {
1117       Value c0 = constantZero(builder, loc, rtt.getElementType());
1118       val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0);
1119     }
1120   }
1121 
1122   void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
1123     val = builder.create<tensor::InsertOp>(loc, v, val, crds);
1124   }
1125 
1126   Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
1127     if (isSparse())
1128       return builder.create<LoadOp>(loc, val, true);
1129     return val;
1130   }
1131 
1132   bool isSparse() const {
1133     return getSparseTensorEncoding(val.getType()) != nullptr;
1134   }
1135 
1136   Value val;
1137 };
1138 
1139 struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
1140   using OpRewritePattern::OpRewritePattern;
1141   LogicalResult matchAndRewrite(tensor::DimOp op,
1142                                 PatternRewriter &rewriter) const override {
1143     std::optional<int64_t> dim = op.getConstantIndex();
1144     auto stt = tryGetSparseTensorType(op.getSource());
1145     if (!dim || !stt || !stt->hasEncoding())
1146       return failure();
1147 
1148     if (stt->isPermutation()) {
1149       rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
1150                                          toLvl(stt->getEncoding(), *dim));
1151       return success();
1152     }
1153 
1154     // Non-permutation dim2lvl/lvl2dim maps.
1155     // Compute as follows:
1156     // affine.apply #map (l0 - 1, l1 - 1, ...) + 1
1157     // Note that it is not the most efficient way (but a more general one) for
1158     // the lvl to dim translation, e.g., for BSR, the dimension size for can be
1159     // computed simply by lvl_size * block_size.
1160     Location loc = op.getLoc();
1161     SmallVector<Value> maxLvlCrds;
1162     for (Level l = 0; l < stt->getLvlRank(); l++) {
1163       Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l);
1164       Value maxLvlCrd = rewriter.create<arith::SubIOp>(
1165           loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType()));
1166       maxLvlCrds.push_back(maxLvlCrd);
1167     }
1168 
1169     AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim);
1170     Value maxDimCrd = rewriter.create<affine::AffineApplyOp>(
1171         op.getLoc(), AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp),
1172         maxLvlCrds);
1173 
1174     Value dimSz = rewriter.create<arith::AddIOp>(
1175         loc, maxDimCrd, constantOne(rewriter, loc, rewriter.getIndexType()));
1176     rewriter.replaceOp(op, dimSz);
1177     return success();
1178   }
1179 };
1180 
1181 struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
1182   using OpRewritePattern::OpRewritePattern;
1183   LogicalResult matchAndRewrite(ConcatenateOp op,
1184                                 PatternRewriter &rewriter) const override {
1185     if (op.needsExtraSort())
1186       op.emitError("ConcatenateOp not staged");
1187 
1188     const Location loc = op.getLoc();
1189     const auto dstTp = getSparseTensorType(op);
1190     const Dimension conDim = op.getDimension();
1191     SmallVector<Value> sizes;
1192     concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim);
1193 
1194     // %t = concatenate %s1, %s2, %s3 {dim = 1}
1195     // ==>
1196     // if (isSparseDst)
1197     //   if (allDense)
1198     //     %tmp = bufferization.alloc_tensor dstTp
1199     //   else
1200     //     %tmp = bufferization.alloc_tensor : unordered COO
1201     // else
1202     //   %tmp = memref.alloc : dense tensor
1203     // foreach in %s1 : insert d0, d1, %tmp
1204     // foreach in %s2 : insert d0, d1 + size(s1), %tmp
1205     // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp
1206 
1207     TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
1208     Value offset = constantIndex(rewriter, loc, 0);
1209     Value iterArg = dstBuf.val;
1210 
1211     ForeachOp foreachOp;
1212     for (Value input : op.getInputs()) {
1213       // Builds a for op for each input tensor to append new values into the
1214       // output tensor.
1215       foreachOp = rewriter.create<ForeachOp>(
1216           loc, input, iterArg,
1217           [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1218               ValueRange reduc) {
1219             SmallVector<Value> offDimCrd(dcvs);
1220             offDimCrd[conDim] =
1221                 builder.create<arith::AddIOp>(loc, offDimCrd[conDim], offset);
1222 
1223             // Enters foreach, updates the SSA chain.
1224             dstBuf.val = reduc.front();
1225             if (!dstTp.isAllDense()) {
1226               Value cond = genIsNonzero(builder, loc, v);
1227               auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1228                                                     /*else*/ true);
1229               builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1230               builder.create<scf::YieldOp>(loc, dstBuf.val);
1231 
1232               builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1233               dstBuf.insert(builder, loc, v, offDimCrd);
1234               builder.create<scf::YieldOp>(loc, dstBuf.val);
1235 
1236               // Exits the ifOp, update the sparse tensor SSA value.
1237               builder.setInsertionPointAfter(ifOp);
1238               dstBuf.val = ifOp.getResult(0);
1239             } else {
1240               dstBuf.insert(builder, loc, v, offDimCrd);
1241             }
1242             builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1243           });
1244       // Accumulates the offset. Note that only static-shaped inputs are allowed
1245       // by concatenate op verifier, which saves us from computing the offset
1246       // dynamically.
1247       const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim);
1248       assert(!ShapedType::isDynamic(sz));
1249       offset = rewriter.create<arith::AddIOp>(loc, offset,
1250                                               constantIndex(rewriter, loc, sz));
1251       iterArg = foreachOp.getResult(0);
1252       dstBuf.val = iterArg;
1253     }
1254 
1255     dstBuf.val = iterArg;
1256     Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
1257     rewriter.replaceOp(op, ret);
1258     return success();
1259   }
1260 };
1261 
1262 struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
1263   using OpRewritePattern::OpRewritePattern;
1264   LogicalResult matchAndRewrite(ConvertOp op,
1265                                 PatternRewriter &rewriter) const override {
1266     if (op.needsExtraSort())
1267       return op.emitError("ConvertOp not staged.");
1268 
1269     // TODO: Maybe we want a different operation for this too.
1270     auto encDst = getSparseTensorEncoding(op.getType());
1271     auto encSrc = getSparseTensorEncoding(op.getSource().getType());
1272     if (encDst && encSrc && !encSrc.isSlice() &&
1273         encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
1274       // Trivial tensor conversion and simple element type conversion is handled
1275       // in codegen.
1276       return failure();
1277     }
1278 
1279     Location loc = op.getLoc();
1280     Value src = op.getSource();
1281 
1282     SparseTensorType srcStt = getSparseTensorType(op.getSource());
1283     SparseTensorType dstStt = getSparseTensorType(op.getDest());
1284 
1285     bool fromSparseConst = false;
1286     if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>())
1287       if (dyn_cast<SparseElementsAttr>(constOp.getValue()))
1288         fromSparseConst = true;
1289 
1290     const AffineMapAttr foreachOrder =
1291         (!dstStt.isIdentity() && fromSparseConst)
1292             ? AffineMapAttr::get(dstStt.getExpandedDimToLvl())
1293             : nullptr;
1294 
1295     bool skipZeroCheck = srcStt.hasEncoding() || fromSparseConst;
1296 
1297     SmallVector<Value> sizes;
1298     sizesFromSrc(rewriter, sizes, loc, src);
1299     ValueRange vs;
1300     TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
1301 
1302     auto foreachOp = rewriter.create<ForeachOp>(
1303         loc, src, dstBuf.val, foreachOrder,
1304         [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1305             ValueRange reduc) {
1306           // Enters the loop, update the SSA value for insertion chain.
1307           dstBuf.val = reduc.front();
1308           if (!skipZeroCheck) {
1309             Value cond = genIsNonzero(builder, loc, v);
1310             auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1311                                                   /*else*/ true);
1312             builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1313             builder.create<scf::YieldOp>(loc, dstBuf.val);
1314 
1315             builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1316             dstBuf.insert(builder, loc, v, dcvs);
1317             builder.create<scf::YieldOp>(loc, dstBuf.val);
1318 
1319             // Exits the ifOp, update the sparse tensor SSA value.
1320             builder.setInsertionPointAfter(ifOp);
1321             dstBuf.val = ifOp.getResult(0);
1322           } else {
1323             dstBuf.insert(builder, loc, v, dcvs);
1324           }
1325           builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1326         });
1327 
1328     rewriter.setInsertionPointAfter(foreachOp);
1329 
1330     // Exits the for loop, links the SSA chain.
1331     dstBuf.val = foreachOp.getResult(0);
1332 
1333     Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
1334     rewriter.replaceOp(op, ret);
1335     return success();
1336   }
1337 };
1338 
1339 struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
1340   using OpRewritePattern::OpRewritePattern;
1341   LogicalResult matchAndRewrite(CrdTranslateOp op,
1342                                 PatternRewriter &rewriter) const override {
1343     AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
1344                         ? op.getEncoder().getDimToLvl()
1345                         : op.getEncoder().getLvlToDim();
1346 
1347     SmallVector<Value> outCrds;
1348     for (AffineExpr result : map.getResults()) {
1349       // TODO: we should probably expand the affine map to IR using our own
1350       // rules, since affine.apply assume signed value, while the cooridinates
1351       // we provided must always be signless.
1352       Value trans = rewriter.create<affine::AffineApplyOp>(
1353           op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
1354           op.getInCrds());
1355       outCrds.push_back(trans);
1356     }
1357     rewriter.replaceOp(op, outCrds);
1358     return success();
1359   }
1360 };
1361 
1362 /// Sparse rewriting rule for the foreach operator.
1363 struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
1364 public:
1365   using OpRewritePattern::OpRewritePattern;
1366 
1367   LogicalResult matchAndRewrite(ForeachOp op,
1368                                 PatternRewriter &rewriter) const override {
1369 
1370     auto loc = op.getLoc();
1371     Value input = op.getTensor();
1372     SmallVector<Value> reduc = op.getInitArgs();
1373     const auto stt = getSparseTensorType(input);
1374     const Level lvlRank = stt.getLvlRank();
1375 
1376     // Special-case: for each over a sparse constant uses its own rewriting
1377     // rule.
1378     if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
1379       if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
1380         return genForeachOnSparseConstant(op, rewriter, attr);
1381       }
1382     }
1383 
1384     // Otherwise, use loop emitter to generate loops.
1385     const auto enc = stt.getEncoding();
1386 
1387     // 1. Generates loop for the sparse input.
1388     LoopEmitter loopEmitter(
1389         ValueRange{input},
1390         StringAttr::get(getContext(), ForeachOp::getOperationName()));
1391     loopEmitter.initializeLoopEmit(rewriter, loc);
1392     for (Level l = 0; l < lvlRank; l++) {
1393       // TODO: provide utility function for loop sequences that only contains
1394       // one for loop?
1395       const SmallVector<TensorLevel, 1> tidLvls{
1396           loopEmitter.makeTensorLevel(0, l)};
1397       loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
1398       // Note that reduc will be taken care of by loop emitter and get updated
1399       // in place.
1400       loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1,
1401                                                     reduc);
1402     }
1403 
1404     SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
1405     if (op.getOrder()) {
1406       // TODO: Support it so that we can do direct conversion from CSR->BSR.
1407       llvm_unreachable(
1408           "Level order not yet implemented on non-constant input tensors.");
1409     }
1410 
1411     Value vals = loopEmitter.getValBuffer()[0];
1412     SmallVector<Value> pos = loopEmitter.getValPosits(0);
1413     // Loads the value from sparse tensor using position-index;
1414     // loads the value from dense tensor using coords.
1415     Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pos)
1416                     : rewriter.create<memref::LoadOp>(loc, vals, lcvs);
1417 
1418     // 2. Inline the block in the foreach operator.
1419     Block *srcBlock = op.getBody();
1420 
1421     // Remap coordinates.
1422     SmallVector<Value> args =
1423         enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
1424 
1425     // Remap value.
1426     args.push_back(val);
1427     // Remap reduction variables.
1428     args.append(reduc);
1429 
1430     // Remove sparse_tensor.yield.
1431     SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands();
1432     rewriter.eraseOp(srcBlock->getTerminator());
1433 
1434     Operation &last = rewriter.getBlock()->back();
1435     if (llvm::isa<scf::YieldOp>(last)) {
1436       // Because `scf.for` inserts an implicit yield op when there is no
1437       // reduction variable upon creation, we reset the insertion point such
1438       // that the block is inlined before *before* the yield op.
1439       rewriter.setInsertionPoint(&last);
1440     }
1441 
1442     rewriter.inlineBlockBefore(srcBlock, rewriter.getBlock(),
1443                                rewriter.getInsertionPoint(), args);
1444     rewriter.setInsertionPointToEnd(rewriter.getBlock());
1445     for (Level l = 0; l < lvlRank; l++) {
1446       // Link the reduction chain. Note that loop emitter update the reducValue
1447       // in place.
1448       loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
1449       loopEmitter.exitCurrentLoopSeq(rewriter, loc);
1450     }
1451 
1452     // Replace the foreach operator with the value returned by the outtermost
1453     // for loop.
1454     rewriter.replaceOp(op, reducValue);
1455     return success();
1456   }
1457 };
1458 
1459 /// Sparse rewriting rule for the new operator.
1460 struct NewRewriter : public OpRewritePattern<NewOp> {
1461   using OpRewritePattern::OpRewritePattern;
1462   LogicalResult matchAndRewrite(NewOp op,
1463                                 PatternRewriter &rewriter) const override {
1464     Location loc = op.getLoc();
1465     auto stt = getSparseTensorType(op.getResult());
1466     if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0)
1467       return failure();
1468 
1469     // Implement the NewOp as follows:
1470     //   %orderedCoo = sparse_tensor.new %filename
1471     //   %t = sparse_tensor.convert %orderedCoo
1472     // with enveloping reinterpreted_map ops for non-permutations.
1473     RankedTensorType dstTp = stt.getRankedTensorType();
1474     RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
1475     Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
1476     Value convert = cooTensor;
1477     auto enc = stt.getEncoding();
1478     if (!stt.isPermutation()) { // demap coo, demap dstTp
1479       auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
1480       convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
1481       dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
1482     }
1483     convert = rewriter.create<ConvertOp>(loc, dstTp, convert);
1484     if (!stt.isPermutation()) // remap to original enc
1485       convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert);
1486     rewriter.replaceOp(op, convert);
1487 
1488     // Release the temporary ordered COO tensor.
1489     rewriter.setInsertionPointAfterValue(convert);
1490     rewriter.create<DeallocTensorOp>(loc, cooTensor);
1491 
1492     return success();
1493   }
1494 };
1495 
1496 /// Sparse rewriting rule for the out operator.
1497 struct OutRewriter : public OpRewritePattern<OutOp> {
1498   using OpRewritePattern::OpRewritePattern;
1499   LogicalResult matchAndRewrite(OutOp op,
1500                                 PatternRewriter &rewriter) const override {
1501     Location loc = op.getLoc();
1502     // Calculate NNZ.
1503     Value src = op.getTensor();
1504     Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
1505 
1506     // Allocate a temporary buffer for storing dimension-sizes/coordinates.
1507     const auto srcTp = getSparseTensorType(src);
1508     const Dimension dimRank = srcTp.getDimRank();
1509     Type indexTp = rewriter.getIndexType();
1510     Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp);
1511 
1512     // Generate code to calculate dimension size values and store the values to
1513     // the buffer.
1514     SmallVector<Value> dims;
1515     sizesForTensor(rewriter, dims, loc, srcTp, src);
1516     for (Dimension d = 0; d < dimRank; d++) {
1517       rewriter.create<memref::StoreOp>(loc, dims[d], dimSizes,
1518                                        constantIndex(rewriter, loc, d));
1519     }
1520 
1521     // Create a sparse tensor writer and output meta data.
1522     Type opaqueTp = getOpaquePointerType(rewriter);
1523     Value writer =
1524         createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp},
1525                        {op.getDest()}, EmitCInterface::Off)
1526             .getResult(0);
1527     Value rankValue = constantIndex(rewriter, loc, dimRank);
1528     createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {},
1529                    {writer, rankValue, nnz, dimSizes}, EmitCInterface::On);
1530 
1531     Value dimCoords = dimSizes; // Reuse the dimSizes buffer for dimCoords.
1532     Type eltTp = srcTp.getElementType();
1533     SmallString<29> outNextFuncName{"outSparseTensorWriterNext",
1534                                     primaryTypeFunctionSuffix(eltTp)};
1535     Value value = genAllocaScalar(rewriter, loc, eltTp);
1536     ModuleOp module = op->getParentOfType<ModuleOp>();
1537 
1538     // For each element in the source tensor, output the element.
1539     rewriter.create<ForeachOp>(
1540         loc, src, std::nullopt,
1541         [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1542             ValueRange reduc) {
1543           for (Dimension d = 0; d < dimRank; d++) {
1544             rewriter.create<memref::StoreOp>(loc, dcvs[d], dimCoords,
1545                                              constantIndex(builder, loc, d));
1546           }
1547           rewriter.create<memref::StoreOp>(loc, v, value);
1548           SmallVector<Value> operands{writer, rankValue, dimCoords, value};
1549           FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands,
1550                                          EmitCInterface::On);
1551           builder.create<func::CallOp>(loc, TypeRange(), fn, operands);
1552           builder.create<sparse_tensor::YieldOp>(loc);
1553         });
1554 
1555     // Release the writer.
1556     createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer},
1557                    EmitCInterface::Off);
1558 
1559     rewriter.eraseOp(op);
1560     return success();
1561   }
1562 };
1563 
1564 } // namespace
1565 
1566 //===---------------------------------------------------------------------===//
1567 // Methods that add patterns described in this file to a pattern list.
1568 //===---------------------------------------------------------------------===//
1569 
1570 void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
1571   patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
1572                FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
1573                GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
1574       patterns.getContext());
1575 }
1576 
1577 void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
1578                                                    bool enableRT,
1579                                                    bool enableConvert) {
1580   patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
1581                ReshapeRewriter<tensor::CollapseShapeOp>,
1582                Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
1583                Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
1584                SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
1585       patterns.getContext());
1586 
1587   if (enableConvert)
1588     patterns.add<DirectConvertRewriter>(patterns.getContext());
1589   if (!enableRT)
1590     patterns.add<NewRewriter>(patterns.getContext());
1591 }
1592 
1593 void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) {
1594   // Run CrdTranslateRewriter later in the pipeline so that operation can be
1595   // folded before lowering to affine.apply
1596   patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext());
1597 }
1598