1 //===- unittest/AST/MatchVerifier.h - AST unit test support ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Provides MatchVerifier, a base class to implement gtest matchers that 10 // verify things that can be matched on the AST. 11 // 12 // Also implements matchers based on MatchVerifier: 13 // LocationVerifier and RangeVerifier to verify whether a matched node has 14 // the expected source location or source range. 15 // 16 //===----------------------------------------------------------------------===// 17 18 #ifndef LLVM_CLANG_UNITTESTS_AST_MATCHVERIFIER_H 19 #define LLVM_CLANG_UNITTESTS_AST_MATCHVERIFIER_H 20 21 #include "clang/AST/ASTContext.h" 22 #include "clang/ASTMatchers/ASTMatchFinder.h" 23 #include "clang/ASTMatchers/ASTMatchers.h" 24 #include "clang/Testing/CommandLineArgs.h" 25 #include "clang/Tooling/Tooling.h" 26 #include "gtest/gtest.h" 27 28 namespace clang { 29 namespace ast_matchers { 30 31 /// \brief Base class for verifying some property of nodes found by a matcher. 32 template <typename NodeType> 33 class MatchVerifier : public MatchFinder::MatchCallback { 34 public: 35 template <typename MatcherType> 36 testing::AssertionResult match(const std::string &Code, 37 const MatcherType &AMatcher) { 38 std::vector<std::string> Args; 39 return match(Code, AMatcher, Args, Lang_CXX03); 40 } 41 42 template <typename MatcherType> 43 testing::AssertionResult match(const std::string &Code, 44 const MatcherType &AMatcher, TestLanguage L) { 45 std::vector<std::string> Args; 46 return match(Code, AMatcher, Args, L); 47 } 48 49 template <typename MatcherType> 50 testing::AssertionResult 51 match(const std::string &Code, const MatcherType &AMatcher, 52 std::vector<std::string> &Args, TestLanguage L); 53 54 template <typename MatcherType> 55 testing::AssertionResult match(const Decl *D, const MatcherType &AMatcher); 56 57 protected: 58 void run(const MatchFinder::MatchResult &Result) override; 59 virtual void verify(const MatchFinder::MatchResult &Result, 60 const NodeType &Node) {} 61 62 void setFailure(const Twine &Result) { 63 Verified = false; 64 VerifyResult = Result.str(); 65 } 66 67 void setSuccess() { 68 Verified = true; 69 } 70 71 private: 72 bool Verified; 73 std::string VerifyResult; 74 }; 75 76 /// \brief Runs a matcher over some code, and returns the result of the 77 /// verifier for the matched node. 78 template <typename NodeType> 79 template <typename MatcherType> 80 testing::AssertionResult 81 MatchVerifier<NodeType>::match(const std::string &Code, 82 const MatcherType &AMatcher, 83 std::vector<std::string> &Args, TestLanguage L) { 84 MatchFinder Finder; 85 Finder.addMatcher(AMatcher.bind(""), this); 86 std::unique_ptr<tooling::FrontendActionFactory> Factory( 87 tooling::newFrontendActionFactory(&Finder)); 88 89 StringRef FileName; 90 switch (L) { 91 #define TESTLANGUAGE(lang, version, std_flag, version_index) \ 92 case Lang_##lang##version: \ 93 Args.push_back("-std=" #std_flag); \ 94 FileName = getFilenameForTesting(Lang_##lang##version); \ 95 break; 96 #include "clang/Testing/TestLanguage.def" 97 98 case Lang_OpenCL: 99 Args.push_back("-cl-no-stdinc"); 100 FileName = "input.cl"; 101 break; 102 case Lang_OBJC: 103 Args.push_back("-fobjc-nonfragile-abi"); 104 FileName = "input.m"; 105 break; 106 case Lang_OBJCXX: 107 FileName = "input.mm"; 108 break; 109 } 110 111 // Default to failure in case callback is never called 112 setFailure("Could not find match"); 113 if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName)) 114 return testing::AssertionFailure() << "Parsing error"; 115 if (!Verified) 116 return testing::AssertionFailure() << VerifyResult; 117 return testing::AssertionSuccess(); 118 } 119 120 /// \brief Runs a matcher over some AST, and returns the result of the 121 /// verifier for the matched node. 122 template <typename NodeType> template <typename MatcherType> 123 testing::AssertionResult MatchVerifier<NodeType>::match( 124 const Decl *D, const MatcherType &AMatcher) { 125 MatchFinder Finder; 126 Finder.addMatcher(AMatcher.bind(""), this); 127 128 setFailure("Could not find match"); 129 Finder.match(*D, D->getASTContext()); 130 131 if (!Verified) 132 return testing::AssertionFailure() << VerifyResult; 133 return testing::AssertionSuccess(); 134 } 135 136 template <typename NodeType> 137 void MatchVerifier<NodeType>::run(const MatchFinder::MatchResult &Result) { 138 const NodeType *Node = Result.Nodes.getNodeAs<NodeType>(""); 139 if (!Node) { 140 setFailure("Matched node has wrong type"); 141 } else { 142 // Callback has been called, default to success. 143 setSuccess(); 144 verify(Result, *Node); 145 } 146 } 147 148 template <> 149 inline void 150 MatchVerifier<DynTypedNode>::run(const MatchFinder::MatchResult &Result) { 151 BoundNodes::IDToNodeMap M = Result.Nodes.getMap(); 152 BoundNodes::IDToNodeMap::const_iterator I = M.find(""); 153 if (I == M.end()) { 154 setFailure("Node was not bound"); 155 } else { 156 // Callback has been called, default to success. 157 setSuccess(); 158 verify(Result, I->second); 159 } 160 } 161 162 /// \brief Verify whether a node has the correct source location. 163 /// 164 /// By default, Node.getSourceLocation() is checked. This can be changed 165 /// by overriding getLocation(). 166 template <typename NodeType> 167 class LocationVerifier : public MatchVerifier<NodeType> { 168 public: 169 void expectLocation(unsigned Line, unsigned Column) { 170 ExpectLine = Line; 171 ExpectColumn = Column; 172 } 173 174 protected: 175 void verify(const MatchFinder::MatchResult &Result, 176 const NodeType &Node) override { 177 SourceLocation Loc = getLocation(Node); 178 unsigned Line = Result.SourceManager->getSpellingLineNumber(Loc); 179 unsigned Column = Result.SourceManager->getSpellingColumnNumber(Loc); 180 if (Line != ExpectLine || Column != ExpectColumn) { 181 std::string MsgStr; 182 llvm::raw_string_ostream Msg(MsgStr); 183 Msg << "Expected location <" << ExpectLine << ":" << ExpectColumn 184 << ">, found <"; 185 Loc.print(Msg, *Result.SourceManager); 186 Msg << '>'; 187 this->setFailure(MsgStr); 188 } 189 } 190 191 virtual SourceLocation getLocation(const NodeType &Node) { 192 return Node.getLocation(); 193 } 194 195 private: 196 unsigned ExpectLine, ExpectColumn; 197 }; 198 199 /// \brief Verify whether a node has the correct source range. 200 /// 201 /// By default, Node.getSourceRange() is checked. This can be changed 202 /// by overriding getRange(). 203 template <typename NodeType> 204 class RangeVerifier : public MatchVerifier<NodeType> { 205 public: 206 void expectRange(unsigned BeginLine, unsigned BeginColumn, 207 unsigned EndLine, unsigned EndColumn) { 208 ExpectBeginLine = BeginLine; 209 ExpectBeginColumn = BeginColumn; 210 ExpectEndLine = EndLine; 211 ExpectEndColumn = EndColumn; 212 } 213 214 protected: 215 void verify(const MatchFinder::MatchResult &Result, 216 const NodeType &Node) override { 217 SourceRange R = getRange(Node); 218 SourceLocation Begin = R.getBegin(); 219 SourceLocation End = R.getEnd(); 220 unsigned BeginLine = Result.SourceManager->getSpellingLineNumber(Begin); 221 unsigned BeginColumn = Result.SourceManager->getSpellingColumnNumber(Begin); 222 unsigned EndLine = Result.SourceManager->getSpellingLineNumber(End); 223 unsigned EndColumn = Result.SourceManager->getSpellingColumnNumber(End); 224 if (BeginLine != ExpectBeginLine || BeginColumn != ExpectBeginColumn || 225 EndLine != ExpectEndLine || EndColumn != ExpectEndColumn) { 226 std::string MsgStr; 227 llvm::raw_string_ostream Msg(MsgStr); 228 Msg << "Expected range <" << ExpectBeginLine << ":" << ExpectBeginColumn 229 << '-' << ExpectEndLine << ":" << ExpectEndColumn << ">, found <"; 230 Begin.print(Msg, *Result.SourceManager); 231 Msg << '-'; 232 End.print(Msg, *Result.SourceManager); 233 Msg << '>'; 234 this->setFailure(MsgStr); 235 } 236 } 237 238 virtual SourceRange getRange(const NodeType &Node) { 239 return Node.getSourceRange(); 240 } 241 242 private: 243 unsigned ExpectBeginLine, ExpectBeginColumn, ExpectEndLine, ExpectEndColumn; 244 }; 245 246 /// \brief Verify whether a node's dump contains a given substring. 247 class DumpVerifier : public MatchVerifier<DynTypedNode> { 248 public: 249 void expectSubstring(const std::string &Str) { 250 ExpectSubstring = Str; 251 } 252 253 protected: 254 void verify(const MatchFinder::MatchResult &Result, 255 const DynTypedNode &Node) override { 256 std::string DumpStr; 257 llvm::raw_string_ostream Dump(DumpStr); 258 Node.dump(Dump, *Result.Context); 259 260 if (DumpStr.find(ExpectSubstring) == std::string::npos) { 261 std::string MsgStr; 262 llvm::raw_string_ostream Msg(MsgStr); 263 Msg << "Expected dump substring <" << ExpectSubstring << ">, found <" 264 << DumpStr << '>'; 265 this->setFailure(MsgStr); 266 } 267 } 268 269 private: 270 std::string ExpectSubstring; 271 }; 272 273 /// \brief Verify whether a node's pretty print matches a given string. 274 class PrintVerifier : public MatchVerifier<DynTypedNode> { 275 public: 276 void expectString(const std::string &Str) { 277 ExpectString = Str; 278 } 279 280 protected: 281 void verify(const MatchFinder::MatchResult &Result, 282 const DynTypedNode &Node) override { 283 std::string PrintStr; 284 llvm::raw_string_ostream Print(PrintStr); 285 Node.print(Print, Result.Context->getPrintingPolicy()); 286 287 if (PrintStr != ExpectString) { 288 std::string MsgStr; 289 llvm::raw_string_ostream Msg(MsgStr); 290 Msg << "Expected pretty print <" << ExpectString << ">, found <" 291 << PrintStr << '>'; 292 this->setFailure(MsgStr); 293 } 294 } 295 296 private: 297 std::string ExpectString; 298 }; 299 300 } // end namespace ast_matchers 301 } // end namespace clang 302 303 #endif 304