xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp (revision 3ace685105d3b50bca68328bf0c945af22d70f23)
1 //===- SparseTensorDescriptor.cpp -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "SparseTensorDescriptor.h"
10 #include "CodegenUtils.h"
11 
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 
18 using namespace mlir;
19 using namespace sparse_tensor;
20 
21 //===----------------------------------------------------------------------===//
22 // Private helper methods.
23 //===----------------------------------------------------------------------===//
24 
25 /// Constructs a nullable `LevelAttr` from the `std::optional<Level>`.
26 static IntegerAttr optionalLevelAttr(MLIRContext *ctx,
27                                      std::optional<Level> lvl) {
28   return lvl ? IntegerAttr::get(IndexType::get(ctx), lvl.value())
29              : IntegerAttr();
30 }
31 
32 // This is only ever called from `SparseTensorTypeToBufferConverter`,
33 // which is why the first argument is `RankedTensorType` rather than
34 // `SparseTensorType`.
35 static std::optional<LogicalResult>
36 convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
37   const SparseTensorType stt(rtp);
38   if (!stt.hasEncoding())
39     return std::nullopt;
40 
41   foreachFieldAndTypeInSparseTensor(
42       stt,
43       [&fields](Type fieldType, FieldIndex fieldIdx,
44                 SparseTensorFieldKind /*fieldKind*/, Level /*lvl*/,
45                 LevelType /*lt*/) -> bool {
46         assert(fieldIdx == fields.size());
47         fields.push_back(fieldType);
48         return true;
49       });
50   return success();
51 }
52 
53 //===----------------------------------------------------------------------===//
54 // The sparse tensor type converter (defined in Passes.h).
55 //===----------------------------------------------------------------------===//
56 
57 static Value materializeTuple(OpBuilder &builder, RankedTensorType tp,
58                               ValueRange inputs, Location loc) {
59   if (!getSparseTensorEncoding(tp))
60     // Not a sparse tensor.
61     return Value();
62   // Sparsifier knows how to cancel out these casts.
63   return genTuple(builder, loc, tp, inputs);
64 }
65 
66 SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
67   addConversion([](Type type) { return type; });
68   addConversion(convertSparseTensorType);
69 
70   // Required by scf.for 1:N type conversion.
71   addSourceMaterialization(materializeTuple);
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // StorageTensorSpecifier methods.
76 //===----------------------------------------------------------------------===//
77 
78 Value SparseTensorSpecifier::getInitValue(OpBuilder &builder, Location loc,
79                                           SparseTensorType stt) {
80   return builder.create<StorageSpecifierInitOp>(
81       loc, StorageSpecifierType::get(stt.getEncoding()));
82 }
83 
84 Value SparseTensorSpecifier::getSpecifierField(OpBuilder &builder, Location loc,
85                                                StorageSpecifierKind kind,
86                                                std::optional<Level> lvl) {
87   return builder.create<GetStorageSpecifierOp>(
88       loc, specifier, kind, optionalLevelAttr(specifier.getContext(), lvl));
89 }
90 
91 void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc,
92                                               Value v,
93                                               StorageSpecifierKind kind,
94                                               std::optional<Level> lvl) {
95   // TODO: make `v` have type `TypedValue<IndexType>` instead.
96   assert(v.getType().isIndex());
97   specifier = builder.create<SetStorageSpecifierOp>(
98       loc, specifier, kind, optionalLevelAttr(specifier.getContext(), lvl), v);
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // SparseTensorDescriptor methods.
103 //===----------------------------------------------------------------------===//
104 
105 Value sparse_tensor::SparseTensorDescriptor::getCrdMemRefOrView(
106     OpBuilder &builder, Location loc, Level lvl) const {
107   const Level cooStart = rType.getAoSCOOStart();
108   if (lvl < cooStart)
109     return getMemRefField(SparseTensorFieldKind::CrdMemRef, lvl);
110 
111   Value stride = constantIndex(builder, loc, rType.getLvlRank() - cooStart);
112   Value size = getCrdMemSize(builder, loc, cooStart);
113   size = builder.create<arith::DivUIOp>(loc, size, stride);
114   return builder.create<memref::SubViewOp>(
115       loc, getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart),
116       /*offset=*/ValueRange{constantIndex(builder, loc, lvl - cooStart)},
117       /*size=*/ValueRange{size},
118       /*step=*/ValueRange{stride});
119 }
120