[NVPTX] Implement llvm.fabs.f32, llvm.max.f32, etc.

Summary:
Previously these only worked via NVPTX-specific intrinsics.

This change will allow us to convert these target-specific intrinsics
into the general LLVM versions, allowing existing LLVM passes to reason
about their behavior.

It also gets us some minor codegen improvements as-is, from situations
where we canonicalize code into one of these llvm intrinsics.

Reviewers: majnemer

Subscribers: llvm-commits, jholewinski, tra

Differential Revision: https://reviews.llvm.org/D24300

llvm-svn: 281092
This commit is contained in:
Justin Lebar 2016-09-09 21:07:26 +00:00
parent b9e51397bf
commit b5e884976b
4 changed files with 394 additions and 17 deletions

View File

@ -279,6 +279,28 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setTargetDAGCombine(ISD::SHL);
setTargetDAGCombine(ISD::SELECT);
// Library functions. These default to Expand, but we have instructions
// for them.
setOperationAction(ISD::FCEIL, MVT::f32, Legal);
setOperationAction(ISD::FCEIL, MVT::f64, Legal);
setOperationAction(ISD::FFLOOR, MVT::f32, Legal);
setOperationAction(ISD::FFLOOR, MVT::f64, Legal);
setOperationAction(ISD::FNEARBYINT, MVT::f32, Legal);
setOperationAction(ISD::FNEARBYINT, MVT::f64, Legal);
setOperationAction(ISD::FRINT, MVT::f32, Legal);
setOperationAction(ISD::FRINT, MVT::f64, Legal);
setOperationAction(ISD::FROUND, MVT::f32, Legal);
setOperationAction(ISD::FROUND, MVT::f64, Legal);
setOperationAction(ISD::FTRUNC, MVT::f32, Legal);
setOperationAction(ISD::FTRUNC, MVT::f64, Legal);
setOperationAction(ISD::FMINNUM, MVT::f32, Legal);
setOperationAction(ISD::FMINNUM, MVT::f64, Legal);
setOperationAction(ISD::FMAXNUM, MVT::f32, Legal);
setOperationAction(ISD::FMAXNUM, MVT::f64, Legal);
// No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate.
// No FPOW or FREM in PTX.
// Now deduce the information based on the above mentioned
// actions
computeRegisterProperties(STI.getRegisterInfo());

View File

@ -207,11 +207,59 @@ multiclass ADD_SUB_INT_32<string OpcStr, SDNode OpNode> {
}
// Template for instructions which take three fp64 or fp32 args. The
// instructions are named "<OpcStr>.f<Width>" (e.g. "add.f64").
// instructions are named "<OpcStr>.f<Width>" (e.g. "min.f64").
//
// Also defines ftz (flush subnormal inputs and results to sign-preserving
// zero) variants for fp32 functions.
//
// This multiclass should be used for nodes that cannot be folded into FMAs.
// For nodes that can be folded into FMAs (i.e. adds and muls), use
// F3_fma_component.
multiclass F3<string OpcStr, SDNode OpNode> {
def f64rr :
NVPTXInst<(outs Float64Regs:$dst),
(ins Float64Regs:$a, Float64Regs:$b),
!strconcat(OpcStr, ".f64 \t$dst, $a, $b;"),
[(set Float64Regs:$dst, (OpNode Float64Regs:$a, Float64Regs:$b))]>;
def f64ri :
NVPTXInst<(outs Float64Regs:$dst),
(ins Float64Regs:$a, f64imm:$b),
!strconcat(OpcStr, ".f64 \t$dst, $a, $b;"),
[(set Float64Regs:$dst, (OpNode Float64Regs:$a, fpimm:$b))]>;
def f32rr_ftz :
NVPTXInst<(outs Float32Regs:$dst),
(ins Float32Regs:$a, Float32Regs:$b),
!strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b;"),
[(set Float32Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]>,
Requires<[doF32FTZ]>;
def f32ri_ftz :
NVPTXInst<(outs Float32Regs:$dst),
(ins Float32Regs:$a, f32imm:$b),
!strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b;"),
[(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>,
Requires<[doF32FTZ]>;
def f32rr :
NVPTXInst<(outs Float32Regs:$dst),
(ins Float32Regs:$a, Float32Regs:$b),
!strconcat(OpcStr, ".f32 \t$dst, $a, $b;"),
[(set Float32Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]>;
def f32ri :
NVPTXInst<(outs Float32Regs:$dst),
(ins Float32Regs:$a, f32imm:$b),
!strconcat(OpcStr, ".f32 \t$dst, $a, $b;"),
[(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>;
}
// Template for instructions which take three fp64 or fp32 args. The
// instructions are named "<OpcStr>.f<Width>" (e.g. "add.f64").
//
// Also defines ftz (flush subnormal inputs and results to sign-preserving
// zero) variants for fp32 functions.
//
// This multiclass should be used for nodes that can be folded to make fma ops.
// In this case, we use the ".rn" variant when FMA is disabled, as this behaves
// just like the non ".rn" op, but prevents ptxas from creating FMAs.
multiclass F3_fma_component<string OpcStr, SDNode OpNode> {
def f64rr :
NVPTXInst<(outs Float64Regs:$dst),
(ins Float64Regs:$a, Float64Regs:$b),
@ -248,41 +296,39 @@ multiclass F3<string OpcStr, SDNode OpNode> {
!strconcat(OpcStr, ".f32 \t$dst, $a, $b;"),
[(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>,
Requires<[allowFMA]>;
}
// Same as F3, but defines ".rn" variants (round to nearest even).
multiclass F3_rn<string OpcStr, SDNode OpNode> {
def f64rr :
// These have strange names so we don't perturb existing mir tests.
def _rnf64rr :
NVPTXInst<(outs Float64Regs:$dst),
(ins Float64Regs:$a, Float64Regs:$b),
!strconcat(OpcStr, ".rn.f64 \t$dst, $a, $b;"),
[(set Float64Regs:$dst, (OpNode Float64Regs:$a, Float64Regs:$b))]>,
Requires<[noFMA]>;
def f64ri :
def _rnf64ri :
NVPTXInst<(outs Float64Regs:$dst),
(ins Float64Regs:$a, f64imm:$b),
!strconcat(OpcStr, ".rn.f64 \t$dst, $a, $b;"),
[(set Float64Regs:$dst, (OpNode Float64Regs:$a, fpimm:$b))]>,
Requires<[noFMA]>;
def f32rr_ftz :
def _rnf32rr_ftz :
NVPTXInst<(outs Float32Regs:$dst),
(ins Float32Regs:$a, Float32Regs:$b),
!strconcat(OpcStr, ".rn.ftz.f32 \t$dst, $a, $b;"),
[(set Float32Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]>,
Requires<[noFMA, doF32FTZ]>;
def f32ri_ftz :
def _rnf32ri_ftz :
NVPTXInst<(outs Float32Regs:$dst),
(ins Float32Regs:$a, f32imm:$b),
!strconcat(OpcStr, ".rn.ftz.f32 \t$dst, $a, $b;"),
[(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>,
Requires<[noFMA, doF32FTZ]>;
def f32rr :
def _rnf32rr :
NVPTXInst<(outs Float32Regs:$dst),
(ins Float32Regs:$a, Float32Regs:$b),
!strconcat(OpcStr, ".rn.f32 \t$dst, $a, $b;"),
[(set Float32Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]>,
Requires<[noFMA]>;
def f32ri :
def _rnf32ri :
NVPTXInst<(outs Float32Regs:$dst),
(ins Float32Regs:$a, f32imm:$b),
!strconcat(OpcStr, ".rn.f32 \t$dst, $a, $b;"),
@ -713,13 +759,12 @@ def DoubleConst1 : PatLeaf<(fpimm), [{
N->getValueAPF().convertToDouble() == 1.0;
}]>;
defm FADD : F3<"add", fadd>;
defm FSUB : F3<"sub", fsub>;
defm FMUL : F3<"mul", fmul>;
defm FADD : F3_fma_component<"add", fadd>;
defm FSUB : F3_fma_component<"sub", fsub>;
defm FMUL : F3_fma_component<"mul", fmul>;
defm FADD_rn : F3_rn<"add", fadd>;
defm FSUB_rn : F3_rn<"sub", fsub>;
defm FMUL_rn : F3_rn<"mul", fmul>;
defm FMIN : F3<"min", fminnum>;
defm FMAX : F3<"max", fmaxnum>;
defm FABS : F2<"abs", fabs>;
defm FNEG : F2<"neg", fneg>;
@ -2628,6 +2673,55 @@ def : Pat<(f64 (fpextend Float32Regs:$a)),
def retflag : SDNode<"NVPTXISD::RET_FLAG", SDTNone,
[SDNPHasChain, SDNPOptInGlue]>;
// fceil, ffloor, fround, ftrunc.
def : Pat<(fceil Float32Regs:$a),
(CVT_f32_f32 Float32Regs:$a, CvtRPI_FTZ)>, Requires<[doF32FTZ]>;
def : Pat<(fceil Float32Regs:$a),
(CVT_f32_f32 Float32Regs:$a, CvtRPI)>, Requires<[doNoF32FTZ]>;
def : Pat<(fceil Float64Regs:$a),
(CVT_f64_f64 Float64Regs:$a, CvtRPI)>;
def : Pat<(ffloor Float32Regs:$a),
(CVT_f32_f32 Float32Regs:$a, CvtRMI_FTZ)>, Requires<[doF32FTZ]>;
def : Pat<(ffloor Float32Regs:$a),
(CVT_f32_f32 Float32Regs:$a, CvtRMI)>, Requires<[doNoF32FTZ]>;
def : Pat<(ffloor Float64Regs:$a),
(CVT_f64_f64 Float64Regs:$a, CvtRMI)>;
def : Pat<(fround Float32Regs:$a),
(CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>;
def : Pat<(f32 (fround Float32Regs:$a)),
(CVT_f32_f32 Float32Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>;
def : Pat<(f64 (fround Float64Regs:$a)),
(CVT_f64_f64 Float64Regs:$a, CvtRNI)>;
def : Pat<(ftrunc Float32Regs:$a),
(CVT_f32_f32 Float32Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>;
def : Pat<(ftrunc Float32Regs:$a),
(CVT_f32_f32 Float32Regs:$a, CvtRZI)>, Requires<[doNoF32FTZ]>;
def : Pat<(ftrunc Float64Regs:$a),
(CVT_f64_f64 Float64Regs:$a, CvtRZI)>;
// nearbyint and rint are implemented as rounding to nearest even. This isn't
// strictly correct, because it causes us to ignore the rounding mode. But it
// matches what CUDA's "libm" does.
def : Pat<(fnearbyint Float32Regs:$a),
(CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>;
def : Pat<(fnearbyint Float32Regs:$a),
(CVT_f32_f32 Float32Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>;
def : Pat<(fnearbyint Float64Regs:$a),
(CVT_f64_f64 Float64Regs:$a, CvtRNI)>;
def : Pat<(frint Float32Regs:$a),
(CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>;
def : Pat<(frint Float32Regs:$a),
(CVT_f32_f32 Float32Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>;
def : Pat<(frint Float64Regs:$a),
(CVT_f64_f64 Float64Regs:$a, CvtRNI)>;
//-----------------------------------
// Control-flow
//-----------------------------------

View File

@ -22,7 +22,7 @@ _ZL11compute_vecRK6float3jb.exit:
%8 = icmp eq i32 %7, 0
%9 = select i1 %8, float 0.000000e+00, float -1.000000e+00
store float %9, float* %ret_vec.sroa.8.i, align 4
; CHECK: setp.lt.f32 %p{{[0-9]+}}, %f{{[0-9]+}}, 0f00000000
; CHECK: max.f32 %f{{[0-9]+}}, %f{{[0-9]+}}, 0f00000000
%10 = fcmp olt float %9, 0.000000e+00
%ret_vec.sroa.8.i.val = load float, float* %ret_vec.sroa.8.i, align 4
%11 = select i1 %10, float 0.000000e+00, float %ret_vec.sroa.8.i.val

View File

@ -0,0 +1,261 @@
; RUN: llc < %s | FileCheck %s
target triple = "nvptx64-nvidia-cuda"
; Checks that llvm intrinsics for math functions are correctly lowered to PTX.
declare float @llvm.ceil.f32(float) #0
declare double @llvm.ceil.f64(double) #0
declare float @llvm.floor.f32(float) #0
declare double @llvm.floor.f64(double) #0
declare float @llvm.round.f32(float) #0
declare double @llvm.round.f64(double) #0
declare float @llvm.nearbyint.f32(float) #0
declare double @llvm.nearbyint.f64(double) #0
declare float @llvm.rint.f32(float) #0
declare double @llvm.rint.f64(double) #0
declare float @llvm.trunc.f32(float) #0
declare double @llvm.trunc.f64(double) #0
declare float @llvm.fabs.f32(float) #0
declare double @llvm.fabs.f64(double) #0
declare float @llvm.minnum.f32(float, float) #0
declare double @llvm.minnum.f64(double, double) #0
declare float @llvm.maxnum.f32(float, float) #0
declare double @llvm.maxnum.f64(double, double) #0
; ---- ceil ----
; CHECK-LABEL: ceil_float
define float @ceil_float(float %a) {
; CHECK: cvt.rpi.f32.f32
%b = call float @llvm.ceil.f32(float %a)
ret float %b
}
; CHECK-LABEL: ceil_float_ftz
define float @ceil_float_ftz(float %a) #1 {
; CHECK: cvt.rpi.ftz.f32.f32
%b = call float @llvm.ceil.f32(float %a)
ret float %b
}
; CHECK-LABEL: ceil_double
define double @ceil_double(double %a) {
; CHECK: cvt.rpi.f64.f64
%b = call double @llvm.ceil.f64(double %a)
ret double %b
}
; ---- floor ----
; CHECK-LABEL: floor_float
define float @floor_float(float %a) {
; CHECK: cvt.rmi.f32.f32
%b = call float @llvm.floor.f32(float %a)
ret float %b
}
; CHECK-LABEL: floor_float_ftz
define float @floor_float_ftz(float %a) #1 {
; CHECK: cvt.rmi.ftz.f32.f32
%b = call float @llvm.floor.f32(float %a)
ret float %b
}
; CHECK-LABEL: floor_double
define double @floor_double(double %a) {
; CHECK: cvt.rmi.f64.f64
%b = call double @llvm.floor.f64(double %a)
ret double %b
}
; ---- round ----
; CHECK-LABEL: round_float
define float @round_float(float %a) {
; CHECK: cvt.rni.f32.f32
%b = call float @llvm.round.f32(float %a)
ret float %b
}
; CHECK-LABEL: round_float_ftz
define float @round_float_ftz(float %a) #1 {
; CHECK: cvt.rni.ftz.f32.f32
%b = call float @llvm.round.f32(float %a)
ret float %b
}
; CHECK-LABEL: round_double
define double @round_double(double %a) {
; CHECK: cvt.rni.f64.f64
%b = call double @llvm.round.f64(double %a)
ret double %b
}
; ---- nearbyint ----
; CHECK-LABEL: nearbyint_float
define float @nearbyint_float(float %a) {
; CHECK: cvt.rni.f32.f32
%b = call float @llvm.nearbyint.f32(float %a)
ret float %b
}
; CHECK-LABEL: nearbyint_float_ftz
define float @nearbyint_float_ftz(float %a) #1 {
; CHECK: cvt.rni.ftz.f32.f32
%b = call float @llvm.nearbyint.f32(float %a)
ret float %b
}
; CHECK-LABEL: nearbyint_double
define double @nearbyint_double(double %a) {
; CHECK: cvt.rni.f64.f64
%b = call double @llvm.nearbyint.f64(double %a)
ret double %b
}
; ---- rint ----
; CHECK-LABEL: rint_float
define float @rint_float(float %a) {
; CHECK: cvt.rni.f32.f32
%b = call float @llvm.rint.f32(float %a)
ret float %b
}
; CHECK-LABEL: rint_float_ftz
define float @rint_float_ftz(float %a) #1 {
; CHECK: cvt.rni.ftz.f32.f32
%b = call float @llvm.rint.f32(float %a)
ret float %b
}
; CHECK-LABEL: rint_double
define double @rint_double(double %a) {
; CHECK: cvt.rni.f64.f64
%b = call double @llvm.rint.f64(double %a)
ret double %b
}
; ---- trunc ----
; CHECK-LABEL: trunc_float
define float @trunc_float(float %a) {
; CHECK: cvt.rzi.f32.f32
%b = call float @llvm.trunc.f32(float %a)
ret float %b
}
; CHECK-LABEL: trunc_float_ftz
define float @trunc_float_ftz(float %a) #1 {
; CHECK: cvt.rzi.ftz.f32.f32
%b = call float @llvm.trunc.f32(float %a)
ret float %b
}
; CHECK-LABEL: trunc_double
define double @trunc_double(double %a) {
; CHECK: cvt.rzi.f64.f64
%b = call double @llvm.trunc.f64(double %a)
ret double %b
}
; ---- abs ----
; CHECK-LABEL: abs_float
define float @abs_float(float %a) {
; CHECK: abs.f32
%b = call float @llvm.fabs.f32(float %a)
ret float %b
}
; CHECK-LABEL: abs_float_ftz
define float @abs_float_ftz(float %a) #1 {
; CHECK: abs.ftz.f32
%b = call float @llvm.fabs.f32(float %a)
ret float %b
}
; CHECK-LABEL: abs_double
define double @abs_double(double %a) {
; CHECK: abs.f64
%b = call double @llvm.fabs.f64(double %a)
ret double %b
}
; ---- min ----
; CHECK-LABEL: min_float
define float @min_float(float %a, float %b) {
; CHECK: min.f32
%x = call float @llvm.minnum.f32(float %a, float %b)
ret float %x
}
; CHECK-LABEL: min_imm1
define float @min_imm1(float %a) {
; CHECK: min.f32
%x = call float @llvm.minnum.f32(float %a, float 0.0)
ret float %x
}
; CHECK-LABEL: min_imm2
define float @min_imm2(float %a) {
; CHECK: min.f32
%x = call float @llvm.minnum.f32(float 0.0, float %a)
ret float %x
}
; CHECK-LABEL: min_float_ftz
define float @min_float_ftz(float %a, float %b) #1 {
; CHECK: min.ftz.f32
%x = call float @llvm.minnum.f32(float %a, float %b)
ret float %x
}
; CHECK-LABEL: min_double
define double @min_double(double %a, double %b) {
; CHECK: min.f64
%x = call double @llvm.minnum.f64(double %a, double %b)
ret double %x
}
; ---- max ----
; CHECK-LABEL: max_imm1
define float @max_imm1(float %a) {
; CHECK: max.f32
%x = call float @llvm.maxnum.f32(float %a, float 0.0)
ret float %x
}
; CHECK-LABEL: max_imm2
define float @max_imm2(float %a) {
; CHECK: max.f32
%x = call float @llvm.maxnum.f32(float 0.0, float %a)
ret float %x
}
; CHECK-LABEL: max_float
define float @max_float(float %a, float %b) {
; CHECK: max.f32
%x = call float @llvm.maxnum.f32(float %a, float %b)
ret float %x
}
; CHECK-LABEL: max_float_ftz
define float @max_float_ftz(float %a, float %b) #1 {
; CHECK: max.ftz.f32
%x = call float @llvm.maxnum.f32(float %a, float %b)
ret float %x
}
; CHECK-LABEL: max_double
define double @max_double(double %a, double %b) {
; CHECK: max.f64
%x = call double @llvm.maxnum.f64(double %a, double %b)
ret double %x
}
attributes #0 = { nounwind readnone }
attributes #1 = { "nvptx-f32ftz" = "true" }