128ebb0b6SAart Bik //===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===// 228ebb0b6SAart Bik // 328ebb0b6SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 428ebb0b6SAart Bik // See https://llvm.org/LICENSE.txt for license information. 528ebb0b6SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 628ebb0b6SAart Bik // 728ebb0b6SAart Bik //===----------------------------------------------------------------------===// 828ebb0b6SAart Bik // 928ebb0b6SAart Bik // This file implements rewriting rules that are specific to sparse tensors. 1028ebb0b6SAart Bik // 1128ebb0b6SAart Bik //===----------------------------------------------------------------------===// 1228ebb0b6SAart Bik 13365777ecSAart Bik #include "Utils/CodegenUtils.h" 14365777ecSAart Bik #include "Utils/LoopEmitter.h" 15c7bb69bcSAart Bik 166456e0bbSPeiming Liu #include "mlir/Dialect/Affine/IR/AffineOps.h" 17abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 1828ebb0b6SAart Bik #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 1928ebb0b6SAart Bik #include "mlir/Dialect/Linalg/IR/Linalg.h" 20555e7835Sbixia1 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21550288cbSPeiming Liu #include "mlir/Dialect/MemRef/IR/MemRef.h" 2267f61b08Sbixia1 #include "mlir/Dialect/SCF/IR/SCF.h" 2328ebb0b6SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 24d37affb0SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" 25f708a549Swren romano #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 2628ebb0b6SAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 2728ebb0b6SAart Bik #include "mlir/Dialect/Tensor/IR/Tensor.h" 28d37affb0SAart Bik #include "mlir/Dialect/Vector/IR/VectorOps.h" 2928ebb0b6SAart Bik #include "mlir/IR/AffineMap.h" 3028ebb0b6SAart Bik #include "mlir/IR/Matchers.h" 3128ebb0b6SAart Bik #include "mlir/Support/LLVM.h" 3228ebb0b6SAart Bik 3328ebb0b6SAart Bik using namespace mlir; 3428ebb0b6SAart Bik using namespace mlir::bufferization; 3528ebb0b6SAart Bik using namespace mlir::linalg; 3628ebb0b6SAart Bik using namespace mlir::sparse_tensor; 3728ebb0b6SAart Bik 3828ebb0b6SAart Bik //===---------------------------------------------------------------------===// 3928ebb0b6SAart Bik // Helper methods for the actual rewriting rules. 4028ebb0b6SAart Bik //===---------------------------------------------------------------------===// 4128ebb0b6SAart Bik 420d4e7fbaSAart Bik // Helper method to match any typed zero. 430d4e7fbaSAart Bik static bool isZeroValue(Value val) { 440d4e7fbaSAart Bik return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()); 450d4e7fbaSAart Bik } 460d4e7fbaSAart Bik 4728ebb0b6SAart Bik // Helper to detect a sparse tensor type operand. 48e7df8281SPeiming Liu static bool isSparseTensor(Value v) { 49e7df8281SPeiming Liu auto enc = getSparseTensorEncoding(v.getType()); 501944c4f7SAart Bik return enc && !llvm::all_of(enc.getLvlTypes(), 51aaf91645SPeiming Liu [](auto lt) { return lt == LevelFormat::Dense; }); 52b4e2b7f9SPeiming Liu } 53e7df8281SPeiming Liu static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); } 54b4e2b7f9SPeiming Liu 556a45339bSAart Bik // Helper method to find zero/uninitialized tensor materialization. 566a45339bSAart Bik static bool isMaterializing(OpOperand *op, bool isZero) { 5728ebb0b6SAart Bik Value val = op->get(); 5865074179SAart Bik // Check allocation, with zero alloc when required. 59ce3d0e87SAart Bik if (auto alloc = val.getDefiningOp<AllocTensorOp>()) { 60ce3d0e87SAart Bik Value copy = alloc.getCopy(); 61ce3d0e87SAart Bik if (isZero) 620d4e7fbaSAart Bik return copy && isZeroValue(copy); 63ce3d0e87SAart Bik return !copy; 64ce3d0e87SAart Bik } 656a45339bSAart Bik // Check for empty tensor materialization. 666a45339bSAart Bik if (auto empty = val.getDefiningOp<tensor::EmptyOp>()) 676a45339bSAart Bik return !isZero; 6865074179SAart Bik // Last resort for zero alloc: the whole value is zero. 6965074179SAart Bik return isZero && isZeroValue(val); 7028ebb0b6SAart Bik } 7128ebb0b6SAart Bik 7228ebb0b6SAart Bik // Helper to detect sampling operation. 7328ebb0b6SAart Bik static bool isSampling(GenericOp op) { 74d3b3f765SJacques Pienaar auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); 7528ebb0b6SAart Bik if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { 7628ebb0b6SAart Bik if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) { 7728ebb0b6SAart Bik // Both scalar input arguments used exactly once. 7828ebb0b6SAart Bik Value s1 = op.getBlock()->getArgument(0); 7928ebb0b6SAart Bik Value s2 = op.getBlock()->getArgument(1); 8028ebb0b6SAart Bik return (def->getOperand(0) == s1 && def->getOperand(1) == s2) || 8128ebb0b6SAart Bik (def->getOperand(1) == s1 && def->getOperand(0) == s2); 8228ebb0b6SAart Bik } 8328ebb0b6SAart Bik } 8428ebb0b6SAart Bik return false; 8528ebb0b6SAart Bik } 8628ebb0b6SAart Bik 8728ebb0b6SAart Bik // Helper to detect chain of multiplications that do not involve x. 8828ebb0b6SAart Bik static bool isMulChain(Value val, Value x) { 895550c821STres Popp if (auto arg = dyn_cast<BlockArgument>(val)) 9028ebb0b6SAart Bik return arg != x; 9128ebb0b6SAart Bik if (auto *def = val.getDefiningOp()) { 9228ebb0b6SAart Bik if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) 9328ebb0b6SAart Bik return isMulChain(def->getOperand(0), x) && 9428ebb0b6SAart Bik isMulChain(def->getOperand(1), x); 9528ebb0b6SAart Bik } 9628ebb0b6SAart Bik return false; 9728ebb0b6SAart Bik } 9828ebb0b6SAart Bik 9928ebb0b6SAart Bik // Helper to detect x = x + <multiplications>. 10028ebb0b6SAart Bik static bool isSumOfMul(GenericOp op) { 101d3b3f765SJacques Pienaar auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); 10228ebb0b6SAart Bik if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { 10328ebb0b6SAart Bik if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) { 10428ebb0b6SAart Bik Value x = op.getBlock()->getArguments().back(); 10528ebb0b6SAart Bik return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) || 10628ebb0b6SAart Bik (def->getOperand(1) == x && isMulChain(def->getOperand(0), x)); 10728ebb0b6SAart Bik } 10828ebb0b6SAart Bik } 10928ebb0b6SAart Bik return false; 11028ebb0b6SAart Bik } 11128ebb0b6SAart Bik 112c7bb69bcSAart Bik // Helper to detect direct yield of a zero value. 113c7bb69bcSAart Bik static bool isZeroYield(GenericOp op) { 114d3b3f765SJacques Pienaar auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); 1155550c821STres Popp if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) { 116c7bb69bcSAart Bik if (arg.getOwner()->getParentOp() == op) { 117a7cccb9cSAlexander Belyaev return isZeroValue(op->getOperand(arg.getArgNumber())); 118c7bb69bcSAart Bik } 119c7bb69bcSAart Bik } 1200d4e7fbaSAart Bik return isZeroValue(yieldOp.getOperand(0)); 121c7bb69bcSAart Bik } 122c7bb69bcSAart Bik 123330d48c4Sbixia1 /// Populates given sizes array from type (for static sizes) and from 124330d48c4Sbixia1 /// the tensor (for dynamic sizes). 1250e1708ffSAart Bik static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes, 126330d48c4Sbixia1 Location loc, ShapedType stp, Value tensor) { 127330d48c4Sbixia1 for (const auto &d : enumerate(stp.getShape())) { 128330d48c4Sbixia1 Value dim; 129399638f9SAliia Khasanova if (d.value() == ShapedType::kDynamic) 130330d48c4Sbixia1 dim = builder.create<tensor::DimOp>(loc, tensor, d.index()); 131330d48c4Sbixia1 else 132330d48c4Sbixia1 dim = constantIndex(builder, loc, d.value()); 133330d48c4Sbixia1 sizes.push_back(dim); 134330d48c4Sbixia1 } 135330d48c4Sbixia1 } 136330d48c4Sbixia1 13776647fceSwren romano static RankedTensorType getBufferType(const SparseTensorType &stt, 13876647fceSwren romano bool needTmpCOO) { 13945288085SAart Bik return needTmpCOO ? stt.getCOOType(/*ordered=*/false) 14076647fceSwren romano : stt.getRankedTensorType(); 14181e3079dSbixia1 } 14281e3079dSbixia1 143eb877006Sbixia1 /// Collects the dynamic dimension sizes for `tp` with the assumption that 144eb877006Sbixia1 /// `sizes` are the dimension sizes for the type. Stores the dynamic dimension 145eb877006Sbixia1 /// sizes to dynSizes. 146dda3dc5eSPeiming Liu static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, 147eb877006Sbixia1 SmallVectorImpl<Value> &dynSizes) { 148eb877006Sbixia1 for (const auto &d : enumerate(tp.getShape())) { 149399638f9SAliia Khasanova if (d.value() == ShapedType::kDynamic) 150eb877006Sbixia1 dynSizes.push_back(sizes[d.index()]); 151eb877006Sbixia1 } 152eb877006Sbixia1 } 153eb877006Sbixia1 1548d615a23SPeiming Liu static LogicalResult genForeachOnSparseConstant(ForeachOp op, 1558d615a23SPeiming Liu RewriterBase &rewriter, 1568d615a23SPeiming Liu SparseElementsAttr attr) { 1578d615a23SPeiming Liu auto loc = op.getLoc(); 1588d615a23SPeiming Liu SmallVector<Value> reduc = op.getInitArgs(); 1598d615a23SPeiming Liu 1608d615a23SPeiming Liu // Foreach on constant. 1618d615a23SPeiming Liu foreachInSparseConstant( 1629d4df97fSwren romano rewriter, loc, attr, op.getOrder().value_or(AffineMap()), 16384cd51bbSwren romano [&reduc, &rewriter, op](ArrayRef<Value> cvs, Value v) mutable { 1648d615a23SPeiming Liu SmallVector<Value> args; 16584cd51bbSwren romano args.append(cvs.begin(), cvs.end()); 1668d615a23SPeiming Liu args.push_back(v); 1678d615a23SPeiming Liu args.append(reduc); 1688d615a23SPeiming Liu // Clones the foreach op to get a copy of the loop body. 1698d615a23SPeiming Liu auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation())); 1708d615a23SPeiming Liu assert(args.size() == cloned.getBody()->getNumArguments()); 1718d615a23SPeiming Liu Operation *yield = cloned.getBody()->getTerminator(); 17242c31d83SMatthias Springer rewriter.inlineBlockBefore(cloned.getBody(), op, args); 1738d615a23SPeiming Liu // clean up 1748d615a23SPeiming Liu rewriter.eraseOp(cloned); 1758d615a23SPeiming Liu reduc = yield->getOperands(); 1768d615a23SPeiming Liu rewriter.eraseOp(yield); 1778d615a23SPeiming Liu }); 1788d615a23SPeiming Liu 1798d615a23SPeiming Liu rewriter.replaceOp(op, reduc); 1808d615a23SPeiming Liu return success(); 1818d615a23SPeiming Liu } 1828d615a23SPeiming Liu 183aedf5d58Sbixia1 /// Populates the given sizes array for concatenation from types (for static 184aedf5d58Sbixia1 /// sizes) and from the source tensors (for dynamic sizes). 185aedf5d58Sbixia1 static void concatSizesFromInputs(OpBuilder &builder, 186aedf5d58Sbixia1 SmallVectorImpl<Value> &sizes, Location loc, 187aedf5d58Sbixia1 ShapedType dstTp, ValueRange srcs, 188aedf5d58Sbixia1 unsigned dim) { 189aedf5d58Sbixia1 auto dstShape = dstTp.getShape(); 190aedf5d58Sbixia1 sizesFromSrc(builder, sizes, loc, srcs[0]); 191aedf5d58Sbixia1 192aedf5d58Sbixia1 // Sum up on the `dim` if the dimension is dynamic. 193aedf5d58Sbixia1 if (dstShape[dim] != ShapedType::kDynamic) { 194aedf5d58Sbixia1 // Faithfully take the static size. 195aedf5d58Sbixia1 sizes[dim] = constantIndex(builder, loc, dstShape[dim]); 196aedf5d58Sbixia1 } else { 197aedf5d58Sbixia1 // Else, compute the shape dynamically. 198aedf5d58Sbixia1 for (const auto &src : srcs.drop_front()) { 199aedf5d58Sbixia1 Value srcSz = linalg::createOrFoldDimOp(builder, loc, src, dim); 200aedf5d58Sbixia1 // Sum up all the sizes. 201aedf5d58Sbixia1 sizes[dim] = builder.create<arith::AddIOp>(loc, sizes[dim], srcSz); 202aedf5d58Sbixia1 } 203aedf5d58Sbixia1 } 204aedf5d58Sbixia1 } 205aedf5d58Sbixia1 20628ebb0b6SAart Bik //===---------------------------------------------------------------------===// 20728ebb0b6SAart Bik // The actual sparse tensor rewriting rules. 20828ebb0b6SAart Bik //===---------------------------------------------------------------------===// 20928ebb0b6SAart Bik 21028ebb0b6SAart Bik namespace { 21128ebb0b6SAart Bik 212ea3eeb48SPeiming Liu /// TODO: move it to tensor dialect instead. 213ea3eeb48SPeiming Liu /// 214ea3eeb48SPeiming Liu /// Fold `tensor.concat` and `tensor.extract_slice` 215ea3eeb48SPeiming Liu /// 216ea3eeb48SPeiming Liu /// %concat = tensor.concat dim(2) %t0, %t1 217ea3eeb48SPeiming Liu /// : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32> 218ea3eeb48SPeiming Liu /// %extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1] 219ea3eeb48SPeiming Liu /// : tensor<1x64x2xf32> to tensor<1x64x1xf32> 220ea3eeb48SPeiming Liu /// %extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1] 221ea3eeb48SPeiming Liu /// : tensor<1x64x2xf32> to tensor<1x64x1xf32> 222ea3eeb48SPeiming Liu /// 223ea3eeb48SPeiming Liu /// Becomes 224ea3eeb48SPeiming Liu /// 225ea3eeb48SPeiming Liu /// %extract0, %extract1 = %t0, %t1 226ea3eeb48SPeiming Liu struct FuseExtractSliceWithConcat 227ea3eeb48SPeiming Liu : public OpRewritePattern<tensor::ExtractSliceOp> { 228ea3eeb48SPeiming Liu using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; 229ea3eeb48SPeiming Liu 230ea3eeb48SPeiming Liu LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp, 231ea3eeb48SPeiming Liu PatternRewriter &rewriter) const override { 232ea3eeb48SPeiming Liu auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>(); 233ea3eeb48SPeiming Liu if (!concatOp) 234ea3eeb48SPeiming Liu return failure(); 235ea3eeb48SPeiming Liu 236ea3eeb48SPeiming Liu Location loc = extractOp.getLoc(); 237ea3eeb48SPeiming Liu int64_t dim = concatOp.getDim(); 238ea3eeb48SPeiming Liu int64_t rank = extractOp.getResultType().getRank(); 239ea3eeb48SPeiming Liu 240ea3eeb48SPeiming Liu SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1)); 241ea3eeb48SPeiming Liu SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0)); 242ea3eeb48SPeiming Liu 243ea3eeb48SPeiming Liu // Compute the partial sums for the slice offsets. 244ea3eeb48SPeiming Liu AffineExpr sum = rewriter.getAffineDimExpr(0); 245ea3eeb48SPeiming Liu SmallVector<AffineExpr> partialSums = {sum}; 246ea3eeb48SPeiming Liu SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)}; 247ea3eeb48SPeiming Liu for (auto [idx, input] : 248ea3eeb48SPeiming Liu llvm::enumerate(concatOp.getInputs().drop_back())) { 249ea3eeb48SPeiming Liu sum = sum + rewriter.getAffineDimExpr(idx + 1); 250ea3eeb48SPeiming Liu partialSums.push_back(sum); 251ea3eeb48SPeiming Liu offsetStrides.push_back( 252ea3eeb48SPeiming Liu rewriter.createOrFold<tensor::DimOp>(loc, input, dim)); 253ea3eeb48SPeiming Liu } 254ea3eeb48SPeiming Liu auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0, 255ea3eeb48SPeiming Liu partialSums, rewriter.getContext()); 256ea3eeb48SPeiming Liu SmallVector<OpFoldResult> dimOffsets = 257ea3eeb48SPeiming Liu affine::makeComposedFoldedMultiResultAffineApply( 258ea3eeb48SPeiming Liu rewriter, loc, partialSumMap, offsetStrides); 259ea3eeb48SPeiming Liu 260ea3eeb48SPeiming Liu auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) { 261ea3eeb48SPeiming Liu for (auto [l, r] : llvm::zip(lhs, rhs)) { 262ea3eeb48SPeiming Liu std::optional<int64_t> staticVal = getConstantIntValue(l); 263ea3eeb48SPeiming Liu if (!staticVal.has_value() || staticVal != getConstantIntValue(r)) 264ea3eeb48SPeiming Liu return false; 265ea3eeb48SPeiming Liu } 266ea3eeb48SPeiming Liu return lhs.size() == rhs.size(); 267ea3eeb48SPeiming Liu }; 268ea3eeb48SPeiming Liu 269ea3eeb48SPeiming Liu for (auto [i, input, offset] : 270ea3eeb48SPeiming Liu llvm::enumerate(concatOp.getInputs(), dimOffsets)) { 271ea3eeb48SPeiming Liu SmallVector<OpFoldResult> srcSizes = 272ea3eeb48SPeiming Liu tensor::getMixedSizes(rewriter, loc, input); 273ea3eeb48SPeiming Liu srcOffsets[dim] = offset; 274ea3eeb48SPeiming Liu 275ea3eeb48SPeiming Liu SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes(); 276ea3eeb48SPeiming Liu SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets(); 277ea3eeb48SPeiming Liu SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides(); 278ea3eeb48SPeiming Liu 279ea3eeb48SPeiming Liu if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) && 280ea3eeb48SPeiming Liu allEqual(srcStrides, dstStrides)) { 281ea3eeb48SPeiming Liu Value operand = concatOp.getOperand(i); 282ea3eeb48SPeiming Liu if (operand.getType() == extractOp.getResultType()) 283ea3eeb48SPeiming Liu rewriter.replaceOp(extractOp, operand); 284ea3eeb48SPeiming Liu break; 285ea3eeb48SPeiming Liu } 286ea3eeb48SPeiming Liu } 287ea3eeb48SPeiming Liu 288ea3eeb48SPeiming Liu return success(); 289ea3eeb48SPeiming Liu } 290ea3eeb48SPeiming Liu }; 291ea3eeb48SPeiming Liu 2923aeb28b9SPeiming Liu /// Rewriting rule that fuses sparse_tensor.convert into producer. 2933aeb28b9SPeiming Liu struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> { 2943aeb28b9SPeiming Liu public: 2953aeb28b9SPeiming Liu using OpRewritePattern::OpRewritePattern; 2963aeb28b9SPeiming Liu 2973aeb28b9SPeiming Liu LogicalResult matchAndRewrite(ConvertOp op, 2983aeb28b9SPeiming Liu PatternRewriter &rewriter) const override { 2993aeb28b9SPeiming Liu auto producer = op.getSource().getDefiningOp<GenericOp>(); 3003aeb28b9SPeiming Liu if (!producer || producer.getDpsInits().size() != 1 || 3013aeb28b9SPeiming Liu !isMaterializing(producer.getDpsInitOperand(0), false) || 3023aeb28b9SPeiming Liu !producer.getResult(0).hasOneUse()) { 3033aeb28b9SPeiming Liu return failure(); 3043aeb28b9SPeiming Liu } 305fb8f492aSPeiming Liu // Clone the materialization operation, but update the result to sparse. 306fb8f492aSPeiming Liu rewriter.setInsertionPoint(producer); 307fb8f492aSPeiming Liu Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp(); 308fb8f492aSPeiming Liu Operation *cloned = rewriter.clone(*init); 309fb8f492aSPeiming Liu cloned->getResult(0).setType(op.getResult().getType()); 310fb8f492aSPeiming Liu 3113aeb28b9SPeiming Liu rewriter.modifyOpInPlace(producer, [&]() { 312fb8f492aSPeiming Liu producer.getDpsInitsMutable().assign(cloned->getResults()); 3133aeb28b9SPeiming Liu producer.getResult(0).setType(op.getResult().getType()); 3143aeb28b9SPeiming Liu }); 3153aeb28b9SPeiming Liu 3163aeb28b9SPeiming Liu rewriter.replaceAllOpUsesWith(op, producer); 3173aeb28b9SPeiming Liu op->erase(); 3183aeb28b9SPeiming Liu 3193aeb28b9SPeiming Liu return success(); 3203aeb28b9SPeiming Liu } 3213aeb28b9SPeiming Liu }; 3223aeb28b9SPeiming Liu 323c7bb69bcSAart Bik /// Rewriting rule that converts direct yield of zero with initial allocation. 324c7bb69bcSAart Bik struct FoldInvariantYield : public OpRewritePattern<GenericOp> { 325c7bb69bcSAart Bik public: 326c7bb69bcSAart Bik using OpRewritePattern<GenericOp>::OpRewritePattern; 327c7bb69bcSAart Bik 328c7bb69bcSAart Bik LogicalResult matchAndRewrite(GenericOp op, 329c7bb69bcSAart Bik PatternRewriter &rewriter) const override { 3300a8e3dd4SMatthias Springer if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 || 3316a45339bSAart Bik !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) || 33263baab8bSPeiming Liu !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse()) 333c7bb69bcSAart Bik return failure(); 334255c3f11Swren romano auto outputType = getRankedTensorType(op.getResult(0)); 3356a45339bSAart Bik // Yielding zero on newly materialized sparse tensor can be 3366a45339bSAart Bik // optimized directly (regardless of dynamic or static size). 337ec495b53SPeiming Liu if (getSparseTensorEncoding(outputType)) { 338b4db15a9SAlexander Belyaev rewriter.replaceOp(op, op.getDpsInitOperand(0)->get()); 339ec495b53SPeiming Liu return success(); 340ec495b53SPeiming Liu } 3416a45339bSAart Bik // Use static zero value directly instead of materialization. 342ec495b53SPeiming Liu if (!outputType.hasStaticShape()) 343ec495b53SPeiming Liu return failure(); 3446a45339bSAart Bik Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp(); 3456a45339bSAart Bik rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType)); 3466a45339bSAart Bik rewriter.eraseOp(def); 347c7bb69bcSAart Bik return success(); 348c7bb69bcSAart Bik } 349c7bb69bcSAart Bik }; 350c7bb69bcSAart Bik 35128ebb0b6SAart Bik /// Rewriting rule that converts two kernels: 35228ebb0b6SAart Bik /// 35328ebb0b6SAart Bik /// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... ) 35428ebb0b6SAart Bik /// X(i,j) = S(i,j) * T(i,j) 35528ebb0b6SAart Bik /// 35628ebb0b6SAart Bik /// into a single kernel, using distributive law: 35728ebb0b6SAart Bik /// 35828ebb0b6SAart Bik /// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... ) 35928ebb0b6SAart Bik /// 36028ebb0b6SAart Bik /// This kind of fusion (merging two ops into one but using arithmetic 36128ebb0b6SAart Bik /// equalities that may not hold for floating-point computations) would 36228ebb0b6SAart Bik /// be undesirable in the dense case, since we distribute the multiplication 36328ebb0b6SAart Bik /// into the reduction loop. However, for sparse sampling tensor S, such 36428ebb0b6SAart Bik /// a fusion may actually reduce the asymptotic complexity of the kernel, 36528ebb0b6SAart Bik /// since intermediate results may be nullified. 36628ebb0b6SAart Bik struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> { 36728ebb0b6SAart Bik public: 36828ebb0b6SAart Bik using OpRewritePattern<GenericOp>::OpRewritePattern; 36928ebb0b6SAart Bik 37028ebb0b6SAart Bik LogicalResult matchAndRewrite(GenericOp op, 37128ebb0b6SAart Bik PatternRewriter &rewriter) const override { 37228ebb0b6SAart Bik // Check consumer. 3730a8e3dd4SMatthias Springer if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 || 37428ebb0b6SAart Bik op.getNumResults() != 1 || 37528ebb0b6SAart Bik op.getNumParallelLoops() != op.getNumLoops() || 376b4db15a9SAlexander Belyaev !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() || 377b4db15a9SAlexander Belyaev !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() || 378b4db15a9SAlexander Belyaev !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity()) 37928ebb0b6SAart Bik return failure(); 38028ebb0b6SAart Bik // Find consuming OP2(sparse, other) or OP2(other, sparse). The other 38128ebb0b6SAart Bik // operand can be sparse or dense, since the point of this rewriting rule 38228ebb0b6SAart Bik // is detecting a situation in which *more* sparsity is introduced into 38328ebb0b6SAart Bik // a computation, be it already sparse or still dense. 38428ebb0b6SAart Bik unsigned other = 0; 385b4db15a9SAlexander Belyaev if (isSparseTensor(op.getDpsInputOperand(0))) 38628ebb0b6SAart Bik other = 1; 387b4db15a9SAlexander Belyaev else if (!isSparseTensor(op.getDpsInputOperand(1))) 38828ebb0b6SAart Bik return failure(); 38928ebb0b6SAart Bik // Check producer. 39028ebb0b6SAart Bik auto prod = dyn_cast_or_null<GenericOp>( 391b4db15a9SAlexander Belyaev op.getDpsInputOperand(other)->get().getDefiningOp()); 3920a8e3dd4SMatthias Springer if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 || 39328ebb0b6SAart Bik !prod.getResult(0).hasOneUse()) 39428ebb0b6SAart Bik return failure(); 39528ebb0b6SAart Bik // Sampling consumer and sum of multiplication chain producer. 3966a45339bSAart Bik if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) || 3976a45339bSAart Bik !isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) || 398ce3d0e87SAart Bik !isSampling(op) || !isSumOfMul(prod)) 39928ebb0b6SAart Bik return failure(); 40028ebb0b6SAart Bik // Modify operand structure of producer and consumer. 40128ebb0b6SAart Bik Location loc = prod.getLoc(); 402a7cccb9cSAlexander Belyaev SmallVector<Value> inputOps = prod.getInputs(); 403a7cccb9cSAlexander Belyaev SmallVector<Value> outputOps = op.getOutputs(); 404d2c0572bSJacques Pienaar SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray(); 405b4db15a9SAlexander Belyaev inputOps.push_back(op.getDpsInputOperand(1 - other)->get()); 40628ebb0b6SAart Bik fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other 40728ebb0b6SAart Bik // Fuse producer and consumer into a new generic op. 40828ebb0b6SAart Bik auto fusedOp = rewriter.create<GenericOp>( 40928ebb0b6SAart Bik loc, op.getResult(0).getType(), inputOps, outputOps, 410c38d9cf2SOleg Shyshkov rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(), 41128ebb0b6SAart Bik /*doc=*/nullptr, /*library_call=*/nullptr); 412d3b3f765SJacques Pienaar Block &prodBlock = prod.getRegion().front(); 413d3b3f765SJacques Pienaar Block &consBlock = op.getRegion().front(); 4144d67b278SJeff Niu IRMapping mapper; 41591d5653eSMatthias Springer Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion()); 41628ebb0b6SAart Bik unsigned num = prodBlock.getNumArguments(); 41728ebb0b6SAart Bik for (unsigned i = 0; i < num - 1; i++) 41828ebb0b6SAart Bik addArg(mapper, fusedBlock, prodBlock.getArgument(i)); 41928ebb0b6SAart Bik addArg(mapper, fusedBlock, consBlock.getArgument(1 - other)); 42028ebb0b6SAart Bik addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1)); 42128ebb0b6SAart Bik // Clone bodies of the producer and consumer in new evaluation order. 42228ebb0b6SAart Bik auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp(); 42328ebb0b6SAart Bik auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp(); 42428ebb0b6SAart Bik Value last; 42528ebb0b6SAart Bik for (auto &op : prodBlock.without_terminator()) 42628ebb0b6SAart Bik if (&op != acc) { 42728ebb0b6SAart Bik last = op.getResult(0); 42828ebb0b6SAart Bik rewriter.clone(op, mapper); 42928ebb0b6SAart Bik } 43028ebb0b6SAart Bik mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0)); 43128ebb0b6SAart Bik mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0)); 43228ebb0b6SAart Bik last = rewriter.clone(*acc, mapper)->getResult(0); 43328ebb0b6SAart Bik rewriter.create<linalg::YieldOp>(loc, last); 434ce3d0e87SAart Bik // Force initial value on merged allocation for dense outputs. 4356a45339bSAart Bik // TODO: deal with non alloc tensor here one day 436ce3d0e87SAart Bik if (!getSparseTensorEncoding(op.getResult(0).getType())) { 437b4db15a9SAlexander Belyaev Value init = prod.getDpsInitOperand(0) 438c7bb69bcSAart Bik ->get() 439c7bb69bcSAart Bik .getDefiningOp<AllocTensorOp>() 440c7bb69bcSAart Bik .getCopy(); 441c7bb69bcSAart Bik AllocTensorOp a = 442b4db15a9SAlexander Belyaev op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>(); 4435fcf907bSMatthias Springer rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); }); 444ce3d0e87SAart Bik } 44528ebb0b6SAart Bik // Replace consumer with fused operation. Old producer 44628ebb0b6SAart Bik // and consumer ops will be removed by DCE. 44728ebb0b6SAart Bik rewriter.replaceOp(op, fusedOp->getResults()); 44828ebb0b6SAart Bik return success(); 44928ebb0b6SAart Bik } 45028ebb0b6SAart Bik 45128ebb0b6SAart Bik private: 45228ebb0b6SAart Bik // Helper to add argument and record the mapping. 4534d67b278SJeff Niu static void addArg(IRMapping &mapper, Block *b, BlockArgument a) { 45428ebb0b6SAart Bik mapper.map(a, b->addArgument(a.getType(), a.getLoc())); 45528ebb0b6SAart Bik } 45628ebb0b6SAart Bik }; 45728ebb0b6SAart Bik 4589a018a7bSAart Bik // Fuse a tensor cast into producing operation. Note that a tensor.cast 4599a018a7bSAart Bik // should really not be used to convert between sparse encodings. Since 4609a018a7bSAart Bik // the pattern currently appears as a result of some prior rewriting 4619a018a7bSAart Bik // we make an attempt to repair very obvious cases. 4629a018a7bSAart Bik // TODO: audit the pure tensor dialect rewriting rules 4639a018a7bSAart Bik struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> { 4649a018a7bSAart Bik public: 4659a018a7bSAart Bik using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 4669a018a7bSAart Bik 4679a018a7bSAart Bik LogicalResult matchAndRewrite(tensor::CastOp op, 4689a018a7bSAart Bik PatternRewriter &rewriter) const override { 4699a018a7bSAart Bik Type srcType = op.getSource().getType(); 4709a018a7bSAart Bik Type dstType = op.getDest().getType(); 4719a018a7bSAart Bik // A nop cast simply folds away. 4729a018a7bSAart Bik if (srcType == dstType) { 4739a018a7bSAart Bik rewriter.replaceOp(op, op->getResults()); 4749a018a7bSAart Bik return success(); 4759a018a7bSAart Bik } 4769a018a7bSAart Bik // See if a sparsity changing cast can be fused into producer. 4779a018a7bSAart Bik if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) { 4789a018a7bSAart Bik if (Operation *def = op.getSource().getDefiningOp()) { 4799a018a7bSAart Bik if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) { 4805fcf907bSMatthias Springer rewriter.modifyOpInPlace(def, [&]() { 4819a018a7bSAart Bik def->getResult(0).setType(op->getResultTypes()[0]); 4823d90c812SMatthias Springer }); 4839a018a7bSAart Bik rewriter.replaceOp(op, def->getResult(0)); 4849a018a7bSAart Bik return success(); 4859a018a7bSAart Bik } 4869a018a7bSAart Bik } 4879a018a7bSAart Bik } 4889a018a7bSAart Bik // Repair tensor casts with at least one sparse operand into the 4899a018a7bSAart Bik // the properly supported sparse_tensor.convert. 4909a018a7bSAart Bik if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) { 4919a018a7bSAart Bik rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource()); 4929a018a7bSAart Bik return success(); 4939a018a7bSAart Bik } 4949a018a7bSAart Bik // Fail otherwise. 4959a018a7bSAart Bik return failure(); 4969a018a7bSAart Bik } 4979a018a7bSAart Bik }; 4989a018a7bSAart Bik 499e7df8281SPeiming Liu /// Rewrites a sequence of operations for sparse tensor selections in to 500c43e6274STim Harvey /// semi-ring operations such that they can be compiled correctly by the 501c43e6274STim Harvey /// sparsifier. E.g., transforming the following sequence 502e7df8281SPeiming Liu /// 503e7df8281SPeiming Liu /// %sel = arith.select %cond, %sp1, %sp2 504e7df8281SPeiming Liu /// 505e7df8281SPeiming Liu /// to 506e7df8281SPeiming Liu /// 507e7df8281SPeiming Liu /// %sel = binary %sp1, %sp2: 508e7df8281SPeiming Liu /// both (%l, %r) {yield select %cond, %l, %r} 509e7df8281SPeiming Liu /// left (%l) {yield select %cond, %l, 0} 510e7df8281SPeiming Liu /// right (%r) {yield select %cond, 0, %r} 511e7df8281SPeiming Liu /// 512e7df8281SPeiming Liu /// TODO: We require that the tensor used for extracting conditions to be dense 513e7df8281SPeiming Liu /// to sparsify the code. To support a sparse condition tensor, we need a 514e7df8281SPeiming Liu /// tri-nary operation. 515e7df8281SPeiming Liu struct GenSemiRingSelect : public OpRewritePattern<GenericOp> { 516e7df8281SPeiming Liu public: 517e7df8281SPeiming Liu using OpRewritePattern<GenericOp>::OpRewritePattern; 518e7df8281SPeiming Liu LogicalResult matchAndRewrite(GenericOp op, 519e7df8281SPeiming Liu PatternRewriter &rewriter) const override { 520e7df8281SPeiming Liu // Rejects non sparse kernels. 5210a8e3dd4SMatthias Springer if (!op.hasPureTensorSemantics() || !hasAnySparseOperand(op)) 522e7df8281SPeiming Liu return failure(); 523e7df8281SPeiming Liu 524e7df8281SPeiming Liu Location loc = op.getLoc(); 525e7df8281SPeiming Liu SmallVector<std::pair<Operation *, sparse_tensor::BinaryOp>> semiRings; 526e7df8281SPeiming Liu for (Operation &inst : *op.getBody()) { 527e7df8281SPeiming Liu // Matches pattern. 528e7df8281SPeiming Liu auto matched = isRewritablePattern(op, &inst); 529e7df8281SPeiming Liu if (!matched.has_value()) 530e7df8281SPeiming Liu continue; 531e7df8281SPeiming Liu 532e7df8281SPeiming Liu rewriter.setInsertionPoint(&inst); 533e7df8281SPeiming Liu auto [c, t, f] = matched.value(); 534e7df8281SPeiming Liu assert(t.getType() == f.getType()); 535e7df8281SPeiming Liu auto selTp = t.getType(); 536e7df8281SPeiming Liu auto c0 = constantZero(rewriter, loc, selTp); 537e7df8281SPeiming Liu auto binOp = rewriter.create<sparse_tensor::BinaryOp>(loc, selTp, t, f); 538e7df8281SPeiming Liu // Initializes all the blocks. 539e7df8281SPeiming Liu rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp}, 540e7df8281SPeiming Liu {t.getLoc(), f.getLoc()}); 541e7df8281SPeiming Liu rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc()); 542e7df8281SPeiming Liu rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc()); 543e7df8281SPeiming Liu 544e7df8281SPeiming Liu for (auto *r : binOp.getRegions()) { 545e7df8281SPeiming Liu Block *b = &r->front(); 546e7df8281SPeiming Liu rewriter.setInsertionPointToStart(b); 547e7df8281SPeiming Liu 548e7df8281SPeiming Liu IRMapping irMap; 549e7df8281SPeiming Liu // Clones the cmp operations into the region to make the binary op 550e7df8281SPeiming Liu // admissible. 551e7df8281SPeiming Liu Value newC = c; 552e7df8281SPeiming Liu if (auto *def = c.getDefiningOp()) 553e7df8281SPeiming Liu newC = rewriter.clone(*def, irMap)->getResult(0); 554e7df8281SPeiming Liu 555e7df8281SPeiming Liu irMap.map(c, newC); 556e7df8281SPeiming Liu if (r == &binOp.getLeftRegion()) { 557e7df8281SPeiming Liu irMap.map(t, b->getArgument(0)); 558e7df8281SPeiming Liu irMap.map(f, c0); 559e7df8281SPeiming Liu } else if (r == &binOp.getRightRegion()) { 560e7df8281SPeiming Liu irMap.map(t, c0); 561e7df8281SPeiming Liu irMap.map(f, b->getArgument(0)); 562e7df8281SPeiming Liu } else { 563e7df8281SPeiming Liu irMap.map(t, b->getArgument(0)); 564e7df8281SPeiming Liu irMap.map(f, b->getArgument(1)); 565e7df8281SPeiming Liu } 566e7df8281SPeiming Liu auto y = rewriter.clone(inst, irMap)->getResult(0); 567e7df8281SPeiming Liu rewriter.create<sparse_tensor::YieldOp>(loc, y); 568e7df8281SPeiming Liu } 569e7df8281SPeiming Liu 570e7df8281SPeiming Liu // We successfully rewrited a operation. We can not do replacement here 571e7df8281SPeiming Liu // becuase it invalidate the iterator for the current loop to traverse 572e7df8281SPeiming Liu // the instructions. 573e7df8281SPeiming Liu semiRings.emplace_back(&inst, binOp); 574e7df8281SPeiming Liu } 575e7df8281SPeiming Liu 576e7df8281SPeiming Liu // Finalizes the replacement. 577e7df8281SPeiming Liu for (auto [sel, semi] : semiRings) 578e7df8281SPeiming Liu rewriter.replaceOp(sel, semi->getResults()); 579e7df8281SPeiming Liu 580e7df8281SPeiming Liu return success(!semiRings.empty()); 581e7df8281SPeiming Liu } 582e7df8281SPeiming Liu 583e7df8281SPeiming Liu private: 584e7df8281SPeiming Liu static std::optional<std::tuple<Value, BlockArgument, BlockArgument>> 585e7df8281SPeiming Liu isRewritablePattern(GenericOp op, Operation *v) { 586e7df8281SPeiming Liu auto sel = dyn_cast<arith::SelectOp>(v); 587e7df8281SPeiming Liu if (!sel) 588e7df8281SPeiming Liu return std::nullopt; 589e7df8281SPeiming Liu 590a5757c5bSChristian Sigg auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue()); 591a5757c5bSChristian Sigg auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue()); 592e7df8281SPeiming Liu // TODO: For simplicity, we only handle cases where both true/false value 593e7df8281SPeiming Liu // are directly loaded the input tensor. We can probably admit more cases 594e7df8281SPeiming Liu // in theory. 595e7df8281SPeiming Liu if (!tVal || !fVal) 596e7df8281SPeiming Liu return std::nullopt; 597e7df8281SPeiming Liu 598e7df8281SPeiming Liu // Helper lambda to determine whether the value is loaded from a dense input 599e7df8281SPeiming Liu // or is a loop invariant. 600e7df8281SPeiming Liu auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool { 601a5757c5bSChristian Sigg if (auto bArg = dyn_cast<BlockArgument>(v); 602e7df8281SPeiming Liu bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber()))) 603e7df8281SPeiming Liu return true; 604e7df8281SPeiming Liu // If the value is defined outside the loop, it is a loop invariant. 605e7df8281SPeiming Liu return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody(); 606e7df8281SPeiming Liu }; 607e7df8281SPeiming Liu 608e7df8281SPeiming Liu // If the condition value is load directly from a dense tensor or 609e7df8281SPeiming Liu // loop-invariants, we can sparsify the kernel. 610e7df8281SPeiming Liu auto cond = sel.getCondition(); 611e7df8281SPeiming Liu if (isValFromDenseInputOrInvariant(cond)) 612e7df8281SPeiming Liu return std::make_tuple(cond, tVal, fVal); 613e7df8281SPeiming Liu 614e7df8281SPeiming Liu Value cmpL, cmpR; 615e7df8281SPeiming Liu if (matchPattern(cond, m_Op<arith::CmpIOp>(matchers::m_Any(&cmpL), 616e7df8281SPeiming Liu matchers::m_Any(&cmpR))) || 617e7df8281SPeiming Liu matchPattern(cond, m_Op<arith::CmpFOp>(matchers::m_Any(&cmpL), 618e7df8281SPeiming Liu matchers::m_Any(&cmpR)))) { 619e7df8281SPeiming Liu // TODO: we can do it recursively to check whether all the leaf values are 620e7df8281SPeiming Liu // loaded from dense tensors or are loop invariants. 621e7df8281SPeiming Liu if (isValFromDenseInputOrInvariant(cmpL) || 622e7df8281SPeiming Liu isValFromDenseInputOrInvariant(cmpR)) 623e7df8281SPeiming Liu return std::make_tuple(cond, tVal, fVal); 624e7df8281SPeiming Liu } 625e7df8281SPeiming Liu 626e7df8281SPeiming Liu return std::nullopt; 627e7df8281SPeiming Liu }; 628e7df8281SPeiming Liu }; 629e7df8281SPeiming Liu 63080fe3168SAart Bik /// Rewrites a sparse reduction that would not sparsify directly since 63180fe3168SAart Bik /// doing so would only iterate over the stored elements, ignoring the 63280fe3168SAart Bik /// implicit zeros, into a semi-ring. Applies to all prod/and/min/max 63380fe3168SAart Bik /// (note that reductions like add/sub/or/xor can directly be sparsified 63480fe3168SAart Bik /// since the implicit zeros do not contribute to the final result). 63580fe3168SAart Bik /// Note that prod/and are still included since, even though they often 63680fe3168SAart Bik /// are nullified in sparse data, they may still occur for special 63780fe3168SAart Bik /// situations in which e.g. some rows in a sparse matrix are fully 63880fe3168SAart Bik /// dense. For min/max, including the implicit zeros is a much more 63980fe3168SAart Bik /// common situation. 64080fe3168SAart Bik /// 64180fe3168SAart Bik /// TODO: this essentially "densifies" the operation; we want to implement 64280fe3168SAart Bik /// this much more efficiently by performing the reduction over the 64380fe3168SAart Bik /// stored values, and feed in the zero once if there were *any* 64480fe3168SAart Bik /// implicit zeros as well; but for now, at least we provide 64580fe3168SAart Bik /// the functionality 64680fe3168SAart Bik /// 64780fe3168SAart Bik struct GenSemiRingReduction : public OpRewritePattern<GenericOp> { 64880fe3168SAart Bik public: 64980fe3168SAart Bik using OpRewritePattern<GenericOp>::OpRewritePattern; 65080fe3168SAart Bik 65180fe3168SAart Bik LogicalResult matchAndRewrite(GenericOp op, 65280fe3168SAart Bik PatternRewriter &rewriter) const override { 65380fe3168SAart Bik // Reject non-reductions. 6540a8e3dd4SMatthias Springer if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 || 65580fe3168SAart Bik op.getNumReductionLoops() == 0 || op.getNumResults() != 1) 65680fe3168SAart Bik return failure(); 65761f64d1cSMehdi Amini auto *inp = op.getDpsInputOperand(0); 65861f64d1cSMehdi Amini auto *init = op.getDpsInitOperand(0); 65980fe3168SAart Bik if (!isSparseTensor(inp)) 66080fe3168SAart Bik return failure(); 66180fe3168SAart Bik // Look for direct x = x OP y for semi-ring ready reductions. 66261f64d1cSMehdi Amini auto *red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()) 66380fe3168SAart Bik .getOperand(0) 66480fe3168SAart Bik .getDefiningOp(); 6658a6e54c9SDaniil Dudkin if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp, 6668a6e54c9SDaniil Dudkin arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp, 66780fe3168SAart Bik arith::MaxUIOp>(red)) 66880fe3168SAart Bik return failure(); 66980fe3168SAart Bik Value s0 = op.getBlock()->getArgument(0); 67080fe3168SAart Bik Value s1 = op.getBlock()->getArgument(1); 67180fe3168SAart Bik if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) && 67280fe3168SAart Bik (red->getOperand(0) != s1 || red->getOperand(1) != s0)) 67380fe3168SAart Bik return failure(); 67480fe3168SAart Bik // Identity. 67580fe3168SAart Bik Location loc = op.getLoc(); 67680fe3168SAart Bik Value identity = 67780fe3168SAart Bik rewriter.create<tensor::ExtractOp>(loc, init->get(), ValueRange()); 67880fe3168SAart Bik // Unary { 67980fe3168SAart Bik // present -> value 68080fe3168SAart Bik // absent -> zero. 68180fe3168SAart Bik // } 68280fe3168SAart Bik Type rtp = s0.getType(); 68380fe3168SAart Bik rewriter.setInsertionPointToStart(&op.getRegion().front()); 68480fe3168SAart Bik auto semiring = rewriter.create<sparse_tensor::UnaryOp>(loc, rtp, s0); 68580fe3168SAart Bik Block *present = 68680fe3168SAart Bik rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc); 68780fe3168SAart Bik rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front()); 68880fe3168SAart Bik rewriter.create<sparse_tensor::YieldOp>(loc, present->getArgument(0)); 68980fe3168SAart Bik rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {}); 69080fe3168SAart Bik rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front()); 69180fe3168SAart Bik auto zero = 69280fe3168SAart Bik rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(rtp)); 69380fe3168SAart Bik rewriter.create<sparse_tensor::YieldOp>(loc, zero); 69480fe3168SAart Bik rewriter.setInsertionPointAfter(semiring); 69580fe3168SAart Bik // CustomReduce { 69680fe3168SAart Bik // x = x REDUC y, identity 69780fe3168SAart Bik // } 69880fe3168SAart Bik auto custom = rewriter.create<sparse_tensor::ReduceOp>( 69980fe3168SAart Bik loc, rtp, semiring.getResult(), s1, identity); 70080fe3168SAart Bik Block *region = 70180fe3168SAart Bik rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc}); 70280fe3168SAart Bik rewriter.setInsertionPointToStart(&custom.getRegion().front()); 70380fe3168SAart Bik IRMapping irMap; 70480fe3168SAart Bik irMap.map(red->getOperand(0), region->getArgument(0)); 70580fe3168SAart Bik irMap.map(red->getOperand(1), region->getArgument(1)); 70661f64d1cSMehdi Amini auto *cloned = rewriter.clone(*red, irMap); 70780fe3168SAart Bik rewriter.create<sparse_tensor::YieldOp>(loc, cloned->getResult(0)); 70880fe3168SAart Bik rewriter.setInsertionPointAfter(custom); 70980fe3168SAart Bik rewriter.replaceOp(red, custom.getResult()); 71080fe3168SAart Bik return success(); 71180fe3168SAart Bik } 71280fe3168SAart Bik }; 71380fe3168SAart Bik 714d37affb0SAart Bik /// Sparse rewriting rule for the print operator. This operation is mainly used 715d37affb0SAart Bik /// for debugging and testing. As such, it lowers to the vector.print operation 716d37affb0SAart Bik /// which only require very light-weight runtime support. 717d37affb0SAart Bik struct PrintRewriter : public OpRewritePattern<PrintOp> { 718d37affb0SAart Bik public: 719d37affb0SAart Bik using OpRewritePattern::OpRewritePattern; 720d37affb0SAart Bik LogicalResult matchAndRewrite(PrintOp op, 721d37affb0SAart Bik PatternRewriter &rewriter) const override { 722d37affb0SAart Bik Location loc = op.getLoc(); 723d37affb0SAart Bik auto tensor = op.getTensor(); 724d37affb0SAart Bik auto stt = getSparseTensorType(tensor); 725d37affb0SAart Bik // Header with NSE. 726d37affb0SAart Bik auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor); 727d37affb0SAart Bik rewriter.create<vector::PrintOp>( 728d37affb0SAart Bik loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = ")); 729d37affb0SAart Bik rewriter.create<vector::PrintOp>(loc, nse); 730691fc7cdSAart Bik // Print run-time contents for dim/lvl sizes. 731691fc7cdSAart Bik rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("dim = ")); 732691fc7cdSAart Bik printSizes(rewriter, loc, tensor, stt.getDimRank(), /*isDim=*/true); 733691fc7cdSAart Bik rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("lvl = ")); 734691fc7cdSAart Bik printSizes(rewriter, loc, tensor, stt.getLvlRank(), /*isDim=*/false); 735d37affb0SAart Bik // Use the "codegen" foreach loop construct to iterate over 736d37affb0SAart Bik // all typical sparse tensor components for printing. 7376bc7c9dfSPeiming Liu foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor, 7386bc7c9dfSPeiming Liu &stt](Type, FieldIndex, 739d37affb0SAart Bik SparseTensorFieldKind kind, 740d37affb0SAart Bik Level l, LevelType) { 741d37affb0SAart Bik switch (kind) { 742d37affb0SAart Bik case SparseTensorFieldKind::StorageSpec: { 743d37affb0SAart Bik break; 744d37affb0SAart Bik } 745d37affb0SAart Bik case SparseTensorFieldKind::PosMemRef: { 746d37affb0SAart Bik auto lvl = constantIndex(rewriter, loc, l); 747d37affb0SAart Bik rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos[")); 748d37affb0SAart Bik rewriter.create<vector::PrintOp>( 749d37affb0SAart Bik loc, lvl, vector::PrintPunctuation::NoPunctuation); 750d37affb0SAart Bik rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : ")); 7516bc7c9dfSPeiming Liu auto pos = rewriter.create<ToPositionsOp>(loc, tensor, l); 7526bc7c9dfSPeiming Liu printContents(rewriter, loc, pos); 753d37affb0SAart Bik break; 754d37affb0SAart Bik } 755d37affb0SAart Bik case SparseTensorFieldKind::CrdMemRef: { 756d37affb0SAart Bik auto lvl = constantIndex(rewriter, loc, l); 757d37affb0SAart Bik rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("crd[")); 758d37affb0SAart Bik rewriter.create<vector::PrintOp>( 759d37affb0SAart Bik loc, lvl, vector::PrintPunctuation::NoPunctuation); 760d37affb0SAart Bik rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : ")); 7616bc7c9dfSPeiming Liu Value crd = nullptr; 762dc4cfdbbSAart Bik // For COO AoS storage, we want to print a single, linear view of 763dc4cfdbbSAart Bik // the full coordinate storage at this level. For any other storage, 764dc4cfdbbSAart Bik // we show the coordinate storage for every indivual level. 7656bc7c9dfSPeiming Liu if (stt.getAoSCOOStart() == l) 7666bc7c9dfSPeiming Liu crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor); 7676bc7c9dfSPeiming Liu else 7686bc7c9dfSPeiming Liu crd = rewriter.create<ToCoordinatesOp>(loc, tensor, l); 7696bc7c9dfSPeiming Liu printContents(rewriter, loc, crd); 770d37affb0SAart Bik break; 771d37affb0SAart Bik } 772d37affb0SAart Bik case SparseTensorFieldKind::ValMemRef: { 773d37affb0SAart Bik rewriter.create<vector::PrintOp>(loc, 774d37affb0SAart Bik rewriter.getStringAttr("values : ")); 7756bc7c9dfSPeiming Liu auto val = rewriter.create<ToValuesOp>(loc, tensor); 7766bc7c9dfSPeiming Liu printContents(rewriter, loc, val); 777d37affb0SAart Bik break; 778d37affb0SAart Bik } 779d37affb0SAart Bik } 780d37affb0SAart Bik return true; 781d37affb0SAart Bik }); 782d37affb0SAart Bik rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("----\n")); 783d37affb0SAart Bik rewriter.eraseOp(op); 784d37affb0SAart Bik return success(); 785d37affb0SAart Bik } 786d37affb0SAart Bik 787d37affb0SAart Bik private: 788c4e5a8a4SAart Bik // Helper to print contents of a single memref. For "push_back" vectors, 789c4e5a8a4SAart Bik // we assume that the previous getters for pos/crd/val have added a 790c4e5a8a4SAart Bik // slice-to-size view to make sure we just print the size and not the 791c4e5a8a4SAart Bik // full capacity. 792d37affb0SAart Bik // 793c4e5a8a4SAart Bik // Generates code to print (1-dim or higher): 794d37affb0SAart Bik // ( a0, a1, ... ) 7956bc7c9dfSPeiming Liu static void printContents(PatternRewriter &rewriter, Location loc, 796d37affb0SAart Bik Value vec) { 797c4e5a8a4SAart Bik auto shape = cast<ShapedType>(vec.getType()).getShape(); 798c4e5a8a4SAart Bik SmallVector<Value> idxs; 799c4e5a8a4SAart Bik printContentsLevel(rewriter, loc, vec, 0, shape, idxs); 800c4e5a8a4SAart Bik rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine); 801c4e5a8a4SAart Bik } 802c4e5a8a4SAart Bik 803c4e5a8a4SAart Bik // Helper to the helper. 804c4e5a8a4SAart Bik static void printContentsLevel(PatternRewriter &rewriter, Location loc, 805c4e5a8a4SAart Bik Value vec, unsigned i, ArrayRef<int64_t> shape, 806c4e5a8a4SAart Bik SmallVectorImpl<Value> &idxs) { 807d37affb0SAart Bik // Open bracket. 808d37affb0SAart Bik rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open); 809c4e5a8a4SAart Bik // Generate for loop. 810d37affb0SAart Bik auto zero = constantIndex(rewriter, loc, 0); 811c4e5a8a4SAart Bik auto index = constantIndex(rewriter, loc, i); 812c4e5a8a4SAart Bik auto size = rewriter.create<memref::DimOp>(loc, vec, index); 813d37affb0SAart Bik auto step = constantIndex(rewriter, loc, 1); 814d37affb0SAart Bik auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step); 815c4e5a8a4SAart Bik idxs.push_back(forOp.getInductionVar()); 816d37affb0SAart Bik rewriter.setInsertionPointToStart(forOp.getBody()); 817c4e5a8a4SAart Bik if (i < shape.size() - 1) { 818c4e5a8a4SAart Bik // Enter deeper loop nest. 819c4e5a8a4SAart Bik printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs); 820c4e5a8a4SAart Bik } else { 821c4e5a8a4SAart Bik // Actual contents printing. 822c4e5a8a4SAart Bik auto val = rewriter.create<memref::LoadOp>(loc, vec, idxs); 823275fe3aeSAart Bik if (llvm::isa<ComplexType>(val.getType())) { 824275fe3aeSAart Bik // Since the vector dialect does not support complex types in any op, 825275fe3aeSAart Bik // we split those into (real, imag) pairs here. 826275fe3aeSAart Bik Value real = rewriter.create<complex::ReOp>(loc, val); 827275fe3aeSAart Bik Value imag = rewriter.create<complex::ImOp>(loc, val); 828275fe3aeSAart Bik rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open); 829275fe3aeSAart Bik rewriter.create<vector::PrintOp>(loc, real, 830275fe3aeSAart Bik vector::PrintPunctuation::Comma); 831275fe3aeSAart Bik rewriter.create<vector::PrintOp>(loc, imag, 832275fe3aeSAart Bik vector::PrintPunctuation::Close); 833275fe3aeSAart Bik } else { 834eb177803SYinying Li rewriter.create<vector::PrintOp>( 835eb177803SYinying Li loc, val, vector::PrintPunctuation::NoPunctuation); 836275fe3aeSAart Bik } 837eb177803SYinying Li // Terminating comma (except at end). 838eb177803SYinying Li auto bound = rewriter.create<arith::AddIOp>(loc, idxs.back(), step); 839eb177803SYinying Li Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, 840eb177803SYinying Li bound, size); 841eb177803SYinying Li scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false); 842eb177803SYinying Li rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 843eb177803SYinying Li rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma); 844c4e5a8a4SAart Bik } 845c4e5a8a4SAart Bik idxs.pop_back(); 846d37affb0SAart Bik rewriter.setInsertionPointAfter(forOp); 847c4e5a8a4SAart Bik // Close bracket. 848d37affb0SAart Bik rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close); 849d37affb0SAart Bik } 850691fc7cdSAart Bik 851691fc7cdSAart Bik // Helper method to print run-time lvl/dim sizes. 852691fc7cdSAart Bik static void printSizes(PatternRewriter &rewriter, Location loc, Value tensor, 853691fc7cdSAart Bik unsigned size, bool isDim) { 854691fc7cdSAart Bik // Open bracket. 855691fc7cdSAart Bik rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open); 856691fc7cdSAart Bik // Print unrolled contents (dimop requires constant value). 857691fc7cdSAart Bik for (unsigned i = 0; i < size; i++) { 858691fc7cdSAart Bik auto idx = constantIndex(rewriter, loc, i); 859691fc7cdSAart Bik Value val; 860691fc7cdSAart Bik if (isDim) 861691fc7cdSAart Bik val = rewriter.create<tensor::DimOp>(loc, tensor, idx); 862691fc7cdSAart Bik else 863691fc7cdSAart Bik val = rewriter.create<LvlOp>(loc, tensor, idx); 864691fc7cdSAart Bik rewriter.create<vector::PrintOp>( 865691fc7cdSAart Bik loc, val, 866691fc7cdSAart Bik i != size - 1 ? vector::PrintPunctuation::Comma 867691fc7cdSAart Bik : vector::PrintPunctuation::NoPunctuation); 868691fc7cdSAart Bik } 869691fc7cdSAart Bik // Close bracket and end of line. 870691fc7cdSAart Bik rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close); 871691fc7cdSAart Bik rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine); 872691fc7cdSAart Bik } 873d37affb0SAart Bik }; 874d37affb0SAart Bik 875330d48c4Sbixia1 /// Sparse rewriting rule for sparse-to-sparse reshape operator. 8766116ca67SAnlun Xu struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> { 8776116ca67SAnlun Xu public: 8786116ca67SAnlun Xu using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern; 8796116ca67SAnlun Xu 8806116ca67SAnlun Xu LogicalResult matchAndRewrite(tensor::ReshapeOp op, 8816116ca67SAnlun Xu PatternRewriter &rewriter) const override { 8826116ca67SAnlun Xu Location loc = op.getLoc(); 8836116ca67SAnlun Xu Value srcTensor = op.getSource(); 884*129ade21SLongsheng Mou const auto srcTp = tryGetSparseTensorType(srcTensor); 885*129ade21SLongsheng Mou const auto dstTp = tryGetSparseTensorType(op.getResult()); 886*129ade21SLongsheng Mou if (!srcTp || !dstTp) 887*129ade21SLongsheng Mou return failure(); 8886116ca67SAnlun Xu 889*129ade21SLongsheng Mou if (!srcTp->hasEncoding() || !dstTp->hasEncoding() || 890*129ade21SLongsheng Mou !dstTp->hasStaticDimShape()) 8916116ca67SAnlun Xu return failure(); 8926116ca67SAnlun Xu 8936116ca67SAnlun Xu SmallVector<Value> srcSizes; 894*129ade21SLongsheng Mou sizesForTensor(rewriter, srcSizes, loc, *srcTp, srcTensor); 8956116ca67SAnlun Xu SmallVector<Value> dstSizes; 896*129ade21SLongsheng Mou for (Dimension d : dstTp->getDimShape()) 8976116ca67SAnlun Xu dstSizes.push_back(constantIndex(rewriter, loc, d)); 8986116ca67SAnlun Xu 8996116ca67SAnlun Xu Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor); 9006116ca67SAnlun Xu // Only need an unordered COO buffer if input and output are not sorted 9016116ca67SAnlun Xu // in the same way. 90276647fceSwren romano Type bufferTp = getBufferType( 903*129ade21SLongsheng Mou dstTp->withoutDimToLvl(), 904*129ade21SLongsheng Mou !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity()); 9056116ca67SAnlun Xu SmallVector<Value> dynSizes; 9066116ca67SAnlun Xu Value buffer = rewriter 9076116ca67SAnlun Xu .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(), 9086116ca67SAnlun Xu nnz, Attribute()) 9096116ca67SAnlun Xu .getResult(); 9106116ca67SAnlun Xu 9116116ca67SAnlun Xu // Convert src coordinates to dst coordinates by first collapsing it to 1D 9126116ca67SAnlun Xu // and then expand it to the match the rank of the destination tensor. 9136116ca67SAnlun Xu // Implemented as follows: 9146116ca67SAnlun Xu // foreach srcCoords %srcTensor 9156116ca67SAnlun Xu // collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank]) 9166116ca67SAnlun Xu // expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank]) 9176116ca67SAnlun Xu // insert expandedCoords, %buffer 9186116ca67SAnlun Xu // 9196116ca67SAnlun Xu // followed by an optional 9206116ca67SAnlun Xu // %t = sparse_tensor.cast %tmp 9216116ca67SAnlun Xu // depending on whether the input/output are sorted in the same way. 922*129ade21SLongsheng Mou const auto encSrc = srcTp->getEncoding(); 9236116ca67SAnlun Xu ForeachOp foreachOp = rewriter.create<ForeachOp>( 9246116ca67SAnlun Xu loc, srcTensor, buffer, 9256116ca67SAnlun Xu [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v, 9266116ca67SAnlun Xu ValueRange reduc) { 927*129ade21SLongsheng Mou const Dimension srcRank = srcTp->getDimRank(); 9286116ca67SAnlun Xu SmallVector<Value> srcDcvs; 9296116ca67SAnlun Xu srcDcvs.reserve(srcRank); 9306116ca67SAnlun Xu for (Dimension d = 0; d < srcRank; d++) { 9314e2f1521SPeiming Liu Level lvl = toLvl(encSrc, d); 9326116ca67SAnlun Xu srcDcvs.push_back(srcLcvs[lvl]); 9336116ca67SAnlun Xu } 9346116ca67SAnlun Xu 935eb14f47bSPeiming Liu Value collapseSize = constantIndex(builder, loc, 1); 9366116ca67SAnlun Xu for (Dimension d = 0; d < srcRank; d++) 937eb14f47bSPeiming Liu collapseSize = 938eb14f47bSPeiming Liu builder.create<arith::MulIOp>(loc, collapseSize, srcSizes[d]); 939eb14f47bSPeiming Liu SmallVector<Value, 1> collapsedSizes = {collapseSize}; 9406116ca67SAnlun Xu 941eb14f47bSPeiming Liu ReassociationIndices collapseIdx; 9426116ca67SAnlun Xu for (Dimension i = 0; i < srcRank; i++) 943eb14f47bSPeiming Liu collapseIdx.push_back(i); 944eb14f47bSPeiming Liu SmallVector<ReassociationIndices, 1> collapseReass = {collapseIdx}; 9456116ca67SAnlun Xu SmallVector<Value, 1> collapsedDcvs; 946eb14f47bSPeiming Liu reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs, 9476116ca67SAnlun Xu collapsedSizes, collapsedDcvs); 9486116ca67SAnlun Xu 949eb14f47bSPeiming Liu ReassociationIndices expandIdx; 950*129ade21SLongsheng Mou for (Dimension i = 0; i < dstTp->getDimRank(); i++) 951eb14f47bSPeiming Liu expandIdx.push_back(i); 952eb14f47bSPeiming Liu SmallVector<ReassociationIndices, 1> expandReass = {expandIdx}; 9536116ca67SAnlun Xu SmallVector<Value> dstDcvs; 954eb14f47bSPeiming Liu reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs, 955eb14f47bSPeiming Liu dstSizes, dstDcvs); 9566116ca67SAnlun Xu 95794e27c26SPeiming Liu auto t = 95894e27c26SPeiming Liu builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs); 9596116ca67SAnlun Xu builder.create<sparse_tensor::YieldOp>(loc, t); 9606116ca67SAnlun Xu }); 9616116ca67SAnlun Xu 9626116ca67SAnlun Xu Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true); 963*129ade21SLongsheng Mou if (bufferTp != *dstTp) { 964*129ade21SLongsheng Mou auto dstRTT = dstTp->getRankedTensorType(); 9656116ca67SAnlun Xu Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult(); 9666116ca67SAnlun Xu rewriter.create<DeallocTensorOp>(loc, t); 9676116ca67SAnlun Xu t = converted; 9686116ca67SAnlun Xu } 9696116ca67SAnlun Xu rewriter.replaceOp(op, t); 9706116ca67SAnlun Xu return success(); 9716116ca67SAnlun Xu } 9726116ca67SAnlun Xu }; 9736116ca67SAnlun Xu 9746116ca67SAnlun Xu /// Sparse rewriting rule for sparse-to-sparse reshape operator. 975330d48c4Sbixia1 template <typename ReshapeOp> 976330d48c4Sbixia1 struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> { 977330d48c4Sbixia1 public: 978330d48c4Sbixia1 using OpRewritePattern<ReshapeOp>::OpRewritePattern; 979330d48c4Sbixia1 980330d48c4Sbixia1 LogicalResult matchAndRewrite(ReshapeOp op, 981330d48c4Sbixia1 PatternRewriter &rewriter) const override { 982330d48c4Sbixia1 Location loc = op.getLoc(); 983330d48c4Sbixia1 Value srcTensor = op.getSrc(); 984f2696e46Swren romano const auto srcTp = getSparseTensorType(srcTensor); 985f2696e46Swren romano const auto dstTp = getSparseTensorType(op.getResult()); 986f2696e46Swren romano if (!srcTp.hasEncoding() || !dstTp.hasEncoding()) 987330d48c4Sbixia1 return failure(); 988330d48c4Sbixia1 989330d48c4Sbixia1 // Generate code to represent the static dimension constants or compute 990330d48c4Sbixia1 // the dynamic dimension values. 9910e1708ffSAart Bik SmallVector<Value> srcSizes; 992330d48c4Sbixia1 sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor); 9930e1708ffSAart Bik SmallVector<Value> dstSizes; 9940e1708ffSAart Bik SmallVector<Value> dstDynSizes; 995f2696e46Swren romano if (dstTp.hasStaticDimShape()) { 996f2696e46Swren romano for (Dimension d : dstTp.getDimShape()) 997330d48c4Sbixia1 dstSizes.push_back(constantIndex(rewriter, loc, d)); 998330d48c4Sbixia1 } else { 99922212ca7SAart Bik ArrayRef<Size> dstShape = dstTp.getDimShape(); 10009d4df97fSwren romano genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape, 1001330d48c4Sbixia1 op.getReassociationIndices()); 10028c258fdaSJakub Kuderski for (auto [idx, shape] : llvm::enumerate(dstShape)) { 10038c258fdaSJakub Kuderski if (shape == ShapedType::kDynamic) 10048c258fdaSJakub Kuderski dstDynSizes.push_back(dstSizes[idx]); 1005330d48c4Sbixia1 } 1006330d48c4Sbixia1 } 10073bd82f30SAart Bik Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor); 1008c24547e9SPeiming Liu // Only need a unordered COO buffer if input and output are not sorted 1009c24547e9SPeiming Liu // in the same way. 101076647fceSwren romano Type bufferTp = getBufferType( 101176647fceSwren romano dstTp.withoutDimToLvl(), 101276647fceSwren romano !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity()); 1013c24547e9SPeiming Liu 1014c24547e9SPeiming Liu Value buffer = 10153bd82f30SAart Bik rewriter 1016c24547e9SPeiming Liu .create<AllocTensorOp>(loc, bufferTp, dstDynSizes, Value(), 10173bd82f30SAart Bik /*sizeHint=*/nnz, Attribute()) 10183bd82f30SAart Bik .getResult(); 10193bd82f30SAart Bik 1020c24547e9SPeiming Liu // Implement the sparse2sparse reshape as follows: 1021c24547e9SPeiming Liu // foreach srcCoords %srcTensor 1022c24547e9SPeiming Liu // insert reshapeCvs(srcCoords), %buffer 1023c24547e9SPeiming Liu // 1024c24547e9SPeiming Liu // followed by an optional 1025c24547e9SPeiming Liu // %t = sparse_tensor.cast %tmp 1026c24547e9SPeiming Liu // depending on whether the input/output are sorted in the same way. 1027f2696e46Swren romano const auto encSrc = srcTp.getEncoding(); 10284fa00ce1SPeiming Liu ForeachOp foreachOp = rewriter.create<ForeachOp>( 1029c24547e9SPeiming Liu loc, srcTensor, buffer, 103084cd51bbSwren romano [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v, 10314fa00ce1SPeiming Liu ValueRange reduc) { 1032f2696e46Swren romano const Dimension dimRank = srcTp.getDimRank(); 103384cd51bbSwren romano SmallVector<Value> srcDcvs; 103484cd51bbSwren romano srcDcvs.reserve(dimRank); 103584cd51bbSwren romano for (Dimension d = 0; d < dimRank; d++) { 10364e2f1521SPeiming Liu Level lvl = toLvl(encSrc, d); 103784cd51bbSwren romano srcDcvs.push_back(srcLcvs[lvl]); 1038330d48c4Sbixia1 } 103984cd51bbSwren romano SmallVector<Value> dstDcvs; 104084cd51bbSwren romano reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes, 104184cd51bbSwren romano srcDcvs, dstSizes, dstDcvs); 104294e27c26SPeiming Liu auto t = 104394e27c26SPeiming Liu builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs); 10444fa00ce1SPeiming Liu builder.create<sparse_tensor::YieldOp>(loc, t); 1045330d48c4Sbixia1 }); 1046c24547e9SPeiming Liu 1047c24547e9SPeiming Liu Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true); 1048c24547e9SPeiming Liu if (bufferTp != dstTp) { 1049f2696e46Swren romano auto dstRTT = dstTp.getRankedTensorType(); 1050f2696e46Swren romano Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult(); 10518ffdcc59SPeiming Liu rewriter.create<DeallocTensorOp>(loc, t); 1052c24547e9SPeiming Liu t = converted; 1053c24547e9SPeiming Liu } 1054c24547e9SPeiming Liu rewriter.replaceOp(op, t); 1055330d48c4Sbixia1 return success(); 1056330d48c4Sbixia1 } 1057330d48c4Sbixia1 }; 1058330d48c4Sbixia1 1059330d48c4Sbixia1 /// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape 1060330d48c4Sbixia1 /// operator. 106128ebb0b6SAart Bik template <typename ReshapeOp> 106228ebb0b6SAart Bik struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> { 106328ebb0b6SAart Bik public: 106428ebb0b6SAart Bik using OpRewritePattern<ReshapeOp>::OpRewritePattern; 106528ebb0b6SAart Bik 106628ebb0b6SAart Bik LogicalResult matchAndRewrite(ReshapeOp op, 106728ebb0b6SAart Bik PatternRewriter &rewriter) const override { 106828ebb0b6SAart Bik Location loc = op->getLoc(); 106928ebb0b6SAart Bik auto encDst = getSparseTensorEncoding(op.getResult().getType()); 107028ebb0b6SAart Bik auto encSrc = getSparseTensorEncoding(op.getSrc().getType()); 107128ebb0b6SAart Bik // Since a pure dense expansion is very cheap (change of view), for 107228ebb0b6SAart Bik // a sparse2dense or dense2sparse, we can simply unfuse a sparse 107328ebb0b6SAart Bik // conversion from the reshape operation itself. 107428ebb0b6SAart Bik // All other cases are handled elsewhere. 107528ebb0b6SAart Bik if (encDst && encSrc) { 107628ebb0b6SAart Bik return failure(); 10770449b6a0SMehdi Amini } 10780449b6a0SMehdi Amini if (encSrc) { 1079255c3f11Swren romano auto rtp = getRankedTensorType(op.getSrc()); 108028ebb0b6SAart Bik auto denseTp = 108128ebb0b6SAart Bik RankedTensorType::get(rtp.getShape(), rtp.getElementType()); 108228ebb0b6SAart Bik auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc()); 10835fcf907bSMatthias Springer rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); }); 108428ebb0b6SAart Bik return success(); 1085550288cbSPeiming Liu } 1086550288cbSPeiming Liu if (encDst) { 1087255c3f11Swren romano auto rtp = getRankedTensorType(op.getResult()); 108828ebb0b6SAart Bik auto denseTp = 108928ebb0b6SAart Bik RankedTensorType::get(rtp.getShape(), rtp.getElementType()); 109097069a86SGaurav Shukla ReshapeOp reshape; 109197069a86SGaurav Shukla if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) { 109297069a86SGaurav Shukla reshape = rewriter.create<ReshapeOp>( 109397069a86SGaurav Shukla loc, denseTp, op.getSrc(), op.getReassociation(), 109497069a86SGaurav Shukla op.getOutputShape(), op.getStaticOutputShape()); 109597069a86SGaurav Shukla } else { 109697069a86SGaurav Shukla reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(), 109728ebb0b6SAart Bik op.getReassociation()); 109897069a86SGaurav Shukla } 109928ebb0b6SAart Bik Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape); 110028ebb0b6SAart Bik rewriter.replaceOp(op, convert); 110128ebb0b6SAart Bik return success(); 110228ebb0b6SAart Bik } 110328ebb0b6SAart Bik return failure(); 110428ebb0b6SAart Bik } 110528ebb0b6SAart Bik }; 110628ebb0b6SAart Bik 110771c97c73SPeiming Liu // A trivial wrapper to help generate different operations for dense/sparse 110871c97c73SPeiming Liu // tensors. 1109dda3dc5eSPeiming Liu struct TensorLike { 1110dda3dc5eSPeiming Liu TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt, 111171c97c73SPeiming Liu ValueRange sizes) { 1112dda3dc5eSPeiming Liu SmallVector<Value> dynSzs; 1113dda3dc5eSPeiming Liu getDynamicSizes(rtt, sizes, dynSzs); 1114dda3dc5eSPeiming Liu 1115dda3dc5eSPeiming Liu val = builder.create<AllocTensorOp>(loc, rtt, dynSzs); 111671c97c73SPeiming Liu if (!isSparse()) { 111771c97c73SPeiming Liu Value c0 = constantZero(builder, loc, rtt.getElementType()); 111871c97c73SPeiming Liu val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0); 111971c97c73SPeiming Liu } 1120dda3dc5eSPeiming Liu } 1121dda3dc5eSPeiming Liu 112271c97c73SPeiming Liu void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) { 112371c97c73SPeiming Liu val = builder.create<tensor::InsertOp>(loc, v, val, crds); 1124dda3dc5eSPeiming Liu } 1125dda3dc5eSPeiming Liu 1126dda3dc5eSPeiming Liu Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const { 112771c97c73SPeiming Liu if (isSparse()) 1128dda3dc5eSPeiming Liu return builder.create<LoadOp>(loc, val, true); 112971c97c73SPeiming Liu return val; 1130dda3dc5eSPeiming Liu } 1131dda3dc5eSPeiming Liu 113271c97c73SPeiming Liu bool isSparse() const { 113371c97c73SPeiming Liu return getSparseTensorEncoding(val.getType()) != nullptr; 1134dda3dc5eSPeiming Liu } 1135dda3dc5eSPeiming Liu 113671c97c73SPeiming Liu Value val; 1137dda3dc5eSPeiming Liu }; 1138dda3dc5eSPeiming Liu 1139c780352dSPeiming Liu struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> { 1140c780352dSPeiming Liu using OpRewritePattern::OpRewritePattern; 1141c780352dSPeiming Liu LogicalResult matchAndRewrite(tensor::DimOp op, 1142c780352dSPeiming Liu PatternRewriter &rewriter) const override { 1143c780352dSPeiming Liu std::optional<int64_t> dim = op.getConstantIndex(); 1144*129ade21SLongsheng Mou auto stt = tryGetSparseTensorType(op.getSource()); 1145*129ade21SLongsheng Mou if (!dim || !stt || !stt->hasEncoding()) 1146c780352dSPeiming Liu return failure(); 1147c780352dSPeiming Liu 1148*129ade21SLongsheng Mou if (stt->isPermutation()) { 1149c780352dSPeiming Liu rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(), 1150*129ade21SLongsheng Mou toLvl(stt->getEncoding(), *dim)); 1151c780352dSPeiming Liu return success(); 1152c780352dSPeiming Liu } 1153c780352dSPeiming Liu 1154c780352dSPeiming Liu // Non-permutation dim2lvl/lvl2dim maps. 1155c780352dSPeiming Liu // Compute as follows: 1156c780352dSPeiming Liu // affine.apply #map (l0 - 1, l1 - 1, ...) + 1 1157c780352dSPeiming Liu // Note that it is not the most efficient way (but a more general one) for 1158c780352dSPeiming Liu // the lvl to dim translation, e.g., for BSR, the dimension size for can be 1159c780352dSPeiming Liu // computed simply by lvl_size * block_size. 1160c780352dSPeiming Liu Location loc = op.getLoc(); 1161c780352dSPeiming Liu SmallVector<Value> maxLvlCrds; 1162*129ade21SLongsheng Mou for (Level l = 0; l < stt->getLvlRank(); l++) { 1163c780352dSPeiming Liu Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l); 1164c780352dSPeiming Liu Value maxLvlCrd = rewriter.create<arith::SubIOp>( 1165c780352dSPeiming Liu loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType())); 1166c780352dSPeiming Liu maxLvlCrds.push_back(maxLvlCrd); 1167c780352dSPeiming Liu } 1168c780352dSPeiming Liu 1169*129ade21SLongsheng Mou AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim); 1170c780352dSPeiming Liu Value maxDimCrd = rewriter.create<affine::AffineApplyOp>( 1171*129ade21SLongsheng Mou op.getLoc(), AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp), 1172c780352dSPeiming Liu maxLvlCrds); 1173c780352dSPeiming Liu 1174c780352dSPeiming Liu Value dimSz = rewriter.create<arith::AddIOp>( 1175c780352dSPeiming Liu loc, maxDimCrd, constantOne(rewriter, loc, rewriter.getIndexType())); 1176c780352dSPeiming Liu rewriter.replaceOp(op, dimSz); 1177c780352dSPeiming Liu return success(); 1178c780352dSPeiming Liu } 1179c780352dSPeiming Liu }; 1180c780352dSPeiming Liu 1181761c9dd9SPeiming Liu struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> { 1182761c9dd9SPeiming Liu using OpRewritePattern::OpRewritePattern; 1183761c9dd9SPeiming Liu LogicalResult matchAndRewrite(ConcatenateOp op, 1184761c9dd9SPeiming Liu PatternRewriter &rewriter) const override { 1185761c9dd9SPeiming Liu if (op.needsExtraSort()) 1186761c9dd9SPeiming Liu op.emitError("ConcatenateOp not staged"); 1187761c9dd9SPeiming Liu 1188761c9dd9SPeiming Liu const Location loc = op.getLoc(); 1189761c9dd9SPeiming Liu const auto dstTp = getSparseTensorType(op); 1190761c9dd9SPeiming Liu const Dimension conDim = op.getDimension(); 1191761c9dd9SPeiming Liu SmallVector<Value> sizes; 1192761c9dd9SPeiming Liu concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim); 1193761c9dd9SPeiming Liu 1194761c9dd9SPeiming Liu // %t = concatenate %s1, %s2, %s3 {dim = 1} 1195761c9dd9SPeiming Liu // ==> 1196761c9dd9SPeiming Liu // if (isSparseDst) 1197761c9dd9SPeiming Liu // if (allDense) 1198761c9dd9SPeiming Liu // %tmp = bufferization.alloc_tensor dstTp 1199761c9dd9SPeiming Liu // else 1200761c9dd9SPeiming Liu // %tmp = bufferization.alloc_tensor : unordered COO 1201761c9dd9SPeiming Liu // else 1202761c9dd9SPeiming Liu // %tmp = memref.alloc : dense tensor 1203761c9dd9SPeiming Liu // foreach in %s1 : insert d0, d1, %tmp 1204761c9dd9SPeiming Liu // foreach in %s2 : insert d0, d1 + size(s1), %tmp 1205761c9dd9SPeiming Liu // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp 1206761c9dd9SPeiming Liu 1207761c9dd9SPeiming Liu TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes); 1208761c9dd9SPeiming Liu Value offset = constantIndex(rewriter, loc, 0); 120971c97c73SPeiming Liu Value iterArg = dstBuf.val; 1210761c9dd9SPeiming Liu 1211761c9dd9SPeiming Liu ForeachOp foreachOp; 1212761c9dd9SPeiming Liu for (Value input : op.getInputs()) { 1213761c9dd9SPeiming Liu // Builds a for op for each input tensor to append new values into the 1214761c9dd9SPeiming Liu // output tensor. 1215761c9dd9SPeiming Liu foreachOp = rewriter.create<ForeachOp>( 121671c97c73SPeiming Liu loc, input, iterArg, 1217761c9dd9SPeiming Liu [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v, 1218761c9dd9SPeiming Liu ValueRange reduc) { 1219ef100c22SPeiming Liu SmallVector<Value> offDimCrd(dcvs); 1220ef100c22SPeiming Liu offDimCrd[conDim] = 1221ef100c22SPeiming Liu builder.create<arith::AddIOp>(loc, offDimCrd[conDim], offset); 1222ef100c22SPeiming Liu 122371c97c73SPeiming Liu // Enters foreach, updates the SSA chain. 122471c97c73SPeiming Liu dstBuf.val = reduc.front(); 1225761c9dd9SPeiming Liu if (!dstTp.isAllDense()) { 1226761c9dd9SPeiming Liu Value cond = genIsNonzero(builder, loc, v); 1227761c9dd9SPeiming Liu auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond, 1228761c9dd9SPeiming Liu /*else*/ true); 1229761c9dd9SPeiming Liu builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 123071c97c73SPeiming Liu builder.create<scf::YieldOp>(loc, dstBuf.val); 1231761c9dd9SPeiming Liu 1232761c9dd9SPeiming Liu builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1233ef100c22SPeiming Liu dstBuf.insert(builder, loc, v, offDimCrd); 123471c97c73SPeiming Liu builder.create<scf::YieldOp>(loc, dstBuf.val); 1235761c9dd9SPeiming Liu 1236761c9dd9SPeiming Liu // Exits the ifOp, update the sparse tensor SSA value. 1237761c9dd9SPeiming Liu builder.setInsertionPointAfter(ifOp); 123871c97c73SPeiming Liu dstBuf.val = ifOp.getResult(0); 1239761c9dd9SPeiming Liu } else { 1240ef100c22SPeiming Liu dstBuf.insert(builder, loc, v, offDimCrd); 1241761c9dd9SPeiming Liu } 124271c97c73SPeiming Liu builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val); 1243761c9dd9SPeiming Liu }); 1244761c9dd9SPeiming Liu // Accumulates the offset. Note that only static-shaped inputs are allowed 1245761c9dd9SPeiming Liu // by concatenate op verifier, which saves us from computing the offset 1246761c9dd9SPeiming Liu // dynamically. 124722212ca7SAart Bik const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim); 124822212ca7SAart Bik assert(!ShapedType::isDynamic(sz)); 124922212ca7SAart Bik offset = rewriter.create<arith::AddIOp>(loc, offset, 125022212ca7SAart Bik constantIndex(rewriter, loc, sz)); 1251761c9dd9SPeiming Liu iterArg = foreachOp.getResult(0); 125271c97c73SPeiming Liu dstBuf.val = iterArg; 1253761c9dd9SPeiming Liu } 1254761c9dd9SPeiming Liu 125571c97c73SPeiming Liu dstBuf.val = iterArg; 1256761c9dd9SPeiming Liu Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType()); 1257761c9dd9SPeiming Liu rewriter.replaceOp(op, ret); 1258761c9dd9SPeiming Liu return success(); 1259761c9dd9SPeiming Liu } 1260761c9dd9SPeiming Liu }; 1261761c9dd9SPeiming Liu 1262dda3dc5eSPeiming Liu struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> { 1263eb877006Sbixia1 using OpRewritePattern::OpRewritePattern; 1264eb877006Sbixia1 LogicalResult matchAndRewrite(ConvertOp op, 1265eb877006Sbixia1 PatternRewriter &rewriter) const override { 1266761c9dd9SPeiming Liu if (op.needsExtraSort()) 1267f248d0b2SPeiming Liu return op.emitError("ConvertOp not staged."); 1268dda3dc5eSPeiming Liu 1269dda3dc5eSPeiming Liu // TODO: Maybe we want a different operation for this too. 1270eb877006Sbixia1 auto encDst = getSparseTensorEncoding(op.getType()); 1271eb877006Sbixia1 auto encSrc = getSparseTensorEncoding(op.getSource().getType()); 127233267f40SPeiming Liu if (encDst && encSrc && !encSrc.isSlice() && 127385dbb3fcSPeiming Liu encSrc.withoutBitWidths() == encDst.withoutBitWidths()) { 127485dbb3fcSPeiming Liu // Trivial tensor conversion and simple element type conversion is handled 127585dbb3fcSPeiming Liu // in codegen. 1276eb877006Sbixia1 return failure(); 1277eb877006Sbixia1 } 1278eb877006Sbixia1 1279eb877006Sbixia1 Location loc = op.getLoc(); 1280eb877006Sbixia1 Value src = op.getSource(); 1281dda3dc5eSPeiming Liu 1282dda3dc5eSPeiming Liu SparseTensorType srcStt = getSparseTensorType(op.getSource()); 1283dda3dc5eSPeiming Liu SparseTensorType dstStt = getSparseTensorType(op.getDest()); 1284eb877006Sbixia1 1285e6cbb914SAart Bik bool fromSparseConst = false; 1286dda3dc5eSPeiming Liu if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>()) 1287dda3dc5eSPeiming Liu if (dyn_cast<SparseElementsAttr>(constOp.getValue())) 1288e6cbb914SAart Bik fromSparseConst = true; 1289e6cbb914SAart Bik 129076647fceSwren romano const AffineMapAttr foreachOrder = 1291dda3dc5eSPeiming Liu (!dstStt.isIdentity() && fromSparseConst) 1292dda3dc5eSPeiming Liu ? AffineMapAttr::get(dstStt.getExpandedDimToLvl()) 129376647fceSwren romano : nullptr; 129441089f86SPeiming Liu 1295dda3dc5eSPeiming Liu bool skipZeroCheck = srcStt.hasEncoding() || fromSparseConst; 1296eb877006Sbixia1 12970e1708ffSAart Bik SmallVector<Value> sizes; 1298dda3dc5eSPeiming Liu sizesFromSrc(rewriter, sizes, loc, src); 1299dda3dc5eSPeiming Liu ValueRange vs; 1300dda3dc5eSPeiming Liu TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes); 1301a61a9a70SAart Bik 1302dda3dc5eSPeiming Liu auto foreachOp = rewriter.create<ForeachOp>( 130371c97c73SPeiming Liu loc, src, dstBuf.val, foreachOrder, 1304dda3dc5eSPeiming Liu [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v, 1305dda3dc5eSPeiming Liu ValueRange reduc) { 1306dda3dc5eSPeiming Liu // Enters the loop, update the SSA value for insertion chain. 130771c97c73SPeiming Liu dstBuf.val = reduc.front(); 1308dda3dc5eSPeiming Liu if (!skipZeroCheck) { 1309dda3dc5eSPeiming Liu Value cond = genIsNonzero(builder, loc, v); 1310dda3dc5eSPeiming Liu auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond, 1311dda3dc5eSPeiming Liu /*else*/ true); 1312dda3dc5eSPeiming Liu builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 131371c97c73SPeiming Liu builder.create<scf::YieldOp>(loc, dstBuf.val); 1314dda3dc5eSPeiming Liu 1315dda3dc5eSPeiming Liu builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1316ef100c22SPeiming Liu dstBuf.insert(builder, loc, v, dcvs); 131771c97c73SPeiming Liu builder.create<scf::YieldOp>(loc, dstBuf.val); 1318dda3dc5eSPeiming Liu 1319dda3dc5eSPeiming Liu // Exits the ifOp, update the sparse tensor SSA value. 1320dda3dc5eSPeiming Liu builder.setInsertionPointAfter(ifOp); 132171c97c73SPeiming Liu dstBuf.val = ifOp.getResult(0); 1322dda3dc5eSPeiming Liu } else { 1323ef100c22SPeiming Liu dstBuf.insert(builder, loc, v, dcvs); 1324dda3dc5eSPeiming Liu } 132571c97c73SPeiming Liu builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val); 1326eb877006Sbixia1 }); 1327eb877006Sbixia1 1328dda3dc5eSPeiming Liu rewriter.setInsertionPointAfter(foreachOp); 1329eb877006Sbixia1 1330dda3dc5eSPeiming Liu // Exits the for loop, links the SSA chain. 133171c97c73SPeiming Liu dstBuf.val = foreachOp.getResult(0); 1332eb877006Sbixia1 1333dda3dc5eSPeiming Liu Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType()); 1334dda3dc5eSPeiming Liu rewriter.replaceOp(op, ret); 1335eb877006Sbixia1 return success(); 1336eb877006Sbixia1 } 1337eb877006Sbixia1 }; 1338eb877006Sbixia1 13393426d330SPeiming Liu struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> { 13403426d330SPeiming Liu using OpRewritePattern::OpRewritePattern; 13413426d330SPeiming Liu LogicalResult matchAndRewrite(CrdTranslateOp op, 13423426d330SPeiming Liu PatternRewriter &rewriter) const override { 13433426d330SPeiming Liu AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl 13443426d330SPeiming Liu ? op.getEncoder().getDimToLvl() 13453426d330SPeiming Liu : op.getEncoder().getLvlToDim(); 13463426d330SPeiming Liu 13473426d330SPeiming Liu SmallVector<Value> outCrds; 13483426d330SPeiming Liu for (AffineExpr result : map.getResults()) { 13493426d330SPeiming Liu // TODO: we should probably expand the affine map to IR using our own 13503426d330SPeiming Liu // rules, since affine.apply assume signed value, while the cooridinates 13513426d330SPeiming Liu // we provided must always be signless. 13523426d330SPeiming Liu Value trans = rewriter.create<affine::AffineApplyOp>( 13533426d330SPeiming Liu op.getLoc(), AffineMap::get(map.getNumDims(), 0, result), 13543426d330SPeiming Liu op.getInCrds()); 13553426d330SPeiming Liu outCrds.push_back(trans); 13563426d330SPeiming Liu } 13573426d330SPeiming Liu rewriter.replaceOp(op, outCrds); 13583426d330SPeiming Liu return success(); 13593426d330SPeiming Liu } 13603426d330SPeiming Liu }; 13613426d330SPeiming Liu 1362550288cbSPeiming Liu /// Sparse rewriting rule for the foreach operator. 1363550288cbSPeiming Liu struct ForeachRewriter : public OpRewritePattern<ForeachOp> { 1364550288cbSPeiming Liu public: 1365550288cbSPeiming Liu using OpRewritePattern::OpRewritePattern; 1366550288cbSPeiming Liu 1367550288cbSPeiming Liu LogicalResult matchAndRewrite(ForeachOp op, 1368550288cbSPeiming Liu PatternRewriter &rewriter) const override { 1369550288cbSPeiming Liu 1370550288cbSPeiming Liu auto loc = op.getLoc(); 1371550288cbSPeiming Liu Value input = op.getTensor(); 13727175f9ddSPeiming Liu SmallVector<Value> reduc = op.getInitArgs(); 1373f708a549Swren romano const auto stt = getSparseTensorType(input); 137484cd51bbSwren romano const Level lvlRank = stt.getLvlRank(); 1375550288cbSPeiming Liu 13767175f9ddSPeiming Liu // Special-case: for each over a sparse constant uses its own rewriting 13777175f9ddSPeiming Liu // rule. 13787175f9ddSPeiming Liu if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) { 13795550c821STres Popp if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) { 13808d615a23SPeiming Liu return genForeachOnSparseConstant(op, rewriter, attr); 13817175f9ddSPeiming Liu } 13827175f9ddSPeiming Liu } 13837175f9ddSPeiming Liu 13847175f9ddSPeiming Liu // Otherwise, use loop emitter to generate loops. 1385f708a549Swren romano const auto enc = stt.getEncoding(); 13864fa00ce1SPeiming Liu 1387550288cbSPeiming Liu // 1. Generates loop for the sparse input. 1388781eabebSPeiming Liu LoopEmitter loopEmitter( 138991e7b9e5SPeiming Liu ValueRange{input}, 139091e7b9e5SPeiming Liu StringAttr::get(getContext(), ForeachOp::getOperationName())); 1391550288cbSPeiming Liu loopEmitter.initializeLoopEmit(rewriter, loc); 1392b8cf7af9Swren romano for (Level l = 0; l < lvlRank; l++) { 1393b0f8057eSPeiming Liu // TODO: provide utility function for loop sequences that only contains 1394b0f8057eSPeiming Liu // one for loop? 139536c95ee7SPeiming Liu const SmallVector<TensorLevel, 1> tidLvls{ 139636c95ee7SPeiming Liu loopEmitter.makeTensorLevel(0, l)}; 139736c95ee7SPeiming Liu loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls); 13984fa00ce1SPeiming Liu // Note that reduc will be taken care of by loop emitter and get updated 13994fa00ce1SPeiming Liu // in place. 1400c4420257SPeiming Liu loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1, 1401fd68d361SPeiming Liu reduc); 1402b0f8057eSPeiming Liu } 1403550288cbSPeiming Liu 1404372d88b0SPeiming Liu SmallVector<Value> lcvs = loopEmitter.getLoopIVs(); 14059e8d9316SPeiming Liu if (op.getOrder()) { 140653ffafb2SPeiming Liu // TODO: Support it so that we can do direct conversion from CSR->BSR. 140753ffafb2SPeiming Liu llvm_unreachable( 140853ffafb2SPeiming Liu "Level order not yet implemented on non-constant input tensors."); 14099e8d9316SPeiming Liu } 141053ffafb2SPeiming Liu 1411b0f8057eSPeiming Liu Value vals = loopEmitter.getValBuffer()[0]; 141252b69aa3SPeiming Liu SmallVector<Value> pos = loopEmitter.getValPosits(0); 141384cd51bbSwren romano // Loads the value from sparse tensor using position-index; 141484cd51bbSwren romano // loads the value from dense tensor using coords. 1415b8cf7af9Swren romano Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pos) 141684cd51bbSwren romano : rewriter.create<memref::LoadOp>(loc, vals, lcvs); 141735b3a0ceSPeiming Liu 1418550288cbSPeiming Liu // 2. Inline the block in the foreach operator. 1419550288cbSPeiming Liu Block *srcBlock = op.getBody(); 1420b0f8057eSPeiming Liu 1421550288cbSPeiming Liu // Remap coordinates. 14226456e0bbSPeiming Liu SmallVector<Value> args = 14236456e0bbSPeiming Liu enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim); 14246456e0bbSPeiming Liu 1425550288cbSPeiming Liu // Remap value. 1426550288cbSPeiming Liu args.push_back(val); 14274fa00ce1SPeiming Liu // Remap reduction variables. 14284fa00ce1SPeiming Liu args.append(reduc); 14294fa00ce1SPeiming Liu 14304fa00ce1SPeiming Liu // Remove sparse_tensor.yield. 14314fa00ce1SPeiming Liu SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands(); 14324fa00ce1SPeiming Liu rewriter.eraseOp(srcBlock->getTerminator()); 1433550288cbSPeiming Liu 1434298412b5SPeiming Liu Operation &last = rewriter.getBlock()->back(); 1435298412b5SPeiming Liu if (llvm::isa<scf::YieldOp>(last)) { 1436298412b5SPeiming Liu // Because `scf.for` inserts an implicit yield op when there is no 1437298412b5SPeiming Liu // reduction variable upon creation, we reset the insertion point such 1438298412b5SPeiming Liu // that the block is inlined before *before* the yield op. 1439298412b5SPeiming Liu rewriter.setInsertionPoint(&last); 14404fa00ce1SPeiming Liu } 14414fa00ce1SPeiming Liu 1442298412b5SPeiming Liu rewriter.inlineBlockBefore(srcBlock, rewriter.getBlock(), 1443298412b5SPeiming Liu rewriter.getInsertionPoint(), args); 1444298412b5SPeiming Liu rewriter.setInsertionPointToEnd(rewriter.getBlock()); 1445e9fa1fdeSPeiming Liu for (Level l = 0; l < lvlRank; l++) { 14464fa00ce1SPeiming Liu // Link the reduction chain. Note that loop emitter update the reducValue 14474fa00ce1SPeiming Liu // in place. 14484fa00ce1SPeiming Liu loopEmitter.exitCurrentLoop(rewriter, loc, reducValue); 14495fd9d801SPeiming Liu loopEmitter.exitCurrentLoopSeq(rewriter, loc); 14504fa00ce1SPeiming Liu } 14514fa00ce1SPeiming Liu 14524fa00ce1SPeiming Liu // Replace the foreach operator with the value returned by the outtermost 14534fa00ce1SPeiming Liu // for loop. 14544fa00ce1SPeiming Liu rewriter.replaceOp(op, reducValue); 1455550288cbSPeiming Liu return success(); 1456550288cbSPeiming Liu } 1457550288cbSPeiming Liu }; 1458550288cbSPeiming Liu 145967f61b08Sbixia1 /// Sparse rewriting rule for the new operator. 146067f61b08Sbixia1 struct NewRewriter : public OpRewritePattern<NewOp> { 146167f61b08Sbixia1 using OpRewritePattern::OpRewritePattern; 146267f61b08Sbixia1 LogicalResult matchAndRewrite(NewOp op, 146367f61b08Sbixia1 PatternRewriter &rewriter) const override { 146467f61b08Sbixia1 Location loc = op.getLoc(); 1465e8fc282fSAart Bik auto stt = getSparseTensorType(op.getResult()); 14665248a987SPeiming Liu if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0) 146767f61b08Sbixia1 return failure(); 146867f61b08Sbixia1 146967f61b08Sbixia1 // Implement the NewOp as follows: 14702c81d432Sbixia1 // %orderedCoo = sparse_tensor.new %filename 147184cd51bbSwren romano // %t = sparse_tensor.convert %orderedCoo 1472e8fc282fSAart Bik // with enveloping reinterpreted_map ops for non-permutations. 1473e8fc282fSAart Bik RankedTensorType dstTp = stt.getRankedTensorType(); 147445288085SAart Bik RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true); 14752c81d432Sbixia1 Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource()); 1476e8fc282fSAart Bik Value convert = cooTensor; 14775b729503SAart Bik auto enc = stt.getEncoding(); 1478e8fc282fSAart Bik if (!stt.isPermutation()) { // demap coo, demap dstTp 1479e8fc282fSAart Bik auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl(); 1480e8fc282fSAart Bik convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert); 1481e8fc282fSAart Bik dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl()); 1482e8fc282fSAart Bik } 1483e8fc282fSAart Bik convert = rewriter.create<ConvertOp>(loc, dstTp, convert); 1484e8fc282fSAart Bik if (!stt.isPermutation()) // remap to original enc 1485e8fc282fSAart Bik convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert); 1486e8fc282fSAart Bik rewriter.replaceOp(op, convert); 14873bd82f30SAart Bik 1488e8fc282fSAart Bik // Release the temporary ordered COO tensor. 14892c81d432Sbixia1 rewriter.setInsertionPointAfterValue(convert); 14902c81d432Sbixia1 rewriter.create<DeallocTensorOp>(loc, cooTensor); 149167f61b08Sbixia1 149267f61b08Sbixia1 return success(); 149367f61b08Sbixia1 } 149467f61b08Sbixia1 }; 149567f61b08Sbixia1 1496e8fc282fSAart Bik /// Sparse rewriting rule for the out operator. 1497e445349dSbixia1 struct OutRewriter : public OpRewritePattern<OutOp> { 1498e445349dSbixia1 using OpRewritePattern::OpRewritePattern; 1499e445349dSbixia1 LogicalResult matchAndRewrite(OutOp op, 1500e445349dSbixia1 PatternRewriter &rewriter) const override { 1501e445349dSbixia1 Location loc = op.getLoc(); 1502e445349dSbixia1 // Calculate NNZ. 1503e445349dSbixia1 Value src = op.getTensor(); 1504e445349dSbixia1 Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src); 1505e445349dSbixia1 150684cd51bbSwren romano // Allocate a temporary buffer for storing dimension-sizes/coordinates. 1507f708a549Swren romano const auto srcTp = getSparseTensorType(src); 1508f708a549Swren romano const Dimension dimRank = srcTp.getDimRank(); 1509e445349dSbixia1 Type indexTp = rewriter.getIndexType(); 1510f708a549Swren romano Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp); 1511e445349dSbixia1 1512e445349dSbixia1 // Generate code to calculate dimension size values and store the values to 1513e445349dSbixia1 // the buffer. 15140e1708ffSAart Bik SmallVector<Value> dims; 1515e445349dSbixia1 sizesForTensor(rewriter, dims, loc, srcTp, src); 1516f708a549Swren romano for (Dimension d = 0; d < dimRank; d++) { 1517f708a549Swren romano rewriter.create<memref::StoreOp>(loc, dims[d], dimSizes, 1518f708a549Swren romano constantIndex(rewriter, loc, d)); 1519e445349dSbixia1 } 1520e445349dSbixia1 1521e445349dSbixia1 // Create a sparse tensor writer and output meta data. 1522e445349dSbixia1 Type opaqueTp = getOpaquePointerType(rewriter); 1523e445349dSbixia1 Value writer = 1524e445349dSbixia1 createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp}, 1525e445349dSbixia1 {op.getDest()}, EmitCInterface::Off) 1526e445349dSbixia1 .getResult(0); 1527f708a549Swren romano Value rankValue = constantIndex(rewriter, loc, dimRank); 1528e445349dSbixia1 createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {}, 1529e445349dSbixia1 {writer, rankValue, nnz, dimSizes}, EmitCInterface::On); 1530e445349dSbixia1 153184cd51bbSwren romano Value dimCoords = dimSizes; // Reuse the dimSizes buffer for dimCoords. 1532e445349dSbixia1 Type eltTp = srcTp.getElementType(); 15332af2e4dbSwren romano SmallString<29> outNextFuncName{"outSparseTensorWriterNext", 1534e445349dSbixia1 primaryTypeFunctionSuffix(eltTp)}; 1535e445349dSbixia1 Value value = genAllocaScalar(rewriter, loc, eltTp); 1536e445349dSbixia1 ModuleOp module = op->getParentOfType<ModuleOp>(); 1537e8fc282fSAart Bik 1538e445349dSbixia1 // For each element in the source tensor, output the element. 1539e445349dSbixia1 rewriter.create<ForeachOp>( 15401a36588eSKazu Hirata loc, src, std::nullopt, 154184cd51bbSwren romano [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v, 15424fa00ce1SPeiming Liu ValueRange reduc) { 1543f708a549Swren romano for (Dimension d = 0; d < dimRank; d++) { 154484cd51bbSwren romano rewriter.create<memref::StoreOp>(loc, dcvs[d], dimCoords, 1545f708a549Swren romano constantIndex(builder, loc, d)); 1546e445349dSbixia1 } 15474fa00ce1SPeiming Liu rewriter.create<memref::StoreOp>(loc, v, value); 154884cd51bbSwren romano SmallVector<Value> operands{writer, rankValue, dimCoords, value}; 1549e445349dSbixia1 FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands, 1550e445349dSbixia1 EmitCInterface::On); 1551e445349dSbixia1 builder.create<func::CallOp>(loc, TypeRange(), fn, operands); 1552e445349dSbixia1 builder.create<sparse_tensor::YieldOp>(loc); 1553e445349dSbixia1 }); 1554e445349dSbixia1 1555e445349dSbixia1 // Release the writer. 1556e445349dSbixia1 createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer}, 1557e445349dSbixia1 EmitCInterface::Off); 1558e445349dSbixia1 1559e445349dSbixia1 rewriter.eraseOp(op); 1560e445349dSbixia1 return success(); 1561e445349dSbixia1 } 1562e445349dSbixia1 }; 1563e445349dSbixia1 156428ebb0b6SAart Bik } // namespace 156528ebb0b6SAart Bik 156628ebb0b6SAart Bik //===---------------------------------------------------------------------===// 156728ebb0b6SAart Bik // Methods that add patterns described in this file to a pattern list. 156828ebb0b6SAart Bik //===---------------------------------------------------------------------===// 1569f81f0cb7Sbixia1 1570f81f0cb7Sbixia1 void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) { 15713aeb28b9SPeiming Liu patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer, 15723aeb28b9SPeiming Liu FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast, 15733aeb28b9SPeiming Liu GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>( 15743aeb28b9SPeiming Liu patterns.getContext()); 1575f81f0cb7Sbixia1 } 1576f81f0cb7Sbixia1 1577f82bee13SPeiming Liu void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, 1578f81f0cb7Sbixia1 bool enableRT, 1579eb877006Sbixia1 bool enableConvert) { 1580ef100c22SPeiming Liu patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>, 1581bc878f70SPeiming Liu ReshapeRewriter<tensor::CollapseShapeOp>, 1582bc878f70SPeiming Liu Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>, 1583bc878f70SPeiming Liu Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>, 15847d608ee2SPeiming Liu SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>( 1585c780352dSPeiming Liu patterns.getContext()); 1586f82bee13SPeiming Liu 1587eb877006Sbixia1 if (enableConvert) 1588dda3dc5eSPeiming Liu patterns.add<DirectConvertRewriter>(patterns.getContext()); 1589f248d0b2SPeiming Liu if (!enableRT) 15907d608ee2SPeiming Liu patterns.add<NewRewriter>(patterns.getContext()); 159128ebb0b6SAart Bik } 1592f82bee13SPeiming Liu 1593f82bee13SPeiming Liu void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) { 15943426d330SPeiming Liu // Run CrdTranslateRewriter later in the pipeline so that operation can be 15953426d330SPeiming Liu // folded before lowering to affine.apply 15963426d330SPeiming Liu patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext()); 1597f82bee13SPeiming Liu } 1598