1893d53d1SChris Cotter //===--- UseConstraintsCheck.cpp - clang-tidy -----------------------------===//
2893d53d1SChris Cotter //
3893d53d1SChris Cotter // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4893d53d1SChris Cotter // See https://llvm.org/LICENSE.txt for license information.
5893d53d1SChris Cotter // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6893d53d1SChris Cotter //
7893d53d1SChris Cotter //===----------------------------------------------------------------------===//
8893d53d1SChris Cotter
9893d53d1SChris Cotter #include "UseConstraintsCheck.h"
10893d53d1SChris Cotter #include "clang/AST/ASTContext.h"
11893d53d1SChris Cotter #include "clang/ASTMatchers/ASTMatchFinder.h"
12893d53d1SChris Cotter #include "clang/Lex/Lexer.h"
13893d53d1SChris Cotter
14893d53d1SChris Cotter #include "../utils/LexerUtils.h"
15893d53d1SChris Cotter
16893d53d1SChris Cotter #include <optional>
17893d53d1SChris Cotter #include <utility>
18893d53d1SChris Cotter
19893d53d1SChris Cotter using namespace clang::ast_matchers;
20893d53d1SChris Cotter
21893d53d1SChris Cotter namespace clang::tidy::modernize {
22893d53d1SChris Cotter
23893d53d1SChris Cotter struct EnableIfData {
24893d53d1SChris Cotter TemplateSpecializationTypeLoc Loc;
25893d53d1SChris Cotter TypeLoc Outer;
26893d53d1SChris Cotter };
27893d53d1SChris Cotter
28893d53d1SChris Cotter namespace {
AST_MATCHER(FunctionDecl,hasOtherDeclarations)29893d53d1SChris Cotter AST_MATCHER(FunctionDecl, hasOtherDeclarations) {
30893d53d1SChris Cotter auto It = Node.redecls_begin();
31893d53d1SChris Cotter auto EndIt = Node.redecls_end();
32893d53d1SChris Cotter
33893d53d1SChris Cotter if (It == EndIt)
34893d53d1SChris Cotter return false;
35893d53d1SChris Cotter
36893d53d1SChris Cotter ++It;
37893d53d1SChris Cotter return It != EndIt;
38893d53d1SChris Cotter }
39893d53d1SChris Cotter } // namespace
40893d53d1SChris Cotter
registerMatchers(MatchFinder * Finder)41893d53d1SChris Cotter void UseConstraintsCheck::registerMatchers(MatchFinder *Finder) {
42893d53d1SChris Cotter Finder->addMatcher(
43893d53d1SChris Cotter functionTemplateDecl(
44ba344760SPiotr Zegar // Skip external libraries included as system headers
45ba344760SPiotr Zegar unless(isExpansionInSystemHeader()),
46893d53d1SChris Cotter has(functionDecl(unless(hasOtherDeclarations()), isDefinition(),
47893d53d1SChris Cotter hasReturnTypeLoc(typeLoc().bind("return")))
48893d53d1SChris Cotter .bind("function")))
49893d53d1SChris Cotter .bind("functionTemplate"),
50893d53d1SChris Cotter this);
51893d53d1SChris Cotter }
52893d53d1SChris Cotter
53893d53d1SChris Cotter static std::optional<TemplateSpecializationTypeLoc>
matchEnableIfSpecializationImplTypename(TypeLoc TheType)54893d53d1SChris Cotter matchEnableIfSpecializationImplTypename(TypeLoc TheType) {
55893d53d1SChris Cotter if (const auto Dep = TheType.getAs<DependentNameTypeLoc>()) {
56893d53d1SChris Cotter const IdentifierInfo *Identifier = Dep.getTypePtr()->getIdentifier();
57893d53d1SChris Cotter if (!Identifier || Identifier->getName() != "type" ||
584ad2ada5SVlad Serebrennikov Dep.getTypePtr()->getKeyword() != ElaboratedTypeKeyword::Typename) {
59893d53d1SChris Cotter return std::nullopt;
60893d53d1SChris Cotter }
61893d53d1SChris Cotter TheType = Dep.getQualifierLoc().getTypeLoc();
62ba344760SPiotr Zegar if (TheType.isNull())
63ba344760SPiotr Zegar return std::nullopt;
64893d53d1SChris Cotter }
65893d53d1SChris Cotter
66893d53d1SChris Cotter if (const auto SpecializationLoc =
67893d53d1SChris Cotter TheType.getAs<TemplateSpecializationTypeLoc>()) {
68893d53d1SChris Cotter
69893d53d1SChris Cotter const auto *Specialization =
70893d53d1SChris Cotter dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
71893d53d1SChris Cotter if (!Specialization)
72893d53d1SChris Cotter return std::nullopt;
73893d53d1SChris Cotter
74893d53d1SChris Cotter const TemplateDecl *TD =
75893d53d1SChris Cotter Specialization->getTemplateName().getAsTemplateDecl();
76893d53d1SChris Cotter if (!TD || TD->getName() != "enable_if")
77893d53d1SChris Cotter return std::nullopt;
78893d53d1SChris Cotter
79893d53d1SChris Cotter int NumArgs = SpecializationLoc.getNumArgs();
80893d53d1SChris Cotter if (NumArgs != 1 && NumArgs != 2)
81893d53d1SChris Cotter return std::nullopt;
82893d53d1SChris Cotter
83893d53d1SChris Cotter return SpecializationLoc;
84893d53d1SChris Cotter }
85893d53d1SChris Cotter return std::nullopt;
86893d53d1SChris Cotter }
87893d53d1SChris Cotter
88893d53d1SChris Cotter static std::optional<TemplateSpecializationTypeLoc>
matchEnableIfSpecializationImplTrait(TypeLoc TheType)89893d53d1SChris Cotter matchEnableIfSpecializationImplTrait(TypeLoc TheType) {
90893d53d1SChris Cotter if (const auto Elaborated = TheType.getAs<ElaboratedTypeLoc>())
91893d53d1SChris Cotter TheType = Elaborated.getNamedTypeLoc();
92893d53d1SChris Cotter
93893d53d1SChris Cotter if (const auto SpecializationLoc =
94893d53d1SChris Cotter TheType.getAs<TemplateSpecializationTypeLoc>()) {
95893d53d1SChris Cotter
96893d53d1SChris Cotter const auto *Specialization =
97893d53d1SChris Cotter dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
98893d53d1SChris Cotter if (!Specialization)
99893d53d1SChris Cotter return std::nullopt;
100893d53d1SChris Cotter
101893d53d1SChris Cotter const TemplateDecl *TD =
102893d53d1SChris Cotter Specialization->getTemplateName().getAsTemplateDecl();
103893d53d1SChris Cotter if (!TD || TD->getName() != "enable_if_t")
104893d53d1SChris Cotter return std::nullopt;
105893d53d1SChris Cotter
106893d53d1SChris Cotter if (!Specialization->isTypeAlias())
107893d53d1SChris Cotter return std::nullopt;
108893d53d1SChris Cotter
109893d53d1SChris Cotter if (const auto *AliasedType =
110893d53d1SChris Cotter dyn_cast<DependentNameType>(Specialization->getAliasedType())) {
111893d53d1SChris Cotter if (AliasedType->getIdentifier()->getName() != "type" ||
1124ad2ada5SVlad Serebrennikov AliasedType->getKeyword() != ElaboratedTypeKeyword::Typename) {
113893d53d1SChris Cotter return std::nullopt;
114893d53d1SChris Cotter }
115893d53d1SChris Cotter } else {
116893d53d1SChris Cotter return std::nullopt;
117893d53d1SChris Cotter }
118893d53d1SChris Cotter int NumArgs = SpecializationLoc.getNumArgs();
119893d53d1SChris Cotter if (NumArgs != 1 && NumArgs != 2)
120893d53d1SChris Cotter return std::nullopt;
121893d53d1SChris Cotter
122893d53d1SChris Cotter return SpecializationLoc;
123893d53d1SChris Cotter }
124893d53d1SChris Cotter return std::nullopt;
125893d53d1SChris Cotter }
126893d53d1SChris Cotter
127893d53d1SChris Cotter static std::optional<TemplateSpecializationTypeLoc>
matchEnableIfSpecializationImpl(TypeLoc TheType)128893d53d1SChris Cotter matchEnableIfSpecializationImpl(TypeLoc TheType) {
129893d53d1SChris Cotter if (auto EnableIf = matchEnableIfSpecializationImplTypename(TheType))
130893d53d1SChris Cotter return EnableIf;
131893d53d1SChris Cotter return matchEnableIfSpecializationImplTrait(TheType);
132893d53d1SChris Cotter }
133893d53d1SChris Cotter
134893d53d1SChris Cotter static std::optional<EnableIfData>
matchEnableIfSpecialization(TypeLoc TheType)135893d53d1SChris Cotter matchEnableIfSpecialization(TypeLoc TheType) {
136893d53d1SChris Cotter if (const auto Pointer = TheType.getAs<PointerTypeLoc>())
137893d53d1SChris Cotter TheType = Pointer.getPointeeLoc();
138893d53d1SChris Cotter else if (const auto Reference = TheType.getAs<ReferenceTypeLoc>())
139893d53d1SChris Cotter TheType = Reference.getPointeeLoc();
140893d53d1SChris Cotter if (const auto Qualified = TheType.getAs<QualifiedTypeLoc>())
141893d53d1SChris Cotter TheType = Qualified.getUnqualifiedLoc();
142893d53d1SChris Cotter
143893d53d1SChris Cotter if (auto EnableIf = matchEnableIfSpecializationImpl(TheType))
144893d53d1SChris Cotter return EnableIfData{std::move(*EnableIf), TheType};
145893d53d1SChris Cotter return std::nullopt;
146893d53d1SChris Cotter }
147893d53d1SChris Cotter
148893d53d1SChris Cotter static std::pair<std::optional<EnableIfData>, const Decl *>
matchTrailingTemplateParam(const FunctionTemplateDecl * FunctionTemplate)149893d53d1SChris Cotter matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate) {
150893d53d1SChris Cotter // For non-type trailing param, match very specifically
151893d53d1SChris Cotter // 'template <..., enable_if_type<Condition, Type> = Default>' where
152893d53d1SChris Cotter // enable_if_type is 'enable_if' or 'enable_if_t'. E.g., 'template <typename
153893d53d1SChris Cotter // T, enable_if_t<is_same_v<T, bool>, int*> = nullptr>
154893d53d1SChris Cotter //
155893d53d1SChris Cotter // Otherwise, match a trailing default type arg.
156893d53d1SChris Cotter // E.g., 'template <typename T, typename = enable_if_t<is_same_v<T, bool>>>'
157893d53d1SChris Cotter
158893d53d1SChris Cotter const TemplateParameterList *TemplateParams =
159893d53d1SChris Cotter FunctionTemplate->getTemplateParameters();
160893d53d1SChris Cotter if (TemplateParams->size() == 0)
161893d53d1SChris Cotter return {};
162893d53d1SChris Cotter
163893d53d1SChris Cotter const NamedDecl *LastParam =
164893d53d1SChris Cotter TemplateParams->getParam(TemplateParams->size() - 1);
165893d53d1SChris Cotter if (const auto *LastTemplateParam =
166893d53d1SChris Cotter dyn_cast<NonTypeTemplateParmDecl>(LastParam)) {
167893d53d1SChris Cotter
168893d53d1SChris Cotter if (!LastTemplateParam->hasDefaultArgument() ||
169893d53d1SChris Cotter !LastTemplateParam->getName().empty())
170893d53d1SChris Cotter return {};
171893d53d1SChris Cotter
172893d53d1SChris Cotter return {matchEnableIfSpecialization(
173893d53d1SChris Cotter LastTemplateParam->getTypeSourceInfo()->getTypeLoc()),
174893d53d1SChris Cotter LastTemplateParam};
17501c8bf6fSPiotr Zegar }
17601c8bf6fSPiotr Zegar if (const auto *LastTemplateParam =
177893d53d1SChris Cotter dyn_cast<TemplateTypeParmDecl>(LastParam)) {
178893d53d1SChris Cotter if (LastTemplateParam->hasDefaultArgument() &&
179893d53d1SChris Cotter LastTemplateParam->getIdentifier() == nullptr) {
180*e42b799bSMatheus Izvekov return {
181*e42b799bSMatheus Izvekov matchEnableIfSpecialization(LastTemplateParam->getDefaultArgument()
182*e42b799bSMatheus Izvekov .getTypeSourceInfo()
183*e42b799bSMatheus Izvekov ->getTypeLoc()),
184893d53d1SChris Cotter LastTemplateParam};
185893d53d1SChris Cotter }
186893d53d1SChris Cotter }
187893d53d1SChris Cotter return {};
188893d53d1SChris Cotter }
189893d53d1SChris Cotter
190893d53d1SChris Cotter template <typename T>
getRAngleFileLoc(const SourceManager & SM,const T & Element)191893d53d1SChris Cotter static SourceLocation getRAngleFileLoc(const SourceManager &SM,
192893d53d1SChris Cotter const T &Element) {
193893d53d1SChris Cotter // getFileLoc handles the case where the RAngle loc is part of a synthesized
194893d53d1SChris Cotter // '>>', which ends up allocating a 'scratch space' buffer in the source
195893d53d1SChris Cotter // manager.
196893d53d1SChris Cotter return SM.getFileLoc(Element.getRAngleLoc());
197893d53d1SChris Cotter }
198893d53d1SChris Cotter
199893d53d1SChris Cotter static SourceRange
getConditionRange(ASTContext & Context,const TemplateSpecializationTypeLoc & EnableIf)200893d53d1SChris Cotter getConditionRange(ASTContext &Context,
201893d53d1SChris Cotter const TemplateSpecializationTypeLoc &EnableIf) {
202893d53d1SChris Cotter // TemplateArgumentLoc's SourceRange End is the location of the last token
203893d53d1SChris Cotter // (per UnqualifiedId docs). E.g., in `enable_if<AAA && BBB>`, the End
204893d53d1SChris Cotter // location will be the first 'B' in 'BBB'.
205893d53d1SChris Cotter const LangOptions &LangOpts = Context.getLangOpts();
206893d53d1SChris Cotter const SourceManager &SM = Context.getSourceManager();
207893d53d1SChris Cotter if (EnableIf.getNumArgs() > 1) {
208893d53d1SChris Cotter TemplateArgumentLoc NextArg = EnableIf.getArgLoc(1);
209ec5f4be4SPiotr Zegar return {EnableIf.getLAngleLoc().getLocWithOffset(1),
210ec5f4be4SPiotr Zegar utils::lexer::findPreviousTokenKind(
211ec5f4be4SPiotr Zegar NextArg.getSourceRange().getBegin(), SM, LangOpts, tok::comma)};
212893d53d1SChris Cotter }
213893d53d1SChris Cotter
214ec5f4be4SPiotr Zegar return {EnableIf.getLAngleLoc().getLocWithOffset(1),
215ec5f4be4SPiotr Zegar getRAngleFileLoc(SM, EnableIf)};
216893d53d1SChris Cotter }
217893d53d1SChris Cotter
getTypeRange(ASTContext & Context,const TemplateSpecializationTypeLoc & EnableIf)218893d53d1SChris Cotter static SourceRange getTypeRange(ASTContext &Context,
219893d53d1SChris Cotter const TemplateSpecializationTypeLoc &EnableIf) {
220893d53d1SChris Cotter TemplateArgumentLoc Arg = EnableIf.getArgLoc(1);
221893d53d1SChris Cotter const LangOptions &LangOpts = Context.getLangOpts();
222893d53d1SChris Cotter const SourceManager &SM = Context.getSourceManager();
223ec5f4be4SPiotr Zegar return {utils::lexer::findPreviousTokenKind(Arg.getSourceRange().getBegin(),
224ec5f4be4SPiotr Zegar SM, LangOpts, tok::comma)
225893d53d1SChris Cotter .getLocWithOffset(1),
226ec5f4be4SPiotr Zegar getRAngleFileLoc(SM, EnableIf)};
227893d53d1SChris Cotter }
228893d53d1SChris Cotter
229893d53d1SChris Cotter // Returns the original source text of the second argument of a call to
230893d53d1SChris Cotter // enable_if_t. E.g., in enable_if_t<Condition, TheType>, this function
231893d53d1SChris Cotter // returns 'TheType'.
232893d53d1SChris Cotter static std::optional<StringRef>
getTypeText(ASTContext & Context,const TemplateSpecializationTypeLoc & EnableIf)233893d53d1SChris Cotter getTypeText(ASTContext &Context,
234893d53d1SChris Cotter const TemplateSpecializationTypeLoc &EnableIf) {
235893d53d1SChris Cotter if (EnableIf.getNumArgs() > 1) {
236893d53d1SChris Cotter const LangOptions &LangOpts = Context.getLangOpts();
237893d53d1SChris Cotter const SourceManager &SM = Context.getSourceManager();
238893d53d1SChris Cotter bool Invalid = false;
239893d53d1SChris Cotter StringRef Text = Lexer::getSourceText(CharSourceRange::getCharRange(
240893d53d1SChris Cotter getTypeRange(Context, EnableIf)),
241893d53d1SChris Cotter SM, LangOpts, &Invalid)
242893d53d1SChris Cotter .trim();
243893d53d1SChris Cotter if (Invalid)
244893d53d1SChris Cotter return std::nullopt;
245893d53d1SChris Cotter
246893d53d1SChris Cotter return Text;
247893d53d1SChris Cotter }
248893d53d1SChris Cotter
249893d53d1SChris Cotter return "void";
250893d53d1SChris Cotter }
251893d53d1SChris Cotter
252893d53d1SChris Cotter static std::optional<SourceLocation>
findInsertionForConstraint(const FunctionDecl * Function,ASTContext & Context)253893d53d1SChris Cotter findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context) {
254893d53d1SChris Cotter SourceManager &SM = Context.getSourceManager();
255893d53d1SChris Cotter const LangOptions &LangOpts = Context.getLangOpts();
256893d53d1SChris Cotter
257893d53d1SChris Cotter if (const auto *Constructor = dyn_cast<CXXConstructorDecl>(Function)) {
258893d53d1SChris Cotter for (const CXXCtorInitializer *Init : Constructor->inits()) {
259893d53d1SChris Cotter if (Init->getSourceOrder() == 0)
260893d53d1SChris Cotter return utils::lexer::findPreviousTokenKind(Init->getSourceLocation(),
261893d53d1SChris Cotter SM, LangOpts, tok::colon);
262893d53d1SChris Cotter }
2633b5a121aSJulian Schmidt if (!Constructor->inits().empty())
264893d53d1SChris Cotter return std::nullopt;
265893d53d1SChris Cotter }
266893d53d1SChris Cotter if (Function->isDeleted()) {
267893d53d1SChris Cotter SourceLocation FunctionEnd = Function->getSourceRange().getEnd();
268893d53d1SChris Cotter return utils::lexer::findNextAnyTokenKind(FunctionEnd, SM, LangOpts,
269893d53d1SChris Cotter tok::equal, tok::equal);
270893d53d1SChris Cotter }
271893d53d1SChris Cotter const Stmt *Body = Function->getBody();
272893d53d1SChris Cotter if (!Body)
273893d53d1SChris Cotter return std::nullopt;
274893d53d1SChris Cotter
275893d53d1SChris Cotter return Body->getBeginLoc();
276893d53d1SChris Cotter }
277893d53d1SChris Cotter
isPrimaryExpression(const Expr * Expression)278893d53d1SChris Cotter bool isPrimaryExpression(const Expr *Expression) {
279893d53d1SChris Cotter // This function is an incomplete approximation of checking whether
280893d53d1SChris Cotter // an Expr is a primary expression. In particular, if this function
281893d53d1SChris Cotter // returns true, the expression is a primary expression. The converse
282893d53d1SChris Cotter // is not necessarily true.
283893d53d1SChris Cotter
284893d53d1SChris Cotter if (const auto *Cast = dyn_cast<ImplicitCastExpr>(Expression))
285893d53d1SChris Cotter Expression = Cast->getSubExprAsWritten();
286893d53d1SChris Cotter if (isa<ParenExpr, DependentScopeDeclRefExpr>(Expression))
287893d53d1SChris Cotter return true;
288893d53d1SChris Cotter
289893d53d1SChris Cotter return false;
290893d53d1SChris Cotter }
291893d53d1SChris Cotter
292893d53d1SChris Cotter // Return the original source text of an enable_if_t condition, i.e., the
293893d53d1SChris Cotter // first template argument). For example, in
294893d53d1SChris Cotter // 'enable_if_t<FirstCondition || SecondCondition, AType>', the text
295893d53d1SChris Cotter // the text 'FirstCondition || SecondCondition' is returned.
getConditionText(const Expr * ConditionExpr,SourceRange ConditionRange,ASTContext & Context)296893d53d1SChris Cotter static std::optional<std::string> getConditionText(const Expr *ConditionExpr,
297893d53d1SChris Cotter SourceRange ConditionRange,
298893d53d1SChris Cotter ASTContext &Context) {
299893d53d1SChris Cotter SourceManager &SM = Context.getSourceManager();
300893d53d1SChris Cotter const LangOptions &LangOpts = Context.getLangOpts();
301893d53d1SChris Cotter
302893d53d1SChris Cotter SourceLocation PrevTokenLoc = ConditionRange.getEnd();
303893d53d1SChris Cotter if (PrevTokenLoc.isInvalid())
304893d53d1SChris Cotter return std::nullopt;
305893d53d1SChris Cotter
306893d53d1SChris Cotter const bool SkipComments = false;
307893d53d1SChris Cotter Token PrevToken;
308893d53d1SChris Cotter std::tie(PrevToken, PrevTokenLoc) = utils::lexer::getPreviousTokenAndStart(
309893d53d1SChris Cotter PrevTokenLoc, SM, LangOpts, SkipComments);
310893d53d1SChris Cotter bool EndsWithDoubleSlash =
311893d53d1SChris Cotter PrevToken.is(tok::comment) &&
312893d53d1SChris Cotter Lexer::getSourceText(CharSourceRange::getCharRange(
313893d53d1SChris Cotter PrevTokenLoc, PrevTokenLoc.getLocWithOffset(2)),
314893d53d1SChris Cotter SM, LangOpts) == "//";
315893d53d1SChris Cotter
316893d53d1SChris Cotter bool Invalid = false;
317893d53d1SChris Cotter llvm::StringRef ConditionText = Lexer::getSourceText(
318893d53d1SChris Cotter CharSourceRange::getCharRange(ConditionRange), SM, LangOpts, &Invalid);
319893d53d1SChris Cotter if (Invalid)
320893d53d1SChris Cotter return std::nullopt;
321893d53d1SChris Cotter
322893d53d1SChris Cotter auto AddParens = [&](llvm::StringRef Text) -> std::string {
323893d53d1SChris Cotter if (isPrimaryExpression(ConditionExpr))
324893d53d1SChris Cotter return Text.str();
325893d53d1SChris Cotter return "(" + Text.str() + ")";
326893d53d1SChris Cotter };
327893d53d1SChris Cotter
328893d53d1SChris Cotter if (EndsWithDoubleSlash)
329893d53d1SChris Cotter return AddParens(ConditionText);
330893d53d1SChris Cotter return AddParens(ConditionText.trim());
331893d53d1SChris Cotter }
332893d53d1SChris Cotter
333893d53d1SChris Cotter // Handle functions that return enable_if_t, e.g.,
334893d53d1SChris Cotter // template <...>
335893d53d1SChris Cotter // enable_if_t<Condition, ReturnType> function();
336893d53d1SChris Cotter //
337893d53d1SChris Cotter // Return a vector of FixItHints if the code can be replaced with
338893d53d1SChris Cotter // a C++20 requires clause. In the example above, returns FixItHints
339893d53d1SChris Cotter // to result in
340893d53d1SChris Cotter // template <...>
341893d53d1SChris Cotter // ReturnType function() requires Condition {}
handleReturnType(const FunctionDecl * Function,const TypeLoc & ReturnType,const EnableIfData & EnableIf,ASTContext & Context)342893d53d1SChris Cotter static std::vector<FixItHint> handleReturnType(const FunctionDecl *Function,
343893d53d1SChris Cotter const TypeLoc &ReturnType,
344893d53d1SChris Cotter const EnableIfData &EnableIf,
345893d53d1SChris Cotter ASTContext &Context) {
346893d53d1SChris Cotter TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
347893d53d1SChris Cotter
348893d53d1SChris Cotter SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
349893d53d1SChris Cotter
350893d53d1SChris Cotter std::optional<std::string> ConditionText = getConditionText(
351893d53d1SChris Cotter EnableCondition.getSourceExpression(), ConditionRange, Context);
352893d53d1SChris Cotter if (!ConditionText)
353893d53d1SChris Cotter return {};
354893d53d1SChris Cotter
355893d53d1SChris Cotter std::optional<StringRef> TypeText = getTypeText(Context, EnableIf.Loc);
356893d53d1SChris Cotter if (!TypeText)
357893d53d1SChris Cotter return {};
358893d53d1SChris Cotter
359893d53d1SChris Cotter SmallVector<const Expr *, 3> ExistingConstraints;
360893d53d1SChris Cotter Function->getAssociatedConstraints(ExistingConstraints);
361c5a4f29eSPiotr Zegar if (!ExistingConstraints.empty()) {
362893d53d1SChris Cotter // FIXME - Support adding new constraints to existing ones. Do we need to
363893d53d1SChris Cotter // consider subsumption?
364893d53d1SChris Cotter return {};
365893d53d1SChris Cotter }
366893d53d1SChris Cotter
367893d53d1SChris Cotter std::optional<SourceLocation> ConstraintInsertionLoc =
368893d53d1SChris Cotter findInsertionForConstraint(Function, Context);
369893d53d1SChris Cotter if (!ConstraintInsertionLoc)
370893d53d1SChris Cotter return {};
371893d53d1SChris Cotter
372893d53d1SChris Cotter std::vector<FixItHint> FixIts;
373893d53d1SChris Cotter FixIts.push_back(FixItHint::CreateReplacement(
374893d53d1SChris Cotter CharSourceRange::getTokenRange(EnableIf.Outer.getSourceRange()),
375893d53d1SChris Cotter *TypeText));
376893d53d1SChris Cotter FixIts.push_back(FixItHint::CreateInsertion(
377893d53d1SChris Cotter *ConstraintInsertionLoc, "requires " + *ConditionText + " "));
378893d53d1SChris Cotter return FixIts;
379893d53d1SChris Cotter }
380893d53d1SChris Cotter
381893d53d1SChris Cotter // Handle enable_if_t in a trailing template parameter, e.g.,
382893d53d1SChris Cotter // template <..., enable_if_t<Condition, Type> = Type{}>
383893d53d1SChris Cotter // ReturnType function();
384893d53d1SChris Cotter //
385893d53d1SChris Cotter // Return a vector of FixItHints if the code can be replaced with
386893d53d1SChris Cotter // a C++20 requires clause. In the example above, returns FixItHints
387893d53d1SChris Cotter // to result in
388893d53d1SChris Cotter // template <...>
389893d53d1SChris Cotter // ReturnType function() requires Condition {}
390893d53d1SChris Cotter static std::vector<FixItHint>
handleTrailingTemplateType(const FunctionTemplateDecl * FunctionTemplate,const FunctionDecl * Function,const Decl * LastTemplateParam,const EnableIfData & EnableIf,ASTContext & Context)391893d53d1SChris Cotter handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate,
392893d53d1SChris Cotter const FunctionDecl *Function,
393893d53d1SChris Cotter const Decl *LastTemplateParam,
394893d53d1SChris Cotter const EnableIfData &EnableIf, ASTContext &Context) {
395893d53d1SChris Cotter SourceManager &SM = Context.getSourceManager();
396893d53d1SChris Cotter const LangOptions &LangOpts = Context.getLangOpts();
397893d53d1SChris Cotter
398893d53d1SChris Cotter TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
399893d53d1SChris Cotter
400893d53d1SChris Cotter SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
401893d53d1SChris Cotter
402893d53d1SChris Cotter std::optional<std::string> ConditionText = getConditionText(
403893d53d1SChris Cotter EnableCondition.getSourceExpression(), ConditionRange, Context);
404893d53d1SChris Cotter if (!ConditionText)
405893d53d1SChris Cotter return {};
406893d53d1SChris Cotter
407893d53d1SChris Cotter SmallVector<const Expr *, 3> ExistingConstraints;
408893d53d1SChris Cotter Function->getAssociatedConstraints(ExistingConstraints);
409c5a4f29eSPiotr Zegar if (!ExistingConstraints.empty()) {
410893d53d1SChris Cotter // FIXME - Support adding new constraints to existing ones. Do we need to
411893d53d1SChris Cotter // consider subsumption?
412893d53d1SChris Cotter return {};
413893d53d1SChris Cotter }
414893d53d1SChris Cotter
415893d53d1SChris Cotter SourceRange RemovalRange;
416893d53d1SChris Cotter const TemplateParameterList *TemplateParams =
417893d53d1SChris Cotter FunctionTemplate->getTemplateParameters();
418893d53d1SChris Cotter if (!TemplateParams || TemplateParams->size() == 0)
419893d53d1SChris Cotter return {};
420893d53d1SChris Cotter
421893d53d1SChris Cotter if (TemplateParams->size() == 1) {
422893d53d1SChris Cotter RemovalRange =
423893d53d1SChris Cotter SourceRange(TemplateParams->getTemplateLoc(),
424893d53d1SChris Cotter getRAngleFileLoc(SM, *TemplateParams).getLocWithOffset(1));
425893d53d1SChris Cotter } else {
426893d53d1SChris Cotter RemovalRange =
427893d53d1SChris Cotter SourceRange(utils::lexer::findPreviousTokenKind(
428893d53d1SChris Cotter LastTemplateParam->getSourceRange().getBegin(), SM,
429893d53d1SChris Cotter LangOpts, tok::comma),
430893d53d1SChris Cotter getRAngleFileLoc(SM, *TemplateParams));
431893d53d1SChris Cotter }
432893d53d1SChris Cotter
433893d53d1SChris Cotter std::optional<SourceLocation> ConstraintInsertionLoc =
434893d53d1SChris Cotter findInsertionForConstraint(Function, Context);
435893d53d1SChris Cotter if (!ConstraintInsertionLoc)
436893d53d1SChris Cotter return {};
437893d53d1SChris Cotter
438893d53d1SChris Cotter std::vector<FixItHint> FixIts;
439893d53d1SChris Cotter FixIts.push_back(
440893d53d1SChris Cotter FixItHint::CreateRemoval(CharSourceRange::getCharRange(RemovalRange)));
441893d53d1SChris Cotter FixIts.push_back(FixItHint::CreateInsertion(
442893d53d1SChris Cotter *ConstraintInsertionLoc, "requires " + *ConditionText + " "));
443893d53d1SChris Cotter return FixIts;
444893d53d1SChris Cotter }
445893d53d1SChris Cotter
check(const MatchFinder::MatchResult & Result)446893d53d1SChris Cotter void UseConstraintsCheck::check(const MatchFinder::MatchResult &Result) {
447893d53d1SChris Cotter const auto *FunctionTemplate =
448893d53d1SChris Cotter Result.Nodes.getNodeAs<FunctionTemplateDecl>("functionTemplate");
449893d53d1SChris Cotter const auto *Function = Result.Nodes.getNodeAs<FunctionDecl>("function");
450893d53d1SChris Cotter const auto *ReturnType = Result.Nodes.getNodeAs<TypeLoc>("return");
451893d53d1SChris Cotter if (!FunctionTemplate || !Function || !ReturnType)
452893d53d1SChris Cotter return;
453893d53d1SChris Cotter
454893d53d1SChris Cotter // Check for
455893d53d1SChris Cotter //
456893d53d1SChris Cotter // Case 1. Return type of function
457893d53d1SChris Cotter //
458893d53d1SChris Cotter // template <...>
459893d53d1SChris Cotter // enable_if_t<Condition, ReturnType>::type function() {}
460893d53d1SChris Cotter //
461893d53d1SChris Cotter // Case 2. Trailing template parameter
462893d53d1SChris Cotter //
463893d53d1SChris Cotter // template <..., enable_if_t<Condition, Type> = Type{}>
464893d53d1SChris Cotter // ReturnType function() {}
465893d53d1SChris Cotter //
466893d53d1SChris Cotter // or
467893d53d1SChris Cotter //
468893d53d1SChris Cotter // template <..., typename = enable_if_t<Condition, void>>
469893d53d1SChris Cotter // ReturnType function() {}
470893d53d1SChris Cotter //
471893d53d1SChris Cotter
472893d53d1SChris Cotter // Case 1. Return type of function
473893d53d1SChris Cotter if (auto EnableIf = matchEnableIfSpecialization(*ReturnType)) {
474893d53d1SChris Cotter diag(ReturnType->getBeginLoc(),
475893d53d1SChris Cotter "use C++20 requires constraints instead of enable_if")
476893d53d1SChris Cotter << handleReturnType(Function, *ReturnType, *EnableIf, *Result.Context);
477893d53d1SChris Cotter return;
478893d53d1SChris Cotter }
479893d53d1SChris Cotter
480893d53d1SChris Cotter // Case 2. Trailing template parameter
481893d53d1SChris Cotter if (auto [EnableIf, LastTemplateParam] =
482893d53d1SChris Cotter matchTrailingTemplateParam(FunctionTemplate);
483893d53d1SChris Cotter EnableIf && LastTemplateParam) {
484893d53d1SChris Cotter diag(LastTemplateParam->getSourceRange().getBegin(),
485893d53d1SChris Cotter "use C++20 requires constraints instead of enable_if")
486893d53d1SChris Cotter << handleTrailingTemplateType(FunctionTemplate, Function,
487893d53d1SChris Cotter LastTemplateParam, *EnableIf,
488893d53d1SChris Cotter *Result.Context);
489893d53d1SChris Cotter return;
490893d53d1SChris Cotter }
491893d53d1SChris Cotter }
492893d53d1SChris Cotter
493893d53d1SChris Cotter } // namespace clang::tidy::modernize
494