[NVPTX] Implement NVPTXTargetLowering::getSqrtEstimate.
Summary: This lets us lower to sqrt.approx and rsqrt.approx under more circumstances. * Now we emit sqrt.approx and rsqrt.approx for calls to @llvm.sqrt.f32, when fast-math is enabled. Previously, we only would emit it for calls to @llvm.nvvm.sqrt.f. (With this patch we no longer emit sqrt.approx for calls to @llvm.nvvm.sqrt.f; we rely on intcombine to simplify llvm.nvvm.sqrt.f into llvm.sqrt.f32.) * Now we emit the ftz version of rsqrt.approx when ftz is enabled. Previously, we only emitted rsqrt.approx when ftz was disabled. Reviewers: hfinkel Subscribers: llvm-commits, tra, jholewinski Differential Revision: https://reviews.llvm.org/D28508 llvm-svn: 293605
This commit is contained in:
parent
93590e09d5
commit
1c9692a46f
|
@ -1043,6 +1043,50 @@ NVPTXTargetLowering::getPreferredVectorAction(EVT VT) const {
|
||||||
return TargetLoweringBase::getPreferredVectorAction(VT);
|
return TargetLoweringBase::getPreferredVectorAction(VT);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
|
||||||
|
int Enabled, int &ExtraSteps,
|
||||||
|
bool &UseOneConst,
|
||||||
|
bool Reciprocal) const {
|
||||||
|
if (!(Enabled == ReciprocalEstimate::Enabled ||
|
||||||
|
(Enabled == ReciprocalEstimate::Unspecified && !usePrecSqrtF32())))
|
||||||
|
return SDValue();
|
||||||
|
|
||||||
|
if (ExtraSteps == ReciprocalEstimate::Unspecified)
|
||||||
|
ExtraSteps = 0;
|
||||||
|
|
||||||
|
SDLoc DL(Operand);
|
||||||
|
EVT VT = Operand.getValueType();
|
||||||
|
bool Ftz = useF32FTZ(DAG.getMachineFunction());
|
||||||
|
|
||||||
|
auto MakeIntrinsicCall = [&](Intrinsic::ID IID) {
|
||||||
|
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
|
||||||
|
DAG.getConstant(IID, DL, MVT::i32), Operand);
|
||||||
|
};
|
||||||
|
|
||||||
|
// The sqrt and rsqrt refinement processes assume we always start out with an
|
||||||
|
// approximation of the rsqrt. Therefore, if we're going to do any refinement
|
||||||
|
// (i.e. ExtraSteps > 0), we must return an rsqrt. But if we're *not* doing
|
||||||
|
// any refinement, we must return a regular sqrt.
|
||||||
|
if (Reciprocal || ExtraSteps > 0) {
|
||||||
|
if (VT == MVT::f32)
|
||||||
|
return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f
|
||||||
|
: Intrinsic::nvvm_rsqrt_approx_f);
|
||||||
|
else if (VT == MVT::f64)
|
||||||
|
return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d);
|
||||||
|
else
|
||||||
|
return SDValue();
|
||||||
|
} else {
|
||||||
|
if (VT == MVT::f32)
|
||||||
|
return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f
|
||||||
|
: Intrinsic::nvvm_sqrt_approx_f);
|
||||||
|
else {
|
||||||
|
// There's no sqrt.approx.f64 instruction, so we emit x * rsqrt(x).
|
||||||
|
return DAG.getNode(ISD::FMUL, DL, VT, Operand,
|
||||||
|
MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
SDValue
|
SDValue
|
||||||
NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
|
NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
|
||||||
SDLoc dl(Op);
|
SDLoc dl(Op);
|
||||||
|
|
|
@ -526,6 +526,10 @@ public:
|
||||||
// to sign-preserving zero.
|
// to sign-preserving zero.
|
||||||
bool useF32FTZ(const MachineFunction &MF) const;
|
bool useF32FTZ(const MachineFunction &MF) const;
|
||||||
|
|
||||||
|
SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
|
||||||
|
int &ExtraSteps, bool &UseOneConst,
|
||||||
|
bool Reciprocal) const override;
|
||||||
|
|
||||||
bool allowFMA(MachineFunction &MF, CodeGenOpt::Level OptLevel) const;
|
bool allowFMA(MachineFunction &MF, CodeGenOpt::Level OptLevel) const;
|
||||||
bool allowUnsafeFPMath(MachineFunction &MF) const;
|
bool allowUnsafeFPMath(MachineFunction &MF) const;
|
||||||
|
|
||||||
|
|
|
@ -966,18 +966,9 @@ def FDIV32ri_prec :
|
||||||
Requires<[reqPTX20]>;
|
Requires<[reqPTX20]>;
|
||||||
|
|
||||||
//
|
//
|
||||||
// F32 rsqrt
|
// FMA
|
||||||
//
|
//
|
||||||
|
|
||||||
def RSQRTF32approx1r : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$b),
|
|
||||||
"rsqrt.approx.f32 \t$dst, $b;", []>;
|
|
||||||
|
|
||||||
// Convert 1.0f/sqrt(x) to rsqrt.approx.f32. (There is an rsqrt.approx.f64, but
|
|
||||||
// it's emulated in software.)
|
|
||||||
def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$b)),
|
|
||||||
(RSQRTF32approx1r Float32Regs:$b)>,
|
|
||||||
Requires<[do_DIVF32_FULL, do_SQRTF32_APPROX, doNoF32FTZ]>;
|
|
||||||
|
|
||||||
multiclass FMA<string OpcStr, RegisterClass RC, Operand ImmCls, Predicate Pred> {
|
multiclass FMA<string OpcStr, RegisterClass RC, Operand ImmCls, Predicate Pred> {
|
||||||
def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
|
def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
|
||||||
!strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
|
!strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
|
||||||
|
|
|
@ -1,25 +1,91 @@
|
||||||
; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
|
; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
|
||||||
|
|
||||||
declare float @llvm.nvvm.sqrt.f(float)
|
declare float @llvm.sqrt.f32(float)
|
||||||
|
declare double @llvm.sqrt.f64(double)
|
||||||
|
|
||||||
; CHECK-LABEL: sqrt_div
|
; CHECK-LABEL: sqrt_div(
|
||||||
; CHECK: sqrt.rn.f32
|
; CHECK: sqrt.rn.f32
|
||||||
; CHECK: div.rn.f32
|
; CHECK: div.rn.f32
|
||||||
define float @sqrt_div(float %a, float %b) {
|
define float @sqrt_div(float %a, float %b) {
|
||||||
%t1 = tail call float @llvm.nvvm.sqrt.f(float %a)
|
%t1 = tail call float @llvm.sqrt.f32(float %a)
|
||||||
%t2 = fdiv float %t1, %b
|
%t2 = fdiv float %t1, %b
|
||||||
ret float %t2
|
ret float %t2
|
||||||
}
|
}
|
||||||
|
|
||||||
; CHECK-LABEL: sqrt_div_fast
|
; CHECK-LABEL: sqrt_div_fast(
|
||||||
; CHECK: sqrt.approx.f32
|
; CHECK: sqrt.approx.f32
|
||||||
; CHECK: div.approx.f32
|
; CHECK: div.approx.f32
|
||||||
define float @sqrt_div_fast(float %a, float %b) #0 {
|
define float @sqrt_div_fast(float %a, float %b) #0 {
|
||||||
%t1 = tail call float @llvm.nvvm.sqrt.f(float %a)
|
%t1 = tail call float @llvm.sqrt.f32(float %a)
|
||||||
%t2 = fdiv float %t1, %b
|
%t2 = fdiv float %t1, %b
|
||||||
ret float %t2
|
ret float %t2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
; CHECK-LABEL: sqrt_div_ftz(
|
||||||
|
; CHECK: sqrt.rn.ftz.f32
|
||||||
|
; CHECK: div.rn.ftz.f32
|
||||||
|
define float @sqrt_div_ftz(float %a, float %b) #1 {
|
||||||
|
%t1 = tail call float @llvm.sqrt.f32(float %a)
|
||||||
|
%t2 = fdiv float %t1, %b
|
||||||
|
ret float %t2
|
||||||
|
}
|
||||||
|
|
||||||
|
; CHECK-LABEL: sqrt_div_fast_ftz(
|
||||||
|
; CHECK: sqrt.approx.ftz.f32
|
||||||
|
; CHECK: div.approx.ftz.f32
|
||||||
|
define float @sqrt_div_fast_ftz(float %a, float %b) #0 #1 {
|
||||||
|
%t1 = tail call float @llvm.sqrt.f32(float %a)
|
||||||
|
%t2 = fdiv float %t1, %b
|
||||||
|
ret float %t2
|
||||||
|
}
|
||||||
|
|
||||||
|
; There are no fast-math or ftz versions of sqrt and div for f64. We use
|
||||||
|
; x * rsqrt(x) for sqrt(x), and emit a vanilla divide.
|
||||||
|
|
||||||
|
; CHECK-LABEL: sqrt_div_fast_ftz_f64(
|
||||||
|
; CHECK: rsqrt.approx.f64
|
||||||
|
; CHECK: mul.f64
|
||||||
|
; CHECK: div.rn.f64
|
||||||
|
define double @sqrt_div_fast_ftz_f64(double %a, double %b) #0 #1 {
|
||||||
|
%t1 = tail call double @llvm.sqrt.f64(double %a)
|
||||||
|
%t2 = fdiv double %t1, %b
|
||||||
|
ret double %t2
|
||||||
|
}
|
||||||
|
|
||||||
|
; CHECK-LABEL: rsqrt(
|
||||||
|
; CHECK-NOT: rsqrt.approx
|
||||||
|
; CHECK: sqrt.rn.f32
|
||||||
|
; CHECK-NOT: rsqrt.approx
|
||||||
|
define float @rsqrt(float %a) {
|
||||||
|
%b = tail call float @llvm.sqrt.f32(float %a)
|
||||||
|
%ret = fdiv float 1.0, %b
|
||||||
|
ret float %ret
|
||||||
|
}
|
||||||
|
|
||||||
|
; CHECK-LABEL: rsqrt_fast(
|
||||||
|
; CHECK-NOT: div.
|
||||||
|
; CHECK-NOT: sqrt.
|
||||||
|
; CHECK: rsqrt.approx.f32
|
||||||
|
; CHECK-NOT: div.
|
||||||
|
; CHECK-NOT: sqrt.
|
||||||
|
define float @rsqrt_fast(float %a) #0 {
|
||||||
|
%b = tail call float @llvm.sqrt.f32(float %a)
|
||||||
|
%ret = fdiv float 1.0, %b
|
||||||
|
ret float %ret
|
||||||
|
}
|
||||||
|
|
||||||
|
; CHECK-LABEL: rsqrt_fast_ftz(
|
||||||
|
; CHECK-NOT: div.
|
||||||
|
; CHECK-NOT: sqrt.
|
||||||
|
; CHECK: rsqrt.approx.ftz.f32
|
||||||
|
; CHECK-NOT: div.
|
||||||
|
; CHECK-NOT: sqrt.
|
||||||
|
define float @rsqrt_fast_ftz(float %a) #0 #1 {
|
||||||
|
%b = tail call float @llvm.sqrt.f32(float %a)
|
||||||
|
%ret = fdiv float 1.0, %b
|
||||||
|
ret float %ret
|
||||||
|
}
|
||||||
|
|
||||||
; CHECK-LABEL: fadd
|
; CHECK-LABEL: fadd
|
||||||
; CHECK: add.rn.f32
|
; CHECK: add.rn.f32
|
||||||
define float @fadd(float %a, float %b) {
|
define float @fadd(float %a, float %b) {
|
||||||
|
|
|
@ -1,13 +0,0 @@
|
||||||
; RUN: llc < %s -march=nvptx -mcpu=sm_20 -nvptx-prec-divf32=1 -nvptx-prec-sqrtf32=0 | FileCheck %s
|
|
||||||
|
|
||||||
target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
|
|
||||||
|
|
||||||
declare float @llvm.nvvm.sqrt.f(float)
|
|
||||||
|
|
||||||
define float @foo(float %a) {
|
|
||||||
; CHECK: rsqrt.approx.f32
|
|
||||||
%val = tail call float @llvm.nvvm.sqrt.f(float %a)
|
|
||||||
%ret = fdiv float 1.0, %val
|
|
||||||
ret float %ret
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,148 @@
|
||||||
|
; RUN: llc < %s -march=nvptx -mcpu=sm_20 -nvptx-prec-divf32=0 -nvptx-prec-sqrtf32=0 \
|
||||||
|
; RUN: | FileCheck %s
|
||||||
|
|
||||||
|
target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
|
||||||
|
|
||||||
|
declare float @llvm.sqrt.f32(float)
|
||||||
|
declare double @llvm.sqrt.f64(double)
|
||||||
|
|
||||||
|
; -- reciprocal sqrt --
|
||||||
|
|
||||||
|
; CHECK-LABEL test_rsqrt32
|
||||||
|
define float @test_rsqrt32(float %a) #0 {
|
||||||
|
; CHECK: rsqrt.approx.f32
|
||||||
|
%val = tail call float @llvm.sqrt.f32(float %a)
|
||||||
|
%ret = fdiv float 1.0, %val
|
||||||
|
ret float %ret
|
||||||
|
}
|
||||||
|
|
||||||
|
; CHECK-LABEL test_rsqrt_ftz
|
||||||
|
define float @test_rsqrt_ftz(float %a) #0 #1 {
|
||||||
|
; CHECK: rsqrt.approx.ftz.f32
|
||||||
|
%val = tail call float @llvm.sqrt.f32(float %a)
|
||||||
|
%ret = fdiv float 1.0, %val
|
||||||
|
ret float %ret
|
||||||
|
}
|
||||||
|
|
||||||
|
; CHECK-LABEL test_rsqrt64
|
||||||
|
define double @test_rsqrt64(double %a) #0 {
|
||||||
|
; CHECK: rsqrt.approx.f64
|
||||||
|
%val = tail call double @llvm.sqrt.f64(double %a)
|
||||||
|
%ret = fdiv double 1.0, %val
|
||||||
|
ret double %ret
|
||||||
|
}
|
||||||
|
|
||||||
|
; CHECK-LABEL test_rsqrt64_ftz
|
||||||
|
define double @test_rsqrt64_ftz(double %a) #0 #1 {
|
||||||
|
; There's no rsqrt.approx.ftz.f64 instruction; we just use the non-ftz version.
|
||||||
|
; CHECK: rsqrt.approx.f64
|
||||||
|
%val = tail call double @llvm.sqrt.f64(double %a)
|
||||||
|
%ret = fdiv double 1.0, %val
|
||||||
|
ret double %ret
|
||||||
|
}
|
||||||
|
|
||||||
|
; -- sqrt --
|
||||||
|
|
||||||
|
; CHECK-LABEL test_sqrt32
|
||||||
|
define float @test_sqrt32(float %a) #0 {
|
||||||
|
; CHECK: sqrt.approx.f32
|
||||||
|
%ret = tail call float @llvm.sqrt.f32(float %a)
|
||||||
|
ret float %ret
|
||||||
|
}
|
||||||
|
|
||||||
|
; CHECK-LABEL test_sqrt_ftz
|
||||||
|
define float @test_sqrt_ftz(float %a) #0 #1 {
|
||||||
|
; CHECK: sqrt.approx.ftz.f32
|
||||||
|
%ret = tail call float @llvm.sqrt.f32(float %a)
|
||||||
|
ret float %ret
|
||||||
|
}
|
||||||
|
|
||||||
|
; CHECK-LABEL test_sqrt64
|
||||||
|
define double @test_sqrt64(double %a) #0 {
|
||||||
|
; There's no sqrt.approx.f64 instruction; we emit x * rsqrt.approx.f64(x).
|
||||||
|
; CHECK: rsqrt.approx.f64
|
||||||
|
; CHECK: mul.f64
|
||||||
|
%ret = tail call double @llvm.sqrt.f64(double %a)
|
||||||
|
ret double %ret
|
||||||
|
}
|
||||||
|
|
||||||
|
; CHECK-LABEL test_sqrt64_ftz
|
||||||
|
define double @test_sqrt64_ftz(double %a) #0 #1 {
|
||||||
|
; There's no sqrt.approx.ftz.f64 instruction; we just use the non-ftz version.
|
||||||
|
; CHECK: rsqrt.approx.f64
|
||||||
|
; CHECK: mul.f64
|
||||||
|
%ret = tail call double @llvm.sqrt.f64(double %a)
|
||||||
|
ret double %ret
|
||||||
|
}
|
||||||
|
|
||||||
|
; -- refined sqrt and rsqrt --
|
||||||
|
;
|
||||||
|
; The sqrt and rsqrt refinement algorithms both emit an rsqrt.approx, followed
|
||||||
|
; by some math.
|
||||||
|
|
||||||
|
; CHECK-LABEL: test_rsqrt32_refined
|
||||||
|
define float @test_rsqrt32_refined(float %a) #0 #2 {
|
||||||
|
; CHECK: rsqrt.approx.f32
|
||||||
|
%val = tail call float @llvm.sqrt.f32(float %a)
|
||||||
|
%ret = fdiv float 1.0, %val
|
||||||
|
ret float %ret
|
||||||
|
}
|
||||||
|
|
||||||
|
; CHECK-LABEL: test_sqrt32_refined
|
||||||
|
define float @test_sqrt32_refined(float %a) #0 #2 {
|
||||||
|
; CHECK: rsqrt.approx.f32
|
||||||
|
%ret = tail call float @llvm.sqrt.f32(float %a)
|
||||||
|
ret float %ret
|
||||||
|
}
|
||||||
|
|
||||||
|
; CHECK-LABEL: test_rsqrt64_refined
|
||||||
|
define double @test_rsqrt64_refined(double %a) #0 #2 {
|
||||||
|
; CHECK: rsqrt.approx.f64
|
||||||
|
%val = tail call double @llvm.sqrt.f64(double %a)
|
||||||
|
%ret = fdiv double 1.0, %val
|
||||||
|
ret double %ret
|
||||||
|
}
|
||||||
|
|
||||||
|
; CHECK-LABEL: test_sqrt64_refined
|
||||||
|
define double @test_sqrt64_refined(double %a) #0 #2 {
|
||||||
|
; CHECK: rsqrt.approx.f64
|
||||||
|
%ret = tail call double @llvm.sqrt.f64(double %a)
|
||||||
|
ret double %ret
|
||||||
|
}
|
||||||
|
|
||||||
|
; -- refined sqrt and rsqrt with ftz enabled --
|
||||||
|
|
||||||
|
; CHECK-LABEL: test_rsqrt32_refined_ftz
|
||||||
|
define float @test_rsqrt32_refined_ftz(float %a) #0 #1 #2 {
|
||||||
|
; CHECK: rsqrt.approx.ftz.f32
|
||||||
|
%val = tail call float @llvm.sqrt.f32(float %a)
|
||||||
|
%ret = fdiv float 1.0, %val
|
||||||
|
ret float %ret
|
||||||
|
}
|
||||||
|
|
||||||
|
; CHECK-LABEL: test_sqrt32_refined_ftz
|
||||||
|
define float @test_sqrt32_refined_ftz(float %a) #0 #1 #2 {
|
||||||
|
; CHECK: rsqrt.approx.ftz.f32
|
||||||
|
%ret = tail call float @llvm.sqrt.f32(float %a)
|
||||||
|
ret float %ret
|
||||||
|
}
|
||||||
|
|
||||||
|
; CHECK-LABEL: test_rsqrt64_refined_ftz
|
||||||
|
define double @test_rsqrt64_refined_ftz(double %a) #0 #1 #2 {
|
||||||
|
; There's no rsqrt.approx.ftz.f64, so we just use the non-ftz version.
|
||||||
|
; CHECK: rsqrt.approx.f64
|
||||||
|
%val = tail call double @llvm.sqrt.f64(double %a)
|
||||||
|
%ret = fdiv double 1.0, %val
|
||||||
|
ret double %ret
|
||||||
|
}
|
||||||
|
|
||||||
|
; CHECK-LABEL: test_sqrt64_refined_ftz
|
||||||
|
define double @test_sqrt64_refined_ftz(double %a) #0 #1 #2 {
|
||||||
|
; CHECK: rsqrt.approx.f64
|
||||||
|
%ret = tail call double @llvm.sqrt.f64(double %a)
|
||||||
|
ret double %ret
|
||||||
|
}
|
||||||
|
|
||||||
|
attributes #0 = { "unsafe-fp-math" = "true" }
|
||||||
|
attributes #1 = { "nvptx-f32ftz" = "true" }
|
||||||
|
attributes #2 = { "reciprocal-estimates" = "rsqrtf:1,rsqrtd:1,sqrtf:1,sqrtd:1" }
|
Loading…
Reference in New Issue