From 1c9692a46fd5650c65da38cb371b8e62a0303cfa Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Tue, 31 Jan 2017 05:58:22 +0000 Subject: [PATCH] [NVPTX] Implement NVPTXTargetLowering::getSqrtEstimate. Summary: This lets us lower to sqrt.approx and rsqrt.approx under more circumstances. * Now we emit sqrt.approx and rsqrt.approx for calls to @llvm.sqrt.f32, when fast-math is enabled. Previously, we only would emit it for calls to @llvm.nvvm.sqrt.f. (With this patch we no longer emit sqrt.approx for calls to @llvm.nvvm.sqrt.f; we rely on intcombine to simplify llvm.nvvm.sqrt.f into llvm.sqrt.f32.) * Now we emit the ftz version of rsqrt.approx when ftz is enabled. Previously, we only emitted rsqrt.approx when ftz was disabled. Reviewers: hfinkel Subscribers: llvm-commits, tra, jholewinski Differential Revision: https://reviews.llvm.org/D28508 llvm-svn: 293605 --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 44 ++++++ llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 4 + llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 11 +- llvm/test/CodeGen/NVPTX/fast-math.ll | 76 +++++++++- llvm/test/CodeGen/NVPTX/rsqrt.ll | 13 -- llvm/test/CodeGen/NVPTX/sqrt-approx.ll | 148 ++++++++++++++++++++ 6 files changed, 268 insertions(+), 28 deletions(-) delete mode 100644 llvm/test/CodeGen/NVPTX/rsqrt.ll create mode 100644 llvm/test/CodeGen/NVPTX/sqrt-approx.ll diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 1fb42496d955..194e46b0448c 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -1043,6 +1043,50 @@ NVPTXTargetLowering::getPreferredVectorAction(EVT VT) const { return TargetLoweringBase::getPreferredVectorAction(VT); } +SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, + int Enabled, int &ExtraSteps, + bool &UseOneConst, + bool Reciprocal) const { + if (!(Enabled == ReciprocalEstimate::Enabled || + (Enabled == ReciprocalEstimate::Unspecified && !usePrecSqrtF32()))) + return SDValue(); + + if (ExtraSteps == ReciprocalEstimate::Unspecified) + ExtraSteps = 0; + + SDLoc DL(Operand); + EVT VT = Operand.getValueType(); + bool Ftz = useF32FTZ(DAG.getMachineFunction()); + + auto MakeIntrinsicCall = [&](Intrinsic::ID IID) { + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getConstant(IID, DL, MVT::i32), Operand); + }; + + // The sqrt and rsqrt refinement processes assume we always start out with an + // approximation of the rsqrt. Therefore, if we're going to do any refinement + // (i.e. ExtraSteps > 0), we must return an rsqrt. But if we're *not* doing + // any refinement, we must return a regular sqrt. + if (Reciprocal || ExtraSteps > 0) { + if (VT == MVT::f32) + return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f + : Intrinsic::nvvm_rsqrt_approx_f); + else if (VT == MVT::f64) + return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d); + else + return SDValue(); + } else { + if (VT == MVT::f32) + return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f + : Intrinsic::nvvm_sqrt_approx_f); + else { + // There's no sqrt.approx.f64 instruction, so we emit x * rsqrt(x). + return DAG.getNode(ISD::FMUL, DL, VT, Operand, + MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d)); + } + } +} + SDValue NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const { SDLoc dl(Op); diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index 05c54018b739..f6494f6d37ef 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -526,6 +526,10 @@ public: // to sign-preserving zero. bool useF32FTZ(const MachineFunction &MF) const; + SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled, + int &ExtraSteps, bool &UseOneConst, + bool Reciprocal) const override; + bool allowFMA(MachineFunction &MF, CodeGenOpt::Level OptLevel) const; bool allowUnsafeFPMath(MachineFunction &MF) const; diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 8b703bd196e7..3345ce8d3cb0 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -966,18 +966,9 @@ def FDIV32ri_prec : Requires<[reqPTX20]>; // -// F32 rsqrt +// FMA // -def RSQRTF32approx1r : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$b), - "rsqrt.approx.f32 \t$dst, $b;", []>; - -// Convert 1.0f/sqrt(x) to rsqrt.approx.f32. (There is an rsqrt.approx.f64, but -// it's emulated in software.) -def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$b)), - (RSQRTF32approx1r Float32Regs:$b)>, - Requires<[do_DIVF32_FULL, do_SQRTF32_APPROX, doNoF32FTZ]>; - multiclass FMA { def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c), !strconcat(OpcStr, " \t$dst, $a, $b, $c;"), diff --git a/llvm/test/CodeGen/NVPTX/fast-math.ll b/llvm/test/CodeGen/NVPTX/fast-math.ll index 08b435b993f5..528d2c02df5f 100644 --- a/llvm/test/CodeGen/NVPTX/fast-math.ll +++ b/llvm/test/CodeGen/NVPTX/fast-math.ll @@ -1,25 +1,91 @@ ; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s -declare float @llvm.nvvm.sqrt.f(float) +declare float @llvm.sqrt.f32(float) +declare double @llvm.sqrt.f64(double) -; CHECK-LABEL: sqrt_div +; CHECK-LABEL: sqrt_div( ; CHECK: sqrt.rn.f32 ; CHECK: div.rn.f32 define float @sqrt_div(float %a, float %b) { - %t1 = tail call float @llvm.nvvm.sqrt.f(float %a) + %t1 = tail call float @llvm.sqrt.f32(float %a) %t2 = fdiv float %t1, %b ret float %t2 } -; CHECK-LABEL: sqrt_div_fast +; CHECK-LABEL: sqrt_div_fast( ; CHECK: sqrt.approx.f32 ; CHECK: div.approx.f32 define float @sqrt_div_fast(float %a, float %b) #0 { - %t1 = tail call float @llvm.nvvm.sqrt.f(float %a) + %t1 = tail call float @llvm.sqrt.f32(float %a) %t2 = fdiv float %t1, %b ret float %t2 } +; CHECK-LABEL: sqrt_div_ftz( +; CHECK: sqrt.rn.ftz.f32 +; CHECK: div.rn.ftz.f32 +define float @sqrt_div_ftz(float %a, float %b) #1 { + %t1 = tail call float @llvm.sqrt.f32(float %a) + %t2 = fdiv float %t1, %b + ret float %t2 +} + +; CHECK-LABEL: sqrt_div_fast_ftz( +; CHECK: sqrt.approx.ftz.f32 +; CHECK: div.approx.ftz.f32 +define float @sqrt_div_fast_ftz(float %a, float %b) #0 #1 { + %t1 = tail call float @llvm.sqrt.f32(float %a) + %t2 = fdiv float %t1, %b + ret float %t2 +} + +; There are no fast-math or ftz versions of sqrt and div for f64. We use +; x * rsqrt(x) for sqrt(x), and emit a vanilla divide. + +; CHECK-LABEL: sqrt_div_fast_ftz_f64( +; CHECK: rsqrt.approx.f64 +; CHECK: mul.f64 +; CHECK: div.rn.f64 +define double @sqrt_div_fast_ftz_f64(double %a, double %b) #0 #1 { + %t1 = tail call double @llvm.sqrt.f64(double %a) + %t2 = fdiv double %t1, %b + ret double %t2 +} + +; CHECK-LABEL: rsqrt( +; CHECK-NOT: rsqrt.approx +; CHECK: sqrt.rn.f32 +; CHECK-NOT: rsqrt.approx +define float @rsqrt(float %a) { + %b = tail call float @llvm.sqrt.f32(float %a) + %ret = fdiv float 1.0, %b + ret float %ret +} + +; CHECK-LABEL: rsqrt_fast( +; CHECK-NOT: div. +; CHECK-NOT: sqrt. +; CHECK: rsqrt.approx.f32 +; CHECK-NOT: div. +; CHECK-NOT: sqrt. +define float @rsqrt_fast(float %a) #0 { + %b = tail call float @llvm.sqrt.f32(float %a) + %ret = fdiv float 1.0, %b + ret float %ret +} + +; CHECK-LABEL: rsqrt_fast_ftz( +; CHECK-NOT: div. +; CHECK-NOT: sqrt. +; CHECK: rsqrt.approx.ftz.f32 +; CHECK-NOT: div. +; CHECK-NOT: sqrt. +define float @rsqrt_fast_ftz(float %a) #0 #1 { + %b = tail call float @llvm.sqrt.f32(float %a) + %ret = fdiv float 1.0, %b + ret float %ret +} + ; CHECK-LABEL: fadd ; CHECK: add.rn.f32 define float @fadd(float %a, float %b) { diff --git a/llvm/test/CodeGen/NVPTX/rsqrt.ll b/llvm/test/CodeGen/NVPTX/rsqrt.ll deleted file mode 100644 index 3a52a493abdd..000000000000 --- a/llvm/test/CodeGen/NVPTX/rsqrt.ll +++ /dev/null @@ -1,13 +0,0 @@ -; RUN: llc < %s -march=nvptx -mcpu=sm_20 -nvptx-prec-divf32=1 -nvptx-prec-sqrtf32=0 | FileCheck %s - -target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64" - -declare float @llvm.nvvm.sqrt.f(float) - -define float @foo(float %a) { -; CHECK: rsqrt.approx.f32 - %val = tail call float @llvm.nvvm.sqrt.f(float %a) - %ret = fdiv float 1.0, %val - ret float %ret -} - diff --git a/llvm/test/CodeGen/NVPTX/sqrt-approx.ll b/llvm/test/CodeGen/NVPTX/sqrt-approx.ll new file mode 100644 index 000000000000..5edf9e28a933 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/sqrt-approx.ll @@ -0,0 +1,148 @@ +; RUN: llc < %s -march=nvptx -mcpu=sm_20 -nvptx-prec-divf32=0 -nvptx-prec-sqrtf32=0 \ +; RUN: | FileCheck %s + +target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64" + +declare float @llvm.sqrt.f32(float) +declare double @llvm.sqrt.f64(double) + +; -- reciprocal sqrt -- + +; CHECK-LABEL test_rsqrt32 +define float @test_rsqrt32(float %a) #0 { +; CHECK: rsqrt.approx.f32 + %val = tail call float @llvm.sqrt.f32(float %a) + %ret = fdiv float 1.0, %val + ret float %ret +} + +; CHECK-LABEL test_rsqrt_ftz +define float @test_rsqrt_ftz(float %a) #0 #1 { +; CHECK: rsqrt.approx.ftz.f32 + %val = tail call float @llvm.sqrt.f32(float %a) + %ret = fdiv float 1.0, %val + ret float %ret +} + +; CHECK-LABEL test_rsqrt64 +define double @test_rsqrt64(double %a) #0 { +; CHECK: rsqrt.approx.f64 + %val = tail call double @llvm.sqrt.f64(double %a) + %ret = fdiv double 1.0, %val + ret double %ret +} + +; CHECK-LABEL test_rsqrt64_ftz +define double @test_rsqrt64_ftz(double %a) #0 #1 { +; There's no rsqrt.approx.ftz.f64 instruction; we just use the non-ftz version. +; CHECK: rsqrt.approx.f64 + %val = tail call double @llvm.sqrt.f64(double %a) + %ret = fdiv double 1.0, %val + ret double %ret +} + +; -- sqrt -- + +; CHECK-LABEL test_sqrt32 +define float @test_sqrt32(float %a) #0 { +; CHECK: sqrt.approx.f32 + %ret = tail call float @llvm.sqrt.f32(float %a) + ret float %ret +} + +; CHECK-LABEL test_sqrt_ftz +define float @test_sqrt_ftz(float %a) #0 #1 { +; CHECK: sqrt.approx.ftz.f32 + %ret = tail call float @llvm.sqrt.f32(float %a) + ret float %ret +} + +; CHECK-LABEL test_sqrt64 +define double @test_sqrt64(double %a) #0 { +; There's no sqrt.approx.f64 instruction; we emit x * rsqrt.approx.f64(x). +; CHECK: rsqrt.approx.f64 +; CHECK: mul.f64 + %ret = tail call double @llvm.sqrt.f64(double %a) + ret double %ret +} + +; CHECK-LABEL test_sqrt64_ftz +define double @test_sqrt64_ftz(double %a) #0 #1 { +; There's no sqrt.approx.ftz.f64 instruction; we just use the non-ftz version. +; CHECK: rsqrt.approx.f64 +; CHECK: mul.f64 + %ret = tail call double @llvm.sqrt.f64(double %a) + ret double %ret +} + +; -- refined sqrt and rsqrt -- +; +; The sqrt and rsqrt refinement algorithms both emit an rsqrt.approx, followed +; by some math. + +; CHECK-LABEL: test_rsqrt32_refined +define float @test_rsqrt32_refined(float %a) #0 #2 { +; CHECK: rsqrt.approx.f32 + %val = tail call float @llvm.sqrt.f32(float %a) + %ret = fdiv float 1.0, %val + ret float %ret +} + +; CHECK-LABEL: test_sqrt32_refined +define float @test_sqrt32_refined(float %a) #0 #2 { +; CHECK: rsqrt.approx.f32 + %ret = tail call float @llvm.sqrt.f32(float %a) + ret float %ret +} + +; CHECK-LABEL: test_rsqrt64_refined +define double @test_rsqrt64_refined(double %a) #0 #2 { +; CHECK: rsqrt.approx.f64 + %val = tail call double @llvm.sqrt.f64(double %a) + %ret = fdiv double 1.0, %val + ret double %ret +} + +; CHECK-LABEL: test_sqrt64_refined +define double @test_sqrt64_refined(double %a) #0 #2 { +; CHECK: rsqrt.approx.f64 + %ret = tail call double @llvm.sqrt.f64(double %a) + ret double %ret +} + +; -- refined sqrt and rsqrt with ftz enabled -- + +; CHECK-LABEL: test_rsqrt32_refined_ftz +define float @test_rsqrt32_refined_ftz(float %a) #0 #1 #2 { +; CHECK: rsqrt.approx.ftz.f32 + %val = tail call float @llvm.sqrt.f32(float %a) + %ret = fdiv float 1.0, %val + ret float %ret +} + +; CHECK-LABEL: test_sqrt32_refined_ftz +define float @test_sqrt32_refined_ftz(float %a) #0 #1 #2 { +; CHECK: rsqrt.approx.ftz.f32 + %ret = tail call float @llvm.sqrt.f32(float %a) + ret float %ret +} + +; CHECK-LABEL: test_rsqrt64_refined_ftz +define double @test_rsqrt64_refined_ftz(double %a) #0 #1 #2 { +; There's no rsqrt.approx.ftz.f64, so we just use the non-ftz version. +; CHECK: rsqrt.approx.f64 + %val = tail call double @llvm.sqrt.f64(double %a) + %ret = fdiv double 1.0, %val + ret double %ret +} + +; CHECK-LABEL: test_sqrt64_refined_ftz +define double @test_sqrt64_refined_ftz(double %a) #0 #1 #2 { +; CHECK: rsqrt.approx.f64 + %ret = tail call double @llvm.sqrt.f64(double %a) + ret double %ret +} + +attributes #0 = { "unsafe-fp-math" = "true" } +attributes #1 = { "nvptx-f32ftz" = "true" } +attributes #2 = { "reciprocal-estimates" = "rsqrtf:1,rsqrtd:1,sqrtf:1,sqrtd:1" }