1 //===--- unittests/Tooling/RecursiveASTVisitorTests/CallbacksCommon.h -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "CRTPTestVisitor.h" 10 11 using namespace clang; 12 13 namespace { 14 15 enum class ShouldTraversePostOrder : bool { 16 No = false, 17 Yes = true, 18 }; 19 20 /// Base class for tests for RecursiveASTVisitor tests that validate the 21 /// sequence of calls to user-defined callbacks like Traverse*(), WalkUp*(), 22 /// Visit*(). 23 template <typename Derived> 24 class RecordingVisitorBase : public CRTPTestVisitor<Derived> { 25 ShouldTraversePostOrder ShouldTraversePostOrderValue; 26 27 public: 28 RecordingVisitorBase(ShouldTraversePostOrder ShouldTraversePostOrderValue) 29 : ShouldTraversePostOrderValue(ShouldTraversePostOrderValue) {} 30 31 bool shouldTraversePostOrder() const { 32 return static_cast<bool>(ShouldTraversePostOrderValue); 33 } 34 35 // Callbacks received during traversal. 36 std::string CallbackLog; 37 unsigned CallbackLogIndent = 0; 38 39 std::string stmtToString(Stmt *S) { 40 StringRef ClassName = S->getStmtClassName(); 41 if (IntegerLiteral *IL = dyn_cast<IntegerLiteral>(S)) { 42 return (ClassName + "(" + toString(IL->getValue(), 10, false) + ")").str(); 43 } 44 if (UnaryOperator *UO = dyn_cast<UnaryOperator>(S)) { 45 return (ClassName + "(" + UnaryOperator::getOpcodeStr(UO->getOpcode()) + 46 ")") 47 .str(); 48 } 49 if (BinaryOperator *BO = dyn_cast<BinaryOperator>(S)) { 50 return (ClassName + "(" + BinaryOperator::getOpcodeStr(BO->getOpcode()) + 51 ")") 52 .str(); 53 } 54 if (CallExpr *CE = dyn_cast<CallExpr>(S)) { 55 if (FunctionDecl *Callee = CE->getDirectCallee()) { 56 if (Callee->getIdentifier()) { 57 return (ClassName + "(" + Callee->getName() + ")").str(); 58 } 59 } 60 } 61 if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(S)) { 62 if (NamedDecl *ND = DRE->getFoundDecl()) { 63 if (ND->getIdentifier()) { 64 return (ClassName + "(" + ND->getName() + ")").str(); 65 } 66 } 67 } 68 return ClassName.str(); 69 } 70 71 /// Record the fact that the user-defined callback member function 72 /// \p CallbackName was called with the argument \p S. Then, record the 73 /// effects of calling the default implementation \p CallDefaultFn. 74 template <typename CallDefault> 75 void recordCallback(StringRef CallbackName, Stmt *S, 76 CallDefault CallDefaultFn) { 77 for (unsigned i = 0; i != CallbackLogIndent; ++i) { 78 CallbackLog += " "; 79 } 80 CallbackLog += (CallbackName + " " + stmtToString(S) + "\n").str(); 81 ++CallbackLogIndent; 82 CallDefaultFn(); 83 --CallbackLogIndent; 84 } 85 }; 86 87 template <typename VisitorTy> 88 ::testing::AssertionResult visitorCallbackLogEqual(VisitorTy Visitor, 89 StringRef Code, 90 StringRef ExpectedLog) { 91 Visitor.runOver(Code); 92 // EXPECT_EQ shows the diff between the two strings if they are different. 93 EXPECT_EQ(ExpectedLog.trim().str(), 94 StringRef(Visitor.CallbackLog).trim().str()); 95 if (ExpectedLog.trim() != StringRef(Visitor.CallbackLog).trim()) { 96 return ::testing::AssertionFailure(); 97 } 98 return ::testing::AssertionSuccess(); 99 } 100 101 } // namespace 102