diff --git a/llvm/lib/Transforms/Scalar/InstructionCombining.cpp b/llvm/lib/Transforms/Scalar/InstructionCombining.cpp index 5f9c54be4e78..f5a85751eaae 100644 --- a/llvm/lib/Transforms/Scalar/InstructionCombining.cpp +++ b/llvm/lib/Transforms/Scalar/InstructionCombining.cpp @@ -1342,6 +1342,11 @@ bool InstCombiner::SimplifyDemandedBits(Value *V, APInt DemandedMask, InsertNewInstBefore(cast(NewVal), *I); return UpdateValueUsesWith(I, NewVal); } + + // If the sign bit is the only bit demanded by this ashr, then there is no + // need to do it, the shift doesn't change the high bit. + if (DemandedMask.isSignBit()) + return UpdateValueUsesWith(I, I->getOperand(0)); if (ConstantInt *SA = dyn_cast(I->getOperand(1))) { uint32_t ShiftAmt = SA->getLimitedValue(BitWidth); @@ -4841,22 +4846,29 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // already been handled above, this requires little checking. // switch (I.getPredicate()) { - default: break; - case ICmpInst::ICMP_ULE: - return new ICmpInst(ICmpInst::ICMP_ULT, Op0, AddOne(CI)); - case ICmpInst::ICMP_SLE: - return new ICmpInst(ICmpInst::ICMP_SLT, Op0, AddOne(CI)); - case ICmpInst::ICMP_UGE: - return new ICmpInst( ICmpInst::ICMP_UGT, Op0, SubOne(CI)); - case ICmpInst::ICMP_SGE: - return new ICmpInst(ICmpInst::ICMP_SGT, Op0, SubOne(CI)); + default: break; + case ICmpInst::ICMP_ULE: + return new ICmpInst(ICmpInst::ICMP_ULT, Op0, AddOne(CI)); + case ICmpInst::ICMP_SLE: + return new ICmpInst(ICmpInst::ICMP_SLT, Op0, AddOne(CI)); + case ICmpInst::ICMP_UGE: + return new ICmpInst( ICmpInst::ICMP_UGT, Op0, SubOne(CI)); + case ICmpInst::ICMP_SGE: + return new ICmpInst(ICmpInst::ICMP_SGT, Op0, SubOne(CI)); } // See if we can fold the comparison based on bits known to be zero or one - // in the input. + // in the input. If this comparison is a normal comparison, it demands all + // bits, if it is a sign bit comparison, it only demands the sign bit. + + bool UnusedBit; + bool isSignBit = isSignBitCheck(I.getPredicate(), CI, UnusedBit); + uint32_t BitWidth = cast(Ty)->getBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - if (SimplifyDemandedBits(Op0, APInt::getAllOnesValue(BitWidth), + if (SimplifyDemandedBits(Op0, + isSignBit ? APInt::getSignBit(BitWidth) + : APInt::getAllOnesValue(BitWidth), KnownZero, KnownOne, 0)) return &I; diff --git a/llvm/test/Transforms/InstCombine/shift-simplify.ll b/llvm/test/Transforms/InstCombine/shift-simplify.ll index 4c846127482c..e02838583950 100644 --- a/llvm/test/Transforms/InstCombine/shift-simplify.ll +++ b/llvm/test/Transforms/InstCombine/shift-simplify.ll @@ -28,3 +28,15 @@ define i1 @test3(i32 %X) { ret i1 %tmp2 } +define i1 @test4(i32 %X) { + %tmp1 = lshr i32 %X, 7 + %tmp2 = icmp slt i32 %tmp1, 0 + ret i1 %tmp2 +} + +define i1 @test5(i32 %X) { + %tmp1 = ashr i32 %X, 7 + %tmp2 = icmp slt i32 %tmp1, 0 + ret i1 %tmp2 +} +