xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (revision 129ade21bdad5f09206b773cd0591a9616ad0ca4)
128ebb0b6SAart Bik //===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
228ebb0b6SAart Bik //
328ebb0b6SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
428ebb0b6SAart Bik // See https://llvm.org/LICENSE.txt for license information.
528ebb0b6SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
628ebb0b6SAart Bik //
728ebb0b6SAart Bik //===----------------------------------------------------------------------===//
828ebb0b6SAart Bik //
928ebb0b6SAart Bik // This file implements rewriting rules that are specific to sparse tensors.
1028ebb0b6SAart Bik //
1128ebb0b6SAart Bik //===----------------------------------------------------------------------===//
1228ebb0b6SAart Bik 
13365777ecSAart Bik #include "Utils/CodegenUtils.h"
14365777ecSAart Bik #include "Utils/LoopEmitter.h"
15c7bb69bcSAart Bik 
166456e0bbSPeiming Liu #include "mlir/Dialect/Affine/IR/AffineOps.h"
17abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1828ebb0b6SAart Bik #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1928ebb0b6SAart Bik #include "mlir/Dialect/Linalg/IR/Linalg.h"
20555e7835Sbixia1 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21550288cbSPeiming Liu #include "mlir/Dialect/MemRef/IR/MemRef.h"
2267f61b08Sbixia1 #include "mlir/Dialect/SCF/IR/SCF.h"
2328ebb0b6SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
24d37affb0SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
25f708a549Swren romano #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
2628ebb0b6SAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
2728ebb0b6SAart Bik #include "mlir/Dialect/Tensor/IR/Tensor.h"
28d37affb0SAart Bik #include "mlir/Dialect/Vector/IR/VectorOps.h"
2928ebb0b6SAart Bik #include "mlir/IR/AffineMap.h"
3028ebb0b6SAart Bik #include "mlir/IR/Matchers.h"
3128ebb0b6SAart Bik #include "mlir/Support/LLVM.h"
3228ebb0b6SAart Bik 
3328ebb0b6SAart Bik using namespace mlir;
3428ebb0b6SAart Bik using namespace mlir::bufferization;
3528ebb0b6SAart Bik using namespace mlir::linalg;
3628ebb0b6SAart Bik using namespace mlir::sparse_tensor;
3728ebb0b6SAart Bik 
3828ebb0b6SAart Bik //===---------------------------------------------------------------------===//
3928ebb0b6SAart Bik // Helper methods for the actual rewriting rules.
4028ebb0b6SAart Bik //===---------------------------------------------------------------------===//
4128ebb0b6SAart Bik 
420d4e7fbaSAart Bik // Helper method to match any typed zero.
430d4e7fbaSAart Bik static bool isZeroValue(Value val) {
440d4e7fbaSAart Bik   return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat());
450d4e7fbaSAart Bik }
460d4e7fbaSAart Bik 
4728ebb0b6SAart Bik // Helper to detect a sparse tensor type operand.
48e7df8281SPeiming Liu static bool isSparseTensor(Value v) {
49e7df8281SPeiming Liu   auto enc = getSparseTensorEncoding(v.getType());
501944c4f7SAart Bik   return enc && !llvm::all_of(enc.getLvlTypes(),
51aaf91645SPeiming Liu                               [](auto lt) { return lt == LevelFormat::Dense; });
52b4e2b7f9SPeiming Liu }
53e7df8281SPeiming Liu static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
54b4e2b7f9SPeiming Liu 
556a45339bSAart Bik // Helper method to find zero/uninitialized tensor materialization.
566a45339bSAart Bik static bool isMaterializing(OpOperand *op, bool isZero) {
5728ebb0b6SAart Bik   Value val = op->get();
5865074179SAart Bik   // Check allocation, with zero alloc when required.
59ce3d0e87SAart Bik   if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
60ce3d0e87SAart Bik     Value copy = alloc.getCopy();
61ce3d0e87SAart Bik     if (isZero)
620d4e7fbaSAart Bik       return copy && isZeroValue(copy);
63ce3d0e87SAart Bik     return !copy;
64ce3d0e87SAart Bik   }
656a45339bSAart Bik   // Check for empty tensor materialization.
666a45339bSAart Bik   if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
676a45339bSAart Bik     return !isZero;
6865074179SAart Bik   // Last resort for zero alloc: the whole value is zero.
6965074179SAart Bik   return isZero && isZeroValue(val);
7028ebb0b6SAart Bik }
7128ebb0b6SAart Bik 
7228ebb0b6SAart Bik // Helper to detect sampling operation.
7328ebb0b6SAart Bik static bool isSampling(GenericOp op) {
74d3b3f765SJacques Pienaar   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
7528ebb0b6SAart Bik   if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
7628ebb0b6SAart Bik     if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
7728ebb0b6SAart Bik       // Both scalar input arguments used exactly once.
7828ebb0b6SAart Bik       Value s1 = op.getBlock()->getArgument(0);
7928ebb0b6SAart Bik       Value s2 = op.getBlock()->getArgument(1);
8028ebb0b6SAart Bik       return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
8128ebb0b6SAart Bik              (def->getOperand(1) == s1 && def->getOperand(0) == s2);
8228ebb0b6SAart Bik     }
8328ebb0b6SAart Bik   }
8428ebb0b6SAart Bik   return false;
8528ebb0b6SAart Bik }
8628ebb0b6SAart Bik 
8728ebb0b6SAart Bik // Helper to detect chain of multiplications that do not involve x.
8828ebb0b6SAart Bik static bool isMulChain(Value val, Value x) {
895550c821STres Popp   if (auto arg = dyn_cast<BlockArgument>(val))
9028ebb0b6SAart Bik     return arg != x;
9128ebb0b6SAart Bik   if (auto *def = val.getDefiningOp()) {
9228ebb0b6SAart Bik     if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
9328ebb0b6SAart Bik       return isMulChain(def->getOperand(0), x) &&
9428ebb0b6SAart Bik              isMulChain(def->getOperand(1), x);
9528ebb0b6SAart Bik   }
9628ebb0b6SAart Bik   return false;
9728ebb0b6SAart Bik }
9828ebb0b6SAart Bik 
9928ebb0b6SAart Bik // Helper to detect x = x + <multiplications>.
10028ebb0b6SAart Bik static bool isSumOfMul(GenericOp op) {
101d3b3f765SJacques Pienaar   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
10228ebb0b6SAart Bik   if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
10328ebb0b6SAart Bik     if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
10428ebb0b6SAart Bik       Value x = op.getBlock()->getArguments().back();
10528ebb0b6SAart Bik       return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
10628ebb0b6SAart Bik              (def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
10728ebb0b6SAart Bik     }
10828ebb0b6SAart Bik   }
10928ebb0b6SAart Bik   return false;
11028ebb0b6SAart Bik }
11128ebb0b6SAart Bik 
112c7bb69bcSAart Bik // Helper to detect direct yield of a zero value.
113c7bb69bcSAart Bik static bool isZeroYield(GenericOp op) {
114d3b3f765SJacques Pienaar   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
1155550c821STres Popp   if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
116c7bb69bcSAart Bik     if (arg.getOwner()->getParentOp() == op) {
117a7cccb9cSAlexander Belyaev       return isZeroValue(op->getOperand(arg.getArgNumber()));
118c7bb69bcSAart Bik     }
119c7bb69bcSAart Bik   }
1200d4e7fbaSAart Bik   return isZeroValue(yieldOp.getOperand(0));
121c7bb69bcSAart Bik }
122c7bb69bcSAart Bik 
123330d48c4Sbixia1 /// Populates given sizes array from type (for static sizes) and from
124330d48c4Sbixia1 /// the tensor (for dynamic sizes).
1250e1708ffSAart Bik static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
126330d48c4Sbixia1                            Location loc, ShapedType stp, Value tensor) {
127330d48c4Sbixia1   for (const auto &d : enumerate(stp.getShape())) {
128330d48c4Sbixia1     Value dim;
129399638f9SAliia Khasanova     if (d.value() == ShapedType::kDynamic)
130330d48c4Sbixia1       dim = builder.create<tensor::DimOp>(loc, tensor, d.index());
131330d48c4Sbixia1     else
132330d48c4Sbixia1       dim = constantIndex(builder, loc, d.value());
133330d48c4Sbixia1     sizes.push_back(dim);
134330d48c4Sbixia1   }
135330d48c4Sbixia1 }
136330d48c4Sbixia1 
13776647fceSwren romano static RankedTensorType getBufferType(const SparseTensorType &stt,
13876647fceSwren romano                                       bool needTmpCOO) {
13945288085SAart Bik   return needTmpCOO ? stt.getCOOType(/*ordered=*/false)
14076647fceSwren romano                     : stt.getRankedTensorType();
14181e3079dSbixia1 }
14281e3079dSbixia1 
143eb877006Sbixia1 /// Collects the dynamic dimension sizes for `tp` with the assumption that
144eb877006Sbixia1 /// `sizes` are the dimension sizes for the type. Stores the dynamic dimension
145eb877006Sbixia1 /// sizes to dynSizes.
146dda3dc5eSPeiming Liu static void getDynamicSizes(RankedTensorType tp, ValueRange sizes,
147eb877006Sbixia1                             SmallVectorImpl<Value> &dynSizes) {
148eb877006Sbixia1   for (const auto &d : enumerate(tp.getShape())) {
149399638f9SAliia Khasanova     if (d.value() == ShapedType::kDynamic)
150eb877006Sbixia1       dynSizes.push_back(sizes[d.index()]);
151eb877006Sbixia1   }
152eb877006Sbixia1 }
153eb877006Sbixia1 
1548d615a23SPeiming Liu static LogicalResult genForeachOnSparseConstant(ForeachOp op,
1558d615a23SPeiming Liu                                                 RewriterBase &rewriter,
1568d615a23SPeiming Liu                                                 SparseElementsAttr attr) {
1578d615a23SPeiming Liu   auto loc = op.getLoc();
1588d615a23SPeiming Liu   SmallVector<Value> reduc = op.getInitArgs();
1598d615a23SPeiming Liu 
1608d615a23SPeiming Liu   // Foreach on constant.
1618d615a23SPeiming Liu   foreachInSparseConstant(
1629d4df97fSwren romano       rewriter, loc, attr, op.getOrder().value_or(AffineMap()),
16384cd51bbSwren romano       [&reduc, &rewriter, op](ArrayRef<Value> cvs, Value v) mutable {
1648d615a23SPeiming Liu         SmallVector<Value> args;
16584cd51bbSwren romano         args.append(cvs.begin(), cvs.end());
1668d615a23SPeiming Liu         args.push_back(v);
1678d615a23SPeiming Liu         args.append(reduc);
1688d615a23SPeiming Liu         // Clones the foreach op to get a copy of the loop body.
1698d615a23SPeiming Liu         auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
1708d615a23SPeiming Liu         assert(args.size() == cloned.getBody()->getNumArguments());
1718d615a23SPeiming Liu         Operation *yield = cloned.getBody()->getTerminator();
17242c31d83SMatthias Springer         rewriter.inlineBlockBefore(cloned.getBody(), op, args);
1738d615a23SPeiming Liu         // clean up
1748d615a23SPeiming Liu         rewriter.eraseOp(cloned);
1758d615a23SPeiming Liu         reduc = yield->getOperands();
1768d615a23SPeiming Liu         rewriter.eraseOp(yield);
1778d615a23SPeiming Liu       });
1788d615a23SPeiming Liu 
1798d615a23SPeiming Liu   rewriter.replaceOp(op, reduc);
1808d615a23SPeiming Liu   return success();
1818d615a23SPeiming Liu }
1828d615a23SPeiming Liu 
183aedf5d58Sbixia1 /// Populates the given sizes array for concatenation from types (for static
184aedf5d58Sbixia1 /// sizes) and from the source tensors (for dynamic sizes).
185aedf5d58Sbixia1 static void concatSizesFromInputs(OpBuilder &builder,
186aedf5d58Sbixia1                                   SmallVectorImpl<Value> &sizes, Location loc,
187aedf5d58Sbixia1                                   ShapedType dstTp, ValueRange srcs,
188aedf5d58Sbixia1                                   unsigned dim) {
189aedf5d58Sbixia1   auto dstShape = dstTp.getShape();
190aedf5d58Sbixia1   sizesFromSrc(builder, sizes, loc, srcs[0]);
191aedf5d58Sbixia1 
192aedf5d58Sbixia1   // Sum up on the `dim` if the dimension is dynamic.
193aedf5d58Sbixia1   if (dstShape[dim] != ShapedType::kDynamic) {
194aedf5d58Sbixia1     // Faithfully take the static size.
195aedf5d58Sbixia1     sizes[dim] = constantIndex(builder, loc, dstShape[dim]);
196aedf5d58Sbixia1   } else {
197aedf5d58Sbixia1     // Else, compute the shape dynamically.
198aedf5d58Sbixia1     for (const auto &src : srcs.drop_front()) {
199aedf5d58Sbixia1       Value srcSz = linalg::createOrFoldDimOp(builder, loc, src, dim);
200aedf5d58Sbixia1       // Sum up all the sizes.
201aedf5d58Sbixia1       sizes[dim] = builder.create<arith::AddIOp>(loc, sizes[dim], srcSz);
202aedf5d58Sbixia1     }
203aedf5d58Sbixia1   }
204aedf5d58Sbixia1 }
205aedf5d58Sbixia1 
20628ebb0b6SAart Bik //===---------------------------------------------------------------------===//
20728ebb0b6SAart Bik // The actual sparse tensor rewriting rules.
20828ebb0b6SAart Bik //===---------------------------------------------------------------------===//
20928ebb0b6SAart Bik 
21028ebb0b6SAart Bik namespace {
21128ebb0b6SAart Bik 
212ea3eeb48SPeiming Liu /// TODO: move it to tensor dialect instead.
213ea3eeb48SPeiming Liu ///
214ea3eeb48SPeiming Liu /// Fold `tensor.concat` and `tensor.extract_slice`
215ea3eeb48SPeiming Liu ///
216ea3eeb48SPeiming Liu /// %concat = tensor.concat dim(2) %t0, %t1
217ea3eeb48SPeiming Liu ///   : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
218ea3eeb48SPeiming Liu /// %extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
219ea3eeb48SPeiming Liu ///   : tensor<1x64x2xf32> to tensor<1x64x1xf32>
220ea3eeb48SPeiming Liu /// %extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
221ea3eeb48SPeiming Liu ///   : tensor<1x64x2xf32> to tensor<1x64x1xf32>
222ea3eeb48SPeiming Liu ///
223ea3eeb48SPeiming Liu /// Becomes
224ea3eeb48SPeiming Liu ///
225ea3eeb48SPeiming Liu /// %extract0, %extract1 = %t0, %t1
226ea3eeb48SPeiming Liu struct FuseExtractSliceWithConcat
227ea3eeb48SPeiming Liu     : public OpRewritePattern<tensor::ExtractSliceOp> {
228ea3eeb48SPeiming Liu   using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
229ea3eeb48SPeiming Liu 
230ea3eeb48SPeiming Liu   LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
231ea3eeb48SPeiming Liu                                 PatternRewriter &rewriter) const override {
232ea3eeb48SPeiming Liu     auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
233ea3eeb48SPeiming Liu     if (!concatOp)
234ea3eeb48SPeiming Liu       return failure();
235ea3eeb48SPeiming Liu 
236ea3eeb48SPeiming Liu     Location loc = extractOp.getLoc();
237ea3eeb48SPeiming Liu     int64_t dim = concatOp.getDim();
238ea3eeb48SPeiming Liu     int64_t rank = extractOp.getResultType().getRank();
239ea3eeb48SPeiming Liu 
240ea3eeb48SPeiming Liu     SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1));
241ea3eeb48SPeiming Liu     SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0));
242ea3eeb48SPeiming Liu 
243ea3eeb48SPeiming Liu     // Compute the partial sums for the slice offsets.
244ea3eeb48SPeiming Liu     AffineExpr sum = rewriter.getAffineDimExpr(0);
245ea3eeb48SPeiming Liu     SmallVector<AffineExpr> partialSums = {sum};
246ea3eeb48SPeiming Liu     SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
247ea3eeb48SPeiming Liu     for (auto [idx, input] :
248ea3eeb48SPeiming Liu          llvm::enumerate(concatOp.getInputs().drop_back())) {
249ea3eeb48SPeiming Liu       sum = sum + rewriter.getAffineDimExpr(idx + 1);
250ea3eeb48SPeiming Liu       partialSums.push_back(sum);
251ea3eeb48SPeiming Liu       offsetStrides.push_back(
252ea3eeb48SPeiming Liu           rewriter.createOrFold<tensor::DimOp>(loc, input, dim));
253ea3eeb48SPeiming Liu     }
254ea3eeb48SPeiming Liu     auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
255ea3eeb48SPeiming Liu                                         partialSums, rewriter.getContext());
256ea3eeb48SPeiming Liu     SmallVector<OpFoldResult> dimOffsets =
257ea3eeb48SPeiming Liu         affine::makeComposedFoldedMultiResultAffineApply(
258ea3eeb48SPeiming Liu             rewriter, loc, partialSumMap, offsetStrides);
259ea3eeb48SPeiming Liu 
260ea3eeb48SPeiming Liu     auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {
261ea3eeb48SPeiming Liu       for (auto [l, r] : llvm::zip(lhs, rhs)) {
262ea3eeb48SPeiming Liu         std::optional<int64_t> staticVal = getConstantIntValue(l);
263ea3eeb48SPeiming Liu         if (!staticVal.has_value() || staticVal != getConstantIntValue(r))
264ea3eeb48SPeiming Liu           return false;
265ea3eeb48SPeiming Liu       }
266ea3eeb48SPeiming Liu       return lhs.size() == rhs.size();
267ea3eeb48SPeiming Liu     };
268ea3eeb48SPeiming Liu 
269ea3eeb48SPeiming Liu     for (auto [i, input, offset] :
270ea3eeb48SPeiming Liu          llvm::enumerate(concatOp.getInputs(), dimOffsets)) {
271ea3eeb48SPeiming Liu       SmallVector<OpFoldResult> srcSizes =
272ea3eeb48SPeiming Liu           tensor::getMixedSizes(rewriter, loc, input);
273ea3eeb48SPeiming Liu       srcOffsets[dim] = offset;
274ea3eeb48SPeiming Liu 
275ea3eeb48SPeiming Liu       SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes();
276ea3eeb48SPeiming Liu       SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets();
277ea3eeb48SPeiming Liu       SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides();
278ea3eeb48SPeiming Liu 
279ea3eeb48SPeiming Liu       if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
280ea3eeb48SPeiming Liu           allEqual(srcStrides, dstStrides)) {
281ea3eeb48SPeiming Liu         Value operand = concatOp.getOperand(i);
282ea3eeb48SPeiming Liu         if (operand.getType() == extractOp.getResultType())
283ea3eeb48SPeiming Liu           rewriter.replaceOp(extractOp, operand);
284ea3eeb48SPeiming Liu         break;
285ea3eeb48SPeiming Liu       }
286ea3eeb48SPeiming Liu     }
287ea3eeb48SPeiming Liu 
288ea3eeb48SPeiming Liu     return success();
289ea3eeb48SPeiming Liu   }
290ea3eeb48SPeiming Liu };
291ea3eeb48SPeiming Liu 
2923aeb28b9SPeiming Liu /// Rewriting rule that fuses sparse_tensor.convert into producer.
2933aeb28b9SPeiming Liu struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
2943aeb28b9SPeiming Liu public:
2953aeb28b9SPeiming Liu   using OpRewritePattern::OpRewritePattern;
2963aeb28b9SPeiming Liu 
2973aeb28b9SPeiming Liu   LogicalResult matchAndRewrite(ConvertOp op,
2983aeb28b9SPeiming Liu                                 PatternRewriter &rewriter) const override {
2993aeb28b9SPeiming Liu     auto producer = op.getSource().getDefiningOp<GenericOp>();
3003aeb28b9SPeiming Liu     if (!producer || producer.getDpsInits().size() != 1 ||
3013aeb28b9SPeiming Liu         !isMaterializing(producer.getDpsInitOperand(0), false) ||
3023aeb28b9SPeiming Liu         !producer.getResult(0).hasOneUse()) {
3033aeb28b9SPeiming Liu       return failure();
3043aeb28b9SPeiming Liu     }
305fb8f492aSPeiming Liu     // Clone the materialization operation, but update the result to sparse.
306fb8f492aSPeiming Liu     rewriter.setInsertionPoint(producer);
307fb8f492aSPeiming Liu     Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp();
308fb8f492aSPeiming Liu     Operation *cloned = rewriter.clone(*init);
309fb8f492aSPeiming Liu     cloned->getResult(0).setType(op.getResult().getType());
310fb8f492aSPeiming Liu 
3113aeb28b9SPeiming Liu     rewriter.modifyOpInPlace(producer, [&]() {
312fb8f492aSPeiming Liu       producer.getDpsInitsMutable().assign(cloned->getResults());
3133aeb28b9SPeiming Liu       producer.getResult(0).setType(op.getResult().getType());
3143aeb28b9SPeiming Liu     });
3153aeb28b9SPeiming Liu 
3163aeb28b9SPeiming Liu     rewriter.replaceAllOpUsesWith(op, producer);
3173aeb28b9SPeiming Liu     op->erase();
3183aeb28b9SPeiming Liu 
3193aeb28b9SPeiming Liu     return success();
3203aeb28b9SPeiming Liu   }
3213aeb28b9SPeiming Liu };
3223aeb28b9SPeiming Liu 
323c7bb69bcSAart Bik /// Rewriting rule that converts direct yield of zero with initial allocation.
324c7bb69bcSAart Bik struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
325c7bb69bcSAart Bik public:
326c7bb69bcSAart Bik   using OpRewritePattern<GenericOp>::OpRewritePattern;
327c7bb69bcSAart Bik 
328c7bb69bcSAart Bik   LogicalResult matchAndRewrite(GenericOp op,
329c7bb69bcSAart Bik                                 PatternRewriter &rewriter) const override {
3300a8e3dd4SMatthias Springer     if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 ||
3316a45339bSAart Bik         !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
33263baab8bSPeiming Liu         !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
333c7bb69bcSAart Bik       return failure();
334255c3f11Swren romano     auto outputType = getRankedTensorType(op.getResult(0));
3356a45339bSAart Bik     // Yielding zero on newly materialized sparse tensor can be
3366a45339bSAart Bik     // optimized directly (regardless of dynamic or static size).
337ec495b53SPeiming Liu     if (getSparseTensorEncoding(outputType)) {
338b4db15a9SAlexander Belyaev       rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
339ec495b53SPeiming Liu       return success();
340ec495b53SPeiming Liu     }
3416a45339bSAart Bik     // Use static zero value directly instead of materialization.
342ec495b53SPeiming Liu     if (!outputType.hasStaticShape())
343ec495b53SPeiming Liu       return failure();
3446a45339bSAart Bik     Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
3456a45339bSAart Bik     rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType));
3466a45339bSAart Bik     rewriter.eraseOp(def);
347c7bb69bcSAart Bik     return success();
348c7bb69bcSAart Bik   }
349c7bb69bcSAart Bik };
350c7bb69bcSAart Bik 
35128ebb0b6SAart Bik /// Rewriting rule that converts two kernels:
35228ebb0b6SAart Bik ///
35328ebb0b6SAart Bik ///      T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
35428ebb0b6SAart Bik ///      X(i,j) = S(i,j) * T(i,j)
35528ebb0b6SAart Bik ///
35628ebb0b6SAart Bik /// into a single kernel, using distributive law:
35728ebb0b6SAart Bik ///
35828ebb0b6SAart Bik ///      X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
35928ebb0b6SAart Bik ///
36028ebb0b6SAart Bik /// This kind of fusion (merging two ops into one but using arithmetic
36128ebb0b6SAart Bik /// equalities that may not hold for floating-point computations) would
36228ebb0b6SAart Bik /// be undesirable in the dense case, since we distribute the multiplication
36328ebb0b6SAart Bik /// into the reduction loop. However, for sparse sampling tensor S, such
36428ebb0b6SAart Bik /// a fusion may actually reduce the asymptotic complexity of the kernel,
36528ebb0b6SAart Bik /// since intermediate results may be nullified.
36628ebb0b6SAart Bik struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
36728ebb0b6SAart Bik public:
36828ebb0b6SAart Bik   using OpRewritePattern<GenericOp>::OpRewritePattern;
36928ebb0b6SAart Bik 
37028ebb0b6SAart Bik   LogicalResult matchAndRewrite(GenericOp op,
37128ebb0b6SAart Bik                                 PatternRewriter &rewriter) const override {
37228ebb0b6SAart Bik     // Check consumer.
3730a8e3dd4SMatthias Springer     if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 ||
37428ebb0b6SAart Bik         op.getNumResults() != 1 ||
37528ebb0b6SAart Bik         op.getNumParallelLoops() != op.getNumLoops() ||
376b4db15a9SAlexander Belyaev         !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
377b4db15a9SAlexander Belyaev         !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
378b4db15a9SAlexander Belyaev         !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
37928ebb0b6SAart Bik       return failure();
38028ebb0b6SAart Bik     // Find consuming OP2(sparse, other) or OP2(other, sparse). The other
38128ebb0b6SAart Bik     // operand can be sparse or dense, since the point of this rewriting rule
38228ebb0b6SAart Bik     // is detecting a situation in which *more* sparsity is introduced into
38328ebb0b6SAart Bik     // a computation, be it already sparse or still dense.
38428ebb0b6SAart Bik     unsigned other = 0;
385b4db15a9SAlexander Belyaev     if (isSparseTensor(op.getDpsInputOperand(0)))
38628ebb0b6SAart Bik       other = 1;
387b4db15a9SAlexander Belyaev     else if (!isSparseTensor(op.getDpsInputOperand(1)))
38828ebb0b6SAart Bik       return failure();
38928ebb0b6SAart Bik     // Check producer.
39028ebb0b6SAart Bik     auto prod = dyn_cast_or_null<GenericOp>(
391b4db15a9SAlexander Belyaev         op.getDpsInputOperand(other)->get().getDefiningOp());
3920a8e3dd4SMatthias Springer     if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 ||
39328ebb0b6SAart Bik         !prod.getResult(0).hasOneUse())
39428ebb0b6SAart Bik       return failure();
39528ebb0b6SAart Bik     // Sampling consumer and sum of multiplication chain producer.
3966a45339bSAart Bik     if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
3976a45339bSAart Bik         !isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) ||
398ce3d0e87SAart Bik         !isSampling(op) || !isSumOfMul(prod))
39928ebb0b6SAart Bik       return failure();
40028ebb0b6SAart Bik     // Modify operand structure of producer and consumer.
40128ebb0b6SAart Bik     Location loc = prod.getLoc();
402a7cccb9cSAlexander Belyaev     SmallVector<Value> inputOps = prod.getInputs();
403a7cccb9cSAlexander Belyaev     SmallVector<Value> outputOps = op.getOutputs();
404d2c0572bSJacques Pienaar     SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
405b4db15a9SAlexander Belyaev     inputOps.push_back(op.getDpsInputOperand(1 - other)->get());
40628ebb0b6SAart Bik     fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
40728ebb0b6SAart Bik     // Fuse producer and consumer into a new generic op.
40828ebb0b6SAart Bik     auto fusedOp = rewriter.create<GenericOp>(
40928ebb0b6SAart Bik         loc, op.getResult(0).getType(), inputOps, outputOps,
410c38d9cf2SOleg Shyshkov         rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(),
41128ebb0b6SAart Bik         /*doc=*/nullptr, /*library_call=*/nullptr);
412d3b3f765SJacques Pienaar     Block &prodBlock = prod.getRegion().front();
413d3b3f765SJacques Pienaar     Block &consBlock = op.getRegion().front();
4144d67b278SJeff Niu     IRMapping mapper;
41591d5653eSMatthias Springer     Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
41628ebb0b6SAart Bik     unsigned num = prodBlock.getNumArguments();
41728ebb0b6SAart Bik     for (unsigned i = 0; i < num - 1; i++)
41828ebb0b6SAart Bik       addArg(mapper, fusedBlock, prodBlock.getArgument(i));
41928ebb0b6SAart Bik     addArg(mapper, fusedBlock, consBlock.getArgument(1 - other));
42028ebb0b6SAart Bik     addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
42128ebb0b6SAart Bik     // Clone bodies of the producer and consumer in new evaluation order.
42228ebb0b6SAart Bik     auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
42328ebb0b6SAart Bik     auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp();
42428ebb0b6SAart Bik     Value last;
42528ebb0b6SAart Bik     for (auto &op : prodBlock.without_terminator())
42628ebb0b6SAart Bik       if (&op != acc) {
42728ebb0b6SAart Bik         last = op.getResult(0);
42828ebb0b6SAart Bik         rewriter.clone(op, mapper);
42928ebb0b6SAart Bik       }
43028ebb0b6SAart Bik     mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0));
43128ebb0b6SAart Bik     mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
43228ebb0b6SAart Bik     last = rewriter.clone(*acc, mapper)->getResult(0);
43328ebb0b6SAart Bik     rewriter.create<linalg::YieldOp>(loc, last);
434ce3d0e87SAart Bik     // Force initial value on merged allocation for dense outputs.
4356a45339bSAart Bik     // TODO: deal with non alloc tensor here one day
436ce3d0e87SAart Bik     if (!getSparseTensorEncoding(op.getResult(0).getType())) {
437b4db15a9SAlexander Belyaev       Value init = prod.getDpsInitOperand(0)
438c7bb69bcSAart Bik                        ->get()
439c7bb69bcSAart Bik                        .getDefiningOp<AllocTensorOp>()
440c7bb69bcSAart Bik                        .getCopy();
441c7bb69bcSAart Bik       AllocTensorOp a =
442b4db15a9SAlexander Belyaev           op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
4435fcf907bSMatthias Springer       rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });
444ce3d0e87SAart Bik     }
44528ebb0b6SAart Bik     // Replace consumer with fused operation. Old producer
44628ebb0b6SAart Bik     // and consumer ops will be removed by DCE.
44728ebb0b6SAart Bik     rewriter.replaceOp(op, fusedOp->getResults());
44828ebb0b6SAart Bik     return success();
44928ebb0b6SAart Bik   }
45028ebb0b6SAart Bik 
45128ebb0b6SAart Bik private:
45228ebb0b6SAart Bik   // Helper to add argument and record the mapping.
4534d67b278SJeff Niu   static void addArg(IRMapping &mapper, Block *b, BlockArgument a) {
45428ebb0b6SAart Bik     mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
45528ebb0b6SAart Bik   }
45628ebb0b6SAart Bik };
45728ebb0b6SAart Bik 
4589a018a7bSAart Bik // Fuse a tensor cast into producing operation. Note that a tensor.cast
4599a018a7bSAart Bik // should really not be used to convert between sparse encodings. Since
4609a018a7bSAart Bik // the pattern currently appears as a result of some prior rewriting
4619a018a7bSAart Bik // we make an attempt to repair very obvious cases.
4629a018a7bSAart Bik // TODO: audit the pure tensor dialect rewriting rules
4639a018a7bSAart Bik struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
4649a018a7bSAart Bik public:
4659a018a7bSAart Bik   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
4669a018a7bSAart Bik 
4679a018a7bSAart Bik   LogicalResult matchAndRewrite(tensor::CastOp op,
4689a018a7bSAart Bik                                 PatternRewriter &rewriter) const override {
4699a018a7bSAart Bik     Type srcType = op.getSource().getType();
4709a018a7bSAart Bik     Type dstType = op.getDest().getType();
4719a018a7bSAart Bik     // A nop cast simply folds away.
4729a018a7bSAart Bik     if (srcType == dstType) {
4739a018a7bSAart Bik       rewriter.replaceOp(op, op->getResults());
4749a018a7bSAart Bik       return success();
4759a018a7bSAart Bik     }
4769a018a7bSAart Bik     // See if a sparsity changing cast can be fused into producer.
4779a018a7bSAart Bik     if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) {
4789a018a7bSAart Bik       if (Operation *def = op.getSource().getDefiningOp()) {
4799a018a7bSAart Bik         if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
4805fcf907bSMatthias Springer           rewriter.modifyOpInPlace(def, [&]() {
4819a018a7bSAart Bik             def->getResult(0).setType(op->getResultTypes()[0]);
4823d90c812SMatthias Springer           });
4839a018a7bSAart Bik           rewriter.replaceOp(op, def->getResult(0));
4849a018a7bSAart Bik           return success();
4859a018a7bSAart Bik         }
4869a018a7bSAart Bik       }
4879a018a7bSAart Bik     }
4889a018a7bSAart Bik     // Repair tensor casts with at least one sparse operand into the
4899a018a7bSAart Bik     // the properly supported sparse_tensor.convert.
4909a018a7bSAart Bik     if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) {
4919a018a7bSAart Bik       rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource());
4929a018a7bSAart Bik       return success();
4939a018a7bSAart Bik     }
4949a018a7bSAart Bik     // Fail otherwise.
4959a018a7bSAart Bik     return failure();
4969a018a7bSAart Bik   }
4979a018a7bSAart Bik };
4989a018a7bSAart Bik 
499e7df8281SPeiming Liu /// Rewrites a sequence of operations for sparse tensor selections in to
500c43e6274STim Harvey /// semi-ring operations such that they can be compiled correctly by the
501c43e6274STim Harvey /// sparsifier. E.g., transforming the following sequence
502e7df8281SPeiming Liu ///
503e7df8281SPeiming Liu /// %sel = arith.select %cond, %sp1, %sp2
504e7df8281SPeiming Liu ///
505e7df8281SPeiming Liu /// to
506e7df8281SPeiming Liu ///
507e7df8281SPeiming Liu /// %sel = binary %sp1, %sp2:
508e7df8281SPeiming Liu ///         both  (%l, %r) {yield select %cond, %l, %r}
509e7df8281SPeiming Liu ///         left  (%l)     {yield select %cond, %l,  0}
510e7df8281SPeiming Liu ///         right (%r)     {yield select %cond,  0, %r}
511e7df8281SPeiming Liu ///
512e7df8281SPeiming Liu /// TODO: We require that the tensor used for extracting conditions to be dense
513e7df8281SPeiming Liu /// to sparsify the code. To support a sparse condition tensor, we need a
514e7df8281SPeiming Liu /// tri-nary operation.
515e7df8281SPeiming Liu struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
516e7df8281SPeiming Liu public:
517e7df8281SPeiming Liu   using OpRewritePattern<GenericOp>::OpRewritePattern;
518e7df8281SPeiming Liu   LogicalResult matchAndRewrite(GenericOp op,
519e7df8281SPeiming Liu                                 PatternRewriter &rewriter) const override {
520e7df8281SPeiming Liu     // Rejects non sparse kernels.
5210a8e3dd4SMatthias Springer     if (!op.hasPureTensorSemantics() || !hasAnySparseOperand(op))
522e7df8281SPeiming Liu       return failure();
523e7df8281SPeiming Liu 
524e7df8281SPeiming Liu     Location loc = op.getLoc();
525e7df8281SPeiming Liu     SmallVector<std::pair<Operation *, sparse_tensor::BinaryOp>> semiRings;
526e7df8281SPeiming Liu     for (Operation &inst : *op.getBody()) {
527e7df8281SPeiming Liu       // Matches pattern.
528e7df8281SPeiming Liu       auto matched = isRewritablePattern(op, &inst);
529e7df8281SPeiming Liu       if (!matched.has_value())
530e7df8281SPeiming Liu         continue;
531e7df8281SPeiming Liu 
532e7df8281SPeiming Liu       rewriter.setInsertionPoint(&inst);
533e7df8281SPeiming Liu       auto [c, t, f] = matched.value();
534e7df8281SPeiming Liu       assert(t.getType() == f.getType());
535e7df8281SPeiming Liu       auto selTp = t.getType();
536e7df8281SPeiming Liu       auto c0 = constantZero(rewriter, loc, selTp);
537e7df8281SPeiming Liu       auto binOp = rewriter.create<sparse_tensor::BinaryOp>(loc, selTp, t, f);
538e7df8281SPeiming Liu       // Initializes all the blocks.
539e7df8281SPeiming Liu       rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
540e7df8281SPeiming Liu                            {t.getLoc(), f.getLoc()});
541e7df8281SPeiming Liu       rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
542e7df8281SPeiming Liu       rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
543e7df8281SPeiming Liu 
544e7df8281SPeiming Liu       for (auto *r : binOp.getRegions()) {
545e7df8281SPeiming Liu         Block *b = &r->front();
546e7df8281SPeiming Liu         rewriter.setInsertionPointToStart(b);
547e7df8281SPeiming Liu 
548e7df8281SPeiming Liu         IRMapping irMap;
549e7df8281SPeiming Liu         // Clones the cmp operations into the region to make the binary op
550e7df8281SPeiming Liu         // admissible.
551e7df8281SPeiming Liu         Value newC = c;
552e7df8281SPeiming Liu         if (auto *def = c.getDefiningOp())
553e7df8281SPeiming Liu           newC = rewriter.clone(*def, irMap)->getResult(0);
554e7df8281SPeiming Liu 
555e7df8281SPeiming Liu         irMap.map(c, newC);
556e7df8281SPeiming Liu         if (r == &binOp.getLeftRegion()) {
557e7df8281SPeiming Liu           irMap.map(t, b->getArgument(0));
558e7df8281SPeiming Liu           irMap.map(f, c0);
559e7df8281SPeiming Liu         } else if (r == &binOp.getRightRegion()) {
560e7df8281SPeiming Liu           irMap.map(t, c0);
561e7df8281SPeiming Liu           irMap.map(f, b->getArgument(0));
562e7df8281SPeiming Liu         } else {
563e7df8281SPeiming Liu           irMap.map(t, b->getArgument(0));
564e7df8281SPeiming Liu           irMap.map(f, b->getArgument(1));
565e7df8281SPeiming Liu         }
566e7df8281SPeiming Liu         auto y = rewriter.clone(inst, irMap)->getResult(0);
567e7df8281SPeiming Liu         rewriter.create<sparse_tensor::YieldOp>(loc, y);
568e7df8281SPeiming Liu       }
569e7df8281SPeiming Liu 
570e7df8281SPeiming Liu       // We successfully rewrited a operation. We can not do replacement here
571e7df8281SPeiming Liu       // becuase it invalidate the iterator for the current loop to traverse
572e7df8281SPeiming Liu       // the instructions.
573e7df8281SPeiming Liu       semiRings.emplace_back(&inst, binOp);
574e7df8281SPeiming Liu     }
575e7df8281SPeiming Liu 
576e7df8281SPeiming Liu     // Finalizes the replacement.
577e7df8281SPeiming Liu     for (auto [sel, semi] : semiRings)
578e7df8281SPeiming Liu       rewriter.replaceOp(sel, semi->getResults());
579e7df8281SPeiming Liu 
580e7df8281SPeiming Liu     return success(!semiRings.empty());
581e7df8281SPeiming Liu   }
582e7df8281SPeiming Liu 
583e7df8281SPeiming Liu private:
584e7df8281SPeiming Liu   static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
585e7df8281SPeiming Liu   isRewritablePattern(GenericOp op, Operation *v) {
586e7df8281SPeiming Liu     auto sel = dyn_cast<arith::SelectOp>(v);
587e7df8281SPeiming Liu     if (!sel)
588e7df8281SPeiming Liu       return std::nullopt;
589e7df8281SPeiming Liu 
590a5757c5bSChristian Sigg     auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
591a5757c5bSChristian Sigg     auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
592e7df8281SPeiming Liu     // TODO: For simplicity, we only handle cases where both true/false value
593e7df8281SPeiming Liu     // are directly loaded the input tensor. We can probably admit more cases
594e7df8281SPeiming Liu     // in theory.
595e7df8281SPeiming Liu     if (!tVal || !fVal)
596e7df8281SPeiming Liu       return std::nullopt;
597e7df8281SPeiming Liu 
598e7df8281SPeiming Liu     // Helper lambda to determine whether the value is loaded from a dense input
599e7df8281SPeiming Liu     // or is a loop invariant.
600e7df8281SPeiming Liu     auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
601a5757c5bSChristian Sigg       if (auto bArg = dyn_cast<BlockArgument>(v);
602e7df8281SPeiming Liu           bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
603e7df8281SPeiming Liu         return true;
604e7df8281SPeiming Liu       // If the value is defined outside the loop, it is a loop invariant.
605e7df8281SPeiming Liu       return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody();
606e7df8281SPeiming Liu     };
607e7df8281SPeiming Liu 
608e7df8281SPeiming Liu     // If the condition value is load directly from a dense tensor or
609e7df8281SPeiming Liu     // loop-invariants, we can sparsify the kernel.
610e7df8281SPeiming Liu     auto cond = sel.getCondition();
611e7df8281SPeiming Liu     if (isValFromDenseInputOrInvariant(cond))
612e7df8281SPeiming Liu       return std::make_tuple(cond, tVal, fVal);
613e7df8281SPeiming Liu 
614e7df8281SPeiming Liu     Value cmpL, cmpR;
615e7df8281SPeiming Liu     if (matchPattern(cond, m_Op<arith::CmpIOp>(matchers::m_Any(&cmpL),
616e7df8281SPeiming Liu                                                matchers::m_Any(&cmpR))) ||
617e7df8281SPeiming Liu         matchPattern(cond, m_Op<arith::CmpFOp>(matchers::m_Any(&cmpL),
618e7df8281SPeiming Liu                                                matchers::m_Any(&cmpR)))) {
619e7df8281SPeiming Liu       // TODO: we can do it recursively to check whether all the leaf values are
620e7df8281SPeiming Liu       // loaded from dense tensors or are loop invariants.
621e7df8281SPeiming Liu       if (isValFromDenseInputOrInvariant(cmpL) ||
622e7df8281SPeiming Liu           isValFromDenseInputOrInvariant(cmpR))
623e7df8281SPeiming Liu         return std::make_tuple(cond, tVal, fVal);
624e7df8281SPeiming Liu     }
625e7df8281SPeiming Liu 
626e7df8281SPeiming Liu     return std::nullopt;
627e7df8281SPeiming Liu   };
628e7df8281SPeiming Liu };
629e7df8281SPeiming Liu 
63080fe3168SAart Bik /// Rewrites a sparse reduction that would not sparsify directly since
63180fe3168SAart Bik /// doing so would only iterate over the stored elements, ignoring the
63280fe3168SAart Bik /// implicit zeros, into a semi-ring. Applies to all prod/and/min/max
63380fe3168SAart Bik /// (note that reductions like add/sub/or/xor can directly be sparsified
63480fe3168SAart Bik /// since the implicit zeros do not contribute to the final result).
63580fe3168SAart Bik /// Note that prod/and are still included since, even though they often
63680fe3168SAart Bik /// are nullified in sparse data, they may still occur for special
63780fe3168SAart Bik /// situations in which e.g. some rows in a sparse matrix are fully
63880fe3168SAart Bik /// dense. For min/max, including the implicit zeros is a much more
63980fe3168SAart Bik /// common situation.
64080fe3168SAart Bik ///
64180fe3168SAart Bik /// TODO: this essentially "densifies" the operation; we want to implement
64280fe3168SAart Bik ///       this much more efficiently by performing the reduction over the
64380fe3168SAart Bik ///       stored values, and feed in the zero once if there were *any*
64480fe3168SAart Bik ///       implicit zeros as well; but for now, at least we provide
64580fe3168SAart Bik ///       the functionality
64680fe3168SAart Bik ///
64780fe3168SAart Bik struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
64880fe3168SAart Bik public:
64980fe3168SAart Bik   using OpRewritePattern<GenericOp>::OpRewritePattern;
65080fe3168SAart Bik 
65180fe3168SAart Bik   LogicalResult matchAndRewrite(GenericOp op,
65280fe3168SAart Bik                                 PatternRewriter &rewriter) const override {
65380fe3168SAart Bik     // Reject non-reductions.
6540a8e3dd4SMatthias Springer     if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 ||
65580fe3168SAart Bik         op.getNumReductionLoops() == 0 || op.getNumResults() != 1)
65680fe3168SAart Bik       return failure();
65761f64d1cSMehdi Amini     auto *inp = op.getDpsInputOperand(0);
65861f64d1cSMehdi Amini     auto *init = op.getDpsInitOperand(0);
65980fe3168SAart Bik     if (!isSparseTensor(inp))
66080fe3168SAart Bik       return failure();
66180fe3168SAart Bik     // Look for direct x = x OP y for semi-ring ready reductions.
66261f64d1cSMehdi Amini     auto *red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator())
66380fe3168SAart Bik                     .getOperand(0)
66480fe3168SAart Bik                     .getDefiningOp();
6658a6e54c9SDaniil Dudkin     if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp,
6668a6e54c9SDaniil Dudkin              arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp,
66780fe3168SAart Bik              arith::MaxUIOp>(red))
66880fe3168SAart Bik       return failure();
66980fe3168SAart Bik     Value s0 = op.getBlock()->getArgument(0);
67080fe3168SAart Bik     Value s1 = op.getBlock()->getArgument(1);
67180fe3168SAart Bik     if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&
67280fe3168SAart Bik         (red->getOperand(0) != s1 || red->getOperand(1) != s0))
67380fe3168SAart Bik       return failure();
67480fe3168SAart Bik     // Identity.
67580fe3168SAart Bik     Location loc = op.getLoc();
67680fe3168SAart Bik     Value identity =
67780fe3168SAart Bik         rewriter.create<tensor::ExtractOp>(loc, init->get(), ValueRange());
67880fe3168SAart Bik     // Unary {
67980fe3168SAart Bik     //    present -> value
68080fe3168SAart Bik     //    absent  -> zero.
68180fe3168SAart Bik     // }
68280fe3168SAart Bik     Type rtp = s0.getType();
68380fe3168SAart Bik     rewriter.setInsertionPointToStart(&op.getRegion().front());
68480fe3168SAart Bik     auto semiring = rewriter.create<sparse_tensor::UnaryOp>(loc, rtp, s0);
68580fe3168SAart Bik     Block *present =
68680fe3168SAart Bik         rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc);
68780fe3168SAart Bik     rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front());
68880fe3168SAart Bik     rewriter.create<sparse_tensor::YieldOp>(loc, present->getArgument(0));
68980fe3168SAart Bik     rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {});
69080fe3168SAart Bik     rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front());
69180fe3168SAart Bik     auto zero =
69280fe3168SAart Bik         rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(rtp));
69380fe3168SAart Bik     rewriter.create<sparse_tensor::YieldOp>(loc, zero);
69480fe3168SAart Bik     rewriter.setInsertionPointAfter(semiring);
69580fe3168SAart Bik     // CustomReduce {
69680fe3168SAart Bik     //    x = x REDUC y, identity
69780fe3168SAart Bik     // }
69880fe3168SAart Bik     auto custom = rewriter.create<sparse_tensor::ReduceOp>(
69980fe3168SAart Bik         loc, rtp, semiring.getResult(), s1, identity);
70080fe3168SAart Bik     Block *region =
70180fe3168SAart Bik         rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});
70280fe3168SAart Bik     rewriter.setInsertionPointToStart(&custom.getRegion().front());
70380fe3168SAart Bik     IRMapping irMap;
70480fe3168SAart Bik     irMap.map(red->getOperand(0), region->getArgument(0));
70580fe3168SAart Bik     irMap.map(red->getOperand(1), region->getArgument(1));
70661f64d1cSMehdi Amini     auto *cloned = rewriter.clone(*red, irMap);
70780fe3168SAart Bik     rewriter.create<sparse_tensor::YieldOp>(loc, cloned->getResult(0));
70880fe3168SAart Bik     rewriter.setInsertionPointAfter(custom);
70980fe3168SAart Bik     rewriter.replaceOp(red, custom.getResult());
71080fe3168SAart Bik     return success();
71180fe3168SAart Bik   }
71280fe3168SAart Bik };
71380fe3168SAart Bik 
714d37affb0SAart Bik /// Sparse rewriting rule for the print operator. This operation is mainly used
715d37affb0SAart Bik /// for debugging and testing. As such, it lowers to the vector.print operation
716d37affb0SAart Bik /// which only require very light-weight runtime support.
717d37affb0SAart Bik struct PrintRewriter : public OpRewritePattern<PrintOp> {
718d37affb0SAart Bik public:
719d37affb0SAart Bik   using OpRewritePattern::OpRewritePattern;
720d37affb0SAart Bik   LogicalResult matchAndRewrite(PrintOp op,
721d37affb0SAart Bik                                 PatternRewriter &rewriter) const override {
722d37affb0SAart Bik     Location loc = op.getLoc();
723d37affb0SAart Bik     auto tensor = op.getTensor();
724d37affb0SAart Bik     auto stt = getSparseTensorType(tensor);
725d37affb0SAart Bik     // Header with NSE.
726d37affb0SAart Bik     auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor);
727d37affb0SAart Bik     rewriter.create<vector::PrintOp>(
728d37affb0SAart Bik         loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
729d37affb0SAart Bik     rewriter.create<vector::PrintOp>(loc, nse);
730691fc7cdSAart Bik     // Print run-time contents for dim/lvl sizes.
731691fc7cdSAart Bik     rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("dim = "));
732691fc7cdSAart Bik     printSizes(rewriter, loc, tensor, stt.getDimRank(), /*isDim=*/true);
733691fc7cdSAart Bik     rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("lvl = "));
734691fc7cdSAart Bik     printSizes(rewriter, loc, tensor, stt.getLvlRank(), /*isDim=*/false);
735d37affb0SAart Bik     // Use the "codegen" foreach loop construct to iterate over
736d37affb0SAart Bik     // all typical sparse tensor components for printing.
7376bc7c9dfSPeiming Liu     foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor,
7386bc7c9dfSPeiming Liu                                             &stt](Type, FieldIndex,
739d37affb0SAart Bik                                                   SparseTensorFieldKind kind,
740d37affb0SAart Bik                                                   Level l, LevelType) {
741d37affb0SAart Bik       switch (kind) {
742d37affb0SAart Bik       case SparseTensorFieldKind::StorageSpec: {
743d37affb0SAart Bik         break;
744d37affb0SAart Bik       }
745d37affb0SAart Bik       case SparseTensorFieldKind::PosMemRef: {
746d37affb0SAart Bik         auto lvl = constantIndex(rewriter, loc, l);
747d37affb0SAart Bik         rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos["));
748d37affb0SAart Bik         rewriter.create<vector::PrintOp>(
749d37affb0SAart Bik             loc, lvl, vector::PrintPunctuation::NoPunctuation);
750d37affb0SAart Bik         rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
7516bc7c9dfSPeiming Liu         auto pos = rewriter.create<ToPositionsOp>(loc, tensor, l);
7526bc7c9dfSPeiming Liu         printContents(rewriter, loc, pos);
753d37affb0SAart Bik         break;
754d37affb0SAart Bik       }
755d37affb0SAart Bik       case SparseTensorFieldKind::CrdMemRef: {
756d37affb0SAart Bik         auto lvl = constantIndex(rewriter, loc, l);
757d37affb0SAart Bik         rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("crd["));
758d37affb0SAart Bik         rewriter.create<vector::PrintOp>(
759d37affb0SAart Bik             loc, lvl, vector::PrintPunctuation::NoPunctuation);
760d37affb0SAart Bik         rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
7616bc7c9dfSPeiming Liu         Value crd = nullptr;
762dc4cfdbbSAart Bik         // For COO AoS storage, we want to print a single, linear view of
763dc4cfdbbSAart Bik         // the full coordinate storage at this level. For any other storage,
764dc4cfdbbSAart Bik         // we show the coordinate storage for every indivual level.
7656bc7c9dfSPeiming Liu         if (stt.getAoSCOOStart() == l)
7666bc7c9dfSPeiming Liu           crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor);
7676bc7c9dfSPeiming Liu         else
7686bc7c9dfSPeiming Liu           crd = rewriter.create<ToCoordinatesOp>(loc, tensor, l);
7696bc7c9dfSPeiming Liu         printContents(rewriter, loc, crd);
770d37affb0SAart Bik         break;
771d37affb0SAart Bik       }
772d37affb0SAart Bik       case SparseTensorFieldKind::ValMemRef: {
773d37affb0SAart Bik         rewriter.create<vector::PrintOp>(loc,
774d37affb0SAart Bik                                          rewriter.getStringAttr("values : "));
7756bc7c9dfSPeiming Liu         auto val = rewriter.create<ToValuesOp>(loc, tensor);
7766bc7c9dfSPeiming Liu         printContents(rewriter, loc, val);
777d37affb0SAart Bik         break;
778d37affb0SAart Bik       }
779d37affb0SAart Bik       }
780d37affb0SAart Bik       return true;
781d37affb0SAart Bik     });
782d37affb0SAart Bik     rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("----\n"));
783d37affb0SAart Bik     rewriter.eraseOp(op);
784d37affb0SAart Bik     return success();
785d37affb0SAart Bik   }
786d37affb0SAart Bik 
787d37affb0SAart Bik private:
788c4e5a8a4SAart Bik   // Helper to print contents of a single memref. For "push_back" vectors,
789c4e5a8a4SAart Bik   // we assume that the previous getters for pos/crd/val have added a
790c4e5a8a4SAart Bik   // slice-to-size view to make sure we just print the size and not the
791c4e5a8a4SAart Bik   // full capacity.
792d37affb0SAart Bik   //
793c4e5a8a4SAart Bik   // Generates code to print (1-dim or higher):
794d37affb0SAart Bik   //    ( a0, a1, ... )
7956bc7c9dfSPeiming Liu   static void printContents(PatternRewriter &rewriter, Location loc,
796d37affb0SAart Bik                             Value vec) {
797c4e5a8a4SAart Bik     auto shape = cast<ShapedType>(vec.getType()).getShape();
798c4e5a8a4SAart Bik     SmallVector<Value> idxs;
799c4e5a8a4SAart Bik     printContentsLevel(rewriter, loc, vec, 0, shape, idxs);
800c4e5a8a4SAart Bik     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
801c4e5a8a4SAart Bik   }
802c4e5a8a4SAart Bik 
803c4e5a8a4SAart Bik   // Helper to the helper.
804c4e5a8a4SAart Bik   static void printContentsLevel(PatternRewriter &rewriter, Location loc,
805c4e5a8a4SAart Bik                                  Value vec, unsigned i, ArrayRef<int64_t> shape,
806c4e5a8a4SAart Bik                                  SmallVectorImpl<Value> &idxs) {
807d37affb0SAart Bik     // Open bracket.
808d37affb0SAart Bik     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
809c4e5a8a4SAart Bik     // Generate for loop.
810d37affb0SAart Bik     auto zero = constantIndex(rewriter, loc, 0);
811c4e5a8a4SAart Bik     auto index = constantIndex(rewriter, loc, i);
812c4e5a8a4SAart Bik     auto size = rewriter.create<memref::DimOp>(loc, vec, index);
813d37affb0SAart Bik     auto step = constantIndex(rewriter, loc, 1);
814d37affb0SAart Bik     auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
815c4e5a8a4SAart Bik     idxs.push_back(forOp.getInductionVar());
816d37affb0SAart Bik     rewriter.setInsertionPointToStart(forOp.getBody());
817c4e5a8a4SAart Bik     if (i < shape.size() - 1) {
818c4e5a8a4SAart Bik       // Enter deeper loop nest.
819c4e5a8a4SAart Bik       printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);
820c4e5a8a4SAart Bik     } else {
821c4e5a8a4SAart Bik       // Actual contents printing.
822c4e5a8a4SAart Bik       auto val = rewriter.create<memref::LoadOp>(loc, vec, idxs);
823275fe3aeSAart Bik       if (llvm::isa<ComplexType>(val.getType())) {
824275fe3aeSAart Bik         // Since the vector dialect does not support complex types in any op,
825275fe3aeSAart Bik         // we split those into (real, imag) pairs here.
826275fe3aeSAart Bik         Value real = rewriter.create<complex::ReOp>(loc, val);
827275fe3aeSAart Bik         Value imag = rewriter.create<complex::ImOp>(loc, val);
828275fe3aeSAart Bik         rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
829275fe3aeSAart Bik         rewriter.create<vector::PrintOp>(loc, real,
830275fe3aeSAart Bik                                          vector::PrintPunctuation::Comma);
831275fe3aeSAart Bik         rewriter.create<vector::PrintOp>(loc, imag,
832275fe3aeSAart Bik                                          vector::PrintPunctuation::Close);
833275fe3aeSAart Bik       } else {
834eb177803SYinying Li         rewriter.create<vector::PrintOp>(
835eb177803SYinying Li             loc, val, vector::PrintPunctuation::NoPunctuation);
836275fe3aeSAart Bik       }
837eb177803SYinying Li       // Terminating comma (except at end).
838eb177803SYinying Li       auto bound = rewriter.create<arith::AddIOp>(loc, idxs.back(), step);
839eb177803SYinying Li       Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
840eb177803SYinying Li                                                   bound, size);
841eb177803SYinying Li       scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
842eb177803SYinying Li       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
843eb177803SYinying Li       rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
844c4e5a8a4SAart Bik     }
845c4e5a8a4SAart Bik     idxs.pop_back();
846d37affb0SAart Bik     rewriter.setInsertionPointAfter(forOp);
847c4e5a8a4SAart Bik     // Close bracket.
848d37affb0SAart Bik     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
849d37affb0SAart Bik   }
850691fc7cdSAart Bik 
851691fc7cdSAart Bik   // Helper method to print run-time lvl/dim sizes.
852691fc7cdSAart Bik   static void printSizes(PatternRewriter &rewriter, Location loc, Value tensor,
853691fc7cdSAart Bik                          unsigned size, bool isDim) {
854691fc7cdSAart Bik     // Open bracket.
855691fc7cdSAart Bik     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
856691fc7cdSAart Bik     // Print unrolled contents (dimop requires constant value).
857691fc7cdSAart Bik     for (unsigned i = 0; i < size; i++) {
858691fc7cdSAart Bik       auto idx = constantIndex(rewriter, loc, i);
859691fc7cdSAart Bik       Value val;
860691fc7cdSAart Bik       if (isDim)
861691fc7cdSAart Bik         val = rewriter.create<tensor::DimOp>(loc, tensor, idx);
862691fc7cdSAart Bik       else
863691fc7cdSAart Bik         val = rewriter.create<LvlOp>(loc, tensor, idx);
864691fc7cdSAart Bik       rewriter.create<vector::PrintOp>(
865691fc7cdSAart Bik           loc, val,
866691fc7cdSAart Bik           i != size - 1 ? vector::PrintPunctuation::Comma
867691fc7cdSAart Bik                         : vector::PrintPunctuation::NoPunctuation);
868691fc7cdSAart Bik     }
869691fc7cdSAart Bik     // Close bracket and end of line.
870691fc7cdSAart Bik     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
871691fc7cdSAart Bik     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
872691fc7cdSAart Bik   }
873d37affb0SAart Bik };
874d37affb0SAart Bik 
875330d48c4Sbixia1 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
8766116ca67SAnlun Xu struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
8776116ca67SAnlun Xu public:
8786116ca67SAnlun Xu   using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern;
8796116ca67SAnlun Xu 
8806116ca67SAnlun Xu   LogicalResult matchAndRewrite(tensor::ReshapeOp op,
8816116ca67SAnlun Xu                                 PatternRewriter &rewriter) const override {
8826116ca67SAnlun Xu     Location loc = op.getLoc();
8836116ca67SAnlun Xu     Value srcTensor = op.getSource();
884*129ade21SLongsheng Mou     const auto srcTp = tryGetSparseTensorType(srcTensor);
885*129ade21SLongsheng Mou     const auto dstTp = tryGetSparseTensorType(op.getResult());
886*129ade21SLongsheng Mou     if (!srcTp || !dstTp)
887*129ade21SLongsheng Mou       return failure();
8886116ca67SAnlun Xu 
889*129ade21SLongsheng Mou     if (!srcTp->hasEncoding() || !dstTp->hasEncoding() ||
890*129ade21SLongsheng Mou         !dstTp->hasStaticDimShape())
8916116ca67SAnlun Xu       return failure();
8926116ca67SAnlun Xu 
8936116ca67SAnlun Xu     SmallVector<Value> srcSizes;
894*129ade21SLongsheng Mou     sizesForTensor(rewriter, srcSizes, loc, *srcTp, srcTensor);
8956116ca67SAnlun Xu     SmallVector<Value> dstSizes;
896*129ade21SLongsheng Mou     for (Dimension d : dstTp->getDimShape())
8976116ca67SAnlun Xu       dstSizes.push_back(constantIndex(rewriter, loc, d));
8986116ca67SAnlun Xu 
8996116ca67SAnlun Xu     Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
9006116ca67SAnlun Xu     // Only need an unordered COO buffer if input and output are not sorted
9016116ca67SAnlun Xu     // in the same way.
90276647fceSwren romano     Type bufferTp = getBufferType(
903*129ade21SLongsheng Mou         dstTp->withoutDimToLvl(),
904*129ade21SLongsheng Mou         !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity());
9056116ca67SAnlun Xu     SmallVector<Value> dynSizes;
9066116ca67SAnlun Xu     Value buffer = rewriter
9076116ca67SAnlun Xu                        .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
9086116ca67SAnlun Xu                                               nnz, Attribute())
9096116ca67SAnlun Xu                        .getResult();
9106116ca67SAnlun Xu 
9116116ca67SAnlun Xu     // Convert src coordinates to dst coordinates by first collapsing it to 1D
9126116ca67SAnlun Xu     // and then expand it to the match the rank of the destination tensor.
9136116ca67SAnlun Xu     // Implemented as follows:
9146116ca67SAnlun Xu     //   foreach srcCoords %srcTensor
9156116ca67SAnlun Xu     //     collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank])
9166116ca67SAnlun Xu     //     expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank])
9176116ca67SAnlun Xu     //     insert expandedCoords, %buffer
9186116ca67SAnlun Xu     //
9196116ca67SAnlun Xu     // followed by an optional
9206116ca67SAnlun Xu     //   %t = sparse_tensor.cast %tmp
9216116ca67SAnlun Xu     // depending on whether the input/output are sorted in the same way.
922*129ade21SLongsheng Mou     const auto encSrc = srcTp->getEncoding();
9236116ca67SAnlun Xu     ForeachOp foreachOp = rewriter.create<ForeachOp>(
9246116ca67SAnlun Xu         loc, srcTensor, buffer,
9256116ca67SAnlun Xu         [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
9266116ca67SAnlun Xu             ValueRange reduc) {
927*129ade21SLongsheng Mou           const Dimension srcRank = srcTp->getDimRank();
9286116ca67SAnlun Xu           SmallVector<Value> srcDcvs;
9296116ca67SAnlun Xu           srcDcvs.reserve(srcRank);
9306116ca67SAnlun Xu           for (Dimension d = 0; d < srcRank; d++) {
9314e2f1521SPeiming Liu             Level lvl = toLvl(encSrc, d);
9326116ca67SAnlun Xu             srcDcvs.push_back(srcLcvs[lvl]);
9336116ca67SAnlun Xu           }
9346116ca67SAnlun Xu 
935eb14f47bSPeiming Liu           Value collapseSize = constantIndex(builder, loc, 1);
9366116ca67SAnlun Xu           for (Dimension d = 0; d < srcRank; d++)
937eb14f47bSPeiming Liu             collapseSize =
938eb14f47bSPeiming Liu                 builder.create<arith::MulIOp>(loc, collapseSize, srcSizes[d]);
939eb14f47bSPeiming Liu           SmallVector<Value, 1> collapsedSizes = {collapseSize};
9406116ca67SAnlun Xu 
941eb14f47bSPeiming Liu           ReassociationIndices collapseIdx;
9426116ca67SAnlun Xu           for (Dimension i = 0; i < srcRank; i++)
943eb14f47bSPeiming Liu             collapseIdx.push_back(i);
944eb14f47bSPeiming Liu           SmallVector<ReassociationIndices, 1> collapseReass = {collapseIdx};
9456116ca67SAnlun Xu           SmallVector<Value, 1> collapsedDcvs;
946eb14f47bSPeiming Liu           reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs,
9476116ca67SAnlun Xu                      collapsedSizes, collapsedDcvs);
9486116ca67SAnlun Xu 
949eb14f47bSPeiming Liu           ReassociationIndices expandIdx;
950*129ade21SLongsheng Mou           for (Dimension i = 0; i < dstTp->getDimRank(); i++)
951eb14f47bSPeiming Liu             expandIdx.push_back(i);
952eb14f47bSPeiming Liu           SmallVector<ReassociationIndices, 1> expandReass = {expandIdx};
9536116ca67SAnlun Xu           SmallVector<Value> dstDcvs;
954eb14f47bSPeiming Liu           reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
955eb14f47bSPeiming Liu                      dstSizes, dstDcvs);
9566116ca67SAnlun Xu 
95794e27c26SPeiming Liu           auto t =
95894e27c26SPeiming Liu               builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
9596116ca67SAnlun Xu           builder.create<sparse_tensor::YieldOp>(loc, t);
9606116ca67SAnlun Xu         });
9616116ca67SAnlun Xu 
9626116ca67SAnlun Xu     Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
963*129ade21SLongsheng Mou     if (bufferTp != *dstTp) {
964*129ade21SLongsheng Mou       auto dstRTT = dstTp->getRankedTensorType();
9656116ca67SAnlun Xu       Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
9666116ca67SAnlun Xu       rewriter.create<DeallocTensorOp>(loc, t);
9676116ca67SAnlun Xu       t = converted;
9686116ca67SAnlun Xu     }
9696116ca67SAnlun Xu     rewriter.replaceOp(op, t);
9706116ca67SAnlun Xu     return success();
9716116ca67SAnlun Xu   }
9726116ca67SAnlun Xu };
9736116ca67SAnlun Xu 
9746116ca67SAnlun Xu /// Sparse rewriting rule for sparse-to-sparse reshape operator.
975330d48c4Sbixia1 template <typename ReshapeOp>
976330d48c4Sbixia1 struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
977330d48c4Sbixia1 public:
978330d48c4Sbixia1   using OpRewritePattern<ReshapeOp>::OpRewritePattern;
979330d48c4Sbixia1 
980330d48c4Sbixia1   LogicalResult matchAndRewrite(ReshapeOp op,
981330d48c4Sbixia1                                 PatternRewriter &rewriter) const override {
982330d48c4Sbixia1     Location loc = op.getLoc();
983330d48c4Sbixia1     Value srcTensor = op.getSrc();
984f2696e46Swren romano     const auto srcTp = getSparseTensorType(srcTensor);
985f2696e46Swren romano     const auto dstTp = getSparseTensorType(op.getResult());
986f2696e46Swren romano     if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
987330d48c4Sbixia1       return failure();
988330d48c4Sbixia1 
989330d48c4Sbixia1     // Generate code to represent the static dimension constants or compute
990330d48c4Sbixia1     // the dynamic dimension values.
9910e1708ffSAart Bik     SmallVector<Value> srcSizes;
992330d48c4Sbixia1     sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
9930e1708ffSAart Bik     SmallVector<Value> dstSizes;
9940e1708ffSAart Bik     SmallVector<Value> dstDynSizes;
995f2696e46Swren romano     if (dstTp.hasStaticDimShape()) {
996f2696e46Swren romano       for (Dimension d : dstTp.getDimShape())
997330d48c4Sbixia1         dstSizes.push_back(constantIndex(rewriter, loc, d));
998330d48c4Sbixia1     } else {
99922212ca7SAart Bik       ArrayRef<Size> dstShape = dstTp.getDimShape();
10009d4df97fSwren romano       genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape,
1001330d48c4Sbixia1                          op.getReassociationIndices());
10028c258fdaSJakub Kuderski       for (auto [idx, shape] : llvm::enumerate(dstShape)) {
10038c258fdaSJakub Kuderski         if (shape == ShapedType::kDynamic)
10048c258fdaSJakub Kuderski           dstDynSizes.push_back(dstSizes[idx]);
1005330d48c4Sbixia1       }
1006330d48c4Sbixia1     }
10073bd82f30SAart Bik     Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
1008c24547e9SPeiming Liu     // Only need a unordered COO buffer if input and output are not sorted
1009c24547e9SPeiming Liu     // in the same way.
101076647fceSwren romano     Type bufferTp = getBufferType(
101176647fceSwren romano         dstTp.withoutDimToLvl(),
101276647fceSwren romano         !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
1013c24547e9SPeiming Liu 
1014c24547e9SPeiming Liu     Value buffer =
10153bd82f30SAart Bik         rewriter
1016c24547e9SPeiming Liu             .create<AllocTensorOp>(loc, bufferTp, dstDynSizes, Value(),
10173bd82f30SAart Bik                                    /*sizeHint=*/nnz, Attribute())
10183bd82f30SAart Bik             .getResult();
10193bd82f30SAart Bik 
1020c24547e9SPeiming Liu     // Implement the sparse2sparse reshape as follows:
1021c24547e9SPeiming Liu     //   foreach srcCoords %srcTensor
1022c24547e9SPeiming Liu     //     insert reshapeCvs(srcCoords), %buffer
1023c24547e9SPeiming Liu     //
1024c24547e9SPeiming Liu     // followed by an optional
1025c24547e9SPeiming Liu     //   %t = sparse_tensor.cast %tmp
1026c24547e9SPeiming Liu     // depending on whether the input/output are sorted in the same way.
1027f2696e46Swren romano     const auto encSrc = srcTp.getEncoding();
10284fa00ce1SPeiming Liu     ForeachOp foreachOp = rewriter.create<ForeachOp>(
1029c24547e9SPeiming Liu         loc, srcTensor, buffer,
103084cd51bbSwren romano         [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
10314fa00ce1SPeiming Liu             ValueRange reduc) {
1032f2696e46Swren romano           const Dimension dimRank = srcTp.getDimRank();
103384cd51bbSwren romano           SmallVector<Value> srcDcvs;
103484cd51bbSwren romano           srcDcvs.reserve(dimRank);
103584cd51bbSwren romano           for (Dimension d = 0; d < dimRank; d++) {
10364e2f1521SPeiming Liu             Level lvl = toLvl(encSrc, d);
103784cd51bbSwren romano             srcDcvs.push_back(srcLcvs[lvl]);
1038330d48c4Sbixia1           }
103984cd51bbSwren romano           SmallVector<Value> dstDcvs;
104084cd51bbSwren romano           reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
104184cd51bbSwren romano                      srcDcvs, dstSizes, dstDcvs);
104294e27c26SPeiming Liu           auto t =
104394e27c26SPeiming Liu               builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
10444fa00ce1SPeiming Liu           builder.create<sparse_tensor::YieldOp>(loc, t);
1045330d48c4Sbixia1         });
1046c24547e9SPeiming Liu 
1047c24547e9SPeiming Liu     Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
1048c24547e9SPeiming Liu     if (bufferTp != dstTp) {
1049f2696e46Swren romano       auto dstRTT = dstTp.getRankedTensorType();
1050f2696e46Swren romano       Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
10518ffdcc59SPeiming Liu       rewriter.create<DeallocTensorOp>(loc, t);
1052c24547e9SPeiming Liu       t = converted;
1053c24547e9SPeiming Liu     }
1054c24547e9SPeiming Liu     rewriter.replaceOp(op, t);
1055330d48c4Sbixia1     return success();
1056330d48c4Sbixia1   }
1057330d48c4Sbixia1 };
1058330d48c4Sbixia1 
1059330d48c4Sbixia1 /// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape
1060330d48c4Sbixia1 /// operator.
106128ebb0b6SAart Bik template <typename ReshapeOp>
106228ebb0b6SAart Bik struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
106328ebb0b6SAart Bik public:
106428ebb0b6SAart Bik   using OpRewritePattern<ReshapeOp>::OpRewritePattern;
106528ebb0b6SAart Bik 
106628ebb0b6SAart Bik   LogicalResult matchAndRewrite(ReshapeOp op,
106728ebb0b6SAart Bik                                 PatternRewriter &rewriter) const override {
106828ebb0b6SAart Bik     Location loc = op->getLoc();
106928ebb0b6SAart Bik     auto encDst = getSparseTensorEncoding(op.getResult().getType());
107028ebb0b6SAart Bik     auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
107128ebb0b6SAart Bik     // Since a pure dense expansion is very cheap (change of view), for
107228ebb0b6SAart Bik     // a sparse2dense or dense2sparse, we can simply unfuse a sparse
107328ebb0b6SAart Bik     // conversion from the reshape operation itself.
107428ebb0b6SAart Bik     // All other cases are handled elsewhere.
107528ebb0b6SAart Bik     if (encDst && encSrc) {
107628ebb0b6SAart Bik       return failure();
10770449b6a0SMehdi Amini     }
10780449b6a0SMehdi Amini     if (encSrc) {
1079255c3f11Swren romano       auto rtp = getRankedTensorType(op.getSrc());
108028ebb0b6SAart Bik       auto denseTp =
108128ebb0b6SAart Bik           RankedTensorType::get(rtp.getShape(), rtp.getElementType());
108228ebb0b6SAart Bik       auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
10835fcf907bSMatthias Springer       rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); });
108428ebb0b6SAart Bik       return success();
1085550288cbSPeiming Liu     }
1086550288cbSPeiming Liu     if (encDst) {
1087255c3f11Swren romano       auto rtp = getRankedTensorType(op.getResult());
108828ebb0b6SAart Bik       auto denseTp =
108928ebb0b6SAart Bik           RankedTensorType::get(rtp.getShape(), rtp.getElementType());
109097069a86SGaurav Shukla       ReshapeOp reshape;
109197069a86SGaurav Shukla       if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) {
109297069a86SGaurav Shukla         reshape = rewriter.create<ReshapeOp>(
109397069a86SGaurav Shukla             loc, denseTp, op.getSrc(), op.getReassociation(),
109497069a86SGaurav Shukla             op.getOutputShape(), op.getStaticOutputShape());
109597069a86SGaurav Shukla       } else {
109697069a86SGaurav Shukla         reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
109728ebb0b6SAart Bik                                              op.getReassociation());
109897069a86SGaurav Shukla       }
109928ebb0b6SAart Bik       Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
110028ebb0b6SAart Bik       rewriter.replaceOp(op, convert);
110128ebb0b6SAart Bik       return success();
110228ebb0b6SAart Bik     }
110328ebb0b6SAart Bik     return failure();
110428ebb0b6SAart Bik   }
110528ebb0b6SAart Bik };
110628ebb0b6SAart Bik 
110771c97c73SPeiming Liu // A trivial wrapper to help generate different operations for dense/sparse
110871c97c73SPeiming Liu // tensors.
1109dda3dc5eSPeiming Liu struct TensorLike {
1110dda3dc5eSPeiming Liu   TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
111171c97c73SPeiming Liu              ValueRange sizes) {
1112dda3dc5eSPeiming Liu     SmallVector<Value> dynSzs;
1113dda3dc5eSPeiming Liu     getDynamicSizes(rtt, sizes, dynSzs);
1114dda3dc5eSPeiming Liu 
1115dda3dc5eSPeiming Liu     val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
111671c97c73SPeiming Liu     if (!isSparse()) {
111771c97c73SPeiming Liu       Value c0 = constantZero(builder, loc, rtt.getElementType());
111871c97c73SPeiming Liu       val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0);
111971c97c73SPeiming Liu     }
1120dda3dc5eSPeiming Liu   }
1121dda3dc5eSPeiming Liu 
112271c97c73SPeiming Liu   void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
112371c97c73SPeiming Liu     val = builder.create<tensor::InsertOp>(loc, v, val, crds);
1124dda3dc5eSPeiming Liu   }
1125dda3dc5eSPeiming Liu 
1126dda3dc5eSPeiming Liu   Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
112771c97c73SPeiming Liu     if (isSparse())
1128dda3dc5eSPeiming Liu       return builder.create<LoadOp>(loc, val, true);
112971c97c73SPeiming Liu     return val;
1130dda3dc5eSPeiming Liu   }
1131dda3dc5eSPeiming Liu 
113271c97c73SPeiming Liu   bool isSparse() const {
113371c97c73SPeiming Liu     return getSparseTensorEncoding(val.getType()) != nullptr;
1134dda3dc5eSPeiming Liu   }
1135dda3dc5eSPeiming Liu 
113671c97c73SPeiming Liu   Value val;
1137dda3dc5eSPeiming Liu };
1138dda3dc5eSPeiming Liu 
1139c780352dSPeiming Liu struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
1140c780352dSPeiming Liu   using OpRewritePattern::OpRewritePattern;
1141c780352dSPeiming Liu   LogicalResult matchAndRewrite(tensor::DimOp op,
1142c780352dSPeiming Liu                                 PatternRewriter &rewriter) const override {
1143c780352dSPeiming Liu     std::optional<int64_t> dim = op.getConstantIndex();
1144*129ade21SLongsheng Mou     auto stt = tryGetSparseTensorType(op.getSource());
1145*129ade21SLongsheng Mou     if (!dim || !stt || !stt->hasEncoding())
1146c780352dSPeiming Liu       return failure();
1147c780352dSPeiming Liu 
1148*129ade21SLongsheng Mou     if (stt->isPermutation()) {
1149c780352dSPeiming Liu       rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
1150*129ade21SLongsheng Mou                                          toLvl(stt->getEncoding(), *dim));
1151c780352dSPeiming Liu       return success();
1152c780352dSPeiming Liu     }
1153c780352dSPeiming Liu 
1154c780352dSPeiming Liu     // Non-permutation dim2lvl/lvl2dim maps.
1155c780352dSPeiming Liu     // Compute as follows:
1156c780352dSPeiming Liu     // affine.apply #map (l0 - 1, l1 - 1, ...) + 1
1157c780352dSPeiming Liu     // Note that it is not the most efficient way (but a more general one) for
1158c780352dSPeiming Liu     // the lvl to dim translation, e.g., for BSR, the dimension size for can be
1159c780352dSPeiming Liu     // computed simply by lvl_size * block_size.
1160c780352dSPeiming Liu     Location loc = op.getLoc();
1161c780352dSPeiming Liu     SmallVector<Value> maxLvlCrds;
1162*129ade21SLongsheng Mou     for (Level l = 0; l < stt->getLvlRank(); l++) {
1163c780352dSPeiming Liu       Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l);
1164c780352dSPeiming Liu       Value maxLvlCrd = rewriter.create<arith::SubIOp>(
1165c780352dSPeiming Liu           loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType()));
1166c780352dSPeiming Liu       maxLvlCrds.push_back(maxLvlCrd);
1167c780352dSPeiming Liu     }
1168c780352dSPeiming Liu 
1169*129ade21SLongsheng Mou     AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim);
1170c780352dSPeiming Liu     Value maxDimCrd = rewriter.create<affine::AffineApplyOp>(
1171*129ade21SLongsheng Mou         op.getLoc(), AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp),
1172c780352dSPeiming Liu         maxLvlCrds);
1173c780352dSPeiming Liu 
1174c780352dSPeiming Liu     Value dimSz = rewriter.create<arith::AddIOp>(
1175c780352dSPeiming Liu         loc, maxDimCrd, constantOne(rewriter, loc, rewriter.getIndexType()));
1176c780352dSPeiming Liu     rewriter.replaceOp(op, dimSz);
1177c780352dSPeiming Liu     return success();
1178c780352dSPeiming Liu   }
1179c780352dSPeiming Liu };
1180c780352dSPeiming Liu 
1181761c9dd9SPeiming Liu struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
1182761c9dd9SPeiming Liu   using OpRewritePattern::OpRewritePattern;
1183761c9dd9SPeiming Liu   LogicalResult matchAndRewrite(ConcatenateOp op,
1184761c9dd9SPeiming Liu                                 PatternRewriter &rewriter) const override {
1185761c9dd9SPeiming Liu     if (op.needsExtraSort())
1186761c9dd9SPeiming Liu       op.emitError("ConcatenateOp not staged");
1187761c9dd9SPeiming Liu 
1188761c9dd9SPeiming Liu     const Location loc = op.getLoc();
1189761c9dd9SPeiming Liu     const auto dstTp = getSparseTensorType(op);
1190761c9dd9SPeiming Liu     const Dimension conDim = op.getDimension();
1191761c9dd9SPeiming Liu     SmallVector<Value> sizes;
1192761c9dd9SPeiming Liu     concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim);
1193761c9dd9SPeiming Liu 
1194761c9dd9SPeiming Liu     // %t = concatenate %s1, %s2, %s3 {dim = 1}
1195761c9dd9SPeiming Liu     // ==>
1196761c9dd9SPeiming Liu     // if (isSparseDst)
1197761c9dd9SPeiming Liu     //   if (allDense)
1198761c9dd9SPeiming Liu     //     %tmp = bufferization.alloc_tensor dstTp
1199761c9dd9SPeiming Liu     //   else
1200761c9dd9SPeiming Liu     //     %tmp = bufferization.alloc_tensor : unordered COO
1201761c9dd9SPeiming Liu     // else
1202761c9dd9SPeiming Liu     //   %tmp = memref.alloc : dense tensor
1203761c9dd9SPeiming Liu     // foreach in %s1 : insert d0, d1, %tmp
1204761c9dd9SPeiming Liu     // foreach in %s2 : insert d0, d1 + size(s1), %tmp
1205761c9dd9SPeiming Liu     // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp
1206761c9dd9SPeiming Liu 
1207761c9dd9SPeiming Liu     TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
1208761c9dd9SPeiming Liu     Value offset = constantIndex(rewriter, loc, 0);
120971c97c73SPeiming Liu     Value iterArg = dstBuf.val;
1210761c9dd9SPeiming Liu 
1211761c9dd9SPeiming Liu     ForeachOp foreachOp;
1212761c9dd9SPeiming Liu     for (Value input : op.getInputs()) {
1213761c9dd9SPeiming Liu       // Builds a for op for each input tensor to append new values into the
1214761c9dd9SPeiming Liu       // output tensor.
1215761c9dd9SPeiming Liu       foreachOp = rewriter.create<ForeachOp>(
121671c97c73SPeiming Liu           loc, input, iterArg,
1217761c9dd9SPeiming Liu           [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1218761c9dd9SPeiming Liu               ValueRange reduc) {
1219ef100c22SPeiming Liu             SmallVector<Value> offDimCrd(dcvs);
1220ef100c22SPeiming Liu             offDimCrd[conDim] =
1221ef100c22SPeiming Liu                 builder.create<arith::AddIOp>(loc, offDimCrd[conDim], offset);
1222ef100c22SPeiming Liu 
122371c97c73SPeiming Liu             // Enters foreach, updates the SSA chain.
122471c97c73SPeiming Liu             dstBuf.val = reduc.front();
1225761c9dd9SPeiming Liu             if (!dstTp.isAllDense()) {
1226761c9dd9SPeiming Liu               Value cond = genIsNonzero(builder, loc, v);
1227761c9dd9SPeiming Liu               auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1228761c9dd9SPeiming Liu                                                     /*else*/ true);
1229761c9dd9SPeiming Liu               builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
123071c97c73SPeiming Liu               builder.create<scf::YieldOp>(loc, dstBuf.val);
1231761c9dd9SPeiming Liu 
1232761c9dd9SPeiming Liu               builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1233ef100c22SPeiming Liu               dstBuf.insert(builder, loc, v, offDimCrd);
123471c97c73SPeiming Liu               builder.create<scf::YieldOp>(loc, dstBuf.val);
1235761c9dd9SPeiming Liu 
1236761c9dd9SPeiming Liu               // Exits the ifOp, update the sparse tensor SSA value.
1237761c9dd9SPeiming Liu               builder.setInsertionPointAfter(ifOp);
123871c97c73SPeiming Liu               dstBuf.val = ifOp.getResult(0);
1239761c9dd9SPeiming Liu             } else {
1240ef100c22SPeiming Liu               dstBuf.insert(builder, loc, v, offDimCrd);
1241761c9dd9SPeiming Liu             }
124271c97c73SPeiming Liu             builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1243761c9dd9SPeiming Liu           });
1244761c9dd9SPeiming Liu       // Accumulates the offset. Note that only static-shaped inputs are allowed
1245761c9dd9SPeiming Liu       // by concatenate op verifier, which saves us from computing the offset
1246761c9dd9SPeiming Liu       // dynamically.
124722212ca7SAart Bik       const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim);
124822212ca7SAart Bik       assert(!ShapedType::isDynamic(sz));
124922212ca7SAart Bik       offset = rewriter.create<arith::AddIOp>(loc, offset,
125022212ca7SAart Bik                                               constantIndex(rewriter, loc, sz));
1251761c9dd9SPeiming Liu       iterArg = foreachOp.getResult(0);
125271c97c73SPeiming Liu       dstBuf.val = iterArg;
1253761c9dd9SPeiming Liu     }
1254761c9dd9SPeiming Liu 
125571c97c73SPeiming Liu     dstBuf.val = iterArg;
1256761c9dd9SPeiming Liu     Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
1257761c9dd9SPeiming Liu     rewriter.replaceOp(op, ret);
1258761c9dd9SPeiming Liu     return success();
1259761c9dd9SPeiming Liu   }
1260761c9dd9SPeiming Liu };
1261761c9dd9SPeiming Liu 
1262dda3dc5eSPeiming Liu struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
1263eb877006Sbixia1   using OpRewritePattern::OpRewritePattern;
1264eb877006Sbixia1   LogicalResult matchAndRewrite(ConvertOp op,
1265eb877006Sbixia1                                 PatternRewriter &rewriter) const override {
1266761c9dd9SPeiming Liu     if (op.needsExtraSort())
1267f248d0b2SPeiming Liu       return op.emitError("ConvertOp not staged.");
1268dda3dc5eSPeiming Liu 
1269dda3dc5eSPeiming Liu     // TODO: Maybe we want a different operation for this too.
1270eb877006Sbixia1     auto encDst = getSparseTensorEncoding(op.getType());
1271eb877006Sbixia1     auto encSrc = getSparseTensorEncoding(op.getSource().getType());
127233267f40SPeiming Liu     if (encDst && encSrc && !encSrc.isSlice() &&
127385dbb3fcSPeiming Liu         encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
127485dbb3fcSPeiming Liu       // Trivial tensor conversion and simple element type conversion is handled
127585dbb3fcSPeiming Liu       // in codegen.
1276eb877006Sbixia1       return failure();
1277eb877006Sbixia1     }
1278eb877006Sbixia1 
1279eb877006Sbixia1     Location loc = op.getLoc();
1280eb877006Sbixia1     Value src = op.getSource();
1281dda3dc5eSPeiming Liu 
1282dda3dc5eSPeiming Liu     SparseTensorType srcStt = getSparseTensorType(op.getSource());
1283dda3dc5eSPeiming Liu     SparseTensorType dstStt = getSparseTensorType(op.getDest());
1284eb877006Sbixia1 
1285e6cbb914SAart Bik     bool fromSparseConst = false;
1286dda3dc5eSPeiming Liu     if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>())
1287dda3dc5eSPeiming Liu       if (dyn_cast<SparseElementsAttr>(constOp.getValue()))
1288e6cbb914SAart Bik         fromSparseConst = true;
1289e6cbb914SAart Bik 
129076647fceSwren romano     const AffineMapAttr foreachOrder =
1291dda3dc5eSPeiming Liu         (!dstStt.isIdentity() && fromSparseConst)
1292dda3dc5eSPeiming Liu             ? AffineMapAttr::get(dstStt.getExpandedDimToLvl())
129376647fceSwren romano             : nullptr;
129441089f86SPeiming Liu 
1295dda3dc5eSPeiming Liu     bool skipZeroCheck = srcStt.hasEncoding() || fromSparseConst;
1296eb877006Sbixia1 
12970e1708ffSAart Bik     SmallVector<Value> sizes;
1298dda3dc5eSPeiming Liu     sizesFromSrc(rewriter, sizes, loc, src);
1299dda3dc5eSPeiming Liu     ValueRange vs;
1300dda3dc5eSPeiming Liu     TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
1301a61a9a70SAart Bik 
1302dda3dc5eSPeiming Liu     auto foreachOp = rewriter.create<ForeachOp>(
130371c97c73SPeiming Liu         loc, src, dstBuf.val, foreachOrder,
1304dda3dc5eSPeiming Liu         [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1305dda3dc5eSPeiming Liu             ValueRange reduc) {
1306dda3dc5eSPeiming Liu           // Enters the loop, update the SSA value for insertion chain.
130771c97c73SPeiming Liu           dstBuf.val = reduc.front();
1308dda3dc5eSPeiming Liu           if (!skipZeroCheck) {
1309dda3dc5eSPeiming Liu             Value cond = genIsNonzero(builder, loc, v);
1310dda3dc5eSPeiming Liu             auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1311dda3dc5eSPeiming Liu                                                   /*else*/ true);
1312dda3dc5eSPeiming Liu             builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
131371c97c73SPeiming Liu             builder.create<scf::YieldOp>(loc, dstBuf.val);
1314dda3dc5eSPeiming Liu 
1315dda3dc5eSPeiming Liu             builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1316ef100c22SPeiming Liu             dstBuf.insert(builder, loc, v, dcvs);
131771c97c73SPeiming Liu             builder.create<scf::YieldOp>(loc, dstBuf.val);
1318dda3dc5eSPeiming Liu 
1319dda3dc5eSPeiming Liu             // Exits the ifOp, update the sparse tensor SSA value.
1320dda3dc5eSPeiming Liu             builder.setInsertionPointAfter(ifOp);
132171c97c73SPeiming Liu             dstBuf.val = ifOp.getResult(0);
1322dda3dc5eSPeiming Liu           } else {
1323ef100c22SPeiming Liu             dstBuf.insert(builder, loc, v, dcvs);
1324dda3dc5eSPeiming Liu           }
132571c97c73SPeiming Liu           builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1326eb877006Sbixia1         });
1327eb877006Sbixia1 
1328dda3dc5eSPeiming Liu     rewriter.setInsertionPointAfter(foreachOp);
1329eb877006Sbixia1 
1330dda3dc5eSPeiming Liu     // Exits the for loop, links the SSA chain.
133171c97c73SPeiming Liu     dstBuf.val = foreachOp.getResult(0);
1332eb877006Sbixia1 
1333dda3dc5eSPeiming Liu     Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
1334dda3dc5eSPeiming Liu     rewriter.replaceOp(op, ret);
1335eb877006Sbixia1     return success();
1336eb877006Sbixia1   }
1337eb877006Sbixia1 };
1338eb877006Sbixia1 
13393426d330SPeiming Liu struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
13403426d330SPeiming Liu   using OpRewritePattern::OpRewritePattern;
13413426d330SPeiming Liu   LogicalResult matchAndRewrite(CrdTranslateOp op,
13423426d330SPeiming Liu                                 PatternRewriter &rewriter) const override {
13433426d330SPeiming Liu     AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
13443426d330SPeiming Liu                         ? op.getEncoder().getDimToLvl()
13453426d330SPeiming Liu                         : op.getEncoder().getLvlToDim();
13463426d330SPeiming Liu 
13473426d330SPeiming Liu     SmallVector<Value> outCrds;
13483426d330SPeiming Liu     for (AffineExpr result : map.getResults()) {
13493426d330SPeiming Liu       // TODO: we should probably expand the affine map to IR using our own
13503426d330SPeiming Liu       // rules, since affine.apply assume signed value, while the cooridinates
13513426d330SPeiming Liu       // we provided must always be signless.
13523426d330SPeiming Liu       Value trans = rewriter.create<affine::AffineApplyOp>(
13533426d330SPeiming Liu           op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
13543426d330SPeiming Liu           op.getInCrds());
13553426d330SPeiming Liu       outCrds.push_back(trans);
13563426d330SPeiming Liu     }
13573426d330SPeiming Liu     rewriter.replaceOp(op, outCrds);
13583426d330SPeiming Liu     return success();
13593426d330SPeiming Liu   }
13603426d330SPeiming Liu };
13613426d330SPeiming Liu 
1362550288cbSPeiming Liu /// Sparse rewriting rule for the foreach operator.
1363550288cbSPeiming Liu struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
1364550288cbSPeiming Liu public:
1365550288cbSPeiming Liu   using OpRewritePattern::OpRewritePattern;
1366550288cbSPeiming Liu 
1367550288cbSPeiming Liu   LogicalResult matchAndRewrite(ForeachOp op,
1368550288cbSPeiming Liu                                 PatternRewriter &rewriter) const override {
1369550288cbSPeiming Liu 
1370550288cbSPeiming Liu     auto loc = op.getLoc();
1371550288cbSPeiming Liu     Value input = op.getTensor();
13727175f9ddSPeiming Liu     SmallVector<Value> reduc = op.getInitArgs();
1373f708a549Swren romano     const auto stt = getSparseTensorType(input);
137484cd51bbSwren romano     const Level lvlRank = stt.getLvlRank();
1375550288cbSPeiming Liu 
13767175f9ddSPeiming Liu     // Special-case: for each over a sparse constant uses its own rewriting
13777175f9ddSPeiming Liu     // rule.
13787175f9ddSPeiming Liu     if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
13795550c821STres Popp       if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
13808d615a23SPeiming Liu         return genForeachOnSparseConstant(op, rewriter, attr);
13817175f9ddSPeiming Liu       }
13827175f9ddSPeiming Liu     }
13837175f9ddSPeiming Liu 
13847175f9ddSPeiming Liu     // Otherwise, use loop emitter to generate loops.
1385f708a549Swren romano     const auto enc = stt.getEncoding();
13864fa00ce1SPeiming Liu 
1387550288cbSPeiming Liu     // 1. Generates loop for the sparse input.
1388781eabebSPeiming Liu     LoopEmitter loopEmitter(
138991e7b9e5SPeiming Liu         ValueRange{input},
139091e7b9e5SPeiming Liu         StringAttr::get(getContext(), ForeachOp::getOperationName()));
1391550288cbSPeiming Liu     loopEmitter.initializeLoopEmit(rewriter, loc);
1392b8cf7af9Swren romano     for (Level l = 0; l < lvlRank; l++) {
1393b0f8057eSPeiming Liu       // TODO: provide utility function for loop sequences that only contains
1394b0f8057eSPeiming Liu       // one for loop?
139536c95ee7SPeiming Liu       const SmallVector<TensorLevel, 1> tidLvls{
139636c95ee7SPeiming Liu           loopEmitter.makeTensorLevel(0, l)};
139736c95ee7SPeiming Liu       loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
13984fa00ce1SPeiming Liu       // Note that reduc will be taken care of by loop emitter and get updated
13994fa00ce1SPeiming Liu       // in place.
1400c4420257SPeiming Liu       loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1,
1401fd68d361SPeiming Liu                                                     reduc);
1402b0f8057eSPeiming Liu     }
1403550288cbSPeiming Liu 
1404372d88b0SPeiming Liu     SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
14059e8d9316SPeiming Liu     if (op.getOrder()) {
140653ffafb2SPeiming Liu       // TODO: Support it so that we can do direct conversion from CSR->BSR.
140753ffafb2SPeiming Liu       llvm_unreachable(
140853ffafb2SPeiming Liu           "Level order not yet implemented on non-constant input tensors.");
14099e8d9316SPeiming Liu     }
141053ffafb2SPeiming Liu 
1411b0f8057eSPeiming Liu     Value vals = loopEmitter.getValBuffer()[0];
141252b69aa3SPeiming Liu     SmallVector<Value> pos = loopEmitter.getValPosits(0);
141384cd51bbSwren romano     // Loads the value from sparse tensor using position-index;
141484cd51bbSwren romano     // loads the value from dense tensor using coords.
1415b8cf7af9Swren romano     Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pos)
141684cd51bbSwren romano                     : rewriter.create<memref::LoadOp>(loc, vals, lcvs);
141735b3a0ceSPeiming Liu 
1418550288cbSPeiming Liu     // 2. Inline the block in the foreach operator.
1419550288cbSPeiming Liu     Block *srcBlock = op.getBody();
1420b0f8057eSPeiming Liu 
1421550288cbSPeiming Liu     // Remap coordinates.
14226456e0bbSPeiming Liu     SmallVector<Value> args =
14236456e0bbSPeiming Liu         enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
14246456e0bbSPeiming Liu 
1425550288cbSPeiming Liu     // Remap value.
1426550288cbSPeiming Liu     args.push_back(val);
14274fa00ce1SPeiming Liu     // Remap reduction variables.
14284fa00ce1SPeiming Liu     args.append(reduc);
14294fa00ce1SPeiming Liu 
14304fa00ce1SPeiming Liu     // Remove sparse_tensor.yield.
14314fa00ce1SPeiming Liu     SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands();
14324fa00ce1SPeiming Liu     rewriter.eraseOp(srcBlock->getTerminator());
1433550288cbSPeiming Liu 
1434298412b5SPeiming Liu     Operation &last = rewriter.getBlock()->back();
1435298412b5SPeiming Liu     if (llvm::isa<scf::YieldOp>(last)) {
1436298412b5SPeiming Liu       // Because `scf.for` inserts an implicit yield op when there is no
1437298412b5SPeiming Liu       // reduction variable upon creation, we reset the insertion point such
1438298412b5SPeiming Liu       // that the block is inlined before *before* the yield op.
1439298412b5SPeiming Liu       rewriter.setInsertionPoint(&last);
14404fa00ce1SPeiming Liu     }
14414fa00ce1SPeiming Liu 
1442298412b5SPeiming Liu     rewriter.inlineBlockBefore(srcBlock, rewriter.getBlock(),
1443298412b5SPeiming Liu                                rewriter.getInsertionPoint(), args);
1444298412b5SPeiming Liu     rewriter.setInsertionPointToEnd(rewriter.getBlock());
1445e9fa1fdeSPeiming Liu     for (Level l = 0; l < lvlRank; l++) {
14464fa00ce1SPeiming Liu       // Link the reduction chain. Note that loop emitter update the reducValue
14474fa00ce1SPeiming Liu       // in place.
14484fa00ce1SPeiming Liu       loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
14495fd9d801SPeiming Liu       loopEmitter.exitCurrentLoopSeq(rewriter, loc);
14504fa00ce1SPeiming Liu     }
14514fa00ce1SPeiming Liu 
14524fa00ce1SPeiming Liu     // Replace the foreach operator with the value returned by the outtermost
14534fa00ce1SPeiming Liu     // for loop.
14544fa00ce1SPeiming Liu     rewriter.replaceOp(op, reducValue);
1455550288cbSPeiming Liu     return success();
1456550288cbSPeiming Liu   }
1457550288cbSPeiming Liu };
1458550288cbSPeiming Liu 
145967f61b08Sbixia1 /// Sparse rewriting rule for the new operator.
146067f61b08Sbixia1 struct NewRewriter : public OpRewritePattern<NewOp> {
146167f61b08Sbixia1   using OpRewritePattern::OpRewritePattern;
146267f61b08Sbixia1   LogicalResult matchAndRewrite(NewOp op,
146367f61b08Sbixia1                                 PatternRewriter &rewriter) const override {
146467f61b08Sbixia1     Location loc = op.getLoc();
1465e8fc282fSAart Bik     auto stt = getSparseTensorType(op.getResult());
14665248a987SPeiming Liu     if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0)
146767f61b08Sbixia1       return failure();
146867f61b08Sbixia1 
146967f61b08Sbixia1     // Implement the NewOp as follows:
14702c81d432Sbixia1     //   %orderedCoo = sparse_tensor.new %filename
147184cd51bbSwren romano     //   %t = sparse_tensor.convert %orderedCoo
1472e8fc282fSAart Bik     // with enveloping reinterpreted_map ops for non-permutations.
1473e8fc282fSAart Bik     RankedTensorType dstTp = stt.getRankedTensorType();
147445288085SAart Bik     RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
14752c81d432Sbixia1     Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
1476e8fc282fSAart Bik     Value convert = cooTensor;
14775b729503SAart Bik     auto enc = stt.getEncoding();
1478e8fc282fSAart Bik     if (!stt.isPermutation()) { // demap coo, demap dstTp
1479e8fc282fSAart Bik       auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
1480e8fc282fSAart Bik       convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
1481e8fc282fSAart Bik       dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
1482e8fc282fSAart Bik     }
1483e8fc282fSAart Bik     convert = rewriter.create<ConvertOp>(loc, dstTp, convert);
1484e8fc282fSAart Bik     if (!stt.isPermutation()) // remap to original enc
1485e8fc282fSAart Bik       convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert);
1486e8fc282fSAart Bik     rewriter.replaceOp(op, convert);
14873bd82f30SAart Bik 
1488e8fc282fSAart Bik     // Release the temporary ordered COO tensor.
14892c81d432Sbixia1     rewriter.setInsertionPointAfterValue(convert);
14902c81d432Sbixia1     rewriter.create<DeallocTensorOp>(loc, cooTensor);
149167f61b08Sbixia1 
149267f61b08Sbixia1     return success();
149367f61b08Sbixia1   }
149467f61b08Sbixia1 };
149567f61b08Sbixia1 
1496e8fc282fSAart Bik /// Sparse rewriting rule for the out operator.
1497e445349dSbixia1 struct OutRewriter : public OpRewritePattern<OutOp> {
1498e445349dSbixia1   using OpRewritePattern::OpRewritePattern;
1499e445349dSbixia1   LogicalResult matchAndRewrite(OutOp op,
1500e445349dSbixia1                                 PatternRewriter &rewriter) const override {
1501e445349dSbixia1     Location loc = op.getLoc();
1502e445349dSbixia1     // Calculate NNZ.
1503e445349dSbixia1     Value src = op.getTensor();
1504e445349dSbixia1     Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
1505e445349dSbixia1 
150684cd51bbSwren romano     // Allocate a temporary buffer for storing dimension-sizes/coordinates.
1507f708a549Swren romano     const auto srcTp = getSparseTensorType(src);
1508f708a549Swren romano     const Dimension dimRank = srcTp.getDimRank();
1509e445349dSbixia1     Type indexTp = rewriter.getIndexType();
1510f708a549Swren romano     Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp);
1511e445349dSbixia1 
1512e445349dSbixia1     // Generate code to calculate dimension size values and store the values to
1513e445349dSbixia1     // the buffer.
15140e1708ffSAart Bik     SmallVector<Value> dims;
1515e445349dSbixia1     sizesForTensor(rewriter, dims, loc, srcTp, src);
1516f708a549Swren romano     for (Dimension d = 0; d < dimRank; d++) {
1517f708a549Swren romano       rewriter.create<memref::StoreOp>(loc, dims[d], dimSizes,
1518f708a549Swren romano                                        constantIndex(rewriter, loc, d));
1519e445349dSbixia1     }
1520e445349dSbixia1 
1521e445349dSbixia1     // Create a sparse tensor writer and output meta data.
1522e445349dSbixia1     Type opaqueTp = getOpaquePointerType(rewriter);
1523e445349dSbixia1     Value writer =
1524e445349dSbixia1         createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp},
1525e445349dSbixia1                        {op.getDest()}, EmitCInterface::Off)
1526e445349dSbixia1             .getResult(0);
1527f708a549Swren romano     Value rankValue = constantIndex(rewriter, loc, dimRank);
1528e445349dSbixia1     createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {},
1529e445349dSbixia1                    {writer, rankValue, nnz, dimSizes}, EmitCInterface::On);
1530e445349dSbixia1 
153184cd51bbSwren romano     Value dimCoords = dimSizes; // Reuse the dimSizes buffer for dimCoords.
1532e445349dSbixia1     Type eltTp = srcTp.getElementType();
15332af2e4dbSwren romano     SmallString<29> outNextFuncName{"outSparseTensorWriterNext",
1534e445349dSbixia1                                     primaryTypeFunctionSuffix(eltTp)};
1535e445349dSbixia1     Value value = genAllocaScalar(rewriter, loc, eltTp);
1536e445349dSbixia1     ModuleOp module = op->getParentOfType<ModuleOp>();
1537e8fc282fSAart Bik 
1538e445349dSbixia1     // For each element in the source tensor, output the element.
1539e445349dSbixia1     rewriter.create<ForeachOp>(
15401a36588eSKazu Hirata         loc, src, std::nullopt,
154184cd51bbSwren romano         [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
15424fa00ce1SPeiming Liu             ValueRange reduc) {
1543f708a549Swren romano           for (Dimension d = 0; d < dimRank; d++) {
154484cd51bbSwren romano             rewriter.create<memref::StoreOp>(loc, dcvs[d], dimCoords,
1545f708a549Swren romano                                              constantIndex(builder, loc, d));
1546e445349dSbixia1           }
15474fa00ce1SPeiming Liu           rewriter.create<memref::StoreOp>(loc, v, value);
154884cd51bbSwren romano           SmallVector<Value> operands{writer, rankValue, dimCoords, value};
1549e445349dSbixia1           FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands,
1550e445349dSbixia1                                          EmitCInterface::On);
1551e445349dSbixia1           builder.create<func::CallOp>(loc, TypeRange(), fn, operands);
1552e445349dSbixia1           builder.create<sparse_tensor::YieldOp>(loc);
1553e445349dSbixia1         });
1554e445349dSbixia1 
1555e445349dSbixia1     // Release the writer.
1556e445349dSbixia1     createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer},
1557e445349dSbixia1                    EmitCInterface::Off);
1558e445349dSbixia1 
1559e445349dSbixia1     rewriter.eraseOp(op);
1560e445349dSbixia1     return success();
1561e445349dSbixia1   }
1562e445349dSbixia1 };
1563e445349dSbixia1 
156428ebb0b6SAart Bik } // namespace
156528ebb0b6SAart Bik 
156628ebb0b6SAart Bik //===---------------------------------------------------------------------===//
156728ebb0b6SAart Bik // Methods that add patterns described in this file to a pattern list.
156828ebb0b6SAart Bik //===---------------------------------------------------------------------===//
1569f81f0cb7Sbixia1 
1570f81f0cb7Sbixia1 void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
15713aeb28b9SPeiming Liu   patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
15723aeb28b9SPeiming Liu                FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
15733aeb28b9SPeiming Liu                GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
15743aeb28b9SPeiming Liu       patterns.getContext());
1575f81f0cb7Sbixia1 }
1576f81f0cb7Sbixia1 
1577f82bee13SPeiming Liu void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
1578f81f0cb7Sbixia1                                                    bool enableRT,
1579eb877006Sbixia1                                                    bool enableConvert) {
1580ef100c22SPeiming Liu   patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
1581bc878f70SPeiming Liu                ReshapeRewriter<tensor::CollapseShapeOp>,
1582bc878f70SPeiming Liu                Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
1583bc878f70SPeiming Liu                Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
15847d608ee2SPeiming Liu                SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
1585c780352dSPeiming Liu       patterns.getContext());
1586f82bee13SPeiming Liu 
1587eb877006Sbixia1   if (enableConvert)
1588dda3dc5eSPeiming Liu     patterns.add<DirectConvertRewriter>(patterns.getContext());
1589f248d0b2SPeiming Liu   if (!enableRT)
15907d608ee2SPeiming Liu     patterns.add<NewRewriter>(patterns.getContext());
159128ebb0b6SAart Bik }
1592f82bee13SPeiming Liu 
1593f82bee13SPeiming Liu void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) {
15943426d330SPeiming Liu   // Run CrdTranslateRewriter later in the pipeline so that operation can be
15953426d330SPeiming Liu   // folded before lowering to affine.apply
15963426d330SPeiming Liu   patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext());
1597f82bee13SPeiming Liu }
1598