1 //===--- UseConstraintsCheck.cpp - clang-tidy -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "UseConstraintsCheck.h" 10 #include "clang/AST/ASTContext.h" 11 #include "clang/ASTMatchers/ASTMatchFinder.h" 12 #include "clang/Lex/Lexer.h" 13 14 #include "../utils/LexerUtils.h" 15 16 #include <optional> 17 #include <utility> 18 19 using namespace clang::ast_matchers; 20 21 namespace clang::tidy::modernize { 22 23 struct EnableIfData { 24 TemplateSpecializationTypeLoc Loc; 25 TypeLoc Outer; 26 }; 27 28 namespace { 29 AST_MATCHER(FunctionDecl, hasOtherDeclarations) { 30 auto It = Node.redecls_begin(); 31 auto EndIt = Node.redecls_end(); 32 33 if (It == EndIt) 34 return false; 35 36 ++It; 37 return It != EndIt; 38 } 39 } // namespace 40 41 void UseConstraintsCheck::registerMatchers(MatchFinder *Finder) { 42 Finder->addMatcher( 43 functionTemplateDecl( 44 has(functionDecl(unless(hasOtherDeclarations()), isDefinition(), 45 hasReturnTypeLoc(typeLoc().bind("return"))) 46 .bind("function"))) 47 .bind("functionTemplate"), 48 this); 49 } 50 51 static std::optional<TemplateSpecializationTypeLoc> 52 matchEnableIfSpecializationImplTypename(TypeLoc TheType) { 53 if (const auto Dep = TheType.getAs<DependentNameTypeLoc>()) { 54 const IdentifierInfo *Identifier = Dep.getTypePtr()->getIdentifier(); 55 if (!Identifier || Identifier->getName() != "type" || 56 Dep.getTypePtr()->getKeyword() != ETK_Typename) { 57 return std::nullopt; 58 } 59 TheType = Dep.getQualifierLoc().getTypeLoc(); 60 } 61 62 if (const auto SpecializationLoc = 63 TheType.getAs<TemplateSpecializationTypeLoc>()) { 64 65 const auto *Specialization = 66 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr()); 67 if (!Specialization) 68 return std::nullopt; 69 70 const TemplateDecl *TD = 71 Specialization->getTemplateName().getAsTemplateDecl(); 72 if (!TD || TD->getName() != "enable_if") 73 return std::nullopt; 74 75 int NumArgs = SpecializationLoc.getNumArgs(); 76 if (NumArgs != 1 && NumArgs != 2) 77 return std::nullopt; 78 79 return SpecializationLoc; 80 } 81 return std::nullopt; 82 } 83 84 static std::optional<TemplateSpecializationTypeLoc> 85 matchEnableIfSpecializationImplTrait(TypeLoc TheType) { 86 if (const auto Elaborated = TheType.getAs<ElaboratedTypeLoc>()) 87 TheType = Elaborated.getNamedTypeLoc(); 88 89 if (const auto SpecializationLoc = 90 TheType.getAs<TemplateSpecializationTypeLoc>()) { 91 92 const auto *Specialization = 93 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr()); 94 if (!Specialization) 95 return std::nullopt; 96 97 const TemplateDecl *TD = 98 Specialization->getTemplateName().getAsTemplateDecl(); 99 if (!TD || TD->getName() != "enable_if_t") 100 return std::nullopt; 101 102 if (!Specialization->isTypeAlias()) 103 return std::nullopt; 104 105 if (const auto *AliasedType = 106 dyn_cast<DependentNameType>(Specialization->getAliasedType())) { 107 if (AliasedType->getIdentifier()->getName() != "type" || 108 AliasedType->getKeyword() != ETK_Typename) { 109 return std::nullopt; 110 } 111 } else { 112 return std::nullopt; 113 } 114 int NumArgs = SpecializationLoc.getNumArgs(); 115 if (NumArgs != 1 && NumArgs != 2) 116 return std::nullopt; 117 118 return SpecializationLoc; 119 } 120 return std::nullopt; 121 } 122 123 static std::optional<TemplateSpecializationTypeLoc> 124 matchEnableIfSpecializationImpl(TypeLoc TheType) { 125 if (auto EnableIf = matchEnableIfSpecializationImplTypename(TheType)) 126 return EnableIf; 127 return matchEnableIfSpecializationImplTrait(TheType); 128 } 129 130 static std::optional<EnableIfData> 131 matchEnableIfSpecialization(TypeLoc TheType) { 132 if (const auto Pointer = TheType.getAs<PointerTypeLoc>()) 133 TheType = Pointer.getPointeeLoc(); 134 else if (const auto Reference = TheType.getAs<ReferenceTypeLoc>()) 135 TheType = Reference.getPointeeLoc(); 136 if (const auto Qualified = TheType.getAs<QualifiedTypeLoc>()) 137 TheType = Qualified.getUnqualifiedLoc(); 138 139 if (auto EnableIf = matchEnableIfSpecializationImpl(TheType)) 140 return EnableIfData{std::move(*EnableIf), TheType}; 141 return std::nullopt; 142 } 143 144 static std::pair<std::optional<EnableIfData>, const Decl *> 145 matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate) { 146 // For non-type trailing param, match very specifically 147 // 'template <..., enable_if_type<Condition, Type> = Default>' where 148 // enable_if_type is 'enable_if' or 'enable_if_t'. E.g., 'template <typename 149 // T, enable_if_t<is_same_v<T, bool>, int*> = nullptr> 150 // 151 // Otherwise, match a trailing default type arg. 152 // E.g., 'template <typename T, typename = enable_if_t<is_same_v<T, bool>>>' 153 154 const TemplateParameterList *TemplateParams = 155 FunctionTemplate->getTemplateParameters(); 156 if (TemplateParams->size() == 0) 157 return {}; 158 159 const NamedDecl *LastParam = 160 TemplateParams->getParam(TemplateParams->size() - 1); 161 if (const auto *LastTemplateParam = 162 dyn_cast<NonTypeTemplateParmDecl>(LastParam)) { 163 164 if (!LastTemplateParam->hasDefaultArgument() || 165 !LastTemplateParam->getName().empty()) 166 return {}; 167 168 return {matchEnableIfSpecialization( 169 LastTemplateParam->getTypeSourceInfo()->getTypeLoc()), 170 LastTemplateParam}; 171 } else if (const auto *LastTemplateParam = 172 dyn_cast<TemplateTypeParmDecl>(LastParam)) { 173 if (LastTemplateParam->hasDefaultArgument() && 174 LastTemplateParam->getIdentifier() == nullptr) { 175 return {matchEnableIfSpecialization( 176 LastTemplateParam->getDefaultArgumentInfo()->getTypeLoc()), 177 LastTemplateParam}; 178 } 179 } 180 return {}; 181 } 182 183 template <typename T> 184 static SourceLocation getRAngleFileLoc(const SourceManager &SM, 185 const T &Element) { 186 // getFileLoc handles the case where the RAngle loc is part of a synthesized 187 // '>>', which ends up allocating a 'scratch space' buffer in the source 188 // manager. 189 return SM.getFileLoc(Element.getRAngleLoc()); 190 } 191 192 static SourceRange 193 getConditionRange(ASTContext &Context, 194 const TemplateSpecializationTypeLoc &EnableIf) { 195 // TemplateArgumentLoc's SourceRange End is the location of the last token 196 // (per UnqualifiedId docs). E.g., in `enable_if<AAA && BBB>`, the End 197 // location will be the first 'B' in 'BBB'. 198 const LangOptions &LangOpts = Context.getLangOpts(); 199 const SourceManager &SM = Context.getSourceManager(); 200 if (EnableIf.getNumArgs() > 1) { 201 TemplateArgumentLoc NextArg = EnableIf.getArgLoc(1); 202 return {EnableIf.getLAngleLoc().getLocWithOffset(1), 203 utils::lexer::findPreviousTokenKind( 204 NextArg.getSourceRange().getBegin(), SM, LangOpts, tok::comma)}; 205 } 206 207 return {EnableIf.getLAngleLoc().getLocWithOffset(1), 208 getRAngleFileLoc(SM, EnableIf)}; 209 } 210 211 static SourceRange getTypeRange(ASTContext &Context, 212 const TemplateSpecializationTypeLoc &EnableIf) { 213 TemplateArgumentLoc Arg = EnableIf.getArgLoc(1); 214 const LangOptions &LangOpts = Context.getLangOpts(); 215 const SourceManager &SM = Context.getSourceManager(); 216 return {utils::lexer::findPreviousTokenKind(Arg.getSourceRange().getBegin(), 217 SM, LangOpts, tok::comma) 218 .getLocWithOffset(1), 219 getRAngleFileLoc(SM, EnableIf)}; 220 } 221 222 // Returns the original source text of the second argument of a call to 223 // enable_if_t. E.g., in enable_if_t<Condition, TheType>, this function 224 // returns 'TheType'. 225 static std::optional<StringRef> 226 getTypeText(ASTContext &Context, 227 const TemplateSpecializationTypeLoc &EnableIf) { 228 if (EnableIf.getNumArgs() > 1) { 229 const LangOptions &LangOpts = Context.getLangOpts(); 230 const SourceManager &SM = Context.getSourceManager(); 231 bool Invalid = false; 232 StringRef Text = Lexer::getSourceText(CharSourceRange::getCharRange( 233 getTypeRange(Context, EnableIf)), 234 SM, LangOpts, &Invalid) 235 .trim(); 236 if (Invalid) 237 return std::nullopt; 238 239 return Text; 240 } 241 242 return "void"; 243 } 244 245 static std::optional<SourceLocation> 246 findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context) { 247 SourceManager &SM = Context.getSourceManager(); 248 const LangOptions &LangOpts = Context.getLangOpts(); 249 250 if (const auto *Constructor = dyn_cast<CXXConstructorDecl>(Function)) { 251 for (const CXXCtorInitializer *Init : Constructor->inits()) { 252 if (Init->getSourceOrder() == 0) 253 return utils::lexer::findPreviousTokenKind(Init->getSourceLocation(), 254 SM, LangOpts, tok::colon); 255 } 256 if (Constructor->init_begin() != Constructor->init_end()) 257 return std::nullopt; 258 } 259 if (Function->isDeleted()) { 260 SourceLocation FunctionEnd = Function->getSourceRange().getEnd(); 261 return utils::lexer::findNextAnyTokenKind(FunctionEnd, SM, LangOpts, 262 tok::equal, tok::equal); 263 } 264 const Stmt *Body = Function->getBody(); 265 if (!Body) 266 return std::nullopt; 267 268 return Body->getBeginLoc(); 269 } 270 271 bool isPrimaryExpression(const Expr *Expression) { 272 // This function is an incomplete approximation of checking whether 273 // an Expr is a primary expression. In particular, if this function 274 // returns true, the expression is a primary expression. The converse 275 // is not necessarily true. 276 277 if (const auto *Cast = dyn_cast<ImplicitCastExpr>(Expression)) 278 Expression = Cast->getSubExprAsWritten(); 279 if (isa<ParenExpr, DependentScopeDeclRefExpr>(Expression)) 280 return true; 281 282 return false; 283 } 284 285 // Return the original source text of an enable_if_t condition, i.e., the 286 // first template argument). For example, in 287 // 'enable_if_t<FirstCondition || SecondCondition, AType>', the text 288 // the text 'FirstCondition || SecondCondition' is returned. 289 static std::optional<std::string> getConditionText(const Expr *ConditionExpr, 290 SourceRange ConditionRange, 291 ASTContext &Context) { 292 SourceManager &SM = Context.getSourceManager(); 293 const LangOptions &LangOpts = Context.getLangOpts(); 294 295 SourceLocation PrevTokenLoc = ConditionRange.getEnd(); 296 if (PrevTokenLoc.isInvalid()) 297 return std::nullopt; 298 299 const bool SkipComments = false; 300 Token PrevToken; 301 std::tie(PrevToken, PrevTokenLoc) = utils::lexer::getPreviousTokenAndStart( 302 PrevTokenLoc, SM, LangOpts, SkipComments); 303 bool EndsWithDoubleSlash = 304 PrevToken.is(tok::comment) && 305 Lexer::getSourceText(CharSourceRange::getCharRange( 306 PrevTokenLoc, PrevTokenLoc.getLocWithOffset(2)), 307 SM, LangOpts) == "//"; 308 309 bool Invalid = false; 310 llvm::StringRef ConditionText = Lexer::getSourceText( 311 CharSourceRange::getCharRange(ConditionRange), SM, LangOpts, &Invalid); 312 if (Invalid) 313 return std::nullopt; 314 315 auto AddParens = [&](llvm::StringRef Text) -> std::string { 316 if (isPrimaryExpression(ConditionExpr)) 317 return Text.str(); 318 return "(" + Text.str() + ")"; 319 }; 320 321 if (EndsWithDoubleSlash) 322 return AddParens(ConditionText); 323 return AddParens(ConditionText.trim()); 324 } 325 326 // Handle functions that return enable_if_t, e.g., 327 // template <...> 328 // enable_if_t<Condition, ReturnType> function(); 329 // 330 // Return a vector of FixItHints if the code can be replaced with 331 // a C++20 requires clause. In the example above, returns FixItHints 332 // to result in 333 // template <...> 334 // ReturnType function() requires Condition {} 335 static std::vector<FixItHint> handleReturnType(const FunctionDecl *Function, 336 const TypeLoc &ReturnType, 337 const EnableIfData &EnableIf, 338 ASTContext &Context) { 339 TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0); 340 341 SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc); 342 343 std::optional<std::string> ConditionText = getConditionText( 344 EnableCondition.getSourceExpression(), ConditionRange, Context); 345 if (!ConditionText) 346 return {}; 347 348 std::optional<StringRef> TypeText = getTypeText(Context, EnableIf.Loc); 349 if (!TypeText) 350 return {}; 351 352 SmallVector<const Expr *, 3> ExistingConstraints; 353 Function->getAssociatedConstraints(ExistingConstraints); 354 if (!ExistingConstraints.empty()) { 355 // FIXME - Support adding new constraints to existing ones. Do we need to 356 // consider subsumption? 357 return {}; 358 } 359 360 std::optional<SourceLocation> ConstraintInsertionLoc = 361 findInsertionForConstraint(Function, Context); 362 if (!ConstraintInsertionLoc) 363 return {}; 364 365 std::vector<FixItHint> FixIts; 366 FixIts.push_back(FixItHint::CreateReplacement( 367 CharSourceRange::getTokenRange(EnableIf.Outer.getSourceRange()), 368 *TypeText)); 369 FixIts.push_back(FixItHint::CreateInsertion( 370 *ConstraintInsertionLoc, "requires " + *ConditionText + " ")); 371 return FixIts; 372 } 373 374 // Handle enable_if_t in a trailing template parameter, e.g., 375 // template <..., enable_if_t<Condition, Type> = Type{}> 376 // ReturnType function(); 377 // 378 // Return a vector of FixItHints if the code can be replaced with 379 // a C++20 requires clause. In the example above, returns FixItHints 380 // to result in 381 // template <...> 382 // ReturnType function() requires Condition {} 383 static std::vector<FixItHint> 384 handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate, 385 const FunctionDecl *Function, 386 const Decl *LastTemplateParam, 387 const EnableIfData &EnableIf, ASTContext &Context) { 388 SourceManager &SM = Context.getSourceManager(); 389 const LangOptions &LangOpts = Context.getLangOpts(); 390 391 TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0); 392 393 SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc); 394 395 std::optional<std::string> ConditionText = getConditionText( 396 EnableCondition.getSourceExpression(), ConditionRange, Context); 397 if (!ConditionText) 398 return {}; 399 400 SmallVector<const Expr *, 3> ExistingConstraints; 401 Function->getAssociatedConstraints(ExistingConstraints); 402 if (!ExistingConstraints.empty()) { 403 // FIXME - Support adding new constraints to existing ones. Do we need to 404 // consider subsumption? 405 return {}; 406 } 407 408 SourceRange RemovalRange; 409 const TemplateParameterList *TemplateParams = 410 FunctionTemplate->getTemplateParameters(); 411 if (!TemplateParams || TemplateParams->size() == 0) 412 return {}; 413 414 if (TemplateParams->size() == 1) { 415 RemovalRange = 416 SourceRange(TemplateParams->getTemplateLoc(), 417 getRAngleFileLoc(SM, *TemplateParams).getLocWithOffset(1)); 418 } else { 419 RemovalRange = 420 SourceRange(utils::lexer::findPreviousTokenKind( 421 LastTemplateParam->getSourceRange().getBegin(), SM, 422 LangOpts, tok::comma), 423 getRAngleFileLoc(SM, *TemplateParams)); 424 } 425 426 std::optional<SourceLocation> ConstraintInsertionLoc = 427 findInsertionForConstraint(Function, Context); 428 if (!ConstraintInsertionLoc) 429 return {}; 430 431 std::vector<FixItHint> FixIts; 432 FixIts.push_back( 433 FixItHint::CreateRemoval(CharSourceRange::getCharRange(RemovalRange))); 434 FixIts.push_back(FixItHint::CreateInsertion( 435 *ConstraintInsertionLoc, "requires " + *ConditionText + " ")); 436 return FixIts; 437 } 438 439 void UseConstraintsCheck::check(const MatchFinder::MatchResult &Result) { 440 const auto *FunctionTemplate = 441 Result.Nodes.getNodeAs<FunctionTemplateDecl>("functionTemplate"); 442 const auto *Function = Result.Nodes.getNodeAs<FunctionDecl>("function"); 443 const auto *ReturnType = Result.Nodes.getNodeAs<TypeLoc>("return"); 444 if (!FunctionTemplate || !Function || !ReturnType) 445 return; 446 447 // Check for 448 // 449 // Case 1. Return type of function 450 // 451 // template <...> 452 // enable_if_t<Condition, ReturnType>::type function() {} 453 // 454 // Case 2. Trailing template parameter 455 // 456 // template <..., enable_if_t<Condition, Type> = Type{}> 457 // ReturnType function() {} 458 // 459 // or 460 // 461 // template <..., typename = enable_if_t<Condition, void>> 462 // ReturnType function() {} 463 // 464 465 // Case 1. Return type of function 466 if (auto EnableIf = matchEnableIfSpecialization(*ReturnType)) { 467 diag(ReturnType->getBeginLoc(), 468 "use C++20 requires constraints instead of enable_if") 469 << handleReturnType(Function, *ReturnType, *EnableIf, *Result.Context); 470 return; 471 } 472 473 // Case 2. Trailing template parameter 474 if (auto [EnableIf, LastTemplateParam] = 475 matchTrailingTemplateParam(FunctionTemplate); 476 EnableIf && LastTemplateParam) { 477 diag(LastTemplateParam->getSourceRange().getBegin(), 478 "use C++20 requires constraints instead of enable_if") 479 << handleTrailingTemplateType(FunctionTemplate, Function, 480 LastTemplateParam, *EnableIf, 481 *Result.Context); 482 return; 483 } 484 } 485 486 } // namespace clang::tidy::modernize 487