xref: /llvm-project/clang-tools-extra/clang-tidy/modernize/UseConstraintsCheck.cpp (revision ec5f4be4521c3b28d6bb14f34f39a87c36fe8c00)
1 //===--- UseConstraintsCheck.cpp - clang-tidy -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "UseConstraintsCheck.h"
10 #include "clang/AST/ASTContext.h"
11 #include "clang/ASTMatchers/ASTMatchFinder.h"
12 #include "clang/Lex/Lexer.h"
13 
14 #include "../utils/LexerUtils.h"
15 
16 #include <optional>
17 #include <utility>
18 
19 using namespace clang::ast_matchers;
20 
21 namespace clang::tidy::modernize {
22 
23 struct EnableIfData {
24   TemplateSpecializationTypeLoc Loc;
25   TypeLoc Outer;
26 };
27 
28 namespace {
29 AST_MATCHER(FunctionDecl, hasOtherDeclarations) {
30   auto It = Node.redecls_begin();
31   auto EndIt = Node.redecls_end();
32 
33   if (It == EndIt)
34     return false;
35 
36   ++It;
37   return It != EndIt;
38 }
39 } // namespace
40 
41 void UseConstraintsCheck::registerMatchers(MatchFinder *Finder) {
42   Finder->addMatcher(
43       functionTemplateDecl(
44           has(functionDecl(unless(hasOtherDeclarations()), isDefinition(),
45                            hasReturnTypeLoc(typeLoc().bind("return")))
46                   .bind("function")))
47           .bind("functionTemplate"),
48       this);
49 }
50 
51 static std::optional<TemplateSpecializationTypeLoc>
52 matchEnableIfSpecializationImplTypename(TypeLoc TheType) {
53   if (const auto Dep = TheType.getAs<DependentNameTypeLoc>()) {
54     const IdentifierInfo *Identifier = Dep.getTypePtr()->getIdentifier();
55     if (!Identifier || Identifier->getName() != "type" ||
56         Dep.getTypePtr()->getKeyword() != ETK_Typename) {
57       return std::nullopt;
58     }
59     TheType = Dep.getQualifierLoc().getTypeLoc();
60   }
61 
62   if (const auto SpecializationLoc =
63           TheType.getAs<TemplateSpecializationTypeLoc>()) {
64 
65     const auto *Specialization =
66         dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
67     if (!Specialization)
68       return std::nullopt;
69 
70     const TemplateDecl *TD =
71         Specialization->getTemplateName().getAsTemplateDecl();
72     if (!TD || TD->getName() != "enable_if")
73       return std::nullopt;
74 
75     int NumArgs = SpecializationLoc.getNumArgs();
76     if (NumArgs != 1 && NumArgs != 2)
77       return std::nullopt;
78 
79     return SpecializationLoc;
80   }
81   return std::nullopt;
82 }
83 
84 static std::optional<TemplateSpecializationTypeLoc>
85 matchEnableIfSpecializationImplTrait(TypeLoc TheType) {
86   if (const auto Elaborated = TheType.getAs<ElaboratedTypeLoc>())
87     TheType = Elaborated.getNamedTypeLoc();
88 
89   if (const auto SpecializationLoc =
90           TheType.getAs<TemplateSpecializationTypeLoc>()) {
91 
92     const auto *Specialization =
93         dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
94     if (!Specialization)
95       return std::nullopt;
96 
97     const TemplateDecl *TD =
98         Specialization->getTemplateName().getAsTemplateDecl();
99     if (!TD || TD->getName() != "enable_if_t")
100       return std::nullopt;
101 
102     if (!Specialization->isTypeAlias())
103       return std::nullopt;
104 
105     if (const auto *AliasedType =
106             dyn_cast<DependentNameType>(Specialization->getAliasedType())) {
107       if (AliasedType->getIdentifier()->getName() != "type" ||
108           AliasedType->getKeyword() != ETK_Typename) {
109         return std::nullopt;
110       }
111     } else {
112       return std::nullopt;
113     }
114     int NumArgs = SpecializationLoc.getNumArgs();
115     if (NumArgs != 1 && NumArgs != 2)
116       return std::nullopt;
117 
118     return SpecializationLoc;
119   }
120   return std::nullopt;
121 }
122 
123 static std::optional<TemplateSpecializationTypeLoc>
124 matchEnableIfSpecializationImpl(TypeLoc TheType) {
125   if (auto EnableIf = matchEnableIfSpecializationImplTypename(TheType))
126     return EnableIf;
127   return matchEnableIfSpecializationImplTrait(TheType);
128 }
129 
130 static std::optional<EnableIfData>
131 matchEnableIfSpecialization(TypeLoc TheType) {
132   if (const auto Pointer = TheType.getAs<PointerTypeLoc>())
133     TheType = Pointer.getPointeeLoc();
134   else if (const auto Reference = TheType.getAs<ReferenceTypeLoc>())
135     TheType = Reference.getPointeeLoc();
136   if (const auto Qualified = TheType.getAs<QualifiedTypeLoc>())
137     TheType = Qualified.getUnqualifiedLoc();
138 
139   if (auto EnableIf = matchEnableIfSpecializationImpl(TheType))
140     return EnableIfData{std::move(*EnableIf), TheType};
141   return std::nullopt;
142 }
143 
144 static std::pair<std::optional<EnableIfData>, const Decl *>
145 matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate) {
146   // For non-type trailing param, match very specifically
147   // 'template <..., enable_if_type<Condition, Type> = Default>' where
148   // enable_if_type is 'enable_if' or 'enable_if_t'. E.g., 'template <typename
149   // T, enable_if_t<is_same_v<T, bool>, int*> = nullptr>
150   //
151   // Otherwise, match a trailing default type arg.
152   // E.g., 'template <typename T, typename = enable_if_t<is_same_v<T, bool>>>'
153 
154   const TemplateParameterList *TemplateParams =
155       FunctionTemplate->getTemplateParameters();
156   if (TemplateParams->size() == 0)
157     return {};
158 
159   const NamedDecl *LastParam =
160       TemplateParams->getParam(TemplateParams->size() - 1);
161   if (const auto *LastTemplateParam =
162           dyn_cast<NonTypeTemplateParmDecl>(LastParam)) {
163 
164     if (!LastTemplateParam->hasDefaultArgument() ||
165         !LastTemplateParam->getName().empty())
166       return {};
167 
168     return {matchEnableIfSpecialization(
169                 LastTemplateParam->getTypeSourceInfo()->getTypeLoc()),
170             LastTemplateParam};
171   } else if (const auto *LastTemplateParam =
172                  dyn_cast<TemplateTypeParmDecl>(LastParam)) {
173     if (LastTemplateParam->hasDefaultArgument() &&
174         LastTemplateParam->getIdentifier() == nullptr) {
175       return {matchEnableIfSpecialization(
176                   LastTemplateParam->getDefaultArgumentInfo()->getTypeLoc()),
177               LastTemplateParam};
178     }
179   }
180   return {};
181 }
182 
183 template <typename T>
184 static SourceLocation getRAngleFileLoc(const SourceManager &SM,
185                                        const T &Element) {
186   // getFileLoc handles the case where the RAngle loc is part of a synthesized
187   // '>>', which ends up allocating a 'scratch space' buffer in the source
188   // manager.
189   return SM.getFileLoc(Element.getRAngleLoc());
190 }
191 
192 static SourceRange
193 getConditionRange(ASTContext &Context,
194                   const TemplateSpecializationTypeLoc &EnableIf) {
195   // TemplateArgumentLoc's SourceRange End is the location of the last token
196   // (per UnqualifiedId docs). E.g., in `enable_if<AAA && BBB>`, the End
197   // location will be the first 'B' in 'BBB'.
198   const LangOptions &LangOpts = Context.getLangOpts();
199   const SourceManager &SM = Context.getSourceManager();
200   if (EnableIf.getNumArgs() > 1) {
201     TemplateArgumentLoc NextArg = EnableIf.getArgLoc(1);
202     return {EnableIf.getLAngleLoc().getLocWithOffset(1),
203             utils::lexer::findPreviousTokenKind(
204                 NextArg.getSourceRange().getBegin(), SM, LangOpts, tok::comma)};
205   }
206 
207   return {EnableIf.getLAngleLoc().getLocWithOffset(1),
208           getRAngleFileLoc(SM, EnableIf)};
209 }
210 
211 static SourceRange getTypeRange(ASTContext &Context,
212                                 const TemplateSpecializationTypeLoc &EnableIf) {
213   TemplateArgumentLoc Arg = EnableIf.getArgLoc(1);
214   const LangOptions &LangOpts = Context.getLangOpts();
215   const SourceManager &SM = Context.getSourceManager();
216   return {utils::lexer::findPreviousTokenKind(Arg.getSourceRange().getBegin(),
217                                               SM, LangOpts, tok::comma)
218               .getLocWithOffset(1),
219           getRAngleFileLoc(SM, EnableIf)};
220 }
221 
222 // Returns the original source text of the second argument of a call to
223 // enable_if_t. E.g., in enable_if_t<Condition, TheType>, this function
224 // returns 'TheType'.
225 static std::optional<StringRef>
226 getTypeText(ASTContext &Context,
227             const TemplateSpecializationTypeLoc &EnableIf) {
228   if (EnableIf.getNumArgs() > 1) {
229     const LangOptions &LangOpts = Context.getLangOpts();
230     const SourceManager &SM = Context.getSourceManager();
231     bool Invalid = false;
232     StringRef Text = Lexer::getSourceText(CharSourceRange::getCharRange(
233                                               getTypeRange(Context, EnableIf)),
234                                           SM, LangOpts, &Invalid)
235                          .trim();
236     if (Invalid)
237       return std::nullopt;
238 
239     return Text;
240   }
241 
242   return "void";
243 }
244 
245 static std::optional<SourceLocation>
246 findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context) {
247   SourceManager &SM = Context.getSourceManager();
248   const LangOptions &LangOpts = Context.getLangOpts();
249 
250   if (const auto *Constructor = dyn_cast<CXXConstructorDecl>(Function)) {
251     for (const CXXCtorInitializer *Init : Constructor->inits()) {
252       if (Init->getSourceOrder() == 0)
253         return utils::lexer::findPreviousTokenKind(Init->getSourceLocation(),
254                                                    SM, LangOpts, tok::colon);
255     }
256     if (Constructor->init_begin() != Constructor->init_end())
257       return std::nullopt;
258   }
259   if (Function->isDeleted()) {
260     SourceLocation FunctionEnd = Function->getSourceRange().getEnd();
261     return utils::lexer::findNextAnyTokenKind(FunctionEnd, SM, LangOpts,
262                                               tok::equal, tok::equal);
263   }
264   const Stmt *Body = Function->getBody();
265   if (!Body)
266     return std::nullopt;
267 
268   return Body->getBeginLoc();
269 }
270 
271 bool isPrimaryExpression(const Expr *Expression) {
272   // This function is an incomplete approximation of checking whether
273   // an Expr is a primary expression. In particular, if this function
274   // returns true, the expression is a primary expression. The converse
275   // is not necessarily true.
276 
277   if (const auto *Cast = dyn_cast<ImplicitCastExpr>(Expression))
278     Expression = Cast->getSubExprAsWritten();
279   if (isa<ParenExpr, DependentScopeDeclRefExpr>(Expression))
280     return true;
281 
282   return false;
283 }
284 
285 // Return the original source text of an enable_if_t condition, i.e., the
286 // first template argument). For example, in
287 // 'enable_if_t<FirstCondition || SecondCondition, AType>', the text
288 // the text 'FirstCondition || SecondCondition' is returned.
289 static std::optional<std::string> getConditionText(const Expr *ConditionExpr,
290                                                    SourceRange ConditionRange,
291                                                    ASTContext &Context) {
292   SourceManager &SM = Context.getSourceManager();
293   const LangOptions &LangOpts = Context.getLangOpts();
294 
295   SourceLocation PrevTokenLoc = ConditionRange.getEnd();
296   if (PrevTokenLoc.isInvalid())
297     return std::nullopt;
298 
299   const bool SkipComments = false;
300   Token PrevToken;
301   std::tie(PrevToken, PrevTokenLoc) = utils::lexer::getPreviousTokenAndStart(
302       PrevTokenLoc, SM, LangOpts, SkipComments);
303   bool EndsWithDoubleSlash =
304       PrevToken.is(tok::comment) &&
305       Lexer::getSourceText(CharSourceRange::getCharRange(
306                                PrevTokenLoc, PrevTokenLoc.getLocWithOffset(2)),
307                            SM, LangOpts) == "//";
308 
309   bool Invalid = false;
310   llvm::StringRef ConditionText = Lexer::getSourceText(
311       CharSourceRange::getCharRange(ConditionRange), SM, LangOpts, &Invalid);
312   if (Invalid)
313     return std::nullopt;
314 
315   auto AddParens = [&](llvm::StringRef Text) -> std::string {
316     if (isPrimaryExpression(ConditionExpr))
317       return Text.str();
318     return "(" + Text.str() + ")";
319   };
320 
321   if (EndsWithDoubleSlash)
322     return AddParens(ConditionText);
323   return AddParens(ConditionText.trim());
324 }
325 
326 // Handle functions that return enable_if_t, e.g.,
327 //   template <...>
328 //   enable_if_t<Condition, ReturnType> function();
329 //
330 // Return a vector of FixItHints if the code can be replaced with
331 // a C++20 requires clause. In the example above, returns FixItHints
332 // to result in
333 //   template <...>
334 //   ReturnType function() requires Condition {}
335 static std::vector<FixItHint> handleReturnType(const FunctionDecl *Function,
336                                                const TypeLoc &ReturnType,
337                                                const EnableIfData &EnableIf,
338                                                ASTContext &Context) {
339   TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
340 
341   SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
342 
343   std::optional<std::string> ConditionText = getConditionText(
344       EnableCondition.getSourceExpression(), ConditionRange, Context);
345   if (!ConditionText)
346     return {};
347 
348   std::optional<StringRef> TypeText = getTypeText(Context, EnableIf.Loc);
349   if (!TypeText)
350     return {};
351 
352   SmallVector<const Expr *, 3> ExistingConstraints;
353   Function->getAssociatedConstraints(ExistingConstraints);
354   if (!ExistingConstraints.empty()) {
355     // FIXME - Support adding new constraints to existing ones. Do we need to
356     // consider subsumption?
357     return {};
358   }
359 
360   std::optional<SourceLocation> ConstraintInsertionLoc =
361       findInsertionForConstraint(Function, Context);
362   if (!ConstraintInsertionLoc)
363     return {};
364 
365   std::vector<FixItHint> FixIts;
366   FixIts.push_back(FixItHint::CreateReplacement(
367       CharSourceRange::getTokenRange(EnableIf.Outer.getSourceRange()),
368       *TypeText));
369   FixIts.push_back(FixItHint::CreateInsertion(
370       *ConstraintInsertionLoc, "requires " + *ConditionText + " "));
371   return FixIts;
372 }
373 
374 // Handle enable_if_t in a trailing template parameter, e.g.,
375 //   template <..., enable_if_t<Condition, Type> = Type{}>
376 //   ReturnType function();
377 //
378 // Return a vector of FixItHints if the code can be replaced with
379 // a C++20 requires clause. In the example above, returns FixItHints
380 // to result in
381 //   template <...>
382 //   ReturnType function() requires Condition {}
383 static std::vector<FixItHint>
384 handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate,
385                            const FunctionDecl *Function,
386                            const Decl *LastTemplateParam,
387                            const EnableIfData &EnableIf, ASTContext &Context) {
388   SourceManager &SM = Context.getSourceManager();
389   const LangOptions &LangOpts = Context.getLangOpts();
390 
391   TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
392 
393   SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
394 
395   std::optional<std::string> ConditionText = getConditionText(
396       EnableCondition.getSourceExpression(), ConditionRange, Context);
397   if (!ConditionText)
398     return {};
399 
400   SmallVector<const Expr *, 3> ExistingConstraints;
401   Function->getAssociatedConstraints(ExistingConstraints);
402   if (!ExistingConstraints.empty()) {
403     // FIXME - Support adding new constraints to existing ones. Do we need to
404     // consider subsumption?
405     return {};
406   }
407 
408   SourceRange RemovalRange;
409   const TemplateParameterList *TemplateParams =
410       FunctionTemplate->getTemplateParameters();
411   if (!TemplateParams || TemplateParams->size() == 0)
412     return {};
413 
414   if (TemplateParams->size() == 1) {
415     RemovalRange =
416         SourceRange(TemplateParams->getTemplateLoc(),
417                     getRAngleFileLoc(SM, *TemplateParams).getLocWithOffset(1));
418   } else {
419     RemovalRange =
420         SourceRange(utils::lexer::findPreviousTokenKind(
421                         LastTemplateParam->getSourceRange().getBegin(), SM,
422                         LangOpts, tok::comma),
423                     getRAngleFileLoc(SM, *TemplateParams));
424   }
425 
426   std::optional<SourceLocation> ConstraintInsertionLoc =
427       findInsertionForConstraint(Function, Context);
428   if (!ConstraintInsertionLoc)
429     return {};
430 
431   std::vector<FixItHint> FixIts;
432   FixIts.push_back(
433       FixItHint::CreateRemoval(CharSourceRange::getCharRange(RemovalRange)));
434   FixIts.push_back(FixItHint::CreateInsertion(
435       *ConstraintInsertionLoc, "requires " + *ConditionText + " "));
436   return FixIts;
437 }
438 
439 void UseConstraintsCheck::check(const MatchFinder::MatchResult &Result) {
440   const auto *FunctionTemplate =
441       Result.Nodes.getNodeAs<FunctionTemplateDecl>("functionTemplate");
442   const auto *Function = Result.Nodes.getNodeAs<FunctionDecl>("function");
443   const auto *ReturnType = Result.Nodes.getNodeAs<TypeLoc>("return");
444   if (!FunctionTemplate || !Function || !ReturnType)
445     return;
446 
447   // Check for
448   //
449   //   Case 1. Return type of function
450   //
451   //     template <...>
452   //     enable_if_t<Condition, ReturnType>::type function() {}
453   //
454   //   Case 2. Trailing template parameter
455   //
456   //     template <..., enable_if_t<Condition, Type> = Type{}>
457   //     ReturnType function() {}
458   //
459   //     or
460   //
461   //     template <..., typename = enable_if_t<Condition, void>>
462   //     ReturnType function() {}
463   //
464 
465   // Case 1. Return type of function
466   if (auto EnableIf = matchEnableIfSpecialization(*ReturnType)) {
467     diag(ReturnType->getBeginLoc(),
468          "use C++20 requires constraints instead of enable_if")
469         << handleReturnType(Function, *ReturnType, *EnableIf, *Result.Context);
470     return;
471   }
472 
473   // Case 2. Trailing template parameter
474   if (auto [EnableIf, LastTemplateParam] =
475           matchTrailingTemplateParam(FunctionTemplate);
476       EnableIf && LastTemplateParam) {
477     diag(LastTemplateParam->getSourceRange().getBegin(),
478          "use C++20 requires constraints instead of enable_if")
479         << handleTrailingTemplateType(FunctionTemplate, Function,
480                                       LastTemplateParam, *EnableIf,
481                                       *Result.Context);
482     return;
483   }
484 }
485 
486 } // namespace clang::tidy::modernize
487