SimplifyCFG: preserve branch-weight metadata when creating a new switch from
a pair of switch/branch where both depend on the value of the same variable and the default case of the first switch/branch goes to the second switch/branch. Code clean up and fixed a few issues: 1> handling the case where some cases of the 2nd switch are invalidated 2> correctly calculate the weight for the 2nd switch when it is a conditional eq Testing case is modified from Alastair's original patch. llvm-svn: 163635
This commit is contained in:
parent
66d2e88799
commit
571d9e4b80
|
@ -752,38 +752,27 @@ static inline bool HasBranchWeights(const Instruction* I) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Tries to get a branch weight for the given instruction, returns NULL if it
|
/// Get Weights of a given TerminatorInst, the default weight is at the front
|
||||||
/// can't. Pos starts at 0.
|
/// of the vector. If TI is a conditional eq, we need to swap the branch-weight
|
||||||
static ConstantInt* GetWeight(Instruction* I, int Pos) {
|
/// metadata.
|
||||||
MDNode* ProfMD = I->getMetadata(LLVMContext::MD_prof);
|
static void GetBranchWeights(TerminatorInst *TI,
|
||||||
if (ProfMD && ProfMD->getOperand(0)) {
|
SmallVectorImpl<uint64_t> &Weights) {
|
||||||
if (MDString* MDS = dyn_cast<MDString>(ProfMD->getOperand(0))) {
|
MDNode* MD = TI->getMetadata(LLVMContext::MD_prof);
|
||||||
if (MDS->getString().equals("branch_weights")) {
|
assert(MD);
|
||||||
assert(ProfMD->getNumOperands() >= 3);
|
for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) {
|
||||||
return dyn_cast<ConstantInt>(ProfMD->getOperand(1 + Pos));
|
ConstantInt* CI = dyn_cast<ConstantInt>(MD->getOperand(i));
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Scale the given weights based on the successor TI's metadata. Scaling is
|
|
||||||
/// done by multiplying every weight by the sum of the successor's weights.
|
|
||||||
static void ScaleWeights(Instruction* STI, MutableArrayRef<uint64_t> Weights) {
|
|
||||||
// Sum the successor's weights
|
|
||||||
assert(HasBranchWeights(STI));
|
|
||||||
unsigned Scale = 0;
|
|
||||||
MDNode* ProfMD = STI->getMetadata(LLVMContext::MD_prof);
|
|
||||||
for (unsigned i = 1; i < ProfMD->getNumOperands(); ++i) {
|
|
||||||
ConstantInt* CI = dyn_cast<ConstantInt>(ProfMD->getOperand(i));
|
|
||||||
assert(CI);
|
assert(CI);
|
||||||
Scale += CI->getValue().getZExtValue();
|
Weights.push_back(CI->getValue().getZExtValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip default, as it's replaced during the folding
|
// If TI is a conditional eq, the default case is the false case,
|
||||||
for (unsigned i = 1; i < Weights.size(); ++i) {
|
// and the corresponding branch-weight data is at index 2. We swap the
|
||||||
Weights[i] *= Scale;
|
// default weight to be the first entry.
|
||||||
|
if (BranchInst* BI = dyn_cast<BranchInst>(TI)) {
|
||||||
|
assert(Weights.size() == 2);
|
||||||
|
ICmpInst *ICI = cast<ICmpInst>(BI->getCondition());
|
||||||
|
if (ICI->getPredicate() == ICmpInst::ICMP_EQ)
|
||||||
|
std::swap(Weights.front(), Weights.back());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -838,52 +827,22 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI,
|
||||||
|
|
||||||
// Update the branch weight metadata along the way
|
// Update the branch weight metadata along the way
|
||||||
SmallVector<uint64_t, 8> Weights;
|
SmallVector<uint64_t, 8> Weights;
|
||||||
uint64_t PredDefaultWeight = 0;
|
|
||||||
bool PredHasWeights = HasBranchWeights(PTI);
|
bool PredHasWeights = HasBranchWeights(PTI);
|
||||||
bool SuccHasWeights = HasBranchWeights(TI);
|
bool SuccHasWeights = HasBranchWeights(TI);
|
||||||
|
|
||||||
if (PredHasWeights) {
|
if (PredHasWeights)
|
||||||
MDNode* MD = PTI->getMetadata(LLVMContext::MD_prof);
|
GetBranchWeights(PTI, Weights);
|
||||||
assert(MD);
|
else if (SuccHasWeights)
|
||||||
for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) {
|
|
||||||
ConstantInt* CI = dyn_cast<ConstantInt>(MD->getOperand(i));
|
|
||||||
assert(CI);
|
|
||||||
Weights.push_back(CI->getValue().getZExtValue());
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the predecessor is a conditional eq, then swap the default weight
|
|
||||||
// to be the first entry.
|
|
||||||
if (BranchInst* BI = dyn_cast<BranchInst>(PTI)) {
|
|
||||||
assert(Weights.size() == 2);
|
|
||||||
ICmpInst *ICI = cast<ICmpInst>(BI->getCondition());
|
|
||||||
|
|
||||||
if (ICI->getPredicate() == ICmpInst::ICMP_EQ) {
|
|
||||||
std::swap(Weights.front(), Weights.back());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
PredDefaultWeight = Weights.front();
|
|
||||||
} else if (SuccHasWeights) {
|
|
||||||
// If there are no predecessor weights but there are successor weights,
|
// If there are no predecessor weights but there are successor weights,
|
||||||
// populate Weights with 1, which will later be scaled to the sum of
|
// populate Weights with 1, which will later be scaled to the sum of
|
||||||
// successor's weights
|
// successor's weights
|
||||||
Weights.assign(1 + PredCases.size(), 1);
|
Weights.assign(1 + PredCases.size(), 1);
|
||||||
PredDefaultWeight = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t SuccDefaultWeight = 0;
|
SmallVector<uint64_t, 8> SuccWeights;
|
||||||
if (SuccHasWeights) {
|
if (SuccHasWeights)
|
||||||
int Index = 0;
|
GetBranchWeights(TI, SuccWeights);
|
||||||
if (BranchInst* BI = dyn_cast<BranchInst>(TI)) {
|
else if (PredHasWeights)
|
||||||
ICmpInst* ICI = dyn_cast<ICmpInst>(BI->getCondition());
|
SuccWeights.assign(1 + BBCases.size(), 1);
|
||||||
assert(ICI);
|
|
||||||
|
|
||||||
if (ICI->getPredicate() == ICmpInst::ICMP_EQ)
|
|
||||||
Index = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
SuccDefaultWeight = GetWeight(TI, Index)->getValue().getZExtValue();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (PredDefault == BB) {
|
if (PredDefault == BB) {
|
||||||
// If this is the default destination from PTI, only the edges in TI
|
// If this is the default destination from PTI, only the edges in TI
|
||||||
|
@ -896,7 +855,9 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI,
|
||||||
// The default destination is BB, we don't need explicit targets.
|
// The default destination is BB, we don't need explicit targets.
|
||||||
std::swap(PredCases[i], PredCases.back());
|
std::swap(PredCases[i], PredCases.back());
|
||||||
|
|
||||||
if (PredHasWeights) {
|
if (PredHasWeights || SuccHasWeights) {
|
||||||
|
// Increase weight for the default case.
|
||||||
|
Weights[0] += Weights[i+1];
|
||||||
std::swap(Weights[i+1], Weights.back());
|
std::swap(Weights[i+1], Weights.back());
|
||||||
Weights.pop_back();
|
Weights.pop_back();
|
||||||
}
|
}
|
||||||
|
@ -912,27 +873,30 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI,
|
||||||
NewSuccessors.push_back(BBDefault);
|
NewSuccessors.push_back(BBDefault);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (SuccHasWeights) {
|
unsigned CasesFromPred = Weights.size();
|
||||||
ScaleWeights(TI, Weights);
|
uint64_t ValidTotalSuccWeight = 0;
|
||||||
Weights.front() *= SuccDefaultWeight;
|
|
||||||
} else if (PredHasWeights) {
|
|
||||||
Weights.front() /= (1 + BBCases.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0, e = BBCases.size(); i != e; ++i)
|
for (unsigned i = 0, e = BBCases.size(); i != e; ++i)
|
||||||
if (!PTIHandled.count(BBCases[i].Value) &&
|
if (!PTIHandled.count(BBCases[i].Value) &&
|
||||||
BBCases[i].Dest != BBDefault) {
|
BBCases[i].Dest != BBDefault) {
|
||||||
PredCases.push_back(BBCases[i]);
|
PredCases.push_back(BBCases[i]);
|
||||||
NewSuccessors.push_back(BBCases[i].Dest);
|
NewSuccessors.push_back(BBCases[i].Dest);
|
||||||
if (SuccHasWeights) {
|
if (SuccHasWeights || PredHasWeights) {
|
||||||
Weights.push_back(PredDefaultWeight *
|
// The default weight is at index 0, so weight for the ith case
|
||||||
GetWeight(TI, i)->getValue().getZExtValue());
|
// should be at index i+1. Scale the cases from successor by
|
||||||
} else if (PredHasWeights) {
|
// PredDefaultWeight (Weights[0]).
|
||||||
// Split the old default's weight amongst the children
|
Weights.push_back(Weights[0] * SuccWeights[i+1]);
|
||||||
Weights.push_back(PredDefaultWeight / (1 + BBCases.size()));
|
ValidTotalSuccWeight += SuccWeights[i+1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (SuccHasWeights || PredHasWeights) {
|
||||||
|
ValidTotalSuccWeight += SuccWeights[0];
|
||||||
|
// Scale the cases from predecessor by ValidTotalSuccWeight.
|
||||||
|
for (unsigned i = 1; i < CasesFromPred; ++i)
|
||||||
|
Weights[i] *= ValidTotalSuccWeight;
|
||||||
|
// Scale the default weight by SuccDefaultWeight (SuccWeights[0]).
|
||||||
|
Weights[0] *= SuccWeights[0];
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// FIXME: preserve branch weight metadata, similarly to the 'then'
|
// FIXME: preserve branch weight metadata, similarly to the 'then'
|
||||||
// above. For now, drop it.
|
// above. For now, drop it.
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
; RUN: opt -simplifycfg -S -o - < %s | FileCheck %s
|
||||||
|
|
||||||
|
declare void @func2(i32)
|
||||||
|
declare void @func4(i32)
|
||||||
|
declare void @func6(i32)
|
||||||
|
declare void @func8(i32)
|
||||||
|
|
||||||
|
;; test1 - create a switch with case 2 and case 4 from two branches: N == 2
|
||||||
|
;; and N == 4.
|
||||||
|
define void @test1(i32 %N) nounwind uwtable {
|
||||||
|
entry:
|
||||||
|
%cmp = icmp eq i32 %N, 2
|
||||||
|
br i1 %cmp, label %if.then, label %if.else, !prof !0
|
||||||
|
; CHECK: test1
|
||||||
|
; CHECK: switch i32 %N
|
||||||
|
; CHECK: ], !prof !0
|
||||||
|
|
||||||
|
if.then:
|
||||||
|
call void @func2(i32 %N) nounwind
|
||||||
|
br label %if.end9
|
||||||
|
|
||||||
|
if.else:
|
||||||
|
%cmp2 = icmp eq i32 %N, 4
|
||||||
|
br i1 %cmp2, label %if.then7, label %if.else8, !prof !1
|
||||||
|
|
||||||
|
if.then7:
|
||||||
|
call void @func4(i32 %N) nounwind
|
||||||
|
br label %if.end
|
||||||
|
|
||||||
|
if.else8:
|
||||||
|
call void @func8(i32 %N) nounwind
|
||||||
|
br label %if.end
|
||||||
|
|
||||||
|
if.end:
|
||||||
|
br label %if.end9
|
||||||
|
|
||||||
|
if.end9:
|
||||||
|
ret void
|
||||||
|
}
|
||||||
|
|
||||||
|
;; test2 - Merge two switches where PredDefault == BB.
|
||||||
|
define void @test2(i32 %M, i32 %N) nounwind uwtable {
|
||||||
|
entry:
|
||||||
|
%cmp = icmp sgt i32 %M, 2
|
||||||
|
br i1 %cmp, label %sw1, label %sw2
|
||||||
|
|
||||||
|
sw1:
|
||||||
|
switch i32 %N, label %sw2 [
|
||||||
|
i32 2, label %sw.bb
|
||||||
|
i32 3, label %sw.bb1
|
||||||
|
], !prof !2
|
||||||
|
; CHECK: test2
|
||||||
|
; CHECK: switch i32 %N, label %sw.epilog
|
||||||
|
; CHECK: i32 2, label %sw.bb
|
||||||
|
; CHECK: i32 3, label %sw.bb1
|
||||||
|
; CHECK: i32 4, label %sw.bb5
|
||||||
|
; CHECK: ], !prof !1
|
||||||
|
|
||||||
|
sw.bb:
|
||||||
|
call void @func2(i32 %N) nounwind
|
||||||
|
br label %sw.epilog
|
||||||
|
|
||||||
|
sw.bb1:
|
||||||
|
call void @func4(i32 %N) nounwind
|
||||||
|
br label %sw.epilog
|
||||||
|
|
||||||
|
sw2:
|
||||||
|
;; Here "case 2" is invalidated if control is transferred through default case
|
||||||
|
;; of the first switch.
|
||||||
|
switch i32 %N, label %sw.epilog [
|
||||||
|
i32 2, label %sw.bb4
|
||||||
|
i32 4, label %sw.bb5
|
||||||
|
], !prof !3
|
||||||
|
|
||||||
|
sw.bb4:
|
||||||
|
call void @func6(i32 %N) nounwind
|
||||||
|
br label %sw.epilog
|
||||||
|
|
||||||
|
sw.bb5:
|
||||||
|
call void @func8(i32 %N) nounwind
|
||||||
|
br label %sw.epilog
|
||||||
|
|
||||||
|
sw.epilog:
|
||||||
|
ret void
|
||||||
|
}
|
||||||
|
|
||||||
|
!0 = metadata !{metadata !"branch_weights", i32 64, i32 4}
|
||||||
|
!1 = metadata !{metadata !"branch_weights", i32 4, i32 64}
|
||||||
|
; CHECK: !0 = metadata !{metadata !"branch_weights", i32 256, i32 4352, i32 16}
|
||||||
|
!2 = metadata !{metadata !"branch_weights", i32 4, i32 4, i32 8}
|
||||||
|
!3 = metadata !{metadata !"branch_weights", i32 8, i32 8, i32 4}
|
||||||
|
; CHECK: !1 = metadata !{metadata !"branch_weights", i32 32, i32 48, i32 96, i32 16}
|
Loading…
Reference in New Issue