1 //===- DXILShaderFlags.cpp - DXIL Shader Flags helper objects -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file This file contains helper objects and APIs for working with DXIL 10 /// Shader Flags. 11 /// 12 //===----------------------------------------------------------------------===// 13 14 #include "DXILShaderFlags.h" 15 #include "DirectX.h" 16 #include "llvm/ADT/SCCIterator.h" 17 #include "llvm/ADT/SmallVector.h" 18 #include "llvm/Analysis/CallGraph.h" 19 #include "llvm/Analysis/DXILResource.h" 20 #include "llvm/IR/Instruction.h" 21 #include "llvm/IR/Instructions.h" 22 #include "llvm/IR/IntrinsicInst.h" 23 #include "llvm/IR/Intrinsics.h" 24 #include "llvm/IR/IntrinsicsDirectX.h" 25 #include "llvm/IR/Module.h" 26 #include "llvm/InitializePasses.h" 27 #include "llvm/Support/FormatVariadic.h" 28 #include "llvm/Support/raw_ostream.h" 29 30 using namespace llvm; 31 using namespace llvm::dxil; 32 33 /// Update the shader flags mask based on the given instruction. 34 /// \param CSF Shader flags mask to update. 35 /// \param I Instruction to check. 36 void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF, 37 const Instruction &I, 38 DXILResourceTypeMap &DRTM) { 39 if (!CSF.Doubles) 40 CSF.Doubles = I.getType()->isDoubleTy(); 41 42 if (!CSF.Doubles) { 43 for (const Value *Op : I.operands()) { 44 if (Op->getType()->isDoubleTy()) { 45 CSF.Doubles = true; 46 break; 47 } 48 } 49 } 50 51 if (CSF.Doubles) { 52 switch (I.getOpcode()) { 53 case Instruction::FDiv: 54 case Instruction::UIToFP: 55 case Instruction::SIToFP: 56 case Instruction::FPToUI: 57 case Instruction::FPToSI: 58 CSF.DX11_1_DoubleExtensions = true; 59 break; 60 } 61 } 62 63 if (auto *II = dyn_cast<IntrinsicInst>(&I)) { 64 switch (II->getIntrinsicID()) { 65 default: 66 break; 67 case Intrinsic::dx_resource_handlefrombinding: 68 switch (DRTM[cast<TargetExtType>(II->getType())].getResourceKind()) { 69 case dxil::ResourceKind::StructuredBuffer: 70 case dxil::ResourceKind::RawBuffer: 71 CSF.EnableRawAndStructuredBuffers = true; 72 break; 73 default: 74 break; 75 } 76 break; 77 case Intrinsic::dx_resource_load_typedbuffer: { 78 dxil::ResourceTypeInfo &RTI = 79 DRTM[cast<TargetExtType>(II->getArgOperand(0)->getType())]; 80 if (RTI.isTyped()) 81 CSF.TypedUAVLoadAdditionalFormats |= RTI.getTyped().ElementCount > 1; 82 break; 83 } 84 } 85 } 86 // Handle call instructions 87 if (auto *CI = dyn_cast<CallInst>(&I)) { 88 const Function *CF = CI->getCalledFunction(); 89 // Merge-in shader flags mask of the called function in the current module 90 if (FunctionFlags.contains(CF)) 91 CSF.merge(FunctionFlags[CF]); 92 93 // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic 94 // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554 95 } 96 } 97 98 /// Construct ModuleShaderFlags for module Module M 99 void ModuleShaderFlags::initialize(Module &M, DXILResourceTypeMap &DRTM) { 100 CallGraph CG(M); 101 102 // Compute Shader Flags Mask for all functions using post-order visit of SCC 103 // of the call graph. 104 for (scc_iterator<CallGraph *> SCCI = scc_begin(&CG); !SCCI.isAtEnd(); 105 ++SCCI) { 106 const std::vector<CallGraphNode *> &CurSCC = *SCCI; 107 108 // Union of shader masks of all functions in CurSCC 109 ComputedShaderFlags SCCSF; 110 // List of functions in CurSCC that are neither external nor declarations 111 // and hence whose flags are collected 112 SmallVector<Function *> CurSCCFuncs; 113 for (CallGraphNode *CGN : CurSCC) { 114 Function *F = CGN->getFunction(); 115 if (!F) 116 continue; 117 118 if (F->isDeclaration()) { 119 assert(!F->getName().starts_with("dx.op.") && 120 "DXIL Shader Flag analysis should not be run post-lowering."); 121 continue; 122 } 123 124 ComputedShaderFlags CSF; 125 for (const auto &BB : *F) 126 for (const auto &I : BB) 127 updateFunctionFlags(CSF, I, DRTM); 128 // Update combined shader flags mask for all functions in this SCC 129 SCCSF.merge(CSF); 130 131 CurSCCFuncs.push_back(F); 132 } 133 134 // Update combined shader flags mask for all functions of the module 135 CombinedSFMask.merge(SCCSF); 136 137 // Shader flags mask of each of the functions in an SCC of the call graph is 138 // the union of all functions in the SCC. Update shader flags masks of 139 // functions in CurSCC accordingly. This is trivially true if SCC contains 140 // one function. 141 for (Function *F : CurSCCFuncs) 142 // Merge SCCSF with that of F 143 FunctionFlags[F].merge(SCCSF); 144 } 145 } 146 147 void ComputedShaderFlags::print(raw_ostream &OS) const { 148 uint64_t FlagVal = (uint64_t) * this; 149 OS << formatv("; Shader Flags Value: {0:x8}\n;\n", FlagVal); 150 if (FlagVal == 0) 151 return; 152 OS << "; Note: shader requires additional functionality:\n"; 153 #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleNum, FlagName, Str) \ 154 if (FlagName) \ 155 (OS << ";").indent(7) << Str << "\n"; 156 #include "llvm/BinaryFormat/DXContainerConstants.def" 157 OS << "; Note: extra DXIL module flags:\n"; 158 #define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \ 159 if (FlagName) \ 160 (OS << ";").indent(7) << Str << "\n"; 161 #include "llvm/BinaryFormat/DXContainerConstants.def" 162 OS << ";\n"; 163 } 164 165 /// Return the shader flags mask of the specified function Func. 166 const ComputedShaderFlags & 167 ModuleShaderFlags::getFunctionFlags(const Function *Func) const { 168 auto Iter = FunctionFlags.find(Func); 169 assert((Iter != FunctionFlags.end() && Iter->first == Func) && 170 "Get Shader Flags : No Shader Flags Mask exists for function"); 171 return Iter->second; 172 } 173 174 //===----------------------------------------------------------------------===// 175 // ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass 176 177 // Provide an explicit template instantiation for the static ID. 178 AnalysisKey ShaderFlagsAnalysis::Key; 179 180 ModuleShaderFlags ShaderFlagsAnalysis::run(Module &M, 181 ModuleAnalysisManager &AM) { 182 DXILResourceTypeMap &DRTM = AM.getResult<DXILResourceTypeAnalysis>(M); 183 184 ModuleShaderFlags MSFI; 185 MSFI.initialize(M, DRTM); 186 187 return MSFI; 188 } 189 190 PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M, 191 ModuleAnalysisManager &AM) { 192 const ModuleShaderFlags &FlagsInfo = AM.getResult<ShaderFlagsAnalysis>(M); 193 // Print description of combined shader flags for all module functions 194 OS << "; Combined Shader Flags for Module\n"; 195 FlagsInfo.getCombinedFlags().print(OS); 196 // Print shader flags mask for each of the module functions 197 OS << "; Shader Flags for Module Functions\n"; 198 for (const auto &F : M.getFunctionList()) { 199 if (F.isDeclaration()) 200 continue; 201 const ComputedShaderFlags &SFMask = FlagsInfo.getFunctionFlags(&F); 202 OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(), 203 (uint64_t)(SFMask)); 204 } 205 206 return PreservedAnalyses::all(); 207 } 208 209 //===----------------------------------------------------------------------===// 210 // ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass 211 212 bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) { 213 DXILResourceTypeMap &DRTM = 214 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap(); 215 216 MSFI.initialize(M, DRTM); 217 return false; 218 } 219 220 void ShaderFlagsAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const { 221 AU.setPreservesAll(); 222 AU.addRequiredTransitive<DXILResourceTypeWrapperPass>(); 223 } 224 225 char ShaderFlagsAnalysisWrapper::ID = 0; 226 227 INITIALIZE_PASS_BEGIN(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis", 228 "DXIL Shader Flag Analysis", true, true) 229 INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass) 230 INITIALIZE_PASS_END(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis", 231 "DXIL Shader Flag Analysis", true, true) 232