1 //===- SPIRVTypes.h - MLIR SPIR-V Types -------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file declares the types in the SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_ 14 #define MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_ 15 16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/Diagnostics.h" 19 #include "mlir/IR/Location.h" 20 #include "mlir/IR/TypeSupport.h" 21 #include "mlir/IR/Types.h" 22 23 #include <cstdint> 24 #include <tuple> 25 26 namespace mlir { 27 namespace spirv { 28 29 namespace detail { 30 struct ArrayTypeStorage; 31 struct CooperativeMatrixTypeStorage; 32 struct ImageTypeStorage; 33 struct MatrixTypeStorage; 34 struct PointerTypeStorage; 35 struct RuntimeArrayTypeStorage; 36 struct SampledImageTypeStorage; 37 struct StructTypeStorage; 38 39 } // namespace detail 40 41 // Base SPIR-V type for providing availability queries. 42 class SPIRVType : public Type { 43 public: 44 using Type::Type; 45 46 static bool classof(Type type); 47 48 bool isScalarOrVector(); 49 50 /// The extension requirements for each type are following the 51 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) 52 /// convention. 53 using ExtensionArrayRefVector = SmallVectorImpl<ArrayRef<Extension>>; 54 55 /// Appends to `extensions` the extensions needed for this type to appear in 56 /// the given `storage` class. This method does not guarantee the uniqueness 57 /// of extensions; the same extension may be appended multiple times. 58 void getExtensions(ExtensionArrayRefVector &extensions, 59 std::optional<StorageClass> storage = std::nullopt); 60 61 /// The capability requirements for each type are following the 62 /// ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D)) 63 /// convention. 64 using CapabilityArrayRefVector = SmallVectorImpl<ArrayRef<Capability>>; 65 66 /// Appends to `capabilities` the capabilities needed for this type to appear 67 /// in the given `storage` class. This method does not guarantee the 68 /// uniqueness of capabilities; the same capability may be appended multiple 69 /// times. 70 void getCapabilities(CapabilityArrayRefVector &capabilities, 71 std::optional<StorageClass> storage = std::nullopt); 72 73 /// Returns the size in bytes for each type. If no size can be calculated, 74 /// returns `std::nullopt`. Note that if the type has explicit layout, it is 75 /// also taken into account in calculation. 76 std::optional<int64_t> getSizeInBytes(); 77 }; 78 79 // SPIR-V scalar type: bool type, integer type, floating point type. 80 class ScalarType : public SPIRVType { 81 public: 82 using SPIRVType::SPIRVType; 83 84 static bool classof(Type type); 85 86 /// Returns true if the given integer type is valid for the SPIR-V dialect. 87 static bool isValid(FloatType); 88 /// Returns true if the given float type is valid for the SPIR-V dialect. 89 static bool isValid(IntegerType); 90 91 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 92 std::optional<StorageClass> storage = std::nullopt); 93 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 94 std::optional<StorageClass> storage = std::nullopt); 95 96 std::optional<int64_t> getSizeInBytes(); 97 }; 98 99 // SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType. 100 class CompositeType : public SPIRVType { 101 public: 102 using SPIRVType::SPIRVType; 103 104 static bool classof(Type type); 105 106 /// Returns true if the given vector type is valid for the SPIR-V dialect. 107 static bool isValid(VectorType); 108 109 /// Return the number of elements of the type. This should only be called if 110 /// hasCompileTimeKnownNumElements is true. 111 unsigned getNumElements() const; 112 113 Type getElementType(unsigned) const; 114 115 /// Return true if the number of elements is known at compile time and is not 116 /// implementation dependent. 117 bool hasCompileTimeKnownNumElements() const; 118 119 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 120 std::optional<StorageClass> storage = std::nullopt); 121 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 122 std::optional<StorageClass> storage = std::nullopt); 123 124 std::optional<int64_t> getSizeInBytes(); 125 }; 126 127 // SPIR-V array type 128 class ArrayType : public Type::TypeBase<ArrayType, CompositeType, 129 detail::ArrayTypeStorage> { 130 public: 131 using Base::Base; 132 133 static constexpr StringLiteral name = "spirv.array"; 134 135 static ArrayType get(Type elementType, unsigned elementCount); 136 137 /// Returns an array type with the given stride in bytes. 138 static ArrayType get(Type elementType, unsigned elementCount, 139 unsigned stride); 140 141 unsigned getNumElements() const; 142 143 Type getElementType() const; 144 145 /// Returns the array stride in bytes. 0 means no stride decorated on this 146 /// type. 147 unsigned getArrayStride() const; 148 149 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 150 std::optional<StorageClass> storage = std::nullopt); 151 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 152 std::optional<StorageClass> storage = std::nullopt); 153 154 /// Returns the array size in bytes. Since array type may have an explicit 155 /// stride declaration (in bytes), we also include it in the calculation. 156 std::optional<int64_t> getSizeInBytes(); 157 }; 158 159 // SPIR-V image type 160 class ImageType 161 : public Type::TypeBase<ImageType, SPIRVType, detail::ImageTypeStorage> { 162 public: 163 using Base::Base; 164 165 static constexpr StringLiteral name = "spirv.image"; 166 167 static ImageType 168 get(Type elementType, Dim dim, 169 ImageDepthInfo depth = ImageDepthInfo::DepthUnknown, 170 ImageArrayedInfo arrayed = ImageArrayedInfo::NonArrayed, 171 ImageSamplingInfo samplingInfo = ImageSamplingInfo::SingleSampled, 172 ImageSamplerUseInfo samplerUse = ImageSamplerUseInfo::SamplerUnknown, 173 ImageFormat format = ImageFormat::Unknown) { 174 return ImageType::get( 175 std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo, 176 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>( 177 elementType, dim, depth, arrayed, samplingInfo, samplerUse, 178 format)); 179 } 180 181 static ImageType 182 get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo, 183 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>); 184 185 Type getElementType() const; 186 Dim getDim() const; 187 ImageDepthInfo getDepthInfo() const; 188 ImageArrayedInfo getArrayedInfo() const; 189 ImageSamplingInfo getSamplingInfo() const; 190 ImageSamplerUseInfo getSamplerUseInfo() const; 191 ImageFormat getImageFormat() const; 192 // TODO: Add support for Access qualifier 193 194 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 195 std::optional<StorageClass> storage = std::nullopt); 196 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 197 std::optional<StorageClass> storage = std::nullopt); 198 }; 199 200 // SPIR-V pointer type 201 class PointerType : public Type::TypeBase<PointerType, SPIRVType, 202 detail::PointerTypeStorage> { 203 public: 204 using Base::Base; 205 206 static constexpr StringLiteral name = "spirv.pointer"; 207 208 static PointerType get(Type pointeeType, StorageClass storageClass); 209 210 Type getPointeeType() const; 211 212 StorageClass getStorageClass() const; 213 214 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 215 std::optional<StorageClass> storage = std::nullopt); 216 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 217 std::optional<StorageClass> storage = std::nullopt); 218 }; 219 220 // SPIR-V run-time array type 221 class RuntimeArrayType 222 : public Type::TypeBase<RuntimeArrayType, SPIRVType, 223 detail::RuntimeArrayTypeStorage> { 224 public: 225 using Base::Base; 226 227 static constexpr StringLiteral name = "spirv.rtarray"; 228 229 static RuntimeArrayType get(Type elementType); 230 231 /// Returns a runtime array type with the given stride in bytes. 232 static RuntimeArrayType get(Type elementType, unsigned stride); 233 234 Type getElementType() const; 235 236 /// Returns the array stride in bytes. 0 means no stride decorated on this 237 /// type. 238 unsigned getArrayStride() const; 239 240 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 241 std::optional<StorageClass> storage = std::nullopt); 242 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 243 std::optional<StorageClass> storage = std::nullopt); 244 }; 245 246 // SPIR-V sampled image type 247 class SampledImageType 248 : public Type::TypeBase<SampledImageType, SPIRVType, 249 detail::SampledImageTypeStorage> { 250 public: 251 using Base::Base; 252 253 static constexpr StringLiteral name = "spirv.sampled_image"; 254 255 static SampledImageType get(Type imageType); 256 257 static SampledImageType 258 getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType); 259 260 static LogicalResult 261 verifyInvariants(function_ref<InFlightDiagnostic()> emitError, 262 Type imageType); 263 264 Type getImageType() const; 265 266 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 267 std::optional<spirv::StorageClass> storage = std::nullopt); 268 void 269 getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 270 std::optional<spirv::StorageClass> storage = std::nullopt); 271 }; 272 273 /// SPIR-V struct type. Two kinds of struct types are supported: 274 /// - Literal: a literal struct type is uniqued by its fields (types + offset 275 /// info + decoration info). 276 /// - Identified: an indentified struct type is uniqued by its string identifier 277 /// (name). This is useful in representing recursive structs. For example, the 278 /// following C struct: 279 /// 280 /// struct A { 281 /// A* next; 282 /// }; 283 /// 284 /// would be represented in MLIR as: 285 /// 286 /// !spirv.struct<A, (!spirv.ptr<!spirv.struct<A>, Generic>)> 287 /// 288 /// In the above, expressing recursive struct types is accomplished by giving a 289 /// recursive struct a unique identified and using that identifier in the struct 290 /// definition for recursive references. 291 class StructType 292 : public Type::TypeBase<StructType, CompositeType, 293 detail::StructTypeStorage, TypeTrait::IsMutable> { 294 public: 295 using Base::Base; 296 297 // Type for specifying the offset of the struct members 298 using OffsetInfo = uint32_t; 299 300 static constexpr StringLiteral name = "spirv.struct"; 301 302 // Type for specifying the decoration(s) on struct members 303 struct MemberDecorationInfo { 304 uint32_t memberIndex : 31; 305 uint32_t hasValue : 1; 306 Decoration decoration; 307 uint32_t decorationValue; 308 309 MemberDecorationInfo(uint32_t index, uint32_t hasValue, 310 Decoration decoration, uint32_t decorationValue) 311 : memberIndex(index), hasValue(hasValue), decoration(decoration), 312 decorationValue(decorationValue) {} 313 314 bool operator==(const MemberDecorationInfo &other) const { 315 return (this->memberIndex == other.memberIndex) && 316 (this->decoration == other.decoration) && 317 (this->decorationValue == other.decorationValue); 318 } 319 320 bool operator<(const MemberDecorationInfo &other) const { 321 return this->memberIndex < other.memberIndex || 322 (this->memberIndex == other.memberIndex && 323 static_cast<uint32_t>(this->decoration) < 324 static_cast<uint32_t>(other.decoration)); 325 } 326 }; 327 328 /// Construct a literal StructType with at least one member. 329 static StructType get(ArrayRef<Type> memberTypes, 330 ArrayRef<OffsetInfo> offsetInfo = {}, 331 ArrayRef<MemberDecorationInfo> memberDecorations = {}); 332 333 /// Construct an identified StructType. This creates a StructType whose body 334 /// (member types, offset info, and decorations) is not set yet. A call to 335 /// StructType::trySetBody(...) must follow when the StructType contents are 336 /// available (e.g. parsed or deserialized). 337 /// 338 /// Note: If another thread creates (or had already created) a struct with the 339 /// same identifier, that struct will be returned as a result. 340 static StructType getIdentified(MLIRContext *context, StringRef identifier); 341 342 /// Construct a (possibly identified) StructType with no members. 343 /// 344 /// Note: this method might fail in a multi-threaded setup if another thread 345 /// created an identified struct with the same identifier but with different 346 /// contents before returning. In which case, an empty (default-constructed) 347 /// StructType is returned. 348 static StructType getEmpty(MLIRContext *context, StringRef identifier = ""); 349 350 /// For literal structs, return an empty string. 351 /// For identified structs, return the struct's identifier. 352 StringRef getIdentifier() const; 353 354 /// Returns true if the StructType is identified. 355 bool isIdentified() const; 356 357 unsigned getNumElements() const; 358 359 Type getElementType(unsigned) const; 360 361 TypeRange getElementTypes() const; 362 363 bool hasOffset() const; 364 365 uint64_t getMemberOffset(unsigned) const; 366 367 // Returns in `memberDecorations` the Decorations (apart from Offset) 368 // associated with all members of the StructType. 369 void getMemberDecorations(SmallVectorImpl<StructType::MemberDecorationInfo> 370 &memberDecorations) const; 371 372 // Returns in `decorationsInfo` all the Decorations (apart from Offset) 373 // associated with the `i`-th member of the StructType. 374 void getMemberDecorations( 375 unsigned i, 376 SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const; 377 378 /// Sets the contents of an incomplete identified StructType. This method must 379 /// be called only for identified StructTypes and it must be called only once 380 /// per instance. Otherwise, failure() is returned. 381 LogicalResult 382 trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {}, 383 ArrayRef<MemberDecorationInfo> memberDecorations = {}); 384 385 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 386 std::optional<StorageClass> storage = std::nullopt); 387 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 388 std::optional<StorageClass> storage = std::nullopt); 389 }; 390 391 llvm::hash_code 392 hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo); 393 394 // SPIR-V KHR cooperative matrix type 395 class CooperativeMatrixType 396 : public Type::TypeBase<CooperativeMatrixType, CompositeType, 397 detail::CooperativeMatrixTypeStorage> { 398 public: 399 using Base::Base; 400 401 static constexpr StringLiteral name = "spirv.coopmatrix"; 402 403 static CooperativeMatrixType get(Type elementType, uint32_t rows, 404 uint32_t columns, Scope scope, 405 CooperativeMatrixUseKHR use); 406 Type getElementType() const; 407 408 /// Returns the scope of the matrix. 409 Scope getScope() const; 410 /// Returns the number of rows of the matrix. 411 uint32_t getRows() const; 412 /// Returns the number of columns of the matrix. 413 uint32_t getColumns() const; 414 /// Returns the use parameter of the cooperative matrix. 415 CooperativeMatrixUseKHR getUse() const; 416 417 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 418 std::optional<StorageClass> storage = std::nullopt); 419 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 420 std::optional<StorageClass> storage = std::nullopt); 421 }; 422 423 // SPIR-V matrix type 424 class MatrixType : public Type::TypeBase<MatrixType, CompositeType, 425 detail::MatrixTypeStorage> { 426 public: 427 using Base::Base; 428 429 static constexpr StringLiteral name = "spirv.matrix"; 430 431 static MatrixType get(Type columnType, uint32_t columnCount); 432 433 static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError, 434 Type columnType, uint32_t columnCount); 435 436 static LogicalResult 437 verifyInvariants(function_ref<InFlightDiagnostic()> emitError, 438 Type columnType, uint32_t columnCount); 439 440 /// Returns true if the matrix elements are vectors of float elements. 441 static bool isValidColumnType(Type columnType); 442 443 Type getColumnType() const; 444 445 /// Returns the number of rows. 446 unsigned getNumRows() const; 447 448 /// Returns the number of columns. 449 unsigned getNumColumns() const; 450 451 /// Returns total number of elements (rows*columns). 452 unsigned getNumElements() const; 453 454 /// Returns the elements' type (i.e, single element type). 455 Type getElementType() const; 456 457 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 458 std::optional<StorageClass> storage = std::nullopt); 459 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 460 std::optional<StorageClass> storage = std::nullopt); 461 }; 462 463 } // namespace spirv 464 } // namespace mlir 465 466 #endif // MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_ 467