xref: /llvm-project/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp (revision 9ae21b073ab48b376687ecd7fbae12e08b4ae86e)
1 //===--- ExtractFunction.cpp -------------------------------------*- C++-*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Extracts statements to a new function and replaces the statements with a
10 // call to the new function.
11 // Before:
12 //   void f(int a) {
13 //     [[if(a < 5)
14 //       a = 5;]]
15 //   }
16 // After:
17 //   void extracted(int &a) {
18 //     if(a < 5)
19 //       a = 5;
20 //   }
21 //   void f(int a) {
22 //     extracted(a);
23 //   }
24 //
25 // - Only extract statements
26 // - Extracts from non-templated free functions only.
27 // - Parameters are const only if the declaration was const
28 //   - Always passed by l-value reference
29 // - Void return type
30 // - Cannot extract declarations that will be needed in the original function
31 //   after extraction.
32 // - Checks for broken control flow (break/continue without loop/switch)
33 //
34 // 1. ExtractFunction is the tweak subclass
35 //    - Prepare does basic analysis of the selection and is therefore fast.
36 //      Successful prepare doesn't always mean we can apply the tweak.
37 //    - Apply does a more detailed analysis and can be slower. In case of
38 //      failure, we let the user know that we are unable to perform extraction.
39 // 2. ExtractionZone store information about the range being extracted and the
40 //    enclosing function.
41 // 3. NewFunction stores properties of the extracted function and provides
42 //    methods for rendering it.
43 // 4. CapturedZoneInfo uses a RecursiveASTVisitor to capture information about
44 //    the extraction like declarations, existing return statements, etc.
45 // 5. getExtractedFunction is responsible for analyzing the CapturedZoneInfo and
46 //    creating a NewFunction.
47 //===----------------------------------------------------------------------===//
48 
49 #include "AST.h"
50 #include "FindTarget.h"
51 #include "ParsedAST.h"
52 #include "Selection.h"
53 #include "SourceCode.h"
54 #include "refactor/Tweak.h"
55 #include "support/Logger.h"
56 #include "clang/AST/ASTContext.h"
57 #include "clang/AST/Decl.h"
58 #include "clang/AST/DeclBase.h"
59 #include "clang/AST/ExprCXX.h"
60 #include "clang/AST/NestedNameSpecifier.h"
61 #include "clang/AST/RecursiveASTVisitor.h"
62 #include "clang/AST/Stmt.h"
63 #include "clang/Basic/LangOptions.h"
64 #include "clang/Basic/SourceLocation.h"
65 #include "clang/Basic/SourceManager.h"
66 #include "clang/Tooling/Core/Replacement.h"
67 #include "clang/Tooling/Refactoring/Extract/SourceExtraction.h"
68 #include "llvm/ADT/STLExtras.h"
69 #include "llvm/ADT/SmallSet.h"
70 #include "llvm/ADT/SmallVector.h"
71 #include "llvm/ADT/StringRef.h"
72 #include "llvm/Support/Casting.h"
73 #include "llvm/Support/Error.h"
74 #include <optional>
75 
76 namespace clang {
77 namespace clangd {
78 namespace {
79 
80 using Node = SelectionTree::Node;
81 
82 // ExtractionZone is the part of code that is being extracted.
83 // EnclosingFunction is the function/method inside which the zone lies.
84 // We split the file into 4 parts relative to extraction zone.
85 enum class ZoneRelative {
86   Before,     // Before Zone and inside EnclosingFunction.
87   Inside,     // Inside Zone.
88   After,      // After Zone and inside EnclosingFunction.
89   OutsideFunc // Outside EnclosingFunction.
90 };
91 
92 enum FunctionDeclKind {
93   InlineDefinition,
94   ForwardDeclaration,
95   OutOfLineDefinition
96 };
97 
98 // A RootStmt is a statement that's fully selected including all its children
99 // and its parent is unselected.
100 // Check if a node is a root statement.
101 bool isRootStmt(const Node *N) {
102   if (!N->ASTNode.get<Stmt>())
103     return false;
104   // Root statement cannot be partially selected.
105   if (N->Selected == SelectionTree::Partial)
106     return false;
107   // A DeclStmt can be an unselected RootStmt since VarDecls claim the entire
108   // selection range in selectionTree. Additionally, a CXXOperatorCallExpr of a
109   // binary operation can be unselected because its children claim the entire
110   // selection range in the selection tree (e.g. <<).
111   if (N->Selected == SelectionTree::Unselected && !N->ASTNode.get<DeclStmt>() &&
112       !N->ASTNode.get<CXXOperatorCallExpr>())
113     return false;
114   return true;
115 }
116 
117 // Returns the (unselected) parent of all RootStmts given the commonAncestor.
118 // Returns null if:
119 // 1. any node is partially selected
120 // 2. If all completely selected nodes don't have the same common parent
121 // 3. Any child of Parent isn't a RootStmt.
122 // Returns null if any child is not a RootStmt.
123 // We only support extraction of RootStmts since it allows us to extract without
124 // having to change the selection range. Also, this means that any scope that
125 // begins in selection range, ends in selection range and any scope that begins
126 // outside the selection range, ends outside as well.
127 const Node *getParentOfRootStmts(const Node *CommonAnc) {
128   if (!CommonAnc)
129     return nullptr;
130   const Node *Parent = nullptr;
131   switch (CommonAnc->Selected) {
132   case SelectionTree::Selection::Unselected:
133     // Typically a block, with the { and } unselected, could also be ForStmt etc
134     // Ensure all Children are RootStmts.
135     Parent = CommonAnc;
136     break;
137   case SelectionTree::Selection::Partial:
138     // Only a fully-selected single statement can be selected.
139     return nullptr;
140   case SelectionTree::Selection::Complete:
141     // If the Common Ancestor is completely selected, then it's a root statement
142     // and its parent will be unselected.
143     Parent = CommonAnc->Parent;
144     // If parent is a DeclStmt, even though it's unselected, we consider it a
145     // root statement and return its parent. This is done because the VarDecls
146     // claim the entire selection range of the Declaration and DeclStmt is
147     // always unselected.
148     if (Parent->ASTNode.get<DeclStmt>())
149       Parent = Parent->Parent;
150     break;
151   }
152   // Ensure all Children are RootStmts.
153   return llvm::all_of(Parent->Children, isRootStmt) ? Parent : nullptr;
154 }
155 
156 // The ExtractionZone class forms a view of the code wrt Zone.
157 struct ExtractionZone {
158   // Parent of RootStatements being extracted.
159   const Node *Parent = nullptr;
160   // The half-open file range of the code being extracted.
161   SourceRange ZoneRange;
162   // The function inside which our zone resides.
163   const FunctionDecl *EnclosingFunction = nullptr;
164   // The half-open file range of the enclosing function.
165   SourceRange EnclosingFuncRange;
166   // Set of statements that form the ExtractionZone.
167   llvm::DenseSet<const Stmt *> RootStmts;
168 
169   SourceLocation getInsertionPoint() const {
170     return EnclosingFuncRange.getBegin();
171   }
172   bool isRootStmt(const Stmt *S) const;
173   // The last root statement is important to decide where we need to insert a
174   // semicolon after the extraction.
175   const Node *getLastRootStmt() const { return Parent->Children.back(); }
176 
177   // Checks if declarations inside extraction zone are accessed afterwards.
178   //
179   // This performs a partial AST traversal proportional to the size of the
180   // enclosing function, so it is possibly expensive.
181   bool requiresHoisting(const SourceManager &SM,
182                         const HeuristicResolver *Resolver) const {
183     // First find all the declarations that happened inside extraction zone.
184     llvm::SmallSet<const Decl *, 1> DeclsInExtZone;
185     for (auto *RootStmt : RootStmts) {
186       findExplicitReferences(
187           RootStmt,
188           [&DeclsInExtZone](const ReferenceLoc &Loc) {
189             if (!Loc.IsDecl)
190               return;
191             DeclsInExtZone.insert(Loc.Targets.front());
192           },
193           Resolver);
194     }
195     // Early exit without performing expensive traversal below.
196     if (DeclsInExtZone.empty())
197       return false;
198     // Then make sure they are not used outside the zone.
199     for (const auto *S : EnclosingFunction->getBody()->children()) {
200       if (SM.isBeforeInTranslationUnit(S->getSourceRange().getEnd(),
201                                        ZoneRange.getEnd()))
202         continue;
203       bool HasPostUse = false;
204       findExplicitReferences(
205           S,
206           [&](const ReferenceLoc &Loc) {
207             if (HasPostUse ||
208                 SM.isBeforeInTranslationUnit(Loc.NameLoc, ZoneRange.getEnd()))
209               return;
210             HasPostUse = llvm::any_of(Loc.Targets,
211                                       [&DeclsInExtZone](const Decl *Target) {
212                                         return DeclsInExtZone.contains(Target);
213                                       });
214           },
215           Resolver);
216       if (HasPostUse)
217         return true;
218     }
219     return false;
220   }
221 };
222 
223 // Whether the code in the extraction zone is guaranteed to return, assuming
224 // no broken control flow (unbound break/continue).
225 // This is a very naive check (does it end with a return stmt).
226 // Doing some rudimentary control flow analysis would cover more cases.
227 bool alwaysReturns(const ExtractionZone &EZ) {
228   const Stmt *Last = EZ.getLastRootStmt()->ASTNode.get<Stmt>();
229   // Unwrap enclosing (unconditional) compound statement.
230   while (const auto *CS = llvm::dyn_cast<CompoundStmt>(Last)) {
231     if (CS->body_empty())
232       return false;
233     Last = CS->body_back();
234   }
235   return llvm::isa<ReturnStmt>(Last);
236 }
237 
238 bool ExtractionZone::isRootStmt(const Stmt *S) const {
239   return RootStmts.contains(S);
240 }
241 
242 // Finds the function in which the zone lies.
243 const FunctionDecl *findEnclosingFunction(const Node *CommonAnc) {
244   // Walk up the SelectionTree until we find a function Decl
245   for (const Node *CurNode = CommonAnc; CurNode; CurNode = CurNode->Parent) {
246     // Don't extract from lambdas
247     if (CurNode->ASTNode.get<LambdaExpr>())
248       return nullptr;
249     if (const FunctionDecl *Func = CurNode->ASTNode.get<FunctionDecl>()) {
250       // FIXME: Support extraction from templated functions.
251       if (Func->isTemplated())
252         return nullptr;
253       if (!Func->getBody())
254         return nullptr;
255       for (const auto *S : Func->getBody()->children()) {
256         // During apply phase, we perform semantic analysis (e.g. figure out
257         // what variables requires hoisting). We cannot perform those when the
258         // body has invalid statements, so fail up front.
259         if (!S)
260           return nullptr;
261       }
262       return Func;
263     }
264   }
265   return nullptr;
266 }
267 
268 // Zone Range is the union of SourceRanges of all child Nodes in Parent since
269 // all child Nodes are RootStmts
270 std::optional<SourceRange> findZoneRange(const Node *Parent,
271                                          const SourceManager &SM,
272                                          const LangOptions &LangOpts) {
273   SourceRange SR;
274   if (auto BeginFileRange = toHalfOpenFileRange(
275           SM, LangOpts, Parent->Children.front()->ASTNode.getSourceRange()))
276     SR.setBegin(BeginFileRange->getBegin());
277   else
278     return std::nullopt;
279   if (auto EndFileRange = toHalfOpenFileRange(
280           SM, LangOpts, Parent->Children.back()->ASTNode.getSourceRange()))
281     SR.setEnd(EndFileRange->getEnd());
282   else
283     return std::nullopt;
284   return SR;
285 }
286 
287 // Compute the range spanned by the enclosing function.
288 // FIXME: check if EnclosingFunction has any attributes as the AST doesn't
289 // always store the source range of the attributes and thus we end up extracting
290 // between the attributes and the EnclosingFunction.
291 std::optional<SourceRange>
292 computeEnclosingFuncRange(const FunctionDecl *EnclosingFunction,
293                           const SourceManager &SM,
294                           const LangOptions &LangOpts) {
295   return toHalfOpenFileRange(SM, LangOpts, EnclosingFunction->getSourceRange());
296 }
297 
298 // returns true if Child can be a single RootStmt being extracted from
299 // EnclosingFunc.
300 bool validSingleChild(const Node *Child, const FunctionDecl *EnclosingFunc) {
301   // Don't extract expressions.
302   // FIXME: We should extract expressions that are "statements" i.e. not
303   // subexpressions
304   if (Child->ASTNode.get<Expr>())
305     return false;
306   // Extracting the body of EnclosingFunc would remove it's definition.
307   assert(EnclosingFunc->hasBody() &&
308          "We should always be extracting from a function body.");
309   if (Child->ASTNode.get<Stmt>() == EnclosingFunc->getBody())
310     return false;
311   return true;
312 }
313 
314 // FIXME: Check we're not extracting from the initializer/condition of a control
315 // flow structure.
316 std::optional<ExtractionZone> findExtractionZone(const Node *CommonAnc,
317                                                  const SourceManager &SM,
318                                                  const LangOptions &LangOpts) {
319   ExtractionZone ExtZone;
320   ExtZone.Parent = getParentOfRootStmts(CommonAnc);
321   if (!ExtZone.Parent || ExtZone.Parent->Children.empty())
322     return std::nullopt;
323   ExtZone.EnclosingFunction = findEnclosingFunction(ExtZone.Parent);
324   if (!ExtZone.EnclosingFunction)
325     return std::nullopt;
326   // When there is a single RootStmt, we must check if it's valid for
327   // extraction.
328   if (ExtZone.Parent->Children.size() == 1 &&
329       !validSingleChild(ExtZone.getLastRootStmt(), ExtZone.EnclosingFunction))
330     return std::nullopt;
331   if (auto FuncRange =
332           computeEnclosingFuncRange(ExtZone.EnclosingFunction, SM, LangOpts))
333     ExtZone.EnclosingFuncRange = *FuncRange;
334   if (auto ZoneRange = findZoneRange(ExtZone.Parent, SM, LangOpts))
335     ExtZone.ZoneRange = *ZoneRange;
336   if (ExtZone.EnclosingFuncRange.isInvalid() || ExtZone.ZoneRange.isInvalid())
337     return std::nullopt;
338 
339   for (const Node *Child : ExtZone.Parent->Children)
340     ExtZone.RootStmts.insert(Child->ASTNode.get<Stmt>());
341 
342   return ExtZone;
343 }
344 
345 // Stores information about the extracted function and provides methods for
346 // rendering it.
347 struct NewFunction {
348   struct Parameter {
349     std::string Name;
350     QualType TypeInfo;
351     bool PassByReference;
352     unsigned OrderPriority; // Lower value parameters are preferred first.
353     std::string render(const DeclContext *Context) const;
354     bool operator<(const Parameter &Other) const {
355       return OrderPriority < Other.OrderPriority;
356     }
357   };
358   std::string Name = "extracted";
359   QualType ReturnType;
360   std::vector<Parameter> Parameters;
361   SourceRange BodyRange;
362   SourceLocation DefinitionPoint;
363   std::optional<SourceLocation> ForwardDeclarationPoint;
364   const CXXRecordDecl *EnclosingClass = nullptr;
365   const NestedNameSpecifier *DefinitionQualifier = nullptr;
366   const DeclContext *SemanticDC = nullptr;
367   const DeclContext *SyntacticDC = nullptr;
368   const DeclContext *ForwardDeclarationSyntacticDC = nullptr;
369   bool CallerReturnsValue = false;
370   bool Static = false;
371   ConstexprSpecKind Constexpr = ConstexprSpecKind::Unspecified;
372   bool Const = false;
373 
374   // Decides whether the extracted function body and the function call need a
375   // semicolon after extraction.
376   tooling::ExtractionSemicolonPolicy SemicolonPolicy;
377   const LangOptions *LangOpts;
378   NewFunction(tooling::ExtractionSemicolonPolicy SemicolonPolicy,
379               const LangOptions *LangOpts)
380       : SemicolonPolicy(SemicolonPolicy), LangOpts(LangOpts) {}
381   // Render the call for this function.
382   std::string renderCall() const;
383   // Render the definition for this function.
384   std::string renderDeclaration(FunctionDeclKind K,
385                                 const DeclContext &SemanticDC,
386                                 const DeclContext &SyntacticDC,
387                                 const SourceManager &SM) const;
388 
389 private:
390   std::string
391   renderParametersForDeclaration(const DeclContext &Enclosing) const;
392   std::string renderParametersForCall() const;
393   std::string renderSpecifiers(FunctionDeclKind K) const;
394   std::string renderQualifiers() const;
395   std::string renderDeclarationName(FunctionDeclKind K) const;
396   // Generate the function body.
397   std::string getFuncBody(const SourceManager &SM) const;
398 };
399 
400 std::string NewFunction::renderParametersForDeclaration(
401     const DeclContext &Enclosing) const {
402   std::string Result;
403   bool NeedCommaBefore = false;
404   for (const Parameter &P : Parameters) {
405     if (NeedCommaBefore)
406       Result += ", ";
407     NeedCommaBefore = true;
408     Result += P.render(&Enclosing);
409   }
410   return Result;
411 }
412 
413 std::string NewFunction::renderParametersForCall() const {
414   std::string Result;
415   bool NeedCommaBefore = false;
416   for (const Parameter &P : Parameters) {
417     if (NeedCommaBefore)
418       Result += ", ";
419     NeedCommaBefore = true;
420     Result += P.Name;
421   }
422   return Result;
423 }
424 
425 std::string NewFunction::renderSpecifiers(FunctionDeclKind K) const {
426   std::string Attributes;
427 
428   if (Static && K != FunctionDeclKind::OutOfLineDefinition) {
429     Attributes += "static ";
430   }
431 
432   switch (Constexpr) {
433   case ConstexprSpecKind::Unspecified:
434   case ConstexprSpecKind::Constinit:
435     break;
436   case ConstexprSpecKind::Constexpr:
437     Attributes += "constexpr ";
438     break;
439   case ConstexprSpecKind::Consteval:
440     Attributes += "consteval ";
441     break;
442   }
443 
444   return Attributes;
445 }
446 
447 std::string NewFunction::renderQualifiers() const {
448   std::string Attributes;
449 
450   if (Const) {
451     Attributes += " const";
452   }
453 
454   return Attributes;
455 }
456 
457 std::string NewFunction::renderDeclarationName(FunctionDeclKind K) const {
458   if (DefinitionQualifier == nullptr || K != OutOfLineDefinition) {
459     return Name;
460   }
461 
462   std::string QualifierName;
463   llvm::raw_string_ostream Oss(QualifierName);
464   DefinitionQualifier->print(Oss, *LangOpts);
465   return llvm::formatv("{0}{1}", QualifierName, Name);
466 }
467 
468 std::string NewFunction::renderCall() const {
469   return std::string(
470       llvm::formatv("{0}{1}({2}){3}", CallerReturnsValue ? "return " : "", Name,
471                     renderParametersForCall(),
472                     (SemicolonPolicy.isNeededInOriginalFunction() ? ";" : "")));
473 }
474 
475 std::string NewFunction::renderDeclaration(FunctionDeclKind K,
476                                            const DeclContext &SemanticDC,
477                                            const DeclContext &SyntacticDC,
478                                            const SourceManager &SM) const {
479   std::string Declaration = std::string(llvm::formatv(
480       "{0}{1} {2}({3}){4}", renderSpecifiers(K),
481       printType(ReturnType, SyntacticDC), renderDeclarationName(K),
482       renderParametersForDeclaration(SemanticDC), renderQualifiers()));
483 
484   switch (K) {
485   case ForwardDeclaration:
486     return std::string(llvm::formatv("{0};\n", Declaration));
487   case OutOfLineDefinition:
488   case InlineDefinition:
489     return std::string(
490         llvm::formatv("{0} {\n{1}\n}\n", Declaration, getFuncBody(SM)));
491     break;
492   }
493   llvm_unreachable("Unsupported FunctionDeclKind enum");
494 }
495 
496 std::string NewFunction::getFuncBody(const SourceManager &SM) const {
497   // FIXME: Generate tooling::Replacements instead of std::string to
498   // - hoist decls
499   // - add return statement
500   // - Add semicolon
501   return toSourceCode(SM, BodyRange).str() +
502          (SemicolonPolicy.isNeededInExtractedFunction() ? ";" : "");
503 }
504 
505 std::string NewFunction::Parameter::render(const DeclContext *Context) const {
506   return printType(TypeInfo, *Context) + (PassByReference ? " &" : " ") + Name;
507 }
508 
509 // Stores captured information about Extraction Zone.
510 struct CapturedZoneInfo {
511   struct DeclInformation {
512     const Decl *TheDecl;
513     ZoneRelative DeclaredIn;
514     // index of the declaration or first reference.
515     unsigned DeclIndex;
516     bool IsReferencedInZone = false;
517     bool IsReferencedInPostZone = false;
518     // FIXME: Capture mutation information
519     DeclInformation(const Decl *TheDecl, ZoneRelative DeclaredIn,
520                     unsigned DeclIndex)
521         : TheDecl(TheDecl), DeclaredIn(DeclaredIn), DeclIndex(DeclIndex){};
522     // Marks the occurence of a reference for this declaration
523     void markOccurence(ZoneRelative ReferenceLoc);
524   };
525   // Maps Decls to their DeclInfo
526   llvm::DenseMap<const Decl *, DeclInformation> DeclInfoMap;
527   bool HasReturnStmt = false; // Are there any return statements in the zone?
528   bool AlwaysReturns = false; // Does the zone always return?
529   // Control flow is broken if we are extracting a break/continue without a
530   // corresponding parent loop/switch
531   bool BrokenControlFlow = false;
532   // FIXME: capture TypeAliasDecl and UsingDirectiveDecl
533   // FIXME: Capture type information as well.
534   DeclInformation *createDeclInfo(const Decl *D, ZoneRelative RelativeLoc);
535   DeclInformation *getDeclInfoFor(const Decl *D);
536 };
537 
538 CapturedZoneInfo::DeclInformation *
539 CapturedZoneInfo::createDeclInfo(const Decl *D, ZoneRelative RelativeLoc) {
540   // The new Decl's index is the size of the map so far.
541   auto InsertionResult = DeclInfoMap.insert(
542       {D, DeclInformation(D, RelativeLoc, DeclInfoMap.size())});
543   // Return the newly created DeclInfo
544   return &InsertionResult.first->second;
545 }
546 
547 CapturedZoneInfo::DeclInformation *
548 CapturedZoneInfo::getDeclInfoFor(const Decl *D) {
549   // If the Decl doesn't exist, we
550   auto Iter = DeclInfoMap.find(D);
551   if (Iter == DeclInfoMap.end())
552     return nullptr;
553   return &Iter->second;
554 }
555 
556 void CapturedZoneInfo::DeclInformation::markOccurence(
557     ZoneRelative ReferenceLoc) {
558   switch (ReferenceLoc) {
559   case ZoneRelative::Inside:
560     IsReferencedInZone = true;
561     break;
562   case ZoneRelative::After:
563     IsReferencedInPostZone = true;
564     break;
565   default:
566     break;
567   }
568 }
569 
570 bool isLoop(const Stmt *S) {
571   return isa<ForStmt>(S) || isa<DoStmt>(S) || isa<WhileStmt>(S) ||
572          isa<CXXForRangeStmt>(S);
573 }
574 
575 // Captures information from Extraction Zone
576 CapturedZoneInfo captureZoneInfo(const ExtractionZone &ExtZone) {
577   // We use the ASTVisitor instead of using the selection tree since we need to
578   // find references in the PostZone as well.
579   // FIXME: Check which statements we don't allow to extract.
580   class ExtractionZoneVisitor
581       : public clang::RecursiveASTVisitor<ExtractionZoneVisitor> {
582   public:
583     ExtractionZoneVisitor(const ExtractionZone &ExtZone) : ExtZone(ExtZone) {
584       TraverseDecl(const_cast<FunctionDecl *>(ExtZone.EnclosingFunction));
585     }
586 
587     bool TraverseStmt(Stmt *S) {
588       if (!S)
589         return true;
590       bool IsRootStmt = ExtZone.isRootStmt(const_cast<const Stmt *>(S));
591       // If we are starting traversal of a RootStmt, we are somewhere inside
592       // ExtractionZone
593       if (IsRootStmt)
594         CurrentLocation = ZoneRelative::Inside;
595       addToLoopSwitchCounters(S, 1);
596       // Traverse using base class's TraverseStmt
597       RecursiveASTVisitor::TraverseStmt(S);
598       addToLoopSwitchCounters(S, -1);
599       // We set the current location as after since next stmt will either be a
600       // RootStmt (handled at the beginning) or after extractionZone
601       if (IsRootStmt)
602         CurrentLocation = ZoneRelative::After;
603       return true;
604     }
605 
606     // Add Increment to CurNumberOf{Loops,Switch} if statement is
607     // {Loop,Switch} and inside Extraction Zone.
608     void addToLoopSwitchCounters(Stmt *S, int Increment) {
609       if (CurrentLocation != ZoneRelative::Inside)
610         return;
611       if (isLoop(S))
612         CurNumberOfNestedLoops += Increment;
613       else if (isa<SwitchStmt>(S))
614         CurNumberOfSwitch += Increment;
615     }
616 
617     bool VisitDecl(Decl *D) {
618       Info.createDeclInfo(D, CurrentLocation);
619       return true;
620     }
621 
622     bool VisitDeclRefExpr(DeclRefExpr *DRE) {
623       // Find the corresponding Decl and mark it's occurrence.
624       const Decl *D = DRE->getDecl();
625       auto *DeclInfo = Info.getDeclInfoFor(D);
626       // If no Decl was found, the Decl must be outside the enclosingFunc.
627       if (!DeclInfo)
628         DeclInfo = Info.createDeclInfo(D, ZoneRelative::OutsideFunc);
629       DeclInfo->markOccurence(CurrentLocation);
630       // FIXME: check if reference mutates the Decl being referred.
631       return true;
632     }
633 
634     bool VisitReturnStmt(ReturnStmt *Return) {
635       if (CurrentLocation == ZoneRelative::Inside)
636         Info.HasReturnStmt = true;
637       return true;
638     }
639 
640     bool VisitBreakStmt(BreakStmt *Break) {
641       // Control flow is broken if break statement is selected without any
642       // parent loop or switch statement.
643       if (CurrentLocation == ZoneRelative::Inside &&
644           !(CurNumberOfNestedLoops || CurNumberOfSwitch))
645         Info.BrokenControlFlow = true;
646       return true;
647     }
648 
649     bool VisitContinueStmt(ContinueStmt *Continue) {
650       // Control flow is broken if Continue statement is selected without any
651       // parent loop
652       if (CurrentLocation == ZoneRelative::Inside && !CurNumberOfNestedLoops)
653         Info.BrokenControlFlow = true;
654       return true;
655     }
656     CapturedZoneInfo Info;
657     const ExtractionZone &ExtZone;
658     ZoneRelative CurrentLocation = ZoneRelative::Before;
659     // Number of {loop,switch} statements that are currently in the traversal
660     // stack inside Extraction Zone. Used to check for broken control flow.
661     unsigned CurNumberOfNestedLoops = 0;
662     unsigned CurNumberOfSwitch = 0;
663   };
664   ExtractionZoneVisitor Visitor(ExtZone);
665   CapturedZoneInfo Result = std::move(Visitor.Info);
666   Result.AlwaysReturns = alwaysReturns(ExtZone);
667   return Result;
668 }
669 
670 // Adds parameters to ExtractedFunc.
671 // Returns true if able to find the parameters successfully and no hoisting
672 // needed.
673 // FIXME: Check if the declaration has a local/anonymous type
674 bool createParameters(NewFunction &ExtractedFunc,
675                       const CapturedZoneInfo &CapturedInfo) {
676   for (const auto &KeyVal : CapturedInfo.DeclInfoMap) {
677     const auto &DeclInfo = KeyVal.second;
678     // If a Decl was Declared in zone and referenced in post zone, it
679     // needs to be hoisted (we bail out in that case).
680     // FIXME: Support Decl Hoisting.
681     if (DeclInfo.DeclaredIn == ZoneRelative::Inside &&
682         DeclInfo.IsReferencedInPostZone)
683       return false;
684     if (!DeclInfo.IsReferencedInZone)
685       continue; // no need to pass as parameter, not referenced
686     if (DeclInfo.DeclaredIn == ZoneRelative::Inside ||
687         DeclInfo.DeclaredIn == ZoneRelative::OutsideFunc)
688       continue; // no need to pass as parameter, still accessible.
689     // Parameter specific checks.
690     const ValueDecl *VD = dyn_cast_or_null<ValueDecl>(DeclInfo.TheDecl);
691     // Can't parameterise if the Decl isn't a ValueDecl or is a FunctionDecl
692     // (this includes the case of recursive call to EnclosingFunc in Zone).
693     if (!VD || isa<FunctionDecl>(DeclInfo.TheDecl))
694       return false;
695     // Parameter qualifiers are same as the Decl's qualifiers.
696     QualType TypeInfo = VD->getType().getNonReferenceType();
697     // FIXME: Need better qualifier checks: check mutated status for
698     // Decl(e.g. was it assigned, passed as nonconst argument, etc)
699     // FIXME: check if parameter will be a non l-value reference.
700     // FIXME: We don't want to always pass variables of types like int,
701     // pointers, etc by reference.
702     bool IsPassedByReference = true;
703     // We use the index of declaration as the ordering priority for parameters.
704     ExtractedFunc.Parameters.push_back({std::string(VD->getName()), TypeInfo,
705                                         IsPassedByReference,
706                                         DeclInfo.DeclIndex});
707   }
708   llvm::sort(ExtractedFunc.Parameters);
709   return true;
710 }
711 
712 // Clangd uses open ranges while ExtractionSemicolonPolicy (in Clang Tooling)
713 // uses closed ranges. Generates the semicolon policy for the extraction and
714 // extends the ZoneRange if necessary.
715 tooling::ExtractionSemicolonPolicy
716 getSemicolonPolicy(ExtractionZone &ExtZone, const SourceManager &SM,
717                    const LangOptions &LangOpts) {
718   // Get closed ZoneRange.
719   SourceRange FuncBodyRange = {ExtZone.ZoneRange.getBegin(),
720                                ExtZone.ZoneRange.getEnd().getLocWithOffset(-1)};
721   auto SemicolonPolicy = tooling::ExtractionSemicolonPolicy::compute(
722       ExtZone.getLastRootStmt()->ASTNode.get<Stmt>(), FuncBodyRange, SM,
723       LangOpts);
724   // Update ZoneRange.
725   ExtZone.ZoneRange.setEnd(FuncBodyRange.getEnd().getLocWithOffset(1));
726   return SemicolonPolicy;
727 }
728 
729 // Generate return type for ExtractedFunc. Return false if unable to do so.
730 bool generateReturnProperties(NewFunction &ExtractedFunc,
731                               const FunctionDecl &EnclosingFunc,
732                               const CapturedZoneInfo &CapturedInfo) {
733   // If the selected code always returns, we preserve those return statements.
734   // The return type should be the same as the enclosing function.
735   // (Others are possible if there are conversions, but this seems clearest).
736   if (CapturedInfo.HasReturnStmt) {
737     // If the return is conditional, neither replacing the code with
738     // `extracted()` nor `return extracted()` is correct.
739     if (!CapturedInfo.AlwaysReturns)
740       return false;
741     QualType Ret = EnclosingFunc.getReturnType();
742     // Once we support members, it'd be nice to support e.g. extracting a method
743     // of Foo<T> that returns T. But it's not clear when that's safe.
744     if (Ret->isDependentType())
745       return false;
746     ExtractedFunc.ReturnType = Ret;
747     return true;
748   }
749   // FIXME: Generate new return statement if needed.
750   ExtractedFunc.ReturnType = EnclosingFunc.getParentASTContext().VoidTy;
751   return true;
752 }
753 
754 void captureMethodInfo(NewFunction &ExtractedFunc,
755                        const CXXMethodDecl *Method) {
756   ExtractedFunc.Static = Method->isStatic();
757   ExtractedFunc.Const = Method->isConst();
758   ExtractedFunc.EnclosingClass = Method->getParent();
759 }
760 
761 // FIXME: add support for adding other function return types besides void.
762 // FIXME: assign the value returned by non void extracted function.
763 llvm::Expected<NewFunction> getExtractedFunction(ExtractionZone &ExtZone,
764                                                  const SourceManager &SM,
765                                                  const LangOptions &LangOpts) {
766   CapturedZoneInfo CapturedInfo = captureZoneInfo(ExtZone);
767   // Bail out if any break of continue exists
768   if (CapturedInfo.BrokenControlFlow)
769     return error("Cannot extract break/continue without corresponding "
770                  "loop/switch statement.");
771   NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts),
772                             &LangOpts);
773 
774   ExtractedFunc.SyntacticDC =
775       ExtZone.EnclosingFunction->getLexicalDeclContext();
776   ExtractedFunc.SemanticDC = ExtZone.EnclosingFunction->getDeclContext();
777   ExtractedFunc.DefinitionQualifier = ExtZone.EnclosingFunction->getQualifier();
778   ExtractedFunc.Constexpr = ExtZone.EnclosingFunction->getConstexprKind();
779 
780   if (const auto *Method =
781           llvm::dyn_cast<CXXMethodDecl>(ExtZone.EnclosingFunction))
782     captureMethodInfo(ExtractedFunc, Method);
783 
784   if (ExtZone.EnclosingFunction->isOutOfLine()) {
785     // FIXME: Put the extracted method in a private section if it's a class or
786     // maybe in an anonymous namespace
787     const auto *FirstOriginalDecl =
788         ExtZone.EnclosingFunction->getCanonicalDecl();
789     auto DeclPos =
790         toHalfOpenFileRange(SM, LangOpts, FirstOriginalDecl->getSourceRange());
791     if (!DeclPos)
792       return error("Declaration is inside a macro");
793     ExtractedFunc.ForwardDeclarationPoint = DeclPos->getBegin();
794     ExtractedFunc.ForwardDeclarationSyntacticDC = ExtractedFunc.SemanticDC;
795   }
796 
797   ExtractedFunc.BodyRange = ExtZone.ZoneRange;
798   ExtractedFunc.DefinitionPoint = ExtZone.getInsertionPoint();
799 
800   ExtractedFunc.CallerReturnsValue = CapturedInfo.AlwaysReturns;
801   if (!createParameters(ExtractedFunc, CapturedInfo) ||
802       !generateReturnProperties(ExtractedFunc, *ExtZone.EnclosingFunction,
803                                 CapturedInfo))
804     return error("Too complex to extract.");
805   return ExtractedFunc;
806 }
807 
808 class ExtractFunction : public Tweak {
809 public:
810   const char *id() const final;
811   bool prepare(const Selection &Inputs) override;
812   Expected<Effect> apply(const Selection &Inputs) override;
813   std::string title() const override { return "Extract to function"; }
814   llvm::StringLiteral kind() const override {
815     return CodeAction::REFACTOR_KIND;
816   }
817 
818 private:
819   ExtractionZone ExtZone;
820 };
821 
822 REGISTER_TWEAK(ExtractFunction)
823 tooling::Replacement replaceWithFuncCall(const NewFunction &ExtractedFunc,
824                                          const SourceManager &SM,
825                                          const LangOptions &LangOpts) {
826   std::string FuncCall = ExtractedFunc.renderCall();
827   return tooling::Replacement(
828       SM, CharSourceRange(ExtractedFunc.BodyRange, false), FuncCall, LangOpts);
829 }
830 
831 tooling::Replacement createFunctionDefinition(const NewFunction &ExtractedFunc,
832                                               const SourceManager &SM) {
833   FunctionDeclKind DeclKind = InlineDefinition;
834   if (ExtractedFunc.ForwardDeclarationPoint)
835     DeclKind = OutOfLineDefinition;
836   std::string FunctionDef = ExtractedFunc.renderDeclaration(
837       DeclKind, *ExtractedFunc.SemanticDC, *ExtractedFunc.SyntacticDC, SM);
838 
839   return tooling::Replacement(SM, ExtractedFunc.DefinitionPoint, 0,
840                               FunctionDef);
841 }
842 
843 tooling::Replacement createForwardDeclaration(const NewFunction &ExtractedFunc,
844                                               const SourceManager &SM) {
845   std::string FunctionDecl = ExtractedFunc.renderDeclaration(
846       ForwardDeclaration, *ExtractedFunc.SemanticDC,
847       *ExtractedFunc.ForwardDeclarationSyntacticDC, SM);
848   SourceLocation DeclPoint = *ExtractedFunc.ForwardDeclarationPoint;
849 
850   return tooling::Replacement(SM, DeclPoint, 0, FunctionDecl);
851 }
852 
853 // Returns true if ExtZone contains any ReturnStmts.
854 bool hasReturnStmt(const ExtractionZone &ExtZone) {
855   class ReturnStmtVisitor
856       : public clang::RecursiveASTVisitor<ReturnStmtVisitor> {
857   public:
858     bool VisitReturnStmt(ReturnStmt *Return) {
859       Found = true;
860       return false; // We found the answer, abort the scan.
861     }
862     bool Found = false;
863   };
864 
865   ReturnStmtVisitor V;
866   for (const Stmt *RootStmt : ExtZone.RootStmts) {
867     V.TraverseStmt(const_cast<Stmt *>(RootStmt));
868     if (V.Found)
869       break;
870   }
871   return V.Found;
872 }
873 
874 bool ExtractFunction::prepare(const Selection &Inputs) {
875   const LangOptions &LangOpts = Inputs.AST->getLangOpts();
876   if (!LangOpts.CPlusPlus)
877     return false;
878   const Node *CommonAnc = Inputs.ASTSelection.commonAncestor();
879   const SourceManager &SM = Inputs.AST->getSourceManager();
880   auto MaybeExtZone = findExtractionZone(CommonAnc, SM, LangOpts);
881   if (!MaybeExtZone ||
882       (hasReturnStmt(*MaybeExtZone) && !alwaysReturns(*MaybeExtZone)))
883     return false;
884 
885   // FIXME: Get rid of this check once we support hoisting.
886   if (MaybeExtZone->requiresHoisting(SM, Inputs.AST->getHeuristicResolver()))
887     return false;
888 
889   ExtZone = std::move(*MaybeExtZone);
890   return true;
891 }
892 
893 Expected<Tweak::Effect> ExtractFunction::apply(const Selection &Inputs) {
894   const SourceManager &SM = Inputs.AST->getSourceManager();
895   const LangOptions &LangOpts = Inputs.AST->getLangOpts();
896   auto ExtractedFunc = getExtractedFunction(ExtZone, SM, LangOpts);
897   // FIXME: Add more types of errors.
898   if (!ExtractedFunc)
899     return ExtractedFunc.takeError();
900   tooling::Replacements Edit;
901   if (auto Err = Edit.add(createFunctionDefinition(*ExtractedFunc, SM)))
902     return std::move(Err);
903   if (auto Err = Edit.add(replaceWithFuncCall(*ExtractedFunc, SM, LangOpts)))
904     return std::move(Err);
905 
906   if (auto FwdLoc = ExtractedFunc->ForwardDeclarationPoint) {
907     // If the fwd-declaration goes in the same file, merge into Replacements.
908     // Otherwise it needs to be a separate file edit.
909     if (SM.isWrittenInSameFile(ExtractedFunc->DefinitionPoint, *FwdLoc)) {
910       if (auto Err = Edit.add(createForwardDeclaration(*ExtractedFunc, SM)))
911         return std::move(Err);
912     } else {
913       auto MultiFileEffect = Effect::mainFileEdit(SM, std::move(Edit));
914       if (!MultiFileEffect)
915         return MultiFileEffect.takeError();
916 
917       tooling::Replacements OtherEdit(
918           createForwardDeclaration(*ExtractedFunc, SM));
919       if (auto PathAndEdit =
920               Tweak::Effect::fileEdit(SM, SM.getFileID(*FwdLoc), OtherEdit))
921         MultiFileEffect->ApplyEdits.try_emplace(PathAndEdit->first,
922                                                 PathAndEdit->second);
923       else
924         return PathAndEdit.takeError();
925       return MultiFileEffect;
926     }
927   }
928   return Effect::mainFileEdit(SM, std::move(Edit));
929 }
930 
931 } // namespace
932 } // namespace clangd
933 } // namespace clang
934