1e5dd7070Spatrick //===--- RefactoringCallbacks.cpp - Structural query framework ------------===//
2e5dd7070Spatrick //
3e5dd7070Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e5dd7070Spatrick // See https://llvm.org/LICENSE.txt for license information.
5e5dd7070Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e5dd7070Spatrick //
7e5dd7070Spatrick //===----------------------------------------------------------------------===//
8e5dd7070Spatrick //
9e5dd7070Spatrick //
10e5dd7070Spatrick //===----------------------------------------------------------------------===//
11e5dd7070Spatrick #include "clang/Tooling/RefactoringCallbacks.h"
12e5dd7070Spatrick #include "clang/ASTMatchers/ASTMatchFinder.h"
13e5dd7070Spatrick #include "clang/Basic/SourceLocation.h"
14e5dd7070Spatrick #include "clang/Lex/Lexer.h"
15e5dd7070Spatrick
16e5dd7070Spatrick using llvm::StringError;
17e5dd7070Spatrick using llvm::make_error;
18e5dd7070Spatrick
19e5dd7070Spatrick namespace clang {
20e5dd7070Spatrick namespace tooling {
21e5dd7070Spatrick
RefactoringCallback()22e5dd7070Spatrick RefactoringCallback::RefactoringCallback() {}
getReplacements()23e5dd7070Spatrick tooling::Replacements &RefactoringCallback::getReplacements() {
24e5dd7070Spatrick return Replace;
25e5dd7070Spatrick }
26e5dd7070Spatrick
ASTMatchRefactorer(std::map<std::string,Replacements> & FileToReplaces)27e5dd7070Spatrick ASTMatchRefactorer::ASTMatchRefactorer(
28e5dd7070Spatrick std::map<std::string, Replacements> &FileToReplaces)
29e5dd7070Spatrick : FileToReplaces(FileToReplaces) {}
30e5dd7070Spatrick
addDynamicMatcher(const ast_matchers::internal::DynTypedMatcher & Matcher,RefactoringCallback * Callback)31e5dd7070Spatrick void ASTMatchRefactorer::addDynamicMatcher(
32e5dd7070Spatrick const ast_matchers::internal::DynTypedMatcher &Matcher,
33e5dd7070Spatrick RefactoringCallback *Callback) {
34e5dd7070Spatrick MatchFinder.addDynamicMatcher(Matcher, Callback);
35e5dd7070Spatrick Callbacks.push_back(Callback);
36e5dd7070Spatrick }
37e5dd7070Spatrick
38e5dd7070Spatrick class RefactoringASTConsumer : public ASTConsumer {
39e5dd7070Spatrick public:
RefactoringASTConsumer(ASTMatchRefactorer & Refactoring)40e5dd7070Spatrick explicit RefactoringASTConsumer(ASTMatchRefactorer &Refactoring)
41e5dd7070Spatrick : Refactoring(Refactoring) {}
42e5dd7070Spatrick
HandleTranslationUnit(ASTContext & Context)43e5dd7070Spatrick void HandleTranslationUnit(ASTContext &Context) override {
44e5dd7070Spatrick // The ASTMatchRefactorer is re-used between translation units.
45e5dd7070Spatrick // Clear the matchers so that each Replacement is only emitted once.
46e5dd7070Spatrick for (const auto &Callback : Refactoring.Callbacks) {
47e5dd7070Spatrick Callback->getReplacements().clear();
48e5dd7070Spatrick }
49e5dd7070Spatrick Refactoring.MatchFinder.matchAST(Context);
50e5dd7070Spatrick for (const auto &Callback : Refactoring.Callbacks) {
51e5dd7070Spatrick for (const auto &Replacement : Callback->getReplacements()) {
52e5dd7070Spatrick llvm::Error Err =
53*ec727ea7Spatrick Refactoring.FileToReplaces[std::string(Replacement.getFilePath())]
54*ec727ea7Spatrick .add(Replacement);
55e5dd7070Spatrick if (Err) {
56e5dd7070Spatrick llvm::errs() << "Skipping replacement " << Replacement.toString()
57e5dd7070Spatrick << " due to this error:\n"
58e5dd7070Spatrick << toString(std::move(Err)) << "\n";
59e5dd7070Spatrick }
60e5dd7070Spatrick }
61e5dd7070Spatrick }
62e5dd7070Spatrick }
63e5dd7070Spatrick
64e5dd7070Spatrick private:
65e5dd7070Spatrick ASTMatchRefactorer &Refactoring;
66e5dd7070Spatrick };
67e5dd7070Spatrick
newASTConsumer()68e5dd7070Spatrick std::unique_ptr<ASTConsumer> ASTMatchRefactorer::newASTConsumer() {
69e5dd7070Spatrick return std::make_unique<RefactoringASTConsumer>(*this);
70e5dd7070Spatrick }
71e5dd7070Spatrick
replaceStmtWithText(SourceManager & Sources,const Stmt & From,StringRef Text)72e5dd7070Spatrick static Replacement replaceStmtWithText(SourceManager &Sources, const Stmt &From,
73e5dd7070Spatrick StringRef Text) {
74e5dd7070Spatrick return tooling::Replacement(
75e5dd7070Spatrick Sources, CharSourceRange::getTokenRange(From.getSourceRange()), Text);
76e5dd7070Spatrick }
replaceStmtWithStmt(SourceManager & Sources,const Stmt & From,const Stmt & To)77e5dd7070Spatrick static Replacement replaceStmtWithStmt(SourceManager &Sources, const Stmt &From,
78e5dd7070Spatrick const Stmt &To) {
79e5dd7070Spatrick return replaceStmtWithText(
80e5dd7070Spatrick Sources, From,
81e5dd7070Spatrick Lexer::getSourceText(CharSourceRange::getTokenRange(To.getSourceRange()),
82e5dd7070Spatrick Sources, LangOptions()));
83e5dd7070Spatrick }
84e5dd7070Spatrick
ReplaceStmtWithText(StringRef FromId,StringRef ToText)85e5dd7070Spatrick ReplaceStmtWithText::ReplaceStmtWithText(StringRef FromId, StringRef ToText)
86*ec727ea7Spatrick : FromId(std::string(FromId)), ToText(std::string(ToText)) {}
87e5dd7070Spatrick
run(const ast_matchers::MatchFinder::MatchResult & Result)88e5dd7070Spatrick void ReplaceStmtWithText::run(
89e5dd7070Spatrick const ast_matchers::MatchFinder::MatchResult &Result) {
90e5dd7070Spatrick if (const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId)) {
91e5dd7070Spatrick auto Err = Replace.add(tooling::Replacement(
92e5dd7070Spatrick *Result.SourceManager,
93e5dd7070Spatrick CharSourceRange::getTokenRange(FromMatch->getSourceRange()), ToText));
94e5dd7070Spatrick // FIXME: better error handling. For now, just print error message in the
95e5dd7070Spatrick // release version.
96e5dd7070Spatrick if (Err) {
97e5dd7070Spatrick llvm::errs() << llvm::toString(std::move(Err)) << "\n";
98e5dd7070Spatrick assert(false);
99e5dd7070Spatrick }
100e5dd7070Spatrick }
101e5dd7070Spatrick }
102e5dd7070Spatrick
ReplaceStmtWithStmt(StringRef FromId,StringRef ToId)103e5dd7070Spatrick ReplaceStmtWithStmt::ReplaceStmtWithStmt(StringRef FromId, StringRef ToId)
104*ec727ea7Spatrick : FromId(std::string(FromId)), ToId(std::string(ToId)) {}
105e5dd7070Spatrick
run(const ast_matchers::MatchFinder::MatchResult & Result)106e5dd7070Spatrick void ReplaceStmtWithStmt::run(
107e5dd7070Spatrick const ast_matchers::MatchFinder::MatchResult &Result) {
108e5dd7070Spatrick const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId);
109e5dd7070Spatrick const Stmt *ToMatch = Result.Nodes.getNodeAs<Stmt>(ToId);
110e5dd7070Spatrick if (FromMatch && ToMatch) {
111e5dd7070Spatrick auto Err = Replace.add(
112e5dd7070Spatrick replaceStmtWithStmt(*Result.SourceManager, *FromMatch, *ToMatch));
113e5dd7070Spatrick // FIXME: better error handling. For now, just print error message in the
114e5dd7070Spatrick // release version.
115e5dd7070Spatrick if (Err) {
116e5dd7070Spatrick llvm::errs() << llvm::toString(std::move(Err)) << "\n";
117e5dd7070Spatrick assert(false);
118e5dd7070Spatrick }
119e5dd7070Spatrick }
120e5dd7070Spatrick }
121e5dd7070Spatrick
ReplaceIfStmtWithItsBody(StringRef Id,bool PickTrueBranch)122e5dd7070Spatrick ReplaceIfStmtWithItsBody::ReplaceIfStmtWithItsBody(StringRef Id,
123e5dd7070Spatrick bool PickTrueBranch)
124*ec727ea7Spatrick : Id(std::string(Id)), PickTrueBranch(PickTrueBranch) {}
125e5dd7070Spatrick
run(const ast_matchers::MatchFinder::MatchResult & Result)126e5dd7070Spatrick void ReplaceIfStmtWithItsBody::run(
127e5dd7070Spatrick const ast_matchers::MatchFinder::MatchResult &Result) {
128e5dd7070Spatrick if (const IfStmt *Node = Result.Nodes.getNodeAs<IfStmt>(Id)) {
129e5dd7070Spatrick const Stmt *Body = PickTrueBranch ? Node->getThen() : Node->getElse();
130e5dd7070Spatrick if (Body) {
131e5dd7070Spatrick auto Err =
132e5dd7070Spatrick Replace.add(replaceStmtWithStmt(*Result.SourceManager, *Node, *Body));
133e5dd7070Spatrick // FIXME: better error handling. For now, just print error message in the
134e5dd7070Spatrick // release version.
135e5dd7070Spatrick if (Err) {
136e5dd7070Spatrick llvm::errs() << llvm::toString(std::move(Err)) << "\n";
137e5dd7070Spatrick assert(false);
138e5dd7070Spatrick }
139e5dd7070Spatrick } else if (!PickTrueBranch) {
140e5dd7070Spatrick // If we want to use the 'else'-branch, but it doesn't exist, delete
141e5dd7070Spatrick // the whole 'if'.
142e5dd7070Spatrick auto Err =
143e5dd7070Spatrick Replace.add(replaceStmtWithText(*Result.SourceManager, *Node, ""));
144e5dd7070Spatrick // FIXME: better error handling. For now, just print error message in the
145e5dd7070Spatrick // release version.
146e5dd7070Spatrick if (Err) {
147e5dd7070Spatrick llvm::errs() << llvm::toString(std::move(Err)) << "\n";
148e5dd7070Spatrick assert(false);
149e5dd7070Spatrick }
150e5dd7070Spatrick }
151e5dd7070Spatrick }
152e5dd7070Spatrick }
153e5dd7070Spatrick
ReplaceNodeWithTemplate(llvm::StringRef FromId,std::vector<TemplateElement> Template)154e5dd7070Spatrick ReplaceNodeWithTemplate::ReplaceNodeWithTemplate(
155e5dd7070Spatrick llvm::StringRef FromId, std::vector<TemplateElement> Template)
156*ec727ea7Spatrick : FromId(std::string(FromId)), Template(std::move(Template)) {}
157e5dd7070Spatrick
158e5dd7070Spatrick llvm::Expected<std::unique_ptr<ReplaceNodeWithTemplate>>
create(StringRef FromId,StringRef ToTemplate)159e5dd7070Spatrick ReplaceNodeWithTemplate::create(StringRef FromId, StringRef ToTemplate) {
160e5dd7070Spatrick std::vector<TemplateElement> ParsedTemplate;
161e5dd7070Spatrick for (size_t Index = 0; Index < ToTemplate.size();) {
162e5dd7070Spatrick if (ToTemplate[Index] == '$') {
163e5dd7070Spatrick if (ToTemplate.substr(Index, 2) == "$$") {
164e5dd7070Spatrick Index += 2;
165e5dd7070Spatrick ParsedTemplate.push_back(
166e5dd7070Spatrick TemplateElement{TemplateElement::Literal, "$"});
167e5dd7070Spatrick } else if (ToTemplate.substr(Index, 2) == "${") {
168e5dd7070Spatrick size_t EndOfIdentifier = ToTemplate.find("}", Index);
169e5dd7070Spatrick if (EndOfIdentifier == std::string::npos) {
170e5dd7070Spatrick return make_error<StringError>(
171e5dd7070Spatrick "Unterminated ${...} in replacement template near " +
172e5dd7070Spatrick ToTemplate.substr(Index),
173e5dd7070Spatrick llvm::inconvertibleErrorCode());
174e5dd7070Spatrick }
175*ec727ea7Spatrick std::string SourceNodeName = std::string(
176*ec727ea7Spatrick ToTemplate.substr(Index + 2, EndOfIdentifier - Index - 2));
177e5dd7070Spatrick ParsedTemplate.push_back(
178e5dd7070Spatrick TemplateElement{TemplateElement::Identifier, SourceNodeName});
179e5dd7070Spatrick Index = EndOfIdentifier + 1;
180e5dd7070Spatrick } else {
181e5dd7070Spatrick return make_error<StringError>(
182e5dd7070Spatrick "Invalid $ in replacement template near " +
183e5dd7070Spatrick ToTemplate.substr(Index),
184e5dd7070Spatrick llvm::inconvertibleErrorCode());
185e5dd7070Spatrick }
186e5dd7070Spatrick } else {
187e5dd7070Spatrick size_t NextIndex = ToTemplate.find('$', Index + 1);
188*ec727ea7Spatrick ParsedTemplate.push_back(TemplateElement{
189*ec727ea7Spatrick TemplateElement::Literal,
190*ec727ea7Spatrick std::string(ToTemplate.substr(Index, NextIndex - Index))});
191e5dd7070Spatrick Index = NextIndex;
192e5dd7070Spatrick }
193e5dd7070Spatrick }
194e5dd7070Spatrick return std::unique_ptr<ReplaceNodeWithTemplate>(
195e5dd7070Spatrick new ReplaceNodeWithTemplate(FromId, std::move(ParsedTemplate)));
196e5dd7070Spatrick }
197e5dd7070Spatrick
run(const ast_matchers::MatchFinder::MatchResult & Result)198e5dd7070Spatrick void ReplaceNodeWithTemplate::run(
199e5dd7070Spatrick const ast_matchers::MatchFinder::MatchResult &Result) {
200e5dd7070Spatrick const auto &NodeMap = Result.Nodes.getMap();
201e5dd7070Spatrick
202e5dd7070Spatrick std::string ToText;
203e5dd7070Spatrick for (const auto &Element : Template) {
204e5dd7070Spatrick switch (Element.Type) {
205e5dd7070Spatrick case TemplateElement::Literal:
206e5dd7070Spatrick ToText += Element.Value;
207e5dd7070Spatrick break;
208e5dd7070Spatrick case TemplateElement::Identifier: {
209e5dd7070Spatrick auto NodeIter = NodeMap.find(Element.Value);
210e5dd7070Spatrick if (NodeIter == NodeMap.end()) {
211e5dd7070Spatrick llvm::errs() << "Node " << Element.Value
212e5dd7070Spatrick << " used in replacement template not bound in Matcher \n";
213e5dd7070Spatrick llvm::report_fatal_error("Unbound node in replacement template.");
214e5dd7070Spatrick }
215e5dd7070Spatrick CharSourceRange Source =
216e5dd7070Spatrick CharSourceRange::getTokenRange(NodeIter->second.getSourceRange());
217e5dd7070Spatrick ToText += Lexer::getSourceText(Source, *Result.SourceManager,
218e5dd7070Spatrick Result.Context->getLangOpts());
219e5dd7070Spatrick break;
220e5dd7070Spatrick }
221e5dd7070Spatrick }
222e5dd7070Spatrick }
223e5dd7070Spatrick if (NodeMap.count(FromId) == 0) {
224e5dd7070Spatrick llvm::errs() << "Node to be replaced " << FromId
225e5dd7070Spatrick << " not bound in query.\n";
226e5dd7070Spatrick llvm::report_fatal_error("FromId node not bound in MatchResult");
227e5dd7070Spatrick }
228e5dd7070Spatrick auto Replacement =
229e5dd7070Spatrick tooling::Replacement(*Result.SourceManager, &NodeMap.at(FromId), ToText,
230e5dd7070Spatrick Result.Context->getLangOpts());
231e5dd7070Spatrick llvm::Error Err = Replace.add(Replacement);
232e5dd7070Spatrick if (Err) {
233e5dd7070Spatrick llvm::errs() << "Query and replace failed in " << Replacement.getFilePath()
234e5dd7070Spatrick << "! " << llvm::toString(std::move(Err)) << "\n";
235e5dd7070Spatrick llvm::report_fatal_error("Replacement failed");
236e5dd7070Spatrick }
237e5dd7070Spatrick }
238e5dd7070Spatrick
239e5dd7070Spatrick } // end namespace tooling
240e5dd7070Spatrick } // end namespace clang
241