* Add code to reduce multiplies by constant integers to shifts, adds and

subtracts. This is a very rough and nasty implementation of Lefevre's
  "pattern finding" algorithm. With a few small changes though, it should
  end up beating most other methods in common use, regardless of the size
  of the constant (currently, it's often one or two shifts worse)

  TODO: rewrite it so it's not hideously ugly (this is a translation from
        perl, which doesn't help ;)
        bypass most of it for multiplies by 2^n+1
	(eventually) teach it that some combinations of shift+add are
	cheaper than others (e.g. shladd on ia64, scaled adds on alpha)
	get it to try multiple booth encodings in search of the cheapest
	routine
	make it work for negative constants

  This is hacked up as a DAG->DAG transform, so once I clean it up I hope
  it'll be pulled out of here and put somewhere else. The only thing backends
  should really have to worry about for now is where to draw the line
  between using this code vs. going ahead and doing an integer multiply
  anyway.

llvm-svn: 21560
This commit is contained in:
Duraid Madina 2005-04-26 07:23:02 +00:00
parent 76dab9a523
commit 81ebb57771
1 changed files with 439 additions and 15 deletions

View File

@ -28,6 +28,7 @@
#include "llvm/Support/MathExtras.h"
#include "llvm/ADT/Statistic.h"
#include <set>
#include <map>
#include <algorithm>
using namespace llvm;
@ -412,6 +413,9 @@ namespace {
/// IA64Lowering - This object fully describes how to lower LLVM code to an
/// IA64-specific SelectionDAG.
IA64TargetLowering IA64Lowering;
SelectionDAG *ISelDAG; // Hack to support us having a dag->dag transform
// for sdiv and udiv until it is put into the future
// dag combiner
/// ExprMap - As shared expressions are codegen'd, we keep track of which
/// vreg the value is produced in, so we only emit one copy of each compiled
@ -420,8 +424,8 @@ namespace {
std::set<SDOperand> LoweredTokens;
public:
ISel(TargetMachine &TM) : SelectionDAGISel(IA64Lowering), IA64Lowering(TM) {
}
ISel(TargetMachine &TM) : SelectionDAGISel(IA64Lowering), IA64Lowering(TM),
ISelDAG(0) { }
/// InstructionSelectBasicBlock - This callback is invoked by
/// SelectionDAGISel when it has created a SelectionDAG for us to codegen.
@ -429,6 +433,9 @@ namespace {
unsigned SelectExpr(SDOperand N);
void Select(SDOperand N);
// a dag->dag to transform mul-by-constant-int to shifts+adds/subs
SDOperand BuildConstmulSequence(SDOperand N);
};
}
@ -437,11 +444,419 @@ namespace {
void ISel::InstructionSelectBasicBlock(SelectionDAG &DAG) {
// Codegen the basic block.
ISelDAG = &DAG;
Select(DAG.getRoot());
// Clear state used for selection.
ExprMap.clear();
LoweredTokens.clear();
ISelDAG = 0;
}
const char sign[2]={'+','-'};
// strip leading '0' characters from a string
void munchLeadingZeros(std::string& inString) {
while(inString.c_str()[0]=='0') {
inString.erase(0, 1);
}
}
// strip trailing '0' characters from a string
void munchTrailingZeros(std::string& inString) {
int curPos=inString.length()-1;
while(inString.c_str()[curPos]=='0') {
inString.erase(curPos, 1);
curPos--;
}
}
// return how many consecutive '0' characters are at the end of a string
unsigned int countTrailingZeros(std::string& inString) {
int curPos=inString.length()-1;
unsigned int zeroCount=0;
// assert goes here
while(inString.c_str()[curPos--]=='0') {
zeroCount++;
}
return zeroCount;
}
// booth encode a string of '1' and '0' characters (returns string of 'P' (+1)
// '0' and 'N' (-1) characters)
void boothEncode(std::string inString, std::string& boothEncodedString) {
int curpos=0;
int replacements=0;
int lim=inString.size();
while(curpos<lim) {
if(inString[curpos]=='1') { // if we see a '1', look for a run of them
int runlength=0;
std::string replaceString="N";
// find the run length
for(;inString[curpos+runlength]=='1';runlength++) ;
for(int i=0; i<runlength-1; i++)
replaceString+="0";
replaceString+="1";
if(runlength>1) {
inString.replace(curpos, runlength+1, replaceString);
curpos+=runlength-1;
} else
curpos++;
} else { // a zero, we just keep chugging along
curpos++;
}
}
// clean up (trim the string, reverse it and turn '1's into 'P's)
munchTrailingZeros(inString);
boothEncodedString="";
for(int i=inString.size()-1;i>=0;i--)
if(inString[i]=='1')
boothEncodedString+="P";
else
boothEncodedString+=inString[i];
}
struct shiftaddblob { // this encodes stuff like (x=) "A << B [+-] C << D"
unsigned firstVal; // A
unsigned firstShift; // B
unsigned secondVal; // C
unsigned secondShift; // D
bool isSub;
};
/* this implements Lefevre's "pattern-based" constant multiplication,
* see "Multiplication by an Integer Constant", INRIA report 1999-06
*
* TODO: implement a method to try rewriting P0N<->0PP / N0P<->0NN
* to get better booth encodings - this does help in practice
* TODO: weight shifts appropriately (most architectures can't
* fuse a shift and an add for arbitrary shift amounts) */
unsigned lefevre(const std::string inString,
std::vector<struct shiftaddblob> &ops) {
std::string retstring;
std::string s = inString;
munchTrailingZeros(s);
int length=s.length()-1;
if(length==0) {
return(0);
}
std::vector<int> p,n;
for(int i=0; i<=length; i++) {
if (s.c_str()[length-i]=='P') {
p.push_back(i);
} else if (s.c_str()[length-i]=='N') {
n.push_back(i);
}
}
std::string t, u;
int c,f;
std::map<const int, int> w;
for(int i=0; i<p.size(); i++) {
for(int j=0; j<i; j++) {
w[p[i]-p[j]]++;
}
}
for(int i=1; i<n.size(); i++) {
for(int j=0; j<i; j++) {
w[n[i]-n[j]]++;
}
}
for(int i=0; i<p.size(); i++) {
for(int j=0; j<n.size(); j++) {
w[-abs(p[i]-n[j])]++;
}
}
std::map<const int, int>::const_iterator ii;
std::vector<int> d;
std::multimap<int, int> sorted_by_value;
for(ii = w.begin(); ii!=w.end(); ii++)
sorted_by_value.insert(std::pair<int, int>((*ii).second,(*ii).first));
for (std::multimap<int, int>::iterator it = sorted_by_value.begin();
it != sorted_by_value.end(); ++it) {
d.push_back((*it).second);
}
int int_W=0;
int int_d;
while(d.size()>0 && (w[int_d=d.back()] > int_W)) {
d.pop_back();
retstring=s; // hmmm
int x=0;
int z=abs(int_d)-1;
if(int_d>0) {
for(int base=0; base<retstring.size(); base++) {
if( ((base+z+1) < retstring.size()) &&
retstring.c_str()[base]=='P' &&
retstring.c_str()[base+z+1]=='P')
{
// match
x++;
retstring.replace(base, 1, "0");
retstring.replace(base+z+1, 1, "p");
}
}
for(int base=0; base<retstring.size(); base++) {
if( ((base+z+1) < retstring.size()) &&
retstring.c_str()[base]=='N' &&
retstring.c_str()[base+z+1]=='N')
{
// match
x++;
retstring.replace(base, 1, "0");
retstring.replace(base+z+1, 1, "n");
}
}
} else {
for(int base=0; base<retstring.size(); base++) {
if( ((base+z+1) < retstring.size()) &&
((retstring.c_str()[base]=='P' &&
retstring.c_str()[base+z+1]=='N') ||
(retstring.c_str()[base]=='N' &&
retstring.c_str()[base+z+1]=='P')) ) {
// match
x++;
if(retstring.c_str()[base]=='P') {
retstring.replace(base, 1, "0");
retstring.replace(base+z+1, 1, "p");
} else { // retstring[base]=='N'
retstring.replace(base, 1, "0");
retstring.replace(base+z+1, 1, "n");
}
}
}
}
if(x>int_W) {
int_W = x;
t = retstring;
c = int_d; // tofix
}
} d.pop_back(); // hmm
u = t;
for(int i=0; i<t.length(); i++) {
if(t.c_str()[i]=='p' || t.c_str()[i]=='n')
t.replace(i, 1, "0");
}
/* and now for something completely different:
//\\\\\\ ` \`..(@)
: \\\\ (@)(@) / /(@)
\ ~L~ )\\\\ \ \ '\(__``',..
/\_~ / |||| \ , | ~~~--/
////| |//// // /
||||^ ~~~~~~--------~/ /
/ ( ( _____---~~~~~\| /
( )| / / / /
\^\ \____/ / \ /
\ \/ \ \ /
)) ) ' ~
| | ` ,
|______|
| || | ,
( || |
\ | | ,
\| /
/_^||
*/
for(int i=0; i<u.length(); i++) {
if(u.c_str()[i]=='P' || u.c_str()[i]=='N')
u.replace(i, 1, "0");
if(u.c_str()[i]=='p')
u.replace(i, 1, "P");
if(u.c_str()[i]=='n')
u.replace(i, 1, "N");
}
if( c<0 ) {
f=1;
c=-c;
} else
f=0;
bool hit=true;
for(int i=0; i<u.length()-1; i++) {
if(u.c_str()[i]!='0')
hit=false;
}
if(u.c_str()[u.length()-1]!='N')
hit=false;
int g=0;
if(hit) {
g=1;
for(int p=0; p<u.length(); p++) {
bool isP=(u.c_str()[p]=='P');
bool isN=(u.c_str()[p]=='N');
if(isP)
u.replace(p, 1, "N");
if(isN)
u.replace(p, 1, "P");
}
}
munchLeadingZeros(u);
int i = lefevre(u, ops);
shiftaddblob blob;
blob.firstVal=i; blob.firstShift=c;
blob.isSub=f;
blob.secondVal=i; blob.secondShift=0;
ops.push_back(blob);
i = ops.size();
munchLeadingZeros(t);
if(t.length()==0)
return i;
if(t.c_str()[0]!='P') {
g=2;
for(int p=0; p<t.length(); p++) {
bool isP=(t.c_str()[p]=='P');
bool isN=(t.c_str()[p]=='N');
if(isP)
t.replace(p, 1, "N");
if(isN)
t.replace(p, 1, "P");
}
}
int j = lefevre(t, ops);
int trail=countTrailingZeros(u);
blob.secondVal=i; blob.secondShift=trail;
trail=countTrailingZeros(t);
blob.firstVal=j; blob.firstShift=trail;
switch(g) {
case 0:
blob.isSub=false; // first + second
break;
case 1:
blob.isSub=true; // first - second
break;
case 2:
blob.isSub=true; // second - first
int tmpval, tmpshift;
tmpval=blob.firstVal;
tmpshift=blob.firstShift;
blob.firstVal=blob.secondVal;
blob.firstShift=blob.secondShift;
blob.secondVal=tmpval;
blob.secondShift=tmpshift;
break;
//assert
}
ops.push_back(blob);
return ops.size();
}
SDOperand ISel::BuildConstmulSequence(SDOperand N) {
//FIXME: we should shortcut this stuff for multiplies by 2^n+1
// in particular, *3 is nicer as *2+1, not *4-1
int64_t constant=cast<ConstantSDNode>(N.getOperand(1))->getValue();
bool flippedSign;
unsigned preliminaryShift=0;
assert(constant > 0 && "erk, don't multiply by zero or negative nums\n");
// first, we make the constant to multiply by positive
if(constant<0) {
constant=-constant;
flippedSign=true;
} else {
flippedSign=false;
}
// next, we make it odd.
for(; (constant%2==0); preliminaryShift++)
constant>>=1;
//OK, we have a positive, odd number of 64 bits or less. Convert it
//to a binary string, constantString[0] is the LSB
char constantString[65];
for(int i=0; i<64; i++)
constantString[i]='0'+((constant>>i)&0x1);
constantString[64]=0;
// now, Booth encode it
std::string boothEncodedString;
boothEncode(constantString, boothEncodedString);
std::vector<struct shiftaddblob> ops;
// do the transformation, filling out 'ops'
lefevre(boothEncodedString, ops);
SDOperand results[ops.size()]; // temporary results (of adds/subs of shifts)
// now turn 'ops' into DAG bits
for(int i=0; i<ops.size(); i++) {
SDOperand amt = ISelDAG->getConstant(ops[i].firstShift, MVT::i64);
SDOperand val = (ops[i].firstVal == 0) ? N.getOperand(0) :
results[ops[i].firstVal-1];
SDOperand left = ISelDAG->getNode(ISD::SHL, MVT::i64, val, amt);
amt = ISelDAG->getConstant(ops[i].secondShift, MVT::i64);
val = (ops[i].secondVal == 0) ? N.getOperand(0) :
results[ops[i].secondVal-1];
SDOperand right = ISelDAG->getNode(ISD::SHL, MVT::i64, val, amt);
if(ops[i].isSub)
results[i] = ISelDAG->getNode(ISD::SUB, MVT::i64, left, right);
else
results[i] = ISelDAG->getNode(ISD::ADD, MVT::i64, left, right);
}
// don't forget flippedSign and preliminaryShift!
SDOperand finalresult;
if(preliminaryShift) {
SDOperand finalshift = ISelDAG->getConstant(preliminaryShift, MVT::i64);
finalresult = ISelDAG->getNode(ISD::SHL, MVT::i64,
results[ops.size()-1], finalshift);
} else { // there was no preliminary divide-by-power-of-2 required
finalresult = results[ops.size()-1];
}
return finalresult;
}
/// ExactLog2 - This function solves for (Val == 1 << (N-1)) and returns N. It
@ -909,23 +1324,32 @@ assert(0 && "hmm, ISD::SIGN_EXTEND: shouldn't ever be reached. bad luck!\n");
}
case ISD::MUL: {
Tmp1 = SelectExpr(N.getOperand(0));
Tmp2 = SelectExpr(N.getOperand(1));
if(DestType != MVT::f64) { // TODO: speed!
// boring old integer multiply with xma
unsigned TempFR1=MakeReg(MVT::f64);
unsigned TempFR2=MakeReg(MVT::f64);
unsigned TempFR3=MakeReg(MVT::f64);
BuildMI(BB, IA64::SETFSIG, 1, TempFR1).addReg(Tmp1);
BuildMI(BB, IA64::SETFSIG, 1, TempFR2).addReg(Tmp2);
BuildMI(BB, IA64::XMAL, 1, TempFR3).addReg(TempFR1).addReg(TempFR2)
.addReg(IA64::F0);
BuildMI(BB, IA64::GETFSIG, 1, Result).addReg(TempFR3);
if(N.getOperand(1).getOpcode() != ISD::Constant) { // if not a const mul
// boring old integer multiply with xma
Tmp1 = SelectExpr(N.getOperand(0));
Tmp2 = SelectExpr(N.getOperand(1));
unsigned TempFR1=MakeReg(MVT::f64);
unsigned TempFR2=MakeReg(MVT::f64);
unsigned TempFR3=MakeReg(MVT::f64);
BuildMI(BB, IA64::SETFSIG, 1, TempFR1).addReg(Tmp1);
BuildMI(BB, IA64::SETFSIG, 1, TempFR2).addReg(Tmp2);
BuildMI(BB, IA64::XMAL, 1, TempFR3).addReg(TempFR1).addReg(TempFR2)
.addReg(IA64::F0);
BuildMI(BB, IA64::GETFSIG, 1, Result).addReg(TempFR3);
return Result; // early exit
} else { // we are multiplying by an integer constant! yay
return Reg = SelectExpr(BuildConstmulSequence(N)); // avert your eyes!
}
}
else // floating point multiply
else { // floating point multiply
Tmp1 = SelectExpr(N.getOperand(0));
Tmp2 = SelectExpr(N.getOperand(1));
BuildMI(BB, IA64::FMPY, 2, Result).addReg(Tmp1).addReg(Tmp2);
return Result;
return Result;
}
}
case ISD::SUB: {