diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 0b15304bed20..affeb740f272 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -54,47 +54,44 @@ static Value *createMinMax(InstCombiner::BuilderTy &Builder, return Builder.CreateSelect(Builder.CreateICmp(Pred, A, B), A, B); } -/// If one of the constants is zero (we know they can't both be) and we have an -/// icmp instruction with zero, and we have an 'and' with the non-constant value -/// and a power of two we can turn the select into a shift on the result of the -/// 'and'. /// This folds: -/// select (icmp eq (and X, C1)), C2, C3 -/// iff C1 is a power 2 and the difference between C2 and C3 is a power of 2. +/// select (icmp eq (and X, C1)), TC, FC +/// iff C1 is a power 2 and the difference between TC and FC is a power-of-2. /// To something like: -/// (shr (and (X, C1)), (log2(C1) - log2(C2-C3))) + C3 +/// (shr (and (X, C1)), (log2(C1) - log2(TC-FC))) + FC /// Or: -/// (shl (and (X, C1)), (log2(C2-C3) - log2(C1))) + C3 -/// With some variations depending if C3 is larger than C2, or the shift +/// (shl (and (X, C1)), (log2(TC-FC) - log2(C1))) + FC +/// With some variations depending if FC is larger than TC, or the shift /// isn't needed, or the bit widths don't match. -static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC, - APInt TrueVal, APInt FalseVal, +static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp, InstCombiner::BuilderTy &Builder) { - assert(SelType->isIntOrIntVectorTy() && "Not an integer select?"); + const APInt *SelTC, *SelFC; + if (!match(Sel.getTrueValue(), m_APInt(SelTC)) || + !match(Sel.getFalseValue(), m_APInt(SelFC))) + return nullptr; // If this is a vector select, we need a vector compare. - if (SelType->isVectorTy() != IC->getType()->isVectorTy()) + Type *SelType = Sel.getType(); + if (SelType->isVectorTy() != Cmp->getType()->isVectorTy()) return nullptr; Value *V; APInt AndMask; bool CreateAnd = false; - ICmpInst::Predicate Pred = IC->getPredicate(); + ICmpInst::Predicate Pred = Cmp->getPredicate(); if (ICmpInst::isEquality(Pred)) { - if (!match(IC->getOperand(1), m_Zero())) + if (!match(Cmp->getOperand(1), m_Zero())) return nullptr; - V = IC->getOperand(0); - + V = Cmp->getOperand(0); const APInt *AndRHS; if (!match(V, m_And(m_Value(), m_Power2(AndRHS)))) return nullptr; AndMask = *AndRHS; - } else if (decomposeBitTestICmp(IC->getOperand(0), IC->getOperand(1), + } else if (decomposeBitTestICmp(Cmp->getOperand(0), Cmp->getOperand(1), Pred, V, AndMask)) { assert(ICmpInst::isEquality(Pred) && "Not equality test?"); - if (!AndMask.isPowerOf2()) return nullptr; @@ -104,38 +101,39 @@ static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC, } // If both select arms are non-zero see if we have a select of the form - // 'x ? 2^n + C : C'. Then we can offset both arms by C, use the logic + // 'x ? 2^n + TC : FC'. Then we can offset both arms by C, use the logic // for 'x ? 2^n : 0' and fix the thing up at the end. - APInt Offset(TrueVal.getBitWidth(), 0); - if (!TrueVal.isNullValue() && !FalseVal.isNullValue()) { - if ((TrueVal - FalseVal).isPowerOf2()) - Offset = FalseVal; - else if ((FalseVal - TrueVal).isPowerOf2()) - Offset = TrueVal; + APInt TC = *SelTC; + APInt FC = *SelFC; + APInt Offset(TC.getBitWidth(), 0); + if (!TC.isNullValue() && !FC.isNullValue()) { + if ((TC - FC).isPowerOf2()) + Offset = FC; + else if ((FC - TC).isPowerOf2()) + Offset = TC; else return nullptr; - // Adjust TrueVal and FalseVal to the offset. - TrueVal -= Offset; - FalseVal -= Offset; + // Adjust TC and FC by the offset. + TC -= Offset; + FC -= Offset; } - // Make sure one of the select arms is a power of 2. - if (!TrueVal.isPowerOf2() && !FalseVal.isPowerOf2()) + // Make sure one of the select arms is a power-of-2. + if (!TC.isPowerOf2() && !FC.isPowerOf2()) return nullptr; // Determine which shift is needed to transform result of the 'and' into the // desired result. - const APInt &ValC = !TrueVal.isNullValue() ? TrueVal : FalseVal; + const APInt &ValC = !TC.isNullValue() ? TC : FC; unsigned ValZeros = ValC.logBase2(); unsigned AndZeros = AndMask.logBase2(); - if (CreateAnd) { - // Insert the AND instruction on the input to the truncate. + // Insert the 'and' instruction on the input to the truncate. + if (CreateAnd) V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), AndMask)); - } - // If types don't match we can still convert the select by introducing a zext + // If types don't match, we can still convert the select by introducing a zext // or a trunc of the 'and'. if (ValZeros > AndZeros) { V = Builder.CreateZExtOrTrunc(V, SelType); @@ -143,12 +141,13 @@ static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC, } else if (ValZeros < AndZeros) { V = Builder.CreateLShr(V, AndZeros - ValZeros); V = Builder.CreateZExtOrTrunc(V, SelType); - } else + } else { V = Builder.CreateZExtOrTrunc(V, SelType); + } // Okay, now we know that everything is set up, we just don't know whether we // have a icmp_ne or icmp_eq and whether the true or false val is the zero. - bool ShouldNotVal = !TrueVal.isNullValue(); + bool ShouldNotVal = !TC.isNullValue(); ShouldNotVal ^= Pred == ICmpInst::ICMP_NE; if (ShouldNotVal) V = Builder.CreateXor(V, ValC); @@ -831,14 +830,8 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, } } - { - const APInt *TrueValC, *FalseValC; - if (match(TrueVal, m_APInt(TrueValC)) && - match(FalseVal, m_APInt(FalseValC))) - if (Value *V = foldSelectICmpAnd(SI.getType(), ICI, *TrueValC, - *FalseValC, Builder)) - return replaceInstUsesWith(SI, V); - } + if (Value *V = foldSelectICmpAnd(SI, ICI, Builder)) + return replaceInstUsesWith(SI, V); // NOTE: if we wanted to, this is where to detect integer MIN/MAX