xref: /llvm-project/clang/unittests/Tooling/RefactoringActionRulesTest.cpp (revision 6ad0788c332bb2043142954d300c49ac3e537f34)
1 //===- unittest/Tooling/RefactoringTestActionRulesTest.cpp ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "ReplacementTest.h"
10 #include "RewriterTestContext.h"
11 #include "clang/Tooling/Refactoring.h"
12 #include "clang/Tooling/Refactoring/Extract/Extract.h"
13 #include "clang/Tooling/Refactoring/RefactoringAction.h"
14 #include "clang/Tooling/Refactoring/RefactoringDiagnostic.h"
15 #include "clang/Tooling/Refactoring/Rename/SymbolName.h"
16 #include "clang/Tooling/Tooling.h"
17 #include "llvm/Support/Errc.h"
18 #include "gtest/gtest.h"
19 #include <optional>
20 
21 using namespace clang;
22 using namespace tooling;
23 
24 namespace {
25 
26 class RefactoringActionRulesTest : public ::testing::Test {
27 protected:
SetUp()28   void SetUp() override {
29     Context.Sources.setMainFileID(
30         Context.createInMemoryFile("input.cpp", DefaultCode));
31   }
32 
33   RewriterTestContext Context;
34   std::string DefaultCode = std::string(100, 'a');
35 };
36 
37 Expected<AtomicChanges>
createReplacements(const std::unique_ptr<RefactoringActionRule> & Rule,RefactoringRuleContext & Context)38 createReplacements(const std::unique_ptr<RefactoringActionRule> &Rule,
39                    RefactoringRuleContext &Context) {
40   class Consumer final : public RefactoringResultConsumer {
41     void handleError(llvm::Error Err) override { Result = std::move(Err); }
42 
43     void handle(AtomicChanges SourceReplacements) override {
44       Result = std::move(SourceReplacements);
45     }
46     void handle(SymbolOccurrences Occurrences) override {
47       RefactoringResultConsumer::handle(std::move(Occurrences));
48     }
49 
50   public:
51     std::optional<Expected<AtomicChanges>> Result;
52   };
53 
54   Consumer C;
55   Rule->invoke(C, Context);
56   return std::move(*C.Result);
57 }
58 
TEST_F(RefactoringActionRulesTest,MyFirstRefactoringRule)59 TEST_F(RefactoringActionRulesTest, MyFirstRefactoringRule) {
60   class ReplaceAWithB : public SourceChangeRefactoringRule {
61     std::pair<SourceRange, int> Selection;
62 
63   public:
64     ReplaceAWithB(std::pair<SourceRange, int> Selection)
65         : Selection(Selection) {}
66 
67     static Expected<ReplaceAWithB>
68     initiate(RefactoringRuleContext &Cotnext,
69              std::pair<SourceRange, int> Selection) {
70       return ReplaceAWithB(Selection);
71     }
72 
73     Expected<AtomicChanges>
74     createSourceReplacements(RefactoringRuleContext &Context) {
75       const SourceManager &SM = Context.getSources();
76       SourceLocation Loc =
77           Selection.first.getBegin().getLocWithOffset(Selection.second);
78       AtomicChange Change(SM, Loc);
79       llvm::Error E = Change.replace(SM, Loc, 1, "b");
80       if (E)
81         return std::move(E);
82       return AtomicChanges{Change};
83     }
84   };
85 
86   class SelectionRequirement : public SourceRangeSelectionRequirement {
87   public:
88     Expected<std::pair<SourceRange, int>>
89     evaluate(RefactoringRuleContext &Context) const {
90       Expected<SourceRange> R =
91           SourceRangeSelectionRequirement::evaluate(Context);
92       if (!R)
93         return R.takeError();
94       return std::make_pair(*R, 20);
95     }
96   };
97   auto Rule =
98       createRefactoringActionRule<ReplaceAWithB>(SelectionRequirement());
99 
100   // When the requirements are satisfied, the rule's function must be invoked.
101   {
102     RefactoringRuleContext RefContext(Context.Sources);
103     SourceLocation Cursor =
104         Context.Sources.getLocForStartOfFile(Context.Sources.getMainFileID())
105             .getLocWithOffset(10);
106     RefContext.setSelectionRange({Cursor, Cursor});
107 
108     Expected<AtomicChanges> ErrorOrResult =
109         createReplacements(Rule, RefContext);
110     ASSERT_FALSE(!ErrorOrResult);
111     AtomicChanges Result = std::move(*ErrorOrResult);
112     ASSERT_EQ(Result.size(), 1u);
113     std::string YAMLString =
114         const_cast<AtomicChange &>(Result[0]).toYAMLString();
115 
116     ASSERT_STREQ("---\n"
117                  "Key:             'input.cpp:30'\n"
118                  "FilePath:        input.cpp\n"
119                  "Error:           ''\n"
120                  "InsertedHeaders: []\n"
121                  "RemovedHeaders:  []\n"
122                  "Replacements:\n"
123                  "  - FilePath:        input.cpp\n"
124                  "    Offset:          30\n"
125                  "    Length:          1\n"
126                  "    ReplacementText: b\n"
127                  "...\n",
128                  YAMLString.c_str());
129   }
130 
131   // When one of the requirements is not satisfied, invoke should return a
132   // valid error.
133   {
134     RefactoringRuleContext RefContext(Context.Sources);
135     Expected<AtomicChanges> ErrorOrResult =
136         createReplacements(Rule, RefContext);
137 
138     ASSERT_TRUE(!ErrorOrResult);
139     unsigned DiagID;
140     llvm::handleAllErrors(ErrorOrResult.takeError(),
141                           [&](DiagnosticError &Error) {
142                             DiagID = Error.getDiagnostic().second.getDiagID();
143                           });
144     EXPECT_EQ(DiagID, diag::err_refactor_no_selection);
145   }
146 }
147 
TEST_F(RefactoringActionRulesTest,ReturnError)148 TEST_F(RefactoringActionRulesTest, ReturnError) {
149   class ErrorRule : public SourceChangeRefactoringRule {
150   public:
151     static Expected<ErrorRule> initiate(RefactoringRuleContext &,
152                                         SourceRange R) {
153       return ErrorRule(R);
154     }
155 
156     ErrorRule(SourceRange R) {}
157     Expected<AtomicChanges> createSourceReplacements(RefactoringRuleContext &) {
158       return llvm::make_error<llvm::StringError>(
159           "Error", llvm::make_error_code(llvm::errc::invalid_argument));
160     }
161   };
162 
163   auto Rule =
164       createRefactoringActionRule<ErrorRule>(SourceRangeSelectionRequirement());
165   RefactoringRuleContext RefContext(Context.Sources);
166   SourceLocation Cursor =
167       Context.Sources.getLocForStartOfFile(Context.Sources.getMainFileID());
168   RefContext.setSelectionRange({Cursor, Cursor});
169   Expected<AtomicChanges> Result = createReplacements(Rule, RefContext);
170 
171   ASSERT_TRUE(!Result);
172   std::string Message;
173   llvm::handleAllErrors(Result.takeError(), [&](llvm::StringError &Error) {
174     Message = Error.getMessage();
175   });
176   EXPECT_EQ(Message, "Error");
177 }
178 
179 std::optional<SymbolOccurrences>
findOccurrences(RefactoringActionRule & Rule,RefactoringRuleContext & Context)180 findOccurrences(RefactoringActionRule &Rule, RefactoringRuleContext &Context) {
181   class Consumer final : public RefactoringResultConsumer {
182     void handleError(llvm::Error) override {}
183     void handle(SymbolOccurrences Occurrences) override {
184       Result = std::move(Occurrences);
185     }
186     void handle(AtomicChanges Changes) override {
187       RefactoringResultConsumer::handle(std::move(Changes));
188     }
189 
190   public:
191     std::optional<SymbolOccurrences> Result;
192   };
193 
194   Consumer C;
195   Rule.invoke(C, Context);
196   return std::move(C.Result);
197 }
198 
TEST_F(RefactoringActionRulesTest,ReturnSymbolOccurrences)199 TEST_F(RefactoringActionRulesTest, ReturnSymbolOccurrences) {
200   class FindOccurrences : public FindSymbolOccurrencesRefactoringRule {
201     SourceRange Selection;
202 
203   public:
204     FindOccurrences(SourceRange Selection) : Selection(Selection) {}
205 
206     static Expected<FindOccurrences> initiate(RefactoringRuleContext &,
207                                               SourceRange Selection) {
208       return FindOccurrences(Selection);
209     }
210 
211     Expected<SymbolOccurrences>
212     findSymbolOccurrences(RefactoringRuleContext &) override {
213       SymbolOccurrences Occurrences;
214       Occurrences.push_back(SymbolOccurrence(SymbolName("test"),
215                                              SymbolOccurrence::MatchingSymbol,
216                                              Selection.getBegin()));
217       return std::move(Occurrences);
218     }
219   };
220 
221   auto Rule = createRefactoringActionRule<FindOccurrences>(
222       SourceRangeSelectionRequirement());
223 
224   RefactoringRuleContext RefContext(Context.Sources);
225   SourceLocation Cursor =
226       Context.Sources.getLocForStartOfFile(Context.Sources.getMainFileID());
227   RefContext.setSelectionRange({Cursor, Cursor});
228   std::optional<SymbolOccurrences> Result = findOccurrences(*Rule, RefContext);
229 
230   ASSERT_FALSE(!Result);
231   SymbolOccurrences Occurrences = std::move(*Result);
232   EXPECT_EQ(Occurrences.size(), 1u);
233   EXPECT_EQ(Occurrences[0].getKind(), SymbolOccurrence::MatchingSymbol);
234   EXPECT_EQ(Occurrences[0].getNameRanges().size(), 1u);
235   EXPECT_EQ(Occurrences[0].getNameRanges()[0],
236             SourceRange(Cursor, Cursor.getLocWithOffset(strlen("test"))));
237 }
238 
TEST_F(RefactoringActionRulesTest,EditorCommandBinding)239 TEST_F(RefactoringActionRulesTest, EditorCommandBinding) {
240   const RefactoringDescriptor &Descriptor = ExtractFunction::describe();
241   EXPECT_EQ(Descriptor.Name, "extract-function");
242   EXPECT_EQ(
243       Descriptor.Description,
244       "(WIP action; use with caution!) Extracts code into a new function");
245   EXPECT_EQ(Descriptor.Title, "Extract Function");
246 }
247 
248 } // end anonymous namespace
249