xref: /llvm-project/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp (revision b719ab4eef634f24605ca7ccd4874338c34e05bd)
1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 
21 #include <cstdint>
22 #include <iterator>
23 
24 using namespace mlir;
25 using namespace mlir::spirv;
26 
27 //===----------------------------------------------------------------------===//
28 // ArrayType
29 //===----------------------------------------------------------------------===//
30 
31 struct spirv::detail::ArrayTypeStorage : public TypeStorage {
32   using KeyTy = std::tuple<Type, unsigned, unsigned>;
33 
34   static ArrayTypeStorage *construct(TypeStorageAllocator &allocator,
35                                      const KeyTy &key) {
36     return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
37   }
38 
39   bool operator==(const KeyTy &key) const {
40     return key == KeyTy(elementType, elementCount, stride);
41   }
42 
43   ArrayTypeStorage(const KeyTy &key)
44       : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
45         stride(std::get<2>(key)) {}
46 
47   Type elementType;
48   unsigned elementCount;
49   unsigned stride;
50 };
51 
52 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
53   assert(elementCount && "ArrayType needs at least one element");
54   return Base::get(elementType.getContext(), elementType, elementCount,
55                    /*stride=*/0);
56 }
57 
58 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
59                          unsigned stride) {
60   assert(elementCount && "ArrayType needs at least one element");
61   return Base::get(elementType.getContext(), elementType, elementCount, stride);
62 }
63 
64 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
65 
66 Type ArrayType::getElementType() const { return getImpl()->elementType; }
67 
68 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
69 
70 void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
71                               std::optional<StorageClass> storage) {
72   llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
73 }
74 
75 void ArrayType::getCapabilities(
76     SPIRVType::CapabilityArrayRefVector &capabilities,
77     std::optional<StorageClass> storage) {
78   llvm::cast<SPIRVType>(getElementType())
79       .getCapabilities(capabilities, storage);
80 }
81 
82 std::optional<int64_t> ArrayType::getSizeInBytes() {
83   auto elementType = llvm::cast<SPIRVType>(getElementType());
84   std::optional<int64_t> size = elementType.getSizeInBytes();
85   if (!size)
86     return std::nullopt;
87   return (*size + getArrayStride()) * getNumElements();
88 }
89 
90 //===----------------------------------------------------------------------===//
91 // CompositeType
92 //===----------------------------------------------------------------------===//
93 
94 bool CompositeType::classof(Type type) {
95   if (auto vectorType = llvm::dyn_cast<VectorType>(type))
96     return isValid(vectorType);
97   return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType,
98                    spirv::MatrixType, spirv::RuntimeArrayType,
99                    spirv::StructType>(type);
100 }
101 
102 bool CompositeType::isValid(VectorType type) {
103   return type.getRank() == 1 &&
104          llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
105          llvm::isa<ScalarType>(type.getElementType());
106 }
107 
108 Type CompositeType::getElementType(unsigned index) const {
109   return TypeSwitch<Type, Type>(*this)
110       .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType>(
111           [](auto type) { return type.getElementType(); })
112       .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
113       .Case<StructType>(
114           [index](StructType type) { return type.getElementType(index); })
115       .Default(
116           [](Type) -> Type { llvm_unreachable("invalid composite type"); });
117 }
118 
119 unsigned CompositeType::getNumElements() const {
120   if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
121     return arrayType.getNumElements();
122   if (auto matrixType = llvm::dyn_cast<MatrixType>(*this))
123     return matrixType.getNumColumns();
124   if (auto structType = llvm::dyn_cast<StructType>(*this))
125     return structType.getNumElements();
126   if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
127     return vectorType.getNumElements();
128   if (llvm::isa<CooperativeMatrixType>(*this)) {
129     llvm_unreachable(
130         "invalid to query number of elements of spirv Cooperative Matrix type");
131   }
132   if (llvm::isa<RuntimeArrayType>(*this)) {
133     llvm_unreachable(
134         "invalid to query number of elements of spirv::RuntimeArray type");
135   }
136   llvm_unreachable("invalid composite type");
137 }
138 
139 bool CompositeType::hasCompileTimeKnownNumElements() const {
140   return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);
141 }
142 
143 void CompositeType::getExtensions(
144     SPIRVType::ExtensionArrayRefVector &extensions,
145     std::optional<StorageClass> storage) {
146   TypeSwitch<Type>(*this)
147       .Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
148             StructType>(
149           [&](auto type) { type.getExtensions(extensions, storage); })
150       .Case<VectorType>([&](VectorType type) {
151         return llvm::cast<ScalarType>(type.getElementType())
152             .getExtensions(extensions, storage);
153       })
154       .Default([](Type) { llvm_unreachable("invalid composite type"); });
155 }
156 
157 void CompositeType::getCapabilities(
158     SPIRVType::CapabilityArrayRefVector &capabilities,
159     std::optional<StorageClass> storage) {
160   TypeSwitch<Type>(*this)
161       .Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
162             StructType>(
163           [&](auto type) { type.getCapabilities(capabilities, storage); })
164       .Case<VectorType>([&](VectorType type) {
165         auto vecSize = getNumElements();
166         if (vecSize == 8 || vecSize == 16) {
167           static const Capability caps[] = {Capability::Vector16};
168           ArrayRef<Capability> ref(caps, std::size(caps));
169           capabilities.push_back(ref);
170         }
171         return llvm::cast<ScalarType>(type.getElementType())
172             .getCapabilities(capabilities, storage);
173       })
174       .Default([](Type) { llvm_unreachable("invalid composite type"); });
175 }
176 
177 std::optional<int64_t> CompositeType::getSizeInBytes() {
178   if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
179     return arrayType.getSizeInBytes();
180   if (auto structType = llvm::dyn_cast<StructType>(*this))
181     return structType.getSizeInBytes();
182   if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) {
183     std::optional<int64_t> elementSize =
184         llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
185     if (!elementSize)
186       return std::nullopt;
187     return *elementSize * vectorType.getNumElements();
188   }
189   return std::nullopt;
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // CooperativeMatrixType
194 //===----------------------------------------------------------------------===//
195 
196 struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
197   using KeyTy =
198       std::tuple<Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR>;
199 
200   static CooperativeMatrixTypeStorage *
201   construct(TypeStorageAllocator &allocator, const KeyTy &key) {
202     return new (allocator.allocate<CooperativeMatrixTypeStorage>())
203         CooperativeMatrixTypeStorage(key);
204   }
205 
206   bool operator==(const KeyTy &key) const {
207     return key == KeyTy(elementType, rows, columns, scope, use);
208   }
209 
210   CooperativeMatrixTypeStorage(const KeyTy &key)
211       : elementType(std::get<0>(key)), rows(std::get<1>(key)),
212         columns(std::get<2>(key)), scope(std::get<3>(key)),
213         use(std::get<4>(key)) {}
214 
215   Type elementType;
216   uint32_t rows;
217   uint32_t columns;
218   Scope scope;
219   CooperativeMatrixUseKHR use;
220 };
221 
222 CooperativeMatrixType CooperativeMatrixType::get(Type elementType,
223                                                  uint32_t rows,
224                                                  uint32_t columns, Scope scope,
225                                                  CooperativeMatrixUseKHR use) {
226   return Base::get(elementType.getContext(), elementType, rows, columns, scope,
227                    use);
228 }
229 
230 Type CooperativeMatrixType::getElementType() const {
231   return getImpl()->elementType;
232 }
233 
234 uint32_t CooperativeMatrixType::getRows() const { return getImpl()->rows; }
235 
236 uint32_t CooperativeMatrixType::getColumns() const {
237   return getImpl()->columns;
238 }
239 
240 Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
241 
242 CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
243   return getImpl()->use;
244 }
245 
246 void CooperativeMatrixType::getExtensions(
247     SPIRVType::ExtensionArrayRefVector &extensions,
248     std::optional<StorageClass> storage) {
249   llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
250   static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
251   extensions.push_back(exts);
252 }
253 
254 void CooperativeMatrixType::getCapabilities(
255     SPIRVType::CapabilityArrayRefVector &capabilities,
256     std::optional<StorageClass> storage) {
257   llvm::cast<SPIRVType>(getElementType())
258       .getCapabilities(capabilities, storage);
259   static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
260   capabilities.push_back(caps);
261 }
262 
263 //===----------------------------------------------------------------------===//
264 // ImageType
265 //===----------------------------------------------------------------------===//
266 
267 template <typename T>
268 static constexpr unsigned getNumBits() {
269   return 0;
270 }
271 template <>
272 constexpr unsigned getNumBits<Dim>() {
273   static_assert((1 << 3) > getMaxEnumValForDim(),
274                 "Not enough bits to encode Dim value");
275   return 3;
276 }
277 template <>
278 constexpr unsigned getNumBits<ImageDepthInfo>() {
279   static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
280                 "Not enough bits to encode ImageDepthInfo value");
281   return 2;
282 }
283 template <>
284 constexpr unsigned getNumBits<ImageArrayedInfo>() {
285   static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
286                 "Not enough bits to encode ImageArrayedInfo value");
287   return 1;
288 }
289 template <>
290 constexpr unsigned getNumBits<ImageSamplingInfo>() {
291   static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
292                 "Not enough bits to encode ImageSamplingInfo value");
293   return 1;
294 }
295 template <>
296 constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
297   static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
298                 "Not enough bits to encode ImageSamplerUseInfo value");
299   return 2;
300 }
301 template <>
302 constexpr unsigned getNumBits<ImageFormat>() {
303   static_assert((1 << 6) > getMaxEnumValForImageFormat(),
304                 "Not enough bits to encode ImageFormat value");
305   return 6;
306 }
307 
308 struct spirv::detail::ImageTypeStorage : public TypeStorage {
309 public:
310   using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
311                            ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
312 
313   static ImageTypeStorage *construct(TypeStorageAllocator &allocator,
314                                      const KeyTy &key) {
315     return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
316   }
317 
318   bool operator==(const KeyTy &key) const {
319     return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
320                         samplerUseInfo, format);
321   }
322 
323   ImageTypeStorage(const KeyTy &key)
324       : elementType(std::get<0>(key)), dim(std::get<1>(key)),
325         depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
326         samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
327         format(std::get<6>(key)) {}
328 
329   Type elementType;
330   Dim dim : getNumBits<Dim>();
331   ImageDepthInfo depthInfo : getNumBits<ImageDepthInfo>();
332   ImageArrayedInfo arrayedInfo : getNumBits<ImageArrayedInfo>();
333   ImageSamplingInfo samplingInfo : getNumBits<ImageSamplingInfo>();
334   ImageSamplerUseInfo samplerUseInfo : getNumBits<ImageSamplerUseInfo>();
335   ImageFormat format : getNumBits<ImageFormat>();
336 };
337 
338 ImageType
339 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
340                           ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
341                    value) {
342   return Base::get(std::get<0>(value).getContext(), value);
343 }
344 
345 Type ImageType::getElementType() const { return getImpl()->elementType; }
346 
347 Dim ImageType::getDim() const { return getImpl()->dim; }
348 
349 ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
350 
351 ImageArrayedInfo ImageType::getArrayedInfo() const {
352   return getImpl()->arrayedInfo;
353 }
354 
355 ImageSamplingInfo ImageType::getSamplingInfo() const {
356   return getImpl()->samplingInfo;
357 }
358 
359 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
360   return getImpl()->samplerUseInfo;
361 }
362 
363 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
364 
365 void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
366                               std::optional<StorageClass>) {
367   // Image types do not require extra extensions thus far.
368 }
369 
370 void ImageType::getCapabilities(
371     SPIRVType::CapabilityArrayRefVector &capabilities,
372     std::optional<StorageClass>) {
373   if (auto dimCaps = spirv::getCapabilities(getDim()))
374     capabilities.push_back(*dimCaps);
375 
376   if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
377     capabilities.push_back(*fmtCaps);
378 }
379 
380 //===----------------------------------------------------------------------===//
381 // PointerType
382 //===----------------------------------------------------------------------===//
383 
384 struct spirv::detail::PointerTypeStorage : public TypeStorage {
385   // (Type, StorageClass) as the key: Type stored in this struct, and
386   // StorageClass stored as TypeStorage's subclass data.
387   using KeyTy = std::pair<Type, StorageClass>;
388 
389   static PointerTypeStorage *construct(TypeStorageAllocator &allocator,
390                                        const KeyTy &key) {
391     return new (allocator.allocate<PointerTypeStorage>())
392         PointerTypeStorage(key);
393   }
394 
395   bool operator==(const KeyTy &key) const {
396     return key == KeyTy(pointeeType, storageClass);
397   }
398 
399   PointerTypeStorage(const KeyTy &key)
400       : pointeeType(key.first), storageClass(key.second) {}
401 
402   Type pointeeType;
403   StorageClass storageClass;
404 };
405 
406 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
407   return Base::get(pointeeType.getContext(), pointeeType, storageClass);
408 }
409 
410 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
411 
412 StorageClass PointerType::getStorageClass() const {
413   return getImpl()->storageClass;
414 }
415 
416 void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
417                                 std::optional<StorageClass> storage) {
418   // Use this pointer type's storage class because this pointer indicates we are
419   // using the pointee type in that specific storage class.
420   llvm::cast<SPIRVType>(getPointeeType())
421       .getExtensions(extensions, getStorageClass());
422 
423   if (auto scExts = spirv::getExtensions(getStorageClass()))
424     extensions.push_back(*scExts);
425 }
426 
427 void PointerType::getCapabilities(
428     SPIRVType::CapabilityArrayRefVector &capabilities,
429     std::optional<StorageClass> storage) {
430   // Use this pointer type's storage class because this pointer indicates we are
431   // using the pointee type in that specific storage class.
432   llvm::cast<SPIRVType>(getPointeeType())
433       .getCapabilities(capabilities, getStorageClass());
434 
435   if (auto scCaps = spirv::getCapabilities(getStorageClass()))
436     capabilities.push_back(*scCaps);
437 }
438 
439 //===----------------------------------------------------------------------===//
440 // RuntimeArrayType
441 //===----------------------------------------------------------------------===//
442 
443 struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
444   using KeyTy = std::pair<Type, unsigned>;
445 
446   static RuntimeArrayTypeStorage *construct(TypeStorageAllocator &allocator,
447                                             const KeyTy &key) {
448     return new (allocator.allocate<RuntimeArrayTypeStorage>())
449         RuntimeArrayTypeStorage(key);
450   }
451 
452   bool operator==(const KeyTy &key) const {
453     return key == KeyTy(elementType, stride);
454   }
455 
456   RuntimeArrayTypeStorage(const KeyTy &key)
457       : elementType(key.first), stride(key.second) {}
458 
459   Type elementType;
460   unsigned stride;
461 };
462 
463 RuntimeArrayType RuntimeArrayType::get(Type elementType) {
464   return Base::get(elementType.getContext(), elementType, /*stride=*/0);
465 }
466 
467 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
468   return Base::get(elementType.getContext(), elementType, stride);
469 }
470 
471 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
472 
473 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
474 
475 void RuntimeArrayType::getExtensions(
476     SPIRVType::ExtensionArrayRefVector &extensions,
477     std::optional<StorageClass> storage) {
478   llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
479 }
480 
481 void RuntimeArrayType::getCapabilities(
482     SPIRVType::CapabilityArrayRefVector &capabilities,
483     std::optional<StorageClass> storage) {
484   {
485     static const Capability caps[] = {Capability::Shader};
486     ArrayRef<Capability> ref(caps, std::size(caps));
487     capabilities.push_back(ref);
488   }
489   llvm::cast<SPIRVType>(getElementType())
490       .getCapabilities(capabilities, storage);
491 }
492 
493 //===----------------------------------------------------------------------===//
494 // ScalarType
495 //===----------------------------------------------------------------------===//
496 
497 bool ScalarType::classof(Type type) {
498   if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
499     return isValid(floatType);
500   }
501   if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
502     return isValid(intType);
503   }
504   return false;
505 }
506 
507 bool ScalarType::isValid(FloatType type) {
508   return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16();
509 }
510 
511 bool ScalarType::isValid(IntegerType type) {
512   return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
513 }
514 
515 void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
516                                std::optional<StorageClass> storage) {
517   // 8- or 16-bit integer/floating-point numbers will require extra extensions
518   // to appear in interface storage classes. See SPV_KHR_16bit_storage and
519   // SPV_KHR_8bit_storage for more details.
520   if (!storage)
521     return;
522 
523   switch (*storage) {
524   case StorageClass::PushConstant:
525   case StorageClass::StorageBuffer:
526   case StorageClass::Uniform:
527     if (getIntOrFloatBitWidth() == 8) {
528       static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
529       ArrayRef<Extension> ref(exts, std::size(exts));
530       extensions.push_back(ref);
531     }
532     [[fallthrough]];
533   case StorageClass::Input:
534   case StorageClass::Output:
535     if (getIntOrFloatBitWidth() == 16) {
536       static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
537       ArrayRef<Extension> ref(exts, std::size(exts));
538       extensions.push_back(ref);
539     }
540     break;
541   default:
542     break;
543   }
544 }
545 
546 void ScalarType::getCapabilities(
547     SPIRVType::CapabilityArrayRefVector &capabilities,
548     std::optional<StorageClass> storage) {
549   unsigned bitwidth = getIntOrFloatBitWidth();
550 
551   // 8- or 16-bit integer/floating-point numbers will require extra capabilities
552   // to appear in interface storage classes. See SPV_KHR_16bit_storage and
553   // SPV_KHR_8bit_storage for more details.
554 
555 #define STORAGE_CASE(storage, cap8, cap16)                                     \
556   case StorageClass::storage: {                                                \
557     if (bitwidth == 8) {                                                       \
558       static const Capability caps[] = {Capability::cap8};                     \
559       ArrayRef<Capability> ref(caps, std::size(caps));                         \
560       capabilities.push_back(ref);                                             \
561       return;                                                                  \
562     }                                                                          \
563     if (bitwidth == 16) {                                                      \
564       static const Capability caps[] = {Capability::cap16};                    \
565       ArrayRef<Capability> ref(caps, std::size(caps));                         \
566       capabilities.push_back(ref);                                             \
567       return;                                                                  \
568     }                                                                          \
569     /* For 64-bit integers/floats, Int64/Float64 enables support for all */    \
570     /* storage classes. Fall through to the next section. */                   \
571   } break
572 
573   // This part only handles the cases where special bitwidths appearing in
574   // interface storage classes.
575   if (storage) {
576     switch (*storage) {
577       STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
578       STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
579                    StorageBuffer16BitAccess);
580       STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
581                    StorageUniform16);
582     case StorageClass::Input:
583     case StorageClass::Output: {
584       if (bitwidth == 16) {
585         static const Capability caps[] = {Capability::StorageInputOutput16};
586         ArrayRef<Capability> ref(caps, std::size(caps));
587         capabilities.push_back(ref);
588         return;
589       }
590       break;
591     }
592     default:
593       break;
594     }
595   }
596 #undef STORAGE_CASE
597 
598   // For other non-interface storage classes, require a different set of
599   // capabilities for special bitwidths.
600 
601 #define WIDTH_CASE(type, width)                                                \
602   case width: {                                                                \
603     static const Capability caps[] = {Capability::type##width};                \
604     ArrayRef<Capability> ref(caps, std::size(caps));                           \
605     capabilities.push_back(ref);                                               \
606   } break
607 
608   if (auto intType = llvm::dyn_cast<IntegerType>(*this)) {
609     switch (bitwidth) {
610       WIDTH_CASE(Int, 8);
611       WIDTH_CASE(Int, 16);
612       WIDTH_CASE(Int, 64);
613     case 1:
614     case 32:
615       break;
616     default:
617       llvm_unreachable("invalid bitwidth to getCapabilities");
618     }
619   } else {
620     assert(llvm::isa<FloatType>(*this));
621     switch (bitwidth) {
622       WIDTH_CASE(Float, 16);
623       WIDTH_CASE(Float, 64);
624     case 32:
625       break;
626     default:
627       llvm_unreachable("invalid bitwidth to getCapabilities");
628     }
629   }
630 
631 #undef WIDTH_CASE
632 }
633 
634 std::optional<int64_t> ScalarType::getSizeInBytes() {
635   auto bitWidth = getIntOrFloatBitWidth();
636   // According to the SPIR-V spec:
637   // "There is no physical size or bit pattern defined for values with boolean
638   // type. If they are stored (in conjunction with OpVariable), they can only
639   // be used with logical addressing operations, not physical, and only with
640   // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
641   // Private, Function, Input, and Output."
642   if (bitWidth == 1)
643     return std::nullopt;
644   return bitWidth / 8;
645 }
646 
647 //===----------------------------------------------------------------------===//
648 // SPIRVType
649 //===----------------------------------------------------------------------===//
650 
651 bool SPIRVType::classof(Type type) {
652   // Allow SPIR-V dialect types
653   if (llvm::isa<SPIRVDialect>(type.getDialect()))
654     return true;
655   if (llvm::isa<ScalarType>(type))
656     return true;
657   if (auto vectorType = llvm::dyn_cast<VectorType>(type))
658     return CompositeType::isValid(vectorType);
659   return false;
660 }
661 
662 bool SPIRVType::isScalarOrVector() {
663   return isIntOrFloat() || llvm::isa<VectorType>(*this);
664 }
665 
666 void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
667                               std::optional<StorageClass> storage) {
668   if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
669     scalarType.getExtensions(extensions, storage);
670   } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
671     compositeType.getExtensions(extensions, storage);
672   } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
673     imageType.getExtensions(extensions, storage);
674   } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
675     sampledImageType.getExtensions(extensions, storage);
676   } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
677     matrixType.getExtensions(extensions, storage);
678   } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
679     ptrType.getExtensions(extensions, storage);
680   } else {
681     llvm_unreachable("invalid SPIR-V Type to getExtensions");
682   }
683 }
684 
685 void SPIRVType::getCapabilities(
686     SPIRVType::CapabilityArrayRefVector &capabilities,
687     std::optional<StorageClass> storage) {
688   if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
689     scalarType.getCapabilities(capabilities, storage);
690   } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
691     compositeType.getCapabilities(capabilities, storage);
692   } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
693     imageType.getCapabilities(capabilities, storage);
694   } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
695     sampledImageType.getCapabilities(capabilities, storage);
696   } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
697     matrixType.getCapabilities(capabilities, storage);
698   } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
699     ptrType.getCapabilities(capabilities, storage);
700   } else {
701     llvm_unreachable("invalid SPIR-V Type to getCapabilities");
702   }
703 }
704 
705 std::optional<int64_t> SPIRVType::getSizeInBytes() {
706   if (auto scalarType = llvm::dyn_cast<ScalarType>(*this))
707     return scalarType.getSizeInBytes();
708   if (auto compositeType = llvm::dyn_cast<CompositeType>(*this))
709     return compositeType.getSizeInBytes();
710   return std::nullopt;
711 }
712 
713 //===----------------------------------------------------------------------===//
714 // SampledImageType
715 //===----------------------------------------------------------------------===//
716 struct spirv::detail::SampledImageTypeStorage : public TypeStorage {
717   using KeyTy = Type;
718 
719   SampledImageTypeStorage(const KeyTy &key) : imageType{key} {}
720 
721   bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
722 
723   static SampledImageTypeStorage *construct(TypeStorageAllocator &allocator,
724                                             const KeyTy &key) {
725     return new (allocator.allocate<SampledImageTypeStorage>())
726         SampledImageTypeStorage(key);
727   }
728 
729   Type imageType;
730 };
731 
732 SampledImageType SampledImageType::get(Type imageType) {
733   return Base::get(imageType.getContext(), imageType);
734 }
735 
736 SampledImageType
737 SampledImageType::getChecked(function_ref<InFlightDiagnostic()> emitError,
738                              Type imageType) {
739   return Base::getChecked(emitError, imageType.getContext(), imageType);
740 }
741 
742 Type SampledImageType::getImageType() const { return getImpl()->imageType; }
743 
744 LogicalResult
745 SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
746                                    Type imageType) {
747   if (!llvm::isa<ImageType>(imageType))
748     return emitError() << "expected image type";
749 
750   return success();
751 }
752 
753 void SampledImageType::getExtensions(
754     SPIRVType::ExtensionArrayRefVector &extensions,
755     std::optional<StorageClass> storage) {
756   llvm::cast<ImageType>(getImageType()).getExtensions(extensions, storage);
757 }
758 
759 void SampledImageType::getCapabilities(
760     SPIRVType::CapabilityArrayRefVector &capabilities,
761     std::optional<StorageClass> storage) {
762   llvm::cast<ImageType>(getImageType()).getCapabilities(capabilities, storage);
763 }
764 
765 //===----------------------------------------------------------------------===//
766 // StructType
767 //===----------------------------------------------------------------------===//
768 
769 /// Type storage for SPIR-V structure types:
770 ///
771 /// Structures are uniqued using:
772 /// - for identified structs:
773 ///   - a string identifier;
774 /// - for literal structs:
775 ///   - a list of member types;
776 ///   - a list of member offset info;
777 ///   - a list of member decoration info.
778 ///
779 /// Identified structures only have a mutable component consisting of:
780 /// - a list of member types;
781 /// - a list of member offset info;
782 /// - a list of member decoration info.
783 struct spirv::detail::StructTypeStorage : public TypeStorage {
784   /// Construct a storage object for an identified struct type. A struct type
785   /// associated with such storage must call StructType::trySetBody(...) later
786   /// in order to mutate the storage object providing the actual content.
787   StructTypeStorage(StringRef identifier)
788       : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
789         numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
790         identifier(identifier) {}
791 
792   /// Construct a storage object for a literal struct type. A struct type
793   /// associated with such storage is immutable.
794   StructTypeStorage(
795       unsigned numMembers, Type const *memberTypes,
796       StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
797       StructType::MemberDecorationInfo const *memberDecorationsInfo)
798       : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
799         numMembers(numMembers), numMemberDecorations(numMemberDecorations),
800         memberDecorationsInfo(memberDecorationsInfo) {}
801 
802   /// A storage key is divided into 2 parts:
803   /// - for identified structs:
804   ///   - a StringRef representing the struct identifier;
805   /// - for literal structs:
806   ///   - an ArrayRef<Type> for member types;
807   ///   - an ArrayRef<StructType::OffsetInfo> for member offset info;
808   ///   - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
809   ///     info.
810   ///
811   /// An identified struct type is uniqued only by the first part (field 0)
812   /// of the key.
813   ///
814   /// A literal struct type is uniqued only by the second part (fields 1, 2, and
815   /// 3) of the key. The identifier field (field 0) must be empty.
816   using KeyTy =
817       std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
818                  ArrayRef<StructType::MemberDecorationInfo>>;
819 
820   /// For identified structs, return true if the given key contains the same
821   /// identifier.
822   ///
823   /// For literal structs, return true if the given key contains a matching list
824   /// of member types + offset info + decoration info.
825   bool operator==(const KeyTy &key) const {
826     if (isIdentified()) {
827       // Identified types are uniqued by their identifier.
828       return getIdentifier() == std::get<0>(key);
829     }
830 
831     return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
832                         getMemberDecorationsInfo());
833   }
834 
835   /// If the given key contains a non-empty identifier, this method constructs
836   /// an identified struct and leaves the rest of the struct type data to be set
837   /// through a later call to StructType::trySetBody(...).
838   ///
839   /// If, on the other hand, the key contains an empty identifier, a literal
840   /// struct is constructed using the other fields of the key.
841   static StructTypeStorage *construct(TypeStorageAllocator &allocator,
842                                       const KeyTy &key) {
843     StringRef keyIdentifier = std::get<0>(key);
844 
845     if (!keyIdentifier.empty()) {
846       StringRef identifier = allocator.copyInto(keyIdentifier);
847 
848       // Identified StructType body/members will be set through trySetBody(...)
849       // later.
850       return new (allocator.allocate<StructTypeStorage>())
851           StructTypeStorage(identifier);
852     }
853 
854     ArrayRef<Type> keyTypes = std::get<1>(key);
855 
856     // Copy the member type and layout information into the bump pointer
857     const Type *typesList = nullptr;
858     if (!keyTypes.empty()) {
859       typesList = allocator.copyInto(keyTypes).data();
860     }
861 
862     const StructType::OffsetInfo *offsetInfoList = nullptr;
863     if (!std::get<2>(key).empty()) {
864       ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
865       assert(keyOffsetInfo.size() == keyTypes.size() &&
866              "size of offset information must be same as the size of number of "
867              "elements");
868       offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
869     }
870 
871     const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
872     unsigned numMemberDecorations = 0;
873     if (!std::get<3>(key).empty()) {
874       auto keyMemberDecorations = std::get<3>(key);
875       numMemberDecorations = keyMemberDecorations.size();
876       memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
877     }
878 
879     return new (allocator.allocate<StructTypeStorage>())
880         StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
881                           numMemberDecorations, memberDecorationList);
882   }
883 
884   ArrayRef<Type> getMemberTypes() const {
885     return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
886   }
887 
888   ArrayRef<StructType::OffsetInfo> getOffsetInfo() const {
889     if (offsetInfo) {
890       return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
891     }
892     return {};
893   }
894 
895   ArrayRef<StructType::MemberDecorationInfo> getMemberDecorationsInfo() const {
896     if (memberDecorationsInfo) {
897       return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
898                                                         numMemberDecorations);
899     }
900     return {};
901   }
902 
903   StringRef getIdentifier() const { return identifier; }
904 
905   bool isIdentified() const { return !identifier.empty(); }
906 
907   /// Sets the struct type content for identified structs. Calling this method
908   /// is only valid for identified structs.
909   ///
910   /// Fails under the following conditions:
911   /// - If called for a literal struct;
912   /// - If called for an identified struct whose body was set before (through a
913   /// call to this method) but with different contents from the passed
914   /// arguments.
915   LogicalResult mutate(
916       TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
917       ArrayRef<StructType::OffsetInfo> structOffsetInfo,
918       ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
919     if (!isIdentified())
920       return failure();
921 
922     if (memberTypesAndIsBodySet.getInt() &&
923         (getMemberTypes() != structMemberTypes ||
924          getOffsetInfo() != structOffsetInfo ||
925          getMemberDecorationsInfo() != structMemberDecorationInfo))
926       return failure();
927 
928     memberTypesAndIsBodySet.setInt(true);
929     numMembers = structMemberTypes.size();
930 
931     // Copy the member type and layout information into the bump pointer.
932     if (!structMemberTypes.empty())
933       memberTypesAndIsBodySet.setPointer(
934           allocator.copyInto(structMemberTypes).data());
935 
936     if (!structOffsetInfo.empty()) {
937       assert(structOffsetInfo.size() == structMemberTypes.size() &&
938              "size of offset information must be same as the size of number of "
939              "elements");
940       offsetInfo = allocator.copyInto(structOffsetInfo).data();
941     }
942 
943     if (!structMemberDecorationInfo.empty()) {
944       numMemberDecorations = structMemberDecorationInfo.size();
945       memberDecorationsInfo =
946           allocator.copyInto(structMemberDecorationInfo).data();
947     }
948 
949     return success();
950   }
951 
952   llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
953   StructType::OffsetInfo const *offsetInfo;
954   unsigned numMembers;
955   unsigned numMemberDecorations;
956   StructType::MemberDecorationInfo const *memberDecorationsInfo;
957   StringRef identifier;
958 };
959 
960 StructType
961 StructType::get(ArrayRef<Type> memberTypes,
962                 ArrayRef<StructType::OffsetInfo> offsetInfo,
963                 ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
964   assert(!memberTypes.empty() && "Struct needs at least one member type");
965   // Sort the decorations.
966   SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
967       memberDecorations);
968   llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
969   return Base::get(memberTypes.vec().front().getContext(),
970                    /*identifier=*/StringRef(), memberTypes, offsetInfo,
971                    sortedDecorations);
972 }
973 
974 StructType StructType::getIdentified(MLIRContext *context,
975                                      StringRef identifier) {
976   assert(!identifier.empty() &&
977          "StructType identifier must be non-empty string");
978 
979   return Base::get(context, identifier, ArrayRef<Type>(),
980                    ArrayRef<StructType::OffsetInfo>(),
981                    ArrayRef<StructType::MemberDecorationInfo>());
982 }
983 
984 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
985   StructType newStructType = Base::get(
986       context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
987       ArrayRef<StructType::MemberDecorationInfo>());
988   // Set an empty body in case this is a identified struct.
989   if (newStructType.isIdentified() &&
990       failed(newStructType.trySetBody(
991           ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
992           ArrayRef<StructType::MemberDecorationInfo>())))
993     return StructType();
994 
995   return newStructType;
996 }
997 
998 StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
999 
1000 bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1001 
1002 unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1003 
1004 Type StructType::getElementType(unsigned index) const {
1005   assert(getNumElements() > index && "member index out of range");
1006   return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1007 }
1008 
1009 TypeRange StructType::getElementTypes() const {
1010   return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1011                    getNumElements());
1012 }
1013 
1014 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1015 
1016 uint64_t StructType::getMemberOffset(unsigned index) const {
1017   assert(getNumElements() > index && "member index out of range");
1018   return getImpl()->offsetInfo[index];
1019 }
1020 
1021 void StructType::getMemberDecorations(
1022     SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorations)
1023     const {
1024   memberDecorations.clear();
1025   auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1026   memberDecorations.append(implMemberDecorations.begin(),
1027                            implMemberDecorations.end());
1028 }
1029 
1030 void StructType::getMemberDecorations(
1031     unsigned index,
1032     SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1033   assert(getNumElements() > index && "member index out of range");
1034   auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1035   decorationsInfo.clear();
1036   for (const auto &memberDecoration : memberDecorations) {
1037     if (memberDecoration.memberIndex == index) {
1038       decorationsInfo.push_back(memberDecoration);
1039     }
1040     if (memberDecoration.memberIndex > index) {
1041       // Early exit since the decorations are stored sorted.
1042       return;
1043     }
1044   }
1045 }
1046 
1047 LogicalResult
1048 StructType::trySetBody(ArrayRef<Type> memberTypes,
1049                        ArrayRef<OffsetInfo> offsetInfo,
1050                        ArrayRef<MemberDecorationInfo> memberDecorations) {
1051   return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1052 }
1053 
1054 void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1055                                std::optional<StorageClass> storage) {
1056   for (Type elementType : getElementTypes())
1057     llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage);
1058 }
1059 
1060 void StructType::getCapabilities(
1061     SPIRVType::CapabilityArrayRefVector &capabilities,
1062     std::optional<StorageClass> storage) {
1063   for (Type elementType : getElementTypes())
1064     llvm::cast<SPIRVType>(elementType).getCapabilities(capabilities, storage);
1065 }
1066 
1067 llvm::hash_code spirv::hash_value(
1068     const StructType::MemberDecorationInfo &memberDecorationInfo) {
1069   return llvm::hash_combine(memberDecorationInfo.memberIndex,
1070                             memberDecorationInfo.decoration);
1071 }
1072 
1073 //===----------------------------------------------------------------------===//
1074 // MatrixType
1075 //===----------------------------------------------------------------------===//
1076 
1077 struct spirv::detail::MatrixTypeStorage : public TypeStorage {
1078   MatrixTypeStorage(Type columnType, uint32_t columnCount)
1079       : columnType(columnType), columnCount(columnCount) {}
1080 
1081   using KeyTy = std::tuple<Type, uint32_t>;
1082 
1083   static MatrixTypeStorage *construct(TypeStorageAllocator &allocator,
1084                                       const KeyTy &key) {
1085 
1086     // Initialize the memory using placement new.
1087     return new (allocator.allocate<MatrixTypeStorage>())
1088         MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1089   }
1090 
1091   bool operator==(const KeyTy &key) const {
1092     return key == KeyTy(columnType, columnCount);
1093   }
1094 
1095   Type columnType;
1096   const uint32_t columnCount;
1097 };
1098 
1099 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1100   return Base::get(columnType.getContext(), columnType, columnCount);
1101 }
1102 
1103 MatrixType MatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1104                                   Type columnType, uint32_t columnCount) {
1105   return Base::getChecked(emitError, columnType.getContext(), columnType,
1106                           columnCount);
1107 }
1108 
1109 LogicalResult
1110 MatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
1111                              Type columnType, uint32_t columnCount) {
1112   if (columnCount < 2 || columnCount > 4)
1113     return emitError() << "matrix can have 2, 3, or 4 columns only";
1114 
1115   if (!isValidColumnType(columnType))
1116     return emitError() << "matrix columns must be vectors of floats";
1117 
1118   /// The underlying vectors (columns) must be of size 2, 3, or 4
1119   ArrayRef<int64_t> columnShape = llvm::cast<VectorType>(columnType).getShape();
1120   if (columnShape.size() != 1)
1121     return emitError() << "matrix columns must be 1D vectors";
1122 
1123   if (columnShape[0] < 2 || columnShape[0] > 4)
1124     return emitError() << "matrix columns must be of size 2, 3, or 4";
1125 
1126   return success();
1127 }
1128 
1129 /// Returns true if the matrix elements are vectors of float elements
1130 bool MatrixType::isValidColumnType(Type columnType) {
1131   if (auto vectorType = llvm::dyn_cast<VectorType>(columnType)) {
1132     if (llvm::isa<FloatType>(vectorType.getElementType()))
1133       return true;
1134   }
1135   return false;
1136 }
1137 
1138 Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1139 
1140 Type MatrixType::getElementType() const {
1141   return llvm::cast<VectorType>(getImpl()->columnType).getElementType();
1142 }
1143 
1144 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1145 
1146 unsigned MatrixType::getNumRows() const {
1147   return llvm::cast<VectorType>(getImpl()->columnType).getShape()[0];
1148 }
1149 
1150 unsigned MatrixType::getNumElements() const {
1151   return (getImpl()->columnCount) * getNumRows();
1152 }
1153 
1154 void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1155                                std::optional<StorageClass> storage) {
1156   llvm::cast<SPIRVType>(getColumnType()).getExtensions(extensions, storage);
1157 }
1158 
1159 void MatrixType::getCapabilities(
1160     SPIRVType::CapabilityArrayRefVector &capabilities,
1161     std::optional<StorageClass> storage) {
1162   {
1163     static const Capability caps[] = {Capability::Matrix};
1164     ArrayRef<Capability> ref(caps, std::size(caps));
1165     capabilities.push_back(ref);
1166   }
1167   // Add any capabilities associated with the underlying vectors (i.e., columns)
1168   llvm::cast<SPIRVType>(getColumnType()).getCapabilities(capabilities, storage);
1169 }
1170 
1171 //===----------------------------------------------------------------------===//
1172 // SPIR-V Dialect
1173 //===----------------------------------------------------------------------===//
1174 
1175 void SPIRVDialect::registerTypes() {
1176   addTypes<ArrayType, CooperativeMatrixType, ImageType, MatrixType, PointerType,
1177            RuntimeArrayType, SampledImageType, StructType>();
1178 }
1179