1*12c85518Srobert //===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//
2*12c85518Srobert //
3*12c85518Srobert // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*12c85518Srobert // See https://llvm.org/LICENSE.txt for license information.
5*12c85518Srobert // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*12c85518Srobert //
7*12c85518Srobert //===----------------------------------------------------------------------===//
8*12c85518Srobert //
9*12c85518Srobert // This provides an abstract class for HLSL code generation. Concrete
10*12c85518Srobert // subclasses of this implement code generation for specific HLSL
11*12c85518Srobert // runtime libraries.
12*12c85518Srobert //
13*12c85518Srobert //===----------------------------------------------------------------------===//
14*12c85518Srobert
15*12c85518Srobert #include "CGHLSLRuntime.h"
16*12c85518Srobert #include "CGDebugInfo.h"
17*12c85518Srobert #include "CodeGenModule.h"
18*12c85518Srobert #include "clang/AST/Decl.h"
19*12c85518Srobert #include "clang/Basic/TargetOptions.h"
20*12c85518Srobert #include "llvm/IR/IntrinsicsDirectX.h"
21*12c85518Srobert #include "llvm/IR/Metadata.h"
22*12c85518Srobert #include "llvm/IR/Module.h"
23*12c85518Srobert #include "llvm/Support/FormatVariadic.h"
24*12c85518Srobert
25*12c85518Srobert using namespace clang;
26*12c85518Srobert using namespace CodeGen;
27*12c85518Srobert using namespace clang::hlsl;
28*12c85518Srobert using namespace llvm;
29*12c85518Srobert
30*12c85518Srobert namespace {
31*12c85518Srobert
addDxilValVersion(StringRef ValVersionStr,llvm::Module & M)32*12c85518Srobert void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
33*12c85518Srobert // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.
34*12c85518Srobert // Assume ValVersionStr is legal here.
35*12c85518Srobert VersionTuple Version;
36*12c85518Srobert if (Version.tryParse(ValVersionStr) || Version.getBuild() ||
37*12c85518Srobert Version.getSubminor() || !Version.getMinor()) {
38*12c85518Srobert return;
39*12c85518Srobert }
40*12c85518Srobert
41*12c85518Srobert uint64_t Major = Version.getMajor();
42*12c85518Srobert uint64_t Minor = *Version.getMinor();
43*12c85518Srobert
44*12c85518Srobert auto &Ctx = M.getContext();
45*12c85518Srobert IRBuilder<> B(M.getContext());
46*12c85518Srobert MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)),
47*12c85518Srobert ConstantAsMetadata::get(B.getInt32(Minor))});
48*12c85518Srobert StringRef DXILValKey = "dx.valver";
49*12c85518Srobert auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey);
50*12c85518Srobert DXILValMD->addOperand(Val);
51*12c85518Srobert }
addDisableOptimizations(llvm::Module & M)52*12c85518Srobert void addDisableOptimizations(llvm::Module &M) {
53*12c85518Srobert StringRef Key = "dx.disable_optimizations";
54*12c85518Srobert M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1);
55*12c85518Srobert }
56*12c85518Srobert // cbuffer will be translated into global variable in special address space.
57*12c85518Srobert // If translate into C,
58*12c85518Srobert // cbuffer A {
59*12c85518Srobert // float a;
60*12c85518Srobert // float b;
61*12c85518Srobert // }
62*12c85518Srobert // float foo() { return a + b; }
63*12c85518Srobert //
64*12c85518Srobert // will be translated into
65*12c85518Srobert //
66*12c85518Srobert // struct A {
67*12c85518Srobert // float a;
68*12c85518Srobert // float b;
69*12c85518Srobert // } cbuffer_A __attribute__((address_space(4)));
70*12c85518Srobert // float foo() { return cbuffer_A.a + cbuffer_A.b; }
71*12c85518Srobert //
72*12c85518Srobert // layoutBuffer will create the struct A type.
73*12c85518Srobert // replaceBuffer will replace use of global variable a and b with cbuffer_A.a
74*12c85518Srobert // and cbuffer_A.b.
75*12c85518Srobert //
layoutBuffer(CGHLSLRuntime::Buffer & Buf,const DataLayout & DL)76*12c85518Srobert void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) {
77*12c85518Srobert if (Buf.Constants.empty())
78*12c85518Srobert return;
79*12c85518Srobert
80*12c85518Srobert std::vector<llvm::Type *> EltTys;
81*12c85518Srobert for (auto &Const : Buf.Constants) {
82*12c85518Srobert GlobalVariable *GV = Const.first;
83*12c85518Srobert Const.second = EltTys.size();
84*12c85518Srobert llvm::Type *Ty = GV->getValueType();
85*12c85518Srobert EltTys.emplace_back(Ty);
86*12c85518Srobert }
87*12c85518Srobert Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys);
88*12c85518Srobert }
89*12c85518Srobert
replaceBuffer(CGHLSLRuntime::Buffer & Buf)90*12c85518Srobert GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) {
91*12c85518Srobert // Create global variable for CB.
92*12c85518Srobert GlobalVariable *CBGV = new GlobalVariable(
93*12c85518Srobert Buf.LayoutStruct, /*isConstant*/ true,
94*12c85518Srobert GlobalValue::LinkageTypes::ExternalLinkage, nullptr,
95*12c85518Srobert llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."),
96*12c85518Srobert GlobalValue::NotThreadLocal);
97*12c85518Srobert
98*12c85518Srobert IRBuilder<> B(CBGV->getContext());
99*12c85518Srobert Value *ZeroIdx = B.getInt32(0);
100*12c85518Srobert // Replace Const use with CB use.
101*12c85518Srobert for (auto &[GV, Offset] : Buf.Constants) {
102*12c85518Srobert Value *GEP =
103*12c85518Srobert B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)});
104*12c85518Srobert
105*12c85518Srobert assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() &&
106*12c85518Srobert "constant type mismatch");
107*12c85518Srobert
108*12c85518Srobert // Replace.
109*12c85518Srobert GV->replaceAllUsesWith(GEP);
110*12c85518Srobert // Erase GV.
111*12c85518Srobert GV->removeDeadConstantUsers();
112*12c85518Srobert GV->eraseFromParent();
113*12c85518Srobert }
114*12c85518Srobert return CBGV;
115*12c85518Srobert }
116*12c85518Srobert
117*12c85518Srobert } // namespace
118*12c85518Srobert
addConstant(VarDecl * D,Buffer & CB)119*12c85518Srobert void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) {
120*12c85518Srobert if (D->getStorageClass() == SC_Static) {
121*12c85518Srobert // For static inside cbuffer, take as global static.
122*12c85518Srobert // Don't add to cbuffer.
123*12c85518Srobert CGM.EmitGlobal(D);
124*12c85518Srobert return;
125*12c85518Srobert }
126*12c85518Srobert
127*12c85518Srobert auto *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(D));
128*12c85518Srobert // Add debug info for constVal.
129*12c85518Srobert if (CGDebugInfo *DI = CGM.getModuleDebugInfo())
130*12c85518Srobert if (CGM.getCodeGenOpts().getDebugInfo() >=
131*12c85518Srobert codegenoptions::DebugInfoKind::LimitedDebugInfo)
132*12c85518Srobert DI->EmitGlobalVariable(cast<GlobalVariable>(GV), D);
133*12c85518Srobert
134*12c85518Srobert // FIXME: support packoffset.
135*12c85518Srobert // See https://github.com/llvm/llvm-project/issues/57914.
136*12c85518Srobert uint32_t Offset = 0;
137*12c85518Srobert bool HasUserOffset = false;
138*12c85518Srobert
139*12c85518Srobert unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX;
140*12c85518Srobert CB.Constants.emplace_back(std::make_pair(GV, LowerBound));
141*12c85518Srobert }
142*12c85518Srobert
addBufferDecls(const DeclContext * DC,Buffer & CB)143*12c85518Srobert void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) {
144*12c85518Srobert for (Decl *it : DC->decls()) {
145*12c85518Srobert if (auto *ConstDecl = dyn_cast<VarDecl>(it)) {
146*12c85518Srobert addConstant(ConstDecl, CB);
147*12c85518Srobert } else if (isa<CXXRecordDecl, EmptyDecl>(it)) {
148*12c85518Srobert // Nothing to do for this declaration.
149*12c85518Srobert } else if (isa<FunctionDecl>(it)) {
150*12c85518Srobert // A function within an cbuffer is effectively a top-level function,
151*12c85518Srobert // as it only refers to globally scoped declarations.
152*12c85518Srobert CGM.EmitTopLevelDecl(it);
153*12c85518Srobert }
154*12c85518Srobert }
155*12c85518Srobert }
156*12c85518Srobert
addBuffer(const HLSLBufferDecl * D)157*12c85518Srobert void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *D) {
158*12c85518Srobert Buffers.emplace_back(Buffer(D));
159*12c85518Srobert addBufferDecls(D, Buffers.back());
160*12c85518Srobert }
161*12c85518Srobert
finishCodeGen()162*12c85518Srobert void CGHLSLRuntime::finishCodeGen() {
163*12c85518Srobert auto &TargetOpts = CGM.getTarget().getTargetOpts();
164*12c85518Srobert llvm::Module &M = CGM.getModule();
165*12c85518Srobert Triple T(M.getTargetTriple());
166*12c85518Srobert if (T.getArch() == Triple::ArchType::dxil)
167*12c85518Srobert addDxilValVersion(TargetOpts.DxilValidatorVersion, M);
168*12c85518Srobert
169*12c85518Srobert generateGlobalCtorDtorCalls();
170*12c85518Srobert if (CGM.getCodeGenOpts().OptimizationLevel == 0)
171*12c85518Srobert addDisableOptimizations(M);
172*12c85518Srobert
173*12c85518Srobert const DataLayout &DL = M.getDataLayout();
174*12c85518Srobert
175*12c85518Srobert for (auto &Buf : Buffers) {
176*12c85518Srobert layoutBuffer(Buf, DL);
177*12c85518Srobert GlobalVariable *GV = replaceBuffer(Buf);
178*12c85518Srobert M.getGlobalList().push_back(GV);
179*12c85518Srobert llvm::hlsl::ResourceClass RC = Buf.IsCBuffer
180*12c85518Srobert ? llvm::hlsl::ResourceClass::CBuffer
181*12c85518Srobert : llvm::hlsl::ResourceClass::SRV;
182*12c85518Srobert llvm::hlsl::ResourceKind RK = Buf.IsCBuffer
183*12c85518Srobert ? llvm::hlsl::ResourceKind::CBuffer
184*12c85518Srobert : llvm::hlsl::ResourceKind::TBuffer;
185*12c85518Srobert std::string TyName =
186*12c85518Srobert Buf.Name.str() + (Buf.IsCBuffer ? ".cb." : ".tb.") + "ty";
187*12c85518Srobert addBufferResourceAnnotation(GV, TyName, RC, RK, Buf.Binding);
188*12c85518Srobert }
189*12c85518Srobert }
190*12c85518Srobert
Buffer(const HLSLBufferDecl * D)191*12c85518Srobert CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D)
192*12c85518Srobert : Name(D->getName()), IsCBuffer(D->isCBuffer()),
193*12c85518Srobert Binding(D->getAttr<HLSLResourceBindingAttr>()) {}
194*12c85518Srobert
addBufferResourceAnnotation(llvm::GlobalVariable * GV,llvm::StringRef TyName,llvm::hlsl::ResourceClass RC,llvm::hlsl::ResourceKind RK,BufferResBinding & Binding)195*12c85518Srobert void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,
196*12c85518Srobert llvm::StringRef TyName,
197*12c85518Srobert llvm::hlsl::ResourceClass RC,
198*12c85518Srobert llvm::hlsl::ResourceKind RK,
199*12c85518Srobert BufferResBinding &Binding) {
200*12c85518Srobert llvm::Module &M = CGM.getModule();
201*12c85518Srobert
202*12c85518Srobert NamedMDNode *ResourceMD = nullptr;
203*12c85518Srobert switch (RC) {
204*12c85518Srobert case llvm::hlsl::ResourceClass::UAV:
205*12c85518Srobert ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs");
206*12c85518Srobert break;
207*12c85518Srobert case llvm::hlsl::ResourceClass::SRV:
208*12c85518Srobert ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs");
209*12c85518Srobert break;
210*12c85518Srobert case llvm::hlsl::ResourceClass::CBuffer:
211*12c85518Srobert ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs");
212*12c85518Srobert break;
213*12c85518Srobert default:
214*12c85518Srobert assert(false && "Unsupported buffer type!");
215*12c85518Srobert return;
216*12c85518Srobert }
217*12c85518Srobert
218*12c85518Srobert assert(ResourceMD != nullptr &&
219*12c85518Srobert "ResourceMD must have been set by the switch above.");
220*12c85518Srobert
221*12c85518Srobert llvm::hlsl::FrontendResource Res(
222*12c85518Srobert GV, TyName, RK, Binding.Reg.value_or(UINT_MAX), Binding.Space);
223*12c85518Srobert ResourceMD->addOperand(Res.getMetadata());
224*12c85518Srobert }
225*12c85518Srobert
226*12c85518Srobert static llvm::hlsl::ResourceKind
castResourceShapeToResourceKind(HLSLResourceAttr::ResourceKind RK)227*12c85518Srobert castResourceShapeToResourceKind(HLSLResourceAttr::ResourceKind RK) {
228*12c85518Srobert switch (RK) {
229*12c85518Srobert case HLSLResourceAttr::ResourceKind::Texture1D:
230*12c85518Srobert return llvm::hlsl::ResourceKind::Texture1D;
231*12c85518Srobert case HLSLResourceAttr::ResourceKind::Texture2D:
232*12c85518Srobert return llvm::hlsl::ResourceKind::Texture2D;
233*12c85518Srobert case HLSLResourceAttr::ResourceKind::Texture2DMS:
234*12c85518Srobert return llvm::hlsl::ResourceKind::Texture2DMS;
235*12c85518Srobert case HLSLResourceAttr::ResourceKind::Texture3D:
236*12c85518Srobert return llvm::hlsl::ResourceKind::Texture3D;
237*12c85518Srobert case HLSLResourceAttr::ResourceKind::TextureCube:
238*12c85518Srobert return llvm::hlsl::ResourceKind::TextureCube;
239*12c85518Srobert case HLSLResourceAttr::ResourceKind::Texture1DArray:
240*12c85518Srobert return llvm::hlsl::ResourceKind::Texture1DArray;
241*12c85518Srobert case HLSLResourceAttr::ResourceKind::Texture2DArray:
242*12c85518Srobert return llvm::hlsl::ResourceKind::Texture2DArray;
243*12c85518Srobert case HLSLResourceAttr::ResourceKind::Texture2DMSArray:
244*12c85518Srobert return llvm::hlsl::ResourceKind::Texture2DMSArray;
245*12c85518Srobert case HLSLResourceAttr::ResourceKind::TextureCubeArray:
246*12c85518Srobert return llvm::hlsl::ResourceKind::TextureCubeArray;
247*12c85518Srobert case HLSLResourceAttr::ResourceKind::TypedBuffer:
248*12c85518Srobert return llvm::hlsl::ResourceKind::TypedBuffer;
249*12c85518Srobert case HLSLResourceAttr::ResourceKind::RawBuffer:
250*12c85518Srobert return llvm::hlsl::ResourceKind::RawBuffer;
251*12c85518Srobert case HLSLResourceAttr::ResourceKind::StructuredBuffer:
252*12c85518Srobert return llvm::hlsl::ResourceKind::StructuredBuffer;
253*12c85518Srobert case HLSLResourceAttr::ResourceKind::CBufferKind:
254*12c85518Srobert return llvm::hlsl::ResourceKind::CBuffer;
255*12c85518Srobert case HLSLResourceAttr::ResourceKind::SamplerKind:
256*12c85518Srobert return llvm::hlsl::ResourceKind::Sampler;
257*12c85518Srobert case HLSLResourceAttr::ResourceKind::TBuffer:
258*12c85518Srobert return llvm::hlsl::ResourceKind::TBuffer;
259*12c85518Srobert case HLSLResourceAttr::ResourceKind::RTAccelerationStructure:
260*12c85518Srobert return llvm::hlsl::ResourceKind::RTAccelerationStructure;
261*12c85518Srobert case HLSLResourceAttr::ResourceKind::FeedbackTexture2D:
262*12c85518Srobert return llvm::hlsl::ResourceKind::FeedbackTexture2D;
263*12c85518Srobert case HLSLResourceAttr::ResourceKind::FeedbackTexture2DArray:
264*12c85518Srobert return llvm::hlsl::ResourceKind::FeedbackTexture2DArray;
265*12c85518Srobert }
266*12c85518Srobert // Make sure to update HLSLResourceAttr::ResourceKind when add new Kind to
267*12c85518Srobert // hlsl::ResourceKind. Assume FeedbackTexture2DArray is the last enum for
268*12c85518Srobert // HLSLResourceAttr::ResourceKind.
269*12c85518Srobert static_assert(
270*12c85518Srobert static_cast<uint32_t>(
271*12c85518Srobert HLSLResourceAttr::ResourceKind::FeedbackTexture2DArray) ==
272*12c85518Srobert (static_cast<uint32_t>(llvm::hlsl::ResourceKind::NumEntries) - 2));
273*12c85518Srobert llvm_unreachable("all switch cases should be covered");
274*12c85518Srobert }
275*12c85518Srobert
annotateHLSLResource(const VarDecl * D,GlobalVariable * GV)276*12c85518Srobert void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) {
277*12c85518Srobert const Type *Ty = D->getType()->getPointeeOrArrayElementType();
278*12c85518Srobert if (!Ty)
279*12c85518Srobert return;
280*12c85518Srobert const auto *RD = Ty->getAsCXXRecordDecl();
281*12c85518Srobert if (!RD)
282*12c85518Srobert return;
283*12c85518Srobert const auto *Attr = RD->getAttr<HLSLResourceAttr>();
284*12c85518Srobert if (!Attr)
285*12c85518Srobert return;
286*12c85518Srobert
287*12c85518Srobert HLSLResourceAttr::ResourceClass RC = Attr->getResourceType();
288*12c85518Srobert llvm::hlsl::ResourceKind RK =
289*12c85518Srobert castResourceShapeToResourceKind(Attr->getResourceShape());
290*12c85518Srobert
291*12c85518Srobert QualType QT(Ty, 0);
292*12c85518Srobert BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>());
293*12c85518Srobert addBufferResourceAnnotation(GV, QT.getAsString(),
294*12c85518Srobert static_cast<llvm::hlsl::ResourceClass>(RC), RK,
295*12c85518Srobert Binding);
296*12c85518Srobert }
297*12c85518Srobert
BufferResBinding(HLSLResourceBindingAttr * Binding)298*12c85518Srobert CGHLSLRuntime::BufferResBinding::BufferResBinding(
299*12c85518Srobert HLSLResourceBindingAttr *Binding) {
300*12c85518Srobert if (Binding) {
301*12c85518Srobert llvm::APInt RegInt(64, 0);
302*12c85518Srobert Binding->getSlot().substr(1).getAsInteger(10, RegInt);
303*12c85518Srobert Reg = RegInt.getLimitedValue();
304*12c85518Srobert llvm::APInt SpaceInt(64, 0);
305*12c85518Srobert Binding->getSpace().substr(5).getAsInteger(10, SpaceInt);
306*12c85518Srobert Space = SpaceInt.getLimitedValue();
307*12c85518Srobert } else {
308*12c85518Srobert Space = 0;
309*12c85518Srobert }
310*12c85518Srobert }
311*12c85518Srobert
setHLSLEntryAttributes(const FunctionDecl * FD,llvm::Function * Fn)312*12c85518Srobert void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
313*12c85518Srobert const FunctionDecl *FD, llvm::Function *Fn) {
314*12c85518Srobert const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
315*12c85518Srobert assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
316*12c85518Srobert const StringRef ShaderAttrKindStr = "hlsl.shader";
317*12c85518Srobert Fn->addFnAttr(ShaderAttrKindStr,
318*12c85518Srobert ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType()));
319*12c85518Srobert if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
320*12c85518Srobert const StringRef NumThreadsKindStr = "hlsl.numthreads";
321*12c85518Srobert std::string NumThreadsStr =
322*12c85518Srobert formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(),
323*12c85518Srobert NumThreadsAttr->getZ());
324*12c85518Srobert Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr);
325*12c85518Srobert }
326*12c85518Srobert }
327*12c85518Srobert
buildVectorInput(IRBuilder<> & B,Function * F,llvm::Type * Ty)328*12c85518Srobert static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) {
329*12c85518Srobert if (const auto *VT = dyn_cast<FixedVectorType>(Ty)) {
330*12c85518Srobert Value *Result = PoisonValue::get(Ty);
331*12c85518Srobert for (unsigned I = 0; I < VT->getNumElements(); ++I) {
332*12c85518Srobert Value *Elt = B.CreateCall(F, {B.getInt32(I)});
333*12c85518Srobert Result = B.CreateInsertElement(Result, Elt, I);
334*12c85518Srobert }
335*12c85518Srobert return Result;
336*12c85518Srobert }
337*12c85518Srobert return B.CreateCall(F, {B.getInt32(0)});
338*12c85518Srobert }
339*12c85518Srobert
emitInputSemantic(IRBuilder<> & B,const ParmVarDecl & D,llvm::Type * Ty)340*12c85518Srobert llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
341*12c85518Srobert const ParmVarDecl &D,
342*12c85518Srobert llvm::Type *Ty) {
343*12c85518Srobert assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");
344*12c85518Srobert if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {
345*12c85518Srobert llvm::Function *DxGroupIndex =
346*12c85518Srobert CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group);
347*12c85518Srobert return B.CreateCall(FunctionCallee(DxGroupIndex));
348*12c85518Srobert }
349*12c85518Srobert if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {
350*12c85518Srobert llvm::Function *DxThreadID = CGM.getIntrinsic(Intrinsic::dx_thread_id);
351*12c85518Srobert return buildVectorInput(B, DxThreadID, Ty);
352*12c85518Srobert }
353*12c85518Srobert assert(false && "Unhandled parameter attribute");
354*12c85518Srobert return nullptr;
355*12c85518Srobert }
356*12c85518Srobert
emitEntryFunction(const FunctionDecl * FD,llvm::Function * Fn)357*12c85518Srobert void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
358*12c85518Srobert llvm::Function *Fn) {
359*12c85518Srobert llvm::Module &M = CGM.getModule();
360*12c85518Srobert llvm::LLVMContext &Ctx = M.getContext();
361*12c85518Srobert auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
362*12c85518Srobert Function *EntryFn =
363*12c85518Srobert Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M);
364*12c85518Srobert
365*12c85518Srobert // Copy function attributes over, we have no argument or return attributes
366*12c85518Srobert // that can be valid on the real entry.
367*12c85518Srobert AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex,
368*12c85518Srobert Fn->getAttributes().getFnAttrs());
369*12c85518Srobert EntryFn->setAttributes(NewAttrs);
370*12c85518Srobert setHLSLEntryAttributes(FD, EntryFn);
371*12c85518Srobert
372*12c85518Srobert // Set the called function as internal linkage.
373*12c85518Srobert Fn->setLinkage(GlobalValue::InternalLinkage);
374*12c85518Srobert
375*12c85518Srobert BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);
376*12c85518Srobert IRBuilder<> B(BB);
377*12c85518Srobert llvm::SmallVector<Value *> Args;
378*12c85518Srobert // FIXME: support struct parameters where semantics are on members.
379*12c85518Srobert // See: https://github.com/llvm/llvm-project/issues/57874
380*12c85518Srobert unsigned SRetOffset = 0;
381*12c85518Srobert for (const auto &Param : Fn->args()) {
382*12c85518Srobert if (Param.hasStructRetAttr()) {
383*12c85518Srobert // FIXME: support output.
384*12c85518Srobert // See: https://github.com/llvm/llvm-project/issues/57874
385*12c85518Srobert SRetOffset = 1;
386*12c85518Srobert Args.emplace_back(PoisonValue::get(Param.getType()));
387*12c85518Srobert continue;
388*12c85518Srobert }
389*12c85518Srobert const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);
390*12c85518Srobert Args.push_back(emitInputSemantic(B, *PD, Param.getType()));
391*12c85518Srobert }
392*12c85518Srobert
393*12c85518Srobert CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args);
394*12c85518Srobert (void)CI;
395*12c85518Srobert // FIXME: Handle codegen for return type semantics.
396*12c85518Srobert // See: https://github.com/llvm/llvm-project/issues/57875
397*12c85518Srobert B.CreateRetVoid();
398*12c85518Srobert }
399*12c85518Srobert
gatherFunctions(SmallVectorImpl<Function * > & Fns,llvm::Module & M,bool CtorOrDtor)400*12c85518Srobert static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M,
401*12c85518Srobert bool CtorOrDtor) {
402*12c85518Srobert const auto *GV =
403*12c85518Srobert M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");
404*12c85518Srobert if (!GV)
405*12c85518Srobert return;
406*12c85518Srobert const auto *CA = dyn_cast<ConstantArray>(GV->getInitializer());
407*12c85518Srobert if (!CA)
408*12c85518Srobert return;
409*12c85518Srobert // The global_ctor array elements are a struct [Priority, Fn *, COMDat].
410*12c85518Srobert // HLSL neither supports priorities or COMDat values, so we will check those
411*12c85518Srobert // in an assert but not handle them.
412*12c85518Srobert
413*12c85518Srobert llvm::SmallVector<Function *> CtorFns;
414*12c85518Srobert for (const auto &Ctor : CA->operands()) {
415*12c85518Srobert if (isa<ConstantAggregateZero>(Ctor))
416*12c85518Srobert continue;
417*12c85518Srobert ConstantStruct *CS = cast<ConstantStruct>(Ctor);
418*12c85518Srobert
419*12c85518Srobert assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 &&
420*12c85518Srobert "HLSL doesn't support setting priority for global ctors.");
421*12c85518Srobert assert(isa<ConstantPointerNull>(CS->getOperand(2)) &&
422*12c85518Srobert "HLSL doesn't support COMDat for global ctors.");
423*12c85518Srobert Fns.push_back(cast<Function>(CS->getOperand(1)));
424*12c85518Srobert }
425*12c85518Srobert }
426*12c85518Srobert
generateGlobalCtorDtorCalls()427*12c85518Srobert void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
428*12c85518Srobert llvm::Module &M = CGM.getModule();
429*12c85518Srobert SmallVector<Function *> CtorFns;
430*12c85518Srobert SmallVector<Function *> DtorFns;
431*12c85518Srobert gatherFunctions(CtorFns, M, true);
432*12c85518Srobert gatherFunctions(DtorFns, M, false);
433*12c85518Srobert
434*12c85518Srobert // Insert a call to the global constructor at the beginning of the entry block
435*12c85518Srobert // to externally exported functions. This is a bit of a hack, but HLSL allows
436*12c85518Srobert // global constructors, but doesn't support driver initialization of globals.
437*12c85518Srobert for (auto &F : M.functions()) {
438*12c85518Srobert if (!F.hasFnAttribute("hlsl.shader"))
439*12c85518Srobert continue;
440*12c85518Srobert IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin());
441*12c85518Srobert for (auto *Fn : CtorFns)
442*12c85518Srobert B.CreateCall(FunctionCallee(Fn));
443*12c85518Srobert
444*12c85518Srobert // Insert global dtors before the terminator of the last instruction
445*12c85518Srobert B.SetInsertPoint(F.back().getTerminator());
446*12c85518Srobert for (auto *Fn : DtorFns)
447*12c85518Srobert B.CreateCall(FunctionCallee(Fn));
448*12c85518Srobert }
449*12c85518Srobert
450*12c85518Srobert // No need to keep global ctors/dtors for non-lib profile after call to
451*12c85518Srobert // ctors/dtors added for entry.
452*12c85518Srobert Triple T(M.getTargetTriple());
453*12c85518Srobert if (T.getEnvironment() != Triple::EnvironmentType::Library) {
454*12c85518Srobert if (auto *GV = M.getNamedGlobal("llvm.global_ctors"))
455*12c85518Srobert GV->eraseFromParent();
456*12c85518Srobert if (auto *GV = M.getNamedGlobal("llvm.global_dtors"))
457*12c85518Srobert GV->eraseFromParent();
458*12c85518Srobert }
459*12c85518Srobert }
460