1 //===- DXContainerGlobals.cpp - DXContainer global generator pass ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // DXContainerGlobalsPass implementation. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "DXILShaderFlags.h" 14 #include "DirectX.h" 15 #include "llvm/ADT/SmallVector.h" 16 #include "llvm/ADT/StringExtras.h" 17 #include "llvm/ADT/StringRef.h" 18 #include "llvm/Analysis/DXILMetadataAnalysis.h" 19 #include "llvm/Analysis/DXILResource.h" 20 #include "llvm/BinaryFormat/DXContainer.h" 21 #include "llvm/CodeGen/Passes.h" 22 #include "llvm/IR/Constants.h" 23 #include "llvm/IR/Module.h" 24 #include "llvm/InitializePasses.h" 25 #include "llvm/MC/DXContainerPSVInfo.h" 26 #include "llvm/Pass.h" 27 #include "llvm/Support/MD5.h" 28 #include "llvm/Transforms/Utils/ModuleUtils.h" 29 30 using namespace llvm; 31 using namespace llvm::dxil; 32 using namespace llvm::mcdxbc; 33 34 namespace { 35 class DXContainerGlobals : public llvm::ModulePass { 36 37 GlobalVariable *buildContainerGlobal(Module &M, Constant *Content, 38 StringRef Name, StringRef SectionName); 39 GlobalVariable *getFeatureFlags(Module &M); 40 GlobalVariable *computeShaderHash(Module &M); 41 GlobalVariable *buildSignature(Module &M, Signature &Sig, StringRef Name, 42 StringRef SectionName); 43 void addSignature(Module &M, SmallVector<GlobalValue *> &Globals); 44 void addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV); 45 void addPipelineStateValidationInfo(Module &M, 46 SmallVector<GlobalValue *> &Globals); 47 48 public: 49 static char ID; // Pass identification, replacement for typeid 50 DXContainerGlobals() : ModulePass(ID) { 51 initializeDXContainerGlobalsPass(*PassRegistry::getPassRegistry()); 52 } 53 54 StringRef getPassName() const override { 55 return "DXContainer Global Emitter"; 56 } 57 58 bool runOnModule(Module &M) override; 59 60 void getAnalysisUsage(AnalysisUsage &AU) const override { 61 AU.setPreservesAll(); 62 AU.addRequired<ShaderFlagsAnalysisWrapper>(); 63 AU.addRequired<DXILMetadataAnalysisWrapperPass>(); 64 AU.addRequired<DXILResourceTypeWrapperPass>(); 65 AU.addRequired<DXILResourceBindingWrapperPass>(); 66 } 67 }; 68 69 } // namespace 70 71 bool DXContainerGlobals::runOnModule(Module &M) { 72 llvm::SmallVector<GlobalValue *> Globals; 73 Globals.push_back(getFeatureFlags(M)); 74 Globals.push_back(computeShaderHash(M)); 75 addSignature(M, Globals); 76 addPipelineStateValidationInfo(M, Globals); 77 appendToCompilerUsed(M, Globals); 78 return true; 79 } 80 81 GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) { 82 uint64_t CombinedFeatureFlags = getAnalysis<ShaderFlagsAnalysisWrapper>() 83 .getShaderFlags() 84 .getCombinedFlags() 85 .getFeatureFlags(); 86 87 Constant *FeatureFlagsConstant = 88 ConstantInt::get(M.getContext(), APInt(64, CombinedFeatureFlags)); 89 return buildContainerGlobal(M, FeatureFlagsConstant, "dx.sfi0", "SFI0"); 90 } 91 92 GlobalVariable *DXContainerGlobals::computeShaderHash(Module &M) { 93 auto *DXILConstant = 94 cast<ConstantDataArray>(M.getNamedGlobal("dx.dxil")->getInitializer()); 95 MD5 Digest; 96 Digest.update(DXILConstant->getRawDataValues()); 97 MD5::MD5Result Result = Digest.final(); 98 99 dxbc::ShaderHash HashData = {0, {0}}; 100 // The Hash's IncludesSource flag gets set whenever the hashed shader includes 101 // debug information. 102 if (M.debug_compile_units_begin() != M.debug_compile_units_end()) 103 HashData.Flags = static_cast<uint32_t>(dxbc::HashFlags::IncludesSource); 104 105 memcpy(reinterpret_cast<void *>(&HashData.Digest), Result.data(), 16); 106 if (sys::IsBigEndianHost) 107 HashData.swapBytes(); 108 StringRef Data(reinterpret_cast<char *>(&HashData), sizeof(dxbc::ShaderHash)); 109 110 Constant *ModuleConstant = 111 ConstantDataArray::get(M.getContext(), arrayRefFromStringRef(Data)); 112 return buildContainerGlobal(M, ModuleConstant, "dx.hash", "HASH"); 113 } 114 115 GlobalVariable *DXContainerGlobals::buildContainerGlobal( 116 Module &M, Constant *Content, StringRef Name, StringRef SectionName) { 117 auto *GV = new llvm::GlobalVariable( 118 M, Content->getType(), true, GlobalValue::PrivateLinkage, Content, Name); 119 GV->setSection(SectionName); 120 GV->setAlignment(Align(4)); 121 return GV; 122 } 123 124 GlobalVariable *DXContainerGlobals::buildSignature(Module &M, Signature &Sig, 125 StringRef Name, 126 StringRef SectionName) { 127 SmallString<256> Data; 128 raw_svector_ostream OS(Data); 129 Sig.write(OS); 130 Constant *Constant = 131 ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false); 132 return buildContainerGlobal(M, Constant, Name, SectionName); 133 } 134 135 void DXContainerGlobals::addSignature(Module &M, 136 SmallVector<GlobalValue *> &Globals) { 137 // FIXME: support graphics shader. 138 // see issue https://github.com/llvm/llvm-project/issues/90504. 139 140 Signature InputSig; 141 Globals.emplace_back(buildSignature(M, InputSig, "dx.isg1", "ISG1")); 142 143 Signature OutputSig; 144 Globals.emplace_back(buildSignature(M, OutputSig, "dx.osg1", "OSG1")); 145 } 146 147 void DXContainerGlobals::addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV) { 148 const DXILBindingMap &DBM = 149 getAnalysis<DXILResourceBindingWrapperPass>().getBindingMap(); 150 DXILResourceTypeMap &DRTM = 151 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap(); 152 153 for (const dxil::ResourceBindingInfo &RBI : DBM) { 154 const dxil::ResourceBindingInfo::ResourceBinding &Binding = 155 RBI.getBinding(); 156 dxbc::PSV::v2::ResourceBindInfo BindInfo; 157 BindInfo.LowerBound = Binding.LowerBound; 158 BindInfo.UpperBound = Binding.LowerBound + Binding.Size - 1; 159 BindInfo.Space = Binding.Space; 160 161 dxil::ResourceTypeInfo &TypeInfo = DRTM[RBI.getHandleTy()]; 162 dxbc::PSV::ResourceType ResType = dxbc::PSV::ResourceType::Invalid; 163 bool IsUAV = TypeInfo.getResourceClass() == dxil::ResourceClass::UAV; 164 switch (TypeInfo.getResourceKind()) { 165 case dxil::ResourceKind::Sampler: 166 ResType = dxbc::PSV::ResourceType::Sampler; 167 break; 168 case dxil::ResourceKind::CBuffer: 169 ResType = dxbc::PSV::ResourceType::CBV; 170 break; 171 case dxil::ResourceKind::StructuredBuffer: 172 ResType = IsUAV ? dxbc::PSV::ResourceType::UAVStructured 173 : dxbc::PSV::ResourceType::SRVStructured; 174 if (IsUAV && TypeInfo.getUAV().HasCounter) 175 ResType = dxbc::PSV::ResourceType::UAVStructuredWithCounter; 176 break; 177 case dxil::ResourceKind::RTAccelerationStructure: 178 ResType = dxbc::PSV::ResourceType::SRVRaw; 179 break; 180 case dxil::ResourceKind::RawBuffer: 181 ResType = IsUAV ? dxbc::PSV::ResourceType::UAVRaw 182 : dxbc::PSV::ResourceType::SRVRaw; 183 break; 184 default: 185 ResType = IsUAV ? dxbc::PSV::ResourceType::UAVTyped 186 : dxbc::PSV::ResourceType::SRVTyped; 187 break; 188 } 189 BindInfo.Type = ResType; 190 191 BindInfo.Kind = 192 static_cast<dxbc::PSV::ResourceKind>(TypeInfo.getResourceKind()); 193 // TODO: Add support for dxbc::PSV::ResourceFlag::UsedByAtomic64, tracking 194 // with https://github.com/llvm/llvm-project/issues/104392 195 BindInfo.Flags.Flags = 0u; 196 197 PSV.Resources.emplace_back(BindInfo); 198 } 199 } 200 201 void DXContainerGlobals::addPipelineStateValidationInfo( 202 Module &M, SmallVector<GlobalValue *> &Globals) { 203 SmallString<256> Data; 204 raw_svector_ostream OS(Data); 205 PSVRuntimeInfo PSV; 206 PSV.BaseData.MinimumWaveLaneCount = 0; 207 PSV.BaseData.MaximumWaveLaneCount = std::numeric_limits<uint32_t>::max(); 208 209 dxil::ModuleMetadataInfo &MMI = 210 getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata(); 211 assert(MMI.EntryPropertyVec.size() == 1 || 212 MMI.ShaderProfile == Triple::Library); 213 PSV.BaseData.ShaderStage = 214 static_cast<uint8_t>(MMI.ShaderProfile - Triple::Pixel); 215 216 addResourcesForPSV(M, PSV); 217 218 // Hardcoded values here to unblock loading the shader into D3D. 219 // 220 // TODO: Lots more stuff to do here! 221 // 222 // See issue https://github.com/llvm/llvm-project/issues/96674. 223 switch (MMI.ShaderProfile) { 224 case Triple::Compute: 225 PSV.BaseData.NumThreadsX = MMI.EntryPropertyVec[0].NumThreadsX; 226 PSV.BaseData.NumThreadsY = MMI.EntryPropertyVec[0].NumThreadsY; 227 PSV.BaseData.NumThreadsZ = MMI.EntryPropertyVec[0].NumThreadsZ; 228 break; 229 default: 230 break; 231 } 232 233 if (MMI.ShaderProfile != Triple::Library) 234 PSV.EntryName = MMI.EntryPropertyVec[0].Entry->getName(); 235 236 PSV.finalize(MMI.ShaderProfile); 237 PSV.write(OS); 238 Constant *Constant = 239 ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false); 240 Globals.emplace_back(buildContainerGlobal(M, Constant, "dx.psv0", "PSV0")); 241 } 242 243 char DXContainerGlobals::ID = 0; 244 INITIALIZE_PASS_BEGIN(DXContainerGlobals, "dxil-globals", 245 "DXContainer Global Emitter", false, true) 246 INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper) 247 INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass) 248 INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass) 249 INITIALIZE_PASS_DEPENDENCY(DXILResourceBindingWrapperPass) 250 INITIALIZE_PASS_END(DXContainerGlobals, "dxil-globals", 251 "DXContainer Global Emitter", false, true) 252 253 ModulePass *llvm::createDXContainerGlobalsPass() { 254 return new DXContainerGlobals(); 255 } 256