diff --git a/clang/AST/StmtIterator.cpp b/clang/AST/StmtIterator.cpp index 2d198c029d94..961ca50f2445 100644 --- a/clang/AST/StmtIterator.cpp +++ b/clang/AST/StmtIterator.cpp @@ -17,11 +17,23 @@ using namespace clang; +static inline bool declHasExpr(ScopedDecl *decl) { + if (VarDecl* D = dyn_cast(decl)) + if (D->getInit()) + return true; + + if (EnumConstantDecl* D = dyn_cast(decl)) + if (D->getInitExpr()) + return true; + + return false; +} + void StmtIteratorBase::NextDecl() { assert (FirstDecl && Ptr.D); do Ptr.D = Ptr.D->getNextDeclarator(); - while (Ptr.D != NULL && !isa(Ptr.D)); + while (Ptr.D != NULL && !declHasExpr(Ptr.D)); if (Ptr.D == NULL) FirstDecl = NULL; } @@ -29,12 +41,8 @@ void StmtIteratorBase::NextDecl() { StmtIteratorBase::StmtIteratorBase(ScopedDecl* d) { assert (d); - while (d != NULL) { - if (VarDecl* V = dyn_cast(d)) - if (V->getInit()) break; - + while (d != NULL && !declHasExpr(d)) d = d->getNextDeclarator(); - } FirstDecl = d; Ptr.D = d; @@ -61,6 +69,11 @@ void StmtIteratorBase::PrevDecl() { Ptr.D = lastVD; } -Stmt*& StmtIteratorBase::GetInitializer() const { - return reinterpret_cast(cast(Ptr.D)->Init); +Stmt*& StmtIteratorBase::GetDeclExpr() const { + if (VarDecl* D = dyn_cast(Ptr.D)) + return reinterpret_cast(D->Init); + else { + EnumConstantDecl* D = cast(Ptr.D); + return reinterpret_cast(D->Init); + } } diff --git a/clang/include/clang/AST/Decl.h b/clang/include/clang/AST/Decl.h index 5788dc40f55b..b8214a7120ca 100644 --- a/clang/include/clang/AST/Decl.h +++ b/clang/include/clang/AST/Decl.h @@ -430,6 +430,8 @@ public: // Implement isa/cast/dyncast/etc. static bool classof(const Decl *D) { return D->getKind() == EnumConstant; } static bool classof(const EnumConstantDecl *D) { return true; } + + friend class StmtIteratorBase; }; diff --git a/clang/include/clang/AST/StmtIterator.h b/clang/include/clang/AST/StmtIterator.h index 3752793c9fd1..14c6aed3d499 100644 --- a/clang/include/clang/AST/StmtIterator.h +++ b/clang/include/clang/AST/StmtIterator.h @@ -28,7 +28,7 @@ protected: void NextDecl(); void PrevDecl(); - Stmt*& GetInitializer() const; + Stmt*& GetDeclExpr() const; StmtIteratorBase(Stmt** s) : FirstDecl(NULL) { Ptr.S = s; } StmtIteratorBase(ScopedDecl* d); @@ -84,7 +84,7 @@ public: } REFERENCE operator*() const { - return (REFERENCE) (FirstDecl ? GetInitializer() : *Ptr.S); + return (REFERENCE) (FirstDecl ? GetDeclExpr() : *Ptr.S); } REFERENCE operator->() const { return operator*(); }