xref: /llvm-project/llvm/lib/Target/DirectX/DXILShaderFlags.cpp (revision b6287fd9714d2a34b27e7ef4953f6e68f39463a4)
1 //===- DXILShaderFlags.cpp - DXIL Shader Flags helper objects -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains helper objects and APIs for working with DXIL
10 ///       Shader Flags.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "DXILShaderFlags.h"
15 #include "DirectX.h"
16 #include "llvm/ADT/SCCIterator.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/CallGraph.h"
19 #include "llvm/Analysis/DXILResource.h"
20 #include "llvm/IR/Instruction.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/IR/Intrinsics.h"
24 #include "llvm/IR/IntrinsicsDirectX.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 using namespace llvm;
31 using namespace llvm::dxil;
32 
33 /// Update the shader flags mask based on the given instruction.
34 /// \param CSF Shader flags mask to update.
35 /// \param I Instruction to check.
36 void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF,
37                                             const Instruction &I,
38                                             DXILResourceTypeMap &DRTM) {
39   if (!CSF.Doubles)
40     CSF.Doubles = I.getType()->isDoubleTy();
41 
42   if (!CSF.Doubles) {
43     for (const Value *Op : I.operands()) {
44       if (Op->getType()->isDoubleTy()) {
45         CSF.Doubles = true;
46         break;
47       }
48     }
49   }
50 
51   if (CSF.Doubles) {
52     switch (I.getOpcode()) {
53     case Instruction::FDiv:
54     case Instruction::UIToFP:
55     case Instruction::SIToFP:
56     case Instruction::FPToUI:
57     case Instruction::FPToSI:
58       CSF.DX11_1_DoubleExtensions = true;
59       break;
60     }
61   }
62 
63   if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
64     switch (II->getIntrinsicID()) {
65     default:
66       break;
67     case Intrinsic::dx_resource_handlefrombinding:
68       switch (DRTM[cast<TargetExtType>(II->getType())].getResourceKind()) {
69       case dxil::ResourceKind::StructuredBuffer:
70       case dxil::ResourceKind::RawBuffer:
71         CSF.EnableRawAndStructuredBuffers = true;
72         break;
73       default:
74         break;
75       }
76       break;
77     case Intrinsic::dx_resource_load_typedbuffer: {
78       dxil::ResourceTypeInfo &RTI =
79           DRTM[cast<TargetExtType>(II->getArgOperand(0)->getType())];
80       if (RTI.isTyped())
81         CSF.TypedUAVLoadAdditionalFormats |= RTI.getTyped().ElementCount > 1;
82       break;
83     }
84     }
85   }
86   // Handle call instructions
87   if (auto *CI = dyn_cast<CallInst>(&I)) {
88     const Function *CF = CI->getCalledFunction();
89     // Merge-in shader flags mask of the called function in the current module
90     if (FunctionFlags.contains(CF))
91       CSF.merge(FunctionFlags[CF]);
92 
93     // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic
94     // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554
95   }
96 }
97 
98 /// Construct ModuleShaderFlags for module Module M
99 void ModuleShaderFlags::initialize(Module &M, DXILResourceTypeMap &DRTM) {
100   CallGraph CG(M);
101 
102   // Compute Shader Flags Mask for all functions using post-order visit of SCC
103   // of the call graph.
104   for (scc_iterator<CallGraph *> SCCI = scc_begin(&CG); !SCCI.isAtEnd();
105        ++SCCI) {
106     const std::vector<CallGraphNode *> &CurSCC = *SCCI;
107 
108     // Union of shader masks of all functions in CurSCC
109     ComputedShaderFlags SCCSF;
110     // List of functions in CurSCC that are neither external nor declarations
111     // and hence whose flags are collected
112     SmallVector<Function *> CurSCCFuncs;
113     for (CallGraphNode *CGN : CurSCC) {
114       Function *F = CGN->getFunction();
115       if (!F)
116         continue;
117 
118       if (F->isDeclaration()) {
119         assert(!F->getName().starts_with("dx.op.") &&
120                "DXIL Shader Flag analysis should not be run post-lowering.");
121         continue;
122       }
123 
124       ComputedShaderFlags CSF;
125       for (const auto &BB : *F)
126         for (const auto &I : BB)
127           updateFunctionFlags(CSF, I, DRTM);
128       // Update combined shader flags mask for all functions in this SCC
129       SCCSF.merge(CSF);
130 
131       CurSCCFuncs.push_back(F);
132     }
133 
134     // Update combined shader flags mask for all functions of the module
135     CombinedSFMask.merge(SCCSF);
136 
137     // Shader flags mask of each of the functions in an SCC of the call graph is
138     // the union of all functions in the SCC. Update shader flags masks of
139     // functions in CurSCC accordingly. This is trivially true if SCC contains
140     // one function.
141     for (Function *F : CurSCCFuncs)
142       // Merge SCCSF with that of F
143       FunctionFlags[F].merge(SCCSF);
144   }
145 }
146 
147 void ComputedShaderFlags::print(raw_ostream &OS) const {
148   uint64_t FlagVal = (uint64_t) * this;
149   OS << formatv("; Shader Flags Value: {0:x8}\n;\n", FlagVal);
150   if (FlagVal == 0)
151     return;
152   OS << "; Note: shader requires additional functionality:\n";
153 #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleNum, FlagName, Str)          \
154   if (FlagName)                                                                \
155     (OS << ";").indent(7) << Str << "\n";
156 #include "llvm/BinaryFormat/DXContainerConstants.def"
157   OS << "; Note: extra DXIL module flags:\n";
158 #define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str)                         \
159   if (FlagName)                                                                \
160     (OS << ";").indent(7) << Str << "\n";
161 #include "llvm/BinaryFormat/DXContainerConstants.def"
162   OS << ";\n";
163 }
164 
165 /// Return the shader flags mask of the specified function Func.
166 const ComputedShaderFlags &
167 ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
168   auto Iter = FunctionFlags.find(Func);
169   assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
170          "Get Shader Flags : No Shader Flags Mask exists for function");
171   return Iter->second;
172 }
173 
174 //===----------------------------------------------------------------------===//
175 // ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
176 
177 // Provide an explicit template instantiation for the static ID.
178 AnalysisKey ShaderFlagsAnalysis::Key;
179 
180 ModuleShaderFlags ShaderFlagsAnalysis::run(Module &M,
181                                            ModuleAnalysisManager &AM) {
182   DXILResourceTypeMap &DRTM = AM.getResult<DXILResourceTypeAnalysis>(M);
183 
184   ModuleShaderFlags MSFI;
185   MSFI.initialize(M, DRTM);
186 
187   return MSFI;
188 }
189 
190 PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
191                                                   ModuleAnalysisManager &AM) {
192   const ModuleShaderFlags &FlagsInfo = AM.getResult<ShaderFlagsAnalysis>(M);
193   // Print description of combined shader flags for all module functions
194   OS << "; Combined Shader Flags for Module\n";
195   FlagsInfo.getCombinedFlags().print(OS);
196   // Print shader flags mask for each of the module functions
197   OS << "; Shader Flags for Module Functions\n";
198   for (const auto &F : M.getFunctionList()) {
199     if (F.isDeclaration())
200       continue;
201     const ComputedShaderFlags &SFMask = FlagsInfo.getFunctionFlags(&F);
202     OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(),
203                   (uint64_t)(SFMask));
204   }
205 
206   return PreservedAnalyses::all();
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
211 
212 bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
213   DXILResourceTypeMap &DRTM =
214       getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
215 
216   MSFI.initialize(M, DRTM);
217   return false;
218 }
219 
220 void ShaderFlagsAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const {
221   AU.setPreservesAll();
222   AU.addRequiredTransitive<DXILResourceTypeWrapperPass>();
223 }
224 
225 char ShaderFlagsAnalysisWrapper::ID = 0;
226 
227 INITIALIZE_PASS_BEGIN(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",
228                       "DXIL Shader Flag Analysis", true, true)
229 INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)
230 INITIALIZE_PASS_END(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",
231                     "DXIL Shader Flag Analysis", true, true)
232