1 //===- DXILShaderFlags.h - DXIL Shader Flags helper objects ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file This file contains helper objects and APIs for working with DXIL 10 /// Shader Flags. 11 /// 12 //===----------------------------------------------------------------------===// 13 14 #ifndef LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H 15 #define LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H 16 17 #include "llvm/IR/Function.h" 18 #include "llvm/IR/PassManager.h" 19 #include "llvm/Pass.h" 20 #include "llvm/Support/Compiler.h" 21 #include "llvm/Support/Debug.h" 22 #include "llvm/Support/raw_ostream.h" 23 #include <cstdint> 24 #include <memory> 25 26 namespace llvm { 27 class Module; 28 class GlobalVariable; 29 class DXILResourceTypeMap; 30 31 namespace dxil { 32 33 struct ComputedShaderFlags { 34 #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \ 35 bool FlagName : 1; 36 #define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) bool FlagName : 1; 37 #include "llvm/BinaryFormat/DXContainerConstants.def" 38 39 #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \ 40 FlagName = false; 41 #define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) FlagName = false; 42 ComputedShaderFlags() { 43 #include "llvm/BinaryFormat/DXContainerConstants.def" 44 } 45 46 constexpr uint64_t getMask(int Bit) const { 47 return Bit != -1 ? 1ull << Bit : 0; 48 } 49 50 uint64_t getModuleFlags() const { 51 uint64_t ModuleFlags = 0; 52 #define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \ 53 ModuleFlags |= FlagName ? getMask(DxilModuleBit) : 0ull; 54 #include "llvm/BinaryFormat/DXContainerConstants.def" 55 return ModuleFlags; 56 } 57 58 operator uint64_t() const { 59 uint64_t FlagValue = getModuleFlags(); 60 #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \ 61 FlagValue |= FlagName ? getMask(DxilModuleBit) : 0ull; 62 #include "llvm/BinaryFormat/DXContainerConstants.def" 63 return FlagValue; 64 } 65 66 uint64_t getFeatureFlags() const { 67 uint64_t FeatureFlags = 0; 68 #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \ 69 FeatureFlags |= FlagName ? getMask(FeatureBit) : 0ull; 70 #include "llvm/BinaryFormat/DXContainerConstants.def" 71 return FeatureFlags; 72 } 73 74 void merge(const ComputedShaderFlags CSF) { 75 #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \ 76 FlagName |= CSF.FlagName; 77 #define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) FlagName |= CSF.FlagName; 78 #include "llvm/BinaryFormat/DXContainerConstants.def" 79 } 80 81 void print(raw_ostream &OS = dbgs()) const; 82 LLVM_DUMP_METHOD void dump() const { print(); } 83 }; 84 85 struct ModuleShaderFlags { 86 void initialize(Module &, DXILResourceTypeMap &DRTM); 87 const ComputedShaderFlags &getFunctionFlags(const Function *) const; 88 const ComputedShaderFlags &getCombinedFlags() const { return CombinedSFMask; } 89 90 private: 91 /// Map of Function-Shader Flag Mask pairs representing properties of each of 92 /// the functions in the module. Shader Flags of each function represent both 93 /// module-level and function-level flags 94 DenseMap<const Function *, ComputedShaderFlags> FunctionFlags; 95 /// Combined Shader Flag Mask of all functions of the module 96 ComputedShaderFlags CombinedSFMask{}; 97 void updateFunctionFlags(ComputedShaderFlags &, const Instruction &, 98 DXILResourceTypeMap &); 99 }; 100 101 class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> { 102 friend AnalysisInfoMixin<ShaderFlagsAnalysis>; 103 static AnalysisKey Key; 104 105 public: 106 ShaderFlagsAnalysis() = default; 107 108 using Result = ModuleShaderFlags; 109 110 ModuleShaderFlags run(Module &M, ModuleAnalysisManager &AM); 111 }; 112 113 /// Printer pass for ShaderFlagsAnalysis results. 114 class ShaderFlagsAnalysisPrinter 115 : public PassInfoMixin<ShaderFlagsAnalysisPrinter> { 116 raw_ostream &OS; 117 118 public: 119 explicit ShaderFlagsAnalysisPrinter(raw_ostream &OS) : OS(OS) {} 120 PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM); 121 }; 122 123 /// Wrapper pass for the legacy pass manager. 124 /// 125 /// This is required because the passes that will depend on this are codegen 126 /// passes which run through the legacy pass manager. 127 class ShaderFlagsAnalysisWrapper : public ModulePass { 128 ModuleShaderFlags MSFI; 129 130 public: 131 static char ID; 132 133 ShaderFlagsAnalysisWrapper() : ModulePass(ID) {} 134 135 const ModuleShaderFlags &getShaderFlags() { return MSFI; } 136 137 bool runOnModule(Module &M) override; 138 139 void getAnalysisUsage(AnalysisUsage &AU) const override; 140 }; 141 142 } // namespace dxil 143 } // namespace llvm 144 145 #endif // LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H 146