xref: /llvm-project/clang/unittests/Tooling/TestVisitor.h (revision 4e600751d2f7e8e7b85a71b7128b68444bdde91b)
1 //===--- TestVisitor.h ------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// \brief Defines utility templates for RecursiveASTVisitor related tests.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_CLANG_UNITTESTS_TOOLING_TESTVISITOR_H
15 #define LLVM_CLANG_UNITTESTS_TOOLING_TESTVISITOR_H
16 
17 #include "clang/AST/ASTConsumer.h"
18 #include "clang/AST/ASTContext.h"
19 #include "clang/AST/DynamicRecursiveASTVisitor.h"
20 #include "clang/Frontend/CompilerInstance.h"
21 #include "clang/Frontend/FrontendAction.h"
22 #include "clang/Tooling/Tooling.h"
23 #include "gtest/gtest.h"
24 #include <vector>
25 
26 namespace clang {
27 namespace detail {
28 // Use 'TestVisitor' or include 'CRTPTestVisitor.h' and use 'CRTPTestVisitor'
29 // instead of using this directly.
30 class TestVisitorHelper {
31 public:
32   enum Language {
33     Lang_C,
34     Lang_CXX98,
35     Lang_CXX11,
36     Lang_CXX14,
37     Lang_CXX17,
38     Lang_CXX2a,
39     Lang_OBJC,
40     Lang_OBJCXX11,
41     Lang_CXX = Lang_CXX98
42   };
43 
44   /// \brief Runs the current AST visitor over the given code.
45   bool runOver(StringRef Code, Language L = Lang_CXX) {
46     std::vector<std::string> Args;
47     switch (L) {
48     case Lang_C:
49       Args.push_back("-x");
50       Args.push_back("c");
51       break;
52     case Lang_CXX98:
53       Args.push_back("-std=c++98");
54       break;
55     case Lang_CXX11:
56       Args.push_back("-std=c++11");
57       break;
58     case Lang_CXX14:
59       Args.push_back("-std=c++14");
60       break;
61     case Lang_CXX17:
62       Args.push_back("-std=c++17");
63       break;
64     case Lang_CXX2a:
65       Args.push_back("-std=c++2a");
66       break;
67     case Lang_OBJC:
68       Args.push_back("-ObjC");
69       Args.push_back("-fobjc-runtime=macosx-10.12.0");
70       break;
71     case Lang_OBJCXX11:
72       Args.push_back("-ObjC++");
73       Args.push_back("-std=c++11");
74       Args.push_back("-fblocks");
75       break;
76     }
77     return tooling::runToolOnCodeWithArgs(CreateTestAction(), Code, Args);
78   }
79 
80 protected:
81   TestVisitorHelper() = default;
82   virtual ~TestVisitorHelper() = default;
83   virtual void InvokeTraverseDecl(TranslationUnitDecl *D) = 0;
84 
85   virtual std::unique_ptr<ASTFrontendAction> CreateTestAction() {
86     return std::make_unique<TestAction>(this);
87   }
88 
89   class FindConsumer : public ASTConsumer {
90   public:
91     FindConsumer(TestVisitorHelper *Visitor) : Visitor(Visitor) {}
92 
93     void HandleTranslationUnit(clang::ASTContext &Context) override {
94       Visitor->Context = &Context;
95       Visitor->InvokeTraverseDecl(Context.getTranslationUnitDecl());
96     }
97 
98   private:
99     TestVisitorHelper *Visitor;
100   };
101 
102   class TestAction : public ASTFrontendAction {
103   public:
104     TestAction(TestVisitorHelper *Visitor) : Visitor(Visitor) {}
105 
106     std::unique_ptr<clang::ASTConsumer>
107     CreateASTConsumer(CompilerInstance &, llvm::StringRef dummy) override {
108       /// TestConsumer will be deleted by the framework calling us.
109       return std::make_unique<FindConsumer>(Visitor);
110     }
111 
112   protected:
113     TestVisitorHelper *Visitor;
114   };
115 
116   ASTContext *Context;
117 };
118 
119 class ExpectedLocationVisitorHelper {
120 public:
121   /// \brief Expect 'Match' *not* to occur at the given 'Line' and 'Column'.
122   ///
123   /// Any number of matches can be disallowed.
124   void DisallowMatch(Twine Match, unsigned Line, unsigned Column) {
125     DisallowedMatches.push_back(MatchCandidate(Match, Line, Column));
126   }
127 
128   /// \brief Expect 'Match' to occur at the given 'Line' and 'Column'.
129   ///
130   /// Any number of expected matches can be set by calling this repeatedly.
131   /// Each is expected to be matched 'Times' number of times. (This is useful in
132   /// cases in which different AST nodes can match at the same source code
133   /// location.)
134   void ExpectMatch(Twine Match, unsigned Line, unsigned Column,
135                    unsigned Times = 1) {
136     ExpectedMatches.push_back(ExpectedMatch(Match, Line, Column, Times));
137   }
138 
139   /// \brief Checks that all expected matches have been found.
140   virtual ~ExpectedLocationVisitorHelper() {
141     // FIXME: Range-based for loop.
142     for (std::vector<ExpectedMatch>::const_iterator
143              It = ExpectedMatches.begin(),
144              End = ExpectedMatches.end();
145          It != End; ++It) {
146       It->ExpectFound();
147     }
148   }
149 
150 protected:
151   virtual ASTContext *getASTContext() = 0;
152 
153   /// \brief Checks an actual match against expected and disallowed matches.
154   ///
155   /// Implementations are required to call this with appropriate values
156   /// for 'Name' during visitation.
157   void Match(StringRef Name, SourceLocation Location) {
158     const FullSourceLoc FullLocation = getASTContext()->getFullLoc(Location);
159 
160     // FIXME: Range-based for loop.
161     for (std::vector<MatchCandidate>::const_iterator
162              It = DisallowedMatches.begin(),
163              End = DisallowedMatches.end();
164          It != End; ++It) {
165       EXPECT_FALSE(It->Matches(Name, FullLocation))
166           << "Matched disallowed " << *It;
167     }
168 
169     // FIXME: Range-based for loop.
170     for (std::vector<ExpectedMatch>::iterator It = ExpectedMatches.begin(),
171                                               End = ExpectedMatches.end();
172          It != End; ++It) {
173       It->UpdateFor(Name, FullLocation, getASTContext()->getSourceManager());
174     }
175   }
176 
177 private:
178   struct MatchCandidate {
179     std::string ExpectedName;
180     unsigned LineNumber;
181     unsigned ColumnNumber;
182 
183     MatchCandidate(Twine Name, unsigned LineNumber, unsigned ColumnNumber)
184       : ExpectedName(Name.str()), LineNumber(LineNumber),
185         ColumnNumber(ColumnNumber) {
186     }
187 
188     bool Matches(StringRef Name, FullSourceLoc const &Location) const {
189       return MatchesName(Name) && MatchesLocation(Location);
190     }
191 
192     bool PartiallyMatches(StringRef Name, FullSourceLoc const &Location) const {
193       return MatchesName(Name) || MatchesLocation(Location);
194     }
195 
196     bool MatchesName(StringRef Name) const {
197       return Name == ExpectedName;
198     }
199 
200     bool MatchesLocation(FullSourceLoc const &Location) const {
201       return Location.isValid() &&
202           Location.getSpellingLineNumber() == LineNumber &&
203           Location.getSpellingColumnNumber() == ColumnNumber;
204     }
205 
206     friend std::ostream &operator<<(std::ostream &Stream,
207                                     MatchCandidate const &Match) {
208       return Stream << Match.ExpectedName
209                     << " at " << Match.LineNumber << ":" << Match.ColumnNumber;
210     }
211   };
212 
213   struct ExpectedMatch {
214     ExpectedMatch(Twine Name, unsigned LineNumber, unsigned ColumnNumber,
215                   unsigned Times)
216         : Candidate(Name, LineNumber, ColumnNumber), TimesExpected(Times),
217           TimesSeen(0) {}
218 
219     void UpdateFor(StringRef Name, FullSourceLoc Location, SourceManager &SM) {
220       if (Candidate.Matches(Name, Location)) {
221         EXPECT_LT(TimesSeen, TimesExpected);
222         ++TimesSeen;
223       } else if (TimesSeen < TimesExpected &&
224                  Candidate.PartiallyMatches(Name, Location)) {
225         llvm::raw_string_ostream Stream(PartialMatches);
226         Stream << ", partial match: \"" << Name << "\" at ";
227         Location.print(Stream, SM);
228       }
229     }
230 
231     void ExpectFound() const {
232       EXPECT_EQ(TimesExpected, TimesSeen)
233           << "Expected \"" << Candidate.ExpectedName
234           << "\" at " << Candidate.LineNumber
235           << ":" << Candidate.ColumnNumber << PartialMatches;
236     }
237 
238     MatchCandidate Candidate;
239     std::string PartialMatches;
240     unsigned TimesExpected;
241     unsigned TimesSeen;
242   };
243 
244   std::vector<MatchCandidate> DisallowedMatches;
245   std::vector<ExpectedMatch> ExpectedMatches;
246 };
247 } // namespace detail
248 
249 /// \brief Base class for simple (Dynamic)RecursiveASTVisitor based tests.
250 ///
251 /// This is a drop-in replacement for DynamicRecursiveASTVisitor itself, with
252 /// the additional capability of running it over a snippet of code.
253 ///
254 /// Visits template instantiations and implicit code by default.
255 ///
256 /// For post-order traversal etc. use CTRPTestVisitor from
257 /// CTRPTestVisitor.h instead.
258 class TestVisitor : public DynamicRecursiveASTVisitor,
259                     public detail::TestVisitorHelper {
260 public:
261   TestVisitor() {
262     ShouldVisitTemplateInstantiations = true;
263     ShouldVisitImplicitCode = true;
264   }
265 
266   void InvokeTraverseDecl(TranslationUnitDecl *D) override { TraverseDecl(D); }
267 };
268 
269 /// \brief A RecursiveASTVisitor to check that certain matches are (or are
270 /// not) observed during visitation.
271 ///
272 /// This is a RecursiveASTVisitor for testing the RecursiveASTVisitor itself,
273 /// and allows simple creation of test visitors running matches on only a small
274 /// subset of the Visit* methods.
275 ///
276 /// For post-order traversal etc. use CTRPExpectedLocationVisitor from
277 /// CTRPTestVisitor.h instead.
278 class ExpectedLocationVisitor : public TestVisitor,
279                                 public detail::ExpectedLocationVisitorHelper {
280   ASTContext *getASTContext() override { return Context; }
281 };
282 } // namespace clang
283 
284 #endif
285