diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index 8c48597fc2e4..ab3768d737d1 100644 --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -1126,72 +1126,75 @@ Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) { if (!Pow->isFast()) return nullptr; - const APFloat *Arg1C; - if (!match(Pow->getArgOperand(1), m_APFloat(Arg1C))) - return nullptr; - if (!Arg1C->isExactlyValue(0.5) && !Arg1C->isExactlyValue(-0.5)) - return nullptr; - - // Fast-math flags from the pow() are propagated to all replacement ops. - IRBuilder<>::FastMathFlagGuard Guard(B); - B.setFastMathFlags(Pow->getFastMathFlags()); + Value *Sqrt, *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); Type *Ty = Pow->getType(); - Value *Sqrt; - if (Pow->hasFnAttr(Attribute::ReadNone)) { - // We know that errno is never set, so replace with an intrinsic: - // pow(x, 0.5) --> llvm.sqrt(x) - // llvm.pow(x, 0.5) --> llvm.sqrt(x) - auto *F = Intrinsic::getDeclaration(Pow->getModule(), Intrinsic::sqrt, Ty); - Sqrt = B.CreateCall(F, Pow->getArgOperand(0)); - } else if (hasUnaryFloatFn(TLI, Ty, LibFunc_sqrt, LibFunc_sqrtf, - LibFunc_sqrtl)) { - // Errno could be set, so we must use a sqrt libcall. - // TODO: We also should check that the target can in fact lower the sqrt - // libcall. We currently have no way to ask this question, so we ask - // whether the target has a sqrt libcall which is not exactly the same. - Sqrt = emitUnaryFloatFnCall(Pow->getArgOperand(0), - TLI->getName(LibFunc_sqrt), B, - Pow->getCalledFunction()->getAttributes()); - } else { - // We can't replace with an intrinsic or a libcall. - return nullptr; - } - // If this is pow(x, -0.5), get the reciprocal. - if (Arg1C->isExactlyValue(-0.5)) - Sqrt = B.CreateFDiv(ConstantFP::get(Ty, 1.0), Sqrt); + const APFloat *ExpoF; + if (!match(Expo, m_APFloat(ExpoF)) || + (!ExpoF->isExactlyValue(0.5) && !ExpoF->isExactlyValue(-0.5))) + return nullptr; + + // If errno is never set, then use the intrinsic for sqrt(). + if (Pow->hasFnAttr(Attribute::ReadNone)) { + + Function *SqrtFn = Intrinsic::getDeclaration(Pow->getModule(), + Intrinsic::sqrt, Ty); + Sqrt = B.CreateCall(SqrtFn, Base); + } + // Otherwise, use the libcall for sqrt(). + else if (hasUnaryFloatFn(TLI, Ty, + LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl)) { + // TODO: We also should check that the target can in fact lower the sqrt() + // libcall. We currently have no way to ask this question, so we ask if + // the target has a sqrt() libcall, which is not exactly the same. + Sqrt = emitUnaryFloatFnCall(Base, TLI->getName(LibFunc_sqrt), B, + Pow->getCalledFunction()->getAttributes()); + } else + return nullptr; + + // If this is pow(x, -0.5), then get the reciprocal. + if (ExpoF->isNegative()) + Sqrt = B.CreateFDiv(ConstantFP::get(Ty, 1.0), Sqrt, "reciprocal"); return Sqrt; } -Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - Value *Ret = nullptr; +Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) { + Value *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); + Function *Callee = Pow->getCalledFunction(); + AttributeList Attrs = Callee->getAttributes(); StringRef Name = Callee->getName(); - if (UnsafeFPShrink && Name == "pow" && hasFloatVersion(Name)) - Ret = optimizeUnaryDoubleFP(CI, B, true); + Module *Module = Pow->getModule(); + Type *Ty = Pow->getType(); + Value *Shrunk = nullptr; + bool Ignored; - Value *Op1 = CI->getArgOperand(0), *Op2 = CI->getArgOperand(1); + if (UnsafeFPShrink && + Name == TLI->getName(LibFunc_pow) && hasFloatVersion(Name)) + Shrunk = optimizeUnaryDoubleFP(Pow, B, true); + + // Propagate math semantics flags from the call to any created instructions. + IRBuilder<>::FastMathFlagGuard Guard(B); + B.setFastMathFlags(Pow->getFastMathFlags()); + + // Evaluate special cases related to the base. // pow(1.0, x) -> 1.0 - if (match(Op1, m_SpecificFP(1.0))) - return Op1; - // pow(2.0, x) -> llvm.exp2(x) - if (match(Op1, m_SpecificFP(2.0))) { - Value *Exp2 = Intrinsic::getDeclaration(CI->getModule(), Intrinsic::exp2, - CI->getType()); - return B.CreateCall(Exp2, Op2, "exp2"); + if (match(Base, m_SpecificFP(1.0))) + return Base; + + // pow(2.0, x) -> exp2(x) + if (match(Base, m_SpecificFP(2.0))) { + Value *Exp2 = Intrinsic::getDeclaration(Module, Intrinsic::exp2, Ty); + return B.CreateCall(Exp2, Expo, "exp2"); } - // There's no llvm.exp10 intrinsic yet, but, maybe, some day there will - // be one. - if (ConstantFP *Op1C = dyn_cast(Op1)) { + // There's no exp10 intrinsic yet, but, maybe, some day there shall be one. + if (ConstantFP *BaseC = dyn_cast(Base)) { // pow(10.0, x) -> exp10(x) - if (Op1C->isExactlyValue(10.0) && - hasUnaryFloatFn(TLI, Op1->getType(), LibFunc_exp10, LibFunc_exp10f, - LibFunc_exp10l)) - return emitUnaryFloatFnCall(Op2, TLI->getName(LibFunc_exp10), B, - Callee->getAttributes()); + if (BaseC->isExactlyValue(10.0) && + hasUnaryFloatFn(TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) + return emitUnaryFloatFnCall(Expo, TLI->getName(LibFunc_exp10), B, Attrs); } // pow(exp(x), y) -> exp(x * y) @@ -1200,91 +1203,91 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { // transformation changes overflow and underflow behavior quite dramatically. // Example: x = 1000, y = 0.001. // pow(exp(x), y) = pow(inf, 0.001) = inf, whereas exp(x*y) = exp(1). - auto *OpC = dyn_cast(Op1); - if (OpC && OpC->isFast() && CI->isFast()) { - LibFunc Func; - Function *OpCCallee = OpC->getCalledFunction(); - if (OpCCallee && TLI->getLibFunc(OpCCallee->getName(), Func) && - TLI->has(Func) && (Func == LibFunc_exp || Func == LibFunc_exp2)) { + auto *BaseFn = dyn_cast(Base); + if (BaseFn && BaseFn->isFast() && Pow->isFast()) { + LibFunc LibFn; + Function *CalleeFn = BaseFn->getCalledFunction(); + if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && + (LibFn == LibFunc_exp || LibFn == LibFunc_exp2) && TLI->has(LibFn)) { IRBuilder<>::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - Value *FMul = B.CreateFMul(OpC->getArgOperand(0), Op2, "mul"); - return emitUnaryFloatFnCall(FMul, OpCCallee->getName(), B, - OpCCallee->getAttributes()); + B.setFastMathFlags(Pow->getFastMathFlags()); + + Value *FMul = B.CreateFMul(BaseFn->getArgOperand(0), Expo, "mul"); + return emitUnaryFloatFnCall(FMul, CalleeFn->getName(), B, + CalleeFn->getAttributes()); } } - if (Value *Sqrt = replacePowWithSqrt(CI, B)) + // Evaluate special cases related to the exponent. + + if (Value *Sqrt = replacePowWithSqrt(Pow, B)) return Sqrt; - ConstantFP *Op2C = dyn_cast(Op2); - if (!Op2C) - return Ret; + ConstantFP *ExpoC = dyn_cast(Expo); + if (!ExpoC) + return Shrunk; - if (Op2C->getValueAPF().isZero()) // pow(x, 0.0) -> 1.0 - return ConstantFP::get(CI->getType(), 1.0); + // pow(x, -1.0) -> 1.0 / x + if (ExpoC->isExactlyValue(-1.0)) + return B.CreateFDiv(ConstantFP::get(Ty, 1.0), Base, "reciprocal"); + + // pow(x, 0.0) -> 1.0 + if (ExpoC->getValueAPF().isZero()) + return ConstantFP::get(Ty, 1.0); + + // pow(x, 1.0) -> x + if (ExpoC->isExactlyValue(1.0)) + return Base; + + // pow(x, 2.0) -> x * x + if (ExpoC->isExactlyValue(2.0)) + return B.CreateFMul(Base, Base, "square"); // FIXME: Correct the transforms and pull this into replacePowWithSqrt(). - if (Op2C->isExactlyValue(0.5) && - hasUnaryFloatFn(TLI, Op2->getType(), LibFunc_sqrt, LibFunc_sqrtf, - LibFunc_sqrtl)) { + if (ExpoC->isExactlyValue(0.5) && + hasUnaryFloatFn(TLI, Ty, LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl)) { // Expand pow(x, 0.5) to (x == -infinity ? +infinity : fabs(sqrt(x))). // This is faster than calling pow, and still handles negative zero // and negative infinity correctly. // TODO: In finite-only mode, this could be just fabs(sqrt(x)). - Value *Inf = ConstantFP::getInfinity(CI->getType()); - Value *NegInf = ConstantFP::getInfinity(CI->getType(), true); + Value *PosInf = ConstantFP::getInfinity(Ty); + Value *NegInf = ConstantFP::getInfinity(Ty, true); - // TODO: As above, we should lower to the sqrt intrinsic if the pow is an - // intrinsic, to match errno semantics. - Value *Sqrt = emitUnaryFloatFnCall(Op1, "sqrt", B, Callee->getAttributes()); + // TODO: As above, we should lower to the sqrt() intrinsic if the pow() is + // an intrinsic, to match errno semantics. + Value *Sqrt = emitUnaryFloatFnCall(Base, TLI->getName(LibFunc_sqrt), + B, Attrs); + Function *FabsFn = Intrinsic::getDeclaration(Module, Intrinsic::fabs, Ty); + Value *FAbs = B.CreateCall(FabsFn, Sqrt, "abs"); - Module *M = Callee->getParent(); - Function *FabsF = Intrinsic::getDeclaration(M, Intrinsic::fabs, - CI->getType()); - Value *FAbs = B.CreateCall(FabsF, Sqrt); - - Value *FCmp = B.CreateFCmpOEQ(Op1, NegInf); - Value *Sel = B.CreateSelect(FCmp, Inf, FAbs); + Value *FCmp = B.CreateFCmpOEQ(Base, NegInf, "isinf"); + Value *Sel = B.CreateSelect(FCmp, PosInf, FAbs); return Sel; } - // Propagate fast-math-flags from the call to any created instructions. - IRBuilder<>::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - // pow(x, 1.0) --> x - if (Op2C->isExactlyValue(1.0)) - return Op1; - // pow(x, 2.0) --> x * x - if (Op2C->isExactlyValue(2.0)) - return B.CreateFMul(Op1, Op1, "pow2"); - // pow(x, -1.0) --> 1.0 / x - if (Op2C->isExactlyValue(-1.0)) - return B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), Op1, "powrecip"); - - // In -ffast-math, generate repeated fmul instead of generating pow(x, n). - if (CI->isFast()) { - APFloat V = abs(Op2C->getValueAPF()); - // We limit to a max of 7 fmul(s). Thus max exponent is 32. + // pow(x, n) -> x * x * x * .... + if (Pow->isFast()) { + APFloat ExpoA = abs(ExpoC->getValueAPF()); + // We limit to a max of 7 fmul(s). Thus the maximum exponent is 32. // This transformation applies to integer exponents only. - if (V.compare(APFloat(V.getSemantics(), 32.0)) == APFloat::cmpGreaterThan || - !V.isInteger()) + if (!ExpoA.isInteger() || + ExpoA.compare + (APFloat(ExpoA.getSemantics(), 32.0)) == APFloat::cmpGreaterThan) return nullptr; // We will memoize intermediate products of the Addition Chain. Value *InnerChain[33] = {nullptr}; - InnerChain[1] = Op1; - InnerChain[2] = B.CreateFMul(Op1, Op1); + InnerChain[1] = Base; + InnerChain[2] = B.CreateFMul(Base, Base, "square"); // We cannot readily convert a non-double type (like float) to a double. - // So we first convert V to something which could be converted to double. - bool Ignored; - V.convert(APFloat::IEEEdouble(), APFloat::rmTowardZero, &Ignored); + // So we first convert ExpoA to something which could be converted to double. + ExpoA.convert(APFloat::IEEEdouble(), APFloat::rmTowardZero, &Ignored); - Value *FMul = getPow(InnerChain, V.convertToDouble(), B); + Value *FMul = getPow(InnerChain, ExpoA.convertToDouble(), B); // For negative exponents simply compute the reciprocal. - if (Op2C->isNegative()) - FMul = B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), FMul); + if (ExpoC->isNegative()) + FMul = B.CreateFDiv(ConstantFP::get(Ty, 1.0), FMul, "reciprocal"); return FMul; } diff --git a/llvm/test/Transforms/InstCombine/pow-sqrt.ll b/llvm/test/Transforms/InstCombine/pow-sqrt.ll index c07a82ccedda..3b885ad5bdae 100644 --- a/llvm/test/Transforms/InstCombine/pow-sqrt.ll +++ b/llvm/test/Transforms/InstCombine/pow-sqrt.ll @@ -20,9 +20,9 @@ define <2 x double> @pow_intrinsic_half_approx(<2 x double> %x) { define double @pow_libcall_half_approx(double %x) { ; CHECK-LABEL: @pow_libcall_half_approx( -; CHECK-NEXT: [[SQRT:%.*]] = call double @sqrt(double %x) -; CHECK-NEXT: [[TMP1:%.*]] = call double @llvm.fabs.f64(double [[SQRT]]) -; CHECK-NEXT: [[TMP2:%.*]] = fcmp oeq double %x, 0xFFF0000000000000 +; CHECK-NEXT: [[SQRT:%.*]] = call afn double @sqrt(double %x) +; CHECK-NEXT: [[TMP1:%.*]] = call afn double @llvm.fabs.f64(double [[SQRT]]) +; CHECK-NEXT: [[TMP2:%.*]] = fcmp afn oeq double %x, 0xFFF0000000000000 ; CHECK-NEXT: [[TMP3:%.*]] = select i1 [[TMP2]], double 0x7FF0000000000000, double [[TMP1]] ; CHECK-NEXT: ret double [[TMP3]] ;