xref: /llvm-project/llvm/lib/Target/DirectX/DXContainerGlobals.cpp (revision 3eca15cbb9888a992749ddd24f0fb666dad733bf)
1 //===- DXContainerGlobals.cpp - DXContainer global generator pass ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // DXContainerGlobalsPass implementation.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "DXILShaderFlags.h"
14 #include "DirectX.h"
15 #include "llvm/ADT/SmallVector.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "llvm/Analysis/DXILMetadataAnalysis.h"
19 #include "llvm/Analysis/DXILResource.h"
20 #include "llvm/BinaryFormat/DXContainer.h"
21 #include "llvm/CodeGen/Passes.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/MC/DXContainerPSVInfo.h"
26 #include "llvm/Pass.h"
27 #include "llvm/Support/MD5.h"
28 #include "llvm/Transforms/Utils/ModuleUtils.h"
29 
30 using namespace llvm;
31 using namespace llvm::dxil;
32 using namespace llvm::mcdxbc;
33 
34 namespace {
35 class DXContainerGlobals : public llvm::ModulePass {
36 
37   GlobalVariable *buildContainerGlobal(Module &M, Constant *Content,
38                                        StringRef Name, StringRef SectionName);
39   GlobalVariable *getFeatureFlags(Module &M);
40   GlobalVariable *computeShaderHash(Module &M);
41   GlobalVariable *buildSignature(Module &M, Signature &Sig, StringRef Name,
42                                  StringRef SectionName);
43   void addSignature(Module &M, SmallVector<GlobalValue *> &Globals);
44   void addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV);
45   void addPipelineStateValidationInfo(Module &M,
46                                       SmallVector<GlobalValue *> &Globals);
47 
48 public:
49   static char ID; // Pass identification, replacement for typeid
50   DXContainerGlobals() : ModulePass(ID) {
51     initializeDXContainerGlobalsPass(*PassRegistry::getPassRegistry());
52   }
53 
54   StringRef getPassName() const override {
55     return "DXContainer Global Emitter";
56   }
57 
58   bool runOnModule(Module &M) override;
59 
60   void getAnalysisUsage(AnalysisUsage &AU) const override {
61     AU.setPreservesAll();
62     AU.addRequired<ShaderFlagsAnalysisWrapper>();
63     AU.addRequired<DXILMetadataAnalysisWrapperPass>();
64     AU.addRequired<DXILResourceTypeWrapperPass>();
65     AU.addRequired<DXILResourceBindingWrapperPass>();
66   }
67 };
68 
69 } // namespace
70 
71 bool DXContainerGlobals::runOnModule(Module &M) {
72   llvm::SmallVector<GlobalValue *> Globals;
73   Globals.push_back(getFeatureFlags(M));
74   Globals.push_back(computeShaderHash(M));
75   addSignature(M, Globals);
76   addPipelineStateValidationInfo(M, Globals);
77   appendToCompilerUsed(M, Globals);
78   return true;
79 }
80 
81 GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) {
82   uint64_t CombinedFeatureFlags = getAnalysis<ShaderFlagsAnalysisWrapper>()
83                                       .getShaderFlags()
84                                       .getCombinedFlags()
85                                       .getFeatureFlags();
86 
87   Constant *FeatureFlagsConstant =
88       ConstantInt::get(M.getContext(), APInt(64, CombinedFeatureFlags));
89   return buildContainerGlobal(M, FeatureFlagsConstant, "dx.sfi0", "SFI0");
90 }
91 
92 GlobalVariable *DXContainerGlobals::computeShaderHash(Module &M) {
93   auto *DXILConstant =
94       cast<ConstantDataArray>(M.getNamedGlobal("dx.dxil")->getInitializer());
95   MD5 Digest;
96   Digest.update(DXILConstant->getRawDataValues());
97   MD5::MD5Result Result = Digest.final();
98 
99   dxbc::ShaderHash HashData = {0, {0}};
100   // The Hash's IncludesSource flag gets set whenever the hashed shader includes
101   // debug information.
102   if (M.debug_compile_units_begin() != M.debug_compile_units_end())
103     HashData.Flags = static_cast<uint32_t>(dxbc::HashFlags::IncludesSource);
104 
105   memcpy(reinterpret_cast<void *>(&HashData.Digest), Result.data(), 16);
106   if (sys::IsBigEndianHost)
107     HashData.swapBytes();
108   StringRef Data(reinterpret_cast<char *>(&HashData), sizeof(dxbc::ShaderHash));
109 
110   Constant *ModuleConstant =
111       ConstantDataArray::get(M.getContext(), arrayRefFromStringRef(Data));
112   return buildContainerGlobal(M, ModuleConstant, "dx.hash", "HASH");
113 }
114 
115 GlobalVariable *DXContainerGlobals::buildContainerGlobal(
116     Module &M, Constant *Content, StringRef Name, StringRef SectionName) {
117   auto *GV = new llvm::GlobalVariable(
118       M, Content->getType(), true, GlobalValue::PrivateLinkage, Content, Name);
119   GV->setSection(SectionName);
120   GV->setAlignment(Align(4));
121   return GV;
122 }
123 
124 GlobalVariable *DXContainerGlobals::buildSignature(Module &M, Signature &Sig,
125                                                    StringRef Name,
126                                                    StringRef SectionName) {
127   SmallString<256> Data;
128   raw_svector_ostream OS(Data);
129   Sig.write(OS);
130   Constant *Constant =
131       ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false);
132   return buildContainerGlobal(M, Constant, Name, SectionName);
133 }
134 
135 void DXContainerGlobals::addSignature(Module &M,
136                                       SmallVector<GlobalValue *> &Globals) {
137   // FIXME: support graphics shader.
138   //  see issue https://github.com/llvm/llvm-project/issues/90504.
139 
140   Signature InputSig;
141   Globals.emplace_back(buildSignature(M, InputSig, "dx.isg1", "ISG1"));
142 
143   Signature OutputSig;
144   Globals.emplace_back(buildSignature(M, OutputSig, "dx.osg1", "OSG1"));
145 }
146 
147 void DXContainerGlobals::addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV) {
148   const DXILBindingMap &DBM =
149       getAnalysis<DXILResourceBindingWrapperPass>().getBindingMap();
150   DXILResourceTypeMap &DRTM =
151       getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
152 
153   for (const dxil::ResourceBindingInfo &RBI : DBM) {
154     const dxil::ResourceBindingInfo::ResourceBinding &Binding =
155         RBI.getBinding();
156     dxbc::PSV::v2::ResourceBindInfo BindInfo;
157     BindInfo.LowerBound = Binding.LowerBound;
158     BindInfo.UpperBound = Binding.LowerBound + Binding.Size - 1;
159     BindInfo.Space = Binding.Space;
160 
161     dxil::ResourceTypeInfo &TypeInfo = DRTM[RBI.getHandleTy()];
162     dxbc::PSV::ResourceType ResType = dxbc::PSV::ResourceType::Invalid;
163     bool IsUAV = TypeInfo.getResourceClass() == dxil::ResourceClass::UAV;
164     switch (TypeInfo.getResourceKind()) {
165     case dxil::ResourceKind::Sampler:
166       ResType = dxbc::PSV::ResourceType::Sampler;
167       break;
168     case dxil::ResourceKind::CBuffer:
169       ResType = dxbc::PSV::ResourceType::CBV;
170       break;
171     case dxil::ResourceKind::StructuredBuffer:
172       ResType = IsUAV ? dxbc::PSV::ResourceType::UAVStructured
173                       : dxbc::PSV::ResourceType::SRVStructured;
174       if (IsUAV && TypeInfo.getUAV().HasCounter)
175         ResType = dxbc::PSV::ResourceType::UAVStructuredWithCounter;
176       break;
177     case dxil::ResourceKind::RTAccelerationStructure:
178       ResType = dxbc::PSV::ResourceType::SRVRaw;
179       break;
180     case dxil::ResourceKind::RawBuffer:
181       ResType = IsUAV ? dxbc::PSV::ResourceType::UAVRaw
182                       : dxbc::PSV::ResourceType::SRVRaw;
183       break;
184     default:
185       ResType = IsUAV ? dxbc::PSV::ResourceType::UAVTyped
186                       : dxbc::PSV::ResourceType::SRVTyped;
187       break;
188     }
189     BindInfo.Type = ResType;
190 
191     BindInfo.Kind =
192         static_cast<dxbc::PSV::ResourceKind>(TypeInfo.getResourceKind());
193     // TODO: Add support for dxbc::PSV::ResourceFlag::UsedByAtomic64, tracking
194     // with https://github.com/llvm/llvm-project/issues/104392
195     BindInfo.Flags.Flags = 0u;
196 
197     PSV.Resources.emplace_back(BindInfo);
198   }
199 }
200 
201 void DXContainerGlobals::addPipelineStateValidationInfo(
202     Module &M, SmallVector<GlobalValue *> &Globals) {
203   SmallString<256> Data;
204   raw_svector_ostream OS(Data);
205   PSVRuntimeInfo PSV;
206   PSV.BaseData.MinimumWaveLaneCount = 0;
207   PSV.BaseData.MaximumWaveLaneCount = std::numeric_limits<uint32_t>::max();
208 
209   dxil::ModuleMetadataInfo &MMI =
210       getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
211   assert(MMI.EntryPropertyVec.size() == 1 ||
212          MMI.ShaderProfile == Triple::Library);
213   PSV.BaseData.ShaderStage =
214       static_cast<uint8_t>(MMI.ShaderProfile - Triple::Pixel);
215 
216   addResourcesForPSV(M, PSV);
217 
218   // Hardcoded values here to unblock loading the shader into D3D.
219   //
220   // TODO: Lots more stuff to do here!
221   //
222   // See issue https://github.com/llvm/llvm-project/issues/96674.
223   switch (MMI.ShaderProfile) {
224   case Triple::Compute:
225     PSV.BaseData.NumThreadsX = MMI.EntryPropertyVec[0].NumThreadsX;
226     PSV.BaseData.NumThreadsY = MMI.EntryPropertyVec[0].NumThreadsY;
227     PSV.BaseData.NumThreadsZ = MMI.EntryPropertyVec[0].NumThreadsZ;
228     break;
229   default:
230     break;
231   }
232 
233   if (MMI.ShaderProfile != Triple::Library)
234     PSV.EntryName = MMI.EntryPropertyVec[0].Entry->getName();
235 
236   PSV.finalize(MMI.ShaderProfile);
237   PSV.write(OS);
238   Constant *Constant =
239       ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false);
240   Globals.emplace_back(buildContainerGlobal(M, Constant, "dx.psv0", "PSV0"));
241 }
242 
243 char DXContainerGlobals::ID = 0;
244 INITIALIZE_PASS_BEGIN(DXContainerGlobals, "dxil-globals",
245                       "DXContainer Global Emitter", false, true)
246 INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper)
247 INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
248 INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)
249 INITIALIZE_PASS_DEPENDENCY(DXILResourceBindingWrapperPass)
250 INITIALIZE_PASS_END(DXContainerGlobals, "dxil-globals",
251                     "DXContainer Global Emitter", false, true)
252 
253 ModulePass *llvm::createDXContainerGlobalsPass() {
254   return new DXContainerGlobals();
255 }
256