diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 9e5bb00e6bb1..7650c73f454c 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -672,6 +672,13 @@ namespace llvm { const SCEV *MaxBECount, unsigned BitWidth); + /// Try to compute a range for the affine SCEVAddRecExpr {\p Start,+,\p + /// Stop} by "factoring out" a ternary expression from the add recurrence. + /// Helper called by \c getRange. + ConstantRange getRangeViaFactoring(const SCEV *Start, const SCEV *Stop, + const SCEV *MaxBECount, + unsigned BitWidth); + /// We know that there is no SCEV for the specified value. Analyze the /// expression. const SCEV *createSCEV(Value *V); diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 3ba790c46e44..d64f8367b608 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -4400,6 +4400,13 @@ ScalarEvolution::getRange(const SCEV *S, if (!RangeFromAffine.isFullSet()) ConservativeResult = ConservativeResult.intersectWith(RangeFromAffine); + + auto RangeFromFactoring = getRangeViaFactoring( + AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount, + BitWidth); + if (!RangeFromFactoring.isFullSet()) + ConservativeResult = + ConservativeResult.intersectWith(RangeFromFactoring); } } @@ -4504,6 +4511,82 @@ ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, return Result; } +ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, + const SCEV *Step, + const SCEV *MaxBECount, + unsigned BitWidth) { + APInt Offset(BitWidth, 0); + + if (auto *SA = dyn_cast(Start)) { + // Peel off a constant offset, if possible. In the future we could consider + // being smarter here and handle {Start+Step,+,Step} too. + if (SA->getNumOperands() != 2 || !isa(SA->getOperand(0))) + return ConstantRange(BitWidth, /* isFullSet = */ true); + Offset = cast(SA->getOperand(0))->getAPInt(); + Start = SA->getOperand(1); + } + + if (!isa(Start) || !isa(Step)) + // We don't have anything new to contribute in this case. + return ConstantRange(BitWidth, /* isFullSet = */ true); + + // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q}) + // == RangeOf({A,+,P}) union RangeOf({B,+,Q}) + + struct SelectPattern { + Value *Condition = nullptr; + const APInt *TrueValue = nullptr; + const APInt *FalseValue = nullptr; + + explicit SelectPattern(const SCEVUnknown *SU) { + using namespace llvm::PatternMatch; + + if (!match(SU->getValue(), + m_Select(m_Value(Condition), m_APInt(TrueValue), + m_APInt(FalseValue)))) { + Condition = nullptr; + TrueValue = FalseValue = nullptr; + } + } + + bool isRecognized() { + assert(((Condition && TrueValue && FalseValue) || + (!Condition && !TrueValue && !FalseValue)) && + "Invariant: either all three are non-null or all three are null"); + return TrueValue != nullptr; + } + }; + + SelectPattern StartPattern(cast(Start)); + if (!StartPattern.isRecognized()) + return ConstantRange(BitWidth, /* isFullSet = */ true); + + SelectPattern StepPattern(cast(Step)); + if (!StepPattern.isRecognized()) + return ConstantRange(BitWidth, /* isFullSet = */ true); + + if (StartPattern.Condition != StepPattern.Condition) { + // We don't handle this case today; but we could, by considering four + // possibilities below instead of two. I'm not sure if there are cases where + // that will help over what getRange already does, though. + return ConstantRange(BitWidth, /* isFullSet = */ true); + } + + // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to + // construct arbitrary general SCEV expressions here. This function is called + // from deep in the call stack, and calling getSCEV (on a sext instruction, + // say) can end up caching a suboptimal value. + + auto TrueRange = getRangeForAffineAR( + getConstant(*StartPattern.TrueValue + Offset), + getConstant(*StepPattern.TrueValue), MaxBECount, BitWidth); + auto FalseRange = getRangeForAffineAR( + getConstant(*StartPattern.FalseValue + Offset), + getConstant(*StepPattern.FalseValue), MaxBECount, BitWidth); + + return TrueRange.unionWith(FalseRange); +} + SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) { if (isa(V)) return SCEV::FlagAnyWrap; const BinaryOperator *BinOp = cast(V); diff --git a/llvm/test/Analysis/ScalarEvolution/increasing-and-decreasing-range.ll b/llvm/test/Analysis/ScalarEvolution/increasing-and-decreasing-range.ll new file mode 100644 index 000000000000..5a83306cc063 --- /dev/null +++ b/llvm/test/Analysis/ScalarEvolution/increasing-and-decreasing-range.ll @@ -0,0 +1,88 @@ +; RUN: opt -analyze -scalar-evolution < %s | FileCheck %s + +define void @f0(i1 %c) { +; CHECK-LABEL: Classifying expressions for: @f0 +entry: + %start = select i1 %c, i32 127, i32 0 + %step = select i1 %c, i32 -1, i32 1 + br label %loop + +loop: + %loop.iv = phi i32 [ 0, %entry ], [ %loop.iv.inc, %loop ] + %iv = phi i32 [ %start, %entry ], [ %iv.next, %loop ] +; CHECK: %iv = phi i32 [ %start, %entry ], [ %iv.next, %loop ] +; CHECK-NEXT: --> {%start,+,%step}<%loop> U: [0,128) S: [0,128) + %iv.next = add i32 %iv, %step + %loop.iv.inc = add i32 %loop.iv, 1 + %be.cond = icmp ne i32 %loop.iv.inc, 128 + br i1 %be.cond, label %loop, label %leave + +leave: + ret void +} + +define void @f1(i1 %c) { +; CHECK-LABEL: Classifying expressions for: @f1 +entry: + %start = select i1 %c, i32 120, i32 0 + %step = select i1 %c, i32 -8, i32 8 + br label %loop + +loop: + %loop.iv = phi i32 [ 0, %entry ], [ %loop.iv.inc, %loop ] + %iv = phi i32 [ %start, %entry ], [ %iv.next, %loop ] + +; CHECK: %iv.1 = add i32 %iv, 1 +; CHECK-NEXT: --> {(1 + %start),+,%step}<%loop> U: [1,122) S: [1,122) +; CHECK: %iv.2 = add i32 %iv, 2 +; CHECK-NEXT: --> {(2 + %start),+,%step}<%loop> U: [2,123) S: [2,123) +; CHECK: %iv.3 = add i32 %iv, 3 +; CHECK-NEXT: --> {(3 + %start),+,%step}<%loop> U: [3,124) S: [3,124) +; CHECK: %iv.4 = add i32 %iv, 4 +; CHECK-NEXT: --> {(4 + %start),+,%step}<%loop> U: [4,125) S: [4,125) +; CHECK: %iv.5 = add i32 %iv, 5 +; CHECK-NEXT: --> {(5 + %start),+,%step}<%loop> U: [5,126) S: [5,126) +; CHECK: %iv.6 = add i32 %iv, 6 +; CHECK-NEXT: --> {(6 + %start),+,%step}<%loop> U: [6,127) S: [6,127) +; CHECK: %iv.7 = add i32 %iv, 7 +; CHECK-NEXT: --> {(7 + %start),+,%step}<%loop> U: [7,128) S: [7,128) + + %iv.1 = add i32 %iv, 1 + %iv.2 = add i32 %iv, 2 + %iv.3 = add i32 %iv, 3 + %iv.4 = add i32 %iv, 4 + %iv.5 = add i32 %iv, 5 + %iv.6 = add i32 %iv, 6 + %iv.7 = add i32 %iv, 7 + +; CHECK: %iv.m1 = sub i32 %iv, 1 +; CHECK-NEXT: --> {(-1 + %start),+,%step}<%loop> U: [-1,120) S: [-1,120) +; CHECK: %iv.m2 = sub i32 %iv, 2 +; CHECK-NEXT: --> {(-2 + %start),+,%step}<%loop> U: [-2,119) S: [-2,119) +; CHECK: %iv.m3 = sub i32 %iv, 3 +; CHECK-NEXT: --> {(-3 + %start),+,%step}<%loop> U: [-3,118) S: [-3,118) +; CHECK: %iv.m4 = sub i32 %iv, 4 +; CHECK-NEXT: --> {(-4 + %start),+,%step}<%loop> U: [-4,117) S: [-4,117) +; CHECK: %iv.m5 = sub i32 %iv, 5 +; CHECK-NEXT: --> {(-5 + %start),+,%step}<%loop> U: [-5,116) S: [-5,116) +; CHECK: %iv.m6 = sub i32 %iv, 6 +; CHECK-NEXT: --> {(-6 + %start),+,%step}<%loop> U: [-6,115) S: [-6,115) +; CHECK: %iv.m7 = sub i32 %iv, 7 +; CHECK-NEXT: --> {(-7 + %start),+,%step}<%loop> U: [-7,114) S: [-7,114) + + %iv.m1 = sub i32 %iv, 1 + %iv.m2 = sub i32 %iv, 2 + %iv.m3 = sub i32 %iv, 3 + %iv.m4 = sub i32 %iv, 4 + %iv.m5 = sub i32 %iv, 5 + %iv.m6 = sub i32 %iv, 6 + %iv.m7 = sub i32 %iv, 7 + + %iv.next = add i32 %iv, %step + %loop.iv.inc = add i32 %loop.iv, 1 + %be.cond = icmp sgt i32 %loop.iv, 14 + br i1 %be.cond, label %leave, label %loop + +leave: + ret void +}