xref: /llvm-project/clang/lib/Tooling/Refactoring/Extract/SourceExtraction.cpp (revision 52d0cfc91e075658546e5f6f5804855de68a89a3)
1 //===--- SourceExtraction.cpp - Clang refactoring library -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "clang/Tooling/Refactoring/Extract/SourceExtraction.h"
10 #include "clang/AST/Stmt.h"
11 #include "clang/AST/StmtCXX.h"
12 #include "clang/AST/StmtObjC.h"
13 #include "clang/Basic/SourceManager.h"
14 #include "clang/Lex/Lexer.h"
15 
16 using namespace clang;
17 
18 namespace {
19 
20 /// Returns true if the token at the given location is a semicolon.
21 bool isSemicolonAtLocation(SourceLocation TokenLoc, const SourceManager &SM,
22                            const LangOptions &LangOpts) {
23   return Lexer::getSourceText(
24              CharSourceRange::getTokenRange(TokenLoc, TokenLoc), SM,
25              LangOpts) == ";";
26 }
27 
28 /// Returns true if there should be a semicolon after the given statement.
29 bool isSemicolonRequiredAfter(const Stmt *S) {
30   if (isa<CompoundStmt>(S))
31     return false;
32   if (const auto *If = dyn_cast<IfStmt>(S))
33     return isSemicolonRequiredAfter(If->getElse() ? If->getElse()
34                                                   : If->getThen());
35   if (const auto *While = dyn_cast<WhileStmt>(S))
36     return isSemicolonRequiredAfter(While->getBody());
37   if (const auto *For = dyn_cast<ForStmt>(S))
38     return isSemicolonRequiredAfter(For->getBody());
39   if (const auto *CXXFor = dyn_cast<CXXForRangeStmt>(S))
40     return isSemicolonRequiredAfter(CXXFor->getBody());
41   if (const auto *ObjCFor = dyn_cast<ObjCForCollectionStmt>(S))
42     return isSemicolonRequiredAfter(ObjCFor->getBody());
43   if(const auto *Switch = dyn_cast<SwitchStmt>(S))
44     return isSemicolonRequiredAfter(Switch->getBody());
45   if(const auto *Case = dyn_cast<SwitchCase>(S))
46     return isSemicolonRequiredAfter(Case->getSubStmt());
47   switch (S->getStmtClass()) {
48   case Stmt::CXXTryStmtClass:
49   case Stmt::ObjCAtSynchronizedStmtClass:
50   case Stmt::ObjCAutoreleasePoolStmtClass:
51   case Stmt::ObjCAtTryStmtClass:
52     return false;
53   default:
54     return true;
55   }
56 }
57 
58 /// Returns true if the two source locations are on the same line.
59 bool areOnSameLine(SourceLocation Loc1, SourceLocation Loc2,
60                    const SourceManager &SM) {
61   return !Loc1.isMacroID() && !Loc2.isMacroID() &&
62          SM.getSpellingLineNumber(Loc1) == SM.getSpellingLineNumber(Loc2);
63 }
64 
65 } // end anonymous namespace
66 
67 namespace clang {
68 namespace tooling {
69 
70 ExtractionSemicolonPolicy
71 ExtractionSemicolonPolicy::compute(const Stmt *S, SourceRange &ExtractedRange,
72                                    const SourceManager &SM,
73                                    const LangOptions &LangOpts) {
74   auto neededInExtractedFunction = []() {
75     return ExtractionSemicolonPolicy(true, false);
76   };
77   auto neededInOriginalFunction = []() {
78     return ExtractionSemicolonPolicy(false, true);
79   };
80 
81   /// The extracted expression should be terminated with a ';'. The call to
82   /// the extracted function will replace this expression, so it won't need
83   /// a terminating ';'.
84   if (isa<Expr>(S))
85     return neededInExtractedFunction();
86 
87   /// Some statements don't need to be terminated with ';'. The call to the
88   /// extracted function will be a standalone statement, so it should be
89   /// terminated with a ';'.
90   bool NeedsSemi = isSemicolonRequiredAfter(S);
91   if (!NeedsSemi)
92     return neededInOriginalFunction();
93 
94   /// Some statements might end at ';'. The extraction will move that ';', so
95   /// the call to the extracted function should be terminated with a ';'.
96   SourceLocation End = ExtractedRange.getEnd();
97   if (isSemicolonAtLocation(End, SM, LangOpts))
98     return neededInOriginalFunction();
99 
100   /// Other statements should generally have a trailing ';'. We can try to find
101   /// it and move it together it with the extracted code.
102   Optional<Token> NextToken = Lexer::findNextToken(End, SM, LangOpts);
103   if (NextToken && NextToken->is(tok::semi) &&
104       areOnSameLine(NextToken->getLocation(), End, SM)) {
105     ExtractedRange.setEnd(NextToken->getLocation());
106     return neededInOriginalFunction();
107   }
108 
109   /// Otherwise insert semicolons in both places.
110   return ExtractionSemicolonPolicy(true, true);
111 }
112 
113 } // end namespace tooling
114 } // end namespace clang
115