121895486Swren romano //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2a2c9d4bbSAart Bik // 3a2c9d4bbSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a2c9d4bbSAart Bik // See https://llvm.org/LICENSE.txt for license information. 5a2c9d4bbSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a2c9d4bbSAart Bik // 7a2c9d4bbSAart Bik //===----------------------------------------------------------------------===// 8a2c9d4bbSAart Bik // 986b22d31SAart Bik // A pass that converts sparse tensor primitives into calls into a runtime 1086b22d31SAart Bik // support library. Sparse tensor types are converted into opaque pointers 1186b22d31SAart Bik // to the underlying sparse storage schemes. The use of opaque pointers 1286b22d31SAart Bik // together with runtime support library keeps the conversion relatively 1386b22d31SAart Bik // simple, but at the expense of IR opacity, which obscures opportunities 1486b22d31SAart Bik // for subsequent optimization of the IR. An alternative is provided by 1586b22d31SAart Bik // the SparseTensorCodegen pass. 16a2c9d4bbSAart Bik // 17a2c9d4bbSAart Bik //===----------------------------------------------------------------------===// 18a2c9d4bbSAart Bik 19365777ecSAart Bik #include "Utils/CodegenUtils.h" 20efa15f41SAart Bik 21c66303c2SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 2257470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 23236a9080SAart Bik #include "mlir/Dialect/Linalg/Utils/Utils.h" 24a2c9d4bbSAart Bik #include "mlir/Dialect/MemRef/IR/MemRef.h" 258b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 26062f29c8Swren romano #include "mlir/Dialect/SparseTensor/IR/Enums.h" 27a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 28f708a549Swren romano #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 29a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 30ca5d0a73SAart Bik #include "mlir/Dialect/Tensor/IR/Tensor.h" 31a2c9d4bbSAart Bik #include "mlir/Transforms/DialectConversion.h" 32a2c9d4bbSAart Bik 33a2c9d4bbSAart Bik using namespace mlir; 3496a23911SAart Bik using namespace mlir::sparse_tensor; 35a2c9d4bbSAart Bik 36a2c9d4bbSAart Bik namespace { 37a2c9d4bbSAart Bik 3805c7f450SAart Bik //===----------------------------------------------------------------------===// 3905c7f450SAart Bik // Helper methods. 4005c7f450SAart Bik //===----------------------------------------------------------------------===// 4105c7f450SAart Bik 4286b22d31SAart Bik /// Maps each sparse tensor type to an opaque pointer. 430de16fafSRamkumar Ramachandra static std::optional<Type> convertSparseTensorTypes(Type type) { 4486b22d31SAart Bik if (getSparseTensorEncoding(type) != nullptr) 45dcae289dSChristian Ulmann return LLVM::LLVMPointerType::get(type.getContext()); 461a36588eSKazu Hirata return std::nullopt; 4786b22d31SAart Bik } 4886b22d31SAart Bik 4986f91e45Swren romano /// Generates call to lookup a level-size. N.B., this only generates 5086f91e45Swren romano /// the raw function call, and therefore (intentionally) does not perform 5186f91e45Swren romano /// any dim<->lvl conversion or other logic. 5286f91e45Swren romano static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor, 53c518745bSwren romano uint64_t lvl) { 54c518745bSwren romano StringRef name = "sparseLvlSize"; 5586f91e45Swren romano SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)}; 56e9fa5590SMatthias Springer Type iTp = builder.getIndexType(); 57ee986ab7SPeiming Liu return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off) 58d8731bfcSwren romano .getResult(0); 599d1db3d4SAart Bik } 609d1db3d4SAart Bik 6186f91e45Swren romano /// Generates call to lookup a dimension-size. N.B., this only generates 6286f91e45Swren romano /// the raw function call, and therefore (intentionally) does not perform 6386f91e45Swren romano /// any dim<->lvl conversion or other logic. 6486f91e45Swren romano static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor, 6586f91e45Swren romano uint64_t dim) { 6686f91e45Swren romano StringRef name = "sparseDimSize"; 6786f91e45Swren romano SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)}; 6886f91e45Swren romano Type iTp = builder.getIndexType(); 6986f91e45Swren romano return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off) 7086f91e45Swren romano .getResult(0); 7186f91e45Swren romano } 7286f91e45Swren romano 7386f91e45Swren romano /// Looks up a level-size by returning a statically-computed constant 7486f91e45Swren romano /// (when possible), or by calling `genLvlSizeCall` (when dynamic). 7586f91e45Swren romano static Value createOrFoldLvlCall(OpBuilder &builder, Location loc, 76f708a549Swren romano SparseTensorType stt, Value tensor, 77f708a549Swren romano Level lvl) { 7886f91e45Swren romano // Only sparse tensors have "levels" to query. 79f708a549Swren romano assert(stt.hasEncoding()); 8086f91e45Swren romano // TODO: The following implementation only handles permutations; 8186f91e45Swren romano // we'll need to generalize this to handle arbitrary AffineExpr. 8286f91e45Swren romano // 8386f91e45Swren romano // There's no need to assert `isPermutation` here: because 8486f91e45Swren romano // `getDimPosition` checks that the expr isa `AffineDimExpr`, 8586f91e45Swren romano // which is all we care about (for supporting permutations). 86f708a549Swren romano const Dimension dim = 8776647fceSwren romano stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl); 8822212ca7SAart Bik const Size sz = stt.getDynamicDimSize(dim); 8922212ca7SAart Bik if (!ShapedType::isDynamic(sz)) 9022212ca7SAart Bik return constantIndex(builder, loc, sz); 9186f91e45Swren romano // If we cannot statically compute the size from the shape, then we 9286f91e45Swren romano // must dynamically query it. (In principle we could also dynamically 9386f91e45Swren romano // compute it, but since we already did so to construct the `tensor` 9486f91e45Swren romano // in the first place, we might as well query rather than recompute.) 9586f91e45Swren romano return genLvlSizeCall(builder, loc, tensor, lvl); 96c248219bSPeiming Liu } 97c248219bSPeiming Liu 9886f91e45Swren romano /// Looks up a dimension-size by returning a constant from the shape 9986f91e45Swren romano /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes 10086f91e45Swren romano /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes 10186f91e45Swren romano /// of dense tensors). 10286f91e45Swren romano static Value createOrFoldDimCall(OpBuilder &builder, Location loc, 103f708a549Swren romano SparseTensorType stt, Value tensor, 104f708a549Swren romano Dimension dim) { 10522212ca7SAart Bik const Size sz = stt.getDynamicDimSize(dim); 10622212ca7SAart Bik if (!ShapedType::isDynamic(sz)) 10722212ca7SAart Bik return constantIndex(builder, loc, sz); 108f708a549Swren romano if (stt.hasEncoding()) 10986f91e45Swren romano return genDimSizeCall(builder, loc, tensor, dim); 11086f91e45Swren romano return linalg::createOrFoldDimOp(builder, loc, tensor, dim); 111c248219bSPeiming Liu } 112c248219bSPeiming Liu 11386f91e45Swren romano /// Populates the array with the dimension-sizes of the given tensor. 114f708a549Swren romano static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, 11586f91e45Swren romano Value tensor, SmallVectorImpl<Value> &out) { 116f708a549Swren romano const Dimension dimRank = stt.getDimRank(); 117f708a549Swren romano out.clear(); 11886f91e45Swren romano out.reserve(dimRank); 119f708a549Swren romano for (Dimension d = 0; d < dimRank; d++) 120f708a549Swren romano out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d)); 1219d1db3d4SAart Bik } 12286f91e45Swren romano 12386f91e45Swren romano /// Returns an array with the dimension-sizes of the given tensor. 124fa6726e2SPeiming Liu /// If the *tensor* parameters is null, the tensor type is assumed to have a 125fa6726e2SPeiming Liu /// static shape. 12686f91e45Swren romano static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc, 127fa6726e2SPeiming Liu SparseTensorType stt, 128fa6726e2SPeiming Liu Value tensor = Value()) { 12986f91e45Swren romano SmallVector<Value> out; 130f708a549Swren romano fillDimSizes(builder, loc, stt, tensor, out); 13186f91e45Swren romano return out; 13286f91e45Swren romano } 13386f91e45Swren romano 1340b55f94dSAart Bik /// Generates an uninitialized buffer of the given size and type, 1350b55f94dSAart Bik /// but returns it as type `memref<? x $tp>` (rather than as type 1360b55f94dSAart Bik /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack, 1370b55f94dSAart Bik /// this buffer must be explicitly deallocated by client. 138e9fa5590SMatthias Springer static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) { 139399638f9SAliia Khasanova auto memTp = MemRefType::get({ShapedType::kDynamic}, tp); 1400b55f94dSAart Bik return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz}); 1410b55f94dSAart Bik } 1420b55f94dSAart Bik 1432af2e4dbSwren romano /// Generates a temporary buffer for the level-types of the given encoding. 1442af2e4dbSwren romano static Value genLvlTypesBuffer(OpBuilder &builder, Location loc, 145f708a549Swren romano SparseTensorType stt) { 1462af2e4dbSwren romano SmallVector<Value> lvlTypes; 147f708a549Swren romano lvlTypes.reserve(stt.getLvlRank()); 1481dd387e1SAart Bik for (const auto lt : stt.getEncoding().getLvlTypes()) 1491944c4f7SAart Bik lvlTypes.push_back(constantLevelTypeEncoding(builder, loc, lt)); 15062896428Swren romano return allocaBuffer(builder, loc, lvlTypes); 1512af2e4dbSwren romano } 1522af2e4dbSwren romano 153fa6726e2SPeiming Liu /// Extracts the bare (aligned) pointers that point to the tensor. 154fa6726e2SPeiming Liu static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc, 155fa6726e2SPeiming Liu Value tensor) { 156fa6726e2SPeiming Liu auto buf = genToMemref(builder, loc, tensor); 157fa6726e2SPeiming Liu return builder.create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf); 158fa6726e2SPeiming Liu } 159fa6726e2SPeiming Liu 160fa6726e2SPeiming Liu /// Generates a temporary buffer for the level-types of the given encoding. 161fa6726e2SPeiming Liu static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc, 162fa6726e2SPeiming Liu ValueRange lvlTensors, Value valTensor) { 163fa6726e2SPeiming Liu SmallVector<Value> lvlBarePtrs; 164fa6726e2SPeiming Liu lvlBarePtrs.reserve(lvlTensors.size() + 1); 165fa6726e2SPeiming Liu // Passing in lvl buffer pointers. 166fa6726e2SPeiming Liu for (const auto lvl : lvlTensors) 167fa6726e2SPeiming Liu lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl)); 168fa6726e2SPeiming Liu 169fa6726e2SPeiming Liu // Passing in value buffer pointers. 170fa6726e2SPeiming Liu lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor)); 171fa6726e2SPeiming Liu Value idxPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>( 172fa6726e2SPeiming Liu loc, allocaBuffer(builder, loc, lvlBarePtrs)); 173fa6726e2SPeiming Liu Value idxCast = 174fa6726e2SPeiming Liu builder.create<arith::IndexCastOp>(loc, builder.getI64Type(), idxPtr); 175fa6726e2SPeiming Liu return builder.create<LLVM::IntToPtrOp>(loc, getOpaquePointerType(builder), 176fa6726e2SPeiming Liu idxCast); 177fa6726e2SPeiming Liu } 178fa6726e2SPeiming Liu 179f5ce99afSwren romano /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`: 180f5ce99afSwren romano /// the "swiss army knife" method of the sparse runtime support library 181f5ce99afSwren romano /// for materializing sparse tensors into the computation. This abstraction 182b7188d28SAart Bik /// reduces the need for modifications when the API changes. 183f5ce99afSwren romano class NewCallParams final { 184f5ce99afSwren romano public: 185b7188d28SAart Bik /// Allocates the `ValueRange` for the `func::CallOp` parameters. 186f5ce99afSwren romano NewCallParams(OpBuilder &builder, Location loc) 187f5ce99afSwren romano : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {} 188f5ce99afSwren romano 189f5ce99afSwren romano /// Initializes all static parameters (i.e., those which indicate 190f5ce99afSwren romano /// type-level information such as the encoding and sizes), generating 191f5ce99afSwren romano /// MLIR buffers as needed, and returning `this` for method chaining. 192b7188d28SAart Bik NewCallParams &genBuffers(SparseTensorType stt, 193d392073fSAart Bik ArrayRef<Value> dimSizesValues, 194d392073fSAart Bik Value dimSizesBuffer = Value()) { 195a942f7c8SAart Bik assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank())); 196b7188d28SAart Bik // Sparsity annotations. 197b7188d28SAart Bik params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt); 198b7188d28SAart Bik // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers. 199d392073fSAart Bik params[kParamDimSizes] = dimSizesBuffer 200d392073fSAart Bik ? dimSizesBuffer 201d392073fSAart Bik : allocaBuffer(builder, loc, dimSizesValues); 2022323f48eSAart Bik SmallVector<Value> lvlSizesValues; // unused 2032323f48eSAart Bik params[kParamLvlSizes] = genMapBuffers( 2042323f48eSAart Bik builder, loc, stt, dimSizesValues, params[kParamDimSizes], 2052323f48eSAart Bik lvlSizesValues, params[kParamDim2Lvl], params[kParamLvl2Dim]); 206b7188d28SAart Bik // Secondary and primary types encoding. 207f708a549Swren romano const auto enc = stt.getEncoding(); 20884cd51bbSwren romano params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc); 20984cd51bbSwren romano params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc); 210f5ce99afSwren romano params[kParamValTp] = 211f708a549Swren romano constantPrimaryTypeEncoding(builder, loc, stt.getElementType()); 212bbecd422SAart Bik // Return `this` for method chaining. 213f5ce99afSwren romano return *this; 214f5ce99afSwren romano } 215f5ce99afSwren romano 216f5ce99afSwren romano /// Checks whether all the static parameters have been initialized. 217f5ce99afSwren romano bool isInitialized() const { 218f5ce99afSwren romano for (unsigned i = 0; i < kNumStaticParams; ++i) 219f5ce99afSwren romano if (!params[i]) 220f5ce99afSwren romano return false; 221f5ce99afSwren romano return true; 222f5ce99afSwren romano } 223f5ce99afSwren romano 224f5ce99afSwren romano /// Generates a function call, with the current static parameters 225f5ce99afSwren romano /// and the given dynamic arguments. 226f5ce99afSwren romano Value genNewCall(Action action, Value ptr = Value()) { 227f5ce99afSwren romano assert(isInitialized() && "Must initialize before genNewCall"); 228f5ce99afSwren romano StringRef name = "newSparseTensor"; 229f5ce99afSwren romano params[kParamAction] = constantAction(builder, loc, action); 23085175eddSTobias Gysi params[kParamPtr] = ptr ? ptr : builder.create<LLVM::ZeroOp>(loc, pTp); 231f5ce99afSwren romano return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On) 232f5ce99afSwren romano .getResult(0); 233f5ce99afSwren romano } 234f5ce99afSwren romano 235f5ce99afSwren romano private: 236c518745bSwren romano static constexpr unsigned kNumStaticParams = 8; 237f5ce99afSwren romano static constexpr unsigned kNumDynamicParams = 2; 238f5ce99afSwren romano static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams; 239c518745bSwren romano static constexpr unsigned kParamDimSizes = 0; 240c518745bSwren romano static constexpr unsigned kParamLvlSizes = 1; 241c518745bSwren romano static constexpr unsigned kParamLvlTypes = 2; 242b7188d28SAart Bik static constexpr unsigned kParamDim2Lvl = 3; 243b7188d28SAart Bik static constexpr unsigned kParamLvl2Dim = 4; 24484cd51bbSwren romano static constexpr unsigned kParamPosTp = 5; 24584cd51bbSwren romano static constexpr unsigned kParamCrdTp = 6; 246c518745bSwren romano static constexpr unsigned kParamValTp = 7; 247c518745bSwren romano static constexpr unsigned kParamAction = 8; 248c518745bSwren romano static constexpr unsigned kParamPtr = 9; 249f5ce99afSwren romano 250f5ce99afSwren romano OpBuilder &builder; 251f5ce99afSwren romano Location loc; 252f5ce99afSwren romano Type pTp; 253f5ce99afSwren romano Value params[kNumParams]; 254f5ce99afSwren romano }; 255f5ce99afSwren romano 2560f3e4d1aSAart Bik /// Generates a call to obtain the values array. 257af8428c0SAart Bik static Value genValuesCall(OpBuilder &builder, Location loc, 258af8428c0SAart Bik SparseTensorType stt, Value ptr) { 259af8428c0SAart Bik auto eltTp = stt.getElementType(); 260af8428c0SAart Bik auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp); 261af8428c0SAart Bik SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltTp)}; 262af8428c0SAart Bik return createFuncCall(builder, loc, name, resTp, {ptr}, EmitCInterface::On) 263af8428c0SAart Bik .getResult(0); 264af8428c0SAart Bik } 265af8428c0SAart Bik 266af8428c0SAart Bik /// Generates a call to obtain the positions array. 267af8428c0SAart Bik static Value genPositionsCall(OpBuilder &builder, Location loc, 268af8428c0SAart Bik SparseTensorType stt, Value ptr, Level l) { 269af8428c0SAart Bik Type posTp = stt.getPosType(); 270af8428c0SAart Bik auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp); 271af8428c0SAart Bik Value lvl = constantIndex(builder, loc, l); 272af8428c0SAart Bik SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)}; 273af8428c0SAart Bik return createFuncCall(builder, loc, name, resTp, {ptr, lvl}, 274af8428c0SAart Bik EmitCInterface::On) 275af8428c0SAart Bik .getResult(0); 276af8428c0SAart Bik } 277af8428c0SAart Bik 278dc4cfdbbSAart Bik /// Generates a call to obtain the coordinates array. 279af8428c0SAart Bik static Value genCoordinatesCall(OpBuilder &builder, Location loc, 280af8428c0SAart Bik SparseTensorType stt, Value ptr, Level l) { 281af8428c0SAart Bik Type crdTp = stt.getCrdType(); 282af8428c0SAart Bik auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp); 283af8428c0SAart Bik Value lvl = constantIndex(builder, loc, l); 284af8428c0SAart Bik SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)}; 285af8428c0SAart Bik return createFuncCall(builder, loc, name, resTp, {ptr, lvl}, 286af8428c0SAart Bik EmitCInterface::On) 2870f3e4d1aSAart Bik .getResult(0); 2880f3e4d1aSAart Bik } 2890f3e4d1aSAart Bik 290dc4cfdbbSAart Bik /// Generates a call to obtain the coordinates array (AoS view). 291dc4cfdbbSAart Bik static Value genCoordinatesBufferCall(OpBuilder &builder, Location loc, 292dc4cfdbbSAart Bik SparseTensorType stt, Value ptr, 293dc4cfdbbSAart Bik Level l) { 294dc4cfdbbSAart Bik Type crdTp = stt.getCrdType(); 295dc4cfdbbSAart Bik auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp); 296dc4cfdbbSAart Bik Value lvl = constantIndex(builder, loc, l); 297dc4cfdbbSAart Bik SmallString<25> name{"sparseCoordinatesBuffer", 298dc4cfdbbSAart Bik overheadTypeFunctionSuffix(crdTp)}; 299dc4cfdbbSAart Bik return createFuncCall(builder, loc, name, resTp, {ptr, lvl}, 300dc4cfdbbSAart Bik EmitCInterface::On) 301dc4cfdbbSAart Bik .getResult(0); 302dc4cfdbbSAart Bik } 303dc4cfdbbSAart Bik 30405c7f450SAart Bik //===----------------------------------------------------------------------===// 30505c7f450SAart Bik // Conversion rules. 30605c7f450SAart Bik //===----------------------------------------------------------------------===// 30705c7f450SAart Bik 30896a23911SAart Bik /// Sparse conversion rule for returns. 30923aa5a74SRiver Riddle class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { 31096a23911SAart Bik public: 311a2c9d4bbSAart Bik using OpConversionPattern::OpConversionPattern; 312a2c9d4bbSAart Bik LogicalResult 31323aa5a74SRiver Riddle matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 314a2c9d4bbSAart Bik ConversionPatternRewriter &rewriter) const override { 31523aa5a74SRiver Riddle rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands()); 316a2c9d4bbSAart Bik return success(); 317a2c9d4bbSAart Bik } 318a2c9d4bbSAart Bik }; 319a2c9d4bbSAart Bik 320c780352dSPeiming Liu /// Sparse conversion rule for accessing level-sizes. 321c780352dSPeiming Liu class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> { 322a2c9d4bbSAart Bik public: 323a2c9d4bbSAart Bik using OpConversionPattern::OpConversionPattern; 324a2c9d4bbSAart Bik LogicalResult 325c780352dSPeiming Liu matchAndRewrite(LvlOp op, OpAdaptor adaptor, 326a2c9d4bbSAart Bik ConversionPatternRewriter &rewriter) const override { 327f708a549Swren romano const auto stt = getSparseTensorType(op.getSource()); 32886f91e45Swren romano // Only rewrite sparse DimOp. 329f708a549Swren romano if (!stt.hasEncoding()) 330d37d72eaSAart Bik return failure(); 331c780352dSPeiming Liu 33286f91e45Swren romano // Only rewrite DimOp with constant index. 333c780352dSPeiming Liu std::optional<int64_t> lvl = op.getConstantLvlIndex(); 334c780352dSPeiming Liu 335c780352dSPeiming Liu if (!lvl) 336d37d72eaSAart Bik return failure(); 337c780352dSPeiming Liu 338c780352dSPeiming Liu // By now, if the level size is constant, the operation should have already 339c780352dSPeiming Liu // been folded by LvlOp's folder, so we generate the call unconditionally. 3409d1db3d4SAart Bik Value src = adaptor.getOperands()[0]; 341c780352dSPeiming Liu rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl)); 342a2c9d4bbSAart Bik return success(); 343a2c9d4bbSAart Bik } 344a2c9d4bbSAart Bik }; 345a2c9d4bbSAart Bik 3461b15160eSAart Bik /// Sparse conversion rule for trivial tensor casts. 3471b15160eSAart Bik class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 348faa00c13SAart Bik public: 3491b15160eSAart Bik using OpConversionPattern::OpConversionPattern; 3501b15160eSAart Bik LogicalResult 3511b15160eSAart Bik matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 3521b15160eSAart Bik ConversionPatternRewriter &rewriter) const override { 3531b15160eSAart Bik // Only rewrite identically annotated source/dest. 3541b15160eSAart Bik auto encDst = getSparseTensorEncoding(op.getType()); 3558df54a6aSJacques Pienaar auto encSrc = getSparseTensorEncoding(op.getSource().getType()); 3561b15160eSAart Bik if (!encDst || encDst != encSrc) 3571b15160eSAart Bik return failure(); 3581b15160eSAart Bik rewriter.replaceOp(op, adaptor.getOperands()); 3591b15160eSAart Bik return success(); 3601b15160eSAart Bik } 3611b15160eSAart Bik }; 3621b15160eSAart Bik 363ef222988SPeiming Liu class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> { 364ef222988SPeiming Liu public: 365ef222988SPeiming Liu using OpConversionPattern::OpConversionPattern; 366ef222988SPeiming Liu LogicalResult 367ef222988SPeiming Liu matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor, 368ef222988SPeiming Liu ConversionPatternRewriter &rewriter) const override { 369ef222988SPeiming Liu // Simply fold the operation. 370ef222988SPeiming Liu rewriter.replaceOp(op, adaptor.getSource()); 371ef222988SPeiming Liu return success(); 372ef222988SPeiming Liu } 373ef222988SPeiming Liu }; 374ef222988SPeiming Liu 37596a23911SAart Bik /// Sparse conversion rule for the new operator. 37696a23911SAart Bik class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 377faa00c13SAart Bik public: 37896a23911SAart Bik using OpConversionPattern::OpConversionPattern; 37996a23911SAart Bik LogicalResult 380b54c724bSRiver Riddle matchAndRewrite(NewOp op, OpAdaptor adaptor, 38196a23911SAart Bik ConversionPatternRewriter &rewriter) const override { 382ee986ab7SPeiming Liu Location loc = op.getLoc(); 383f708a549Swren romano const auto stt = getSparseTensorType(op); 384f708a549Swren romano if (!stt.hasEncoding()) 38596a23911SAart Bik return failure(); 386d392073fSAart Bik // Construct the `reader` opening method calls. 3872323f48eSAart Bik SmallVector<Value> dimSizesValues; 3882af2e4dbSwren romano Value dimSizesBuffer; 389d3af6535SAart Bik Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0], 3902323f48eSAart Bik dimSizesValues, dimSizesBuffer); 3912af2e4dbSwren romano // Use the `reader` to parse the file. 392d392073fSAart Bik Value tensor = NewCallParams(rewriter, loc) 3932323f48eSAart Bik .genBuffers(stt, dimSizesValues, dimSizesBuffer) 394d392073fSAart Bik .genNewCall(Action::kFromReader, reader); 3952af2e4dbSwren romano // Free the memory for `reader`. 3962af2e4dbSwren romano createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader}, 3972af2e4dbSwren romano EmitCInterface::Off); 3982af2e4dbSwren romano rewriter.replaceOp(op, tensor); 399b24788abSAart Bik return success(); 400b24788abSAart Bik } 401b24788abSAart Bik }; 402b24788abSAart Bik 4036232a8f3SMatthias Springer /// Sparse conversion rule for the alloc operator. 404c6472f57SAart Bik /// TODO(springerm): remove when bufferization.alloc_tensor is gone 4056232a8f3SMatthias Springer class SparseTensorAllocConverter 4066232a8f3SMatthias Springer : public OpConversionPattern<bufferization::AllocTensorOp> { 407faa00c13SAart Bik public: 408b24788abSAart Bik using OpConversionPattern::OpConversionPattern; 409b24788abSAart Bik LogicalResult 4106232a8f3SMatthias Springer matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, 411b24788abSAart Bik ConversionPatternRewriter &rewriter) const override { 412f708a549Swren romano const auto stt = getSparseTensorType(op); 413f708a549Swren romano if (!stt.hasEncoding()) 414b24788abSAart Bik return failure(); 415160d483bSAart Bik if (op.getCopy()) 416160d483bSAart Bik return rewriter.notifyMatchFailure(op, "alloc copy not implemented"); 4176232a8f3SMatthias Springer // Gather all dimension sizes as SSA values. 418160d483bSAart Bik Location loc = op.getLoc(); 419f708a549Swren romano const Dimension dimRank = stt.getDimRank(); 4202323f48eSAart Bik SmallVector<Value> dimSizesValues; 4212323f48eSAart Bik dimSizesValues.reserve(dimRank); 422f708a549Swren romano unsigned operandCtr = 0; 423af8428c0SAart Bik for (Dimension d = 0; d < dimRank; d++) { 4242323f48eSAart Bik dimSizesValues.push_back( 425f708a549Swren romano stt.isDynamicDim(d) 426f708a549Swren romano ? adaptor.getOperands()[operandCtr++] 427f708a549Swren romano : constantIndex(rewriter, loc, op.getStaticSize(d))); 4286232a8f3SMatthias Springer } 4299d1db3d4SAart Bik // Generate the call to construct empty tensor. The sizes are 4306232a8f3SMatthias Springer // explicitly defined by the arguments to the alloc operator. 431f708a549Swren romano rewriter.replaceOp(op, NewCallParams(rewriter, loc) 4322323f48eSAart Bik .genBuffers(stt, dimSizesValues) 433f5ce99afSwren romano .genNewCall(Action::kEmpty)); 43496a23911SAart Bik return success(); 43596a23911SAart Bik } 43696a23911SAart Bik }; 43796a23911SAart Bik 438c6472f57SAart Bik /// Sparse conversion rule for the empty tensor. 439c6472f57SAart Bik class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> { 440c6472f57SAart Bik public: 441c6472f57SAart Bik using OpConversionPattern::OpConversionPattern; 442c6472f57SAart Bik LogicalResult 443c6472f57SAart Bik matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor, 444c6472f57SAart Bik ConversionPatternRewriter &rewriter) const override { 445c6472f57SAart Bik Location loc = op.getLoc(); 446c6472f57SAart Bik const auto stt = getSparseTensorType(op); 447c6472f57SAart Bik if (!stt.hasEncoding()) 448c6472f57SAart Bik return failure(); 449c6472f57SAart Bik // Gather all dimension sizes as SSA values. 450c6472f57SAart Bik const Dimension dimRank = stt.getDimRank(); 4512323f48eSAart Bik SmallVector<Value> dimSizesValues; 4522323f48eSAart Bik dimSizesValues.reserve(dimRank); 453c6472f57SAart Bik auto shape = op.getType().getShape(); 454c6472f57SAart Bik unsigned operandCtr = 0; 455af8428c0SAart Bik for (Dimension d = 0; d < dimRank; d++) { 4562323f48eSAart Bik dimSizesValues.push_back(stt.isDynamicDim(d) 457c6472f57SAart Bik ? adaptor.getOperands()[operandCtr++] 458c6472f57SAart Bik : constantIndex(rewriter, loc, shape[d])); 459c6472f57SAart Bik } 460c6472f57SAart Bik // Generate the call to construct empty tensor. The sizes are 461c6472f57SAart Bik // explicitly defined by the arguments to the alloc operator. 462c6472f57SAart Bik rewriter.replaceOp(op, NewCallParams(rewriter, loc) 4632323f48eSAart Bik .genBuffers(stt, dimSizesValues) 464c6472f57SAart Bik .genNewCall(Action::kEmpty)); 465c6472f57SAart Bik return success(); 466c6472f57SAart Bik } 467c6472f57SAart Bik }; 468c6472f57SAart Bik 469697ea09dSAart Bik /// Sparse conversion rule for the convert operator. 470f248d0b2SPeiming Liu class SparseTensorReorderCOOConverter 471f248d0b2SPeiming Liu : public OpConversionPattern<ReorderCOOOp> { 472c7e24db4Swren romano public: 473697ea09dSAart Bik using OpConversionPattern::OpConversionPattern; 474c7e24db4Swren romano 475697ea09dSAart Bik LogicalResult 476f248d0b2SPeiming Liu matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor, 477697ea09dSAart Bik ConversionPatternRewriter &rewriter) const override { 478f708a549Swren romano const Location loc = op->getLoc(); 479f248d0b2SPeiming Liu const auto srcTp = getSparseTensorType(op.getInputCoo()); 480f708a549Swren romano const auto dstTp = getSparseTensorType(op); 481f708a549Swren romano 482f248d0b2SPeiming Liu const Value src = adaptor.getInputCoo(); 483f248d0b2SPeiming Liu 484f5ce99afSwren romano NewCallParams params(rewriter, loc); 4852323f48eSAart Bik SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, srcTp, src); 4862323f48eSAart Bik rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizesValues) 487f248d0b2SPeiming Liu .genNewCall(Action::kSortCOOInPlace, src)); 488faa00c13SAart Bik 489f248d0b2SPeiming Liu return success(); 490f248d0b2SPeiming Liu } 491697ea09dSAart Bik }; 492697ea09dSAart Bik 49327a431f5SMatthias Springer /// Sparse conversion rule for the dealloc operator. 49427a431f5SMatthias Springer class SparseTensorDeallocConverter 49527a431f5SMatthias Springer : public OpConversionPattern<bufferization::DeallocTensorOp> { 49616b8f4ddSAart Bik public: 49716b8f4ddSAart Bik using OpConversionPattern::OpConversionPattern; 49816b8f4ddSAart Bik LogicalResult 49927a431f5SMatthias Springer matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor, 50016b8f4ddSAart Bik ConversionPatternRewriter &rewriter) const override { 501f708a549Swren romano if (!getSparseTensorType(op.getTensor()).hasEncoding()) 50227a431f5SMatthias Springer return failure(); 50316b8f4ddSAart Bik StringRef name = "delSparseTensor"; 504ee986ab7SPeiming Liu createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(), 505d8731bfcSwren romano EmitCInterface::Off); 50616b8f4ddSAart Bik rewriter.eraseOp(op); 50716b8f4ddSAart Bik return success(); 50816b8f4ddSAart Bik } 50916b8f4ddSAart Bik }; 51016b8f4ddSAart Bik 51184cd51bbSwren romano /// Sparse conversion rule for position accesses. 51284cd51bbSwren romano class SparseTensorToPositionsConverter 51384cd51bbSwren romano : public OpConversionPattern<ToPositionsOp> { 514a2c9d4bbSAart Bik public: 515a2c9d4bbSAart Bik using OpConversionPattern::OpConversionPattern; 516a2c9d4bbSAart Bik LogicalResult 51784cd51bbSwren romano matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor, 518a2c9d4bbSAart Bik ConversionPatternRewriter &rewriter) const override { 519af8428c0SAart Bik auto stt = getSparseTensorType(op.getTensor()); 520af8428c0SAart Bik auto poss = genPositionsCall(rewriter, op.getLoc(), stt, 521af8428c0SAart Bik adaptor.getTensor(), op.getLevel()); 522af8428c0SAart Bik rewriter.replaceOp(op, poss); 523a2c9d4bbSAart Bik return success(); 524a2c9d4bbSAart Bik } 525a2c9d4bbSAart Bik }; 526a2c9d4bbSAart Bik 52784cd51bbSwren romano /// Sparse conversion rule for coordinate accesses. 52884cd51bbSwren romano class SparseTensorToCoordinatesConverter 52984cd51bbSwren romano : public OpConversionPattern<ToCoordinatesOp> { 530a2c9d4bbSAart Bik public: 531a2c9d4bbSAart Bik using OpConversionPattern::OpConversionPattern; 532a2c9d4bbSAart Bik LogicalResult 53384cd51bbSwren romano matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor, 534a2c9d4bbSAart Bik ConversionPatternRewriter &rewriter) const override { 535dc4cfdbbSAart Bik const Location loc = op.getLoc(); 536af8428c0SAart Bik auto stt = getSparseTensorType(op.getTensor()); 537dc4cfdbbSAart Bik auto crds = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), 538dc4cfdbbSAart Bik op.getLevel()); 53990aa4362Sbixia1 // Cast the MemRef type to the type expected by the users, though these 54090aa4362Sbixia1 // two types should be compatible at runtime. 541af8428c0SAart Bik if (op.getType() != crds.getType()) 542dc4cfdbbSAart Bik crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds); 543dc4cfdbbSAart Bik rewriter.replaceOp(op, crds); 544dc4cfdbbSAart Bik return success(); 545dc4cfdbbSAart Bik } 546dc4cfdbbSAart Bik }; 547dc4cfdbbSAart Bik 548dc4cfdbbSAart Bik /// Sparse conversion rule for coordinate accesses (AoS style). 549dc4cfdbbSAart Bik class SparseToCoordinatesBufferConverter 550dc4cfdbbSAart Bik : public OpConversionPattern<ToCoordinatesBufferOp> { 551dc4cfdbbSAart Bik public: 552dc4cfdbbSAart Bik using OpConversionPattern::OpConversionPattern; 553dc4cfdbbSAart Bik LogicalResult 554dc4cfdbbSAart Bik matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor, 555dc4cfdbbSAart Bik ConversionPatternRewriter &rewriter) const override { 556dc4cfdbbSAart Bik const Location loc = op.getLoc(); 557dc4cfdbbSAart Bik auto stt = getSparseTensorType(op.getTensor()); 558dc4cfdbbSAart Bik auto crds = genCoordinatesBufferCall( 559dc4cfdbbSAart Bik rewriter, loc, stt, adaptor.getTensor(), stt.getAoSCOOStart()); 560dc4cfdbbSAart Bik // Cast the MemRef type to the type expected by the users, though these 561dc4cfdbbSAart Bik // two types should be compatible at runtime. 562dc4cfdbbSAart Bik if (op.getType() != crds.getType()) 563dc4cfdbbSAart Bik crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds); 564af8428c0SAart Bik rewriter.replaceOp(op, crds); 565a2c9d4bbSAart Bik return success(); 566a2c9d4bbSAart Bik } 567a2c9d4bbSAart Bik }; 568a2c9d4bbSAart Bik 569a2c9d4bbSAart Bik /// Sparse conversion rule for value accesses. 57096a23911SAart Bik class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 571a2c9d4bbSAart Bik public: 572a2c9d4bbSAart Bik using OpConversionPattern::OpConversionPattern; 573a2c9d4bbSAart Bik LogicalResult 574b54c724bSRiver Riddle matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 575a2c9d4bbSAart Bik ConversionPatternRewriter &rewriter) const override { 576af8428c0SAart Bik auto stt = getSparseTensorType(op.getTensor()); 577af8428c0SAart Bik auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor()); 578af8428c0SAart Bik rewriter.replaceOp(op, vals); 5790f3e4d1aSAart Bik return success(); 5800f3e4d1aSAart Bik } 5810f3e4d1aSAart Bik }; 5820f3e4d1aSAart Bik 5830f3e4d1aSAart Bik /// Sparse conversion rule for number of entries operator. 5840f3e4d1aSAart Bik class SparseNumberOfEntriesConverter 5850f3e4d1aSAart Bik : public OpConversionPattern<NumberOfEntriesOp> { 5860f3e4d1aSAart Bik public: 5870f3e4d1aSAart Bik using OpConversionPattern::OpConversionPattern; 5880f3e4d1aSAart Bik LogicalResult 5890f3e4d1aSAart Bik matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, 5900f3e4d1aSAart Bik ConversionPatternRewriter &rewriter) const override { 5910f3e4d1aSAart Bik // Query values array size for the actually stored values size. 592af8428c0SAart Bik auto stt = getSparseTensorType(op.getTensor()); 593af8428c0SAart Bik auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor()); 594af8428c0SAart Bik auto zero = constantIndex(rewriter, op.getLoc(), 0); 595af8428c0SAart Bik rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero); 596a2c9d4bbSAart Bik return success(); 597a2c9d4bbSAart Bik } 598a2c9d4bbSAart Bik }; 599a2c9d4bbSAart Bik 600f66e5769SAart Bik /// Sparse conversion rule for tensor rematerialization. 601f66e5769SAart Bik class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 602727a63e0SAart Bik public: 603727a63e0SAart Bik using OpConversionPattern::OpConversionPattern; 604727a63e0SAart Bik LogicalResult 605f66e5769SAart Bik matchAndRewrite(LoadOp op, OpAdaptor adaptor, 606727a63e0SAart Bik ConversionPatternRewriter &rewriter) const override { 6078df54a6aSJacques Pienaar if (op.getHasInserts()) { 608f66e5769SAart Bik // Finalize any pending insertions. 6092045cca0SAart Bik StringRef name = "endLexInsert"; 610ee986ab7SPeiming Liu createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(), 611d8731bfcSwren romano EmitCInterface::Off); 61236b66ab9SAart Bik } 613f66e5769SAart Bik rewriter.replaceOp(op, adaptor.getOperands()); 614f66e5769SAart Bik return success(); 61536b66ab9SAart Bik } 616f66e5769SAart Bik }; 617f66e5769SAart Bik 618f76dcedeSAart Bik /// Sparse conversion rule for the insertion operator. 61994e27c26SPeiming Liu class SparseTensorInsertConverter 62094e27c26SPeiming Liu : public OpConversionPattern<tensor::InsertOp> { 621f66e5769SAart Bik public: 622f66e5769SAart Bik using OpConversionPattern::OpConversionPattern; 623f66e5769SAart Bik LogicalResult 62494e27c26SPeiming Liu matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor, 625f66e5769SAart Bik ConversionPatternRewriter &rewriter) const override { 626f76dcedeSAart Bik // Note that the current regime only allows for strict lexicographic 62784cd51bbSwren romano // coordinate order. All values are passed by reference through stack 628a3610359SAart Bik // allocated memrefs. 629a3610359SAart Bik Location loc = op->getLoc(); 63094e27c26SPeiming Liu const auto stt = getSparseTensorType(op.getDest()); 63194e27c26SPeiming Liu 63294e27c26SPeiming Liu // Dense tensor insertion. 63394e27c26SPeiming Liu if (!stt.hasEncoding()) 63494e27c26SPeiming Liu return failure(); 63594e27c26SPeiming Liu 63694e27c26SPeiming Liu assert(stt.isIdentity() && "Run reinterpret-map before conversion."); 637f708a549Swren romano const auto elemTp = stt.getElementType(); 63884cd51bbSwren romano const Level lvlRank = stt.getLvlRank(); 6396243d7d2SPeiming Liu Value lvlCoords, vref; 6406243d7d2SPeiming Liu { 6416243d7d2SPeiming Liu OpBuilder::InsertionGuard guard(rewriter); 64243961264SPeiming Liu Operation *loop = op; 64343961264SPeiming Liu // Finds the outermost loop. 64443961264SPeiming Liu while (auto l = loop->getParentOfType<LoopLikeOpInterface>()) 64543961264SPeiming Liu loop = l; 64643961264SPeiming Liu 64743961264SPeiming Liu if (llvm::isa<LoopLikeOpInterface>(loop)) { 6486243d7d2SPeiming Liu // Hoists alloca outside the loop to avoid stack overflow. 6496243d7d2SPeiming Liu rewriter.setInsertionPoint(loop); 6506243d7d2SPeiming Liu } 6516243d7d2SPeiming Liu lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType()); 6526243d7d2SPeiming Liu vref = genAllocaScalar(rewriter, loc, elemTp); 6536243d7d2SPeiming Liu } 65494e27c26SPeiming Liu storeAll(rewriter, loc, lvlCoords, adaptor.getIndices()); 65594e27c26SPeiming Liu rewriter.create<memref::StoreOp>(loc, adaptor.getScalar(), vref); 656c9489225Swren romano SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; 65784cd51bbSwren romano createFuncCall(rewriter, loc, name, {}, 65894e27c26SPeiming Liu {adaptor.getDest(), lvlCoords, vref}, EmitCInterface::On); 65994e27c26SPeiming Liu rewriter.replaceOp(op, adaptor.getDest()); 66036b66ab9SAart Bik return success(); 661727a63e0SAart Bik } 662727a63e0SAart Bik }; 663727a63e0SAart Bik 664faa00c13SAart Bik /// Sparse conversion rule for the expand operator. 6654f2ec7f9SAart Bik class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> { 6664f2ec7f9SAart Bik public: 6674f2ec7f9SAart Bik using OpConversionPattern::OpConversionPattern; 6684f2ec7f9SAart Bik LogicalResult 6694f2ec7f9SAart Bik matchAndRewrite(ExpandOp op, OpAdaptor adaptor, 6704f2ec7f9SAart Bik ConversionPatternRewriter &rewriter) const override { 6714f2ec7f9SAart Bik Location loc = op->getLoc(); 672f708a549Swren romano const auto srcTp = getSparseTensorType(op.getTensor()); 673f708a549Swren romano Type eltType = srcTp.getElementType(); 6744f2ec7f9SAart Bik Type boolType = rewriter.getIntegerType(1); 6754f2ec7f9SAart Bik Type idxType = rewriter.getIndexType(); 6764f2ec7f9SAart Bik // All initialization should be done on entry of the loop nest. 6778df54a6aSJacques Pienaar rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); 67886f91e45Swren romano // Get the cardinality of valid coordinates for the innermost level. 679f708a549Swren romano Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(), 680f708a549Swren romano srcTp.getLvlRank() - 1); 68184cd51bbSwren romano // Allocate temporary buffers for values, filled-switch, and coordinates. 6820b55f94dSAart Bik // We do not use stack buffers for this, since the expanded size may 6830b55f94dSAart Bik // be rather large (as it envelops a single expanded dense dimension). 6840b55f94dSAart Bik Value values = genAlloc(rewriter, loc, sz, eltType); 6850b55f94dSAart Bik Value filled = genAlloc(rewriter, loc, sz, boolType); 68684cd51bbSwren romano Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType); 6874f2ec7f9SAart Bik Value zero = constantZero(rewriter, loc, idxType); 6884f2ec7f9SAart Bik // Reset the values/filled-switch to all-zero/false. Note that this 6894f2ec7f9SAart Bik // introduces an O(N) operation into the computation, but this reset 6904f2ec7f9SAart Bik // operation is amortized over the innermost loops for the access 6910b55f94dSAart Bik // pattern expansion. As noted in the operation doc, we would like 6920b55f94dSAart Bik // to amortize this setup cost even between kernels. 6937294be2bSgysit rewriter.create<linalg::FillOp>( 6947294be2bSgysit loc, ValueRange{constantZero(rewriter, loc, eltType)}, 6957294be2bSgysit ValueRange{values}); 6967294be2bSgysit rewriter.create<linalg::FillOp>( 6977294be2bSgysit loc, ValueRange{constantZero(rewriter, loc, boolType)}, 6987294be2bSgysit ValueRange{filled}); 69984cd51bbSwren romano // Replace expansion op with these buffers and initial coordinate. 7004f2ec7f9SAart Bik assert(op.getNumResults() == 4); 70184cd51bbSwren romano rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero}); 7024f2ec7f9SAart Bik return success(); 7034f2ec7f9SAart Bik } 7044f2ec7f9SAart Bik }; 7054f2ec7f9SAart Bik 706faa00c13SAart Bik /// Sparse conversion rule for the compress operator. 7074f2ec7f9SAart Bik class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> { 7084f2ec7f9SAart Bik public: 7094f2ec7f9SAart Bik using OpConversionPattern::OpConversionPattern; 7104f2ec7f9SAart Bik LogicalResult 7114f2ec7f9SAart Bik matchAndRewrite(CompressOp op, OpAdaptor adaptor, 7124f2ec7f9SAart Bik ConversionPatternRewriter &rewriter) const override { 7130b55f94dSAart Bik Location loc = op->getLoc(); 7144f2ec7f9SAart Bik // Note that this method call resets the values/filled-switch back to 7154f2ec7f9SAart Bik // all-zero/false by only iterating over the set elements, so the 7164f2ec7f9SAart Bik // complexity remains proportional to the sparsity of the expanded 7174f2ec7f9SAart Bik // access pattern. 718a3610359SAart Bik Value values = adaptor.getValues(); 719a3610359SAart Bik Value filled = adaptor.getFilled(); 720a3610359SAart Bik Value added = adaptor.getAdded(); 721a3610359SAart Bik Value count = adaptor.getCount(); 722a3610359SAart Bik Value tensor = adaptor.getTensor(); 723f708a549Swren romano const auto stt = getSparseTensorType(op.getTensor()); 724f708a549Swren romano const Type elemTp = stt.getElementType(); 72584cd51bbSwren romano const Level lvlRank = stt.getLvlRank(); 72684cd51bbSwren romano auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType()); 72784cd51bbSwren romano storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords()); 728c9489225Swren romano SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)}; 7299f596a7cSAart Bik createFuncCall(rewriter, loc, name, {}, 73084cd51bbSwren romano {tensor, lvlCoords, values, filled, added, count}, 731bb8632c1SAart Bik EmitCInterface::On); 7329f596a7cSAart Bik rewriter.replaceOp(op, adaptor.getTensor()); 7330b55f94dSAart Bik // Deallocate the buffers on exit of the loop nest. 7345661647eSAart Bik Operation *parent = getTop(op); 7350b55f94dSAart Bik rewriter.setInsertionPointAfter(parent); 736a3610359SAart Bik rewriter.create<memref::DeallocOp>(loc, values); 737a3610359SAart Bik rewriter.create<memref::DeallocOp>(loc, filled); 738a3610359SAart Bik rewriter.create<memref::DeallocOp>(loc, added); 7394f2ec7f9SAart Bik return success(); 7404f2ec7f9SAart Bik } 7414f2ec7f9SAart Bik }; 7424f2ec7f9SAart Bik 743af8428c0SAart Bik /// Sparse conversion rule for the sparse_tensor.assemble operator. 7446ca47eb4SPeiming Liu class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> { 745fa6726e2SPeiming Liu public: 746fa6726e2SPeiming Liu using OpConversionPattern::OpConversionPattern; 747fa6726e2SPeiming Liu LogicalResult 7486ca47eb4SPeiming Liu matchAndRewrite(AssembleOp op, OpAdaptor adaptor, 749fa6726e2SPeiming Liu ConversionPatternRewriter &rewriter) const override { 750fa6726e2SPeiming Liu const Location loc = op->getLoc(); 751fa6726e2SPeiming Liu const auto dstTp = getSparseTensorType(op.getResult()); 752fa6726e2SPeiming Liu assert(dstTp.hasStaticDimShape()); 7532323f48eSAart Bik SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, dstTp); 754af8428c0SAart Bik // Use a library method to transfer the external buffers from 755af8428c0SAart Bik // clients to the internal SparseTensorStorage. Since we cannot 756af8428c0SAart Bik // assume clients transfer ownership of the buffers, this method 757af8428c0SAart Bik // will copy all data over into a new SparseTensorStorage. 758fa6726e2SPeiming Liu Value dst = 759fa6726e2SPeiming Liu NewCallParams(rewriter, loc) 7602323f48eSAart Bik .genBuffers(dstTp.withoutDimToLvl(), dimSizesValues) 761fa6726e2SPeiming Liu .genNewCall(Action::kPack, 762fa6726e2SPeiming Liu genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(), 763fa6726e2SPeiming Liu adaptor.getValues())); 764fa6726e2SPeiming Liu rewriter.replaceOp(op, dst); 765fa6726e2SPeiming Liu return success(); 766fa6726e2SPeiming Liu } 767fa6726e2SPeiming Liu }; 768fa6726e2SPeiming Liu 769af8428c0SAart Bik /// Sparse conversion rule for the sparse_tensor.disassemble operator. 7705122a2c2SAart Bik /// Note that the current implementation simply exposes the buffers to 7715122a2c2SAart Bik /// the external client. This assumes the client only reads the buffers 7725122a2c2SAart Bik /// (usually copying it to the external data structures, such as numpy 7735122a2c2SAart Bik /// arrays). The semantics of the disassemble operation technically 7745122a2c2SAart Bik /// require that the copying is done here already using the out-levels 7755122a2c2SAart Bik /// and out-values clause. 776af8428c0SAart Bik class SparseTensorDisassembleConverter 777af8428c0SAart Bik : public OpConversionPattern<DisassembleOp> { 778af8428c0SAart Bik public: 779af8428c0SAart Bik using OpConversionPattern::OpConversionPattern; 780af8428c0SAart Bik LogicalResult 781af8428c0SAart Bik matchAndRewrite(DisassembleOp op, OpAdaptor adaptor, 782af8428c0SAart Bik ConversionPatternRewriter &rewriter) const override { 783af8428c0SAart Bik Location loc = op->getLoc(); 784af8428c0SAart Bik auto stt = getSparseTensorType(op.getTensor()); 785af8428c0SAart Bik SmallVector<Value> retVal; 786af8428c0SAart Bik SmallVector<Value> retLen; 787fc9f1d49SPeiming Liu // Get the positions and coordinates buffers. 788af8428c0SAart Bik const Level lvlRank = stt.getLvlRank(); 789af8428c0SAart Bik Level trailCOOLen = 0; 790af8428c0SAart Bik for (Level l = 0; l < lvlRank; l++) { 791af8428c0SAart Bik if (!stt.isUniqueLvl(l) && 792af8428c0SAart Bik (stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) { 793af8428c0SAart Bik // A `(loose)compressed_nu` level marks the start of trailing COO 794af8428c0SAart Bik // start level. Since the target coordinate buffer used for trailing 795af8428c0SAart Bik // COO is passed in as AoS scheme and SparseTensorStorage uses a SoA 796af8428c0SAart Bik // scheme, we cannot simply use the internal buffers. 797af8428c0SAart Bik trailCOOLen = lvlRank - l; 798af8428c0SAart Bik break; 799af8428c0SAart Bik } 800af8428c0SAart Bik if (stt.isWithPos(l)) { 801af8428c0SAart Bik auto poss = 802af8428c0SAart Bik genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l); 803af8428c0SAart Bik auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0); 804fc9f1d49SPeiming Liu auto posLenTp = op.getLvlLens().getTypes()[retLen.size()]; 805af8428c0SAart Bik retVal.push_back(poss); 806af8428c0SAart Bik retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp)); 807af8428c0SAart Bik } 808af8428c0SAart Bik if (stt.isWithCrd(l)) { 809af8428c0SAart Bik auto crds = 810af8428c0SAart Bik genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l); 811af8428c0SAart Bik auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0); 812fc9f1d49SPeiming Liu auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()]; 813af8428c0SAart Bik retVal.push_back(crds); 814af8428c0SAart Bik retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp)); 815af8428c0SAart Bik } 816af8428c0SAart Bik } 817af8428c0SAart Bik // Handle AoS vs. SoA mismatch for COO. 818af8428c0SAart Bik if (trailCOOLen != 0) { 819af8428c0SAart Bik uint64_t cooStartLvl = lvlRank - trailCOOLen; 820af8428c0SAart Bik assert(!stt.isUniqueLvl(cooStartLvl) && 821af8428c0SAart Bik (stt.isCompressedLvl(cooStartLvl) || 822af8428c0SAart Bik stt.isLooseCompressedLvl(cooStartLvl))); 823af8428c0SAart Bik // Positions. 824af8428c0SAart Bik auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), 825af8428c0SAart Bik cooStartLvl); 826af8428c0SAart Bik auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0); 827fc9f1d49SPeiming Liu auto posLenTp = op.getLvlLens().getTypes()[retLen.size()]; 828af8428c0SAart Bik retVal.push_back(poss); 829af8428c0SAart Bik retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp)); 830af8428c0SAart Bik // Coordinates, copied over with: 831af8428c0SAart Bik // for (i = 0; i < crdLen; i++) 832af8428c0SAart Bik // buf[i][0] = crd0[i]; buf[i][1] = crd1[i]; 833fc9f1d49SPeiming Liu auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]); 834af8428c0SAart Bik auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), 835af8428c0SAart Bik cooStartLvl); 836af8428c0SAart Bik auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), 837af8428c0SAart Bik cooStartLvl + 1); 838af8428c0SAart Bik auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds0, 0); 839af8428c0SAart Bik auto two = constantIndex(rewriter, loc, 2); 840af8428c0SAart Bik auto bufLen = rewriter.create<arith::MulIOp>(loc, crdLen, two); 841af8428c0SAart Bik Type indexType = rewriter.getIndexType(); 842af8428c0SAart Bik auto zero = constantZero(rewriter, loc, indexType); 843af8428c0SAart Bik auto one = constantOne(rewriter, loc, indexType); 844af8428c0SAart Bik scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, zero, crdLen, one); 845af8428c0SAart Bik auto idx = forOp.getInductionVar(); 846af8428c0SAart Bik rewriter.setInsertionPointToStart(forOp.getBody()); 847af8428c0SAart Bik auto c0 = rewriter.create<memref::LoadOp>(loc, crds0, idx); 848af8428c0SAart Bik auto c1 = rewriter.create<memref::LoadOp>(loc, crds1, idx); 849af8428c0SAart Bik SmallVector<Value> args; 850af8428c0SAart Bik args.push_back(idx); 851af8428c0SAart Bik args.push_back(zero); 852af8428c0SAart Bik rewriter.create<memref::StoreOp>(loc, c0, buf, args); 853af8428c0SAart Bik args[1] = one; 854af8428c0SAart Bik rewriter.create<memref::StoreOp>(loc, c1, buf, args); 855af8428c0SAart Bik rewriter.setInsertionPointAfter(forOp); 856fc9f1d49SPeiming Liu auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()]; 857af8428c0SAart Bik retVal.push_back(buf); 858af8428c0SAart Bik retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp)); 859af8428c0SAart Bik } 860fc9f1d49SPeiming Liu // Get the values buffer last. 861fc9f1d49SPeiming Liu auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor()); 862fc9f1d49SPeiming Liu auto valLenTp = op.getValLen().getType(); 863fc9f1d49SPeiming Liu auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0); 864fc9f1d49SPeiming Liu retVal.push_back(vals); 865fc9f1d49SPeiming Liu retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp)); 866fc9f1d49SPeiming Liu 867af8428c0SAart Bik // Converts MemRefs back to Tensors. 868af8428c0SAart Bik assert(retVal.size() + retLen.size() == op.getNumResults()); 869af8428c0SAart Bik for (unsigned i = 0, sz = retVal.size(); i < sz; i++) { 870af8428c0SAart Bik auto tensor = rewriter.create<bufferization::ToTensorOp>(loc, retVal[i]); 871af8428c0SAart Bik retVal[i] = 872af8428c0SAart Bik rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor); 873af8428c0SAart Bik } 874fc9f1d49SPeiming Liu 875af8428c0SAart Bik // Appends the actual memory length used in each buffer returned. 876af8428c0SAart Bik retVal.append(retLen.begin(), retLen.end()); 877af8428c0SAart Bik rewriter.replaceOp(op, retVal); 878af8428c0SAart Bik return success(); 879af8428c0SAart Bik } 880af8428c0SAart Bik }; 881af8428c0SAart Bik 882e8e8df4cSMatthias Springer struct SparseHasRuntimeLibraryConverter 883e8e8df4cSMatthias Springer : public OpConversionPattern<HasRuntimeLibraryOp> { 884e8e8df4cSMatthias Springer using OpConversionPattern::OpConversionPattern; 885e8e8df4cSMatthias Springer LogicalResult 886e8e8df4cSMatthias Springer matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor, 887e8e8df4cSMatthias Springer ConversionPatternRewriter &rewriter) const override { 888e8e8df4cSMatthias Springer auto i1Type = rewriter.getI1Type(); 889e8e8df4cSMatthias Springer rewriter.replaceOpWithNewOp<arith::ConstantOp>( 890e8e8df4cSMatthias Springer op, i1Type, rewriter.getIntegerAttr(i1Type, 1)); 891e8e8df4cSMatthias Springer return success(); 892e8e8df4cSMatthias Springer } 893e8e8df4cSMatthias Springer }; 894e8e8df4cSMatthias Springer 895a2c9d4bbSAart Bik } // namespace 896a2c9d4bbSAart Bik 89705c7f450SAart Bik //===----------------------------------------------------------------------===// 89886b22d31SAart Bik // Sparse tensor type conversion into opaque pointer. 89986b22d31SAart Bik //===----------------------------------------------------------------------===// 90086b22d31SAart Bik 90186b22d31SAart Bik mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() { 90286b22d31SAart Bik addConversion([](Type type) { return type; }); 90386b22d31SAart Bik addConversion(convertSparseTensorTypes); 90486b22d31SAart Bik } 90586b22d31SAart Bik 90686b22d31SAart Bik //===----------------------------------------------------------------------===// 90705c7f450SAart Bik // Public method for populating conversion rules. 90805c7f450SAart Bik //===----------------------------------------------------------------------===// 90905c7f450SAart Bik 910a2c9d4bbSAart Bik /// Populates the given patterns list with conversion rules required for 911a2c9d4bbSAart Bik /// the sparsification of linear algebra operations. 912*206fad0eSMatthias Springer void mlir::populateSparseTensorConversionPatterns( 913*206fad0eSMatthias Springer const TypeConverter &typeConverter, RewritePatternSet &patterns) { 914f248d0b2SPeiming Liu patterns 915c780352dSPeiming Liu .add<SparseReturnConverter, SparseTensorLvlOpConverter, 916ef222988SPeiming Liu SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter, 917c3b01b46SPeiming Liu SparseTensorAllocConverter, SparseTensorEmptyConverter, 918f248d0b2SPeiming Liu SparseTensorDeallocConverter, SparseTensorReorderCOOConverter, 919f248d0b2SPeiming Liu SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter, 920dc4cfdbbSAart Bik SparseToCoordinatesBufferConverter, SparseTensorToValuesConverter, 921dc4cfdbbSAart Bik SparseNumberOfEntriesConverter, SparseTensorLoadConverter, 922dc4cfdbbSAart Bik SparseTensorInsertConverter, SparseTensorExpandConverter, 923dc4cfdbbSAart Bik SparseTensorCompressConverter, SparseTensorAssembleConverter, 924dc4cfdbbSAart Bik SparseTensorDisassembleConverter, SparseHasRuntimeLibraryConverter>( 925dc4cfdbbSAart Bik typeConverter, patterns.getContext()); 926a2c9d4bbSAart Bik } 927