diff --git a/clang/lib/CodeGen/CGCXX.cpp b/clang/lib/CodeGen/CGCXX.cpp index 6be2b82fa583..b7670dfb45f1 100644 --- a/clang/lib/CodeGen/CGCXX.cpp +++ b/clang/lib/CodeGen/CGCXX.cpp @@ -1211,10 +1211,10 @@ void CodeGenFunction::EmitClassMemberwiseCopy( const CXXRecordDecl *ClassDecl, const CXXRecordDecl *BaseClassDecl, QualType Ty) { if (ClassDecl) { - Dest = GetAddressCXXOfBaseClass(Dest, ClassDecl, BaseClassDecl, - /*NullCheckValue=*/false); - Src = GetAddressCXXOfBaseClass(Src, ClassDecl, BaseClassDecl, - /*NullCheckValue=*/false); + Dest = GetAddressOfBaseClass(Dest, ClassDecl, BaseClassDecl, + /*NullCheckValue=*/false); + Src = GetAddressOfBaseClass(Src, ClassDecl, BaseClassDecl, + /*NullCheckValue=*/false); } if (BaseClassDecl->hasTrivialCopyConstructor()) { EmitAggregateCopy(Dest, Src, Ty); @@ -1250,10 +1250,10 @@ void CodeGenFunction::EmitClassCopyAssignment( const CXXRecordDecl *BaseClassDecl, QualType Ty) { if (ClassDecl) { - Dest = GetAddressCXXOfBaseClass(Dest, ClassDecl, BaseClassDecl, - /*NullCheckValue=*/false); - Src = GetAddressCXXOfBaseClass(Src, ClassDecl, BaseClassDecl, - /*NullCheckValue=*/false); + Dest = GetAddressOfBaseClass(Dest, ClassDecl, BaseClassDecl, + /*NullCheckValue=*/false); + Src = GetAddressOfBaseClass(Src, ClassDecl, BaseClassDecl, + /*NullCheckValue=*/false); } if (BaseClassDecl->hasTrivialCopyAssignment()) { EmitAggregateCopy(Dest, Src, Ty); @@ -1493,9 +1493,9 @@ static void EmitBaseInitializer(CodeGenFunction &CGF, const Type *BaseType = BaseInit->getBaseClass(); CXXRecordDecl *BaseClassDecl = cast(BaseType->getAs()->getDecl()); - llvm::Value *V = CGF.GetAddressCXXOfBaseClass(ThisPtr, ClassDecl, - BaseClassDecl, - /*NullCheckValue=*/false); + llvm::Value *V = CGF.GetAddressOfBaseClass(ThisPtr, ClassDecl, + BaseClassDecl, + /*NullCheckValue=*/false); CGF.EmitCXXConstructorCall(BaseInit->getConstructor(), CtorType, V, BaseInit->const_arg_begin(), @@ -1710,9 +1710,9 @@ void CodeGenFunction::EmitDtorEpilogue(const CXXDestructorDecl *DD, if (BaseClassDecl->hasTrivialDestructor()) continue; - llvm::Value *V = GetAddressCXXOfBaseClass(LoadCXXThis(), - ClassDecl, BaseClassDecl, - /*NullCheckValue=*/false); + llvm::Value *V = GetAddressOfBaseClass(LoadCXXThis(), + ClassDecl, BaseClassDecl, + /*NullCheckValue=*/false); EmitCXXDestructorCall(BaseClassDecl->getDestructor(getContext()), Dtor_Base, V); } diff --git a/clang/lib/CodeGen/CGCXXClass.cpp b/clang/lib/CodeGen/CGCXXClass.cpp index 533aabc8616e..e122b95a7338 100644 --- a/clang/lib/CodeGen/CGCXXClass.cpp +++ b/clang/lib/CodeGen/CGCXXClass.cpp @@ -117,10 +117,10 @@ static llvm::Value *GetCXXBaseClassOffset(CodeGenFunction &CGF, } llvm::Value * -CodeGenFunction::GetAddressCXXOfBaseClass(llvm::Value *BaseValue, - const CXXRecordDecl *ClassDecl, - const CXXRecordDecl *BaseClassDecl, - bool NullCheckValue) { +CodeGenFunction::GetAddressOfBaseClass(llvm::Value *Value, + const CXXRecordDecl *ClassDecl, + const CXXRecordDecl *BaseClassDecl, + bool NullCheckValue) { QualType BTy = getContext().getCanonicalType( getContext().getTypeDeclType(const_cast(BaseClassDecl))); @@ -128,7 +128,7 @@ CodeGenFunction::GetAddressCXXOfBaseClass(llvm::Value *BaseValue, if (ClassDecl == BaseClassDecl) { // Just cast back. - return Builder.CreateBitCast(BaseValue, BasePtrTy); + return Builder.CreateBitCast(Value, BasePtrTy); } llvm::BasicBlock *CastNull = 0; @@ -141,8 +141,8 @@ CodeGenFunction::GetAddressCXXOfBaseClass(llvm::Value *BaseValue, CastEnd = createBasicBlock("cast.end"); llvm::Value *IsNull = - Builder.CreateICmpEQ(BaseValue, - llvm::Constant::getNullValue(BaseValue->getType())); + Builder.CreateICmpEQ(Value, + llvm::Constant::getNullValue(Value->getType())); Builder.CreateCondBr(IsNull, CastNull, CastNotNull); EmitBlock(CastNotNull); } @@ -150,16 +150,16 @@ CodeGenFunction::GetAddressCXXOfBaseClass(llvm::Value *BaseValue, const llvm::Type *Int8PtrTy = llvm::Type::getInt8PtrTy(VMContext); llvm::Value *Offset = - GetCXXBaseClassOffset(*this, BaseValue, ClassDecl, BaseClassDecl); + GetCXXBaseClassOffset(*this, Value, ClassDecl, BaseClassDecl); if (Offset) { // Apply the offset. - BaseValue = Builder.CreateBitCast(BaseValue, Int8PtrTy); - BaseValue = Builder.CreateGEP(BaseValue, Offset, "add.ptr"); + Value = Builder.CreateBitCast(Value, Int8PtrTy); + Value = Builder.CreateGEP(Value, Offset, "add.ptr"); } // Cast back. - BaseValue = Builder.CreateBitCast(BaseValue, BasePtrTy); + Value = Builder.CreateBitCast(Value, BasePtrTy); if (NullCheckValue) { Builder.CreateBr(CastEnd); @@ -167,13 +167,73 @@ CodeGenFunction::GetAddressCXXOfBaseClass(llvm::Value *BaseValue, Builder.CreateBr(CastEnd); EmitBlock(CastEnd); - llvm::PHINode *PHI = Builder.CreatePHI(BaseValue->getType()); + llvm::PHINode *PHI = Builder.CreatePHI(Value->getType()); PHI->reserveOperandSpace(2); - PHI->addIncoming(BaseValue, CastNotNull); - PHI->addIncoming(llvm::Constant::getNullValue(BaseValue->getType()), + PHI->addIncoming(Value, CastNotNull); + PHI->addIncoming(llvm::Constant::getNullValue(Value->getType()), CastNull); - BaseValue = PHI; + Value = PHI; } - return BaseValue; + return Value; +} + +llvm::Value * +CodeGenFunction::GetAddressOfDerivedClass(llvm::Value *Value, + const CXXRecordDecl *ClassDecl, + const CXXRecordDecl *DerivedClassDecl, + bool NullCheckValue) { + QualType DerivedTy = + getContext().getCanonicalType( + getContext().getTypeDeclType(const_cast(DerivedClassDecl))); + const llvm::Type *DerivedPtrTy = ConvertType(DerivedTy)->getPointerTo(); + + if (ClassDecl == DerivedClassDecl) { + // Just cast back. + return Builder.CreateBitCast(Value, DerivedPtrTy); + } + + llvm::BasicBlock *CastNull = 0; + llvm::BasicBlock *CastNotNull = 0; + llvm::BasicBlock *CastEnd = 0; + + if (NullCheckValue) { + CastNull = createBasicBlock("cast.null"); + CastNotNull = createBasicBlock("cast.notnull"); + CastEnd = createBasicBlock("cast.end"); + + llvm::Value *IsNull = + Builder.CreateICmpEQ(Value, + llvm::Constant::getNullValue(Value->getType())); + Builder.CreateCondBr(IsNull, CastNull, CastNotNull); + EmitBlock(CastNotNull); + } + + llvm::Value *Offset = GetCXXBaseClassOffset(*this, Value, DerivedClassDecl, + ClassDecl); + if (Offset) { + // Apply the offset. + Value = Builder.CreatePtrToInt(Value, Offset->getType()); + Value = Builder.CreateSub(Value, Offset); + Value = Builder.CreateIntToPtr(Value, DerivedPtrTy); + } else { + // Just cast. + Value = Builder.CreateBitCast(Value, DerivedPtrTy); + } + + if (NullCheckValue) { + Builder.CreateBr(CastEnd); + EmitBlock(CastNull); + Builder.CreateBr(CastEnd); + EmitBlock(CastEnd); + + llvm::PHINode *PHI = Builder.CreatePHI(Value->getType()); + PHI->reserveOperandSpace(2); + PHI->addIncoming(Value, CastNotNull); + PHI->addIncoming(llvm::Constant::getNullValue(Value->getType()), + CastNull); + Value = PHI; + } + + return Value; } diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp index 2a544c560931..3a93473c5428 100644 --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -148,8 +148,8 @@ RValue CodeGenFunction::EmitReferenceBindingToExpr(const Expr* E, if (BaseClassDecl) { llvm::Value *Derived = Val.getAggregateAddr(); llvm::Value *Base = - GetAddressCXXOfBaseClass(Derived, DerivedClassDecl, BaseClassDecl, - /*NullCheckValue=*/false); + GetAddressOfBaseClass(Derived, DerivedClassDecl, BaseClassDecl, + /*NullCheckValue=*/false); return RValue::get(Base); } } @@ -1328,8 +1328,8 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) { // Perform the derived-to-base conversion llvm::Value *Base = - GetAddressCXXOfBaseClass(LV.getAddress(), DerivedClassDecl, - BaseClassDecl, /*NullCheckValue=*/false); + GetAddressOfBaseClass(LV.getAddress(), DerivedClassDecl, + BaseClassDecl, /*NullCheckValue=*/false); return LValue::MakeAddr(Base, MakeQualifiers(E->getType())); } @@ -1340,7 +1340,23 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) { return LValue::MakeAddr(Temp, MakeQualifiers(E->getType())); } case CastExpr::CK_BaseToDerived: { - return EmitUnsupportedLValue(E, "base-to-derived cast lvalue"); + const RecordType *BaseClassTy = + E->getSubExpr()->getType()->getAs(); + CXXRecordDecl *BaseClassDecl = + cast(BaseClassTy->getDecl()); + + const RecordType *DerivedClassTy = E->getType()->getAs(); + CXXRecordDecl *DerivedClassDecl = + cast(DerivedClassTy->getDecl()); + + LValue LV = EmitLValue(E->getSubExpr()); + + // Perform the base-to-derived conversion + llvm::Value *Derived = + GetAddressOfDerivedClass(LV.getAddress(), BaseClassDecl, + DerivedClassDecl, /*NullCheckValue=*/false); + + return LValue::MakeAddr(Derived, MakeQualifiers(E->getType())); } case CastExpr::CK_BitCast: { // This must be a reinterpret_cast (or c-style equivalent). diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp index ca8adffb5ea1..5c6657c6c7d6 100644 --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -748,6 +748,23 @@ Value *ScalarExprEmitter::VisitInitListExpr(InitListExpr *E) { return V; } +static bool ShouldNullCheckClassCastValue(const CastExpr *CE) { + const Expr *E = CE->getSubExpr(); + + if (isa(E)) { + // We always assume that 'this' is never null. + return false; + } + + if (const ImplicitCastExpr *ICE = dyn_cast(CE)) { + // And that lvalue casts are never null. + if (ICE->isLvalueCast()) + return false; + } + + return true; +} + // VisitCastExpr - Emit code for an explicit or implicit cast. Implicit casts // have to handle a more broad range of conversions than explicit casts, as they // handle things like function to ptr-to-function decay etc. @@ -775,6 +792,19 @@ Value *ScalarExprEmitter::EmitCastExpr(const CastExpr *CE) { case CastExpr::CK_NoOp: return Visit(const_cast(E)); + case CastExpr::CK_BaseToDerived: { + const CXXRecordDecl *BaseClassDecl = + E->getType()->getCXXRecordDeclForPointerType(); + const CXXRecordDecl *DerivedClassDecl = + DestTy->getCXXRecordDeclForPointerType(); + + Value *Src = Visit(const_cast(E)); + + bool NullCheckValue = ShouldNullCheckClassCastValue(CE); + return CGF.GetAddressOfDerivedClass(Src, BaseClassDecl, DerivedClassDecl, + NullCheckValue); + } + case CastExpr::CK_DerivedToBase: { const RecordType *DerivedClassTy = E->getType()->getAs()->getPointeeType()->getAs(); @@ -787,18 +817,9 @@ Value *ScalarExprEmitter::EmitCastExpr(const CastExpr *CE) { Value *Src = Visit(const_cast(E)); - bool NullCheckValue = true; - - if (isa(E)) { - // We always assume that 'this' is never null. - NullCheckValue = false; - } else if (const ImplicitCastExpr *ICE = dyn_cast(CE)) { - // And that lvalue casts are never null. - if (ICE->isLvalueCast()) - NullCheckValue = false; - } - return CGF.GetAddressCXXOfBaseClass(Src, DerivedClassDecl, BaseClassDecl, - NullCheckValue); + bool NullCheckValue = ShouldNullCheckClassCastValue(CE); + return CGF.GetAddressOfBaseClass(Src, DerivedClassDecl, BaseClassDecl, + NullCheckValue); } case CastExpr::CK_ToUnion: { assert(0 && "Should be unreachable!"); diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index 0be527c8d465..6d0d81ecfc96 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -598,15 +598,20 @@ public: /// generating code for an C++ member function. llvm::Value *LoadCXXThis(); - /// GetAddressCXXOfBaseClass - This function will add the necessary delta + /// GetAddressOfBaseClass - This function will add the necessary delta /// to the load of 'this' and returns address of the base class. // FIXME. This currently only does a derived to non-virtual base conversion. // Other kinds of conversions will come later. - llvm::Value *GetAddressCXXOfBaseClass(llvm::Value *BaseValue, - const CXXRecordDecl *ClassDecl, - const CXXRecordDecl *BaseClassDecl, - bool NullCheckValue); + llvm::Value *GetAddressOfBaseClass(llvm::Value *Value, + const CXXRecordDecl *ClassDecl, + const CXXRecordDecl *BaseClassDecl, + bool NullCheckValue); + llvm::Value *GetAddressOfDerivedClass(llvm::Value *Value, + const CXXRecordDecl *ClassDecl, + const CXXRecordDecl *DerivedClassDecl, + bool NullCheckValue); + llvm::Value * GetVirtualCXXBaseClassOffset(llvm::Value *This, const CXXRecordDecl *ClassDecl,