[Hexagon] Improve code generation for 32x32-bit multiplication

For multiplications of 64-bit values (giving 64-bit result), detect
cases where the arguments are sign-extended 32-bit values, on a per-
operand basis. This will allow few patterns to match a wider variety
of combinations in which extensions can occur.

llvm-svn: 304223
This commit is contained in:
Krzysztof Parzyszek 2017-05-30 17:47:51 +00:00
parent 591312c5c1
commit ef58017b35
4 changed files with 208 additions and 139 deletions

View File

@ -71,6 +71,9 @@ public:
return true;
}
bool ComplexPatternFuncMutatesDAG() const override {
return true;
}
void PreprocessISelDAG() override;
void EmitFunctionEntryCode() override;
@ -81,6 +84,7 @@ public:
inline bool SelectAddrGP(SDValue &N, SDValue &R);
bool SelectGlobalAddress(SDValue &N, SDValue &R, bool UseGP);
bool SelectAddrFI(SDValue &N, SDValue &R);
bool DetectUseSxtw(SDValue &N, SDValue &R);
StringRef getPassName() const override {
return "Hexagon DAG->DAG Pattern Instruction Selection";
@ -106,7 +110,6 @@ public:
void SelectIndexedStore(StoreSDNode *ST, const SDLoc &dl);
void SelectStore(SDNode *N);
void SelectSHL(SDNode *N);
void SelectMul(SDNode *N);
void SelectZeroExtend(SDNode *N);
void SelectIntrinsicWChain(SDNode *N);
void SelectIntrinsicWOChain(SDNode *N);
@ -118,7 +121,7 @@ public:
#include "HexagonGenDAGISel.inc"
private:
bool isValueExtension(const SDValue &Val, unsigned FromBits, SDValue &Src);
bool keepsLowBits(const SDValue &Val, unsigned NumBits, SDValue &Src);
bool isOrEquivalentToAdd(const SDNode *N) const;
bool isAlignedMemNode(const MemSDNode *N) const;
bool isPositiveHalfWord(const SDNode *N) const;
@ -597,90 +600,6 @@ void HexagonDAGToDAGISel::SelectStore(SDNode *N) {
SelectCode(ST);
}
void HexagonDAGToDAGISel::SelectMul(SDNode *N) {
SDLoc dl(N);
// %conv.i = sext i32 %tmp1 to i64
// %conv2.i = sext i32 %add to i64
// %mul.i = mul nsw i64 %conv2.i, %conv.i
//
// --- match with the following ---
//
// %mul.i = mpy (%tmp1, %add)
//
if (N->getValueType(0) == MVT::i64) {
// Shifting a i64 signed multiply.
SDValue MulOp0 = N->getOperand(0);
SDValue MulOp1 = N->getOperand(1);
SDValue OP0;
SDValue OP1;
// Handle sign_extend and sextload.
if (MulOp0.getOpcode() == ISD::SIGN_EXTEND) {
SDValue Sext0 = MulOp0.getOperand(0);
if (Sext0.getNode()->getValueType(0) != MVT::i32) {
SelectCode(N);
return;
}
OP0 = Sext0;
} else if (MulOp0.getOpcode() == ISD::LOAD) {
LoadSDNode *LD = cast<LoadSDNode>(MulOp0.getNode());
if (LD->getMemoryVT() != MVT::i32 ||
LD->getExtensionType() != ISD::SEXTLOAD ||
LD->getAddressingMode() != ISD::UNINDEXED) {
SelectCode(N);
return;
}
SDValue Chain = LD->getChain();
SDValue TargetConst0 = CurDAG->getTargetConstant(0, dl, MVT::i32);
OP0 = SDValue(CurDAG->getMachineNode(Hexagon::L2_loadri_io, dl, MVT::i32,
MVT::Other,
LD->getBasePtr(), TargetConst0,
Chain), 0);
} else {
SelectCode(N);
return;
}
// Same goes for the second operand.
if (MulOp1.getOpcode() == ISD::SIGN_EXTEND) {
SDValue Sext1 = MulOp1.getOperand(0);
if (Sext1.getNode()->getValueType(0) != MVT::i32) {
SelectCode(N);
return;
}
OP1 = Sext1;
} else if (MulOp1.getOpcode() == ISD::LOAD) {
LoadSDNode *LD = cast<LoadSDNode>(MulOp1.getNode());
if (LD->getMemoryVT() != MVT::i32 ||
LD->getExtensionType() != ISD::SEXTLOAD ||
LD->getAddressingMode() != ISD::UNINDEXED) {
SelectCode(N);
return;
}
SDValue Chain = LD->getChain();
SDValue TargetConst0 = CurDAG->getTargetConstant(0, dl, MVT::i32);
OP1 = SDValue(CurDAG->getMachineNode(Hexagon::L2_loadri_io, dl, MVT::i32,
MVT::Other,
LD->getBasePtr(), TargetConst0,
Chain), 0);
} else {
SelectCode(N);
return;
}
// Generate a mpy instruction.
SDNode *Result = CurDAG->getMachineNode(Hexagon::M2_dpmpyss_s0, dl,
MVT::i64, OP0, OP1);
ReplaceNode(N, Result);
return;
}
SelectCode(N);
}
void HexagonDAGToDAGISel::SelectSHL(SDNode *N) {
SDLoc dl(N);
SDValue Shl_0 = N->getOperand(0);
@ -843,7 +762,7 @@ void HexagonDAGToDAGISel::SelectIntrinsicWOChain(SDNode *N) {
SDValue V = N->getOperand(1);
SDValue U;
if (isValueExtension(V, Bits, U)) {
if (keepsLowBits(V, Bits, U)) {
SDValue R = CurDAG->getNode(N->getOpcode(), SDLoc(N), N->getValueType(0),
N->getOperand(0), U);
ReplaceNode(N, R.getNode());
@ -949,7 +868,6 @@ void HexagonDAGToDAGISel::Select(SDNode *N) {
case ISD::SHL: return SelectSHL(N);
case ISD::LOAD: return SelectLoad(N);
case ISD::STORE: return SelectStore(N);
case ISD::MUL: return SelectMul(N);
case ISD::ZERO_EXTEND: return SelectZeroExtend(N);
case ISD::INTRINSIC_W_CHAIN: return SelectIntrinsicWChain(N);
case ISD::INTRINSIC_WO_CHAIN: return SelectIntrinsicWOChain(N);
@ -1327,7 +1245,7 @@ void HexagonDAGToDAGISel::EmitFunctionEntryCode() {
}
// Match a frame index that can be used in an addressing mode.
bool HexagonDAGToDAGISel::SelectAddrFI(SDValue& N, SDValue &R) {
bool HexagonDAGToDAGISel::SelectAddrFI(SDValue &N, SDValue &R) {
if (N.getOpcode() != ISD::FrameIndex)
return false;
auto &HFI = *HST->getFrameLowering();
@ -1388,16 +1306,83 @@ bool HexagonDAGToDAGISel::SelectGlobalAddress(SDValue &N, SDValue &R,
return false;
}
bool HexagonDAGToDAGISel::isValueExtension(const SDValue &Val,
unsigned FromBits, SDValue &Src) {
bool HexagonDAGToDAGISel::DetectUseSxtw(SDValue &N, SDValue &R) {
// This (complex pattern) function is meant to detect a sign-extension
// i32->i64 on a per-operand basis. This would allow writing single
// patterns that would cover a number of combinations of different ways
// a sign-extensions could be written. For example:
// (mul (DetectUseSxtw x) (DetectUseSxtw y)) -> (M2_dpmpyss_s0 x y)
// could match either one of these:
// (mul (sext x) (sext_inreg y))
// (mul (sext-load *p) (sext_inreg y))
// (mul (sext_inreg x) (sext y))
// etc.
//
// The returned value will have type i64 and its low word will
// contain the value being extended. The high bits are not specified.
// The returned type is i64 because the original type of N was i64,
// but the users of this function should only use the low-word of the
// result, e.g.
// (mul sxtw:x, sxtw:y) -> (M2_dpmpyss_s0 (LoReg sxtw:x), (LoReg sxtw:y))
if (N.getValueType() != MVT::i64)
return false;
EVT SrcVT;
unsigned Opc = N.getOpcode();
switch (Opc) {
case ISD::SIGN_EXTEND:
case ISD::SIGN_EXTEND_INREG: {
// sext_inreg has the source type as a separate operand.
EVT T = Opc == ISD::SIGN_EXTEND
? N.getOperand(0).getValueType()
: cast<VTSDNode>(N.getOperand(1))->getVT();
if (T.getSizeInBits() != 32)
return false;
R = N.getOperand(0);
break;
}
case ISD::LOAD: {
LoadSDNode *L = cast<LoadSDNode>(N);
if (L->getExtensionType() != ISD::SEXTLOAD)
return false;
// All extending loads extend to i32, so even if the value in
// memory is shorter than 32 bits, it will be i32 after the load.
if (L->getMemoryVT().getSizeInBits() > 32)
return false;
R = N;
break;
}
default:
return false;
}
EVT RT = R.getValueType();
if (RT == MVT::i64)
return true;
assert(RT == MVT::i32);
// This is only to produce a value of type i64. Do not rely on the
// high bits produced by this.
const SDLoc &dl(N);
SDValue Ops[] = {
CurDAG->getTargetConstant(Hexagon::DoubleRegsRegClassID, dl, MVT::i32),
R, CurDAG->getTargetConstant(Hexagon::isub_hi, dl, MVT::i32),
R, CurDAG->getTargetConstant(Hexagon::isub_lo, dl, MVT::i32)
};
SDNode *T = CurDAG->getMachineNode(TargetOpcode::REG_SEQUENCE, dl,
MVT::i64, Ops);
R = SDValue(T, 0);
return true;
}
bool HexagonDAGToDAGISel::keepsLowBits(const SDValue &Val, unsigned NumBits,
SDValue &Src) {
unsigned Opc = Val.getOpcode();
switch (Opc) {
case ISD::SIGN_EXTEND:
case ISD::ZERO_EXTEND:
case ISD::ANY_EXTEND: {
SDValue const &Op0 = Val.getOperand(0);
const SDValue &Op0 = Val.getOperand(0);
EVT T = Op0.getValueType();
if (T.isInteger() && T.getSizeInBits() == FromBits) {
if (T.isInteger() && T.getSizeInBits() == NumBits) {
Src = Op0;
return true;
}
@ -1408,23 +1393,23 @@ bool HexagonDAGToDAGISel::isValueExtension(const SDValue &Val,
case ISD::AssertZext:
if (Val.getOperand(0).getValueType().isInteger()) {
VTSDNode *T = cast<VTSDNode>(Val.getOperand(1));
if (T->getVT().getSizeInBits() == FromBits) {
if (T->getVT().getSizeInBits() == NumBits) {
Src = Val.getOperand(0);
return true;
}
}
break;
case ISD::AND: {
// Check if this is an AND with "FromBits" of lower bits set to 1.
uint64_t FromMask = (1 << FromBits) - 1;
// Check if this is an AND with NumBits of lower bits set to 1.
uint64_t Mask = (1 << NumBits) - 1;
if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val.getOperand(0))) {
if (C->getZExtValue() == FromMask) {
if (C->getZExtValue() == Mask) {
Src = Val.getOperand(1);
return true;
}
}
if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val.getOperand(1))) {
if (C->getZExtValue() == FromMask) {
if (C->getZExtValue() == Mask) {
Src = Val.getOperand(0);
return true;
}
@ -1433,16 +1418,16 @@ bool HexagonDAGToDAGISel::isValueExtension(const SDValue &Val,
}
case ISD::OR:
case ISD::XOR: {
// OR/XOR with the lower "FromBits" bits set to 0.
uint64_t FromMask = (1 << FromBits) - 1;
// OR/XOR with the lower NumBits bits set to 0.
uint64_t Mask = (1 << NumBits) - 1;
if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val.getOperand(0))) {
if ((C->getZExtValue() & FromMask) == 0) {
if ((C->getZExtValue() & Mask) == 0) {
Src = Val.getOperand(1);
return true;
}
}
if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val.getOperand(1))) {
if ((C->getZExtValue() & FromMask) == 0) {
if ((C->getZExtValue() & Mask) == 0) {
Src = Val.getOperand(0);
return true;
}

View File

@ -1928,11 +1928,7 @@ HexagonTargetLowering::HexagonTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::BITREVERSE, MVT::i64, Legal);
setOperationAction(ISD::BSWAP, MVT::i32, Legal);
setOperationAction(ISD::BSWAP, MVT::i64, Legal);
// We custom lower i64 to i64 mul, so that it is not considered as a legal
// operation. There is a pattern that will match i64 mul and transform it
// to a series of instructions.
setOperationAction(ISD::MUL, MVT::i64, Expand);
setOperationAction(ISD::MUL, MVT::i64, Legal);
for (unsigned IntExpOp :
{ ISD::SDIV, ISD::UDIV, ISD::SREM, ISD::UREM,

View File

@ -382,48 +382,42 @@ def: T_MType_acc_pat3 <M4_or_andn, and, or>;
def: T_MType_acc_pat3 <M4_and_andn, and, and>;
def: T_MType_acc_pat3 <M4_xor_andn, and, xor>;
// This complex pattern is really only to detect various forms of
// sign-extension i32->i64. The selected value will be of type i64
// whose low word is the value being extended. The high word is
// unspecified.
def Usxtw : ComplexPattern<i64, 1, "DetectUseSxtw", [], []>;
def Aext64: PatFrag<(ops node:$Rs), (i64 (anyext node:$Rs))>;
def Sext64: PatFrag<(ops node:$Rs), (i64 (sext node:$Rs))>;
def Zext64: PatFrag<(ops node:$Rs), (i64 (zext node:$Rs))>;
def Sext64: PatLeaf<(i64 Usxtw:$Rs)>;
// Return true if for a 32 to 64-bit sign-extended load.
def Sext64Ld : PatLeaf<(i64 DoubleRegs:$src1), [{
LoadSDNode *LD = dyn_cast<LoadSDNode>(N);
if (!LD)
return false;
return LD->getExtensionType() == ISD::SEXTLOAD &&
LD->getMemoryVT().getScalarType() == MVT::i32;
}]>;
def: Pat<(mul (Aext64 I32:$Rs), (Aext64 I32:$Rt)),
(M2_dpmpyuu_s0 I32:$Rs, I32:$Rt)>;
def: Pat<(mul (Aext64 I32:$src1), (Aext64 I32:$src2)),
(M2_dpmpyuu_s0 IntRegs:$src1, IntRegs:$src2)>;
def: Pat<(mul (Sext64 I32:$src1), (Sext64 I32:$src2)),
(M2_dpmpyss_s0 IntRegs:$src1, IntRegs:$src2)>;
def: Pat<(mul Sext64Ld:$src1, Sext64Ld:$src2),
(M2_dpmpyss_s0 (LoReg DoubleRegs:$src1), (LoReg DoubleRegs:$src2))>;
def: Pat<(mul Sext64:$Rs, Sext64:$Rt),
(M2_dpmpyss_s0 (LoReg Sext64:$Rs), (LoReg Sext64:$Rt))>;
// Multiply and accumulate, use full result.
// Rxx[+-]=mpy(Rs,Rt)
def: Pat<(add I64:$src1, (mul (Sext64 I32:$src2), (Sext64 I32:$src3))),
(M2_dpmpyss_acc_s0 DoubleRegs:$src1, IntRegs:$src2, IntRegs:$src3)>;
def: Pat<(add I64:$Rx, (mul Sext64:$Rs, Sext64:$Rt)),
(M2_dpmpyss_acc_s0 I64:$Rx, (LoReg Sext64:$Rs), (LoReg Sext64:$Rt))>;
def: Pat<(sub I64:$src1, (mul (Sext64 I32:$src2), (Sext64 I32:$src3))),
(M2_dpmpyss_nac_s0 DoubleRegs:$src1, IntRegs:$src2, IntRegs:$src3)>;
def: Pat<(sub I64:$Rx, (mul Sext64:$Rs, Sext64:$Rt)),
(M2_dpmpyss_nac_s0 I64:$Rx, (LoReg Sext64:$Rs), (LoReg Sext64:$Rt))>;
def: Pat<(add I64:$src1, (mul (Aext64 I32:$src2), (Aext64 I32:$src3))),
(M2_dpmpyuu_acc_s0 DoubleRegs:$src1, IntRegs:$src2, IntRegs:$src3)>;
def: Pat<(add I64:$Rx, (mul (Aext64 I32:$Rs), (Aext64 I32:$Rt))),
(M2_dpmpyuu_acc_s0 I64:$Rx, I32:$Rs, I32:$Rt)>;
def: Pat<(add I64:$src1, (mul (Zext64 I32:$src2), (Zext64 I32:$src3))),
(M2_dpmpyuu_acc_s0 DoubleRegs:$src1, IntRegs:$src2, IntRegs:$src3)>;
def: Pat<(add I64:$Rx, (mul (Zext64 I32:$Rs), (Zext64 I32:$Rt))),
(M2_dpmpyuu_acc_s0 I64:$Rx, I32:$Rs, I32:$Rt)>;
def: Pat<(sub I64:$src1, (mul (Aext64 I32:$src2), (Aext64 I32:$src3))),
(M2_dpmpyuu_nac_s0 DoubleRegs:$src1, IntRegs:$src2, IntRegs:$src3)>;
def: Pat<(sub I64:$Rx, (mul (Aext64 I32:$Rs), (Aext64 I32:$Rt))),
(M2_dpmpyuu_nac_s0 I64:$Rx, I32:$Rs, I32:$Rt)>;
def: Pat<(sub I64:$src1, (mul (Zext64 I32:$src2), (Zext64 I32:$src3))),
(M2_dpmpyuu_nac_s0 DoubleRegs:$src1, IntRegs:$src2, IntRegs:$src3)>;
def: Pat<(sub I64:$Rx, (mul (Zext64 I32:$Rs), (Zext64 I32:$Rt))),
(M2_dpmpyuu_nac_s0 I64:$Rx, I32:$Rs, I32:$Rt)>;
class Storepi_pat<PatFrag Store, PatFrag Value, PatFrag Offset,
InstHexagon MI>
@ -545,7 +539,8 @@ def: Storexm_simple_pat<truncstorei8, I64, LoReg, S2_storerb_io>;
def: Storexm_simple_pat<truncstorei16, I64, LoReg, S2_storerh_io>;
def: Storexm_simple_pat<truncstorei32, I64, LoReg, S2_storeri_io>;
def: Pat <(Sext64 I32:$src), (A2_sxtw I32:$src)>;
def: Pat <(i64 (sext I32:$src)), (A2_sxtw I32:$src)>;
def: Pat <(i64 (sext_inreg I64:$src, i32)), (A2_sxtw (LoReg I64:$src))>;
def: Pat<(select (i1 (setlt I32:$src, 0)), (sub 0, I32:$src), I32:$src),
(A2_abs IntRegs:$src)>;
@ -1159,8 +1154,8 @@ multiclass MinMax_pats_p<PatFrag Op, InstHexagon Inst, InstHexagon SwapInst> {
defm: T_MinMax_pats<Op, I64, Inst, SwapInst>;
}
def: Pat<(add (Sext64 I32:$Rs), I64:$Rt),
(A2_addsp IntRegs:$Rs, DoubleRegs:$Rt)>;
def: Pat<(add Sext64:$Rs, I64:$Rt),
(A2_addsp (LoReg Sext64:$Rs), DoubleRegs:$Rt)>;
let AddedComplexity = 200 in {
defm: MinMax_pats_p<setge, A2_maxp, A2_minp>;

View File

@ -0,0 +1,93 @@
; RUN: llc -march=hexagon < %s | FileCheck %s
target triple = "hexagon-unknown--elf"
; CHECK-LABEL: mul_1
; CHECK: r1:0 = mpy(r2,r0)
define i64 @mul_1(i64 %a0, i64 %a1) #0 {
b2:
%v3 = shl i64 %a0, 32
%v4 = ashr exact i64 %v3, 32
%v5 = shl i64 %a1, 32
%v6 = ashr exact i64 %v5, 32
%v7 = mul nsw i64 %v6, %v4
ret i64 %v7
}
; CHECK-LABEL: mul_2
; CHECK: r0 = memb(r0+#0)
; CHECK: r1:0 = mpy(r2,r0)
; CHECK: jumpr r31
define i64 @mul_2(i8* %a0, i64 %a1) #0 {
b2:
%v3 = load i8, i8* %a0
%v4 = sext i8 %v3 to i64
%v5 = shl i64 %a1, 32
%v6 = ashr exact i64 %v5, 32
%v7 = mul nsw i64 %v6, %v4
ret i64 %v7
}
; CHECK-LABEL: mul_acc_1
; CHECK: r5:4 += mpy(r2,r0)
; CHECK: r1:0 = combine(r5,r4)
; CHECK: jumpr r31
define i64 @mul_acc_1(i64 %a0, i64 %a1, i64 %a2) #0 {
b3:
%v4 = shl i64 %a0, 32
%v5 = ashr exact i64 %v4, 32
%v6 = shl i64 %a1, 32
%v7 = ashr exact i64 %v6, 32
%v8 = mul nsw i64 %v7, %v5
%v9 = add i64 %a2, %v8
ret i64 %v9
}
; CHECK-LABEL: mul_acc_2
; CHECK: r2 = memw(r2+#0)
; CHECK: r5:4 += mpy(r2,r0)
; CHECK: r1:0 = combine(r5,r4)
; CHECK: jumpr r31
define i64 @mul_acc_2(i64 %a0, i32* %a1, i64 %a2) #0 {
b3:
%v4 = shl i64 %a0, 32
%v5 = ashr exact i64 %v4, 32
%v6 = load i32, i32* %a1
%v7 = sext i32 %v6 to i64
%v8 = mul nsw i64 %v7, %v5
%v9 = add i64 %a2, %v8
ret i64 %v9
}
; CHECK-LABEL: mul_nac_1
; CHECK: r5:4 -= mpy(r2,r0)
; CHECK: r1:0 = combine(r5,r4)
; CHECK: jumpr r31
define i64 @mul_nac_1(i64 %a0, i64 %a1, i64 %a2) #0 {
b3:
%v4 = shl i64 %a0, 32
%v5 = ashr exact i64 %v4, 32
%v6 = shl i64 %a1, 32
%v7 = ashr exact i64 %v6, 32
%v8 = mul nsw i64 %v7, %v5
%v9 = sub i64 %a2, %v8
ret i64 %v9
}
; CHECK-LABEL: mul_nac_2
; CHECK: r0 = memw(r0+#0)
; CHECK: r5:4 -= mpy(r2,r0)
; CHECK: r1:0 = combine(r5,r4)
; CHECK: jumpr r31
define i64 @mul_nac_2(i32* %a0, i64 %a1, i64 %a2) #0 {
b3:
%v4 = load i32, i32* %a0
%v5 = sext i32 %v4 to i64
%v6 = shl i64 %a1, 32
%v7 = ashr exact i64 %v6, 32
%v8 = mul nsw i64 %v7, %v5
%v9 = sub i64 %a2, %v8
ret i64 %v9
}
attributes #0 = { nounwind }