1 //===--- UseConstraintsCheck.cpp - clang-tidy -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "UseConstraintsCheck.h" 10 #include "clang/AST/ASTContext.h" 11 #include "clang/ASTMatchers/ASTMatchFinder.h" 12 #include "clang/Lex/Lexer.h" 13 14 #include "../utils/LexerUtils.h" 15 16 #include <optional> 17 #include <utility> 18 19 using namespace clang::ast_matchers; 20 21 namespace clang::tidy::modernize { 22 23 struct EnableIfData { 24 TemplateSpecializationTypeLoc Loc; 25 TypeLoc Outer; 26 }; 27 28 namespace { 29 AST_MATCHER(FunctionDecl, hasOtherDeclarations) { 30 auto It = Node.redecls_begin(); 31 auto EndIt = Node.redecls_end(); 32 33 if (It == EndIt) 34 return false; 35 36 ++It; 37 return It != EndIt; 38 } 39 } // namespace 40 41 void UseConstraintsCheck::registerMatchers(MatchFinder *Finder) { 42 Finder->addMatcher( 43 functionTemplateDecl( 44 // Skip external libraries included as system headers 45 unless(isExpansionInSystemHeader()), 46 has(functionDecl(unless(hasOtherDeclarations()), isDefinition(), 47 hasReturnTypeLoc(typeLoc().bind("return"))) 48 .bind("function"))) 49 .bind("functionTemplate"), 50 this); 51 } 52 53 static std::optional<TemplateSpecializationTypeLoc> 54 matchEnableIfSpecializationImplTypename(TypeLoc TheType) { 55 if (const auto Dep = TheType.getAs<DependentNameTypeLoc>()) { 56 const IdentifierInfo *Identifier = Dep.getTypePtr()->getIdentifier(); 57 if (!Identifier || Identifier->getName() != "type" || 58 Dep.getTypePtr()->getKeyword() != ElaboratedTypeKeyword::Typename) { 59 return std::nullopt; 60 } 61 TheType = Dep.getQualifierLoc().getTypeLoc(); 62 if (TheType.isNull()) 63 return std::nullopt; 64 } 65 66 if (const auto SpecializationLoc = 67 TheType.getAs<TemplateSpecializationTypeLoc>()) { 68 69 const auto *Specialization = 70 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr()); 71 if (!Specialization) 72 return std::nullopt; 73 74 const TemplateDecl *TD = 75 Specialization->getTemplateName().getAsTemplateDecl(); 76 if (!TD || TD->getName() != "enable_if") 77 return std::nullopt; 78 79 int NumArgs = SpecializationLoc.getNumArgs(); 80 if (NumArgs != 1 && NumArgs != 2) 81 return std::nullopt; 82 83 return SpecializationLoc; 84 } 85 return std::nullopt; 86 } 87 88 static std::optional<TemplateSpecializationTypeLoc> 89 matchEnableIfSpecializationImplTrait(TypeLoc TheType) { 90 if (const auto Elaborated = TheType.getAs<ElaboratedTypeLoc>()) 91 TheType = Elaborated.getNamedTypeLoc(); 92 93 if (const auto SpecializationLoc = 94 TheType.getAs<TemplateSpecializationTypeLoc>()) { 95 96 const auto *Specialization = 97 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr()); 98 if (!Specialization) 99 return std::nullopt; 100 101 const TemplateDecl *TD = 102 Specialization->getTemplateName().getAsTemplateDecl(); 103 if (!TD || TD->getName() != "enable_if_t") 104 return std::nullopt; 105 106 if (!Specialization->isTypeAlias()) 107 return std::nullopt; 108 109 if (const auto *AliasedType = 110 dyn_cast<DependentNameType>(Specialization->getAliasedType())) { 111 if (AliasedType->getIdentifier()->getName() != "type" || 112 AliasedType->getKeyword() != ElaboratedTypeKeyword::Typename) { 113 return std::nullopt; 114 } 115 } else { 116 return std::nullopt; 117 } 118 int NumArgs = SpecializationLoc.getNumArgs(); 119 if (NumArgs != 1 && NumArgs != 2) 120 return std::nullopt; 121 122 return SpecializationLoc; 123 } 124 return std::nullopt; 125 } 126 127 static std::optional<TemplateSpecializationTypeLoc> 128 matchEnableIfSpecializationImpl(TypeLoc TheType) { 129 if (auto EnableIf = matchEnableIfSpecializationImplTypename(TheType)) 130 return EnableIf; 131 return matchEnableIfSpecializationImplTrait(TheType); 132 } 133 134 static std::optional<EnableIfData> 135 matchEnableIfSpecialization(TypeLoc TheType) { 136 if (const auto Pointer = TheType.getAs<PointerTypeLoc>()) 137 TheType = Pointer.getPointeeLoc(); 138 else if (const auto Reference = TheType.getAs<ReferenceTypeLoc>()) 139 TheType = Reference.getPointeeLoc(); 140 if (const auto Qualified = TheType.getAs<QualifiedTypeLoc>()) 141 TheType = Qualified.getUnqualifiedLoc(); 142 143 if (auto EnableIf = matchEnableIfSpecializationImpl(TheType)) 144 return EnableIfData{std::move(*EnableIf), TheType}; 145 return std::nullopt; 146 } 147 148 static std::pair<std::optional<EnableIfData>, const Decl *> 149 matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate) { 150 // For non-type trailing param, match very specifically 151 // 'template <..., enable_if_type<Condition, Type> = Default>' where 152 // enable_if_type is 'enable_if' or 'enable_if_t'. E.g., 'template <typename 153 // T, enable_if_t<is_same_v<T, bool>, int*> = nullptr> 154 // 155 // Otherwise, match a trailing default type arg. 156 // E.g., 'template <typename T, typename = enable_if_t<is_same_v<T, bool>>>' 157 158 const TemplateParameterList *TemplateParams = 159 FunctionTemplate->getTemplateParameters(); 160 if (TemplateParams->size() == 0) 161 return {}; 162 163 const NamedDecl *LastParam = 164 TemplateParams->getParam(TemplateParams->size() - 1); 165 if (const auto *LastTemplateParam = 166 dyn_cast<NonTypeTemplateParmDecl>(LastParam)) { 167 168 if (!LastTemplateParam->hasDefaultArgument() || 169 !LastTemplateParam->getName().empty()) 170 return {}; 171 172 return {matchEnableIfSpecialization( 173 LastTemplateParam->getTypeSourceInfo()->getTypeLoc()), 174 LastTemplateParam}; 175 } 176 if (const auto *LastTemplateParam = 177 dyn_cast<TemplateTypeParmDecl>(LastParam)) { 178 if (LastTemplateParam->hasDefaultArgument() && 179 LastTemplateParam->getIdentifier() == nullptr) { 180 return {matchEnableIfSpecialization( 181 LastTemplateParam->getDefaultArgumentInfo()->getTypeLoc()), 182 LastTemplateParam}; 183 } 184 } 185 return {}; 186 } 187 188 template <typename T> 189 static SourceLocation getRAngleFileLoc(const SourceManager &SM, 190 const T &Element) { 191 // getFileLoc handles the case where the RAngle loc is part of a synthesized 192 // '>>', which ends up allocating a 'scratch space' buffer in the source 193 // manager. 194 return SM.getFileLoc(Element.getRAngleLoc()); 195 } 196 197 static SourceRange 198 getConditionRange(ASTContext &Context, 199 const TemplateSpecializationTypeLoc &EnableIf) { 200 // TemplateArgumentLoc's SourceRange End is the location of the last token 201 // (per UnqualifiedId docs). E.g., in `enable_if<AAA && BBB>`, the End 202 // location will be the first 'B' in 'BBB'. 203 const LangOptions &LangOpts = Context.getLangOpts(); 204 const SourceManager &SM = Context.getSourceManager(); 205 if (EnableIf.getNumArgs() > 1) { 206 TemplateArgumentLoc NextArg = EnableIf.getArgLoc(1); 207 return {EnableIf.getLAngleLoc().getLocWithOffset(1), 208 utils::lexer::findPreviousTokenKind( 209 NextArg.getSourceRange().getBegin(), SM, LangOpts, tok::comma)}; 210 } 211 212 return {EnableIf.getLAngleLoc().getLocWithOffset(1), 213 getRAngleFileLoc(SM, EnableIf)}; 214 } 215 216 static SourceRange getTypeRange(ASTContext &Context, 217 const TemplateSpecializationTypeLoc &EnableIf) { 218 TemplateArgumentLoc Arg = EnableIf.getArgLoc(1); 219 const LangOptions &LangOpts = Context.getLangOpts(); 220 const SourceManager &SM = Context.getSourceManager(); 221 return {utils::lexer::findPreviousTokenKind(Arg.getSourceRange().getBegin(), 222 SM, LangOpts, tok::comma) 223 .getLocWithOffset(1), 224 getRAngleFileLoc(SM, EnableIf)}; 225 } 226 227 // Returns the original source text of the second argument of a call to 228 // enable_if_t. E.g., in enable_if_t<Condition, TheType>, this function 229 // returns 'TheType'. 230 static std::optional<StringRef> 231 getTypeText(ASTContext &Context, 232 const TemplateSpecializationTypeLoc &EnableIf) { 233 if (EnableIf.getNumArgs() > 1) { 234 const LangOptions &LangOpts = Context.getLangOpts(); 235 const SourceManager &SM = Context.getSourceManager(); 236 bool Invalid = false; 237 StringRef Text = Lexer::getSourceText(CharSourceRange::getCharRange( 238 getTypeRange(Context, EnableIf)), 239 SM, LangOpts, &Invalid) 240 .trim(); 241 if (Invalid) 242 return std::nullopt; 243 244 return Text; 245 } 246 247 return "void"; 248 } 249 250 static std::optional<SourceLocation> 251 findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context) { 252 SourceManager &SM = Context.getSourceManager(); 253 const LangOptions &LangOpts = Context.getLangOpts(); 254 255 if (const auto *Constructor = dyn_cast<CXXConstructorDecl>(Function)) { 256 for (const CXXCtorInitializer *Init : Constructor->inits()) { 257 if (Init->getSourceOrder() == 0) 258 return utils::lexer::findPreviousTokenKind(Init->getSourceLocation(), 259 SM, LangOpts, tok::colon); 260 } 261 if (!Constructor->inits().empty()) 262 return std::nullopt; 263 } 264 if (Function->isDeleted()) { 265 SourceLocation FunctionEnd = Function->getSourceRange().getEnd(); 266 return utils::lexer::findNextAnyTokenKind(FunctionEnd, SM, LangOpts, 267 tok::equal, tok::equal); 268 } 269 const Stmt *Body = Function->getBody(); 270 if (!Body) 271 return std::nullopt; 272 273 return Body->getBeginLoc(); 274 } 275 276 bool isPrimaryExpression(const Expr *Expression) { 277 // This function is an incomplete approximation of checking whether 278 // an Expr is a primary expression. In particular, if this function 279 // returns true, the expression is a primary expression. The converse 280 // is not necessarily true. 281 282 if (const auto *Cast = dyn_cast<ImplicitCastExpr>(Expression)) 283 Expression = Cast->getSubExprAsWritten(); 284 if (isa<ParenExpr, DependentScopeDeclRefExpr>(Expression)) 285 return true; 286 287 return false; 288 } 289 290 // Return the original source text of an enable_if_t condition, i.e., the 291 // first template argument). For example, in 292 // 'enable_if_t<FirstCondition || SecondCondition, AType>', the text 293 // the text 'FirstCondition || SecondCondition' is returned. 294 static std::optional<std::string> getConditionText(const Expr *ConditionExpr, 295 SourceRange ConditionRange, 296 ASTContext &Context) { 297 SourceManager &SM = Context.getSourceManager(); 298 const LangOptions &LangOpts = Context.getLangOpts(); 299 300 SourceLocation PrevTokenLoc = ConditionRange.getEnd(); 301 if (PrevTokenLoc.isInvalid()) 302 return std::nullopt; 303 304 const bool SkipComments = false; 305 Token PrevToken; 306 std::tie(PrevToken, PrevTokenLoc) = utils::lexer::getPreviousTokenAndStart( 307 PrevTokenLoc, SM, LangOpts, SkipComments); 308 bool EndsWithDoubleSlash = 309 PrevToken.is(tok::comment) && 310 Lexer::getSourceText(CharSourceRange::getCharRange( 311 PrevTokenLoc, PrevTokenLoc.getLocWithOffset(2)), 312 SM, LangOpts) == "//"; 313 314 bool Invalid = false; 315 llvm::StringRef ConditionText = Lexer::getSourceText( 316 CharSourceRange::getCharRange(ConditionRange), SM, LangOpts, &Invalid); 317 if (Invalid) 318 return std::nullopt; 319 320 auto AddParens = [&](llvm::StringRef Text) -> std::string { 321 if (isPrimaryExpression(ConditionExpr)) 322 return Text.str(); 323 return "(" + Text.str() + ")"; 324 }; 325 326 if (EndsWithDoubleSlash) 327 return AddParens(ConditionText); 328 return AddParens(ConditionText.trim()); 329 } 330 331 // Handle functions that return enable_if_t, e.g., 332 // template <...> 333 // enable_if_t<Condition, ReturnType> function(); 334 // 335 // Return a vector of FixItHints if the code can be replaced with 336 // a C++20 requires clause. In the example above, returns FixItHints 337 // to result in 338 // template <...> 339 // ReturnType function() requires Condition {} 340 static std::vector<FixItHint> handleReturnType(const FunctionDecl *Function, 341 const TypeLoc &ReturnType, 342 const EnableIfData &EnableIf, 343 ASTContext &Context) { 344 TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0); 345 346 SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc); 347 348 std::optional<std::string> ConditionText = getConditionText( 349 EnableCondition.getSourceExpression(), ConditionRange, Context); 350 if (!ConditionText) 351 return {}; 352 353 std::optional<StringRef> TypeText = getTypeText(Context, EnableIf.Loc); 354 if (!TypeText) 355 return {}; 356 357 SmallVector<const Expr *, 3> ExistingConstraints; 358 Function->getAssociatedConstraints(ExistingConstraints); 359 if (!ExistingConstraints.empty()) { 360 // FIXME - Support adding new constraints to existing ones. Do we need to 361 // consider subsumption? 362 return {}; 363 } 364 365 std::optional<SourceLocation> ConstraintInsertionLoc = 366 findInsertionForConstraint(Function, Context); 367 if (!ConstraintInsertionLoc) 368 return {}; 369 370 std::vector<FixItHint> FixIts; 371 FixIts.push_back(FixItHint::CreateReplacement( 372 CharSourceRange::getTokenRange(EnableIf.Outer.getSourceRange()), 373 *TypeText)); 374 FixIts.push_back(FixItHint::CreateInsertion( 375 *ConstraintInsertionLoc, "requires " + *ConditionText + " ")); 376 return FixIts; 377 } 378 379 // Handle enable_if_t in a trailing template parameter, e.g., 380 // template <..., enable_if_t<Condition, Type> = Type{}> 381 // ReturnType function(); 382 // 383 // Return a vector of FixItHints if the code can be replaced with 384 // a C++20 requires clause. In the example above, returns FixItHints 385 // to result in 386 // template <...> 387 // ReturnType function() requires Condition {} 388 static std::vector<FixItHint> 389 handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate, 390 const FunctionDecl *Function, 391 const Decl *LastTemplateParam, 392 const EnableIfData &EnableIf, ASTContext &Context) { 393 SourceManager &SM = Context.getSourceManager(); 394 const LangOptions &LangOpts = Context.getLangOpts(); 395 396 TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0); 397 398 SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc); 399 400 std::optional<std::string> ConditionText = getConditionText( 401 EnableCondition.getSourceExpression(), ConditionRange, Context); 402 if (!ConditionText) 403 return {}; 404 405 SmallVector<const Expr *, 3> ExistingConstraints; 406 Function->getAssociatedConstraints(ExistingConstraints); 407 if (!ExistingConstraints.empty()) { 408 // FIXME - Support adding new constraints to existing ones. Do we need to 409 // consider subsumption? 410 return {}; 411 } 412 413 SourceRange RemovalRange; 414 const TemplateParameterList *TemplateParams = 415 FunctionTemplate->getTemplateParameters(); 416 if (!TemplateParams || TemplateParams->size() == 0) 417 return {}; 418 419 if (TemplateParams->size() == 1) { 420 RemovalRange = 421 SourceRange(TemplateParams->getTemplateLoc(), 422 getRAngleFileLoc(SM, *TemplateParams).getLocWithOffset(1)); 423 } else { 424 RemovalRange = 425 SourceRange(utils::lexer::findPreviousTokenKind( 426 LastTemplateParam->getSourceRange().getBegin(), SM, 427 LangOpts, tok::comma), 428 getRAngleFileLoc(SM, *TemplateParams)); 429 } 430 431 std::optional<SourceLocation> ConstraintInsertionLoc = 432 findInsertionForConstraint(Function, Context); 433 if (!ConstraintInsertionLoc) 434 return {}; 435 436 std::vector<FixItHint> FixIts; 437 FixIts.push_back( 438 FixItHint::CreateRemoval(CharSourceRange::getCharRange(RemovalRange))); 439 FixIts.push_back(FixItHint::CreateInsertion( 440 *ConstraintInsertionLoc, "requires " + *ConditionText + " ")); 441 return FixIts; 442 } 443 444 void UseConstraintsCheck::check(const MatchFinder::MatchResult &Result) { 445 const auto *FunctionTemplate = 446 Result.Nodes.getNodeAs<FunctionTemplateDecl>("functionTemplate"); 447 const auto *Function = Result.Nodes.getNodeAs<FunctionDecl>("function"); 448 const auto *ReturnType = Result.Nodes.getNodeAs<TypeLoc>("return"); 449 if (!FunctionTemplate || !Function || !ReturnType) 450 return; 451 452 // Check for 453 // 454 // Case 1. Return type of function 455 // 456 // template <...> 457 // enable_if_t<Condition, ReturnType>::type function() {} 458 // 459 // Case 2. Trailing template parameter 460 // 461 // template <..., enable_if_t<Condition, Type> = Type{}> 462 // ReturnType function() {} 463 // 464 // or 465 // 466 // template <..., typename = enable_if_t<Condition, void>> 467 // ReturnType function() {} 468 // 469 470 // Case 1. Return type of function 471 if (auto EnableIf = matchEnableIfSpecialization(*ReturnType)) { 472 diag(ReturnType->getBeginLoc(), 473 "use C++20 requires constraints instead of enable_if") 474 << handleReturnType(Function, *ReturnType, *EnableIf, *Result.Context); 475 return; 476 } 477 478 // Case 2. Trailing template parameter 479 if (auto [EnableIf, LastTemplateParam] = 480 matchTrailingTemplateParam(FunctionTemplate); 481 EnableIf && LastTemplateParam) { 482 diag(LastTemplateParam->getSourceRange().getBegin(), 483 "use C++20 requires constraints instead of enable_if") 484 << handleTrailingTemplateType(FunctionTemplate, Function, 485 LastTemplateParam, *EnableIf, 486 *Result.Context); 487 return; 488 } 489 } 490 491 } // namespace clang::tidy::modernize 492