xref: /llvm-project/clang/lib/Sema/SemaSYCL.cpp (revision eaaac050588ec67afcdbb347df5597458a9b10d1)
1 //===- SemaSYCL.cpp - Semantic Analysis for SYCL constructs ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This implements Semantic Analysis for SYCL constructs.
9 //===----------------------------------------------------------------------===//
10 
11 #include "clang/Sema/SemaSYCL.h"
12 #include "TreeTransform.h"
13 #include "clang/AST/Mangle.h"
14 #include "clang/AST/SYCLKernelInfo.h"
15 #include "clang/AST/StmtSYCL.h"
16 #include "clang/AST/TypeOrdering.h"
17 #include "clang/Basic/Diagnostic.h"
18 #include "clang/Sema/Attr.h"
19 #include "clang/Sema/ParsedAttr.h"
20 #include "clang/Sema/Sema.h"
21 
22 using namespace clang;
23 
24 // -----------------------------------------------------------------------------
25 // SYCL device specific diagnostics implementation
26 // -----------------------------------------------------------------------------
27 
28 SemaSYCL::SemaSYCL(Sema &S) : SemaBase(S) {}
29 
30 Sema::SemaDiagnosticBuilder SemaSYCL::DiagIfDeviceCode(SourceLocation Loc,
31                                                        unsigned DiagID) {
32   assert(getLangOpts().SYCLIsDevice &&
33          "Should only be called during SYCL compilation");
34   FunctionDecl *FD = dyn_cast<FunctionDecl>(SemaRef.getCurLexicalContext());
35   SemaDiagnosticBuilder::Kind DiagKind = [this, FD] {
36     if (!FD)
37       return SemaDiagnosticBuilder::K_Nop;
38     if (SemaRef.getEmissionStatus(FD) == Sema::FunctionEmissionStatus::Emitted)
39       return SemaDiagnosticBuilder::K_ImmediateWithCallStack;
40     return SemaDiagnosticBuilder::K_Deferred;
41   }();
42   return SemaDiagnosticBuilder(DiagKind, Loc, DiagID, FD, SemaRef);
43 }
44 
45 static bool isZeroSizedArray(SemaSYCL &S, QualType Ty) {
46   if (const auto *CAT = S.getASTContext().getAsConstantArrayType(Ty))
47     return CAT->isZeroSize();
48   return false;
49 }
50 
51 void SemaSYCL::deepTypeCheckForDevice(SourceLocation UsedAt,
52                                       llvm::DenseSet<QualType> Visited,
53                                       ValueDecl *DeclToCheck) {
54   assert(getLangOpts().SYCLIsDevice &&
55          "Should only be called during SYCL compilation");
56   // Emit notes only for the first discovered declaration of unsupported type
57   // to avoid mess of notes. This flag is to track that error already happened.
58   bool NeedToEmitNotes = true;
59 
60   auto Check = [&](QualType TypeToCheck, const ValueDecl *D) {
61     bool ErrorFound = false;
62     if (isZeroSizedArray(*this, TypeToCheck)) {
63       DiagIfDeviceCode(UsedAt, diag::err_typecheck_zero_array_size) << 1;
64       ErrorFound = true;
65     }
66     // Checks for other types can also be done here.
67     if (ErrorFound) {
68       if (NeedToEmitNotes) {
69         if (auto *FD = dyn_cast<FieldDecl>(D))
70           DiagIfDeviceCode(FD->getLocation(),
71                            diag::note_illegal_field_declared_here)
72               << FD->getType()->isPointerType() << FD->getType();
73         else
74           DiagIfDeviceCode(D->getLocation(), diag::note_declared_at);
75       }
76     }
77 
78     return ErrorFound;
79   };
80 
81   // In case we have a Record used do the DFS for a bad field.
82   SmallVector<const ValueDecl *, 4> StackForRecursion;
83   StackForRecursion.push_back(DeclToCheck);
84 
85   // While doing DFS save how we get there to emit a nice set of notes.
86   SmallVector<const FieldDecl *, 4> History;
87   History.push_back(nullptr);
88 
89   do {
90     const ValueDecl *Next = StackForRecursion.pop_back_val();
91     if (!Next) {
92       assert(!History.empty());
93       // Found a marker, we have gone up a level.
94       History.pop_back();
95       continue;
96     }
97     QualType NextTy = Next->getType();
98 
99     if (!Visited.insert(NextTy).second)
100       continue;
101 
102     auto EmitHistory = [&]() {
103       // The first element is always nullptr.
104       for (uint64_t Index = 1; Index < History.size(); ++Index) {
105         DiagIfDeviceCode(History[Index]->getLocation(),
106                          diag::note_within_field_of_type)
107             << History[Index]->getType();
108       }
109     };
110 
111     if (Check(NextTy, Next)) {
112       if (NeedToEmitNotes)
113         EmitHistory();
114       NeedToEmitNotes = false;
115     }
116 
117     // In case pointer/array/reference type is met get pointee type, then
118     // proceed with that type.
119     while (NextTy->isAnyPointerType() || NextTy->isArrayType() ||
120            NextTy->isReferenceType()) {
121       if (NextTy->isArrayType())
122         NextTy = QualType{NextTy->getArrayElementTypeNoTypeQual(), 0};
123       else
124         NextTy = NextTy->getPointeeType();
125       if (Check(NextTy, Next)) {
126         if (NeedToEmitNotes)
127           EmitHistory();
128         NeedToEmitNotes = false;
129       }
130     }
131 
132     if (const auto *RecDecl = NextTy->getAsRecordDecl()) {
133       if (auto *NextFD = dyn_cast<FieldDecl>(Next))
134         History.push_back(NextFD);
135       // When nullptr is discovered, this means we've gone back up a level, so
136       // the history should be cleaned.
137       StackForRecursion.push_back(nullptr);
138       llvm::copy(RecDecl->fields(), std::back_inserter(StackForRecursion));
139     }
140   } while (!StackForRecursion.empty());
141 }
142 
143 ExprResult SemaSYCL::BuildUniqueStableNameExpr(SourceLocation OpLoc,
144                                                SourceLocation LParen,
145                                                SourceLocation RParen,
146                                                TypeSourceInfo *TSI) {
147   return SYCLUniqueStableNameExpr::Create(getASTContext(), OpLoc, LParen,
148                                           RParen, TSI);
149 }
150 
151 ExprResult SemaSYCL::ActOnUniqueStableNameExpr(SourceLocation OpLoc,
152                                                SourceLocation LParen,
153                                                SourceLocation RParen,
154                                                ParsedType ParsedTy) {
155   TypeSourceInfo *TSI = nullptr;
156   QualType Ty = SemaRef.GetTypeFromParser(ParsedTy, &TSI);
157 
158   if (Ty.isNull())
159     return ExprError();
160   if (!TSI)
161     TSI = getASTContext().getTrivialTypeSourceInfo(Ty, LParen);
162 
163   return BuildUniqueStableNameExpr(OpLoc, LParen, RParen, TSI);
164 }
165 
166 void SemaSYCL::handleKernelAttr(Decl *D, const ParsedAttr &AL) {
167   // The 'sycl_kernel' attribute applies only to function templates.
168   const auto *FD = cast<FunctionDecl>(D);
169   const FunctionTemplateDecl *FT = FD->getDescribedFunctionTemplate();
170   assert(FT && "Function template is expected");
171 
172   // Function template must have at least two template parameters.
173   const TemplateParameterList *TL = FT->getTemplateParameters();
174   if (TL->size() < 2) {
175     Diag(FT->getLocation(), diag::warn_sycl_kernel_num_of_template_params);
176     return;
177   }
178 
179   // Template parameters must be typenames.
180   for (unsigned I = 0; I < 2; ++I) {
181     const NamedDecl *TParam = TL->getParam(I);
182     if (isa<NonTypeTemplateParmDecl>(TParam)) {
183       Diag(FT->getLocation(),
184            diag::warn_sycl_kernel_invalid_template_param_type);
185       return;
186     }
187   }
188 
189   // Function must have at least one argument.
190   if (getFunctionOrMethodNumParams(D) != 1) {
191     Diag(FT->getLocation(), diag::warn_sycl_kernel_num_of_function_params);
192     return;
193   }
194 
195   // Function must return void.
196   QualType RetTy = getFunctionOrMethodResultType(D);
197   if (!RetTy->isVoidType()) {
198     Diag(FT->getLocation(), diag::warn_sycl_kernel_return_type);
199     return;
200   }
201 
202   handleSimpleAttribute<SYCLKernelAttr>(*this, D, AL);
203 }
204 
205 void SemaSYCL::handleKernelEntryPointAttr(Decl *D, const ParsedAttr &AL) {
206   ParsedType PT = AL.getTypeArg();
207   TypeSourceInfo *TSI = nullptr;
208   (void)SemaRef.GetTypeFromParser(PT, &TSI);
209   assert(TSI && "no type source info for attribute argument");
210   D->addAttr(::new (SemaRef.Context)
211                  SYCLKernelEntryPointAttr(SemaRef.Context, AL, TSI));
212 }
213 
214 // Given a potentially qualified type, SourceLocationForUserDeclaredType()
215 // returns the source location of the canonical declaration of the unqualified
216 // desugared user declared type, if any. For non-user declared types, an
217 // invalid source location is returned. The intended usage of this function
218 // is to identify an appropriate source location, if any, for a
219 // "entity declared here" diagnostic note.
220 static SourceLocation SourceLocationForUserDeclaredType(QualType QT) {
221   SourceLocation Loc;
222   const Type *T = QT->getUnqualifiedDesugaredType();
223   if (const TagType *TT = dyn_cast<TagType>(T))
224     Loc = TT->getDecl()->getLocation();
225   else if (const ObjCInterfaceType *ObjCIT = dyn_cast<ObjCInterfaceType>(T))
226     Loc = ObjCIT->getDecl()->getLocation();
227   return Loc;
228 }
229 
230 static bool CheckSYCLKernelName(Sema &S, SourceLocation Loc,
231                                 QualType KernelName) {
232   assert(!KernelName->isDependentType());
233 
234   if (!KernelName->isStructureOrClassType()) {
235     // SYCL 2020 section 5.2, "Naming of kernels", only requires that the
236     // kernel name be a C++ typename. However, the definition of "kernel name"
237     // in the glossary states that a kernel name is a class type. Neither
238     // section explicitly states whether the kernel name type can be
239     // cv-qualified. For now, kernel name types are required to be class types
240     // and that they may be cv-qualified. The following issue requests
241     // clarification from the SYCL WG.
242     //   https://github.com/KhronosGroup/SYCL-Docs/issues/568
243     S.Diag(Loc, diag::warn_sycl_kernel_name_not_a_class_type) << KernelName;
244     SourceLocation DeclTypeLoc = SourceLocationForUserDeclaredType(KernelName);
245     if (DeclTypeLoc.isValid())
246       S.Diag(DeclTypeLoc, diag::note_entity_declared_at) << KernelName;
247     return true;
248   }
249 
250   return false;
251 }
252 
253 void SemaSYCL::CheckSYCLEntryPointFunctionDecl(FunctionDecl *FD) {
254   // Ensure that all attributes present on the declaration are consistent
255   // and warn about any redundant ones.
256   SYCLKernelEntryPointAttr *SKEPAttr = nullptr;
257   for (auto *SAI : FD->specific_attrs<SYCLKernelEntryPointAttr>()) {
258     if (!SKEPAttr) {
259       SKEPAttr = SAI;
260       continue;
261     }
262     if (!getASTContext().hasSameType(SAI->getKernelName(),
263                                      SKEPAttr->getKernelName())) {
264       Diag(SAI->getLocation(), diag::err_sycl_entry_point_invalid_redeclaration)
265           << SAI->getKernelName() << SKEPAttr->getKernelName();
266       Diag(SKEPAttr->getLocation(), diag::note_previous_attribute);
267       SAI->setInvalidAttr();
268     } else {
269       Diag(SAI->getLocation(),
270            diag::warn_sycl_entry_point_redundant_declaration);
271       Diag(SKEPAttr->getLocation(), diag::note_previous_attribute);
272     }
273   }
274   assert(SKEPAttr && "Missing sycl_kernel_entry_point attribute");
275 
276   // Ensure the kernel name type is valid.
277   if (!SKEPAttr->getKernelName()->isDependentType() &&
278       CheckSYCLKernelName(SemaRef, SKEPAttr->getLocation(),
279                           SKEPAttr->getKernelName()))
280     SKEPAttr->setInvalidAttr();
281 
282   // Ensure that an attribute present on the previous declaration
283   // matches the one on this declaration.
284   FunctionDecl *PrevFD = FD->getPreviousDecl();
285   if (PrevFD && !PrevFD->isInvalidDecl()) {
286     const auto *PrevSKEPAttr = PrevFD->getAttr<SYCLKernelEntryPointAttr>();
287     if (PrevSKEPAttr && !PrevSKEPAttr->isInvalidAttr()) {
288       if (!getASTContext().hasSameType(SKEPAttr->getKernelName(),
289                                        PrevSKEPAttr->getKernelName())) {
290         Diag(SKEPAttr->getLocation(),
291              diag::err_sycl_entry_point_invalid_redeclaration)
292             << SKEPAttr->getKernelName() << PrevSKEPAttr->getKernelName();
293         Diag(PrevSKEPAttr->getLocation(), diag::note_previous_decl) << PrevFD;
294         SKEPAttr->setInvalidAttr();
295       }
296     }
297   }
298 
299   if (const auto *MD = dyn_cast<CXXMethodDecl>(FD)) {
300     if (!MD->isStatic()) {
301       Diag(SKEPAttr->getLocation(), diag::err_sycl_entry_point_invalid)
302           << /*non-static member function*/ 0;
303       SKEPAttr->setInvalidAttr();
304     }
305   }
306 
307   if (FD->isVariadic()) {
308     Diag(SKEPAttr->getLocation(), diag::err_sycl_entry_point_invalid)
309         << /*variadic function*/ 1;
310     SKEPAttr->setInvalidAttr();
311   }
312 
313   if (FD->isDefaulted()) {
314     Diag(SKEPAttr->getLocation(), diag::err_sycl_entry_point_invalid)
315         << /*defaulted function*/ 3;
316     SKEPAttr->setInvalidAttr();
317   } else if (FD->isDeleted()) {
318     Diag(SKEPAttr->getLocation(), diag::err_sycl_entry_point_invalid)
319         << /*deleted function*/ 2;
320     SKEPAttr->setInvalidAttr();
321   }
322 
323   if (FD->isConsteval()) {
324     Diag(SKEPAttr->getLocation(), diag::err_sycl_entry_point_invalid)
325         << /*consteval function*/ 5;
326     SKEPAttr->setInvalidAttr();
327   } else if (FD->isConstexpr()) {
328     Diag(SKEPAttr->getLocation(), diag::err_sycl_entry_point_invalid)
329         << /*constexpr function*/ 4;
330     SKEPAttr->setInvalidAttr();
331   }
332 
333   if (FD->isNoReturn()) {
334     Diag(SKEPAttr->getLocation(), diag::err_sycl_entry_point_invalid)
335         << /*function declared with the 'noreturn' attribute*/ 6;
336     SKEPAttr->setInvalidAttr();
337   }
338 
339   if (FD->getReturnType()->isUndeducedType()) {
340     Diag(SKEPAttr->getLocation(),
341          diag::err_sycl_entry_point_deduced_return_type);
342     SKEPAttr->setInvalidAttr();
343   } else if (!FD->getReturnType()->isDependentType() &&
344              !FD->getReturnType()->isVoidType()) {
345     Diag(SKEPAttr->getLocation(), diag::err_sycl_entry_point_return_type);
346     SKEPAttr->setInvalidAttr();
347   }
348 
349   if (!FD->isInvalidDecl() && !FD->isTemplated() &&
350       !SKEPAttr->isInvalidAttr()) {
351     const SYCLKernelInfo *SKI =
352         getASTContext().findSYCLKernelInfo(SKEPAttr->getKernelName());
353     if (SKI) {
354       if (!declaresSameEntity(FD, SKI->getKernelEntryPointDecl())) {
355         // FIXME: This diagnostic should include the origin of the kernel
356         // FIXME: names; not just the locations of the conflicting declarations.
357         Diag(FD->getLocation(), diag::err_sycl_kernel_name_conflict);
358         Diag(SKI->getKernelEntryPointDecl()->getLocation(),
359              diag::note_previous_declaration);
360         SKEPAttr->setInvalidAttr();
361       }
362     } else {
363       getASTContext().registerSYCLEntryPointFunction(FD);
364     }
365   }
366 }
367 
368 namespace {
369 
370 // The body of a function declared with the [[sycl_kernel_entry_point]]
371 // attribute is cloned and transformed to substitute references to the original
372 // function parameters with references to replacement variables that stand in
373 // for SYCL kernel parameters or local variables that reconstitute a decomposed
374 // SYCL kernel argument.
375 class OutlinedFunctionDeclBodyInstantiator
376     : public TreeTransform<OutlinedFunctionDeclBodyInstantiator> {
377 public:
378   using ParmDeclMap = llvm::DenseMap<ParmVarDecl *, VarDecl *>;
379 
380   OutlinedFunctionDeclBodyInstantiator(Sema &S, ParmDeclMap &M)
381       : TreeTransform<OutlinedFunctionDeclBodyInstantiator>(S), SemaRef(S),
382         MapRef(M) {}
383 
384   // A new set of AST nodes is always required.
385   bool AlwaysRebuild() { return true; }
386 
387   // Transform ParmVarDecl references to the supplied replacement variables.
388   ExprResult TransformDeclRefExpr(DeclRefExpr *DRE) {
389     const ParmVarDecl *PVD = dyn_cast<ParmVarDecl>(DRE->getDecl());
390     if (PVD) {
391       ParmDeclMap::iterator I = MapRef.find(PVD);
392       if (I != MapRef.end()) {
393         VarDecl *VD = I->second;
394         assert(SemaRef.getASTContext().hasSameUnqualifiedType(PVD->getType(),
395                                                               VD->getType()));
396         assert(!VD->getType().isMoreQualifiedThan(PVD->getType(),
397                                                   SemaRef.getASTContext()));
398         VD->setIsUsed();
399         return DeclRefExpr::Create(
400             SemaRef.getASTContext(), DRE->getQualifierLoc(),
401             DRE->getTemplateKeywordLoc(), VD, false, DRE->getNameInfo(),
402             DRE->getType(), DRE->getValueKind());
403       }
404     }
405     return DRE;
406   }
407 
408 private:
409   Sema &SemaRef;
410   ParmDeclMap &MapRef;
411 };
412 
413 } // unnamed namespace
414 
415 StmtResult SemaSYCL::BuildSYCLKernelCallStmt(FunctionDecl *FD,
416                                              CompoundStmt *Body) {
417   assert(!FD->isInvalidDecl());
418   assert(!FD->isTemplated());
419   assert(FD->hasPrototype());
420 
421   const auto *SKEPAttr = FD->getAttr<SYCLKernelEntryPointAttr>();
422   assert(SKEPAttr && "Missing sycl_kernel_entry_point attribute");
423   assert(!SKEPAttr->isInvalidAttr() &&
424          "sycl_kernel_entry_point attribute is invalid");
425 
426   // Ensure that the kernel name was previously registered and that the
427   // stored declaration matches.
428   const SYCLKernelInfo &SKI =
429       getASTContext().getSYCLKernelInfo(SKEPAttr->getKernelName());
430   assert(declaresSameEntity(SKI.getKernelEntryPointDecl(), FD) &&
431          "SYCL kernel name conflict");
432   (void)SKI;
433 
434   using ParmDeclMap = OutlinedFunctionDeclBodyInstantiator::ParmDeclMap;
435   ParmDeclMap ParmMap;
436 
437   assert(SemaRef.CurContext == FD);
438   OutlinedFunctionDecl *OFD =
439       OutlinedFunctionDecl::Create(getASTContext(), FD, FD->getNumParams());
440   unsigned i = 0;
441   for (ParmVarDecl *PVD : FD->parameters()) {
442     ImplicitParamDecl *IPD = ImplicitParamDecl::Create(
443         getASTContext(), OFD, SourceLocation(), PVD->getIdentifier(),
444         PVD->getType(), ImplicitParamKind::Other);
445     OFD->setParam(i, IPD);
446     ParmMap[PVD] = IPD;
447     ++i;
448   }
449 
450   OutlinedFunctionDeclBodyInstantiator OFDBodyInstantiator(SemaRef, ParmMap);
451   Stmt *OFDBody = OFDBodyInstantiator.TransformStmt(Body).get();
452   OFD->setBody(OFDBody);
453   OFD->setNothrow();
454   Stmt *NewBody = new (getASTContext()) SYCLKernelCallStmt(Body, OFD);
455 
456   return NewBody;
457 }
458