xref: /llvm-project/clang/unittests/AST/RecursiveASTVisitorTest.cpp (revision 805f7a4fa4ce97277c3b73d0c204fc3aa4b072e1)
1 //===- unittest/AST/RecursiveASTVisitorTest.cpp ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "clang/AST/RecursiveASTVisitor.h"
10 #include "clang/AST/ASTConsumer.h"
11 #include "clang/AST/ASTContext.h"
12 #include "clang/AST/Attr.h"
13 #include "clang/AST/Decl.h"
14 #include "clang/AST/TypeLoc.h"
15 #include "clang/Frontend/FrontendAction.h"
16 #include "clang/Tooling/Tooling.h"
17 #include "llvm/ADT/FunctionExtras.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "gmock/gmock.h"
20 #include "gtest/gtest.h"
21 #include <cassert>
22 
23 using namespace clang;
24 using ::testing::ElementsAre;
25 
26 namespace {
27 class ProcessASTAction : public clang::ASTFrontendAction {
28 public:
ProcessASTAction(llvm::unique_function<void (clang::ASTContext &)> Process)29   ProcessASTAction(llvm::unique_function<void(clang::ASTContext &)> Process)
30       : Process(std::move(Process)) {
31     assert(this->Process);
32   }
33 
CreateASTConsumer(CompilerInstance & CI,StringRef InFile)34   std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI,
35                                                  StringRef InFile) {
36     class Consumer : public ASTConsumer {
37     public:
38       Consumer(llvm::function_ref<void(ASTContext &CTx)> Process)
39           : Process(Process) {}
40 
41       void HandleTranslationUnit(ASTContext &Ctx) override { Process(Ctx); }
42 
43     private:
44       llvm::function_ref<void(ASTContext &CTx)> Process;
45     };
46 
47     return std::make_unique<Consumer>(Process);
48   }
49 
50 private:
51   llvm::unique_function<void(clang::ASTContext &)> Process;
52 };
53 
54 enum class VisitEvent {
55   StartTraverseFunction,
56   EndTraverseFunction,
57   StartTraverseAttr,
58   EndTraverseAttr,
59   StartTraverseEnum,
60   EndTraverseEnum,
61   StartTraverseTypedefType,
62   EndTraverseTypedefType,
63   StartTraverseObjCInterface,
64   EndTraverseObjCInterface,
65   StartTraverseObjCProtocol,
66   EndTraverseObjCProtocol,
67   StartTraverseObjCProtocolLoc,
68   EndTraverseObjCProtocolLoc,
69 };
70 
71 class CollectInterestingEvents
72     : public RecursiveASTVisitor<CollectInterestingEvents> {
73 public:
TraverseFunctionDecl(FunctionDecl * D)74   bool TraverseFunctionDecl(FunctionDecl *D) {
75     Events.push_back(VisitEvent::StartTraverseFunction);
76     bool Ret = RecursiveASTVisitor::TraverseFunctionDecl(D);
77     Events.push_back(VisitEvent::EndTraverseFunction);
78 
79     return Ret;
80   }
81 
TraverseAttr(Attr * A)82   bool TraverseAttr(Attr *A) {
83     Events.push_back(VisitEvent::StartTraverseAttr);
84     bool Ret = RecursiveASTVisitor::TraverseAttr(A);
85     Events.push_back(VisitEvent::EndTraverseAttr);
86 
87     return Ret;
88   }
89 
TraverseEnumDecl(EnumDecl * D)90   bool TraverseEnumDecl(EnumDecl *D) {
91     Events.push_back(VisitEvent::StartTraverseEnum);
92     bool Ret = RecursiveASTVisitor::TraverseEnumDecl(D);
93     Events.push_back(VisitEvent::EndTraverseEnum);
94 
95     return Ret;
96   }
97 
TraverseTypedefTypeLoc(TypedefTypeLoc TL)98   bool TraverseTypedefTypeLoc(TypedefTypeLoc TL) {
99     Events.push_back(VisitEvent::StartTraverseTypedefType);
100     bool Ret = RecursiveASTVisitor::TraverseTypedefTypeLoc(TL);
101     Events.push_back(VisitEvent::EndTraverseTypedefType);
102 
103     return Ret;
104   }
105 
TraverseObjCInterfaceDecl(ObjCInterfaceDecl * ID)106   bool TraverseObjCInterfaceDecl(ObjCInterfaceDecl *ID) {
107     Events.push_back(VisitEvent::StartTraverseObjCInterface);
108     bool Ret = RecursiveASTVisitor::TraverseObjCInterfaceDecl(ID);
109     Events.push_back(VisitEvent::EndTraverseObjCInterface);
110 
111     return Ret;
112   }
113 
TraverseObjCProtocolDecl(ObjCProtocolDecl * PD)114   bool TraverseObjCProtocolDecl(ObjCProtocolDecl *PD) {
115     Events.push_back(VisitEvent::StartTraverseObjCProtocol);
116     bool Ret = RecursiveASTVisitor::TraverseObjCProtocolDecl(PD);
117     Events.push_back(VisitEvent::EndTraverseObjCProtocol);
118 
119     return Ret;
120   }
121 
TraverseObjCProtocolLoc(ObjCProtocolLoc ProtocolLoc)122   bool TraverseObjCProtocolLoc(ObjCProtocolLoc ProtocolLoc) {
123     Events.push_back(VisitEvent::StartTraverseObjCProtocolLoc);
124     bool Ret = RecursiveASTVisitor::TraverseObjCProtocolLoc(ProtocolLoc);
125     Events.push_back(VisitEvent::EndTraverseObjCProtocolLoc);
126 
127     return Ret;
128   }
129 
takeEvents()130   std::vector<VisitEvent> takeEvents() && { return std::move(Events); }
131 
132 private:
133   std::vector<VisitEvent> Events;
134 };
135 
collectEvents(llvm::StringRef Code,const Twine & FileName="input.cc")136 std::vector<VisitEvent> collectEvents(llvm::StringRef Code,
137                                       const Twine &FileName = "input.cc") {
138   CollectInterestingEvents Visitor;
139   clang::tooling::runToolOnCode(
140       std::make_unique<ProcessASTAction>(
141           [&](clang::ASTContext &Ctx) { Visitor.TraverseAST(Ctx); }),
142       Code, FileName);
143   return std::move(Visitor).takeEvents();
144 }
145 } // namespace
146 
TEST(RecursiveASTVisitorTest,AttributesInsideDecls)147 TEST(RecursiveASTVisitorTest, AttributesInsideDecls) {
148   /// Check attributes are traversed inside TraverseFunctionDecl.
149   llvm::StringRef Code = R"cpp(
150 __attribute__((annotate("something"))) int foo() { return 10; }
151   )cpp";
152 
153   EXPECT_THAT(collectEvents(Code),
154               ElementsAre(VisitEvent::StartTraverseFunction,
155                           VisitEvent::StartTraverseAttr,
156                           VisitEvent::EndTraverseAttr,
157                           VisitEvent::EndTraverseFunction));
158 }
159 
TEST(RecursiveASTVisitorTest,EnumDeclWithBase)160 TEST(RecursiveASTVisitorTest, EnumDeclWithBase) {
161   // Check enum and its integer base is visited.
162   llvm::StringRef Code = R"cpp(
163   typedef int Foo;
164   enum Bar : Foo;
165   )cpp";
166 
167   EXPECT_THAT(collectEvents(Code),
168               ElementsAre(VisitEvent::StartTraverseEnum,
169                           VisitEvent::StartTraverseTypedefType,
170                           VisitEvent::EndTraverseTypedefType,
171                           VisitEvent::EndTraverseEnum));
172 }
173 
TEST(RecursiveASTVisitorTest,InterfaceDeclWithProtocols)174 TEST(RecursiveASTVisitorTest, InterfaceDeclWithProtocols) {
175   // Check interface and its protocols are visited.
176   llvm::StringRef Code = R"cpp(
177   @protocol Foo
178   @end
179   @protocol Bar
180   @end
181 
182   @interface SomeObject <Foo, Bar>
183   @end
184   )cpp";
185 
186   EXPECT_THAT(collectEvents(Code, "input.m"),
187               ElementsAre(VisitEvent::StartTraverseObjCProtocol,
188                           VisitEvent::EndTraverseObjCProtocol,
189                           VisitEvent::StartTraverseObjCProtocol,
190                           VisitEvent::EndTraverseObjCProtocol,
191                           VisitEvent::StartTraverseObjCInterface,
192                           VisitEvent::StartTraverseObjCProtocolLoc,
193                           VisitEvent::EndTraverseObjCProtocolLoc,
194                           VisitEvent::StartTraverseObjCProtocolLoc,
195                           VisitEvent::EndTraverseObjCProtocolLoc,
196                           VisitEvent::EndTraverseObjCInterface));
197 }
198