xref: /openbsd-src/gnu/llvm/clang/lib/Tooling/RefactoringCallbacks.cpp (revision ec727ea710c91afd8ce4f788c5aaa8482b7b69b2)
1e5dd7070Spatrick //===--- RefactoringCallbacks.cpp - Structural query framework ------------===//
2e5dd7070Spatrick //
3e5dd7070Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e5dd7070Spatrick // See https://llvm.org/LICENSE.txt for license information.
5e5dd7070Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e5dd7070Spatrick //
7e5dd7070Spatrick //===----------------------------------------------------------------------===//
8e5dd7070Spatrick //
9e5dd7070Spatrick //
10e5dd7070Spatrick //===----------------------------------------------------------------------===//
11e5dd7070Spatrick #include "clang/Tooling/RefactoringCallbacks.h"
12e5dd7070Spatrick #include "clang/ASTMatchers/ASTMatchFinder.h"
13e5dd7070Spatrick #include "clang/Basic/SourceLocation.h"
14e5dd7070Spatrick #include "clang/Lex/Lexer.h"
15e5dd7070Spatrick 
16e5dd7070Spatrick using llvm::StringError;
17e5dd7070Spatrick using llvm::make_error;
18e5dd7070Spatrick 
19e5dd7070Spatrick namespace clang {
20e5dd7070Spatrick namespace tooling {
21e5dd7070Spatrick 
RefactoringCallback()22e5dd7070Spatrick RefactoringCallback::RefactoringCallback() {}
getReplacements()23e5dd7070Spatrick tooling::Replacements &RefactoringCallback::getReplacements() {
24e5dd7070Spatrick   return Replace;
25e5dd7070Spatrick }
26e5dd7070Spatrick 
ASTMatchRefactorer(std::map<std::string,Replacements> & FileToReplaces)27e5dd7070Spatrick ASTMatchRefactorer::ASTMatchRefactorer(
28e5dd7070Spatrick     std::map<std::string, Replacements> &FileToReplaces)
29e5dd7070Spatrick     : FileToReplaces(FileToReplaces) {}
30e5dd7070Spatrick 
addDynamicMatcher(const ast_matchers::internal::DynTypedMatcher & Matcher,RefactoringCallback * Callback)31e5dd7070Spatrick void ASTMatchRefactorer::addDynamicMatcher(
32e5dd7070Spatrick     const ast_matchers::internal::DynTypedMatcher &Matcher,
33e5dd7070Spatrick     RefactoringCallback *Callback) {
34e5dd7070Spatrick   MatchFinder.addDynamicMatcher(Matcher, Callback);
35e5dd7070Spatrick   Callbacks.push_back(Callback);
36e5dd7070Spatrick }
37e5dd7070Spatrick 
38e5dd7070Spatrick class RefactoringASTConsumer : public ASTConsumer {
39e5dd7070Spatrick public:
RefactoringASTConsumer(ASTMatchRefactorer & Refactoring)40e5dd7070Spatrick   explicit RefactoringASTConsumer(ASTMatchRefactorer &Refactoring)
41e5dd7070Spatrick       : Refactoring(Refactoring) {}
42e5dd7070Spatrick 
HandleTranslationUnit(ASTContext & Context)43e5dd7070Spatrick   void HandleTranslationUnit(ASTContext &Context) override {
44e5dd7070Spatrick     // The ASTMatchRefactorer is re-used between translation units.
45e5dd7070Spatrick     // Clear the matchers so that each Replacement is only emitted once.
46e5dd7070Spatrick     for (const auto &Callback : Refactoring.Callbacks) {
47e5dd7070Spatrick       Callback->getReplacements().clear();
48e5dd7070Spatrick     }
49e5dd7070Spatrick     Refactoring.MatchFinder.matchAST(Context);
50e5dd7070Spatrick     for (const auto &Callback : Refactoring.Callbacks) {
51e5dd7070Spatrick       for (const auto &Replacement : Callback->getReplacements()) {
52e5dd7070Spatrick         llvm::Error Err =
53*ec727ea7Spatrick             Refactoring.FileToReplaces[std::string(Replacement.getFilePath())]
54*ec727ea7Spatrick                 .add(Replacement);
55e5dd7070Spatrick         if (Err) {
56e5dd7070Spatrick           llvm::errs() << "Skipping replacement " << Replacement.toString()
57e5dd7070Spatrick                        << " due to this error:\n"
58e5dd7070Spatrick                        << toString(std::move(Err)) << "\n";
59e5dd7070Spatrick         }
60e5dd7070Spatrick       }
61e5dd7070Spatrick     }
62e5dd7070Spatrick   }
63e5dd7070Spatrick 
64e5dd7070Spatrick private:
65e5dd7070Spatrick   ASTMatchRefactorer &Refactoring;
66e5dd7070Spatrick };
67e5dd7070Spatrick 
newASTConsumer()68e5dd7070Spatrick std::unique_ptr<ASTConsumer> ASTMatchRefactorer::newASTConsumer() {
69e5dd7070Spatrick   return std::make_unique<RefactoringASTConsumer>(*this);
70e5dd7070Spatrick }
71e5dd7070Spatrick 
replaceStmtWithText(SourceManager & Sources,const Stmt & From,StringRef Text)72e5dd7070Spatrick static Replacement replaceStmtWithText(SourceManager &Sources, const Stmt &From,
73e5dd7070Spatrick                                        StringRef Text) {
74e5dd7070Spatrick   return tooling::Replacement(
75e5dd7070Spatrick       Sources, CharSourceRange::getTokenRange(From.getSourceRange()), Text);
76e5dd7070Spatrick }
replaceStmtWithStmt(SourceManager & Sources,const Stmt & From,const Stmt & To)77e5dd7070Spatrick static Replacement replaceStmtWithStmt(SourceManager &Sources, const Stmt &From,
78e5dd7070Spatrick                                        const Stmt &To) {
79e5dd7070Spatrick   return replaceStmtWithText(
80e5dd7070Spatrick       Sources, From,
81e5dd7070Spatrick       Lexer::getSourceText(CharSourceRange::getTokenRange(To.getSourceRange()),
82e5dd7070Spatrick                            Sources, LangOptions()));
83e5dd7070Spatrick }
84e5dd7070Spatrick 
ReplaceStmtWithText(StringRef FromId,StringRef ToText)85e5dd7070Spatrick ReplaceStmtWithText::ReplaceStmtWithText(StringRef FromId, StringRef ToText)
86*ec727ea7Spatrick     : FromId(std::string(FromId)), ToText(std::string(ToText)) {}
87e5dd7070Spatrick 
run(const ast_matchers::MatchFinder::MatchResult & Result)88e5dd7070Spatrick void ReplaceStmtWithText::run(
89e5dd7070Spatrick     const ast_matchers::MatchFinder::MatchResult &Result) {
90e5dd7070Spatrick   if (const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId)) {
91e5dd7070Spatrick     auto Err = Replace.add(tooling::Replacement(
92e5dd7070Spatrick         *Result.SourceManager,
93e5dd7070Spatrick         CharSourceRange::getTokenRange(FromMatch->getSourceRange()), ToText));
94e5dd7070Spatrick     // FIXME: better error handling. For now, just print error message in the
95e5dd7070Spatrick     // release version.
96e5dd7070Spatrick     if (Err) {
97e5dd7070Spatrick       llvm::errs() << llvm::toString(std::move(Err)) << "\n";
98e5dd7070Spatrick       assert(false);
99e5dd7070Spatrick     }
100e5dd7070Spatrick   }
101e5dd7070Spatrick }
102e5dd7070Spatrick 
ReplaceStmtWithStmt(StringRef FromId,StringRef ToId)103e5dd7070Spatrick ReplaceStmtWithStmt::ReplaceStmtWithStmt(StringRef FromId, StringRef ToId)
104*ec727ea7Spatrick     : FromId(std::string(FromId)), ToId(std::string(ToId)) {}
105e5dd7070Spatrick 
run(const ast_matchers::MatchFinder::MatchResult & Result)106e5dd7070Spatrick void ReplaceStmtWithStmt::run(
107e5dd7070Spatrick     const ast_matchers::MatchFinder::MatchResult &Result) {
108e5dd7070Spatrick   const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId);
109e5dd7070Spatrick   const Stmt *ToMatch = Result.Nodes.getNodeAs<Stmt>(ToId);
110e5dd7070Spatrick   if (FromMatch && ToMatch) {
111e5dd7070Spatrick     auto Err = Replace.add(
112e5dd7070Spatrick         replaceStmtWithStmt(*Result.SourceManager, *FromMatch, *ToMatch));
113e5dd7070Spatrick     // FIXME: better error handling. For now, just print error message in the
114e5dd7070Spatrick     // release version.
115e5dd7070Spatrick     if (Err) {
116e5dd7070Spatrick       llvm::errs() << llvm::toString(std::move(Err)) << "\n";
117e5dd7070Spatrick       assert(false);
118e5dd7070Spatrick     }
119e5dd7070Spatrick   }
120e5dd7070Spatrick }
121e5dd7070Spatrick 
ReplaceIfStmtWithItsBody(StringRef Id,bool PickTrueBranch)122e5dd7070Spatrick ReplaceIfStmtWithItsBody::ReplaceIfStmtWithItsBody(StringRef Id,
123e5dd7070Spatrick                                                    bool PickTrueBranch)
124*ec727ea7Spatrick     : Id(std::string(Id)), PickTrueBranch(PickTrueBranch) {}
125e5dd7070Spatrick 
run(const ast_matchers::MatchFinder::MatchResult & Result)126e5dd7070Spatrick void ReplaceIfStmtWithItsBody::run(
127e5dd7070Spatrick     const ast_matchers::MatchFinder::MatchResult &Result) {
128e5dd7070Spatrick   if (const IfStmt *Node = Result.Nodes.getNodeAs<IfStmt>(Id)) {
129e5dd7070Spatrick     const Stmt *Body = PickTrueBranch ? Node->getThen() : Node->getElse();
130e5dd7070Spatrick     if (Body) {
131e5dd7070Spatrick       auto Err =
132e5dd7070Spatrick           Replace.add(replaceStmtWithStmt(*Result.SourceManager, *Node, *Body));
133e5dd7070Spatrick       // FIXME: better error handling. For now, just print error message in the
134e5dd7070Spatrick       // release version.
135e5dd7070Spatrick       if (Err) {
136e5dd7070Spatrick         llvm::errs() << llvm::toString(std::move(Err)) << "\n";
137e5dd7070Spatrick         assert(false);
138e5dd7070Spatrick       }
139e5dd7070Spatrick     } else if (!PickTrueBranch) {
140e5dd7070Spatrick       // If we want to use the 'else'-branch, but it doesn't exist, delete
141e5dd7070Spatrick       // the whole 'if'.
142e5dd7070Spatrick       auto Err =
143e5dd7070Spatrick           Replace.add(replaceStmtWithText(*Result.SourceManager, *Node, ""));
144e5dd7070Spatrick       // FIXME: better error handling. For now, just print error message in the
145e5dd7070Spatrick       // release version.
146e5dd7070Spatrick       if (Err) {
147e5dd7070Spatrick         llvm::errs() << llvm::toString(std::move(Err)) << "\n";
148e5dd7070Spatrick         assert(false);
149e5dd7070Spatrick       }
150e5dd7070Spatrick     }
151e5dd7070Spatrick   }
152e5dd7070Spatrick }
153e5dd7070Spatrick 
ReplaceNodeWithTemplate(llvm::StringRef FromId,std::vector<TemplateElement> Template)154e5dd7070Spatrick ReplaceNodeWithTemplate::ReplaceNodeWithTemplate(
155e5dd7070Spatrick     llvm::StringRef FromId, std::vector<TemplateElement> Template)
156*ec727ea7Spatrick     : FromId(std::string(FromId)), Template(std::move(Template)) {}
157e5dd7070Spatrick 
158e5dd7070Spatrick llvm::Expected<std::unique_ptr<ReplaceNodeWithTemplate>>
create(StringRef FromId,StringRef ToTemplate)159e5dd7070Spatrick ReplaceNodeWithTemplate::create(StringRef FromId, StringRef ToTemplate) {
160e5dd7070Spatrick   std::vector<TemplateElement> ParsedTemplate;
161e5dd7070Spatrick   for (size_t Index = 0; Index < ToTemplate.size();) {
162e5dd7070Spatrick     if (ToTemplate[Index] == '$') {
163e5dd7070Spatrick       if (ToTemplate.substr(Index, 2) == "$$") {
164e5dd7070Spatrick         Index += 2;
165e5dd7070Spatrick         ParsedTemplate.push_back(
166e5dd7070Spatrick             TemplateElement{TemplateElement::Literal, "$"});
167e5dd7070Spatrick       } else if (ToTemplate.substr(Index, 2) == "${") {
168e5dd7070Spatrick         size_t EndOfIdentifier = ToTemplate.find("}", Index);
169e5dd7070Spatrick         if (EndOfIdentifier == std::string::npos) {
170e5dd7070Spatrick           return make_error<StringError>(
171e5dd7070Spatrick               "Unterminated ${...} in replacement template near " +
172e5dd7070Spatrick                   ToTemplate.substr(Index),
173e5dd7070Spatrick               llvm::inconvertibleErrorCode());
174e5dd7070Spatrick         }
175*ec727ea7Spatrick         std::string SourceNodeName = std::string(
176*ec727ea7Spatrick             ToTemplate.substr(Index + 2, EndOfIdentifier - Index - 2));
177e5dd7070Spatrick         ParsedTemplate.push_back(
178e5dd7070Spatrick             TemplateElement{TemplateElement::Identifier, SourceNodeName});
179e5dd7070Spatrick         Index = EndOfIdentifier + 1;
180e5dd7070Spatrick       } else {
181e5dd7070Spatrick         return make_error<StringError>(
182e5dd7070Spatrick             "Invalid $ in replacement template near " +
183e5dd7070Spatrick                 ToTemplate.substr(Index),
184e5dd7070Spatrick             llvm::inconvertibleErrorCode());
185e5dd7070Spatrick       }
186e5dd7070Spatrick     } else {
187e5dd7070Spatrick       size_t NextIndex = ToTemplate.find('$', Index + 1);
188*ec727ea7Spatrick       ParsedTemplate.push_back(TemplateElement{
189*ec727ea7Spatrick           TemplateElement::Literal,
190*ec727ea7Spatrick           std::string(ToTemplate.substr(Index, NextIndex - Index))});
191e5dd7070Spatrick       Index = NextIndex;
192e5dd7070Spatrick     }
193e5dd7070Spatrick   }
194e5dd7070Spatrick   return std::unique_ptr<ReplaceNodeWithTemplate>(
195e5dd7070Spatrick       new ReplaceNodeWithTemplate(FromId, std::move(ParsedTemplate)));
196e5dd7070Spatrick }
197e5dd7070Spatrick 
run(const ast_matchers::MatchFinder::MatchResult & Result)198e5dd7070Spatrick void ReplaceNodeWithTemplate::run(
199e5dd7070Spatrick     const ast_matchers::MatchFinder::MatchResult &Result) {
200e5dd7070Spatrick   const auto &NodeMap = Result.Nodes.getMap();
201e5dd7070Spatrick 
202e5dd7070Spatrick   std::string ToText;
203e5dd7070Spatrick   for (const auto &Element : Template) {
204e5dd7070Spatrick     switch (Element.Type) {
205e5dd7070Spatrick     case TemplateElement::Literal:
206e5dd7070Spatrick       ToText += Element.Value;
207e5dd7070Spatrick       break;
208e5dd7070Spatrick     case TemplateElement::Identifier: {
209e5dd7070Spatrick       auto NodeIter = NodeMap.find(Element.Value);
210e5dd7070Spatrick       if (NodeIter == NodeMap.end()) {
211e5dd7070Spatrick         llvm::errs() << "Node " << Element.Value
212e5dd7070Spatrick                      << " used in replacement template not bound in Matcher \n";
213e5dd7070Spatrick         llvm::report_fatal_error("Unbound node in replacement template.");
214e5dd7070Spatrick       }
215e5dd7070Spatrick       CharSourceRange Source =
216e5dd7070Spatrick           CharSourceRange::getTokenRange(NodeIter->second.getSourceRange());
217e5dd7070Spatrick       ToText += Lexer::getSourceText(Source, *Result.SourceManager,
218e5dd7070Spatrick                                      Result.Context->getLangOpts());
219e5dd7070Spatrick       break;
220e5dd7070Spatrick     }
221e5dd7070Spatrick     }
222e5dd7070Spatrick   }
223e5dd7070Spatrick   if (NodeMap.count(FromId) == 0) {
224e5dd7070Spatrick     llvm::errs() << "Node to be replaced " << FromId
225e5dd7070Spatrick                  << " not bound in query.\n";
226e5dd7070Spatrick     llvm::report_fatal_error("FromId node not bound in MatchResult");
227e5dd7070Spatrick   }
228e5dd7070Spatrick   auto Replacement =
229e5dd7070Spatrick       tooling::Replacement(*Result.SourceManager, &NodeMap.at(FromId), ToText,
230e5dd7070Spatrick                            Result.Context->getLangOpts());
231e5dd7070Spatrick   llvm::Error Err = Replace.add(Replacement);
232e5dd7070Spatrick   if (Err) {
233e5dd7070Spatrick     llvm::errs() << "Query and replace failed in " << Replacement.getFilePath()
234e5dd7070Spatrick                  << "! " << llvm::toString(std::move(Err)) << "\n";
235e5dd7070Spatrick     llvm::report_fatal_error("Replacement failed");
236e5dd7070Spatrick   }
237e5dd7070Spatrick }
238e5dd7070Spatrick 
239e5dd7070Spatrick } // end namespace tooling
240e5dd7070Spatrick } // end namespace clang
241