1 //===- SparseTensorDescriptor.cpp -----------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "SparseTensorDescriptor.h" 10 #include "CodegenUtils.h" 11 12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 15 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 18 using namespace mlir; 19 using namespace sparse_tensor; 20 21 //===----------------------------------------------------------------------===// 22 // Private helper methods. 23 //===----------------------------------------------------------------------===// 24 25 /// Constructs a nullable `LevelAttr` from the `std::optional<Level>`. 26 static IntegerAttr optionalLevelAttr(MLIRContext *ctx, 27 std::optional<Level> lvl) { 28 return lvl ? IntegerAttr::get(IndexType::get(ctx), lvl.value()) 29 : IntegerAttr(); 30 } 31 32 // This is only ever called from `SparseTensorTypeToBufferConverter`, 33 // which is why the first argument is `RankedTensorType` rather than 34 // `SparseTensorType`. 35 static std::optional<LogicalResult> 36 convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) { 37 const SparseTensorType stt(rtp); 38 if (!stt.hasEncoding()) 39 return std::nullopt; 40 41 foreachFieldAndTypeInSparseTensor( 42 stt, 43 [&fields](Type fieldType, FieldIndex fieldIdx, 44 SparseTensorFieldKind /*fieldKind*/, Level /*lvl*/, 45 LevelType /*lt*/) -> bool { 46 assert(fieldIdx == fields.size()); 47 fields.push_back(fieldType); 48 return true; 49 }); 50 return success(); 51 } 52 53 //===----------------------------------------------------------------------===// 54 // The sparse tensor type converter (defined in Passes.h). 55 //===----------------------------------------------------------------------===// 56 57 static Value materializeTuple(OpBuilder &builder, RankedTensorType tp, 58 ValueRange inputs, Location loc) { 59 if (!getSparseTensorEncoding(tp)) 60 // Not a sparse tensor. 61 return Value(); 62 // Sparsifier knows how to cancel out these casts. 63 return genTuple(builder, loc, tp, inputs); 64 } 65 66 SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { 67 addConversion([](Type type) { return type; }); 68 addConversion(convertSparseTensorType); 69 70 // Required by scf.for 1:N type conversion. 71 addSourceMaterialization(materializeTuple); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // StorageTensorSpecifier methods. 76 //===----------------------------------------------------------------------===// 77 78 Value SparseTensorSpecifier::getInitValue(OpBuilder &builder, Location loc, 79 SparseTensorType stt) { 80 return builder.create<StorageSpecifierInitOp>( 81 loc, StorageSpecifierType::get(stt.getEncoding())); 82 } 83 84 Value SparseTensorSpecifier::getSpecifierField(OpBuilder &builder, Location loc, 85 StorageSpecifierKind kind, 86 std::optional<Level> lvl) { 87 return builder.create<GetStorageSpecifierOp>( 88 loc, specifier, kind, optionalLevelAttr(specifier.getContext(), lvl)); 89 } 90 91 void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc, 92 Value v, 93 StorageSpecifierKind kind, 94 std::optional<Level> lvl) { 95 // TODO: make `v` have type `TypedValue<IndexType>` instead. 96 assert(v.getType().isIndex()); 97 specifier = builder.create<SetStorageSpecifierOp>( 98 loc, specifier, kind, optionalLevelAttr(specifier.getContext(), lvl), v); 99 } 100 101 //===----------------------------------------------------------------------===// 102 // SparseTensorDescriptor methods. 103 //===----------------------------------------------------------------------===// 104 105 Value sparse_tensor::SparseTensorDescriptor::getCrdMemRefOrView( 106 OpBuilder &builder, Location loc, Level lvl) const { 107 const Level cooStart = rType.getAoSCOOStart(); 108 if (lvl < cooStart) 109 return getMemRefField(SparseTensorFieldKind::CrdMemRef, lvl); 110 111 Value stride = constantIndex(builder, loc, rType.getLvlRank() - cooStart); 112 Value size = getCrdMemSize(builder, loc, cooStart); 113 size = builder.create<arith::DivUIOp>(loc, size, stride); 114 return builder.create<memref::SubViewOp>( 115 loc, getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart), 116 /*offset=*/ValueRange{constantIndex(builder, loc, lvl - cooStart)}, 117 /*size=*/ValueRange{size}, 118 /*step=*/ValueRange{stride}); 119 } 120