xref: /llvm-project/clang/unittests/AST/MatchVerifier.h (revision 7dfdca1961aadc75ca397818bfb9bd32f1879248)
1 //===- unittest/AST/MatchVerifier.h - AST unit test support ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //  Provides MatchVerifier, a base class to implement gtest matchers that
10 //  verify things that can be matched on the AST.
11 //
12 //  Also implements matchers based on MatchVerifier:
13 //  LocationVerifier and RangeVerifier to verify whether a matched node has
14 //  the expected source location or source range.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #ifndef LLVM_CLANG_UNITTESTS_AST_MATCHVERIFIER_H
19 #define LLVM_CLANG_UNITTESTS_AST_MATCHVERIFIER_H
20 
21 #include "clang/AST/ASTContext.h"
22 #include "clang/ASTMatchers/ASTMatchFinder.h"
23 #include "clang/ASTMatchers/ASTMatchers.h"
24 #include "clang/Testing/CommandLineArgs.h"
25 #include "clang/Tooling/Tooling.h"
26 #include "gtest/gtest.h"
27 
28 namespace clang {
29 namespace ast_matchers {
30 
31 /// \brief Base class for verifying some property of nodes found by a matcher.
32 template <typename NodeType>
33 class MatchVerifier : public MatchFinder::MatchCallback {
34 public:
35   template <typename MatcherType>
36   testing::AssertionResult match(const std::string &Code,
37                                  const MatcherType &AMatcher) {
38     std::vector<std::string> Args;
39     return match(Code, AMatcher, Args, Lang_CXX03);
40   }
41 
42   template <typename MatcherType>
43   testing::AssertionResult match(const std::string &Code,
44                                  const MatcherType &AMatcher, TestLanguage L) {
45     std::vector<std::string> Args;
46     return match(Code, AMatcher, Args, L);
47   }
48 
49   template <typename MatcherType>
50   testing::AssertionResult
51   match(const std::string &Code, const MatcherType &AMatcher,
52         std::vector<std::string> &Args, TestLanguage L);
53 
54   template <typename MatcherType>
55   testing::AssertionResult match(const Decl *D, const MatcherType &AMatcher);
56 
57 protected:
58   void run(const MatchFinder::MatchResult &Result) override;
59   virtual void verify(const MatchFinder::MatchResult &Result,
60                       const NodeType &Node) {}
61 
62   void setFailure(const Twine &Result) {
63     Verified = false;
64     VerifyResult = Result.str();
65   }
66 
67   void setSuccess() {
68     Verified = true;
69   }
70 
71 private:
72   bool Verified;
73   std::string VerifyResult;
74 };
75 
76 /// \brief Runs a matcher over some code, and returns the result of the
77 /// verifier for the matched node.
78 template <typename NodeType>
79 template <typename MatcherType>
80 testing::AssertionResult
81 MatchVerifier<NodeType>::match(const std::string &Code,
82                                const MatcherType &AMatcher,
83                                std::vector<std::string> &Args, TestLanguage L) {
84   MatchFinder Finder;
85   Finder.addMatcher(AMatcher.bind(""), this);
86   std::unique_ptr<tooling::FrontendActionFactory> Factory(
87       tooling::newFrontendActionFactory(&Finder));
88 
89   StringRef FileName;
90   switch (L) {
91 #define TESTLANGUAGE(lang, version, std_flag, version_index)                   \
92   case Lang_##lang##version:                                                   \
93     Args.push_back("-std=" #std_flag);                                         \
94     FileName = getFilenameForTesting(Lang_##lang##version);                    \
95     break;
96 #include "clang/Testing/TestLanguage.def"
97 
98   case Lang_OpenCL:
99     Args.push_back("-cl-no-stdinc");
100     FileName = "input.cl";
101     break;
102   case Lang_OBJC:
103     Args.push_back("-fobjc-nonfragile-abi");
104     FileName = "input.m";
105     break;
106   case Lang_OBJCXX:
107     FileName = "input.mm";
108     break;
109   }
110 
111   // Default to failure in case callback is never called
112   setFailure("Could not find match");
113   if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName))
114     return testing::AssertionFailure() << "Parsing error";
115   if (!Verified)
116     return testing::AssertionFailure() << VerifyResult;
117   return testing::AssertionSuccess();
118 }
119 
120 /// \brief Runs a matcher over some AST, and returns the result of the
121 /// verifier for the matched node.
122 template <typename NodeType> template <typename MatcherType>
123 testing::AssertionResult MatchVerifier<NodeType>::match(
124     const Decl *D, const MatcherType &AMatcher) {
125   MatchFinder Finder;
126   Finder.addMatcher(AMatcher.bind(""), this);
127 
128   setFailure("Could not find match");
129   Finder.match(*D, D->getASTContext());
130 
131   if (!Verified)
132     return testing::AssertionFailure() << VerifyResult;
133   return testing::AssertionSuccess();
134 }
135 
136 template <typename NodeType>
137 void MatchVerifier<NodeType>::run(const MatchFinder::MatchResult &Result) {
138   const NodeType *Node = Result.Nodes.getNodeAs<NodeType>("");
139   if (!Node) {
140     setFailure("Matched node has wrong type");
141   } else {
142     // Callback has been called, default to success.
143     setSuccess();
144     verify(Result, *Node);
145   }
146 }
147 
148 template <>
149 inline void
150 MatchVerifier<DynTypedNode>::run(const MatchFinder::MatchResult &Result) {
151   BoundNodes::IDToNodeMap M = Result.Nodes.getMap();
152   BoundNodes::IDToNodeMap::const_iterator I = M.find("");
153   if (I == M.end()) {
154     setFailure("Node was not bound");
155   } else {
156     // Callback has been called, default to success.
157     setSuccess();
158     verify(Result, I->second);
159   }
160 }
161 
162 /// \brief Verify whether a node has the correct source location.
163 ///
164 /// By default, Node.getSourceLocation() is checked. This can be changed
165 /// by overriding getLocation().
166 template <typename NodeType>
167 class LocationVerifier : public MatchVerifier<NodeType> {
168 public:
169   void expectLocation(unsigned Line, unsigned Column) {
170     ExpectLine = Line;
171     ExpectColumn = Column;
172   }
173 
174 protected:
175   void verify(const MatchFinder::MatchResult &Result,
176               const NodeType &Node) override {
177     SourceLocation Loc = getLocation(Node);
178     unsigned Line = Result.SourceManager->getSpellingLineNumber(Loc);
179     unsigned Column = Result.SourceManager->getSpellingColumnNumber(Loc);
180     if (Line != ExpectLine || Column != ExpectColumn) {
181       std::string MsgStr;
182       llvm::raw_string_ostream Msg(MsgStr);
183       Msg << "Expected location <" << ExpectLine << ":" << ExpectColumn
184           << ">, found <";
185       Loc.print(Msg, *Result.SourceManager);
186       Msg << '>';
187       this->setFailure(MsgStr);
188     }
189   }
190 
191   virtual SourceLocation getLocation(const NodeType &Node) {
192     return Node.getLocation();
193   }
194 
195 private:
196   unsigned ExpectLine, ExpectColumn;
197 };
198 
199 /// \brief Verify whether a node has the correct source range.
200 ///
201 /// By default, Node.getSourceRange() is checked. This can be changed
202 /// by overriding getRange().
203 template <typename NodeType>
204 class RangeVerifier : public MatchVerifier<NodeType> {
205 public:
206   void expectRange(unsigned BeginLine, unsigned BeginColumn,
207                    unsigned EndLine, unsigned EndColumn) {
208     ExpectBeginLine = BeginLine;
209     ExpectBeginColumn = BeginColumn;
210     ExpectEndLine = EndLine;
211     ExpectEndColumn = EndColumn;
212   }
213 
214 protected:
215   void verify(const MatchFinder::MatchResult &Result,
216               const NodeType &Node) override {
217     SourceRange R = getRange(Node);
218     SourceLocation Begin = R.getBegin();
219     SourceLocation End = R.getEnd();
220     unsigned BeginLine = Result.SourceManager->getSpellingLineNumber(Begin);
221     unsigned BeginColumn = Result.SourceManager->getSpellingColumnNumber(Begin);
222     unsigned EndLine = Result.SourceManager->getSpellingLineNumber(End);
223     unsigned EndColumn = Result.SourceManager->getSpellingColumnNumber(End);
224     if (BeginLine != ExpectBeginLine || BeginColumn != ExpectBeginColumn ||
225         EndLine != ExpectEndLine || EndColumn != ExpectEndColumn) {
226       std::string MsgStr;
227       llvm::raw_string_ostream Msg(MsgStr);
228       Msg << "Expected range <" << ExpectBeginLine << ":" << ExpectBeginColumn
229           << '-' << ExpectEndLine << ":" << ExpectEndColumn << ">, found <";
230       Begin.print(Msg, *Result.SourceManager);
231       Msg << '-';
232       End.print(Msg, *Result.SourceManager);
233       Msg << '>';
234       this->setFailure(MsgStr);
235     }
236   }
237 
238   virtual SourceRange getRange(const NodeType &Node) {
239     return Node.getSourceRange();
240   }
241 
242 private:
243   unsigned ExpectBeginLine, ExpectBeginColumn, ExpectEndLine, ExpectEndColumn;
244 };
245 
246 /// \brief Verify whether a node's dump contains a given substring.
247 class DumpVerifier : public MatchVerifier<DynTypedNode> {
248 public:
249   void expectSubstring(const std::string &Str) {
250     ExpectSubstring = Str;
251   }
252 
253 protected:
254   void verify(const MatchFinder::MatchResult &Result,
255               const DynTypedNode &Node) override {
256     std::string DumpStr;
257     llvm::raw_string_ostream Dump(DumpStr);
258     Node.dump(Dump, *Result.Context);
259 
260     if (DumpStr.find(ExpectSubstring) == std::string::npos) {
261       std::string MsgStr;
262       llvm::raw_string_ostream Msg(MsgStr);
263       Msg << "Expected dump substring <" << ExpectSubstring << ">, found <"
264           << DumpStr << '>';
265       this->setFailure(MsgStr);
266     }
267   }
268 
269 private:
270   std::string ExpectSubstring;
271 };
272 
273 /// \brief Verify whether a node's pretty print matches a given string.
274 class PrintVerifier : public MatchVerifier<DynTypedNode> {
275 public:
276   void expectString(const std::string &Str) {
277     ExpectString = Str;
278   }
279 
280 protected:
281   void verify(const MatchFinder::MatchResult &Result,
282               const DynTypedNode &Node) override {
283     std::string PrintStr;
284     llvm::raw_string_ostream Print(PrintStr);
285     Node.print(Print, Result.Context->getPrintingPolicy());
286 
287     if (PrintStr != ExpectString) {
288       std::string MsgStr;
289       llvm::raw_string_ostream Msg(MsgStr);
290       Msg << "Expected pretty print <" << ExpectString << ">, found <"
291           << PrintStr << '>';
292       this->setFailure(MsgStr);
293     }
294   }
295 
296 private:
297   std::string ExpectString;
298 };
299 
300 } // end namespace ast_matchers
301 } // end namespace clang
302 
303 #endif
304