xref: /llvm-project/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp (revision 380bb51b70b6d9f3da07a87f56fc3fe44bc78691)
1 //===- DXILTranslateMetadata.cpp - Pass to emit DXIL metadata -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "DXILTranslateMetadata.h"
10 #include "DXILResource.h"
11 #include "DXILResourceAnalysis.h"
12 #include "DXILShaderFlags.h"
13 #include "DirectX.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/ADT/Twine.h"
16 #include "llvm/Analysis/DXILMetadataAnalysis.h"
17 #include "llvm/Analysis/DXILResource.h"
18 #include "llvm/IR/BasicBlock.h"
19 #include "llvm/IR/Constants.h"
20 #include "llvm/IR/DiagnosticInfo.h"
21 #include "llvm/IR/DiagnosticPrinter.h"
22 #include "llvm/IR/Function.h"
23 #include "llvm/IR/IRBuilder.h"
24 #include "llvm/IR/LLVMContext.h"
25 #include "llvm/IR/MDBuilder.h"
26 #include "llvm/IR/Metadata.h"
27 #include "llvm/IR/Module.h"
28 #include "llvm/InitializePasses.h"
29 #include "llvm/Pass.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/VersionTuple.h"
32 #include "llvm/TargetParser/Triple.h"
33 #include <cstdint>
34 
35 using namespace llvm;
36 using namespace llvm::dxil;
37 
38 namespace {
39 /// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic
40 /// for TranslateMetadata pass
41 class DiagnosticInfoTranslateMD : public DiagnosticInfo {
42 private:
43   const Twine &Msg;
44   const Module &Mod;
45 
46 public:
47   /// \p M is the module for which the diagnostic is being emitted. \p Msg is
48   /// the message to show. Note that this class does not copy this message, so
49   /// this reference must be valid for the whole life time of the diagnostic.
50   DiagnosticInfoTranslateMD(const Module &M, const Twine &Msg,
51                             DiagnosticSeverity Severity = DS_Error)
52       : DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {}
53 
54   void print(DiagnosticPrinter &DP) const override {
55     DP << Mod.getName() << ": " << Msg << '\n';
56   }
57 };
58 
59 enum class EntryPropsTag {
60   ShaderFlags = 0,
61   GSState,
62   DSState,
63   HSState,
64   NumThreads,
65   AutoBindingSpace,
66   RayPayloadSize,
67   RayAttribSize,
68   ShaderKind,
69   MSState,
70   ASStateTag,
71   WaveSize,
72   EntryRootSig,
73 };
74 
75 } // namespace
76 
77 static NamedMDNode *emitResourceMetadata(Module &M, DXILBindingMap &DBM,
78                                          DXILResourceTypeMap &DRTM,
79                                          const dxil::Resources &MDResources) {
80   LLVMContext &Context = M.getContext();
81 
82   for (ResourceBindingInfo &RI : DBM)
83     if (!RI.hasSymbol())
84       RI.createSymbol(M, DRTM[RI.getHandleTy()].createElementStruct());
85 
86   SmallVector<Metadata *> SRVs, UAVs, CBufs, Smps;
87   for (const ResourceBindingInfo &RI : DBM.srvs())
88     SRVs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()]));
89   for (const ResourceBindingInfo &RI : DBM.uavs())
90     UAVs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()]));
91   for (const ResourceBindingInfo &RI : DBM.cbuffers())
92     CBufs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()]));
93   for (const ResourceBindingInfo &RI : DBM.samplers())
94     Smps.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()]));
95 
96   Metadata *SRVMD = SRVs.empty() ? nullptr : MDNode::get(Context, SRVs);
97   Metadata *UAVMD = UAVs.empty() ? nullptr : MDNode::get(Context, UAVs);
98   Metadata *CBufMD = CBufs.empty() ? nullptr : MDNode::get(Context, CBufs);
99   Metadata *SmpMD = Smps.empty() ? nullptr : MDNode::get(Context, Smps);
100   bool HasResources = !DBM.empty();
101 
102   if (MDResources.hasUAVs()) {
103     assert(!UAVMD && "Old and new UAV representations can't coexist");
104     UAVMD = MDResources.writeUAVs(M);
105     HasResources = true;
106   }
107 
108   if (MDResources.hasCBuffers()) {
109     assert(!CBufMD && "Old and new cbuffer representations can't coexist");
110     CBufMD = MDResources.writeCBuffers(M);
111     HasResources = true;
112   }
113 
114   if (!HasResources)
115     return nullptr;
116 
117   NamedMDNode *ResourceMD = M.getOrInsertNamedMetadata("dx.resources");
118   ResourceMD->addOperand(
119       MDNode::get(M.getContext(), {SRVMD, UAVMD, CBufMD, SmpMD}));
120 
121   return ResourceMD;
122 }
123 
124 static StringRef getShortShaderStage(Triple::EnvironmentType Env) {
125   switch (Env) {
126   case Triple::Pixel:
127     return "ps";
128   case Triple::Vertex:
129     return "vs";
130   case Triple::Geometry:
131     return "gs";
132   case Triple::Hull:
133     return "hs";
134   case Triple::Domain:
135     return "ds";
136   case Triple::Compute:
137     return "cs";
138   case Triple::Library:
139     return "lib";
140   case Triple::Mesh:
141     return "ms";
142   case Triple::Amplification:
143     return "as";
144   default:
145     break;
146   }
147   llvm_unreachable("Unsupported environment for DXIL generation.");
148 }
149 
150 static uint32_t getShaderStage(Triple::EnvironmentType Env) {
151   return (uint32_t)Env - (uint32_t)llvm::Triple::Pixel;
152 }
153 
154 static SmallVector<Metadata *>
155 getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) {
156   SmallVector<Metadata *> MDVals;
157   MDVals.emplace_back(ConstantAsMetadata::get(
158       ConstantInt::get(Type::getInt32Ty(Ctx), static_cast<int>(Tag))));
159   switch (Tag) {
160   case EntryPropsTag::ShaderFlags:
161     MDVals.emplace_back(ConstantAsMetadata::get(
162         ConstantInt::get(Type::getInt64Ty(Ctx), Value)));
163     break;
164   case EntryPropsTag::ShaderKind:
165     MDVals.emplace_back(ConstantAsMetadata::get(
166         ConstantInt::get(Type::getInt32Ty(Ctx), Value)));
167     break;
168   case EntryPropsTag::GSState:
169   case EntryPropsTag::DSState:
170   case EntryPropsTag::HSState:
171   case EntryPropsTag::NumThreads:
172   case EntryPropsTag::AutoBindingSpace:
173   case EntryPropsTag::RayPayloadSize:
174   case EntryPropsTag::RayAttribSize:
175   case EntryPropsTag::MSState:
176   case EntryPropsTag::ASStateTag:
177   case EntryPropsTag::WaveSize:
178   case EntryPropsTag::EntryRootSig:
179     llvm_unreachable("NYI: Unhandled entry property tag");
180   }
181   return MDVals;
182 }
183 
184 static MDTuple *
185 getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
186                        const Triple::EnvironmentType ShaderProfile) {
187   SmallVector<Metadata *> MDVals;
188   LLVMContext &Ctx = EP.Entry->getContext();
189   if (EntryShaderFlags != 0)
190     MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderFlags,
191                                         EntryShaderFlags, Ctx));
192 
193   if (EP.Entry != nullptr) {
194     // FIXME: support more props.
195     // See https://github.com/llvm/llvm-project/issues/57948.
196     // Add shader kind for lib entries.
197     if (ShaderProfile == Triple::EnvironmentType::Library &&
198         EP.ShaderStage != Triple::EnvironmentType::Library)
199       MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderKind,
200                                           getShaderStage(EP.ShaderStage), Ctx));
201 
202     if (EP.ShaderStage == Triple::EnvironmentType::Compute) {
203       MDVals.emplace_back(ConstantAsMetadata::get(ConstantInt::get(
204           Type::getInt32Ty(Ctx), static_cast<int>(EntryPropsTag::NumThreads))));
205       Metadata *NumThreadVals[] = {ConstantAsMetadata::get(ConstantInt::get(
206                                        Type::getInt32Ty(Ctx), EP.NumThreadsX)),
207                                    ConstantAsMetadata::get(ConstantInt::get(
208                                        Type::getInt32Ty(Ctx), EP.NumThreadsY)),
209                                    ConstantAsMetadata::get(ConstantInt::get(
210                                        Type::getInt32Ty(Ctx), EP.NumThreadsZ))};
211       MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals));
212     }
213   }
214   if (MDVals.empty())
215     return nullptr;
216   return MDNode::get(Ctx, MDVals);
217 }
218 
219 MDTuple *constructEntryMetadata(const Function *EntryFn, MDTuple *Signatures,
220                                 MDNode *Resources, MDTuple *Properties,
221                                 LLVMContext &Ctx) {
222   // Each entry point metadata record specifies:
223   //  * reference to the entry point function global symbol
224   //  * unmangled name
225   //  * list of signatures
226   //  * list of resources
227   //  * list of tag-value pairs of shader capabilities and other properties
228   Metadata *MDVals[5];
229   MDVals[0] =
230       EntryFn ? ValueAsMetadata::get(const_cast<Function *>(EntryFn)) : nullptr;
231   MDVals[1] = MDString::get(Ctx, EntryFn ? EntryFn->getName() : "");
232   MDVals[2] = Signatures;
233   MDVals[3] = Resources;
234   MDVals[4] = Properties;
235   return MDNode::get(Ctx, MDVals);
236 }
237 
238 static MDTuple *emitEntryMD(const EntryProperties &EP, MDTuple *Signatures,
239                             MDNode *MDResources,
240                             const uint64_t EntryShaderFlags,
241                             const Triple::EnvironmentType ShaderProfile) {
242   MDTuple *Properties =
243       getEntryPropAsMetadata(EP, EntryShaderFlags, ShaderProfile);
244   return constructEntryMetadata(EP.Entry, Signatures, MDResources, Properties,
245                                 EP.Entry->getContext());
246 }
247 
248 static void emitValidatorVersionMD(Module &M, const ModuleMetadataInfo &MMDI) {
249   if (MMDI.ValidatorVersion.empty())
250     return;
251 
252   LLVMContext &Ctx = M.getContext();
253   IRBuilder<> IRB(Ctx);
254   Metadata *MDVals[2];
255   MDVals[0] =
256       ConstantAsMetadata::get(IRB.getInt32(MMDI.ValidatorVersion.getMajor()));
257   MDVals[1] = ConstantAsMetadata::get(
258       IRB.getInt32(MMDI.ValidatorVersion.getMinor().value_or(0)));
259   NamedMDNode *ValVerNode = M.getOrInsertNamedMetadata("dx.valver");
260   // Set validator version obtained from DXIL Metadata Analysis pass
261   ValVerNode->clearOperands();
262   ValVerNode->addOperand(MDNode::get(Ctx, MDVals));
263 }
264 
265 static void emitShaderModelVersionMD(Module &M,
266                                      const ModuleMetadataInfo &MMDI) {
267   LLVMContext &Ctx = M.getContext();
268   IRBuilder<> IRB(Ctx);
269   Metadata *SMVals[3];
270   VersionTuple SM = MMDI.ShaderModelVersion;
271   SMVals[0] = MDString::get(Ctx, getShortShaderStage(MMDI.ShaderProfile));
272   SMVals[1] = ConstantAsMetadata::get(IRB.getInt32(SM.getMajor()));
273   SMVals[2] = ConstantAsMetadata::get(IRB.getInt32(SM.getMinor().value_or(0)));
274   NamedMDNode *SMMDNode = M.getOrInsertNamedMetadata("dx.shaderModel");
275   SMMDNode->addOperand(MDNode::get(Ctx, SMVals));
276 }
277 
278 static void emitDXILVersionTupleMD(Module &M, const ModuleMetadataInfo &MMDI) {
279   LLVMContext &Ctx = M.getContext();
280   IRBuilder<> IRB(Ctx);
281   VersionTuple DXILVer = MMDI.DXILVersion;
282   Metadata *DXILVals[2];
283   DXILVals[0] = ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMajor()));
284   DXILVals[1] =
285       ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMinor().value_or(0)));
286   NamedMDNode *DXILVerMDNode = M.getOrInsertNamedMetadata("dx.version");
287   DXILVerMDNode->addOperand(MDNode::get(Ctx, DXILVals));
288 }
289 
290 static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
291                                         uint64_t ShaderFlags) {
292   LLVMContext &Ctx = M.getContext();
293   MDTuple *Properties = nullptr;
294   if (ShaderFlags != 0) {
295     SmallVector<Metadata *> MDVals;
296     MDVals.append(
297         getTagValueAsMetadata(EntryPropsTag::ShaderFlags, ShaderFlags, Ctx));
298     Properties = MDNode::get(Ctx, MDVals);
299   }
300   // Library has an entry metadata with resource table metadata and all other
301   // MDNodes as null.
302   return constructEntryMetadata(nullptr, nullptr, RMD, Properties, Ctx);
303 }
304 
305 // TODO: We might need to refactor this to be more generic,
306 // in case we need more metadata to be replaced.
307 static void translateBranchMetadata(Module &M) {
308   for (Function &F : M) {
309     for (BasicBlock &BB : F) {
310       Instruction *BBTerminatorInst = BB.getTerminator();
311 
312       MDNode *HlslControlFlowMD =
313           BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
314 
315       if (!HlslControlFlowMD)
316         continue;
317 
318       assert(HlslControlFlowMD->getNumOperands() == 2 &&
319              "invalid operands for hlsl.controlflow.hint");
320 
321       MDBuilder MDHelper(M.getContext());
322       ConstantInt *Op1 =
323           mdconst::extract<ConstantInt>(HlslControlFlowMD->getOperand(1));
324 
325       SmallVector<llvm::Metadata *, 2> Vals(
326           ArrayRef<Metadata *>{MDHelper.createString("dx.controlflow.hints"),
327                                MDHelper.createConstant(Op1)});
328 
329       MDNode *MDNode = llvm::MDNode::get(M.getContext(), Vals);
330 
331       BBTerminatorInst->setMetadata("dx.controlflow.hints", MDNode);
332       BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr);
333     }
334   }
335 }
336 
337 static void translateMetadata(Module &M, DXILBindingMap &DBM,
338                               DXILResourceTypeMap &DRTM,
339                               const Resources &MDResources,
340                               const ModuleShaderFlags &ShaderFlags,
341                               const ModuleMetadataInfo &MMDI) {
342   LLVMContext &Ctx = M.getContext();
343   IRBuilder<> IRB(Ctx);
344   SmallVector<MDNode *> EntryFnMDNodes;
345 
346   emitValidatorVersionMD(M, MMDI);
347   emitShaderModelVersionMD(M, MMDI);
348   emitDXILVersionTupleMD(M, MMDI);
349   NamedMDNode *NamedResourceMD =
350       emitResourceMetadata(M, DBM, DRTM, MDResources);
351   auto *ResourceMD =
352       (NamedResourceMD != nullptr) ? NamedResourceMD->getOperand(0) : nullptr;
353   // FIXME: Add support to construct Signatures
354   // See https://github.com/llvm/llvm-project/issues/57928
355   MDTuple *Signatures = nullptr;
356 
357   if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) {
358     // Get the combined shader flag mask of all functions in the library to be
359     // used as shader flags mask value associated with top-level library entry
360     // metadata.
361     uint64_t CombinedMask = ShaderFlags.getCombinedFlags();
362     EntryFnMDNodes.emplace_back(
363         emitTopLevelLibraryNode(M, ResourceMD, CombinedMask));
364   } else if (MMDI.EntryPropertyVec.size() > 1) {
365     M.getContext().diagnose(DiagnosticInfoTranslateMD(
366         M, "Non-library shader: One and only one entry expected"));
367   }
368 
369   for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
370     const ComputedShaderFlags &EntrySFMask =
371         ShaderFlags.getFunctionFlags(EntryProp.Entry);
372 
373     // If ShaderProfile is Library, mask is already consolidated in the
374     // top-level library node. Hence it is not emitted.
375     uint64_t EntryShaderFlags = 0;
376     if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
377       EntryShaderFlags = EntrySFMask;
378       if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
379         M.getContext().diagnose(DiagnosticInfoTranslateMD(
380             M,
381             "Shader stage '" +
382                 Twine(getShortShaderStage(EntryProp.ShaderStage) +
383                       "' for entry '" + Twine(EntryProp.Entry->getName()) +
384                       "' different from specified target profile '" +
385                       Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
386                             "'"))));
387       }
388     }
389     EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
390                                             EntryShaderFlags,
391                                             MMDI.ShaderProfile));
392   }
393 
394   NamedMDNode *EntryPointsNamedMD =
395       M.getOrInsertNamedMetadata("dx.entryPoints");
396   for (auto *Entry : EntryFnMDNodes)
397     EntryPointsNamedMD->addOperand(Entry);
398 }
399 
400 PreservedAnalyses DXILTranslateMetadata::run(Module &M,
401                                              ModuleAnalysisManager &MAM) {
402   DXILBindingMap &DBM = MAM.getResult<DXILResourceBindingAnalysis>(M);
403   DXILResourceTypeMap &DRTM = MAM.getResult<DXILResourceTypeAnalysis>(M);
404   const dxil::Resources &MDResources = MAM.getResult<DXILResourceMDAnalysis>(M);
405   const ModuleShaderFlags &ShaderFlags = MAM.getResult<ShaderFlagsAnalysis>(M);
406   const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);
407 
408   translateMetadata(M, DBM, DRTM, MDResources, ShaderFlags, MMDI);
409   translateBranchMetadata(M);
410 
411   return PreservedAnalyses::all();
412 }
413 
414 namespace {
415 class DXILTranslateMetadataLegacy : public ModulePass {
416 public:
417   static char ID; // Pass identification, replacement for typeid
418   explicit DXILTranslateMetadataLegacy() : ModulePass(ID) {}
419 
420   StringRef getPassName() const override { return "DXIL Translate Metadata"; }
421 
422   void getAnalysisUsage(AnalysisUsage &AU) const override {
423     AU.addRequired<DXILResourceTypeWrapperPass>();
424     AU.addRequired<DXILResourceBindingWrapperPass>();
425     AU.addRequired<DXILResourceMDWrapper>();
426     AU.addRequired<ShaderFlagsAnalysisWrapper>();
427     AU.addRequired<DXILMetadataAnalysisWrapperPass>();
428     AU.addPreserved<DXILResourceBindingWrapperPass>();
429     AU.addPreserved<DXILResourceMDWrapper>();
430     AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
431     AU.addPreserved<ShaderFlagsAnalysisWrapper>();
432   }
433 
434   bool runOnModule(Module &M) override {
435     DXILBindingMap &DBM =
436         getAnalysis<DXILResourceBindingWrapperPass>().getBindingMap();
437     DXILResourceTypeMap &DRTM =
438         getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
439     const dxil::Resources &MDResources =
440         getAnalysis<DXILResourceMDWrapper>().getDXILResource();
441     const ModuleShaderFlags &ShaderFlags =
442         getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
443     dxil::ModuleMetadataInfo MMDI =
444         getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
445 
446     translateMetadata(M, DBM, DRTM, MDResources, ShaderFlags, MMDI);
447     translateBranchMetadata(M);
448     return true;
449   }
450 };
451 
452 } // namespace
453 
454 char DXILTranslateMetadataLegacy::ID = 0;
455 
456 ModulePass *llvm::createDXILTranslateMetadataLegacyPass() {
457   return new DXILTranslateMetadataLegacy();
458 }
459 
460 INITIALIZE_PASS_BEGIN(DXILTranslateMetadataLegacy, "dxil-translate-metadata",
461                       "DXIL Translate Metadata", false, false)
462 INITIALIZE_PASS_DEPENDENCY(DXILResourceBindingWrapperPass)
463 INITIALIZE_PASS_DEPENDENCY(DXILResourceMDWrapper)
464 INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper)
465 INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
466 INITIALIZE_PASS_END(DXILTranslateMetadataLegacy, "dxil-translate-metadata",
467                     "DXIL Translate Metadata", false, false)
468