Fix PR15296

- Move SRA/SRL/SHL lowering support from DAG combination to DAG lowering
  to support extended 256-bit integer in AVX but not AVX2.

llvm-svn: 177478
This commit is contained in:
Michael Liao 2013-03-20 02:33:21 +00:00
parent 5a4e81d2e8
commit 0f4ea0c4a9
2 changed files with 245 additions and 119 deletions

View File

@ -11595,6 +11595,188 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
}
}
// Special case in 32-bit mode, where i64 is expanded into high and low parts.
if (!Subtarget->is64Bit() &&
(VT == MVT::v2i64 || (Subtarget->hasInt256() && VT == MVT::v4i64)) &&
Amt.getOpcode() == ISD::BITCAST &&
Amt.getOperand(0).getOpcode() == ISD::BUILD_VECTOR) {
Amt = Amt.getOperand(0);
unsigned Ratio = Amt.getValueType().getVectorNumElements() /
VT.getVectorNumElements();
unsigned RatioInLog2 = Log2_32_Ceil(Ratio);
uint64_t ShiftAmt = 0;
for (unsigned i = 0; i != Ratio; ++i) {
ConstantSDNode *C = dyn_cast<ConstantSDNode>(Amt.getOperand(i));
if (C == 0)
return SDValue();
// 6 == Log2(64)
ShiftAmt |= C->getZExtValue() << (i * (1 << (6 - RatioInLog2)));
}
// Check remaining shift amounts.
for (unsigned i = Ratio; i != Amt.getNumOperands(); i += Ratio) {
uint64_t ShAmt = 0;
for (unsigned j = 0; j != Ratio; ++j) {
ConstantSDNode *C =
dyn_cast<ConstantSDNode>(Amt.getOperand(i + j));
if (C == 0)
return SDValue();
// 6 == Log2(64)
ShAmt |= C->getZExtValue() << (j * (1 << (6 - RatioInLog2)));
}
if (ShAmt != ShiftAmt)
return SDValue();
}
switch (Op.getOpcode()) {
default:
llvm_unreachable("Unknown shift opcode!");
case ISD::SHL:
return DAG.getNode(X86ISD::VSHLI, dl, VT, R,
DAG.getConstant(ShiftAmt, MVT::i32));
case ISD::SRL:
return DAG.getNode(X86ISD::VSRLI, dl, VT, R,
DAG.getConstant(ShiftAmt, MVT::i32));
case ISD::SRA:
return DAG.getNode(X86ISD::VSRAI, dl, VT, R,
DAG.getConstant(ShiftAmt, MVT::i32));
}
}
return SDValue();
}
static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG,
const X86Subtarget* Subtarget) {
EVT VT = Op.getValueType();
DebugLoc dl = Op.getDebugLoc();
SDValue R = Op.getOperand(0);
SDValue Amt = Op.getOperand(1);
if ((VT == MVT::v2i64 && Op.getOpcode() != ISD::SRA) ||
VT == MVT::v4i32 || VT == MVT::v8i16 ||
(Subtarget->hasInt256() &&
((VT == MVT::v4i64 && Op.getOpcode() != ISD::SRA) ||
VT == MVT::v8i32 || VT == MVT::v16i16))) {
SDValue BaseShAmt;
EVT EltVT = VT.getVectorElementType();
if (Amt.getOpcode() == ISD::BUILD_VECTOR) {
unsigned NumElts = VT.getVectorNumElements();
unsigned i, j;
for (i = 0; i != NumElts; ++i) {
if (Amt.getOperand(i).getOpcode() == ISD::UNDEF)
continue;
break;
}
for (j = i; j != NumElts; ++j) {
SDValue Arg = Amt.getOperand(j);
if (Arg.getOpcode() == ISD::UNDEF) continue;
if (Arg != Amt.getOperand(i))
break;
}
if (i != NumElts && j == NumElts)
BaseShAmt = Amt.getOperand(i);
} else {
if (Amt.getOpcode() == ISD::EXTRACT_SUBVECTOR)
Amt = Amt.getOperand(0);
if (Amt.getOpcode() == ISD::VECTOR_SHUFFLE &&
cast<ShuffleVectorSDNode>(Amt)->isSplat()) {
SDValue InVec = Amt.getOperand(0);
if (InVec.getOpcode() == ISD::BUILD_VECTOR) {
unsigned NumElts = InVec.getValueType().getVectorNumElements();
unsigned i = 0;
for (; i != NumElts; ++i) {
SDValue Arg = InVec.getOperand(i);
if (Arg.getOpcode() == ISD::UNDEF) continue;
BaseShAmt = Arg;
break;
}
} else if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT) {
if (ConstantSDNode *C =
dyn_cast<ConstantSDNode>(InVec.getOperand(2))) {
unsigned SplatIdx =
cast<ShuffleVectorSDNode>(Amt)->getSplatIndex();
if (C->getZExtValue() == SplatIdx)
BaseShAmt = InVec.getOperand(1);
}
}
if (BaseShAmt.getNode() == 0)
BaseShAmt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Amt,
DAG.getIntPtrConstant(0));
}
}
if (BaseShAmt.getNode()) {
if (EltVT.bitsGT(MVT::i32))
BaseShAmt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, BaseShAmt);
else if (EltVT.bitsLT(MVT::i32))
BaseShAmt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, BaseShAmt);
switch (Op.getOpcode()) {
default:
llvm_unreachable("Unknown shift opcode!");
case ISD::SHL:
switch (VT.getSimpleVT().SimpleTy) {
default: return SDValue();
case MVT::v2i64:
case MVT::v4i32:
case MVT::v8i16:
case MVT::v4i64:
case MVT::v8i32:
case MVT::v16i16:
return getTargetVShiftNode(X86ISD::VSHLI, dl, VT, R, BaseShAmt, DAG);
}
case ISD::SRA:
switch (VT.getSimpleVT().SimpleTy) {
default: return SDValue();
case MVT::v4i32:
case MVT::v8i16:
case MVT::v8i32:
case MVT::v16i16:
return getTargetVShiftNode(X86ISD::VSRAI, dl, VT, R, BaseShAmt, DAG);
}
case ISD::SRL:
switch (VT.getSimpleVT().SimpleTy) {
default: return SDValue();
case MVT::v2i64:
case MVT::v4i32:
case MVT::v8i16:
case MVT::v4i64:
case MVT::v8i32:
case MVT::v16i16:
return getTargetVShiftNode(X86ISD::VSRLI, dl, VT, R, BaseShAmt, DAG);
}
}
}
}
// Special case in 32-bit mode, where i64 is expanded into high and low parts.
if (!Subtarget->is64Bit() &&
(VT == MVT::v2i64 || (Subtarget->hasInt256() && VT == MVT::v4i64)) &&
Amt.getOpcode() == ISD::BITCAST &&
Amt.getOperand(0).getOpcode() == ISD::BUILD_VECTOR) {
Amt = Amt.getOperand(0);
unsigned Ratio = Amt.getValueType().getVectorNumElements() /
VT.getVectorNumElements();
std::vector<SDValue> Vals(Ratio);
for (unsigned i = 0; i != Ratio; ++i)
Vals[i] = Amt.getOperand(i);
for (unsigned i = Ratio; i != Amt.getNumOperands(); i += Ratio) {
for (unsigned j = 0; j != Ratio; ++j)
if (Vals[j] != Amt.getOperand(i + j))
return SDValue();
}
switch (Op.getOpcode()) {
default:
llvm_unreachable("Unknown shift opcode!");
case ISD::SHL:
return DAG.getNode(X86ISD::VSHL, dl, VT, R, Op.getOperand(1));
case ISD::SRL:
return DAG.getNode(X86ISD::VSRL, dl, VT, R, Op.getOperand(1));
case ISD::SRA:
return DAG.getNode(X86ISD::VSRA, dl, VT, R, Op.getOperand(1));
}
}
return SDValue();
}
@ -11613,6 +11795,10 @@ SDValue X86TargetLowering::LowerShift(SDValue Op, SelectionDAG &DAG) const {
if (V.getNode())
return V;
V = LowerScalarVariableShift(Op, DAG, Subtarget);
if (V.getNode())
return V;
// AVX2 has VPSLLV/VPSRAV/VPSRLV.
if (Subtarget->hasInt256()) {
if (Op.getOpcode() == ISD::SRL &&
@ -15951,124 +16137,12 @@ static SDValue PerformSHLCombine(SDNode *N, SelectionDAG &DAG) {
static SDValue PerformShiftCombine(SDNode* N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget *Subtarget) {
EVT VT = N->getValueType(0);
if (N->getOpcode() == ISD::SHL) {
SDValue V = PerformSHLCombine(N, DAG);
if (V.getNode()) return V;
}
// On X86 with SSE2 support, we can transform this to a vector shift if
// all elements are shifted by the same amount. We can't do this in legalize
// because the a constant vector is typically transformed to a constant pool
// so we have no knowledge of the shift amount.
if (!Subtarget->hasSSE2())
return SDValue();
if (VT != MVT::v2i64 && VT != MVT::v4i32 && VT != MVT::v8i16 &&
(!Subtarget->hasInt256() ||
(VT != MVT::v4i64 && VT != MVT::v8i32 && VT != MVT::v16i16)))
return SDValue();
SDValue ShAmtOp = N->getOperand(1);
EVT EltVT = VT.getVectorElementType();
DebugLoc DL = N->getDebugLoc();
SDValue BaseShAmt = SDValue();
if (ShAmtOp.getOpcode() == ISD::BUILD_VECTOR) {
unsigned NumElts = VT.getVectorNumElements();
unsigned i = 0;
for (; i != NumElts; ++i) {
SDValue Arg = ShAmtOp.getOperand(i);
if (Arg.getOpcode() == ISD::UNDEF) continue;
BaseShAmt = Arg;
break;
}
// Handle the case where the build_vector is all undef
// FIXME: Should DAG allow this?
if (i == NumElts)
return SDValue();
for (; i != NumElts; ++i) {
SDValue Arg = ShAmtOp.getOperand(i);
if (Arg.getOpcode() == ISD::UNDEF) continue;
if (Arg != BaseShAmt) {
return SDValue();
}
}
} else if (ShAmtOp.getOpcode() == ISD::VECTOR_SHUFFLE &&
cast<ShuffleVectorSDNode>(ShAmtOp)->isSplat()) {
SDValue InVec = ShAmtOp.getOperand(0);
if (InVec.getOpcode() == ISD::BUILD_VECTOR) {
unsigned NumElts = InVec.getValueType().getVectorNumElements();
unsigned i = 0;
for (; i != NumElts; ++i) {
SDValue Arg = InVec.getOperand(i);
if (Arg.getOpcode() == ISD::UNDEF) continue;
BaseShAmt = Arg;
break;
}
} else if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT) {
if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(InVec.getOperand(2))) {
unsigned SplatIdx= cast<ShuffleVectorSDNode>(ShAmtOp)->getSplatIndex();
if (C->getZExtValue() == SplatIdx)
BaseShAmt = InVec.getOperand(1);
}
}
if (BaseShAmt.getNode() == 0) {
// Don't create instructions with illegal types after legalize
// types has run.
if (!DAG.getTargetLoweringInfo().isTypeLegal(EltVT) &&
!DCI.isBeforeLegalize())
return SDValue();
BaseShAmt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, ShAmtOp,
DAG.getIntPtrConstant(0));
}
} else
return SDValue();
// The shift amount is an i32.
if (EltVT.bitsGT(MVT::i32))
BaseShAmt = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, BaseShAmt);
else if (EltVT.bitsLT(MVT::i32))
BaseShAmt = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, BaseShAmt);
// The shift amount is identical so we can do a vector shift.
SDValue ValOp = N->getOperand(0);
switch (N->getOpcode()) {
default:
llvm_unreachable("Unknown shift opcode!");
case ISD::SHL:
switch (VT.getSimpleVT().SimpleTy) {
default: return SDValue();
case MVT::v2i64:
case MVT::v4i32:
case MVT::v8i16:
case MVT::v4i64:
case MVT::v8i32:
case MVT::v16i16:
return getTargetVShiftNode(X86ISD::VSHLI, DL, VT, ValOp, BaseShAmt, DAG);
}
case ISD::SRA:
switch (VT.getSimpleVT().SimpleTy) {
default: return SDValue();
case MVT::v4i32:
case MVT::v8i16:
case MVT::v8i32:
case MVT::v16i16:
return getTargetVShiftNode(X86ISD::VSRAI, DL, VT, ValOp, BaseShAmt, DAG);
}
case ISD::SRL:
switch (VT.getSimpleVT().SimpleTy) {
default: return SDValue();
case MVT::v2i64:
case MVT::v4i32:
case MVT::v8i16:
case MVT::v4i64:
case MVT::v8i32:
case MVT::v16i16:
return getTargetVShiftNode(X86ISD::VSRLI, DL, VT, ValOp, BaseShAmt, DAG);
}
}
return SDValue();
}
// CMPEQCombine - Recognize the distinctive (AND (setcc ...) (setcc ..))
@ -16379,13 +16453,19 @@ static SDValue PerformOrCombine(SDNode *N, SelectionDAG &DAG,
// Validate that the Mask operand is a vector sra node.
// FIXME: what to do for bytes, since there is a psignb/pblendvb, but
// there is no psrai.b
if (Mask.getOpcode() != X86ISD::VSRAI)
return SDValue();
// Check that the SRA is all signbits.
SDValue SraC = Mask.getOperand(1);
unsigned SraAmt = cast<ConstantSDNode>(SraC)->getZExtValue();
unsigned EltBits = MaskVT.getVectorElementType().getSizeInBits();
unsigned SraAmt = ~0;
if (Mask.getOpcode() == ISD::SRA) {
SDValue Amt = Mask.getOperand(1);
if (isSplatVector(Amt.getNode())) {
SDValue SclrAmt = Amt->getOperand(0);
if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(SclrAmt))
SraAmt = C->getZExtValue();
}
} else if (Mask.getOpcode() == X86ISD::VSRAI) {
SDValue SraC = Mask.getOperand(1);
SraAmt = cast<ConstantSDNode>(SraC)->getZExtValue();
}
if ((SraAmt + 1) != EltBits)
return SDValue();

View File

@ -0,0 +1,46 @@
; RUN: llc < %s -mtriple=i686-pc-linux -mcpu=corei7-avx | FileCheck %s
define <8 x i32> @shiftInput___vyuunu(<8 x i32> %input, i32 %shiftval, <8 x i32> %__mask) nounwind {
allocas:
%smear.0 = insertelement <8 x i32> undef, i32 %shiftval, i32 0
%smear.1 = insertelement <8 x i32> %smear.0, i32 %shiftval, i32 1
%smear.2 = insertelement <8 x i32> %smear.1, i32 %shiftval, i32 2
%smear.3 = insertelement <8 x i32> %smear.2, i32 %shiftval, i32 3
%smear.4 = insertelement <8 x i32> %smear.3, i32 %shiftval, i32 4
%smear.5 = insertelement <8 x i32> %smear.4, i32 %shiftval, i32 5
%smear.6 = insertelement <8 x i32> %smear.5, i32 %shiftval, i32 6
%smear.7 = insertelement <8 x i32> %smear.6, i32 %shiftval, i32 7
%bitop = lshr <8 x i32> %input, %smear.7
ret <8 x i32> %bitop
}
; CHECK: shiftInput___vyuunu
; CHECK: psrld
; CHECK: psrld
; CHECK: ret
define <8 x i32> @shiftInput___canonical(<8 x i32> %input, i32 %shiftval, <8 x i32> %__mask) nounwind {
allocas:
%smear.0 = insertelement <8 x i32> undef, i32 %shiftval, i32 0
%smear.7 = shufflevector <8 x i32> %smear.0, <8 x i32> undef, <8 x i32> zeroinitializer
%bitop = lshr <8 x i32> %input, %smear.7
ret <8 x i32> %bitop
}
; CHECK: shiftInput___canonical
; CHECK: psrld
; CHECK: psrld
; CHECK: ret
define <4 x i64> @shiftInput___64in32bitmode(<4 x i64> %input, i64 %shiftval, <4 x i64> %__mask) nounwind {
allocas:
%smear.0 = insertelement <4 x i64> undef, i64 %shiftval, i32 0
%smear.7 = shufflevector <4 x i64> %smear.0, <4 x i64> undef, <4 x i32> zeroinitializer
%bitop = lshr <4 x i64> %input, %smear.7
ret <4 x i64> %bitop
}
; CHECK: shiftInput___64in32bitmode
; CHECK: psrlq
; CHECK: psrlq
; CHECK: ret