1 //===- Enums.h - Enums for the SparseTensor dialect -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Typedefs and enums shared between MLIR code for manipulating the
10 // IR, and the lightweight runtime support library for sparse tensor
11 // manipulations. That is, all the enums are used to define the API
12 // of the runtime library and hence are also needed when generating
13 // calls into the runtime library. Moveover, the `LevelType` enum
14 // is also used as the internal IR encoding of dimension level types,
15 // to avoid code duplication (e.g., for the predicates).
16 //
17 // This file also defines x-macros <https://en.wikipedia.org/wiki/X_Macro>
18 // so that we can generate variations of the public functions for each
19 // supported primary- and/or overhead-type.
20 //
21 // Because this file defines a library which is a dependency of the
22 // runtime library itself, this file must not depend on any MLIR internals
23 // (e.g., operators, attributes, ArrayRefs, etc) lest the runtime library
24 // inherit those dependencies.
25 //
26 //===----------------------------------------------------------------------===//
27
28 #ifndef MLIR_DIALECT_SPARSETENSOR_IR_ENUMS_H
29 #define MLIR_DIALECT_SPARSETENSOR_IR_ENUMS_H
30
31 // NOTE: Client code will need to include "mlir/ExecutionEngine/Float16bits.h"
32 // if they want to use the `MLIR_SPARSETENSOR_FOREVERY_V` macro.
33
34 #include <cassert>
35 #include <cinttypes>
36 #include <complex>
37 #include <optional>
38 #include <vector>
39
40 namespace mlir {
41 namespace sparse_tensor {
42
43 /// This type is used in the public API at all places where MLIR expects
44 /// values with the built-in type "index". For now, we simply assume that
45 /// type is 64-bit, but targets with different "index" bitwidths should
46 /// link with an alternatively built runtime support library.
47 using index_type = uint64_t;
48
49 /// Encoding of overhead types (both position overhead and coordinate
50 /// overhead), for "overloading" @newSparseTensor.
51 enum class OverheadType : uint32_t {
52 kIndex = 0,
53 kU64 = 1,
54 kU32 = 2,
55 kU16 = 3,
56 kU8 = 4
57 };
58
59 // This x-macro calls its argument on every overhead type which has
60 // fixed-width. It excludes `index_type` because that type is often
61 // handled specially (e.g., by translating it into the architecture-dependent
62 // equivalent fixed-width overhead type).
63 #define MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DO) \
64 DO(64, uint64_t) \
65 DO(32, uint32_t) \
66 DO(16, uint16_t) \
67 DO(8, uint8_t)
68
69 // This x-macro calls its argument on every overhead type, including
70 // `index_type`.
71 #define MLIR_SPARSETENSOR_FOREVERY_O(DO) \
72 MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DO) \
73 DO(0, index_type)
74
75 // These are not just shorthands but indicate the particular
76 // implementation used (e.g., as opposed to C99's `complex double`,
77 // or MLIR's `ComplexType`).
78 using complex64 = std::complex<double>;
79 using complex32 = std::complex<float>;
80
81 /// Encoding of the elemental type, for "overloading" @newSparseTensor.
82 enum class PrimaryType : uint32_t {
83 kF64 = 1,
84 kF32 = 2,
85 kF16 = 3,
86 kBF16 = 4,
87 kI64 = 5,
88 kI32 = 6,
89 kI16 = 7,
90 kI8 = 8,
91 kC64 = 9,
92 kC32 = 10
93 };
94
95 // This x-macro includes all `V` types.
96 #define MLIR_SPARSETENSOR_FOREVERY_V(DO) \
97 DO(F64, double) \
98 DO(F32, float) \
99 DO(F16, f16) \
100 DO(BF16, bf16) \
101 DO(I64, int64_t) \
102 DO(I32, int32_t) \
103 DO(I16, int16_t) \
104 DO(I8, int8_t) \
105 DO(C64, complex64) \
106 DO(C32, complex32)
107
108 // This x-macro includes all `V` types and supports variadic arguments.
109 #define MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, ...) \
110 DO(F64, double, __VA_ARGS__) \
111 DO(F32, float, __VA_ARGS__) \
112 DO(F16, f16, __VA_ARGS__) \
113 DO(BF16, bf16, __VA_ARGS__) \
114 DO(I64, int64_t, __VA_ARGS__) \
115 DO(I32, int32_t, __VA_ARGS__) \
116 DO(I16, int16_t, __VA_ARGS__) \
117 DO(I8, int8_t, __VA_ARGS__) \
118 DO(C64, complex64, __VA_ARGS__) \
119 DO(C32, complex32, __VA_ARGS__)
120
121 // This x-macro calls its argument on every pair of overhead and `V` types.
122 #define MLIR_SPARSETENSOR_FOREVERY_V_O(DO) \
123 MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 64, uint64_t) \
124 MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 32, uint32_t) \
125 MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 16, uint16_t) \
126 MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 8, uint8_t) \
127 MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 0, index_type)
128
isFloatingPrimaryType(PrimaryType valTy)129 constexpr bool isFloatingPrimaryType(PrimaryType valTy) {
130 return PrimaryType::kF64 <= valTy && valTy <= PrimaryType::kBF16;
131 }
132
isIntegralPrimaryType(PrimaryType valTy)133 constexpr bool isIntegralPrimaryType(PrimaryType valTy) {
134 return PrimaryType::kI64 <= valTy && valTy <= PrimaryType::kI8;
135 }
136
isRealPrimaryType(PrimaryType valTy)137 constexpr bool isRealPrimaryType(PrimaryType valTy) {
138 return PrimaryType::kF64 <= valTy && valTy <= PrimaryType::kI8;
139 }
140
isComplexPrimaryType(PrimaryType valTy)141 constexpr bool isComplexPrimaryType(PrimaryType valTy) {
142 return PrimaryType::kC64 <= valTy && valTy <= PrimaryType::kC32;
143 }
144
145 /// The actions performed by @newSparseTensor.
146 enum class Action : uint32_t {
147 kEmpty = 0,
148 kFromReader = 1,
149 kPack = 2,
150 kSortCOOInPlace = 3,
151 };
152
153 /// This enum defines all supported storage format without the level properties.
154 enum class LevelFormat : uint64_t {
155 Undef = 0x00000000,
156 Dense = 0x00010000,
157 Batch = 0x00020000,
158 Compressed = 0x00040000,
159 Singleton = 0x00080000,
160 LooseCompressed = 0x00100000,
161 NOutOfM = 0x00200000,
162 };
163
encPowOfTwo(LevelFormat fmt)164 constexpr bool encPowOfTwo(LevelFormat fmt) {
165 auto enc = static_cast<std::underlying_type_t<LevelFormat>>(fmt);
166 return (enc & (enc - 1)) == 0;
167 }
168
169 // All LevelFormats must have only one bit set (power of two).
170 static_assert(encPowOfTwo(LevelFormat::Dense) &&
171 encPowOfTwo(LevelFormat::Batch) &&
172 encPowOfTwo(LevelFormat::Compressed) &&
173 encPowOfTwo(LevelFormat::Singleton) &&
174 encPowOfTwo(LevelFormat::LooseCompressed) &&
175 encPowOfTwo(LevelFormat::NOutOfM));
176
177 template <LevelFormat... targets>
isAnyOfFmt(LevelFormat fmt)178 constexpr bool isAnyOfFmt(LevelFormat fmt) {
179 return (... || (targets == fmt));
180 }
181
182 /// Returns string representation of the given level format.
toFormatString(LevelFormat lvlFmt)183 constexpr const char *toFormatString(LevelFormat lvlFmt) {
184 switch (lvlFmt) {
185 case LevelFormat::Undef:
186 return "undef";
187 case LevelFormat::Dense:
188 return "dense";
189 case LevelFormat::Batch:
190 return "batch";
191 case LevelFormat::Compressed:
192 return "compressed";
193 case LevelFormat::Singleton:
194 return "singleton";
195 case LevelFormat::LooseCompressed:
196 return "loose_compressed";
197 case LevelFormat::NOutOfM:
198 return "structured";
199 }
200 return "";
201 }
202
203 /// This enum defines all the nondefault properties for storage formats.
204 enum class LevelPropNonDefault : uint64_t {
205 Nonunique = 0x0001, // 0b001
206 Nonordered = 0x0002, // 0b010
207 SoA = 0x0004, // 0b100
208 };
209
210 /// Returns string representation of the given level properties.
toPropString(LevelPropNonDefault lvlProp)211 constexpr const char *toPropString(LevelPropNonDefault lvlProp) {
212 switch (lvlProp) {
213 case LevelPropNonDefault::Nonunique:
214 return "nonunique";
215 case LevelPropNonDefault::Nonordered:
216 return "nonordered";
217 case LevelPropNonDefault::SoA:
218 return "soa";
219 }
220 return "";
221 }
222
223 /// This enum defines all the sparse representations supportable by
224 /// the SparseTensor dialect. We use a lightweight encoding to encode
225 /// the "format" per se (dense, compressed, singleton, loose_compressed,
226 /// n-out-of-m), the "properties" (ordered, unique) as well as n and m when
227 /// the format is NOutOfM.
228 /// The encoding is chosen for performance of the runtime library, and thus may
229 /// change in future versions; consequently, client code should use the
230 /// predicate functions defined below, rather than relying on knowledge
231 /// about the particular binary encoding.
232 ///
233 /// The `Undef` "format" is a special value used internally for cases
234 /// where we need to store an undefined or indeterminate `LevelType`.
235 /// It should not be used externally, since it does not indicate an
236 /// actual/representable format.
237
238 struct LevelType {
239 public:
240 /// Check that the `LevelType` contains a valid (possibly undefined) value.
isValidLvlBitsLevelType241 static constexpr bool isValidLvlBits(uint64_t lvlBits) {
242 auto fmt = static_cast<LevelFormat>(lvlBits & 0xffff0000);
243 const uint64_t propertyBits = lvlBits & 0xffff;
244 // If undefined/dense/batch/NOutOfM, then must be unique and ordered.
245 // Otherwise, the format must be one of the known ones.
246 return (isAnyOfFmt<LevelFormat::Undef, LevelFormat::Dense,
247 LevelFormat::Batch, LevelFormat::NOutOfM>(fmt))
248 ? (propertyBits == 0)
249 : (isAnyOfFmt<LevelFormat::Compressed, LevelFormat::Singleton,
250 LevelFormat::LooseCompressed>(fmt));
251 }
252
253 /// Convert a LevelFormat to its corresponding LevelType with the given
254 /// properties. Returns std::nullopt when the properties are not applicable
255 /// for the input level format.
256 static std::optional<LevelType>
257 buildLvlType(LevelFormat lf,
258 const std::vector<LevelPropNonDefault> &properties,
259 uint64_t n = 0, uint64_t m = 0) {
260 assert((n & 0xff) == n && (m & 0xff) == m);
261 uint64_t newN = n << 32;
262 uint64_t newM = m << 40;
263 uint64_t ltBits = static_cast<uint64_t>(lf) | newN | newM;
264 for (auto p : properties)
265 ltBits |= static_cast<uint64_t>(p);
266
267 return isValidLvlBits(ltBits) ? std::optional(LevelType(ltBits))
268 : std::nullopt;
269 }
270 static std::optional<LevelType> buildLvlType(LevelFormat lf, bool ordered,
271 bool unique, uint64_t n = 0,
272 uint64_t m = 0) {
273 std::vector<LevelPropNonDefault> properties;
274 if (!ordered)
275 properties.push_back(LevelPropNonDefault::Nonordered);
276 if (!unique)
277 properties.push_back(LevelPropNonDefault::Nonunique);
278 return buildLvlType(lf, properties, n, m);
279 }
280
281 /// Explicit conversion from uint64_t.
LevelTypeLevelType282 constexpr explicit LevelType(uint64_t bits) : lvlBits(bits) {
283 assert(isValidLvlBits(bits));
284 };
285
286 /// Constructs a LevelType with the given format using all default properties.
LevelTypeLevelType287 /*implicit*/ LevelType(LevelFormat f) : lvlBits(static_cast<uint64_t>(f)) {
288 assert(isValidLvlBits(lvlBits) && !isa<LevelFormat::NOutOfM>());
289 };
290
291 /// Converts to uint64_t
uint64_tLevelType292 explicit operator uint64_t() const { return lvlBits; }
293
294 bool operator==(const LevelType lhs) const {
295 return static_cast<uint64_t>(lhs) == lvlBits;
296 }
297 bool operator!=(const LevelType lhs) const { return !(*this == lhs); }
298
stripStorageIrrelevantPropertiesLevelType299 LevelType stripStorageIrrelevantProperties() const {
300 // Properties other than `SoA` do not change the storage scheme of the
301 // sparse tensor.
302 constexpr uint64_t mask =
303 0xffff & ~static_cast<uint64_t>(LevelPropNonDefault::SoA);
304 return LevelType(lvlBits & ~mask);
305 }
306
307 /// Get N of NOutOfM level type.
getNLevelType308 constexpr uint64_t getN() const {
309 assert(isa<LevelFormat::NOutOfM>());
310 return (lvlBits >> 32) & 0xff;
311 }
312
313 /// Get M of NOutOfM level type.
getMLevelType314 constexpr uint64_t getM() const {
315 assert(isa<LevelFormat::NOutOfM>());
316 return (lvlBits >> 40) & 0xff;
317 }
318
319 /// Get the `LevelFormat` of the `LevelType`.
getLvlFmtLevelType320 constexpr LevelFormat getLvlFmt() const {
321 return static_cast<LevelFormat>(lvlBits & 0xffff0000);
322 }
323
324 /// Check if the `LevelType` is in the `LevelFormat`.
325 template <LevelFormat... fmt>
isaLevelType326 constexpr bool isa() const {
327 return (... || (getLvlFmt() == fmt)) || false;
328 }
329
330 /// Check if the `LevelType` has the properties
331 template <LevelPropNonDefault p>
isaLevelType332 constexpr bool isa() const {
333 return lvlBits & static_cast<uint64_t>(p);
334 }
335
336 /// Check if the `LevelType` is considered to be sparse.
hasSparseSemanticLevelType337 constexpr bool hasSparseSemantic() const {
338 return isa<LevelFormat::Compressed, LevelFormat::Singleton,
339 LevelFormat::LooseCompressed, LevelFormat::NOutOfM>();
340 }
341
342 /// Check if the `LevelType` is considered to be dense-like.
hasDenseSemanticLevelType343 constexpr bool hasDenseSemantic() const {
344 return isa<LevelFormat::Dense, LevelFormat::Batch>();
345 }
346
347 /// Check if the `LevelType` needs positions array.
isWithPosLTLevelType348 constexpr bool isWithPosLT() const {
349 assert(!isa<LevelFormat::Undef>());
350 return isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>();
351 }
352
353 /// Check if the `LevelType` needs coordinates array.
isWithCrdLTLevelType354 constexpr bool isWithCrdLT() const {
355 assert(!isa<LevelFormat::Undef>());
356 // All sparse levels has coordinate array.
357 return hasSparseSemantic();
358 }
359
getNumBufferLevelType360 constexpr unsigned getNumBuffer() const {
361 return hasDenseSemantic() ? 0 : (isWithPosLT() ? 2 : 1);
362 }
363
toMLIRStringLevelType364 std::string toMLIRString() const {
365 std::string lvlStr = toFormatString(getLvlFmt());
366 std::string propStr = "";
367 if (isa<LevelFormat::NOutOfM>()) {
368 lvlStr +=
369 "[" + std::to_string(getN()) + ", " + std::to_string(getM()) + "]";
370 }
371 if (isa<LevelPropNonDefault::Nonunique>())
372 propStr += toPropString(LevelPropNonDefault::Nonunique);
373
374 if (isa<LevelPropNonDefault::Nonordered>()) {
375 if (!propStr.empty())
376 propStr += ", ";
377 propStr += toPropString(LevelPropNonDefault::Nonordered);
378 }
379 if (isa<LevelPropNonDefault::SoA>()) {
380 if (!propStr.empty())
381 propStr += ", ";
382 propStr += toPropString(LevelPropNonDefault::SoA);
383 }
384 if (!propStr.empty())
385 lvlStr += ("(" + propStr + ")");
386 return lvlStr;
387 }
388
389 private:
390 /// Bit manipulations for LevelType:
391 ///
392 /// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
393 ///
394 uint64_t lvlBits;
395 };
396
397 // For backward-compatibility. TODO: remove below after fully migration.
nToBits(uint64_t n)398 constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
mToBits(uint64_t m)399 constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
400
401 inline std::optional<LevelType>
402 buildLevelType(LevelFormat lf,
403 const std::vector<LevelPropNonDefault> &properties,
404 uint64_t n = 0, uint64_t m = 0) {
405 return LevelType::buildLvlType(lf, properties, n, m);
406 }
407 inline std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
408 bool unique, uint64_t n = 0,
409 uint64_t m = 0) {
410 return LevelType::buildLvlType(lf, ordered, unique, n, m);
411 }
isUndefLT(LevelType lt)412 inline bool isUndefLT(LevelType lt) { return lt.isa<LevelFormat::Undef>(); }
isDenseLT(LevelType lt)413 inline bool isDenseLT(LevelType lt) { return lt.isa<LevelFormat::Dense>(); }
isBatchLT(LevelType lt)414 inline bool isBatchLT(LevelType lt) { return lt.isa<LevelFormat::Batch>(); }
isCompressedLT(LevelType lt)415 inline bool isCompressedLT(LevelType lt) {
416 return lt.isa<LevelFormat::Compressed>();
417 }
isLooseCompressedLT(LevelType lt)418 inline bool isLooseCompressedLT(LevelType lt) {
419 return lt.isa<LevelFormat::LooseCompressed>();
420 }
isSingletonLT(LevelType lt)421 inline bool isSingletonLT(LevelType lt) {
422 return lt.isa<LevelFormat::Singleton>();
423 }
isNOutOfMLT(LevelType lt)424 inline bool isNOutOfMLT(LevelType lt) { return lt.isa<LevelFormat::NOutOfM>(); }
isOrderedLT(LevelType lt)425 inline bool isOrderedLT(LevelType lt) {
426 return !lt.isa<LevelPropNonDefault::Nonordered>();
427 }
isUniqueLT(LevelType lt)428 inline bool isUniqueLT(LevelType lt) {
429 return !lt.isa<LevelPropNonDefault::Nonunique>();
430 }
isWithCrdLT(LevelType lt)431 inline bool isWithCrdLT(LevelType lt) { return lt.isWithCrdLT(); }
isWithPosLT(LevelType lt)432 inline bool isWithPosLT(LevelType lt) { return lt.isWithPosLT(); }
isValidLT(LevelType lt)433 inline bool isValidLT(LevelType lt) {
434 return LevelType::isValidLvlBits(static_cast<uint64_t>(lt));
435 }
getLevelFormat(LevelType lt)436 inline std::optional<LevelFormat> getLevelFormat(LevelType lt) {
437 LevelFormat fmt = lt.getLvlFmt();
438 if (fmt == LevelFormat::Undef)
439 return std::nullopt;
440 return fmt;
441 }
getN(LevelType lt)442 inline uint64_t getN(LevelType lt) { return lt.getN(); }
getM(LevelType lt)443 inline uint64_t getM(LevelType lt) { return lt.getM(); }
isValidNOutOfMLT(LevelType lt,uint64_t n,uint64_t m)444 inline bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
445 return isNOutOfMLT(lt) && lt.getN() == n && lt.getM() == m;
446 }
toMLIRString(LevelType lt)447 inline std::string toMLIRString(LevelType lt) { return lt.toMLIRString(); }
448
449 /// Bit manipulations for affine encoding.
450 ///
451 /// Note that because the indices in the mappings refer to dimensions
452 /// and levels (and *not* the sizes of these dimensions and levels), the
453 /// 64-bit encoding gives ample room for a compact encoding of affine
454 /// operations in the higher bits. Pure permutations still allow for
455 /// 60-bit indices. But non-permutations reserve 20-bits for the
456 /// potential three components (index i, constant, index ii).
457 ///
458 /// The compact encoding is as follows:
459 ///
460 /// 0xffffffffffffffff
461 /// |0000 | 60-bit idx| e.g. i
462 /// |0001 floor| 20-bit const|20-bit idx| e.g. i floor c
463 /// |0010 mod | 20-bit const|20-bit idx| e.g. i mod c
464 /// |0011 mul |20-bit idx|20-bit const|20-bit idx| e.g. i + c * ii
465 ///
466 /// This encoding provides sufficient generality for currently supported
467 /// sparse tensor types. To generalize this more, we will need to provide
468 /// a broader encoding scheme for affine functions. Also, the library
469 /// encoding may be replaced with pure "direct-IR" code in the future.
470 ///
encodeDim(uint64_t i,uint64_t cf,uint64_t cm)471 constexpr uint64_t encodeDim(uint64_t i, uint64_t cf, uint64_t cm) {
472 if (cf != 0) {
473 assert(cf <= 0xfffffu && cm == 0 && i <= 0xfffffu);
474 return (static_cast<uint64_t>(0x01u) << 60) | (cf << 20) | i;
475 }
476 if (cm != 0) {
477 assert(cm <= 0xfffffu && i <= 0xfffffu);
478 return (static_cast<uint64_t>(0x02u) << 60) | (cm << 20) | i;
479 }
480 assert(i <= 0x0fffffffffffffffu);
481 return i;
482 }
encodeLvl(uint64_t i,uint64_t c,uint64_t ii)483 constexpr uint64_t encodeLvl(uint64_t i, uint64_t c, uint64_t ii) {
484 if (c != 0) {
485 assert(c <= 0xfffffu && ii <= 0xfffffu && i <= 0xfffffu);
486 return (static_cast<uint64_t>(0x03u) << 60) | (c << 20) | (ii << 40) | i;
487 }
488 assert(i <= 0x0fffffffffffffffu);
489 return i;
490 }
isEncodedFloor(uint64_t v)491 constexpr bool isEncodedFloor(uint64_t v) { return (v >> 60) == 0x01u; }
isEncodedMod(uint64_t v)492 constexpr bool isEncodedMod(uint64_t v) { return (v >> 60) == 0x02u; }
isEncodedMul(uint64_t v)493 constexpr bool isEncodedMul(uint64_t v) { return (v >> 60) == 0x03u; }
decodeIndex(uint64_t v)494 constexpr uint64_t decodeIndex(uint64_t v) { return v & 0xfffffu; }
decodeConst(uint64_t v)495 constexpr uint64_t decodeConst(uint64_t v) { return (v >> 20) & 0xfffffu; }
decodeMulc(uint64_t v)496 constexpr uint64_t decodeMulc(uint64_t v) { return (v >> 20) & 0xfffffu; }
decodeMuli(uint64_t v)497 constexpr uint64_t decodeMuli(uint64_t v) { return (v >> 40) & 0xfffffu; }
498
499 } // namespace sparse_tensor
500 } // namespace mlir
501
502 #endif // MLIR_DIALECT_SPARSETENSOR_IR_ENUMS_H
503