1 //===--- UseConstraintsCheck.cpp - clang-tidy -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "UseConstraintsCheck.h" 10 #include "clang/AST/ASTContext.h" 11 #include "clang/ASTMatchers/ASTMatchFinder.h" 12 #include "clang/Lex/Lexer.h" 13 14 #include "../utils/LexerUtils.h" 15 16 #include <optional> 17 #include <utility> 18 19 using namespace clang::ast_matchers; 20 21 namespace clang::tidy::modernize { 22 23 struct EnableIfData { 24 TemplateSpecializationTypeLoc Loc; 25 TypeLoc Outer; 26 }; 27 28 namespace { 29 AST_MATCHER(FunctionDecl, hasOtherDeclarations) { 30 auto It = Node.redecls_begin(); 31 auto EndIt = Node.redecls_end(); 32 33 if (It == EndIt) 34 return false; 35 36 ++It; 37 return It != EndIt; 38 } 39 } // namespace 40 41 void UseConstraintsCheck::registerMatchers(MatchFinder *Finder) { 42 Finder->addMatcher( 43 functionTemplateDecl( 44 has(functionDecl(unless(hasOtherDeclarations()), isDefinition(), 45 hasReturnTypeLoc(typeLoc().bind("return"))) 46 .bind("function"))) 47 .bind("functionTemplate"), 48 this); 49 } 50 51 static std::optional<TemplateSpecializationTypeLoc> 52 matchEnableIfSpecializationImplTypename(TypeLoc TheType) { 53 if (const auto Dep = TheType.getAs<DependentNameTypeLoc>()) { 54 const IdentifierInfo *Identifier = Dep.getTypePtr()->getIdentifier(); 55 if (!Identifier || Identifier->getName() != "type" || 56 Dep.getTypePtr()->getKeyword() != ElaboratedTypeKeyword::Typename) { 57 return std::nullopt; 58 } 59 TheType = Dep.getQualifierLoc().getTypeLoc(); 60 } 61 62 if (const auto SpecializationLoc = 63 TheType.getAs<TemplateSpecializationTypeLoc>()) { 64 65 const auto *Specialization = 66 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr()); 67 if (!Specialization) 68 return std::nullopt; 69 70 const TemplateDecl *TD = 71 Specialization->getTemplateName().getAsTemplateDecl(); 72 if (!TD || TD->getName() != "enable_if") 73 return std::nullopt; 74 75 int NumArgs = SpecializationLoc.getNumArgs(); 76 if (NumArgs != 1 && NumArgs != 2) 77 return std::nullopt; 78 79 return SpecializationLoc; 80 } 81 return std::nullopt; 82 } 83 84 static std::optional<TemplateSpecializationTypeLoc> 85 matchEnableIfSpecializationImplTrait(TypeLoc TheType) { 86 if (const auto Elaborated = TheType.getAs<ElaboratedTypeLoc>()) 87 TheType = Elaborated.getNamedTypeLoc(); 88 89 if (const auto SpecializationLoc = 90 TheType.getAs<TemplateSpecializationTypeLoc>()) { 91 92 const auto *Specialization = 93 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr()); 94 if (!Specialization) 95 return std::nullopt; 96 97 const TemplateDecl *TD = 98 Specialization->getTemplateName().getAsTemplateDecl(); 99 if (!TD || TD->getName() != "enable_if_t") 100 return std::nullopt; 101 102 if (!Specialization->isTypeAlias()) 103 return std::nullopt; 104 105 if (const auto *AliasedType = 106 dyn_cast<DependentNameType>(Specialization->getAliasedType())) { 107 if (AliasedType->getIdentifier()->getName() != "type" || 108 AliasedType->getKeyword() != ElaboratedTypeKeyword::Typename) { 109 return std::nullopt; 110 } 111 } else { 112 return std::nullopt; 113 } 114 int NumArgs = SpecializationLoc.getNumArgs(); 115 if (NumArgs != 1 && NumArgs != 2) 116 return std::nullopt; 117 118 return SpecializationLoc; 119 } 120 return std::nullopt; 121 } 122 123 static std::optional<TemplateSpecializationTypeLoc> 124 matchEnableIfSpecializationImpl(TypeLoc TheType) { 125 if (auto EnableIf = matchEnableIfSpecializationImplTypename(TheType)) 126 return EnableIf; 127 return matchEnableIfSpecializationImplTrait(TheType); 128 } 129 130 static std::optional<EnableIfData> 131 matchEnableIfSpecialization(TypeLoc TheType) { 132 if (const auto Pointer = TheType.getAs<PointerTypeLoc>()) 133 TheType = Pointer.getPointeeLoc(); 134 else if (const auto Reference = TheType.getAs<ReferenceTypeLoc>()) 135 TheType = Reference.getPointeeLoc(); 136 if (const auto Qualified = TheType.getAs<QualifiedTypeLoc>()) 137 TheType = Qualified.getUnqualifiedLoc(); 138 139 if (auto EnableIf = matchEnableIfSpecializationImpl(TheType)) 140 return EnableIfData{std::move(*EnableIf), TheType}; 141 return std::nullopt; 142 } 143 144 static std::pair<std::optional<EnableIfData>, const Decl *> 145 matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate) { 146 // For non-type trailing param, match very specifically 147 // 'template <..., enable_if_type<Condition, Type> = Default>' where 148 // enable_if_type is 'enable_if' or 'enable_if_t'. E.g., 'template <typename 149 // T, enable_if_t<is_same_v<T, bool>, int*> = nullptr> 150 // 151 // Otherwise, match a trailing default type arg. 152 // E.g., 'template <typename T, typename = enable_if_t<is_same_v<T, bool>>>' 153 154 const TemplateParameterList *TemplateParams = 155 FunctionTemplate->getTemplateParameters(); 156 if (TemplateParams->size() == 0) 157 return {}; 158 159 const NamedDecl *LastParam = 160 TemplateParams->getParam(TemplateParams->size() - 1); 161 if (const auto *LastTemplateParam = 162 dyn_cast<NonTypeTemplateParmDecl>(LastParam)) { 163 164 if (!LastTemplateParam->hasDefaultArgument() || 165 !LastTemplateParam->getName().empty()) 166 return {}; 167 168 return {matchEnableIfSpecialization( 169 LastTemplateParam->getTypeSourceInfo()->getTypeLoc()), 170 LastTemplateParam}; 171 } 172 if (const auto *LastTemplateParam = 173 dyn_cast<TemplateTypeParmDecl>(LastParam)) { 174 if (LastTemplateParam->hasDefaultArgument() && 175 LastTemplateParam->getIdentifier() == nullptr) { 176 return {matchEnableIfSpecialization( 177 LastTemplateParam->getDefaultArgumentInfo()->getTypeLoc()), 178 LastTemplateParam}; 179 } 180 } 181 return {}; 182 } 183 184 template <typename T> 185 static SourceLocation getRAngleFileLoc(const SourceManager &SM, 186 const T &Element) { 187 // getFileLoc handles the case where the RAngle loc is part of a synthesized 188 // '>>', which ends up allocating a 'scratch space' buffer in the source 189 // manager. 190 return SM.getFileLoc(Element.getRAngleLoc()); 191 } 192 193 static SourceRange 194 getConditionRange(ASTContext &Context, 195 const TemplateSpecializationTypeLoc &EnableIf) { 196 // TemplateArgumentLoc's SourceRange End is the location of the last token 197 // (per UnqualifiedId docs). E.g., in `enable_if<AAA && BBB>`, the End 198 // location will be the first 'B' in 'BBB'. 199 const LangOptions &LangOpts = Context.getLangOpts(); 200 const SourceManager &SM = Context.getSourceManager(); 201 if (EnableIf.getNumArgs() > 1) { 202 TemplateArgumentLoc NextArg = EnableIf.getArgLoc(1); 203 return {EnableIf.getLAngleLoc().getLocWithOffset(1), 204 utils::lexer::findPreviousTokenKind( 205 NextArg.getSourceRange().getBegin(), SM, LangOpts, tok::comma)}; 206 } 207 208 return {EnableIf.getLAngleLoc().getLocWithOffset(1), 209 getRAngleFileLoc(SM, EnableIf)}; 210 } 211 212 static SourceRange getTypeRange(ASTContext &Context, 213 const TemplateSpecializationTypeLoc &EnableIf) { 214 TemplateArgumentLoc Arg = EnableIf.getArgLoc(1); 215 const LangOptions &LangOpts = Context.getLangOpts(); 216 const SourceManager &SM = Context.getSourceManager(); 217 return {utils::lexer::findPreviousTokenKind(Arg.getSourceRange().getBegin(), 218 SM, LangOpts, tok::comma) 219 .getLocWithOffset(1), 220 getRAngleFileLoc(SM, EnableIf)}; 221 } 222 223 // Returns the original source text of the second argument of a call to 224 // enable_if_t. E.g., in enable_if_t<Condition, TheType>, this function 225 // returns 'TheType'. 226 static std::optional<StringRef> 227 getTypeText(ASTContext &Context, 228 const TemplateSpecializationTypeLoc &EnableIf) { 229 if (EnableIf.getNumArgs() > 1) { 230 const LangOptions &LangOpts = Context.getLangOpts(); 231 const SourceManager &SM = Context.getSourceManager(); 232 bool Invalid = false; 233 StringRef Text = Lexer::getSourceText(CharSourceRange::getCharRange( 234 getTypeRange(Context, EnableIf)), 235 SM, LangOpts, &Invalid) 236 .trim(); 237 if (Invalid) 238 return std::nullopt; 239 240 return Text; 241 } 242 243 return "void"; 244 } 245 246 static std::optional<SourceLocation> 247 findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context) { 248 SourceManager &SM = Context.getSourceManager(); 249 const LangOptions &LangOpts = Context.getLangOpts(); 250 251 if (const auto *Constructor = dyn_cast<CXXConstructorDecl>(Function)) { 252 for (const CXXCtorInitializer *Init : Constructor->inits()) { 253 if (Init->getSourceOrder() == 0) 254 return utils::lexer::findPreviousTokenKind(Init->getSourceLocation(), 255 SM, LangOpts, tok::colon); 256 } 257 if (!Constructor->inits().empty()) 258 return std::nullopt; 259 } 260 if (Function->isDeleted()) { 261 SourceLocation FunctionEnd = Function->getSourceRange().getEnd(); 262 return utils::lexer::findNextAnyTokenKind(FunctionEnd, SM, LangOpts, 263 tok::equal, tok::equal); 264 } 265 const Stmt *Body = Function->getBody(); 266 if (!Body) 267 return std::nullopt; 268 269 return Body->getBeginLoc(); 270 } 271 272 bool isPrimaryExpression(const Expr *Expression) { 273 // This function is an incomplete approximation of checking whether 274 // an Expr is a primary expression. In particular, if this function 275 // returns true, the expression is a primary expression. The converse 276 // is not necessarily true. 277 278 if (const auto *Cast = dyn_cast<ImplicitCastExpr>(Expression)) 279 Expression = Cast->getSubExprAsWritten(); 280 if (isa<ParenExpr, DependentScopeDeclRefExpr>(Expression)) 281 return true; 282 283 return false; 284 } 285 286 // Return the original source text of an enable_if_t condition, i.e., the 287 // first template argument). For example, in 288 // 'enable_if_t<FirstCondition || SecondCondition, AType>', the text 289 // the text 'FirstCondition || SecondCondition' is returned. 290 static std::optional<std::string> getConditionText(const Expr *ConditionExpr, 291 SourceRange ConditionRange, 292 ASTContext &Context) { 293 SourceManager &SM = Context.getSourceManager(); 294 const LangOptions &LangOpts = Context.getLangOpts(); 295 296 SourceLocation PrevTokenLoc = ConditionRange.getEnd(); 297 if (PrevTokenLoc.isInvalid()) 298 return std::nullopt; 299 300 const bool SkipComments = false; 301 Token PrevToken; 302 std::tie(PrevToken, PrevTokenLoc) = utils::lexer::getPreviousTokenAndStart( 303 PrevTokenLoc, SM, LangOpts, SkipComments); 304 bool EndsWithDoubleSlash = 305 PrevToken.is(tok::comment) && 306 Lexer::getSourceText(CharSourceRange::getCharRange( 307 PrevTokenLoc, PrevTokenLoc.getLocWithOffset(2)), 308 SM, LangOpts) == "//"; 309 310 bool Invalid = false; 311 llvm::StringRef ConditionText = Lexer::getSourceText( 312 CharSourceRange::getCharRange(ConditionRange), SM, LangOpts, &Invalid); 313 if (Invalid) 314 return std::nullopt; 315 316 auto AddParens = [&](llvm::StringRef Text) -> std::string { 317 if (isPrimaryExpression(ConditionExpr)) 318 return Text.str(); 319 return "(" + Text.str() + ")"; 320 }; 321 322 if (EndsWithDoubleSlash) 323 return AddParens(ConditionText); 324 return AddParens(ConditionText.trim()); 325 } 326 327 // Handle functions that return enable_if_t, e.g., 328 // template <...> 329 // enable_if_t<Condition, ReturnType> function(); 330 // 331 // Return a vector of FixItHints if the code can be replaced with 332 // a C++20 requires clause. In the example above, returns FixItHints 333 // to result in 334 // template <...> 335 // ReturnType function() requires Condition {} 336 static std::vector<FixItHint> handleReturnType(const FunctionDecl *Function, 337 const TypeLoc &ReturnType, 338 const EnableIfData &EnableIf, 339 ASTContext &Context) { 340 TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0); 341 342 SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc); 343 344 std::optional<std::string> ConditionText = getConditionText( 345 EnableCondition.getSourceExpression(), ConditionRange, Context); 346 if (!ConditionText) 347 return {}; 348 349 std::optional<StringRef> TypeText = getTypeText(Context, EnableIf.Loc); 350 if (!TypeText) 351 return {}; 352 353 SmallVector<const Expr *, 3> ExistingConstraints; 354 Function->getAssociatedConstraints(ExistingConstraints); 355 if (!ExistingConstraints.empty()) { 356 // FIXME - Support adding new constraints to existing ones. Do we need to 357 // consider subsumption? 358 return {}; 359 } 360 361 std::optional<SourceLocation> ConstraintInsertionLoc = 362 findInsertionForConstraint(Function, Context); 363 if (!ConstraintInsertionLoc) 364 return {}; 365 366 std::vector<FixItHint> FixIts; 367 FixIts.push_back(FixItHint::CreateReplacement( 368 CharSourceRange::getTokenRange(EnableIf.Outer.getSourceRange()), 369 *TypeText)); 370 FixIts.push_back(FixItHint::CreateInsertion( 371 *ConstraintInsertionLoc, "requires " + *ConditionText + " ")); 372 return FixIts; 373 } 374 375 // Handle enable_if_t in a trailing template parameter, e.g., 376 // template <..., enable_if_t<Condition, Type> = Type{}> 377 // ReturnType function(); 378 // 379 // Return a vector of FixItHints if the code can be replaced with 380 // a C++20 requires clause. In the example above, returns FixItHints 381 // to result in 382 // template <...> 383 // ReturnType function() requires Condition {} 384 static std::vector<FixItHint> 385 handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate, 386 const FunctionDecl *Function, 387 const Decl *LastTemplateParam, 388 const EnableIfData &EnableIf, ASTContext &Context) { 389 SourceManager &SM = Context.getSourceManager(); 390 const LangOptions &LangOpts = Context.getLangOpts(); 391 392 TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0); 393 394 SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc); 395 396 std::optional<std::string> ConditionText = getConditionText( 397 EnableCondition.getSourceExpression(), ConditionRange, Context); 398 if (!ConditionText) 399 return {}; 400 401 SmallVector<const Expr *, 3> ExistingConstraints; 402 Function->getAssociatedConstraints(ExistingConstraints); 403 if (!ExistingConstraints.empty()) { 404 // FIXME - Support adding new constraints to existing ones. Do we need to 405 // consider subsumption? 406 return {}; 407 } 408 409 SourceRange RemovalRange; 410 const TemplateParameterList *TemplateParams = 411 FunctionTemplate->getTemplateParameters(); 412 if (!TemplateParams || TemplateParams->size() == 0) 413 return {}; 414 415 if (TemplateParams->size() == 1) { 416 RemovalRange = 417 SourceRange(TemplateParams->getTemplateLoc(), 418 getRAngleFileLoc(SM, *TemplateParams).getLocWithOffset(1)); 419 } else { 420 RemovalRange = 421 SourceRange(utils::lexer::findPreviousTokenKind( 422 LastTemplateParam->getSourceRange().getBegin(), SM, 423 LangOpts, tok::comma), 424 getRAngleFileLoc(SM, *TemplateParams)); 425 } 426 427 std::optional<SourceLocation> ConstraintInsertionLoc = 428 findInsertionForConstraint(Function, Context); 429 if (!ConstraintInsertionLoc) 430 return {}; 431 432 std::vector<FixItHint> FixIts; 433 FixIts.push_back( 434 FixItHint::CreateRemoval(CharSourceRange::getCharRange(RemovalRange))); 435 FixIts.push_back(FixItHint::CreateInsertion( 436 *ConstraintInsertionLoc, "requires " + *ConditionText + " ")); 437 return FixIts; 438 } 439 440 void UseConstraintsCheck::check(const MatchFinder::MatchResult &Result) { 441 const auto *FunctionTemplate = 442 Result.Nodes.getNodeAs<FunctionTemplateDecl>("functionTemplate"); 443 const auto *Function = Result.Nodes.getNodeAs<FunctionDecl>("function"); 444 const auto *ReturnType = Result.Nodes.getNodeAs<TypeLoc>("return"); 445 if (!FunctionTemplate || !Function || !ReturnType) 446 return; 447 448 // Check for 449 // 450 // Case 1. Return type of function 451 // 452 // template <...> 453 // enable_if_t<Condition, ReturnType>::type function() {} 454 // 455 // Case 2. Trailing template parameter 456 // 457 // template <..., enable_if_t<Condition, Type> = Type{}> 458 // ReturnType function() {} 459 // 460 // or 461 // 462 // template <..., typename = enable_if_t<Condition, void>> 463 // ReturnType function() {} 464 // 465 466 // Case 1. Return type of function 467 if (auto EnableIf = matchEnableIfSpecialization(*ReturnType)) { 468 diag(ReturnType->getBeginLoc(), 469 "use C++20 requires constraints instead of enable_if") 470 << handleReturnType(Function, *ReturnType, *EnableIf, *Result.Context); 471 return; 472 } 473 474 // Case 2. Trailing template parameter 475 if (auto [EnableIf, LastTemplateParam] = 476 matchTrailingTemplateParam(FunctionTemplate); 477 EnableIf && LastTemplateParam) { 478 diag(LastTemplateParam->getSourceRange().getBegin(), 479 "use C++20 requires constraints instead of enable_if") 480 << handleTrailingTemplateType(FunctionTemplate, Function, 481 LastTemplateParam, *EnableIf, 482 *Result.Context); 483 return; 484 } 485 } 486 487 } // namespace clang::tidy::modernize 488