1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types in the SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 16 #include "mlir/IR/Attributes.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "llvm/ADT/STLExtras.h" 19 #include "llvm/ADT/TypeSwitch.h" 20 21 #include <cstdint> 22 #include <iterator> 23 24 using namespace mlir; 25 using namespace mlir::spirv; 26 27 //===----------------------------------------------------------------------===// 28 // ArrayType 29 //===----------------------------------------------------------------------===// 30 31 struct spirv::detail::ArrayTypeStorage : public TypeStorage { 32 using KeyTy = std::tuple<Type, unsigned, unsigned>; 33 34 static ArrayTypeStorage *construct(TypeStorageAllocator &allocator, 35 const KeyTy &key) { 36 return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key); 37 } 38 39 bool operator==(const KeyTy &key) const { 40 return key == KeyTy(elementType, elementCount, stride); 41 } 42 43 ArrayTypeStorage(const KeyTy &key) 44 : elementType(std::get<0>(key)), elementCount(std::get<1>(key)), 45 stride(std::get<2>(key)) {} 46 47 Type elementType; 48 unsigned elementCount; 49 unsigned stride; 50 }; 51 52 ArrayType ArrayType::get(Type elementType, unsigned elementCount) { 53 assert(elementCount && "ArrayType needs at least one element"); 54 return Base::get(elementType.getContext(), elementType, elementCount, 55 /*stride=*/0); 56 } 57 58 ArrayType ArrayType::get(Type elementType, unsigned elementCount, 59 unsigned stride) { 60 assert(elementCount && "ArrayType needs at least one element"); 61 return Base::get(elementType.getContext(), elementType, elementCount, stride); 62 } 63 64 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; } 65 66 Type ArrayType::getElementType() const { return getImpl()->elementType; } 67 68 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; } 69 70 void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 71 std::optional<StorageClass> storage) { 72 llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage); 73 } 74 75 void ArrayType::getCapabilities( 76 SPIRVType::CapabilityArrayRefVector &capabilities, 77 std::optional<StorageClass> storage) { 78 llvm::cast<SPIRVType>(getElementType()) 79 .getCapabilities(capabilities, storage); 80 } 81 82 std::optional<int64_t> ArrayType::getSizeInBytes() { 83 auto elementType = llvm::cast<SPIRVType>(getElementType()); 84 std::optional<int64_t> size = elementType.getSizeInBytes(); 85 if (!size) 86 return std::nullopt; 87 return (*size + getArrayStride()) * getNumElements(); 88 } 89 90 //===----------------------------------------------------------------------===// 91 // CompositeType 92 //===----------------------------------------------------------------------===// 93 94 bool CompositeType::classof(Type type) { 95 if (auto vectorType = llvm::dyn_cast<VectorType>(type)) 96 return isValid(vectorType); 97 return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType, 98 spirv::MatrixType, spirv::RuntimeArrayType, 99 spirv::StructType>(type); 100 } 101 102 bool CompositeType::isValid(VectorType type) { 103 return type.getRank() == 1 && 104 llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) && 105 llvm::isa<ScalarType>(type.getElementType()); 106 } 107 108 Type CompositeType::getElementType(unsigned index) const { 109 return TypeSwitch<Type, Type>(*this) 110 .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType>( 111 [](auto type) { return type.getElementType(); }) 112 .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); }) 113 .Case<StructType>( 114 [index](StructType type) { return type.getElementType(index); }) 115 .Default( 116 [](Type) -> Type { llvm_unreachable("invalid composite type"); }); 117 } 118 119 unsigned CompositeType::getNumElements() const { 120 if (auto arrayType = llvm::dyn_cast<ArrayType>(*this)) 121 return arrayType.getNumElements(); 122 if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) 123 return matrixType.getNumColumns(); 124 if (auto structType = llvm::dyn_cast<StructType>(*this)) 125 return structType.getNumElements(); 126 if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) 127 return vectorType.getNumElements(); 128 if (llvm::isa<CooperativeMatrixType>(*this)) { 129 llvm_unreachable( 130 "invalid to query number of elements of spirv Cooperative Matrix type"); 131 } 132 if (llvm::isa<RuntimeArrayType>(*this)) { 133 llvm_unreachable( 134 "invalid to query number of elements of spirv::RuntimeArray type"); 135 } 136 llvm_unreachable("invalid composite type"); 137 } 138 139 bool CompositeType::hasCompileTimeKnownNumElements() const { 140 return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this); 141 } 142 143 void CompositeType::getExtensions( 144 SPIRVType::ExtensionArrayRefVector &extensions, 145 std::optional<StorageClass> storage) { 146 TypeSwitch<Type>(*this) 147 .Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType, 148 StructType>( 149 [&](auto type) { type.getExtensions(extensions, storage); }) 150 .Case<VectorType>([&](VectorType type) { 151 return llvm::cast<ScalarType>(type.getElementType()) 152 .getExtensions(extensions, storage); 153 }) 154 .Default([](Type) { llvm_unreachable("invalid composite type"); }); 155 } 156 157 void CompositeType::getCapabilities( 158 SPIRVType::CapabilityArrayRefVector &capabilities, 159 std::optional<StorageClass> storage) { 160 TypeSwitch<Type>(*this) 161 .Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType, 162 StructType>( 163 [&](auto type) { type.getCapabilities(capabilities, storage); }) 164 .Case<VectorType>([&](VectorType type) { 165 auto vecSize = getNumElements(); 166 if (vecSize == 8 || vecSize == 16) { 167 static const Capability caps[] = {Capability::Vector16}; 168 ArrayRef<Capability> ref(caps, std::size(caps)); 169 capabilities.push_back(ref); 170 } 171 return llvm::cast<ScalarType>(type.getElementType()) 172 .getCapabilities(capabilities, storage); 173 }) 174 .Default([](Type) { llvm_unreachable("invalid composite type"); }); 175 } 176 177 std::optional<int64_t> CompositeType::getSizeInBytes() { 178 if (auto arrayType = llvm::dyn_cast<ArrayType>(*this)) 179 return arrayType.getSizeInBytes(); 180 if (auto structType = llvm::dyn_cast<StructType>(*this)) 181 return structType.getSizeInBytes(); 182 if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) { 183 std::optional<int64_t> elementSize = 184 llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes(); 185 if (!elementSize) 186 return std::nullopt; 187 return *elementSize * vectorType.getNumElements(); 188 } 189 return std::nullopt; 190 } 191 192 //===----------------------------------------------------------------------===// 193 // CooperativeMatrixType 194 //===----------------------------------------------------------------------===// 195 196 struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage { 197 using KeyTy = 198 std::tuple<Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR>; 199 200 static CooperativeMatrixTypeStorage * 201 construct(TypeStorageAllocator &allocator, const KeyTy &key) { 202 return new (allocator.allocate<CooperativeMatrixTypeStorage>()) 203 CooperativeMatrixTypeStorage(key); 204 } 205 206 bool operator==(const KeyTy &key) const { 207 return key == KeyTy(elementType, rows, columns, scope, use); 208 } 209 210 CooperativeMatrixTypeStorage(const KeyTy &key) 211 : elementType(std::get<0>(key)), rows(std::get<1>(key)), 212 columns(std::get<2>(key)), scope(std::get<3>(key)), 213 use(std::get<4>(key)) {} 214 215 Type elementType; 216 uint32_t rows; 217 uint32_t columns; 218 Scope scope; 219 CooperativeMatrixUseKHR use; 220 }; 221 222 CooperativeMatrixType CooperativeMatrixType::get(Type elementType, 223 uint32_t rows, 224 uint32_t columns, Scope scope, 225 CooperativeMatrixUseKHR use) { 226 return Base::get(elementType.getContext(), elementType, rows, columns, scope, 227 use); 228 } 229 230 Type CooperativeMatrixType::getElementType() const { 231 return getImpl()->elementType; 232 } 233 234 uint32_t CooperativeMatrixType::getRows() const { return getImpl()->rows; } 235 236 uint32_t CooperativeMatrixType::getColumns() const { 237 return getImpl()->columns; 238 } 239 240 Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; } 241 242 CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const { 243 return getImpl()->use; 244 } 245 246 void CooperativeMatrixType::getExtensions( 247 SPIRVType::ExtensionArrayRefVector &extensions, 248 std::optional<StorageClass> storage) { 249 llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage); 250 static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix}; 251 extensions.push_back(exts); 252 } 253 254 void CooperativeMatrixType::getCapabilities( 255 SPIRVType::CapabilityArrayRefVector &capabilities, 256 std::optional<StorageClass> storage) { 257 llvm::cast<SPIRVType>(getElementType()) 258 .getCapabilities(capabilities, storage); 259 static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR}; 260 capabilities.push_back(caps); 261 } 262 263 //===----------------------------------------------------------------------===// 264 // ImageType 265 //===----------------------------------------------------------------------===// 266 267 template <typename T> 268 static constexpr unsigned getNumBits() { 269 return 0; 270 } 271 template <> 272 constexpr unsigned getNumBits<Dim>() { 273 static_assert((1 << 3) > getMaxEnumValForDim(), 274 "Not enough bits to encode Dim value"); 275 return 3; 276 } 277 template <> 278 constexpr unsigned getNumBits<ImageDepthInfo>() { 279 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(), 280 "Not enough bits to encode ImageDepthInfo value"); 281 return 2; 282 } 283 template <> 284 constexpr unsigned getNumBits<ImageArrayedInfo>() { 285 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(), 286 "Not enough bits to encode ImageArrayedInfo value"); 287 return 1; 288 } 289 template <> 290 constexpr unsigned getNumBits<ImageSamplingInfo>() { 291 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(), 292 "Not enough bits to encode ImageSamplingInfo value"); 293 return 1; 294 } 295 template <> 296 constexpr unsigned getNumBits<ImageSamplerUseInfo>() { 297 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(), 298 "Not enough bits to encode ImageSamplerUseInfo value"); 299 return 2; 300 } 301 template <> 302 constexpr unsigned getNumBits<ImageFormat>() { 303 static_assert((1 << 6) > getMaxEnumValForImageFormat(), 304 "Not enough bits to encode ImageFormat value"); 305 return 6; 306 } 307 308 struct spirv::detail::ImageTypeStorage : public TypeStorage { 309 public: 310 using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo, 311 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>; 312 313 static ImageTypeStorage *construct(TypeStorageAllocator &allocator, 314 const KeyTy &key) { 315 return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key); 316 } 317 318 bool operator==(const KeyTy &key) const { 319 return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo, 320 samplerUseInfo, format); 321 } 322 323 ImageTypeStorage(const KeyTy &key) 324 : elementType(std::get<0>(key)), dim(std::get<1>(key)), 325 depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)), 326 samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)), 327 format(std::get<6>(key)) {} 328 329 Type elementType; 330 Dim dim : getNumBits<Dim>(); 331 ImageDepthInfo depthInfo : getNumBits<ImageDepthInfo>(); 332 ImageArrayedInfo arrayedInfo : getNumBits<ImageArrayedInfo>(); 333 ImageSamplingInfo samplingInfo : getNumBits<ImageSamplingInfo>(); 334 ImageSamplerUseInfo samplerUseInfo : getNumBits<ImageSamplerUseInfo>(); 335 ImageFormat format : getNumBits<ImageFormat>(); 336 }; 337 338 ImageType 339 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo, 340 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat> 341 value) { 342 return Base::get(std::get<0>(value).getContext(), value); 343 } 344 345 Type ImageType::getElementType() const { return getImpl()->elementType; } 346 347 Dim ImageType::getDim() const { return getImpl()->dim; } 348 349 ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; } 350 351 ImageArrayedInfo ImageType::getArrayedInfo() const { 352 return getImpl()->arrayedInfo; 353 } 354 355 ImageSamplingInfo ImageType::getSamplingInfo() const { 356 return getImpl()->samplingInfo; 357 } 358 359 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const { 360 return getImpl()->samplerUseInfo; 361 } 362 363 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; } 364 365 void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &, 366 std::optional<StorageClass>) { 367 // Image types do not require extra extensions thus far. 368 } 369 370 void ImageType::getCapabilities( 371 SPIRVType::CapabilityArrayRefVector &capabilities, 372 std::optional<StorageClass>) { 373 if (auto dimCaps = spirv::getCapabilities(getDim())) 374 capabilities.push_back(*dimCaps); 375 376 if (auto fmtCaps = spirv::getCapabilities(getImageFormat())) 377 capabilities.push_back(*fmtCaps); 378 } 379 380 //===----------------------------------------------------------------------===// 381 // PointerType 382 //===----------------------------------------------------------------------===// 383 384 struct spirv::detail::PointerTypeStorage : public TypeStorage { 385 // (Type, StorageClass) as the key: Type stored in this struct, and 386 // StorageClass stored as TypeStorage's subclass data. 387 using KeyTy = std::pair<Type, StorageClass>; 388 389 static PointerTypeStorage *construct(TypeStorageAllocator &allocator, 390 const KeyTy &key) { 391 return new (allocator.allocate<PointerTypeStorage>()) 392 PointerTypeStorage(key); 393 } 394 395 bool operator==(const KeyTy &key) const { 396 return key == KeyTy(pointeeType, storageClass); 397 } 398 399 PointerTypeStorage(const KeyTy &key) 400 : pointeeType(key.first), storageClass(key.second) {} 401 402 Type pointeeType; 403 StorageClass storageClass; 404 }; 405 406 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) { 407 return Base::get(pointeeType.getContext(), pointeeType, storageClass); 408 } 409 410 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; } 411 412 StorageClass PointerType::getStorageClass() const { 413 return getImpl()->storageClass; 414 } 415 416 void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 417 std::optional<StorageClass> storage) { 418 // Use this pointer type's storage class because this pointer indicates we are 419 // using the pointee type in that specific storage class. 420 llvm::cast<SPIRVType>(getPointeeType()) 421 .getExtensions(extensions, getStorageClass()); 422 423 if (auto scExts = spirv::getExtensions(getStorageClass())) 424 extensions.push_back(*scExts); 425 } 426 427 void PointerType::getCapabilities( 428 SPIRVType::CapabilityArrayRefVector &capabilities, 429 std::optional<StorageClass> storage) { 430 // Use this pointer type's storage class because this pointer indicates we are 431 // using the pointee type in that specific storage class. 432 llvm::cast<SPIRVType>(getPointeeType()) 433 .getCapabilities(capabilities, getStorageClass()); 434 435 if (auto scCaps = spirv::getCapabilities(getStorageClass())) 436 capabilities.push_back(*scCaps); 437 } 438 439 //===----------------------------------------------------------------------===// 440 // RuntimeArrayType 441 //===----------------------------------------------------------------------===// 442 443 struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage { 444 using KeyTy = std::pair<Type, unsigned>; 445 446 static RuntimeArrayTypeStorage *construct(TypeStorageAllocator &allocator, 447 const KeyTy &key) { 448 return new (allocator.allocate<RuntimeArrayTypeStorage>()) 449 RuntimeArrayTypeStorage(key); 450 } 451 452 bool operator==(const KeyTy &key) const { 453 return key == KeyTy(elementType, stride); 454 } 455 456 RuntimeArrayTypeStorage(const KeyTy &key) 457 : elementType(key.first), stride(key.second) {} 458 459 Type elementType; 460 unsigned stride; 461 }; 462 463 RuntimeArrayType RuntimeArrayType::get(Type elementType) { 464 return Base::get(elementType.getContext(), elementType, /*stride=*/0); 465 } 466 467 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) { 468 return Base::get(elementType.getContext(), elementType, stride); 469 } 470 471 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; } 472 473 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; } 474 475 void RuntimeArrayType::getExtensions( 476 SPIRVType::ExtensionArrayRefVector &extensions, 477 std::optional<StorageClass> storage) { 478 llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage); 479 } 480 481 void RuntimeArrayType::getCapabilities( 482 SPIRVType::CapabilityArrayRefVector &capabilities, 483 std::optional<StorageClass> storage) { 484 { 485 static const Capability caps[] = {Capability::Shader}; 486 ArrayRef<Capability> ref(caps, std::size(caps)); 487 capabilities.push_back(ref); 488 } 489 llvm::cast<SPIRVType>(getElementType()) 490 .getCapabilities(capabilities, storage); 491 } 492 493 //===----------------------------------------------------------------------===// 494 // ScalarType 495 //===----------------------------------------------------------------------===// 496 497 bool ScalarType::classof(Type type) { 498 if (auto floatType = llvm::dyn_cast<FloatType>(type)) { 499 return isValid(floatType); 500 } 501 if (auto intType = llvm::dyn_cast<IntegerType>(type)) { 502 return isValid(intType); 503 } 504 return false; 505 } 506 507 bool ScalarType::isValid(FloatType type) { 508 return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16(); 509 } 510 511 bool ScalarType::isValid(IntegerType type) { 512 return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth()); 513 } 514 515 void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 516 std::optional<StorageClass> storage) { 517 // 8- or 16-bit integer/floating-point numbers will require extra extensions 518 // to appear in interface storage classes. See SPV_KHR_16bit_storage and 519 // SPV_KHR_8bit_storage for more details. 520 if (!storage) 521 return; 522 523 switch (*storage) { 524 case StorageClass::PushConstant: 525 case StorageClass::StorageBuffer: 526 case StorageClass::Uniform: 527 if (getIntOrFloatBitWidth() == 8) { 528 static const Extension exts[] = {Extension::SPV_KHR_8bit_storage}; 529 ArrayRef<Extension> ref(exts, std::size(exts)); 530 extensions.push_back(ref); 531 } 532 [[fallthrough]]; 533 case StorageClass::Input: 534 case StorageClass::Output: 535 if (getIntOrFloatBitWidth() == 16) { 536 static const Extension exts[] = {Extension::SPV_KHR_16bit_storage}; 537 ArrayRef<Extension> ref(exts, std::size(exts)); 538 extensions.push_back(ref); 539 } 540 break; 541 default: 542 break; 543 } 544 } 545 546 void ScalarType::getCapabilities( 547 SPIRVType::CapabilityArrayRefVector &capabilities, 548 std::optional<StorageClass> storage) { 549 unsigned bitwidth = getIntOrFloatBitWidth(); 550 551 // 8- or 16-bit integer/floating-point numbers will require extra capabilities 552 // to appear in interface storage classes. See SPV_KHR_16bit_storage and 553 // SPV_KHR_8bit_storage for more details. 554 555 #define STORAGE_CASE(storage, cap8, cap16) \ 556 case StorageClass::storage: { \ 557 if (bitwidth == 8) { \ 558 static const Capability caps[] = {Capability::cap8}; \ 559 ArrayRef<Capability> ref(caps, std::size(caps)); \ 560 capabilities.push_back(ref); \ 561 return; \ 562 } \ 563 if (bitwidth == 16) { \ 564 static const Capability caps[] = {Capability::cap16}; \ 565 ArrayRef<Capability> ref(caps, std::size(caps)); \ 566 capabilities.push_back(ref); \ 567 return; \ 568 } \ 569 /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \ 570 /* storage classes. Fall through to the next section. */ \ 571 } break 572 573 // This part only handles the cases where special bitwidths appearing in 574 // interface storage classes. 575 if (storage) { 576 switch (*storage) { 577 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16); 578 STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess, 579 StorageBuffer16BitAccess); 580 STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess, 581 StorageUniform16); 582 case StorageClass::Input: 583 case StorageClass::Output: { 584 if (bitwidth == 16) { 585 static const Capability caps[] = {Capability::StorageInputOutput16}; 586 ArrayRef<Capability> ref(caps, std::size(caps)); 587 capabilities.push_back(ref); 588 return; 589 } 590 break; 591 } 592 default: 593 break; 594 } 595 } 596 #undef STORAGE_CASE 597 598 // For other non-interface storage classes, require a different set of 599 // capabilities for special bitwidths. 600 601 #define WIDTH_CASE(type, width) \ 602 case width: { \ 603 static const Capability caps[] = {Capability::type##width}; \ 604 ArrayRef<Capability> ref(caps, std::size(caps)); \ 605 capabilities.push_back(ref); \ 606 } break 607 608 if (auto intType = llvm::dyn_cast<IntegerType>(*this)) { 609 switch (bitwidth) { 610 WIDTH_CASE(Int, 8); 611 WIDTH_CASE(Int, 16); 612 WIDTH_CASE(Int, 64); 613 case 1: 614 case 32: 615 break; 616 default: 617 llvm_unreachable("invalid bitwidth to getCapabilities"); 618 } 619 } else { 620 assert(llvm::isa<FloatType>(*this)); 621 switch (bitwidth) { 622 WIDTH_CASE(Float, 16); 623 WIDTH_CASE(Float, 64); 624 case 32: 625 break; 626 default: 627 llvm_unreachable("invalid bitwidth to getCapabilities"); 628 } 629 } 630 631 #undef WIDTH_CASE 632 } 633 634 std::optional<int64_t> ScalarType::getSizeInBytes() { 635 auto bitWidth = getIntOrFloatBitWidth(); 636 // According to the SPIR-V spec: 637 // "There is no physical size or bit pattern defined for values with boolean 638 // type. If they are stored (in conjunction with OpVariable), they can only 639 // be used with logical addressing operations, not physical, and only with 640 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup, 641 // Private, Function, Input, and Output." 642 if (bitWidth == 1) 643 return std::nullopt; 644 return bitWidth / 8; 645 } 646 647 //===----------------------------------------------------------------------===// 648 // SPIRVType 649 //===----------------------------------------------------------------------===// 650 651 bool SPIRVType::classof(Type type) { 652 // Allow SPIR-V dialect types 653 if (llvm::isa<SPIRVDialect>(type.getDialect())) 654 return true; 655 if (llvm::isa<ScalarType>(type)) 656 return true; 657 if (auto vectorType = llvm::dyn_cast<VectorType>(type)) 658 return CompositeType::isValid(vectorType); 659 return false; 660 } 661 662 bool SPIRVType::isScalarOrVector() { 663 return isIntOrFloat() || llvm::isa<VectorType>(*this); 664 } 665 666 void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 667 std::optional<StorageClass> storage) { 668 if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) { 669 scalarType.getExtensions(extensions, storage); 670 } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) { 671 compositeType.getExtensions(extensions, storage); 672 } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) { 673 imageType.getExtensions(extensions, storage); 674 } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) { 675 sampledImageType.getExtensions(extensions, storage); 676 } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) { 677 matrixType.getExtensions(extensions, storage); 678 } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) { 679 ptrType.getExtensions(extensions, storage); 680 } else { 681 llvm_unreachable("invalid SPIR-V Type to getExtensions"); 682 } 683 } 684 685 void SPIRVType::getCapabilities( 686 SPIRVType::CapabilityArrayRefVector &capabilities, 687 std::optional<StorageClass> storage) { 688 if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) { 689 scalarType.getCapabilities(capabilities, storage); 690 } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) { 691 compositeType.getCapabilities(capabilities, storage); 692 } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) { 693 imageType.getCapabilities(capabilities, storage); 694 } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) { 695 sampledImageType.getCapabilities(capabilities, storage); 696 } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) { 697 matrixType.getCapabilities(capabilities, storage); 698 } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) { 699 ptrType.getCapabilities(capabilities, storage); 700 } else { 701 llvm_unreachable("invalid SPIR-V Type to getCapabilities"); 702 } 703 } 704 705 std::optional<int64_t> SPIRVType::getSizeInBytes() { 706 if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) 707 return scalarType.getSizeInBytes(); 708 if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) 709 return compositeType.getSizeInBytes(); 710 return std::nullopt; 711 } 712 713 //===----------------------------------------------------------------------===// 714 // SampledImageType 715 //===----------------------------------------------------------------------===// 716 struct spirv::detail::SampledImageTypeStorage : public TypeStorage { 717 using KeyTy = Type; 718 719 SampledImageTypeStorage(const KeyTy &key) : imageType{key} {} 720 721 bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); } 722 723 static SampledImageTypeStorage *construct(TypeStorageAllocator &allocator, 724 const KeyTy &key) { 725 return new (allocator.allocate<SampledImageTypeStorage>()) 726 SampledImageTypeStorage(key); 727 } 728 729 Type imageType; 730 }; 731 732 SampledImageType SampledImageType::get(Type imageType) { 733 return Base::get(imageType.getContext(), imageType); 734 } 735 736 SampledImageType 737 SampledImageType::getChecked(function_ref<InFlightDiagnostic()> emitError, 738 Type imageType) { 739 return Base::getChecked(emitError, imageType.getContext(), imageType); 740 } 741 742 Type SampledImageType::getImageType() const { return getImpl()->imageType; } 743 744 LogicalResult 745 SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError, 746 Type imageType) { 747 if (!llvm::isa<ImageType>(imageType)) 748 return emitError() << "expected image type"; 749 750 return success(); 751 } 752 753 void SampledImageType::getExtensions( 754 SPIRVType::ExtensionArrayRefVector &extensions, 755 std::optional<StorageClass> storage) { 756 llvm::cast<ImageType>(getImageType()).getExtensions(extensions, storage); 757 } 758 759 void SampledImageType::getCapabilities( 760 SPIRVType::CapabilityArrayRefVector &capabilities, 761 std::optional<StorageClass> storage) { 762 llvm::cast<ImageType>(getImageType()).getCapabilities(capabilities, storage); 763 } 764 765 //===----------------------------------------------------------------------===// 766 // StructType 767 //===----------------------------------------------------------------------===// 768 769 /// Type storage for SPIR-V structure types: 770 /// 771 /// Structures are uniqued using: 772 /// - for identified structs: 773 /// - a string identifier; 774 /// - for literal structs: 775 /// - a list of member types; 776 /// - a list of member offset info; 777 /// - a list of member decoration info. 778 /// 779 /// Identified structures only have a mutable component consisting of: 780 /// - a list of member types; 781 /// - a list of member offset info; 782 /// - a list of member decoration info. 783 struct spirv::detail::StructTypeStorage : public TypeStorage { 784 /// Construct a storage object for an identified struct type. A struct type 785 /// associated with such storage must call StructType::trySetBody(...) later 786 /// in order to mutate the storage object providing the actual content. 787 StructTypeStorage(StringRef identifier) 788 : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr), 789 numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr), 790 identifier(identifier) {} 791 792 /// Construct a storage object for a literal struct type. A struct type 793 /// associated with such storage is immutable. 794 StructTypeStorage( 795 unsigned numMembers, Type const *memberTypes, 796 StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, 797 StructType::MemberDecorationInfo const *memberDecorationsInfo) 798 : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo), 799 numMembers(numMembers), numMemberDecorations(numMemberDecorations), 800 memberDecorationsInfo(memberDecorationsInfo) {} 801 802 /// A storage key is divided into 2 parts: 803 /// - for identified structs: 804 /// - a StringRef representing the struct identifier; 805 /// - for literal structs: 806 /// - an ArrayRef<Type> for member types; 807 /// - an ArrayRef<StructType::OffsetInfo> for member offset info; 808 /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration 809 /// info. 810 /// 811 /// An identified struct type is uniqued only by the first part (field 0) 812 /// of the key. 813 /// 814 /// A literal struct type is uniqued only by the second part (fields 1, 2, and 815 /// 3) of the key. The identifier field (field 0) must be empty. 816 using KeyTy = 817 std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>, 818 ArrayRef<StructType::MemberDecorationInfo>>; 819 820 /// For identified structs, return true if the given key contains the same 821 /// identifier. 822 /// 823 /// For literal structs, return true if the given key contains a matching list 824 /// of member types + offset info + decoration info. 825 bool operator==(const KeyTy &key) const { 826 if (isIdentified()) { 827 // Identified types are uniqued by their identifier. 828 return getIdentifier() == std::get<0>(key); 829 } 830 831 return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(), 832 getMemberDecorationsInfo()); 833 } 834 835 /// If the given key contains a non-empty identifier, this method constructs 836 /// an identified struct and leaves the rest of the struct type data to be set 837 /// through a later call to StructType::trySetBody(...). 838 /// 839 /// If, on the other hand, the key contains an empty identifier, a literal 840 /// struct is constructed using the other fields of the key. 841 static StructTypeStorage *construct(TypeStorageAllocator &allocator, 842 const KeyTy &key) { 843 StringRef keyIdentifier = std::get<0>(key); 844 845 if (!keyIdentifier.empty()) { 846 StringRef identifier = allocator.copyInto(keyIdentifier); 847 848 // Identified StructType body/members will be set through trySetBody(...) 849 // later. 850 return new (allocator.allocate<StructTypeStorage>()) 851 StructTypeStorage(identifier); 852 } 853 854 ArrayRef<Type> keyTypes = std::get<1>(key); 855 856 // Copy the member type and layout information into the bump pointer 857 const Type *typesList = nullptr; 858 if (!keyTypes.empty()) { 859 typesList = allocator.copyInto(keyTypes).data(); 860 } 861 862 const StructType::OffsetInfo *offsetInfoList = nullptr; 863 if (!std::get<2>(key).empty()) { 864 ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key); 865 assert(keyOffsetInfo.size() == keyTypes.size() && 866 "size of offset information must be same as the size of number of " 867 "elements"); 868 offsetInfoList = allocator.copyInto(keyOffsetInfo).data(); 869 } 870 871 const StructType::MemberDecorationInfo *memberDecorationList = nullptr; 872 unsigned numMemberDecorations = 0; 873 if (!std::get<3>(key).empty()) { 874 auto keyMemberDecorations = std::get<3>(key); 875 numMemberDecorations = keyMemberDecorations.size(); 876 memberDecorationList = allocator.copyInto(keyMemberDecorations).data(); 877 } 878 879 return new (allocator.allocate<StructTypeStorage>()) 880 StructTypeStorage(keyTypes.size(), typesList, offsetInfoList, 881 numMemberDecorations, memberDecorationList); 882 } 883 884 ArrayRef<Type> getMemberTypes() const { 885 return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers); 886 } 887 888 ArrayRef<StructType::OffsetInfo> getOffsetInfo() const { 889 if (offsetInfo) { 890 return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers); 891 } 892 return {}; 893 } 894 895 ArrayRef<StructType::MemberDecorationInfo> getMemberDecorationsInfo() const { 896 if (memberDecorationsInfo) { 897 return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo, 898 numMemberDecorations); 899 } 900 return {}; 901 } 902 903 StringRef getIdentifier() const { return identifier; } 904 905 bool isIdentified() const { return !identifier.empty(); } 906 907 /// Sets the struct type content for identified structs. Calling this method 908 /// is only valid for identified structs. 909 /// 910 /// Fails under the following conditions: 911 /// - If called for a literal struct; 912 /// - If called for an identified struct whose body was set before (through a 913 /// call to this method) but with different contents from the passed 914 /// arguments. 915 LogicalResult mutate( 916 TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes, 917 ArrayRef<StructType::OffsetInfo> structOffsetInfo, 918 ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) { 919 if (!isIdentified()) 920 return failure(); 921 922 if (memberTypesAndIsBodySet.getInt() && 923 (getMemberTypes() != structMemberTypes || 924 getOffsetInfo() != structOffsetInfo || 925 getMemberDecorationsInfo() != structMemberDecorationInfo)) 926 return failure(); 927 928 memberTypesAndIsBodySet.setInt(true); 929 numMembers = structMemberTypes.size(); 930 931 // Copy the member type and layout information into the bump pointer. 932 if (!structMemberTypes.empty()) 933 memberTypesAndIsBodySet.setPointer( 934 allocator.copyInto(structMemberTypes).data()); 935 936 if (!structOffsetInfo.empty()) { 937 assert(structOffsetInfo.size() == structMemberTypes.size() && 938 "size of offset information must be same as the size of number of " 939 "elements"); 940 offsetInfo = allocator.copyInto(structOffsetInfo).data(); 941 } 942 943 if (!structMemberDecorationInfo.empty()) { 944 numMemberDecorations = structMemberDecorationInfo.size(); 945 memberDecorationsInfo = 946 allocator.copyInto(structMemberDecorationInfo).data(); 947 } 948 949 return success(); 950 } 951 952 llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet; 953 StructType::OffsetInfo const *offsetInfo; 954 unsigned numMembers; 955 unsigned numMemberDecorations; 956 StructType::MemberDecorationInfo const *memberDecorationsInfo; 957 StringRef identifier; 958 }; 959 960 StructType 961 StructType::get(ArrayRef<Type> memberTypes, 962 ArrayRef<StructType::OffsetInfo> offsetInfo, 963 ArrayRef<StructType::MemberDecorationInfo> memberDecorations) { 964 assert(!memberTypes.empty() && "Struct needs at least one member type"); 965 // Sort the decorations. 966 SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations( 967 memberDecorations); 968 llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end()); 969 return Base::get(memberTypes.vec().front().getContext(), 970 /*identifier=*/StringRef(), memberTypes, offsetInfo, 971 sortedDecorations); 972 } 973 974 StructType StructType::getIdentified(MLIRContext *context, 975 StringRef identifier) { 976 assert(!identifier.empty() && 977 "StructType identifier must be non-empty string"); 978 979 return Base::get(context, identifier, ArrayRef<Type>(), 980 ArrayRef<StructType::OffsetInfo>(), 981 ArrayRef<StructType::MemberDecorationInfo>()); 982 } 983 984 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) { 985 StructType newStructType = Base::get( 986 context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(), 987 ArrayRef<StructType::MemberDecorationInfo>()); 988 // Set an empty body in case this is a identified struct. 989 if (newStructType.isIdentified() && 990 failed(newStructType.trySetBody( 991 ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(), 992 ArrayRef<StructType::MemberDecorationInfo>()))) 993 return StructType(); 994 995 return newStructType; 996 } 997 998 StringRef StructType::getIdentifier() const { return getImpl()->identifier; } 999 1000 bool StructType::isIdentified() const { return getImpl()->isIdentified(); } 1001 1002 unsigned StructType::getNumElements() const { return getImpl()->numMembers; } 1003 1004 Type StructType::getElementType(unsigned index) const { 1005 assert(getNumElements() > index && "member index out of range"); 1006 return getImpl()->memberTypesAndIsBodySet.getPointer()[index]; 1007 } 1008 1009 TypeRange StructType::getElementTypes() const { 1010 return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(), 1011 getNumElements()); 1012 } 1013 1014 bool StructType::hasOffset() const { return getImpl()->offsetInfo; } 1015 1016 uint64_t StructType::getMemberOffset(unsigned index) const { 1017 assert(getNumElements() > index && "member index out of range"); 1018 return getImpl()->offsetInfo[index]; 1019 } 1020 1021 void StructType::getMemberDecorations( 1022 SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorations) 1023 const { 1024 memberDecorations.clear(); 1025 auto implMemberDecorations = getImpl()->getMemberDecorationsInfo(); 1026 memberDecorations.append(implMemberDecorations.begin(), 1027 implMemberDecorations.end()); 1028 } 1029 1030 void StructType::getMemberDecorations( 1031 unsigned index, 1032 SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const { 1033 assert(getNumElements() > index && "member index out of range"); 1034 auto memberDecorations = getImpl()->getMemberDecorationsInfo(); 1035 decorationsInfo.clear(); 1036 for (const auto &memberDecoration : memberDecorations) { 1037 if (memberDecoration.memberIndex == index) { 1038 decorationsInfo.push_back(memberDecoration); 1039 } 1040 if (memberDecoration.memberIndex > index) { 1041 // Early exit since the decorations are stored sorted. 1042 return; 1043 } 1044 } 1045 } 1046 1047 LogicalResult 1048 StructType::trySetBody(ArrayRef<Type> memberTypes, 1049 ArrayRef<OffsetInfo> offsetInfo, 1050 ArrayRef<MemberDecorationInfo> memberDecorations) { 1051 return Base::mutate(memberTypes, offsetInfo, memberDecorations); 1052 } 1053 1054 void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 1055 std::optional<StorageClass> storage) { 1056 for (Type elementType : getElementTypes()) 1057 llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage); 1058 } 1059 1060 void StructType::getCapabilities( 1061 SPIRVType::CapabilityArrayRefVector &capabilities, 1062 std::optional<StorageClass> storage) { 1063 for (Type elementType : getElementTypes()) 1064 llvm::cast<SPIRVType>(elementType).getCapabilities(capabilities, storage); 1065 } 1066 1067 llvm::hash_code spirv::hash_value( 1068 const StructType::MemberDecorationInfo &memberDecorationInfo) { 1069 return llvm::hash_combine(memberDecorationInfo.memberIndex, 1070 memberDecorationInfo.decoration); 1071 } 1072 1073 //===----------------------------------------------------------------------===// 1074 // MatrixType 1075 //===----------------------------------------------------------------------===// 1076 1077 struct spirv::detail::MatrixTypeStorage : public TypeStorage { 1078 MatrixTypeStorage(Type columnType, uint32_t columnCount) 1079 : columnType(columnType), columnCount(columnCount) {} 1080 1081 using KeyTy = std::tuple<Type, uint32_t>; 1082 1083 static MatrixTypeStorage *construct(TypeStorageAllocator &allocator, 1084 const KeyTy &key) { 1085 1086 // Initialize the memory using placement new. 1087 return new (allocator.allocate<MatrixTypeStorage>()) 1088 MatrixTypeStorage(std::get<0>(key), std::get<1>(key)); 1089 } 1090 1091 bool operator==(const KeyTy &key) const { 1092 return key == KeyTy(columnType, columnCount); 1093 } 1094 1095 Type columnType; 1096 const uint32_t columnCount; 1097 }; 1098 1099 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) { 1100 return Base::get(columnType.getContext(), columnType, columnCount); 1101 } 1102 1103 MatrixType MatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError, 1104 Type columnType, uint32_t columnCount) { 1105 return Base::getChecked(emitError, columnType.getContext(), columnType, 1106 columnCount); 1107 } 1108 1109 LogicalResult 1110 MatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError, 1111 Type columnType, uint32_t columnCount) { 1112 if (columnCount < 2 || columnCount > 4) 1113 return emitError() << "matrix can have 2, 3, or 4 columns only"; 1114 1115 if (!isValidColumnType(columnType)) 1116 return emitError() << "matrix columns must be vectors of floats"; 1117 1118 /// The underlying vectors (columns) must be of size 2, 3, or 4 1119 ArrayRef<int64_t> columnShape = llvm::cast<VectorType>(columnType).getShape(); 1120 if (columnShape.size() != 1) 1121 return emitError() << "matrix columns must be 1D vectors"; 1122 1123 if (columnShape[0] < 2 || columnShape[0] > 4) 1124 return emitError() << "matrix columns must be of size 2, 3, or 4"; 1125 1126 return success(); 1127 } 1128 1129 /// Returns true if the matrix elements are vectors of float elements 1130 bool MatrixType::isValidColumnType(Type columnType) { 1131 if (auto vectorType = llvm::dyn_cast<VectorType>(columnType)) { 1132 if (llvm::isa<FloatType>(vectorType.getElementType())) 1133 return true; 1134 } 1135 return false; 1136 } 1137 1138 Type MatrixType::getColumnType() const { return getImpl()->columnType; } 1139 1140 Type MatrixType::getElementType() const { 1141 return llvm::cast<VectorType>(getImpl()->columnType).getElementType(); 1142 } 1143 1144 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; } 1145 1146 unsigned MatrixType::getNumRows() const { 1147 return llvm::cast<VectorType>(getImpl()->columnType).getShape()[0]; 1148 } 1149 1150 unsigned MatrixType::getNumElements() const { 1151 return (getImpl()->columnCount) * getNumRows(); 1152 } 1153 1154 void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 1155 std::optional<StorageClass> storage) { 1156 llvm::cast<SPIRVType>(getColumnType()).getExtensions(extensions, storage); 1157 } 1158 1159 void MatrixType::getCapabilities( 1160 SPIRVType::CapabilityArrayRefVector &capabilities, 1161 std::optional<StorageClass> storage) { 1162 { 1163 static const Capability caps[] = {Capability::Matrix}; 1164 ArrayRef<Capability> ref(caps, std::size(caps)); 1165 capabilities.push_back(ref); 1166 } 1167 // Add any capabilities associated with the underlying vectors (i.e., columns) 1168 llvm::cast<SPIRVType>(getColumnType()).getCapabilities(capabilities, storage); 1169 } 1170 1171 //===----------------------------------------------------------------------===// 1172 // SPIR-V Dialect 1173 //===----------------------------------------------------------------------===// 1174 1175 void SPIRVDialect::registerTypes() { 1176 addTypes<ArrayType, CooperativeMatrixType, ImageType, MatrixType, PointerType, 1177 RuntimeArrayType, SampledImageType, StructType>(); 1178 } 1179