xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h (revision 9df63b2651b2435c02a7d825953ca2ddc65c778e)
1 //===- SparseTensorDescriptor.h ---------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines utilities for the sparse memory layout.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORDESCRIPTOR_H_
14 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORDESCRIPTOR_H_
15 
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
17 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
18 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
19 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
20 
21 namespace mlir {
22 namespace sparse_tensor {
23 
24 class SparseTensorSpecifier {
25 public:
26   explicit SparseTensorSpecifier(Value specifier)
27       : specifier(cast<TypedValue<StorageSpecifierType>>(specifier)) {}
28 
29   // Undef value for level-sizes, all zero values for memory-sizes.
30   static Value getInitValue(OpBuilder &builder, Location loc,
31                             SparseTensorType stt);
32 
33   /*implicit*/ operator Value() { return specifier; }
34 
35   Value getSpecifierField(OpBuilder &builder, Location loc,
36                           StorageSpecifierKind kind, std::optional<Level> lvl);
37 
38   void setSpecifierField(OpBuilder &builder, Location loc, Value v,
39                          StorageSpecifierKind kind, std::optional<Level> lvl);
40 
41 private:
42   TypedValue<StorageSpecifierType> specifier;
43 };
44 
45 /// A helper class around an array of values that corresponds to a sparse
46 /// tensor. This class provides a set of meaningful APIs to query and update
47 /// a particular field in a consistent way. Users should not make assumptions
48 /// on how a sparse tensor is laid out but instead rely on this class to access
49 /// the right value for the right field.
50 template <typename ValueArrayRef>
51 class SparseTensorDescriptorImpl {
52 protected:
53   SparseTensorDescriptorImpl(SparseTensorType stt, ValueArrayRef fields)
54       : rType(stt), fields(fields), layout(stt) {
55     assert(layout.getNumFields() == getNumFields());
56     // We should make sure the class is trivially copyable (and should be small
57     // enough) such that we can pass it by value.
58     static_assert(std::is_trivially_copyable_v<
59                   SparseTensorDescriptorImpl<ValueArrayRef>>);
60   }
61 
62 public:
63   FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind,
64                                  std::optional<Level> lvl) const {
65     // Delegates to storage layout.
66     return layout.getMemRefFieldIndex(kind, lvl);
67   }
68 
69   unsigned getNumFields() const { return fields.size(); }
70 
71   ///
72   /// Getters: get the value for required field.
73   ///
74 
75   Value getSpecifier() const { return fields.back(); }
76 
77   Value getSpecifierField(OpBuilder &builder, Location loc,
78                           StorageSpecifierKind kind,
79                           std::optional<Level> lvl) const {
80     SparseTensorSpecifier md(fields.back());
81     return md.getSpecifierField(builder, loc, kind, lvl);
82   }
83 
84   Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const {
85     return getSpecifierField(builder, loc, StorageSpecifierKind::LvlSize, lvl);
86   }
87 
88   Value getPosMemRef(Level lvl) const {
89     return getMemRefField(SparseTensorFieldKind::PosMemRef, lvl);
90   }
91 
92   Value getValMemRef() const {
93     return getMemRefField(SparseTensorFieldKind::ValMemRef, std::nullopt);
94   }
95 
96   Value getMemRefField(SparseTensorFieldKind kind,
97                        std::optional<Level> lvl) const {
98     return getField(getMemRefFieldIndex(kind, lvl));
99   }
100 
101   Value getMemRefField(FieldIndex fidx) const {
102     assert(fidx < fields.size() - 1);
103     return getField(fidx);
104   }
105 
106   Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const {
107     return getSpecifierField(builder, loc, StorageSpecifierKind::PosMemSize,
108                              lvl);
109   }
110 
111   Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const {
112     return getSpecifierField(builder, loc, StorageSpecifierKind::CrdMemSize,
113                              lvl);
114   }
115 
116   Value getValMemSize(OpBuilder &builder, Location loc) const {
117     return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
118                              std::nullopt);
119   }
120 
121   Type getMemRefElementType(SparseTensorFieldKind kind,
122                             std::optional<Level> lvl) const {
123     return getMemRefType(getMemRefField(kind, lvl)).getElementType();
124   }
125 
126   Value getField(FieldIndex fidx) const {
127     assert(fidx < fields.size());
128     return fields[fidx];
129   }
130 
131   ValueRange getMemRefFields() const {
132     return fields.drop_back(); // drop the last metadata fields
133   }
134 
135   std::pair<FieldIndex, unsigned> getCrdMemRefIndexAndStride(Level lvl) const {
136     return layout.getFieldIndexAndStride(SparseTensorFieldKind::CrdMemRef, lvl);
137   }
138 
139   Value getAOSMemRef() const {
140     const Level cooStart = rType.getAoSCOOStart();
141     assert(cooStart < rType.getLvlRank());
142     return getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart);
143   }
144 
145   RankedTensorType getRankedTensorType() const { return rType; }
146   ValueArrayRef getFields() const { return fields; }
147   StorageLayout getLayout() const { return layout; }
148 
149 protected:
150   SparseTensorType rType;
151   ValueArrayRef fields;
152   StorageLayout layout;
153 };
154 
155 /// Uses ValueRange for immutable descriptors.
156 class SparseTensorDescriptor : public SparseTensorDescriptorImpl<ValueRange> {
157 public:
158   SparseTensorDescriptor(SparseTensorType stt, ValueRange buffers)
159       : SparseTensorDescriptorImpl<ValueRange>(stt, buffers) {}
160 
161   Value getCrdMemRefOrView(OpBuilder &builder, Location loc, Level lvl) const;
162 };
163 
164 /// Using SmallVector for mutable descriptor allows users to reuse it as a
165 /// tmp buffers to append value for some special cases, though users should
166 /// be responsible to restore the buffer to legal states after their use. It
167 /// is probably not a clean way, but it is the most efficient way to avoid
168 /// copying the fields into another SmallVector. If a more clear way is
169 /// wanted, we should change it to MutableArrayRef instead.
170 class MutSparseTensorDescriptor
171     : public SparseTensorDescriptorImpl<SmallVectorImpl<Value> &> {
172 public:
173   MutSparseTensorDescriptor(SparseTensorType stt,
174                             SmallVectorImpl<Value> &buffers)
175       : SparseTensorDescriptorImpl<SmallVectorImpl<Value> &>(stt, buffers) {}
176 
177   // Allow implicit type conversion from mutable descriptors to immutable ones
178   // (but not vice versa).
179   /*implicit*/ operator SparseTensorDescriptor() const {
180     return SparseTensorDescriptor(rType, fields);
181   }
182 
183   ///
184   /// Adds additional setters for mutable descriptor, update the value for
185   /// required field.
186   ///
187 
188   void setMemRefField(SparseTensorFieldKind kind, std::optional<Level> lvl,
189                       Value v) {
190     fields[getMemRefFieldIndex(kind, lvl)] = v;
191   }
192 
193   void setMemRefField(FieldIndex fidx, Value v) {
194     assert(fidx < fields.size() - 1);
195     fields[fidx] = v;
196   }
197 
198   void setField(FieldIndex fidx, Value v) {
199     assert(fidx < fields.size());
200     fields[fidx] = v;
201   }
202 
203   void setSpecifier(Value newSpec) { fields.back() = newSpec; }
204 
205   void setSpecifierField(OpBuilder &builder, Location loc,
206                          StorageSpecifierKind kind, std::optional<Level> lvl,
207                          Value v) {
208     SparseTensorSpecifier md(fields.back());
209     md.setSpecifierField(builder, loc, v, kind, lvl);
210     fields.back() = md;
211   }
212 
213   void setValMemSize(OpBuilder &builder, Location loc, Value v) {
214     setSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
215                       std::nullopt, v);
216   }
217 
218   void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) {
219     setSpecifierField(builder, loc, StorageSpecifierKind::CrdMemSize, lvl, v);
220   }
221 
222   void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) {
223     setSpecifierField(builder, loc, StorageSpecifierKind::PosMemSize, lvl, v);
224   }
225 
226   void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v) {
227     setSpecifierField(builder, loc, StorageSpecifierKind::LvlSize, lvl, v);
228   }
229 };
230 
231 /// Packs the given values as a "tuple" value.
232 inline Value genTuple(OpBuilder &builder, Location loc, Type tp,
233                       ValueRange values) {
234   return builder.create<UnrealizedConversionCastOp>(loc, TypeRange(tp), values)
235       .getResult(0);
236 }
237 
238 inline Value genTuple(OpBuilder &builder, Location loc,
239                       SparseTensorDescriptor desc) {
240   return genTuple(builder, loc, desc.getRankedTensorType(), desc.getFields());
241 }
242 
243 inline SparseTensorDescriptor
244 getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type) {
245   return SparseTensorDescriptor(SparseTensorType(type), adaptorValues);
246 }
247 
248 inline MutSparseTensorDescriptor
249 getMutDescriptorFromTensorTuple(ValueRange adaptorValues,
250                                 SmallVectorImpl<Value> &fields,
251                                 RankedTensorType type) {
252   fields.assign(adaptorValues.begin(), adaptorValues.end());
253   return MutSparseTensorDescriptor(SparseTensorType(type), fields);
254 }
255 
256 } // namespace sparse_tensor
257 } // namespace mlir
258 
259 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSODESCRIPTOR_H_
260