xref: /llvm-project/clang-tools-extra/clangd/refactor/InsertionPoint.cpp (revision edd690b02e16e991393bf7f67631196942369aed)
1 //===--- InsertionPoint.cpp - Where should we add new code? ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "refactor/InsertionPoint.h"
10 #include "support/Logger.h"
11 #include "clang/AST/ASTContext.h"
12 #include "clang/AST/DeclCXX.h"
13 #include "clang/AST/DeclObjC.h"
14 #include "clang/AST/DeclTemplate.h"
15 #include "clang/Basic/SourceManager.h"
16 #include <optional>
17 
18 namespace clang {
19 namespace clangd {
20 namespace {
21 
22 // Choose the decl to insert before, according to an anchor.
23 // Nullptr means insert at end of DC.
24 // std::nullopt means no valid place to insert.
insertionDecl(const DeclContext & DC,const Anchor & A)25 std::optional<const Decl *> insertionDecl(const DeclContext &DC,
26                                           const Anchor &A) {
27   bool LastMatched = false;
28   bool ReturnNext = false;
29   for (const auto *D : DC.decls()) {
30     if (D->isImplicit())
31       continue;
32     if (ReturnNext)
33       return D;
34 
35     const Decl *NonTemplate = D;
36     if (auto *TD = llvm::dyn_cast<TemplateDecl>(D))
37       NonTemplate = TD->getTemplatedDecl();
38     bool Matches = A.Match(NonTemplate);
39     dlog("    {0} {1} {2}", Matches, D->getDeclKindName(), D);
40 
41     switch (A.Direction) {
42     case Anchor::Above:
43       if (Matches && !LastMatched) {
44         // Special case: if "above" matches an access specifier, we actually
45         // want to insert below it!
46         if (llvm::isa<AccessSpecDecl>(D)) {
47           ReturnNext = true;
48           continue;
49         }
50         return D;
51       }
52       break;
53     case Anchor::Below:
54       if (LastMatched && !Matches)
55         return D;
56       break;
57     }
58 
59     LastMatched = Matches;
60   }
61   if (ReturnNext || (LastMatched && A.Direction == Anchor::Below))
62     return nullptr;
63   return std::nullopt;
64 }
65 
beginLoc(const Decl & D)66 SourceLocation beginLoc(const Decl &D) {
67   auto Loc = D.getBeginLoc();
68   if (RawComment *Comment = D.getASTContext().getRawCommentForDeclNoCache(&D)) {
69     auto CommentLoc = Comment->getBeginLoc();
70     if (CommentLoc.isValid() && Loc.isValid() &&
71         D.getASTContext().getSourceManager().isBeforeInTranslationUnit(
72             CommentLoc, Loc))
73       Loc = CommentLoc;
74   }
75   return Loc;
76 }
77 
any(const Decl * D)78 bool any(const Decl *D) { return true; }
79 
endLoc(const DeclContext & DC)80 SourceLocation endLoc(const DeclContext &DC) {
81   const Decl *D = llvm::cast<Decl>(&DC);
82   if (auto *OCD = llvm::dyn_cast<ObjCContainerDecl>(D))
83     return OCD->getAtEndRange().getBegin();
84   return D->getEndLoc();
85 }
86 
getAccessAtEnd(const CXXRecordDecl & C)87 AccessSpecifier getAccessAtEnd(const CXXRecordDecl &C) {
88   AccessSpecifier Spec =
89       (C.getTagKind() == TagTypeKind::Class ? AS_private : AS_public);
90   for (const auto *D : C.decls())
91     if (const auto *ASD = llvm::dyn_cast<AccessSpecDecl>(D))
92       Spec = ASD->getAccess();
93   return Spec;
94 }
95 
96 } // namespace
97 
insertionPoint(const DeclContext & DC,llvm::ArrayRef<Anchor> Anchors)98 SourceLocation insertionPoint(const DeclContext &DC,
99                               llvm::ArrayRef<Anchor> Anchors) {
100   dlog("Looking for insertion point in {0}", DC.getDeclKindName());
101   for (const auto &A : Anchors) {
102     dlog("  anchor ({0})", A.Direction == Anchor::Above ? "above" : "below");
103     if (auto D = insertionDecl(DC, A)) {
104       dlog("  anchor matched before {0}", *D);
105       return *D ? beginLoc(**D) : endLoc(DC);
106     }
107   }
108   dlog("no anchor matched");
109   return SourceLocation();
110 }
111 
112 llvm::Expected<tooling::Replacement>
insertDecl(llvm::StringRef Code,const DeclContext & DC,llvm::ArrayRef<Anchor> Anchors)113 insertDecl(llvm::StringRef Code, const DeclContext &DC,
114            llvm::ArrayRef<Anchor> Anchors) {
115   auto Loc = insertionPoint(DC, Anchors);
116   // Fallback: insert at the end.
117   if (Loc.isInvalid())
118     Loc = endLoc(DC);
119   const auto &SM = DC.getParentASTContext().getSourceManager();
120   if (!SM.isWrittenInSameFile(Loc, cast<Decl>(DC).getLocation()))
121     return error("{0} body in wrong file: {1}", DC.getDeclKindName(),
122                  Loc.printToString(SM));
123   return tooling::Replacement(SM, Loc, 0, Code);
124 }
125 
insertionPoint(const CXXRecordDecl & InClass,std::vector<Anchor> Anchors,AccessSpecifier Protection)126 SourceLocation insertionPoint(const CXXRecordDecl &InClass,
127                               std::vector<Anchor> Anchors,
128                               AccessSpecifier Protection) {
129   for (auto &A : Anchors)
130     A.Match = [Inner(std::move(A.Match)), Protection](const Decl *D) {
131       return D->getAccess() == Protection && Inner(D);
132     };
133   return insertionPoint(InClass, Anchors);
134 }
135 
insertDecl(llvm::StringRef Code,const CXXRecordDecl & InClass,std::vector<Anchor> Anchors,AccessSpecifier Protection)136 llvm::Expected<tooling::Replacement> insertDecl(llvm::StringRef Code,
137                                                 const CXXRecordDecl &InClass,
138                                                 std::vector<Anchor> Anchors,
139                                                 AccessSpecifier Protection) {
140   // Fallback: insert at the bottom of the relevant access section.
141   Anchors.push_back({any, Anchor::Below});
142   auto Loc = insertionPoint(InClass, std::move(Anchors), Protection);
143   std::string CodeBuffer;
144   auto &SM = InClass.getASTContext().getSourceManager();
145   // Fallback: insert at the end of the class. Check if protection matches!
146   if (Loc.isInvalid()) {
147     Loc = InClass.getBraceRange().getEnd();
148     if (Protection != getAccessAtEnd(InClass)) {
149       CodeBuffer = (getAccessSpelling(Protection) + ":\n" + Code).str();
150       Code = CodeBuffer;
151     }
152   }
153   if (!SM.isWrittenInSameFile(Loc, InClass.getLocation()))
154     return error("Class body in wrong file: {0}", Loc.printToString(SM));
155   return tooling::Replacement(SM, Loc, 0, Code);
156 }
157 
158 } // namespace clangd
159 } // namespace clang
160