diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp index 5b11579f66bc..a30602889d5d 100644 --- a/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/llvm/lib/Transforms/Scalar/SROA.cpp @@ -2678,6 +2678,24 @@ private: return V; } + /// \brief Compute a vector splat for a given element value. + Value *getVectorSplat(IRBuilder<> &IRB, Value *V, unsigned NumElements) { + assert(NumElements > 0 && "Cannot splat to an empty vector."); + + // First insert it into a one-element vector so we can shuffle it. It is + // really silly that LLVM's IR requires this in order to form a splat. + Value *Undef = UndefValue::get(VectorType::get(V->getType(), 1)); + V = IRB.CreateInsertElement(Undef, V, IRB.getInt32(0), + getName(".splatinsert")); + + // Shuffle the value across the desired number of elements. + SmallVector Mask(NumElements, IRB.getInt32(0)); + V = IRB.CreateShuffleVector(V, Undef, ConstantVector::get(Mask), + getName(".splat")); + DEBUG(dbgs() << " splat: " << *V << "\n"); + return V; + } + bool visitMemSetInst(MemSetInst &II) { DEBUG(dbgs() << " original: " << II << "\n"); IRBuilder<> IRB(&II); @@ -2706,7 +2724,8 @@ private: (BeginOffset != NewAllocaBeginOffset || EndOffset != NewAllocaEndOffset || !AllocaTy->isSingleValueType() || - !TD.isLegalInteger(TD.getTypeSizeInBits(ScalarTy)))) { + !TD.isLegalInteger(TD.getTypeSizeInBits(ScalarTy)) || + TD.getTypeSizeInBits(ScalarTy)%8 != 0)) { Type *SizeTy = II.getLength()->getType(); Constant *Size = ConstantInt::get(SizeTy, EndOffset - BeginOffset); CallInst *New @@ -2722,45 +2741,60 @@ private: // If we can represent this as a simple value, we have to build the actual // value to store, which requires expanding the byte present in memset to // a sensible representation for the alloca type. This is essentially - // splatting the byte to a sufficiently wide integer, bitcasting to the - // desired scalar type, and splatting it across any desired vector type. + // splatting the byte to a sufficiently wide integer, splatting it across + // any desired vector width, and bitcasting to the final type. uint64_t Size = EndOffset - BeginOffset; Value *V = getIntegerSplat(IRB, II.getValue(), Size); - // If this is an element-wide memset of a vectorizable alloca, insert it. - if (VecTy && (BeginOffset > NewAllocaBeginOffset || - EndOffset < NewAllocaEndOffset)) { - if (V->getType() != ScalarTy) - V = convertValue(TD, IRB, V, ScalarTy); - StoreInst *Store = IRB.CreateAlignedStore( - IRB.CreateInsertElement(IRB.CreateAlignedLoad(&NewAI, - NewAI.getAlignment(), - getName(".load")), - V, IRB.getInt32(getIndex(BeginOffset)), - getName(".insert")), - &NewAI, NewAI.getAlignment()); - (void)Store; - DEBUG(dbgs() << " to: " << *Store << "\n"); - return true; - } + if (VecTy) { + // If this is a memset of a vectorized alloca, insert it. + assert(ElementTy == ScalarTy); - // If this is a memset on an alloca where we can widen stores, insert the - // set integer. - if (IntTy && (BeginOffset > NewAllocaBeginOffset || - EndOffset < NewAllocaEndOffset)) { + unsigned BeginIndex = getIndex(BeginOffset); + unsigned EndIndex = getIndex(EndOffset); + assert(EndIndex > BeginIndex && "Empty vector!"); + unsigned NumElements = EndIndex - BeginIndex; + assert(NumElements <= VecTy->getNumElements() && "Too many elements!"); + + Value *Splat = getIntegerSplat(IRB, II.getValue(), + TD.getTypeSizeInBits(ElementTy)/8); + if (NumElements > 1) + Splat = getVectorSplat(IRB, Splat, NumElements); + + V = insertVector(IRB, Splat, BeginIndex, EndIndex); + } else if (IntTy) { + // If this is a memset on an alloca where we can widen stores, insert the + // set integer. assert(!II.isVolatile()); - Value *Old = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), - getName(".oldload")); - Old = convertValue(TD, IRB, Old, IntTy); - assert(BeginOffset >= NewAllocaBeginOffset && "Out of bounds offset"); - uint64_t Offset = BeginOffset - NewAllocaBeginOffset; - V = insertInteger(TD, IRB, Old, V, Offset, getName(".insert")); + + V = getIntegerSplat(IRB, II.getValue(), Size); + + if (IntTy && (BeginOffset != NewAllocaBeginOffset || + EndOffset != NewAllocaBeginOffset)) { + Value *Old = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), + getName(".oldload")); + Old = convertValue(TD, IRB, Old, IntTy); + assert(BeginOffset >= NewAllocaBeginOffset && "Out of bounds offset"); + uint64_t Offset = BeginOffset - NewAllocaBeginOffset; + V = insertInteger(TD, IRB, Old, V, Offset, getName(".insert")); + } else { + assert(V->getType() == IntTy && + "Wrong type for an alloca wide integer!"); + } + } else { + // Established these invariants above. + assert(BeginOffset == NewAllocaBeginOffset); + assert(EndOffset == NewAllocaEndOffset); + + V = getIntegerSplat(IRB, II.getValue(), + TD.getTypeSizeInBits(ScalarTy)/8); + + if (VectorType *AllocaVecTy = dyn_cast(AllocaTy)) + V = getVectorSplat(IRB, V, AllocaVecTy->getNumElements()); } - if (V->getType() != AllocaTy) - V = convertValue(TD, IRB, V, AllocaTy); - - Value *New = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment(), + Value *New = IRB.CreateAlignedStore(convertValue(TD, IRB, V, AllocaTy), + &NewAI, NewAI.getAlignment(), II.isVolatile()); (void)New; DEBUG(dbgs() << " to: " << *New << "\n"); diff --git a/llvm/test/Transforms/SROA/vector-promotion.ll b/llvm/test/Transforms/SROA/vector-promotion.ll index bb34e3f084ac..f957fef6dd07 100644 --- a/llvm/test/Transforms/SROA/vector-promotion.ll +++ b/llvm/test/Transforms/SROA/vector-promotion.ll @@ -279,6 +279,41 @@ entry: ; CHECK-NEXT: ret <4 x i32> %[[ret]] } +declare void @llvm.memset.p0i32.i32(i32* nocapture, i32, i32, i32, i1) nounwind + +define <4 x i32> @test_subvec_memset() { +; CHECK: @test_subvec_memset +entry: + %a = alloca <4 x i32> +; CHECK-NOT: alloca + + %a.gep0 = getelementptr <4 x i32>* %a, i32 0, i32 0 + %a.cast0 = bitcast i32* %a.gep0 to i8* + call void @llvm.memset.p0i8.i32(i8* %a.cast0, i8 0, i32 8, i32 0, i1 false) +; CHECK-NOT: store +; CHECK: %[[insert1:.*]] = shufflevector <4 x i32> , <4 x i32> undef, <4 x i32> + + %a.gep1 = getelementptr <4 x i32>* %a, i32 0, i32 1 + %a.cast1 = bitcast i32* %a.gep1 to i8* + call void @llvm.memset.p0i8.i32(i8* %a.cast1, i8 1, i32 8, i32 0, i1 false) +; CHECK-NEXT: %[[insert2:.*]] = shufflevector <4 x i32> , <4 x i32> %[[insert1]], <4 x i32> + + %a.gep2 = getelementptr <4 x i32>* %a, i32 0, i32 2 + %a.cast2 = bitcast i32* %a.gep2 to i8* + call void @llvm.memset.p0i8.i32(i8* %a.cast2, i8 3, i32 8, i32 0, i1 false) +; CHECK-NEXT: %[[insert3:.*]] = shufflevector <4 x i32> , <4 x i32> %[[insert2]], <4 x i32> + + %a.gep3 = getelementptr <4 x i32>* %a, i32 0, i32 3 + %a.cast3 = bitcast i32* %a.gep3 to i8* + call void @llvm.memset.p0i8.i32(i8* %a.cast3, i8 7, i32 4, i32 0, i1 false) +; CHECK-NEXT: %[[insert4:.*]] = insertelement <4 x i32> %[[insert3]], i32 117901063, i32 3 + + %ret = load <4 x i32>* %a + + ret <4 x i32> %ret +; CHECK-NEXT: ret <4 x i32> %[[insert4]] +} + define i32 @PR14212() { ; CHECK: @PR14212 ; This caused a crash when "splitting" the load of the i32 in order to promote