xref: /llvm-project/clang-tools-extra/clang-tidy/modernize/UseConstraintsCheck.cpp (revision ba3447601c435bb2b24ad9e3c8d146c578f00568)
1 //===--- UseConstraintsCheck.cpp - clang-tidy -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "UseConstraintsCheck.h"
10 #include "clang/AST/ASTContext.h"
11 #include "clang/ASTMatchers/ASTMatchFinder.h"
12 #include "clang/Lex/Lexer.h"
13 
14 #include "../utils/LexerUtils.h"
15 
16 #include <optional>
17 #include <utility>
18 
19 using namespace clang::ast_matchers;
20 
21 namespace clang::tidy::modernize {
22 
23 struct EnableIfData {
24   TemplateSpecializationTypeLoc Loc;
25   TypeLoc Outer;
26 };
27 
28 namespace {
29 AST_MATCHER(FunctionDecl, hasOtherDeclarations) {
30   auto It = Node.redecls_begin();
31   auto EndIt = Node.redecls_end();
32 
33   if (It == EndIt)
34     return false;
35 
36   ++It;
37   return It != EndIt;
38 }
39 } // namespace
40 
41 void UseConstraintsCheck::registerMatchers(MatchFinder *Finder) {
42   Finder->addMatcher(
43       functionTemplateDecl(
44           // Skip external libraries included as system headers
45           unless(isExpansionInSystemHeader()),
46           has(functionDecl(unless(hasOtherDeclarations()), isDefinition(),
47                            hasReturnTypeLoc(typeLoc().bind("return")))
48                   .bind("function")))
49           .bind("functionTemplate"),
50       this);
51 }
52 
53 static std::optional<TemplateSpecializationTypeLoc>
54 matchEnableIfSpecializationImplTypename(TypeLoc TheType) {
55   if (const auto Dep = TheType.getAs<DependentNameTypeLoc>()) {
56     const IdentifierInfo *Identifier = Dep.getTypePtr()->getIdentifier();
57     if (!Identifier || Identifier->getName() != "type" ||
58         Dep.getTypePtr()->getKeyword() != ElaboratedTypeKeyword::Typename) {
59       return std::nullopt;
60     }
61     TheType = Dep.getQualifierLoc().getTypeLoc();
62     if (TheType.isNull())
63       return std::nullopt;
64   }
65 
66   if (const auto SpecializationLoc =
67           TheType.getAs<TemplateSpecializationTypeLoc>()) {
68 
69     const auto *Specialization =
70         dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
71     if (!Specialization)
72       return std::nullopt;
73 
74     const TemplateDecl *TD =
75         Specialization->getTemplateName().getAsTemplateDecl();
76     if (!TD || TD->getName() != "enable_if")
77       return std::nullopt;
78 
79     int NumArgs = SpecializationLoc.getNumArgs();
80     if (NumArgs != 1 && NumArgs != 2)
81       return std::nullopt;
82 
83     return SpecializationLoc;
84   }
85   return std::nullopt;
86 }
87 
88 static std::optional<TemplateSpecializationTypeLoc>
89 matchEnableIfSpecializationImplTrait(TypeLoc TheType) {
90   if (const auto Elaborated = TheType.getAs<ElaboratedTypeLoc>())
91     TheType = Elaborated.getNamedTypeLoc();
92 
93   if (const auto SpecializationLoc =
94           TheType.getAs<TemplateSpecializationTypeLoc>()) {
95 
96     const auto *Specialization =
97         dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
98     if (!Specialization)
99       return std::nullopt;
100 
101     const TemplateDecl *TD =
102         Specialization->getTemplateName().getAsTemplateDecl();
103     if (!TD || TD->getName() != "enable_if_t")
104       return std::nullopt;
105 
106     if (!Specialization->isTypeAlias())
107       return std::nullopt;
108 
109     if (const auto *AliasedType =
110             dyn_cast<DependentNameType>(Specialization->getAliasedType())) {
111       if (AliasedType->getIdentifier()->getName() != "type" ||
112           AliasedType->getKeyword() != ElaboratedTypeKeyword::Typename) {
113         return std::nullopt;
114       }
115     } else {
116       return std::nullopt;
117     }
118     int NumArgs = SpecializationLoc.getNumArgs();
119     if (NumArgs != 1 && NumArgs != 2)
120       return std::nullopt;
121 
122     return SpecializationLoc;
123   }
124   return std::nullopt;
125 }
126 
127 static std::optional<TemplateSpecializationTypeLoc>
128 matchEnableIfSpecializationImpl(TypeLoc TheType) {
129   if (auto EnableIf = matchEnableIfSpecializationImplTypename(TheType))
130     return EnableIf;
131   return matchEnableIfSpecializationImplTrait(TheType);
132 }
133 
134 static std::optional<EnableIfData>
135 matchEnableIfSpecialization(TypeLoc TheType) {
136   if (const auto Pointer = TheType.getAs<PointerTypeLoc>())
137     TheType = Pointer.getPointeeLoc();
138   else if (const auto Reference = TheType.getAs<ReferenceTypeLoc>())
139     TheType = Reference.getPointeeLoc();
140   if (const auto Qualified = TheType.getAs<QualifiedTypeLoc>())
141     TheType = Qualified.getUnqualifiedLoc();
142 
143   if (auto EnableIf = matchEnableIfSpecializationImpl(TheType))
144     return EnableIfData{std::move(*EnableIf), TheType};
145   return std::nullopt;
146 }
147 
148 static std::pair<std::optional<EnableIfData>, const Decl *>
149 matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate) {
150   // For non-type trailing param, match very specifically
151   // 'template <..., enable_if_type<Condition, Type> = Default>' where
152   // enable_if_type is 'enable_if' or 'enable_if_t'. E.g., 'template <typename
153   // T, enable_if_t<is_same_v<T, bool>, int*> = nullptr>
154   //
155   // Otherwise, match a trailing default type arg.
156   // E.g., 'template <typename T, typename = enable_if_t<is_same_v<T, bool>>>'
157 
158   const TemplateParameterList *TemplateParams =
159       FunctionTemplate->getTemplateParameters();
160   if (TemplateParams->size() == 0)
161     return {};
162 
163   const NamedDecl *LastParam =
164       TemplateParams->getParam(TemplateParams->size() - 1);
165   if (const auto *LastTemplateParam =
166           dyn_cast<NonTypeTemplateParmDecl>(LastParam)) {
167 
168     if (!LastTemplateParam->hasDefaultArgument() ||
169         !LastTemplateParam->getName().empty())
170       return {};
171 
172     return {matchEnableIfSpecialization(
173                 LastTemplateParam->getTypeSourceInfo()->getTypeLoc()),
174             LastTemplateParam};
175   }
176   if (const auto *LastTemplateParam =
177           dyn_cast<TemplateTypeParmDecl>(LastParam)) {
178     if (LastTemplateParam->hasDefaultArgument() &&
179         LastTemplateParam->getIdentifier() == nullptr) {
180       return {matchEnableIfSpecialization(
181                   LastTemplateParam->getDefaultArgumentInfo()->getTypeLoc()),
182               LastTemplateParam};
183     }
184   }
185   return {};
186 }
187 
188 template <typename T>
189 static SourceLocation getRAngleFileLoc(const SourceManager &SM,
190                                        const T &Element) {
191   // getFileLoc handles the case where the RAngle loc is part of a synthesized
192   // '>>', which ends up allocating a 'scratch space' buffer in the source
193   // manager.
194   return SM.getFileLoc(Element.getRAngleLoc());
195 }
196 
197 static SourceRange
198 getConditionRange(ASTContext &Context,
199                   const TemplateSpecializationTypeLoc &EnableIf) {
200   // TemplateArgumentLoc's SourceRange End is the location of the last token
201   // (per UnqualifiedId docs). E.g., in `enable_if<AAA && BBB>`, the End
202   // location will be the first 'B' in 'BBB'.
203   const LangOptions &LangOpts = Context.getLangOpts();
204   const SourceManager &SM = Context.getSourceManager();
205   if (EnableIf.getNumArgs() > 1) {
206     TemplateArgumentLoc NextArg = EnableIf.getArgLoc(1);
207     return {EnableIf.getLAngleLoc().getLocWithOffset(1),
208             utils::lexer::findPreviousTokenKind(
209                 NextArg.getSourceRange().getBegin(), SM, LangOpts, tok::comma)};
210   }
211 
212   return {EnableIf.getLAngleLoc().getLocWithOffset(1),
213           getRAngleFileLoc(SM, EnableIf)};
214 }
215 
216 static SourceRange getTypeRange(ASTContext &Context,
217                                 const TemplateSpecializationTypeLoc &EnableIf) {
218   TemplateArgumentLoc Arg = EnableIf.getArgLoc(1);
219   const LangOptions &LangOpts = Context.getLangOpts();
220   const SourceManager &SM = Context.getSourceManager();
221   return {utils::lexer::findPreviousTokenKind(Arg.getSourceRange().getBegin(),
222                                               SM, LangOpts, tok::comma)
223               .getLocWithOffset(1),
224           getRAngleFileLoc(SM, EnableIf)};
225 }
226 
227 // Returns the original source text of the second argument of a call to
228 // enable_if_t. E.g., in enable_if_t<Condition, TheType>, this function
229 // returns 'TheType'.
230 static std::optional<StringRef>
231 getTypeText(ASTContext &Context,
232             const TemplateSpecializationTypeLoc &EnableIf) {
233   if (EnableIf.getNumArgs() > 1) {
234     const LangOptions &LangOpts = Context.getLangOpts();
235     const SourceManager &SM = Context.getSourceManager();
236     bool Invalid = false;
237     StringRef Text = Lexer::getSourceText(CharSourceRange::getCharRange(
238                                               getTypeRange(Context, EnableIf)),
239                                           SM, LangOpts, &Invalid)
240                          .trim();
241     if (Invalid)
242       return std::nullopt;
243 
244     return Text;
245   }
246 
247   return "void";
248 }
249 
250 static std::optional<SourceLocation>
251 findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context) {
252   SourceManager &SM = Context.getSourceManager();
253   const LangOptions &LangOpts = Context.getLangOpts();
254 
255   if (const auto *Constructor = dyn_cast<CXXConstructorDecl>(Function)) {
256     for (const CXXCtorInitializer *Init : Constructor->inits()) {
257       if (Init->getSourceOrder() == 0)
258         return utils::lexer::findPreviousTokenKind(Init->getSourceLocation(),
259                                                    SM, LangOpts, tok::colon);
260     }
261     if (!Constructor->inits().empty())
262       return std::nullopt;
263   }
264   if (Function->isDeleted()) {
265     SourceLocation FunctionEnd = Function->getSourceRange().getEnd();
266     return utils::lexer::findNextAnyTokenKind(FunctionEnd, SM, LangOpts,
267                                               tok::equal, tok::equal);
268   }
269   const Stmt *Body = Function->getBody();
270   if (!Body)
271     return std::nullopt;
272 
273   return Body->getBeginLoc();
274 }
275 
276 bool isPrimaryExpression(const Expr *Expression) {
277   // This function is an incomplete approximation of checking whether
278   // an Expr is a primary expression. In particular, if this function
279   // returns true, the expression is a primary expression. The converse
280   // is not necessarily true.
281 
282   if (const auto *Cast = dyn_cast<ImplicitCastExpr>(Expression))
283     Expression = Cast->getSubExprAsWritten();
284   if (isa<ParenExpr, DependentScopeDeclRefExpr>(Expression))
285     return true;
286 
287   return false;
288 }
289 
290 // Return the original source text of an enable_if_t condition, i.e., the
291 // first template argument). For example, in
292 // 'enable_if_t<FirstCondition || SecondCondition, AType>', the text
293 // the text 'FirstCondition || SecondCondition' is returned.
294 static std::optional<std::string> getConditionText(const Expr *ConditionExpr,
295                                                    SourceRange ConditionRange,
296                                                    ASTContext &Context) {
297   SourceManager &SM = Context.getSourceManager();
298   const LangOptions &LangOpts = Context.getLangOpts();
299 
300   SourceLocation PrevTokenLoc = ConditionRange.getEnd();
301   if (PrevTokenLoc.isInvalid())
302     return std::nullopt;
303 
304   const bool SkipComments = false;
305   Token PrevToken;
306   std::tie(PrevToken, PrevTokenLoc) = utils::lexer::getPreviousTokenAndStart(
307       PrevTokenLoc, SM, LangOpts, SkipComments);
308   bool EndsWithDoubleSlash =
309       PrevToken.is(tok::comment) &&
310       Lexer::getSourceText(CharSourceRange::getCharRange(
311                                PrevTokenLoc, PrevTokenLoc.getLocWithOffset(2)),
312                            SM, LangOpts) == "//";
313 
314   bool Invalid = false;
315   llvm::StringRef ConditionText = Lexer::getSourceText(
316       CharSourceRange::getCharRange(ConditionRange), SM, LangOpts, &Invalid);
317   if (Invalid)
318     return std::nullopt;
319 
320   auto AddParens = [&](llvm::StringRef Text) -> std::string {
321     if (isPrimaryExpression(ConditionExpr))
322       return Text.str();
323     return "(" + Text.str() + ")";
324   };
325 
326   if (EndsWithDoubleSlash)
327     return AddParens(ConditionText);
328   return AddParens(ConditionText.trim());
329 }
330 
331 // Handle functions that return enable_if_t, e.g.,
332 //   template <...>
333 //   enable_if_t<Condition, ReturnType> function();
334 //
335 // Return a vector of FixItHints if the code can be replaced with
336 // a C++20 requires clause. In the example above, returns FixItHints
337 // to result in
338 //   template <...>
339 //   ReturnType function() requires Condition {}
340 static std::vector<FixItHint> handleReturnType(const FunctionDecl *Function,
341                                                const TypeLoc &ReturnType,
342                                                const EnableIfData &EnableIf,
343                                                ASTContext &Context) {
344   TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
345 
346   SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
347 
348   std::optional<std::string> ConditionText = getConditionText(
349       EnableCondition.getSourceExpression(), ConditionRange, Context);
350   if (!ConditionText)
351     return {};
352 
353   std::optional<StringRef> TypeText = getTypeText(Context, EnableIf.Loc);
354   if (!TypeText)
355     return {};
356 
357   SmallVector<const Expr *, 3> ExistingConstraints;
358   Function->getAssociatedConstraints(ExistingConstraints);
359   if (!ExistingConstraints.empty()) {
360     // FIXME - Support adding new constraints to existing ones. Do we need to
361     // consider subsumption?
362     return {};
363   }
364 
365   std::optional<SourceLocation> ConstraintInsertionLoc =
366       findInsertionForConstraint(Function, Context);
367   if (!ConstraintInsertionLoc)
368     return {};
369 
370   std::vector<FixItHint> FixIts;
371   FixIts.push_back(FixItHint::CreateReplacement(
372       CharSourceRange::getTokenRange(EnableIf.Outer.getSourceRange()),
373       *TypeText));
374   FixIts.push_back(FixItHint::CreateInsertion(
375       *ConstraintInsertionLoc, "requires " + *ConditionText + " "));
376   return FixIts;
377 }
378 
379 // Handle enable_if_t in a trailing template parameter, e.g.,
380 //   template <..., enable_if_t<Condition, Type> = Type{}>
381 //   ReturnType function();
382 //
383 // Return a vector of FixItHints if the code can be replaced with
384 // a C++20 requires clause. In the example above, returns FixItHints
385 // to result in
386 //   template <...>
387 //   ReturnType function() requires Condition {}
388 static std::vector<FixItHint>
389 handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate,
390                            const FunctionDecl *Function,
391                            const Decl *LastTemplateParam,
392                            const EnableIfData &EnableIf, ASTContext &Context) {
393   SourceManager &SM = Context.getSourceManager();
394   const LangOptions &LangOpts = Context.getLangOpts();
395 
396   TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
397 
398   SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
399 
400   std::optional<std::string> ConditionText = getConditionText(
401       EnableCondition.getSourceExpression(), ConditionRange, Context);
402   if (!ConditionText)
403     return {};
404 
405   SmallVector<const Expr *, 3> ExistingConstraints;
406   Function->getAssociatedConstraints(ExistingConstraints);
407   if (!ExistingConstraints.empty()) {
408     // FIXME - Support adding new constraints to existing ones. Do we need to
409     // consider subsumption?
410     return {};
411   }
412 
413   SourceRange RemovalRange;
414   const TemplateParameterList *TemplateParams =
415       FunctionTemplate->getTemplateParameters();
416   if (!TemplateParams || TemplateParams->size() == 0)
417     return {};
418 
419   if (TemplateParams->size() == 1) {
420     RemovalRange =
421         SourceRange(TemplateParams->getTemplateLoc(),
422                     getRAngleFileLoc(SM, *TemplateParams).getLocWithOffset(1));
423   } else {
424     RemovalRange =
425         SourceRange(utils::lexer::findPreviousTokenKind(
426                         LastTemplateParam->getSourceRange().getBegin(), SM,
427                         LangOpts, tok::comma),
428                     getRAngleFileLoc(SM, *TemplateParams));
429   }
430 
431   std::optional<SourceLocation> ConstraintInsertionLoc =
432       findInsertionForConstraint(Function, Context);
433   if (!ConstraintInsertionLoc)
434     return {};
435 
436   std::vector<FixItHint> FixIts;
437   FixIts.push_back(
438       FixItHint::CreateRemoval(CharSourceRange::getCharRange(RemovalRange)));
439   FixIts.push_back(FixItHint::CreateInsertion(
440       *ConstraintInsertionLoc, "requires " + *ConditionText + " "));
441   return FixIts;
442 }
443 
444 void UseConstraintsCheck::check(const MatchFinder::MatchResult &Result) {
445   const auto *FunctionTemplate =
446       Result.Nodes.getNodeAs<FunctionTemplateDecl>("functionTemplate");
447   const auto *Function = Result.Nodes.getNodeAs<FunctionDecl>("function");
448   const auto *ReturnType = Result.Nodes.getNodeAs<TypeLoc>("return");
449   if (!FunctionTemplate || !Function || !ReturnType)
450     return;
451 
452   // Check for
453   //
454   //   Case 1. Return type of function
455   //
456   //     template <...>
457   //     enable_if_t<Condition, ReturnType>::type function() {}
458   //
459   //   Case 2. Trailing template parameter
460   //
461   //     template <..., enable_if_t<Condition, Type> = Type{}>
462   //     ReturnType function() {}
463   //
464   //     or
465   //
466   //     template <..., typename = enable_if_t<Condition, void>>
467   //     ReturnType function() {}
468   //
469 
470   // Case 1. Return type of function
471   if (auto EnableIf = matchEnableIfSpecialization(*ReturnType)) {
472     diag(ReturnType->getBeginLoc(),
473          "use C++20 requires constraints instead of enable_if")
474         << handleReturnType(Function, *ReturnType, *EnableIf, *Result.Context);
475     return;
476   }
477 
478   // Case 2. Trailing template parameter
479   if (auto [EnableIf, LastTemplateParam] =
480           matchTrailingTemplateParam(FunctionTemplate);
481       EnableIf && LastTemplateParam) {
482     diag(LastTemplateParam->getSourceRange().getBegin(),
483          "use C++20 requires constraints instead of enable_if")
484         << handleTrailingTemplateType(FunctionTemplate, Function,
485                                       LastTemplateParam, *EnableIf,
486                                       *Result.Context);
487     return;
488   }
489 }
490 
491 } // namespace clang::tidy::modernize
492