xref: /llvm-project/clang-tools-extra/clang-tidy/modernize/UseConstraintsCheck.cpp (revision 3b5a121a2478e586f59e3277d04d17fb63be5d76)
1 //===--- UseConstraintsCheck.cpp - clang-tidy -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "UseConstraintsCheck.h"
10 #include "clang/AST/ASTContext.h"
11 #include "clang/ASTMatchers/ASTMatchFinder.h"
12 #include "clang/Lex/Lexer.h"
13 
14 #include "../utils/LexerUtils.h"
15 
16 #include <optional>
17 #include <utility>
18 
19 using namespace clang::ast_matchers;
20 
21 namespace clang::tidy::modernize {
22 
23 struct EnableIfData {
24   TemplateSpecializationTypeLoc Loc;
25   TypeLoc Outer;
26 };
27 
28 namespace {
29 AST_MATCHER(FunctionDecl, hasOtherDeclarations) {
30   auto It = Node.redecls_begin();
31   auto EndIt = Node.redecls_end();
32 
33   if (It == EndIt)
34     return false;
35 
36   ++It;
37   return It != EndIt;
38 }
39 } // namespace
40 
41 void UseConstraintsCheck::registerMatchers(MatchFinder *Finder) {
42   Finder->addMatcher(
43       functionTemplateDecl(
44           has(functionDecl(unless(hasOtherDeclarations()), isDefinition(),
45                            hasReturnTypeLoc(typeLoc().bind("return")))
46                   .bind("function")))
47           .bind("functionTemplate"),
48       this);
49 }
50 
51 static std::optional<TemplateSpecializationTypeLoc>
52 matchEnableIfSpecializationImplTypename(TypeLoc TheType) {
53   if (const auto Dep = TheType.getAs<DependentNameTypeLoc>()) {
54     const IdentifierInfo *Identifier = Dep.getTypePtr()->getIdentifier();
55     if (!Identifier || Identifier->getName() != "type" ||
56         Dep.getTypePtr()->getKeyword() != ElaboratedTypeKeyword::Typename) {
57       return std::nullopt;
58     }
59     TheType = Dep.getQualifierLoc().getTypeLoc();
60   }
61 
62   if (const auto SpecializationLoc =
63           TheType.getAs<TemplateSpecializationTypeLoc>()) {
64 
65     const auto *Specialization =
66         dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
67     if (!Specialization)
68       return std::nullopt;
69 
70     const TemplateDecl *TD =
71         Specialization->getTemplateName().getAsTemplateDecl();
72     if (!TD || TD->getName() != "enable_if")
73       return std::nullopt;
74 
75     int NumArgs = SpecializationLoc.getNumArgs();
76     if (NumArgs != 1 && NumArgs != 2)
77       return std::nullopt;
78 
79     return SpecializationLoc;
80   }
81   return std::nullopt;
82 }
83 
84 static std::optional<TemplateSpecializationTypeLoc>
85 matchEnableIfSpecializationImplTrait(TypeLoc TheType) {
86   if (const auto Elaborated = TheType.getAs<ElaboratedTypeLoc>())
87     TheType = Elaborated.getNamedTypeLoc();
88 
89   if (const auto SpecializationLoc =
90           TheType.getAs<TemplateSpecializationTypeLoc>()) {
91 
92     const auto *Specialization =
93         dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
94     if (!Specialization)
95       return std::nullopt;
96 
97     const TemplateDecl *TD =
98         Specialization->getTemplateName().getAsTemplateDecl();
99     if (!TD || TD->getName() != "enable_if_t")
100       return std::nullopt;
101 
102     if (!Specialization->isTypeAlias())
103       return std::nullopt;
104 
105     if (const auto *AliasedType =
106             dyn_cast<DependentNameType>(Specialization->getAliasedType())) {
107       if (AliasedType->getIdentifier()->getName() != "type" ||
108           AliasedType->getKeyword() != ElaboratedTypeKeyword::Typename) {
109         return std::nullopt;
110       }
111     } else {
112       return std::nullopt;
113     }
114     int NumArgs = SpecializationLoc.getNumArgs();
115     if (NumArgs != 1 && NumArgs != 2)
116       return std::nullopt;
117 
118     return SpecializationLoc;
119   }
120   return std::nullopt;
121 }
122 
123 static std::optional<TemplateSpecializationTypeLoc>
124 matchEnableIfSpecializationImpl(TypeLoc TheType) {
125   if (auto EnableIf = matchEnableIfSpecializationImplTypename(TheType))
126     return EnableIf;
127   return matchEnableIfSpecializationImplTrait(TheType);
128 }
129 
130 static std::optional<EnableIfData>
131 matchEnableIfSpecialization(TypeLoc TheType) {
132   if (const auto Pointer = TheType.getAs<PointerTypeLoc>())
133     TheType = Pointer.getPointeeLoc();
134   else if (const auto Reference = TheType.getAs<ReferenceTypeLoc>())
135     TheType = Reference.getPointeeLoc();
136   if (const auto Qualified = TheType.getAs<QualifiedTypeLoc>())
137     TheType = Qualified.getUnqualifiedLoc();
138 
139   if (auto EnableIf = matchEnableIfSpecializationImpl(TheType))
140     return EnableIfData{std::move(*EnableIf), TheType};
141   return std::nullopt;
142 }
143 
144 static std::pair<std::optional<EnableIfData>, const Decl *>
145 matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate) {
146   // For non-type trailing param, match very specifically
147   // 'template <..., enable_if_type<Condition, Type> = Default>' where
148   // enable_if_type is 'enable_if' or 'enable_if_t'. E.g., 'template <typename
149   // T, enable_if_t<is_same_v<T, bool>, int*> = nullptr>
150   //
151   // Otherwise, match a trailing default type arg.
152   // E.g., 'template <typename T, typename = enable_if_t<is_same_v<T, bool>>>'
153 
154   const TemplateParameterList *TemplateParams =
155       FunctionTemplate->getTemplateParameters();
156   if (TemplateParams->size() == 0)
157     return {};
158 
159   const NamedDecl *LastParam =
160       TemplateParams->getParam(TemplateParams->size() - 1);
161   if (const auto *LastTemplateParam =
162           dyn_cast<NonTypeTemplateParmDecl>(LastParam)) {
163 
164     if (!LastTemplateParam->hasDefaultArgument() ||
165         !LastTemplateParam->getName().empty())
166       return {};
167 
168     return {matchEnableIfSpecialization(
169                 LastTemplateParam->getTypeSourceInfo()->getTypeLoc()),
170             LastTemplateParam};
171   }
172   if (const auto *LastTemplateParam =
173           dyn_cast<TemplateTypeParmDecl>(LastParam)) {
174     if (LastTemplateParam->hasDefaultArgument() &&
175         LastTemplateParam->getIdentifier() == nullptr) {
176       return {matchEnableIfSpecialization(
177                   LastTemplateParam->getDefaultArgumentInfo()->getTypeLoc()),
178               LastTemplateParam};
179     }
180   }
181   return {};
182 }
183 
184 template <typename T>
185 static SourceLocation getRAngleFileLoc(const SourceManager &SM,
186                                        const T &Element) {
187   // getFileLoc handles the case where the RAngle loc is part of a synthesized
188   // '>>', which ends up allocating a 'scratch space' buffer in the source
189   // manager.
190   return SM.getFileLoc(Element.getRAngleLoc());
191 }
192 
193 static SourceRange
194 getConditionRange(ASTContext &Context,
195                   const TemplateSpecializationTypeLoc &EnableIf) {
196   // TemplateArgumentLoc's SourceRange End is the location of the last token
197   // (per UnqualifiedId docs). E.g., in `enable_if<AAA && BBB>`, the End
198   // location will be the first 'B' in 'BBB'.
199   const LangOptions &LangOpts = Context.getLangOpts();
200   const SourceManager &SM = Context.getSourceManager();
201   if (EnableIf.getNumArgs() > 1) {
202     TemplateArgumentLoc NextArg = EnableIf.getArgLoc(1);
203     return {EnableIf.getLAngleLoc().getLocWithOffset(1),
204             utils::lexer::findPreviousTokenKind(
205                 NextArg.getSourceRange().getBegin(), SM, LangOpts, tok::comma)};
206   }
207 
208   return {EnableIf.getLAngleLoc().getLocWithOffset(1),
209           getRAngleFileLoc(SM, EnableIf)};
210 }
211 
212 static SourceRange getTypeRange(ASTContext &Context,
213                                 const TemplateSpecializationTypeLoc &EnableIf) {
214   TemplateArgumentLoc Arg = EnableIf.getArgLoc(1);
215   const LangOptions &LangOpts = Context.getLangOpts();
216   const SourceManager &SM = Context.getSourceManager();
217   return {utils::lexer::findPreviousTokenKind(Arg.getSourceRange().getBegin(),
218                                               SM, LangOpts, tok::comma)
219               .getLocWithOffset(1),
220           getRAngleFileLoc(SM, EnableIf)};
221 }
222 
223 // Returns the original source text of the second argument of a call to
224 // enable_if_t. E.g., in enable_if_t<Condition, TheType>, this function
225 // returns 'TheType'.
226 static std::optional<StringRef>
227 getTypeText(ASTContext &Context,
228             const TemplateSpecializationTypeLoc &EnableIf) {
229   if (EnableIf.getNumArgs() > 1) {
230     const LangOptions &LangOpts = Context.getLangOpts();
231     const SourceManager &SM = Context.getSourceManager();
232     bool Invalid = false;
233     StringRef Text = Lexer::getSourceText(CharSourceRange::getCharRange(
234                                               getTypeRange(Context, EnableIf)),
235                                           SM, LangOpts, &Invalid)
236                          .trim();
237     if (Invalid)
238       return std::nullopt;
239 
240     return Text;
241   }
242 
243   return "void";
244 }
245 
246 static std::optional<SourceLocation>
247 findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context) {
248   SourceManager &SM = Context.getSourceManager();
249   const LangOptions &LangOpts = Context.getLangOpts();
250 
251   if (const auto *Constructor = dyn_cast<CXXConstructorDecl>(Function)) {
252     for (const CXXCtorInitializer *Init : Constructor->inits()) {
253       if (Init->getSourceOrder() == 0)
254         return utils::lexer::findPreviousTokenKind(Init->getSourceLocation(),
255                                                    SM, LangOpts, tok::colon);
256     }
257     if (!Constructor->inits().empty())
258       return std::nullopt;
259   }
260   if (Function->isDeleted()) {
261     SourceLocation FunctionEnd = Function->getSourceRange().getEnd();
262     return utils::lexer::findNextAnyTokenKind(FunctionEnd, SM, LangOpts,
263                                               tok::equal, tok::equal);
264   }
265   const Stmt *Body = Function->getBody();
266   if (!Body)
267     return std::nullopt;
268 
269   return Body->getBeginLoc();
270 }
271 
272 bool isPrimaryExpression(const Expr *Expression) {
273   // This function is an incomplete approximation of checking whether
274   // an Expr is a primary expression. In particular, if this function
275   // returns true, the expression is a primary expression. The converse
276   // is not necessarily true.
277 
278   if (const auto *Cast = dyn_cast<ImplicitCastExpr>(Expression))
279     Expression = Cast->getSubExprAsWritten();
280   if (isa<ParenExpr, DependentScopeDeclRefExpr>(Expression))
281     return true;
282 
283   return false;
284 }
285 
286 // Return the original source text of an enable_if_t condition, i.e., the
287 // first template argument). For example, in
288 // 'enable_if_t<FirstCondition || SecondCondition, AType>', the text
289 // the text 'FirstCondition || SecondCondition' is returned.
290 static std::optional<std::string> getConditionText(const Expr *ConditionExpr,
291                                                    SourceRange ConditionRange,
292                                                    ASTContext &Context) {
293   SourceManager &SM = Context.getSourceManager();
294   const LangOptions &LangOpts = Context.getLangOpts();
295 
296   SourceLocation PrevTokenLoc = ConditionRange.getEnd();
297   if (PrevTokenLoc.isInvalid())
298     return std::nullopt;
299 
300   const bool SkipComments = false;
301   Token PrevToken;
302   std::tie(PrevToken, PrevTokenLoc) = utils::lexer::getPreviousTokenAndStart(
303       PrevTokenLoc, SM, LangOpts, SkipComments);
304   bool EndsWithDoubleSlash =
305       PrevToken.is(tok::comment) &&
306       Lexer::getSourceText(CharSourceRange::getCharRange(
307                                PrevTokenLoc, PrevTokenLoc.getLocWithOffset(2)),
308                            SM, LangOpts) == "//";
309 
310   bool Invalid = false;
311   llvm::StringRef ConditionText = Lexer::getSourceText(
312       CharSourceRange::getCharRange(ConditionRange), SM, LangOpts, &Invalid);
313   if (Invalid)
314     return std::nullopt;
315 
316   auto AddParens = [&](llvm::StringRef Text) -> std::string {
317     if (isPrimaryExpression(ConditionExpr))
318       return Text.str();
319     return "(" + Text.str() + ")";
320   };
321 
322   if (EndsWithDoubleSlash)
323     return AddParens(ConditionText);
324   return AddParens(ConditionText.trim());
325 }
326 
327 // Handle functions that return enable_if_t, e.g.,
328 //   template <...>
329 //   enable_if_t<Condition, ReturnType> function();
330 //
331 // Return a vector of FixItHints if the code can be replaced with
332 // a C++20 requires clause. In the example above, returns FixItHints
333 // to result in
334 //   template <...>
335 //   ReturnType function() requires Condition {}
336 static std::vector<FixItHint> handleReturnType(const FunctionDecl *Function,
337                                                const TypeLoc &ReturnType,
338                                                const EnableIfData &EnableIf,
339                                                ASTContext &Context) {
340   TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
341 
342   SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
343 
344   std::optional<std::string> ConditionText = getConditionText(
345       EnableCondition.getSourceExpression(), ConditionRange, Context);
346   if (!ConditionText)
347     return {};
348 
349   std::optional<StringRef> TypeText = getTypeText(Context, EnableIf.Loc);
350   if (!TypeText)
351     return {};
352 
353   SmallVector<const Expr *, 3> ExistingConstraints;
354   Function->getAssociatedConstraints(ExistingConstraints);
355   if (!ExistingConstraints.empty()) {
356     // FIXME - Support adding new constraints to existing ones. Do we need to
357     // consider subsumption?
358     return {};
359   }
360 
361   std::optional<SourceLocation> ConstraintInsertionLoc =
362       findInsertionForConstraint(Function, Context);
363   if (!ConstraintInsertionLoc)
364     return {};
365 
366   std::vector<FixItHint> FixIts;
367   FixIts.push_back(FixItHint::CreateReplacement(
368       CharSourceRange::getTokenRange(EnableIf.Outer.getSourceRange()),
369       *TypeText));
370   FixIts.push_back(FixItHint::CreateInsertion(
371       *ConstraintInsertionLoc, "requires " + *ConditionText + " "));
372   return FixIts;
373 }
374 
375 // Handle enable_if_t in a trailing template parameter, e.g.,
376 //   template <..., enable_if_t<Condition, Type> = Type{}>
377 //   ReturnType function();
378 //
379 // Return a vector of FixItHints if the code can be replaced with
380 // a C++20 requires clause. In the example above, returns FixItHints
381 // to result in
382 //   template <...>
383 //   ReturnType function() requires Condition {}
384 static std::vector<FixItHint>
385 handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate,
386                            const FunctionDecl *Function,
387                            const Decl *LastTemplateParam,
388                            const EnableIfData &EnableIf, ASTContext &Context) {
389   SourceManager &SM = Context.getSourceManager();
390   const LangOptions &LangOpts = Context.getLangOpts();
391 
392   TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
393 
394   SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
395 
396   std::optional<std::string> ConditionText = getConditionText(
397       EnableCondition.getSourceExpression(), ConditionRange, Context);
398   if (!ConditionText)
399     return {};
400 
401   SmallVector<const Expr *, 3> ExistingConstraints;
402   Function->getAssociatedConstraints(ExistingConstraints);
403   if (!ExistingConstraints.empty()) {
404     // FIXME - Support adding new constraints to existing ones. Do we need to
405     // consider subsumption?
406     return {};
407   }
408 
409   SourceRange RemovalRange;
410   const TemplateParameterList *TemplateParams =
411       FunctionTemplate->getTemplateParameters();
412   if (!TemplateParams || TemplateParams->size() == 0)
413     return {};
414 
415   if (TemplateParams->size() == 1) {
416     RemovalRange =
417         SourceRange(TemplateParams->getTemplateLoc(),
418                     getRAngleFileLoc(SM, *TemplateParams).getLocWithOffset(1));
419   } else {
420     RemovalRange =
421         SourceRange(utils::lexer::findPreviousTokenKind(
422                         LastTemplateParam->getSourceRange().getBegin(), SM,
423                         LangOpts, tok::comma),
424                     getRAngleFileLoc(SM, *TemplateParams));
425   }
426 
427   std::optional<SourceLocation> ConstraintInsertionLoc =
428       findInsertionForConstraint(Function, Context);
429   if (!ConstraintInsertionLoc)
430     return {};
431 
432   std::vector<FixItHint> FixIts;
433   FixIts.push_back(
434       FixItHint::CreateRemoval(CharSourceRange::getCharRange(RemovalRange)));
435   FixIts.push_back(FixItHint::CreateInsertion(
436       *ConstraintInsertionLoc, "requires " + *ConditionText + " "));
437   return FixIts;
438 }
439 
440 void UseConstraintsCheck::check(const MatchFinder::MatchResult &Result) {
441   const auto *FunctionTemplate =
442       Result.Nodes.getNodeAs<FunctionTemplateDecl>("functionTemplate");
443   const auto *Function = Result.Nodes.getNodeAs<FunctionDecl>("function");
444   const auto *ReturnType = Result.Nodes.getNodeAs<TypeLoc>("return");
445   if (!FunctionTemplate || !Function || !ReturnType)
446     return;
447 
448   // Check for
449   //
450   //   Case 1. Return type of function
451   //
452   //     template <...>
453   //     enable_if_t<Condition, ReturnType>::type function() {}
454   //
455   //   Case 2. Trailing template parameter
456   //
457   //     template <..., enable_if_t<Condition, Type> = Type{}>
458   //     ReturnType function() {}
459   //
460   //     or
461   //
462   //     template <..., typename = enable_if_t<Condition, void>>
463   //     ReturnType function() {}
464   //
465 
466   // Case 1. Return type of function
467   if (auto EnableIf = matchEnableIfSpecialization(*ReturnType)) {
468     diag(ReturnType->getBeginLoc(),
469          "use C++20 requires constraints instead of enable_if")
470         << handleReturnType(Function, *ReturnType, *EnableIf, *Result.Context);
471     return;
472   }
473 
474   // Case 2. Trailing template parameter
475   if (auto [EnableIf, LastTemplateParam] =
476           matchTrailingTemplateParam(FunctionTemplate);
477       EnableIf && LastTemplateParam) {
478     diag(LastTemplateParam->getSourceRange().getBegin(),
479          "use C++20 requires constraints instead of enable_if")
480         << handleTrailingTemplateType(FunctionTemplate, Function,
481                                       LastTemplateParam, *EnableIf,
482                                       *Result.Context);
483     return;
484   }
485 }
486 
487 } // namespace clang::tidy::modernize
488