xref: /llvm-project/clang/lib/Sema/SemaHLSL.cpp (revision aab25f20f6c06bab7aac6fb83d54705ec4cdfadd)
1 //===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This implements Semantic Analysis for HLSL constructs.
9 //===----------------------------------------------------------------------===//
10 
11 #include "clang/Sema/SemaHLSL.h"
12 #include "clang/AST/ASTContext.h"
13 #include "clang/AST/Attr.h"
14 #include "clang/AST/Attrs.inc"
15 #include "clang/AST/Decl.h"
16 #include "clang/AST/DeclBase.h"
17 #include "clang/AST/DeclCXX.h"
18 #include "clang/AST/DeclarationName.h"
19 #include "clang/AST/DynamicRecursiveASTVisitor.h"
20 #include "clang/AST/Expr.h"
21 #include "clang/AST/Type.h"
22 #include "clang/AST/TypeLoc.h"
23 #include "clang/Basic/Builtins.h"
24 #include "clang/Basic/DiagnosticSema.h"
25 #include "clang/Basic/IdentifierTable.h"
26 #include "clang/Basic/LLVM.h"
27 #include "clang/Basic/SourceLocation.h"
28 #include "clang/Basic/Specifiers.h"
29 #include "clang/Basic/TargetInfo.h"
30 #include "clang/Sema/Initialization.h"
31 #include "clang/Sema/ParsedAttr.h"
32 #include "clang/Sema/Sema.h"
33 #include "clang/Sema/Template.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/ADT/StringExtras.h"
37 #include "llvm/ADT/StringRef.h"
38 #include "llvm/ADT/Twine.h"
39 #include "llvm/Support/Casting.h"
40 #include "llvm/Support/DXILABI.h"
41 #include "llvm/Support/ErrorHandling.h"
42 #include "llvm/TargetParser/Triple.h"
43 #include <cstddef>
44 #include <iterator>
45 #include <utility>
46 
47 using namespace clang;
48 using RegisterType = HLSLResourceBindingAttr::RegisterType;
49 
50 static CXXRecordDecl *createHostLayoutStruct(Sema &S,
51                                              CXXRecordDecl *StructDecl);
52 
53 static RegisterType getRegisterType(ResourceClass RC) {
54   switch (RC) {
55   case ResourceClass::SRV:
56     return RegisterType::SRV;
57   case ResourceClass::UAV:
58     return RegisterType::UAV;
59   case ResourceClass::CBuffer:
60     return RegisterType::CBuffer;
61   case ResourceClass::Sampler:
62     return RegisterType::Sampler;
63   }
64   llvm_unreachable("unexpected ResourceClass value");
65 }
66 
67 // Converts the first letter of string Slot to RegisterType.
68 // Returns false if the letter does not correspond to a valid register type.
69 static bool convertToRegisterType(StringRef Slot, RegisterType *RT) {
70   assert(RT != nullptr);
71   switch (Slot[0]) {
72   case 't':
73   case 'T':
74     *RT = RegisterType::SRV;
75     return true;
76   case 'u':
77   case 'U':
78     *RT = RegisterType::UAV;
79     return true;
80   case 'b':
81   case 'B':
82     *RT = RegisterType::CBuffer;
83     return true;
84   case 's':
85   case 'S':
86     *RT = RegisterType::Sampler;
87     return true;
88   case 'c':
89   case 'C':
90     *RT = RegisterType::C;
91     return true;
92   case 'i':
93   case 'I':
94     *RT = RegisterType::I;
95     return true;
96   default:
97     return false;
98   }
99 }
100 
101 static ResourceClass getResourceClass(RegisterType RT) {
102   switch (RT) {
103   case RegisterType::SRV:
104     return ResourceClass::SRV;
105   case RegisterType::UAV:
106     return ResourceClass::UAV;
107   case RegisterType::CBuffer:
108     return ResourceClass::CBuffer;
109   case RegisterType::Sampler:
110     return ResourceClass::Sampler;
111   case RegisterType::C:
112   case RegisterType::I:
113     // Deliberately falling through to the unreachable below.
114     break;
115   }
116   llvm_unreachable("unexpected RegisterType value");
117 }
118 
119 DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
120                                                       ResourceClass ResClass) {
121   assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
122          "DeclBindingInfo already added");
123   assert(!hasBindingInfoForDecl(VD) || BindingsList.back().Decl == VD);
124   // VarDecl may have multiple entries for different resource classes.
125   // DeclToBindingListIndex stores the index of the first binding we saw
126   // for this decl. If there are any additional ones then that index
127   // shouldn't be updated.
128   DeclToBindingListIndex.try_emplace(VD, BindingsList.size());
129   return &BindingsList.emplace_back(VD, ResClass);
130 }
131 
132 DeclBindingInfo *ResourceBindings::getDeclBindingInfo(const VarDecl *VD,
133                                                       ResourceClass ResClass) {
134   auto Entry = DeclToBindingListIndex.find(VD);
135   if (Entry != DeclToBindingListIndex.end()) {
136     for (unsigned Index = Entry->getSecond();
137          Index < BindingsList.size() && BindingsList[Index].Decl == VD;
138          ++Index) {
139       if (BindingsList[Index].ResClass == ResClass)
140         return &BindingsList[Index];
141     }
142   }
143   return nullptr;
144 }
145 
146 bool ResourceBindings::hasBindingInfoForDecl(const VarDecl *VD) const {
147   return DeclToBindingListIndex.contains(VD);
148 }
149 
150 SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
151 
152 Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
153                                  SourceLocation KwLoc, IdentifierInfo *Ident,
154                                  SourceLocation IdentLoc,
155                                  SourceLocation LBrace) {
156   // For anonymous namespace, take the location of the left brace.
157   DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
158   HLSLBufferDecl *Result = HLSLBufferDecl::Create(
159       getASTContext(), LexicalParent, CBuffer, KwLoc, Ident, IdentLoc, LBrace);
160 
161   // if CBuffer is false, then it's a TBuffer
162   auto RC = CBuffer ? llvm::hlsl::ResourceClass::CBuffer
163                     : llvm::hlsl::ResourceClass::SRV;
164   auto RK = CBuffer ? llvm::hlsl::ResourceKind::CBuffer
165                     : llvm::hlsl::ResourceKind::TBuffer;
166   Result->addAttr(HLSLResourceClassAttr::CreateImplicit(getASTContext(), RC));
167   Result->addAttr(HLSLResourceAttr::CreateImplicit(getASTContext(), RK));
168 
169   SemaRef.PushOnScopeChains(Result, BufferScope);
170   SemaRef.PushDeclContext(BufferScope, Result);
171 
172   return Result;
173 }
174 
175 // Calculate the size of a legacy cbuffer type in bytes based on
176 // https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules
177 static unsigned calculateLegacyCbufferSize(const ASTContext &Context,
178                                            QualType T) {
179   unsigned Size = 0;
180   constexpr unsigned CBufferAlign = 16;
181   if (const RecordType *RT = T->getAs<RecordType>()) {
182     const RecordDecl *RD = RT->getDecl();
183     for (const FieldDecl *Field : RD->fields()) {
184       QualType Ty = Field->getType();
185       unsigned FieldSize = calculateLegacyCbufferSize(Context, Ty);
186       // FIXME: This is not the correct alignment, it does not work for 16-bit
187       // types. See llvm/llvm-project#119641.
188       unsigned FieldAlign = 4;
189       if (Ty->isAggregateType())
190         FieldAlign = CBufferAlign;
191       Size = llvm::alignTo(Size, FieldAlign);
192       Size += FieldSize;
193     }
194   } else if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) {
195     if (unsigned ElementCount = AT->getSize().getZExtValue()) {
196       unsigned ElementSize =
197           calculateLegacyCbufferSize(Context, AT->getElementType());
198       unsigned AlignedElementSize = llvm::alignTo(ElementSize, CBufferAlign);
199       Size = AlignedElementSize * (ElementCount - 1) + ElementSize;
200     }
201   } else if (const VectorType *VT = T->getAs<VectorType>()) {
202     unsigned ElementCount = VT->getNumElements();
203     unsigned ElementSize =
204         calculateLegacyCbufferSize(Context, VT->getElementType());
205     Size = ElementSize * ElementCount;
206   } else {
207     Size = Context.getTypeSize(T) / 8;
208   }
209   return Size;
210 }
211 
212 // Validate packoffset:
213 // - if packoffset it used it must be set on all declarations inside the buffer
214 // - packoffset ranges must not overlap
215 static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl) {
216   llvm::SmallVector<std::pair<VarDecl *, HLSLPackOffsetAttr *>> PackOffsetVec;
217 
218   // Make sure the packoffset annotations are either on all declarations
219   // or on none.
220   bool HasPackOffset = false;
221   bool HasNonPackOffset = false;
222   for (auto *Field : BufDecl->decls()) {
223     VarDecl *Var = dyn_cast<VarDecl>(Field);
224     if (!Var)
225       continue;
226     if (Field->hasAttr<HLSLPackOffsetAttr>()) {
227       PackOffsetVec.emplace_back(Var, Field->getAttr<HLSLPackOffsetAttr>());
228       HasPackOffset = true;
229     } else {
230       HasNonPackOffset = true;
231     }
232   }
233 
234   if (!HasPackOffset)
235     return;
236 
237   if (HasNonPackOffset)
238     S.Diag(BufDecl->getLocation(), diag::warn_hlsl_packoffset_mix);
239 
240   // Make sure there is no overlap in packoffset - sort PackOffsetVec by offset
241   // and compare adjacent values.
242   ASTContext &Context = S.getASTContext();
243   std::sort(PackOffsetVec.begin(), PackOffsetVec.end(),
244             [](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS,
245                const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) {
246               return LHS.second->getOffsetInBytes() <
247                      RHS.second->getOffsetInBytes();
248             });
249   for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) {
250     VarDecl *Var = PackOffsetVec[i].first;
251     HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second;
252     unsigned Size = calculateLegacyCbufferSize(Context, Var->getType());
253     unsigned Begin = Attr->getOffsetInBytes();
254     unsigned End = Begin + Size;
255     unsigned NextBegin = PackOffsetVec[i + 1].second->getOffsetInBytes();
256     if (End > NextBegin) {
257       VarDecl *NextVar = PackOffsetVec[i + 1].first;
258       S.Diag(NextVar->getLocation(), diag::err_hlsl_packoffset_overlap)
259           << NextVar << Var;
260     }
261   }
262 }
263 
264 // Returns true if the array has a zero size = if any of the dimensions is 0
265 static bool isZeroSizedArray(const ConstantArrayType *CAT) {
266   while (CAT && !CAT->isZeroSize())
267     CAT = dyn_cast<ConstantArrayType>(
268         CAT->getElementType()->getUnqualifiedDesugaredType());
269   return CAT != nullptr;
270 }
271 
272 // Returns true if the record type is an HLSL resource class
273 static bool isResourceRecordType(const Type *Ty) {
274   return HLSLAttributedResourceType::findHandleTypeOnResource(Ty) != nullptr;
275 }
276 
277 // Returns true if the type is a leaf element type that is not valid to be
278 // included in HLSL Buffer, such as a resource class, empty struct, zero-sized
279 // array, or a builtin intangible type. Returns false it is a valid leaf element
280 // type or if it is a record type that needs to be inspected further.
281 static bool isInvalidConstantBufferLeafElementType(const Type *Ty) {
282   if (Ty->isRecordType()) {
283     if (isResourceRecordType(Ty) || Ty->getAsCXXRecordDecl()->isEmpty())
284       return true;
285     return false;
286   }
287   if (Ty->isConstantArrayType() &&
288       isZeroSizedArray(cast<ConstantArrayType>(Ty)))
289     return true;
290   if (Ty->isHLSLBuiltinIntangibleType())
291     return true;
292   return false;
293 }
294 
295 // Returns true if the struct contains at least one element that prevents it
296 // from being included inside HLSL Buffer as is, such as an intangible type,
297 // empty struct, or zero-sized array. If it does, a new implicit layout struct
298 // needs to be created for HLSL Buffer use that will exclude these unwanted
299 // declarations (see createHostLayoutStruct function).
300 static bool requiresImplicitBufferLayoutStructure(const CXXRecordDecl *RD) {
301   if (RD->getTypeForDecl()->isHLSLIntangibleType() || RD->isEmpty())
302     return true;
303   // check fields
304   for (const FieldDecl *Field : RD->fields()) {
305     QualType Ty = Field->getType();
306     if (isInvalidConstantBufferLeafElementType(Ty.getTypePtr()))
307       return true;
308     if (Ty->isRecordType() &&
309         requiresImplicitBufferLayoutStructure(Ty->getAsCXXRecordDecl()))
310       return true;
311   }
312   // check bases
313   for (const CXXBaseSpecifier &Base : RD->bases())
314     if (requiresImplicitBufferLayoutStructure(
315             Base.getType()->getAsCXXRecordDecl()))
316       return true;
317   return false;
318 }
319 
320 static CXXRecordDecl *findRecordDeclInContext(IdentifierInfo *II,
321                                               DeclContext *DC) {
322   CXXRecordDecl *RD = nullptr;
323   for (NamedDecl *Decl :
324        DC->getNonTransparentContext()->lookup(DeclarationName(II))) {
325     if (CXXRecordDecl *FoundRD = dyn_cast<CXXRecordDecl>(Decl)) {
326       assert(RD == nullptr &&
327              "there should be at most 1 record by a given name in a scope");
328       RD = FoundRD;
329     }
330   }
331   return RD;
332 }
333 
334 // Creates a name for buffer layout struct using the provide name base.
335 // If the name must be unique (not previously defined), a suffix is added
336 // until a unique name is found.
337 static IdentifierInfo *getHostLayoutStructName(Sema &S, NamedDecl *BaseDecl,
338                                                bool MustBeUnique) {
339   ASTContext &AST = S.getASTContext();
340 
341   IdentifierInfo *NameBaseII = BaseDecl->getIdentifier();
342   llvm::SmallString<64> Name("__layout_");
343   if (NameBaseII) {
344     Name.append(NameBaseII->getName());
345   } else {
346     // anonymous struct
347     Name.append("anon");
348     MustBeUnique = true;
349   }
350 
351   size_t NameLength = Name.size();
352   IdentifierInfo *II = &AST.Idents.get(Name, tok::TokenKind::identifier);
353   if (!MustBeUnique)
354     return II;
355 
356   unsigned suffix = 0;
357   while (true) {
358     if (suffix != 0) {
359       Name.append("_");
360       Name.append(llvm::Twine(suffix).str());
361       II = &AST.Idents.get(Name, tok::TokenKind::identifier);
362     }
363     if (!findRecordDeclInContext(II, BaseDecl->getDeclContext()))
364       return II;
365     // declaration with that name already exists - increment suffix and try
366     // again until unique name is found
367     suffix++;
368     Name.truncate(NameLength);
369   };
370 }
371 
372 // Creates a field declaration of given name and type for HLSL buffer layout
373 // struct. Returns nullptr if the type cannot be use in HLSL Buffer layout.
374 static FieldDecl *createFieldForHostLayoutStruct(Sema &S, const Type *Ty,
375                                                  IdentifierInfo *II,
376                                                  CXXRecordDecl *LayoutStruct) {
377   if (isInvalidConstantBufferLeafElementType(Ty))
378     return nullptr;
379 
380   if (Ty->isRecordType()) {
381     CXXRecordDecl *RD = Ty->getAsCXXRecordDecl();
382     if (requiresImplicitBufferLayoutStructure(RD)) {
383       RD = createHostLayoutStruct(S, RD);
384       if (!RD)
385         return nullptr;
386       Ty = RD->getTypeForDecl();
387     }
388   }
389 
390   QualType QT = QualType(Ty, 0);
391   ASTContext &AST = S.getASTContext();
392   TypeSourceInfo *TSI = AST.getTrivialTypeSourceInfo(QT, SourceLocation());
393   auto *Field = FieldDecl::Create(AST, LayoutStruct, SourceLocation(),
394                                   SourceLocation(), II, QT, TSI, nullptr, false,
395                                   InClassInitStyle::ICIS_NoInit);
396   Field->setAccess(AccessSpecifier::AS_private);
397   return Field;
398 }
399 
400 // Creates host layout struct for a struct included in HLSL Buffer.
401 // The layout struct will include only fields that are allowed in HLSL buffer.
402 // These fields will be filtered out:
403 // - resource classes
404 // - empty structs
405 // - zero-sized arrays
406 // Returns nullptr if the resulting layout struct would be empty.
407 static CXXRecordDecl *createHostLayoutStruct(Sema &S,
408                                              CXXRecordDecl *StructDecl) {
409   assert(requiresImplicitBufferLayoutStructure(StructDecl) &&
410          "struct is already HLSL buffer compatible");
411 
412   ASTContext &AST = S.getASTContext();
413   DeclContext *DC = StructDecl->getDeclContext();
414   IdentifierInfo *II = getHostLayoutStructName(S, StructDecl, false);
415 
416   // reuse existing if the layout struct if it already exists
417   if (CXXRecordDecl *RD = findRecordDeclInContext(II, DC))
418     return RD;
419 
420   CXXRecordDecl *LS = CXXRecordDecl::Create(
421       AST, TagDecl::TagKind::Class, DC, SourceLocation(), SourceLocation(), II);
422   LS->setImplicit(true);
423   LS->startDefinition();
424 
425   // copy base struct, create HLSL Buffer compatible version if needed
426   if (unsigned NumBases = StructDecl->getNumBases()) {
427     assert(NumBases == 1 && "HLSL supports only one base type");
428     (void)NumBases;
429     CXXBaseSpecifier Base = *StructDecl->bases_begin();
430     CXXRecordDecl *BaseDecl = Base.getType()->getAsCXXRecordDecl();
431     if (requiresImplicitBufferLayoutStructure(BaseDecl)) {
432       BaseDecl = createHostLayoutStruct(S, BaseDecl);
433       if (BaseDecl) {
434         TypeSourceInfo *TSI = AST.getTrivialTypeSourceInfo(
435             QualType(BaseDecl->getTypeForDecl(), 0));
436         Base = CXXBaseSpecifier(SourceRange(), false, StructDecl->isClass(),
437                                 AS_none, TSI, SourceLocation());
438       }
439     }
440     if (BaseDecl) {
441       const CXXBaseSpecifier *BasesArray[1] = {&Base};
442       LS->setBases(BasesArray, 1);
443     }
444   }
445 
446   // filter struct fields
447   for (const FieldDecl *FD : StructDecl->fields()) {
448     const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
449     if (FieldDecl *NewFD =
450             createFieldForHostLayoutStruct(S, Ty, FD->getIdentifier(), LS))
451       LS->addDecl(NewFD);
452   }
453   LS->completeDefinition();
454 
455   if (LS->field_empty() && LS->getNumBases() == 0)
456     return nullptr;
457 
458   DC->addDecl(LS);
459   return LS;
460 }
461 
462 // Creates host layout struct for HLSL Buffer. The struct will include only
463 // fields of types that are allowed in HLSL buffer and it will filter out:
464 // - static variable declarations
465 // - resource classes
466 // - empty structs
467 // - zero-sized arrays
468 // - non-variable declarations
469 // The layour struct will be added to the HLSLBufferDecl declarations.
470 void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl) {
471   ASTContext &AST = S.getASTContext();
472   IdentifierInfo *II = getHostLayoutStructName(S, BufDecl, true);
473 
474   CXXRecordDecl *LS =
475       CXXRecordDecl::Create(AST, TagDecl::TagKind::Class, BufDecl,
476                             SourceLocation(), SourceLocation(), II);
477   LS->setImplicit(true);
478   LS->startDefinition();
479 
480   for (Decl *D : BufDecl->decls()) {
481     VarDecl *VD = dyn_cast<VarDecl>(D);
482     if (!VD || VD->getStorageClass() == SC_Static)
483       continue;
484     const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
485     if (FieldDecl *FD =
486             createFieldForHostLayoutStruct(S, Ty, VD->getIdentifier(), LS)) {
487       // add the field decl to the layout struct
488       LS->addDecl(FD);
489       // update address space of the original decl to hlsl_constant
490       QualType NewTy =
491           AST.getAddrSpaceQualType(VD->getType(), LangAS::hlsl_constant);
492       VD->setType(NewTy);
493     }
494   }
495   LS->completeDefinition();
496   BufDecl->addDecl(LS);
497 }
498 
499 // Handle end of cbuffer/tbuffer declaration
500 void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {
501   auto *BufDecl = cast<HLSLBufferDecl>(Dcl);
502   BufDecl->setRBraceLoc(RBrace);
503 
504   validatePackoffset(SemaRef, BufDecl);
505 
506   // create buffer layout struct
507   createHostLayoutStructForBuffer(SemaRef, BufDecl);
508 
509   SemaRef.PopDeclContext();
510 }
511 
512 HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
513                                                   const AttributeCommonInfo &AL,
514                                                   int X, int Y, int Z) {
515   if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
516     if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
517       Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
518       Diag(AL.getLoc(), diag::note_conflicting_attribute);
519     }
520     return nullptr;
521   }
522   return ::new (getASTContext())
523       HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
524 }
525 
526 HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
527                                               const AttributeCommonInfo &AL,
528                                               int Min, int Max, int Preferred,
529                                               int SpelledArgsCount) {
530   if (HLSLWaveSizeAttr *WS = D->getAttr<HLSLWaveSizeAttr>()) {
531     if (WS->getMin() != Min || WS->getMax() != Max ||
532         WS->getPreferred() != Preferred ||
533         WS->getSpelledArgsCount() != SpelledArgsCount) {
534       Diag(WS->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
535       Diag(AL.getLoc(), diag::note_conflicting_attribute);
536     }
537     return nullptr;
538   }
539   HLSLWaveSizeAttr *Result = ::new (getASTContext())
540       HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred);
541   Result->setSpelledArgsCount(SpelledArgsCount);
542   return Result;
543 }
544 
545 HLSLShaderAttr *
546 SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
547                           llvm::Triple::EnvironmentType ShaderType) {
548   if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
549     if (NT->getType() != ShaderType) {
550       Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
551       Diag(AL.getLoc(), diag::note_conflicting_attribute);
552     }
553     return nullptr;
554   }
555   return HLSLShaderAttr::Create(getASTContext(), ShaderType, AL);
556 }
557 
558 HLSLParamModifierAttr *
559 SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
560                                  HLSLParamModifierAttr::Spelling Spelling) {
561   // We can only merge an `in` attribute with an `out` attribute. All other
562   // combinations of duplicated attributes are ill-formed.
563   if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
564     if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
565         (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
566       D->dropAttr<HLSLParamModifierAttr>();
567       SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
568       return HLSLParamModifierAttr::Create(
569           getASTContext(), /*MergedSpelling=*/true, AdjustedRange,
570           HLSLParamModifierAttr::Keyword_inout);
571     }
572     Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL;
573     Diag(PA->getLocation(), diag::note_conflicting_attribute);
574     return nullptr;
575   }
576   return HLSLParamModifierAttr::Create(getASTContext(), AL);
577 }
578 
579 void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
580   auto &TargetInfo = getASTContext().getTargetInfo();
581 
582   if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
583     return;
584 
585   llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
586   if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) {
587     if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
588       // The entry point is already annotated - check that it matches the
589       // triple.
590       if (Shader->getType() != Env) {
591         Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
592             << Shader;
593         FD->setInvalidDecl();
594       }
595     } else {
596       // Implicitly add the shader attribute if the entry function isn't
597       // explicitly annotated.
598       FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), Env,
599                                                  FD->getBeginLoc()));
600     }
601   } else {
602     switch (Env) {
603     case llvm::Triple::UnknownEnvironment:
604     case llvm::Triple::Library:
605       break;
606     default:
607       llvm_unreachable("Unhandled environment in triple");
608     }
609   }
610 }
611 
612 void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
613   const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
614   assert(ShaderAttr && "Entry point has no shader attribute");
615   llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
616   auto &TargetInfo = getASTContext().getTargetInfo();
617   VersionTuple Ver = TargetInfo.getTriple().getOSVersion();
618   switch (ST) {
619   case llvm::Triple::Pixel:
620   case llvm::Triple::Vertex:
621   case llvm::Triple::Geometry:
622   case llvm::Triple::Hull:
623   case llvm::Triple::Domain:
624   case llvm::Triple::RayGeneration:
625   case llvm::Triple::Intersection:
626   case llvm::Triple::AnyHit:
627   case llvm::Triple::ClosestHit:
628   case llvm::Triple::Miss:
629   case llvm::Triple::Callable:
630     if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
631       DiagnoseAttrStageMismatch(NT, ST,
632                                 {llvm::Triple::Compute,
633                                  llvm::Triple::Amplification,
634                                  llvm::Triple::Mesh});
635       FD->setInvalidDecl();
636     }
637     if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
638       DiagnoseAttrStageMismatch(WS, ST,
639                                 {llvm::Triple::Compute,
640                                  llvm::Triple::Amplification,
641                                  llvm::Triple::Mesh});
642       FD->setInvalidDecl();
643     }
644     break;
645 
646   case llvm::Triple::Compute:
647   case llvm::Triple::Amplification:
648   case llvm::Triple::Mesh:
649     if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
650       Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
651           << llvm::Triple::getEnvironmentTypeName(ST);
652       FD->setInvalidDecl();
653     }
654     if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
655       if (Ver < VersionTuple(6, 6)) {
656         Diag(WS->getLocation(), diag::err_hlsl_attribute_in_wrong_shader_model)
657             << WS << "6.6";
658         FD->setInvalidDecl();
659       } else if (WS->getSpelledArgsCount() > 1 && Ver < VersionTuple(6, 8)) {
660         Diag(
661             WS->getLocation(),
662             diag::err_hlsl_attribute_number_arguments_insufficient_shader_model)
663             << WS << WS->getSpelledArgsCount() << "6.8";
664         FD->setInvalidDecl();
665       }
666     }
667     break;
668   default:
669     llvm_unreachable("Unhandled environment in triple");
670   }
671 
672   for (ParmVarDecl *Param : FD->parameters()) {
673     if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) {
674       CheckSemanticAnnotation(FD, Param, AnnotationAttr);
675     } else {
676       // FIXME: Handle struct parameters where annotations are on struct fields.
677       // See: https://github.com/llvm/llvm-project/issues/57875
678       Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation);
679       Diag(Param->getLocation(), diag::note_previous_decl) << Param;
680       FD->setInvalidDecl();
681     }
682   }
683   // FIXME: Verify return type semantic annotation.
684 }
685 
686 void SemaHLSL::CheckSemanticAnnotation(
687     FunctionDecl *EntryPoint, const Decl *Param,
688     const HLSLAnnotationAttr *AnnotationAttr) {
689   auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
690   assert(ShaderAttr && "Entry point has no shader attribute");
691   llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
692 
693   switch (AnnotationAttr->getKind()) {
694   case attr::HLSLSV_DispatchThreadID:
695   case attr::HLSLSV_GroupIndex:
696   case attr::HLSLSV_GroupThreadID:
697   case attr::HLSLSV_GroupID:
698     if (ST == llvm::Triple::Compute)
699       return;
700     DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute});
701     break;
702   default:
703     llvm_unreachable("Unknown HLSLAnnotationAttr");
704   }
705 }
706 
707 void SemaHLSL::DiagnoseAttrStageMismatch(
708     const Attr *A, llvm::Triple::EnvironmentType Stage,
709     std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
710   SmallVector<StringRef, 8> StageStrings;
711   llvm::transform(AllowedStages, std::back_inserter(StageStrings),
712                   [](llvm::Triple::EnvironmentType ST) {
713                     return StringRef(
714                         HLSLShaderAttr::ConvertEnvironmentTypeToStr(ST));
715                   });
716   Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
717       << A << llvm::Triple::getEnvironmentTypeName(Stage)
718       << (AllowedStages.size() != 1) << join(StageStrings, ", ");
719 }
720 
721 template <CastKind Kind>
722 static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz) {
723   if (const auto *VTy = Ty->getAs<VectorType>())
724     Ty = VTy->getElementType();
725   Ty = S.getASTContext().getExtVectorType(Ty, Sz);
726   E = S.ImpCastExprToType(E.get(), Ty, Kind);
727 }
728 
729 template <CastKind Kind>
730 static QualType castElement(Sema &S, ExprResult &E, QualType Ty) {
731   E = S.ImpCastExprToType(E.get(), Ty, Kind);
732   return Ty;
733 }
734 
735 static QualType handleFloatVectorBinOpConversion(
736     Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
737     QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
738   bool LHSFloat = LElTy->isRealFloatingType();
739   bool RHSFloat = RElTy->isRealFloatingType();
740 
741   if (LHSFloat && RHSFloat) {
742     if (IsCompAssign ||
743         SemaRef.getASTContext().getFloatingTypeOrder(LElTy, RElTy) > 0)
744       return castElement<CK_FloatingCast>(SemaRef, RHS, LHSType);
745 
746     return castElement<CK_FloatingCast>(SemaRef, LHS, RHSType);
747   }
748 
749   if (LHSFloat)
750     return castElement<CK_IntegralToFloating>(SemaRef, RHS, LHSType);
751 
752   assert(RHSFloat);
753   if (IsCompAssign)
754     return castElement<clang::CK_FloatingToIntegral>(SemaRef, RHS, LHSType);
755 
756   return castElement<CK_IntegralToFloating>(SemaRef, LHS, RHSType);
757 }
758 
759 static QualType handleIntegerVectorBinOpConversion(
760     Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
761     QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
762 
763   int IntOrder = SemaRef.Context.getIntegerTypeOrder(LElTy, RElTy);
764   bool LHSSigned = LElTy->hasSignedIntegerRepresentation();
765   bool RHSSigned = RElTy->hasSignedIntegerRepresentation();
766   auto &Ctx = SemaRef.getASTContext();
767 
768   // If both types have the same signedness, use the higher ranked type.
769   if (LHSSigned == RHSSigned) {
770     if (IsCompAssign || IntOrder >= 0)
771       return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
772 
773     return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
774   }
775 
776   // If the unsigned type has greater than or equal rank of the signed type, use
777   // the unsigned type.
778   if (IntOrder != (LHSSigned ? 1 : -1)) {
779     if (IsCompAssign || RHSSigned)
780       return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
781     return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
782   }
783 
784   // At this point the signed type has higher rank than the unsigned type, which
785   // means it will be the same size or bigger. If the signed type is bigger, it
786   // can represent all the values of the unsigned type, so select it.
787   if (Ctx.getIntWidth(LElTy) != Ctx.getIntWidth(RElTy)) {
788     if (IsCompAssign || LHSSigned)
789       return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
790     return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
791   }
792 
793   // This is a bit of an odd duck case in HLSL. It shouldn't happen, but can due
794   // to C/C++ leaking through. The place this happens today is long vs long
795   // long. When arguments are vector<unsigned long, N> and vector<long long, N>,
796   // the long long has higher rank than long even though they are the same size.
797 
798   // If this is a compound assignment cast the right hand side to the left hand
799   // side's type.
800   if (IsCompAssign)
801     return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
802 
803   // If this isn't a compound assignment we convert to unsigned long long.
804   QualType ElTy = Ctx.getCorrespondingUnsignedType(LHSSigned ? LElTy : RElTy);
805   QualType NewTy = Ctx.getExtVectorType(
806       ElTy, RHSType->castAs<VectorType>()->getNumElements());
807   (void)castElement<CK_IntegralCast>(SemaRef, RHS, NewTy);
808 
809   return castElement<CK_IntegralCast>(SemaRef, LHS, NewTy);
810 }
811 
812 static CastKind getScalarCastKind(ASTContext &Ctx, QualType DestTy,
813                                   QualType SrcTy) {
814   if (DestTy->isRealFloatingType() && SrcTy->isRealFloatingType())
815     return CK_FloatingCast;
816   if (DestTy->isIntegralType(Ctx) && SrcTy->isIntegralType(Ctx))
817     return CK_IntegralCast;
818   if (DestTy->isRealFloatingType())
819     return CK_IntegralToFloating;
820   assert(SrcTy->isRealFloatingType() && DestTy->isIntegralType(Ctx));
821   return CK_FloatingToIntegral;
822 }
823 
824 QualType SemaHLSL::handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
825                                                QualType LHSType,
826                                                QualType RHSType,
827                                                bool IsCompAssign) {
828   const auto *LVecTy = LHSType->getAs<VectorType>();
829   const auto *RVecTy = RHSType->getAs<VectorType>();
830   auto &Ctx = getASTContext();
831 
832   // If the LHS is not a vector and this is a compound assignment, we truncate
833   // the argument to a scalar then convert it to the LHS's type.
834   if (!LVecTy && IsCompAssign) {
835     QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
836     RHS = SemaRef.ImpCastExprToType(RHS.get(), RElTy, CK_HLSLVectorTruncation);
837     RHSType = RHS.get()->getType();
838     if (Ctx.hasSameUnqualifiedType(LHSType, RHSType))
839       return LHSType;
840     RHS = SemaRef.ImpCastExprToType(RHS.get(), LHSType,
841                                     getScalarCastKind(Ctx, LHSType, RHSType));
842     return LHSType;
843   }
844 
845   unsigned EndSz = std::numeric_limits<unsigned>::max();
846   unsigned LSz = 0;
847   if (LVecTy)
848     LSz = EndSz = LVecTy->getNumElements();
849   if (RVecTy)
850     EndSz = std::min(RVecTy->getNumElements(), EndSz);
851   assert(EndSz != std::numeric_limits<unsigned>::max() &&
852          "one of the above should have had a value");
853 
854   // In a compound assignment, the left operand does not change type, the right
855   // operand is converted to the type of the left operand.
856   if (IsCompAssign && LSz != EndSz) {
857     Diag(LHS.get()->getBeginLoc(),
858          diag::err_hlsl_vector_compound_assignment_truncation)
859         << LHSType << RHSType;
860     return QualType();
861   }
862 
863   if (RVecTy && RVecTy->getNumElements() > EndSz)
864     castVector<CK_HLSLVectorTruncation>(SemaRef, RHS, RHSType, EndSz);
865   if (!IsCompAssign && LVecTy && LVecTy->getNumElements() > EndSz)
866     castVector<CK_HLSLVectorTruncation>(SemaRef, LHS, LHSType, EndSz);
867 
868   if (!RVecTy)
869     castVector<CK_VectorSplat>(SemaRef, RHS, RHSType, EndSz);
870   if (!IsCompAssign && !LVecTy)
871     castVector<CK_VectorSplat>(SemaRef, LHS, LHSType, EndSz);
872 
873   // If we're at the same type after resizing we can stop here.
874   if (Ctx.hasSameUnqualifiedType(LHSType, RHSType))
875     return Ctx.getCommonSugaredType(LHSType, RHSType);
876 
877   QualType LElTy = LHSType->castAs<VectorType>()->getElementType();
878   QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
879 
880   // Handle conversion for floating point vectors.
881   if (LElTy->isRealFloatingType() || RElTy->isRealFloatingType())
882     return handleFloatVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
883                                             LElTy, RElTy, IsCompAssign);
884 
885   assert(LElTy->isIntegralType(Ctx) && RElTy->isIntegralType(Ctx) &&
886          "HLSL Vectors can only contain integer or floating point types");
887   return handleIntegerVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
888                                             LElTy, RElTy, IsCompAssign);
889 }
890 
891 void SemaHLSL::emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS,
892                                         BinaryOperatorKind Opc) {
893   assert((Opc == BO_LOr || Opc == BO_LAnd) &&
894          "Called with non-logical operator");
895   llvm::SmallVector<char, 256> Buff;
896   llvm::raw_svector_ostream OS(Buff);
897   PrintingPolicy PP(SemaRef.getLangOpts());
898   StringRef NewFnName = Opc == BO_LOr ? "or" : "and";
899   OS << NewFnName << "(";
900   LHS->printPretty(OS, nullptr, PP);
901   OS << ", ";
902   RHS->printPretty(OS, nullptr, PP);
903   OS << ")";
904   SourceRange FullRange = SourceRange(LHS->getBeginLoc(), RHS->getEndLoc());
905   SemaRef.Diag(LHS->getBeginLoc(), diag::note_function_suggestion)
906       << NewFnName << FixItHint::CreateReplacement(FullRange, OS.str());
907 }
908 
909 void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
910   llvm::VersionTuple SMVersion =
911       getASTContext().getTargetInfo().getTriple().getOSVersion();
912   uint32_t ZMax = 1024;
913   uint32_t ThreadMax = 1024;
914   if (SMVersion.getMajor() <= 4) {
915     ZMax = 1;
916     ThreadMax = 768;
917   } else if (SMVersion.getMajor() == 5) {
918     ZMax = 64;
919     ThreadMax = 1024;
920   }
921 
922   uint32_t X;
923   if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), X))
924     return;
925   if (X > 1024) {
926     Diag(AL.getArgAsExpr(0)->getExprLoc(),
927          diag::err_hlsl_numthreads_argument_oor)
928         << 0 << 1024;
929     return;
930   }
931   uint32_t Y;
932   if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Y))
933     return;
934   if (Y > 1024) {
935     Diag(AL.getArgAsExpr(1)->getExprLoc(),
936          diag::err_hlsl_numthreads_argument_oor)
937         << 1 << 1024;
938     return;
939   }
940   uint32_t Z;
941   if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Z))
942     return;
943   if (Z > ZMax) {
944     SemaRef.Diag(AL.getArgAsExpr(2)->getExprLoc(),
945                  diag::err_hlsl_numthreads_argument_oor)
946         << 2 << ZMax;
947     return;
948   }
949 
950   if (X * Y * Z > ThreadMax) {
951     Diag(AL.getLoc(), diag::err_hlsl_numthreads_invalid) << ThreadMax;
952     return;
953   }
954 
955   HLSLNumThreadsAttr *NewAttr = mergeNumThreadsAttr(D, AL, X, Y, Z);
956   if (NewAttr)
957     D->addAttr(NewAttr);
958 }
959 
960 static bool isValidWaveSizeValue(unsigned Value) {
961   return llvm::isPowerOf2_32(Value) && Value >= 4 && Value <= 128;
962 }
963 
964 void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
965   // validate that the wavesize argument is a power of 2 between 4 and 128
966   // inclusive
967   unsigned SpelledArgsCount = AL.getNumArgs();
968   if (SpelledArgsCount == 0 || SpelledArgsCount > 3)
969     return;
970 
971   uint32_t Min;
972   if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Min))
973     return;
974 
975   uint32_t Max = 0;
976   if (SpelledArgsCount > 1 &&
977       !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Max))
978     return;
979 
980   uint32_t Preferred = 0;
981   if (SpelledArgsCount > 2 &&
982       !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Preferred))
983     return;
984 
985   if (SpelledArgsCount > 2) {
986     if (!isValidWaveSizeValue(Preferred)) {
987       Diag(AL.getArgAsExpr(2)->getExprLoc(),
988            diag::err_attribute_power_of_two_in_range)
989           << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize
990           << Preferred;
991       return;
992     }
993     // Preferred not in range.
994     if (Preferred < Min || Preferred > Max) {
995       Diag(AL.getArgAsExpr(2)->getExprLoc(),
996            diag::err_attribute_power_of_two_in_range)
997           << AL << Min << Max << Preferred;
998       return;
999     }
1000   } else if (SpelledArgsCount > 1) {
1001     if (!isValidWaveSizeValue(Max)) {
1002       Diag(AL.getArgAsExpr(1)->getExprLoc(),
1003            diag::err_attribute_power_of_two_in_range)
1004           << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Max;
1005       return;
1006     }
1007     if (Max < Min) {
1008       Diag(AL.getLoc(), diag::err_attribute_argument_invalid) << AL << 1;
1009       return;
1010     } else if (Max == Min) {
1011       Diag(AL.getLoc(), diag::warn_attr_min_eq_max) << AL;
1012     }
1013   } else {
1014     if (!isValidWaveSizeValue(Min)) {
1015       Diag(AL.getArgAsExpr(0)->getExprLoc(),
1016            diag::err_attribute_power_of_two_in_range)
1017           << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Min;
1018       return;
1019     }
1020   }
1021 
1022   HLSLWaveSizeAttr *NewAttr =
1023       mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount);
1024   if (NewAttr)
1025     D->addAttr(NewAttr);
1026 }
1027 
1028 bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
1029   const auto *VT = T->getAs<VectorType>();
1030 
1031   if (!T->hasUnsignedIntegerRepresentation() ||
1032       (VT && VT->getNumElements() > 3)) {
1033     Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1034         << AL << "uint/uint2/uint3";
1035     return false;
1036   }
1037 
1038   return true;
1039 }
1040 
1041 void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
1042   auto *VD = cast<ValueDecl>(D);
1043   if (!diagnoseInputIDType(VD->getType(), AL))
1044     return;
1045 
1046   D->addAttr(::new (getASTContext())
1047                  HLSLSV_DispatchThreadIDAttr(getASTContext(), AL));
1048 }
1049 
1050 void SemaHLSL::handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL) {
1051   auto *VD = cast<ValueDecl>(D);
1052   if (!diagnoseInputIDType(VD->getType(), AL))
1053     return;
1054 
1055   D->addAttr(::new (getASTContext())
1056                  HLSLSV_GroupThreadIDAttr(getASTContext(), AL));
1057 }
1058 
1059 void SemaHLSL::handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL) {
1060   auto *VD = cast<ValueDecl>(D);
1061   if (!diagnoseInputIDType(VD->getType(), AL))
1062     return;
1063 
1064   D->addAttr(::new (getASTContext()) HLSLSV_GroupIDAttr(getASTContext(), AL));
1065 }
1066 
1067 void SemaHLSL::handlePackOffsetAttr(Decl *D, const ParsedAttr &AL) {
1068   if (!isa<VarDecl>(D) || !isa<HLSLBufferDecl>(D->getDeclContext())) {
1069     Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node)
1070         << AL << "shader constant in a constant buffer";
1071     return;
1072   }
1073 
1074   uint32_t SubComponent;
1075   if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), SubComponent))
1076     return;
1077   uint32_t Component;
1078   if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Component))
1079     return;
1080 
1081   QualType T = cast<VarDecl>(D)->getType().getCanonicalType();
1082   // Check if T is an array or struct type.
1083   // TODO: mark matrix type as aggregate type.
1084   bool IsAggregateTy = (T->isArrayType() || T->isStructureType());
1085 
1086   // Check Component is valid for T.
1087   if (Component) {
1088     unsigned Size = getASTContext().getTypeSize(T);
1089     if (IsAggregateTy || Size > 128) {
1090       Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary);
1091       return;
1092     } else {
1093       // Make sure Component + sizeof(T) <= 4.
1094       if ((Component * 32 + Size) > 128) {
1095         Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary);
1096         return;
1097       }
1098       QualType EltTy = T;
1099       if (const auto *VT = T->getAs<VectorType>())
1100         EltTy = VT->getElementType();
1101       unsigned Align = getASTContext().getTypeAlign(EltTy);
1102       if (Align > 32 && Component == 1) {
1103         // NOTE: Component 3 will hit err_hlsl_packoffset_cross_reg_boundary.
1104         // So we only need to check Component 1 here.
1105         Diag(AL.getLoc(), diag::err_hlsl_packoffset_alignment_mismatch)
1106             << Align << EltTy;
1107         return;
1108       }
1109     }
1110   }
1111 
1112   D->addAttr(::new (getASTContext()) HLSLPackOffsetAttr(
1113       getASTContext(), AL, SubComponent, Component));
1114 }
1115 
1116 void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) {
1117   StringRef Str;
1118   SourceLocation ArgLoc;
1119   if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc))
1120     return;
1121 
1122   llvm::Triple::EnvironmentType ShaderType;
1123   if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, ShaderType)) {
1124     Diag(AL.getLoc(), diag::warn_attribute_type_not_supported)
1125         << AL << Str << ArgLoc;
1126     return;
1127   }
1128 
1129   // FIXME: check function match the shader stage.
1130 
1131   HLSLShaderAttr *NewAttr = mergeShaderAttr(D, AL, ShaderType);
1132   if (NewAttr)
1133     D->addAttr(NewAttr);
1134 }
1135 
1136 bool clang::CreateHLSLAttributedResourceType(
1137     Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList,
1138     QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo) {
1139   assert(AttrList.size() && "expected list of resource attributes");
1140 
1141   QualType ContainedTy = QualType();
1142   TypeSourceInfo *ContainedTyInfo = nullptr;
1143   SourceLocation LocBegin = AttrList[0]->getRange().getBegin();
1144   SourceLocation LocEnd = AttrList[0]->getRange().getEnd();
1145 
1146   HLSLAttributedResourceType::Attributes ResAttrs;
1147 
1148   bool HasResourceClass = false;
1149   for (const Attr *A : AttrList) {
1150     if (!A)
1151       continue;
1152     LocEnd = A->getRange().getEnd();
1153     switch (A->getKind()) {
1154     case attr::HLSLResourceClass: {
1155       ResourceClass RC = cast<HLSLResourceClassAttr>(A)->getResourceClass();
1156       if (HasResourceClass) {
1157         S.Diag(A->getLocation(), ResAttrs.ResourceClass == RC
1158                                      ? diag::warn_duplicate_attribute_exact
1159                                      : diag::warn_duplicate_attribute)
1160             << A;
1161         return false;
1162       }
1163       ResAttrs.ResourceClass = RC;
1164       HasResourceClass = true;
1165       break;
1166     }
1167     case attr::HLSLROV:
1168       if (ResAttrs.IsROV) {
1169         S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
1170         return false;
1171       }
1172       ResAttrs.IsROV = true;
1173       break;
1174     case attr::HLSLRawBuffer:
1175       if (ResAttrs.RawBuffer) {
1176         S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
1177         return false;
1178       }
1179       ResAttrs.RawBuffer = true;
1180       break;
1181     case attr::HLSLContainedType: {
1182       const HLSLContainedTypeAttr *CTAttr = cast<HLSLContainedTypeAttr>(A);
1183       QualType Ty = CTAttr->getType();
1184       if (!ContainedTy.isNull()) {
1185         S.Diag(A->getLocation(), ContainedTy == Ty
1186                                      ? diag::warn_duplicate_attribute_exact
1187                                      : diag::warn_duplicate_attribute)
1188             << A;
1189         return false;
1190       }
1191       ContainedTy = Ty;
1192       ContainedTyInfo = CTAttr->getTypeLoc();
1193       break;
1194     }
1195     default:
1196       llvm_unreachable("unhandled resource attribute type");
1197     }
1198   }
1199 
1200   if (!HasResourceClass) {
1201     S.Diag(AttrList.back()->getRange().getEnd(),
1202            diag::err_hlsl_missing_resource_class);
1203     return false;
1204   }
1205 
1206   ResType = S.getASTContext().getHLSLAttributedResourceType(
1207       Wrapped, ContainedTy, ResAttrs);
1208 
1209   if (LocInfo && ContainedTyInfo) {
1210     LocInfo->Range = SourceRange(LocBegin, LocEnd);
1211     LocInfo->ContainedTyInfo = ContainedTyInfo;
1212   }
1213   return true;
1214 }
1215 
1216 // Validates and creates an HLSL attribute that is applied as type attribute on
1217 // HLSL resource. The attributes are collected in HLSLResourcesTypeAttrs and at
1218 // the end of the declaration they are applied to the declaration type by
1219 // wrapping it in HLSLAttributedResourceType.
1220 bool SemaHLSL::handleResourceTypeAttr(QualType T, const ParsedAttr &AL) {
1221   // only allow resource type attributes on intangible types
1222   if (!T->isHLSLResourceType()) {
1223     Diag(AL.getLoc(), diag::err_hlsl_attribute_needs_intangible_type)
1224         << AL << getASTContext().HLSLResourceTy;
1225     return false;
1226   }
1227 
1228   // validate number of arguments
1229   if (!AL.checkExactlyNumArgs(SemaRef, AL.getMinArgs()))
1230     return false;
1231 
1232   Attr *A = nullptr;
1233   switch (AL.getKind()) {
1234   case ParsedAttr::AT_HLSLResourceClass: {
1235     if (!AL.isArgIdent(0)) {
1236       Diag(AL.getLoc(), diag::err_attribute_argument_type)
1237           << AL << AANT_ArgumentIdentifier;
1238       return false;
1239     }
1240 
1241     IdentifierLoc *Loc = AL.getArgAsIdent(0);
1242     StringRef Identifier = Loc->Ident->getName();
1243     SourceLocation ArgLoc = Loc->Loc;
1244 
1245     // Validate resource class value
1246     ResourceClass RC;
1247     if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Identifier, RC)) {
1248       Diag(ArgLoc, diag::warn_attribute_type_not_supported)
1249           << "ResourceClass" << Identifier;
1250       return false;
1251     }
1252     A = HLSLResourceClassAttr::Create(getASTContext(), RC, AL.getLoc());
1253     break;
1254   }
1255 
1256   case ParsedAttr::AT_HLSLROV:
1257     A = HLSLROVAttr::Create(getASTContext(), AL.getLoc());
1258     break;
1259 
1260   case ParsedAttr::AT_HLSLRawBuffer:
1261     A = HLSLRawBufferAttr::Create(getASTContext(), AL.getLoc());
1262     break;
1263 
1264   case ParsedAttr::AT_HLSLContainedType: {
1265     if (AL.getNumArgs() != 1 && !AL.hasParsedType()) {
1266       Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1;
1267       return false;
1268     }
1269 
1270     TypeSourceInfo *TSI = nullptr;
1271     QualType QT = SemaRef.GetTypeFromParser(AL.getTypeArg(), &TSI);
1272     assert(TSI && "no type source info for attribute argument");
1273     if (SemaRef.RequireCompleteType(TSI->getTypeLoc().getBeginLoc(), QT,
1274                                     diag::err_incomplete_type))
1275       return false;
1276     A = HLSLContainedTypeAttr::Create(getASTContext(), TSI, AL.getLoc());
1277     break;
1278   }
1279 
1280   default:
1281     llvm_unreachable("unhandled HLSL attribute");
1282   }
1283 
1284   HLSLResourcesTypeAttrs.emplace_back(A);
1285   return true;
1286 }
1287 
1288 // Combines all resource type attributes and creates HLSLAttributedResourceType.
1289 QualType SemaHLSL::ProcessResourceTypeAttributes(QualType CurrentType) {
1290   if (!HLSLResourcesTypeAttrs.size())
1291     return CurrentType;
1292 
1293   QualType QT = CurrentType;
1294   HLSLAttributedResourceLocInfo LocInfo;
1295   if (CreateHLSLAttributedResourceType(SemaRef, CurrentType,
1296                                        HLSLResourcesTypeAttrs, QT, &LocInfo)) {
1297     const HLSLAttributedResourceType *RT =
1298         cast<HLSLAttributedResourceType>(QT.getTypePtr());
1299 
1300     // Temporarily store TypeLoc information for the new type.
1301     // It will be transferred to HLSLAttributesResourceTypeLoc
1302     // shortly after the type is created by TypeSpecLocFiller which
1303     // will call the TakeLocForHLSLAttribute method below.
1304     LocsForHLSLAttributedResources.insert(std::pair(RT, LocInfo));
1305   }
1306   HLSLResourcesTypeAttrs.clear();
1307   return QT;
1308 }
1309 
1310 // Returns source location for the HLSLAttributedResourceType
1311 HLSLAttributedResourceLocInfo
1312 SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
1313   HLSLAttributedResourceLocInfo LocInfo = {};
1314   auto I = LocsForHLSLAttributedResources.find(RT);
1315   if (I != LocsForHLSLAttributedResources.end()) {
1316     LocInfo = I->second;
1317     LocsForHLSLAttributedResources.erase(I);
1318     return LocInfo;
1319   }
1320   LocInfo.Range = SourceRange();
1321   return LocInfo;
1322 }
1323 
1324 // Walks though the global variable declaration, collects all resource binding
1325 // requirements and adds them to Bindings
1326 void SemaHLSL::collectResourcesOnUserRecordDecl(const VarDecl *VD,
1327                                                 const RecordType *RT) {
1328   const RecordDecl *RD = RT->getDecl();
1329   for (FieldDecl *FD : RD->fields()) {
1330     const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
1331 
1332     // Unwrap arrays
1333     // FIXME: Calculate array size while unwrapping
1334     assert(!Ty->isIncompleteArrayType() &&
1335            "incomplete arrays inside user defined types are not supported");
1336     while (Ty->isConstantArrayType()) {
1337       const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty);
1338       Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
1339     }
1340 
1341     if (!Ty->isRecordType())
1342       continue;
1343 
1344     if (const HLSLAttributedResourceType *AttrResType =
1345             HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) {
1346       // Add a new DeclBindingInfo to Bindings if it does not already exist
1347       ResourceClass RC = AttrResType->getAttrs().ResourceClass;
1348       DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, RC);
1349       if (!DBI)
1350         Bindings.addDeclBindingInfo(VD, RC);
1351     } else if (const RecordType *RT = dyn_cast<RecordType>(Ty)) {
1352       // Recursively scan embedded struct or class; it would be nice to do this
1353       // without recursion, but tricky to correctly calculate the size of the
1354       // binding, which is something we are probably going to need to do later
1355       // on. Hopefully nesting of structs in structs too many levels is
1356       // unlikely.
1357       collectResourcesOnUserRecordDecl(VD, RT);
1358     }
1359   }
1360 }
1361 
1362 // Diagnore localized register binding errors for a single binding; does not
1363 // diagnose resource binding on user record types, that will be done later
1364 // in processResourceBindingOnDecl based on the information collected in
1365 // collectResourcesOnVarDecl.
1366 // Returns false if the register binding is not valid.
1367 static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
1368                                          Decl *D, RegisterType RegType,
1369                                          bool SpecifiedSpace) {
1370   int RegTypeNum = static_cast<int>(RegType);
1371 
1372   // check if the decl type is groupshared
1373   if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
1374     S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1375     return false;
1376   }
1377 
1378   // Cbuffers and Tbuffers are HLSLBufferDecl types
1379   if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
1380     ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
1381                                                      : ResourceClass::SRV;
1382     if (RegType == getRegisterType(RC))
1383       return true;
1384 
1385     S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
1386         << RegTypeNum;
1387     return false;
1388   }
1389 
1390   // Samplers, UAVs, and SRVs are VarDecl types
1391   assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
1392   VarDecl *VD = cast<VarDecl>(D);
1393 
1394   // Resource
1395   if (const HLSLAttributedResourceType *AttrResType =
1396           HLSLAttributedResourceType::findHandleTypeOnResource(
1397               VD->getType().getTypePtr())) {
1398     if (RegType == getRegisterType(AttrResType->getAttrs().ResourceClass))
1399       return true;
1400 
1401     S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
1402         << RegTypeNum;
1403     return false;
1404   }
1405 
1406   const clang::Type *Ty = VD->getType().getTypePtr();
1407   while (Ty->isArrayType())
1408     Ty = Ty->getArrayElementTypeNoTypeQual();
1409 
1410   // Basic types
1411   if (Ty->isArithmeticType()) {
1412     bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext());
1413     if (SpecifiedSpace && !DeclaredInCOrTBuffer)
1414       S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);
1415 
1416     if (!DeclaredInCOrTBuffer &&
1417         (Ty->isIntegralType(S.getASTContext()) || Ty->isFloatingType())) {
1418       // Default Globals
1419       if (RegType == RegisterType::CBuffer)
1420         S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
1421       else if (RegType != RegisterType::C)
1422         S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1423     } else {
1424       if (RegType == RegisterType::C)
1425         S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
1426       else
1427         S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1428     }
1429     return false;
1430   }
1431   if (Ty->isRecordType())
1432     // RecordTypes will be diagnosed in processResourceBindingOnDecl
1433     // that is called from ActOnVariableDeclarator
1434     return true;
1435 
1436   // Anything else is an error
1437   S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1438   return false;
1439 }
1440 
1441 static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
1442                                                 RegisterType regType) {
1443   // make sure that there are no two register annotations
1444   // applied to the decl with the same register type
1445   bool RegisterTypesDetected[5] = {false};
1446   RegisterTypesDetected[static_cast<int>(regType)] = true;
1447 
1448   for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) {
1449     if (HLSLResourceBindingAttr *attr =
1450             dyn_cast<HLSLResourceBindingAttr>(*it)) {
1451 
1452       RegisterType otherRegType = attr->getRegisterType();
1453       if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
1454         int otherRegTypeNum = static_cast<int>(otherRegType);
1455         S.Diag(TheDecl->getLocation(),
1456                diag::err_hlsl_duplicate_register_annotation)
1457             << otherRegTypeNum;
1458         return false;
1459       }
1460       RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
1461     }
1462   }
1463   return true;
1464 }
1465 
1466 static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
1467                                           Decl *D, RegisterType RegType,
1468                                           bool SpecifiedSpace) {
1469 
1470   // exactly one of these two types should be set
1471   assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||
1472           (!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) &&
1473          "expecting VarDecl or HLSLBufferDecl");
1474 
1475   // check if the declaration contains resource matching the register type
1476   if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace))
1477     return false;
1478 
1479   // next, if multiple register annotations exist, check that none conflict.
1480   return ValidateMultipleRegisterAnnotations(S, D, RegType);
1481 }
1482 
1483 void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
1484   if (isa<VarDecl>(TheDecl)) {
1485     if (SemaRef.RequireCompleteType(TheDecl->getBeginLoc(),
1486                                     cast<ValueDecl>(TheDecl)->getType(),
1487                                     diag::err_incomplete_type))
1488       return;
1489   }
1490   StringRef Space = "space0";
1491   StringRef Slot = "";
1492 
1493   if (!AL.isArgIdent(0)) {
1494     Diag(AL.getLoc(), diag::err_attribute_argument_type)
1495         << AL << AANT_ArgumentIdentifier;
1496     return;
1497   }
1498 
1499   IdentifierLoc *Loc = AL.getArgAsIdent(0);
1500   StringRef Str = Loc->Ident->getName();
1501   SourceLocation ArgLoc = Loc->Loc;
1502 
1503   SourceLocation SpaceArgLoc;
1504   bool SpecifiedSpace = false;
1505   if (AL.getNumArgs() == 2) {
1506     SpecifiedSpace = true;
1507     Slot = Str;
1508     if (!AL.isArgIdent(1)) {
1509       Diag(AL.getLoc(), diag::err_attribute_argument_type)
1510           << AL << AANT_ArgumentIdentifier;
1511       return;
1512     }
1513 
1514     IdentifierLoc *Loc = AL.getArgAsIdent(1);
1515     Space = Loc->Ident->getName();
1516     SpaceArgLoc = Loc->Loc;
1517   } else {
1518     Slot = Str;
1519   }
1520 
1521   RegisterType RegType;
1522   unsigned SlotNum = 0;
1523   unsigned SpaceNum = 0;
1524 
1525   // Validate.
1526   if (!Slot.empty()) {
1527     if (!convertToRegisterType(Slot, &RegType)) {
1528       Diag(ArgLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1);
1529       return;
1530     }
1531     if (RegType == RegisterType::I) {
1532       Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_i);
1533       return;
1534     }
1535 
1536     StringRef SlotNumStr = Slot.substr(1);
1537     if (SlotNumStr.getAsInteger(10, SlotNum)) {
1538       Diag(ArgLoc, diag::err_hlsl_unsupported_register_number);
1539       return;
1540     }
1541   }
1542 
1543   if (!Space.starts_with("space")) {
1544     Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space;
1545     return;
1546   }
1547   StringRef SpaceNumStr = Space.substr(5);
1548   if (SpaceNumStr.getAsInteger(10, SpaceNum)) {
1549     Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space;
1550     return;
1551   }
1552 
1553   if (!DiagnoseHLSLRegisterAttribute(SemaRef, ArgLoc, TheDecl, RegType,
1554                                      SpecifiedSpace))
1555     return;
1556 
1557   HLSLResourceBindingAttr *NewAttr =
1558       HLSLResourceBindingAttr::Create(getASTContext(), Slot, Space, AL);
1559   if (NewAttr) {
1560     NewAttr->setBinding(RegType, SlotNum, SpaceNum);
1561     TheDecl->addAttr(NewAttr);
1562   }
1563 }
1564 
1565 void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) {
1566   HLSLParamModifierAttr *NewAttr = mergeParamModifierAttr(
1567       D, AL,
1568       static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
1569   if (NewAttr)
1570     D->addAttr(NewAttr);
1571 }
1572 
1573 namespace {
1574 
1575 /// This class implements HLSL availability diagnostics for default
1576 /// and relaxed mode
1577 ///
1578 /// The goal of this diagnostic is to emit an error or warning when an
1579 /// unavailable API is found in code that is reachable from the shader
1580 /// entry function or from an exported function (when compiling a shader
1581 /// library).
1582 ///
1583 /// This is done by traversing the AST of all shader entry point functions
1584 /// and of all exported functions, and any functions that are referenced
1585 /// from this AST. In other words, any functions that are reachable from
1586 /// the entry points.
1587 class DiagnoseHLSLAvailability : public DynamicRecursiveASTVisitor {
1588   Sema &SemaRef;
1589 
1590   // Stack of functions to be scaned
1591   llvm::SmallVector<const FunctionDecl *, 8> DeclsToScan;
1592 
1593   // Tracks which environments functions have been scanned in.
1594   //
1595   // Maps FunctionDecl to an unsigned number that represents the set of shader
1596   // environments the function has been scanned for.
1597   // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
1598   // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
1599   // (verified by static_asserts in Triple.cpp), we can use it to index
1600   // individual bits in the set, as long as we shift the values to start with 0
1601   // by subtracting the value of llvm::Triple::Pixel first.
1602   //
1603   // The N'th bit in the set will be set if the function has been scanned
1604   // in shader environment whose llvm::Triple::EnvironmentType integer value
1605   // equals (llvm::Triple::Pixel + N).
1606   //
1607   // For example, if a function has been scanned in compute and pixel stage
1608   // environment, the value will be 0x21 (100001 binary) because:
1609   //
1610   //   (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
1611   //   (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
1612   //
1613   // A FunctionDecl is mapped to 0 (or not included in the map) if it has not
1614   // been scanned in any environment.
1615   llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
1616 
1617   // Do not access these directly, use the get/set methods below to make
1618   // sure the values are in sync
1619   llvm::Triple::EnvironmentType CurrentShaderEnvironment;
1620   unsigned CurrentShaderStageBit;
1621 
1622   // True if scanning a function that was already scanned in a different
1623   // shader stage context, and therefore we should not report issues that
1624   // depend only on shader model version because they would be duplicate.
1625   bool ReportOnlyShaderStageIssues;
1626 
1627   // Helper methods for dealing with current stage context / environment
1628   void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
1629     static_assert(sizeof(unsigned) >= 4);
1630     assert(HLSLShaderAttr::isValidShaderType(ShaderType));
1631     assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
1632            "ShaderType is too big for this bitmap"); // 31 is reserved for
1633                                                      // "unknown"
1634 
1635     unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
1636     CurrentShaderEnvironment = ShaderType;
1637     CurrentShaderStageBit = (1 << bitmapIndex);
1638   }
1639 
1640   void SetUnknownShaderStageContext() {
1641     CurrentShaderEnvironment = llvm::Triple::UnknownEnvironment;
1642     CurrentShaderStageBit = (1 << 31);
1643   }
1644 
1645   llvm::Triple::EnvironmentType GetCurrentShaderEnvironment() const {
1646     return CurrentShaderEnvironment;
1647   }
1648 
1649   bool InUnknownShaderStageContext() const {
1650     return CurrentShaderEnvironment == llvm::Triple::UnknownEnvironment;
1651   }
1652 
1653   // Helper methods for dealing with shader stage bitmap
1654   void AddToScannedFunctions(const FunctionDecl *FD) {
1655     unsigned &ScannedStages = ScannedDecls[FD];
1656     ScannedStages |= CurrentShaderStageBit;
1657   }
1658 
1659   unsigned GetScannedStages(const FunctionDecl *FD) { return ScannedDecls[FD]; }
1660 
1661   bool WasAlreadyScannedInCurrentStage(const FunctionDecl *FD) {
1662     return WasAlreadyScannedInCurrentStage(GetScannedStages(FD));
1663   }
1664 
1665   bool WasAlreadyScannedInCurrentStage(unsigned ScannerStages) {
1666     return ScannerStages & CurrentShaderStageBit;
1667   }
1668 
1669   static bool NeverBeenScanned(unsigned ScannedStages) {
1670     return ScannedStages == 0;
1671   }
1672 
1673   // Scanning methods
1674   void HandleFunctionOrMethodRef(FunctionDecl *FD, Expr *RefExpr);
1675   void CheckDeclAvailability(NamedDecl *D, const AvailabilityAttr *AA,
1676                              SourceRange Range);
1677   const AvailabilityAttr *FindAvailabilityAttr(const Decl *D);
1678   bool HasMatchingEnvironmentOrNone(const AvailabilityAttr *AA);
1679 
1680 public:
1681   DiagnoseHLSLAvailability(Sema &SemaRef)
1682       : SemaRef(SemaRef),
1683         CurrentShaderEnvironment(llvm::Triple::UnknownEnvironment),
1684         CurrentShaderStageBit(0), ReportOnlyShaderStageIssues(false) {}
1685 
1686   // AST traversal methods
1687   void RunOnTranslationUnit(const TranslationUnitDecl *TU);
1688   void RunOnFunction(const FunctionDecl *FD);
1689 
1690   bool VisitDeclRefExpr(DeclRefExpr *DRE) override {
1691     FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(DRE->getDecl());
1692     if (FD)
1693       HandleFunctionOrMethodRef(FD, DRE);
1694     return true;
1695   }
1696 
1697   bool VisitMemberExpr(MemberExpr *ME) override {
1698     FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(ME->getMemberDecl());
1699     if (FD)
1700       HandleFunctionOrMethodRef(FD, ME);
1701     return true;
1702   }
1703 };
1704 
1705 void DiagnoseHLSLAvailability::HandleFunctionOrMethodRef(FunctionDecl *FD,
1706                                                          Expr *RefExpr) {
1707   assert((isa<DeclRefExpr>(RefExpr) || isa<MemberExpr>(RefExpr)) &&
1708          "expected DeclRefExpr or MemberExpr");
1709 
1710   // has a definition -> add to stack to be scanned
1711   const FunctionDecl *FDWithBody = nullptr;
1712   if (FD->hasBody(FDWithBody)) {
1713     if (!WasAlreadyScannedInCurrentStage(FDWithBody))
1714       DeclsToScan.push_back(FDWithBody);
1715     return;
1716   }
1717 
1718   // no body -> diagnose availability
1719   const AvailabilityAttr *AA = FindAvailabilityAttr(FD);
1720   if (AA)
1721     CheckDeclAvailability(
1722         FD, AA, SourceRange(RefExpr->getBeginLoc(), RefExpr->getEndLoc()));
1723 }
1724 
1725 void DiagnoseHLSLAvailability::RunOnTranslationUnit(
1726     const TranslationUnitDecl *TU) {
1727 
1728   // Iterate over all shader entry functions and library exports, and for those
1729   // that have a body (definiton), run diag scan on each, setting appropriate
1730   // shader environment context based on whether it is a shader entry function
1731   // or an exported function. Exported functions can be in namespaces and in
1732   // export declarations so we need to scan those declaration contexts as well.
1733   llvm::SmallVector<const DeclContext *, 8> DeclContextsToScan;
1734   DeclContextsToScan.push_back(TU);
1735 
1736   while (!DeclContextsToScan.empty()) {
1737     const DeclContext *DC = DeclContextsToScan.pop_back_val();
1738     for (auto &D : DC->decls()) {
1739       // do not scan implicit declaration generated by the implementation
1740       if (D->isImplicit())
1741         continue;
1742 
1743       // for namespace or export declaration add the context to the list to be
1744       // scanned later
1745       if (llvm::dyn_cast<NamespaceDecl>(D) || llvm::dyn_cast<ExportDecl>(D)) {
1746         DeclContextsToScan.push_back(llvm::dyn_cast<DeclContext>(D));
1747         continue;
1748       }
1749 
1750       // skip over other decls or function decls without body
1751       const FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(D);
1752       if (!FD || !FD->isThisDeclarationADefinition())
1753         continue;
1754 
1755       // shader entry point
1756       if (HLSLShaderAttr *ShaderAttr = FD->getAttr<HLSLShaderAttr>()) {
1757         SetShaderStageContext(ShaderAttr->getType());
1758         RunOnFunction(FD);
1759         continue;
1760       }
1761       // exported library function
1762       // FIXME: replace this loop with external linkage check once issue #92071
1763       // is resolved
1764       bool isExport = FD->isInExportDeclContext();
1765       if (!isExport) {
1766         for (const auto *Redecl : FD->redecls()) {
1767           if (Redecl->isInExportDeclContext()) {
1768             isExport = true;
1769             break;
1770           }
1771         }
1772       }
1773       if (isExport) {
1774         SetUnknownShaderStageContext();
1775         RunOnFunction(FD);
1776         continue;
1777       }
1778     }
1779   }
1780 }
1781 
1782 void DiagnoseHLSLAvailability::RunOnFunction(const FunctionDecl *FD) {
1783   assert(DeclsToScan.empty() && "DeclsToScan should be empty");
1784   DeclsToScan.push_back(FD);
1785 
1786   while (!DeclsToScan.empty()) {
1787     // Take one decl from the stack and check it by traversing its AST.
1788     // For any CallExpr found during the traversal add it's callee to the top of
1789     // the stack to be processed next. Functions already processed are stored in
1790     // ScannedDecls.
1791     const FunctionDecl *FD = DeclsToScan.pop_back_val();
1792 
1793     // Decl was already scanned
1794     const unsigned ScannedStages = GetScannedStages(FD);
1795     if (WasAlreadyScannedInCurrentStage(ScannedStages))
1796       continue;
1797 
1798     ReportOnlyShaderStageIssues = !NeverBeenScanned(ScannedStages);
1799 
1800     AddToScannedFunctions(FD);
1801     TraverseStmt(FD->getBody());
1802   }
1803 }
1804 
1805 bool DiagnoseHLSLAvailability::HasMatchingEnvironmentOrNone(
1806     const AvailabilityAttr *AA) {
1807   IdentifierInfo *IIEnvironment = AA->getEnvironment();
1808   if (!IIEnvironment)
1809     return true;
1810 
1811   llvm::Triple::EnvironmentType CurrentEnv = GetCurrentShaderEnvironment();
1812   if (CurrentEnv == llvm::Triple::UnknownEnvironment)
1813     return false;
1814 
1815   llvm::Triple::EnvironmentType AttrEnv =
1816       AvailabilityAttr::getEnvironmentType(IIEnvironment->getName());
1817 
1818   return CurrentEnv == AttrEnv;
1819 }
1820 
1821 const AvailabilityAttr *
1822 DiagnoseHLSLAvailability::FindAvailabilityAttr(const Decl *D) {
1823   AvailabilityAttr const *PartialMatch = nullptr;
1824   // Check each AvailabilityAttr to find the one for this platform.
1825   // For multiple attributes with the same platform try to find one for this
1826   // environment.
1827   for (const auto *A : D->attrs()) {
1828     if (const auto *Avail = dyn_cast<AvailabilityAttr>(A)) {
1829       StringRef AttrPlatform = Avail->getPlatform()->getName();
1830       StringRef TargetPlatform =
1831           SemaRef.getASTContext().getTargetInfo().getPlatformName();
1832 
1833       // Match the platform name.
1834       if (AttrPlatform == TargetPlatform) {
1835         // Find the best matching attribute for this environment
1836         if (HasMatchingEnvironmentOrNone(Avail))
1837           return Avail;
1838         PartialMatch = Avail;
1839       }
1840     }
1841   }
1842   return PartialMatch;
1843 }
1844 
1845 // Check availability against target shader model version and current shader
1846 // stage and emit diagnostic
1847 void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D,
1848                                                      const AvailabilityAttr *AA,
1849                                                      SourceRange Range) {
1850 
1851   IdentifierInfo *IIEnv = AA->getEnvironment();
1852 
1853   if (!IIEnv) {
1854     // The availability attribute does not have environment -> it depends only
1855     // on shader model version and not on specific the shader stage.
1856 
1857     // Skip emitting the diagnostics if the diagnostic mode is set to
1858     // strict (-fhlsl-strict-availability) because all relevant diagnostics
1859     // were already emitted in the DiagnoseUnguardedAvailability scan
1860     // (SemaAvailability.cpp).
1861     if (SemaRef.getLangOpts().HLSLStrictAvailability)
1862       return;
1863 
1864     // Do not report shader-stage-independent issues if scanning a function
1865     // that was already scanned in a different shader stage context (they would
1866     // be duplicate)
1867     if (ReportOnlyShaderStageIssues)
1868       return;
1869 
1870   } else {
1871     // The availability attribute has environment -> we need to know
1872     // the current stage context to property diagnose it.
1873     if (InUnknownShaderStageContext())
1874       return;
1875   }
1876 
1877   // Check introduced version and if environment matches
1878   bool EnvironmentMatches = HasMatchingEnvironmentOrNone(AA);
1879   VersionTuple Introduced = AA->getIntroduced();
1880   VersionTuple TargetVersion =
1881       SemaRef.Context.getTargetInfo().getPlatformMinVersion();
1882 
1883   if (TargetVersion >= Introduced && EnvironmentMatches)
1884     return;
1885 
1886   // Emit diagnostic message
1887   const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
1888   llvm::StringRef PlatformName(
1889       AvailabilityAttr::getPrettyPlatformName(TI.getPlatformName()));
1890 
1891   llvm::StringRef CurrentEnvStr =
1892       llvm::Triple::getEnvironmentTypeName(GetCurrentShaderEnvironment());
1893 
1894   llvm::StringRef AttrEnvStr =
1895       AA->getEnvironment() ? AA->getEnvironment()->getName() : "";
1896   bool UseEnvironment = !AttrEnvStr.empty();
1897 
1898   if (EnvironmentMatches) {
1899     SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability)
1900         << Range << D << PlatformName << Introduced.getAsString()
1901         << UseEnvironment << CurrentEnvStr;
1902   } else {
1903     SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability_unavailable)
1904         << Range << D;
1905   }
1906 
1907   SemaRef.Diag(D->getLocation(), diag::note_partial_availability_specified_here)
1908       << D << PlatformName << Introduced.getAsString()
1909       << SemaRef.Context.getTargetInfo().getPlatformMinVersion().getAsString()
1910       << UseEnvironment << AttrEnvStr << CurrentEnvStr;
1911 }
1912 
1913 } // namespace
1914 
1915 void SemaHLSL::DiagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
1916   // Skip running the diagnostics scan if the diagnostic mode is
1917   // strict (-fhlsl-strict-availability) and the target shader stage is known
1918   // because all relevant diagnostics were already emitted in the
1919   // DiagnoseUnguardedAvailability scan (SemaAvailability.cpp).
1920   const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
1921   if (SemaRef.getLangOpts().HLSLStrictAvailability &&
1922       TI.getTriple().getEnvironment() != llvm::Triple::EnvironmentType::Library)
1923     return;
1924 
1925   DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);
1926 }
1927 
1928 // Helper function for CheckHLSLBuiltinFunctionCall
1929 static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
1930   assert(TheCall->getNumArgs() > 1);
1931   ExprResult A = TheCall->getArg(0);
1932 
1933   QualType ArgTyA = A.get()->getType();
1934 
1935   auto *VecTyA = ArgTyA->getAs<VectorType>();
1936   SourceLocation BuiltinLoc = TheCall->getBeginLoc();
1937 
1938   bool AllBArgAreVectors = true;
1939   for (unsigned i = 1; i < TheCall->getNumArgs(); ++i) {
1940     ExprResult B = TheCall->getArg(i);
1941     QualType ArgTyB = B.get()->getType();
1942     auto *VecTyB = ArgTyB->getAs<VectorType>();
1943     if (VecTyB == nullptr)
1944       AllBArgAreVectors &= false;
1945     if (VecTyA && VecTyB == nullptr) {
1946       // Note: if we get here 'B' is scalar which
1947       // requires a VectorSplat on ArgN
1948       S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
1949           << TheCall->getDirectCallee() << /*useAllTerminology*/ true
1950           << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
1951       return true;
1952     }
1953     if (VecTyA && VecTyB) {
1954       bool retValue = false;
1955       if (VecTyA->getElementType() != VecTyB->getElementType()) {
1956         // Note: type promotion is intended to be handeled via the intrinsics
1957         //  and not the builtin itself.
1958         S->Diag(TheCall->getBeginLoc(),
1959                 diag::err_vec_builtin_incompatible_vector)
1960             << TheCall->getDirectCallee() << /*useAllTerminology*/ true
1961             << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
1962         retValue = true;
1963       }
1964       if (VecTyA->getNumElements() != VecTyB->getNumElements()) {
1965         // You should only be hitting this case if you are calling the builtin
1966         // directly. HLSL intrinsics should avoid this case via a
1967         // HLSLVectorTruncation.
1968         S->Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
1969             << TheCall->getDirectCallee() << /*useAllTerminology*/ true
1970             << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
1971         retValue = true;
1972       }
1973       if (retValue)
1974         return retValue;
1975     }
1976   }
1977 
1978   if (VecTyA == nullptr && AllBArgAreVectors) {
1979     // Note: if we get here 'A' is a scalar which
1980     // requires a VectorSplat on Arg0
1981     S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
1982         << TheCall->getDirectCallee() << /*useAllTerminology*/ true
1983         << SourceRange(A.get()->getBeginLoc(), A.get()->getEndLoc());
1984     return true;
1985   }
1986   return false;
1987 }
1988 
1989 static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
1990   QualType ArgType = Arg->getType();
1991   if (!S->getASTContext().hasSameUnqualifiedType(ArgType, ExpectedType)) {
1992     S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
1993         << ArgType << ExpectedType << 1 << 0 << 0;
1994     return true;
1995   }
1996   return false;
1997 }
1998 
1999 static bool CheckArgTypeIsCorrect(
2000     Sema *S, Expr *Arg, QualType ExpectedType,
2001     llvm::function_ref<bool(clang::QualType PassedType)> Check) {
2002   QualType PassedType = Arg->getType();
2003   if (Check(PassedType)) {
2004     if (auto *VecTyA = PassedType->getAs<VectorType>())
2005       ExpectedType = S->Context.getVectorType(
2006           ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
2007     S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
2008         << PassedType << ExpectedType << 1 << 0 << 0;
2009     return true;
2010   }
2011   return false;
2012 }
2013 
2014 static bool CheckAllArgTypesAreCorrect(
2015     Sema *S, CallExpr *TheCall, QualType ExpectedType,
2016     llvm::function_ref<bool(clang::QualType PassedType)> Check) {
2017   for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
2018     Expr *Arg = TheCall->getArg(i);
2019     if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
2020       return true;
2021     }
2022   }
2023   return false;
2024 }
2025 
2026 static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
2027   auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
2028     return !PassedType->hasFloatingRepresentation();
2029   };
2030   return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
2031                                     checkAllFloatTypes);
2032 }
2033 
2034 static bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
2035   auto checkFloatorHalf = [](clang::QualType PassedType) -> bool {
2036     clang::QualType BaseType =
2037         PassedType->isVectorType()
2038             ? PassedType->getAs<clang::VectorType>()->getElementType()
2039             : PassedType;
2040     return !BaseType->isHalfType() && !BaseType->isFloat32Type();
2041   };
2042   return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
2043                                     checkFloatorHalf);
2044 }
2045 
2046 static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
2047                                   unsigned ArgIndex) {
2048   auto *Arg = TheCall->getArg(ArgIndex);
2049   SourceLocation OrigLoc = Arg->getExprLoc();
2050   if (Arg->IgnoreCasts()->isModifiableLvalue(S->Context, &OrigLoc) ==
2051       Expr::MLV_Valid)
2052     return false;
2053   S->Diag(OrigLoc, diag::error_hlsl_inout_lvalue) << Arg << 0;
2054   return true;
2055 }
2056 
2057 static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
2058   auto checkDoubleVector = [](clang::QualType PassedType) -> bool {
2059     if (const auto *VecTy = PassedType->getAs<VectorType>())
2060       return VecTy->getElementType()->isDoubleType();
2061     return false;
2062   };
2063   return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
2064                                     checkDoubleVector);
2065 }
2066 static bool CheckFloatingOrIntRepresentation(Sema *S, CallExpr *TheCall) {
2067   auto checkAllSignedTypes = [](clang::QualType PassedType) -> bool {
2068     return !PassedType->hasIntegerRepresentation() &&
2069            !PassedType->hasFloatingRepresentation();
2070   };
2071   return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.IntTy,
2072                                     checkAllSignedTypes);
2073 }
2074 
2075 static bool CheckUnsignedIntRepresentation(Sema *S, CallExpr *TheCall) {
2076   auto checkAllUnsignedTypes = [](clang::QualType PassedType) -> bool {
2077     return !PassedType->hasUnsignedIntegerRepresentation();
2078   };
2079   return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,
2080                                     checkAllUnsignedTypes);
2081 }
2082 
2083 static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
2084                                        QualType ReturnType) {
2085   auto *VecTyA = TheCall->getArg(0)->getType()->getAs<VectorType>();
2086   if (VecTyA)
2087     ReturnType = S->Context.getVectorType(ReturnType, VecTyA->getNumElements(),
2088                                           VectorKind::Generic);
2089   TheCall->setType(ReturnType);
2090 }
2091 
2092 static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
2093                                 unsigned ArgIndex) {
2094   assert(TheCall->getNumArgs() >= ArgIndex);
2095   QualType ArgType = TheCall->getArg(ArgIndex)->getType();
2096   auto *VTy = ArgType->getAs<VectorType>();
2097   // not the scalar or vector<scalar>
2098   if (!(S->Context.hasSameUnqualifiedType(ArgType, Scalar) ||
2099         (VTy &&
2100          S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar)))) {
2101     S->Diag(TheCall->getArg(0)->getBeginLoc(),
2102             diag::err_typecheck_expect_scalar_or_vector)
2103         << ArgType << Scalar;
2104     return true;
2105   }
2106   return false;
2107 }
2108 
2109 static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
2110                                    unsigned ArgIndex) {
2111   assert(TheCall->getNumArgs() >= ArgIndex);
2112   QualType ArgType = TheCall->getArg(ArgIndex)->getType();
2113   auto *VTy = ArgType->getAs<VectorType>();
2114   // not the scalar or vector<scalar>
2115   if (!(ArgType->isScalarType() ||
2116         (VTy && VTy->getElementType()->isScalarType()))) {
2117     S->Diag(TheCall->getArg(0)->getBeginLoc(),
2118             diag::err_typecheck_expect_any_scalar_or_vector)
2119         << ArgType << 1;
2120     return true;
2121   }
2122   return false;
2123 }
2124 
2125 static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
2126   QualType BoolType = S->getASTContext().BoolTy;
2127   assert(TheCall->getNumArgs() >= 1);
2128   QualType ArgType = TheCall->getArg(0)->getType();
2129   auto *VTy = ArgType->getAs<VectorType>();
2130   // is the bool or vector<bool>
2131   if (S->Context.hasSameUnqualifiedType(ArgType, BoolType) ||
2132       (VTy &&
2133        S->Context.hasSameUnqualifiedType(VTy->getElementType(), BoolType))) {
2134     S->Diag(TheCall->getArg(0)->getBeginLoc(),
2135             diag::err_typecheck_expect_any_scalar_or_vector)
2136         << ArgType << 0;
2137     return true;
2138   }
2139   return false;
2140 }
2141 
2142 static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
2143   assert(TheCall->getNumArgs() == 3);
2144   Expr *Arg1 = TheCall->getArg(1);
2145   Expr *Arg2 = TheCall->getArg(2);
2146   if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
2147     S->Diag(TheCall->getBeginLoc(),
2148             diag::err_typecheck_call_different_arg_types)
2149         << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
2150         << Arg2->getSourceRange();
2151     return true;
2152   }
2153 
2154   TheCall->setType(Arg1->getType());
2155   return false;
2156 }
2157 
2158 static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
2159   assert(TheCall->getNumArgs() == 3);
2160   Expr *Arg1 = TheCall->getArg(1);
2161   Expr *Arg2 = TheCall->getArg(2);
2162   if (!Arg1->getType()->isVectorType()) {
2163     S->Diag(Arg1->getBeginLoc(), diag::err_builtin_non_vector_type)
2164         << "Second" << TheCall->getDirectCallee() << Arg1->getType()
2165         << Arg1->getSourceRange();
2166     return true;
2167   }
2168 
2169   if (!Arg2->getType()->isVectorType()) {
2170     S->Diag(Arg2->getBeginLoc(), diag::err_builtin_non_vector_type)
2171         << "Third" << TheCall->getDirectCallee() << Arg2->getType()
2172         << Arg2->getSourceRange();
2173     return true;
2174   }
2175 
2176   if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
2177     S->Diag(TheCall->getBeginLoc(),
2178             diag::err_typecheck_call_different_arg_types)
2179         << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
2180         << Arg2->getSourceRange();
2181     return true;
2182   }
2183 
2184   // caller has checked that Arg0 is a vector.
2185   // check all three args have the same length.
2186   if (TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() !=
2187       Arg1->getType()->getAs<VectorType>()->getNumElements()) {
2188     S->Diag(TheCall->getBeginLoc(),
2189             diag::err_typecheck_vector_lengths_not_equal)
2190         << TheCall->getArg(0)->getType() << Arg1->getType()
2191         << TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange();
2192     return true;
2193   }
2194   TheCall->setType(Arg1->getType());
2195   return false;
2196 }
2197 
2198 static bool CheckResourceHandle(
2199     Sema *S, CallExpr *TheCall, unsigned ArgIndex,
2200     llvm::function_ref<bool(const HLSLAttributedResourceType *ResType)> Check =
2201         nullptr) {
2202   assert(TheCall->getNumArgs() >= ArgIndex);
2203   QualType ArgType = TheCall->getArg(ArgIndex)->getType();
2204   const HLSLAttributedResourceType *ResTy =
2205       ArgType.getTypePtr()->getAs<HLSLAttributedResourceType>();
2206   if (!ResTy) {
2207     S->Diag(TheCall->getArg(ArgIndex)->getBeginLoc(),
2208             diag::err_typecheck_expect_hlsl_resource)
2209         << ArgType;
2210     return true;
2211   }
2212   if (Check && Check(ResTy)) {
2213     S->Diag(TheCall->getArg(ArgIndex)->getExprLoc(),
2214             diag::err_invalid_hlsl_resource_type)
2215         << ArgType;
2216     return true;
2217   }
2218   return false;
2219 }
2220 
2221 // Note: returning true in this case results in CheckBuiltinFunctionCall
2222 // returning an ExprError
2223 bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
2224   switch (BuiltinID) {
2225   case Builtin::BI__builtin_hlsl_resource_getpointer: {
2226     if (SemaRef.checkArgCount(TheCall, 2) ||
2227         CheckResourceHandle(&SemaRef, TheCall, 0) ||
2228         CheckArgTypeMatches(&SemaRef, TheCall->getArg(1),
2229                             SemaRef.getASTContext().UnsignedIntTy))
2230       return true;
2231 
2232     auto *ResourceTy =
2233         TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
2234     QualType ContainedTy = ResourceTy->getContainedType();
2235     // TODO: Map to an hlsl_device address space.
2236     TheCall->setType(getASTContext().getPointerType(ContainedTy));
2237     TheCall->setValueKind(VK_LValue);
2238 
2239     break;
2240   }
2241   case Builtin::BI__builtin_hlsl_all:
2242   case Builtin::BI__builtin_hlsl_any: {
2243     if (SemaRef.checkArgCount(TheCall, 1))
2244       return true;
2245     break;
2246   }
2247   case Builtin::BI__builtin_hlsl_asdouble: {
2248     if (SemaRef.checkArgCount(TheCall, 2))
2249       return true;
2250     if (CheckUnsignedIntRepresentation(&SemaRef, TheCall))
2251       return true;
2252 
2253     SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().DoubleTy);
2254     break;
2255   }
2256   case Builtin::BI__builtin_hlsl_elementwise_clamp: {
2257     if (SemaRef.checkArgCount(TheCall, 3))
2258       return true;
2259     if (CheckVectorElementCallArgs(&SemaRef, TheCall))
2260       return true;
2261     if (SemaRef.BuiltinElementwiseTernaryMath(
2262             TheCall, /*CheckForFloatArgs*/
2263             TheCall->getArg(0)->getType()->hasFloatingRepresentation()))
2264       return true;
2265     break;
2266   }
2267   case Builtin::BI__builtin_hlsl_cross: {
2268     if (SemaRef.checkArgCount(TheCall, 2))
2269       return true;
2270     if (CheckVectorElementCallArgs(&SemaRef, TheCall))
2271       return true;
2272     if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
2273       return true;
2274     // ensure both args have 3 elements
2275     int NumElementsArg1 =
2276         TheCall->getArg(0)->getType()->castAs<VectorType>()->getNumElements();
2277     int NumElementsArg2 =
2278         TheCall->getArg(1)->getType()->castAs<VectorType>()->getNumElements();
2279 
2280     if (NumElementsArg1 != 3) {
2281       int LessOrMore = NumElementsArg1 > 3 ? 1 : 0;
2282       SemaRef.Diag(TheCall->getBeginLoc(),
2283                    diag::err_vector_incorrect_num_elements)
2284           << LessOrMore << 3 << NumElementsArg1 << /*operand*/ 1;
2285       return true;
2286     }
2287     if (NumElementsArg2 != 3) {
2288       int LessOrMore = NumElementsArg2 > 3 ? 1 : 0;
2289 
2290       SemaRef.Diag(TheCall->getBeginLoc(),
2291                    diag::err_vector_incorrect_num_elements)
2292           << LessOrMore << 3 << NumElementsArg2 << /*operand*/ 1;
2293       return true;
2294     }
2295 
2296     ExprResult A = TheCall->getArg(0);
2297     QualType ArgTyA = A.get()->getType();
2298     // return type is the same as the input type
2299     TheCall->setType(ArgTyA);
2300     break;
2301   }
2302   case Builtin::BI__builtin_hlsl_dot: {
2303     if (SemaRef.checkArgCount(TheCall, 2))
2304       return true;
2305     if (CheckVectorElementCallArgs(&SemaRef, TheCall))
2306       return true;
2307     if (SemaRef.BuiltinVectorToScalarMath(TheCall))
2308       return true;
2309     if (CheckNoDoubleVectors(&SemaRef, TheCall))
2310       return true;
2311     break;
2312   }
2313   case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
2314   case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: {
2315     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2316       return true;
2317 
2318     const Expr *Arg = TheCall->getArg(0);
2319     QualType ArgTy = Arg->getType();
2320     QualType EltTy = ArgTy;
2321 
2322     QualType ResTy = SemaRef.Context.UnsignedIntTy;
2323 
2324     if (auto *VecTy = EltTy->getAs<VectorType>()) {
2325       EltTy = VecTy->getElementType();
2326       ResTy = SemaRef.Context.getVectorType(ResTy, VecTy->getNumElements(),
2327                                             VecTy->getVectorKind());
2328     }
2329 
2330     if (!EltTy->isIntegerType()) {
2331       Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
2332           << 1 << /* integer ty */ 6 << ArgTy;
2333       return true;
2334     }
2335 
2336     TheCall->setType(ResTy);
2337     break;
2338   }
2339   case Builtin::BI__builtin_hlsl_select: {
2340     if (SemaRef.checkArgCount(TheCall, 3))
2341       return true;
2342     if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0))
2343       return true;
2344     QualType ArgTy = TheCall->getArg(0)->getType();
2345     if (ArgTy->isBooleanType() && CheckBoolSelect(&SemaRef, TheCall))
2346       return true;
2347     auto *VTy = ArgTy->getAs<VectorType>();
2348     if (VTy && VTy->getElementType()->isBooleanType() &&
2349         CheckVectorSelect(&SemaRef, TheCall))
2350       return true;
2351     break;
2352   }
2353   case Builtin::BI__builtin_hlsl_elementwise_saturate:
2354   case Builtin::BI__builtin_hlsl_elementwise_rcp: {
2355     if (CheckAllArgsHaveFloatRepresentation(&SemaRef, TheCall))
2356       return true;
2357     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2358       return true;
2359     break;
2360   }
2361   case Builtin::BI__builtin_hlsl_elementwise_degrees:
2362   case Builtin::BI__builtin_hlsl_elementwise_radians:
2363   case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
2364   case Builtin::BI__builtin_hlsl_elementwise_frac: {
2365     if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
2366       return true;
2367     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2368       return true;
2369     break;
2370   }
2371   case Builtin::BI__builtin_hlsl_elementwise_isinf: {
2372     if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
2373       return true;
2374     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2375       return true;
2376     SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().BoolTy);
2377     break;
2378   }
2379   case Builtin::BI__builtin_hlsl_lerp: {
2380     if (SemaRef.checkArgCount(TheCall, 3))
2381       return true;
2382     if (CheckVectorElementCallArgs(&SemaRef, TheCall))
2383       return true;
2384     if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
2385       return true;
2386     if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
2387       return true;
2388     break;
2389   }
2390   case Builtin::BI__builtin_hlsl_mad: {
2391     if (SemaRef.checkArgCount(TheCall, 3))
2392       return true;
2393     if (CheckVectorElementCallArgs(&SemaRef, TheCall))
2394       return true;
2395     if (SemaRef.BuiltinElementwiseTernaryMath(
2396             TheCall, /*CheckForFloatArgs*/
2397             TheCall->getArg(0)->getType()->hasFloatingRepresentation()))
2398       return true;
2399     break;
2400   }
2401   case Builtin::BI__builtin_hlsl_normalize: {
2402     if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
2403       return true;
2404     if (SemaRef.checkArgCount(TheCall, 1))
2405       return true;
2406 
2407     ExprResult A = TheCall->getArg(0);
2408     QualType ArgTyA = A.get()->getType();
2409     // return type is the same as the input type
2410     TheCall->setType(ArgTyA);
2411     break;
2412   }
2413   case Builtin::BI__builtin_hlsl_elementwise_sign: {
2414     if (CheckFloatingOrIntRepresentation(&SemaRef, TheCall))
2415       return true;
2416     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2417       return true;
2418     SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().IntTy);
2419     break;
2420   }
2421   case Builtin::BI__builtin_hlsl_step: {
2422     if (SemaRef.checkArgCount(TheCall, 2))
2423       return true;
2424     if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
2425       return true;
2426 
2427     ExprResult A = TheCall->getArg(0);
2428     QualType ArgTyA = A.get()->getType();
2429     // return type is the same as the input type
2430     TheCall->setType(ArgTyA);
2431     break;
2432   }
2433   case Builtin::BI__builtin_hlsl_wave_active_max:
2434   case Builtin::BI__builtin_hlsl_wave_active_sum: {
2435     if (SemaRef.checkArgCount(TheCall, 1))
2436       return true;
2437 
2438     // Ensure input expr type is a scalar/vector and the same as the return type
2439     if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
2440       return true;
2441     if (CheckWaveActive(&SemaRef, TheCall))
2442       return true;
2443     ExprResult Expr = TheCall->getArg(0);
2444     QualType ArgTyExpr = Expr.get()->getType();
2445     TheCall->setType(ArgTyExpr);
2446     break;
2447   }
2448   // Note these are llvm builtins that we want to catch invalid intrinsic
2449   // generation. Normal handling of these builitns will occur elsewhere.
2450   case Builtin::BI__builtin_elementwise_bitreverse: {
2451     if (CheckUnsignedIntRepresentation(&SemaRef, TheCall))
2452       return true;
2453     break;
2454   }
2455   case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
2456     if (SemaRef.checkArgCount(TheCall, 2))
2457       return true;
2458 
2459     // Ensure index parameter type can be interpreted as a uint
2460     ExprResult Index = TheCall->getArg(1);
2461     QualType ArgTyIndex = Index.get()->getType();
2462     if (!ArgTyIndex->isIntegerType()) {
2463       SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
2464                    diag::err_typecheck_convert_incompatible)
2465           << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
2466       return true;
2467     }
2468 
2469     // Ensure input expr type is a scalar/vector and the same as the return type
2470     if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
2471       return true;
2472 
2473     ExprResult Expr = TheCall->getArg(0);
2474     QualType ArgTyExpr = Expr.get()->getType();
2475     TheCall->setType(ArgTyExpr);
2476     break;
2477   }
2478   case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
2479     if (SemaRef.checkArgCount(TheCall, 0))
2480       return true;
2481     break;
2482   }
2483   case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
2484     if (SemaRef.checkArgCount(TheCall, 3))
2485       return true;
2486 
2487     if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.DoubleTy, 0) ||
2488         CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
2489                             1) ||
2490         CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
2491                             2))
2492       return true;
2493 
2494     if (CheckModifiableLValue(&SemaRef, TheCall, 1) ||
2495         CheckModifiableLValue(&SemaRef, TheCall, 2))
2496       return true;
2497     break;
2498   }
2499   case Builtin::BI__builtin_hlsl_elementwise_clip: {
2500     if (SemaRef.checkArgCount(TheCall, 1))
2501       return true;
2502 
2503     if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.FloatTy, 0))
2504       return true;
2505     break;
2506   }
2507   case Builtin::BI__builtin_elementwise_acos:
2508   case Builtin::BI__builtin_elementwise_asin:
2509   case Builtin::BI__builtin_elementwise_atan:
2510   case Builtin::BI__builtin_elementwise_atan2:
2511   case Builtin::BI__builtin_elementwise_ceil:
2512   case Builtin::BI__builtin_elementwise_cos:
2513   case Builtin::BI__builtin_elementwise_cosh:
2514   case Builtin::BI__builtin_elementwise_exp:
2515   case Builtin::BI__builtin_elementwise_exp2:
2516   case Builtin::BI__builtin_elementwise_floor:
2517   case Builtin::BI__builtin_elementwise_fmod:
2518   case Builtin::BI__builtin_elementwise_log:
2519   case Builtin::BI__builtin_elementwise_log2:
2520   case Builtin::BI__builtin_elementwise_log10:
2521   case Builtin::BI__builtin_elementwise_pow:
2522   case Builtin::BI__builtin_elementwise_roundeven:
2523   case Builtin::BI__builtin_elementwise_sin:
2524   case Builtin::BI__builtin_elementwise_sinh:
2525   case Builtin::BI__builtin_elementwise_sqrt:
2526   case Builtin::BI__builtin_elementwise_tan:
2527   case Builtin::BI__builtin_elementwise_tanh:
2528   case Builtin::BI__builtin_elementwise_trunc: {
2529     if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
2530       return true;
2531     break;
2532   }
2533   case Builtin::BI__builtin_hlsl_buffer_update_counter: {
2534     auto checkResTy = [](const HLSLAttributedResourceType *ResTy) -> bool {
2535       return !(ResTy->getAttrs().ResourceClass == ResourceClass::UAV &&
2536                ResTy->getAttrs().RawBuffer && ResTy->hasContainedType());
2537     };
2538     if (SemaRef.checkArgCount(TheCall, 2) ||
2539         CheckResourceHandle(&SemaRef, TheCall, 0, checkResTy) ||
2540         CheckArgTypeMatches(&SemaRef, TheCall->getArg(1),
2541                             SemaRef.getASTContext().IntTy))
2542       return true;
2543     Expr *OffsetExpr = TheCall->getArg(1);
2544     std::optional<llvm::APSInt> Offset =
2545         OffsetExpr->getIntegerConstantExpr(SemaRef.getASTContext());
2546     if (!Offset.has_value() || std::abs(Offset->getExtValue()) != 1) {
2547       SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
2548                    diag::err_hlsl_expect_arg_const_int_one_or_neg_one)
2549           << 1;
2550       return true;
2551     }
2552     break;
2553   }
2554   }
2555   return false;
2556 }
2557 
2558 static void BuildFlattenedTypeList(QualType BaseTy,
2559                                    llvm::SmallVectorImpl<QualType> &List) {
2560   llvm::SmallVector<QualType, 16> WorkList;
2561   WorkList.push_back(BaseTy);
2562   while (!WorkList.empty()) {
2563     QualType T = WorkList.pop_back_val();
2564     T = T.getCanonicalType().getUnqualifiedType();
2565     assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL");
2566     if (const auto *AT = dyn_cast<ConstantArrayType>(T)) {
2567       llvm::SmallVector<QualType, 16> ElementFields;
2568       // Generally I've avoided recursion in this algorithm, but arrays of
2569       // structs could be time-consuming to flatten and churn through on the
2570       // work list. Hopefully nesting arrays of structs containing arrays
2571       // of structs too many levels deep is unlikely.
2572       BuildFlattenedTypeList(AT->getElementType(), ElementFields);
2573       // Repeat the element's field list n times.
2574       for (uint64_t Ct = 0; Ct < AT->getZExtSize(); ++Ct)
2575         List.insert(List.end(), ElementFields.begin(), ElementFields.end());
2576       continue;
2577     }
2578     // Vectors can only have element types that are builtin types, so this can
2579     // add directly to the list instead of to the WorkList.
2580     if (const auto *VT = dyn_cast<VectorType>(T)) {
2581       List.insert(List.end(), VT->getNumElements(), VT->getElementType());
2582       continue;
2583     }
2584     if (const auto *RT = dyn_cast<RecordType>(T)) {
2585       const RecordDecl *RD = RT->getDecl();
2586       if (RD->isUnion()) {
2587         List.push_back(T);
2588         continue;
2589       }
2590       const CXXRecordDecl *CXXD = dyn_cast<CXXRecordDecl>(RD);
2591 
2592       llvm::SmallVector<QualType, 16> FieldTypes;
2593       if (CXXD && CXXD->isStandardLayout())
2594         RD = CXXD->getStandardLayoutBaseWithFields();
2595 
2596       for (const auto *FD : RD->fields())
2597         FieldTypes.push_back(FD->getType());
2598       // Reverse the newly added sub-range.
2599       std::reverse(FieldTypes.begin(), FieldTypes.end());
2600       WorkList.insert(WorkList.end(), FieldTypes.begin(), FieldTypes.end());
2601 
2602       // If this wasn't a standard layout type we may also have some base
2603       // classes to deal with.
2604       if (CXXD && !CXXD->isStandardLayout()) {
2605         FieldTypes.clear();
2606         for (const auto &Base : CXXD->bases())
2607           FieldTypes.push_back(Base.getType());
2608         std::reverse(FieldTypes.begin(), FieldTypes.end());
2609         WorkList.insert(WorkList.end(), FieldTypes.begin(), FieldTypes.end());
2610       }
2611       continue;
2612     }
2613     List.push_back(T);
2614   }
2615 }
2616 
2617 bool SemaHLSL::IsTypedResourceElementCompatible(clang::QualType QT) {
2618   // null and array types are not allowed.
2619   if (QT.isNull() || QT->isArrayType())
2620     return false;
2621 
2622   // UDT types are not allowed
2623   if (QT->isRecordType())
2624     return false;
2625 
2626   if (QT->isBooleanType() || QT->isEnumeralType())
2627     return false;
2628 
2629   // the only other valid builtin types are scalars or vectors
2630   if (QT->isArithmeticType()) {
2631     if (SemaRef.Context.getTypeSize(QT) / 8 > 16)
2632       return false;
2633     return true;
2634   }
2635 
2636   if (const VectorType *VT = QT->getAs<VectorType>()) {
2637     int ArraySize = VT->getNumElements();
2638 
2639     if (ArraySize > 4)
2640       return false;
2641 
2642     QualType ElTy = VT->getElementType();
2643     if (ElTy->isBooleanType())
2644       return false;
2645 
2646     if (SemaRef.Context.getTypeSize(QT) / 8 > 16)
2647       return false;
2648     return true;
2649   }
2650 
2651   return false;
2652 }
2653 
2654 bool SemaHLSL::IsScalarizedLayoutCompatible(QualType T1, QualType T2) const {
2655   if (T1.isNull() || T2.isNull())
2656     return false;
2657 
2658   T1 = T1.getCanonicalType().getUnqualifiedType();
2659   T2 = T2.getCanonicalType().getUnqualifiedType();
2660 
2661   // If both types are the same canonical type, they're obviously compatible.
2662   if (SemaRef.getASTContext().hasSameType(T1, T2))
2663     return true;
2664 
2665   llvm::SmallVector<QualType, 16> T1Types;
2666   BuildFlattenedTypeList(T1, T1Types);
2667   llvm::SmallVector<QualType, 16> T2Types;
2668   BuildFlattenedTypeList(T2, T2Types);
2669 
2670   // Check the flattened type list
2671   return llvm::equal(T1Types, T2Types,
2672                      [this](QualType LHS, QualType RHS) -> bool {
2673                        return SemaRef.IsLayoutCompatible(LHS, RHS);
2674                      });
2675 }
2676 
2677 bool SemaHLSL::CheckCompatibleParameterABI(FunctionDecl *New,
2678                                            FunctionDecl *Old) {
2679   if (New->getNumParams() != Old->getNumParams())
2680     return true;
2681 
2682   bool HadError = false;
2683 
2684   for (unsigned i = 0, e = New->getNumParams(); i != e; ++i) {
2685     ParmVarDecl *NewParam = New->getParamDecl(i);
2686     ParmVarDecl *OldParam = Old->getParamDecl(i);
2687 
2688     // HLSL parameter declarations for inout and out must match between
2689     // declarations. In HLSL inout and out are ambiguous at the call site,
2690     // but have different calling behavior, so you cannot overload a
2691     // method based on a difference between inout and out annotations.
2692     const auto *NDAttr = NewParam->getAttr<HLSLParamModifierAttr>();
2693     unsigned NSpellingIdx = (NDAttr ? NDAttr->getSpellingListIndex() : 0);
2694     const auto *ODAttr = OldParam->getAttr<HLSLParamModifierAttr>();
2695     unsigned OSpellingIdx = (ODAttr ? ODAttr->getSpellingListIndex() : 0);
2696 
2697     if (NSpellingIdx != OSpellingIdx) {
2698       SemaRef.Diag(NewParam->getLocation(),
2699                    diag::err_hlsl_param_qualifier_mismatch)
2700           << NDAttr << NewParam;
2701       SemaRef.Diag(OldParam->getLocation(), diag::note_previous_declaration_as)
2702           << ODAttr;
2703       HadError = true;
2704     }
2705   }
2706   return HadError;
2707 }
2708 
2709 ExprResult SemaHLSL::ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg) {
2710   assert(Param->hasAttr<HLSLParamModifierAttr>() &&
2711          "We should not get here without a parameter modifier expression");
2712   const auto *Attr = Param->getAttr<HLSLParamModifierAttr>();
2713   if (Attr->getABI() == ParameterABI::Ordinary)
2714     return ExprResult(Arg);
2715 
2716   bool IsInOut = Attr->getABI() == ParameterABI::HLSLInOut;
2717   if (!Arg->isLValue()) {
2718     SemaRef.Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_lvalue)
2719         << Arg << (IsInOut ? 1 : 0);
2720     return ExprError();
2721   }
2722 
2723   ASTContext &Ctx = SemaRef.getASTContext();
2724 
2725   QualType Ty = Param->getType().getNonLValueExprType(Ctx);
2726 
2727   // HLSL allows implicit conversions from scalars to vectors, but not the
2728   // inverse, so we need to disallow `inout` with scalar->vector or
2729   // scalar->matrix conversions.
2730   if (Arg->getType()->isScalarType() != Ty->isScalarType()) {
2731     SemaRef.Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_scalar_extension)
2732         << Arg << (IsInOut ? 1 : 0);
2733     return ExprError();
2734   }
2735 
2736   auto *ArgOpV = new (Ctx) OpaqueValueExpr(Param->getBeginLoc(), Arg->getType(),
2737                                            VK_LValue, OK_Ordinary, Arg);
2738 
2739   // Parameters are initialized via copy initialization. This allows for
2740   // overload resolution of argument constructors.
2741   InitializedEntity Entity =
2742       InitializedEntity::InitializeParameter(Ctx, Ty, false);
2743   ExprResult Res =
2744       SemaRef.PerformCopyInitialization(Entity, Param->getBeginLoc(), ArgOpV);
2745   if (Res.isInvalid())
2746     return ExprError();
2747   Expr *Base = Res.get();
2748   // After the cast, drop the reference type when creating the exprs.
2749   Ty = Ty.getNonLValueExprType(Ctx);
2750   auto *OpV = new (Ctx)
2751       OpaqueValueExpr(Param->getBeginLoc(), Ty, VK_LValue, OK_Ordinary, Base);
2752 
2753   // Writebacks are performed with `=` binary operator, which allows for
2754   // overload resolution on writeback result expressions.
2755   Res = SemaRef.ActOnBinOp(SemaRef.getCurScope(), Param->getBeginLoc(),
2756                            tok::equal, ArgOpV, OpV);
2757 
2758   if (Res.isInvalid())
2759     return ExprError();
2760   Expr *Writeback = Res.get();
2761   auto *OutExpr =
2762       HLSLOutArgExpr::Create(Ctx, Ty, ArgOpV, OpV, Writeback, IsInOut);
2763 
2764   return ExprResult(OutExpr);
2765 }
2766 
2767 QualType SemaHLSL::getInoutParameterType(QualType Ty) {
2768   // If HLSL gains support for references, all the cites that use this will need
2769   // to be updated with semantic checking to produce errors for
2770   // pointers/references.
2771   assert(!Ty->isReferenceType() &&
2772          "Pointer and reference types cannot be inout or out parameters");
2773   Ty = SemaRef.getASTContext().getLValueReferenceType(Ty);
2774   Ty.addRestrict();
2775   return Ty;
2776 }
2777 
2778 void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
2779   if (VD->hasGlobalStorage()) {
2780     // make sure the declaration has a complete type
2781     if (SemaRef.RequireCompleteType(
2782             VD->getLocation(),
2783             SemaRef.getASTContext().getBaseElementType(VD->getType()),
2784             diag::err_typecheck_decl_incomplete_type)) {
2785       VD->setInvalidDecl();
2786       return;
2787     }
2788 
2789     // find all resources on decl
2790     if (VD->getType()->isHLSLIntangibleType())
2791       collectResourcesOnVarDecl(VD);
2792 
2793     // process explicit bindings
2794     processExplicitBindingsOnDecl(VD);
2795   }
2796 }
2797 
2798 // Walks though the global variable declaration, collects all resource binding
2799 // requirements and adds them to Bindings
2800 void SemaHLSL::collectResourcesOnVarDecl(VarDecl *VD) {
2801   assert(VD->hasGlobalStorage() && VD->getType()->isHLSLIntangibleType() &&
2802          "expected global variable that contains HLSL resource");
2803 
2804   // Cbuffers and Tbuffers are HLSLBufferDecl types
2805   if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(VD)) {
2806     Bindings.addDeclBindingInfo(VD, CBufferOrTBuffer->isCBuffer()
2807                                         ? ResourceClass::CBuffer
2808                                         : ResourceClass::SRV);
2809     return;
2810   }
2811 
2812   // Unwrap arrays
2813   // FIXME: Calculate array size while unwrapping
2814   const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
2815   while (Ty->isConstantArrayType()) {
2816     const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty);
2817     Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
2818   }
2819 
2820   // Resource (or array of resources)
2821   if (const HLSLAttributedResourceType *AttrResType =
2822           HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) {
2823     Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass);
2824     return;
2825   }
2826 
2827   // User defined record type
2828   if (const RecordType *RT = dyn_cast<RecordType>(Ty))
2829     collectResourcesOnUserRecordDecl(VD, RT);
2830 }
2831 
2832 // Walks though the explicit resource binding attributes on the declaration,
2833 // and makes sure there is a resource that matched the binding and updates
2834 // DeclBindingInfoLists
2835 void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) {
2836   assert(VD->hasGlobalStorage() && "expected global variable");
2837 
2838   for (Attr *A : VD->attrs()) {
2839     HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
2840     if (!RBA)
2841       continue;
2842 
2843     RegisterType RT = RBA->getRegisterType();
2844     assert(RT != RegisterType::I && "invalid or obsolete register type should "
2845                                     "never have an attribute created");
2846 
2847     if (RT == RegisterType::C) {
2848       if (Bindings.hasBindingInfoForDecl(VD))
2849         SemaRef.Diag(VD->getLocation(),
2850                      diag::warn_hlsl_user_defined_type_missing_member)
2851             << static_cast<int>(RT);
2852       continue;
2853     }
2854 
2855     // Find DeclBindingInfo for this binding and update it, or report error
2856     // if it does not exist (user type does to contain resources with the
2857     // expected resource class).
2858     ResourceClass RC = getResourceClass(RT);
2859     if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, RC)) {
2860       // update binding info
2861       BI->setBindingAttribute(RBA, BindingType::Explicit);
2862     } else {
2863       SemaRef.Diag(VD->getLocation(),
2864                    diag::warn_hlsl_user_defined_type_missing_member)
2865           << static_cast<int>(RT);
2866     }
2867   }
2868 }
2869