[X86][SSE] LowerScalarImmediateShift - use getTargetConstantBitsFromNode to get immediate data

Don't just attempt to find a splat build vector.

First step towards getting rid of all the 32-bit special case code.

llvm-svn: 343383
This commit is contained in:
Simon Pilgrim 2018-09-29 16:40:35 +00:00
parent fb1b80191e
commit ae34ae12ef
1 changed files with 71 additions and 61 deletions

View File

@ -23422,6 +23422,7 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
SDLoc dl(Op); SDLoc dl(Op);
SDValue R = Op.getOperand(0); SDValue R = Op.getOperand(0);
SDValue Amt = Op.getOperand(1); SDValue Amt = Op.getOperand(1);
unsigned EltSizeInBits = VT.getScalarSizeInBits();
unsigned X86Opc = getTargetVShiftUniformOpcode(Op.getOpcode(), false); unsigned X86Opc = getTargetVShiftUniformOpcode(Op.getOpcode(), false);
auto ArithmeticShiftRight64 = [&](uint64_t ShiftAmt) { auto ArithmeticShiftRight64 = [&](uint64_t ShiftAmt) {
@ -23465,74 +23466,83 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
}; };
// Optimize shl/srl/sra with constant shift amount. // Optimize shl/srl/sra with constant shift amount.
if (auto *BVAmt = dyn_cast<BuildVectorSDNode>(Amt)) { APInt UndefElts;
if (auto *ShiftConst = BVAmt->getConstantSplatNode()) { SmallVector<APInt, 8> EltBits;
uint64_t ShiftAmt = ShiftConst->getZExtValue(); if (getTargetConstantBitsFromNode(Amt, EltSizeInBits, UndefElts, EltBits,
true, false)) {
int SplatIndex = -1;
for (int i = 0, e = VT.getVectorNumElements(); i != e; ++i) {
if (UndefElts[i])
continue;
if (0 <= SplatIndex && EltBits[i] != EltBits[SplatIndex])
return SDValue();
SplatIndex = i;
}
if (SplatIndex < 0)
return SDValue();
if (SupportedVectorShiftWithImm(VT, Subtarget, Op.getOpcode())) uint64_t ShiftAmt = EltBits[SplatIndex].getZExtValue();
return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG); if (SupportedVectorShiftWithImm(VT, Subtarget, Op.getOpcode()))
return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG);
// i64 SRA needs to be performed as partial shifts. // i64 SRA needs to be performed as partial shifts.
if (((!Subtarget.hasXOP() && VT == MVT::v2i64) || if (((!Subtarget.hasXOP() && VT == MVT::v2i64) ||
(Subtarget.hasInt256() && VT == MVT::v4i64)) && (Subtarget.hasInt256() && VT == MVT::v4i64)) &&
Op.getOpcode() == ISD::SRA) Op.getOpcode() == ISD::SRA)
return ArithmeticShiftRight64(ShiftAmt); return ArithmeticShiftRight64(ShiftAmt);
if (VT == MVT::v16i8 || if (VT == MVT::v16i8 || (Subtarget.hasInt256() && VT == MVT::v32i8) ||
(Subtarget.hasInt256() && VT == MVT::v32i8) || VT == MVT::v64i8) {
VT == MVT::v64i8) { unsigned NumElts = VT.getVectorNumElements();
unsigned NumElts = VT.getVectorNumElements(); MVT ShiftVT = MVT::getVectorVT(MVT::i16, NumElts / 2);
MVT ShiftVT = MVT::getVectorVT(MVT::i16, NumElts / 2);
// Simple i8 add case // Simple i8 add case
if (Op.getOpcode() == ISD::SHL && ShiftAmt == 1) if (Op.getOpcode() == ISD::SHL && ShiftAmt == 1)
return DAG.getNode(ISD::ADD, dl, VT, R, R); return DAG.getNode(ISD::ADD, dl, VT, R, R);
// ashr(R, 7) === cmp_slt(R, 0) // ashr(R, 7) === cmp_slt(R, 0)
if (Op.getOpcode() == ISD::SRA && ShiftAmt == 7) { if (Op.getOpcode() == ISD::SRA && ShiftAmt == 7) {
SDValue Zeros = getZeroVector(VT, Subtarget, DAG, dl); SDValue Zeros = getZeroVector(VT, Subtarget, DAG, dl);
if (VT.is512BitVector()) { if (VT.is512BitVector()) {
assert(VT == MVT::v64i8 && "Unexpected element type!"); assert(VT == MVT::v64i8 && "Unexpected element type!");
SDValue CMP = DAG.getSetCC(dl, MVT::v64i1, Zeros, R, SDValue CMP = DAG.getSetCC(dl, MVT::v64i1, Zeros, R, ISD::SETGT);
ISD::SETGT); return DAG.getNode(ISD::SIGN_EXTEND, dl, VT, CMP);
return DAG.getNode(ISD::SIGN_EXTEND, dl, VT, CMP);
}
return DAG.getNode(X86ISD::PCMPGT, dl, VT, Zeros, R);
} }
return DAG.getNode(X86ISD::PCMPGT, dl, VT, Zeros, R);
// XOP can shift v16i8 directly instead of as shift v8i16 + mask.
if (VT == MVT::v16i8 && Subtarget.hasXOP())
return SDValue();
if (Op.getOpcode() == ISD::SHL) {
// Make a large shift.
SDValue SHL = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, ShiftVT,
R, ShiftAmt, DAG);
SHL = DAG.getBitcast(VT, SHL);
// Zero out the rightmost bits.
return DAG.getNode(ISD::AND, dl, VT, SHL,
DAG.getConstant(uint8_t(-1U << ShiftAmt), dl, VT));
}
if (Op.getOpcode() == ISD::SRL) {
// Make a large shift.
SDValue SRL = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ShiftVT,
R, ShiftAmt, DAG);
SRL = DAG.getBitcast(VT, SRL);
// Zero out the leftmost bits.
return DAG.getNode(ISD::AND, dl, VT, SRL,
DAG.getConstant(uint8_t(-1U) >> ShiftAmt, dl, VT));
}
if (Op.getOpcode() == ISD::SRA) {
// ashr(R, Amt) === sub(xor(lshr(R, Amt), Mask), Mask)
SDValue Res = DAG.getNode(ISD::SRL, dl, VT, R, Amt);
SDValue Mask = DAG.getConstant(128 >> ShiftAmt, dl, VT);
Res = DAG.getNode(ISD::XOR, dl, VT, Res, Mask);
Res = DAG.getNode(ISD::SUB, dl, VT, Res, Mask);
return Res;
}
llvm_unreachable("Unknown shift opcode.");
} }
// XOP can shift v16i8 directly instead of as shift v8i16 + mask.
if (VT == MVT::v16i8 && Subtarget.hasXOP())
return SDValue();
if (Op.getOpcode() == ISD::SHL) {
// Make a large shift.
SDValue SHL = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, ShiftVT, R,
ShiftAmt, DAG);
SHL = DAG.getBitcast(VT, SHL);
// Zero out the rightmost bits.
return DAG.getNode(ISD::AND, dl, VT, SHL,
DAG.getConstant(uint8_t(-1U << ShiftAmt), dl, VT));
}
if (Op.getOpcode() == ISD::SRL) {
// Make a large shift.
SDValue SRL = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ShiftVT, R,
ShiftAmt, DAG);
SRL = DAG.getBitcast(VT, SRL);
// Zero out the leftmost bits.
return DAG.getNode(ISD::AND, dl, VT, SRL,
DAG.getConstant(uint8_t(-1U) >> ShiftAmt, dl, VT));
}
if (Op.getOpcode() == ISD::SRA) {
// ashr(R, Amt) === sub(xor(lshr(R, Amt), Mask), Mask)
SDValue Res = DAG.getNode(ISD::SRL, dl, VT, R, Amt);
SDValue Mask = DAG.getConstant(128 >> ShiftAmt, dl, VT);
Res = DAG.getNode(ISD::XOR, dl, VT, Res, Mask);
Res = DAG.getNode(ISD::SUB, dl, VT, Res, Mask);
return Res;
}
llvm_unreachable("Unknown shift opcode.");
} }
} }