1 //===- SparseTensorDescriptor.h ---------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This header file defines utilities for the sparse memory layout. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORDESCRIPTOR_H_ 14 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORDESCRIPTOR_H_ 15 16 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 17 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" 18 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 19 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 20 21 namespace mlir { 22 namespace sparse_tensor { 23 24 class SparseTensorSpecifier { 25 public: 26 explicit SparseTensorSpecifier(Value specifier) 27 : specifier(cast<TypedValue<StorageSpecifierType>>(specifier)) {} 28 29 // Undef value for level-sizes, all zero values for memory-sizes. 30 static Value getInitValue(OpBuilder &builder, Location loc, 31 SparseTensorType stt); 32 33 /*implicit*/ operator Value() { return specifier; } 34 35 Value getSpecifierField(OpBuilder &builder, Location loc, 36 StorageSpecifierKind kind, std::optional<Level> lvl); 37 38 void setSpecifierField(OpBuilder &builder, Location loc, Value v, 39 StorageSpecifierKind kind, std::optional<Level> lvl); 40 41 private: 42 TypedValue<StorageSpecifierType> specifier; 43 }; 44 45 /// A helper class around an array of values that corresponds to a sparse 46 /// tensor. This class provides a set of meaningful APIs to query and update 47 /// a particular field in a consistent way. Users should not make assumptions 48 /// on how a sparse tensor is laid out but instead rely on this class to access 49 /// the right value for the right field. 50 template <typename ValueArrayRef> 51 class SparseTensorDescriptorImpl { 52 protected: 53 SparseTensorDescriptorImpl(SparseTensorType stt, ValueArrayRef fields) 54 : rType(stt), fields(fields), layout(stt) { 55 assert(layout.getNumFields() == getNumFields()); 56 // We should make sure the class is trivially copyable (and should be small 57 // enough) such that we can pass it by value. 58 static_assert(std::is_trivially_copyable_v< 59 SparseTensorDescriptorImpl<ValueArrayRef>>); 60 } 61 62 public: 63 FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind, 64 std::optional<Level> lvl) const { 65 // Delegates to storage layout. 66 return layout.getMemRefFieldIndex(kind, lvl); 67 } 68 69 unsigned getNumFields() const { return fields.size(); } 70 71 /// 72 /// Getters: get the value for required field. 73 /// 74 75 Value getSpecifier() const { return fields.back(); } 76 77 Value getSpecifierField(OpBuilder &builder, Location loc, 78 StorageSpecifierKind kind, 79 std::optional<Level> lvl) const { 80 SparseTensorSpecifier md(fields.back()); 81 return md.getSpecifierField(builder, loc, kind, lvl); 82 } 83 84 Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const { 85 return getSpecifierField(builder, loc, StorageSpecifierKind::LvlSize, lvl); 86 } 87 88 Value getPosMemRef(Level lvl) const { 89 return getMemRefField(SparseTensorFieldKind::PosMemRef, lvl); 90 } 91 92 Value getValMemRef() const { 93 return getMemRefField(SparseTensorFieldKind::ValMemRef, std::nullopt); 94 } 95 96 Value getMemRefField(SparseTensorFieldKind kind, 97 std::optional<Level> lvl) const { 98 return getField(getMemRefFieldIndex(kind, lvl)); 99 } 100 101 Value getMemRefField(FieldIndex fidx) const { 102 assert(fidx < fields.size() - 1); 103 return getField(fidx); 104 } 105 106 Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const { 107 return getSpecifierField(builder, loc, StorageSpecifierKind::PosMemSize, 108 lvl); 109 } 110 111 Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const { 112 return getSpecifierField(builder, loc, StorageSpecifierKind::CrdMemSize, 113 lvl); 114 } 115 116 Value getValMemSize(OpBuilder &builder, Location loc) const { 117 return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize, 118 std::nullopt); 119 } 120 121 Type getMemRefElementType(SparseTensorFieldKind kind, 122 std::optional<Level> lvl) const { 123 return getMemRefType(getMemRefField(kind, lvl)).getElementType(); 124 } 125 126 Value getField(FieldIndex fidx) const { 127 assert(fidx < fields.size()); 128 return fields[fidx]; 129 } 130 131 ValueRange getMemRefFields() const { 132 return fields.drop_back(); // drop the last metadata fields 133 } 134 135 std::pair<FieldIndex, unsigned> getCrdMemRefIndexAndStride(Level lvl) const { 136 return layout.getFieldIndexAndStride(SparseTensorFieldKind::CrdMemRef, lvl); 137 } 138 139 Value getAOSMemRef() const { 140 const Level cooStart = rType.getAoSCOOStart(); 141 assert(cooStart < rType.getLvlRank()); 142 return getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart); 143 } 144 145 RankedTensorType getRankedTensorType() const { return rType; } 146 ValueArrayRef getFields() const { return fields; } 147 StorageLayout getLayout() const { return layout; } 148 149 protected: 150 SparseTensorType rType; 151 ValueArrayRef fields; 152 StorageLayout layout; 153 }; 154 155 /// Uses ValueRange for immutable descriptors. 156 class SparseTensorDescriptor : public SparseTensorDescriptorImpl<ValueRange> { 157 public: 158 SparseTensorDescriptor(SparseTensorType stt, ValueRange buffers) 159 : SparseTensorDescriptorImpl<ValueRange>(stt, buffers) {} 160 161 Value getCrdMemRefOrView(OpBuilder &builder, Location loc, Level lvl) const; 162 }; 163 164 /// Using SmallVector for mutable descriptor allows users to reuse it as a 165 /// tmp buffers to append value for some special cases, though users should 166 /// be responsible to restore the buffer to legal states after their use. It 167 /// is probably not a clean way, but it is the most efficient way to avoid 168 /// copying the fields into another SmallVector. If a more clear way is 169 /// wanted, we should change it to MutableArrayRef instead. 170 class MutSparseTensorDescriptor 171 : public SparseTensorDescriptorImpl<SmallVectorImpl<Value> &> { 172 public: 173 MutSparseTensorDescriptor(SparseTensorType stt, 174 SmallVectorImpl<Value> &buffers) 175 : SparseTensorDescriptorImpl<SmallVectorImpl<Value> &>(stt, buffers) {} 176 177 // Allow implicit type conversion from mutable descriptors to immutable ones 178 // (but not vice versa). 179 /*implicit*/ operator SparseTensorDescriptor() const { 180 return SparseTensorDescriptor(rType, fields); 181 } 182 183 /// 184 /// Adds additional setters for mutable descriptor, update the value for 185 /// required field. 186 /// 187 188 void setMemRefField(SparseTensorFieldKind kind, std::optional<Level> lvl, 189 Value v) { 190 fields[getMemRefFieldIndex(kind, lvl)] = v; 191 } 192 193 void setMemRefField(FieldIndex fidx, Value v) { 194 assert(fidx < fields.size() - 1); 195 fields[fidx] = v; 196 } 197 198 void setField(FieldIndex fidx, Value v) { 199 assert(fidx < fields.size()); 200 fields[fidx] = v; 201 } 202 203 void setSpecifier(Value newSpec) { fields.back() = newSpec; } 204 205 void setSpecifierField(OpBuilder &builder, Location loc, 206 StorageSpecifierKind kind, std::optional<Level> lvl, 207 Value v) { 208 SparseTensorSpecifier md(fields.back()); 209 md.setSpecifierField(builder, loc, v, kind, lvl); 210 fields.back() = md; 211 } 212 213 void setValMemSize(OpBuilder &builder, Location loc, Value v) { 214 setSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize, 215 std::nullopt, v); 216 } 217 218 void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) { 219 setSpecifierField(builder, loc, StorageSpecifierKind::CrdMemSize, lvl, v); 220 } 221 222 void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) { 223 setSpecifierField(builder, loc, StorageSpecifierKind::PosMemSize, lvl, v); 224 } 225 226 void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v) { 227 setSpecifierField(builder, loc, StorageSpecifierKind::LvlSize, lvl, v); 228 } 229 }; 230 231 /// Packs the given values as a "tuple" value. 232 inline Value genTuple(OpBuilder &builder, Location loc, Type tp, 233 ValueRange values) { 234 return builder.create<UnrealizedConversionCastOp>(loc, TypeRange(tp), values) 235 .getResult(0); 236 } 237 238 inline Value genTuple(OpBuilder &builder, Location loc, 239 SparseTensorDescriptor desc) { 240 return genTuple(builder, loc, desc.getRankedTensorType(), desc.getFields()); 241 } 242 243 inline SparseTensorDescriptor 244 getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type) { 245 return SparseTensorDescriptor(SparseTensorType(type), adaptorValues); 246 } 247 248 inline MutSparseTensorDescriptor 249 getMutDescriptorFromTensorTuple(ValueRange adaptorValues, 250 SmallVectorImpl<Value> &fields, 251 RankedTensorType type) { 252 fields.assign(adaptorValues.begin(), adaptorValues.end()); 253 return MutSparseTensorDescriptor(SparseTensorType(type), fields); 254 } 255 256 } // namespace sparse_tensor 257 } // namespace mlir 258 259 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSODESCRIPTOR_H_ 260