[X86] Combine (vXi1 (bitcast (-1)))) and (vXi1 (bitcast (0))) to all ones or all zeros vXi1 vector.

llvm-svn: 331847
This commit is contained in:
Craig Topper 2018-05-09 06:07:20 +00:00
parent 618437459c
commit b9a473d186
2 changed files with 48 additions and 0 deletions

View File

@ -31095,6 +31095,16 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG,
return combinevXi1ConstantToInteger(N0, DAG);
}
if (Subtarget.hasAVX512() && SrcVT.isScalarInteger() &&
VT.isVector() && VT.getVectorElementType() == MVT::i1 &&
isa<ConstantSDNode>(N0)) {
auto *C = cast<ConstantSDNode>(N0);
if (C->isAllOnesValue())
return DAG.getConstant(1, SDLoc(N0), VT);
if (C->isNullValue())
return DAG.getConstant(0, SDLoc(N0), VT);
}
// Try to remove bitcasts from input and output of mask arithmetic to
// remove GPR<->K-register crossings.
if (SDValue V = combineCastedMaskArithmetic(N, DAG, DCI, Subtarget))

View File

@ -3446,3 +3446,41 @@ entry:
store <4 x i1> <i1 1, i1 0, i1 1, i1 0>, <4 x i1>* %R
ret void
}
; Make sure we bring the -1 constant into the mask domain.
define void @mask_not_cast(i8*, <8 x i64>, <8 x i64>, <8 x i64>, <8 x i64>) {
; CHECK-LABEL: mask_not_cast:
; CHECK: ## %bb.0:
; CHECK-NEXT: vpcmpleud %zmm3, %zmm2, %k0
; CHECK-NEXT: knotw %k0, %k1
; CHECK-NEXT: vptestmd %zmm0, %zmm1, %k1 {%k1}
; CHECK-NEXT: vmovdqu32 %zmm0, (%rdi) {%k1}
; CHECK-NEXT: vzeroupper
; CHECK-NEXT: retq
;
; X86-LABEL: mask_not_cast:
; X86: ## %bb.0:
; X86-NEXT: movl {{[0-9]+}}(%esp), %eax
; X86-NEXT: vpcmpleud %zmm3, %zmm2, %k0
; X86-NEXT: knotw %k0, %k1
; X86-NEXT: vptestmd %zmm0, %zmm1, %k1 {%k1}
; X86-NEXT: vmovdqu32 %zmm0, (%eax) {%k1}
; X86-NEXT: vzeroupper
; X86-NEXT: retl
%6 = and <8 x i64> %2, %1
%7 = bitcast <8 x i64> %6 to <16 x i32>
%8 = icmp ne <16 x i32> %7, zeroinitializer
%9 = bitcast <16 x i1> %8 to i16
%10 = bitcast <8 x i64> %3 to <16 x i32>
%11 = bitcast <8 x i64> %4 to <16 x i32>
%12 = icmp ule <16 x i32> %10, %11
%13 = bitcast <16 x i1> %12 to i16
%14 = xor i16 %13, -1
%15 = and i16 %14, %9
%16 = bitcast <8 x i64> %1 to <16 x i32>
%17 = bitcast i8* %0 to <16 x i32>*
%18 = bitcast i16 %15 to <16 x i1>
tail call void @llvm.masked.store.v16i32.p0v16i32(<16 x i32> %16, <16 x i32>* %17, i32 1, <16 x i1> %18) #2
ret void
}
declare void @llvm.masked.store.v16i32.p0v16i32(<16 x i32>, <16 x i32>*, i32, <16 x i1>)