xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (revision e095d978ba476c9624b4e72776089ea7301fa657)
1 //===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting rules that are specific to sparse tensors.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Utils/CodegenUtils.h"
14 #include "Utils/LoopEmitter.h"
15 
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
24 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
25 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
26 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
27 #include "mlir/Dialect/Tensor/IR/Tensor.h"
28 #include "mlir/Dialect/Vector/IR/VectorOps.h"
29 #include "mlir/IR/AffineMap.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/Support/LLVM.h"
32 
33 using namespace mlir;
34 using namespace mlir::bufferization;
35 using namespace mlir::linalg;
36 using namespace mlir::sparse_tensor;
37 
38 //===---------------------------------------------------------------------===//
39 // Helper methods for the actual rewriting rules.
40 //===---------------------------------------------------------------------===//
41 
42 // Helper method to match any typed zero.
43 static bool isZeroValue(Value val) {
44   return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat());
45 }
46 
47 // Helper to detect a sparse tensor type operand.
48 static bool isSparseTensor(Value v) {
49   auto enc = getSparseTensorEncoding(v.getType());
50   return enc && !llvm::all_of(enc.getLvlTypes(),
51                               [](auto lt) { return lt == LevelFormat::Dense; });
52 }
53 static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
54 
55 // Helper method to find zero/uninitialized tensor materialization.
56 static bool isMaterializing(OpOperand *op, bool isZero) {
57   Value val = op->get();
58   // Check allocation, with zero alloc when required.
59   if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
60     Value copy = alloc.getCopy();
61     if (isZero)
62       return copy && isZeroValue(copy);
63     return !copy;
64   }
65   // Check for empty tensor materialization.
66   if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
67     return !isZero;
68   // Last resort for zero alloc: the whole value is zero.
69   return isZero && isZeroValue(val);
70 }
71 
72 // Helper to detect sampling operation.
73 static bool isSampling(GenericOp op) {
74   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
75   if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
76     if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
77       // Both scalar input arguments used exactly once.
78       Value s1 = op.getBlock()->getArgument(0);
79       Value s2 = op.getBlock()->getArgument(1);
80       return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
81              (def->getOperand(1) == s1 && def->getOperand(0) == s2);
82     }
83   }
84   return false;
85 }
86 
87 // Helper to detect chain of multiplications that do not involve x.
88 static bool isMulChain(Value val, Value x) {
89   if (auto arg = dyn_cast<BlockArgument>(val))
90     return arg != x;
91   if (auto *def = val.getDefiningOp()) {
92     if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
93       return isMulChain(def->getOperand(0), x) &&
94              isMulChain(def->getOperand(1), x);
95   }
96   return false;
97 }
98 
99 // Helper to detect x = x + <multiplications>.
100 static bool isSumOfMul(GenericOp op) {
101   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
102   if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
103     if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
104       Value x = op.getBlock()->getArguments().back();
105       return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
106              (def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
107     }
108   }
109   return false;
110 }
111 
112 // Helper to detect direct yield of a zero value.
113 static bool isZeroYield(GenericOp op) {
114   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
115   if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
116     if (arg.getOwner()->getParentOp() == op) {
117       return isZeroValue(op->getOperand(arg.getArgNumber()));
118     }
119   }
120   return isZeroValue(yieldOp.getOperand(0));
121 }
122 
123 /// Populates given sizes array from type (for static sizes) and from
124 /// the tensor (for dynamic sizes).
125 static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
126                            Location loc, ShapedType stp, Value tensor) {
127   for (const auto &d : enumerate(stp.getShape())) {
128     Value dim;
129     if (d.value() == ShapedType::kDynamic)
130       dim = builder.create<tensor::DimOp>(loc, tensor, d.index());
131     else
132       dim = constantIndex(builder, loc, d.value());
133     sizes.push_back(dim);
134   }
135 }
136 
137 static RankedTensorType getBufferType(const SparseTensorType &stt,
138                                       bool needTmpCOO) {
139   return needTmpCOO ? stt.getCOOType(/*ordered=*/false)
140                     : stt.getRankedTensorType();
141 }
142 
143 /// Collects the dynamic dimension sizes for `tp` with the assumption that
144 /// `sizes` are the dimension sizes for the type. Stores the dynamic dimension
145 /// sizes to dynSizes.
146 static void getDynamicSizes(RankedTensorType tp, ValueRange sizes,
147                             SmallVectorImpl<Value> &dynSizes) {
148   for (const auto &d : enumerate(tp.getShape())) {
149     if (d.value() == ShapedType::kDynamic)
150       dynSizes.push_back(sizes[d.index()]);
151   }
152 }
153 
154 static LogicalResult genForeachOnSparseConstant(ForeachOp op,
155                                                 RewriterBase &rewriter,
156                                                 SparseElementsAttr attr) {
157   auto loc = op.getLoc();
158   SmallVector<Value> reduc = op.getInitArgs();
159 
160   // Foreach on constant.
161   foreachInSparseConstant(
162       rewriter, loc, attr, op.getOrder().value_or(AffineMap()),
163       [&reduc, &rewriter, op](ArrayRef<Value> cvs, Value v) mutable {
164         SmallVector<Value> args;
165         args.append(cvs.begin(), cvs.end());
166         args.push_back(v);
167         args.append(reduc);
168         // Clones the foreach op to get a copy of the loop body.
169         auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
170         assert(args.size() == cloned.getBody()->getNumArguments());
171         Operation *yield = cloned.getBody()->getTerminator();
172         rewriter.inlineBlockBefore(cloned.getBody(), op, args);
173         // clean up
174         rewriter.eraseOp(cloned);
175         reduc = yield->getOperands();
176         rewriter.eraseOp(yield);
177       });
178 
179   rewriter.replaceOp(op, reduc);
180   return success();
181 }
182 
183 /// Populates the given sizes array for concatenation from types (for static
184 /// sizes) and from the source tensors (for dynamic sizes).
185 static void concatSizesFromInputs(OpBuilder &builder,
186                                   SmallVectorImpl<Value> &sizes, Location loc,
187                                   ShapedType dstTp, ValueRange srcs,
188                                   unsigned dim) {
189   auto dstShape = dstTp.getShape();
190   sizesFromSrc(builder, sizes, loc, srcs[0]);
191 
192   // Sum up on the `dim` if the dimension is dynamic.
193   if (dstShape[dim] != ShapedType::kDynamic) {
194     // Faithfully take the static size.
195     sizes[dim] = constantIndex(builder, loc, dstShape[dim]);
196   } else {
197     // Else, compute the shape dynamically.
198     for (const auto &src : srcs.drop_front()) {
199       Value srcSz = linalg::createOrFoldDimOp(builder, loc, src, dim);
200       // Sum up all the sizes.
201       sizes[dim] = builder.create<arith::AddIOp>(loc, sizes[dim], srcSz);
202     }
203   }
204 }
205 
206 //===---------------------------------------------------------------------===//
207 // The actual sparse tensor rewriting rules.
208 //===---------------------------------------------------------------------===//
209 
210 namespace {
211 
212 /// Rewriting rule that converts direct yield of zero with initial allocation.
213 struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
214 public:
215   using OpRewritePattern<GenericOp>::OpRewritePattern;
216 
217   LogicalResult matchAndRewrite(GenericOp op,
218                                 PatternRewriter &rewriter) const override {
219     if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 ||
220         !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
221         !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
222       return failure();
223     auto outputType = getRankedTensorType(op.getResult(0));
224     // Yielding zero on newly materialized sparse tensor can be
225     // optimized directly (regardless of dynamic or static size).
226     if (getSparseTensorEncoding(outputType)) {
227       rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
228       return success();
229     }
230     // Use static zero value directly instead of materialization.
231     if (!outputType.hasStaticShape())
232       return failure();
233     Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
234     rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType));
235     rewriter.eraseOp(def);
236     return success();
237   }
238 };
239 
240 /// Rewriting rule that converts two kernels:
241 ///
242 ///      T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
243 ///      X(i,j) = S(i,j) * T(i,j)
244 ///
245 /// into a single kernel, using distributive law:
246 ///
247 ///      X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
248 ///
249 /// This kind of fusion (merging two ops into one but using arithmetic
250 /// equalities that may not hold for floating-point computations) would
251 /// be undesirable in the dense case, since we distribute the multiplication
252 /// into the reduction loop. However, for sparse sampling tensor S, such
253 /// a fusion may actually reduce the asymptotic complexity of the kernel,
254 /// since intermediate results may be nullified.
255 struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
256 public:
257   using OpRewritePattern<GenericOp>::OpRewritePattern;
258 
259   LogicalResult matchAndRewrite(GenericOp op,
260                                 PatternRewriter &rewriter) const override {
261     // Check consumer.
262     if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 ||
263         op.getNumResults() != 1 ||
264         op.getNumParallelLoops() != op.getNumLoops() ||
265         !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
266         !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
267         !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
268       return failure();
269     // Find consuming OP2(sparse, other) or OP2(other, sparse). The other
270     // operand can be sparse or dense, since the point of this rewriting rule
271     // is detecting a situation in which *more* sparsity is introduced into
272     // a computation, be it already sparse or still dense.
273     unsigned other = 0;
274     if (isSparseTensor(op.getDpsInputOperand(0)))
275       other = 1;
276     else if (!isSparseTensor(op.getDpsInputOperand(1)))
277       return failure();
278     // Check producer.
279     auto prod = dyn_cast_or_null<GenericOp>(
280         op.getDpsInputOperand(other)->get().getDefiningOp());
281     if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 ||
282         !prod.getResult(0).hasOneUse())
283       return failure();
284     // Sampling consumer and sum of multiplication chain producer.
285     if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
286         !isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) ||
287         !isSampling(op) || !isSumOfMul(prod))
288       return failure();
289     // Modify operand structure of producer and consumer.
290     Location loc = prod.getLoc();
291     SmallVector<Value> inputOps = prod.getInputs();
292     SmallVector<Value> outputOps = op.getOutputs();
293     SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
294     inputOps.push_back(op.getDpsInputOperand(1 - other)->get());
295     fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
296     // Fuse producer and consumer into a new generic op.
297     auto fusedOp = rewriter.create<GenericOp>(
298         loc, op.getResult(0).getType(), inputOps, outputOps,
299         rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(),
300         /*doc=*/nullptr, /*library_call=*/nullptr);
301     Block &prodBlock = prod.getRegion().front();
302     Block &consBlock = op.getRegion().front();
303     IRMapping mapper;
304     Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
305     unsigned num = prodBlock.getNumArguments();
306     for (unsigned i = 0; i < num - 1; i++)
307       addArg(mapper, fusedBlock, prodBlock.getArgument(i));
308     addArg(mapper, fusedBlock, consBlock.getArgument(1 - other));
309     addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
310     // Clone bodies of the producer and consumer in new evaluation order.
311     auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
312     auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp();
313     Value last;
314     for (auto &op : prodBlock.without_terminator())
315       if (&op != acc) {
316         last = op.getResult(0);
317         rewriter.clone(op, mapper);
318       }
319     mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0));
320     mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
321     last = rewriter.clone(*acc, mapper)->getResult(0);
322     rewriter.create<linalg::YieldOp>(loc, last);
323     // Force initial value on merged allocation for dense outputs.
324     // TODO: deal with non alloc tensor here one day
325     if (!getSparseTensorEncoding(op.getResult(0).getType())) {
326       Value init = prod.getDpsInitOperand(0)
327                        ->get()
328                        .getDefiningOp<AllocTensorOp>()
329                        .getCopy();
330       AllocTensorOp a =
331           op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
332       rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });
333     }
334     // Replace consumer with fused operation. Old producer
335     // and consumer ops will be removed by DCE.
336     rewriter.replaceOp(op, fusedOp->getResults());
337     return success();
338   }
339 
340 private:
341   // Helper to add argument and record the mapping.
342   static void addArg(IRMapping &mapper, Block *b, BlockArgument a) {
343     mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
344   }
345 };
346 
347 // Fuse a tensor cast into producing operation. Note that a tensor.cast
348 // should really not be used to convert between sparse encodings. Since
349 // the pattern currently appears as a result of some prior rewriting
350 // we make an attempt to repair very obvious cases.
351 // TODO: audit the pure tensor dialect rewriting rules
352 struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
353 public:
354   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
355 
356   LogicalResult matchAndRewrite(tensor::CastOp op,
357                                 PatternRewriter &rewriter) const override {
358     Type srcType = op.getSource().getType();
359     Type dstType = op.getDest().getType();
360     // A nop cast simply folds away.
361     if (srcType == dstType) {
362       rewriter.replaceOp(op, op->getResults());
363       return success();
364     }
365     // See if a sparsity changing cast can be fused into producer.
366     if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) {
367       if (Operation *def = op.getSource().getDefiningOp()) {
368         if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
369           rewriter.modifyOpInPlace(def, [&]() {
370             def->getResult(0).setType(op->getResultTypes()[0]);
371           });
372           rewriter.replaceOp(op, def->getResult(0));
373           return success();
374         }
375       }
376     }
377     // Repair tensor casts with at least one sparse operand into the
378     // the properly supported sparse_tensor.convert.
379     if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) {
380       rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource());
381       return success();
382     }
383     // Fail otherwise.
384     return failure();
385   }
386 };
387 
388 /// Rewrites a sequence of operations for sparse tensor selections in to
389 /// semi-ring operations such that they can be compiled correctly by the
390 /// sparsifier. E.g., transforming the following sequence
391 ///
392 /// %sel = arith.select %cond, %sp1, %sp2
393 ///
394 /// to
395 ///
396 /// %sel = binary %sp1, %sp2:
397 ///         both  (%l, %r) {yield select %cond, %l, %r}
398 ///         left  (%l)     {yield select %cond, %l,  0}
399 ///         right (%r)     {yield select %cond,  0, %r}
400 ///
401 /// TODO: We require that the tensor used for extracting conditions to be dense
402 /// to sparsify the code. To support a sparse condition tensor, we need a
403 /// tri-nary operation.
404 struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
405 public:
406   using OpRewritePattern<GenericOp>::OpRewritePattern;
407   LogicalResult matchAndRewrite(GenericOp op,
408                                 PatternRewriter &rewriter) const override {
409     // Rejects non sparse kernels.
410     if (!op.hasPureTensorSemantics() || !hasAnySparseOperand(op))
411       return failure();
412 
413     Location loc = op.getLoc();
414     SmallVector<std::pair<Operation *, sparse_tensor::BinaryOp>> semiRings;
415     for (Operation &inst : *op.getBody()) {
416       // Matches pattern.
417       auto matched = isRewritablePattern(op, &inst);
418       if (!matched.has_value())
419         continue;
420 
421       rewriter.setInsertionPoint(&inst);
422       auto [c, t, f] = matched.value();
423       assert(t.getType() == f.getType());
424       auto selTp = t.getType();
425       auto c0 = constantZero(rewriter, loc, selTp);
426       auto binOp = rewriter.create<sparse_tensor::BinaryOp>(loc, selTp, t, f);
427       // Initializes all the blocks.
428       rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
429                            {t.getLoc(), f.getLoc()});
430       rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
431       rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
432 
433       for (auto *r : binOp.getRegions()) {
434         Block *b = &r->front();
435         rewriter.setInsertionPointToStart(b);
436 
437         IRMapping irMap;
438         // Clones the cmp operations into the region to make the binary op
439         // admissible.
440         Value newC = c;
441         if (auto *def = c.getDefiningOp())
442           newC = rewriter.clone(*def, irMap)->getResult(0);
443 
444         irMap.map(c, newC);
445         if (r == &binOp.getLeftRegion()) {
446           irMap.map(t, b->getArgument(0));
447           irMap.map(f, c0);
448         } else if (r == &binOp.getRightRegion()) {
449           irMap.map(t, c0);
450           irMap.map(f, b->getArgument(0));
451         } else {
452           irMap.map(t, b->getArgument(0));
453           irMap.map(f, b->getArgument(1));
454         }
455         auto y = rewriter.clone(inst, irMap)->getResult(0);
456         rewriter.create<sparse_tensor::YieldOp>(loc, y);
457       }
458 
459       // We successfully rewrited a operation. We can not do replacement here
460       // becuase it invalidate the iterator for the current loop to traverse
461       // the instructions.
462       semiRings.emplace_back(&inst, binOp);
463     }
464 
465     // Finalizes the replacement.
466     for (auto [sel, semi] : semiRings)
467       rewriter.replaceOp(sel, semi->getResults());
468 
469     return success(!semiRings.empty());
470   }
471 
472 private:
473   static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
474   isRewritablePattern(GenericOp op, Operation *v) {
475     auto sel = dyn_cast<arith::SelectOp>(v);
476     if (!sel)
477       return std::nullopt;
478 
479     auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
480     auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
481     // TODO: For simplicity, we only handle cases where both true/false value
482     // are directly loaded the input tensor. We can probably admit more cases
483     // in theory.
484     if (!tVal || !fVal)
485       return std::nullopt;
486 
487     // Helper lambda to determine whether the value is loaded from a dense input
488     // or is a loop invariant.
489     auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
490       if (auto bArg = dyn_cast<BlockArgument>(v);
491           bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
492         return true;
493       // If the value is defined outside the loop, it is a loop invariant.
494       return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody();
495     };
496 
497     // If the condition value is load directly from a dense tensor or
498     // loop-invariants, we can sparsify the kernel.
499     auto cond = sel.getCondition();
500     if (isValFromDenseInputOrInvariant(cond))
501       return std::make_tuple(cond, tVal, fVal);
502 
503     Value cmpL, cmpR;
504     if (matchPattern(cond, m_Op<arith::CmpIOp>(matchers::m_Any(&cmpL),
505                                                matchers::m_Any(&cmpR))) ||
506         matchPattern(cond, m_Op<arith::CmpFOp>(matchers::m_Any(&cmpL),
507                                                matchers::m_Any(&cmpR)))) {
508       // TODO: we can do it recursively to check whether all the leaf values are
509       // loaded from dense tensors or are loop invariants.
510       if (isValFromDenseInputOrInvariant(cmpL) ||
511           isValFromDenseInputOrInvariant(cmpR))
512         return std::make_tuple(cond, tVal, fVal);
513     }
514 
515     return std::nullopt;
516   };
517 };
518 
519 /// Rewrites a sparse reduction that would not sparsify directly since
520 /// doing so would only iterate over the stored elements, ignoring the
521 /// implicit zeros, into a semi-ring. Applies to all prod/and/min/max
522 /// (note that reductions like add/sub/or/xor can directly be sparsified
523 /// since the implicit zeros do not contribute to the final result).
524 /// Note that prod/and are still included since, even though they often
525 /// are nullified in sparse data, they may still occur for special
526 /// situations in which e.g. some rows in a sparse matrix are fully
527 /// dense. For min/max, including the implicit zeros is a much more
528 /// common situation.
529 ///
530 /// TODO: this essentially "densifies" the operation; we want to implement
531 ///       this much more efficiently by performing the reduction over the
532 ///       stored values, and feed in the zero once if there were *any*
533 ///       implicit zeros as well; but for now, at least we provide
534 ///       the functionality
535 ///
536 struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
537 public:
538   using OpRewritePattern<GenericOp>::OpRewritePattern;
539 
540   LogicalResult matchAndRewrite(GenericOp op,
541                                 PatternRewriter &rewriter) const override {
542     // Reject non-reductions.
543     if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 ||
544         op.getNumReductionLoops() == 0 || op.getNumResults() != 1)
545       return failure();
546     auto *inp = op.getDpsInputOperand(0);
547     auto *init = op.getDpsInitOperand(0);
548     if (!isSparseTensor(inp))
549       return failure();
550     // Look for direct x = x OP y for semi-ring ready reductions.
551     auto *red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator())
552                     .getOperand(0)
553                     .getDefiningOp();
554     if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp,
555              arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp,
556              arith::MaxUIOp>(red))
557       return failure();
558     Value s0 = op.getBlock()->getArgument(0);
559     Value s1 = op.getBlock()->getArgument(1);
560     if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&
561         (red->getOperand(0) != s1 || red->getOperand(1) != s0))
562       return failure();
563     // Identity.
564     Location loc = op.getLoc();
565     Value identity =
566         rewriter.create<tensor::ExtractOp>(loc, init->get(), ValueRange());
567     // Unary {
568     //    present -> value
569     //    absent  -> zero.
570     // }
571     Type rtp = s0.getType();
572     rewriter.setInsertionPointToStart(&op.getRegion().front());
573     auto semiring = rewriter.create<sparse_tensor::UnaryOp>(loc, rtp, s0);
574     Block *present =
575         rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc);
576     rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front());
577     rewriter.create<sparse_tensor::YieldOp>(loc, present->getArgument(0));
578     rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {});
579     rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front());
580     auto zero =
581         rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(rtp));
582     rewriter.create<sparse_tensor::YieldOp>(loc, zero);
583     rewriter.setInsertionPointAfter(semiring);
584     // CustomReduce {
585     //    x = x REDUC y, identity
586     // }
587     auto custom = rewriter.create<sparse_tensor::ReduceOp>(
588         loc, rtp, semiring.getResult(), s1, identity);
589     Block *region =
590         rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});
591     rewriter.setInsertionPointToStart(&custom.getRegion().front());
592     IRMapping irMap;
593     irMap.map(red->getOperand(0), region->getArgument(0));
594     irMap.map(red->getOperand(1), region->getArgument(1));
595     auto *cloned = rewriter.clone(*red, irMap);
596     rewriter.create<sparse_tensor::YieldOp>(loc, cloned->getResult(0));
597     rewriter.setInsertionPointAfter(custom);
598     rewriter.replaceOp(red, custom.getResult());
599     return success();
600   }
601 };
602 
603 /// Sparse rewriting rule for the print operator. This operation is mainly used
604 /// for debugging and testing. As such, it lowers to the vector.print operation
605 /// which only require very light-weight runtime support.
606 struct PrintRewriter : public OpRewritePattern<PrintOp> {
607 public:
608   using OpRewritePattern::OpRewritePattern;
609   LogicalResult matchAndRewrite(PrintOp op,
610                                 PatternRewriter &rewriter) const override {
611     Location loc = op.getLoc();
612     auto tensor = op.getTensor();
613     auto stt = getSparseTensorType(tensor);
614     // Header with NSE.
615     auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor);
616     rewriter.create<vector::PrintOp>(
617         loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
618     rewriter.create<vector::PrintOp>(loc, nse);
619     // Print run-time contents for dim/lvl sizes.
620     rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("dim = "));
621     printSizes(rewriter, loc, tensor, stt.getDimRank(), /*isDim=*/true);
622     rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("lvl = "));
623     printSizes(rewriter, loc, tensor, stt.getLvlRank(), /*isDim=*/false);
624     // Use the "codegen" foreach loop construct to iterate over
625     // all typical sparse tensor components for printing.
626     foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor,
627                                             &stt](Type, FieldIndex,
628                                                   SparseTensorFieldKind kind,
629                                                   Level l, LevelType) {
630       switch (kind) {
631       case SparseTensorFieldKind::StorageSpec: {
632         break;
633       }
634       case SparseTensorFieldKind::PosMemRef: {
635         auto lvl = constantIndex(rewriter, loc, l);
636         rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos["));
637         rewriter.create<vector::PrintOp>(
638             loc, lvl, vector::PrintPunctuation::NoPunctuation);
639         rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
640         auto pos = rewriter.create<ToPositionsOp>(loc, tensor, l);
641         printContents(rewriter, loc, pos);
642         break;
643       }
644       case SparseTensorFieldKind::CrdMemRef: {
645         auto lvl = constantIndex(rewriter, loc, l);
646         rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("crd["));
647         rewriter.create<vector::PrintOp>(
648             loc, lvl, vector::PrintPunctuation::NoPunctuation);
649         rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
650         Value crd = nullptr;
651         // For COO AoS storage, we want to print a single, linear view of
652         // the full coordinate storage at this level. For any other storage,
653         // we show the coordinate storage for every indivual level.
654         if (stt.getAoSCOOStart() == l)
655           crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor);
656         else
657           crd = rewriter.create<ToCoordinatesOp>(loc, tensor, l);
658         printContents(rewriter, loc, crd);
659         break;
660       }
661       case SparseTensorFieldKind::ValMemRef: {
662         rewriter.create<vector::PrintOp>(loc,
663                                          rewriter.getStringAttr("values : "));
664         auto val = rewriter.create<ToValuesOp>(loc, tensor);
665         printContents(rewriter, loc, val);
666         break;
667       }
668       }
669       return true;
670     });
671     rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("----\n"));
672     rewriter.eraseOp(op);
673     return success();
674   }
675 
676 private:
677   // Helper to print contents of a single memref. Note that for the "push_back"
678   // vectors, this prints the full capacity, not just the size. This is done
679   // on purpose, so that clients see how much storage has been allocated in
680   // total. Contents of the extra capacity in the buffer may be uninitialized
681   // (unless the flag enable-buffer-initialization is set to true).
682   //
683   // Generates code to print:
684   //    ( a0, a1, ... )
685   static void printContents(PatternRewriter &rewriter, Location loc,
686                             Value vec) {
687     // Open bracket.
688     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
689     // For loop over elements.
690     auto zero = constantIndex(rewriter, loc, 0);
691     auto size = rewriter.create<memref::DimOp>(loc, vec, zero);
692     auto step = constantIndex(rewriter, loc, 1);
693     auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
694     rewriter.setInsertionPointToStart(forOp.getBody());
695     auto idx = forOp.getInductionVar();
696     auto val = rewriter.create<memref::LoadOp>(loc, vec, idx);
697     if (llvm::isa<ComplexType>(val.getType())) {
698       // Since the vector dialect does not support complex types in any op,
699       // we split those into (real, imag) pairs here.
700       Value real = rewriter.create<complex::ReOp>(loc, val);
701       Value imag = rewriter.create<complex::ImOp>(loc, val);
702       rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
703       rewriter.create<vector::PrintOp>(loc, real,
704                                        vector::PrintPunctuation::Comma);
705       rewriter.create<vector::PrintOp>(loc, imag,
706                                        vector::PrintPunctuation::Close);
707       rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
708     } else {
709       rewriter.create<vector::PrintOp>(loc, val,
710                                        vector::PrintPunctuation::Comma);
711     }
712     rewriter.setInsertionPointAfter(forOp);
713     // Close bracket and end of line.
714     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
715     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
716   }
717 
718   // Helper method to print run-time lvl/dim sizes.
719   static void printSizes(PatternRewriter &rewriter, Location loc, Value tensor,
720                          unsigned size, bool isDim) {
721     // Open bracket.
722     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
723     // Print unrolled contents (dimop requires constant value).
724     for (unsigned i = 0; i < size; i++) {
725       auto idx = constantIndex(rewriter, loc, i);
726       Value val;
727       if (isDim)
728         val = rewriter.create<tensor::DimOp>(loc, tensor, idx);
729       else
730         val = rewriter.create<LvlOp>(loc, tensor, idx);
731       rewriter.create<vector::PrintOp>(
732           loc, val,
733           i != size - 1 ? vector::PrintPunctuation::Comma
734                         : vector::PrintPunctuation::NoPunctuation);
735     }
736     // Close bracket and end of line.
737     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
738     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
739   }
740 };
741 
742 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
743 struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
744 public:
745   using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern;
746 
747   LogicalResult matchAndRewrite(tensor::ReshapeOp op,
748                                 PatternRewriter &rewriter) const override {
749     Location loc = op.getLoc();
750     Value srcTensor = op.getSource();
751     const auto srcTp = getSparseTensorType(srcTensor);
752     const auto dstTp = getSparseTensorType(op.getResult());
753 
754     if (!srcTp.hasEncoding() || !dstTp.hasEncoding() ||
755         !dstTp.hasStaticDimShape())
756       return failure();
757 
758     SmallVector<Value> srcSizes;
759     sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
760     SmallVector<Value> dstSizes;
761     for (Dimension d : dstTp.getDimShape())
762       dstSizes.push_back(constantIndex(rewriter, loc, d));
763 
764     Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
765     // Only need an unordered COO buffer if input and output are not sorted
766     // in the same way.
767     Type bufferTp = getBufferType(
768         dstTp.withoutDimToLvl(),
769         !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
770     SmallVector<Value> dynSizes;
771     Value buffer = rewriter
772                        .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
773                                               nnz, Attribute())
774                        .getResult();
775 
776     // Convert src coordinates to dst coordinates by first collapsing it to 1D
777     // and then expand it to the match the rank of the destination tensor.
778     // Implemented as follows:
779     //   foreach srcCoords %srcTensor
780     //     collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank])
781     //     expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank])
782     //     insert expandedCoords, %buffer
783     //
784     // followed by an optional
785     //   %t = sparse_tensor.cast %tmp
786     // depending on whether the input/output are sorted in the same way.
787     const auto encSrc = srcTp.getEncoding();
788     ForeachOp foreachOp = rewriter.create<ForeachOp>(
789         loc, srcTensor, buffer,
790         [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
791             ValueRange reduc) {
792           const Dimension srcRank = srcTp.getDimRank();
793           SmallVector<Value> srcDcvs;
794           srcDcvs.reserve(srcRank);
795           for (Dimension d = 0; d < srcRank; d++) {
796             Level lvl = toLvl(encSrc, d);
797             srcDcvs.push_back(srcLcvs[lvl]);
798           }
799 
800           Value collapseSize = constantIndex(builder, loc, 1);
801           for (Dimension d = 0; d < srcRank; d++)
802             collapseSize =
803                 builder.create<arith::MulIOp>(loc, collapseSize, srcSizes[d]);
804           SmallVector<Value, 1> collapsedSizes = {collapseSize};
805 
806           ReassociationIndices collapseIdx;
807           for (Dimension i = 0; i < srcRank; i++)
808             collapseIdx.push_back(i);
809           SmallVector<ReassociationIndices, 1> collapseReass = {collapseIdx};
810           SmallVector<Value, 1> collapsedDcvs;
811           reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs,
812                      collapsedSizes, collapsedDcvs);
813 
814           ReassociationIndices expandIdx;
815           for (Dimension i = 0; i < dstTp.getDimRank(); i++)
816             expandIdx.push_back(i);
817           SmallVector<ReassociationIndices, 1> expandReass = {expandIdx};
818           SmallVector<Value> dstDcvs;
819           reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
820                      dstSizes, dstDcvs);
821 
822           auto t =
823               builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
824           builder.create<sparse_tensor::YieldOp>(loc, t);
825         });
826 
827     Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
828     if (bufferTp != dstTp) {
829       auto dstRTT = dstTp.getRankedTensorType();
830       Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
831       rewriter.create<DeallocTensorOp>(loc, t);
832       t = converted;
833     }
834     rewriter.replaceOp(op, t);
835     return success();
836   }
837 };
838 
839 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
840 template <typename ReshapeOp>
841 struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
842 public:
843   using OpRewritePattern<ReshapeOp>::OpRewritePattern;
844 
845   LogicalResult matchAndRewrite(ReshapeOp op,
846                                 PatternRewriter &rewriter) const override {
847     Location loc = op.getLoc();
848     Value srcTensor = op.getSrc();
849     const auto srcTp = getSparseTensorType(srcTensor);
850     const auto dstTp = getSparseTensorType(op.getResult());
851     if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
852       return failure();
853 
854     // Generate code to represent the static dimension constants or compute
855     // the dynamic dimension values.
856     SmallVector<Value> srcSizes;
857     sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
858     SmallVector<Value> dstSizes;
859     SmallVector<Value> dstDynSizes;
860     if (dstTp.hasStaticDimShape()) {
861       for (Dimension d : dstTp.getDimShape())
862         dstSizes.push_back(constantIndex(rewriter, loc, d));
863     } else {
864       ArrayRef<Size> dstShape = dstTp.getDimShape();
865       genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape,
866                          op.getReassociationIndices());
867       for (auto [idx, shape] : llvm::enumerate(dstShape)) {
868         if (shape == ShapedType::kDynamic)
869           dstDynSizes.push_back(dstSizes[idx]);
870       }
871     }
872     Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
873     // Only need a unordered COO buffer if input and output are not sorted
874     // in the same way.
875     Type bufferTp = getBufferType(
876         dstTp.withoutDimToLvl(),
877         !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
878 
879     Value buffer =
880         rewriter
881             .create<AllocTensorOp>(loc, bufferTp, dstDynSizes, Value(),
882                                    /*sizeHint=*/nnz, Attribute())
883             .getResult();
884 
885     // Implement the sparse2sparse reshape as follows:
886     //   foreach srcCoords %srcTensor
887     //     insert reshapeCvs(srcCoords), %buffer
888     //
889     // followed by an optional
890     //   %t = sparse_tensor.cast %tmp
891     // depending on whether the input/output are sorted in the same way.
892     const auto encSrc = srcTp.getEncoding();
893     ForeachOp foreachOp = rewriter.create<ForeachOp>(
894         loc, srcTensor, buffer,
895         [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
896             ValueRange reduc) {
897           const Dimension dimRank = srcTp.getDimRank();
898           SmallVector<Value> srcDcvs;
899           srcDcvs.reserve(dimRank);
900           for (Dimension d = 0; d < dimRank; d++) {
901             Level lvl = toLvl(encSrc, d);
902             srcDcvs.push_back(srcLcvs[lvl]);
903           }
904           SmallVector<Value> dstDcvs;
905           reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
906                      srcDcvs, dstSizes, dstDcvs);
907           auto t =
908               builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
909           builder.create<sparse_tensor::YieldOp>(loc, t);
910         });
911 
912     Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
913     if (bufferTp != dstTp) {
914       auto dstRTT = dstTp.getRankedTensorType();
915       Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
916       rewriter.create<DeallocTensorOp>(loc, t);
917       t = converted;
918     }
919     rewriter.replaceOp(op, t);
920     return success();
921   }
922 };
923 
924 /// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape
925 /// operator.
926 template <typename ReshapeOp>
927 struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
928 public:
929   using OpRewritePattern<ReshapeOp>::OpRewritePattern;
930 
931   LogicalResult matchAndRewrite(ReshapeOp op,
932                                 PatternRewriter &rewriter) const override {
933     Location loc = op->getLoc();
934     auto encDst = getSparseTensorEncoding(op.getResult().getType());
935     auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
936     // Since a pure dense expansion is very cheap (change of view), for
937     // a sparse2dense or dense2sparse, we can simply unfuse a sparse
938     // conversion from the reshape operation itself.
939     // All other cases are handled elsewhere.
940     if (encDst && encSrc) {
941       return failure();
942     }
943     if (encSrc) {
944       auto rtp = getRankedTensorType(op.getSrc());
945       auto denseTp =
946           RankedTensorType::get(rtp.getShape(), rtp.getElementType());
947       auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
948       rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); });
949       return success();
950     }
951     if (encDst) {
952       auto rtp = getRankedTensorType(op.getResult());
953       auto denseTp =
954           RankedTensorType::get(rtp.getShape(), rtp.getElementType());
955       ReshapeOp reshape;
956       if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) {
957         reshape = rewriter.create<ReshapeOp>(
958             loc, denseTp, op.getSrc(), op.getReassociation(),
959             op.getOutputShape(), op.getStaticOutputShape());
960       } else {
961         reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
962                                              op.getReassociation());
963       }
964       Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
965       rewriter.replaceOp(op, convert);
966       return success();
967     }
968     return failure();
969   }
970 };
971 
972 // A trivial wrapper to help generate different operations for dense/sparse
973 // tensors.
974 struct TensorLike {
975   TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
976              ValueRange sizes) {
977     SmallVector<Value> dynSzs;
978     getDynamicSizes(rtt, sizes, dynSzs);
979 
980     val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
981     if (!isSparse()) {
982       Value c0 = constantZero(builder, loc, rtt.getElementType());
983       val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0);
984     }
985   }
986 
987   void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
988     val = builder.create<tensor::InsertOp>(loc, v, val, crds);
989   }
990 
991   Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
992     if (isSparse())
993       return builder.create<LoadOp>(loc, val, true);
994     return val;
995   }
996 
997   bool isSparse() const {
998     return getSparseTensorEncoding(val.getType()) != nullptr;
999   }
1000 
1001   Value val;
1002 };
1003 
1004 struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
1005   using OpRewritePattern::OpRewritePattern;
1006   LogicalResult matchAndRewrite(tensor::DimOp op,
1007                                 PatternRewriter &rewriter) const override {
1008     std::optional<int64_t> dim = op.getConstantIndex();
1009     auto stt = getSparseTensorType(op.getSource());
1010     if (!dim || !stt.hasEncoding())
1011       return failure();
1012 
1013     if (stt.isPermutation()) {
1014       rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
1015                                          toLvl(stt.getEncoding(), *dim));
1016       return success();
1017     }
1018 
1019     // Non-permutation dim2lvl/lvl2dim maps.
1020     // Compute as follows:
1021     // affine.apply #map (l0 - 1, l1 - 1, ...) + 1
1022     // Note that it is not the most efficient way (but a more general one) for
1023     // the lvl to dim translation, e.g., for BSR, the dimension size for can be
1024     // computed simply by lvl_size * block_size.
1025     Location loc = op.getLoc();
1026     SmallVector<Value> maxLvlCrds;
1027     for (Level l = 0; l < stt.getLvlRank(); l++) {
1028       Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l);
1029       Value maxLvlCrd = rewriter.create<arith::SubIOp>(
1030           loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType()));
1031       maxLvlCrds.push_back(maxLvlCrd);
1032     }
1033 
1034     AffineExpr lvl2DimExp = stt.getLvlToDim().getResult(*dim);
1035     Value maxDimCrd = rewriter.create<affine::AffineApplyOp>(
1036         op.getLoc(), AffineMap::get(stt.getLvlRank(), 0, lvl2DimExp),
1037         maxLvlCrds);
1038 
1039     Value dimSz = rewriter.create<arith::AddIOp>(
1040         loc, maxDimCrd, constantOne(rewriter, loc, rewriter.getIndexType()));
1041     rewriter.replaceOp(op, dimSz);
1042     return success();
1043   }
1044 };
1045 
1046 struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
1047   using OpRewritePattern::OpRewritePattern;
1048   LogicalResult matchAndRewrite(ConcatenateOp op,
1049                                 PatternRewriter &rewriter) const override {
1050     if (op.needsExtraSort())
1051       op.emitError("ConcatenateOp not staged");
1052 
1053     const Location loc = op.getLoc();
1054     const auto dstTp = getSparseTensorType(op);
1055     const Dimension conDim = op.getDimension();
1056     SmallVector<Value> sizes;
1057     concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim);
1058 
1059     // %t = concatenate %s1, %s2, %s3 {dim = 1}
1060     // ==>
1061     // if (isSparseDst)
1062     //   if (allDense)
1063     //     %tmp = bufferization.alloc_tensor dstTp
1064     //   else
1065     //     %tmp = bufferization.alloc_tensor : unordered COO
1066     // else
1067     //   %tmp = memref.alloc : dense tensor
1068     // foreach in %s1 : insert d0, d1, %tmp
1069     // foreach in %s2 : insert d0, d1 + size(s1), %tmp
1070     // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp
1071 
1072     TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
1073     Value offset = constantIndex(rewriter, loc, 0);
1074     Value iterArg = dstBuf.val;
1075 
1076     ForeachOp foreachOp;
1077     for (Value input : op.getInputs()) {
1078       // Builds a for op for each input tensor to append new values into the
1079       // output tensor.
1080       foreachOp = rewriter.create<ForeachOp>(
1081           loc, input, iterArg,
1082           [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1083               ValueRange reduc) {
1084             SmallVector<Value> offDimCrd(dcvs);
1085             offDimCrd[conDim] =
1086                 builder.create<arith::AddIOp>(loc, offDimCrd[conDim], offset);
1087 
1088             // Enters foreach, updates the SSA chain.
1089             dstBuf.val = reduc.front();
1090             if (!dstTp.isAllDense()) {
1091               Value cond = genIsNonzero(builder, loc, v);
1092               auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1093                                                     /*else*/ true);
1094               builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1095               builder.create<scf::YieldOp>(loc, dstBuf.val);
1096 
1097               builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1098               dstBuf.insert(builder, loc, v, offDimCrd);
1099               builder.create<scf::YieldOp>(loc, dstBuf.val);
1100 
1101               // Exits the ifOp, update the sparse tensor SSA value.
1102               builder.setInsertionPointAfter(ifOp);
1103               dstBuf.val = ifOp.getResult(0);
1104             } else {
1105               dstBuf.insert(builder, loc, v, offDimCrd);
1106             }
1107             builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1108           });
1109       // Accumulates the offset. Note that only static-shaped inputs are allowed
1110       // by concatenate op verifier, which saves us from computing the offset
1111       // dynamically.
1112       const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim);
1113       assert(!ShapedType::isDynamic(sz));
1114       offset = rewriter.create<arith::AddIOp>(loc, offset,
1115                                               constantIndex(rewriter, loc, sz));
1116       iterArg = foreachOp.getResult(0);
1117       dstBuf.val = iterArg;
1118     }
1119 
1120     dstBuf.val = iterArg;
1121     Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
1122     rewriter.replaceOp(op, ret);
1123     return success();
1124   }
1125 };
1126 
1127 struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
1128   using OpRewritePattern::OpRewritePattern;
1129   LogicalResult matchAndRewrite(ConvertOp op,
1130                                 PatternRewriter &rewriter) const override {
1131     if (op.needsExtraSort())
1132       return op.emitError("ConvertOp not staged.");
1133 
1134     // TODO: Maybe we want a different operation for this too.
1135     auto encDst = getSparseTensorEncoding(op.getType());
1136     auto encSrc = getSparseTensorEncoding(op.getSource().getType());
1137     if (encDst && encSrc && !encSrc.isSlice() &&
1138         encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
1139       // Trivial tensor conversion and simple element type conversion is handled
1140       // in codegen.
1141       return failure();
1142     }
1143 
1144     Location loc = op.getLoc();
1145     Value src = op.getSource();
1146 
1147     SparseTensorType srcStt = getSparseTensorType(op.getSource());
1148     SparseTensorType dstStt = getSparseTensorType(op.getDest());
1149 
1150     bool fromSparseConst = false;
1151     if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>())
1152       if (dyn_cast<SparseElementsAttr>(constOp.getValue()))
1153         fromSparseConst = true;
1154 
1155     const AffineMapAttr foreachOrder =
1156         (!dstStt.isIdentity() && fromSparseConst)
1157             ? AffineMapAttr::get(dstStt.getExpandedDimToLvl())
1158             : nullptr;
1159 
1160     bool skipZeroCheck = srcStt.hasEncoding() || fromSparseConst;
1161 
1162     SmallVector<Value> sizes;
1163     sizesFromSrc(rewriter, sizes, loc, src);
1164     ValueRange vs;
1165     TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
1166 
1167     auto foreachOp = rewriter.create<ForeachOp>(
1168         loc, src, dstBuf.val, foreachOrder,
1169         [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1170             ValueRange reduc) {
1171           // Enters the loop, update the SSA value for insertion chain.
1172           dstBuf.val = reduc.front();
1173           if (!skipZeroCheck) {
1174             Value cond = genIsNonzero(builder, loc, v);
1175             auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1176                                                   /*else*/ true);
1177             builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1178             builder.create<scf::YieldOp>(loc, dstBuf.val);
1179 
1180             builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1181             dstBuf.insert(builder, loc, v, dcvs);
1182             builder.create<scf::YieldOp>(loc, dstBuf.val);
1183 
1184             // Exits the ifOp, update the sparse tensor SSA value.
1185             builder.setInsertionPointAfter(ifOp);
1186             dstBuf.val = ifOp.getResult(0);
1187           } else {
1188             dstBuf.insert(builder, loc, v, dcvs);
1189           }
1190           builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1191         });
1192 
1193     rewriter.setInsertionPointAfter(foreachOp);
1194 
1195     // Exits the for loop, links the SSA chain.
1196     dstBuf.val = foreachOp.getResult(0);
1197 
1198     Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
1199     rewriter.replaceOp(op, ret);
1200     return success();
1201   }
1202 };
1203 
1204 struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
1205   using OpRewritePattern::OpRewritePattern;
1206   LogicalResult matchAndRewrite(CrdTranslateOp op,
1207                                 PatternRewriter &rewriter) const override {
1208     AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
1209                         ? op.getEncoder().getDimToLvl()
1210                         : op.getEncoder().getLvlToDim();
1211 
1212     SmallVector<Value> outCrds;
1213     for (AffineExpr result : map.getResults()) {
1214       // TODO: we should probably expand the affine map to IR using our own
1215       // rules, since affine.apply assume signed value, while the cooridinates
1216       // we provided must always be signless.
1217       Value trans = rewriter.create<affine::AffineApplyOp>(
1218           op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
1219           op.getInCrds());
1220       outCrds.push_back(trans);
1221     }
1222     rewriter.replaceOp(op, outCrds);
1223     return success();
1224   }
1225 };
1226 
1227 /// Sparse rewriting rule for the foreach operator.
1228 struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
1229 public:
1230   using OpRewritePattern::OpRewritePattern;
1231 
1232   LogicalResult matchAndRewrite(ForeachOp op,
1233                                 PatternRewriter &rewriter) const override {
1234 
1235     auto loc = op.getLoc();
1236     Value input = op.getTensor();
1237     SmallVector<Value> reduc = op.getInitArgs();
1238     const auto stt = getSparseTensorType(input);
1239     const Level lvlRank = stt.getLvlRank();
1240 
1241     // Special-case: for each over a sparse constant uses its own rewriting
1242     // rule.
1243     if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
1244       if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
1245         return genForeachOnSparseConstant(op, rewriter, attr);
1246       }
1247     }
1248 
1249     // Otherwise, use loop emitter to generate loops.
1250     const auto enc = stt.getEncoding();
1251 
1252     // 1. Generates loop for the sparse input.
1253     LoopEmitter loopEmitter(
1254         ValueRange{input},
1255         StringAttr::get(getContext(), ForeachOp::getOperationName()));
1256     loopEmitter.initializeLoopEmit(rewriter, loc);
1257     for (Level l = 0; l < lvlRank; l++) {
1258       // TODO: provide utility function for loop sequences that only contains
1259       // one for loop?
1260       const SmallVector<TensorLevel, 1> tidLvls{
1261           loopEmitter.makeTensorLevel(0, l)};
1262       loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
1263       // Note that reduc will be taken care of by loop emitter and get updated
1264       // in place.
1265       loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls,
1266                                                     reduc);
1267     }
1268 
1269     SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
1270     if (op.getOrder()) {
1271       // TODO: Support it so that we can do direct conversion from CSR->BSR.
1272       llvm_unreachable(
1273           "Level order not yet implemented on non-constant input tensors.");
1274     }
1275 
1276     Value vals = loopEmitter.getValBuffer()[0];
1277     SmallVector<Value> pos = loopEmitter.getValPosits(0);
1278     // Loads the value from sparse tensor using position-index;
1279     // loads the value from dense tensor using coords.
1280     Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pos)
1281                     : rewriter.create<memref::LoadOp>(loc, vals, lcvs);
1282 
1283     // 2. Inline the block in the foreach operator.
1284     Block *srcBlock = op.getBody();
1285 
1286     // Remap coordinates.
1287     SmallVector<Value> args =
1288         enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
1289 
1290     // Remap value.
1291     args.push_back(val);
1292     // Remap reduction variables.
1293     args.append(reduc);
1294 
1295     // Remove sparse_tensor.yield.
1296     SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands();
1297     rewriter.eraseOp(srcBlock->getTerminator());
1298 
1299     Operation &last = rewriter.getBlock()->back();
1300     if (llvm::isa<scf::YieldOp>(last)) {
1301       // Because `scf.for` inserts an implicit yield op when there is no
1302       // reduction variable upon creation, we reset the insertion point such
1303       // that the block is inlined before *before* the yield op.
1304       rewriter.setInsertionPoint(&last);
1305     }
1306 
1307     rewriter.inlineBlockBefore(srcBlock, rewriter.getBlock(),
1308                                rewriter.getInsertionPoint(), args);
1309     rewriter.setInsertionPointToEnd(rewriter.getBlock());
1310     for (Level l = 0; l < lvlRank; l++) {
1311       // Link the reduction chain. Note that loop emitter update the reducValue
1312       // in place.
1313       loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
1314       loopEmitter.exitCurrentLoopSeq(rewriter, loc);
1315     }
1316 
1317     // Replace the foreach operator with the value returned by the outtermost
1318     // for loop.
1319     rewriter.replaceOp(op, reducValue);
1320     return success();
1321   }
1322 };
1323 
1324 /// Sparse rewriting rule for the new operator.
1325 struct NewRewriter : public OpRewritePattern<NewOp> {
1326   using OpRewritePattern::OpRewritePattern;
1327   LogicalResult matchAndRewrite(NewOp op,
1328                                 PatternRewriter &rewriter) const override {
1329     Location loc = op.getLoc();
1330     auto stt = getSparseTensorType(op.getResult());
1331     if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0)
1332       return failure();
1333 
1334     // Implement the NewOp as follows:
1335     //   %orderedCoo = sparse_tensor.new %filename
1336     //   %t = sparse_tensor.convert %orderedCoo
1337     // with enveloping reinterpreted_map ops for non-permutations.
1338     RankedTensorType dstTp = stt.getRankedTensorType();
1339     RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
1340     Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
1341     Value convert = cooTensor;
1342     auto enc = stt.getEncoding();
1343     if (!stt.isPermutation()) { // demap coo, demap dstTp
1344       auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
1345       convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
1346       dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
1347     }
1348     convert = rewriter.create<ConvertOp>(loc, dstTp, convert);
1349     if (!stt.isPermutation()) // remap to original enc
1350       convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert);
1351     rewriter.replaceOp(op, convert);
1352 
1353     // Release the temporary ordered COO tensor.
1354     rewriter.setInsertionPointAfterValue(convert);
1355     rewriter.create<DeallocTensorOp>(loc, cooTensor);
1356 
1357     return success();
1358   }
1359 };
1360 
1361 /// Sparse rewriting rule for the out operator.
1362 struct OutRewriter : public OpRewritePattern<OutOp> {
1363   using OpRewritePattern::OpRewritePattern;
1364   LogicalResult matchAndRewrite(OutOp op,
1365                                 PatternRewriter &rewriter) const override {
1366     Location loc = op.getLoc();
1367     // Calculate NNZ.
1368     Value src = op.getTensor();
1369     Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
1370 
1371     // Allocate a temporary buffer for storing dimension-sizes/coordinates.
1372     const auto srcTp = getSparseTensorType(src);
1373     const Dimension dimRank = srcTp.getDimRank();
1374     Type indexTp = rewriter.getIndexType();
1375     Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp);
1376 
1377     // Generate code to calculate dimension size values and store the values to
1378     // the buffer.
1379     SmallVector<Value> dims;
1380     sizesForTensor(rewriter, dims, loc, srcTp, src);
1381     for (Dimension d = 0; d < dimRank; d++) {
1382       rewriter.create<memref::StoreOp>(loc, dims[d], dimSizes,
1383                                        constantIndex(rewriter, loc, d));
1384     }
1385 
1386     // Create a sparse tensor writer and output meta data.
1387     Type opaqueTp = getOpaquePointerType(rewriter);
1388     Value writer =
1389         createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp},
1390                        {op.getDest()}, EmitCInterface::Off)
1391             .getResult(0);
1392     Value rankValue = constantIndex(rewriter, loc, dimRank);
1393     createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {},
1394                    {writer, rankValue, nnz, dimSizes}, EmitCInterface::On);
1395 
1396     Value dimCoords = dimSizes; // Reuse the dimSizes buffer for dimCoords.
1397     Type eltTp = srcTp.getElementType();
1398     SmallString<29> outNextFuncName{"outSparseTensorWriterNext",
1399                                     primaryTypeFunctionSuffix(eltTp)};
1400     Value value = genAllocaScalar(rewriter, loc, eltTp);
1401     ModuleOp module = op->getParentOfType<ModuleOp>();
1402 
1403     // For each element in the source tensor, output the element.
1404     rewriter.create<ForeachOp>(
1405         loc, src, std::nullopt,
1406         [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1407             ValueRange reduc) {
1408           for (Dimension d = 0; d < dimRank; d++) {
1409             rewriter.create<memref::StoreOp>(loc, dcvs[d], dimCoords,
1410                                              constantIndex(builder, loc, d));
1411           }
1412           rewriter.create<memref::StoreOp>(loc, v, value);
1413           SmallVector<Value> operands{writer, rankValue, dimCoords, value};
1414           FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands,
1415                                          EmitCInterface::On);
1416           builder.create<func::CallOp>(loc, TypeRange(), fn, operands);
1417           builder.create<sparse_tensor::YieldOp>(loc);
1418         });
1419 
1420     // Release the writer.
1421     createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer},
1422                    EmitCInterface::Off);
1423 
1424     rewriter.eraseOp(op);
1425     return success();
1426   }
1427 };
1428 
1429 } // namespace
1430 
1431 //===---------------------------------------------------------------------===//
1432 // Methods that add patterns described in this file to a pattern list.
1433 //===---------------------------------------------------------------------===//
1434 
1435 void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
1436   patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
1437                GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
1438       patterns.getContext());
1439 }
1440 
1441 void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
1442                                                    bool enableRT,
1443                                                    bool enableConvert) {
1444   patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
1445                ReshapeRewriter<tensor::CollapseShapeOp>,
1446                Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
1447                Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
1448                SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
1449       patterns.getContext());
1450 
1451   if (enableConvert)
1452     patterns.add<DirectConvertRewriter>(patterns.getContext());
1453   if (!enableRT)
1454     patterns.add<NewRewriter>(patterns.getContext());
1455 }
1456 
1457 void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) {
1458   // Run CrdTranslateRewriter later in the pipeline so that operation can be
1459   // folded before lowering to affine.apply
1460   patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext());
1461 }
1462