1 //===- DXILTranslateMetadata.cpp - Pass to emit DXIL metadata -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "DXILTranslateMetadata.h" 10 #include "DXILResource.h" 11 #include "DXILResourceAnalysis.h" 12 #include "DXILShaderFlags.h" 13 #include "DirectX.h" 14 #include "llvm/ADT/SmallVector.h" 15 #include "llvm/ADT/Twine.h" 16 #include "llvm/Analysis/DXILMetadataAnalysis.h" 17 #include "llvm/Analysis/DXILResource.h" 18 #include "llvm/IR/BasicBlock.h" 19 #include "llvm/IR/Constants.h" 20 #include "llvm/IR/DiagnosticInfo.h" 21 #include "llvm/IR/DiagnosticPrinter.h" 22 #include "llvm/IR/Function.h" 23 #include "llvm/IR/IRBuilder.h" 24 #include "llvm/IR/LLVMContext.h" 25 #include "llvm/IR/MDBuilder.h" 26 #include "llvm/IR/Metadata.h" 27 #include "llvm/IR/Module.h" 28 #include "llvm/InitializePasses.h" 29 #include "llvm/Pass.h" 30 #include "llvm/Support/ErrorHandling.h" 31 #include "llvm/Support/VersionTuple.h" 32 #include "llvm/TargetParser/Triple.h" 33 #include <cstdint> 34 35 using namespace llvm; 36 using namespace llvm::dxil; 37 38 namespace { 39 /// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic 40 /// for TranslateMetadata pass 41 class DiagnosticInfoTranslateMD : public DiagnosticInfo { 42 private: 43 const Twine &Msg; 44 const Module &Mod; 45 46 public: 47 /// \p M is the module for which the diagnostic is being emitted. \p Msg is 48 /// the message to show. Note that this class does not copy this message, so 49 /// this reference must be valid for the whole life time of the diagnostic. 50 DiagnosticInfoTranslateMD(const Module &M, const Twine &Msg, 51 DiagnosticSeverity Severity = DS_Error) 52 : DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {} 53 54 void print(DiagnosticPrinter &DP) const override { 55 DP << Mod.getName() << ": " << Msg << '\n'; 56 } 57 }; 58 59 enum class EntryPropsTag { 60 ShaderFlags = 0, 61 GSState, 62 DSState, 63 HSState, 64 NumThreads, 65 AutoBindingSpace, 66 RayPayloadSize, 67 RayAttribSize, 68 ShaderKind, 69 MSState, 70 ASStateTag, 71 WaveSize, 72 EntryRootSig, 73 }; 74 75 } // namespace 76 77 static NamedMDNode *emitResourceMetadata(Module &M, DXILBindingMap &DBM, 78 DXILResourceTypeMap &DRTM, 79 const dxil::Resources &MDResources) { 80 LLVMContext &Context = M.getContext(); 81 82 for (ResourceBindingInfo &RI : DBM) 83 if (!RI.hasSymbol()) 84 RI.createSymbol(M, DRTM[RI.getHandleTy()].createElementStruct()); 85 86 SmallVector<Metadata *> SRVs, UAVs, CBufs, Smps; 87 for (const ResourceBindingInfo &RI : DBM.srvs()) 88 SRVs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()])); 89 for (const ResourceBindingInfo &RI : DBM.uavs()) 90 UAVs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()])); 91 for (const ResourceBindingInfo &RI : DBM.cbuffers()) 92 CBufs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()])); 93 for (const ResourceBindingInfo &RI : DBM.samplers()) 94 Smps.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()])); 95 96 Metadata *SRVMD = SRVs.empty() ? nullptr : MDNode::get(Context, SRVs); 97 Metadata *UAVMD = UAVs.empty() ? nullptr : MDNode::get(Context, UAVs); 98 Metadata *CBufMD = CBufs.empty() ? nullptr : MDNode::get(Context, CBufs); 99 Metadata *SmpMD = Smps.empty() ? nullptr : MDNode::get(Context, Smps); 100 bool HasResources = !DBM.empty(); 101 102 if (MDResources.hasUAVs()) { 103 assert(!UAVMD && "Old and new UAV representations can't coexist"); 104 UAVMD = MDResources.writeUAVs(M); 105 HasResources = true; 106 } 107 108 if (MDResources.hasCBuffers()) { 109 assert(!CBufMD && "Old and new cbuffer representations can't coexist"); 110 CBufMD = MDResources.writeCBuffers(M); 111 HasResources = true; 112 } 113 114 if (!HasResources) 115 return nullptr; 116 117 NamedMDNode *ResourceMD = M.getOrInsertNamedMetadata("dx.resources"); 118 ResourceMD->addOperand( 119 MDNode::get(M.getContext(), {SRVMD, UAVMD, CBufMD, SmpMD})); 120 121 return ResourceMD; 122 } 123 124 static StringRef getShortShaderStage(Triple::EnvironmentType Env) { 125 switch (Env) { 126 case Triple::Pixel: 127 return "ps"; 128 case Triple::Vertex: 129 return "vs"; 130 case Triple::Geometry: 131 return "gs"; 132 case Triple::Hull: 133 return "hs"; 134 case Triple::Domain: 135 return "ds"; 136 case Triple::Compute: 137 return "cs"; 138 case Triple::Library: 139 return "lib"; 140 case Triple::Mesh: 141 return "ms"; 142 case Triple::Amplification: 143 return "as"; 144 default: 145 break; 146 } 147 llvm_unreachable("Unsupported environment for DXIL generation."); 148 } 149 150 static uint32_t getShaderStage(Triple::EnvironmentType Env) { 151 return (uint32_t)Env - (uint32_t)llvm::Triple::Pixel; 152 } 153 154 static SmallVector<Metadata *> 155 getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) { 156 SmallVector<Metadata *> MDVals; 157 MDVals.emplace_back(ConstantAsMetadata::get( 158 ConstantInt::get(Type::getInt32Ty(Ctx), static_cast<int>(Tag)))); 159 switch (Tag) { 160 case EntryPropsTag::ShaderFlags: 161 MDVals.emplace_back(ConstantAsMetadata::get( 162 ConstantInt::get(Type::getInt64Ty(Ctx), Value))); 163 break; 164 case EntryPropsTag::ShaderKind: 165 MDVals.emplace_back(ConstantAsMetadata::get( 166 ConstantInt::get(Type::getInt32Ty(Ctx), Value))); 167 break; 168 case EntryPropsTag::GSState: 169 case EntryPropsTag::DSState: 170 case EntryPropsTag::HSState: 171 case EntryPropsTag::NumThreads: 172 case EntryPropsTag::AutoBindingSpace: 173 case EntryPropsTag::RayPayloadSize: 174 case EntryPropsTag::RayAttribSize: 175 case EntryPropsTag::MSState: 176 case EntryPropsTag::ASStateTag: 177 case EntryPropsTag::WaveSize: 178 case EntryPropsTag::EntryRootSig: 179 llvm_unreachable("NYI: Unhandled entry property tag"); 180 } 181 return MDVals; 182 } 183 184 static MDTuple * 185 getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags, 186 const Triple::EnvironmentType ShaderProfile) { 187 SmallVector<Metadata *> MDVals; 188 LLVMContext &Ctx = EP.Entry->getContext(); 189 if (EntryShaderFlags != 0) 190 MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderFlags, 191 EntryShaderFlags, Ctx)); 192 193 if (EP.Entry != nullptr) { 194 // FIXME: support more props. 195 // See https://github.com/llvm/llvm-project/issues/57948. 196 // Add shader kind for lib entries. 197 if (ShaderProfile == Triple::EnvironmentType::Library && 198 EP.ShaderStage != Triple::EnvironmentType::Library) 199 MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderKind, 200 getShaderStage(EP.ShaderStage), Ctx)); 201 202 if (EP.ShaderStage == Triple::EnvironmentType::Compute) { 203 MDVals.emplace_back(ConstantAsMetadata::get(ConstantInt::get( 204 Type::getInt32Ty(Ctx), static_cast<int>(EntryPropsTag::NumThreads)))); 205 Metadata *NumThreadVals[] = {ConstantAsMetadata::get(ConstantInt::get( 206 Type::getInt32Ty(Ctx), EP.NumThreadsX)), 207 ConstantAsMetadata::get(ConstantInt::get( 208 Type::getInt32Ty(Ctx), EP.NumThreadsY)), 209 ConstantAsMetadata::get(ConstantInt::get( 210 Type::getInt32Ty(Ctx), EP.NumThreadsZ))}; 211 MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals)); 212 } 213 } 214 if (MDVals.empty()) 215 return nullptr; 216 return MDNode::get(Ctx, MDVals); 217 } 218 219 MDTuple *constructEntryMetadata(const Function *EntryFn, MDTuple *Signatures, 220 MDNode *Resources, MDTuple *Properties, 221 LLVMContext &Ctx) { 222 // Each entry point metadata record specifies: 223 // * reference to the entry point function global symbol 224 // * unmangled name 225 // * list of signatures 226 // * list of resources 227 // * list of tag-value pairs of shader capabilities and other properties 228 Metadata *MDVals[5]; 229 MDVals[0] = 230 EntryFn ? ValueAsMetadata::get(const_cast<Function *>(EntryFn)) : nullptr; 231 MDVals[1] = MDString::get(Ctx, EntryFn ? EntryFn->getName() : ""); 232 MDVals[2] = Signatures; 233 MDVals[3] = Resources; 234 MDVals[4] = Properties; 235 return MDNode::get(Ctx, MDVals); 236 } 237 238 static MDTuple *emitEntryMD(const EntryProperties &EP, MDTuple *Signatures, 239 MDNode *MDResources, 240 const uint64_t EntryShaderFlags, 241 const Triple::EnvironmentType ShaderProfile) { 242 MDTuple *Properties = 243 getEntryPropAsMetadata(EP, EntryShaderFlags, ShaderProfile); 244 return constructEntryMetadata(EP.Entry, Signatures, MDResources, Properties, 245 EP.Entry->getContext()); 246 } 247 248 static void emitValidatorVersionMD(Module &M, const ModuleMetadataInfo &MMDI) { 249 if (MMDI.ValidatorVersion.empty()) 250 return; 251 252 LLVMContext &Ctx = M.getContext(); 253 IRBuilder<> IRB(Ctx); 254 Metadata *MDVals[2]; 255 MDVals[0] = 256 ConstantAsMetadata::get(IRB.getInt32(MMDI.ValidatorVersion.getMajor())); 257 MDVals[1] = ConstantAsMetadata::get( 258 IRB.getInt32(MMDI.ValidatorVersion.getMinor().value_or(0))); 259 NamedMDNode *ValVerNode = M.getOrInsertNamedMetadata("dx.valver"); 260 // Set validator version obtained from DXIL Metadata Analysis pass 261 ValVerNode->clearOperands(); 262 ValVerNode->addOperand(MDNode::get(Ctx, MDVals)); 263 } 264 265 static void emitShaderModelVersionMD(Module &M, 266 const ModuleMetadataInfo &MMDI) { 267 LLVMContext &Ctx = M.getContext(); 268 IRBuilder<> IRB(Ctx); 269 Metadata *SMVals[3]; 270 VersionTuple SM = MMDI.ShaderModelVersion; 271 SMVals[0] = MDString::get(Ctx, getShortShaderStage(MMDI.ShaderProfile)); 272 SMVals[1] = ConstantAsMetadata::get(IRB.getInt32(SM.getMajor())); 273 SMVals[2] = ConstantAsMetadata::get(IRB.getInt32(SM.getMinor().value_or(0))); 274 NamedMDNode *SMMDNode = M.getOrInsertNamedMetadata("dx.shaderModel"); 275 SMMDNode->addOperand(MDNode::get(Ctx, SMVals)); 276 } 277 278 static void emitDXILVersionTupleMD(Module &M, const ModuleMetadataInfo &MMDI) { 279 LLVMContext &Ctx = M.getContext(); 280 IRBuilder<> IRB(Ctx); 281 VersionTuple DXILVer = MMDI.DXILVersion; 282 Metadata *DXILVals[2]; 283 DXILVals[0] = ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMajor())); 284 DXILVals[1] = 285 ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMinor().value_or(0))); 286 NamedMDNode *DXILVerMDNode = M.getOrInsertNamedMetadata("dx.version"); 287 DXILVerMDNode->addOperand(MDNode::get(Ctx, DXILVals)); 288 } 289 290 static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD, 291 uint64_t ShaderFlags) { 292 LLVMContext &Ctx = M.getContext(); 293 MDTuple *Properties = nullptr; 294 if (ShaderFlags != 0) { 295 SmallVector<Metadata *> MDVals; 296 MDVals.append( 297 getTagValueAsMetadata(EntryPropsTag::ShaderFlags, ShaderFlags, Ctx)); 298 Properties = MDNode::get(Ctx, MDVals); 299 } 300 // Library has an entry metadata with resource table metadata and all other 301 // MDNodes as null. 302 return constructEntryMetadata(nullptr, nullptr, RMD, Properties, Ctx); 303 } 304 305 // TODO: We might need to refactor this to be more generic, 306 // in case we need more metadata to be replaced. 307 static void translateBranchMetadata(Module &M) { 308 for (Function &F : M) { 309 for (BasicBlock &BB : F) { 310 Instruction *BBTerminatorInst = BB.getTerminator(); 311 312 MDNode *HlslControlFlowMD = 313 BBTerminatorInst->getMetadata("hlsl.controlflow.hint"); 314 315 if (!HlslControlFlowMD) 316 continue; 317 318 assert(HlslControlFlowMD->getNumOperands() == 2 && 319 "invalid operands for hlsl.controlflow.hint"); 320 321 MDBuilder MDHelper(M.getContext()); 322 ConstantInt *Op1 = 323 mdconst::extract<ConstantInt>(HlslControlFlowMD->getOperand(1)); 324 325 SmallVector<llvm::Metadata *, 2> Vals( 326 ArrayRef<Metadata *>{MDHelper.createString("dx.controlflow.hints"), 327 MDHelper.createConstant(Op1)}); 328 329 MDNode *MDNode = llvm::MDNode::get(M.getContext(), Vals); 330 331 BBTerminatorInst->setMetadata("dx.controlflow.hints", MDNode); 332 BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr); 333 } 334 } 335 } 336 337 static void translateMetadata(Module &M, DXILBindingMap &DBM, 338 DXILResourceTypeMap &DRTM, 339 const Resources &MDResources, 340 const ModuleShaderFlags &ShaderFlags, 341 const ModuleMetadataInfo &MMDI) { 342 LLVMContext &Ctx = M.getContext(); 343 IRBuilder<> IRB(Ctx); 344 SmallVector<MDNode *> EntryFnMDNodes; 345 346 emitValidatorVersionMD(M, MMDI); 347 emitShaderModelVersionMD(M, MMDI); 348 emitDXILVersionTupleMD(M, MMDI); 349 NamedMDNode *NamedResourceMD = 350 emitResourceMetadata(M, DBM, DRTM, MDResources); 351 auto *ResourceMD = 352 (NamedResourceMD != nullptr) ? NamedResourceMD->getOperand(0) : nullptr; 353 // FIXME: Add support to construct Signatures 354 // See https://github.com/llvm/llvm-project/issues/57928 355 MDTuple *Signatures = nullptr; 356 357 if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) { 358 // Get the combined shader flag mask of all functions in the library to be 359 // used as shader flags mask value associated with top-level library entry 360 // metadata. 361 uint64_t CombinedMask = ShaderFlags.getCombinedFlags(); 362 EntryFnMDNodes.emplace_back( 363 emitTopLevelLibraryNode(M, ResourceMD, CombinedMask)); 364 } else if (MMDI.EntryPropertyVec.size() > 1) { 365 M.getContext().diagnose(DiagnosticInfoTranslateMD( 366 M, "Non-library shader: One and only one entry expected")); 367 } 368 369 for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) { 370 const ComputedShaderFlags &EntrySFMask = 371 ShaderFlags.getFunctionFlags(EntryProp.Entry); 372 373 // If ShaderProfile is Library, mask is already consolidated in the 374 // top-level library node. Hence it is not emitted. 375 uint64_t EntryShaderFlags = 0; 376 if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) { 377 EntryShaderFlags = EntrySFMask; 378 if (EntryProp.ShaderStage != MMDI.ShaderProfile) { 379 M.getContext().diagnose(DiagnosticInfoTranslateMD( 380 M, 381 "Shader stage '" + 382 Twine(getShortShaderStage(EntryProp.ShaderStage) + 383 "' for entry '" + Twine(EntryProp.Entry->getName()) + 384 "' different from specified target profile '" + 385 Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) + 386 "'")))); 387 } 388 } 389 EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD, 390 EntryShaderFlags, 391 MMDI.ShaderProfile)); 392 } 393 394 NamedMDNode *EntryPointsNamedMD = 395 M.getOrInsertNamedMetadata("dx.entryPoints"); 396 for (auto *Entry : EntryFnMDNodes) 397 EntryPointsNamedMD->addOperand(Entry); 398 } 399 400 PreservedAnalyses DXILTranslateMetadata::run(Module &M, 401 ModuleAnalysisManager &MAM) { 402 DXILBindingMap &DBM = MAM.getResult<DXILResourceBindingAnalysis>(M); 403 DXILResourceTypeMap &DRTM = MAM.getResult<DXILResourceTypeAnalysis>(M); 404 const dxil::Resources &MDResources = MAM.getResult<DXILResourceMDAnalysis>(M); 405 const ModuleShaderFlags &ShaderFlags = MAM.getResult<ShaderFlagsAnalysis>(M); 406 const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M); 407 408 translateMetadata(M, DBM, DRTM, MDResources, ShaderFlags, MMDI); 409 translateBranchMetadata(M); 410 411 return PreservedAnalyses::all(); 412 } 413 414 namespace { 415 class DXILTranslateMetadataLegacy : public ModulePass { 416 public: 417 static char ID; // Pass identification, replacement for typeid 418 explicit DXILTranslateMetadataLegacy() : ModulePass(ID) {} 419 420 StringRef getPassName() const override { return "DXIL Translate Metadata"; } 421 422 void getAnalysisUsage(AnalysisUsage &AU) const override { 423 AU.addRequired<DXILResourceTypeWrapperPass>(); 424 AU.addRequired<DXILResourceBindingWrapperPass>(); 425 AU.addRequired<DXILResourceMDWrapper>(); 426 AU.addRequired<ShaderFlagsAnalysisWrapper>(); 427 AU.addRequired<DXILMetadataAnalysisWrapperPass>(); 428 AU.addPreserved<DXILResourceBindingWrapperPass>(); 429 AU.addPreserved<DXILResourceMDWrapper>(); 430 AU.addPreserved<DXILMetadataAnalysisWrapperPass>(); 431 AU.addPreserved<ShaderFlagsAnalysisWrapper>(); 432 } 433 434 bool runOnModule(Module &M) override { 435 DXILBindingMap &DBM = 436 getAnalysis<DXILResourceBindingWrapperPass>().getBindingMap(); 437 DXILResourceTypeMap &DRTM = 438 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap(); 439 const dxil::Resources &MDResources = 440 getAnalysis<DXILResourceMDWrapper>().getDXILResource(); 441 const ModuleShaderFlags &ShaderFlags = 442 getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags(); 443 dxil::ModuleMetadataInfo MMDI = 444 getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata(); 445 446 translateMetadata(M, DBM, DRTM, MDResources, ShaderFlags, MMDI); 447 translateBranchMetadata(M); 448 return true; 449 } 450 }; 451 452 } // namespace 453 454 char DXILTranslateMetadataLegacy::ID = 0; 455 456 ModulePass *llvm::createDXILTranslateMetadataLegacyPass() { 457 return new DXILTranslateMetadataLegacy(); 458 } 459 460 INITIALIZE_PASS_BEGIN(DXILTranslateMetadataLegacy, "dxil-translate-metadata", 461 "DXIL Translate Metadata", false, false) 462 INITIALIZE_PASS_DEPENDENCY(DXILResourceBindingWrapperPass) 463 INITIALIZE_PASS_DEPENDENCY(DXILResourceMDWrapper) 464 INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper) 465 INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass) 466 INITIALIZE_PASS_END(DXILTranslateMetadataLegacy, "dxil-translate-metadata", 467 "DXIL Translate Metadata", false, false) 468