1 //===- DXContainer.h - DXContainer file implementation ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file declares the DXContainerFile class, which implements the ObjectFile 10 // interface for DXContainer files. 11 // 12 // 13 //===----------------------------------------------------------------------===// 14 15 #ifndef LLVM_OBJECT_DXCONTAINER_H 16 #define LLVM_OBJECT_DXCONTAINER_H 17 18 #include "llvm/ADT/SmallVector.h" 19 #include "llvm/ADT/StringRef.h" 20 #include "llvm/BinaryFormat/DXContainer.h" 21 #include "llvm/Support/Error.h" 22 #include "llvm/Support/MemoryBufferRef.h" 23 #include "llvm/TargetParser/Triple.h" 24 #include <array> 25 #include <variant> 26 27 namespace llvm { 28 namespace object { 29 30 namespace detail { 31 template <typename T> 32 std::enable_if_t<std::is_arithmetic<T>::value, void> swapBytes(T &value) { 33 sys::swapByteOrder(value); 34 } 35 36 template <typename T> 37 std::enable_if_t<std::is_class<T>::value, void> swapBytes(T &value) { 38 value.swapBytes(); 39 } 40 } // namespace detail 41 42 // This class provides a view into the underlying resource array. The Resource 43 // data is little-endian encoded and may not be properly aligned to read 44 // directly from. The dereference operator creates a copy of the data and byte 45 // swaps it as appropriate. 46 template <typename T> struct ViewArray { 47 StringRef Data; 48 uint32_t Stride = sizeof(T); // size of each element in the list. 49 50 ViewArray() = default; 51 ViewArray(StringRef D, size_t S) : Data(D), Stride(S) {} 52 53 using value_type = T; 54 static constexpr uint32_t MaxStride() { 55 return static_cast<uint32_t>(sizeof(value_type)); 56 } 57 58 struct iterator { 59 StringRef Data; 60 uint32_t Stride; // size of each element in the list. 61 const char *Current; 62 63 iterator(const ViewArray &A, const char *C) 64 : Data(A.Data), Stride(A.Stride), Current(C) {} 65 iterator(const iterator &) = default; 66 67 value_type operator*() { 68 // Explicitly zero the structure so that unused fields are zeroed. It is 69 // up to the user to know if the fields are used by verifying the PSV 70 // version. 71 value_type Val; 72 std::memset(&Val, 0, sizeof(value_type)); 73 if (Current >= Data.end()) 74 return Val; 75 memcpy(static_cast<void *>(&Val), Current, std::min(Stride, MaxStride())); 76 if (sys::IsBigEndianHost) 77 detail::swapBytes(Val); 78 return Val; 79 } 80 81 iterator operator++() { 82 if (Current < Data.end()) 83 Current += Stride; 84 return *this; 85 } 86 87 iterator operator++(int) { 88 iterator Tmp = *this; 89 ++*this; 90 return Tmp; 91 } 92 93 iterator operator--() { 94 if (Current > Data.begin()) 95 Current -= Stride; 96 return *this; 97 } 98 99 iterator operator--(int) { 100 iterator Tmp = *this; 101 --*this; 102 return Tmp; 103 } 104 105 bool operator==(const iterator I) { return I.Current == Current; } 106 bool operator!=(const iterator I) { return !(*this == I); } 107 }; 108 109 iterator begin() const { return iterator(*this, Data.begin()); } 110 111 iterator end() const { return iterator(*this, Data.end()); } 112 113 size_t size() const { return Data.size() / Stride; } 114 115 bool isEmpty() const { return Data.empty(); } 116 }; 117 118 namespace DirectX { 119 class PSVRuntimeInfo { 120 121 using ResourceArray = ViewArray<dxbc::PSV::v2::ResourceBindInfo>; 122 using SigElementArray = ViewArray<dxbc::PSV::v0::SignatureElement>; 123 124 StringRef Data; 125 uint32_t Size; 126 using InfoStruct = 127 std::variant<std::monostate, dxbc::PSV::v0::RuntimeInfo, 128 dxbc::PSV::v1::RuntimeInfo, dxbc::PSV::v2::RuntimeInfo, 129 dxbc::PSV::v3::RuntimeInfo>; 130 InfoStruct BasicInfo; 131 ResourceArray Resources; 132 StringRef StringTable; 133 SmallVector<uint32_t> SemanticIndexTable; 134 SigElementArray SigInputElements; 135 SigElementArray SigOutputElements; 136 SigElementArray SigPatchOrPrimElements; 137 138 std::array<ViewArray<uint32_t>, 4> OutputVectorMasks; 139 ViewArray<uint32_t> PatchOrPrimMasks; 140 std::array<ViewArray<uint32_t>, 4> InputOutputMap; 141 ViewArray<uint32_t> InputPatchMap; 142 ViewArray<uint32_t> PatchOutputMap; 143 144 public: 145 PSVRuntimeInfo(StringRef D) : Data(D), Size(0) {} 146 147 // Parsing depends on the shader kind 148 Error parse(uint16_t ShaderKind); 149 150 uint32_t getSize() const { return Size; } 151 uint32_t getResourceCount() const { return Resources.size(); } 152 ResourceArray getResources() const { return Resources; } 153 154 uint32_t getVersion() const { 155 return Size >= sizeof(dxbc::PSV::v3::RuntimeInfo) 156 ? 3 157 : (Size >= sizeof(dxbc::PSV::v2::RuntimeInfo) ? 2 158 : (Size >= sizeof(dxbc::PSV::v1::RuntimeInfo)) ? 1 159 : 0); 160 } 161 162 uint32_t getResourceStride() const { return Resources.Stride; } 163 164 const InfoStruct &getInfo() const { return BasicInfo; } 165 166 template <typename T> const T *getInfoAs() const { 167 if (const auto *P = std::get_if<dxbc::PSV::v3::RuntimeInfo>(&BasicInfo)) 168 return static_cast<const T *>(P); 169 if (std::is_same<T, dxbc::PSV::v3::RuntimeInfo>::value) 170 return nullptr; 171 172 if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(&BasicInfo)) 173 return static_cast<const T *>(P); 174 if (std::is_same<T, dxbc::PSV::v2::RuntimeInfo>::value) 175 return nullptr; 176 177 if (const auto *P = std::get_if<dxbc::PSV::v1::RuntimeInfo>(&BasicInfo)) 178 return static_cast<const T *>(P); 179 if (std::is_same<T, dxbc::PSV::v1::RuntimeInfo>::value) 180 return nullptr; 181 182 if (const auto *P = std::get_if<dxbc::PSV::v0::RuntimeInfo>(&BasicInfo)) 183 return static_cast<const T *>(P); 184 return nullptr; 185 } 186 187 StringRef getStringTable() const { return StringTable; } 188 ArrayRef<uint32_t> getSemanticIndexTable() const { 189 return SemanticIndexTable; 190 } 191 192 uint8_t getSigInputCount() const; 193 uint8_t getSigOutputCount() const; 194 uint8_t getSigPatchOrPrimCount() const; 195 196 SigElementArray getSigInputElements() const { return SigInputElements; } 197 SigElementArray getSigOutputElements() const { return SigOutputElements; } 198 SigElementArray getSigPatchOrPrimElements() const { 199 return SigPatchOrPrimElements; 200 } 201 202 ViewArray<uint32_t> getOutputVectorMasks(size_t Idx) const { 203 assert(Idx < 4); 204 return OutputVectorMasks[Idx]; 205 } 206 207 ViewArray<uint32_t> getPatchOrPrimMasks() const { return PatchOrPrimMasks; } 208 209 ViewArray<uint32_t> getInputOutputMap(size_t Idx) const { 210 assert(Idx < 4); 211 return InputOutputMap[Idx]; 212 } 213 214 ViewArray<uint32_t> getInputPatchMap() const { return InputPatchMap; } 215 ViewArray<uint32_t> getPatchOutputMap() const { return PatchOutputMap; } 216 217 uint32_t getSigElementStride() const { return SigInputElements.Stride; } 218 219 bool usesViewID() const { 220 if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>()) 221 return P->UsesViewID != 0; 222 return false; 223 } 224 225 uint8_t getInputVectorCount() const { 226 if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>()) 227 return P->SigInputVectors; 228 return 0; 229 } 230 231 ArrayRef<uint8_t> getOutputVectorCounts() const { 232 if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>()) 233 return ArrayRef<uint8_t>(P->SigOutputVectors); 234 return ArrayRef<uint8_t>(); 235 } 236 237 uint8_t getPatchConstOrPrimVectorCount() const { 238 if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>()) 239 return P->GeomData.SigPatchConstOrPrimVectors; 240 return 0; 241 } 242 }; 243 244 class Signature { 245 ViewArray<dxbc::ProgramSignatureElement> Parameters; 246 uint32_t StringTableOffset; 247 StringRef StringTable; 248 249 public: 250 ViewArray<dxbc::ProgramSignatureElement>::iterator begin() const { 251 return Parameters.begin(); 252 } 253 254 ViewArray<dxbc::ProgramSignatureElement>::iterator end() const { 255 return Parameters.end(); 256 } 257 258 StringRef getName(uint32_t Offset) const { 259 assert(Offset >= StringTableOffset && 260 Offset < StringTableOffset + StringTable.size() && 261 "Offset out of range."); 262 // Name offsets are from the start of the signature data, not from the start 263 // of the string table. The header encodes the start offset of the sting 264 // table, so we convert the offset here. 265 uint32_t TableOffset = Offset - StringTableOffset; 266 return StringTable.slice(TableOffset, StringTable.find('\0', TableOffset)); 267 } 268 269 bool isEmpty() const { return Parameters.isEmpty(); } 270 271 Error initialize(StringRef Part); 272 }; 273 274 } // namespace DirectX 275 276 class DXContainer { 277 public: 278 using DXILData = std::pair<dxbc::ProgramHeader, const char *>; 279 280 private: 281 DXContainer(MemoryBufferRef O); 282 283 MemoryBufferRef Data; 284 dxbc::Header Header; 285 SmallVector<uint32_t, 4> PartOffsets; 286 std::optional<DXILData> DXIL; 287 std::optional<uint64_t> ShaderFeatureFlags; 288 std::optional<dxbc::ShaderHash> Hash; 289 std::optional<DirectX::PSVRuntimeInfo> PSVInfo; 290 DirectX::Signature InputSignature; 291 DirectX::Signature OutputSignature; 292 DirectX::Signature PatchConstantSignature; 293 294 Error parseHeader(); 295 Error parsePartOffsets(); 296 Error parseDXILHeader(StringRef Part); 297 Error parseShaderFeatureFlags(StringRef Part); 298 Error parseHash(StringRef Part); 299 Error parsePSVInfo(StringRef Part); 300 Error parseSignature(StringRef Part, DirectX::Signature &Array); 301 friend class PartIterator; 302 303 public: 304 // The PartIterator is a wrapper around the iterator for the PartOffsets 305 // member of the DXContainer. It contains a refernce to the container, and the 306 // current iterator value, as well as storage for a parsed part header. 307 class PartIterator { 308 const DXContainer &Container; 309 SmallVectorImpl<uint32_t>::const_iterator OffsetIt; 310 struct PartData { 311 dxbc::PartHeader Part; 312 uint32_t Offset; 313 StringRef Data; 314 } IteratorState; 315 316 friend class DXContainer; 317 318 PartIterator(const DXContainer &C, 319 SmallVectorImpl<uint32_t>::const_iterator It) 320 : Container(C), OffsetIt(It) { 321 if (OffsetIt == Container.PartOffsets.end()) 322 updateIteratorImpl(Container.PartOffsets.back()); 323 else 324 updateIterator(); 325 } 326 327 // Updates the iterator's state data. This results in copying the part 328 // header into the iterator and handling any required byte swapping. This is 329 // called when incrementing or decrementing the iterator. 330 void updateIterator() { 331 if (OffsetIt != Container.PartOffsets.end()) 332 updateIteratorImpl(*OffsetIt); 333 } 334 335 // Implementation for updating the iterator state based on a specified 336 // offest. 337 void updateIteratorImpl(const uint32_t Offset); 338 339 public: 340 PartIterator &operator++() { 341 if (OffsetIt == Container.PartOffsets.end()) 342 return *this; 343 ++OffsetIt; 344 updateIterator(); 345 return *this; 346 } 347 348 PartIterator operator++(int) { 349 PartIterator Tmp = *this; 350 ++(*this); 351 return Tmp; 352 } 353 354 bool operator==(const PartIterator &RHS) const { 355 return OffsetIt == RHS.OffsetIt; 356 } 357 358 bool operator!=(const PartIterator &RHS) const { 359 return OffsetIt != RHS.OffsetIt; 360 } 361 362 const PartData &operator*() { return IteratorState; } 363 const PartData *operator->() { return &IteratorState; } 364 }; 365 366 PartIterator begin() const { 367 return PartIterator(*this, PartOffsets.begin()); 368 } 369 370 PartIterator end() const { return PartIterator(*this, PartOffsets.end()); } 371 372 StringRef getData() const { return Data.getBuffer(); } 373 static Expected<DXContainer> create(MemoryBufferRef Object); 374 375 const dxbc::Header &getHeader() const { return Header; } 376 377 const std::optional<DXILData> &getDXIL() const { return DXIL; } 378 379 std::optional<uint64_t> getShaderFeatureFlags() const { 380 return ShaderFeatureFlags; 381 } 382 383 std::optional<dxbc::ShaderHash> getShaderHash() const { return Hash; } 384 385 const std::optional<DirectX::PSVRuntimeInfo> &getPSVInfo() const { 386 return PSVInfo; 387 }; 388 389 const DirectX::Signature &getInputSignature() const { return InputSignature; } 390 const DirectX::Signature &getOutputSignature() const { 391 return OutputSignature; 392 } 393 const DirectX::Signature &getPatchConstantSignature() const { 394 return PatchConstantSignature; 395 } 396 }; 397 398 } // namespace object 399 } // namespace llvm 400 401 #endif // LLVM_OBJECT_DXCONTAINER_H 402