xref: /llvm-project/clang/lib/CodeGen/CGHLSLRuntime.h (revision 719f0d92538c917306004e541f38c79717d0c07d)
1 //===----- CGHLSLRuntime.h - Interface to HLSL Runtimes -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This provides an abstract class for HLSL code generation.  Concrete
10 // subclasses of this implement code generation for specific HLSL
11 // runtime libraries.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef LLVM_CLANG_LIB_CODEGEN_CGHLSLRUNTIME_H
16 #define LLVM_CLANG_LIB_CODEGEN_CGHLSLRUNTIME_H
17 
18 #include "llvm/IR/IRBuilder.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/IntrinsicsDirectX.h"
21 #include "llvm/IR/IntrinsicsSPIRV.h"
22 
23 #include "clang/Basic/Builtins.h"
24 #include "clang/Basic/HLSLRuntime.h"
25 
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Frontend/HLSL/HLSLResource.h"
29 
30 #include <optional>
31 #include <vector>
32 
33 // A function generator macro for picking the right intrinsic
34 // for the target backend
35 #define GENERATE_HLSL_INTRINSIC_FUNCTION(FunctionName, IntrinsicPostfix)       \
36   llvm::Intrinsic::ID get##FunctionName##Intrinsic() {                         \
37     llvm::Triple::ArchType Arch = getArch();                                   \
38     switch (Arch) {                                                            \
39     case llvm::Triple::dxil:                                                   \
40       return llvm::Intrinsic::dx_##IntrinsicPostfix;                           \
41     case llvm::Triple::spirv:                                                  \
42       return llvm::Intrinsic::spv_##IntrinsicPostfix;                          \
43     default:                                                                   \
44       llvm_unreachable("Intrinsic " #IntrinsicPostfix                          \
45                        " not supported by target architecture");               \
46     }                                                                          \
47   }
48 
49 namespace llvm {
50 class GlobalVariable;
51 class Function;
52 class StructType;
53 } // namespace llvm
54 
55 namespace clang {
56 class VarDecl;
57 class ParmVarDecl;
58 class HLSLBufferDecl;
59 class HLSLResourceBindingAttr;
60 class Type;
61 class DeclContext;
62 
63 class FunctionDecl;
64 
65 namespace CodeGen {
66 
67 class CodeGenModule;
68 
69 class CGHLSLRuntime {
70 public:
71   //===----------------------------------------------------------------------===//
72   // Start of reserved area for HLSL intrinsic getters.
73   //===----------------------------------------------------------------------===//
74 
75   GENERATE_HLSL_INTRINSIC_FUNCTION(All, all)
76   GENERATE_HLSL_INTRINSIC_FUNCTION(Any, any)
77   GENERATE_HLSL_INTRINSIC_FUNCTION(Cross, cross)
78   GENERATE_HLSL_INTRINSIC_FUNCTION(Degrees, degrees)
79   GENERATE_HLSL_INTRINSIC_FUNCTION(Frac, frac)
80   GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
81   GENERATE_HLSL_INTRINSIC_FUNCTION(Normalize, normalize)
82   GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
83   GENERATE_HLSL_INTRINSIC_FUNCTION(Saturate, saturate)
84   GENERATE_HLSL_INTRINSIC_FUNCTION(Sign, sign)
85   GENERATE_HLSL_INTRINSIC_FUNCTION(Step, step)
86   GENERATE_HLSL_INTRINSIC_FUNCTION(Radians, radians)
87   GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
88   GENERATE_HLSL_INTRINSIC_FUNCTION(GroupThreadId, thread_id_in_group)
89   GENERATE_HLSL_INTRINSIC_FUNCTION(GroupId, group_id)
90   GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
91   GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
92   GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
93   GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed)
94   GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
95   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all)
96   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
97   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
98   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
99   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
100   GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitUHigh, firstbituhigh)
101   GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitSHigh, firstbitshigh)
102   GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitLow, firstbitlow)
103   GENERATE_HLSL_INTRINSIC_FUNCTION(NClamp, nclamp)
104   GENERATE_HLSL_INTRINSIC_FUNCTION(SClamp, sclamp)
105   GENERATE_HLSL_INTRINSIC_FUNCTION(UClamp, uclamp)
106 
107   GENERATE_HLSL_INTRINSIC_FUNCTION(CreateResourceGetPointer,
108                                    resource_getpointer)
109   GENERATE_HLSL_INTRINSIC_FUNCTION(CreateHandleFromBinding,
110                                    resource_handlefrombinding)
111   GENERATE_HLSL_INTRINSIC_FUNCTION(BufferUpdateCounter, resource_updatecounter)
112   GENERATE_HLSL_INTRINSIC_FUNCTION(GroupMemoryBarrierWithGroupSync,
113                                    group_memory_barrier_with_group_sync)
114 
115   //===----------------------------------------------------------------------===//
116   // End of reserved area for HLSL intrinsic getters.
117   //===----------------------------------------------------------------------===//
118 
119   struct BufferResBinding {
120     // The ID like 2 in register(b2, space1).
121     std::optional<unsigned> Reg;
122     // The Space like 1 is register(b2, space1).
123     // Default value is 0.
124     unsigned Space;
125     BufferResBinding(HLSLResourceBindingAttr *Attr);
126   };
127   struct Buffer {
128     Buffer(const HLSLBufferDecl *D);
129     llvm::StringRef Name;
130     // IsCBuffer - Whether the buffer is a cbuffer (and not a tbuffer).
131     bool IsCBuffer;
132     BufferResBinding Binding;
133     // Global variable and offset for each constant.
134     std::vector<std::pair<llvm::GlobalVariable *, unsigned>> Constants;
135     llvm::StructType *LayoutStruct = nullptr;
136   };
137 
138 protected:
139   CodeGenModule &CGM;
140 
141   llvm::Value *emitInputSemantic(llvm::IRBuilder<> &B, const ParmVarDecl &D,
142                                  llvm::Type *Ty);
143 
144 public:
145   CGHLSLRuntime(CodeGenModule &CGM) : CGM(CGM) {}
146   virtual ~CGHLSLRuntime() {}
147 
148   llvm::Type *convertHLSLSpecificType(const Type *T);
149 
150   void annotateHLSLResource(const VarDecl *D, llvm::GlobalVariable *GV);
151   void generateGlobalCtorDtorCalls();
152 
153   void addBuffer(const HLSLBufferDecl *D);
154   void finishCodeGen();
155 
156   void setHLSLEntryAttributes(const FunctionDecl *FD, llvm::Function *Fn);
157 
158   void emitEntryFunction(const FunctionDecl *FD, llvm::Function *Fn);
159   void setHLSLFunctionAttributes(const FunctionDecl *FD, llvm::Function *Fn);
160   void handleGlobalVarDefinition(const VarDecl *VD, llvm::GlobalVariable *Var);
161 
162   llvm::Instruction *getConvergenceToken(llvm::BasicBlock &BB);
163 
164 private:
165   void addBufferResourceAnnotation(llvm::GlobalVariable *GV,
166                                    llvm::hlsl::ResourceClass RC,
167                                    llvm::hlsl::ResourceKind RK, bool IsROV,
168                                    llvm::hlsl::ElementType ET,
169                                    BufferResBinding &Binding);
170   void addConstant(VarDecl *D, Buffer &CB);
171   void addBufferDecls(const DeclContext *DC, Buffer &CB);
172   llvm::Triple::ArchType getArch();
173   llvm::SmallVector<Buffer> Buffers;
174 };
175 
176 } // namespace CodeGen
177 } // namespace clang
178 
179 #endif
180