diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 7d2bfd421e45..ad2d4b55bdaa 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -1310,8 +1310,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::UINT_TO_FP, MVT::v8i64, Legal); setOperationAction(ISD::FP_TO_SINT, MVT::v8i64, Legal); setOperationAction(ISD::FP_TO_UINT, MVT::v8i64, Legal); - - setOperationAction(ISD::MUL, MVT::v8i64, Legal); } if (Subtarget.hasCDI()) { @@ -1388,8 +1386,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::UINT_TO_FP, VT, Legal); setOperationAction(ISD::FP_TO_SINT, VT, Legal); setOperationAction(ISD::FP_TO_UINT, VT, Legal); - - setOperationAction(ISD::MUL, VT, Legal); } } @@ -22140,6 +22136,11 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget, bool AHiIsZero = DAG.MaskedValueIsZero(A, UpperBitsMask); bool BHiIsZero = DAG.MaskedValueIsZero(B, UpperBitsMask); + // If DQI is supported we can use MULLQ, but MULUDQ is still better if the + // the high bits are known to be zero. + if (Subtarget.hasDQI() && (!AHiIsZero || !BHiIsZero)) + return Op; + // Bit cast to 32-bit vectors for MULUDQ. SDValue Alo = DAG.getBitcast(MulVT, A); SDValue Blo = DAG.getBitcast(MulVT, B); @@ -32423,41 +32424,6 @@ static SDValue combineMulSpecial(uint64_t MulAmt, SDNode *N, SelectionDAG &DAG, return SDValue(); } -static SDValue combineVMUL(SDNode *N, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { - EVT VT = N->getValueType(0); - SDLoc dl(N); - - if (VT.getScalarType() != MVT::i64) - return SDValue(); - - // Don't try to lower 256 bit integer vectors on AVX1 targets. - if (!Subtarget.hasAVX2() && VT.getVectorNumElements() > 2) - return SDValue(); - - MVT MulVT = MVT::getVectorVT(MVT::i32, VT.getVectorNumElements() * 2); - - SDValue LHS = N->getOperand(0); - SDValue RHS = N->getOperand(1); - - // MULDQ returns the 64-bit result of the signed multiplication of the lower - // 32-bits. We can lower with this if the sign bits stretch that far. - if (Subtarget.hasSSE41() && DAG.ComputeNumSignBits(LHS) > 32 && - DAG.ComputeNumSignBits(RHS) > 32) { - return DAG.getNode(X86ISD::PMULDQ, dl, VT, DAG.getBitcast(MulVT, LHS), - DAG.getBitcast(MulVT, RHS)); - } - - // If the upper bits are zero we can use a single pmuludq. - APInt Mask = APInt::getHighBitsSet(64, 32); - if (DAG.MaskedValueIsZero(LHS, Mask) && DAG.MaskedValueIsZero(RHS, Mask)) { - return DAG.getNode(X86ISD::PMULUDQ, dl, VT, DAG.getBitcast(MulVT, LHS), - DAG.getBitcast(MulVT, RHS)); - } - - return SDValue(); -} - /// Optimize a single multiply with constant into two operations in order to /// implement it with two cheaper instructions, e.g. LEA + SHL, LEA + LEA. static SDValue combineMul(SDNode *N, SelectionDAG &DAG, @@ -32467,9 +32433,6 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG, if (DCI.isBeforeLegalize() && VT.isVector()) return reduceVMULWidth(N, DAG, Subtarget); - if (!DCI.isBeforeLegalize() && VT.isVector()) - return combineVMUL(N, DAG, Subtarget); - if (!MulConstantOptimization) return SDValue(); // An imul is usually smaller than the alternative sequence. @@ -34911,7 +34874,7 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG, // X86 is rubbish at scalar and vector i64 multiplies (until AVX512DQ) - its // better to truncate if we have the chance. if (SrcVT.getScalarType() == MVT::i64 && TLI.isOperationLegal(Opcode, VT) && - !TLI.isOperationLegal(Opcode, SrcVT)) + !Subtarget.hasDQI()) return TruncateArithmetic(Src.getOperand(0), Src.getOperand(1)); LLVM_FALLTHROUGH; case ISD::ADD: { diff --git a/llvm/test/CodeGen/X86/combine-pmuldq.ll b/llvm/test/CodeGen/X86/combine-pmuldq.ll index 0c7b8d6f4c50..ebfe0d56358e 100644 --- a/llvm/test/CodeGen/X86/combine-pmuldq.ll +++ b/llvm/test/CodeGen/X86/combine-pmuldq.ll @@ -90,7 +90,7 @@ define <2 x i64> @combine_shuffle_zero_pmuludq(<4 x i32> %a0, <4 x i32> %a1) { ; AVX512DQVL-NEXT: vpxor %xmm2, %xmm2, %xmm2 ; AVX512DQVL-NEXT: vpblendd {{.*#+}} xmm0 = xmm0[0],xmm2[1],xmm0[2],xmm2[3] ; AVX512DQVL-NEXT: vpblendd {{.*#+}} xmm1 = xmm1[0],xmm2[1],xmm1[2],xmm2[3] -; AVX512DQVL-NEXT: vpmullq %xmm1, %xmm0, %xmm0 +; AVX512DQVL-NEXT: vpmuludq %xmm1, %xmm0, %xmm0 ; AVX512DQVL-NEXT: retq %1 = shufflevector <4 x i32> %a0, <4 x i32> zeroinitializer, <4 x i32> %2 = shufflevector <4 x i32> %a1, <4 x i32> zeroinitializer, <4 x i32> @@ -134,7 +134,7 @@ define <4 x i64> @combine_shuffle_zero_pmuludq_256(<8 x i32> %a0, <8 x i32> %a1) ; AVX512DQVL-NEXT: vpxor %xmm2, %xmm2, %xmm2 ; AVX512DQVL-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0],ymm2[1],ymm0[2],ymm2[3],ymm0[4],ymm2[5],ymm0[6],ymm2[7] ; AVX512DQVL-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0],ymm2[1],ymm1[2],ymm2[3],ymm1[4],ymm2[5],ymm1[6],ymm2[7] -; AVX512DQVL-NEXT: vpmullq %ymm1, %ymm0, %ymm0 +; AVX512DQVL-NEXT: vpmuludq %ymm1, %ymm0, %ymm0 ; AVX512DQVL-NEXT: retq %1 = shufflevector <8 x i32> %a0, <8 x i32> zeroinitializer, <8 x i32> %2 = shufflevector <8 x i32> %a1, <8 x i32> zeroinitializer, <8 x i32>