1 //===- SparseTensorType.h - Wrapper around RankedTensorType -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This header defines the `SparseTensorType` wrapper class. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORTYPE_H_ 14 #define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORTYPE_H_ 15 16 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 17 18 namespace mlir { 19 namespace sparse_tensor { 20 21 //===----------------------------------------------------------------------===// 22 /// A wrapper around `RankedTensorType`, which has three goals: 23 /// 24 /// (1) To provide a uniform API for querying aspects of sparse-tensor 25 /// types; in particular, to make the "dimension" vs "level" distinction 26 /// overt (i.e., explicit everywhere). Thus, throughout the sparsifier 27 /// this class should be preferred over using `RankedTensorType` or 28 /// `ShapedType` directly, since the methods of the latter do not make 29 /// the "dimension" vs "level" distinction overt. 30 /// 31 /// (2) To provide a uniform abstraction over both sparse-tensor 32 /// types (i.e., `RankedTensorType` with `SparseTensorEncodingAttr`) 33 /// and dense-tensor types (i.e., `RankedTensorType` without an encoding). 34 /// That is, we want to manipulate dense-tensor types using the same API 35 /// that we use for manipulating sparse-tensor types; both to keep the 36 /// "dimension" vs "level" distinction overt, and to avoid needing to 37 /// handle certain cases specially in the sparsifier. 38 /// 39 /// (3) To provide uniform handling of "defaults". In particular 40 /// this means that dense-tensors should always return the same answers 41 /// as sparse-tensors with a default encoding. But it additionally means 42 /// that the answers should be normalized, so that there's no way to 43 /// distinguish between non-provided data (which is filled in by default) 44 /// vs explicitly-provided data which equals the defaults. 45 /// 46 class SparseTensorType { 47 public: 48 // We memoize `lvlRank`, `dimToLvl`, and `lvlToDim` to avoid repeating 49 // the conditionals throughout the rest of the class. 50 SparseTensorType(RankedTensorType rtp) 51 : rtp(rtp), enc(getSparseTensorEncoding(rtp)), 52 lvlRank(enc ? enc.getLvlRank() : getDimRank()), 53 dimToLvl(enc.isIdentity() ? AffineMap() : enc.getDimToLvl()), 54 lvlToDim(enc.isIdentity() ? AffineMap() : enc.getLvlToDim()) { 55 assert(rtp && "got null RankedTensorType"); 56 assert((!isIdentity() || getDimRank() == lvlRank) && "Rank mismatch"); 57 } 58 59 SparseTensorType(ShapedType stp, SparseTensorEncodingAttr enc) 60 : SparseTensorType( 61 RankedTensorType::get(stp.getShape(), stp.getElementType(), enc)) {} 62 63 SparseTensorType &operator=(const SparseTensorType &) = delete; 64 SparseTensorType(const SparseTensorType &) = default; 65 66 // 67 // Factory methods to construct a new `SparseTensorType` 68 // with the same dimension-shape and element type. 69 // 70 71 SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const { 72 return SparseTensorType(rtp, newEnc); 73 } 74 75 SparseTensorType withDimToLvl(AffineMap dimToLvl) const { 76 return withEncoding(enc.withDimToLvl(dimToLvl)); 77 } 78 79 SparseTensorType withDimToLvl(SparseTensorEncodingAttr dimToLvlEnc) const { 80 return withEncoding(enc.withDimToLvl(dimToLvlEnc)); 81 } 82 83 SparseTensorType withDimToLvl(const SparseTensorType &dimToLvlSTT) const { 84 return withDimToLvl(dimToLvlSTT.getEncoding()); 85 } 86 87 SparseTensorType withoutDimToLvl() const { 88 return withEncoding(enc.withoutDimToLvl()); 89 } 90 91 SparseTensorType withBitWidths(unsigned posWidth, unsigned crdWidth) const { 92 return withEncoding(enc.withBitWidths(posWidth, crdWidth)); 93 } 94 95 SparseTensorType withoutBitWidths() const { 96 return withEncoding(enc.withoutBitWidths()); 97 } 98 99 SparseTensorType withExplicitVal(Attribute explicitVal) const { 100 return withEncoding(enc.withExplicitVal(explicitVal)); 101 } 102 103 SparseTensorType withoutExplicitVal() const { 104 return withEncoding(enc.withoutExplicitVal()); 105 } 106 107 SparseTensorType withImplicitVal(Attribute implicitVal) const { 108 return withEncoding(enc.withImplicitVal(implicitVal)); 109 } 110 111 SparseTensorType withoutImplicitVal() const { 112 return withEncoding(enc.withoutImplicitVal()); 113 } 114 115 SparseTensorType 116 withDimSlices(ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { 117 return withEncoding(enc.withDimSlices(dimSlices)); 118 } 119 120 SparseTensorType withoutDimSlices() const { 121 return withEncoding(enc.withoutDimSlices()); 122 } 123 124 /// Allow implicit conversion to `RankedTensorType`, `ShapedType`, 125 /// and `Type`. These are implicit to help alleviate the impedance 126 /// mismatch for code that has not been converted to use `SparseTensorType` 127 /// directly. Once more uses have been converted to `SparseTensorType`, 128 /// we may want to make these explicit instead. 129 /// 130 /// WARNING: This user-defined-conversion method causes overload 131 /// ambiguity whenever passing a `SparseTensorType` directly to a 132 /// function which is overloaded to accept either `Type` or `TypeRange`. 133 /// In particular, this includes `RewriterBase::replaceOpWithNewOp<OpTy>` 134 /// and `OpBuilder::create<OpTy>` whenever the `OpTy::build` is overloaded 135 /// thus. This happens because the `TypeRange<T>(T&&)` ctor is implicit 136 /// as well, and there's no SFINAE we can add to this method that would 137 /// block subsequent application of that ctor. The only way to fix the 138 /// overload ambiguity is to avoid *implicit* conversion at the callsite: 139 /// e.g., by using `static_cast` to make the conversion explicit, by 140 /// assigning the `SparseTensorType` to a temporary variable of the 141 /// desired type, etc. 142 // 143 // NOTE: We implement this as a single templated user-defined-conversion 144 // function to avoid ambiguity problems when the desired result is `Type` 145 // (since both `RankedTensorType` and `ShapedType` can be implicitly 146 // converted to `Type`). 147 template <typename T, typename = std::enable_if_t< 148 std::is_convertible_v<RankedTensorType, T>>> 149 /*implicit*/ operator T() const { 150 return rtp; 151 } 152 153 /// Explicitly convert to `RankedTensorType`. This method is 154 /// a convenience for resolving overload-ambiguity issues with 155 /// implicit conversion. 156 RankedTensorType getRankedTensorType() const { return rtp; } 157 158 bool operator==(const SparseTensorType &other) const { 159 // All other fields are derived from `rtp` and therefore don't need 160 // to be checked. 161 return rtp == other.rtp; 162 } 163 164 bool operator!=(const SparseTensorType &other) const { 165 return !(*this == other); 166 } 167 168 MLIRContext *getContext() const { return rtp.getContext(); } 169 170 Type getElementType() const { return rtp.getElementType(); } 171 172 SparseTensorEncodingAttr getEncoding() const { return enc; } 173 174 // 175 // SparseTensorEncodingAttr delegators 176 // 177 178 /// Returns true for tensors which have an encoding, and false for 179 /// those which do not. Therefore tensors with an all-dense encoding 180 /// return true. 181 bool hasEncoding() const { return static_cast<bool>(enc); } 182 183 /// Returns true for tensors where every level is dense. 184 /// (This is always true for dense-tensors.) 185 bool isAllDense() const { return enc.isAllDense(); } 186 187 /// Returns true for tensors where every level is ordered. 188 /// (This is always true for dense-tensors.) 189 bool isAllOrdered() const { return enc.isAllOrdered(); } 190 191 /// Translates between level / dimension coordinate space. 192 ValueRange translateCrds(OpBuilder &builder, Location loc, ValueRange crds, 193 CrdTransDirectionKind dir) const { 194 return enc.translateCrds(builder, loc, crds, dir); 195 } 196 197 /// Returns true if the dimToLvl mapping is a permutation. 198 /// (This is always true for dense-tensors.) 199 bool isPermutation() const { return enc.isPermutation(); } 200 201 /// Returns true if the dimToLvl mapping is the identity. 202 /// (This is always true for dense-tensors.) 203 bool isIdentity() const { return enc.isIdentity(); } 204 205 // 206 // Other methods. 207 // 208 209 /// Returns the dimToLvl mapping (or the null-map for the identity). 210 /// If you intend to compare the results of this method for equality, 211 /// see `hasSameDimToLvl` instead. 212 AffineMap getDimToLvl() const { return dimToLvl; } 213 214 /// Returns the lvlToDiml mapping (or the null-map for the identity). 215 AffineMap getLvlToDim() const { return lvlToDim; } 216 217 /// Returns the dimToLvl mapping, where the identity map is expanded out 218 /// into a full `AffineMap`. This method is provided as a convenience, 219 /// but for most purposes other methods (`isIdentity`, `getDimToLvl`, 220 /// etc) will be more helpful. 221 AffineMap getExpandedDimToLvl() const { 222 return dimToLvl 223 ? dimToLvl 224 : AffineMap::getMultiDimIdentityMap(getDimRank(), getContext()); 225 } 226 227 /// Returns true iff the two types have the same mapping. This method 228 /// takes care to handle identity maps properly, so it should be preferred 229 /// over using `getDimToLvl` followed by `AffineMap::operator==`. 230 bool hasSameDimToLvl(const SparseTensorType &other) const { 231 // If the maps are the identity, then we need to check the rank 232 // to be sure they're the same size identity. (And since identity 233 // means dimRank==lvlRank, we use lvlRank as a minor optimization.) 234 return isIdentity() ? (other.isIdentity() && lvlRank == other.lvlRank) 235 : (dimToLvl == other.dimToLvl); 236 } 237 238 /// Returns the dimension-rank. 239 Dimension getDimRank() const { return rtp.getRank(); } 240 241 /// Returns the level-rank. 242 Level getLvlRank() const { return lvlRank; } 243 244 /// Returns the dimension-shape. 245 ArrayRef<Size> getDimShape() const { return rtp.getShape(); } 246 247 /// Returns the level-shape. 248 SmallVector<Size> getLvlShape() const { 249 return getEncoding().translateShape(getDimShape(), 250 CrdTransDirectionKind::dim2lvl); 251 } 252 253 /// Returns the batched level-rank. 254 unsigned getBatchLvlRank() const { return getEncoding().getBatchLvlRank(); } 255 256 /// Returns the batched level-shape. 257 SmallVector<Size> getBatchLvlShape() const { 258 auto lvlShape = getEncoding().translateShape( 259 getDimShape(), CrdTransDirectionKind::dim2lvl); 260 lvlShape.truncate(getEncoding().getBatchLvlRank()); 261 return lvlShape; 262 } 263 264 /// Returns the type with an identity mapping. 265 RankedTensorType getDemappedType() const { 266 return RankedTensorType::get(getLvlShape(), getElementType(), 267 enc.withoutDimToLvl()); 268 } 269 270 /// Safely looks up the requested dimension-DynSize. If you intend 271 /// to check the result with `ShapedType::isDynamic`, then see the 272 /// `getStaticDimSize` method instead. 273 Size getDynamicDimSize(Dimension d) const { 274 assert(d < getDimRank() && "Dimension is out of bounds"); 275 return getDimShape()[d]; 276 } 277 278 /// Returns true if no dimension has dynamic size. 279 bool hasStaticDimShape() const { return rtp.hasStaticShape(); } 280 281 /// Returns true if any dimension has dynamic size. 282 bool hasDynamicDimShape() const { return !hasStaticDimShape(); } 283 284 /// Returns true if the given dimension has dynamic size. If you 285 /// intend to call `getDynamicDimSize` based on the result, then see 286 /// the `getStaticDimSize` method instead. 287 bool isDynamicDim(Dimension d) const { 288 // We don't use `rtp.isDynamicDim(d)` because we want the 289 // OOB error message to be consistent with `getDynamicDimSize`. 290 return ShapedType::isDynamic(getDynamicDimSize(d)); 291 } 292 293 /// Returns the number of dimensions which have dynamic sizes. 294 /// The return type is `int64_t` to maintain consistency with 295 /// `ShapedType::Trait<T>::getNumDynamicDims`. 296 size_t getNumDynamicDims() const { return rtp.getNumDynamicDims(); } 297 298 ArrayRef<LevelType> getLvlTypes() const { return enc.getLvlTypes(); } 299 LevelType getLvlType(Level l) const { 300 // This OOB check is for dense-tensors, since this class knows 301 // their lvlRank (whereas STEA::getLvlType will/can only check 302 // OOB for sparse-tensors). 303 assert(l < lvlRank && "Level out of bounds"); 304 return enc.getLvlType(l); 305 } 306 307 // We can't just delegate these, since we want to use this class's 308 // `getLvlType` method instead of STEA's. 309 bool isDenseLvl(Level l) const { return isDenseLT(getLvlType(l)); } 310 bool isCompressedLvl(Level l) const { return isCompressedLT(getLvlType(l)); } 311 bool isLooseCompressedLvl(Level l) const { 312 return isLooseCompressedLT(getLvlType(l)); 313 } 314 bool isSingletonLvl(Level l) const { return isSingletonLT(getLvlType(l)); } 315 bool isNOutOfMLvl(Level l) const { return isNOutOfMLT(getLvlType(l)); } 316 bool isOrderedLvl(Level l) const { return isOrderedLT(getLvlType(l)); } 317 bool isUniqueLvl(Level l) const { return isUniqueLT(getLvlType(l)); } 318 bool isWithPos(Level l) const { return isWithPosLT(getLvlType(l)); } 319 bool isWithCrd(Level l) const { return isWithCrdLT(getLvlType(l)); } 320 321 /// Returns the coordinate-overhead bitwidth, defaulting to zero. 322 unsigned getCrdWidth() const { return enc ? enc.getCrdWidth() : 0; } 323 324 /// Returns the position-overhead bitwidth, defaulting to zero. 325 unsigned getPosWidth() const { return enc ? enc.getPosWidth() : 0; } 326 327 /// Returns the explicit value, defaulting to null Attribute for unset. 328 Attribute getExplicitVal() const { 329 return enc ? enc.getExplicitVal() : nullptr; 330 } 331 332 /// Returns the implicit value, defaulting to null Attribute for 0. 333 Attribute getImplicitVal() const { 334 return enc ? enc.getImplicitVal() : nullptr; 335 } 336 337 /// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`. 338 Type getCrdType() const { return enc.getCrdElemType(); } 339 340 /// Returns the position-overhead MLIR type, defaulting to `IndexType`. 341 Type getPosType() const { return enc.getPosElemType(); } 342 343 /// Returns true iff this sparse tensor type has a trailing 344 /// COO region starting at the given level. By default, it 345 /// tests for a unique COO type at top level. 346 bool isCOOType(Level startLvl = 0, bool isUnique = true) const; 347 348 /// Returns the starting level of this sparse tensor type for a 349 /// trailing COO region that spans **at least** two levels. If 350 /// no such COO region is found, then returns the level-rank. 351 /// 352 /// DEPRECATED: use getCOOSegment instead; 353 Level getAoSCOOStart() const { return getEncoding().getAoSCOOStart(); }; 354 355 /// Returns [un]ordered COO type for this sparse tensor type. 356 RankedTensorType getCOOType(bool ordered) const; 357 358 /// Returns a list of COO segments in the sparse tensor types. 359 SmallVector<COOSegment> getCOOSegments() const { 360 return getEncoding().getCOOSegments(); 361 } 362 363 private: 364 // These two must be const, to ensure coherence of the memoized fields. 365 const RankedTensorType rtp; 366 const SparseTensorEncodingAttr enc; 367 // Memoized to avoid frequent redundant conditionals. 368 const Level lvlRank; 369 const AffineMap dimToLvl; 370 const AffineMap lvlToDim; 371 }; 372 373 /// Convenience methods to obtain a SparseTensorType from a Value. 374 inline SparseTensorType getSparseTensorType(Value val) { 375 return SparseTensorType(cast<RankedTensorType>(val.getType())); 376 } 377 inline std::optional<SparseTensorType> tryGetSparseTensorType(Value val) { 378 if (auto rtp = dyn_cast<RankedTensorType>(val.getType())) 379 return SparseTensorType(rtp); 380 return std::nullopt; 381 } 382 383 } // namespace sparse_tensor 384 } // namespace mlir 385 386 #endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORTYPE_H_ 387