//===- SparseStorageSpecifierToLLVM.cpp - convert specifier to llvm -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "Utils/CodegenUtils.h" #include "mlir/Conversion/LLVMCommon/StructBuilder.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include using namespace mlir; using namespace sparse_tensor; namespace { //===----------------------------------------------------------------------===// // Helper methods. //===----------------------------------------------------------------------===// static SmallVector getSpecifierFields(StorageSpecifierType tp) { MLIRContext *ctx = tp.getContext(); auto enc = tp.getEncoding(); const Level lvlRank = enc.getLvlRank(); SmallVector result; // TODO: how can we get the lowering type for index type in the later pipeline // to be consistent? LLVM::StructureType does not allow index fields. auto sizeType = IntegerType::get(tp.getContext(), 64); auto lvlSizes = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank); auto memSizes = LLVM::LLVMArrayType::get(ctx, sizeType, getNumDataFieldsFromEncoding(enc)); result.push_back(lvlSizes); result.push_back(memSizes); if (enc.isSlice()) { // Extra fields are required for the slice information. auto dimOffset = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank); auto dimStride = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank); result.push_back(dimOffset); result.push_back(dimStride); } return result; } static Type convertSpecifier(StorageSpecifierType tp) { return LLVM::LLVMStructType::getLiteral(tp.getContext(), getSpecifierFields(tp)); } //===----------------------------------------------------------------------===// // Specifier struct builder. //===----------------------------------------------------------------------===// constexpr uint64_t kLvlSizePosInSpecifier = 0; constexpr uint64_t kMemSizePosInSpecifier = 1; constexpr uint64_t kDimOffsetPosInSpecifier = 2; constexpr uint64_t kDimStridePosInSpecifier = 3; class SpecifierStructBuilder : public StructBuilder { private: Value extractField(OpBuilder &builder, Location loc, ArrayRef indices) const { return genCast(builder, loc, builder.create(loc, value, indices), builder.getIndexType()); } void insertField(OpBuilder &builder, Location loc, ArrayRef indices, Value v) { value = builder.create( loc, value, genCast(builder, loc, v, builder.getIntegerType(64)), indices); } public: explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) { assert(value); } // Undef value for dimension sizes, all zero value for memory sizes. static Value getInitValue(OpBuilder &builder, Location loc, Type structType, Value source); Value lvlSize(OpBuilder &builder, Location loc, Level lvl) const; void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value size); Value dimOffset(OpBuilder &builder, Location loc, Dimension dim) const; void setDimOffset(OpBuilder &builder, Location loc, Dimension dim, Value size); Value dimStride(OpBuilder &builder, Location loc, Dimension dim) const; void setDimStride(OpBuilder &builder, Location loc, Dimension dim, Value size); Value memSize(OpBuilder &builder, Location loc, FieldIndex fidx) const; void setMemSize(OpBuilder &builder, Location loc, FieldIndex fidx, Value size); Value memSizeArray(OpBuilder &builder, Location loc) const; void setMemSizeArray(OpBuilder &builder, Location loc, Value array); }; Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc, Type structType, Value source) { Value metaData = builder.create(loc, structType); SpecifierStructBuilder md(metaData); if (!source) { auto memSizeArrayType = cast(cast(structType) .getBody()[kMemSizePosInSpecifier]); Value zero = constantZero(builder, loc, memSizeArrayType.getElementType()); // Fill memSizes array with zero. for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++) md.setMemSize(builder, loc, i, zero); } else { // We copy non-slice information (memory sizes array) from source SpecifierStructBuilder sourceMd(source); md.setMemSizeArray(builder, loc, sourceMd.memSizeArray(builder, loc)); } return md; } /// Builds IR extracting the pos-th offset from the descriptor. Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc, Dimension dim) const { return extractField( builder, loc, ArrayRef{kDimOffsetPosInSpecifier, static_cast(dim)}); } /// Builds IR inserting the pos-th offset into the descriptor. void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc, Dimension dim, Value size) { insertField( builder, loc, ArrayRef{kDimOffsetPosInSpecifier, static_cast(dim)}, size); } /// Builds IR extracting the `lvl`-th level-size from the descriptor. Value SpecifierStructBuilder::lvlSize(OpBuilder &builder, Location loc, Level lvl) const { // This static_cast makes the narrowing of `lvl` explicit, as required // by the braces notation for the ctor. return extractField( builder, loc, ArrayRef{kLvlSizePosInSpecifier, static_cast(lvl)}); } /// Builds IR inserting the `lvl`-th level-size into the descriptor. void SpecifierStructBuilder::setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value size) { // This static_cast makes the narrowing of `lvl` explicit, as required // by the braces notation for the ctor. insertField( builder, loc, ArrayRef{kLvlSizePosInSpecifier, static_cast(lvl)}, size); } /// Builds IR extracting the pos-th stride from the descriptor. Value SpecifierStructBuilder::dimStride(OpBuilder &builder, Location loc, Dimension dim) const { return extractField( builder, loc, ArrayRef{kDimStridePosInSpecifier, static_cast(dim)}); } /// Builds IR inserting the pos-th stride into the descriptor. void SpecifierStructBuilder::setDimStride(OpBuilder &builder, Location loc, Dimension dim, Value size) { insertField( builder, loc, ArrayRef{kDimStridePosInSpecifier, static_cast(dim)}, size); } /// Builds IR extracting the pos-th memory size into the descriptor. Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc, FieldIndex fidx) const { return extractField( builder, loc, ArrayRef{kMemSizePosInSpecifier, static_cast(fidx)}); } /// Builds IR inserting the `fidx`-th memory-size into the descriptor. void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc, FieldIndex fidx, Value size) { insertField( builder, loc, ArrayRef{kMemSizePosInSpecifier, static_cast(fidx)}, size); } /// Builds IR extracting the memory size array from the descriptor. Value SpecifierStructBuilder::memSizeArray(OpBuilder &builder, Location loc) const { return builder.create(loc, value, kMemSizePosInSpecifier); } /// Builds IR inserting the memory size array into the descriptor. void SpecifierStructBuilder::setMemSizeArray(OpBuilder &builder, Location loc, Value array) { value = builder.create(loc, value, array, kMemSizePosInSpecifier); } } // namespace //===----------------------------------------------------------------------===// // The sparse storage specifier type converter (defined in Passes.h). //===----------------------------------------------------------------------===// StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() { addConversion([](Type type) { return type; }); addConversion(convertSpecifier); } //===----------------------------------------------------------------------===// // Storage specifier conversion rules. //===----------------------------------------------------------------------===// template class SpecifierGetterSetterOpConverter : public OpConversionPattern { public: using OpAdaptor = typename SourceOp::Adaptor; using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SpecifierStructBuilder spec(adaptor.getSpecifier()); switch (op.getSpecifierKind()) { case StorageSpecifierKind::LvlSize: { Value v = Base::onLvlSize(rewriter, op, spec, (*op.getLevel())); rewriter.replaceOp(op, v); return success(); } case StorageSpecifierKind::DimOffset: { Value v = Base::onDimOffset(rewriter, op, spec, (*op.getLevel())); rewriter.replaceOp(op, v); return success(); } case StorageSpecifierKind::DimStride: { Value v = Base::onDimStride(rewriter, op, spec, (*op.getLevel())); rewriter.replaceOp(op, v); return success(); } case StorageSpecifierKind::CrdMemSize: case StorageSpecifierKind::PosMemSize: case StorageSpecifierKind::ValMemSize: { auto enc = op.getSpecifier().getType().getEncoding(); StorageLayout layout(enc); std::optional lvl; if (op.getLevel()) lvl = (*op.getLevel()); unsigned idx = layout.getMemRefFieldIndex(toFieldKind(op.getSpecifierKind()), lvl); Value v = Base::onMemSize(rewriter, op, spec, idx); rewriter.replaceOp(op, v); return success(); } } llvm_unreachable("unrecognized specifer kind"); } }; struct StorageSpecifierSetOpConverter : public SpecifierGetterSetterOpConverter { using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter; static Value onLvlSize(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, Level lvl) { spec.setLvlSize(builder, op.getLoc(), lvl, op.getValue()); return spec; } static Value onDimOffset(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, Dimension d) { spec.setDimOffset(builder, op.getLoc(), d, op.getValue()); return spec; } static Value onDimStride(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, Dimension d) { spec.setDimStride(builder, op.getLoc(), d, op.getValue()); return spec; } static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, FieldIndex fidx) { spec.setMemSize(builder, op.getLoc(), fidx, op.getValue()); return spec; } }; struct StorageSpecifierGetOpConverter : public SpecifierGetterSetterOpConverter { using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter; static Value onLvlSize(OpBuilder &builder, GetStorageSpecifierOp op, SpecifierStructBuilder &spec, Level lvl) { return spec.lvlSize(builder, op.getLoc(), lvl); } static Value onDimOffset(OpBuilder &builder, GetStorageSpecifierOp op, const SpecifierStructBuilder &spec, Dimension d) { return spec.dimOffset(builder, op.getLoc(), d); } static Value onDimStride(OpBuilder &builder, GetStorageSpecifierOp op, const SpecifierStructBuilder &spec, Dimension d) { return spec.dimStride(builder, op.getLoc(), d); } static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op, SpecifierStructBuilder &spec, FieldIndex fidx) { return spec.memSize(builder, op.getLoc(), fidx); } }; struct StorageSpecifierInitOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type llvmType = getTypeConverter()->convertType(op.getResult().getType()); rewriter.replaceOp( op, SpecifierStructBuilder::getInitValue( rewriter, op.getLoc(), llvmType, adaptor.getSource())); return success(); } }; //===----------------------------------------------------------------------===// // Public method for populating conversion rules. //===----------------------------------------------------------------------===// void mlir::populateStorageSpecifierToLLVMPatterns( const TypeConverter &converter, RewritePatternSet &patterns) { patterns.add(converter, patterns.getContext()); }