1 //===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // This implements Semantic Analysis for HLSL constructs. 9 //===----------------------------------------------------------------------===// 10 11 #include "clang/Sema/SemaHLSL.h" 12 #include "clang/AST/ASTContext.h" 13 #include "clang/AST/Attr.h" 14 #include "clang/AST/Attrs.inc" 15 #include "clang/AST/Decl.h" 16 #include "clang/AST/DeclBase.h" 17 #include "clang/AST/DeclCXX.h" 18 #include "clang/AST/DeclarationName.h" 19 #include "clang/AST/DynamicRecursiveASTVisitor.h" 20 #include "clang/AST/Expr.h" 21 #include "clang/AST/Type.h" 22 #include "clang/AST/TypeLoc.h" 23 #include "clang/Basic/Builtins.h" 24 #include "clang/Basic/DiagnosticSema.h" 25 #include "clang/Basic/IdentifierTable.h" 26 #include "clang/Basic/LLVM.h" 27 #include "clang/Basic/SourceLocation.h" 28 #include "clang/Basic/Specifiers.h" 29 #include "clang/Basic/TargetInfo.h" 30 #include "clang/Sema/Initialization.h" 31 #include "clang/Sema/ParsedAttr.h" 32 #include "clang/Sema/Sema.h" 33 #include "clang/Sema/Template.h" 34 #include "llvm/ADT/STLExtras.h" 35 #include "llvm/ADT/SmallVector.h" 36 #include "llvm/ADT/StringExtras.h" 37 #include "llvm/ADT/StringRef.h" 38 #include "llvm/ADT/Twine.h" 39 #include "llvm/Support/Casting.h" 40 #include "llvm/Support/DXILABI.h" 41 #include "llvm/Support/ErrorHandling.h" 42 #include "llvm/TargetParser/Triple.h" 43 #include <cstddef> 44 #include <iterator> 45 #include <utility> 46 47 using namespace clang; 48 using RegisterType = HLSLResourceBindingAttr::RegisterType; 49 50 static CXXRecordDecl *createHostLayoutStruct(Sema &S, 51 CXXRecordDecl *StructDecl); 52 53 static RegisterType getRegisterType(ResourceClass RC) { 54 switch (RC) { 55 case ResourceClass::SRV: 56 return RegisterType::SRV; 57 case ResourceClass::UAV: 58 return RegisterType::UAV; 59 case ResourceClass::CBuffer: 60 return RegisterType::CBuffer; 61 case ResourceClass::Sampler: 62 return RegisterType::Sampler; 63 } 64 llvm_unreachable("unexpected ResourceClass value"); 65 } 66 67 // Converts the first letter of string Slot to RegisterType. 68 // Returns false if the letter does not correspond to a valid register type. 69 static bool convertToRegisterType(StringRef Slot, RegisterType *RT) { 70 assert(RT != nullptr); 71 switch (Slot[0]) { 72 case 't': 73 case 'T': 74 *RT = RegisterType::SRV; 75 return true; 76 case 'u': 77 case 'U': 78 *RT = RegisterType::UAV; 79 return true; 80 case 'b': 81 case 'B': 82 *RT = RegisterType::CBuffer; 83 return true; 84 case 's': 85 case 'S': 86 *RT = RegisterType::Sampler; 87 return true; 88 case 'c': 89 case 'C': 90 *RT = RegisterType::C; 91 return true; 92 case 'i': 93 case 'I': 94 *RT = RegisterType::I; 95 return true; 96 default: 97 return false; 98 } 99 } 100 101 static ResourceClass getResourceClass(RegisterType RT) { 102 switch (RT) { 103 case RegisterType::SRV: 104 return ResourceClass::SRV; 105 case RegisterType::UAV: 106 return ResourceClass::UAV; 107 case RegisterType::CBuffer: 108 return ResourceClass::CBuffer; 109 case RegisterType::Sampler: 110 return ResourceClass::Sampler; 111 case RegisterType::C: 112 case RegisterType::I: 113 // Deliberately falling through to the unreachable below. 114 break; 115 } 116 llvm_unreachable("unexpected RegisterType value"); 117 } 118 119 DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD, 120 ResourceClass ResClass) { 121 assert(getDeclBindingInfo(VD, ResClass) == nullptr && 122 "DeclBindingInfo already added"); 123 assert(!hasBindingInfoForDecl(VD) || BindingsList.back().Decl == VD); 124 // VarDecl may have multiple entries for different resource classes. 125 // DeclToBindingListIndex stores the index of the first binding we saw 126 // for this decl. If there are any additional ones then that index 127 // shouldn't be updated. 128 DeclToBindingListIndex.try_emplace(VD, BindingsList.size()); 129 return &BindingsList.emplace_back(VD, ResClass); 130 } 131 132 DeclBindingInfo *ResourceBindings::getDeclBindingInfo(const VarDecl *VD, 133 ResourceClass ResClass) { 134 auto Entry = DeclToBindingListIndex.find(VD); 135 if (Entry != DeclToBindingListIndex.end()) { 136 for (unsigned Index = Entry->getSecond(); 137 Index < BindingsList.size() && BindingsList[Index].Decl == VD; 138 ++Index) { 139 if (BindingsList[Index].ResClass == ResClass) 140 return &BindingsList[Index]; 141 } 142 } 143 return nullptr; 144 } 145 146 bool ResourceBindings::hasBindingInfoForDecl(const VarDecl *VD) const { 147 return DeclToBindingListIndex.contains(VD); 148 } 149 150 SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {} 151 152 Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer, 153 SourceLocation KwLoc, IdentifierInfo *Ident, 154 SourceLocation IdentLoc, 155 SourceLocation LBrace) { 156 // For anonymous namespace, take the location of the left brace. 157 DeclContext *LexicalParent = SemaRef.getCurLexicalContext(); 158 HLSLBufferDecl *Result = HLSLBufferDecl::Create( 159 getASTContext(), LexicalParent, CBuffer, KwLoc, Ident, IdentLoc, LBrace); 160 161 // if CBuffer is false, then it's a TBuffer 162 auto RC = CBuffer ? llvm::hlsl::ResourceClass::CBuffer 163 : llvm::hlsl::ResourceClass::SRV; 164 auto RK = CBuffer ? llvm::hlsl::ResourceKind::CBuffer 165 : llvm::hlsl::ResourceKind::TBuffer; 166 Result->addAttr(HLSLResourceClassAttr::CreateImplicit(getASTContext(), RC)); 167 Result->addAttr(HLSLResourceAttr::CreateImplicit(getASTContext(), RK)); 168 169 SemaRef.PushOnScopeChains(Result, BufferScope); 170 SemaRef.PushDeclContext(BufferScope, Result); 171 172 return Result; 173 } 174 175 // Calculate the size of a legacy cbuffer type in bytes based on 176 // https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules 177 static unsigned calculateLegacyCbufferSize(const ASTContext &Context, 178 QualType T) { 179 unsigned Size = 0; 180 constexpr unsigned CBufferAlign = 16; 181 if (const RecordType *RT = T->getAs<RecordType>()) { 182 const RecordDecl *RD = RT->getDecl(); 183 for (const FieldDecl *Field : RD->fields()) { 184 QualType Ty = Field->getType(); 185 unsigned FieldSize = calculateLegacyCbufferSize(Context, Ty); 186 // FIXME: This is not the correct alignment, it does not work for 16-bit 187 // types. See llvm/llvm-project#119641. 188 unsigned FieldAlign = 4; 189 if (Ty->isAggregateType()) 190 FieldAlign = CBufferAlign; 191 Size = llvm::alignTo(Size, FieldAlign); 192 Size += FieldSize; 193 } 194 } else if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) { 195 if (unsigned ElementCount = AT->getSize().getZExtValue()) { 196 unsigned ElementSize = 197 calculateLegacyCbufferSize(Context, AT->getElementType()); 198 unsigned AlignedElementSize = llvm::alignTo(ElementSize, CBufferAlign); 199 Size = AlignedElementSize * (ElementCount - 1) + ElementSize; 200 } 201 } else if (const VectorType *VT = T->getAs<VectorType>()) { 202 unsigned ElementCount = VT->getNumElements(); 203 unsigned ElementSize = 204 calculateLegacyCbufferSize(Context, VT->getElementType()); 205 Size = ElementSize * ElementCount; 206 } else { 207 Size = Context.getTypeSize(T) / 8; 208 } 209 return Size; 210 } 211 212 // Validate packoffset: 213 // - if packoffset it used it must be set on all declarations inside the buffer 214 // - packoffset ranges must not overlap 215 static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl) { 216 llvm::SmallVector<std::pair<VarDecl *, HLSLPackOffsetAttr *>> PackOffsetVec; 217 218 // Make sure the packoffset annotations are either on all declarations 219 // or on none. 220 bool HasPackOffset = false; 221 bool HasNonPackOffset = false; 222 for (auto *Field : BufDecl->decls()) { 223 VarDecl *Var = dyn_cast<VarDecl>(Field); 224 if (!Var) 225 continue; 226 if (Field->hasAttr<HLSLPackOffsetAttr>()) { 227 PackOffsetVec.emplace_back(Var, Field->getAttr<HLSLPackOffsetAttr>()); 228 HasPackOffset = true; 229 } else { 230 HasNonPackOffset = true; 231 } 232 } 233 234 if (!HasPackOffset) 235 return; 236 237 if (HasNonPackOffset) 238 S.Diag(BufDecl->getLocation(), diag::warn_hlsl_packoffset_mix); 239 240 // Make sure there is no overlap in packoffset - sort PackOffsetVec by offset 241 // and compare adjacent values. 242 ASTContext &Context = S.getASTContext(); 243 std::sort(PackOffsetVec.begin(), PackOffsetVec.end(), 244 [](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS, 245 const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) { 246 return LHS.second->getOffsetInBytes() < 247 RHS.second->getOffsetInBytes(); 248 }); 249 for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) { 250 VarDecl *Var = PackOffsetVec[i].first; 251 HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second; 252 unsigned Size = calculateLegacyCbufferSize(Context, Var->getType()); 253 unsigned Begin = Attr->getOffsetInBytes(); 254 unsigned End = Begin + Size; 255 unsigned NextBegin = PackOffsetVec[i + 1].second->getOffsetInBytes(); 256 if (End > NextBegin) { 257 VarDecl *NextVar = PackOffsetVec[i + 1].first; 258 S.Diag(NextVar->getLocation(), diag::err_hlsl_packoffset_overlap) 259 << NextVar << Var; 260 } 261 } 262 } 263 264 // Returns true if the array has a zero size = if any of the dimensions is 0 265 static bool isZeroSizedArray(const ConstantArrayType *CAT) { 266 while (CAT && !CAT->isZeroSize()) 267 CAT = dyn_cast<ConstantArrayType>( 268 CAT->getElementType()->getUnqualifiedDesugaredType()); 269 return CAT != nullptr; 270 } 271 272 // Returns true if the record type is an HLSL resource class 273 static bool isResourceRecordType(const Type *Ty) { 274 return HLSLAttributedResourceType::findHandleTypeOnResource(Ty) != nullptr; 275 } 276 277 // Returns true if the type is a leaf element type that is not valid to be 278 // included in HLSL Buffer, such as a resource class, empty struct, zero-sized 279 // array, or a builtin intangible type. Returns false it is a valid leaf element 280 // type or if it is a record type that needs to be inspected further. 281 static bool isInvalidConstantBufferLeafElementType(const Type *Ty) { 282 if (Ty->isRecordType()) { 283 if (isResourceRecordType(Ty) || Ty->getAsCXXRecordDecl()->isEmpty()) 284 return true; 285 return false; 286 } 287 if (Ty->isConstantArrayType() && 288 isZeroSizedArray(cast<ConstantArrayType>(Ty))) 289 return true; 290 if (Ty->isHLSLBuiltinIntangibleType()) 291 return true; 292 return false; 293 } 294 295 // Returns true if the struct contains at least one element that prevents it 296 // from being included inside HLSL Buffer as is, such as an intangible type, 297 // empty struct, or zero-sized array. If it does, a new implicit layout struct 298 // needs to be created for HLSL Buffer use that will exclude these unwanted 299 // declarations (see createHostLayoutStruct function). 300 static bool requiresImplicitBufferLayoutStructure(const CXXRecordDecl *RD) { 301 if (RD->getTypeForDecl()->isHLSLIntangibleType() || RD->isEmpty()) 302 return true; 303 // check fields 304 for (const FieldDecl *Field : RD->fields()) { 305 QualType Ty = Field->getType(); 306 if (isInvalidConstantBufferLeafElementType(Ty.getTypePtr())) 307 return true; 308 if (Ty->isRecordType() && 309 requiresImplicitBufferLayoutStructure(Ty->getAsCXXRecordDecl())) 310 return true; 311 } 312 // check bases 313 for (const CXXBaseSpecifier &Base : RD->bases()) 314 if (requiresImplicitBufferLayoutStructure( 315 Base.getType()->getAsCXXRecordDecl())) 316 return true; 317 return false; 318 } 319 320 static CXXRecordDecl *findRecordDeclInContext(IdentifierInfo *II, 321 DeclContext *DC) { 322 CXXRecordDecl *RD = nullptr; 323 for (NamedDecl *Decl : 324 DC->getNonTransparentContext()->lookup(DeclarationName(II))) { 325 if (CXXRecordDecl *FoundRD = dyn_cast<CXXRecordDecl>(Decl)) { 326 assert(RD == nullptr && 327 "there should be at most 1 record by a given name in a scope"); 328 RD = FoundRD; 329 } 330 } 331 return RD; 332 } 333 334 // Creates a name for buffer layout struct using the provide name base. 335 // If the name must be unique (not previously defined), a suffix is added 336 // until a unique name is found. 337 static IdentifierInfo *getHostLayoutStructName(Sema &S, NamedDecl *BaseDecl, 338 bool MustBeUnique) { 339 ASTContext &AST = S.getASTContext(); 340 341 IdentifierInfo *NameBaseII = BaseDecl->getIdentifier(); 342 llvm::SmallString<64> Name("__layout_"); 343 if (NameBaseII) { 344 Name.append(NameBaseII->getName()); 345 } else { 346 // anonymous struct 347 Name.append("anon"); 348 MustBeUnique = true; 349 } 350 351 size_t NameLength = Name.size(); 352 IdentifierInfo *II = &AST.Idents.get(Name, tok::TokenKind::identifier); 353 if (!MustBeUnique) 354 return II; 355 356 unsigned suffix = 0; 357 while (true) { 358 if (suffix != 0) { 359 Name.append("_"); 360 Name.append(llvm::Twine(suffix).str()); 361 II = &AST.Idents.get(Name, tok::TokenKind::identifier); 362 } 363 if (!findRecordDeclInContext(II, BaseDecl->getDeclContext())) 364 return II; 365 // declaration with that name already exists - increment suffix and try 366 // again until unique name is found 367 suffix++; 368 Name.truncate(NameLength); 369 }; 370 } 371 372 // Creates a field declaration of given name and type for HLSL buffer layout 373 // struct. Returns nullptr if the type cannot be use in HLSL Buffer layout. 374 static FieldDecl *createFieldForHostLayoutStruct(Sema &S, const Type *Ty, 375 IdentifierInfo *II, 376 CXXRecordDecl *LayoutStruct) { 377 if (isInvalidConstantBufferLeafElementType(Ty)) 378 return nullptr; 379 380 if (Ty->isRecordType()) { 381 CXXRecordDecl *RD = Ty->getAsCXXRecordDecl(); 382 if (requiresImplicitBufferLayoutStructure(RD)) { 383 RD = createHostLayoutStruct(S, RD); 384 if (!RD) 385 return nullptr; 386 Ty = RD->getTypeForDecl(); 387 } 388 } 389 390 QualType QT = QualType(Ty, 0); 391 ASTContext &AST = S.getASTContext(); 392 TypeSourceInfo *TSI = AST.getTrivialTypeSourceInfo(QT, SourceLocation()); 393 auto *Field = FieldDecl::Create(AST, LayoutStruct, SourceLocation(), 394 SourceLocation(), II, QT, TSI, nullptr, false, 395 InClassInitStyle::ICIS_NoInit); 396 Field->setAccess(AccessSpecifier::AS_private); 397 return Field; 398 } 399 400 // Creates host layout struct for a struct included in HLSL Buffer. 401 // The layout struct will include only fields that are allowed in HLSL buffer. 402 // These fields will be filtered out: 403 // - resource classes 404 // - empty structs 405 // - zero-sized arrays 406 // Returns nullptr if the resulting layout struct would be empty. 407 static CXXRecordDecl *createHostLayoutStruct(Sema &S, 408 CXXRecordDecl *StructDecl) { 409 assert(requiresImplicitBufferLayoutStructure(StructDecl) && 410 "struct is already HLSL buffer compatible"); 411 412 ASTContext &AST = S.getASTContext(); 413 DeclContext *DC = StructDecl->getDeclContext(); 414 IdentifierInfo *II = getHostLayoutStructName(S, StructDecl, false); 415 416 // reuse existing if the layout struct if it already exists 417 if (CXXRecordDecl *RD = findRecordDeclInContext(II, DC)) 418 return RD; 419 420 CXXRecordDecl *LS = CXXRecordDecl::Create( 421 AST, TagDecl::TagKind::Class, DC, SourceLocation(), SourceLocation(), II); 422 LS->setImplicit(true); 423 LS->startDefinition(); 424 425 // copy base struct, create HLSL Buffer compatible version if needed 426 if (unsigned NumBases = StructDecl->getNumBases()) { 427 assert(NumBases == 1 && "HLSL supports only one base type"); 428 (void)NumBases; 429 CXXBaseSpecifier Base = *StructDecl->bases_begin(); 430 CXXRecordDecl *BaseDecl = Base.getType()->getAsCXXRecordDecl(); 431 if (requiresImplicitBufferLayoutStructure(BaseDecl)) { 432 BaseDecl = createHostLayoutStruct(S, BaseDecl); 433 if (BaseDecl) { 434 TypeSourceInfo *TSI = AST.getTrivialTypeSourceInfo( 435 QualType(BaseDecl->getTypeForDecl(), 0)); 436 Base = CXXBaseSpecifier(SourceRange(), false, StructDecl->isClass(), 437 AS_none, TSI, SourceLocation()); 438 } 439 } 440 if (BaseDecl) { 441 const CXXBaseSpecifier *BasesArray[1] = {&Base}; 442 LS->setBases(BasesArray, 1); 443 } 444 } 445 446 // filter struct fields 447 for (const FieldDecl *FD : StructDecl->fields()) { 448 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType(); 449 if (FieldDecl *NewFD = 450 createFieldForHostLayoutStruct(S, Ty, FD->getIdentifier(), LS)) 451 LS->addDecl(NewFD); 452 } 453 LS->completeDefinition(); 454 455 if (LS->field_empty() && LS->getNumBases() == 0) 456 return nullptr; 457 458 DC->addDecl(LS); 459 return LS; 460 } 461 462 // Creates host layout struct for HLSL Buffer. The struct will include only 463 // fields of types that are allowed in HLSL buffer and it will filter out: 464 // - static variable declarations 465 // - resource classes 466 // - empty structs 467 // - zero-sized arrays 468 // - non-variable declarations 469 // The layour struct will be added to the HLSLBufferDecl declarations. 470 void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl) { 471 ASTContext &AST = S.getASTContext(); 472 IdentifierInfo *II = getHostLayoutStructName(S, BufDecl, true); 473 474 CXXRecordDecl *LS = 475 CXXRecordDecl::Create(AST, TagDecl::TagKind::Class, BufDecl, 476 SourceLocation(), SourceLocation(), II); 477 LS->setImplicit(true); 478 LS->startDefinition(); 479 480 for (Decl *D : BufDecl->decls()) { 481 VarDecl *VD = dyn_cast<VarDecl>(D); 482 if (!VD || VD->getStorageClass() == SC_Static) 483 continue; 484 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType(); 485 if (FieldDecl *FD = 486 createFieldForHostLayoutStruct(S, Ty, VD->getIdentifier(), LS)) { 487 // add the field decl to the layout struct 488 LS->addDecl(FD); 489 // update address space of the original decl to hlsl_constant 490 QualType NewTy = 491 AST.getAddrSpaceQualType(VD->getType(), LangAS::hlsl_constant); 492 VD->setType(NewTy); 493 } 494 } 495 LS->completeDefinition(); 496 BufDecl->addDecl(LS); 497 } 498 499 // Handle end of cbuffer/tbuffer declaration 500 void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) { 501 auto *BufDecl = cast<HLSLBufferDecl>(Dcl); 502 BufDecl->setRBraceLoc(RBrace); 503 504 validatePackoffset(SemaRef, BufDecl); 505 506 // create buffer layout struct 507 createHostLayoutStructForBuffer(SemaRef, BufDecl); 508 509 SemaRef.PopDeclContext(); 510 } 511 512 HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D, 513 const AttributeCommonInfo &AL, 514 int X, int Y, int Z) { 515 if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) { 516 if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) { 517 Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL; 518 Diag(AL.getLoc(), diag::note_conflicting_attribute); 519 } 520 return nullptr; 521 } 522 return ::new (getASTContext()) 523 HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z); 524 } 525 526 HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D, 527 const AttributeCommonInfo &AL, 528 int Min, int Max, int Preferred, 529 int SpelledArgsCount) { 530 if (HLSLWaveSizeAttr *WS = D->getAttr<HLSLWaveSizeAttr>()) { 531 if (WS->getMin() != Min || WS->getMax() != Max || 532 WS->getPreferred() != Preferred || 533 WS->getSpelledArgsCount() != SpelledArgsCount) { 534 Diag(WS->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL; 535 Diag(AL.getLoc(), diag::note_conflicting_attribute); 536 } 537 return nullptr; 538 } 539 HLSLWaveSizeAttr *Result = ::new (getASTContext()) 540 HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred); 541 Result->setSpelledArgsCount(SpelledArgsCount); 542 return Result; 543 } 544 545 HLSLShaderAttr * 546 SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL, 547 llvm::Triple::EnvironmentType ShaderType) { 548 if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) { 549 if (NT->getType() != ShaderType) { 550 Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL; 551 Diag(AL.getLoc(), diag::note_conflicting_attribute); 552 } 553 return nullptr; 554 } 555 return HLSLShaderAttr::Create(getASTContext(), ShaderType, AL); 556 } 557 558 HLSLParamModifierAttr * 559 SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL, 560 HLSLParamModifierAttr::Spelling Spelling) { 561 // We can only merge an `in` attribute with an `out` attribute. All other 562 // combinations of duplicated attributes are ill-formed. 563 if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) { 564 if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) || 565 (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) { 566 D->dropAttr<HLSLParamModifierAttr>(); 567 SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()}; 568 return HLSLParamModifierAttr::Create( 569 getASTContext(), /*MergedSpelling=*/true, AdjustedRange, 570 HLSLParamModifierAttr::Keyword_inout); 571 } 572 Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL; 573 Diag(PA->getLocation(), diag::note_conflicting_attribute); 574 return nullptr; 575 } 576 return HLSLParamModifierAttr::Create(getASTContext(), AL); 577 } 578 579 void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) { 580 auto &TargetInfo = getASTContext().getTargetInfo(); 581 582 if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry) 583 return; 584 585 llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment(); 586 if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) { 587 if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) { 588 // The entry point is already annotated - check that it matches the 589 // triple. 590 if (Shader->getType() != Env) { 591 Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch) 592 << Shader; 593 FD->setInvalidDecl(); 594 } 595 } else { 596 // Implicitly add the shader attribute if the entry function isn't 597 // explicitly annotated. 598 FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), Env, 599 FD->getBeginLoc())); 600 } 601 } else { 602 switch (Env) { 603 case llvm::Triple::UnknownEnvironment: 604 case llvm::Triple::Library: 605 break; 606 default: 607 llvm_unreachable("Unhandled environment in triple"); 608 } 609 } 610 } 611 612 void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) { 613 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>(); 614 assert(ShaderAttr && "Entry point has no shader attribute"); 615 llvm::Triple::EnvironmentType ST = ShaderAttr->getType(); 616 auto &TargetInfo = getASTContext().getTargetInfo(); 617 VersionTuple Ver = TargetInfo.getTriple().getOSVersion(); 618 switch (ST) { 619 case llvm::Triple::Pixel: 620 case llvm::Triple::Vertex: 621 case llvm::Triple::Geometry: 622 case llvm::Triple::Hull: 623 case llvm::Triple::Domain: 624 case llvm::Triple::RayGeneration: 625 case llvm::Triple::Intersection: 626 case llvm::Triple::AnyHit: 627 case llvm::Triple::ClosestHit: 628 case llvm::Triple::Miss: 629 case llvm::Triple::Callable: 630 if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) { 631 DiagnoseAttrStageMismatch(NT, ST, 632 {llvm::Triple::Compute, 633 llvm::Triple::Amplification, 634 llvm::Triple::Mesh}); 635 FD->setInvalidDecl(); 636 } 637 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) { 638 DiagnoseAttrStageMismatch(WS, ST, 639 {llvm::Triple::Compute, 640 llvm::Triple::Amplification, 641 llvm::Triple::Mesh}); 642 FD->setInvalidDecl(); 643 } 644 break; 645 646 case llvm::Triple::Compute: 647 case llvm::Triple::Amplification: 648 case llvm::Triple::Mesh: 649 if (!FD->hasAttr<HLSLNumThreadsAttr>()) { 650 Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads) 651 << llvm::Triple::getEnvironmentTypeName(ST); 652 FD->setInvalidDecl(); 653 } 654 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) { 655 if (Ver < VersionTuple(6, 6)) { 656 Diag(WS->getLocation(), diag::err_hlsl_attribute_in_wrong_shader_model) 657 << WS << "6.6"; 658 FD->setInvalidDecl(); 659 } else if (WS->getSpelledArgsCount() > 1 && Ver < VersionTuple(6, 8)) { 660 Diag( 661 WS->getLocation(), 662 diag::err_hlsl_attribute_number_arguments_insufficient_shader_model) 663 << WS << WS->getSpelledArgsCount() << "6.8"; 664 FD->setInvalidDecl(); 665 } 666 } 667 break; 668 default: 669 llvm_unreachable("Unhandled environment in triple"); 670 } 671 672 for (ParmVarDecl *Param : FD->parameters()) { 673 if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) { 674 CheckSemanticAnnotation(FD, Param, AnnotationAttr); 675 } else { 676 // FIXME: Handle struct parameters where annotations are on struct fields. 677 // See: https://github.com/llvm/llvm-project/issues/57875 678 Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation); 679 Diag(Param->getLocation(), diag::note_previous_decl) << Param; 680 FD->setInvalidDecl(); 681 } 682 } 683 // FIXME: Verify return type semantic annotation. 684 } 685 686 void SemaHLSL::CheckSemanticAnnotation( 687 FunctionDecl *EntryPoint, const Decl *Param, 688 const HLSLAnnotationAttr *AnnotationAttr) { 689 auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>(); 690 assert(ShaderAttr && "Entry point has no shader attribute"); 691 llvm::Triple::EnvironmentType ST = ShaderAttr->getType(); 692 693 switch (AnnotationAttr->getKind()) { 694 case attr::HLSLSV_DispatchThreadID: 695 case attr::HLSLSV_GroupIndex: 696 case attr::HLSLSV_GroupThreadID: 697 case attr::HLSLSV_GroupID: 698 if (ST == llvm::Triple::Compute) 699 return; 700 DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute}); 701 break; 702 default: 703 llvm_unreachable("Unknown HLSLAnnotationAttr"); 704 } 705 } 706 707 void SemaHLSL::DiagnoseAttrStageMismatch( 708 const Attr *A, llvm::Triple::EnvironmentType Stage, 709 std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) { 710 SmallVector<StringRef, 8> StageStrings; 711 llvm::transform(AllowedStages, std::back_inserter(StageStrings), 712 [](llvm::Triple::EnvironmentType ST) { 713 return StringRef( 714 HLSLShaderAttr::ConvertEnvironmentTypeToStr(ST)); 715 }); 716 Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage) 717 << A << llvm::Triple::getEnvironmentTypeName(Stage) 718 << (AllowedStages.size() != 1) << join(StageStrings, ", "); 719 } 720 721 template <CastKind Kind> 722 static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz) { 723 if (const auto *VTy = Ty->getAs<VectorType>()) 724 Ty = VTy->getElementType(); 725 Ty = S.getASTContext().getExtVectorType(Ty, Sz); 726 E = S.ImpCastExprToType(E.get(), Ty, Kind); 727 } 728 729 template <CastKind Kind> 730 static QualType castElement(Sema &S, ExprResult &E, QualType Ty) { 731 E = S.ImpCastExprToType(E.get(), Ty, Kind); 732 return Ty; 733 } 734 735 static QualType handleFloatVectorBinOpConversion( 736 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType, 737 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) { 738 bool LHSFloat = LElTy->isRealFloatingType(); 739 bool RHSFloat = RElTy->isRealFloatingType(); 740 741 if (LHSFloat && RHSFloat) { 742 if (IsCompAssign || 743 SemaRef.getASTContext().getFloatingTypeOrder(LElTy, RElTy) > 0) 744 return castElement<CK_FloatingCast>(SemaRef, RHS, LHSType); 745 746 return castElement<CK_FloatingCast>(SemaRef, LHS, RHSType); 747 } 748 749 if (LHSFloat) 750 return castElement<CK_IntegralToFloating>(SemaRef, RHS, LHSType); 751 752 assert(RHSFloat); 753 if (IsCompAssign) 754 return castElement<clang::CK_FloatingToIntegral>(SemaRef, RHS, LHSType); 755 756 return castElement<CK_IntegralToFloating>(SemaRef, LHS, RHSType); 757 } 758 759 static QualType handleIntegerVectorBinOpConversion( 760 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType, 761 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) { 762 763 int IntOrder = SemaRef.Context.getIntegerTypeOrder(LElTy, RElTy); 764 bool LHSSigned = LElTy->hasSignedIntegerRepresentation(); 765 bool RHSSigned = RElTy->hasSignedIntegerRepresentation(); 766 auto &Ctx = SemaRef.getASTContext(); 767 768 // If both types have the same signedness, use the higher ranked type. 769 if (LHSSigned == RHSSigned) { 770 if (IsCompAssign || IntOrder >= 0) 771 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType); 772 773 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType); 774 } 775 776 // If the unsigned type has greater than or equal rank of the signed type, use 777 // the unsigned type. 778 if (IntOrder != (LHSSigned ? 1 : -1)) { 779 if (IsCompAssign || RHSSigned) 780 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType); 781 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType); 782 } 783 784 // At this point the signed type has higher rank than the unsigned type, which 785 // means it will be the same size or bigger. If the signed type is bigger, it 786 // can represent all the values of the unsigned type, so select it. 787 if (Ctx.getIntWidth(LElTy) != Ctx.getIntWidth(RElTy)) { 788 if (IsCompAssign || LHSSigned) 789 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType); 790 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType); 791 } 792 793 // This is a bit of an odd duck case in HLSL. It shouldn't happen, but can due 794 // to C/C++ leaking through. The place this happens today is long vs long 795 // long. When arguments are vector<unsigned long, N> and vector<long long, N>, 796 // the long long has higher rank than long even though they are the same size. 797 798 // If this is a compound assignment cast the right hand side to the left hand 799 // side's type. 800 if (IsCompAssign) 801 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType); 802 803 // If this isn't a compound assignment we convert to unsigned long long. 804 QualType ElTy = Ctx.getCorrespondingUnsignedType(LHSSigned ? LElTy : RElTy); 805 QualType NewTy = Ctx.getExtVectorType( 806 ElTy, RHSType->castAs<VectorType>()->getNumElements()); 807 (void)castElement<CK_IntegralCast>(SemaRef, RHS, NewTy); 808 809 return castElement<CK_IntegralCast>(SemaRef, LHS, NewTy); 810 } 811 812 static CastKind getScalarCastKind(ASTContext &Ctx, QualType DestTy, 813 QualType SrcTy) { 814 if (DestTy->isRealFloatingType() && SrcTy->isRealFloatingType()) 815 return CK_FloatingCast; 816 if (DestTy->isIntegralType(Ctx) && SrcTy->isIntegralType(Ctx)) 817 return CK_IntegralCast; 818 if (DestTy->isRealFloatingType()) 819 return CK_IntegralToFloating; 820 assert(SrcTy->isRealFloatingType() && DestTy->isIntegralType(Ctx)); 821 return CK_FloatingToIntegral; 822 } 823 824 QualType SemaHLSL::handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS, 825 QualType LHSType, 826 QualType RHSType, 827 bool IsCompAssign) { 828 const auto *LVecTy = LHSType->getAs<VectorType>(); 829 const auto *RVecTy = RHSType->getAs<VectorType>(); 830 auto &Ctx = getASTContext(); 831 832 // If the LHS is not a vector and this is a compound assignment, we truncate 833 // the argument to a scalar then convert it to the LHS's type. 834 if (!LVecTy && IsCompAssign) { 835 QualType RElTy = RHSType->castAs<VectorType>()->getElementType(); 836 RHS = SemaRef.ImpCastExprToType(RHS.get(), RElTy, CK_HLSLVectorTruncation); 837 RHSType = RHS.get()->getType(); 838 if (Ctx.hasSameUnqualifiedType(LHSType, RHSType)) 839 return LHSType; 840 RHS = SemaRef.ImpCastExprToType(RHS.get(), LHSType, 841 getScalarCastKind(Ctx, LHSType, RHSType)); 842 return LHSType; 843 } 844 845 unsigned EndSz = std::numeric_limits<unsigned>::max(); 846 unsigned LSz = 0; 847 if (LVecTy) 848 LSz = EndSz = LVecTy->getNumElements(); 849 if (RVecTy) 850 EndSz = std::min(RVecTy->getNumElements(), EndSz); 851 assert(EndSz != std::numeric_limits<unsigned>::max() && 852 "one of the above should have had a value"); 853 854 // In a compound assignment, the left operand does not change type, the right 855 // operand is converted to the type of the left operand. 856 if (IsCompAssign && LSz != EndSz) { 857 Diag(LHS.get()->getBeginLoc(), 858 diag::err_hlsl_vector_compound_assignment_truncation) 859 << LHSType << RHSType; 860 return QualType(); 861 } 862 863 if (RVecTy && RVecTy->getNumElements() > EndSz) 864 castVector<CK_HLSLVectorTruncation>(SemaRef, RHS, RHSType, EndSz); 865 if (!IsCompAssign && LVecTy && LVecTy->getNumElements() > EndSz) 866 castVector<CK_HLSLVectorTruncation>(SemaRef, LHS, LHSType, EndSz); 867 868 if (!RVecTy) 869 castVector<CK_VectorSplat>(SemaRef, RHS, RHSType, EndSz); 870 if (!IsCompAssign && !LVecTy) 871 castVector<CK_VectorSplat>(SemaRef, LHS, LHSType, EndSz); 872 873 // If we're at the same type after resizing we can stop here. 874 if (Ctx.hasSameUnqualifiedType(LHSType, RHSType)) 875 return Ctx.getCommonSugaredType(LHSType, RHSType); 876 877 QualType LElTy = LHSType->castAs<VectorType>()->getElementType(); 878 QualType RElTy = RHSType->castAs<VectorType>()->getElementType(); 879 880 // Handle conversion for floating point vectors. 881 if (LElTy->isRealFloatingType() || RElTy->isRealFloatingType()) 882 return handleFloatVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType, 883 LElTy, RElTy, IsCompAssign); 884 885 assert(LElTy->isIntegralType(Ctx) && RElTy->isIntegralType(Ctx) && 886 "HLSL Vectors can only contain integer or floating point types"); 887 return handleIntegerVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType, 888 LElTy, RElTy, IsCompAssign); 889 } 890 891 void SemaHLSL::emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS, 892 BinaryOperatorKind Opc) { 893 assert((Opc == BO_LOr || Opc == BO_LAnd) && 894 "Called with non-logical operator"); 895 llvm::SmallVector<char, 256> Buff; 896 llvm::raw_svector_ostream OS(Buff); 897 PrintingPolicy PP(SemaRef.getLangOpts()); 898 StringRef NewFnName = Opc == BO_LOr ? "or" : "and"; 899 OS << NewFnName << "("; 900 LHS->printPretty(OS, nullptr, PP); 901 OS << ", "; 902 RHS->printPretty(OS, nullptr, PP); 903 OS << ")"; 904 SourceRange FullRange = SourceRange(LHS->getBeginLoc(), RHS->getEndLoc()); 905 SemaRef.Diag(LHS->getBeginLoc(), diag::note_function_suggestion) 906 << NewFnName << FixItHint::CreateReplacement(FullRange, OS.str()); 907 } 908 909 void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) { 910 llvm::VersionTuple SMVersion = 911 getASTContext().getTargetInfo().getTriple().getOSVersion(); 912 uint32_t ZMax = 1024; 913 uint32_t ThreadMax = 1024; 914 if (SMVersion.getMajor() <= 4) { 915 ZMax = 1; 916 ThreadMax = 768; 917 } else if (SMVersion.getMajor() == 5) { 918 ZMax = 64; 919 ThreadMax = 1024; 920 } 921 922 uint32_t X; 923 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), X)) 924 return; 925 if (X > 1024) { 926 Diag(AL.getArgAsExpr(0)->getExprLoc(), 927 diag::err_hlsl_numthreads_argument_oor) 928 << 0 << 1024; 929 return; 930 } 931 uint32_t Y; 932 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Y)) 933 return; 934 if (Y > 1024) { 935 Diag(AL.getArgAsExpr(1)->getExprLoc(), 936 diag::err_hlsl_numthreads_argument_oor) 937 << 1 << 1024; 938 return; 939 } 940 uint32_t Z; 941 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Z)) 942 return; 943 if (Z > ZMax) { 944 SemaRef.Diag(AL.getArgAsExpr(2)->getExprLoc(), 945 diag::err_hlsl_numthreads_argument_oor) 946 << 2 << ZMax; 947 return; 948 } 949 950 if (X * Y * Z > ThreadMax) { 951 Diag(AL.getLoc(), diag::err_hlsl_numthreads_invalid) << ThreadMax; 952 return; 953 } 954 955 HLSLNumThreadsAttr *NewAttr = mergeNumThreadsAttr(D, AL, X, Y, Z); 956 if (NewAttr) 957 D->addAttr(NewAttr); 958 } 959 960 static bool isValidWaveSizeValue(unsigned Value) { 961 return llvm::isPowerOf2_32(Value) && Value >= 4 && Value <= 128; 962 } 963 964 void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) { 965 // validate that the wavesize argument is a power of 2 between 4 and 128 966 // inclusive 967 unsigned SpelledArgsCount = AL.getNumArgs(); 968 if (SpelledArgsCount == 0 || SpelledArgsCount > 3) 969 return; 970 971 uint32_t Min; 972 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Min)) 973 return; 974 975 uint32_t Max = 0; 976 if (SpelledArgsCount > 1 && 977 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Max)) 978 return; 979 980 uint32_t Preferred = 0; 981 if (SpelledArgsCount > 2 && 982 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Preferred)) 983 return; 984 985 if (SpelledArgsCount > 2) { 986 if (!isValidWaveSizeValue(Preferred)) { 987 Diag(AL.getArgAsExpr(2)->getExprLoc(), 988 diag::err_attribute_power_of_two_in_range) 989 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize 990 << Preferred; 991 return; 992 } 993 // Preferred not in range. 994 if (Preferred < Min || Preferred > Max) { 995 Diag(AL.getArgAsExpr(2)->getExprLoc(), 996 diag::err_attribute_power_of_two_in_range) 997 << AL << Min << Max << Preferred; 998 return; 999 } 1000 } else if (SpelledArgsCount > 1) { 1001 if (!isValidWaveSizeValue(Max)) { 1002 Diag(AL.getArgAsExpr(1)->getExprLoc(), 1003 diag::err_attribute_power_of_two_in_range) 1004 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Max; 1005 return; 1006 } 1007 if (Max < Min) { 1008 Diag(AL.getLoc(), diag::err_attribute_argument_invalid) << AL << 1; 1009 return; 1010 } else if (Max == Min) { 1011 Diag(AL.getLoc(), diag::warn_attr_min_eq_max) << AL; 1012 } 1013 } else { 1014 if (!isValidWaveSizeValue(Min)) { 1015 Diag(AL.getArgAsExpr(0)->getExprLoc(), 1016 diag::err_attribute_power_of_two_in_range) 1017 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Min; 1018 return; 1019 } 1020 } 1021 1022 HLSLWaveSizeAttr *NewAttr = 1023 mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount); 1024 if (NewAttr) 1025 D->addAttr(NewAttr); 1026 } 1027 1028 bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) { 1029 const auto *VT = T->getAs<VectorType>(); 1030 1031 if (!T->hasUnsignedIntegerRepresentation() || 1032 (VT && VT->getNumElements() > 3)) { 1033 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type) 1034 << AL << "uint/uint2/uint3"; 1035 return false; 1036 } 1037 1038 return true; 1039 } 1040 1041 void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) { 1042 auto *VD = cast<ValueDecl>(D); 1043 if (!diagnoseInputIDType(VD->getType(), AL)) 1044 return; 1045 1046 D->addAttr(::new (getASTContext()) 1047 HLSLSV_DispatchThreadIDAttr(getASTContext(), AL)); 1048 } 1049 1050 void SemaHLSL::handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL) { 1051 auto *VD = cast<ValueDecl>(D); 1052 if (!diagnoseInputIDType(VD->getType(), AL)) 1053 return; 1054 1055 D->addAttr(::new (getASTContext()) 1056 HLSLSV_GroupThreadIDAttr(getASTContext(), AL)); 1057 } 1058 1059 void SemaHLSL::handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL) { 1060 auto *VD = cast<ValueDecl>(D); 1061 if (!diagnoseInputIDType(VD->getType(), AL)) 1062 return; 1063 1064 D->addAttr(::new (getASTContext()) HLSLSV_GroupIDAttr(getASTContext(), AL)); 1065 } 1066 1067 void SemaHLSL::handlePackOffsetAttr(Decl *D, const ParsedAttr &AL) { 1068 if (!isa<VarDecl>(D) || !isa<HLSLBufferDecl>(D->getDeclContext())) { 1069 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node) 1070 << AL << "shader constant in a constant buffer"; 1071 return; 1072 } 1073 1074 uint32_t SubComponent; 1075 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), SubComponent)) 1076 return; 1077 uint32_t Component; 1078 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Component)) 1079 return; 1080 1081 QualType T = cast<VarDecl>(D)->getType().getCanonicalType(); 1082 // Check if T is an array or struct type. 1083 // TODO: mark matrix type as aggregate type. 1084 bool IsAggregateTy = (T->isArrayType() || T->isStructureType()); 1085 1086 // Check Component is valid for T. 1087 if (Component) { 1088 unsigned Size = getASTContext().getTypeSize(T); 1089 if (IsAggregateTy || Size > 128) { 1090 Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary); 1091 return; 1092 } else { 1093 // Make sure Component + sizeof(T) <= 4. 1094 if ((Component * 32 + Size) > 128) { 1095 Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary); 1096 return; 1097 } 1098 QualType EltTy = T; 1099 if (const auto *VT = T->getAs<VectorType>()) 1100 EltTy = VT->getElementType(); 1101 unsigned Align = getASTContext().getTypeAlign(EltTy); 1102 if (Align > 32 && Component == 1) { 1103 // NOTE: Component 3 will hit err_hlsl_packoffset_cross_reg_boundary. 1104 // So we only need to check Component 1 here. 1105 Diag(AL.getLoc(), diag::err_hlsl_packoffset_alignment_mismatch) 1106 << Align << EltTy; 1107 return; 1108 } 1109 } 1110 } 1111 1112 D->addAttr(::new (getASTContext()) HLSLPackOffsetAttr( 1113 getASTContext(), AL, SubComponent, Component)); 1114 } 1115 1116 void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) { 1117 StringRef Str; 1118 SourceLocation ArgLoc; 1119 if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc)) 1120 return; 1121 1122 llvm::Triple::EnvironmentType ShaderType; 1123 if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, ShaderType)) { 1124 Diag(AL.getLoc(), diag::warn_attribute_type_not_supported) 1125 << AL << Str << ArgLoc; 1126 return; 1127 } 1128 1129 // FIXME: check function match the shader stage. 1130 1131 HLSLShaderAttr *NewAttr = mergeShaderAttr(D, AL, ShaderType); 1132 if (NewAttr) 1133 D->addAttr(NewAttr); 1134 } 1135 1136 bool clang::CreateHLSLAttributedResourceType( 1137 Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList, 1138 QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo) { 1139 assert(AttrList.size() && "expected list of resource attributes"); 1140 1141 QualType ContainedTy = QualType(); 1142 TypeSourceInfo *ContainedTyInfo = nullptr; 1143 SourceLocation LocBegin = AttrList[0]->getRange().getBegin(); 1144 SourceLocation LocEnd = AttrList[0]->getRange().getEnd(); 1145 1146 HLSLAttributedResourceType::Attributes ResAttrs; 1147 1148 bool HasResourceClass = false; 1149 for (const Attr *A : AttrList) { 1150 if (!A) 1151 continue; 1152 LocEnd = A->getRange().getEnd(); 1153 switch (A->getKind()) { 1154 case attr::HLSLResourceClass: { 1155 ResourceClass RC = cast<HLSLResourceClassAttr>(A)->getResourceClass(); 1156 if (HasResourceClass) { 1157 S.Diag(A->getLocation(), ResAttrs.ResourceClass == RC 1158 ? diag::warn_duplicate_attribute_exact 1159 : diag::warn_duplicate_attribute) 1160 << A; 1161 return false; 1162 } 1163 ResAttrs.ResourceClass = RC; 1164 HasResourceClass = true; 1165 break; 1166 } 1167 case attr::HLSLROV: 1168 if (ResAttrs.IsROV) { 1169 S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A; 1170 return false; 1171 } 1172 ResAttrs.IsROV = true; 1173 break; 1174 case attr::HLSLRawBuffer: 1175 if (ResAttrs.RawBuffer) { 1176 S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A; 1177 return false; 1178 } 1179 ResAttrs.RawBuffer = true; 1180 break; 1181 case attr::HLSLContainedType: { 1182 const HLSLContainedTypeAttr *CTAttr = cast<HLSLContainedTypeAttr>(A); 1183 QualType Ty = CTAttr->getType(); 1184 if (!ContainedTy.isNull()) { 1185 S.Diag(A->getLocation(), ContainedTy == Ty 1186 ? diag::warn_duplicate_attribute_exact 1187 : diag::warn_duplicate_attribute) 1188 << A; 1189 return false; 1190 } 1191 ContainedTy = Ty; 1192 ContainedTyInfo = CTAttr->getTypeLoc(); 1193 break; 1194 } 1195 default: 1196 llvm_unreachable("unhandled resource attribute type"); 1197 } 1198 } 1199 1200 if (!HasResourceClass) { 1201 S.Diag(AttrList.back()->getRange().getEnd(), 1202 diag::err_hlsl_missing_resource_class); 1203 return false; 1204 } 1205 1206 ResType = S.getASTContext().getHLSLAttributedResourceType( 1207 Wrapped, ContainedTy, ResAttrs); 1208 1209 if (LocInfo && ContainedTyInfo) { 1210 LocInfo->Range = SourceRange(LocBegin, LocEnd); 1211 LocInfo->ContainedTyInfo = ContainedTyInfo; 1212 } 1213 return true; 1214 } 1215 1216 // Validates and creates an HLSL attribute that is applied as type attribute on 1217 // HLSL resource. The attributes are collected in HLSLResourcesTypeAttrs and at 1218 // the end of the declaration they are applied to the declaration type by 1219 // wrapping it in HLSLAttributedResourceType. 1220 bool SemaHLSL::handleResourceTypeAttr(QualType T, const ParsedAttr &AL) { 1221 // only allow resource type attributes on intangible types 1222 if (!T->isHLSLResourceType()) { 1223 Diag(AL.getLoc(), diag::err_hlsl_attribute_needs_intangible_type) 1224 << AL << getASTContext().HLSLResourceTy; 1225 return false; 1226 } 1227 1228 // validate number of arguments 1229 if (!AL.checkExactlyNumArgs(SemaRef, AL.getMinArgs())) 1230 return false; 1231 1232 Attr *A = nullptr; 1233 switch (AL.getKind()) { 1234 case ParsedAttr::AT_HLSLResourceClass: { 1235 if (!AL.isArgIdent(0)) { 1236 Diag(AL.getLoc(), diag::err_attribute_argument_type) 1237 << AL << AANT_ArgumentIdentifier; 1238 return false; 1239 } 1240 1241 IdentifierLoc *Loc = AL.getArgAsIdent(0); 1242 StringRef Identifier = Loc->Ident->getName(); 1243 SourceLocation ArgLoc = Loc->Loc; 1244 1245 // Validate resource class value 1246 ResourceClass RC; 1247 if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Identifier, RC)) { 1248 Diag(ArgLoc, diag::warn_attribute_type_not_supported) 1249 << "ResourceClass" << Identifier; 1250 return false; 1251 } 1252 A = HLSLResourceClassAttr::Create(getASTContext(), RC, AL.getLoc()); 1253 break; 1254 } 1255 1256 case ParsedAttr::AT_HLSLROV: 1257 A = HLSLROVAttr::Create(getASTContext(), AL.getLoc()); 1258 break; 1259 1260 case ParsedAttr::AT_HLSLRawBuffer: 1261 A = HLSLRawBufferAttr::Create(getASTContext(), AL.getLoc()); 1262 break; 1263 1264 case ParsedAttr::AT_HLSLContainedType: { 1265 if (AL.getNumArgs() != 1 && !AL.hasParsedType()) { 1266 Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1; 1267 return false; 1268 } 1269 1270 TypeSourceInfo *TSI = nullptr; 1271 QualType QT = SemaRef.GetTypeFromParser(AL.getTypeArg(), &TSI); 1272 assert(TSI && "no type source info for attribute argument"); 1273 if (SemaRef.RequireCompleteType(TSI->getTypeLoc().getBeginLoc(), QT, 1274 diag::err_incomplete_type)) 1275 return false; 1276 A = HLSLContainedTypeAttr::Create(getASTContext(), TSI, AL.getLoc()); 1277 break; 1278 } 1279 1280 default: 1281 llvm_unreachable("unhandled HLSL attribute"); 1282 } 1283 1284 HLSLResourcesTypeAttrs.emplace_back(A); 1285 return true; 1286 } 1287 1288 // Combines all resource type attributes and creates HLSLAttributedResourceType. 1289 QualType SemaHLSL::ProcessResourceTypeAttributes(QualType CurrentType) { 1290 if (!HLSLResourcesTypeAttrs.size()) 1291 return CurrentType; 1292 1293 QualType QT = CurrentType; 1294 HLSLAttributedResourceLocInfo LocInfo; 1295 if (CreateHLSLAttributedResourceType(SemaRef, CurrentType, 1296 HLSLResourcesTypeAttrs, QT, &LocInfo)) { 1297 const HLSLAttributedResourceType *RT = 1298 cast<HLSLAttributedResourceType>(QT.getTypePtr()); 1299 1300 // Temporarily store TypeLoc information for the new type. 1301 // It will be transferred to HLSLAttributesResourceTypeLoc 1302 // shortly after the type is created by TypeSpecLocFiller which 1303 // will call the TakeLocForHLSLAttribute method below. 1304 LocsForHLSLAttributedResources.insert(std::pair(RT, LocInfo)); 1305 } 1306 HLSLResourcesTypeAttrs.clear(); 1307 return QT; 1308 } 1309 1310 // Returns source location for the HLSLAttributedResourceType 1311 HLSLAttributedResourceLocInfo 1312 SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) { 1313 HLSLAttributedResourceLocInfo LocInfo = {}; 1314 auto I = LocsForHLSLAttributedResources.find(RT); 1315 if (I != LocsForHLSLAttributedResources.end()) { 1316 LocInfo = I->second; 1317 LocsForHLSLAttributedResources.erase(I); 1318 return LocInfo; 1319 } 1320 LocInfo.Range = SourceRange(); 1321 return LocInfo; 1322 } 1323 1324 // Walks though the global variable declaration, collects all resource binding 1325 // requirements and adds them to Bindings 1326 void SemaHLSL::collectResourcesOnUserRecordDecl(const VarDecl *VD, 1327 const RecordType *RT) { 1328 const RecordDecl *RD = RT->getDecl(); 1329 for (FieldDecl *FD : RD->fields()) { 1330 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType(); 1331 1332 // Unwrap arrays 1333 // FIXME: Calculate array size while unwrapping 1334 assert(!Ty->isIncompleteArrayType() && 1335 "incomplete arrays inside user defined types are not supported"); 1336 while (Ty->isConstantArrayType()) { 1337 const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty); 1338 Ty = CAT->getElementType()->getUnqualifiedDesugaredType(); 1339 } 1340 1341 if (!Ty->isRecordType()) 1342 continue; 1343 1344 if (const HLSLAttributedResourceType *AttrResType = 1345 HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) { 1346 // Add a new DeclBindingInfo to Bindings if it does not already exist 1347 ResourceClass RC = AttrResType->getAttrs().ResourceClass; 1348 DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, RC); 1349 if (!DBI) 1350 Bindings.addDeclBindingInfo(VD, RC); 1351 } else if (const RecordType *RT = dyn_cast<RecordType>(Ty)) { 1352 // Recursively scan embedded struct or class; it would be nice to do this 1353 // without recursion, but tricky to correctly calculate the size of the 1354 // binding, which is something we are probably going to need to do later 1355 // on. Hopefully nesting of structs in structs too many levels is 1356 // unlikely. 1357 collectResourcesOnUserRecordDecl(VD, RT); 1358 } 1359 } 1360 } 1361 1362 // Diagnore localized register binding errors for a single binding; does not 1363 // diagnose resource binding on user record types, that will be done later 1364 // in processResourceBindingOnDecl based on the information collected in 1365 // collectResourcesOnVarDecl. 1366 // Returns false if the register binding is not valid. 1367 static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc, 1368 Decl *D, RegisterType RegType, 1369 bool SpecifiedSpace) { 1370 int RegTypeNum = static_cast<int>(RegType); 1371 1372 // check if the decl type is groupshared 1373 if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) { 1374 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; 1375 return false; 1376 } 1377 1378 // Cbuffers and Tbuffers are HLSLBufferDecl types 1379 if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) { 1380 ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer 1381 : ResourceClass::SRV; 1382 if (RegType == getRegisterType(RC)) 1383 return true; 1384 1385 S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) 1386 << RegTypeNum; 1387 return false; 1388 } 1389 1390 // Samplers, UAVs, and SRVs are VarDecl types 1391 assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl"); 1392 VarDecl *VD = cast<VarDecl>(D); 1393 1394 // Resource 1395 if (const HLSLAttributedResourceType *AttrResType = 1396 HLSLAttributedResourceType::findHandleTypeOnResource( 1397 VD->getType().getTypePtr())) { 1398 if (RegType == getRegisterType(AttrResType->getAttrs().ResourceClass)) 1399 return true; 1400 1401 S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) 1402 << RegTypeNum; 1403 return false; 1404 } 1405 1406 const clang::Type *Ty = VD->getType().getTypePtr(); 1407 while (Ty->isArrayType()) 1408 Ty = Ty->getArrayElementTypeNoTypeQual(); 1409 1410 // Basic types 1411 if (Ty->isArithmeticType()) { 1412 bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext()); 1413 if (SpecifiedSpace && !DeclaredInCOrTBuffer) 1414 S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant); 1415 1416 if (!DeclaredInCOrTBuffer && 1417 (Ty->isIntegralType(S.getASTContext()) || Ty->isFloatingType())) { 1418 // Default Globals 1419 if (RegType == RegisterType::CBuffer) 1420 S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b); 1421 else if (RegType != RegisterType::C) 1422 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; 1423 } else { 1424 if (RegType == RegisterType::C) 1425 S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset); 1426 else 1427 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; 1428 } 1429 return false; 1430 } 1431 if (Ty->isRecordType()) 1432 // RecordTypes will be diagnosed in processResourceBindingOnDecl 1433 // that is called from ActOnVariableDeclarator 1434 return true; 1435 1436 // Anything else is an error 1437 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; 1438 return false; 1439 } 1440 1441 static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl, 1442 RegisterType regType) { 1443 // make sure that there are no two register annotations 1444 // applied to the decl with the same register type 1445 bool RegisterTypesDetected[5] = {false}; 1446 RegisterTypesDetected[static_cast<int>(regType)] = true; 1447 1448 for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) { 1449 if (HLSLResourceBindingAttr *attr = 1450 dyn_cast<HLSLResourceBindingAttr>(*it)) { 1451 1452 RegisterType otherRegType = attr->getRegisterType(); 1453 if (RegisterTypesDetected[static_cast<int>(otherRegType)]) { 1454 int otherRegTypeNum = static_cast<int>(otherRegType); 1455 S.Diag(TheDecl->getLocation(), 1456 diag::err_hlsl_duplicate_register_annotation) 1457 << otherRegTypeNum; 1458 return false; 1459 } 1460 RegisterTypesDetected[static_cast<int>(otherRegType)] = true; 1461 } 1462 } 1463 return true; 1464 } 1465 1466 static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, 1467 Decl *D, RegisterType RegType, 1468 bool SpecifiedSpace) { 1469 1470 // exactly one of these two types should be set 1471 assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) || 1472 (!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) && 1473 "expecting VarDecl or HLSLBufferDecl"); 1474 1475 // check if the declaration contains resource matching the register type 1476 if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace)) 1477 return false; 1478 1479 // next, if multiple register annotations exist, check that none conflict. 1480 return ValidateMultipleRegisterAnnotations(S, D, RegType); 1481 } 1482 1483 void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) { 1484 if (isa<VarDecl>(TheDecl)) { 1485 if (SemaRef.RequireCompleteType(TheDecl->getBeginLoc(), 1486 cast<ValueDecl>(TheDecl)->getType(), 1487 diag::err_incomplete_type)) 1488 return; 1489 } 1490 StringRef Space = "space0"; 1491 StringRef Slot = ""; 1492 1493 if (!AL.isArgIdent(0)) { 1494 Diag(AL.getLoc(), diag::err_attribute_argument_type) 1495 << AL << AANT_ArgumentIdentifier; 1496 return; 1497 } 1498 1499 IdentifierLoc *Loc = AL.getArgAsIdent(0); 1500 StringRef Str = Loc->Ident->getName(); 1501 SourceLocation ArgLoc = Loc->Loc; 1502 1503 SourceLocation SpaceArgLoc; 1504 bool SpecifiedSpace = false; 1505 if (AL.getNumArgs() == 2) { 1506 SpecifiedSpace = true; 1507 Slot = Str; 1508 if (!AL.isArgIdent(1)) { 1509 Diag(AL.getLoc(), diag::err_attribute_argument_type) 1510 << AL << AANT_ArgumentIdentifier; 1511 return; 1512 } 1513 1514 IdentifierLoc *Loc = AL.getArgAsIdent(1); 1515 Space = Loc->Ident->getName(); 1516 SpaceArgLoc = Loc->Loc; 1517 } else { 1518 Slot = Str; 1519 } 1520 1521 RegisterType RegType; 1522 unsigned SlotNum = 0; 1523 unsigned SpaceNum = 0; 1524 1525 // Validate. 1526 if (!Slot.empty()) { 1527 if (!convertToRegisterType(Slot, &RegType)) { 1528 Diag(ArgLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1); 1529 return; 1530 } 1531 if (RegType == RegisterType::I) { 1532 Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_i); 1533 return; 1534 } 1535 1536 StringRef SlotNumStr = Slot.substr(1); 1537 if (SlotNumStr.getAsInteger(10, SlotNum)) { 1538 Diag(ArgLoc, diag::err_hlsl_unsupported_register_number); 1539 return; 1540 } 1541 } 1542 1543 if (!Space.starts_with("space")) { 1544 Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space; 1545 return; 1546 } 1547 StringRef SpaceNumStr = Space.substr(5); 1548 if (SpaceNumStr.getAsInteger(10, SpaceNum)) { 1549 Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space; 1550 return; 1551 } 1552 1553 if (!DiagnoseHLSLRegisterAttribute(SemaRef, ArgLoc, TheDecl, RegType, 1554 SpecifiedSpace)) 1555 return; 1556 1557 HLSLResourceBindingAttr *NewAttr = 1558 HLSLResourceBindingAttr::Create(getASTContext(), Slot, Space, AL); 1559 if (NewAttr) { 1560 NewAttr->setBinding(RegType, SlotNum, SpaceNum); 1561 TheDecl->addAttr(NewAttr); 1562 } 1563 } 1564 1565 void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) { 1566 HLSLParamModifierAttr *NewAttr = mergeParamModifierAttr( 1567 D, AL, 1568 static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling())); 1569 if (NewAttr) 1570 D->addAttr(NewAttr); 1571 } 1572 1573 namespace { 1574 1575 /// This class implements HLSL availability diagnostics for default 1576 /// and relaxed mode 1577 /// 1578 /// The goal of this diagnostic is to emit an error or warning when an 1579 /// unavailable API is found in code that is reachable from the shader 1580 /// entry function or from an exported function (when compiling a shader 1581 /// library). 1582 /// 1583 /// This is done by traversing the AST of all shader entry point functions 1584 /// and of all exported functions, and any functions that are referenced 1585 /// from this AST. In other words, any functions that are reachable from 1586 /// the entry points. 1587 class DiagnoseHLSLAvailability : public DynamicRecursiveASTVisitor { 1588 Sema &SemaRef; 1589 1590 // Stack of functions to be scaned 1591 llvm::SmallVector<const FunctionDecl *, 8> DeclsToScan; 1592 1593 // Tracks which environments functions have been scanned in. 1594 // 1595 // Maps FunctionDecl to an unsigned number that represents the set of shader 1596 // environments the function has been scanned for. 1597 // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed 1598 // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification 1599 // (verified by static_asserts in Triple.cpp), we can use it to index 1600 // individual bits in the set, as long as we shift the values to start with 0 1601 // by subtracting the value of llvm::Triple::Pixel first. 1602 // 1603 // The N'th bit in the set will be set if the function has been scanned 1604 // in shader environment whose llvm::Triple::EnvironmentType integer value 1605 // equals (llvm::Triple::Pixel + N). 1606 // 1607 // For example, if a function has been scanned in compute and pixel stage 1608 // environment, the value will be 0x21 (100001 binary) because: 1609 // 1610 // (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0 1611 // (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5 1612 // 1613 // A FunctionDecl is mapped to 0 (or not included in the map) if it has not 1614 // been scanned in any environment. 1615 llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls; 1616 1617 // Do not access these directly, use the get/set methods below to make 1618 // sure the values are in sync 1619 llvm::Triple::EnvironmentType CurrentShaderEnvironment; 1620 unsigned CurrentShaderStageBit; 1621 1622 // True if scanning a function that was already scanned in a different 1623 // shader stage context, and therefore we should not report issues that 1624 // depend only on shader model version because they would be duplicate. 1625 bool ReportOnlyShaderStageIssues; 1626 1627 // Helper methods for dealing with current stage context / environment 1628 void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) { 1629 static_assert(sizeof(unsigned) >= 4); 1630 assert(HLSLShaderAttr::isValidShaderType(ShaderType)); 1631 assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 && 1632 "ShaderType is too big for this bitmap"); // 31 is reserved for 1633 // "unknown" 1634 1635 unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel; 1636 CurrentShaderEnvironment = ShaderType; 1637 CurrentShaderStageBit = (1 << bitmapIndex); 1638 } 1639 1640 void SetUnknownShaderStageContext() { 1641 CurrentShaderEnvironment = llvm::Triple::UnknownEnvironment; 1642 CurrentShaderStageBit = (1 << 31); 1643 } 1644 1645 llvm::Triple::EnvironmentType GetCurrentShaderEnvironment() const { 1646 return CurrentShaderEnvironment; 1647 } 1648 1649 bool InUnknownShaderStageContext() const { 1650 return CurrentShaderEnvironment == llvm::Triple::UnknownEnvironment; 1651 } 1652 1653 // Helper methods for dealing with shader stage bitmap 1654 void AddToScannedFunctions(const FunctionDecl *FD) { 1655 unsigned &ScannedStages = ScannedDecls[FD]; 1656 ScannedStages |= CurrentShaderStageBit; 1657 } 1658 1659 unsigned GetScannedStages(const FunctionDecl *FD) { return ScannedDecls[FD]; } 1660 1661 bool WasAlreadyScannedInCurrentStage(const FunctionDecl *FD) { 1662 return WasAlreadyScannedInCurrentStage(GetScannedStages(FD)); 1663 } 1664 1665 bool WasAlreadyScannedInCurrentStage(unsigned ScannerStages) { 1666 return ScannerStages & CurrentShaderStageBit; 1667 } 1668 1669 static bool NeverBeenScanned(unsigned ScannedStages) { 1670 return ScannedStages == 0; 1671 } 1672 1673 // Scanning methods 1674 void HandleFunctionOrMethodRef(FunctionDecl *FD, Expr *RefExpr); 1675 void CheckDeclAvailability(NamedDecl *D, const AvailabilityAttr *AA, 1676 SourceRange Range); 1677 const AvailabilityAttr *FindAvailabilityAttr(const Decl *D); 1678 bool HasMatchingEnvironmentOrNone(const AvailabilityAttr *AA); 1679 1680 public: 1681 DiagnoseHLSLAvailability(Sema &SemaRef) 1682 : SemaRef(SemaRef), 1683 CurrentShaderEnvironment(llvm::Triple::UnknownEnvironment), 1684 CurrentShaderStageBit(0), ReportOnlyShaderStageIssues(false) {} 1685 1686 // AST traversal methods 1687 void RunOnTranslationUnit(const TranslationUnitDecl *TU); 1688 void RunOnFunction(const FunctionDecl *FD); 1689 1690 bool VisitDeclRefExpr(DeclRefExpr *DRE) override { 1691 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(DRE->getDecl()); 1692 if (FD) 1693 HandleFunctionOrMethodRef(FD, DRE); 1694 return true; 1695 } 1696 1697 bool VisitMemberExpr(MemberExpr *ME) override { 1698 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(ME->getMemberDecl()); 1699 if (FD) 1700 HandleFunctionOrMethodRef(FD, ME); 1701 return true; 1702 } 1703 }; 1704 1705 void DiagnoseHLSLAvailability::HandleFunctionOrMethodRef(FunctionDecl *FD, 1706 Expr *RefExpr) { 1707 assert((isa<DeclRefExpr>(RefExpr) || isa<MemberExpr>(RefExpr)) && 1708 "expected DeclRefExpr or MemberExpr"); 1709 1710 // has a definition -> add to stack to be scanned 1711 const FunctionDecl *FDWithBody = nullptr; 1712 if (FD->hasBody(FDWithBody)) { 1713 if (!WasAlreadyScannedInCurrentStage(FDWithBody)) 1714 DeclsToScan.push_back(FDWithBody); 1715 return; 1716 } 1717 1718 // no body -> diagnose availability 1719 const AvailabilityAttr *AA = FindAvailabilityAttr(FD); 1720 if (AA) 1721 CheckDeclAvailability( 1722 FD, AA, SourceRange(RefExpr->getBeginLoc(), RefExpr->getEndLoc())); 1723 } 1724 1725 void DiagnoseHLSLAvailability::RunOnTranslationUnit( 1726 const TranslationUnitDecl *TU) { 1727 1728 // Iterate over all shader entry functions and library exports, and for those 1729 // that have a body (definiton), run diag scan on each, setting appropriate 1730 // shader environment context based on whether it is a shader entry function 1731 // or an exported function. Exported functions can be in namespaces and in 1732 // export declarations so we need to scan those declaration contexts as well. 1733 llvm::SmallVector<const DeclContext *, 8> DeclContextsToScan; 1734 DeclContextsToScan.push_back(TU); 1735 1736 while (!DeclContextsToScan.empty()) { 1737 const DeclContext *DC = DeclContextsToScan.pop_back_val(); 1738 for (auto &D : DC->decls()) { 1739 // do not scan implicit declaration generated by the implementation 1740 if (D->isImplicit()) 1741 continue; 1742 1743 // for namespace or export declaration add the context to the list to be 1744 // scanned later 1745 if (llvm::dyn_cast<NamespaceDecl>(D) || llvm::dyn_cast<ExportDecl>(D)) { 1746 DeclContextsToScan.push_back(llvm::dyn_cast<DeclContext>(D)); 1747 continue; 1748 } 1749 1750 // skip over other decls or function decls without body 1751 const FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(D); 1752 if (!FD || !FD->isThisDeclarationADefinition()) 1753 continue; 1754 1755 // shader entry point 1756 if (HLSLShaderAttr *ShaderAttr = FD->getAttr<HLSLShaderAttr>()) { 1757 SetShaderStageContext(ShaderAttr->getType()); 1758 RunOnFunction(FD); 1759 continue; 1760 } 1761 // exported library function 1762 // FIXME: replace this loop with external linkage check once issue #92071 1763 // is resolved 1764 bool isExport = FD->isInExportDeclContext(); 1765 if (!isExport) { 1766 for (const auto *Redecl : FD->redecls()) { 1767 if (Redecl->isInExportDeclContext()) { 1768 isExport = true; 1769 break; 1770 } 1771 } 1772 } 1773 if (isExport) { 1774 SetUnknownShaderStageContext(); 1775 RunOnFunction(FD); 1776 continue; 1777 } 1778 } 1779 } 1780 } 1781 1782 void DiagnoseHLSLAvailability::RunOnFunction(const FunctionDecl *FD) { 1783 assert(DeclsToScan.empty() && "DeclsToScan should be empty"); 1784 DeclsToScan.push_back(FD); 1785 1786 while (!DeclsToScan.empty()) { 1787 // Take one decl from the stack and check it by traversing its AST. 1788 // For any CallExpr found during the traversal add it's callee to the top of 1789 // the stack to be processed next. Functions already processed are stored in 1790 // ScannedDecls. 1791 const FunctionDecl *FD = DeclsToScan.pop_back_val(); 1792 1793 // Decl was already scanned 1794 const unsigned ScannedStages = GetScannedStages(FD); 1795 if (WasAlreadyScannedInCurrentStage(ScannedStages)) 1796 continue; 1797 1798 ReportOnlyShaderStageIssues = !NeverBeenScanned(ScannedStages); 1799 1800 AddToScannedFunctions(FD); 1801 TraverseStmt(FD->getBody()); 1802 } 1803 } 1804 1805 bool DiagnoseHLSLAvailability::HasMatchingEnvironmentOrNone( 1806 const AvailabilityAttr *AA) { 1807 IdentifierInfo *IIEnvironment = AA->getEnvironment(); 1808 if (!IIEnvironment) 1809 return true; 1810 1811 llvm::Triple::EnvironmentType CurrentEnv = GetCurrentShaderEnvironment(); 1812 if (CurrentEnv == llvm::Triple::UnknownEnvironment) 1813 return false; 1814 1815 llvm::Triple::EnvironmentType AttrEnv = 1816 AvailabilityAttr::getEnvironmentType(IIEnvironment->getName()); 1817 1818 return CurrentEnv == AttrEnv; 1819 } 1820 1821 const AvailabilityAttr * 1822 DiagnoseHLSLAvailability::FindAvailabilityAttr(const Decl *D) { 1823 AvailabilityAttr const *PartialMatch = nullptr; 1824 // Check each AvailabilityAttr to find the one for this platform. 1825 // For multiple attributes with the same platform try to find one for this 1826 // environment. 1827 for (const auto *A : D->attrs()) { 1828 if (const auto *Avail = dyn_cast<AvailabilityAttr>(A)) { 1829 StringRef AttrPlatform = Avail->getPlatform()->getName(); 1830 StringRef TargetPlatform = 1831 SemaRef.getASTContext().getTargetInfo().getPlatformName(); 1832 1833 // Match the platform name. 1834 if (AttrPlatform == TargetPlatform) { 1835 // Find the best matching attribute for this environment 1836 if (HasMatchingEnvironmentOrNone(Avail)) 1837 return Avail; 1838 PartialMatch = Avail; 1839 } 1840 } 1841 } 1842 return PartialMatch; 1843 } 1844 1845 // Check availability against target shader model version and current shader 1846 // stage and emit diagnostic 1847 void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D, 1848 const AvailabilityAttr *AA, 1849 SourceRange Range) { 1850 1851 IdentifierInfo *IIEnv = AA->getEnvironment(); 1852 1853 if (!IIEnv) { 1854 // The availability attribute does not have environment -> it depends only 1855 // on shader model version and not on specific the shader stage. 1856 1857 // Skip emitting the diagnostics if the diagnostic mode is set to 1858 // strict (-fhlsl-strict-availability) because all relevant diagnostics 1859 // were already emitted in the DiagnoseUnguardedAvailability scan 1860 // (SemaAvailability.cpp). 1861 if (SemaRef.getLangOpts().HLSLStrictAvailability) 1862 return; 1863 1864 // Do not report shader-stage-independent issues if scanning a function 1865 // that was already scanned in a different shader stage context (they would 1866 // be duplicate) 1867 if (ReportOnlyShaderStageIssues) 1868 return; 1869 1870 } else { 1871 // The availability attribute has environment -> we need to know 1872 // the current stage context to property diagnose it. 1873 if (InUnknownShaderStageContext()) 1874 return; 1875 } 1876 1877 // Check introduced version and if environment matches 1878 bool EnvironmentMatches = HasMatchingEnvironmentOrNone(AA); 1879 VersionTuple Introduced = AA->getIntroduced(); 1880 VersionTuple TargetVersion = 1881 SemaRef.Context.getTargetInfo().getPlatformMinVersion(); 1882 1883 if (TargetVersion >= Introduced && EnvironmentMatches) 1884 return; 1885 1886 // Emit diagnostic message 1887 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo(); 1888 llvm::StringRef PlatformName( 1889 AvailabilityAttr::getPrettyPlatformName(TI.getPlatformName())); 1890 1891 llvm::StringRef CurrentEnvStr = 1892 llvm::Triple::getEnvironmentTypeName(GetCurrentShaderEnvironment()); 1893 1894 llvm::StringRef AttrEnvStr = 1895 AA->getEnvironment() ? AA->getEnvironment()->getName() : ""; 1896 bool UseEnvironment = !AttrEnvStr.empty(); 1897 1898 if (EnvironmentMatches) { 1899 SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability) 1900 << Range << D << PlatformName << Introduced.getAsString() 1901 << UseEnvironment << CurrentEnvStr; 1902 } else { 1903 SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability_unavailable) 1904 << Range << D; 1905 } 1906 1907 SemaRef.Diag(D->getLocation(), diag::note_partial_availability_specified_here) 1908 << D << PlatformName << Introduced.getAsString() 1909 << SemaRef.Context.getTargetInfo().getPlatformMinVersion().getAsString() 1910 << UseEnvironment << AttrEnvStr << CurrentEnvStr; 1911 } 1912 1913 } // namespace 1914 1915 void SemaHLSL::DiagnoseAvailabilityViolations(TranslationUnitDecl *TU) { 1916 // Skip running the diagnostics scan if the diagnostic mode is 1917 // strict (-fhlsl-strict-availability) and the target shader stage is known 1918 // because all relevant diagnostics were already emitted in the 1919 // DiagnoseUnguardedAvailability scan (SemaAvailability.cpp). 1920 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo(); 1921 if (SemaRef.getLangOpts().HLSLStrictAvailability && 1922 TI.getTriple().getEnvironment() != llvm::Triple::EnvironmentType::Library) 1923 return; 1924 1925 DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU); 1926 } 1927 1928 // Helper function for CheckHLSLBuiltinFunctionCall 1929 static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) { 1930 assert(TheCall->getNumArgs() > 1); 1931 ExprResult A = TheCall->getArg(0); 1932 1933 QualType ArgTyA = A.get()->getType(); 1934 1935 auto *VecTyA = ArgTyA->getAs<VectorType>(); 1936 SourceLocation BuiltinLoc = TheCall->getBeginLoc(); 1937 1938 bool AllBArgAreVectors = true; 1939 for (unsigned i = 1; i < TheCall->getNumArgs(); ++i) { 1940 ExprResult B = TheCall->getArg(i); 1941 QualType ArgTyB = B.get()->getType(); 1942 auto *VecTyB = ArgTyB->getAs<VectorType>(); 1943 if (VecTyB == nullptr) 1944 AllBArgAreVectors &= false; 1945 if (VecTyA && VecTyB == nullptr) { 1946 // Note: if we get here 'B' is scalar which 1947 // requires a VectorSplat on ArgN 1948 S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector) 1949 << TheCall->getDirectCallee() << /*useAllTerminology*/ true 1950 << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc()); 1951 return true; 1952 } 1953 if (VecTyA && VecTyB) { 1954 bool retValue = false; 1955 if (VecTyA->getElementType() != VecTyB->getElementType()) { 1956 // Note: type promotion is intended to be handeled via the intrinsics 1957 // and not the builtin itself. 1958 S->Diag(TheCall->getBeginLoc(), 1959 diag::err_vec_builtin_incompatible_vector) 1960 << TheCall->getDirectCallee() << /*useAllTerminology*/ true 1961 << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc()); 1962 retValue = true; 1963 } 1964 if (VecTyA->getNumElements() != VecTyB->getNumElements()) { 1965 // You should only be hitting this case if you are calling the builtin 1966 // directly. HLSL intrinsics should avoid this case via a 1967 // HLSLVectorTruncation. 1968 S->Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector) 1969 << TheCall->getDirectCallee() << /*useAllTerminology*/ true 1970 << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc()); 1971 retValue = true; 1972 } 1973 if (retValue) 1974 return retValue; 1975 } 1976 } 1977 1978 if (VecTyA == nullptr && AllBArgAreVectors) { 1979 // Note: if we get here 'A' is a scalar which 1980 // requires a VectorSplat on Arg0 1981 S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector) 1982 << TheCall->getDirectCallee() << /*useAllTerminology*/ true 1983 << SourceRange(A.get()->getBeginLoc(), A.get()->getEndLoc()); 1984 return true; 1985 } 1986 return false; 1987 } 1988 1989 static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) { 1990 QualType ArgType = Arg->getType(); 1991 if (!S->getASTContext().hasSameUnqualifiedType(ArgType, ExpectedType)) { 1992 S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible) 1993 << ArgType << ExpectedType << 1 << 0 << 0; 1994 return true; 1995 } 1996 return false; 1997 } 1998 1999 static bool CheckArgTypeIsCorrect( 2000 Sema *S, Expr *Arg, QualType ExpectedType, 2001 llvm::function_ref<bool(clang::QualType PassedType)> Check) { 2002 QualType PassedType = Arg->getType(); 2003 if (Check(PassedType)) { 2004 if (auto *VecTyA = PassedType->getAs<VectorType>()) 2005 ExpectedType = S->Context.getVectorType( 2006 ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind()); 2007 S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible) 2008 << PassedType << ExpectedType << 1 << 0 << 0; 2009 return true; 2010 } 2011 return false; 2012 } 2013 2014 static bool CheckAllArgTypesAreCorrect( 2015 Sema *S, CallExpr *TheCall, QualType ExpectedType, 2016 llvm::function_ref<bool(clang::QualType PassedType)> Check) { 2017 for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) { 2018 Expr *Arg = TheCall->getArg(i); 2019 if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) { 2020 return true; 2021 } 2022 } 2023 return false; 2024 } 2025 2026 static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) { 2027 auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool { 2028 return !PassedType->hasFloatingRepresentation(); 2029 }; 2030 return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy, 2031 checkAllFloatTypes); 2032 } 2033 2034 static bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) { 2035 auto checkFloatorHalf = [](clang::QualType PassedType) -> bool { 2036 clang::QualType BaseType = 2037 PassedType->isVectorType() 2038 ? PassedType->getAs<clang::VectorType>()->getElementType() 2039 : PassedType; 2040 return !BaseType->isHalfType() && !BaseType->isFloat32Type(); 2041 }; 2042 return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy, 2043 checkFloatorHalf); 2044 } 2045 2046 static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall, 2047 unsigned ArgIndex) { 2048 auto *Arg = TheCall->getArg(ArgIndex); 2049 SourceLocation OrigLoc = Arg->getExprLoc(); 2050 if (Arg->IgnoreCasts()->isModifiableLvalue(S->Context, &OrigLoc) == 2051 Expr::MLV_Valid) 2052 return false; 2053 S->Diag(OrigLoc, diag::error_hlsl_inout_lvalue) << Arg << 0; 2054 return true; 2055 } 2056 2057 static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) { 2058 auto checkDoubleVector = [](clang::QualType PassedType) -> bool { 2059 if (const auto *VecTy = PassedType->getAs<VectorType>()) 2060 return VecTy->getElementType()->isDoubleType(); 2061 return false; 2062 }; 2063 return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy, 2064 checkDoubleVector); 2065 } 2066 static bool CheckFloatingOrIntRepresentation(Sema *S, CallExpr *TheCall) { 2067 auto checkAllSignedTypes = [](clang::QualType PassedType) -> bool { 2068 return !PassedType->hasIntegerRepresentation() && 2069 !PassedType->hasFloatingRepresentation(); 2070 }; 2071 return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.IntTy, 2072 checkAllSignedTypes); 2073 } 2074 2075 static bool CheckUnsignedIntRepresentation(Sema *S, CallExpr *TheCall) { 2076 auto checkAllUnsignedTypes = [](clang::QualType PassedType) -> bool { 2077 return !PassedType->hasUnsignedIntegerRepresentation(); 2078 }; 2079 return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy, 2080 checkAllUnsignedTypes); 2081 } 2082 2083 static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall, 2084 QualType ReturnType) { 2085 auto *VecTyA = TheCall->getArg(0)->getType()->getAs<VectorType>(); 2086 if (VecTyA) 2087 ReturnType = S->Context.getVectorType(ReturnType, VecTyA->getNumElements(), 2088 VectorKind::Generic); 2089 TheCall->setType(ReturnType); 2090 } 2091 2092 static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar, 2093 unsigned ArgIndex) { 2094 assert(TheCall->getNumArgs() >= ArgIndex); 2095 QualType ArgType = TheCall->getArg(ArgIndex)->getType(); 2096 auto *VTy = ArgType->getAs<VectorType>(); 2097 // not the scalar or vector<scalar> 2098 if (!(S->Context.hasSameUnqualifiedType(ArgType, Scalar) || 2099 (VTy && 2100 S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar)))) { 2101 S->Diag(TheCall->getArg(0)->getBeginLoc(), 2102 diag::err_typecheck_expect_scalar_or_vector) 2103 << ArgType << Scalar; 2104 return true; 2105 } 2106 return false; 2107 } 2108 2109 static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall, 2110 unsigned ArgIndex) { 2111 assert(TheCall->getNumArgs() >= ArgIndex); 2112 QualType ArgType = TheCall->getArg(ArgIndex)->getType(); 2113 auto *VTy = ArgType->getAs<VectorType>(); 2114 // not the scalar or vector<scalar> 2115 if (!(ArgType->isScalarType() || 2116 (VTy && VTy->getElementType()->isScalarType()))) { 2117 S->Diag(TheCall->getArg(0)->getBeginLoc(), 2118 diag::err_typecheck_expect_any_scalar_or_vector) 2119 << ArgType << 1; 2120 return true; 2121 } 2122 return false; 2123 } 2124 2125 static bool CheckWaveActive(Sema *S, CallExpr *TheCall) { 2126 QualType BoolType = S->getASTContext().BoolTy; 2127 assert(TheCall->getNumArgs() >= 1); 2128 QualType ArgType = TheCall->getArg(0)->getType(); 2129 auto *VTy = ArgType->getAs<VectorType>(); 2130 // is the bool or vector<bool> 2131 if (S->Context.hasSameUnqualifiedType(ArgType, BoolType) || 2132 (VTy && 2133 S->Context.hasSameUnqualifiedType(VTy->getElementType(), BoolType))) { 2134 S->Diag(TheCall->getArg(0)->getBeginLoc(), 2135 diag::err_typecheck_expect_any_scalar_or_vector) 2136 << ArgType << 0; 2137 return true; 2138 } 2139 return false; 2140 } 2141 2142 static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) { 2143 assert(TheCall->getNumArgs() == 3); 2144 Expr *Arg1 = TheCall->getArg(1); 2145 Expr *Arg2 = TheCall->getArg(2); 2146 if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) { 2147 S->Diag(TheCall->getBeginLoc(), 2148 diag::err_typecheck_call_different_arg_types) 2149 << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange() 2150 << Arg2->getSourceRange(); 2151 return true; 2152 } 2153 2154 TheCall->setType(Arg1->getType()); 2155 return false; 2156 } 2157 2158 static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) { 2159 assert(TheCall->getNumArgs() == 3); 2160 Expr *Arg1 = TheCall->getArg(1); 2161 Expr *Arg2 = TheCall->getArg(2); 2162 if (!Arg1->getType()->isVectorType()) { 2163 S->Diag(Arg1->getBeginLoc(), diag::err_builtin_non_vector_type) 2164 << "Second" << TheCall->getDirectCallee() << Arg1->getType() 2165 << Arg1->getSourceRange(); 2166 return true; 2167 } 2168 2169 if (!Arg2->getType()->isVectorType()) { 2170 S->Diag(Arg2->getBeginLoc(), diag::err_builtin_non_vector_type) 2171 << "Third" << TheCall->getDirectCallee() << Arg2->getType() 2172 << Arg2->getSourceRange(); 2173 return true; 2174 } 2175 2176 if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) { 2177 S->Diag(TheCall->getBeginLoc(), 2178 diag::err_typecheck_call_different_arg_types) 2179 << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange() 2180 << Arg2->getSourceRange(); 2181 return true; 2182 } 2183 2184 // caller has checked that Arg0 is a vector. 2185 // check all three args have the same length. 2186 if (TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() != 2187 Arg1->getType()->getAs<VectorType>()->getNumElements()) { 2188 S->Diag(TheCall->getBeginLoc(), 2189 diag::err_typecheck_vector_lengths_not_equal) 2190 << TheCall->getArg(0)->getType() << Arg1->getType() 2191 << TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange(); 2192 return true; 2193 } 2194 TheCall->setType(Arg1->getType()); 2195 return false; 2196 } 2197 2198 static bool CheckResourceHandle( 2199 Sema *S, CallExpr *TheCall, unsigned ArgIndex, 2200 llvm::function_ref<bool(const HLSLAttributedResourceType *ResType)> Check = 2201 nullptr) { 2202 assert(TheCall->getNumArgs() >= ArgIndex); 2203 QualType ArgType = TheCall->getArg(ArgIndex)->getType(); 2204 const HLSLAttributedResourceType *ResTy = 2205 ArgType.getTypePtr()->getAs<HLSLAttributedResourceType>(); 2206 if (!ResTy) { 2207 S->Diag(TheCall->getArg(ArgIndex)->getBeginLoc(), 2208 diag::err_typecheck_expect_hlsl_resource) 2209 << ArgType; 2210 return true; 2211 } 2212 if (Check && Check(ResTy)) { 2213 S->Diag(TheCall->getArg(ArgIndex)->getExprLoc(), 2214 diag::err_invalid_hlsl_resource_type) 2215 << ArgType; 2216 return true; 2217 } 2218 return false; 2219 } 2220 2221 // Note: returning true in this case results in CheckBuiltinFunctionCall 2222 // returning an ExprError 2223 bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { 2224 switch (BuiltinID) { 2225 case Builtin::BI__builtin_hlsl_resource_getpointer: { 2226 if (SemaRef.checkArgCount(TheCall, 2) || 2227 CheckResourceHandle(&SemaRef, TheCall, 0) || 2228 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1), 2229 SemaRef.getASTContext().UnsignedIntTy)) 2230 return true; 2231 2232 auto *ResourceTy = 2233 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>(); 2234 QualType ContainedTy = ResourceTy->getContainedType(); 2235 // TODO: Map to an hlsl_device address space. 2236 TheCall->setType(getASTContext().getPointerType(ContainedTy)); 2237 TheCall->setValueKind(VK_LValue); 2238 2239 break; 2240 } 2241 case Builtin::BI__builtin_hlsl_all: 2242 case Builtin::BI__builtin_hlsl_any: { 2243 if (SemaRef.checkArgCount(TheCall, 1)) 2244 return true; 2245 break; 2246 } 2247 case Builtin::BI__builtin_hlsl_asdouble: { 2248 if (SemaRef.checkArgCount(TheCall, 2)) 2249 return true; 2250 if (CheckUnsignedIntRepresentation(&SemaRef, TheCall)) 2251 return true; 2252 2253 SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().DoubleTy); 2254 break; 2255 } 2256 case Builtin::BI__builtin_hlsl_elementwise_clamp: { 2257 if (SemaRef.checkArgCount(TheCall, 3)) 2258 return true; 2259 if (CheckVectorElementCallArgs(&SemaRef, TheCall)) 2260 return true; 2261 if (SemaRef.BuiltinElementwiseTernaryMath( 2262 TheCall, /*CheckForFloatArgs*/ 2263 TheCall->getArg(0)->getType()->hasFloatingRepresentation())) 2264 return true; 2265 break; 2266 } 2267 case Builtin::BI__builtin_hlsl_cross: { 2268 if (SemaRef.checkArgCount(TheCall, 2)) 2269 return true; 2270 if (CheckVectorElementCallArgs(&SemaRef, TheCall)) 2271 return true; 2272 if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall)) 2273 return true; 2274 // ensure both args have 3 elements 2275 int NumElementsArg1 = 2276 TheCall->getArg(0)->getType()->castAs<VectorType>()->getNumElements(); 2277 int NumElementsArg2 = 2278 TheCall->getArg(1)->getType()->castAs<VectorType>()->getNumElements(); 2279 2280 if (NumElementsArg1 != 3) { 2281 int LessOrMore = NumElementsArg1 > 3 ? 1 : 0; 2282 SemaRef.Diag(TheCall->getBeginLoc(), 2283 diag::err_vector_incorrect_num_elements) 2284 << LessOrMore << 3 << NumElementsArg1 << /*operand*/ 1; 2285 return true; 2286 } 2287 if (NumElementsArg2 != 3) { 2288 int LessOrMore = NumElementsArg2 > 3 ? 1 : 0; 2289 2290 SemaRef.Diag(TheCall->getBeginLoc(), 2291 diag::err_vector_incorrect_num_elements) 2292 << LessOrMore << 3 << NumElementsArg2 << /*operand*/ 1; 2293 return true; 2294 } 2295 2296 ExprResult A = TheCall->getArg(0); 2297 QualType ArgTyA = A.get()->getType(); 2298 // return type is the same as the input type 2299 TheCall->setType(ArgTyA); 2300 break; 2301 } 2302 case Builtin::BI__builtin_hlsl_dot: { 2303 if (SemaRef.checkArgCount(TheCall, 2)) 2304 return true; 2305 if (CheckVectorElementCallArgs(&SemaRef, TheCall)) 2306 return true; 2307 if (SemaRef.BuiltinVectorToScalarMath(TheCall)) 2308 return true; 2309 if (CheckNoDoubleVectors(&SemaRef, TheCall)) 2310 return true; 2311 break; 2312 } 2313 case Builtin::BI__builtin_hlsl_elementwise_firstbithigh: 2314 case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: { 2315 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall)) 2316 return true; 2317 2318 const Expr *Arg = TheCall->getArg(0); 2319 QualType ArgTy = Arg->getType(); 2320 QualType EltTy = ArgTy; 2321 2322 QualType ResTy = SemaRef.Context.UnsignedIntTy; 2323 2324 if (auto *VecTy = EltTy->getAs<VectorType>()) { 2325 EltTy = VecTy->getElementType(); 2326 ResTy = SemaRef.Context.getVectorType(ResTy, VecTy->getNumElements(), 2327 VecTy->getVectorKind()); 2328 } 2329 2330 if (!EltTy->isIntegerType()) { 2331 Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type) 2332 << 1 << /* integer ty */ 6 << ArgTy; 2333 return true; 2334 } 2335 2336 TheCall->setType(ResTy); 2337 break; 2338 } 2339 case Builtin::BI__builtin_hlsl_select: { 2340 if (SemaRef.checkArgCount(TheCall, 3)) 2341 return true; 2342 if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0)) 2343 return true; 2344 QualType ArgTy = TheCall->getArg(0)->getType(); 2345 if (ArgTy->isBooleanType() && CheckBoolSelect(&SemaRef, TheCall)) 2346 return true; 2347 auto *VTy = ArgTy->getAs<VectorType>(); 2348 if (VTy && VTy->getElementType()->isBooleanType() && 2349 CheckVectorSelect(&SemaRef, TheCall)) 2350 return true; 2351 break; 2352 } 2353 case Builtin::BI__builtin_hlsl_elementwise_saturate: 2354 case Builtin::BI__builtin_hlsl_elementwise_rcp: { 2355 if (CheckAllArgsHaveFloatRepresentation(&SemaRef, TheCall)) 2356 return true; 2357 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall)) 2358 return true; 2359 break; 2360 } 2361 case Builtin::BI__builtin_hlsl_elementwise_degrees: 2362 case Builtin::BI__builtin_hlsl_elementwise_radians: 2363 case Builtin::BI__builtin_hlsl_elementwise_rsqrt: 2364 case Builtin::BI__builtin_hlsl_elementwise_frac: { 2365 if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall)) 2366 return true; 2367 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall)) 2368 return true; 2369 break; 2370 } 2371 case Builtin::BI__builtin_hlsl_elementwise_isinf: { 2372 if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall)) 2373 return true; 2374 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall)) 2375 return true; 2376 SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().BoolTy); 2377 break; 2378 } 2379 case Builtin::BI__builtin_hlsl_lerp: { 2380 if (SemaRef.checkArgCount(TheCall, 3)) 2381 return true; 2382 if (CheckVectorElementCallArgs(&SemaRef, TheCall)) 2383 return true; 2384 if (SemaRef.BuiltinElementwiseTernaryMath(TheCall)) 2385 return true; 2386 if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall)) 2387 return true; 2388 break; 2389 } 2390 case Builtin::BI__builtin_hlsl_mad: { 2391 if (SemaRef.checkArgCount(TheCall, 3)) 2392 return true; 2393 if (CheckVectorElementCallArgs(&SemaRef, TheCall)) 2394 return true; 2395 if (SemaRef.BuiltinElementwiseTernaryMath( 2396 TheCall, /*CheckForFloatArgs*/ 2397 TheCall->getArg(0)->getType()->hasFloatingRepresentation())) 2398 return true; 2399 break; 2400 } 2401 case Builtin::BI__builtin_hlsl_normalize: { 2402 if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall)) 2403 return true; 2404 if (SemaRef.checkArgCount(TheCall, 1)) 2405 return true; 2406 2407 ExprResult A = TheCall->getArg(0); 2408 QualType ArgTyA = A.get()->getType(); 2409 // return type is the same as the input type 2410 TheCall->setType(ArgTyA); 2411 break; 2412 } 2413 case Builtin::BI__builtin_hlsl_elementwise_sign: { 2414 if (CheckFloatingOrIntRepresentation(&SemaRef, TheCall)) 2415 return true; 2416 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall)) 2417 return true; 2418 SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().IntTy); 2419 break; 2420 } 2421 case Builtin::BI__builtin_hlsl_step: { 2422 if (SemaRef.checkArgCount(TheCall, 2)) 2423 return true; 2424 if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall)) 2425 return true; 2426 2427 ExprResult A = TheCall->getArg(0); 2428 QualType ArgTyA = A.get()->getType(); 2429 // return type is the same as the input type 2430 TheCall->setType(ArgTyA); 2431 break; 2432 } 2433 case Builtin::BI__builtin_hlsl_wave_active_max: 2434 case Builtin::BI__builtin_hlsl_wave_active_sum: { 2435 if (SemaRef.checkArgCount(TheCall, 1)) 2436 return true; 2437 2438 // Ensure input expr type is a scalar/vector and the same as the return type 2439 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0)) 2440 return true; 2441 if (CheckWaveActive(&SemaRef, TheCall)) 2442 return true; 2443 ExprResult Expr = TheCall->getArg(0); 2444 QualType ArgTyExpr = Expr.get()->getType(); 2445 TheCall->setType(ArgTyExpr); 2446 break; 2447 } 2448 // Note these are llvm builtins that we want to catch invalid intrinsic 2449 // generation. Normal handling of these builitns will occur elsewhere. 2450 case Builtin::BI__builtin_elementwise_bitreverse: { 2451 if (CheckUnsignedIntRepresentation(&SemaRef, TheCall)) 2452 return true; 2453 break; 2454 } 2455 case Builtin::BI__builtin_hlsl_wave_read_lane_at: { 2456 if (SemaRef.checkArgCount(TheCall, 2)) 2457 return true; 2458 2459 // Ensure index parameter type can be interpreted as a uint 2460 ExprResult Index = TheCall->getArg(1); 2461 QualType ArgTyIndex = Index.get()->getType(); 2462 if (!ArgTyIndex->isIntegerType()) { 2463 SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(), 2464 diag::err_typecheck_convert_incompatible) 2465 << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0; 2466 return true; 2467 } 2468 2469 // Ensure input expr type is a scalar/vector and the same as the return type 2470 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0)) 2471 return true; 2472 2473 ExprResult Expr = TheCall->getArg(0); 2474 QualType ArgTyExpr = Expr.get()->getType(); 2475 TheCall->setType(ArgTyExpr); 2476 break; 2477 } 2478 case Builtin::BI__builtin_hlsl_wave_get_lane_index: { 2479 if (SemaRef.checkArgCount(TheCall, 0)) 2480 return true; 2481 break; 2482 } 2483 case Builtin::BI__builtin_hlsl_elementwise_splitdouble: { 2484 if (SemaRef.checkArgCount(TheCall, 3)) 2485 return true; 2486 2487 if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.DoubleTy, 0) || 2488 CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy, 2489 1) || 2490 CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy, 2491 2)) 2492 return true; 2493 2494 if (CheckModifiableLValue(&SemaRef, TheCall, 1) || 2495 CheckModifiableLValue(&SemaRef, TheCall, 2)) 2496 return true; 2497 break; 2498 } 2499 case Builtin::BI__builtin_hlsl_elementwise_clip: { 2500 if (SemaRef.checkArgCount(TheCall, 1)) 2501 return true; 2502 2503 if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.FloatTy, 0)) 2504 return true; 2505 break; 2506 } 2507 case Builtin::BI__builtin_elementwise_acos: 2508 case Builtin::BI__builtin_elementwise_asin: 2509 case Builtin::BI__builtin_elementwise_atan: 2510 case Builtin::BI__builtin_elementwise_atan2: 2511 case Builtin::BI__builtin_elementwise_ceil: 2512 case Builtin::BI__builtin_elementwise_cos: 2513 case Builtin::BI__builtin_elementwise_cosh: 2514 case Builtin::BI__builtin_elementwise_exp: 2515 case Builtin::BI__builtin_elementwise_exp2: 2516 case Builtin::BI__builtin_elementwise_floor: 2517 case Builtin::BI__builtin_elementwise_fmod: 2518 case Builtin::BI__builtin_elementwise_log: 2519 case Builtin::BI__builtin_elementwise_log2: 2520 case Builtin::BI__builtin_elementwise_log10: 2521 case Builtin::BI__builtin_elementwise_pow: 2522 case Builtin::BI__builtin_elementwise_roundeven: 2523 case Builtin::BI__builtin_elementwise_sin: 2524 case Builtin::BI__builtin_elementwise_sinh: 2525 case Builtin::BI__builtin_elementwise_sqrt: 2526 case Builtin::BI__builtin_elementwise_tan: 2527 case Builtin::BI__builtin_elementwise_tanh: 2528 case Builtin::BI__builtin_elementwise_trunc: { 2529 if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall)) 2530 return true; 2531 break; 2532 } 2533 case Builtin::BI__builtin_hlsl_buffer_update_counter: { 2534 auto checkResTy = [](const HLSLAttributedResourceType *ResTy) -> bool { 2535 return !(ResTy->getAttrs().ResourceClass == ResourceClass::UAV && 2536 ResTy->getAttrs().RawBuffer && ResTy->hasContainedType()); 2537 }; 2538 if (SemaRef.checkArgCount(TheCall, 2) || 2539 CheckResourceHandle(&SemaRef, TheCall, 0, checkResTy) || 2540 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1), 2541 SemaRef.getASTContext().IntTy)) 2542 return true; 2543 Expr *OffsetExpr = TheCall->getArg(1); 2544 std::optional<llvm::APSInt> Offset = 2545 OffsetExpr->getIntegerConstantExpr(SemaRef.getASTContext()); 2546 if (!Offset.has_value() || std::abs(Offset->getExtValue()) != 1) { 2547 SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(), 2548 diag::err_hlsl_expect_arg_const_int_one_or_neg_one) 2549 << 1; 2550 return true; 2551 } 2552 break; 2553 } 2554 } 2555 return false; 2556 } 2557 2558 static void BuildFlattenedTypeList(QualType BaseTy, 2559 llvm::SmallVectorImpl<QualType> &List) { 2560 llvm::SmallVector<QualType, 16> WorkList; 2561 WorkList.push_back(BaseTy); 2562 while (!WorkList.empty()) { 2563 QualType T = WorkList.pop_back_val(); 2564 T = T.getCanonicalType().getUnqualifiedType(); 2565 assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL"); 2566 if (const auto *AT = dyn_cast<ConstantArrayType>(T)) { 2567 llvm::SmallVector<QualType, 16> ElementFields; 2568 // Generally I've avoided recursion in this algorithm, but arrays of 2569 // structs could be time-consuming to flatten and churn through on the 2570 // work list. Hopefully nesting arrays of structs containing arrays 2571 // of structs too many levels deep is unlikely. 2572 BuildFlattenedTypeList(AT->getElementType(), ElementFields); 2573 // Repeat the element's field list n times. 2574 for (uint64_t Ct = 0; Ct < AT->getZExtSize(); ++Ct) 2575 List.insert(List.end(), ElementFields.begin(), ElementFields.end()); 2576 continue; 2577 } 2578 // Vectors can only have element types that are builtin types, so this can 2579 // add directly to the list instead of to the WorkList. 2580 if (const auto *VT = dyn_cast<VectorType>(T)) { 2581 List.insert(List.end(), VT->getNumElements(), VT->getElementType()); 2582 continue; 2583 } 2584 if (const auto *RT = dyn_cast<RecordType>(T)) { 2585 const RecordDecl *RD = RT->getDecl(); 2586 if (RD->isUnion()) { 2587 List.push_back(T); 2588 continue; 2589 } 2590 const CXXRecordDecl *CXXD = dyn_cast<CXXRecordDecl>(RD); 2591 2592 llvm::SmallVector<QualType, 16> FieldTypes; 2593 if (CXXD && CXXD->isStandardLayout()) 2594 RD = CXXD->getStandardLayoutBaseWithFields(); 2595 2596 for (const auto *FD : RD->fields()) 2597 FieldTypes.push_back(FD->getType()); 2598 // Reverse the newly added sub-range. 2599 std::reverse(FieldTypes.begin(), FieldTypes.end()); 2600 WorkList.insert(WorkList.end(), FieldTypes.begin(), FieldTypes.end()); 2601 2602 // If this wasn't a standard layout type we may also have some base 2603 // classes to deal with. 2604 if (CXXD && !CXXD->isStandardLayout()) { 2605 FieldTypes.clear(); 2606 for (const auto &Base : CXXD->bases()) 2607 FieldTypes.push_back(Base.getType()); 2608 std::reverse(FieldTypes.begin(), FieldTypes.end()); 2609 WorkList.insert(WorkList.end(), FieldTypes.begin(), FieldTypes.end()); 2610 } 2611 continue; 2612 } 2613 List.push_back(T); 2614 } 2615 } 2616 2617 bool SemaHLSL::IsTypedResourceElementCompatible(clang::QualType QT) { 2618 // null and array types are not allowed. 2619 if (QT.isNull() || QT->isArrayType()) 2620 return false; 2621 2622 // UDT types are not allowed 2623 if (QT->isRecordType()) 2624 return false; 2625 2626 if (QT->isBooleanType() || QT->isEnumeralType()) 2627 return false; 2628 2629 // the only other valid builtin types are scalars or vectors 2630 if (QT->isArithmeticType()) { 2631 if (SemaRef.Context.getTypeSize(QT) / 8 > 16) 2632 return false; 2633 return true; 2634 } 2635 2636 if (const VectorType *VT = QT->getAs<VectorType>()) { 2637 int ArraySize = VT->getNumElements(); 2638 2639 if (ArraySize > 4) 2640 return false; 2641 2642 QualType ElTy = VT->getElementType(); 2643 if (ElTy->isBooleanType()) 2644 return false; 2645 2646 if (SemaRef.Context.getTypeSize(QT) / 8 > 16) 2647 return false; 2648 return true; 2649 } 2650 2651 return false; 2652 } 2653 2654 bool SemaHLSL::IsScalarizedLayoutCompatible(QualType T1, QualType T2) const { 2655 if (T1.isNull() || T2.isNull()) 2656 return false; 2657 2658 T1 = T1.getCanonicalType().getUnqualifiedType(); 2659 T2 = T2.getCanonicalType().getUnqualifiedType(); 2660 2661 // If both types are the same canonical type, they're obviously compatible. 2662 if (SemaRef.getASTContext().hasSameType(T1, T2)) 2663 return true; 2664 2665 llvm::SmallVector<QualType, 16> T1Types; 2666 BuildFlattenedTypeList(T1, T1Types); 2667 llvm::SmallVector<QualType, 16> T2Types; 2668 BuildFlattenedTypeList(T2, T2Types); 2669 2670 // Check the flattened type list 2671 return llvm::equal(T1Types, T2Types, 2672 [this](QualType LHS, QualType RHS) -> bool { 2673 return SemaRef.IsLayoutCompatible(LHS, RHS); 2674 }); 2675 } 2676 2677 bool SemaHLSL::CheckCompatibleParameterABI(FunctionDecl *New, 2678 FunctionDecl *Old) { 2679 if (New->getNumParams() != Old->getNumParams()) 2680 return true; 2681 2682 bool HadError = false; 2683 2684 for (unsigned i = 0, e = New->getNumParams(); i != e; ++i) { 2685 ParmVarDecl *NewParam = New->getParamDecl(i); 2686 ParmVarDecl *OldParam = Old->getParamDecl(i); 2687 2688 // HLSL parameter declarations for inout and out must match between 2689 // declarations. In HLSL inout and out are ambiguous at the call site, 2690 // but have different calling behavior, so you cannot overload a 2691 // method based on a difference between inout and out annotations. 2692 const auto *NDAttr = NewParam->getAttr<HLSLParamModifierAttr>(); 2693 unsigned NSpellingIdx = (NDAttr ? NDAttr->getSpellingListIndex() : 0); 2694 const auto *ODAttr = OldParam->getAttr<HLSLParamModifierAttr>(); 2695 unsigned OSpellingIdx = (ODAttr ? ODAttr->getSpellingListIndex() : 0); 2696 2697 if (NSpellingIdx != OSpellingIdx) { 2698 SemaRef.Diag(NewParam->getLocation(), 2699 diag::err_hlsl_param_qualifier_mismatch) 2700 << NDAttr << NewParam; 2701 SemaRef.Diag(OldParam->getLocation(), diag::note_previous_declaration_as) 2702 << ODAttr; 2703 HadError = true; 2704 } 2705 } 2706 return HadError; 2707 } 2708 2709 ExprResult SemaHLSL::ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg) { 2710 assert(Param->hasAttr<HLSLParamModifierAttr>() && 2711 "We should not get here without a parameter modifier expression"); 2712 const auto *Attr = Param->getAttr<HLSLParamModifierAttr>(); 2713 if (Attr->getABI() == ParameterABI::Ordinary) 2714 return ExprResult(Arg); 2715 2716 bool IsInOut = Attr->getABI() == ParameterABI::HLSLInOut; 2717 if (!Arg->isLValue()) { 2718 SemaRef.Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_lvalue) 2719 << Arg << (IsInOut ? 1 : 0); 2720 return ExprError(); 2721 } 2722 2723 ASTContext &Ctx = SemaRef.getASTContext(); 2724 2725 QualType Ty = Param->getType().getNonLValueExprType(Ctx); 2726 2727 // HLSL allows implicit conversions from scalars to vectors, but not the 2728 // inverse, so we need to disallow `inout` with scalar->vector or 2729 // scalar->matrix conversions. 2730 if (Arg->getType()->isScalarType() != Ty->isScalarType()) { 2731 SemaRef.Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_scalar_extension) 2732 << Arg << (IsInOut ? 1 : 0); 2733 return ExprError(); 2734 } 2735 2736 auto *ArgOpV = new (Ctx) OpaqueValueExpr(Param->getBeginLoc(), Arg->getType(), 2737 VK_LValue, OK_Ordinary, Arg); 2738 2739 // Parameters are initialized via copy initialization. This allows for 2740 // overload resolution of argument constructors. 2741 InitializedEntity Entity = 2742 InitializedEntity::InitializeParameter(Ctx, Ty, false); 2743 ExprResult Res = 2744 SemaRef.PerformCopyInitialization(Entity, Param->getBeginLoc(), ArgOpV); 2745 if (Res.isInvalid()) 2746 return ExprError(); 2747 Expr *Base = Res.get(); 2748 // After the cast, drop the reference type when creating the exprs. 2749 Ty = Ty.getNonLValueExprType(Ctx); 2750 auto *OpV = new (Ctx) 2751 OpaqueValueExpr(Param->getBeginLoc(), Ty, VK_LValue, OK_Ordinary, Base); 2752 2753 // Writebacks are performed with `=` binary operator, which allows for 2754 // overload resolution on writeback result expressions. 2755 Res = SemaRef.ActOnBinOp(SemaRef.getCurScope(), Param->getBeginLoc(), 2756 tok::equal, ArgOpV, OpV); 2757 2758 if (Res.isInvalid()) 2759 return ExprError(); 2760 Expr *Writeback = Res.get(); 2761 auto *OutExpr = 2762 HLSLOutArgExpr::Create(Ctx, Ty, ArgOpV, OpV, Writeback, IsInOut); 2763 2764 return ExprResult(OutExpr); 2765 } 2766 2767 QualType SemaHLSL::getInoutParameterType(QualType Ty) { 2768 // If HLSL gains support for references, all the cites that use this will need 2769 // to be updated with semantic checking to produce errors for 2770 // pointers/references. 2771 assert(!Ty->isReferenceType() && 2772 "Pointer and reference types cannot be inout or out parameters"); 2773 Ty = SemaRef.getASTContext().getLValueReferenceType(Ty); 2774 Ty.addRestrict(); 2775 return Ty; 2776 } 2777 2778 void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) { 2779 if (VD->hasGlobalStorage()) { 2780 // make sure the declaration has a complete type 2781 if (SemaRef.RequireCompleteType( 2782 VD->getLocation(), 2783 SemaRef.getASTContext().getBaseElementType(VD->getType()), 2784 diag::err_typecheck_decl_incomplete_type)) { 2785 VD->setInvalidDecl(); 2786 return; 2787 } 2788 2789 // find all resources on decl 2790 if (VD->getType()->isHLSLIntangibleType()) 2791 collectResourcesOnVarDecl(VD); 2792 2793 // process explicit bindings 2794 processExplicitBindingsOnDecl(VD); 2795 } 2796 } 2797 2798 // Walks though the global variable declaration, collects all resource binding 2799 // requirements and adds them to Bindings 2800 void SemaHLSL::collectResourcesOnVarDecl(VarDecl *VD) { 2801 assert(VD->hasGlobalStorage() && VD->getType()->isHLSLIntangibleType() && 2802 "expected global variable that contains HLSL resource"); 2803 2804 // Cbuffers and Tbuffers are HLSLBufferDecl types 2805 if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(VD)) { 2806 Bindings.addDeclBindingInfo(VD, CBufferOrTBuffer->isCBuffer() 2807 ? ResourceClass::CBuffer 2808 : ResourceClass::SRV); 2809 return; 2810 } 2811 2812 // Unwrap arrays 2813 // FIXME: Calculate array size while unwrapping 2814 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType(); 2815 while (Ty->isConstantArrayType()) { 2816 const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty); 2817 Ty = CAT->getElementType()->getUnqualifiedDesugaredType(); 2818 } 2819 2820 // Resource (or array of resources) 2821 if (const HLSLAttributedResourceType *AttrResType = 2822 HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) { 2823 Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass); 2824 return; 2825 } 2826 2827 // User defined record type 2828 if (const RecordType *RT = dyn_cast<RecordType>(Ty)) 2829 collectResourcesOnUserRecordDecl(VD, RT); 2830 } 2831 2832 // Walks though the explicit resource binding attributes on the declaration, 2833 // and makes sure there is a resource that matched the binding and updates 2834 // DeclBindingInfoLists 2835 void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) { 2836 assert(VD->hasGlobalStorage() && "expected global variable"); 2837 2838 for (Attr *A : VD->attrs()) { 2839 HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A); 2840 if (!RBA) 2841 continue; 2842 2843 RegisterType RT = RBA->getRegisterType(); 2844 assert(RT != RegisterType::I && "invalid or obsolete register type should " 2845 "never have an attribute created"); 2846 2847 if (RT == RegisterType::C) { 2848 if (Bindings.hasBindingInfoForDecl(VD)) 2849 SemaRef.Diag(VD->getLocation(), 2850 diag::warn_hlsl_user_defined_type_missing_member) 2851 << static_cast<int>(RT); 2852 continue; 2853 } 2854 2855 // Find DeclBindingInfo for this binding and update it, or report error 2856 // if it does not exist (user type does to contain resources with the 2857 // expected resource class). 2858 ResourceClass RC = getResourceClass(RT); 2859 if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, RC)) { 2860 // update binding info 2861 BI->setBindingAttribute(RBA, BindingType::Explicit); 2862 } else { 2863 SemaRef.Diag(VD->getLocation(), 2864 diag::warn_hlsl_user_defined_type_missing_member) 2865 << static_cast<int>(RT); 2866 } 2867 } 2868 } 2869