xref: /llvm-project/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h (revision c42bbda42542cbf811f42a288f63b10e72e204de)
1 //===- Enums.h - Enums for the SparseTensor dialect -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Typedefs and enums shared between MLIR code for manipulating the
10 // IR, and the lightweight runtime support library for sparse tensor
11 // manipulations.  That is, all the enums are used to define the API
12 // of the runtime library and hence are also needed when generating
13 // calls into the runtime library.  Moveover, the `LevelType` enum
14 // is also used as the internal IR encoding of dimension level types,
15 // to avoid code duplication (e.g., for the predicates).
16 //
17 // This file also defines x-macros <https://en.wikipedia.org/wiki/X_Macro>
18 // so that we can generate variations of the public functions for each
19 // supported primary- and/or overhead-type.
20 //
21 // Because this file defines a library which is a dependency of the
22 // runtime library itself, this file must not depend on any MLIR internals
23 // (e.g., operators, attributes, ArrayRefs, etc) lest the runtime library
24 // inherit those dependencies.
25 //
26 //===----------------------------------------------------------------------===//
27 
28 #ifndef MLIR_DIALECT_SPARSETENSOR_IR_ENUMS_H
29 #define MLIR_DIALECT_SPARSETENSOR_IR_ENUMS_H
30 
31 // NOTE: Client code will need to include "mlir/ExecutionEngine/Float16bits.h"
32 // if they want to use the `MLIR_SPARSETENSOR_FOREVERY_V` macro.
33 
34 #include <cassert>
35 #include <cinttypes>
36 #include <complex>
37 #include <optional>
38 #include <vector>
39 
40 namespace mlir {
41 namespace sparse_tensor {
42 
43 /// This type is used in the public API at all places where MLIR expects
44 /// values with the built-in type "index".  For now, we simply assume that
45 /// type is 64-bit, but targets with different "index" bitwidths should
46 /// link with an alternatively built runtime support library.
47 using index_type = uint64_t;
48 
49 /// Encoding of overhead types (both position overhead and coordinate
50 /// overhead), for "overloading" @newSparseTensor.
51 enum class OverheadType : uint32_t {
52   kIndex = 0,
53   kU64 = 1,
54   kU32 = 2,
55   kU16 = 3,
56   kU8 = 4
57 };
58 
59 // This x-macro calls its argument on every overhead type which has
60 // fixed-width.  It excludes `index_type` because that type is often
61 // handled specially (e.g., by translating it into the architecture-dependent
62 // equivalent fixed-width overhead type).
63 #define MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DO)                                 \
64   DO(64, uint64_t)                                                             \
65   DO(32, uint32_t)                                                             \
66   DO(16, uint16_t)                                                             \
67   DO(8, uint8_t)
68 
69 // This x-macro calls its argument on every overhead type, including
70 // `index_type`.
71 #define MLIR_SPARSETENSOR_FOREVERY_O(DO)                                       \
72   MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DO)                                       \
73   DO(0, index_type)
74 
75 // These are not just shorthands but indicate the particular
76 // implementation used (e.g., as opposed to C99's `complex double`,
77 // or MLIR's `ComplexType`).
78 using complex64 = std::complex<double>;
79 using complex32 = std::complex<float>;
80 
81 /// Encoding of the elemental type, for "overloading" @newSparseTensor.
82 enum class PrimaryType : uint32_t {
83   kF64 = 1,
84   kF32 = 2,
85   kF16 = 3,
86   kBF16 = 4,
87   kI64 = 5,
88   kI32 = 6,
89   kI16 = 7,
90   kI8 = 8,
91   kC64 = 9,
92   kC32 = 10
93 };
94 
95 // This x-macro includes all `V` types.
96 #define MLIR_SPARSETENSOR_FOREVERY_V(DO)                                       \
97   DO(F64, double)                                                              \
98   DO(F32, float)                                                               \
99   DO(F16, f16)                                                                 \
100   DO(BF16, bf16)                                                               \
101   DO(I64, int64_t)                                                             \
102   DO(I32, int32_t)                                                             \
103   DO(I16, int16_t)                                                             \
104   DO(I8, int8_t)                                                               \
105   DO(C64, complex64)                                                           \
106   DO(C32, complex32)
107 
108 // This x-macro includes all `V` types and supports variadic arguments.
109 #define MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, ...)                              \
110   DO(F64, double, __VA_ARGS__)                                                 \
111   DO(F32, float, __VA_ARGS__)                                                  \
112   DO(F16, f16, __VA_ARGS__)                                                    \
113   DO(BF16, bf16, __VA_ARGS__)                                                  \
114   DO(I64, int64_t, __VA_ARGS__)                                                \
115   DO(I32, int32_t, __VA_ARGS__)                                                \
116   DO(I16, int16_t, __VA_ARGS__)                                                \
117   DO(I8, int8_t, __VA_ARGS__)                                                  \
118   DO(C64, complex64, __VA_ARGS__)                                              \
119   DO(C32, complex32, __VA_ARGS__)
120 
121 // This x-macro calls its argument on every pair of overhead and `V` types.
122 #define MLIR_SPARSETENSOR_FOREVERY_V_O(DO)                                     \
123   MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 64, uint64_t)                           \
124   MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 32, uint32_t)                           \
125   MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 16, uint16_t)                           \
126   MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 8, uint8_t)                             \
127   MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 0, index_type)
128 
isFloatingPrimaryType(PrimaryType valTy)129 constexpr bool isFloatingPrimaryType(PrimaryType valTy) {
130   return PrimaryType::kF64 <= valTy && valTy <= PrimaryType::kBF16;
131 }
132 
isIntegralPrimaryType(PrimaryType valTy)133 constexpr bool isIntegralPrimaryType(PrimaryType valTy) {
134   return PrimaryType::kI64 <= valTy && valTy <= PrimaryType::kI8;
135 }
136 
isRealPrimaryType(PrimaryType valTy)137 constexpr bool isRealPrimaryType(PrimaryType valTy) {
138   return PrimaryType::kF64 <= valTy && valTy <= PrimaryType::kI8;
139 }
140 
isComplexPrimaryType(PrimaryType valTy)141 constexpr bool isComplexPrimaryType(PrimaryType valTy) {
142   return PrimaryType::kC64 <= valTy && valTy <= PrimaryType::kC32;
143 }
144 
145 /// The actions performed by @newSparseTensor.
146 enum class Action : uint32_t {
147   kEmpty = 0,
148   kFromReader = 1,
149   kPack = 2,
150   kSortCOOInPlace = 3,
151 };
152 
153 /// This enum defines all supported storage format without the level properties.
154 enum class LevelFormat : uint64_t {
155   Undef = 0x00000000,
156   Dense = 0x00010000,
157   Batch = 0x00020000,
158   Compressed = 0x00040000,
159   Singleton = 0x00080000,
160   LooseCompressed = 0x00100000,
161   NOutOfM = 0x00200000,
162 };
163 
encPowOfTwo(LevelFormat fmt)164 constexpr bool encPowOfTwo(LevelFormat fmt) {
165   auto enc = static_cast<std::underlying_type_t<LevelFormat>>(fmt);
166   return (enc & (enc - 1)) == 0;
167 }
168 
169 // All LevelFormats must have only one bit set (power of two).
170 static_assert(encPowOfTwo(LevelFormat::Dense) &&
171               encPowOfTwo(LevelFormat::Batch) &&
172               encPowOfTwo(LevelFormat::Compressed) &&
173               encPowOfTwo(LevelFormat::Singleton) &&
174               encPowOfTwo(LevelFormat::LooseCompressed) &&
175               encPowOfTwo(LevelFormat::NOutOfM));
176 
177 template <LevelFormat... targets>
isAnyOfFmt(LevelFormat fmt)178 constexpr bool isAnyOfFmt(LevelFormat fmt) {
179   return (... || (targets == fmt));
180 }
181 
182 /// Returns string representation of the given level format.
toFormatString(LevelFormat lvlFmt)183 constexpr const char *toFormatString(LevelFormat lvlFmt) {
184   switch (lvlFmt) {
185   case LevelFormat::Undef:
186     return "undef";
187   case LevelFormat::Dense:
188     return "dense";
189   case LevelFormat::Batch:
190     return "batch";
191   case LevelFormat::Compressed:
192     return "compressed";
193   case LevelFormat::Singleton:
194     return "singleton";
195   case LevelFormat::LooseCompressed:
196     return "loose_compressed";
197   case LevelFormat::NOutOfM:
198     return "structured";
199   }
200   return "";
201 }
202 
203 /// This enum defines all the nondefault properties for storage formats.
204 enum class LevelPropNonDefault : uint64_t {
205   Nonunique = 0x0001,  // 0b001
206   Nonordered = 0x0002, // 0b010
207   SoA = 0x0004,        // 0b100
208 };
209 
210 /// Returns string representation of the given level properties.
toPropString(LevelPropNonDefault lvlProp)211 constexpr const char *toPropString(LevelPropNonDefault lvlProp) {
212   switch (lvlProp) {
213   case LevelPropNonDefault::Nonunique:
214     return "nonunique";
215   case LevelPropNonDefault::Nonordered:
216     return "nonordered";
217   case LevelPropNonDefault::SoA:
218     return "soa";
219   }
220   return "";
221 }
222 
223 /// This enum defines all the sparse representations supportable by
224 /// the SparseTensor dialect. We use a lightweight encoding to encode
225 /// the "format" per se (dense, compressed, singleton, loose_compressed,
226 /// n-out-of-m), the "properties" (ordered, unique) as well as n and m when
227 /// the format is NOutOfM.
228 /// The encoding is chosen for performance of the runtime library, and thus may
229 /// change in future versions; consequently, client code should use the
230 /// predicate functions defined below, rather than relying on knowledge
231 /// about the particular binary encoding.
232 ///
233 /// The `Undef` "format" is a special value used internally for cases
234 /// where we need to store an undefined or indeterminate `LevelType`.
235 /// It should not be used externally, since it does not indicate an
236 /// actual/representable format.
237 
238 struct LevelType {
239 public:
240   /// Check that the `LevelType` contains a valid (possibly undefined) value.
isValidLvlBitsLevelType241   static constexpr bool isValidLvlBits(uint64_t lvlBits) {
242     auto fmt = static_cast<LevelFormat>(lvlBits & 0xffff0000);
243     const uint64_t propertyBits = lvlBits & 0xffff;
244     // If undefined/dense/batch/NOutOfM, then must be unique and ordered.
245     // Otherwise, the format must be one of the known ones.
246     return (isAnyOfFmt<LevelFormat::Undef, LevelFormat::Dense,
247                        LevelFormat::Batch, LevelFormat::NOutOfM>(fmt))
248                ? (propertyBits == 0)
249                : (isAnyOfFmt<LevelFormat::Compressed, LevelFormat::Singleton,
250                              LevelFormat::LooseCompressed>(fmt));
251   }
252 
253   /// Convert a LevelFormat to its corresponding LevelType with the given
254   /// properties. Returns std::nullopt when the properties are not applicable
255   /// for the input level format.
256   static std::optional<LevelType>
257   buildLvlType(LevelFormat lf,
258                const std::vector<LevelPropNonDefault> &properties,
259                uint64_t n = 0, uint64_t m = 0) {
260     assert((n & 0xff) == n && (m & 0xff) == m);
261     uint64_t newN = n << 32;
262     uint64_t newM = m << 40;
263     uint64_t ltBits = static_cast<uint64_t>(lf) | newN | newM;
264     for (auto p : properties)
265       ltBits |= static_cast<uint64_t>(p);
266 
267     return isValidLvlBits(ltBits) ? std::optional(LevelType(ltBits))
268                                   : std::nullopt;
269   }
270   static std::optional<LevelType> buildLvlType(LevelFormat lf, bool ordered,
271                                                bool unique, uint64_t n = 0,
272                                                uint64_t m = 0) {
273     std::vector<LevelPropNonDefault> properties;
274     if (!ordered)
275       properties.push_back(LevelPropNonDefault::Nonordered);
276     if (!unique)
277       properties.push_back(LevelPropNonDefault::Nonunique);
278     return buildLvlType(lf, properties, n, m);
279   }
280 
281   /// Explicit conversion from uint64_t.
LevelTypeLevelType282   constexpr explicit LevelType(uint64_t bits) : lvlBits(bits) {
283     assert(isValidLvlBits(bits));
284   };
285 
286   /// Constructs a LevelType with the given format using all default properties.
LevelTypeLevelType287   /*implicit*/ LevelType(LevelFormat f) : lvlBits(static_cast<uint64_t>(f)) {
288     assert(isValidLvlBits(lvlBits) && !isa<LevelFormat::NOutOfM>());
289   };
290 
291   /// Converts to uint64_t
uint64_tLevelType292   explicit operator uint64_t() const { return lvlBits; }
293 
294   bool operator==(const LevelType lhs) const {
295     return static_cast<uint64_t>(lhs) == lvlBits;
296   }
297   bool operator!=(const LevelType lhs) const { return !(*this == lhs); }
298 
stripStorageIrrelevantPropertiesLevelType299   LevelType stripStorageIrrelevantProperties() const {
300     // Properties other than `SoA` do not change the storage scheme of the
301     // sparse tensor.
302     constexpr uint64_t mask =
303         0xffff & ~static_cast<uint64_t>(LevelPropNonDefault::SoA);
304     return LevelType(lvlBits & ~mask);
305   }
306 
307   /// Get N of NOutOfM level type.
getNLevelType308   constexpr uint64_t getN() const {
309     assert(isa<LevelFormat::NOutOfM>());
310     return (lvlBits >> 32) & 0xff;
311   }
312 
313   /// Get M of NOutOfM level type.
getMLevelType314   constexpr uint64_t getM() const {
315     assert(isa<LevelFormat::NOutOfM>());
316     return (lvlBits >> 40) & 0xff;
317   }
318 
319   /// Get the `LevelFormat` of the `LevelType`.
getLvlFmtLevelType320   constexpr LevelFormat getLvlFmt() const {
321     return static_cast<LevelFormat>(lvlBits & 0xffff0000);
322   }
323 
324   /// Check if the `LevelType` is in the `LevelFormat`.
325   template <LevelFormat... fmt>
isaLevelType326   constexpr bool isa() const {
327     return (... || (getLvlFmt() == fmt)) || false;
328   }
329 
330   /// Check if the `LevelType` has the properties
331   template <LevelPropNonDefault p>
isaLevelType332   constexpr bool isa() const {
333     return lvlBits & static_cast<uint64_t>(p);
334   }
335 
336   /// Check if the `LevelType` is considered to be sparse.
hasSparseSemanticLevelType337   constexpr bool hasSparseSemantic() const {
338     return isa<LevelFormat::Compressed, LevelFormat::Singleton,
339                LevelFormat::LooseCompressed, LevelFormat::NOutOfM>();
340   }
341 
342   /// Check if the `LevelType` is considered to be dense-like.
hasDenseSemanticLevelType343   constexpr bool hasDenseSemantic() const {
344     return isa<LevelFormat::Dense, LevelFormat::Batch>();
345   }
346 
347   /// Check if the `LevelType` needs positions array.
isWithPosLTLevelType348   constexpr bool isWithPosLT() const {
349     assert(!isa<LevelFormat::Undef>());
350     return isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>();
351   }
352 
353   /// Check if the `LevelType` needs coordinates array.
isWithCrdLTLevelType354   constexpr bool isWithCrdLT() const {
355     assert(!isa<LevelFormat::Undef>());
356     // All sparse levels has coordinate array.
357     return hasSparseSemantic();
358   }
359 
getNumBufferLevelType360   constexpr unsigned getNumBuffer() const {
361     return hasDenseSemantic() ? 0 : (isWithPosLT() ? 2 : 1);
362   }
363 
toMLIRStringLevelType364   std::string toMLIRString() const {
365     std::string lvlStr = toFormatString(getLvlFmt());
366     std::string propStr = "";
367     if (isa<LevelFormat::NOutOfM>()) {
368       lvlStr +=
369           "[" + std::to_string(getN()) + ", " + std::to_string(getM()) + "]";
370     }
371     if (isa<LevelPropNonDefault::Nonunique>())
372       propStr += toPropString(LevelPropNonDefault::Nonunique);
373 
374     if (isa<LevelPropNonDefault::Nonordered>()) {
375       if (!propStr.empty())
376         propStr += ", ";
377       propStr += toPropString(LevelPropNonDefault::Nonordered);
378     }
379     if (isa<LevelPropNonDefault::SoA>()) {
380       if (!propStr.empty())
381         propStr += ", ";
382       propStr += toPropString(LevelPropNonDefault::SoA);
383     }
384     if (!propStr.empty())
385       lvlStr += ("(" + propStr + ")");
386     return lvlStr;
387   }
388 
389 private:
390   /// Bit manipulations for LevelType:
391   ///
392   /// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
393   ///
394   uint64_t lvlBits;
395 };
396 
397 // For backward-compatibility. TODO: remove below after fully migration.
nToBits(uint64_t n)398 constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
mToBits(uint64_t m)399 constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
400 
401 inline std::optional<LevelType>
402 buildLevelType(LevelFormat lf,
403                const std::vector<LevelPropNonDefault> &properties,
404                uint64_t n = 0, uint64_t m = 0) {
405   return LevelType::buildLvlType(lf, properties, n, m);
406 }
407 inline std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
408                                                bool unique, uint64_t n = 0,
409                                                uint64_t m = 0) {
410   return LevelType::buildLvlType(lf, ordered, unique, n, m);
411 }
isUndefLT(LevelType lt)412 inline bool isUndefLT(LevelType lt) { return lt.isa<LevelFormat::Undef>(); }
isDenseLT(LevelType lt)413 inline bool isDenseLT(LevelType lt) { return lt.isa<LevelFormat::Dense>(); }
isBatchLT(LevelType lt)414 inline bool isBatchLT(LevelType lt) { return lt.isa<LevelFormat::Batch>(); }
isCompressedLT(LevelType lt)415 inline bool isCompressedLT(LevelType lt) {
416   return lt.isa<LevelFormat::Compressed>();
417 }
isLooseCompressedLT(LevelType lt)418 inline bool isLooseCompressedLT(LevelType lt) {
419   return lt.isa<LevelFormat::LooseCompressed>();
420 }
isSingletonLT(LevelType lt)421 inline bool isSingletonLT(LevelType lt) {
422   return lt.isa<LevelFormat::Singleton>();
423 }
isNOutOfMLT(LevelType lt)424 inline bool isNOutOfMLT(LevelType lt) { return lt.isa<LevelFormat::NOutOfM>(); }
isOrderedLT(LevelType lt)425 inline bool isOrderedLT(LevelType lt) {
426   return !lt.isa<LevelPropNonDefault::Nonordered>();
427 }
isUniqueLT(LevelType lt)428 inline bool isUniqueLT(LevelType lt) {
429   return !lt.isa<LevelPropNonDefault::Nonunique>();
430 }
isWithCrdLT(LevelType lt)431 inline bool isWithCrdLT(LevelType lt) { return lt.isWithCrdLT(); }
isWithPosLT(LevelType lt)432 inline bool isWithPosLT(LevelType lt) { return lt.isWithPosLT(); }
isValidLT(LevelType lt)433 inline bool isValidLT(LevelType lt) {
434   return LevelType::isValidLvlBits(static_cast<uint64_t>(lt));
435 }
getLevelFormat(LevelType lt)436 inline std::optional<LevelFormat> getLevelFormat(LevelType lt) {
437   LevelFormat fmt = lt.getLvlFmt();
438   if (fmt == LevelFormat::Undef)
439     return std::nullopt;
440   return fmt;
441 }
getN(LevelType lt)442 inline uint64_t getN(LevelType lt) { return lt.getN(); }
getM(LevelType lt)443 inline uint64_t getM(LevelType lt) { return lt.getM(); }
isValidNOutOfMLT(LevelType lt,uint64_t n,uint64_t m)444 inline bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
445   return isNOutOfMLT(lt) && lt.getN() == n && lt.getM() == m;
446 }
toMLIRString(LevelType lt)447 inline std::string toMLIRString(LevelType lt) { return lt.toMLIRString(); }
448 
449 /// Bit manipulations for affine encoding.
450 ///
451 /// Note that because the indices in the mappings refer to dimensions
452 /// and levels (and *not* the sizes of these dimensions and levels), the
453 /// 64-bit encoding gives ample room for a compact encoding of affine
454 /// operations in the higher bits. Pure permutations still allow for
455 /// 60-bit indices. But non-permutations reserve 20-bits for the
456 /// potential three components (index i, constant, index ii).
457 ///
458 /// The compact encoding is as follows:
459 ///
460 ///  0xffffffffffffffff
461 /// |0000      |                        60-bit idx| e.g. i
462 /// |0001 floor|           20-bit const|20-bit idx| e.g. i floor c
463 /// |0010 mod  |           20-bit const|20-bit idx| e.g. i mod c
464 /// |0011 mul  |20-bit idx|20-bit const|20-bit idx| e.g. i + c * ii
465 ///
466 /// This encoding provides sufficient generality for currently supported
467 /// sparse tensor types. To generalize this more, we will need to provide
468 /// a broader encoding scheme for affine functions. Also, the library
469 /// encoding may be replaced with pure "direct-IR" code in the future.
470 ///
encodeDim(uint64_t i,uint64_t cf,uint64_t cm)471 constexpr uint64_t encodeDim(uint64_t i, uint64_t cf, uint64_t cm) {
472   if (cf != 0) {
473     assert(cf <= 0xfffffu && cm == 0 && i <= 0xfffffu);
474     return (static_cast<uint64_t>(0x01u) << 60) | (cf << 20) | i;
475   }
476   if (cm != 0) {
477     assert(cm <= 0xfffffu && i <= 0xfffffu);
478     return (static_cast<uint64_t>(0x02u) << 60) | (cm << 20) | i;
479   }
480   assert(i <= 0x0fffffffffffffffu);
481   return i;
482 }
encodeLvl(uint64_t i,uint64_t c,uint64_t ii)483 constexpr uint64_t encodeLvl(uint64_t i, uint64_t c, uint64_t ii) {
484   if (c != 0) {
485     assert(c <= 0xfffffu && ii <= 0xfffffu && i <= 0xfffffu);
486     return (static_cast<uint64_t>(0x03u) << 60) | (c << 20) | (ii << 40) | i;
487   }
488   assert(i <= 0x0fffffffffffffffu);
489   return i;
490 }
isEncodedFloor(uint64_t v)491 constexpr bool isEncodedFloor(uint64_t v) { return (v >> 60) == 0x01u; }
isEncodedMod(uint64_t v)492 constexpr bool isEncodedMod(uint64_t v) { return (v >> 60) == 0x02u; }
isEncodedMul(uint64_t v)493 constexpr bool isEncodedMul(uint64_t v) { return (v >> 60) == 0x03u; }
decodeIndex(uint64_t v)494 constexpr uint64_t decodeIndex(uint64_t v) { return v & 0xfffffu; }
decodeConst(uint64_t v)495 constexpr uint64_t decodeConst(uint64_t v) { return (v >> 20) & 0xfffffu; }
decodeMulc(uint64_t v)496 constexpr uint64_t decodeMulc(uint64_t v) { return (v >> 20) & 0xfffffu; }
decodeMuli(uint64_t v)497 constexpr uint64_t decodeMuli(uint64_t v) { return (v >> 40) & 0xfffffu; }
498 
499 } // namespace sparse_tensor
500 } // namespace mlir
501 
502 #endif // MLIR_DIALECT_SPARSETENSOR_IR_ENUMS_H
503