PTX: Fix predicate logic bug
Code such as: %vreg100 = setcc %vreg10, -1, SETNE brcond %vreg10, %tgt was being incorrectly morphed into %vreg100 = and %vreg10, 1 brcond %vreg10, %tgt where the 'and' instruction could be eliminated since such logic is on 1-bit types in the PTX back-end, leaving us with just: brcond %vreg10, %tgt which essentially gives us inverted branch conditions. llvm-svn: 153364
This commit is contained in:
parent
86027e954c
commit
a84577dcff
|
@ -97,7 +97,8 @@ PTXTargetLowering::PTXTargetLowering(TargetMachine &TM)
|
||||||
|
|
||||||
// customise setcc to use bitwise logic if possible
|
// customise setcc to use bitwise logic if possible
|
||||||
|
|
||||||
setOperationAction(ISD::SETCC, MVT::i1, Custom);
|
//setOperationAction(ISD::SETCC, MVT::i1, Custom);
|
||||||
|
setOperationAction(ISD::SETCC, MVT::i1, Legal);
|
||||||
|
|
||||||
// customize translation of memory addresses
|
// customize translation of memory addresses
|
||||||
|
|
||||||
|
@ -156,18 +157,27 @@ SDValue PTXTargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
|
||||||
SDValue Op1 = Op.getOperand(1);
|
SDValue Op1 = Op.getOperand(1);
|
||||||
SDValue Op2 = Op.getOperand(2);
|
SDValue Op2 = Op.getOperand(2);
|
||||||
DebugLoc dl = Op.getDebugLoc();
|
DebugLoc dl = Op.getDebugLoc();
|
||||||
ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
|
//ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
|
||||||
|
|
||||||
// Look for X == 0, X == 1, X != 0, or X != 1
|
// Look for X == 0, X == 1, X != 0, or X != 1
|
||||||
// We can simplify these to bitwise logic
|
// We can simplify these to bitwise logic
|
||||||
|
|
||||||
if (Op1.getOpcode() == ISD::Constant &&
|
//if (Op1.getOpcode() == ISD::Constant &&
|
||||||
(cast<ConstantSDNode>(Op1)->getZExtValue() == 1 ||
|
// (cast<ConstantSDNode>(Op1)->getZExtValue() == 1 ||
|
||||||
cast<ConstantSDNode>(Op1)->isNullValue()) &&
|
// cast<ConstantSDNode>(Op1)->isNullValue()) &&
|
||||||
(CC == ISD::SETEQ || CC == ISD::SETNE)) {
|
// (CC == ISD::SETEQ || CC == ISD::SETNE)) {
|
||||||
|
//
|
||||||
|
// return DAG.getNode(ISD::AND, dl, MVT::i1, Op0, Op1);
|
||||||
|
//}
|
||||||
|
|
||||||
return DAG.getNode(ISD::AND, dl, MVT::i1, Op0, Op1);
|
//ConstantSDNode* COp1 = cast<ConstantSDNode>(Op1);
|
||||||
}
|
//if(COp1 && COp1->getZExtValue() == 1) {
|
||||||
|
// if(CC == ISD::SETNE) {
|
||||||
|
// return DAG.getNode(PTX::XORripreds, dl, MVT::i1, Op0);
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
|
||||||
|
llvm_unreachable("setcc was not matched by a pattern!");
|
||||||
|
|
||||||
return DAG.getNode(ISD::SETCC, dl, MVT::i1, Op0, Op1, Op2);
|
return DAG.getNode(ISD::SETCC, dl, MVT::i1, Op0, Op1, Op2);
|
||||||
}
|
}
|
||||||
|
@ -384,22 +394,22 @@ PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
|
||||||
PTXMachineFunctionInfo *PTXMFI = MF.getInfo<PTXMachineFunctionInfo>();
|
PTXMachineFunctionInfo *PTXMFI = MF.getInfo<PTXMachineFunctionInfo>();
|
||||||
PTXParamManager &PM = PTXMFI->getParamManager();
|
PTXParamManager &PM = PTXMFI->getParamManager();
|
||||||
MachineFrameInfo *MFI = MF.getFrameInfo();
|
MachineFrameInfo *MFI = MF.getFrameInfo();
|
||||||
|
|
||||||
assert(getTargetMachine().getSubtarget<PTXSubtarget>().callsAreHandled() &&
|
assert(getTargetMachine().getSubtarget<PTXSubtarget>().callsAreHandled() &&
|
||||||
"Calls are not handled for the target device");
|
"Calls are not handled for the target device");
|
||||||
|
|
||||||
// Identify the callee function
|
// Identify the callee function
|
||||||
const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal();
|
const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal();
|
||||||
const Function *function = cast<Function>(GV);
|
const Function *function = cast<Function>(GV);
|
||||||
|
|
||||||
// allow non-device calls only for printf
|
// allow non-device calls only for printf
|
||||||
bool isPrintf = function->getName() == "printf" || function->getName() == "puts";
|
bool isPrintf = function->getName() == "printf" || function->getName() == "puts";
|
||||||
|
|
||||||
assert((isPrintf || function->getCallingConv() == CallingConv::PTX_Device) &&
|
assert((isPrintf || function->getCallingConv() == CallingConv::PTX_Device) &&
|
||||||
"PTX function calls must be to PTX device functions");
|
"PTX function calls must be to PTX device functions");
|
||||||
|
|
||||||
unsigned outSize = isPrintf ? 2 : Outs.size();
|
unsigned outSize = isPrintf ? 2 : Outs.size();
|
||||||
|
|
||||||
std::vector<SDValue> Ops;
|
std::vector<SDValue> Ops;
|
||||||
// The layout of the ops will be [Chain, #Ins, Ins, Callee, #Outs, Outs]
|
// The layout of the ops will be [Chain, #Ins, Ins, Callee, #Outs, Outs]
|
||||||
Ops.resize(outSize + Ins.size() + 4);
|
Ops.resize(outSize + Ins.size() + 4);
|
||||||
|
@ -412,7 +422,7 @@ PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
|
||||||
|
|
||||||
// #Outs
|
// #Outs
|
||||||
Ops[Ins.size()+3] = DAG.getTargetConstant(outSize, MVT::i32);
|
Ops[Ins.size()+3] = DAG.getTargetConstant(outSize, MVT::i32);
|
||||||
|
|
||||||
if (isPrintf) {
|
if (isPrintf) {
|
||||||
// first argument is the address of the global string variable in memory
|
// first argument is the address of the global string variable in memory
|
||||||
unsigned Param0 = PM.addLocalParam(getPointerTy().getSizeInBits());
|
unsigned Param0 = PM.addLocalParam(getPointerTy().getSizeInBits());
|
||||||
|
@ -421,29 +431,29 @@ PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
|
||||||
Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
|
Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
|
||||||
ParamValue0, OutVals[0]);
|
ParamValue0, OutVals[0]);
|
||||||
Ops[Ins.size()+4] = ParamValue0;
|
Ops[Ins.size()+4] = ParamValue0;
|
||||||
|
|
||||||
// alignment is the maximum size of all the arguments
|
// alignment is the maximum size of all the arguments
|
||||||
unsigned alignment = 0;
|
unsigned alignment = 0;
|
||||||
for (unsigned i = 1; i < OutVals.size(); ++i) {
|
for (unsigned i = 1; i < OutVals.size(); ++i) {
|
||||||
alignment = std::max(alignment,
|
alignment = std::max(alignment,
|
||||||
OutVals[i].getValueType().getSizeInBits());
|
OutVals[i].getValueType().getSizeInBits());
|
||||||
}
|
}
|
||||||
|
|
||||||
// size is the alignment multiplied by the number of arguments
|
// size is the alignment multiplied by the number of arguments
|
||||||
unsigned size = alignment * (OutVals.size() - 1);
|
unsigned size = alignment * (OutVals.size() - 1);
|
||||||
|
|
||||||
// second argument is the address of the stack object (unless no arguments)
|
// second argument is the address of the stack object (unless no arguments)
|
||||||
unsigned Param1 = PM.addLocalParam(getPointerTy().getSizeInBits());
|
unsigned Param1 = PM.addLocalParam(getPointerTy().getSizeInBits());
|
||||||
SDValue ParamValue1 = DAG.getTargetExternalSymbol(PM.getParamName(Param1).c_str(),
|
SDValue ParamValue1 = DAG.getTargetExternalSymbol(PM.getParamName(Param1).c_str(),
|
||||||
MVT::Other);
|
MVT::Other);
|
||||||
Ops[Ins.size()+5] = ParamValue1;
|
Ops[Ins.size()+5] = ParamValue1;
|
||||||
|
|
||||||
if (size > 0)
|
if (size > 0)
|
||||||
{
|
{
|
||||||
// create a local stack object to store the arguments
|
// create a local stack object to store the arguments
|
||||||
unsigned StackObject = MFI->CreateStackObject(size / 8, alignment / 8, false);
|
unsigned StackObject = MFI->CreateStackObject(size / 8, alignment / 8, false);
|
||||||
SDValue FrameIndex = DAG.getFrameIndex(StackObject, getPointerTy());
|
SDValue FrameIndex = DAG.getFrameIndex(StackObject, getPointerTy());
|
||||||
|
|
||||||
// store each of the arguments to the stack in turn
|
// store each of the arguments to the stack in turn
|
||||||
for (unsigned int i = 1; i != OutVals.size(); i++) {
|
for (unsigned int i = 1; i != OutVals.size(); i++) {
|
||||||
SDValue FrameAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), FrameIndex, DAG.getTargetConstant((i - 1) * 8, getPointerTy()));
|
SDValue FrameAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), FrameIndex, DAG.getTargetConstant((i - 1) * 8, getPointerTy()));
|
||||||
|
@ -475,7 +485,7 @@ PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
|
||||||
Ops[i+Ins.size()+4] = ParamValue;
|
Ops[i+Ins.size()+4] = ParamValue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<SDValue> InParams;
|
std::vector<SDValue> InParams;
|
||||||
|
|
||||||
// Generate list of .param variables to hold the return value(s).
|
// Generate list of .param variables to hold the return value(s).
|
||||||
|
|
|
@ -808,6 +808,8 @@ let isBranch = 1, isTerminator = 1, isBarrier = 1 in {
|
||||||
let isBranch = 1, isTerminator = 1 in {
|
let isBranch = 1, isTerminator = 1 in {
|
||||||
// FIXME: The pattern part is blank because I cannot (or do not yet know
|
// FIXME: The pattern part is blank because I cannot (or do not yet know
|
||||||
// how to) use the first operand of PredicateOperand (a RegPred register) here
|
// how to) use the first operand of PredicateOperand (a RegPred register) here
|
||||||
|
// When this is revisited, make sure to also look at LowerSETCC and try to
|
||||||
|
// fold it into negated predicates, if possible.
|
||||||
def BRAdp
|
def BRAdp
|
||||||
: InstPTX<(outs), (ins brtarget:$d), "bra\t$d",
|
: InstPTX<(outs), (ins brtarget:$d), "bra\t$d",
|
||||||
[/*(brcond pred:$_p, bb:$d)*/]>;
|
[/*(brcond pred:$_p, bb:$d)*/]>;
|
||||||
|
@ -1017,6 +1019,9 @@ def : Pat<(f64 (sint_to_fp RegI64:$a)), (CVTf64s64 RndDefault, RegI64:$a)>;
|
||||||
def : Pat<(f64 (fextend RegF32:$a)), (CVTf64f32 RegF32:$a)>;
|
def : Pat<(f64 (fextend RegF32:$a)), (CVTf64f32 RegF32:$a)>;
|
||||||
def : Pat<(f64 (bitconvert RegI64:$a)), (MOVf64i64 RegI64:$a)>;
|
def : Pat<(f64 (bitconvert RegI64:$a)), (MOVf64i64 RegI64:$a)>;
|
||||||
|
|
||||||
|
// setcc - predicate inversion for branch conditions
|
||||||
|
def : Pat<(i1 (setcc RegPred:$a, imm:$b, SETNE)),
|
||||||
|
(XORripreds RegPred:$a, imm:$b)>;
|
||||||
|
|
||||||
///===- Intrinsic Instructions --------------------------------------------===//
|
///===- Intrinsic Instructions --------------------------------------------===//
|
||||||
include "PTXIntrinsicInstrInfo.td"
|
include "PTXIntrinsicInstrInfo.td"
|
||||||
|
|
Loading…
Reference in New Issue