diff --git a/clang/Driver/RewriteBlocks.cpp b/clang/Driver/RewriteBlocks.cpp index 05cbb704d125..2e7ce0559f76 100644 --- a/clang/Driver/RewriteBlocks.cpp +++ b/clang/Driver/RewriteBlocks.cpp @@ -140,8 +140,12 @@ public: void RewriteCategoryDecl(ObjCCategoryDecl *CatDecl); void RewriteProtocolDecl(ObjCProtocolDecl *PDecl); void RewriteMethodDecl(ObjCMethodDecl *MDecl); + + void RewriteFunctionTypeProto(QualType funcType, NamedDecl *D); + void CheckFunctionPointerDecl(QualType dType, NamedDecl *ND); + void RewriteCastExpr(CastExpr *CE); - bool BlockPointerTypeTakesAnyBlockArguments(QualType QT); + bool PointerTypeTakesAnyBlockArguments(QualType QT); void GetExtentOfArgList(const char *Name, const char *&LParen, const char *&RParen); }; @@ -738,6 +742,28 @@ void RewriteBlocks::RewriteBlockDeclRefExpr(BlockDeclRefExpr *BDRE) { InsertText(BDRE->getLocStart(), "*", 1); } +void RewriteBlocks::RewriteCastExpr(CastExpr *CE) { + SourceLocation LocStart = CE->getLocStart(); + SourceLocation LocEnd = CE->getLocEnd(); + + const char *startBuf = SM->getCharacterData(LocStart); + const char *endBuf = SM->getCharacterData(LocEnd); + + // advance the location to startArgList. + const char *argPtr = startBuf; + + while (*argPtr++ && (argPtr < endBuf)) { + switch (*argPtr) { + case '^': + // Replace the '^' with '*'. + LocStart = LocStart.getFileLocWithOffset(argPtr-startBuf); + ReplaceText(LocStart, 1, "*", 1); + break; + } + } + return; +} + void RewriteBlocks::RewriteBlockPointerFunctionArgs(FunctionDecl *FD) { SourceLocation DeclLoc = FD->getLocation(); unsigned parenCount = 0; @@ -773,10 +799,17 @@ void RewriteBlocks::RewriteBlockPointerFunctionArgs(FunctionDecl *FD) { return; } -bool RewriteBlocks::BlockPointerTypeTakesAnyBlockArguments(QualType QT) { - const BlockPointerType *BPT = QT->getAsBlockPointerType(); - assert(BPT && "BlockPointerTypeTakeAnyBlockArguments(): not a block pointer type"); - const FunctionTypeProto *FTP = BPT->getPointeeType()->getAsFunctionTypeProto(); +bool RewriteBlocks::PointerTypeTakesAnyBlockArguments(QualType QT) { + const FunctionTypeProto *FTP; + const PointerType *PT = QT->getAsPointerType(); + if (PT) { + FTP = PT->getPointeeType()->getAsFunctionTypeProto(); + assert(FTP && "BlockPointerTypeTakeAnyBlockArguments(): not a function pointer type"); + } else { + const BlockPointerType *BPT = QT->getAsBlockPointerType(); + assert(BPT && "BlockPointerTypeTakeAnyBlockArguments(): not a block pointer type"); + FTP = BPT->getPointeeType()->getAsFunctionTypeProto(); + } if (FTP) { for (FunctionTypeProto::arg_type_iterator I = FTP->arg_type_begin(), E = FTP->arg_type_end(); I != E; ++I) @@ -829,13 +862,15 @@ void RewriteBlocks::RewriteBlockPointerDecl(NamedDecl *ND) { // scan backward (from the decl location) for the end of the previous decl. while (*startBuf != '^' && *startBuf != ';' && startBuf != MainFileStart) startBuf--; - assert((*startBuf == '^') && - "RewriteBlockPointerDecl() scan error: no caret"); - // Replace the '^' with '*', computing a negative offset. - DeclLoc = DeclLoc.getFileLocWithOffset(startBuf-endBuf); - ReplaceText(DeclLoc, 1, "*", 1); - - if (BlockPointerTypeTakesAnyBlockArguments(DeclT)) { + + // *startBuf != '^' if we are dealing with a pointer to function that + // may take block argument types (which will be handled below). + if (*startBuf == '^') { + // Replace the '^' with '*', computing a negative offset. + DeclLoc = DeclLoc.getFileLocWithOffset(startBuf-endBuf); + ReplaceText(DeclLoc, 1, "*", 1); + } + if (PointerTypeTakesAnyBlockArguments(DeclT)) { // Replace the '^' with '*' for arguments. DeclLoc = ND->getLocation(); startBuf = SM->getCharacterData(DeclLoc); @@ -981,6 +1016,9 @@ Stmt *RewriteBlocks::RewriteFunctionBody(Stmt *S) { if (CE->getCallee()->getType()->isBlockPointerType()) RewriteBlockCall(CE); } + if (CastExpr *CE = dyn_cast(S)) { + RewriteCastExpr(CE); + } if (DeclStmt *DS = dyn_cast(S)) { for (DeclStmt::decl_iterator DI = DS->decl_begin(), DE = DS->decl_end(); DI != DE; ++DI) { @@ -989,10 +1027,14 @@ Stmt *RewriteBlocks::RewriteFunctionBody(Stmt *S) { if (ValueDecl *ND = dyn_cast(SD)) { if (isBlockPointerType(ND->getType())) RewriteBlockPointerDecl(ND); + else if (ND->getType()->isFunctionPointerType()) + CheckFunctionPointerDecl(ND->getType(), ND); } if (TypedefDecl *TD = dyn_cast(SD)) { if (isBlockPointerType(TD->getUnderlyingType())) RewriteBlockPointerDecl(TD); + else if (TD->getUnderlyingType()->isFunctionPointerType()) + CheckFunctionPointerDecl(TD->getUnderlyingType(), TD); } } } @@ -1005,25 +1047,33 @@ Stmt *RewriteBlocks::RewriteFunctionBody(Stmt *S) { return S; } +void RewriteBlocks::RewriteFunctionTypeProto(QualType funcType, NamedDecl *D) { + if (FunctionTypeProto *fproto = dyn_cast(funcType)) { + for (FunctionTypeProto::arg_type_iterator I = fproto->arg_type_begin(), + E = fproto->arg_type_end(); I && (I != E); ++I) + if (isBlockPointerType(*I)) { + // All the args are checked/rewritten. Don't call twice! + RewriteBlockPointerDecl(D); + break; + } + } +} + +void RewriteBlocks::CheckFunctionPointerDecl(QualType funcType, NamedDecl *ND) { + const PointerType *PT = funcType->getAsPointerType(); + if (PT && PointerTypeTakesAnyBlockArguments(funcType)) + RewriteFunctionTypeProto(PT->getPointeeType(), ND); +} + /// HandleDeclInMainFile - This is called for each top-level decl defined in the /// main file of the input. void RewriteBlocks::HandleDeclInMainFile(Decl *D) { if (FunctionDecl *FD = dyn_cast(D)) { - // Since function prototypes don't have ParmDecl's, we check the function // prototype. This enables us to rewrite function declarations and // definitions using the same code. - QualType funcType = FD->getType(); + RewriteFunctionTypeProto(FD->getType(), FD); - if (FunctionTypeProto *fproto = dyn_cast(funcType)) { - for (FunctionTypeProto::arg_type_iterator I = fproto->arg_type_begin(), - E = fproto->arg_type_end(); I && (I != E); ++I) - if (isBlockPointerType(*I)) { - // All the args are checked/rewritten. Don't call twice! - RewriteBlockPointerDecl(FD); - break; - } - } if (Stmt *Body = FD->getBody()) { CurFunctionDef = FD; FD->setBody(RewriteFunctionBody(Body)); @@ -1058,6 +1108,15 @@ void RewriteBlocks::HandleDeclInMainFile(Decl *D) { // Do the rewrite, using S.size() which contains the rewritten size. ReplaceText(CBE->getLocStart(), S.size(), Init.c_str(), Init.size()); SynthesizeBlockLiterals(VD->getTypeSpecStartLoc(), VD->getName()); + } else if (CastExpr *CE = dyn_cast(VD->getInit())) { + RewriteCastExpr(CE); + } + } + } else if (VD->getType()->isFunctionPointerType()) { + CheckFunctionPointerDecl(VD->getType(), VD); + if (VD->getInit()) { + if (CastExpr *CE = dyn_cast(VD->getInit())) { + RewriteCastExpr(CE); } } } @@ -1066,6 +1125,8 @@ void RewriteBlocks::HandleDeclInMainFile(Decl *D) { if (TypedefDecl *TD = dyn_cast(D)) { if (isBlockPointerType(TD->getUnderlyingType())) RewriteBlockPointerDecl(TD); + else if (TD->getUnderlyingType()->isFunctionPointerType()) + CheckFunctionPointerDecl(TD->getUnderlyingType(), TD); return; } if (RecordDecl *RD = dyn_cast(D)) { diff --git a/clang/test/Rewriter/block-test.c b/clang/test/Rewriter/block-test.c index 0a6bde0886ad..82b63a09f9d1 100644 --- a/clang/test/Rewriter/block-test.c +++ b/clang/test/Rewriter/block-test.c @@ -1,5 +1,16 @@ // RUN: clang -rewrite-blocks %s -o - +static int (^block)(const void *, const void *) = (int (^)(const void *, const void *))0; +static int (*func)(int (^block)(void *, void *)) = (int (*)(int (^block)(void *, void *)))0; + +typedef int (^block_T)(const void *, const void *); +typedef int (*func_T)(int (^block)(void *, void *)); + +void foo(const void *a, const void *b, void *c) { + int (^block)(const void *, const void *) = (int (^)(const void *, const void *))c; + int (*func)(int (^block)(void *, void *)) = (int (*)(int (^block)(void *, void *)))c; +} + typedef void (^test_block_t)(); int main(int argc, char **argv) {