1 //===- SemaSYCL.cpp - Semantic Analysis for SYCL constructs ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // This implements Semantic Analysis for SYCL constructs. 9 //===----------------------------------------------------------------------===// 10 11 #include "clang/Sema/SemaSYCL.h" 12 #include "TreeTransform.h" 13 #include "clang/AST/Mangle.h" 14 #include "clang/AST/SYCLKernelInfo.h" 15 #include "clang/AST/StmtSYCL.h" 16 #include "clang/AST/TypeOrdering.h" 17 #include "clang/Basic/Diagnostic.h" 18 #include "clang/Sema/Attr.h" 19 #include "clang/Sema/ParsedAttr.h" 20 #include "clang/Sema/Sema.h" 21 22 using namespace clang; 23 24 // ----------------------------------------------------------------------------- 25 // SYCL device specific diagnostics implementation 26 // ----------------------------------------------------------------------------- 27 28 SemaSYCL::SemaSYCL(Sema &S) : SemaBase(S) {} 29 30 Sema::SemaDiagnosticBuilder SemaSYCL::DiagIfDeviceCode(SourceLocation Loc, 31 unsigned DiagID) { 32 assert(getLangOpts().SYCLIsDevice && 33 "Should only be called during SYCL compilation"); 34 FunctionDecl *FD = dyn_cast<FunctionDecl>(SemaRef.getCurLexicalContext()); 35 SemaDiagnosticBuilder::Kind DiagKind = [this, FD] { 36 if (!FD) 37 return SemaDiagnosticBuilder::K_Nop; 38 if (SemaRef.getEmissionStatus(FD) == Sema::FunctionEmissionStatus::Emitted) 39 return SemaDiagnosticBuilder::K_ImmediateWithCallStack; 40 return SemaDiagnosticBuilder::K_Deferred; 41 }(); 42 return SemaDiagnosticBuilder(DiagKind, Loc, DiagID, FD, SemaRef); 43 } 44 45 static bool isZeroSizedArray(SemaSYCL &S, QualType Ty) { 46 if (const auto *CAT = S.getASTContext().getAsConstantArrayType(Ty)) 47 return CAT->isZeroSize(); 48 return false; 49 } 50 51 void SemaSYCL::deepTypeCheckForDevice(SourceLocation UsedAt, 52 llvm::DenseSet<QualType> Visited, 53 ValueDecl *DeclToCheck) { 54 assert(getLangOpts().SYCLIsDevice && 55 "Should only be called during SYCL compilation"); 56 // Emit notes only for the first discovered declaration of unsupported type 57 // to avoid mess of notes. This flag is to track that error already happened. 58 bool NeedToEmitNotes = true; 59 60 auto Check = [&](QualType TypeToCheck, const ValueDecl *D) { 61 bool ErrorFound = false; 62 if (isZeroSizedArray(*this, TypeToCheck)) { 63 DiagIfDeviceCode(UsedAt, diag::err_typecheck_zero_array_size) << 1; 64 ErrorFound = true; 65 } 66 // Checks for other types can also be done here. 67 if (ErrorFound) { 68 if (NeedToEmitNotes) { 69 if (auto *FD = dyn_cast<FieldDecl>(D)) 70 DiagIfDeviceCode(FD->getLocation(), 71 diag::note_illegal_field_declared_here) 72 << FD->getType()->isPointerType() << FD->getType(); 73 else 74 DiagIfDeviceCode(D->getLocation(), diag::note_declared_at); 75 } 76 } 77 78 return ErrorFound; 79 }; 80 81 // In case we have a Record used do the DFS for a bad field. 82 SmallVector<const ValueDecl *, 4> StackForRecursion; 83 StackForRecursion.push_back(DeclToCheck); 84 85 // While doing DFS save how we get there to emit a nice set of notes. 86 SmallVector<const FieldDecl *, 4> History; 87 History.push_back(nullptr); 88 89 do { 90 const ValueDecl *Next = StackForRecursion.pop_back_val(); 91 if (!Next) { 92 assert(!History.empty()); 93 // Found a marker, we have gone up a level. 94 History.pop_back(); 95 continue; 96 } 97 QualType NextTy = Next->getType(); 98 99 if (!Visited.insert(NextTy).second) 100 continue; 101 102 auto EmitHistory = [&]() { 103 // The first element is always nullptr. 104 for (uint64_t Index = 1; Index < History.size(); ++Index) { 105 DiagIfDeviceCode(History[Index]->getLocation(), 106 diag::note_within_field_of_type) 107 << History[Index]->getType(); 108 } 109 }; 110 111 if (Check(NextTy, Next)) { 112 if (NeedToEmitNotes) 113 EmitHistory(); 114 NeedToEmitNotes = false; 115 } 116 117 // In case pointer/array/reference type is met get pointee type, then 118 // proceed with that type. 119 while (NextTy->isAnyPointerType() || NextTy->isArrayType() || 120 NextTy->isReferenceType()) { 121 if (NextTy->isArrayType()) 122 NextTy = QualType{NextTy->getArrayElementTypeNoTypeQual(), 0}; 123 else 124 NextTy = NextTy->getPointeeType(); 125 if (Check(NextTy, Next)) { 126 if (NeedToEmitNotes) 127 EmitHistory(); 128 NeedToEmitNotes = false; 129 } 130 } 131 132 if (const auto *RecDecl = NextTy->getAsRecordDecl()) { 133 if (auto *NextFD = dyn_cast<FieldDecl>(Next)) 134 History.push_back(NextFD); 135 // When nullptr is discovered, this means we've gone back up a level, so 136 // the history should be cleaned. 137 StackForRecursion.push_back(nullptr); 138 llvm::copy(RecDecl->fields(), std::back_inserter(StackForRecursion)); 139 } 140 } while (!StackForRecursion.empty()); 141 } 142 143 ExprResult SemaSYCL::BuildUniqueStableNameExpr(SourceLocation OpLoc, 144 SourceLocation LParen, 145 SourceLocation RParen, 146 TypeSourceInfo *TSI) { 147 return SYCLUniqueStableNameExpr::Create(getASTContext(), OpLoc, LParen, 148 RParen, TSI); 149 } 150 151 ExprResult SemaSYCL::ActOnUniqueStableNameExpr(SourceLocation OpLoc, 152 SourceLocation LParen, 153 SourceLocation RParen, 154 ParsedType ParsedTy) { 155 TypeSourceInfo *TSI = nullptr; 156 QualType Ty = SemaRef.GetTypeFromParser(ParsedTy, &TSI); 157 158 if (Ty.isNull()) 159 return ExprError(); 160 if (!TSI) 161 TSI = getASTContext().getTrivialTypeSourceInfo(Ty, LParen); 162 163 return BuildUniqueStableNameExpr(OpLoc, LParen, RParen, TSI); 164 } 165 166 void SemaSYCL::handleKernelAttr(Decl *D, const ParsedAttr &AL) { 167 // The 'sycl_kernel' attribute applies only to function templates. 168 const auto *FD = cast<FunctionDecl>(D); 169 const FunctionTemplateDecl *FT = FD->getDescribedFunctionTemplate(); 170 assert(FT && "Function template is expected"); 171 172 // Function template must have at least two template parameters. 173 const TemplateParameterList *TL = FT->getTemplateParameters(); 174 if (TL->size() < 2) { 175 Diag(FT->getLocation(), diag::warn_sycl_kernel_num_of_template_params); 176 return; 177 } 178 179 // Template parameters must be typenames. 180 for (unsigned I = 0; I < 2; ++I) { 181 const NamedDecl *TParam = TL->getParam(I); 182 if (isa<NonTypeTemplateParmDecl>(TParam)) { 183 Diag(FT->getLocation(), 184 diag::warn_sycl_kernel_invalid_template_param_type); 185 return; 186 } 187 } 188 189 // Function must have at least one argument. 190 if (getFunctionOrMethodNumParams(D) != 1) { 191 Diag(FT->getLocation(), diag::warn_sycl_kernel_num_of_function_params); 192 return; 193 } 194 195 // Function must return void. 196 QualType RetTy = getFunctionOrMethodResultType(D); 197 if (!RetTy->isVoidType()) { 198 Diag(FT->getLocation(), diag::warn_sycl_kernel_return_type); 199 return; 200 } 201 202 handleSimpleAttribute<SYCLKernelAttr>(*this, D, AL); 203 } 204 205 void SemaSYCL::handleKernelEntryPointAttr(Decl *D, const ParsedAttr &AL) { 206 ParsedType PT = AL.getTypeArg(); 207 TypeSourceInfo *TSI = nullptr; 208 (void)SemaRef.GetTypeFromParser(PT, &TSI); 209 assert(TSI && "no type source info for attribute argument"); 210 D->addAttr(::new (SemaRef.Context) 211 SYCLKernelEntryPointAttr(SemaRef.Context, AL, TSI)); 212 } 213 214 // Given a potentially qualified type, SourceLocationForUserDeclaredType() 215 // returns the source location of the canonical declaration of the unqualified 216 // desugared user declared type, if any. For non-user declared types, an 217 // invalid source location is returned. The intended usage of this function 218 // is to identify an appropriate source location, if any, for a 219 // "entity declared here" diagnostic note. 220 static SourceLocation SourceLocationForUserDeclaredType(QualType QT) { 221 SourceLocation Loc; 222 const Type *T = QT->getUnqualifiedDesugaredType(); 223 if (const TagType *TT = dyn_cast<TagType>(T)) 224 Loc = TT->getDecl()->getLocation(); 225 else if (const ObjCInterfaceType *ObjCIT = dyn_cast<ObjCInterfaceType>(T)) 226 Loc = ObjCIT->getDecl()->getLocation(); 227 return Loc; 228 } 229 230 static bool CheckSYCLKernelName(Sema &S, SourceLocation Loc, 231 QualType KernelName) { 232 assert(!KernelName->isDependentType()); 233 234 if (!KernelName->isStructureOrClassType()) { 235 // SYCL 2020 section 5.2, "Naming of kernels", only requires that the 236 // kernel name be a C++ typename. However, the definition of "kernel name" 237 // in the glossary states that a kernel name is a class type. Neither 238 // section explicitly states whether the kernel name type can be 239 // cv-qualified. For now, kernel name types are required to be class types 240 // and that they may be cv-qualified. The following issue requests 241 // clarification from the SYCL WG. 242 // https://github.com/KhronosGroup/SYCL-Docs/issues/568 243 S.Diag(Loc, diag::warn_sycl_kernel_name_not_a_class_type) << KernelName; 244 SourceLocation DeclTypeLoc = SourceLocationForUserDeclaredType(KernelName); 245 if (DeclTypeLoc.isValid()) 246 S.Diag(DeclTypeLoc, diag::note_entity_declared_at) << KernelName; 247 return true; 248 } 249 250 return false; 251 } 252 253 void SemaSYCL::CheckSYCLEntryPointFunctionDecl(FunctionDecl *FD) { 254 // Ensure that all attributes present on the declaration are consistent 255 // and warn about any redundant ones. 256 SYCLKernelEntryPointAttr *SKEPAttr = nullptr; 257 for (auto *SAI : FD->specific_attrs<SYCLKernelEntryPointAttr>()) { 258 if (!SKEPAttr) { 259 SKEPAttr = SAI; 260 continue; 261 } 262 if (!getASTContext().hasSameType(SAI->getKernelName(), 263 SKEPAttr->getKernelName())) { 264 Diag(SAI->getLocation(), diag::err_sycl_entry_point_invalid_redeclaration) 265 << SAI->getKernelName() << SKEPAttr->getKernelName(); 266 Diag(SKEPAttr->getLocation(), diag::note_previous_attribute); 267 SAI->setInvalidAttr(); 268 } else { 269 Diag(SAI->getLocation(), 270 diag::warn_sycl_entry_point_redundant_declaration); 271 Diag(SKEPAttr->getLocation(), diag::note_previous_attribute); 272 } 273 } 274 assert(SKEPAttr && "Missing sycl_kernel_entry_point attribute"); 275 276 // Ensure the kernel name type is valid. 277 if (!SKEPAttr->getKernelName()->isDependentType() && 278 CheckSYCLKernelName(SemaRef, SKEPAttr->getLocation(), 279 SKEPAttr->getKernelName())) 280 SKEPAttr->setInvalidAttr(); 281 282 // Ensure that an attribute present on the previous declaration 283 // matches the one on this declaration. 284 FunctionDecl *PrevFD = FD->getPreviousDecl(); 285 if (PrevFD && !PrevFD->isInvalidDecl()) { 286 const auto *PrevSKEPAttr = PrevFD->getAttr<SYCLKernelEntryPointAttr>(); 287 if (PrevSKEPAttr && !PrevSKEPAttr->isInvalidAttr()) { 288 if (!getASTContext().hasSameType(SKEPAttr->getKernelName(), 289 PrevSKEPAttr->getKernelName())) { 290 Diag(SKEPAttr->getLocation(), 291 diag::err_sycl_entry_point_invalid_redeclaration) 292 << SKEPAttr->getKernelName() << PrevSKEPAttr->getKernelName(); 293 Diag(PrevSKEPAttr->getLocation(), diag::note_previous_decl) << PrevFD; 294 SKEPAttr->setInvalidAttr(); 295 } 296 } 297 } 298 299 if (const auto *MD = dyn_cast<CXXMethodDecl>(FD)) { 300 if (!MD->isStatic()) { 301 Diag(SKEPAttr->getLocation(), diag::err_sycl_entry_point_invalid) 302 << /*non-static member function*/ 0; 303 SKEPAttr->setInvalidAttr(); 304 } 305 } 306 307 if (FD->isVariadic()) { 308 Diag(SKEPAttr->getLocation(), diag::err_sycl_entry_point_invalid) 309 << /*variadic function*/ 1; 310 SKEPAttr->setInvalidAttr(); 311 } 312 313 if (FD->isDefaulted()) { 314 Diag(SKEPAttr->getLocation(), diag::err_sycl_entry_point_invalid) 315 << /*defaulted function*/ 3; 316 SKEPAttr->setInvalidAttr(); 317 } else if (FD->isDeleted()) { 318 Diag(SKEPAttr->getLocation(), diag::err_sycl_entry_point_invalid) 319 << /*deleted function*/ 2; 320 SKEPAttr->setInvalidAttr(); 321 } 322 323 if (FD->isConsteval()) { 324 Diag(SKEPAttr->getLocation(), diag::err_sycl_entry_point_invalid) 325 << /*consteval function*/ 5; 326 SKEPAttr->setInvalidAttr(); 327 } else if (FD->isConstexpr()) { 328 Diag(SKEPAttr->getLocation(), diag::err_sycl_entry_point_invalid) 329 << /*constexpr function*/ 4; 330 SKEPAttr->setInvalidAttr(); 331 } 332 333 if (FD->isNoReturn()) { 334 Diag(SKEPAttr->getLocation(), diag::err_sycl_entry_point_invalid) 335 << /*function declared with the 'noreturn' attribute*/ 6; 336 SKEPAttr->setInvalidAttr(); 337 } 338 339 if (FD->getReturnType()->isUndeducedType()) { 340 Diag(SKEPAttr->getLocation(), 341 diag::err_sycl_entry_point_deduced_return_type); 342 SKEPAttr->setInvalidAttr(); 343 } else if (!FD->getReturnType()->isDependentType() && 344 !FD->getReturnType()->isVoidType()) { 345 Diag(SKEPAttr->getLocation(), diag::err_sycl_entry_point_return_type); 346 SKEPAttr->setInvalidAttr(); 347 } 348 349 if (!FD->isInvalidDecl() && !FD->isTemplated() && 350 !SKEPAttr->isInvalidAttr()) { 351 const SYCLKernelInfo *SKI = 352 getASTContext().findSYCLKernelInfo(SKEPAttr->getKernelName()); 353 if (SKI) { 354 if (!declaresSameEntity(FD, SKI->getKernelEntryPointDecl())) { 355 // FIXME: This diagnostic should include the origin of the kernel 356 // FIXME: names; not just the locations of the conflicting declarations. 357 Diag(FD->getLocation(), diag::err_sycl_kernel_name_conflict); 358 Diag(SKI->getKernelEntryPointDecl()->getLocation(), 359 diag::note_previous_declaration); 360 SKEPAttr->setInvalidAttr(); 361 } 362 } else { 363 getASTContext().registerSYCLEntryPointFunction(FD); 364 } 365 } 366 } 367 368 namespace { 369 370 // The body of a function declared with the [[sycl_kernel_entry_point]] 371 // attribute is cloned and transformed to substitute references to the original 372 // function parameters with references to replacement variables that stand in 373 // for SYCL kernel parameters or local variables that reconstitute a decomposed 374 // SYCL kernel argument. 375 class OutlinedFunctionDeclBodyInstantiator 376 : public TreeTransform<OutlinedFunctionDeclBodyInstantiator> { 377 public: 378 using ParmDeclMap = llvm::DenseMap<ParmVarDecl *, VarDecl *>; 379 380 OutlinedFunctionDeclBodyInstantiator(Sema &S, ParmDeclMap &M) 381 : TreeTransform<OutlinedFunctionDeclBodyInstantiator>(S), SemaRef(S), 382 MapRef(M) {} 383 384 // A new set of AST nodes is always required. 385 bool AlwaysRebuild() { return true; } 386 387 // Transform ParmVarDecl references to the supplied replacement variables. 388 ExprResult TransformDeclRefExpr(DeclRefExpr *DRE) { 389 const ParmVarDecl *PVD = dyn_cast<ParmVarDecl>(DRE->getDecl()); 390 if (PVD) { 391 ParmDeclMap::iterator I = MapRef.find(PVD); 392 if (I != MapRef.end()) { 393 VarDecl *VD = I->second; 394 assert(SemaRef.getASTContext().hasSameUnqualifiedType(PVD->getType(), 395 VD->getType())); 396 assert(!VD->getType().isMoreQualifiedThan(PVD->getType(), 397 SemaRef.getASTContext())); 398 VD->setIsUsed(); 399 return DeclRefExpr::Create( 400 SemaRef.getASTContext(), DRE->getQualifierLoc(), 401 DRE->getTemplateKeywordLoc(), VD, false, DRE->getNameInfo(), 402 DRE->getType(), DRE->getValueKind()); 403 } 404 } 405 return DRE; 406 } 407 408 private: 409 Sema &SemaRef; 410 ParmDeclMap &MapRef; 411 }; 412 413 } // unnamed namespace 414 415 StmtResult SemaSYCL::BuildSYCLKernelCallStmt(FunctionDecl *FD, 416 CompoundStmt *Body) { 417 assert(!FD->isInvalidDecl()); 418 assert(!FD->isTemplated()); 419 assert(FD->hasPrototype()); 420 421 const auto *SKEPAttr = FD->getAttr<SYCLKernelEntryPointAttr>(); 422 assert(SKEPAttr && "Missing sycl_kernel_entry_point attribute"); 423 assert(!SKEPAttr->isInvalidAttr() && 424 "sycl_kernel_entry_point attribute is invalid"); 425 426 // Ensure that the kernel name was previously registered and that the 427 // stored declaration matches. 428 const SYCLKernelInfo &SKI = 429 getASTContext().getSYCLKernelInfo(SKEPAttr->getKernelName()); 430 assert(declaresSameEntity(SKI.getKernelEntryPointDecl(), FD) && 431 "SYCL kernel name conflict"); 432 (void)SKI; 433 434 using ParmDeclMap = OutlinedFunctionDeclBodyInstantiator::ParmDeclMap; 435 ParmDeclMap ParmMap; 436 437 assert(SemaRef.CurContext == FD); 438 OutlinedFunctionDecl *OFD = 439 OutlinedFunctionDecl::Create(getASTContext(), FD, FD->getNumParams()); 440 unsigned i = 0; 441 for (ParmVarDecl *PVD : FD->parameters()) { 442 ImplicitParamDecl *IPD = ImplicitParamDecl::Create( 443 getASTContext(), OFD, SourceLocation(), PVD->getIdentifier(), 444 PVD->getType(), ImplicitParamKind::Other); 445 OFD->setParam(i, IPD); 446 ParmMap[PVD] = IPD; 447 ++i; 448 } 449 450 OutlinedFunctionDeclBodyInstantiator OFDBodyInstantiator(SemaRef, ParmMap); 451 Stmt *OFDBody = OFDBodyInstantiator.TransformStmt(Body).get(); 452 OFD->setBody(OFDBody); 453 OFD->setNothrow(); 454 Stmt *NewBody = new (getASTContext()) SYCLKernelCallStmt(Body, OFD); 455 456 return NewBody; 457 } 458