Fix problems in the sprintf optimizer

llvm-svn: 35754
This commit is contained in:
Chris Lattner 2007-04-07 21:17:51 +00:00
parent bed184cbcf
commit 08c0b8b3c8
1 changed files with 60 additions and 81 deletions

View File

@ -1276,8 +1276,7 @@ public:
if (CI->getNumOperands() != 3 && CI->getNumOperands() != 4) if (CI->getNumOperands() != 3 && CI->getNumOperands() != 4)
return false; return false;
// All the optimizations depend on the length of the second argument and the // All the optimizations depend on the format string.
// fact that it is a constant string array. Check that now
uint64_t FormatLen, FormatStartIdx; uint64_t FormatLen, FormatStartIdx;
ConstantArray *CA = 0; ConstantArray *CA = 0;
if (!GetConstantStringInfo(CI->getOperand(2), CA, FormatLen,FormatStartIdx)) if (!GetConstantStringInfo(CI->getOperand(2), CA, FormatLen,FormatStartIdx))
@ -1368,108 +1367,88 @@ public:
SPrintFOptimization() : LibCallOptimization("sprintf", SPrintFOptimization() : LibCallOptimization("sprintf",
"Number of 'sprintf' calls simplified") {} "Number of 'sprintf' calls simplified") {}
/// @brief Make sure that the "fprintf" function has the right prototype /// @brief Make sure that the "sprintf" function has the right prototype
virtual bool ValidateCalledFunction(const Function *f, SimplifyLibCalls &SLC){ virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){
// Just make sure this has at least 2 arguments const FunctionType *FT = F->getFunctionType();
return (f->getReturnType() == Type::Int32Ty && f->arg_size() >= 2); return FT->getNumParams() == 2 && // two fixed arguments.
FT->getParamType(1) == PointerType::get(Type::Int8Ty) &&
FT->getParamType(0) == FT->getParamType(1) &&
isa<IntegerType>(FT->getReturnType());
} }
/// @brief Perform the sprintf optimization. /// @brief Perform the sprintf optimization.
virtual bool OptimizeCall(CallInst *ci, SimplifyLibCalls &SLC) { virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) {
// If the call has more than 3 operands, we can't optimize it // If the call has more than 3 operands, we can't optimize it
if (ci->getNumOperands() > 4 || ci->getNumOperands() < 3) if (CI->getNumOperands() != 3 && CI->getNumOperands() != 4)
return false; return false;
// All the optimizations depend on the length of the second argument and the uint64_t FormatLen, FormatStartIdx;
// fact that it is a constant string array. Check that now ConstantArray *CA = 0;
uint64_t len, StartIdx; if (!GetConstantStringInfo(CI->getOperand(2), CA, FormatLen,FormatStartIdx))
ConstantArray* CA = 0;
if (!GetConstantStringInfo(ci->getOperand(2), CA, len, StartIdx))
return false; return false;
if (ci->getNumOperands() == 3) { if (CI->getNumOperands() == 3) {
if (len == 0) { if (!CA->isCString()) return false;
// If the length is 0, we just need to store a null byte
new StoreInst(ConstantInt::get(Type::Int8Ty,0),ci->getOperand(1),ci);
return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty,0));
}
// Make sure there's no % in the constant array // Make sure there's no % in the constant array
for (unsigned i = 0; i < len; ++i) { std::string S = CA->getAsString();
if (ConstantInt* CI = dyn_cast<ConstantInt>(CA->getOperand(i))) { for (unsigned i = FormatStartIdx, e = S.size(); i != e; ++i)
// Check for the null terminator if (S[i] == '%')
if (CI->getZExtValue() == '%') return false; // we found a format specifier
return false; // we found a %, can't optimize
} else {
return false; // initializer is not constant int, can't optimize
}
}
// Increment length because we want to copy the null byte too
len++;
// sprintf(str,fmt) -> llvm.memcpy(str,fmt,strlen(fmt),1) // sprintf(str,fmt) -> llvm.memcpy(str,fmt,strlen(fmt),1)
Value *args[4] = { Value *MemCpyArgs[] = {
ci->getOperand(1), CI->getOperand(1), CI->getOperand(2),
ci->getOperand(2), ConstantInt::get(SLC.getIntPtrType(), FormatLen+1), // Copy the nul byte
ConstantInt::get(SLC.getIntPtrType(),len),
ConstantInt::get(Type::Int32Ty, 1) ConstantInt::get(Type::Int32Ty, 1)
}; };
new CallInst(SLC.get_memcpy(), args, 4, "", ci); new CallInst(SLC.get_memcpy(), MemCpyArgs, 4, "", CI);
return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty,len)); return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), FormatLen));
} }
// The remaining optimizations require the format string to be length 2 // The remaining optimizations require the format string to be "%s" or "%c".
// "%s" or "%c". if (FormatLen != 2 ||
if (len != 2) cast<ConstantInt>(CA->getOperand(FormatStartIdx))->getZExtValue() !='%')
return false; return false;
// The first character has to be a %
if (ConstantInt* CI = dyn_cast<ConstantInt>(CA->getOperand(0)))
if (CI->getZExtValue() != '%')
return false;
// Get the second character and switch on its value // Get the second character and switch on its value
ConstantInt* CI = dyn_cast<ConstantInt>(CA->getOperand(1)); switch (cast<ConstantInt>(CA->getOperand(1))->getZExtValue()) {
switch (CI->getZExtValue()) { case 'c': {
// sprintf(dest,"%c",chr) -> store chr, dest
Value *V = CastInst::createTruncOrBitCast(CI->getOperand(3),
Type::Int8Ty, "char", CI);
new StoreInst(V, CI->getOperand(1), CI);
Value *Ptr = new GetElementPtrInst(CI->getOperand(1),
ConstantInt::get(Type::Int32Ty, 1),
CI->getOperand(1)->getName()+".end",
CI);
new StoreInst(ConstantInt::get(Type::Int8Ty,0), Ptr, CI);
return ReplaceCallWith(CI, ConstantInt::get(Type::Int32Ty, 1));
}
case 's': { case 's': {
// sprintf(dest,"%s",str) -> llvm.memcpy(dest, str, strlen(str)+1, 1) // sprintf(dest,"%s",str) -> llvm.memcpy(dest, str, strlen(str)+1, 1)
Value *Len = new CallInst(SLC.get_strlen(), Value *Len = new CallInst(SLC.get_strlen(),
CastToCStr(ci->getOperand(3), ci), CastToCStr(CI->getOperand(3), CI),
ci->getOperand(3)->getName()+".len", ci); CI->getOperand(3)->getName()+".len", CI);
Value *Len1 = BinaryOperator::createAdd(Len, Value *UnincLen = Len;
ConstantInt::get(Len->getType(), 1), Len = BinaryOperator::createAdd(Len, ConstantInt::get(Len->getType(), 1),
Len->getName()+"1", ci); Len->getName()+"1", CI);
if (Len1->getType() != SLC.getIntPtrType()) Value *MemcpyArgs[4] = {
Len1 = CastInst::createIntegerCast(Len1, SLC.getIntPtrType(), false, CI->getOperand(1),
Len1->getName(), ci); CastToCStr(CI->getOperand(3), CI),
Value *args[4] = { Len,
CastToCStr(ci->getOperand(1), ci), ConstantInt::get(Type::Int32Ty, 1)
CastToCStr(ci->getOperand(3), ci),
Len1,
ConstantInt::get(Type::Int32Ty,1)
}; };
new CallInst(SLC.get_memcpy(), args, 4, "", ci); new CallInst(SLC.get_memcpy(), MemcpyArgs, 4, "", CI);
// The strlen result is the unincremented number of bytes in the string. // The strlen result is the unincremented number of bytes in the string.
if (!ci->use_empty()) { if (!CI->use_empty()) {
if (Len->getType() != ci->getType()) if (UnincLen->getType() != CI->getType())
Len = CastInst::createIntegerCast(Len, ci->getType(), false, UnincLen = CastInst::createIntegerCast(UnincLen, CI->getType(), false,
Len->getName(), ci); Len->getName(), CI);
ci->replaceAllUsesWith(Len); CI->replaceAllUsesWith(UnincLen);
} }
return ReplaceCallWith(ci, 0); return ReplaceCallWith(CI, 0);
}
case 'c': {
// sprintf(dest,"%c",chr) -> store chr, dest
CastInst* cast = CastInst::createTruncOrBitCast(
ci->getOperand(3), Type::Int8Ty, "char", ci);
new StoreInst(cast, ci->getOperand(1), ci);
GetElementPtrInst* gep = new GetElementPtrInst(ci->getOperand(1),
ConstantInt::get(Type::Int32Ty,1),ci->getOperand(1)->getName()+".end",
ci);
new StoreInst(ConstantInt::get(Type::Int8Ty,0),gep,ci);
return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty, 1));
} }
} }
return false; return false;