xref: /llvm-project/clang-tools-extra/clang-tidy/modernize/UseConstraintsCheck.cpp (revision e42b799bb28815431f2c5a95f7e13fde3f1b36a1)
1 //===--- UseConstraintsCheck.cpp - clang-tidy -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "UseConstraintsCheck.h"
10 #include "clang/AST/ASTContext.h"
11 #include "clang/ASTMatchers/ASTMatchFinder.h"
12 #include "clang/Lex/Lexer.h"
13 
14 #include "../utils/LexerUtils.h"
15 
16 #include <optional>
17 #include <utility>
18 
19 using namespace clang::ast_matchers;
20 
21 namespace clang::tidy::modernize {
22 
23 struct EnableIfData {
24   TemplateSpecializationTypeLoc Loc;
25   TypeLoc Outer;
26 };
27 
28 namespace {
AST_MATCHER(FunctionDecl,hasOtherDeclarations)29 AST_MATCHER(FunctionDecl, hasOtherDeclarations) {
30   auto It = Node.redecls_begin();
31   auto EndIt = Node.redecls_end();
32 
33   if (It == EndIt)
34     return false;
35 
36   ++It;
37   return It != EndIt;
38 }
39 } // namespace
40 
registerMatchers(MatchFinder * Finder)41 void UseConstraintsCheck::registerMatchers(MatchFinder *Finder) {
42   Finder->addMatcher(
43       functionTemplateDecl(
44           // Skip external libraries included as system headers
45           unless(isExpansionInSystemHeader()),
46           has(functionDecl(unless(hasOtherDeclarations()), isDefinition(),
47                            hasReturnTypeLoc(typeLoc().bind("return")))
48                   .bind("function")))
49           .bind("functionTemplate"),
50       this);
51 }
52 
53 static std::optional<TemplateSpecializationTypeLoc>
matchEnableIfSpecializationImplTypename(TypeLoc TheType)54 matchEnableIfSpecializationImplTypename(TypeLoc TheType) {
55   if (const auto Dep = TheType.getAs<DependentNameTypeLoc>()) {
56     const IdentifierInfo *Identifier = Dep.getTypePtr()->getIdentifier();
57     if (!Identifier || Identifier->getName() != "type" ||
58         Dep.getTypePtr()->getKeyword() != ElaboratedTypeKeyword::Typename) {
59       return std::nullopt;
60     }
61     TheType = Dep.getQualifierLoc().getTypeLoc();
62     if (TheType.isNull())
63       return std::nullopt;
64   }
65 
66   if (const auto SpecializationLoc =
67           TheType.getAs<TemplateSpecializationTypeLoc>()) {
68 
69     const auto *Specialization =
70         dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
71     if (!Specialization)
72       return std::nullopt;
73 
74     const TemplateDecl *TD =
75         Specialization->getTemplateName().getAsTemplateDecl();
76     if (!TD || TD->getName() != "enable_if")
77       return std::nullopt;
78 
79     int NumArgs = SpecializationLoc.getNumArgs();
80     if (NumArgs != 1 && NumArgs != 2)
81       return std::nullopt;
82 
83     return SpecializationLoc;
84   }
85   return std::nullopt;
86 }
87 
88 static std::optional<TemplateSpecializationTypeLoc>
matchEnableIfSpecializationImplTrait(TypeLoc TheType)89 matchEnableIfSpecializationImplTrait(TypeLoc TheType) {
90   if (const auto Elaborated = TheType.getAs<ElaboratedTypeLoc>())
91     TheType = Elaborated.getNamedTypeLoc();
92 
93   if (const auto SpecializationLoc =
94           TheType.getAs<TemplateSpecializationTypeLoc>()) {
95 
96     const auto *Specialization =
97         dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
98     if (!Specialization)
99       return std::nullopt;
100 
101     const TemplateDecl *TD =
102         Specialization->getTemplateName().getAsTemplateDecl();
103     if (!TD || TD->getName() != "enable_if_t")
104       return std::nullopt;
105 
106     if (!Specialization->isTypeAlias())
107       return std::nullopt;
108 
109     if (const auto *AliasedType =
110             dyn_cast<DependentNameType>(Specialization->getAliasedType())) {
111       if (AliasedType->getIdentifier()->getName() != "type" ||
112           AliasedType->getKeyword() != ElaboratedTypeKeyword::Typename) {
113         return std::nullopt;
114       }
115     } else {
116       return std::nullopt;
117     }
118     int NumArgs = SpecializationLoc.getNumArgs();
119     if (NumArgs != 1 && NumArgs != 2)
120       return std::nullopt;
121 
122     return SpecializationLoc;
123   }
124   return std::nullopt;
125 }
126 
127 static std::optional<TemplateSpecializationTypeLoc>
matchEnableIfSpecializationImpl(TypeLoc TheType)128 matchEnableIfSpecializationImpl(TypeLoc TheType) {
129   if (auto EnableIf = matchEnableIfSpecializationImplTypename(TheType))
130     return EnableIf;
131   return matchEnableIfSpecializationImplTrait(TheType);
132 }
133 
134 static std::optional<EnableIfData>
matchEnableIfSpecialization(TypeLoc TheType)135 matchEnableIfSpecialization(TypeLoc TheType) {
136   if (const auto Pointer = TheType.getAs<PointerTypeLoc>())
137     TheType = Pointer.getPointeeLoc();
138   else if (const auto Reference = TheType.getAs<ReferenceTypeLoc>())
139     TheType = Reference.getPointeeLoc();
140   if (const auto Qualified = TheType.getAs<QualifiedTypeLoc>())
141     TheType = Qualified.getUnqualifiedLoc();
142 
143   if (auto EnableIf = matchEnableIfSpecializationImpl(TheType))
144     return EnableIfData{std::move(*EnableIf), TheType};
145   return std::nullopt;
146 }
147 
148 static std::pair<std::optional<EnableIfData>, const Decl *>
matchTrailingTemplateParam(const FunctionTemplateDecl * FunctionTemplate)149 matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate) {
150   // For non-type trailing param, match very specifically
151   // 'template <..., enable_if_type<Condition, Type> = Default>' where
152   // enable_if_type is 'enable_if' or 'enable_if_t'. E.g., 'template <typename
153   // T, enable_if_t<is_same_v<T, bool>, int*> = nullptr>
154   //
155   // Otherwise, match a trailing default type arg.
156   // E.g., 'template <typename T, typename = enable_if_t<is_same_v<T, bool>>>'
157 
158   const TemplateParameterList *TemplateParams =
159       FunctionTemplate->getTemplateParameters();
160   if (TemplateParams->size() == 0)
161     return {};
162 
163   const NamedDecl *LastParam =
164       TemplateParams->getParam(TemplateParams->size() - 1);
165   if (const auto *LastTemplateParam =
166           dyn_cast<NonTypeTemplateParmDecl>(LastParam)) {
167 
168     if (!LastTemplateParam->hasDefaultArgument() ||
169         !LastTemplateParam->getName().empty())
170       return {};
171 
172     return {matchEnableIfSpecialization(
173                 LastTemplateParam->getTypeSourceInfo()->getTypeLoc()),
174             LastTemplateParam};
175   }
176   if (const auto *LastTemplateParam =
177           dyn_cast<TemplateTypeParmDecl>(LastParam)) {
178     if (LastTemplateParam->hasDefaultArgument() &&
179         LastTemplateParam->getIdentifier() == nullptr) {
180       return {
181           matchEnableIfSpecialization(LastTemplateParam->getDefaultArgument()
182                                           .getTypeSourceInfo()
183                                           ->getTypeLoc()),
184           LastTemplateParam};
185     }
186   }
187   return {};
188 }
189 
190 template <typename T>
getRAngleFileLoc(const SourceManager & SM,const T & Element)191 static SourceLocation getRAngleFileLoc(const SourceManager &SM,
192                                        const T &Element) {
193   // getFileLoc handles the case where the RAngle loc is part of a synthesized
194   // '>>', which ends up allocating a 'scratch space' buffer in the source
195   // manager.
196   return SM.getFileLoc(Element.getRAngleLoc());
197 }
198 
199 static SourceRange
getConditionRange(ASTContext & Context,const TemplateSpecializationTypeLoc & EnableIf)200 getConditionRange(ASTContext &Context,
201                   const TemplateSpecializationTypeLoc &EnableIf) {
202   // TemplateArgumentLoc's SourceRange End is the location of the last token
203   // (per UnqualifiedId docs). E.g., in `enable_if<AAA && BBB>`, the End
204   // location will be the first 'B' in 'BBB'.
205   const LangOptions &LangOpts = Context.getLangOpts();
206   const SourceManager &SM = Context.getSourceManager();
207   if (EnableIf.getNumArgs() > 1) {
208     TemplateArgumentLoc NextArg = EnableIf.getArgLoc(1);
209     return {EnableIf.getLAngleLoc().getLocWithOffset(1),
210             utils::lexer::findPreviousTokenKind(
211                 NextArg.getSourceRange().getBegin(), SM, LangOpts, tok::comma)};
212   }
213 
214   return {EnableIf.getLAngleLoc().getLocWithOffset(1),
215           getRAngleFileLoc(SM, EnableIf)};
216 }
217 
getTypeRange(ASTContext & Context,const TemplateSpecializationTypeLoc & EnableIf)218 static SourceRange getTypeRange(ASTContext &Context,
219                                 const TemplateSpecializationTypeLoc &EnableIf) {
220   TemplateArgumentLoc Arg = EnableIf.getArgLoc(1);
221   const LangOptions &LangOpts = Context.getLangOpts();
222   const SourceManager &SM = Context.getSourceManager();
223   return {utils::lexer::findPreviousTokenKind(Arg.getSourceRange().getBegin(),
224                                               SM, LangOpts, tok::comma)
225               .getLocWithOffset(1),
226           getRAngleFileLoc(SM, EnableIf)};
227 }
228 
229 // Returns the original source text of the second argument of a call to
230 // enable_if_t. E.g., in enable_if_t<Condition, TheType>, this function
231 // returns 'TheType'.
232 static std::optional<StringRef>
getTypeText(ASTContext & Context,const TemplateSpecializationTypeLoc & EnableIf)233 getTypeText(ASTContext &Context,
234             const TemplateSpecializationTypeLoc &EnableIf) {
235   if (EnableIf.getNumArgs() > 1) {
236     const LangOptions &LangOpts = Context.getLangOpts();
237     const SourceManager &SM = Context.getSourceManager();
238     bool Invalid = false;
239     StringRef Text = Lexer::getSourceText(CharSourceRange::getCharRange(
240                                               getTypeRange(Context, EnableIf)),
241                                           SM, LangOpts, &Invalid)
242                          .trim();
243     if (Invalid)
244       return std::nullopt;
245 
246     return Text;
247   }
248 
249   return "void";
250 }
251 
252 static std::optional<SourceLocation>
findInsertionForConstraint(const FunctionDecl * Function,ASTContext & Context)253 findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context) {
254   SourceManager &SM = Context.getSourceManager();
255   const LangOptions &LangOpts = Context.getLangOpts();
256 
257   if (const auto *Constructor = dyn_cast<CXXConstructorDecl>(Function)) {
258     for (const CXXCtorInitializer *Init : Constructor->inits()) {
259       if (Init->getSourceOrder() == 0)
260         return utils::lexer::findPreviousTokenKind(Init->getSourceLocation(),
261                                                    SM, LangOpts, tok::colon);
262     }
263     if (!Constructor->inits().empty())
264       return std::nullopt;
265   }
266   if (Function->isDeleted()) {
267     SourceLocation FunctionEnd = Function->getSourceRange().getEnd();
268     return utils::lexer::findNextAnyTokenKind(FunctionEnd, SM, LangOpts,
269                                               tok::equal, tok::equal);
270   }
271   const Stmt *Body = Function->getBody();
272   if (!Body)
273     return std::nullopt;
274 
275   return Body->getBeginLoc();
276 }
277 
isPrimaryExpression(const Expr * Expression)278 bool isPrimaryExpression(const Expr *Expression) {
279   // This function is an incomplete approximation of checking whether
280   // an Expr is a primary expression. In particular, if this function
281   // returns true, the expression is a primary expression. The converse
282   // is not necessarily true.
283 
284   if (const auto *Cast = dyn_cast<ImplicitCastExpr>(Expression))
285     Expression = Cast->getSubExprAsWritten();
286   if (isa<ParenExpr, DependentScopeDeclRefExpr>(Expression))
287     return true;
288 
289   return false;
290 }
291 
292 // Return the original source text of an enable_if_t condition, i.e., the
293 // first template argument). For example, in
294 // 'enable_if_t<FirstCondition || SecondCondition, AType>', the text
295 // the text 'FirstCondition || SecondCondition' is returned.
getConditionText(const Expr * ConditionExpr,SourceRange ConditionRange,ASTContext & Context)296 static std::optional<std::string> getConditionText(const Expr *ConditionExpr,
297                                                    SourceRange ConditionRange,
298                                                    ASTContext &Context) {
299   SourceManager &SM = Context.getSourceManager();
300   const LangOptions &LangOpts = Context.getLangOpts();
301 
302   SourceLocation PrevTokenLoc = ConditionRange.getEnd();
303   if (PrevTokenLoc.isInvalid())
304     return std::nullopt;
305 
306   const bool SkipComments = false;
307   Token PrevToken;
308   std::tie(PrevToken, PrevTokenLoc) = utils::lexer::getPreviousTokenAndStart(
309       PrevTokenLoc, SM, LangOpts, SkipComments);
310   bool EndsWithDoubleSlash =
311       PrevToken.is(tok::comment) &&
312       Lexer::getSourceText(CharSourceRange::getCharRange(
313                                PrevTokenLoc, PrevTokenLoc.getLocWithOffset(2)),
314                            SM, LangOpts) == "//";
315 
316   bool Invalid = false;
317   llvm::StringRef ConditionText = Lexer::getSourceText(
318       CharSourceRange::getCharRange(ConditionRange), SM, LangOpts, &Invalid);
319   if (Invalid)
320     return std::nullopt;
321 
322   auto AddParens = [&](llvm::StringRef Text) -> std::string {
323     if (isPrimaryExpression(ConditionExpr))
324       return Text.str();
325     return "(" + Text.str() + ")";
326   };
327 
328   if (EndsWithDoubleSlash)
329     return AddParens(ConditionText);
330   return AddParens(ConditionText.trim());
331 }
332 
333 // Handle functions that return enable_if_t, e.g.,
334 //   template <...>
335 //   enable_if_t<Condition, ReturnType> function();
336 //
337 // Return a vector of FixItHints if the code can be replaced with
338 // a C++20 requires clause. In the example above, returns FixItHints
339 // to result in
340 //   template <...>
341 //   ReturnType function() requires Condition {}
handleReturnType(const FunctionDecl * Function,const TypeLoc & ReturnType,const EnableIfData & EnableIf,ASTContext & Context)342 static std::vector<FixItHint> handleReturnType(const FunctionDecl *Function,
343                                                const TypeLoc &ReturnType,
344                                                const EnableIfData &EnableIf,
345                                                ASTContext &Context) {
346   TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
347 
348   SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
349 
350   std::optional<std::string> ConditionText = getConditionText(
351       EnableCondition.getSourceExpression(), ConditionRange, Context);
352   if (!ConditionText)
353     return {};
354 
355   std::optional<StringRef> TypeText = getTypeText(Context, EnableIf.Loc);
356   if (!TypeText)
357     return {};
358 
359   SmallVector<const Expr *, 3> ExistingConstraints;
360   Function->getAssociatedConstraints(ExistingConstraints);
361   if (!ExistingConstraints.empty()) {
362     // FIXME - Support adding new constraints to existing ones. Do we need to
363     // consider subsumption?
364     return {};
365   }
366 
367   std::optional<SourceLocation> ConstraintInsertionLoc =
368       findInsertionForConstraint(Function, Context);
369   if (!ConstraintInsertionLoc)
370     return {};
371 
372   std::vector<FixItHint> FixIts;
373   FixIts.push_back(FixItHint::CreateReplacement(
374       CharSourceRange::getTokenRange(EnableIf.Outer.getSourceRange()),
375       *TypeText));
376   FixIts.push_back(FixItHint::CreateInsertion(
377       *ConstraintInsertionLoc, "requires " + *ConditionText + " "));
378   return FixIts;
379 }
380 
381 // Handle enable_if_t in a trailing template parameter, e.g.,
382 //   template <..., enable_if_t<Condition, Type> = Type{}>
383 //   ReturnType function();
384 //
385 // Return a vector of FixItHints if the code can be replaced with
386 // a C++20 requires clause. In the example above, returns FixItHints
387 // to result in
388 //   template <...>
389 //   ReturnType function() requires Condition {}
390 static std::vector<FixItHint>
handleTrailingTemplateType(const FunctionTemplateDecl * FunctionTemplate,const FunctionDecl * Function,const Decl * LastTemplateParam,const EnableIfData & EnableIf,ASTContext & Context)391 handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate,
392                            const FunctionDecl *Function,
393                            const Decl *LastTemplateParam,
394                            const EnableIfData &EnableIf, ASTContext &Context) {
395   SourceManager &SM = Context.getSourceManager();
396   const LangOptions &LangOpts = Context.getLangOpts();
397 
398   TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
399 
400   SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
401 
402   std::optional<std::string> ConditionText = getConditionText(
403       EnableCondition.getSourceExpression(), ConditionRange, Context);
404   if (!ConditionText)
405     return {};
406 
407   SmallVector<const Expr *, 3> ExistingConstraints;
408   Function->getAssociatedConstraints(ExistingConstraints);
409   if (!ExistingConstraints.empty()) {
410     // FIXME - Support adding new constraints to existing ones. Do we need to
411     // consider subsumption?
412     return {};
413   }
414 
415   SourceRange RemovalRange;
416   const TemplateParameterList *TemplateParams =
417       FunctionTemplate->getTemplateParameters();
418   if (!TemplateParams || TemplateParams->size() == 0)
419     return {};
420 
421   if (TemplateParams->size() == 1) {
422     RemovalRange =
423         SourceRange(TemplateParams->getTemplateLoc(),
424                     getRAngleFileLoc(SM, *TemplateParams).getLocWithOffset(1));
425   } else {
426     RemovalRange =
427         SourceRange(utils::lexer::findPreviousTokenKind(
428                         LastTemplateParam->getSourceRange().getBegin(), SM,
429                         LangOpts, tok::comma),
430                     getRAngleFileLoc(SM, *TemplateParams));
431   }
432 
433   std::optional<SourceLocation> ConstraintInsertionLoc =
434       findInsertionForConstraint(Function, Context);
435   if (!ConstraintInsertionLoc)
436     return {};
437 
438   std::vector<FixItHint> FixIts;
439   FixIts.push_back(
440       FixItHint::CreateRemoval(CharSourceRange::getCharRange(RemovalRange)));
441   FixIts.push_back(FixItHint::CreateInsertion(
442       *ConstraintInsertionLoc, "requires " + *ConditionText + " "));
443   return FixIts;
444 }
445 
check(const MatchFinder::MatchResult & Result)446 void UseConstraintsCheck::check(const MatchFinder::MatchResult &Result) {
447   const auto *FunctionTemplate =
448       Result.Nodes.getNodeAs<FunctionTemplateDecl>("functionTemplate");
449   const auto *Function = Result.Nodes.getNodeAs<FunctionDecl>("function");
450   const auto *ReturnType = Result.Nodes.getNodeAs<TypeLoc>("return");
451   if (!FunctionTemplate || !Function || !ReturnType)
452     return;
453 
454   // Check for
455   //
456   //   Case 1. Return type of function
457   //
458   //     template <...>
459   //     enable_if_t<Condition, ReturnType>::type function() {}
460   //
461   //   Case 2. Trailing template parameter
462   //
463   //     template <..., enable_if_t<Condition, Type> = Type{}>
464   //     ReturnType function() {}
465   //
466   //     or
467   //
468   //     template <..., typename = enable_if_t<Condition, void>>
469   //     ReturnType function() {}
470   //
471 
472   // Case 1. Return type of function
473   if (auto EnableIf = matchEnableIfSpecialization(*ReturnType)) {
474     diag(ReturnType->getBeginLoc(),
475          "use C++20 requires constraints instead of enable_if")
476         << handleReturnType(Function, *ReturnType, *EnableIf, *Result.Context);
477     return;
478   }
479 
480   // Case 2. Trailing template parameter
481   if (auto [EnableIf, LastTemplateParam] =
482           matchTrailingTemplateParam(FunctionTemplate);
483       EnableIf && LastTemplateParam) {
484     diag(LastTemplateParam->getSourceRange().getBegin(),
485          "use C++20 requires constraints instead of enable_if")
486         << handleTrailingTemplateType(FunctionTemplate, Function,
487                                       LastTemplateParam, *EnableIf,
488                                       *Result.Context);
489     return;
490   }
491 }
492 
493 } // namespace clang::tidy::modernize
494