diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h index 93363b171528..e32643933461 100644 --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -405,18 +405,14 @@ private: bool TraverseFunctionHelper(FunctionDecl *D); bool TraverseVarHelper(VarDecl *D); - bool Walk(Stmt *S); - struct EnqueueJob { Stmt *S; Stmt::child_iterator StmtIt; - EnqueueJob(Stmt *S) : S(S), StmtIt() { - if (Expr *E = dyn_cast_or_null(S)) - S = E->IgnoreParens(); - } + EnqueueJob(Stmt *S) : S(S), StmtIt() {} }; bool dataTraverse(Stmt *S); + bool dataTraverseNode(Stmt *S, bool &EnqueueChildren); }; template @@ -435,7 +431,12 @@ bool RecursiveASTVisitor::dataTraverse(Stmt *S) { if (getDerived().shouldUseDataRecursionFor(CurrS)) { if (job.StmtIt == Stmt::child_iterator()) { - if (!Walk(CurrS)) return false; + bool EnqueueChildren = true; + if (!dataTraverseNode(CurrS, EnqueueChildren)) return false; + if (!EnqueueChildren) { + Queue.pop_back(); + continue; + } job.StmtIt = CurrS->child_begin(); } else { ++job.StmtIt; @@ -456,10 +457,16 @@ bool RecursiveASTVisitor::dataTraverse(Stmt *S) { } template -bool RecursiveASTVisitor::Walk(Stmt *S) { +bool RecursiveASTVisitor::dataTraverseNode(Stmt *S, + bool &EnqueueChildren) { + // Dispatch to the corresponding WalkUpFrom* function only if the derived + // class didn't override Traverse* (and thus the traversal is trivial). #define DISPATCH_WALK(NAME, CLASS, VAR) \ - return getDerived().WalkUpFrom##NAME(static_cast(VAR)); + if (&RecursiveASTVisitor::Traverse##NAME == &Derived::Traverse##NAME) \ + return getDerived().WalkUpFrom##NAME(static_cast(VAR)); \ + EnqueueChildren = false; \ + return getDerived().Traverse##NAME(static_cast(VAR)); if (BinaryOperator *BinOp = dyn_cast(S)) { switch (BinOp->getOpcode()) { diff --git a/clang/unittests/Tooling/RecursiveASTVisitorTest.cpp b/clang/unittests/Tooling/RecursiveASTVisitorTest.cpp index d4fda73ccb8d..39803c35bc6e 100644 --- a/clang/unittests/Tooling/RecursiveASTVisitorTest.cpp +++ b/clang/unittests/Tooling/RecursiveASTVisitorTest.cpp @@ -165,6 +165,25 @@ public: } }; +class CXXOperatorCallExprTraverser + : public ExpectedLocationVisitor { +public: + // Use Traverse, not Visit, to check that data recursion optimization isn't + // bypassing the call of this function. + bool TraverseCXXOperatorCallExpr(CXXOperatorCallExpr *CE) { + Match(getOperatorSpelling(CE->getOperator()), CE->getExprLoc()); + return ExpectedLocationVisitor::TraverseCXXOperatorCallExpr(CE); + } +}; + +class ParenExprVisitor : public ExpectedLocationVisitor { +public: + bool VisitParenExpr(ParenExpr *Parens) { + Match("", Parens->getExprLoc()); + return true; + } +}; + TEST(RecursiveASTVisitor, VisitsBaseClassDeclarations) { TypeLocVisitor Visitor; Visitor.ExpectMatch("class X", 1, 30); @@ -345,4 +364,20 @@ TEST(RecursiveASTVisitor, NoRecursionInSelfFriend) { "vector_iterator it_int;\n")); } +TEST(RecursiveASTVisitor, TraversesOverloadedOperator) { + CXXOperatorCallExprTraverser Visitor; + Visitor.ExpectMatch("()", 4, 9); + EXPECT_TRUE(Visitor.runOver( + "struct A {\n" + " int operator()();\n" + "} a;\n" + "int k = a();\n")); +} + +TEST(RecursiveASTVisitor, VisitsParensDuringDataRecursion) { + ParenExprVisitor Visitor; + Visitor.ExpectMatch("", 1, 9); + EXPECT_TRUE(Visitor.runOver("int k = (4) + 9;\n")); +} + } // end namespace clang