diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index bb0ec10b957e..3011407eb32e 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -1557,14 +1557,13 @@ void X86TargetLowering::resetOperationActions() { EVT X86TargetLowering::getSetCCResultType(LLVMContext &, EVT VT) const { if (!VT.isVector()) - return MVT::i8; + return Subtarget->hasAVX512() ? MVT::i1: MVT::i8; - const TargetMachine &TM = getTargetMachine(); - if (!TM.Options.UseSoftFloat && Subtarget->hasAVX512()) + if (Subtarget->hasAVX512()) switch(VT.getVectorNumElements()) { case 8: return MVT::v8i1; case 16: return MVT::v16i1; - } + } return VT.changeVectorElementTypeToInteger(); } @@ -10199,7 +10198,7 @@ SDValue X86TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { if (VT.isVector()) return LowerVSETCC(Op, Subtarget, DAG); - assert((VT == MVT::i8 || (Subtarget->hasAVX512() && VT == MVT::i1)) + assert(((!Subtarget->hasAVX512() && VT == MVT::i8) || (VT == MVT::i1)) && "SetCC type must be 8-bit or 1-bit integer"); SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); @@ -10235,7 +10234,7 @@ SDValue X86TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { if (!Invert) return Op0; CCode = X86::GetOppositeBranchCondition(CCode); - return DAG.getNode(X86ISD::SETCC, dl, MVT::i8, + return DAG.getNode(X86ISD::SETCC, dl, VT, DAG.getConstant(CCode, MVT::i8), Op0.getOperand(1)); } } @@ -10247,8 +10246,7 @@ SDValue X86TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { SDValue EFLAGS = EmitCmp(Op0, Op1, X86CC, DAG); EFLAGS = ConvertCmpIfNecessary(EFLAGS, DAG); - MVT SetCCVT = Subtarget->hasAVX512() ? MVT::i1 : MVT::i8; - return DAG.getNode(X86ISD::SETCC, dl, SetCCVT, + return DAG.getNode(X86ISD::SETCC, dl, VT, DAG.getConstant(X86CC, MVT::i8), EFLAGS); } diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td index b64d0c307ccf..5e5b6fbbc5ad 100644 --- a/llvm/lib/Target/X86/X86InstrAVX512.td +++ b/llvm/lib/Target/X86/X86InstrAVX512.td @@ -1018,6 +1018,10 @@ def : Pat<(not VK1:$src), (COPY_TO_REGCLASS (VCMPSSZrr (f32 (IMPLICIT_DEF)), (f32 (IMPLICIT_DEF)), (i8 0)), VK16)), VK1)>; +def : Pat<(and VK1:$src1, VK1:$src2), + (COPY_TO_REGCLASS (KANDWrr (COPY_TO_REGCLASS VK1:$src1, VK16), + (COPY_TO_REGCLASS VK1:$src2, VK16)), VK1)>; + multiclass avx512_mask_binop_int { let Predicates = [HasAVX512] in def : Pat<(!cast("int_x86_avx512_"##IntName##"_w")