Reapply r216805 "[MachineCombiner][AArch64] Use the correct register class for MADD, SUB, and OR.""

This reapplies r216805 with a fix to a copy-past error, which resulted in an
incorrect register class.

Original commit message:
Select the correct register class for the various instructions that are
generated when combining instructions and constrain the registers to the
appropriate register class.

This fixes rdar://problem/18183707.

llvm-svn: 217019
This commit is contained in:
Juergen Ributzka 2014-09-03 07:07:10 +00:00
parent cf05f91ab5
commit 31e5b7fb12
2 changed files with 148 additions and 79 deletions

View File

@ -2426,20 +2426,34 @@ bool AArch64InstrInfo::hasPattern(
static MachineInstr *genMadd(MachineFunction &MF, MachineRegisterInfo &MRI,
const TargetInstrInfo *TII, MachineInstr &Root,
SmallVectorImpl<MachineInstr *> &InsInstrs,
unsigned IdxMulOpd, unsigned MaddOpc) {
unsigned IdxMulOpd, unsigned MaddOpc,
const TargetRegisterClass *RC) {
assert(IdxMulOpd == 1 || IdxMulOpd == 2);
unsigned IdxOtherOpd = IdxMulOpd == 1 ? 2 : 1;
MachineInstr *MUL = MRI.getUniqueVRegDef(Root.getOperand(IdxMulOpd).getReg());
MachineOperand R = Root.getOperand(0);
MachineOperand A = MUL->getOperand(1);
MachineOperand B = MUL->getOperand(2);
MachineOperand C = Root.getOperand(IdxOtherOpd);
MachineInstrBuilder MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MaddOpc))
.addOperand(R)
.addOperand(A)
.addOperand(B)
.addOperand(C);
unsigned ResultReg = Root.getOperand(0).getReg();
unsigned SrcReg0 = MUL->getOperand(1).getReg();
bool Src0IsKill = MUL->getOperand(1).isKill();
unsigned SrcReg1 = MUL->getOperand(2).getReg();
bool Src1IsKill = MUL->getOperand(2).isKill();
unsigned SrcReg2 = Root.getOperand(IdxOtherOpd).getReg();
bool Src2IsKill = Root.getOperand(IdxOtherOpd).isKill();
if (TargetRegisterInfo::isVirtualRegister(ResultReg))
MRI.constrainRegClass(ResultReg, RC);
if (TargetRegisterInfo::isVirtualRegister(SrcReg0))
MRI.constrainRegClass(SrcReg0, RC);
if (TargetRegisterInfo::isVirtualRegister(SrcReg1))
MRI.constrainRegClass(SrcReg1, RC);
if (TargetRegisterInfo::isVirtualRegister(SrcReg2))
MRI.constrainRegClass(SrcReg2, RC);
MachineInstrBuilder MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MaddOpc),
ResultReg)
.addReg(SrcReg0, getKillRegState(Src0IsKill))
.addReg(SrcReg1, getKillRegState(Src1IsKill))
.addReg(SrcReg2, getKillRegState(Src2IsKill));
// Insert the MADD
InsInstrs.push_back(MIB);
return MUL;
@ -2464,22 +2478,35 @@ static MachineInstr *genMaddR(MachineFunction &MF, MachineRegisterInfo &MRI,
const TargetInstrInfo *TII, MachineInstr &Root,
SmallVectorImpl<MachineInstr *> &InsInstrs,
unsigned IdxMulOpd, unsigned MaddOpc,
unsigned VR) {
unsigned VR, const TargetRegisterClass *RC) {
assert(IdxMulOpd == 1 || IdxMulOpd == 2);
MachineInstr *MUL = MRI.getUniqueVRegDef(Root.getOperand(IdxMulOpd).getReg());
MachineOperand R = Root.getOperand(0);
MachineOperand A = MUL->getOperand(1);
MachineOperand B = MUL->getOperand(2);
MachineInstrBuilder MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MaddOpc))
.addOperand(R)
.addOperand(A)
.addOperand(B)
unsigned ResultReg = Root.getOperand(0).getReg();
unsigned SrcReg0 = MUL->getOperand(1).getReg();
bool Src0IsKill = MUL->getOperand(1).isKill();
unsigned SrcReg1 = MUL->getOperand(2).getReg();
bool Src1IsKill = MUL->getOperand(2).isKill();
if (TargetRegisterInfo::isVirtualRegister(ResultReg))
MRI.constrainRegClass(ResultReg, RC);
if (TargetRegisterInfo::isVirtualRegister(SrcReg0))
MRI.constrainRegClass(SrcReg0, RC);
if (TargetRegisterInfo::isVirtualRegister(SrcReg1))
MRI.constrainRegClass(SrcReg1, RC);
if (TargetRegisterInfo::isVirtualRegister(VR))
MRI.constrainRegClass(VR, RC);
MachineInstrBuilder MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MaddOpc),
ResultReg)
.addReg(SrcReg0, getKillRegState(Src0IsKill))
.addReg(SrcReg1, getKillRegState(Src1IsKill))
.addReg(VR);
// Insert the MADD
InsInstrs.push_back(MIB);
return MUL;
}
/// genAlternativeCodeSequence - when hasPattern() finds a pattern
/// this function generates the instructions that could replace the
/// original code sequence
@ -2494,6 +2521,7 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
const TargetInstrInfo *TII = MF.getTarget().getSubtargetImpl()->getInstrInfo();
MachineInstr *MUL;
const TargetRegisterClass *RC;
unsigned Opc;
switch (Pattern) {
default:
@ -2505,9 +2533,14 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
// ADD R,I,C
// ==> MADD R,A,B,C
// --- Create(MADD);
Opc = Pattern == MachineCombinerPattern::MC_MULADDW_OP1 ? AArch64::MADDWrrr
: AArch64::MADDXrrr;
MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 1, Opc);
if (Pattern == MachineCombinerPattern::MC_MULADDW_OP1) {
Opc = AArch64::MADDWrrr;
RC = &AArch64::GPR32RegClass;
} else {
Opc = AArch64::MADDXrrr;
RC = &AArch64::GPR64RegClass;
}
MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
break;
case MachineCombinerPattern::MC_MULADDW_OP2:
case MachineCombinerPattern::MC_MULADDX_OP2:
@ -2515,54 +2548,59 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
// ADD R,C,I
// ==> MADD R,A,B,C
// --- Create(MADD);
Opc = Pattern == MachineCombinerPattern::MC_MULADDW_OP2 ? AArch64::MADDWrrr
: AArch64::MADDXrrr;
MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 2, Opc);
if (Pattern == MachineCombinerPattern::MC_MULADDW_OP2) {
Opc = AArch64::MADDWrrr;
RC = &AArch64::GPR32RegClass;
} else {
Opc = AArch64::MADDXrrr;
RC = &AArch64::GPR64RegClass;
}
MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
break;
case MachineCombinerPattern::MC_MULADDWI_OP1:
case MachineCombinerPattern::MC_MULADDXI_OP1:
case MachineCombinerPattern::MC_MULADDXI_OP1: {
// MUL I=A,B,0
// ADD R,I,Imm
// ==> ORR V, ZR, Imm
// ==> MADD R,A,B,V
// --- Create(MADD);
{
const TargetRegisterClass *RC =
MRI.getRegClass(Root.getOperand(1).getReg());
unsigned NewVR = MRI.createVirtualRegister(RC);
const TargetRegisterClass *OrrRC;
unsigned BitSize, OrrOpc, ZeroReg;
if (Pattern == MachineCombinerPattern::MC_MULADDWI_OP1) {
BitSize = 32;
OrrOpc = AArch64::ORRWri;
OrrRC = &AArch64::GPR32spRegClass;
BitSize = 32;
ZeroReg = AArch64::WZR;
Opc = AArch64::MADDWrrr;
RC = &AArch64::GPR32RegClass;
} else {
OrrOpc = AArch64::ORRXri;
OrrRC = &AArch64::GPR64spRegClass;
BitSize = 64;
ZeroReg = AArch64::XZR;
Opc = AArch64::MADDXrrr;
RC = &AArch64::GPR64RegClass;
}
unsigned NewVR = MRI.createVirtualRegister(OrrRC);
uint64_t Imm = Root.getOperand(2).getImm();
if (Root.getOperand(3).isImm()) {
unsigned val = Root.getOperand(3).getImm();
Imm = Imm << val;
unsigned Val = Root.getOperand(3).getImm();
Imm = Imm << Val;
}
uint64_t UImm = Imm << (64 - BitSize) >> (64 - BitSize);
uint64_t Encoding;
if (AArch64_AM::processLogicalImmediate(UImm, BitSize, Encoding)) {
MachineInstrBuilder MIB1 =
BuildMI(MF, Root.getDebugLoc(), TII->get(OrrOpc))
.addOperand(MachineOperand::CreateReg(NewVR, true))
BuildMI(MF, Root.getDebugLoc(), TII->get(OrrOpc), NewVR)
.addReg(ZeroReg)
.addImm(Encoding);
InsInstrs.push_back(MIB1);
InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0));
MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR);
}
MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR, RC);
}
break;
}
case MachineCombinerPattern::MC_MULSUBW_OP1:
case MachineCombinerPattern::MC_MULSUBX_OP1: {
// MUL I=A,B,0
@ -2570,38 +2608,46 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
// ==> SUB V, 0, C
// ==> MADD R,A,B,V // = -C + A*B
// --- Create(MADD);
const TargetRegisterClass *RC =
MRI.getRegClass(Root.getOperand(1).getReg());
unsigned NewVR = MRI.createVirtualRegister(RC);
const TargetRegisterClass *SubRC;
unsigned SubOpc, ZeroReg;
if (Pattern == MachineCombinerPattern::MC_MULSUBW_OP1) {
SubOpc = AArch64::SUBWrr;
SubRC = &AArch64::GPR32spRegClass;
ZeroReg = AArch64::WZR;
Opc = AArch64::MADDWrrr;
RC = &AArch64::GPR32RegClass;
} else {
SubOpc = AArch64::SUBXrr;
SubRC = &AArch64::GPR64spRegClass;
ZeroReg = AArch64::XZR;
Opc = AArch64::MADDXrrr;
RC = &AArch64::GPR64RegClass;
}
unsigned NewVR = MRI.createVirtualRegister(SubRC);
// SUB NewVR, 0, C
MachineInstrBuilder MIB1 =
BuildMI(MF, Root.getDebugLoc(), TII->get(SubOpc))
.addOperand(MachineOperand::CreateReg(NewVR, true))
BuildMI(MF, Root.getDebugLoc(), TII->get(SubOpc), NewVR)
.addReg(ZeroReg)
.addOperand(Root.getOperand(2));
InsInstrs.push_back(MIB1);
InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0));
MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR);
} break;
MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR, RC);
break;
}
case MachineCombinerPattern::MC_MULSUBW_OP2:
case MachineCombinerPattern::MC_MULSUBX_OP2:
// MUL I=A,B,0
// SUB R,C,I
// ==> MSUB R,A,B,C (computes C - A*B)
// --- Create(MSUB);
Opc = Pattern == MachineCombinerPattern::MC_MULSUBW_OP2 ? AArch64::MSUBWrrr
: AArch64::MSUBXrrr;
MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 2, Opc);
if (Pattern == MachineCombinerPattern::MC_MULSUBW_OP2) {
Opc = AArch64::MSUBWrrr;
RC = &AArch64::GPR32RegClass;
} else {
Opc = AArch64::MSUBXrrr;
RC = &AArch64::GPR64RegClass;
}
MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
break;
case MachineCombinerPattern::MC_MULSUBWI_OP1:
case MachineCombinerPattern::MC_MULSUBXI_OP1: {
@ -2610,40 +2656,43 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
// ==> ORR V, ZR, -Imm
// ==> MADD R,A,B,V // = -Imm + A*B
// --- Create(MADD);
const TargetRegisterClass *RC =
MRI.getRegClass(Root.getOperand(1).getReg());
unsigned NewVR = MRI.createVirtualRegister(RC);
const TargetRegisterClass *OrrRC;
unsigned BitSize, OrrOpc, ZeroReg;
if (Pattern == MachineCombinerPattern::MC_MULSUBWI_OP1) {
BitSize = 32;
OrrOpc = AArch64::ORRWri;
OrrRC = &AArch64::GPR32spRegClass;
BitSize = 32;
ZeroReg = AArch64::WZR;
Opc = AArch64::MADDWrrr;
RC = &AArch64::GPR32RegClass;
} else {
OrrOpc = AArch64::ORRXri;
OrrRC = &AArch64::GPR64RegClass;
BitSize = 64;
ZeroReg = AArch64::XZR;
Opc = AArch64::MADDXrrr;
RC = &AArch64::GPR64RegClass;
}
unsigned NewVR = MRI.createVirtualRegister(OrrRC);
int Imm = Root.getOperand(2).getImm();
if (Root.getOperand(3).isImm()) {
unsigned val = Root.getOperand(3).getImm();
Imm = Imm << val;
unsigned Val = Root.getOperand(3).getImm();
Imm = Imm << Val;
}
uint64_t UImm = -Imm << (64 - BitSize) >> (64 - BitSize);
uint64_t Encoding;
if (AArch64_AM::processLogicalImmediate(UImm, BitSize, Encoding)) {
MachineInstrBuilder MIB1 =
BuildMI(MF, Root.getDebugLoc(), TII->get(OrrOpc))
.addOperand(MachineOperand::CreateReg(NewVR, true))
BuildMI(MF, Root.getDebugLoc(), TII->get(OrrOpc), NewVR)
.addReg(ZeroReg)
.addImm(Encoding);
InsInstrs.push_back(MIB1);
InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0));
MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR);
MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR, RC);
}
} break;
break;
}
} // end switch (Pattern)
// Record MUL and ADD/SUB for deletion
DelInstrs.push_back(MUL);
DelInstrs.push_back(&Root);

View File

@ -0,0 +1,20 @@
; RUN: llc -mtriple=aarch64-apple-darwin -verify-machineinstrs < %s | FileCheck %s
; Test that we use the correct register class.
define i32 @mul_add_imm(i32 %a, i32 %b) {
; CHECK-LABEL: mul_add_imm
; CHECK: orr [[REG:w[0-9]+]], wzr, #0x4
; CHECK-NEXT: madd {{w[0-9]+}}, w0, w1, [[REG]]
%1 = mul i32 %a, %b
%2 = add i32 %1, 4
ret i32 %2
}
define i32 @mul_sub_imm1(i32 %a, i32 %b) {
; CHECK-LABEL: mul_sub_imm1
; CHECK: orr [[REG:w[0-9]+]], wzr, #0x4
; CHECK-NEXT: msub {{w[0-9]+}}, w0, w1, [[REG]]
%1 = mul i32 %a, %b
%2 = sub i32 4, %1
ret i32 %2
}