[ValueTracking] Use assumptions in computeConstantRange.

This patch updates computeConstantRange to optionally take an assumption
cache as argument and use the available assumptions to limit the range
of the result.

Currently this is limited to assumptions that are comparisons.

Reviewers: reames, nikic, spatel, jdoerfert, lebedev.ri

Reviewed By: nikic

Differential Revision: https://reviews.llvm.org/D76193
This commit is contained in:
Florian Hahn 2020-05-22 19:16:15 +01:00
parent 2833c46f75
commit 8d04181198
3 changed files with 211 additions and 9 deletions

View File

@ -531,7 +531,10 @@ class Value;
/// Determine the possible constant range of an integer or vector of integer
/// value. This is intended as a cheap, non-recursive check.
ConstantRange computeConstantRange(const Value *V, bool UseInstrInfo = true);
ConstantRange computeConstantRange(const Value *V, bool UseInstrInfo = true,
AssumptionCache *AC = nullptr,
const Instruction *CtxI = nullptr,
unsigned Depth = 0);
/// Return true if this function can prove that the instruction I will
/// always transfer execution to one of its successors (including the next

View File

@ -6367,9 +6367,15 @@ static void setLimitsForSelectPattern(const SelectInst &SI, APInt &Lower,
}
}
ConstantRange llvm::computeConstantRange(const Value *V, bool UseInstrInfo) {
ConstantRange llvm::computeConstantRange(const Value *V, bool UseInstrInfo,
AssumptionCache *AC,
const Instruction *CtxI,
unsigned Depth) {
assert(V->getType()->isIntOrIntVectorTy() && "Expected integer instruction");
if (Depth == MaxDepth)
return ConstantRange::getFull(V->getType()->getScalarSizeInBits());
const APInt *C;
if (match(V, m_APInt(C)))
return ConstantRange(*C);
@ -6391,6 +6397,31 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool UseInstrInfo) {
if (auto *Range = IIQ.getMetadata(I, LLVMContext::MD_range))
CR = CR.intersectWith(getConstantRangeFromMetadata(*Range));
if (CtxI && AC) {
// Try to restrict the range based on information from assumptions.
for (auto &AssumeVH : AC->assumptionsFor(V)) {
if (!AssumeVH)
continue;
CallInst *I = cast<CallInst>(AssumeVH);
assert(I->getParent()->getParent() == CtxI->getParent()->getParent() &&
"Got assumption for the wrong function!");
assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume &&
"must be an assume intrinsic");
if (!isValidAssumeForContext(I, CtxI, nullptr))
continue;
Value *Arg = I->getArgOperand(0);
ICmpInst *Cmp = dyn_cast<ICmpInst>(Arg);
// Currently we just use information from comparisons.
if (!Cmp || Cmp->getOperand(0) != V)
continue;
ConstantRange RHS = computeConstantRange(Cmp->getOperand(1), UseInstrInfo,
AC, I, Depth + 1);
CR = CR.intersectWith(
ConstantRange::makeSatisfyingICmpRegion(Cmp->getPredicate(), RHS));
}
}
return CR;
}

View File

@ -9,6 +9,7 @@
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
@ -23,6 +24,14 @@ using namespace llvm;
namespace {
static Instruction &findInstructionByName(Function *F, StringRef Name) {
for (Instruction &I : instructions(F))
if (I.getName() == Name)
return I;
llvm_unreachable("Expected value not found");
}
class ValueTrackingTest : public testing::Test {
protected:
std::unique_ptr<Module> parseModule(StringRef Assembly) {
@ -46,13 +55,7 @@ protected:
if (!F)
return;
A = nullptr;
for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
if (I->hasName()) {
if (I->getName() == "A")
A = &*I;
}
}
A = &findInstructionByName(F, "A");
ASSERT_TRUE(A) << "@test must have an instruction %A";
}
@ -1246,3 +1249,168 @@ TEST_P(IsBytewiseValueTest, IsBytewiseValue) {
S << *Actual;
EXPECT_EQ(GetParam().first, S.str());
}
TEST_F(ValueTrackingTest, ComputeConstantRange) {
{
// Assumptions:
// * stride >= 5
// * stride < 10
//
// stride = [5, 10)
auto M = parseModule(R"(
declare void @llvm.assume(i1)
define i32 @test(i32 %stride) {
%gt = icmp uge i32 %stride, 5
call void @llvm.assume(i1 %gt)
%lt = icmp ult i32 %stride, 10
call void @llvm.assume(i1 %lt)
%stride.plus.one = add nsw nuw i32 %stride, 1
ret i32 %stride.plus.one
})");
Function *F = M->getFunction("test");
AssumptionCache AC(*F);
Value *Stride = &*F->arg_begin();
ConstantRange CR1 = computeConstantRange(Stride, true, &AC, nullptr);
EXPECT_TRUE(CR1.isFullSet());
Instruction *I = &findInstructionByName(F, "stride.plus.one");
ConstantRange CR2 = computeConstantRange(Stride, true, &AC, I);
EXPECT_EQ(5, CR2.getLower());
EXPECT_EQ(10, CR2.getUpper());
}
{
// Assumptions:
// * stride >= 5
// * stride < 200
// * stride == 99
//
// stride = [99, 100)
auto M = parseModule(R"(
declare void @llvm.assume(i1)
define i32 @test(i32 %stride) {
%gt = icmp uge i32 %stride, 5
call void @llvm.assume(i1 %gt)
%lt = icmp ult i32 %stride, 200
call void @llvm.assume(i1 %lt)
%eq = icmp eq i32 %stride, 99
call void @llvm.assume(i1 %eq)
%stride.plus.one = add nsw nuw i32 %stride, 1
ret i32 %stride.plus.one
})");
Function *F = M->getFunction("test");
AssumptionCache AC(*F);
Value *Stride = &*F->arg_begin();
Instruction *I = &findInstructionByName(F, "stride.plus.one");
ConstantRange CR = computeConstantRange(Stride, true, &AC, I);
EXPECT_EQ(99, *CR.getSingleElement());
}
{
// Assumptions:
// * stride >= 5
// * stride >= 50
// * stride < 100
// * stride < 200
//
// stride = [50, 100)
auto M = parseModule(R"(
declare void @llvm.assume(i1)
define i32 @test(i32 %stride, i1 %cond) {
%gt = icmp uge i32 %stride, 5
call void @llvm.assume(i1 %gt)
%gt.2 = icmp uge i32 %stride, 50
call void @llvm.assume(i1 %gt.2)
br i1 %cond, label %bb1, label %bb2
bb1:
%lt = icmp ult i32 %stride, 200
call void @llvm.assume(i1 %lt)
%lt.2 = icmp ult i32 %stride, 100
call void @llvm.assume(i1 %lt.2)
%stride.plus.one = add nsw nuw i32 %stride, 1
ret i32 %stride.plus.one
bb2:
ret i32 0
})");
Function *F = M->getFunction("test");
AssumptionCache AC(*F);
Value *Stride = &*F->arg_begin();
Instruction *GT2 = &findInstructionByName(F, "gt.2");
ConstantRange CR = computeConstantRange(Stride, true, &AC, GT2);
EXPECT_EQ(5, CR.getLower());
EXPECT_EQ(0, CR.getUpper());
Instruction *I = &findInstructionByName(F, "stride.plus.one");
ConstantRange CR2 = computeConstantRange(Stride, true, &AC, I);
EXPECT_EQ(50, CR2.getLower());
EXPECT_EQ(100, CR2.getUpper());
}
{
// Assumptions:
// * stride > 5
// * stride < 5
//
// stride = empty range, as the assumptions contradict each other.
auto M = parseModule(R"(
declare void @llvm.assume(i1)
define i32 @test(i32 %stride, i1 %cond) {
%gt = icmp ugt i32 %stride, 5
call void @llvm.assume(i1 %gt)
%lt = icmp ult i32 %stride, 5
call void @llvm.assume(i1 %lt)
%stride.plus.one = add nsw nuw i32 %stride, 1
ret i32 %stride.plus.one
})");
Function *F = M->getFunction("test");
AssumptionCache AC(*F);
Value *Stride = &*F->arg_begin();
Instruction *I = &findInstructionByName(F, "stride.plus.one");
ConstantRange CR = computeConstantRange(Stride, true, &AC, I);
EXPECT_TRUE(CR.isEmptySet());
}
{
// Assumptions:
// * x.1 >= 5
// * x.2 < x.1
//
// stride = [0, 5)
auto M = parseModule(R"(
declare void @llvm.assume(i1)
define i32 @test(i32 %x.1, i32 %x.2) {
%gt = icmp uge i32 %x.1, 5
call void @llvm.assume(i1 %gt)
%lt = icmp ult i32 %x.2, %x.1
call void @llvm.assume(i1 %lt)
%stride.plus.one = add nsw nuw i32 %x.1, 1
ret i32 %stride.plus.one
})");
Function *F = M->getFunction("test");
AssumptionCache AC(*F);
Value *X2 = &*std::next(F->arg_begin());
Instruction *I = &findInstructionByName(F, "stride.plus.one");
ConstantRange CR1 = computeConstantRange(X2, true, &AC, I);
EXPECT_EQ(0, CR1.getLower());
EXPECT_EQ(5, CR1.getUpper());
// Check the depth cutoff results in a conservative result (full set) by
// passing Depth == MaxDepth == 6.
ConstantRange CR2 = computeConstantRange(X2, true, &AC, I, 6);
EXPECT_TRUE(CR2.isFullSet());
}
}