[BasicAA] Be more careful with modulo ops on VariableGEPIndex.

(V * Scale) % X may not produce the same result for any possible value
of V, e.g. if the multiplication overflows. This means we currently
incorrectly determine NoAlias in some cases.

This patch updates LinearExpression to track whether the expression
has NSW and uses that to adjust the scale used for alias checks.

Reviewed By: nikic

Differential Revision: https://reviews.llvm.org/D99424
This commit is contained in:
Florian Hahn 2021-06-29 08:56:50 +01:00
parent 51d969dc27
commit 91fa3565da
No known key found for this signature in database
GPG Key ID: 61D7554B5CECDC0D
3 changed files with 46 additions and 26 deletions

View File

@ -116,6 +116,9 @@ private:
// Context instruction to use when querying information about this index.
const Instruction *CxtI;
/// True if all operations in this expression are NSW.
bool IsNSW;
void dump() const {
print(dbgs());
dbgs() << "\n";

View File

@ -284,11 +284,14 @@ struct LinearExpression {
APInt Scale;
APInt Offset;
LinearExpression(const ExtendedValue &Val, const APInt &Scale,
const APInt &Offset)
: Val(Val), Scale(Scale), Offset(Offset) {}
/// True if all operations in this expression are NSW.
bool IsNSW;
LinearExpression(const ExtendedValue &Val) : Val(Val) {
LinearExpression(const ExtendedValue &Val, const APInt &Scale,
const APInt &Offset, bool IsNSW)
: Val(Val), Scale(Scale), Offset(Offset), IsNSW(IsNSW) {}
LinearExpression(const ExtendedValue &Val) : Val(Val), IsNSW(true) {
unsigned BitWidth = Val.getBitWidth();
Scale = APInt(BitWidth, 1);
Offset = APInt(BitWidth, 0);
@ -307,7 +310,7 @@ static LinearExpression GetLinearExpression(
if (const ConstantInt *Const = dyn_cast<ConstantInt>(Val.V))
return LinearExpression(Val, APInt(Val.getBitWidth(), 0),
Val.evaluateWith(Const->getValue()));
Val.evaluateWith(Const->getValue()), true);
if (const BinaryOperator *BOp = dyn_cast<BinaryOperator>(Val.V)) {
if (ConstantInt *RHSC = dyn_cast<ConstantInt>(BOp->getOperand(1))) {
@ -322,6 +325,7 @@ static LinearExpression GetLinearExpression(
if (!Val.canDistributeOver(NUW, NSW))
return Val;
LinearExpression E(Val);
switch (BOp->getOpcode()) {
default:
// We don't understand this instruction, so we can't decompose it any
@ -336,23 +340,26 @@ static LinearExpression GetLinearExpression(
LLVM_FALLTHROUGH;
case Instruction::Add: {
LinearExpression E = GetLinearExpression(
Val.withValue(BOp->getOperand(0)), DL, Depth + 1, AC, DT);
E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
Depth + 1, AC, DT);
E.Offset += RHS;
return E;
E.IsNSW &= NSW;
break;
}
case Instruction::Sub: {
LinearExpression E = GetLinearExpression(
Val.withValue(BOp->getOperand(0)), DL, Depth + 1, AC, DT);
E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
Depth + 1, AC, DT);
E.Offset -= RHS;
return E;
E.IsNSW &= NSW;
break;
}
case Instruction::Mul: {
LinearExpression E = GetLinearExpression(
Val.withValue(BOp->getOperand(0)), DL, Depth + 1, AC, DT);
E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
Depth + 1, AC, DT);
E.Offset *= RHS;
E.Scale *= RHS;
return E;
E.IsNSW &= NSW;
break;
}
case Instruction::Shl:
// We're trying to linearize an expression of the kind:
@ -363,12 +370,14 @@ static LinearExpression GetLinearExpression(
if (RHS.getLimitedValue() > Val.getBitWidth())
return Val;
LinearExpression E = GetLinearExpression(
Val.withValue(BOp->getOperand(0)), DL, Depth + 1, AC, DT);
E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
Depth + 1, AC, DT);
E.Offset <<= RHS.getLimitedValue();
E.Scale <<= RHS.getLimitedValue();
return E;
E.IsNSW &= NSW;
break;
}
return E;
}
}
@ -578,8 +587,8 @@ BasicAAResult::DecomposeGEPExpression(const Value *V, const DataLayout &DL,
Scale = adjustToPointerSize(Scale, PointerSize);
if (!!Scale) {
VariableGEPIndex Entry = {LE.Val.V, LE.Val.ZExtBits, LE.Val.SExtBits,
Scale, CxtI};
VariableGEPIndex Entry = {
LE.Val.V, LE.Val.ZExtBits, LE.Val.SExtBits, Scale, CxtI, LE.IsNSW};
Decomposed.VarIndices.push_back(Entry);
}
}
@ -1138,7 +1147,11 @@ AliasResult BasicAAResult::aliasGEP(
bool AllNonNegative = DecompGEP1.Offset.isNonNegative();
bool AllNonPositive = DecompGEP1.Offset.isNonPositive();
for (unsigned i = 0, e = DecompGEP1.VarIndices.size(); i != e; ++i) {
const APInt &Scale = DecompGEP1.VarIndices[i].Scale;
APInt Scale = DecompGEP1.VarIndices[i].Scale;
if (!DecompGEP1.VarIndices[i].IsNSW)
Scale = APInt::getOneBitSet(Scale.getBitWidth(),
Scale.countTrailingZeros());
if (i == 0)
GCD = Scale.abs();
else
@ -1701,9 +1714,10 @@ void BasicAAResult::GetIndexDifference(
// If we found it, subtract off Scale V's from the entry in Dest. If it
// goes to zero, remove the entry.
if (Dest[j].Scale != Scale)
if (Dest[j].Scale != Scale) {
Dest[j].Scale -= Scale;
else
Dest[j].IsNSW = false;
} else
Dest.erase(Dest.begin() + j);
Scale = 0;
break;
@ -1711,7 +1725,8 @@ void BasicAAResult::GetIndexDifference(
// If we didn't consume this entry, add it to the end of the Dest list.
if (!!Scale) {
VariableGEPIndex Entry = {V, ZExtBits, SExtBits, -Scale, Src[i].CxtI};
VariableGEPIndex Entry = {V, ZExtBits, SExtBits,
-Scale, Src[i].CxtI, Src[i].IsNSW};
Dest.push_back(Entry);
}
}

View File

@ -70,7 +70,7 @@ define void @may_overflow_mul_sub_i64([16 x i8]* %ptr, i64 %idx) {
; CHECK-LABEL: Function: may_overflow_mul_sub_i64: 3 pointers, 0 call sites
; CHECK-NEXT: MayAlias: [16 x i8]* %ptr, i8* %gep.idx
; CHECK-NEXT: PartialAlias (off 3): [16 x i8]* %ptr, i8* %gep.3
; CHECK-NEXT: NoAlias: i8* %gep.3, i8* %gep.idx
; CHECK-NEXT: MayAlias: i8* %gep.3, i8* %gep.idx
;
%mul = mul i64 %idx, 5
%sub = sub i64 %mul, 1
@ -115,7 +115,7 @@ define void @only_nuw_mul_sub_i64([16 x i8]* %ptr, i64 %idx) {
; CHECK-LABEL: Function: only_nuw_mul_sub_i64: 3 pointers, 0 call sites
; CHECK-NEXT: MayAlias: [16 x i8]* %ptr, i8* %gep.idx
; CHECK-NEXT: PartialAlias (off 3): [16 x i8]* %ptr, i8* %gep.3
; CHECK-NEXT: NoAlias: i8* %gep.3, i8* %gep.idx
; CHECK-NEXT: MayAlias: i8* %gep.3, i8* %gep.idx
;
%mul = mul nuw i64 %idx, 5
%sub = sub nuw i64 %mul, 1
@ -126,6 +126,8 @@ define void @only_nuw_mul_sub_i64([16 x i8]* %ptr, i64 %idx) {
ret void
}
; Even though the mul and sub may overflow %gep.idx and %gep.3 cannot alias
; because we multiply by a power-of-2.
define void @may_overflow_mul_pow2_sub_i64([16 x i8]* %ptr, i64 %idx) {
; CHECK-LABEL: Function: may_overflow_mul_pow2_sub_i64: 3 pointers, 0 call sites
; CHECK-NEXT: MayAlias: [16 x i8]* %ptr, i8* %gep.idx
@ -259,7 +261,7 @@ define void @may_overflow_pointer_diff([16 x i8]* %ptr, i64 %idx) {
; CHECK-LABEL: Function: may_overflow_pointer_diff: 3 pointers, 0 call sites
; CHECK-NEXT: MayAlias: [16 x i8]* %ptr, i8* %gep.mul.1
; CHECK-NEXT: MayAlias: [16 x i8]* %ptr, i8* %gep.sub.2
; CHECK-NEXT: NoAlias: i8* %gep.mul.1, i8* %gep.sub.2
; CHECK-NEXT: MayAlias: i8* %gep.mul.1, i8* %gep.sub.2
;
%mul.1 = mul i64 %idx, 6148914691236517207
%gep.mul.1 = getelementptr [16 x i8], [16 x i8]* %ptr, i32 0, i64 %mul.1