1 //===- SparseStorageSpecifierToLLVM.cpp - convert specifier to llvm -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Utils/CodegenUtils.h" 10 11 #include "mlir/Conversion/LLVMCommon/StructBuilder.h" 12 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" 13 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 14 15 #include <optional> 16 17 using namespace mlir; 18 using namespace sparse_tensor; 19 20 namespace { 21 22 //===----------------------------------------------------------------------===// 23 // Helper methods. 24 //===----------------------------------------------------------------------===// 25 26 static SmallVector<Type, 4> getSpecifierFields(StorageSpecifierType tp) { 27 MLIRContext *ctx = tp.getContext(); 28 auto enc = tp.getEncoding(); 29 const Level lvlRank = enc.getLvlRank(); 30 31 SmallVector<Type, 4> result; 32 // TODO: how can we get the lowering type for index type in the later pipeline 33 // to be consistent? LLVM::StructureType does not allow index fields. 34 auto sizeType = IntegerType::get(tp.getContext(), 64); 35 auto lvlSizes = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank); 36 auto memSizes = LLVM::LLVMArrayType::get(ctx, sizeType, 37 getNumDataFieldsFromEncoding(enc)); 38 result.push_back(lvlSizes); 39 result.push_back(memSizes); 40 41 if (enc.isSlice()) { 42 // Extra fields are required for the slice information. 43 auto dimOffset = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank); 44 auto dimStride = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank); 45 46 result.push_back(dimOffset); 47 result.push_back(dimStride); 48 } 49 50 return result; 51 } 52 53 static Type convertSpecifier(StorageSpecifierType tp) { 54 return LLVM::LLVMStructType::getLiteral(tp.getContext(), 55 getSpecifierFields(tp)); 56 } 57 58 //===----------------------------------------------------------------------===// 59 // Specifier struct builder. 60 //===----------------------------------------------------------------------===// 61 62 constexpr uint64_t kLvlSizePosInSpecifier = 0; 63 constexpr uint64_t kMemSizePosInSpecifier = 1; 64 constexpr uint64_t kDimOffsetPosInSpecifier = 2; 65 constexpr uint64_t kDimStridePosInSpecifier = 3; 66 67 class SpecifierStructBuilder : public StructBuilder { 68 private: 69 Value extractField(OpBuilder &builder, Location loc, 70 ArrayRef<int64_t> indices) const { 71 return genCast(builder, loc, 72 builder.create<LLVM::ExtractValueOp>(loc, value, indices), 73 builder.getIndexType()); 74 } 75 76 void insertField(OpBuilder &builder, Location loc, ArrayRef<int64_t> indices, 77 Value v) { 78 value = builder.create<LLVM::InsertValueOp>( 79 loc, value, genCast(builder, loc, v, builder.getIntegerType(64)), 80 indices); 81 } 82 83 public: 84 explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) { 85 assert(value); 86 } 87 88 // Undef value for dimension sizes, all zero value for memory sizes. 89 static Value getInitValue(OpBuilder &builder, Location loc, Type structType, 90 Value source); 91 92 Value lvlSize(OpBuilder &builder, Location loc, Level lvl) const; 93 void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value size); 94 95 Value dimOffset(OpBuilder &builder, Location loc, Dimension dim) const; 96 void setDimOffset(OpBuilder &builder, Location loc, Dimension dim, 97 Value size); 98 99 Value dimStride(OpBuilder &builder, Location loc, Dimension dim) const; 100 void setDimStride(OpBuilder &builder, Location loc, Dimension dim, 101 Value size); 102 103 Value memSize(OpBuilder &builder, Location loc, FieldIndex fidx) const; 104 void setMemSize(OpBuilder &builder, Location loc, FieldIndex fidx, 105 Value size); 106 107 Value memSizeArray(OpBuilder &builder, Location loc) const; 108 void setMemSizeArray(OpBuilder &builder, Location loc, Value array); 109 }; 110 111 Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc, 112 Type structType, Value source) { 113 Value metaData = builder.create<LLVM::UndefOp>(loc, structType); 114 SpecifierStructBuilder md(metaData); 115 if (!source) { 116 auto memSizeArrayType = 117 cast<LLVM::LLVMArrayType>(cast<LLVM::LLVMStructType>(structType) 118 .getBody()[kMemSizePosInSpecifier]); 119 120 Value zero = constantZero(builder, loc, memSizeArrayType.getElementType()); 121 // Fill memSizes array with zero. 122 for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++) 123 md.setMemSize(builder, loc, i, zero); 124 } else { 125 // We copy non-slice information (memory sizes array) from source 126 SpecifierStructBuilder sourceMd(source); 127 md.setMemSizeArray(builder, loc, sourceMd.memSizeArray(builder, loc)); 128 } 129 return md; 130 } 131 132 /// Builds IR extracting the pos-th offset from the descriptor. 133 Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc, 134 Dimension dim) const { 135 return extractField( 136 builder, loc, 137 ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)}); 138 } 139 140 /// Builds IR inserting the pos-th offset into the descriptor. 141 void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc, 142 Dimension dim, Value size) { 143 insertField( 144 builder, loc, 145 ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)}, 146 size); 147 } 148 149 /// Builds IR extracting the `lvl`-th level-size from the descriptor. 150 Value SpecifierStructBuilder::lvlSize(OpBuilder &builder, Location loc, 151 Level lvl) const { 152 // This static_cast makes the narrowing of `lvl` explicit, as required 153 // by the braces notation for the ctor. 154 return extractField( 155 builder, loc, 156 ArrayRef<int64_t>{kLvlSizePosInSpecifier, static_cast<int64_t>(lvl)}); 157 } 158 159 /// Builds IR inserting the `lvl`-th level-size into the descriptor. 160 void SpecifierStructBuilder::setLvlSize(OpBuilder &builder, Location loc, 161 Level lvl, Value size) { 162 // This static_cast makes the narrowing of `lvl` explicit, as required 163 // by the braces notation for the ctor. 164 insertField( 165 builder, loc, 166 ArrayRef<int64_t>{kLvlSizePosInSpecifier, static_cast<int64_t>(lvl)}, 167 size); 168 } 169 170 /// Builds IR extracting the pos-th stride from the descriptor. 171 Value SpecifierStructBuilder::dimStride(OpBuilder &builder, Location loc, 172 Dimension dim) const { 173 return extractField( 174 builder, loc, 175 ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)}); 176 } 177 178 /// Builds IR inserting the pos-th stride into the descriptor. 179 void SpecifierStructBuilder::setDimStride(OpBuilder &builder, Location loc, 180 Dimension dim, Value size) { 181 insertField( 182 builder, loc, 183 ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)}, 184 size); 185 } 186 187 /// Builds IR extracting the pos-th memory size into the descriptor. 188 Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc, 189 FieldIndex fidx) const { 190 return extractField( 191 builder, loc, 192 ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)}); 193 } 194 195 /// Builds IR inserting the `fidx`-th memory-size into the descriptor. 196 void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc, 197 FieldIndex fidx, Value size) { 198 insertField( 199 builder, loc, 200 ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)}, 201 size); 202 } 203 204 /// Builds IR extracting the memory size array from the descriptor. 205 Value SpecifierStructBuilder::memSizeArray(OpBuilder &builder, 206 Location loc) const { 207 return builder.create<LLVM::ExtractValueOp>(loc, value, 208 kMemSizePosInSpecifier); 209 } 210 211 /// Builds IR inserting the memory size array into the descriptor. 212 void SpecifierStructBuilder::setMemSizeArray(OpBuilder &builder, Location loc, 213 Value array) { 214 value = builder.create<LLVM::InsertValueOp>(loc, value, array, 215 kMemSizePosInSpecifier); 216 } 217 218 } // namespace 219 220 //===----------------------------------------------------------------------===// 221 // The sparse storage specifier type converter (defined in Passes.h). 222 //===----------------------------------------------------------------------===// 223 224 StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() { 225 addConversion([](Type type) { return type; }); 226 addConversion(convertSpecifier); 227 } 228 229 //===----------------------------------------------------------------------===// 230 // Storage specifier conversion rules. 231 //===----------------------------------------------------------------------===// 232 233 template <typename Base, typename SourceOp> 234 class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> { 235 public: 236 using OpAdaptor = typename SourceOp::Adaptor; 237 using OpConversionPattern<SourceOp>::OpConversionPattern; 238 239 LogicalResult 240 matchAndRewrite(SourceOp op, OpAdaptor adaptor, 241 ConversionPatternRewriter &rewriter) const override { 242 SpecifierStructBuilder spec(adaptor.getSpecifier()); 243 switch (op.getSpecifierKind()) { 244 case StorageSpecifierKind::LvlSize: { 245 Value v = Base::onLvlSize(rewriter, op, spec, (*op.getLevel())); 246 rewriter.replaceOp(op, v); 247 return success(); 248 } 249 case StorageSpecifierKind::DimOffset: { 250 Value v = Base::onDimOffset(rewriter, op, spec, (*op.getLevel())); 251 rewriter.replaceOp(op, v); 252 return success(); 253 } 254 case StorageSpecifierKind::DimStride: { 255 Value v = Base::onDimStride(rewriter, op, spec, (*op.getLevel())); 256 rewriter.replaceOp(op, v); 257 return success(); 258 } 259 case StorageSpecifierKind::CrdMemSize: 260 case StorageSpecifierKind::PosMemSize: 261 case StorageSpecifierKind::ValMemSize: { 262 auto enc = op.getSpecifier().getType().getEncoding(); 263 StorageLayout layout(enc); 264 std::optional<unsigned> lvl; 265 if (op.getLevel()) 266 lvl = (*op.getLevel()); 267 unsigned idx = 268 layout.getMemRefFieldIndex(toFieldKind(op.getSpecifierKind()), lvl); 269 Value v = Base::onMemSize(rewriter, op, spec, idx); 270 rewriter.replaceOp(op, v); 271 return success(); 272 } 273 } 274 llvm_unreachable("unrecognized specifer kind"); 275 } 276 }; 277 278 struct StorageSpecifierSetOpConverter 279 : public SpecifierGetterSetterOpConverter<StorageSpecifierSetOpConverter, 280 SetStorageSpecifierOp> { 281 using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter; 282 283 static Value onLvlSize(OpBuilder &builder, SetStorageSpecifierOp op, 284 SpecifierStructBuilder &spec, Level lvl) { 285 spec.setLvlSize(builder, op.getLoc(), lvl, op.getValue()); 286 return spec; 287 } 288 289 static Value onDimOffset(OpBuilder &builder, SetStorageSpecifierOp op, 290 SpecifierStructBuilder &spec, Dimension d) { 291 spec.setDimOffset(builder, op.getLoc(), d, op.getValue()); 292 return spec; 293 } 294 295 static Value onDimStride(OpBuilder &builder, SetStorageSpecifierOp op, 296 SpecifierStructBuilder &spec, Dimension d) { 297 spec.setDimStride(builder, op.getLoc(), d, op.getValue()); 298 return spec; 299 } 300 301 static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op, 302 SpecifierStructBuilder &spec, FieldIndex fidx) { 303 spec.setMemSize(builder, op.getLoc(), fidx, op.getValue()); 304 return spec; 305 } 306 }; 307 308 struct StorageSpecifierGetOpConverter 309 : public SpecifierGetterSetterOpConverter<StorageSpecifierGetOpConverter, 310 GetStorageSpecifierOp> { 311 using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter; 312 313 static Value onLvlSize(OpBuilder &builder, GetStorageSpecifierOp op, 314 SpecifierStructBuilder &spec, Level lvl) { 315 return spec.lvlSize(builder, op.getLoc(), lvl); 316 } 317 318 static Value onDimOffset(OpBuilder &builder, GetStorageSpecifierOp op, 319 const SpecifierStructBuilder &spec, Dimension d) { 320 return spec.dimOffset(builder, op.getLoc(), d); 321 } 322 323 static Value onDimStride(OpBuilder &builder, GetStorageSpecifierOp op, 324 const SpecifierStructBuilder &spec, Dimension d) { 325 return spec.dimStride(builder, op.getLoc(), d); 326 } 327 328 static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op, 329 SpecifierStructBuilder &spec, FieldIndex fidx) { 330 return spec.memSize(builder, op.getLoc(), fidx); 331 } 332 }; 333 334 struct StorageSpecifierInitOpConverter 335 : public OpConversionPattern<StorageSpecifierInitOp> { 336 public: 337 using OpConversionPattern::OpConversionPattern; 338 LogicalResult 339 matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor, 340 ConversionPatternRewriter &rewriter) const override { 341 Type llvmType = getTypeConverter()->convertType(op.getResult().getType()); 342 rewriter.replaceOp( 343 op, SpecifierStructBuilder::getInitValue( 344 rewriter, op.getLoc(), llvmType, adaptor.getSource())); 345 return success(); 346 } 347 }; 348 349 //===----------------------------------------------------------------------===// 350 // Public method for populating conversion rules. 351 //===----------------------------------------------------------------------===// 352 353 void mlir::populateStorageSpecifierToLLVMPatterns( 354 const TypeConverter &converter, RewritePatternSet &patterns) { 355 patterns.add<StorageSpecifierGetOpConverter, StorageSpecifierSetOpConverter, 356 StorageSpecifierInitOpConverter>(converter, 357 patterns.getContext()); 358 } 359