diff --git a/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp index b130689081d5..524b0104f26f 100644 --- a/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp @@ -69,7 +69,8 @@ namespace { void VersionLoop(Value *LIC, Loop *L, Loop *&Out1, Loop *&Out2); BasicBlock *SplitBlock(BasicBlock *BB, bool SplitAtTop); void RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, bool Val); - void UnswitchTrivialCondition(Loop *L, Value *Cond, ConstantBool *LoopCond); + void UnswitchTrivialCondition(Loop *L, Value *Cond, bool EntersLoopOnCond, + BasicBlock *ExitBlock); }; RegisterOpt X("loop-unswitch", "Unswitch loops"); } @@ -123,38 +124,41 @@ static bool LoopValuesUsedOutsideLoop(Loop *L) { /// If this is a trivial condition, return ConstantBool::True if the loop body /// runs when the condition is true, False if the loop body executes when the /// condition is false. Otherwise, return null to indicate a complex condition. -static ConstantBool *IsTrivialUnswitchCondition(Loop *L, Value *Cond) { +static bool IsTrivialUnswitchCondition(Loop *L, Value *Cond, + bool *CondEntersLoop = 0, + BasicBlock **LoopExit = 0) { BasicBlock *Header = L->getHeader(); BranchInst *HeaderTerm = dyn_cast(Header->getTerminator()); - ConstantBool *RetVal = 0; // If the header block doesn't end with a conditional branch on Cond, we can't // handle it. if (!HeaderTerm || !HeaderTerm->isConditional() || HeaderTerm->getCondition() != Cond) - return 0; + return false; // Check to see if the conditional branch goes to the latch block. If not, // it's not trivial. This also determines the value of Cond that will execute // the loop. BasicBlock *Latch = L->getLoopLatch(); - if (HeaderTerm->getSuccessor(1) == Latch) - RetVal = ConstantBool::True; - else if (HeaderTerm->getSuccessor(0) == Latch) - RetVal = ConstantBool::False; + if (HeaderTerm->getSuccessor(1) == Latch) { + if (CondEntersLoop) *CondEntersLoop = true; + } else if (HeaderTerm->getSuccessor(0) == Latch) + if (CondEntersLoop) *CondEntersLoop = false; else - return 0; // Doesn't branch to latch block. + return false; // Doesn't branch to latch block. // The latch block must end with a conditional branch where one edge goes to // the header (this much we know) and one edge goes OUT of the loop. BranchInst *LatchBranch = dyn_cast(Latch->getTerminator()); - if (!LatchBranch || !LatchBranch->isConditional()) return 0; + if (!LatchBranch || !LatchBranch->isConditional()) return false; if (LatchBranch->getSuccessor(0) == Header) { - if (L->contains(LatchBranch->getSuccessor(1))) return 0; + if (L->contains(LatchBranch->getSuccessor(1))) return false; + if (LoopExit) *LoopExit = LatchBranch->getSuccessor(1); } else { assert(LatchBranch->getSuccessor(1) == Header); - if (L->contains(LatchBranch->getSuccessor(0))) return 0; + if (L->contains(LatchBranch->getSuccessor(0))) return false; + if (LoopExit) *LoopExit = LatchBranch->getSuccessor(0); } // We already know that nothing uses any scalar values defined inside of this @@ -163,11 +167,11 @@ static ConstantBool *IsTrivialUnswitchCondition(Loop *L, Value *Cond) { // part of the loop that the code *would* execute. for (BasicBlock::iterator I = Header->begin(), E = Header->end(); I != E; ++I) if (I->mayWriteToMemory()) - return 0; + return false; for (BasicBlock::iterator I = Latch->begin(), E = Latch->end(); I != E; ++I) if (I->mayWriteToMemory()) - return 0; - return RetVal; + return false; + return true; } /// getLoopUnswitchCost - Return the cost (code size growth) that will happen if @@ -257,8 +261,12 @@ bool LoopUnswitch::visitLoop(Loop *L) { // If this is a trivial condition to unswitch (which results in no code // duplication), do it now. - if (ConstantBool *V = IsTrivialUnswitchCondition(L, BI->getCondition())) { - UnswitchTrivialCondition(L, BI->getCondition(), V); + bool EntersLoopOnCond; + BasicBlock *ExitBlock; + if (IsTrivialUnswitchCondition(L, BI->getCondition(), &EntersLoopOnCond, + &ExitBlock)) { + UnswitchTrivialCondition(L, BI->getCondition(), + EntersLoopOnCond, ExitBlock); NewLoop1 = L; } else { VersionLoop(BI->getCondition(), L, NewLoop1, NewLoop2); @@ -345,7 +353,8 @@ static Loop *CloneLoop(Loop *L, Loop *PL, std::map &VM, /// side-effects), unswitch it. This doesn't involve any code duplication, just /// moving the conditional branch outside of the loop and updating loop info. void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, - ConstantBool *LoopCond) { + bool EnterOnCond, + BasicBlock *ExitBlock) { DEBUG(std::cerr << "loop-unswitch: Trivial-Unswitch loop %" << L->getHeader()->getName() << " [" << L->getBlocks().size() << " blocks] in Function " << L->getHeader()->getParent()->getName() @@ -358,26 +367,23 @@ void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, BasicBlock *NewPH = SplitBlock(OrigPH, false); // Now that we have a place to insert the conditional branch, create a place - // to branch to: this is the non-header successor of the latch block. - BranchInst *LatchBranch =cast(L->getLoopLatch()->getTerminator()); - BasicBlock *ExitBlock = - LatchBranch->getSuccessor(LatchBranch->getSuccessor(0) == L->getHeader()); - assert(!L->contains(ExitBlock) && "Exit block is in the loop?"); + // to branch to: this is the exit block out of the loop that we should + // short-circuit to. // Split this block now, so that the loop maintains its exit block. + assert(!L->contains(ExitBlock) && "Exit block is in the loop?"); BasicBlock *NewExit = SplitBlock(ExitBlock, true); // Okay, now we have a position to branch from and a position to branch to, // insert the new conditional branch. - bool EnterOnTrue = LoopCond->getValue(); - new BranchInst(EnterOnTrue ? NewPH : NewExit, EnterOnTrue ? NewExit : NewPH, + new BranchInst(EnterOnCond ? NewPH : NewExit, EnterOnCond ? NewExit : NewPH, Cond, OrigPH->getTerminator()); OrigPH->getTerminator()->eraseFromParent(); // Now that we know that the loop is never entered when this condition is a // particular value, rewrite the loop with this info. We know that this will // at least eliminate the old branch. - RewriteLoopBodyWithConditionConstant(L, Cond, EnterOnTrue); + RewriteLoopBodyWithConditionConstant(L, Cond, EnterOnCond); ++NumUnswitched; }