1 //===--- ExtractFunction.cpp -------------------------------------*- C++-*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Extracts statements to a new function and replaces the statements with a 10 // call to the new function. 11 // Before: 12 // void f(int a) { 13 // [[if(a < 5) 14 // a = 5;]] 15 // } 16 // After: 17 // void extracted(int &a) { 18 // if(a < 5) 19 // a = 5; 20 // } 21 // void f(int a) { 22 // extracted(a); 23 // } 24 // 25 // - Only extract statements 26 // - Extracts from non-templated free functions only. 27 // - Parameters are const only if the declaration was const 28 // - Always passed by l-value reference 29 // - Void return type 30 // - Cannot extract declarations that will be needed in the original function 31 // after extraction. 32 // - Checks for broken control flow (break/continue without loop/switch) 33 // 34 // 1. ExtractFunction is the tweak subclass 35 // - Prepare does basic analysis of the selection and is therefore fast. 36 // Successful prepare doesn't always mean we can apply the tweak. 37 // - Apply does a more detailed analysis and can be slower. In case of 38 // failure, we let the user know that we are unable to perform extraction. 39 // 2. ExtractionZone store information about the range being extracted and the 40 // enclosing function. 41 // 3. NewFunction stores properties of the extracted function and provides 42 // methods for rendering it. 43 // 4. CapturedZoneInfo uses a RecursiveASTVisitor to capture information about 44 // the extraction like declarations, existing return statements, etc. 45 // 5. getExtractedFunction is responsible for analyzing the CapturedZoneInfo and 46 // creating a NewFunction. 47 //===----------------------------------------------------------------------===// 48 49 #include "AST.h" 50 #include "FindTarget.h" 51 #include "ParsedAST.h" 52 #include "Selection.h" 53 #include "SourceCode.h" 54 #include "refactor/Tweak.h" 55 #include "support/Logger.h" 56 #include "clang/AST/ASTContext.h" 57 #include "clang/AST/Decl.h" 58 #include "clang/AST/DeclBase.h" 59 #include "clang/AST/ExprCXX.h" 60 #include "clang/AST/NestedNameSpecifier.h" 61 #include "clang/AST/RecursiveASTVisitor.h" 62 #include "clang/AST/Stmt.h" 63 #include "clang/Basic/LangOptions.h" 64 #include "clang/Basic/SourceLocation.h" 65 #include "clang/Basic/SourceManager.h" 66 #include "clang/Tooling/Core/Replacement.h" 67 #include "clang/Tooling/Refactoring/Extract/SourceExtraction.h" 68 #include "llvm/ADT/STLExtras.h" 69 #include "llvm/ADT/SmallSet.h" 70 #include "llvm/ADT/SmallVector.h" 71 #include "llvm/ADT/StringRef.h" 72 #include "llvm/Support/Casting.h" 73 #include "llvm/Support/Error.h" 74 #include <optional> 75 76 namespace clang { 77 namespace clangd { 78 namespace { 79 80 using Node = SelectionTree::Node; 81 82 // ExtractionZone is the part of code that is being extracted. 83 // EnclosingFunction is the function/method inside which the zone lies. 84 // We split the file into 4 parts relative to extraction zone. 85 enum class ZoneRelative { 86 Before, // Before Zone and inside EnclosingFunction. 87 Inside, // Inside Zone. 88 After, // After Zone and inside EnclosingFunction. 89 OutsideFunc // Outside EnclosingFunction. 90 }; 91 92 enum FunctionDeclKind { 93 InlineDefinition, 94 ForwardDeclaration, 95 OutOfLineDefinition 96 }; 97 98 // A RootStmt is a statement that's fully selected including all its children 99 // and its parent is unselected. 100 // Check if a node is a root statement. 101 bool isRootStmt(const Node *N) { 102 if (!N->ASTNode.get<Stmt>()) 103 return false; 104 // Root statement cannot be partially selected. 105 if (N->Selected == SelectionTree::Partial) 106 return false; 107 // A DeclStmt can be an unselected RootStmt since VarDecls claim the entire 108 // selection range in selectionTree. Additionally, a CXXOperatorCallExpr of a 109 // binary operation can be unselected because its children claim the entire 110 // selection range in the selection tree (e.g. <<). 111 if (N->Selected == SelectionTree::Unselected && !N->ASTNode.get<DeclStmt>() && 112 !N->ASTNode.get<CXXOperatorCallExpr>()) 113 return false; 114 return true; 115 } 116 117 // Returns the (unselected) parent of all RootStmts given the commonAncestor. 118 // Returns null if: 119 // 1. any node is partially selected 120 // 2. If all completely selected nodes don't have the same common parent 121 // 3. Any child of Parent isn't a RootStmt. 122 // Returns null if any child is not a RootStmt. 123 // We only support extraction of RootStmts since it allows us to extract without 124 // having to change the selection range. Also, this means that any scope that 125 // begins in selection range, ends in selection range and any scope that begins 126 // outside the selection range, ends outside as well. 127 const Node *getParentOfRootStmts(const Node *CommonAnc) { 128 if (!CommonAnc) 129 return nullptr; 130 const Node *Parent = nullptr; 131 switch (CommonAnc->Selected) { 132 case SelectionTree::Selection::Unselected: 133 // Typically a block, with the { and } unselected, could also be ForStmt etc 134 // Ensure all Children are RootStmts. 135 Parent = CommonAnc; 136 break; 137 case SelectionTree::Selection::Partial: 138 // Only a fully-selected single statement can be selected. 139 return nullptr; 140 case SelectionTree::Selection::Complete: 141 // If the Common Ancestor is completely selected, then it's a root statement 142 // and its parent will be unselected. 143 Parent = CommonAnc->Parent; 144 // If parent is a DeclStmt, even though it's unselected, we consider it a 145 // root statement and return its parent. This is done because the VarDecls 146 // claim the entire selection range of the Declaration and DeclStmt is 147 // always unselected. 148 if (Parent->ASTNode.get<DeclStmt>()) 149 Parent = Parent->Parent; 150 break; 151 } 152 // Ensure all Children are RootStmts. 153 return llvm::all_of(Parent->Children, isRootStmt) ? Parent : nullptr; 154 } 155 156 // The ExtractionZone class forms a view of the code wrt Zone. 157 struct ExtractionZone { 158 // Parent of RootStatements being extracted. 159 const Node *Parent = nullptr; 160 // The half-open file range of the code being extracted. 161 SourceRange ZoneRange; 162 // The function inside which our zone resides. 163 const FunctionDecl *EnclosingFunction = nullptr; 164 // The half-open file range of the enclosing function. 165 SourceRange EnclosingFuncRange; 166 // Set of statements that form the ExtractionZone. 167 llvm::DenseSet<const Stmt *> RootStmts; 168 169 SourceLocation getInsertionPoint() const { 170 return EnclosingFuncRange.getBegin(); 171 } 172 bool isRootStmt(const Stmt *S) const; 173 // The last root statement is important to decide where we need to insert a 174 // semicolon after the extraction. 175 const Node *getLastRootStmt() const { return Parent->Children.back(); } 176 177 // Checks if declarations inside extraction zone are accessed afterwards. 178 // 179 // This performs a partial AST traversal proportional to the size of the 180 // enclosing function, so it is possibly expensive. 181 bool requiresHoisting(const SourceManager &SM, 182 const HeuristicResolver *Resolver) const { 183 // First find all the declarations that happened inside extraction zone. 184 llvm::SmallSet<const Decl *, 1> DeclsInExtZone; 185 for (auto *RootStmt : RootStmts) { 186 findExplicitReferences( 187 RootStmt, 188 [&DeclsInExtZone](const ReferenceLoc &Loc) { 189 if (!Loc.IsDecl) 190 return; 191 DeclsInExtZone.insert(Loc.Targets.front()); 192 }, 193 Resolver); 194 } 195 // Early exit without performing expensive traversal below. 196 if (DeclsInExtZone.empty()) 197 return false; 198 // Then make sure they are not used outside the zone. 199 for (const auto *S : EnclosingFunction->getBody()->children()) { 200 if (SM.isBeforeInTranslationUnit(S->getSourceRange().getEnd(), 201 ZoneRange.getEnd())) 202 continue; 203 bool HasPostUse = false; 204 findExplicitReferences( 205 S, 206 [&](const ReferenceLoc &Loc) { 207 if (HasPostUse || 208 SM.isBeforeInTranslationUnit(Loc.NameLoc, ZoneRange.getEnd())) 209 return; 210 HasPostUse = llvm::any_of(Loc.Targets, 211 [&DeclsInExtZone](const Decl *Target) { 212 return DeclsInExtZone.contains(Target); 213 }); 214 }, 215 Resolver); 216 if (HasPostUse) 217 return true; 218 } 219 return false; 220 } 221 }; 222 223 // Whether the code in the extraction zone is guaranteed to return, assuming 224 // no broken control flow (unbound break/continue). 225 // This is a very naive check (does it end with a return stmt). 226 // Doing some rudimentary control flow analysis would cover more cases. 227 bool alwaysReturns(const ExtractionZone &EZ) { 228 const Stmt *Last = EZ.getLastRootStmt()->ASTNode.get<Stmt>(); 229 // Unwrap enclosing (unconditional) compound statement. 230 while (const auto *CS = llvm::dyn_cast<CompoundStmt>(Last)) { 231 if (CS->body_empty()) 232 return false; 233 Last = CS->body_back(); 234 } 235 return llvm::isa<ReturnStmt>(Last); 236 } 237 238 bool ExtractionZone::isRootStmt(const Stmt *S) const { 239 return RootStmts.contains(S); 240 } 241 242 // Finds the function in which the zone lies. 243 const FunctionDecl *findEnclosingFunction(const Node *CommonAnc) { 244 // Walk up the SelectionTree until we find a function Decl 245 for (const Node *CurNode = CommonAnc; CurNode; CurNode = CurNode->Parent) { 246 // Don't extract from lambdas 247 if (CurNode->ASTNode.get<LambdaExpr>()) 248 return nullptr; 249 if (const FunctionDecl *Func = CurNode->ASTNode.get<FunctionDecl>()) { 250 // FIXME: Support extraction from templated functions. 251 if (Func->isTemplated()) 252 return nullptr; 253 if (!Func->getBody()) 254 return nullptr; 255 for (const auto *S : Func->getBody()->children()) { 256 // During apply phase, we perform semantic analysis (e.g. figure out 257 // what variables requires hoisting). We cannot perform those when the 258 // body has invalid statements, so fail up front. 259 if (!S) 260 return nullptr; 261 } 262 return Func; 263 } 264 } 265 return nullptr; 266 } 267 268 // Zone Range is the union of SourceRanges of all child Nodes in Parent since 269 // all child Nodes are RootStmts 270 std::optional<SourceRange> findZoneRange(const Node *Parent, 271 const SourceManager &SM, 272 const LangOptions &LangOpts) { 273 SourceRange SR; 274 if (auto BeginFileRange = toHalfOpenFileRange( 275 SM, LangOpts, Parent->Children.front()->ASTNode.getSourceRange())) 276 SR.setBegin(BeginFileRange->getBegin()); 277 else 278 return std::nullopt; 279 if (auto EndFileRange = toHalfOpenFileRange( 280 SM, LangOpts, Parent->Children.back()->ASTNode.getSourceRange())) 281 SR.setEnd(EndFileRange->getEnd()); 282 else 283 return std::nullopt; 284 return SR; 285 } 286 287 // Compute the range spanned by the enclosing function. 288 // FIXME: check if EnclosingFunction has any attributes as the AST doesn't 289 // always store the source range of the attributes and thus we end up extracting 290 // between the attributes and the EnclosingFunction. 291 std::optional<SourceRange> 292 computeEnclosingFuncRange(const FunctionDecl *EnclosingFunction, 293 const SourceManager &SM, 294 const LangOptions &LangOpts) { 295 return toHalfOpenFileRange(SM, LangOpts, EnclosingFunction->getSourceRange()); 296 } 297 298 // returns true if Child can be a single RootStmt being extracted from 299 // EnclosingFunc. 300 bool validSingleChild(const Node *Child, const FunctionDecl *EnclosingFunc) { 301 // Don't extract expressions. 302 // FIXME: We should extract expressions that are "statements" i.e. not 303 // subexpressions 304 if (Child->ASTNode.get<Expr>()) 305 return false; 306 // Extracting the body of EnclosingFunc would remove it's definition. 307 assert(EnclosingFunc->hasBody() && 308 "We should always be extracting from a function body."); 309 if (Child->ASTNode.get<Stmt>() == EnclosingFunc->getBody()) 310 return false; 311 return true; 312 } 313 314 // FIXME: Check we're not extracting from the initializer/condition of a control 315 // flow structure. 316 std::optional<ExtractionZone> findExtractionZone(const Node *CommonAnc, 317 const SourceManager &SM, 318 const LangOptions &LangOpts) { 319 ExtractionZone ExtZone; 320 ExtZone.Parent = getParentOfRootStmts(CommonAnc); 321 if (!ExtZone.Parent || ExtZone.Parent->Children.empty()) 322 return std::nullopt; 323 ExtZone.EnclosingFunction = findEnclosingFunction(ExtZone.Parent); 324 if (!ExtZone.EnclosingFunction) 325 return std::nullopt; 326 // When there is a single RootStmt, we must check if it's valid for 327 // extraction. 328 if (ExtZone.Parent->Children.size() == 1 && 329 !validSingleChild(ExtZone.getLastRootStmt(), ExtZone.EnclosingFunction)) 330 return std::nullopt; 331 if (auto FuncRange = 332 computeEnclosingFuncRange(ExtZone.EnclosingFunction, SM, LangOpts)) 333 ExtZone.EnclosingFuncRange = *FuncRange; 334 if (auto ZoneRange = findZoneRange(ExtZone.Parent, SM, LangOpts)) 335 ExtZone.ZoneRange = *ZoneRange; 336 if (ExtZone.EnclosingFuncRange.isInvalid() || ExtZone.ZoneRange.isInvalid()) 337 return std::nullopt; 338 339 for (const Node *Child : ExtZone.Parent->Children) 340 ExtZone.RootStmts.insert(Child->ASTNode.get<Stmt>()); 341 342 return ExtZone; 343 } 344 345 // Stores information about the extracted function and provides methods for 346 // rendering it. 347 struct NewFunction { 348 struct Parameter { 349 std::string Name; 350 QualType TypeInfo; 351 bool PassByReference; 352 unsigned OrderPriority; // Lower value parameters are preferred first. 353 std::string render(const DeclContext *Context) const; 354 bool operator<(const Parameter &Other) const { 355 return OrderPriority < Other.OrderPriority; 356 } 357 }; 358 std::string Name = "extracted"; 359 QualType ReturnType; 360 std::vector<Parameter> Parameters; 361 SourceRange BodyRange; 362 SourceLocation DefinitionPoint; 363 std::optional<SourceLocation> ForwardDeclarationPoint; 364 const CXXRecordDecl *EnclosingClass = nullptr; 365 const NestedNameSpecifier *DefinitionQualifier = nullptr; 366 const DeclContext *SemanticDC = nullptr; 367 const DeclContext *SyntacticDC = nullptr; 368 const DeclContext *ForwardDeclarationSyntacticDC = nullptr; 369 bool CallerReturnsValue = false; 370 bool Static = false; 371 ConstexprSpecKind Constexpr = ConstexprSpecKind::Unspecified; 372 bool Const = false; 373 374 // Decides whether the extracted function body and the function call need a 375 // semicolon after extraction. 376 tooling::ExtractionSemicolonPolicy SemicolonPolicy; 377 const LangOptions *LangOpts; 378 NewFunction(tooling::ExtractionSemicolonPolicy SemicolonPolicy, 379 const LangOptions *LangOpts) 380 : SemicolonPolicy(SemicolonPolicy), LangOpts(LangOpts) {} 381 // Render the call for this function. 382 std::string renderCall() const; 383 // Render the definition for this function. 384 std::string renderDeclaration(FunctionDeclKind K, 385 const DeclContext &SemanticDC, 386 const DeclContext &SyntacticDC, 387 const SourceManager &SM) const; 388 389 private: 390 std::string 391 renderParametersForDeclaration(const DeclContext &Enclosing) const; 392 std::string renderParametersForCall() const; 393 std::string renderSpecifiers(FunctionDeclKind K) const; 394 std::string renderQualifiers() const; 395 std::string renderDeclarationName(FunctionDeclKind K) const; 396 // Generate the function body. 397 std::string getFuncBody(const SourceManager &SM) const; 398 }; 399 400 std::string NewFunction::renderParametersForDeclaration( 401 const DeclContext &Enclosing) const { 402 std::string Result; 403 bool NeedCommaBefore = false; 404 for (const Parameter &P : Parameters) { 405 if (NeedCommaBefore) 406 Result += ", "; 407 NeedCommaBefore = true; 408 Result += P.render(&Enclosing); 409 } 410 return Result; 411 } 412 413 std::string NewFunction::renderParametersForCall() const { 414 std::string Result; 415 bool NeedCommaBefore = false; 416 for (const Parameter &P : Parameters) { 417 if (NeedCommaBefore) 418 Result += ", "; 419 NeedCommaBefore = true; 420 Result += P.Name; 421 } 422 return Result; 423 } 424 425 std::string NewFunction::renderSpecifiers(FunctionDeclKind K) const { 426 std::string Attributes; 427 428 if (Static && K != FunctionDeclKind::OutOfLineDefinition) { 429 Attributes += "static "; 430 } 431 432 switch (Constexpr) { 433 case ConstexprSpecKind::Unspecified: 434 case ConstexprSpecKind::Constinit: 435 break; 436 case ConstexprSpecKind::Constexpr: 437 Attributes += "constexpr "; 438 break; 439 case ConstexprSpecKind::Consteval: 440 Attributes += "consteval "; 441 break; 442 } 443 444 return Attributes; 445 } 446 447 std::string NewFunction::renderQualifiers() const { 448 std::string Attributes; 449 450 if (Const) { 451 Attributes += " const"; 452 } 453 454 return Attributes; 455 } 456 457 std::string NewFunction::renderDeclarationName(FunctionDeclKind K) const { 458 if (DefinitionQualifier == nullptr || K != OutOfLineDefinition) { 459 return Name; 460 } 461 462 std::string QualifierName; 463 llvm::raw_string_ostream Oss(QualifierName); 464 DefinitionQualifier->print(Oss, *LangOpts); 465 return llvm::formatv("{0}{1}", QualifierName, Name); 466 } 467 468 std::string NewFunction::renderCall() const { 469 return std::string( 470 llvm::formatv("{0}{1}({2}){3}", CallerReturnsValue ? "return " : "", Name, 471 renderParametersForCall(), 472 (SemicolonPolicy.isNeededInOriginalFunction() ? ";" : ""))); 473 } 474 475 std::string NewFunction::renderDeclaration(FunctionDeclKind K, 476 const DeclContext &SemanticDC, 477 const DeclContext &SyntacticDC, 478 const SourceManager &SM) const { 479 std::string Declaration = std::string(llvm::formatv( 480 "{0}{1} {2}({3}){4}", renderSpecifiers(K), 481 printType(ReturnType, SyntacticDC), renderDeclarationName(K), 482 renderParametersForDeclaration(SemanticDC), renderQualifiers())); 483 484 switch (K) { 485 case ForwardDeclaration: 486 return std::string(llvm::formatv("{0};\n", Declaration)); 487 case OutOfLineDefinition: 488 case InlineDefinition: 489 return std::string( 490 llvm::formatv("{0} {\n{1}\n}\n", Declaration, getFuncBody(SM))); 491 break; 492 } 493 llvm_unreachable("Unsupported FunctionDeclKind enum"); 494 } 495 496 std::string NewFunction::getFuncBody(const SourceManager &SM) const { 497 // FIXME: Generate tooling::Replacements instead of std::string to 498 // - hoist decls 499 // - add return statement 500 // - Add semicolon 501 return toSourceCode(SM, BodyRange).str() + 502 (SemicolonPolicy.isNeededInExtractedFunction() ? ";" : ""); 503 } 504 505 std::string NewFunction::Parameter::render(const DeclContext *Context) const { 506 return printType(TypeInfo, *Context) + (PassByReference ? " &" : " ") + Name; 507 } 508 509 // Stores captured information about Extraction Zone. 510 struct CapturedZoneInfo { 511 struct DeclInformation { 512 const Decl *TheDecl; 513 ZoneRelative DeclaredIn; 514 // index of the declaration or first reference. 515 unsigned DeclIndex; 516 bool IsReferencedInZone = false; 517 bool IsReferencedInPostZone = false; 518 // FIXME: Capture mutation information 519 DeclInformation(const Decl *TheDecl, ZoneRelative DeclaredIn, 520 unsigned DeclIndex) 521 : TheDecl(TheDecl), DeclaredIn(DeclaredIn), DeclIndex(DeclIndex){}; 522 // Marks the occurence of a reference for this declaration 523 void markOccurence(ZoneRelative ReferenceLoc); 524 }; 525 // Maps Decls to their DeclInfo 526 llvm::DenseMap<const Decl *, DeclInformation> DeclInfoMap; 527 bool HasReturnStmt = false; // Are there any return statements in the zone? 528 bool AlwaysReturns = false; // Does the zone always return? 529 // Control flow is broken if we are extracting a break/continue without a 530 // corresponding parent loop/switch 531 bool BrokenControlFlow = false; 532 // FIXME: capture TypeAliasDecl and UsingDirectiveDecl 533 // FIXME: Capture type information as well. 534 DeclInformation *createDeclInfo(const Decl *D, ZoneRelative RelativeLoc); 535 DeclInformation *getDeclInfoFor(const Decl *D); 536 }; 537 538 CapturedZoneInfo::DeclInformation * 539 CapturedZoneInfo::createDeclInfo(const Decl *D, ZoneRelative RelativeLoc) { 540 // The new Decl's index is the size of the map so far. 541 auto InsertionResult = DeclInfoMap.insert( 542 {D, DeclInformation(D, RelativeLoc, DeclInfoMap.size())}); 543 // Return the newly created DeclInfo 544 return &InsertionResult.first->second; 545 } 546 547 CapturedZoneInfo::DeclInformation * 548 CapturedZoneInfo::getDeclInfoFor(const Decl *D) { 549 // If the Decl doesn't exist, we 550 auto Iter = DeclInfoMap.find(D); 551 if (Iter == DeclInfoMap.end()) 552 return nullptr; 553 return &Iter->second; 554 } 555 556 void CapturedZoneInfo::DeclInformation::markOccurence( 557 ZoneRelative ReferenceLoc) { 558 switch (ReferenceLoc) { 559 case ZoneRelative::Inside: 560 IsReferencedInZone = true; 561 break; 562 case ZoneRelative::After: 563 IsReferencedInPostZone = true; 564 break; 565 default: 566 break; 567 } 568 } 569 570 bool isLoop(const Stmt *S) { 571 return isa<ForStmt>(S) || isa<DoStmt>(S) || isa<WhileStmt>(S) || 572 isa<CXXForRangeStmt>(S); 573 } 574 575 // Captures information from Extraction Zone 576 CapturedZoneInfo captureZoneInfo(const ExtractionZone &ExtZone) { 577 // We use the ASTVisitor instead of using the selection tree since we need to 578 // find references in the PostZone as well. 579 // FIXME: Check which statements we don't allow to extract. 580 class ExtractionZoneVisitor 581 : public clang::RecursiveASTVisitor<ExtractionZoneVisitor> { 582 public: 583 ExtractionZoneVisitor(const ExtractionZone &ExtZone) : ExtZone(ExtZone) { 584 TraverseDecl(const_cast<FunctionDecl *>(ExtZone.EnclosingFunction)); 585 } 586 587 bool TraverseStmt(Stmt *S) { 588 if (!S) 589 return true; 590 bool IsRootStmt = ExtZone.isRootStmt(const_cast<const Stmt *>(S)); 591 // If we are starting traversal of a RootStmt, we are somewhere inside 592 // ExtractionZone 593 if (IsRootStmt) 594 CurrentLocation = ZoneRelative::Inside; 595 addToLoopSwitchCounters(S, 1); 596 // Traverse using base class's TraverseStmt 597 RecursiveASTVisitor::TraverseStmt(S); 598 addToLoopSwitchCounters(S, -1); 599 // We set the current location as after since next stmt will either be a 600 // RootStmt (handled at the beginning) or after extractionZone 601 if (IsRootStmt) 602 CurrentLocation = ZoneRelative::After; 603 return true; 604 } 605 606 // Add Increment to CurNumberOf{Loops,Switch} if statement is 607 // {Loop,Switch} and inside Extraction Zone. 608 void addToLoopSwitchCounters(Stmt *S, int Increment) { 609 if (CurrentLocation != ZoneRelative::Inside) 610 return; 611 if (isLoop(S)) 612 CurNumberOfNestedLoops += Increment; 613 else if (isa<SwitchStmt>(S)) 614 CurNumberOfSwitch += Increment; 615 } 616 617 bool VisitDecl(Decl *D) { 618 Info.createDeclInfo(D, CurrentLocation); 619 return true; 620 } 621 622 bool VisitDeclRefExpr(DeclRefExpr *DRE) { 623 // Find the corresponding Decl and mark it's occurrence. 624 const Decl *D = DRE->getDecl(); 625 auto *DeclInfo = Info.getDeclInfoFor(D); 626 // If no Decl was found, the Decl must be outside the enclosingFunc. 627 if (!DeclInfo) 628 DeclInfo = Info.createDeclInfo(D, ZoneRelative::OutsideFunc); 629 DeclInfo->markOccurence(CurrentLocation); 630 // FIXME: check if reference mutates the Decl being referred. 631 return true; 632 } 633 634 bool VisitReturnStmt(ReturnStmt *Return) { 635 if (CurrentLocation == ZoneRelative::Inside) 636 Info.HasReturnStmt = true; 637 return true; 638 } 639 640 bool VisitBreakStmt(BreakStmt *Break) { 641 // Control flow is broken if break statement is selected without any 642 // parent loop or switch statement. 643 if (CurrentLocation == ZoneRelative::Inside && 644 !(CurNumberOfNestedLoops || CurNumberOfSwitch)) 645 Info.BrokenControlFlow = true; 646 return true; 647 } 648 649 bool VisitContinueStmt(ContinueStmt *Continue) { 650 // Control flow is broken if Continue statement is selected without any 651 // parent loop 652 if (CurrentLocation == ZoneRelative::Inside && !CurNumberOfNestedLoops) 653 Info.BrokenControlFlow = true; 654 return true; 655 } 656 CapturedZoneInfo Info; 657 const ExtractionZone &ExtZone; 658 ZoneRelative CurrentLocation = ZoneRelative::Before; 659 // Number of {loop,switch} statements that are currently in the traversal 660 // stack inside Extraction Zone. Used to check for broken control flow. 661 unsigned CurNumberOfNestedLoops = 0; 662 unsigned CurNumberOfSwitch = 0; 663 }; 664 ExtractionZoneVisitor Visitor(ExtZone); 665 CapturedZoneInfo Result = std::move(Visitor.Info); 666 Result.AlwaysReturns = alwaysReturns(ExtZone); 667 return Result; 668 } 669 670 // Adds parameters to ExtractedFunc. 671 // Returns true if able to find the parameters successfully and no hoisting 672 // needed. 673 // FIXME: Check if the declaration has a local/anonymous type 674 bool createParameters(NewFunction &ExtractedFunc, 675 const CapturedZoneInfo &CapturedInfo) { 676 for (const auto &KeyVal : CapturedInfo.DeclInfoMap) { 677 const auto &DeclInfo = KeyVal.second; 678 // If a Decl was Declared in zone and referenced in post zone, it 679 // needs to be hoisted (we bail out in that case). 680 // FIXME: Support Decl Hoisting. 681 if (DeclInfo.DeclaredIn == ZoneRelative::Inside && 682 DeclInfo.IsReferencedInPostZone) 683 return false; 684 if (!DeclInfo.IsReferencedInZone) 685 continue; // no need to pass as parameter, not referenced 686 if (DeclInfo.DeclaredIn == ZoneRelative::Inside || 687 DeclInfo.DeclaredIn == ZoneRelative::OutsideFunc) 688 continue; // no need to pass as parameter, still accessible. 689 // Parameter specific checks. 690 const ValueDecl *VD = dyn_cast_or_null<ValueDecl>(DeclInfo.TheDecl); 691 // Can't parameterise if the Decl isn't a ValueDecl or is a FunctionDecl 692 // (this includes the case of recursive call to EnclosingFunc in Zone). 693 if (!VD || isa<FunctionDecl>(DeclInfo.TheDecl)) 694 return false; 695 // Parameter qualifiers are same as the Decl's qualifiers. 696 QualType TypeInfo = VD->getType().getNonReferenceType(); 697 // FIXME: Need better qualifier checks: check mutated status for 698 // Decl(e.g. was it assigned, passed as nonconst argument, etc) 699 // FIXME: check if parameter will be a non l-value reference. 700 // FIXME: We don't want to always pass variables of types like int, 701 // pointers, etc by reference. 702 bool IsPassedByReference = true; 703 // We use the index of declaration as the ordering priority for parameters. 704 ExtractedFunc.Parameters.push_back({std::string(VD->getName()), TypeInfo, 705 IsPassedByReference, 706 DeclInfo.DeclIndex}); 707 } 708 llvm::sort(ExtractedFunc.Parameters); 709 return true; 710 } 711 712 // Clangd uses open ranges while ExtractionSemicolonPolicy (in Clang Tooling) 713 // uses closed ranges. Generates the semicolon policy for the extraction and 714 // extends the ZoneRange if necessary. 715 tooling::ExtractionSemicolonPolicy 716 getSemicolonPolicy(ExtractionZone &ExtZone, const SourceManager &SM, 717 const LangOptions &LangOpts) { 718 // Get closed ZoneRange. 719 SourceRange FuncBodyRange = {ExtZone.ZoneRange.getBegin(), 720 ExtZone.ZoneRange.getEnd().getLocWithOffset(-1)}; 721 auto SemicolonPolicy = tooling::ExtractionSemicolonPolicy::compute( 722 ExtZone.getLastRootStmt()->ASTNode.get<Stmt>(), FuncBodyRange, SM, 723 LangOpts); 724 // Update ZoneRange. 725 ExtZone.ZoneRange.setEnd(FuncBodyRange.getEnd().getLocWithOffset(1)); 726 return SemicolonPolicy; 727 } 728 729 // Generate return type for ExtractedFunc. Return false if unable to do so. 730 bool generateReturnProperties(NewFunction &ExtractedFunc, 731 const FunctionDecl &EnclosingFunc, 732 const CapturedZoneInfo &CapturedInfo) { 733 // If the selected code always returns, we preserve those return statements. 734 // The return type should be the same as the enclosing function. 735 // (Others are possible if there are conversions, but this seems clearest). 736 if (CapturedInfo.HasReturnStmt) { 737 // If the return is conditional, neither replacing the code with 738 // `extracted()` nor `return extracted()` is correct. 739 if (!CapturedInfo.AlwaysReturns) 740 return false; 741 QualType Ret = EnclosingFunc.getReturnType(); 742 // Once we support members, it'd be nice to support e.g. extracting a method 743 // of Foo<T> that returns T. But it's not clear when that's safe. 744 if (Ret->isDependentType()) 745 return false; 746 ExtractedFunc.ReturnType = Ret; 747 return true; 748 } 749 // FIXME: Generate new return statement if needed. 750 ExtractedFunc.ReturnType = EnclosingFunc.getParentASTContext().VoidTy; 751 return true; 752 } 753 754 void captureMethodInfo(NewFunction &ExtractedFunc, 755 const CXXMethodDecl *Method) { 756 ExtractedFunc.Static = Method->isStatic(); 757 ExtractedFunc.Const = Method->isConst(); 758 ExtractedFunc.EnclosingClass = Method->getParent(); 759 } 760 761 // FIXME: add support for adding other function return types besides void. 762 // FIXME: assign the value returned by non void extracted function. 763 llvm::Expected<NewFunction> getExtractedFunction(ExtractionZone &ExtZone, 764 const SourceManager &SM, 765 const LangOptions &LangOpts) { 766 CapturedZoneInfo CapturedInfo = captureZoneInfo(ExtZone); 767 // Bail out if any break of continue exists 768 if (CapturedInfo.BrokenControlFlow) 769 return error("Cannot extract break/continue without corresponding " 770 "loop/switch statement."); 771 NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts), 772 &LangOpts); 773 774 ExtractedFunc.SyntacticDC = 775 ExtZone.EnclosingFunction->getLexicalDeclContext(); 776 ExtractedFunc.SemanticDC = ExtZone.EnclosingFunction->getDeclContext(); 777 ExtractedFunc.DefinitionQualifier = ExtZone.EnclosingFunction->getQualifier(); 778 ExtractedFunc.Constexpr = ExtZone.EnclosingFunction->getConstexprKind(); 779 780 if (const auto *Method = 781 llvm::dyn_cast<CXXMethodDecl>(ExtZone.EnclosingFunction)) 782 captureMethodInfo(ExtractedFunc, Method); 783 784 if (ExtZone.EnclosingFunction->isOutOfLine()) { 785 // FIXME: Put the extracted method in a private section if it's a class or 786 // maybe in an anonymous namespace 787 const auto *FirstOriginalDecl = 788 ExtZone.EnclosingFunction->getCanonicalDecl(); 789 auto DeclPos = 790 toHalfOpenFileRange(SM, LangOpts, FirstOriginalDecl->getSourceRange()); 791 if (!DeclPos) 792 return error("Declaration is inside a macro"); 793 ExtractedFunc.ForwardDeclarationPoint = DeclPos->getBegin(); 794 ExtractedFunc.ForwardDeclarationSyntacticDC = ExtractedFunc.SemanticDC; 795 } 796 797 ExtractedFunc.BodyRange = ExtZone.ZoneRange; 798 ExtractedFunc.DefinitionPoint = ExtZone.getInsertionPoint(); 799 800 ExtractedFunc.CallerReturnsValue = CapturedInfo.AlwaysReturns; 801 if (!createParameters(ExtractedFunc, CapturedInfo) || 802 !generateReturnProperties(ExtractedFunc, *ExtZone.EnclosingFunction, 803 CapturedInfo)) 804 return error("Too complex to extract."); 805 return ExtractedFunc; 806 } 807 808 class ExtractFunction : public Tweak { 809 public: 810 const char *id() const final; 811 bool prepare(const Selection &Inputs) override; 812 Expected<Effect> apply(const Selection &Inputs) override; 813 std::string title() const override { return "Extract to function"; } 814 llvm::StringLiteral kind() const override { 815 return CodeAction::REFACTOR_KIND; 816 } 817 818 private: 819 ExtractionZone ExtZone; 820 }; 821 822 REGISTER_TWEAK(ExtractFunction) 823 tooling::Replacement replaceWithFuncCall(const NewFunction &ExtractedFunc, 824 const SourceManager &SM, 825 const LangOptions &LangOpts) { 826 std::string FuncCall = ExtractedFunc.renderCall(); 827 return tooling::Replacement( 828 SM, CharSourceRange(ExtractedFunc.BodyRange, false), FuncCall, LangOpts); 829 } 830 831 tooling::Replacement createFunctionDefinition(const NewFunction &ExtractedFunc, 832 const SourceManager &SM) { 833 FunctionDeclKind DeclKind = InlineDefinition; 834 if (ExtractedFunc.ForwardDeclarationPoint) 835 DeclKind = OutOfLineDefinition; 836 std::string FunctionDef = ExtractedFunc.renderDeclaration( 837 DeclKind, *ExtractedFunc.SemanticDC, *ExtractedFunc.SyntacticDC, SM); 838 839 return tooling::Replacement(SM, ExtractedFunc.DefinitionPoint, 0, 840 FunctionDef); 841 } 842 843 tooling::Replacement createForwardDeclaration(const NewFunction &ExtractedFunc, 844 const SourceManager &SM) { 845 std::string FunctionDecl = ExtractedFunc.renderDeclaration( 846 ForwardDeclaration, *ExtractedFunc.SemanticDC, 847 *ExtractedFunc.ForwardDeclarationSyntacticDC, SM); 848 SourceLocation DeclPoint = *ExtractedFunc.ForwardDeclarationPoint; 849 850 return tooling::Replacement(SM, DeclPoint, 0, FunctionDecl); 851 } 852 853 // Returns true if ExtZone contains any ReturnStmts. 854 bool hasReturnStmt(const ExtractionZone &ExtZone) { 855 class ReturnStmtVisitor 856 : public clang::RecursiveASTVisitor<ReturnStmtVisitor> { 857 public: 858 bool VisitReturnStmt(ReturnStmt *Return) { 859 Found = true; 860 return false; // We found the answer, abort the scan. 861 } 862 bool Found = false; 863 }; 864 865 ReturnStmtVisitor V; 866 for (const Stmt *RootStmt : ExtZone.RootStmts) { 867 V.TraverseStmt(const_cast<Stmt *>(RootStmt)); 868 if (V.Found) 869 break; 870 } 871 return V.Found; 872 } 873 874 bool ExtractFunction::prepare(const Selection &Inputs) { 875 const LangOptions &LangOpts = Inputs.AST->getLangOpts(); 876 if (!LangOpts.CPlusPlus) 877 return false; 878 const Node *CommonAnc = Inputs.ASTSelection.commonAncestor(); 879 const SourceManager &SM = Inputs.AST->getSourceManager(); 880 auto MaybeExtZone = findExtractionZone(CommonAnc, SM, LangOpts); 881 if (!MaybeExtZone || 882 (hasReturnStmt(*MaybeExtZone) && !alwaysReturns(*MaybeExtZone))) 883 return false; 884 885 // FIXME: Get rid of this check once we support hoisting. 886 if (MaybeExtZone->requiresHoisting(SM, Inputs.AST->getHeuristicResolver())) 887 return false; 888 889 ExtZone = std::move(*MaybeExtZone); 890 return true; 891 } 892 893 Expected<Tweak::Effect> ExtractFunction::apply(const Selection &Inputs) { 894 const SourceManager &SM = Inputs.AST->getSourceManager(); 895 const LangOptions &LangOpts = Inputs.AST->getLangOpts(); 896 auto ExtractedFunc = getExtractedFunction(ExtZone, SM, LangOpts); 897 // FIXME: Add more types of errors. 898 if (!ExtractedFunc) 899 return ExtractedFunc.takeError(); 900 tooling::Replacements Edit; 901 if (auto Err = Edit.add(createFunctionDefinition(*ExtractedFunc, SM))) 902 return std::move(Err); 903 if (auto Err = Edit.add(replaceWithFuncCall(*ExtractedFunc, SM, LangOpts))) 904 return std::move(Err); 905 906 if (auto FwdLoc = ExtractedFunc->ForwardDeclarationPoint) { 907 // If the fwd-declaration goes in the same file, merge into Replacements. 908 // Otherwise it needs to be a separate file edit. 909 if (SM.isWrittenInSameFile(ExtractedFunc->DefinitionPoint, *FwdLoc)) { 910 if (auto Err = Edit.add(createForwardDeclaration(*ExtractedFunc, SM))) 911 return std::move(Err); 912 } else { 913 auto MultiFileEffect = Effect::mainFileEdit(SM, std::move(Edit)); 914 if (!MultiFileEffect) 915 return MultiFileEffect.takeError(); 916 917 tooling::Replacements OtherEdit( 918 createForwardDeclaration(*ExtractedFunc, SM)); 919 if (auto PathAndEdit = 920 Tweak::Effect::fileEdit(SM, SM.getFileID(*FwdLoc), OtherEdit)) 921 MultiFileEffect->ApplyEdits.try_emplace(PathAndEdit->first, 922 PathAndEdit->second); 923 else 924 return PathAndEdit.takeError(); 925 return MultiFileEffect; 926 } 927 } 928 return Effect::mainFileEdit(SM, std::move(Edit)); 929 } 930 931 } // namespace 932 } // namespace clangd 933 } // namespace clang 934