xref: /llvm-project/clang-tools-extra/clang-tidy/modernize/UseConstraintsCheck.cpp (revision e42b799bb28815431f2c5a95f7e13fde3f1b36a1)
1893d53d1SChris Cotter //===--- UseConstraintsCheck.cpp - clang-tidy -----------------------------===//
2893d53d1SChris Cotter //
3893d53d1SChris Cotter // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4893d53d1SChris Cotter // See https://llvm.org/LICENSE.txt for license information.
5893d53d1SChris Cotter // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6893d53d1SChris Cotter //
7893d53d1SChris Cotter //===----------------------------------------------------------------------===//
8893d53d1SChris Cotter 
9893d53d1SChris Cotter #include "UseConstraintsCheck.h"
10893d53d1SChris Cotter #include "clang/AST/ASTContext.h"
11893d53d1SChris Cotter #include "clang/ASTMatchers/ASTMatchFinder.h"
12893d53d1SChris Cotter #include "clang/Lex/Lexer.h"
13893d53d1SChris Cotter 
14893d53d1SChris Cotter #include "../utils/LexerUtils.h"
15893d53d1SChris Cotter 
16893d53d1SChris Cotter #include <optional>
17893d53d1SChris Cotter #include <utility>
18893d53d1SChris Cotter 
19893d53d1SChris Cotter using namespace clang::ast_matchers;
20893d53d1SChris Cotter 
21893d53d1SChris Cotter namespace clang::tidy::modernize {
22893d53d1SChris Cotter 
23893d53d1SChris Cotter struct EnableIfData {
24893d53d1SChris Cotter   TemplateSpecializationTypeLoc Loc;
25893d53d1SChris Cotter   TypeLoc Outer;
26893d53d1SChris Cotter };
27893d53d1SChris Cotter 
28893d53d1SChris Cotter namespace {
AST_MATCHER(FunctionDecl,hasOtherDeclarations)29893d53d1SChris Cotter AST_MATCHER(FunctionDecl, hasOtherDeclarations) {
30893d53d1SChris Cotter   auto It = Node.redecls_begin();
31893d53d1SChris Cotter   auto EndIt = Node.redecls_end();
32893d53d1SChris Cotter 
33893d53d1SChris Cotter   if (It == EndIt)
34893d53d1SChris Cotter     return false;
35893d53d1SChris Cotter 
36893d53d1SChris Cotter   ++It;
37893d53d1SChris Cotter   return It != EndIt;
38893d53d1SChris Cotter }
39893d53d1SChris Cotter } // namespace
40893d53d1SChris Cotter 
registerMatchers(MatchFinder * Finder)41893d53d1SChris Cotter void UseConstraintsCheck::registerMatchers(MatchFinder *Finder) {
42893d53d1SChris Cotter   Finder->addMatcher(
43893d53d1SChris Cotter       functionTemplateDecl(
44ba344760SPiotr Zegar           // Skip external libraries included as system headers
45ba344760SPiotr Zegar           unless(isExpansionInSystemHeader()),
46893d53d1SChris Cotter           has(functionDecl(unless(hasOtherDeclarations()), isDefinition(),
47893d53d1SChris Cotter                            hasReturnTypeLoc(typeLoc().bind("return")))
48893d53d1SChris Cotter                   .bind("function")))
49893d53d1SChris Cotter           .bind("functionTemplate"),
50893d53d1SChris Cotter       this);
51893d53d1SChris Cotter }
52893d53d1SChris Cotter 
53893d53d1SChris Cotter static std::optional<TemplateSpecializationTypeLoc>
matchEnableIfSpecializationImplTypename(TypeLoc TheType)54893d53d1SChris Cotter matchEnableIfSpecializationImplTypename(TypeLoc TheType) {
55893d53d1SChris Cotter   if (const auto Dep = TheType.getAs<DependentNameTypeLoc>()) {
56893d53d1SChris Cotter     const IdentifierInfo *Identifier = Dep.getTypePtr()->getIdentifier();
57893d53d1SChris Cotter     if (!Identifier || Identifier->getName() != "type" ||
584ad2ada5SVlad Serebrennikov         Dep.getTypePtr()->getKeyword() != ElaboratedTypeKeyword::Typename) {
59893d53d1SChris Cotter       return std::nullopt;
60893d53d1SChris Cotter     }
61893d53d1SChris Cotter     TheType = Dep.getQualifierLoc().getTypeLoc();
62ba344760SPiotr Zegar     if (TheType.isNull())
63ba344760SPiotr Zegar       return std::nullopt;
64893d53d1SChris Cotter   }
65893d53d1SChris Cotter 
66893d53d1SChris Cotter   if (const auto SpecializationLoc =
67893d53d1SChris Cotter           TheType.getAs<TemplateSpecializationTypeLoc>()) {
68893d53d1SChris Cotter 
69893d53d1SChris Cotter     const auto *Specialization =
70893d53d1SChris Cotter         dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
71893d53d1SChris Cotter     if (!Specialization)
72893d53d1SChris Cotter       return std::nullopt;
73893d53d1SChris Cotter 
74893d53d1SChris Cotter     const TemplateDecl *TD =
75893d53d1SChris Cotter         Specialization->getTemplateName().getAsTemplateDecl();
76893d53d1SChris Cotter     if (!TD || TD->getName() != "enable_if")
77893d53d1SChris Cotter       return std::nullopt;
78893d53d1SChris Cotter 
79893d53d1SChris Cotter     int NumArgs = SpecializationLoc.getNumArgs();
80893d53d1SChris Cotter     if (NumArgs != 1 && NumArgs != 2)
81893d53d1SChris Cotter       return std::nullopt;
82893d53d1SChris Cotter 
83893d53d1SChris Cotter     return SpecializationLoc;
84893d53d1SChris Cotter   }
85893d53d1SChris Cotter   return std::nullopt;
86893d53d1SChris Cotter }
87893d53d1SChris Cotter 
88893d53d1SChris Cotter static std::optional<TemplateSpecializationTypeLoc>
matchEnableIfSpecializationImplTrait(TypeLoc TheType)89893d53d1SChris Cotter matchEnableIfSpecializationImplTrait(TypeLoc TheType) {
90893d53d1SChris Cotter   if (const auto Elaborated = TheType.getAs<ElaboratedTypeLoc>())
91893d53d1SChris Cotter     TheType = Elaborated.getNamedTypeLoc();
92893d53d1SChris Cotter 
93893d53d1SChris Cotter   if (const auto SpecializationLoc =
94893d53d1SChris Cotter           TheType.getAs<TemplateSpecializationTypeLoc>()) {
95893d53d1SChris Cotter 
96893d53d1SChris Cotter     const auto *Specialization =
97893d53d1SChris Cotter         dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
98893d53d1SChris Cotter     if (!Specialization)
99893d53d1SChris Cotter       return std::nullopt;
100893d53d1SChris Cotter 
101893d53d1SChris Cotter     const TemplateDecl *TD =
102893d53d1SChris Cotter         Specialization->getTemplateName().getAsTemplateDecl();
103893d53d1SChris Cotter     if (!TD || TD->getName() != "enable_if_t")
104893d53d1SChris Cotter       return std::nullopt;
105893d53d1SChris Cotter 
106893d53d1SChris Cotter     if (!Specialization->isTypeAlias())
107893d53d1SChris Cotter       return std::nullopt;
108893d53d1SChris Cotter 
109893d53d1SChris Cotter     if (const auto *AliasedType =
110893d53d1SChris Cotter             dyn_cast<DependentNameType>(Specialization->getAliasedType())) {
111893d53d1SChris Cotter       if (AliasedType->getIdentifier()->getName() != "type" ||
1124ad2ada5SVlad Serebrennikov           AliasedType->getKeyword() != ElaboratedTypeKeyword::Typename) {
113893d53d1SChris Cotter         return std::nullopt;
114893d53d1SChris Cotter       }
115893d53d1SChris Cotter     } else {
116893d53d1SChris Cotter       return std::nullopt;
117893d53d1SChris Cotter     }
118893d53d1SChris Cotter     int NumArgs = SpecializationLoc.getNumArgs();
119893d53d1SChris Cotter     if (NumArgs != 1 && NumArgs != 2)
120893d53d1SChris Cotter       return std::nullopt;
121893d53d1SChris Cotter 
122893d53d1SChris Cotter     return SpecializationLoc;
123893d53d1SChris Cotter   }
124893d53d1SChris Cotter   return std::nullopt;
125893d53d1SChris Cotter }
126893d53d1SChris Cotter 
127893d53d1SChris Cotter static std::optional<TemplateSpecializationTypeLoc>
matchEnableIfSpecializationImpl(TypeLoc TheType)128893d53d1SChris Cotter matchEnableIfSpecializationImpl(TypeLoc TheType) {
129893d53d1SChris Cotter   if (auto EnableIf = matchEnableIfSpecializationImplTypename(TheType))
130893d53d1SChris Cotter     return EnableIf;
131893d53d1SChris Cotter   return matchEnableIfSpecializationImplTrait(TheType);
132893d53d1SChris Cotter }
133893d53d1SChris Cotter 
134893d53d1SChris Cotter static std::optional<EnableIfData>
matchEnableIfSpecialization(TypeLoc TheType)135893d53d1SChris Cotter matchEnableIfSpecialization(TypeLoc TheType) {
136893d53d1SChris Cotter   if (const auto Pointer = TheType.getAs<PointerTypeLoc>())
137893d53d1SChris Cotter     TheType = Pointer.getPointeeLoc();
138893d53d1SChris Cotter   else if (const auto Reference = TheType.getAs<ReferenceTypeLoc>())
139893d53d1SChris Cotter     TheType = Reference.getPointeeLoc();
140893d53d1SChris Cotter   if (const auto Qualified = TheType.getAs<QualifiedTypeLoc>())
141893d53d1SChris Cotter     TheType = Qualified.getUnqualifiedLoc();
142893d53d1SChris Cotter 
143893d53d1SChris Cotter   if (auto EnableIf = matchEnableIfSpecializationImpl(TheType))
144893d53d1SChris Cotter     return EnableIfData{std::move(*EnableIf), TheType};
145893d53d1SChris Cotter   return std::nullopt;
146893d53d1SChris Cotter }
147893d53d1SChris Cotter 
148893d53d1SChris Cotter static std::pair<std::optional<EnableIfData>, const Decl *>
matchTrailingTemplateParam(const FunctionTemplateDecl * FunctionTemplate)149893d53d1SChris Cotter matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate) {
150893d53d1SChris Cotter   // For non-type trailing param, match very specifically
151893d53d1SChris Cotter   // 'template <..., enable_if_type<Condition, Type> = Default>' where
152893d53d1SChris Cotter   // enable_if_type is 'enable_if' or 'enable_if_t'. E.g., 'template <typename
153893d53d1SChris Cotter   // T, enable_if_t<is_same_v<T, bool>, int*> = nullptr>
154893d53d1SChris Cotter   //
155893d53d1SChris Cotter   // Otherwise, match a trailing default type arg.
156893d53d1SChris Cotter   // E.g., 'template <typename T, typename = enable_if_t<is_same_v<T, bool>>>'
157893d53d1SChris Cotter 
158893d53d1SChris Cotter   const TemplateParameterList *TemplateParams =
159893d53d1SChris Cotter       FunctionTemplate->getTemplateParameters();
160893d53d1SChris Cotter   if (TemplateParams->size() == 0)
161893d53d1SChris Cotter     return {};
162893d53d1SChris Cotter 
163893d53d1SChris Cotter   const NamedDecl *LastParam =
164893d53d1SChris Cotter       TemplateParams->getParam(TemplateParams->size() - 1);
165893d53d1SChris Cotter   if (const auto *LastTemplateParam =
166893d53d1SChris Cotter           dyn_cast<NonTypeTemplateParmDecl>(LastParam)) {
167893d53d1SChris Cotter 
168893d53d1SChris Cotter     if (!LastTemplateParam->hasDefaultArgument() ||
169893d53d1SChris Cotter         !LastTemplateParam->getName().empty())
170893d53d1SChris Cotter       return {};
171893d53d1SChris Cotter 
172893d53d1SChris Cotter     return {matchEnableIfSpecialization(
173893d53d1SChris Cotter                 LastTemplateParam->getTypeSourceInfo()->getTypeLoc()),
174893d53d1SChris Cotter             LastTemplateParam};
17501c8bf6fSPiotr Zegar   }
17601c8bf6fSPiotr Zegar   if (const auto *LastTemplateParam =
177893d53d1SChris Cotter           dyn_cast<TemplateTypeParmDecl>(LastParam)) {
178893d53d1SChris Cotter     if (LastTemplateParam->hasDefaultArgument() &&
179893d53d1SChris Cotter         LastTemplateParam->getIdentifier() == nullptr) {
180*e42b799bSMatheus Izvekov       return {
181*e42b799bSMatheus Izvekov           matchEnableIfSpecialization(LastTemplateParam->getDefaultArgument()
182*e42b799bSMatheus Izvekov                                           .getTypeSourceInfo()
183*e42b799bSMatheus Izvekov                                           ->getTypeLoc()),
184893d53d1SChris Cotter           LastTemplateParam};
185893d53d1SChris Cotter     }
186893d53d1SChris Cotter   }
187893d53d1SChris Cotter   return {};
188893d53d1SChris Cotter }
189893d53d1SChris Cotter 
190893d53d1SChris Cotter template <typename T>
getRAngleFileLoc(const SourceManager & SM,const T & Element)191893d53d1SChris Cotter static SourceLocation getRAngleFileLoc(const SourceManager &SM,
192893d53d1SChris Cotter                                        const T &Element) {
193893d53d1SChris Cotter   // getFileLoc handles the case where the RAngle loc is part of a synthesized
194893d53d1SChris Cotter   // '>>', which ends up allocating a 'scratch space' buffer in the source
195893d53d1SChris Cotter   // manager.
196893d53d1SChris Cotter   return SM.getFileLoc(Element.getRAngleLoc());
197893d53d1SChris Cotter }
198893d53d1SChris Cotter 
199893d53d1SChris Cotter static SourceRange
getConditionRange(ASTContext & Context,const TemplateSpecializationTypeLoc & EnableIf)200893d53d1SChris Cotter getConditionRange(ASTContext &Context,
201893d53d1SChris Cotter                   const TemplateSpecializationTypeLoc &EnableIf) {
202893d53d1SChris Cotter   // TemplateArgumentLoc's SourceRange End is the location of the last token
203893d53d1SChris Cotter   // (per UnqualifiedId docs). E.g., in `enable_if<AAA && BBB>`, the End
204893d53d1SChris Cotter   // location will be the first 'B' in 'BBB'.
205893d53d1SChris Cotter   const LangOptions &LangOpts = Context.getLangOpts();
206893d53d1SChris Cotter   const SourceManager &SM = Context.getSourceManager();
207893d53d1SChris Cotter   if (EnableIf.getNumArgs() > 1) {
208893d53d1SChris Cotter     TemplateArgumentLoc NextArg = EnableIf.getArgLoc(1);
209ec5f4be4SPiotr Zegar     return {EnableIf.getLAngleLoc().getLocWithOffset(1),
210ec5f4be4SPiotr Zegar             utils::lexer::findPreviousTokenKind(
211ec5f4be4SPiotr Zegar                 NextArg.getSourceRange().getBegin(), SM, LangOpts, tok::comma)};
212893d53d1SChris Cotter   }
213893d53d1SChris Cotter 
214ec5f4be4SPiotr Zegar   return {EnableIf.getLAngleLoc().getLocWithOffset(1),
215ec5f4be4SPiotr Zegar           getRAngleFileLoc(SM, EnableIf)};
216893d53d1SChris Cotter }
217893d53d1SChris Cotter 
getTypeRange(ASTContext & Context,const TemplateSpecializationTypeLoc & EnableIf)218893d53d1SChris Cotter static SourceRange getTypeRange(ASTContext &Context,
219893d53d1SChris Cotter                                 const TemplateSpecializationTypeLoc &EnableIf) {
220893d53d1SChris Cotter   TemplateArgumentLoc Arg = EnableIf.getArgLoc(1);
221893d53d1SChris Cotter   const LangOptions &LangOpts = Context.getLangOpts();
222893d53d1SChris Cotter   const SourceManager &SM = Context.getSourceManager();
223ec5f4be4SPiotr Zegar   return {utils::lexer::findPreviousTokenKind(Arg.getSourceRange().getBegin(),
224ec5f4be4SPiotr Zegar                                               SM, LangOpts, tok::comma)
225893d53d1SChris Cotter               .getLocWithOffset(1),
226ec5f4be4SPiotr Zegar           getRAngleFileLoc(SM, EnableIf)};
227893d53d1SChris Cotter }
228893d53d1SChris Cotter 
229893d53d1SChris Cotter // Returns the original source text of the second argument of a call to
230893d53d1SChris Cotter // enable_if_t. E.g., in enable_if_t<Condition, TheType>, this function
231893d53d1SChris Cotter // returns 'TheType'.
232893d53d1SChris Cotter static std::optional<StringRef>
getTypeText(ASTContext & Context,const TemplateSpecializationTypeLoc & EnableIf)233893d53d1SChris Cotter getTypeText(ASTContext &Context,
234893d53d1SChris Cotter             const TemplateSpecializationTypeLoc &EnableIf) {
235893d53d1SChris Cotter   if (EnableIf.getNumArgs() > 1) {
236893d53d1SChris Cotter     const LangOptions &LangOpts = Context.getLangOpts();
237893d53d1SChris Cotter     const SourceManager &SM = Context.getSourceManager();
238893d53d1SChris Cotter     bool Invalid = false;
239893d53d1SChris Cotter     StringRef Text = Lexer::getSourceText(CharSourceRange::getCharRange(
240893d53d1SChris Cotter                                               getTypeRange(Context, EnableIf)),
241893d53d1SChris Cotter                                           SM, LangOpts, &Invalid)
242893d53d1SChris Cotter                          .trim();
243893d53d1SChris Cotter     if (Invalid)
244893d53d1SChris Cotter       return std::nullopt;
245893d53d1SChris Cotter 
246893d53d1SChris Cotter     return Text;
247893d53d1SChris Cotter   }
248893d53d1SChris Cotter 
249893d53d1SChris Cotter   return "void";
250893d53d1SChris Cotter }
251893d53d1SChris Cotter 
252893d53d1SChris Cotter static std::optional<SourceLocation>
findInsertionForConstraint(const FunctionDecl * Function,ASTContext & Context)253893d53d1SChris Cotter findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context) {
254893d53d1SChris Cotter   SourceManager &SM = Context.getSourceManager();
255893d53d1SChris Cotter   const LangOptions &LangOpts = Context.getLangOpts();
256893d53d1SChris Cotter 
257893d53d1SChris Cotter   if (const auto *Constructor = dyn_cast<CXXConstructorDecl>(Function)) {
258893d53d1SChris Cotter     for (const CXXCtorInitializer *Init : Constructor->inits()) {
259893d53d1SChris Cotter       if (Init->getSourceOrder() == 0)
260893d53d1SChris Cotter         return utils::lexer::findPreviousTokenKind(Init->getSourceLocation(),
261893d53d1SChris Cotter                                                    SM, LangOpts, tok::colon);
262893d53d1SChris Cotter     }
2633b5a121aSJulian Schmidt     if (!Constructor->inits().empty())
264893d53d1SChris Cotter       return std::nullopt;
265893d53d1SChris Cotter   }
266893d53d1SChris Cotter   if (Function->isDeleted()) {
267893d53d1SChris Cotter     SourceLocation FunctionEnd = Function->getSourceRange().getEnd();
268893d53d1SChris Cotter     return utils::lexer::findNextAnyTokenKind(FunctionEnd, SM, LangOpts,
269893d53d1SChris Cotter                                               tok::equal, tok::equal);
270893d53d1SChris Cotter   }
271893d53d1SChris Cotter   const Stmt *Body = Function->getBody();
272893d53d1SChris Cotter   if (!Body)
273893d53d1SChris Cotter     return std::nullopt;
274893d53d1SChris Cotter 
275893d53d1SChris Cotter   return Body->getBeginLoc();
276893d53d1SChris Cotter }
277893d53d1SChris Cotter 
isPrimaryExpression(const Expr * Expression)278893d53d1SChris Cotter bool isPrimaryExpression(const Expr *Expression) {
279893d53d1SChris Cotter   // This function is an incomplete approximation of checking whether
280893d53d1SChris Cotter   // an Expr is a primary expression. In particular, if this function
281893d53d1SChris Cotter   // returns true, the expression is a primary expression. The converse
282893d53d1SChris Cotter   // is not necessarily true.
283893d53d1SChris Cotter 
284893d53d1SChris Cotter   if (const auto *Cast = dyn_cast<ImplicitCastExpr>(Expression))
285893d53d1SChris Cotter     Expression = Cast->getSubExprAsWritten();
286893d53d1SChris Cotter   if (isa<ParenExpr, DependentScopeDeclRefExpr>(Expression))
287893d53d1SChris Cotter     return true;
288893d53d1SChris Cotter 
289893d53d1SChris Cotter   return false;
290893d53d1SChris Cotter }
291893d53d1SChris Cotter 
292893d53d1SChris Cotter // Return the original source text of an enable_if_t condition, i.e., the
293893d53d1SChris Cotter // first template argument). For example, in
294893d53d1SChris Cotter // 'enable_if_t<FirstCondition || SecondCondition, AType>', the text
295893d53d1SChris Cotter // the text 'FirstCondition || SecondCondition' is returned.
getConditionText(const Expr * ConditionExpr,SourceRange ConditionRange,ASTContext & Context)296893d53d1SChris Cotter static std::optional<std::string> getConditionText(const Expr *ConditionExpr,
297893d53d1SChris Cotter                                                    SourceRange ConditionRange,
298893d53d1SChris Cotter                                                    ASTContext &Context) {
299893d53d1SChris Cotter   SourceManager &SM = Context.getSourceManager();
300893d53d1SChris Cotter   const LangOptions &LangOpts = Context.getLangOpts();
301893d53d1SChris Cotter 
302893d53d1SChris Cotter   SourceLocation PrevTokenLoc = ConditionRange.getEnd();
303893d53d1SChris Cotter   if (PrevTokenLoc.isInvalid())
304893d53d1SChris Cotter     return std::nullopt;
305893d53d1SChris Cotter 
306893d53d1SChris Cotter   const bool SkipComments = false;
307893d53d1SChris Cotter   Token PrevToken;
308893d53d1SChris Cotter   std::tie(PrevToken, PrevTokenLoc) = utils::lexer::getPreviousTokenAndStart(
309893d53d1SChris Cotter       PrevTokenLoc, SM, LangOpts, SkipComments);
310893d53d1SChris Cotter   bool EndsWithDoubleSlash =
311893d53d1SChris Cotter       PrevToken.is(tok::comment) &&
312893d53d1SChris Cotter       Lexer::getSourceText(CharSourceRange::getCharRange(
313893d53d1SChris Cotter                                PrevTokenLoc, PrevTokenLoc.getLocWithOffset(2)),
314893d53d1SChris Cotter                            SM, LangOpts) == "//";
315893d53d1SChris Cotter 
316893d53d1SChris Cotter   bool Invalid = false;
317893d53d1SChris Cotter   llvm::StringRef ConditionText = Lexer::getSourceText(
318893d53d1SChris Cotter       CharSourceRange::getCharRange(ConditionRange), SM, LangOpts, &Invalid);
319893d53d1SChris Cotter   if (Invalid)
320893d53d1SChris Cotter     return std::nullopt;
321893d53d1SChris Cotter 
322893d53d1SChris Cotter   auto AddParens = [&](llvm::StringRef Text) -> std::string {
323893d53d1SChris Cotter     if (isPrimaryExpression(ConditionExpr))
324893d53d1SChris Cotter       return Text.str();
325893d53d1SChris Cotter     return "(" + Text.str() + ")";
326893d53d1SChris Cotter   };
327893d53d1SChris Cotter 
328893d53d1SChris Cotter   if (EndsWithDoubleSlash)
329893d53d1SChris Cotter     return AddParens(ConditionText);
330893d53d1SChris Cotter   return AddParens(ConditionText.trim());
331893d53d1SChris Cotter }
332893d53d1SChris Cotter 
333893d53d1SChris Cotter // Handle functions that return enable_if_t, e.g.,
334893d53d1SChris Cotter //   template <...>
335893d53d1SChris Cotter //   enable_if_t<Condition, ReturnType> function();
336893d53d1SChris Cotter //
337893d53d1SChris Cotter // Return a vector of FixItHints if the code can be replaced with
338893d53d1SChris Cotter // a C++20 requires clause. In the example above, returns FixItHints
339893d53d1SChris Cotter // to result in
340893d53d1SChris Cotter //   template <...>
341893d53d1SChris Cotter //   ReturnType function() requires Condition {}
handleReturnType(const FunctionDecl * Function,const TypeLoc & ReturnType,const EnableIfData & EnableIf,ASTContext & Context)342893d53d1SChris Cotter static std::vector<FixItHint> handleReturnType(const FunctionDecl *Function,
343893d53d1SChris Cotter                                                const TypeLoc &ReturnType,
344893d53d1SChris Cotter                                                const EnableIfData &EnableIf,
345893d53d1SChris Cotter                                                ASTContext &Context) {
346893d53d1SChris Cotter   TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
347893d53d1SChris Cotter 
348893d53d1SChris Cotter   SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
349893d53d1SChris Cotter 
350893d53d1SChris Cotter   std::optional<std::string> ConditionText = getConditionText(
351893d53d1SChris Cotter       EnableCondition.getSourceExpression(), ConditionRange, Context);
352893d53d1SChris Cotter   if (!ConditionText)
353893d53d1SChris Cotter     return {};
354893d53d1SChris Cotter 
355893d53d1SChris Cotter   std::optional<StringRef> TypeText = getTypeText(Context, EnableIf.Loc);
356893d53d1SChris Cotter   if (!TypeText)
357893d53d1SChris Cotter     return {};
358893d53d1SChris Cotter 
359893d53d1SChris Cotter   SmallVector<const Expr *, 3> ExistingConstraints;
360893d53d1SChris Cotter   Function->getAssociatedConstraints(ExistingConstraints);
361c5a4f29eSPiotr Zegar   if (!ExistingConstraints.empty()) {
362893d53d1SChris Cotter     // FIXME - Support adding new constraints to existing ones. Do we need to
363893d53d1SChris Cotter     // consider subsumption?
364893d53d1SChris Cotter     return {};
365893d53d1SChris Cotter   }
366893d53d1SChris Cotter 
367893d53d1SChris Cotter   std::optional<SourceLocation> ConstraintInsertionLoc =
368893d53d1SChris Cotter       findInsertionForConstraint(Function, Context);
369893d53d1SChris Cotter   if (!ConstraintInsertionLoc)
370893d53d1SChris Cotter     return {};
371893d53d1SChris Cotter 
372893d53d1SChris Cotter   std::vector<FixItHint> FixIts;
373893d53d1SChris Cotter   FixIts.push_back(FixItHint::CreateReplacement(
374893d53d1SChris Cotter       CharSourceRange::getTokenRange(EnableIf.Outer.getSourceRange()),
375893d53d1SChris Cotter       *TypeText));
376893d53d1SChris Cotter   FixIts.push_back(FixItHint::CreateInsertion(
377893d53d1SChris Cotter       *ConstraintInsertionLoc, "requires " + *ConditionText + " "));
378893d53d1SChris Cotter   return FixIts;
379893d53d1SChris Cotter }
380893d53d1SChris Cotter 
381893d53d1SChris Cotter // Handle enable_if_t in a trailing template parameter, e.g.,
382893d53d1SChris Cotter //   template <..., enable_if_t<Condition, Type> = Type{}>
383893d53d1SChris Cotter //   ReturnType function();
384893d53d1SChris Cotter //
385893d53d1SChris Cotter // Return a vector of FixItHints if the code can be replaced with
386893d53d1SChris Cotter // a C++20 requires clause. In the example above, returns FixItHints
387893d53d1SChris Cotter // to result in
388893d53d1SChris Cotter //   template <...>
389893d53d1SChris Cotter //   ReturnType function() requires Condition {}
390893d53d1SChris Cotter static std::vector<FixItHint>
handleTrailingTemplateType(const FunctionTemplateDecl * FunctionTemplate,const FunctionDecl * Function,const Decl * LastTemplateParam,const EnableIfData & EnableIf,ASTContext & Context)391893d53d1SChris Cotter handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate,
392893d53d1SChris Cotter                            const FunctionDecl *Function,
393893d53d1SChris Cotter                            const Decl *LastTemplateParam,
394893d53d1SChris Cotter                            const EnableIfData &EnableIf, ASTContext &Context) {
395893d53d1SChris Cotter   SourceManager &SM = Context.getSourceManager();
396893d53d1SChris Cotter   const LangOptions &LangOpts = Context.getLangOpts();
397893d53d1SChris Cotter 
398893d53d1SChris Cotter   TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
399893d53d1SChris Cotter 
400893d53d1SChris Cotter   SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
401893d53d1SChris Cotter 
402893d53d1SChris Cotter   std::optional<std::string> ConditionText = getConditionText(
403893d53d1SChris Cotter       EnableCondition.getSourceExpression(), ConditionRange, Context);
404893d53d1SChris Cotter   if (!ConditionText)
405893d53d1SChris Cotter     return {};
406893d53d1SChris Cotter 
407893d53d1SChris Cotter   SmallVector<const Expr *, 3> ExistingConstraints;
408893d53d1SChris Cotter   Function->getAssociatedConstraints(ExistingConstraints);
409c5a4f29eSPiotr Zegar   if (!ExistingConstraints.empty()) {
410893d53d1SChris Cotter     // FIXME - Support adding new constraints to existing ones. Do we need to
411893d53d1SChris Cotter     // consider subsumption?
412893d53d1SChris Cotter     return {};
413893d53d1SChris Cotter   }
414893d53d1SChris Cotter 
415893d53d1SChris Cotter   SourceRange RemovalRange;
416893d53d1SChris Cotter   const TemplateParameterList *TemplateParams =
417893d53d1SChris Cotter       FunctionTemplate->getTemplateParameters();
418893d53d1SChris Cotter   if (!TemplateParams || TemplateParams->size() == 0)
419893d53d1SChris Cotter     return {};
420893d53d1SChris Cotter 
421893d53d1SChris Cotter   if (TemplateParams->size() == 1) {
422893d53d1SChris Cotter     RemovalRange =
423893d53d1SChris Cotter         SourceRange(TemplateParams->getTemplateLoc(),
424893d53d1SChris Cotter                     getRAngleFileLoc(SM, *TemplateParams).getLocWithOffset(1));
425893d53d1SChris Cotter   } else {
426893d53d1SChris Cotter     RemovalRange =
427893d53d1SChris Cotter         SourceRange(utils::lexer::findPreviousTokenKind(
428893d53d1SChris Cotter                         LastTemplateParam->getSourceRange().getBegin(), SM,
429893d53d1SChris Cotter                         LangOpts, tok::comma),
430893d53d1SChris Cotter                     getRAngleFileLoc(SM, *TemplateParams));
431893d53d1SChris Cotter   }
432893d53d1SChris Cotter 
433893d53d1SChris Cotter   std::optional<SourceLocation> ConstraintInsertionLoc =
434893d53d1SChris Cotter       findInsertionForConstraint(Function, Context);
435893d53d1SChris Cotter   if (!ConstraintInsertionLoc)
436893d53d1SChris Cotter     return {};
437893d53d1SChris Cotter 
438893d53d1SChris Cotter   std::vector<FixItHint> FixIts;
439893d53d1SChris Cotter   FixIts.push_back(
440893d53d1SChris Cotter       FixItHint::CreateRemoval(CharSourceRange::getCharRange(RemovalRange)));
441893d53d1SChris Cotter   FixIts.push_back(FixItHint::CreateInsertion(
442893d53d1SChris Cotter       *ConstraintInsertionLoc, "requires " + *ConditionText + " "));
443893d53d1SChris Cotter   return FixIts;
444893d53d1SChris Cotter }
445893d53d1SChris Cotter 
check(const MatchFinder::MatchResult & Result)446893d53d1SChris Cotter void UseConstraintsCheck::check(const MatchFinder::MatchResult &Result) {
447893d53d1SChris Cotter   const auto *FunctionTemplate =
448893d53d1SChris Cotter       Result.Nodes.getNodeAs<FunctionTemplateDecl>("functionTemplate");
449893d53d1SChris Cotter   const auto *Function = Result.Nodes.getNodeAs<FunctionDecl>("function");
450893d53d1SChris Cotter   const auto *ReturnType = Result.Nodes.getNodeAs<TypeLoc>("return");
451893d53d1SChris Cotter   if (!FunctionTemplate || !Function || !ReturnType)
452893d53d1SChris Cotter     return;
453893d53d1SChris Cotter 
454893d53d1SChris Cotter   // Check for
455893d53d1SChris Cotter   //
456893d53d1SChris Cotter   //   Case 1. Return type of function
457893d53d1SChris Cotter   //
458893d53d1SChris Cotter   //     template <...>
459893d53d1SChris Cotter   //     enable_if_t<Condition, ReturnType>::type function() {}
460893d53d1SChris Cotter   //
461893d53d1SChris Cotter   //   Case 2. Trailing template parameter
462893d53d1SChris Cotter   //
463893d53d1SChris Cotter   //     template <..., enable_if_t<Condition, Type> = Type{}>
464893d53d1SChris Cotter   //     ReturnType function() {}
465893d53d1SChris Cotter   //
466893d53d1SChris Cotter   //     or
467893d53d1SChris Cotter   //
468893d53d1SChris Cotter   //     template <..., typename = enable_if_t<Condition, void>>
469893d53d1SChris Cotter   //     ReturnType function() {}
470893d53d1SChris Cotter   //
471893d53d1SChris Cotter 
472893d53d1SChris Cotter   // Case 1. Return type of function
473893d53d1SChris Cotter   if (auto EnableIf = matchEnableIfSpecialization(*ReturnType)) {
474893d53d1SChris Cotter     diag(ReturnType->getBeginLoc(),
475893d53d1SChris Cotter          "use C++20 requires constraints instead of enable_if")
476893d53d1SChris Cotter         << handleReturnType(Function, *ReturnType, *EnableIf, *Result.Context);
477893d53d1SChris Cotter     return;
478893d53d1SChris Cotter   }
479893d53d1SChris Cotter 
480893d53d1SChris Cotter   // Case 2. Trailing template parameter
481893d53d1SChris Cotter   if (auto [EnableIf, LastTemplateParam] =
482893d53d1SChris Cotter           matchTrailingTemplateParam(FunctionTemplate);
483893d53d1SChris Cotter       EnableIf && LastTemplateParam) {
484893d53d1SChris Cotter     diag(LastTemplateParam->getSourceRange().getBegin(),
485893d53d1SChris Cotter          "use C++20 requires constraints instead of enable_if")
486893d53d1SChris Cotter         << handleTrailingTemplateType(FunctionTemplate, Function,
487893d53d1SChris Cotter                                       LastTemplateParam, *EnableIf,
488893d53d1SChris Cotter                                       *Result.Context);
489893d53d1SChris Cotter     return;
490893d53d1SChris Cotter   }
491893d53d1SChris Cotter }
492893d53d1SChris Cotter 
493893d53d1SChris Cotter } // namespace clang::tidy::modernize
494