xref: /llvm-project/clang/unittests/Tooling/RefactoringCallbacksTest.cpp (revision 027899dab6ac31a34e17b0f43eeb3d00e310a361)
1 //===- unittest/Tooling/RefactoringCallbacksTest.cpp ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "RewriterTestContext.h"
10 #include "clang/ASTMatchers/ASTMatchFinder.h"
11 #include "clang/ASTMatchers/ASTMatchers.h"
12 #include "clang/Tooling/RefactoringCallbacks.h"
13 #include "gtest/gtest.h"
14 
15 namespace clang {
16 namespace tooling {
17 
18 using namespace ast_matchers;
19 
20 template <typename T>
expectRewritten(const std::string & Code,const std::string & Expected,const T & AMatcher,RefactoringCallback & Callback)21 void expectRewritten(const std::string &Code, const std::string &Expected,
22                      const T &AMatcher, RefactoringCallback &Callback) {
23   std::map<std::string, Replacements> FileToReplace;
24   ASTMatchRefactorer Finder(FileToReplace);
25   Finder.addMatcher(traverse(TK_AsIs, AMatcher), &Callback);
26   std::unique_ptr<tooling::FrontendActionFactory> Factory(
27       tooling::newFrontendActionFactory(&Finder));
28   ASSERT_TRUE(tooling::runToolOnCode(Factory->create(), Code))
29       << "Parsing error in \"" << Code << "\"";
30   RewriterTestContext Context;
31   FileID ID = Context.createInMemoryFile("input.cc", Code);
32   EXPECT_TRUE(tooling::applyAllReplacements(FileToReplace["input.cc"],
33                                             Context.Rewrite));
34   EXPECT_EQ(Expected, Context.getRewrittenText(ID));
35 }
36 
TEST(RefactoringCallbacksTest,ReplacesStmtsWithString)37 TEST(RefactoringCallbacksTest, ReplacesStmtsWithString) {
38   std::string Code = "void f() { int i = 1; }";
39   std::string Expected = "void f() { ; }";
40   ReplaceStmtWithText Callback("id", ";");
41   expectRewritten(Code, Expected, declStmt().bind("id"), Callback);
42 }
43 
TEST(RefactoringCallbacksTest,ReplacesStmtsInCalledMacros)44 TEST(RefactoringCallbacksTest, ReplacesStmtsInCalledMacros) {
45   std::string Code = "#define A void f() { int i = 1; }\nA";
46   std::string Expected = "#define A void f() { ; }\nA";
47   ReplaceStmtWithText Callback("id", ";");
48   expectRewritten(Code, Expected, declStmt().bind("id"), Callback);
49 }
50 
TEST(RefactoringCallbacksTest,IgnoresStmtsInUncalledMacros)51 TEST(RefactoringCallbacksTest, IgnoresStmtsInUncalledMacros) {
52   std::string Code = "#define A void f() { int i = 1; }";
53   std::string Expected = "#define A void f() { int i = 1; }";
54   ReplaceStmtWithText Callback("id", ";");
55   expectRewritten(Code, Expected, declStmt().bind("id"), Callback);
56 }
57 
TEST(RefactoringCallbacksTest,ReplacesInteger)58 TEST(RefactoringCallbacksTest, ReplacesInteger) {
59   std::string Code = "void f() { int i = 1; }";
60   std::string Expected = "void f() { int i = 2; }";
61   ReplaceStmtWithText Callback("id", "2");
62   expectRewritten(Code, Expected, expr(integerLiteral()).bind("id"), Callback);
63 }
64 
TEST(RefactoringCallbacksTest,ReplacesStmtWithStmt)65 TEST(RefactoringCallbacksTest, ReplacesStmtWithStmt) {
66   std::string Code = "void f() { int i = false ? 1 : i * 2; }";
67   std::string Expected = "void f() { int i = i * 2; }";
68   ReplaceStmtWithStmt Callback("always-false", "should-be");
69   expectRewritten(
70       Code, Expected,
71       conditionalOperator(hasCondition(cxxBoolLiteral(equals(false))),
72                           hasFalseExpression(expr().bind("should-be")))
73           .bind("always-false"),
74       Callback);
75 }
76 
TEST(RefactoringCallbacksTest,ReplacesIfStmt)77 TEST(RefactoringCallbacksTest, ReplacesIfStmt) {
78   std::string Code = "bool a; void f() { if (a) f(); else a = true; }";
79   std::string Expected = "bool a; void f() { f(); }";
80   ReplaceIfStmtWithItsBody Callback("id", true);
81   expectRewritten(Code, Expected,
82                   ifStmt(hasCondition(implicitCastExpr(hasSourceExpression(
83                              declRefExpr(to(varDecl(hasName("a"))))))))
84                       .bind("id"),
85                   Callback);
86 }
87 
TEST(RefactoringCallbacksTest,RemovesEntireIfOnEmptyElse)88 TEST(RefactoringCallbacksTest, RemovesEntireIfOnEmptyElse) {
89   std::string Code = "void f() { if (false) int i = 0; }";
90   std::string Expected = "void f() {  }";
91   ReplaceIfStmtWithItsBody Callback("id", false);
92   expectRewritten(
93       Code, Expected,
94       ifStmt(hasCondition(cxxBoolLiteral(equals(false)))).bind("id"), Callback);
95 }
96 
TEST(RefactoringCallbacksTest,TemplateJustText)97 TEST(RefactoringCallbacksTest, TemplateJustText) {
98   std::string Code = "void f() { int i = 1; }";
99   std::string Expected = "void f() { FOO }";
100   auto Callback = ReplaceNodeWithTemplate::create("id", "FOO");
101   EXPECT_FALSE(Callback.takeError());
102   expectRewritten(Code, Expected, declStmt().bind("id"), **Callback);
103 }
104 
TEST(RefactoringCallbacksTest,TemplateSimpleSubst)105 TEST(RefactoringCallbacksTest, TemplateSimpleSubst) {
106   std::string Code = "void f() { int i = 1; }";
107   std::string Expected = "void f() { long x = 1; }";
108   auto Callback = ReplaceNodeWithTemplate::create("decl", "long x = ${init}");
109   EXPECT_FALSE(Callback.takeError());
110   expectRewritten(Code, Expected,
111                   varDecl(hasInitializer(expr().bind("init"))).bind("decl"),
112                   **Callback);
113 }
114 
TEST(RefactoringCallbacksTest,TemplateLiteral)115 TEST(RefactoringCallbacksTest, TemplateLiteral) {
116   std::string Code = "void f() { int i = 1; }";
117   std::string Expected = "void f() { string x = \"$-1\"; }";
118   auto Callback = ReplaceNodeWithTemplate::create("decl",
119                                                   "string x = \"$$-${init}\"");
120   EXPECT_FALSE(Callback.takeError());
121   expectRewritten(Code, Expected,
122                   varDecl(hasInitializer(expr().bind("init"))).bind("decl"),
123                   **Callback);
124 }
125 
ExpectStringError(const std::string & Expected,llvm::Error E)126 static void ExpectStringError(const std::string &Expected,
127                               llvm::Error E) {
128   std::string Found;
129   handleAllErrors(std::move(E), [&](const llvm::StringError &SE) {
130       llvm::raw_string_ostream Stream(Found);
131       SE.log(Stream);
132     });
133   EXPECT_EQ(Expected, Found);
134 }
135 
TEST(RefactoringCallbacksTest,TemplateUnterminated)136 TEST(RefactoringCallbacksTest, TemplateUnterminated) {
137   auto Callback = ReplaceNodeWithTemplate::create("decl",
138                                                   "string x = \"$$-${init\"");
139   ExpectStringError("Unterminated ${...} in replacement template near ${init\"",
140                     Callback.takeError());
141 }
142 
TEST(RefactoringCallbacksTest,TemplateUnknownDollar)143 TEST(RefactoringCallbacksTest, TemplateUnknownDollar) {
144   auto Callback = ReplaceNodeWithTemplate::create("decl",
145                                                   "string x = \"$<");
146   ExpectStringError("Invalid $ in replacement template near $<",
147                     Callback.takeError());
148 }
149 
150 } // namespace tooling
151 } // end namespace clang
152