1 //===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 11 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 12 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/AffineMap.h" 15 16 using namespace mlir; 17 using namespace mlir::sparse_tensor; 18 19 namespace { 20 21 // TODO: 22 // (1) insert the zero-cost sparse_tensor.reinterpret_map ops 23 // (2) rewrite linalg.generic ops traits on level crds 24 // (3) compute topsort, and resolve cyles with sparse_tensor.convert ops 25 26 // CRTP to help implementing a rewriter that demaps all its inputs and remaps 27 // all its outputs. 28 template <typename SubClass, typename SourceOp> 29 struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> { 30 using OpRewritePattern<SourceOp>::OpRewritePattern; 31 using OpAdaptor = typename SourceOp::Adaptor; 32 33 LogicalResult matchAndRewrite(SourceOp op, 34 PatternRewriter &rewriter) const override { 35 if (!static_cast<const SubClass *>(this)->matchOp(op)) 36 return failure(); 37 38 Location loc = op.getLoc(); 39 // Demaps non-trivial inputs. 40 SmallVector<Value> deMappedIns(op->getOperands()); 41 for (Value &in : deMappedIns) 42 if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) 43 in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in); 44 45 // CRTP call. 46 OpAdaptor adaptor(deMappedIns); 47 ValueRange outs = 48 static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter); 49 assert(outs.size() == op->getResults().size()); 50 51 // Remap outputs. 52 SmallVector<Value> reMappedOuts(outs); 53 for (auto [r, a] : llvm::zip(reMappedOuts, op->getResults())) 54 if (r.getType() != a.getType()) 55 r = rewriter.create<ReinterpretMapOp>(loc, a.getType(), r); 56 57 rewriter.replaceOp(op, reMappedOuts); 58 return success(); 59 } 60 }; 61 62 //===----------------------------------------------------------------------===// 63 // Reinterpret Map Rewriters for operations other than linalg.generics 64 //===----------------------------------------------------------------------===// 65 66 struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> { 67 using OpRewritePattern::OpRewritePattern; 68 LogicalResult matchAndRewrite(CrdTranslateOp op, 69 PatternRewriter &rewriter) const override { 70 AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl 71 ? op.getEncoder().getDimToLvl() 72 : op.getEncoder().getLvlToDim(); 73 74 SmallVector<Value> outCrds; 75 for (AffineExpr result : map.getResults()) { 76 // TODO: we should probably expand the affine map to IR using our own 77 // rules, since affine.apply assume signed value, while the cooridinates 78 // we provided must always be signless. 79 Value trans = rewriter.create<affine::AffineApplyOp>( 80 op.getLoc(), AffineMap::get(map.getNumDims(), 0, result), 81 op.getInCrds()); 82 outCrds.push_back(trans); 83 } 84 rewriter.replaceOp(op, outCrds); 85 return success(); 86 } 87 }; 88 89 struct TensorInsertRewriter 90 : public DemapInsRemapOutsRewriter<TensorInsertRewriter, tensor::InsertOp> { 91 using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter; 92 93 bool matchOp(tensor::InsertOp op) const { 94 return op.getResult().getType().getEncoding() != nullptr; 95 } 96 97 ValueRange rewriteOp(tensor::InsertOp op, OpAdaptor adaptor, 98 PatternRewriter &rewriter) const { 99 Location loc = op.getLoc(); 100 auto stt = getSparseTensorType(op.getResult()); 101 ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(), 102 CrdTransDirectionKind::dim2lvl); 103 Operation *insertOp = rewriter.create<sparse_tensor::InsertOp>( 104 loc, op.getScalar(), adaptor.getDest(), lvlCrd); 105 return insertOp->getResults(); 106 } 107 }; 108 109 } // namespace 110 111 void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns, 112 ReinterpretMapScope scope) { 113 if (scope == ReinterpretMapScope::kAll || 114 scope == ReinterpretMapScope::kExceptGeneric) { 115 patterns.add<CrdTranslateRewriter, TensorInsertRewriter>( 116 patterns.getContext()); 117 } 118 } 119