xref: /llvm-project/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h (revision b719ab4eef634f24605ca7ccd4874338c34e05bd)
1 //===- SPIRVTypes.h - MLIR SPIR-V Types -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
14 #define MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
15 
16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/IR/Location.h"
20 #include "mlir/IR/TypeSupport.h"
21 #include "mlir/IR/Types.h"
22 
23 #include <cstdint>
24 #include <tuple>
25 
26 namespace mlir {
27 namespace spirv {
28 
29 namespace detail {
30 struct ArrayTypeStorage;
31 struct CooperativeMatrixTypeStorage;
32 struct ImageTypeStorage;
33 struct MatrixTypeStorage;
34 struct PointerTypeStorage;
35 struct RuntimeArrayTypeStorage;
36 struct SampledImageTypeStorage;
37 struct StructTypeStorage;
38 
39 } // namespace detail
40 
41 // Base SPIR-V type for providing availability queries.
42 class SPIRVType : public Type {
43 public:
44   using Type::Type;
45 
46   static bool classof(Type type);
47 
48   bool isScalarOrVector();
49 
50   /// The extension requirements for each type are following the
51   /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
52   /// convention.
53   using ExtensionArrayRefVector = SmallVectorImpl<ArrayRef<Extension>>;
54 
55   /// Appends to `extensions` the extensions needed for this type to appear in
56   /// the given `storage` class. This method does not guarantee the uniqueness
57   /// of extensions; the same extension may be appended multiple times.
58   void getExtensions(ExtensionArrayRefVector &extensions,
59                      std::optional<StorageClass> storage = std::nullopt);
60 
61   /// The capability requirements for each type are following the
62   /// ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D))
63   /// convention.
64   using CapabilityArrayRefVector = SmallVectorImpl<ArrayRef<Capability>>;
65 
66   /// Appends to `capabilities` the capabilities needed for this type to appear
67   /// in the given `storage` class. This method does not guarantee the
68   /// uniqueness of capabilities; the same capability may be appended multiple
69   /// times.
70   void getCapabilities(CapabilityArrayRefVector &capabilities,
71                        std::optional<StorageClass> storage = std::nullopt);
72 
73   /// Returns the size in bytes for each type. If no size can be calculated,
74   /// returns `std::nullopt`. Note that if the type has explicit layout, it is
75   /// also taken into account in calculation.
76   std::optional<int64_t> getSizeInBytes();
77 };
78 
79 // SPIR-V scalar type: bool type, integer type, floating point type.
80 class ScalarType : public SPIRVType {
81 public:
82   using SPIRVType::SPIRVType;
83 
84   static bool classof(Type type);
85 
86   /// Returns true if the given integer type is valid for the SPIR-V dialect.
87   static bool isValid(FloatType);
88   /// Returns true if the given float type is valid for the SPIR-V dialect.
89   static bool isValid(IntegerType);
90 
91   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
92                      std::optional<StorageClass> storage = std::nullopt);
93   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
94                        std::optional<StorageClass> storage = std::nullopt);
95 
96   std::optional<int64_t> getSizeInBytes();
97 };
98 
99 // SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
100 class CompositeType : public SPIRVType {
101 public:
102   using SPIRVType::SPIRVType;
103 
104   static bool classof(Type type);
105 
106   /// Returns true if the given vector type is valid for the SPIR-V dialect.
107   static bool isValid(VectorType);
108 
109   /// Return the number of elements of the type. This should only be called if
110   /// hasCompileTimeKnownNumElements is true.
111   unsigned getNumElements() const;
112 
113   Type getElementType(unsigned) const;
114 
115   /// Return true if the number of elements is known at compile time and is not
116   /// implementation dependent.
117   bool hasCompileTimeKnownNumElements() const;
118 
119   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
120                      std::optional<StorageClass> storage = std::nullopt);
121   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
122                        std::optional<StorageClass> storage = std::nullopt);
123 
124   std::optional<int64_t> getSizeInBytes();
125 };
126 
127 // SPIR-V array type
128 class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
129                                         detail::ArrayTypeStorage> {
130 public:
131   using Base::Base;
132 
133   static constexpr StringLiteral name = "spirv.array";
134 
135   static ArrayType get(Type elementType, unsigned elementCount);
136 
137   /// Returns an array type with the given stride in bytes.
138   static ArrayType get(Type elementType, unsigned elementCount,
139                        unsigned stride);
140 
141   unsigned getNumElements() const;
142 
143   Type getElementType() const;
144 
145   /// Returns the array stride in bytes. 0 means no stride decorated on this
146   /// type.
147   unsigned getArrayStride() const;
148 
149   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
150                      std::optional<StorageClass> storage = std::nullopt);
151   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
152                        std::optional<StorageClass> storage = std::nullopt);
153 
154   /// Returns the array size in bytes. Since array type may have an explicit
155   /// stride declaration (in bytes), we also include it in the calculation.
156   std::optional<int64_t> getSizeInBytes();
157 };
158 
159 // SPIR-V image type
160 class ImageType
161     : public Type::TypeBase<ImageType, SPIRVType, detail::ImageTypeStorage> {
162 public:
163   using Base::Base;
164 
165   static constexpr StringLiteral name = "spirv.image";
166 
167   static ImageType
168   get(Type elementType, Dim dim,
169       ImageDepthInfo depth = ImageDepthInfo::DepthUnknown,
170       ImageArrayedInfo arrayed = ImageArrayedInfo::NonArrayed,
171       ImageSamplingInfo samplingInfo = ImageSamplingInfo::SingleSampled,
172       ImageSamplerUseInfo samplerUse = ImageSamplerUseInfo::SamplerUnknown,
173       ImageFormat format = ImageFormat::Unknown) {
174     return ImageType::get(
175         std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
176                    ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>(
177             elementType, dim, depth, arrayed, samplingInfo, samplerUse,
178             format));
179   }
180 
181   static ImageType
182       get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
183                      ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>);
184 
185   Type getElementType() const;
186   Dim getDim() const;
187   ImageDepthInfo getDepthInfo() const;
188   ImageArrayedInfo getArrayedInfo() const;
189   ImageSamplingInfo getSamplingInfo() const;
190   ImageSamplerUseInfo getSamplerUseInfo() const;
191   ImageFormat getImageFormat() const;
192   // TODO: Add support for Access qualifier
193 
194   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
195                      std::optional<StorageClass> storage = std::nullopt);
196   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
197                        std::optional<StorageClass> storage = std::nullopt);
198 };
199 
200 // SPIR-V pointer type
201 class PointerType : public Type::TypeBase<PointerType, SPIRVType,
202                                           detail::PointerTypeStorage> {
203 public:
204   using Base::Base;
205 
206   static constexpr StringLiteral name = "spirv.pointer";
207 
208   static PointerType get(Type pointeeType, StorageClass storageClass);
209 
210   Type getPointeeType() const;
211 
212   StorageClass getStorageClass() const;
213 
214   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
215                      std::optional<StorageClass> storage = std::nullopt);
216   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
217                        std::optional<StorageClass> storage = std::nullopt);
218 };
219 
220 // SPIR-V run-time array type
221 class RuntimeArrayType
222     : public Type::TypeBase<RuntimeArrayType, SPIRVType,
223                             detail::RuntimeArrayTypeStorage> {
224 public:
225   using Base::Base;
226 
227   static constexpr StringLiteral name = "spirv.rtarray";
228 
229   static RuntimeArrayType get(Type elementType);
230 
231   /// Returns a runtime array type with the given stride in bytes.
232   static RuntimeArrayType get(Type elementType, unsigned stride);
233 
234   Type getElementType() const;
235 
236   /// Returns the array stride in bytes. 0 means no stride decorated on this
237   /// type.
238   unsigned getArrayStride() const;
239 
240   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
241                      std::optional<StorageClass> storage = std::nullopt);
242   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
243                        std::optional<StorageClass> storage = std::nullopt);
244 };
245 
246 // SPIR-V sampled image type
247 class SampledImageType
248     : public Type::TypeBase<SampledImageType, SPIRVType,
249                             detail::SampledImageTypeStorage> {
250 public:
251   using Base::Base;
252 
253   static constexpr StringLiteral name = "spirv.sampled_image";
254 
255   static SampledImageType get(Type imageType);
256 
257   static SampledImageType
258   getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType);
259 
260   static LogicalResult
261   verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
262                    Type imageType);
263 
264   Type getImageType() const;
265 
266   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
267                      std::optional<spirv::StorageClass> storage = std::nullopt);
268   void
269   getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
270                   std::optional<spirv::StorageClass> storage = std::nullopt);
271 };
272 
273 /// SPIR-V struct type. Two kinds of struct types are supported:
274 /// - Literal: a literal struct type is uniqued by its fields (types + offset
275 /// info + decoration info).
276 /// - Identified: an indentified struct type is uniqued by its string identifier
277 /// (name). This is useful in representing recursive structs. For example, the
278 /// following C struct:
279 ///
280 /// struct A {
281 ///   A* next;
282 /// };
283 ///
284 /// would be represented in MLIR as:
285 ///
286 /// !spirv.struct<A, (!spirv.ptr<!spirv.struct<A>, Generic>)>
287 ///
288 /// In the above, expressing recursive struct types is accomplished by giving a
289 /// recursive struct a unique identified and using that identifier in the struct
290 /// definition for recursive references.
291 class StructType
292     : public Type::TypeBase<StructType, CompositeType,
293                             detail::StructTypeStorage, TypeTrait::IsMutable> {
294 public:
295   using Base::Base;
296 
297   // Type for specifying the offset of the struct members
298   using OffsetInfo = uint32_t;
299 
300   static constexpr StringLiteral name = "spirv.struct";
301 
302   // Type for specifying the decoration(s) on struct members
303   struct MemberDecorationInfo {
304     uint32_t memberIndex : 31;
305     uint32_t hasValue : 1;
306     Decoration decoration;
307     uint32_t decorationValue;
308 
309     MemberDecorationInfo(uint32_t index, uint32_t hasValue,
310                          Decoration decoration, uint32_t decorationValue)
311         : memberIndex(index), hasValue(hasValue), decoration(decoration),
312           decorationValue(decorationValue) {}
313 
314     bool operator==(const MemberDecorationInfo &other) const {
315       return (this->memberIndex == other.memberIndex) &&
316              (this->decoration == other.decoration) &&
317              (this->decorationValue == other.decorationValue);
318     }
319 
320     bool operator<(const MemberDecorationInfo &other) const {
321       return this->memberIndex < other.memberIndex ||
322              (this->memberIndex == other.memberIndex &&
323               static_cast<uint32_t>(this->decoration) <
324                   static_cast<uint32_t>(other.decoration));
325     }
326   };
327 
328   /// Construct a literal StructType with at least one member.
329   static StructType get(ArrayRef<Type> memberTypes,
330                         ArrayRef<OffsetInfo> offsetInfo = {},
331                         ArrayRef<MemberDecorationInfo> memberDecorations = {});
332 
333   /// Construct an identified StructType. This creates a StructType whose body
334   /// (member types, offset info, and decorations) is not set yet. A call to
335   /// StructType::trySetBody(...) must follow when the StructType contents are
336   /// available (e.g. parsed or deserialized).
337   ///
338   /// Note: If another thread creates (or had already created) a struct with the
339   /// same identifier, that struct will be returned as a result.
340   static StructType getIdentified(MLIRContext *context, StringRef identifier);
341 
342   /// Construct a (possibly identified) StructType with no members.
343   ///
344   /// Note: this method might fail in a multi-threaded setup if another thread
345   /// created an identified struct with the same identifier but with different
346   /// contents before returning. In which case, an empty (default-constructed)
347   /// StructType is returned.
348   static StructType getEmpty(MLIRContext *context, StringRef identifier = "");
349 
350   /// For literal structs, return an empty string.
351   /// For identified structs, return the struct's identifier.
352   StringRef getIdentifier() const;
353 
354   /// Returns true if the StructType is identified.
355   bool isIdentified() const;
356 
357   unsigned getNumElements() const;
358 
359   Type getElementType(unsigned) const;
360 
361   TypeRange getElementTypes() const;
362 
363   bool hasOffset() const;
364 
365   uint64_t getMemberOffset(unsigned) const;
366 
367   // Returns in `memberDecorations` the Decorations (apart from Offset)
368   // associated with all members of the StructType.
369   void getMemberDecorations(SmallVectorImpl<StructType::MemberDecorationInfo>
370                                 &memberDecorations) const;
371 
372   // Returns in `decorationsInfo` all the Decorations (apart from Offset)
373   // associated with the `i`-th member of the StructType.
374   void getMemberDecorations(
375       unsigned i,
376       SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const;
377 
378   /// Sets the contents of an incomplete identified StructType. This method must
379   /// be called only for identified StructTypes and it must be called only once
380   /// per instance. Otherwise, failure() is returned.
381   LogicalResult
382   trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
383              ArrayRef<MemberDecorationInfo> memberDecorations = {});
384 
385   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
386                      std::optional<StorageClass> storage = std::nullopt);
387   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
388                        std::optional<StorageClass> storage = std::nullopt);
389 };
390 
391 llvm::hash_code
392 hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
393 
394 // SPIR-V KHR cooperative matrix type
395 class CooperativeMatrixType
396     : public Type::TypeBase<CooperativeMatrixType, CompositeType,
397                             detail::CooperativeMatrixTypeStorage> {
398 public:
399   using Base::Base;
400 
401   static constexpr StringLiteral name = "spirv.coopmatrix";
402 
403   static CooperativeMatrixType get(Type elementType, uint32_t rows,
404                                    uint32_t columns, Scope scope,
405                                    CooperativeMatrixUseKHR use);
406   Type getElementType() const;
407 
408   /// Returns the scope of the matrix.
409   Scope getScope() const;
410   /// Returns the number of rows of the matrix.
411   uint32_t getRows() const;
412   /// Returns the number of columns of the matrix.
413   uint32_t getColumns() const;
414   /// Returns the use parameter of the cooperative matrix.
415   CooperativeMatrixUseKHR getUse() const;
416 
417   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
418                      std::optional<StorageClass> storage = std::nullopt);
419   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
420                        std::optional<StorageClass> storage = std::nullopt);
421 };
422 
423 // SPIR-V matrix type
424 class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
425                                          detail::MatrixTypeStorage> {
426 public:
427   using Base::Base;
428 
429   static constexpr StringLiteral name = "spirv.matrix";
430 
431   static MatrixType get(Type columnType, uint32_t columnCount);
432 
433   static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
434                                Type columnType, uint32_t columnCount);
435 
436   static LogicalResult
437   verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
438                    Type columnType, uint32_t columnCount);
439 
440   /// Returns true if the matrix elements are vectors of float elements.
441   static bool isValidColumnType(Type columnType);
442 
443   Type getColumnType() const;
444 
445   /// Returns the number of rows.
446   unsigned getNumRows() const;
447 
448   /// Returns the number of columns.
449   unsigned getNumColumns() const;
450 
451   /// Returns total number of elements (rows*columns).
452   unsigned getNumElements() const;
453 
454   /// Returns the elements' type (i.e, single element type).
455   Type getElementType() const;
456 
457   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
458                      std::optional<StorageClass> storage = std::nullopt);
459   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
460                        std::optional<StorageClass> storage = std::nullopt);
461 };
462 
463 } // namespace spirv
464 } // namespace mlir
465 
466 #endif // MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
467