[SCEV] Improve the run-time checking of the NoWrap predicate

Summary:
This implements a new method of run-time checking the NoWrap
SCEV predicates, which should be easier to optimize and nicer
for targets that don't correctly handle multiplication/addition
of large integer types (like i128).

If the AddRec is {a,+,b} and the backedge taken count is c,
the idea is to check that |b| * c doesn't have unsigned overflow,
and depending on the sign of b, that:

   a + |b| * c >= a (b >= 0) or
   a - |b| * c <= a (b <= 0)

where the comparisons above are signed or unsigned, depending on
the flag that we're checking.

The advantage of doing this is that we avoid extending to a larger
type and we avoid the multiplication of large types (multiplying
i128 can be expensive).

Reviewers: sanjoy

Subscribers: llvm-commits, mzolotukhin

Differential Revision: http://reviews.llvm.org/D19266

llvm-svn: 267389
This commit is contained in:
Silviu Baranga 2016-04-25 09:27:16 +00:00
parent a44d44cb2e
commit 795c629ec9
2 changed files with 193 additions and 34 deletions

View File

@ -2007,37 +2007,81 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR,
SCEVUnionPredicate Pred;
const SCEV *ExitCount =
SE.getPredicatedBackedgeTakenCount(AR->getLoop(), Pred);
const SCEV *Step = AR->getStepRecurrence(SE);
const SCEV *Start = AR->getStart();
unsigned DstBits = SE.getTypeSizeInBits(AR->getType());
unsigned SrcBits = SE.getTypeSizeInBits(ExitCount->getType());
unsigned MaxBits = 2 * std::max(DstBits, SrcBits);
auto *TripCount = SE.getTruncateOrZeroExtend(ExitCount, AR->getType());
IntegerType *MaxTy = IntegerType::get(Loc->getContext(), MaxBits);
assert(ExitCount != SE.getCouldNotCompute() && "Invalid loop count");
const auto *ExtendedTripCount = SE.getZeroExtendExpr(ExitCount, MaxTy);
const auto *ExtendedStep = SE.getSignExtendExpr(Step, MaxTy);
const auto *ExtendedStart = Signed ? SE.getSignExtendExpr(Start, MaxTy)
: SE.getZeroExtendExpr(Start, MaxTy);
const SCEV *Step = AR->getStepRecurrence(SE);
const SCEV *Start = AR->getStart();
const SCEV *End = SE.getAddExpr(Start, SE.getMulExpr(TripCount, Step));
const SCEV *RHS = Signed ? SE.getSignExtendExpr(End, MaxTy)
: SE.getZeroExtendExpr(End, MaxTy);
unsigned SrcBits = SE.getTypeSizeInBits(ExitCount->getType());
unsigned DstBits = SE.getTypeSizeInBits(AR->getType());
const SCEV *LHS = SE.getAddExpr(
ExtendedStart, SE.getMulExpr(ExtendedTripCount, ExtendedStep));
// The expression {Start,+,Step} has nusw/nssw if
// Step < 0, Start - |Step| * Backedge <= Start
// Step >= 0, Start + |Step| * Backedge > Start
// and |Step| * Backedge doesn't unsigned overflow.
// Do all SCEV expansions now.
Value *LHSVal = expandCodeFor(LHS, MaxTy, Loc);
Value *RHSVal = expandCodeFor(RHS, MaxTy, Loc);
IntegerType *CountTy = IntegerType::get(Loc->getContext(), SrcBits);
Builder.SetInsertPoint(Loc);
Value *TripCountVal = expandCodeFor(ExitCount, CountTy, Loc);
IntegerType *Ty =
IntegerType::get(Loc->getContext(), SE.getTypeSizeInBits(AR->getType()));
Value *StepValue = expandCodeFor(Step, Ty, Loc);
Value *NegStepValue = expandCodeFor(SE.getNegativeSCEV(Step), Ty, Loc);
Value *StartValue = expandCodeFor(Start, Ty, Loc);
ConstantInt *Zero =
ConstantInt::get(Loc->getContext(), APInt::getNullValue(DstBits));
Builder.SetInsertPoint(Loc);
// Compute |Step|
Value *StepCompare = Builder.CreateICmp(ICmpInst::ICMP_SLT, StepValue, Zero);
Value *AbsStep = Builder.CreateSelect(StepCompare, NegStepValue, StepValue);
return Builder.CreateICmp(ICmpInst::ICMP_NE, RHSVal, LHSVal);
// Get the backedge taken count and truncate or extended to the AR type.
Value *TruncTripCount = Builder.CreateZExtOrTrunc(TripCountVal, Ty);
auto *MulF = Intrinsic::getDeclaration(Loc->getModule(),
Intrinsic::umul_with_overflow, Ty);
// Compute |Step| * Backedge
CallInst *Mul = Builder.CreateCall(MulF, {AbsStep, TruncTripCount}, "mul");
Value *MulV = Builder.CreateExtractValue(Mul, 0, "mul.result");
Value *OfMul = Builder.CreateExtractValue(Mul, 1, "mul.overflow");
// Compute:
// Start + |Step| * Backedge < Start
// Start - |Step| * Backedge > Start
Value *Add = Builder.CreateAdd(StartValue, MulV);
Value *Sub = Builder.CreateSub(StartValue, MulV);
Value *EndCompareGT = Builder.CreateICmp(
Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT, Sub, StartValue);
Value *EndCompareLT = Builder.CreateICmp(
Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, Add, StartValue);
// Select the answer based on the sign of Step.
Value *EndCheck =
Builder.CreateSelect(StepCompare, EndCompareGT, EndCompareLT);
// If the backedge taken count type is larger than the AR type,
// check that we don't drop any bits by truncating it. If we are
// droping bits, then we have overflow (unless the step is zero).
if (SE.getTypeSizeInBits(CountTy) > SE.getTypeSizeInBits(Ty)) {
auto MaxVal = APInt::getMaxValue(DstBits).zext(SrcBits);
auto *BackedgeCheck =
Builder.CreateICmp(ICmpInst::ICMP_UGT, TripCountVal,
ConstantInt::get(Loc->getContext(), MaxVal));
BackedgeCheck = Builder.CreateAnd(
BackedgeCheck, Builder.CreateICmp(ICmpInst::ICMP_NE, StepValue, Zero));
EndCheck = Builder.CreateOr(EndCheck, BackedgeCheck);
}
EndCheck = Builder.CreateOr(EndCheck, OfMul);
return EndCheck;
}
Value *SCEVExpander::expandWrapPredicate(const SCEVWrapPredicate *Pred,

View File

@ -38,9 +38,32 @@ target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
; LV-LABEL: f1
; LV-LABEL: for.body.lver.check
; LV: [[PredCheck0:%[^ ]*]] = icmp ne i128
; LV: [[Or0:%[^ ]*]] = or i1 false, [[PredCheck0]]
; LV: [[PredCheck1:%[^ ]*]] = icmp ne i128
; LV: [[BETrunc:%[^ ]*]] = trunc i64 [[BE:%[^ ]*]] to i32
; LV-NEXT: [[OFMul:%[^ ]*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 2, i32 [[BETrunc]])
; LV-NEXT: [[OFMulResult:%[^ ]*]] = extractvalue { i32, i1 } [[OFMul]], 0
; LV-NEXT: [[OFMulOverflow:%[^ ]*]] = extractvalue { i32, i1 } [[OFMul]], 1
; LV-NEXT: [[AddEnd:%[^ ]*]] = add i32 0, [[OFMulResult]]
; LV-NEXT: [[SubEnd:%[^ ]*]] = sub i32 0, [[OFMulResult]]
; LV-NEXT: [[CmpNeg:%[^ ]*]] = icmp ugt i32 [[SubEnd]], 0
; LV-NEXT: [[CmpPos:%[^ ]*]] = icmp ult i32 [[AddEnd]], 0
; LV-NEXT: [[Cmp:%[^ ]*]] = select i1 false, i1 [[CmpNeg]], i1 [[CmpPos]]
; LV-NEXT: [[BECheck:%[^ ]*]] = icmp ugt i64 [[BE]], 4294967295
; LV-NEXT: [[CheckOr0:%[^ ]*]] = or i1 [[Cmp]], [[BECheck]]
; LV-NEXT: [[PredCheck0:%[^ ]*]] = or i1 [[CheckOr0]], [[OFMulOverflow]]
; LV-NEXT: [[Or0:%[^ ]*]] = or i1 false, [[PredCheck0]]
; LV-NEXT: [[OFMul1:%[^ ]*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 4, i64 [[BE]])
; LV-NEXT: [[OFMulResult1:%[^ ]*]] = extractvalue { i64, i1 } [[OFMul1]], 0
; LV-NEXT: [[OFMulOverflow1:%[^ ]*]] = extractvalue { i64, i1 } [[OFMul1]], 1
; LV-NEXT: [[AddEnd1:%[^ ]*]] = add i64 %a2, [[OFMulResult1]]
; LV-NEXT: [[SubEnd1:%[^ ]*]] = sub i64 %a2, [[OFMulResult1]]
; LV-NEXT: [[CmpNeg1:%[^ ]*]] = icmp ugt i64 [[SubEnd1]], %a2
; LV-NEXT: [[CmpPos1:%[^ ]*]] = icmp ult i64 [[AddEnd1]], %a2
; LV-NEXT: [[Cmp:%[^ ]*]] = select i1 false, i1 [[CmpNeg1]], i1 [[CmpPos1]]
; LV-NEXT: [[PredCheck1:%[^ ]*]] = or i1 [[Cmp]], [[OFMulOverflow1]]
; LV: [[FinalCheck:%[^ ]*]] = or i1 [[Or0]], [[PredCheck1]]
; LV: br i1 [[FinalCheck]], label %for.body.ph.lver.orig, label %for.body.ph
define void @f1(i16* noalias %a,
@ -111,9 +134,31 @@ for.end: ; preds = %for.body
; LV-LABEL: f2
; LV-LABEL: for.body.lver.check
; LV: [[PredCheck0:%[^ ]*]] = icmp ne i128
; LV: [[Or0:%[^ ]*]] = or i1 false, [[PredCheck0]]
; LV: [[PredCheck1:%[^ ]*]] = icmp ne i128
; LV: [[OFMul:%[^ ]*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 2, i32 [[BETrunc:%[^ ]*]])
; LV-NEXT: [[OFMulResult:%[^ ]*]] = extractvalue { i32, i1 } [[OFMul]], 0
; LV-NEXT: [[OFMulOverflow:%[^ ]*]] = extractvalue { i32, i1 } [[OFMul]], 1
; LV-NEXT: [[AddEnd:%[^ ]*]] = add i32 [[Start:%[^ ]*]], [[OFMulResult]]
; LV-NEXT: [[SubEnd:%[^ ]*]] = sub i32 [[Start]], [[OFMulResult]]
; LV-NEXT: [[CmpNeg:%[^ ]*]] = icmp ugt i32 [[SubEnd]], [[Start]]
; LV-NEXT: [[CmpPos:%[^ ]*]] = icmp ult i32 [[AddEnd]], [[Start]]
; LV-NEXT: [[Cmp:%[^ ]*]] = select i1 true, i1 [[CmpNeg]], i1 [[CmpPos]]
; LV-NEXT: [[BECheck:%[^ ]*]] = icmp ugt i64 [[BE]], 4294967295
; LV-NEXT: [[CheckOr0:%[^ ]*]] = or i1 [[Cmp]], [[BECheck]]
; LV-NEXT: [[PredCheck0:%[^ ]*]] = or i1 [[CheckOr0]], [[OFMulOverflow]]
; LV-NEXT: [[Or0:%[^ ]*]] = or i1 false, [[PredCheck0]]
; LV: [[OFMul1:%[^ ]*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 4, i64 [[BE]])
; LV-NEXT: [[OFMulResult1:%[^ ]*]] = extractvalue { i64, i1 } [[OFMul1]], 0
; LV-NEXT: [[OFMulOverflow1:%[^ ]*]] = extractvalue { i64, i1 } [[OFMul1]], 1
; LV-NEXT: [[AddEnd1:%[^ ]*]] = add i64 [[Start:%[^ ]*]], [[OFMulResult1]]
; LV-NEXT: [[SubEnd1:%[^ ]*]] = sub i64 [[Start]], [[OFMulResult1]]
; LV-NEXT: [[CmpNeg1:%[^ ]*]] = icmp ugt i64 [[SubEnd1]], [[Start]]
; LV-NEXT: [[CmpPos1:%[^ ]*]] = icmp ult i64 [[AddEnd1]], [[Start]]
; LV-NEXT: [[Cmp:%[^ ]*]] = select i1 true, i1 [[CmpNeg1]], i1 [[CmpPos1]]
; LV-NEXT: [[PredCheck1:%[^ ]*]] = or i1 [[Cmp]], [[OFMulOverflow1]]
; LV: [[FinalCheck:%[^ ]*]] = or i1 [[Or0]], [[PredCheck1]]
; LV: br i1 [[FinalCheck]], label %for.body.ph.lver.orig, label %for.body.ph
define void @f2(i16* noalias %a,
@ -169,9 +214,31 @@ for.end: ; preds = %for.body
; LV-LABEL: f3
; LV-LABEL: for.body.lver.check
; LV: [[PredCheck0:%[^ ]*]] = icmp ne i128
; LV: [[Or0:%[^ ]*]] = or i1 false, [[PredCheck0]]
; LV: [[PredCheck1:%[^ ]*]] = icmp ne i128
; LV: [[OFMul:%[^ ]*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 2, i32 [[BETrunc:%[^ ]*]])
; LV-NEXT: [[OFMulResult:%[^ ]*]] = extractvalue { i32, i1 } [[OFMul]], 0
; LV-NEXT: [[OFMulOverflow:%[^ ]*]] = extractvalue { i32, i1 } [[OFMul]], 1
; LV-NEXT: [[AddEnd:%[^ ]*]] = add i32 0, [[OFMulResult]]
; LV-NEXT: [[SubEnd:%[^ ]*]] = sub i32 0, [[OFMulResult]]
; LV-NEXT: [[CmpNeg:%[^ ]*]] = icmp sgt i32 [[SubEnd]], 0
; LV-NEXT: [[CmpPos:%[^ ]*]] = icmp slt i32 [[AddEnd]], 0
; LV-NEXT: [[Cmp:%[^ ]*]] = select i1 false, i1 [[CmpNeg]], i1 [[CmpPos]]
; LV-NEXT: [[BECheck:%[^ ]*]] = icmp ugt i64 [[BE]], 4294967295
; LV-NEXT: [[CheckOr0:%[^ ]*]] = or i1 [[Cmp]], [[BECheck]]
; LV-NEXT: [[PredCheck0:%[^ ]*]] = or i1 [[CheckOr0]], [[OFMulOverflow]]
; LV-NEXT: [[Or0:%[^ ]*]] = or i1 false, [[PredCheck0]]
; LV: [[OFMul1:%[^ ]*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 4, i64 [[BE:%[^ ]*]])
; LV-NEXT: [[OFMulResult1:%[^ ]*]] = extractvalue { i64, i1 } [[OFMul1]], 0
; LV-NEXT: [[OFMulOverflow1:%[^ ]*]] = extractvalue { i64, i1 } [[OFMul1]], 1
; LV-NEXT: [[AddEnd1:%[^ ]*]] = add i64 %a2, [[OFMulResult1]]
; LV-NEXT: [[SubEnd1:%[^ ]*]] = sub i64 %a2, [[OFMulResult1]]
; LV-NEXT: [[CmpNeg1:%[^ ]*]] = icmp ugt i64 [[SubEnd1]], %a2
; LV-NEXT: [[CmpPos1:%[^ ]*]] = icmp ult i64 [[AddEnd1]], %a2
; LV-NEXT: [[Cmp:%[^ ]*]] = select i1 false, i1 [[CmpNeg1]], i1 [[CmpPos1]]
; LV-NEXT: [[PredCheck1:%[^ ]*]] = or i1 [[Cmp]], [[OFMulOverflow1]]
; LV: [[FinalCheck:%[^ ]*]] = or i1 [[Or0]], [[PredCheck1]]
; LV: br i1 [[FinalCheck]], label %for.body.ph.lver.orig, label %for.body.ph
define void @f3(i16* noalias %a,
@ -223,9 +290,31 @@ for.end: ; preds = %for.body
; LV-LABEL: f4
; LV-LABEL: for.body.lver.check
; LV: [[PredCheck0:%[^ ]*]] = icmp ne i128
; LV: [[Or0:%[^ ]*]] = or i1 false, [[PredCheck0]]
; LV: [[PredCheck1:%[^ ]*]] = icmp ne i128
; LV: [[OFMul:%[^ ]*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 2, i32 [[BETrunc:%[^ ]*]])
; LV-NEXT: [[OFMulResult:%[^ ]*]] = extractvalue { i32, i1 } [[OFMul]], 0
; LV-NEXT: [[OFMulOverflow:%[^ ]*]] = extractvalue { i32, i1 } [[OFMul]], 1
; LV-NEXT: [[AddEnd:%[^ ]*]] = add i32 [[Start:%[^ ]*]], [[OFMulResult]]
; LV-NEXT: [[SubEnd:%[^ ]*]] = sub i32 [[Start]], [[OFMulResult]]
; LV-NEXT: [[CmpNeg:%[^ ]*]] = icmp sgt i32 [[SubEnd]], [[Start]]
; LV-NEXT: [[CmpPos:%[^ ]*]] = icmp slt i32 [[AddEnd]], [[Start]]
; LV-NEXT: [[Cmp:%[^ ]*]] = select i1 true, i1 [[CmpNeg]], i1 [[CmpPos]]
; LV-NEXT: [[BECheck:%[^ ]*]] = icmp ugt i64 [[BE]], 4294967295
; LV-NEXT: [[CheckOr0:%[^ ]*]] = or i1 [[Cmp]], [[BECheck]]
; LV-NEXT: [[PredCheck0:%[^ ]*]] = or i1 [[CheckOr0]], [[OFMulOverflow]]
; LV-NEXT: [[Or0:%[^ ]*]] = or i1 false, [[PredCheck0]]
; LV: [[OFMul1:%[^ ]*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 4, i64 [[BE:%[^ ]*]])
; LV-NEXT: [[OFMulResult1:%[^ ]*]] = extractvalue { i64, i1 } [[OFMul1]], 0
; LV-NEXT: [[OFMulOverflow1:%[^ ]*]] = extractvalue { i64, i1 } [[OFMul1]], 1
; LV-NEXT: [[AddEnd1:%[^ ]*]] = add i64 [[Start:%[^ ]*]], [[OFMulResult1]]
; LV-NEXT: [[SubEnd1:%[^ ]*]] = sub i64 [[Start]], [[OFMulResult1]]
; LV-NEXT: [[CmpNeg1:%[^ ]*]] = icmp ugt i64 [[SubEnd1]], [[Start]]
; LV-NEXT: [[CmpPos1:%[^ ]*]] = icmp ult i64 [[AddEnd1]], [[Start]]
; LV-NEXT: [[Cmp:%[^ ]*]] = select i1 true, i1 [[CmpNeg1]], i1 [[CmpPos1]]
; LV-NEXT: [[PredCheck1:%[^ ]*]] = or i1 [[Cmp]], [[OFMulOverflow1]]
; LV: [[FinalCheck:%[^ ]*]] = or i1 [[Or0]], [[PredCheck1]]
; LV: br i1 [[FinalCheck]], label %for.body.ph.lver.orig, label %for.body.ph
define void @f4(i16* noalias %a,
@ -280,6 +369,32 @@ for.end: ; preds = %for.body
; LV-LABEL: f5
; LV-LABEL: for.body.lver.check
; LV: [[OFMul:%[^ ]*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 2, i32 [[BETrunc:%[^ ]*]])
; LV-NEXT: [[OFMulResult:%[^ ]*]] = extractvalue { i32, i1 } [[OFMul]], 0
; LV-NEXT: [[OFMulOverflow:%[^ ]*]] = extractvalue { i32, i1 } [[OFMul]], 1
; LV-NEXT: [[AddEnd:%[^ ]*]] = add i32 [[Start:%[^ ]*]], [[OFMulResult]]
; LV-NEXT: [[SubEnd:%[^ ]*]] = sub i32 [[Start]], [[OFMulResult]]
; LV-NEXT: [[CmpNeg:%[^ ]*]] = icmp sgt i32 [[SubEnd]], [[Start]]
; LV-NEXT: [[CmpPos:%[^ ]*]] = icmp slt i32 [[AddEnd]], [[Start]]
; LV-NEXT: [[Cmp:%[^ ]*]] = select i1 true, i1 [[CmpNeg]], i1 [[CmpPos]]
; LV-NEXT: [[BECheck:%[^ ]*]] = icmp ugt i64 [[BE]], 4294967295
; LV-NEXT: [[CheckOr0:%[^ ]*]] = or i1 [[Cmp]], [[BECheck]]
; LV-NEXT: [[PredCheck0:%[^ ]*]] = or i1 [[CheckOr0]], [[OFMulOverflow]]
; LV-NEXT: [[Or0:%[^ ]*]] = or i1 false, [[PredCheck0]]
; LV: [[OFMul1:%[^ ]*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 4, i64 [[BE:%[^ ]*]])
; LV-NEXT: [[OFMulResult1:%[^ ]*]] = extractvalue { i64, i1 } [[OFMul1]], 0
; LV-NEXT: [[OFMulOverflow1:%[^ ]*]] = extractvalue { i64, i1 } [[OFMul1]], 1
; LV-NEXT: [[AddEnd1:%[^ ]*]] = add i64 [[Start:%[^ ]*]], [[OFMulResult1]]
; LV-NEXT: [[SubEnd1:%[^ ]*]] = sub i64 [[Start]], [[OFMulResult1]]
; LV-NEXT: [[CmpNeg1:%[^ ]*]] = icmp ugt i64 [[SubEnd1]], [[Start]]
; LV-NEXT: [[CmpPos1:%[^ ]*]] = icmp ult i64 [[AddEnd1]], [[Start]]
; LV-NEXT: [[Cmp:%[^ ]*]] = select i1 true, i1 [[CmpNeg1]], i1 [[CmpPos1]]
; LV-NEXT: [[PredCheck1:%[^ ]*]] = or i1 [[Cmp]], [[OFMulOverflow1]]
; LV: [[FinalCheck:%[^ ]*]] = or i1 [[Or0]], [[PredCheck1]]
; LV: br i1 [[FinalCheck]], label %for.body.ph.lver.orig, label %for.body.ph
define void @f5(i16* noalias %a,
i16* noalias %b, i64 %N) {
entry: