xref: /llvm-project/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h (revision bfde17834dd9bd30da8f56166cd545f566f64895)
1 //===- SparseTensorType.h - Wrapper around RankedTensorType -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header defines the `SparseTensorType` wrapper class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORTYPE_H_
14 #define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORTYPE_H_
15 
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
17 
18 namespace mlir {
19 namespace sparse_tensor {
20 
21 //===----------------------------------------------------------------------===//
22 /// A wrapper around `RankedTensorType`, which has three goals:
23 ///
24 /// (1) To provide a uniform API for querying aspects of sparse-tensor
25 /// types; in particular, to make the "dimension" vs "level" distinction
26 /// overt (i.e., explicit everywhere).  Thus, throughout the sparsifier
27 /// this class should be preferred over using `RankedTensorType` or
28 /// `ShapedType` directly, since the methods of the latter do not make
29 /// the "dimension" vs "level" distinction overt.
30 ///
31 /// (2) To provide a uniform abstraction over both sparse-tensor
32 /// types (i.e., `RankedTensorType` with `SparseTensorEncodingAttr`)
33 /// and dense-tensor types (i.e., `RankedTensorType` without an encoding).
34 /// That is, we want to manipulate dense-tensor types using the same API
35 /// that we use for manipulating sparse-tensor types; both to keep the
36 /// "dimension" vs "level" distinction overt, and to avoid needing to
37 /// handle certain cases specially in the sparsifier.
38 ///
39 /// (3) To provide uniform handling of "defaults".  In particular
40 /// this means that dense-tensors should always return the same answers
41 /// as sparse-tensors with a default encoding.  But it additionally means
42 /// that the answers should be normalized, so that there's no way to
43 /// distinguish between non-provided data (which is filled in by default)
44 /// vs explicitly-provided data which equals the defaults.
45 ///
46 class SparseTensorType {
47 public:
48   // We memoize `lvlRank`, `dimToLvl`, and `lvlToDim` to avoid repeating
49   // the conditionals throughout the rest of the class.
50   SparseTensorType(RankedTensorType rtp)
51       : rtp(rtp), enc(getSparseTensorEncoding(rtp)),
52         lvlRank(enc ? enc.getLvlRank() : getDimRank()),
53         dimToLvl(enc.isIdentity() ? AffineMap() : enc.getDimToLvl()),
54         lvlToDim(enc.isIdentity() ? AffineMap() : enc.getLvlToDim()) {
55     assert(rtp && "got null RankedTensorType");
56     assert((!isIdentity() || getDimRank() == lvlRank) && "Rank mismatch");
57   }
58 
59   SparseTensorType(ShapedType stp, SparseTensorEncodingAttr enc)
60       : SparseTensorType(
61             RankedTensorType::get(stp.getShape(), stp.getElementType(), enc)) {}
62 
63   SparseTensorType &operator=(const SparseTensorType &) = delete;
64   SparseTensorType(const SparseTensorType &) = default;
65 
66   //
67   // Factory methods to construct a new `SparseTensorType`
68   // with the same dimension-shape and element type.
69   //
70 
71   SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const {
72     return SparseTensorType(rtp, newEnc);
73   }
74 
75   SparseTensorType withDimToLvl(AffineMap dimToLvl) const {
76     return withEncoding(enc.withDimToLvl(dimToLvl));
77   }
78 
79   SparseTensorType withDimToLvl(SparseTensorEncodingAttr dimToLvlEnc) const {
80     return withEncoding(enc.withDimToLvl(dimToLvlEnc));
81   }
82 
83   SparseTensorType withDimToLvl(const SparseTensorType &dimToLvlSTT) const {
84     return withDimToLvl(dimToLvlSTT.getEncoding());
85   }
86 
87   SparseTensorType withoutDimToLvl() const {
88     return withEncoding(enc.withoutDimToLvl());
89   }
90 
91   SparseTensorType withBitWidths(unsigned posWidth, unsigned crdWidth) const {
92     return withEncoding(enc.withBitWidths(posWidth, crdWidth));
93   }
94 
95   SparseTensorType withoutBitWidths() const {
96     return withEncoding(enc.withoutBitWidths());
97   }
98 
99   SparseTensorType withExplicitVal(Attribute explicitVal) const {
100     return withEncoding(enc.withExplicitVal(explicitVal));
101   }
102 
103   SparseTensorType withoutExplicitVal() const {
104     return withEncoding(enc.withoutExplicitVal());
105   }
106 
107   SparseTensorType withImplicitVal(Attribute implicitVal) const {
108     return withEncoding(enc.withImplicitVal(implicitVal));
109   }
110 
111   SparseTensorType withoutImplicitVal() const {
112     return withEncoding(enc.withoutImplicitVal());
113   }
114 
115   SparseTensorType
116   withDimSlices(ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
117     return withEncoding(enc.withDimSlices(dimSlices));
118   }
119 
120   SparseTensorType withoutDimSlices() const {
121     return withEncoding(enc.withoutDimSlices());
122   }
123 
124   /// Allow implicit conversion to `RankedTensorType`, `ShapedType`,
125   /// and `Type`.  These are implicit to help alleviate the impedance
126   /// mismatch for code that has not been converted to use `SparseTensorType`
127   /// directly.  Once more uses have been converted to `SparseTensorType`,
128   /// we may want to make these explicit instead.
129   ///
130   /// WARNING: This user-defined-conversion method causes overload
131   /// ambiguity whenever passing a `SparseTensorType` directly to a
132   /// function which is overloaded to accept either `Type` or `TypeRange`.
133   /// In particular, this includes `RewriterBase::replaceOpWithNewOp<OpTy>`
134   /// and `OpBuilder::create<OpTy>` whenever the `OpTy::build` is overloaded
135   /// thus.  This happens because the `TypeRange<T>(T&&)` ctor is implicit
136   /// as well, and there's no SFINAE we can add to this method that would
137   /// block subsequent application of that ctor.  The only way to fix the
138   /// overload ambiguity is to avoid *implicit* conversion at the callsite:
139   /// e.g., by using `static_cast` to make the conversion explicit, by
140   /// assigning the `SparseTensorType` to a temporary variable of the
141   /// desired type, etc.
142   //
143   // NOTE: We implement this as a single templated user-defined-conversion
144   // function to avoid ambiguity problems when the desired result is `Type`
145   // (since both `RankedTensorType` and `ShapedType` can be implicitly
146   // converted to `Type`).
147   template <typename T, typename = std::enable_if_t<
148                             std::is_convertible_v<RankedTensorType, T>>>
149   /*implicit*/ operator T() const {
150     return rtp;
151   }
152 
153   /// Explicitly convert to `RankedTensorType`.  This method is
154   /// a convenience for resolving overload-ambiguity issues with
155   /// implicit conversion.
156   RankedTensorType getRankedTensorType() const { return rtp; }
157 
158   bool operator==(const SparseTensorType &other) const {
159     // All other fields are derived from `rtp` and therefore don't need
160     // to be checked.
161     return rtp == other.rtp;
162   }
163 
164   bool operator!=(const SparseTensorType &other) const {
165     return !(*this == other);
166   }
167 
168   MLIRContext *getContext() const { return rtp.getContext(); }
169 
170   Type getElementType() const { return rtp.getElementType(); }
171 
172   SparseTensorEncodingAttr getEncoding() const { return enc; }
173 
174   //
175   // SparseTensorEncodingAttr delegators
176   //
177 
178   /// Returns true for tensors which have an encoding, and false for
179   /// those which do not.  Therefore tensors with an all-dense encoding
180   /// return true.
181   bool hasEncoding() const { return static_cast<bool>(enc); }
182 
183   /// Returns true for tensors where every level is dense.
184   /// (This is always true for dense-tensors.)
185   bool isAllDense() const { return enc.isAllDense(); }
186 
187   /// Returns true for tensors where every level is ordered.
188   /// (This is always true for dense-tensors.)
189   bool isAllOrdered() const { return enc.isAllOrdered(); }
190 
191   /// Translates between level / dimension coordinate space.
192   ValueRange translateCrds(OpBuilder &builder, Location loc, ValueRange crds,
193                            CrdTransDirectionKind dir) const {
194     return enc.translateCrds(builder, loc, crds, dir);
195   }
196 
197   /// Returns true if the dimToLvl mapping is a permutation.
198   /// (This is always true for dense-tensors.)
199   bool isPermutation() const { return enc.isPermutation(); }
200 
201   /// Returns true if the dimToLvl mapping is the identity.
202   /// (This is always true for dense-tensors.)
203   bool isIdentity() const { return enc.isIdentity(); }
204 
205   //
206   // Other methods.
207   //
208 
209   /// Returns the dimToLvl mapping (or the null-map for the identity).
210   /// If you intend to compare the results of this method for equality,
211   /// see `hasSameDimToLvl` instead.
212   AffineMap getDimToLvl() const { return dimToLvl; }
213 
214   /// Returns the lvlToDiml mapping (or the null-map for the identity).
215   AffineMap getLvlToDim() const { return lvlToDim; }
216 
217   /// Returns the dimToLvl mapping, where the identity map is expanded out
218   /// into a full `AffineMap`.  This method is provided as a convenience,
219   /// but for most purposes other methods (`isIdentity`, `getDimToLvl`,
220   /// etc) will be more helpful.
221   AffineMap getExpandedDimToLvl() const {
222     return dimToLvl
223                ? dimToLvl
224                : AffineMap::getMultiDimIdentityMap(getDimRank(), getContext());
225   }
226 
227   /// Returns true iff the two types have the same mapping.  This method
228   /// takes care to handle identity maps properly, so it should be preferred
229   /// over using `getDimToLvl` followed by `AffineMap::operator==`.
230   bool hasSameDimToLvl(const SparseTensorType &other) const {
231     // If the maps are the identity, then we need to check the rank
232     // to be sure they're the same size identity.  (And since identity
233     // means dimRank==lvlRank, we use lvlRank as a minor optimization.)
234     return isIdentity() ? (other.isIdentity() && lvlRank == other.lvlRank)
235                         : (dimToLvl == other.dimToLvl);
236   }
237 
238   /// Returns the dimension-rank.
239   Dimension getDimRank() const { return rtp.getRank(); }
240 
241   /// Returns the level-rank.
242   Level getLvlRank() const { return lvlRank; }
243 
244   /// Returns the dimension-shape.
245   ArrayRef<Size> getDimShape() const { return rtp.getShape(); }
246 
247   /// Returns the level-shape.
248   SmallVector<Size> getLvlShape() const {
249     return getEncoding().translateShape(getDimShape(),
250                                         CrdTransDirectionKind::dim2lvl);
251   }
252 
253   /// Returns the batched level-rank.
254   unsigned getBatchLvlRank() const { return getEncoding().getBatchLvlRank(); }
255 
256   /// Returns the batched level-shape.
257   SmallVector<Size> getBatchLvlShape() const {
258     auto lvlShape = getEncoding().translateShape(
259         getDimShape(), CrdTransDirectionKind::dim2lvl);
260     lvlShape.truncate(getEncoding().getBatchLvlRank());
261     return lvlShape;
262   }
263 
264   /// Returns the type with an identity mapping.
265   RankedTensorType getDemappedType() const {
266     return RankedTensorType::get(getLvlShape(), getElementType(),
267                                  enc.withoutDimToLvl());
268   }
269 
270   /// Safely looks up the requested dimension-DynSize.  If you intend
271   /// to check the result with `ShapedType::isDynamic`, then see the
272   /// `getStaticDimSize` method instead.
273   Size getDynamicDimSize(Dimension d) const {
274     assert(d < getDimRank() && "Dimension is out of bounds");
275     return getDimShape()[d];
276   }
277 
278   /// Returns true if no dimension has dynamic size.
279   bool hasStaticDimShape() const { return rtp.hasStaticShape(); }
280 
281   /// Returns true if any dimension has dynamic size.
282   bool hasDynamicDimShape() const { return !hasStaticDimShape(); }
283 
284   /// Returns true if the given dimension has dynamic size.  If you
285   /// intend to call `getDynamicDimSize` based on the result, then see
286   /// the `getStaticDimSize` method instead.
287   bool isDynamicDim(Dimension d) const {
288     // We don't use `rtp.isDynamicDim(d)` because we want the
289     // OOB error message to be consistent with `getDynamicDimSize`.
290     return ShapedType::isDynamic(getDynamicDimSize(d));
291   }
292 
293   /// Returns the number of dimensions which have dynamic sizes.
294   /// The return type is `int64_t` to maintain consistency with
295   /// `ShapedType::Trait<T>::getNumDynamicDims`.
296   size_t getNumDynamicDims() const { return rtp.getNumDynamicDims(); }
297 
298   ArrayRef<LevelType> getLvlTypes() const { return enc.getLvlTypes(); }
299   LevelType getLvlType(Level l) const {
300     // This OOB check is for dense-tensors, since this class knows
301     // their lvlRank (whereas STEA::getLvlType will/can only check
302     // OOB for sparse-tensors).
303     assert(l < lvlRank && "Level out of bounds");
304     return enc.getLvlType(l);
305   }
306 
307   // We can't just delegate these, since we want to use this class's
308   // `getLvlType` method instead of STEA's.
309   bool isDenseLvl(Level l) const { return isDenseLT(getLvlType(l)); }
310   bool isCompressedLvl(Level l) const { return isCompressedLT(getLvlType(l)); }
311   bool isLooseCompressedLvl(Level l) const {
312     return isLooseCompressedLT(getLvlType(l));
313   }
314   bool isSingletonLvl(Level l) const { return isSingletonLT(getLvlType(l)); }
315   bool isNOutOfMLvl(Level l) const { return isNOutOfMLT(getLvlType(l)); }
316   bool isOrderedLvl(Level l) const { return isOrderedLT(getLvlType(l)); }
317   bool isUniqueLvl(Level l) const { return isUniqueLT(getLvlType(l)); }
318   bool isWithPos(Level l) const { return isWithPosLT(getLvlType(l)); }
319   bool isWithCrd(Level l) const { return isWithCrdLT(getLvlType(l)); }
320 
321   /// Returns the coordinate-overhead bitwidth, defaulting to zero.
322   unsigned getCrdWidth() const { return enc ? enc.getCrdWidth() : 0; }
323 
324   /// Returns the position-overhead bitwidth, defaulting to zero.
325   unsigned getPosWidth() const { return enc ? enc.getPosWidth() : 0; }
326 
327   /// Returns the explicit value, defaulting to null Attribute for unset.
328   Attribute getExplicitVal() const {
329     return enc ? enc.getExplicitVal() : nullptr;
330   }
331 
332   /// Returns the implicit value, defaulting to null Attribute for 0.
333   Attribute getImplicitVal() const {
334     return enc ? enc.getImplicitVal() : nullptr;
335   }
336 
337   /// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`.
338   Type getCrdType() const { return enc.getCrdElemType(); }
339 
340   /// Returns the position-overhead MLIR type, defaulting to `IndexType`.
341   Type getPosType() const { return enc.getPosElemType(); }
342 
343   /// Returns true iff this sparse tensor type has a trailing
344   /// COO region starting at the given level. By default, it
345   /// tests for a unique COO type at top level.
346   bool isCOOType(Level startLvl = 0, bool isUnique = true) const;
347 
348   /// Returns the starting level of this sparse tensor type for a
349   /// trailing COO region that spans **at least** two levels. If
350   /// no such COO region is found, then returns the level-rank.
351   ///
352   /// DEPRECATED: use getCOOSegment instead;
353   Level getAoSCOOStart() const { return getEncoding().getAoSCOOStart(); };
354 
355   /// Returns [un]ordered COO type for this sparse tensor type.
356   RankedTensorType getCOOType(bool ordered) const;
357 
358   /// Returns a list of COO segments in the sparse tensor types.
359   SmallVector<COOSegment> getCOOSegments() const {
360     return getEncoding().getCOOSegments();
361   }
362 
363 private:
364   // These two must be const, to ensure coherence of the memoized fields.
365   const RankedTensorType rtp;
366   const SparseTensorEncodingAttr enc;
367   // Memoized to avoid frequent redundant conditionals.
368   const Level lvlRank;
369   const AffineMap dimToLvl;
370   const AffineMap lvlToDim;
371 };
372 
373 /// Convenience methods to obtain a SparseTensorType from a Value.
374 inline SparseTensorType getSparseTensorType(Value val) {
375   return SparseTensorType(cast<RankedTensorType>(val.getType()));
376 }
377 inline std::optional<SparseTensorType> tryGetSparseTensorType(Value val) {
378   if (auto rtp = dyn_cast<RankedTensorType>(val.getType()))
379     return SparseTensorType(rtp);
380   return std::nullopt;
381 }
382 
383 } // namespace sparse_tensor
384 } // namespace mlir
385 
386 #endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORTYPE_H_
387