diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp index 1ad610d80af1..68cf7310c759 100644 --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -26,16 +26,11 @@ #include "Support/Debug.h" #include "Support/StringExtras.h" #include -#include -#include +#include using namespace llvm; namespace { - inline bool contains(const std::vector &V, const BasicBlock *BB){ - return std::find(V.begin(), V.end(), BB) != V.end(); - } - /// getFunctionArg - Return a pointer to F's ARGNOth argument. /// Argument *getFunctionArg(Function *F, unsigned argno) { @@ -49,19 +44,16 @@ namespace { typedef std::vector > PhiValChangesTy; typedef std::map PhiVal2ArgTy; PhiVal2ArgTy PhiVal2Arg; - + std::set BlocksToExtract; public: Function *ExtractCodeRegion(const std::vector &code); private: - void findInputsOutputs(const std::vector &code, - Values &inputs, - Values &outputs, + void findInputsOutputs(Values &inputs, Values &outputs, BasicBlock *newHeader, BasicBlock *newRootNode); void processPhiNodeInputs(PHINode *Phi, - const std::vector &code, Values &inputs, BasicBlock *newHeader, BasicBlock *newRootNode); @@ -71,15 +63,12 @@ namespace { Function *constructFunction(const Values &inputs, const Values &outputs, BasicBlock *newRootNode, BasicBlock *newHeader, - const std::vector &code, Function *oldFunction, Module *M); - void moveCodeToFunction(const std::vector &code, - Function *newFunction); + void moveCodeToFunction(Function *newFunction); void emitCallAndSwitchStatement(Function *newFunction, BasicBlock *newHeader, - const std::vector &code, Values &inputs, Values &outputs); @@ -87,7 +76,6 @@ namespace { } void CodeExtractor::processPhiNodeInputs(PHINode *Phi, - const std::vector &code, Values &inputs, BasicBlock *codeReplacer, BasicBlock *newFuncRoot) @@ -102,11 +90,11 @@ void CodeExtractor::processPhiNodeInputs(PHINode *Phi, for (unsigned i = 0, e = Phi->getNumIncomingValues(); i != e; ++i) { Value *phiVal = Phi->getIncomingValue(i); if (Instruction *Inst = dyn_cast(phiVal)) { - if (contains(code, Inst->getParent())) { - if (!contains(code, Phi->getIncomingBlock(i))) + if (BlocksToExtract.count(Inst->getParent())) { + if (!BlocksToExtract.count(Phi->getIncomingBlock(i))) IValEBB.push_back(i); } else { - if (contains(code, Phi->getIncomingBlock(i))) + if (BlocksToExtract.count(Phi->getIncomingBlock(i))) EValIBB.push_back(i); else EValEBB.push_back(i); @@ -114,11 +102,11 @@ void CodeExtractor::processPhiNodeInputs(PHINode *Phi, } else if (Constant *Const = dyn_cast(phiVal)) { // Constants are internal, but considered `external' if they are coming // from an external block. - if (!contains(code, Phi->getIncomingBlock(i))) + if (!BlocksToExtract.count(Phi->getIncomingBlock(i))) EValEBB.push_back(i); } else if (Argument *Arg = dyn_cast(phiVal)) { // arguments are external - if (contains(code, Phi->getIncomingBlock(i))) + if (BlocksToExtract.count(Phi->getIncomingBlock(i))) EValIBB.push_back(i); else EValEBB.push_back(i); @@ -184,14 +172,13 @@ void CodeExtractor::processPhiNodeInputs(PHINode *Phi, } -void CodeExtractor::findInputsOutputs(const std::vector &code, - Values &inputs, +void CodeExtractor::findInputsOutputs(Values &inputs, Values &outputs, BasicBlock *newHeader, BasicBlock *newRootNode) { - for (std::vector::const_iterator ci = code.begin(), - ce = code.end(); ci != ce; ++ci) { + for (std::set::const_iterator ci = BlocksToExtract.begin(), + ce = BlocksToExtract.end(); ci != ce; ++ci) { BasicBlock *BB = *ci; for (BasicBlock::iterator BBi = BB->begin(), BBe = BB->end(); BBi != BBe; ++BBi) { @@ -200,7 +187,7 @@ void CodeExtractor::findInputsOutputs(const std::vector &code, if (Instruction *I = dyn_cast(&*BBi)) { // If it's a phi node if (PHINode *Phi = dyn_cast(I)) { - processPhiNodeInputs(Phi, code, inputs, newHeader, newRootNode); + processPhiNodeInputs(Phi, inputs, newHeader, newRootNode); } else { // All other instructions go through the generic input finder // Loop over the operands of each instruction (inputs) @@ -208,7 +195,7 @@ void CodeExtractor::findInputsOutputs(const std::vector &code, op != opE; ++op) { if (Instruction *opI = dyn_cast(op->get())) { // Check if definition of this operand is within the loop - if (!contains(code, opI->getParent())) { + if (!BlocksToExtract.count(opI->getParent())) { // add this operand to the inputs inputs.push_back(opI); } @@ -220,7 +207,7 @@ void CodeExtractor::findInputsOutputs(const std::vector &code, for (Value::use_iterator use = I->use_begin(), useE = I->use_end(); use != useE; ++use) { if (Instruction* inst = dyn_cast(*use)) { - if (!contains(code, inst->getParent())) { + if (!BlocksToExtract.count(inst->getParent())) { // add this op to the outputs outputs.push_back(I); } @@ -276,11 +263,10 @@ Function *CodeExtractor::constructFunction(const Values &inputs, const Values &outputs, BasicBlock *newRootNode, BasicBlock *newHeader, - const std::vector &code, Function *oldFunction, Module *M) { DEBUG(std::cerr << "inputs: " << inputs.size() << "\n"); DEBUG(std::cerr << "outputs: " << outputs.size() << "\n"); - BasicBlock *header = code[0]; + BasicBlock *header = *BlocksToExtract.begin(); // This function returns unsigned, outputs will go back by reference. Type *retTy = Type::UShortTy; @@ -327,7 +313,7 @@ Function *CodeExtractor::constructFunction(const Values &inputs, for (std::vector::iterator use = Users.begin(), useE = Users.end(); use != useE; ++use) if (Instruction* inst = dyn_cast(*use)) - if (contains(code, inst->getParent())) + if (BlocksToExtract.count(inst->getParent())) inst->replaceUsesOfWith(inputs[i], getFunctionArg(newFunction, i)); } @@ -339,7 +325,7 @@ Function *CodeExtractor::constructFunction(const Values &inputs, i != e; ++i) { if (BranchInst *inst = dyn_cast(*i)) { BasicBlock *BB = inst->getParent(); - if (!contains(code, BB) && BB->getParent() == oldFunction) { + if (!BlocksToExtract.count(BB) && BB->getParent() == oldFunction) { // The BasicBlock which contains the branch is not in the region // modify the branch target to a new block inst->replaceUsesOfWith(header, newHeader); @@ -350,29 +336,25 @@ Function *CodeExtractor::constructFunction(const Values &inputs, return newFunction; } -void CodeExtractor::moveCodeToFunction(const std::vector &code, - Function *newFunction) +void CodeExtractor::moveCodeToFunction(Function *newFunction) { - Function *oldFunc = code[0]->getParent(); + Function *oldFunc = (*BlocksToExtract.begin())->getParent(); Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList(); Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList(); - for (std::vector::const_iterator i = code.begin(), e =code.end(); - i != e; ++i) { - BasicBlock *BB = *i; - + for (std::set::const_iterator i = BlocksToExtract.begin(), + e = BlocksToExtract.end(); i != e; ++i) { // Delete the basic block from the old function, and the list of blocks - oldBlocks.remove(BB); + oldBlocks.remove(*i); // Insert this basic block into the new function - newBlocks.push_back(BB); + newBlocks.push_back(*i); } } void CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, - const std::vector &code, Values &inputs, Values &outputs) { @@ -399,7 +381,7 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, for (std::vector::iterator use = Users.begin(), useE =Users.end(); use != useE; ++use) { if (Instruction* inst = dyn_cast(*use)) { - if (!contains(code, inst->getParent())) { + if (!BlocksToExtract.count(inst->getParent())) { inst->replaceUsesOfWith(*i, load); } } @@ -425,8 +407,8 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, // Since there may be multiple exits from the original region, make the new // function return an unsigned, switch on that number unsigned switchVal = 0; - for (std::vector::const_iterator i =code.begin(), e = code.end(); - i != e; ++i) { + for (std::set::const_iterator i = BlocksToExtract.begin(), + e = BlocksToExtract.end(); i != e; ++i) { BasicBlock *BB = *i; // rewrite the terminator of the original BasicBlock @@ -436,16 +418,14 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, // Restore values just before we exit // FIXME: Use a GetElementPtr to bunch the outputs in a struct for (unsigned outIdx = 0, outE = outputs.size(); outIdx != outE; ++outIdx) - { new StoreInst(outputs[outIdx], getFunctionArg(newFunction, outIdx), brInst); - } // Rewrite branches into exits which return a value based on which // exit we take from this function if (brInst->isUnconditional()) { - if (!contains(code, brInst->getSuccessor(0))) { + if (!BlocksToExtract.count(brInst->getSuccessor(0))) { ConstantUInt *brVal = ConstantUInt::get(Type::UShortTy, switchVal); ReturnInst *newRet = new ReturnInst(brVal); // add a new target to the switch @@ -461,7 +441,7 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, // to two new blocks, each of which returns a different code. for (unsigned idx = 0; idx < 2; ++idx) { BasicBlock *oldTarget = brInst->getSuccessor(idx); - if (!contains(code, oldTarget)) { + if (!BlocksToExtract.count(oldTarget)) { // add a new basic block which returns the appropriate value BasicBlock *newTarget = new BasicBlock("newTarget", newFunction); ConstantUInt *brVal = ConstantUInt::get(Type::UShortTy, switchVal); @@ -475,13 +455,15 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, } } } + } else if (SwitchInst *swTerm = dyn_cast(term)) { + + assert(0 && "Cannot handle switch instructions just yet."); + } else if (ReturnInst *retTerm = dyn_cast(term)) { assert(0 && "Cannot handle return instructions just yet."); // FIXME: what if the terminator is a return!??! // Need to rewrite: add new basic block, move the return there // treat the original as an unconditional branch to that basicblock - } else if (SwitchInst *swTerm = dyn_cast(term)) { - assert(0 && "Cannot handle switch instructions just yet."); } else if (InvokeInst *invInst = dyn_cast(term)) { assert(0 && "Cannot handle invoke instructions just yet."); } else { @@ -514,7 +496,8 @@ Function *CodeExtractor::ExtractCodeRegion(const std::vector &code) // * Add allocas for defs, pass as args by reference // * Pass in uses as args // 3) Move code region, add call instr to func - // + // + BlocksToExtract.insert(code.begin(), code.end()); Values inputs, outputs; @@ -548,19 +531,18 @@ Function *CodeExtractor::ExtractCodeRegion(const std::vector &code) // blocks moving to a new function. // SOLUTION: move Phi nodes out of the loop header into the codeReplacer, pass // the values as parameters to the function - findInputsOutputs(code, inputs, outputs, codeReplacer, newFuncRoot); + findInputsOutputs(inputs, outputs, codeReplacer, newFuncRoot); // Step 2: Construct new function based on inputs/outputs, // Add allocas for all defs Function *newFunction = constructFunction(inputs, outputs, newFuncRoot, - codeReplacer, code, - oldFunction, module); + codeReplacer, oldFunction, module); rewritePhiNodes(newFunction, newFuncRoot); - emitCallAndSwitchStatement(newFunction, codeReplacer, code, inputs, outputs); + emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs); - moveCodeToFunction(code, newFunction); + moveCodeToFunction(newFunction); DEBUG(if (verifyFunction(*newFunction)) abort()); return newFunction;