xref: /llvm-project/llvm/lib/Target/DirectX/DXILShaderFlags.h (revision a4b7a2d021ca7371752f0e8180200ffd7b48ca70)
1 //===- DXILShaderFlags.h - DXIL Shader Flags helper objects ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains helper objects and APIs for working with DXIL
10 ///       Shader Flags.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
15 #define LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
16 
17 #include "llvm/IR/Function.h"
18 #include "llvm/IR/PassManager.h"
19 #include "llvm/Pass.h"
20 #include "llvm/Support/Compiler.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include <cstdint>
24 #include <memory>
25 
26 namespace llvm {
27 class Module;
28 class GlobalVariable;
29 class DXILResourceTypeMap;
30 
31 namespace dxil {
32 
33 struct ComputedShaderFlags {
34 #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str)          \
35   bool FlagName : 1;
36 #define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) bool FlagName : 1;
37 #include "llvm/BinaryFormat/DXContainerConstants.def"
38 
39 #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str)          \
40   FlagName = false;
41 #define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) FlagName = false;
42   ComputedShaderFlags() {
43 #include "llvm/BinaryFormat/DXContainerConstants.def"
44   }
45 
46   constexpr uint64_t getMask(int Bit) const {
47     return Bit != -1 ? 1ull << Bit : 0;
48   }
49 
50   uint64_t getModuleFlags() const {
51     uint64_t ModuleFlags = 0;
52 #define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str)                         \
53   ModuleFlags |= FlagName ? getMask(DxilModuleBit) : 0ull;
54 #include "llvm/BinaryFormat/DXContainerConstants.def"
55     return ModuleFlags;
56   }
57 
58   operator uint64_t() const {
59     uint64_t FlagValue = getModuleFlags();
60 #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str)          \
61   FlagValue |= FlagName ? getMask(DxilModuleBit) : 0ull;
62 #include "llvm/BinaryFormat/DXContainerConstants.def"
63     return FlagValue;
64   }
65 
66   uint64_t getFeatureFlags() const {
67     uint64_t FeatureFlags = 0;
68 #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str)          \
69   FeatureFlags |= FlagName ? getMask(FeatureBit) : 0ull;
70 #include "llvm/BinaryFormat/DXContainerConstants.def"
71     return FeatureFlags;
72   }
73 
74   void merge(const ComputedShaderFlags CSF) {
75 #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str)          \
76   FlagName |= CSF.FlagName;
77 #define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) FlagName |= CSF.FlagName;
78 #include "llvm/BinaryFormat/DXContainerConstants.def"
79   }
80 
81   void print(raw_ostream &OS = dbgs()) const;
82   LLVM_DUMP_METHOD void dump() const { print(); }
83 };
84 
85 struct ModuleShaderFlags {
86   void initialize(Module &, DXILResourceTypeMap &DRTM);
87   const ComputedShaderFlags &getFunctionFlags(const Function *) const;
88   const ComputedShaderFlags &getCombinedFlags() const { return CombinedSFMask; }
89 
90 private:
91   /// Map of Function-Shader Flag Mask pairs representing properties of each of
92   /// the functions in the module. Shader Flags of each function represent both
93   /// module-level and function-level flags
94   DenseMap<const Function *, ComputedShaderFlags> FunctionFlags;
95   /// Combined Shader Flag Mask of all functions of the module
96   ComputedShaderFlags CombinedSFMask{};
97   void updateFunctionFlags(ComputedShaderFlags &, const Instruction &,
98                            DXILResourceTypeMap &);
99 };
100 
101 class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
102   friend AnalysisInfoMixin<ShaderFlagsAnalysis>;
103   static AnalysisKey Key;
104 
105 public:
106   ShaderFlagsAnalysis() = default;
107 
108   using Result = ModuleShaderFlags;
109 
110   ModuleShaderFlags run(Module &M, ModuleAnalysisManager &AM);
111 };
112 
113 /// Printer pass for ShaderFlagsAnalysis results.
114 class ShaderFlagsAnalysisPrinter
115     : public PassInfoMixin<ShaderFlagsAnalysisPrinter> {
116   raw_ostream &OS;
117 
118 public:
119   explicit ShaderFlagsAnalysisPrinter(raw_ostream &OS) : OS(OS) {}
120   PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
121 };
122 
123 /// Wrapper pass for the legacy pass manager.
124 ///
125 /// This is required because the passes that will depend on this are codegen
126 /// passes which run through the legacy pass manager.
127 class ShaderFlagsAnalysisWrapper : public ModulePass {
128   ModuleShaderFlags MSFI;
129 
130 public:
131   static char ID;
132 
133   ShaderFlagsAnalysisWrapper() : ModulePass(ID) {}
134 
135   const ModuleShaderFlags &getShaderFlags() { return MSFI; }
136 
137   bool runOnModule(Module &M) override;
138 
139   void getAnalysisUsage(AnalysisUsage &AU) const override;
140 };
141 
142 } // namespace dxil
143 } // namespace llvm
144 
145 #endif // LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
146