X86 vector element shift-by-immediate instructions take i8 immediates. Make

the instruction defenitions and ISEL reflect this.

Prior to this patch these instructions took an i32i8imm, and the high bits were
dropped during encoding. This led to incorrect behavior for shifts by
immediates higher than 255. This patch fixes that issue by detecting large
immediate shifts and returning constant zero (for logical shifts) or capping
the shift amount at an encodable value (for arithmetic shifts).

Fixes <rdar://problem/14968098>

llvm-svn: 193096
This commit is contained in:
Lang Hames 2013-10-21 17:51:24 +00:00
parent 691281be2f
commit 2783993fca
5 changed files with 77 additions and 64 deletions

View File

@ -10952,6 +10952,26 @@ static SDValue LowerVACOPY(SDValue Op, const X86Subtarget *Subtarget,
MachinePointerInfo(DstSV), MachinePointerInfo(SrcSV));
}
// getTargetVShiftByConstNode - Handle vector element shifts where the shift
// amount is a constant. Takes immediate version of shift as input.
static SDValue getTargetVShiftByConstNode(unsigned Opc, SDLoc dl, EVT VT,
SDValue SrcOp, uint64_t ShiftAmt,
SelectionDAG &DAG) {
// Check for ShiftAmt >= element width
if (ShiftAmt >= VT.getVectorElementType().getSizeInBits()) {
if (Opc == X86ISD::VSRAI)
ShiftAmt = VT.getVectorElementType().getSizeInBits() - 1;
else
return DAG.getConstant(0, VT);
}
assert((Opc == X86ISD::VSHLI || Opc == X86ISD::VSRLI || Opc == X86ISD::VSRAI)
&& "Unknown target vector shift-by-constant node");
return DAG.getNode(Opc, dl, VT, SrcOp, DAG.getConstant(ShiftAmt, MVT::i8));
}
// getTargetVShiftNode - Handle vector element shifts where the shift amount
// may or may not be a constant. Takes immediate version of shift as input.
static SDValue getTargetVShiftNode(unsigned Opc, SDLoc dl, EVT VT,
@ -10959,18 +10979,10 @@ static SDValue getTargetVShiftNode(unsigned Opc, SDLoc dl, EVT VT,
SelectionDAG &DAG) {
assert(ShAmt.getValueType() == MVT::i32 && "ShAmt is not i32");
if (isa<ConstantSDNode>(ShAmt)) {
// Constant may be a TargetConstant. Use a regular constant.
uint32_t ShiftAmt = cast<ConstantSDNode>(ShAmt)->getZExtValue();
switch (Opc) {
default: llvm_unreachable("Unknown target vector shift node");
case X86ISD::VSHLI:
case X86ISD::VSRLI:
case X86ISD::VSRAI:
return DAG.getNode(Opc, dl, VT, SrcOp,
DAG.getConstant(ShiftAmt, MVT::i32));
}
}
// Catch shift-by-constant.
if (ConstantSDNode *CShAmt = dyn_cast<ConstantSDNode>(ShAmt))
return getTargetVShiftByConstNode(Opc, dl, VT, SrcOp,
CShAmt->getZExtValue(), DAG);
// Change opcode to non-immediate version
switch (Opc) {
@ -12416,10 +12428,8 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget *Subtarget,
// AhiBlo = psllqi(AhiBlo, 32);
// return AloBlo + AloBhi + AhiBlo;
SDValue ShAmt = DAG.getConstant(32, MVT::i32);
SDValue Ahi = DAG.getNode(X86ISD::VSRLI, dl, VT, A, ShAmt);
SDValue Bhi = DAG.getNode(X86ISD::VSRLI, dl, VT, B, ShAmt);
SDValue Ahi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, A, 32, DAG);
SDValue Bhi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, B, 32, DAG);
// Bit cast to 32-bit vectors for MULUDQ
EVT MulVT = (VT == MVT::v2i64) ? MVT::v4i32 :
@ -12433,8 +12443,8 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget *Subtarget,
SDValue AloBhi = DAG.getNode(X86ISD::PMULUDQ, dl, VT, A, Bhi);
SDValue AhiBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Ahi, B);
AloBhi = DAG.getNode(X86ISD::VSHLI, dl, VT, AloBhi, ShAmt);
AhiBlo = DAG.getNode(X86ISD::VSHLI, dl, VT, AhiBlo, ShAmt);
AloBhi = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, AloBhi, 32, DAG);
AhiBlo = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, AhiBlo, 32, DAG);
SDValue Res = DAG.getNode(ISD::ADD, dl, VT, AloBlo, AloBhi);
return DAG.getNode(ISD::ADD, dl, VT, Res, AhiBlo);
@ -12462,7 +12472,7 @@ static SDValue LowerSDIV(SDValue Op, SelectionDAG &DAG) {
if ((SplatValue != 0) &&
(SplatValue.isPowerOf2() || (-SplatValue).isPowerOf2())) {
unsigned lg2 = SplatValue.countTrailingZeros();
unsigned Lg2 = SplatValue.countTrailingZeros();
// Splat the sign bit.
SmallVector<SDValue, 16> Sz(NumElts,
DAG.getConstant(EltTy.getSizeInBits() - 1,
@ -12472,13 +12482,13 @@ static SDValue LowerSDIV(SDValue Op, SelectionDAG &DAG) {
NumElts));
// Add (N0 < 0) ? abs2 - 1 : 0;
SmallVector<SDValue, 16> Amt(NumElts,
DAG.getConstant(EltTy.getSizeInBits() - lg2,
DAG.getConstant(EltTy.getSizeInBits() - Lg2,
EltTy));
SDValue SRL = DAG.getNode(ISD::SRL, dl, VT, SGN,
DAG.getNode(ISD::BUILD_VECTOR, dl, VT, &Amt[0],
NumElts));
SDValue ADD = DAG.getNode(ISD::ADD, dl, VT, N0, SRL);
SmallVector<SDValue, 16> Lg2Amt(NumElts, DAG.getConstant(lg2, EltTy));
SmallVector<SDValue, 16> Lg2Amt(NumElts, DAG.getConstant(Lg2, EltTy));
SDValue SRA = DAG.getNode(ISD::SRA, dl, VT, ADD,
DAG.getNode(ISD::BUILD_VECTOR, dl, VT, &Lg2Amt[0],
NumElts));
@ -12514,21 +12524,22 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
(Subtarget->hasAVX512() &&
(VT == MVT::v8i64 || VT == MVT::v16i32))) {
if (Op.getOpcode() == ISD::SHL)
return DAG.getNode(X86ISD::VSHLI, dl, VT, R,
DAG.getConstant(ShiftAmt, MVT::i32));
return getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, R, ShiftAmt,
DAG);
if (Op.getOpcode() == ISD::SRL)
return DAG.getNode(X86ISD::VSRLI, dl, VT, R,
DAG.getConstant(ShiftAmt, MVT::i32));
return getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, R, ShiftAmt,
DAG);
if (Op.getOpcode() == ISD::SRA && VT != MVT::v2i64 && VT != MVT::v4i64)
return DAG.getNode(X86ISD::VSRAI, dl, VT, R,
DAG.getConstant(ShiftAmt, MVT::i32));
return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, R, ShiftAmt,
DAG);
}
if (VT == MVT::v16i8) {
if (Op.getOpcode() == ISD::SHL) {
// Make a large shift.
SDValue SHL = DAG.getNode(X86ISD::VSHLI, dl, MVT::v8i16, R,
DAG.getConstant(ShiftAmt, MVT::i32));
SDValue SHL = getTargetVShiftByConstNode(X86ISD::VSHLI, dl,
MVT::v8i16, R, ShiftAmt,
DAG);
SHL = DAG.getNode(ISD::BITCAST, dl, VT, SHL);
// Zero out the rightmost bits.
SmallVector<SDValue, 16> V(16,
@ -12539,8 +12550,9 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
}
if (Op.getOpcode() == ISD::SRL) {
// Make a large shift.
SDValue SRL = DAG.getNode(X86ISD::VSRLI, dl, MVT::v8i16, R,
DAG.getConstant(ShiftAmt, MVT::i32));
SDValue SRL = getTargetVShiftByConstNode(X86ISD::VSRLI, dl,
MVT::v8i16, R, ShiftAmt,
DAG);
SRL = DAG.getNode(ISD::BITCAST, dl, VT, SRL);
// Zero out the leftmost bits.
SmallVector<SDValue, 16> V(16,
@ -12571,8 +12583,9 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
if (Subtarget->hasInt256() && VT == MVT::v32i8) {
if (Op.getOpcode() == ISD::SHL) {
// Make a large shift.
SDValue SHL = DAG.getNode(X86ISD::VSHLI, dl, MVT::v16i16, R,
DAG.getConstant(ShiftAmt, MVT::i32));
SDValue SHL = getTargetVShiftByConstNode(X86ISD::VSHLI, dl,
MVT::v16i16, R, ShiftAmt,
DAG);
SHL = DAG.getNode(ISD::BITCAST, dl, VT, SHL);
// Zero out the rightmost bits.
SmallVector<SDValue, 32> V(32,
@ -12583,8 +12596,9 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
}
if (Op.getOpcode() == ISD::SRL) {
// Make a large shift.
SDValue SRL = DAG.getNode(X86ISD::VSRLI, dl, MVT::v16i16, R,
DAG.getConstant(ShiftAmt, MVT::i32));
SDValue SRL = getTargetVShiftByConstNode(X86ISD::VSRLI, dl,
MVT::v16i16, R, ShiftAmt,
DAG);
SRL = DAG.getNode(ISD::BITCAST, dl, VT, SRL);
// Zero out the leftmost bits.
SmallVector<SDValue, 32> V(32,
@ -12649,14 +12663,14 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
default:
llvm_unreachable("Unknown shift opcode!");
case ISD::SHL:
return DAG.getNode(X86ISD::VSHLI, dl, VT, R,
DAG.getConstant(ShiftAmt, MVT::i32));
return getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, R, ShiftAmt,
DAG);
case ISD::SRL:
return DAG.getNode(X86ISD::VSRLI, dl, VT, R,
DAG.getConstant(ShiftAmt, MVT::i32));
return getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, R, ShiftAmt,
DAG);
case ISD::SRA:
return DAG.getNode(X86ISD::VSRAI, dl, VT, R,
DAG.getConstant(ShiftAmt, MVT::i32));
return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, R, ShiftAmt,
DAG);
}
}
@ -12869,8 +12883,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget* Subtarget,
// r = VSELECT(r, psllw(r & (char16)15, 4), a);
SDValue M = DAG.getNode(ISD::AND, dl, VT, R, CM1);
M = getTargetVShiftNode(X86ISD::VSHLI, dl, MVT::v8i16, M,
DAG.getConstant(4, MVT::i32), DAG);
M = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, MVT::v8i16, M, 4, DAG);
M = DAG.getNode(ISD::BITCAST, dl, VT, M);
R = DAG.getNode(ISD::VSELECT, dl, VT, OpVSel, M, R);
@ -12881,8 +12894,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget* Subtarget,
// r = VSELECT(r, psllw(r & (char16)63, 2), a);
M = DAG.getNode(ISD::AND, dl, VT, R, CM2);
M = getTargetVShiftNode(X86ISD::VSHLI, dl, MVT::v8i16, M,
DAG.getConstant(2, MVT::i32), DAG);
M = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, MVT::v8i16, M, 2, DAG);
M = DAG.getNode(ISD::BITCAST, dl, VT, M);
R = DAG.getNode(ISD::VSELECT, dl, VT, OpVSel, M, R);
@ -13025,7 +13037,6 @@ SDValue X86TargetLowering::LowerSIGN_EXTEND_INREG(SDValue Op,
unsigned BitsDiff = VT.getScalarType().getSizeInBits() -
ExtraVT.getScalarType().getSizeInBits();
SDValue ShAmt = DAG.getConstant(BitsDiff, MVT::i32);
switch (VT.getSimpleVT().SimpleTy) {
default: return SDValue();
@ -13075,8 +13086,10 @@ SDValue X86TargetLowering::LowerSIGN_EXTEND_INREG(SDValue Op,
}
// If the above didn't work, then just use Shift-Left + Shift-Right.
Tmp1 = getTargetVShiftNode(X86ISD::VSHLI, dl, VT, Op0, ShAmt, DAG);
return getTargetVShiftNode(X86ISD::VSRAI, dl, VT, Tmp1, ShAmt, DAG);
Tmp1 = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, Op0, BitsDiff,
DAG);
return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, Tmp1, BitsDiff,
DAG);
}
}
}

View File

@ -1845,22 +1845,22 @@ multiclass avx512_shift_rmi<bits<8> opc, Format ImmFormR, Format ImmFormM,
ValueType vt, X86MemOperand x86memop, PatFrag mem_frag,
RegisterClass KRC> {
def ri : AVX512BIi8<opc, ImmFormR, (outs RC:$dst),
(ins RC:$src1, i32i8imm:$src2),
(ins RC:$src1, i8imm:$src2),
!strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
[(set RC:$dst, (vt (OpNode RC:$src1, (i32 imm:$src2))))],
[(set RC:$dst, (vt (OpNode RC:$src1, (i8 imm:$src2))))],
SSE_INTSHIFT_ITINS_P.rr>, EVEX_4V;
def rik : AVX512BIi8<opc, ImmFormR, (outs RC:$dst),
(ins KRC:$mask, RC:$src1, i32i8imm:$src2),
(ins KRC:$mask, RC:$src1, i8imm:$src2),
!strconcat(OpcodeStr,
"\t{$src2, $src1, $dst {${mask}}|$dst {${mask}}, $src1, $src2}"),
[], SSE_INTSHIFT_ITINS_P.rr>, EVEX_4V, EVEX_K;
def mi: AVX512BIi8<opc, ImmFormM, (outs RC:$dst),
(ins x86memop:$src1, i32i8imm:$src2),
(ins x86memop:$src1, i8imm:$src2),
!strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
[(set RC:$dst, (OpNode (mem_frag addr:$src1),
(i32 imm:$src2)))], SSE_INTSHIFT_ITINS_P.rm>, EVEX_4V;
(i8 imm:$src2)))], SSE_INTSHIFT_ITINS_P.rm>, EVEX_4V;
def mik: AVX512BIi8<opc, ImmFormM, (outs RC:$dst),
(ins KRC:$mask, x86memop:$src1, i32i8imm:$src2),
(ins KRC:$mask, x86memop:$src1, i8imm:$src2),
!strconcat(OpcodeStr,
"\t{$src2, $src1, $dst {${mask}}|$dst {${mask}}, $src1, $src2}"),
[], SSE_INTSHIFT_ITINS_P.rm>, EVEX_4V, EVEX_K;

View File

@ -3744,11 +3744,11 @@ multiclass PDI_binop_rmi<bits<8> opc, bits<8> opc2, Format ImmForm,
(bc_frag (memopv2i64 addr:$src2)))))], itins.rm>,
Sched<[WriteVecShiftLd, ReadAfterLd]>;
def ri : PDIi8<opc2, ImmForm, (outs RC:$dst),
(ins RC:$src1, i32i8imm:$src2),
(ins RC:$src1, i8imm:$src2),
!if(Is2Addr,
!strconcat(OpcodeStr, "\t{$src2, $dst|$dst, $src2}"),
!strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}")),
[(set RC:$dst, (DstVT (OpNode2 RC:$src1, (i32 imm:$src2))))], itins.ri>,
[(set RC:$dst, (DstVT (OpNode2 RC:$src1, (i8 imm:$src2))))], itins.ri>,
Sched<[WriteVecShift]>;
}
@ -5064,12 +5064,12 @@ multiclass SS3I_unop_rm_int_y<bits<8> opc, string OpcodeStr,
// Helper fragments to match sext vXi1 to vXiY.
def v16i1sextv16i8 : PatLeaf<(v16i8 (X86pcmpgt (bc_v16i8 (v4i32 immAllZerosV)),
VR128:$src))>;
def v8i1sextv8i16 : PatLeaf<(v8i16 (X86vsrai VR128:$src, (i32 15)))>;
def v4i1sextv4i32 : PatLeaf<(v4i32 (X86vsrai VR128:$src, (i32 31)))>;
def v8i1sextv8i16 : PatLeaf<(v8i16 (X86vsrai VR128:$src, (i8 15)))>;
def v4i1sextv4i32 : PatLeaf<(v4i32 (X86vsrai VR128:$src, (i8 31)))>;
def v32i1sextv32i8 : PatLeaf<(v32i8 (X86pcmpgt (bc_v32i8 (v8i32 immAllZerosV)),
VR256:$src))>;
def v16i1sextv16i16: PatLeaf<(v16i16 (X86vsrai VR256:$src, (i32 15)))>;
def v8i1sextv8i32 : PatLeaf<(v8i32 (X86vsrai VR256:$src, (i32 31)))>;
def v16i1sextv16i16: PatLeaf<(v16i16 (X86vsrai VR256:$src, (i8 15)))>;
def v8i1sextv8i32 : PatLeaf<(v8i32 (X86vsrai VR256:$src, (i8 31)))>;
let Predicates = [HasAVX] in {
defm VPABSB : SS3I_unop_rm_int<0x1C, "vpabsb",

View File

@ -121,7 +121,7 @@ entry:
}
; CHECK-LABEL: test_sraw_3:
; CHECK: vpsraw $16, %ymm0, %ymm0
; CHECK: vpsraw $15, %ymm0, %ymm0
; CHECK: ret
define <8 x i32> @test_srad_1(<8 x i32> %InVec) {
@ -151,7 +151,7 @@ entry:
}
; CHECK-LABEL: test_srad_3:
; CHECK: vpsrad $32, %ymm0, %ymm0
; CHECK: vpsrad $31, %ymm0, %ymm0
; CHECK: ret
; SSE Logical Shift Right

View File

@ -121,7 +121,7 @@ entry:
}
; CHECK-LABEL: test_sraw_3:
; CHECK: psraw $16, %xmm0
; CHECK: psraw $15, %xmm0
; CHECK-NEXT: ret
define <4 x i32> @test_srad_1(<4 x i32> %InVec) {
@ -151,7 +151,7 @@ entry:
}
; CHECK-LABEL: test_srad_3:
; CHECK: psrad $32, %xmm0
; CHECK: psrad $31, %xmm0
; CHECK-NEXT: ret
; SSE Logical Shift Right