From 06c57b594cd1ba36febe07e526b53d3e3382ee56 Mon Sep 17 00:00:00 2001 From: Johannes Doerfert Date: Sun, 20 Sep 2015 15:00:20 +0000 Subject: [PATCH] Allow loops with multiple back edges In order to allow multiple back edges we: - compute the conditions under which each back edge is taken - build the union over all these conditions, thus the condition that any back edge is taken - apply the same logic to the union we applied to a single back edge llvm-svn: 248120 --- polly/lib/Analysis/ScopDetection.cpp | 4 -- polly/lib/Analysis/ScopInfo.cpp | 74 +++++++++++--------- polly/test/ScopInfo/multiple_latch_blocks.ll | 47 +++++++++++++ 3 files changed, 88 insertions(+), 37 deletions(-) create mode 100644 polly/test/ScopInfo/multiple_latch_blocks.ll diff --git a/polly/lib/Analysis/ScopDetection.cpp b/polly/lib/Analysis/ScopDetection.cpp index b3b17310fbb7..e32f267ec274 100644 --- a/polly/lib/Analysis/ScopDetection.cpp +++ b/polly/lib/Analysis/ScopDetection.cpp @@ -699,10 +699,6 @@ bool ScopDetection::isValidInstruction(Instruction &Inst, bool ScopDetection::canUseISLTripCount(Loop *L, DetectionContext &Context) const { - // Ensure the loop has a single back edge. - if (L->getNumBackEdges() != 1) - return false; - // Ensure the loop has valid exiting blocks, otherwise we need to // overapproximate it as a boxed loop. SmallVector ExitingBlocks; diff --git a/polly/lib/Analysis/ScopInfo.cpp b/polly/lib/Analysis/ScopInfo.cpp index 0bb742cfb96e..a8a695e6146b 100644 --- a/polly/lib/Analysis/ScopInfo.cpp +++ b/polly/lib/Analysis/ScopInfo.cpp @@ -1913,28 +1913,7 @@ void Scop::addLoopBoundsToHeaderDomains(LoopInfo &LI, ScopDetection &SD, int LoopDepth = getRelativeLoopDepth(L); assert(LoopDepth >= 0 && "Loop in region should have at least depth one"); - BasicBlock *LatchBB = L->getLoopLatch(); - assert(LatchBB && "TODO implement multiple exit loop handling"); - - isl_set *LatchBBDom = DomainMap[LatchBB]; - isl_set *BackedgeCondition = nullptr; - BasicBlock *HeaderBB = L->getHeader(); - - BranchInst *BI = cast(LatchBB->getTerminator()); - if (BI->isUnconditional()) - BackedgeCondition = isl_set_copy(LatchBBDom); - else { - SmallVector ConditionSets; - int idx = BI->getSuccessor(0) != HeaderBB; - buildConditionSets(*this, BI, L, LatchBBDom, ConditionSets); - - // Free the non back edge condition set as we do not need it. - isl_set_free(ConditionSets[1 - idx]); - - BackedgeCondition = ConditionSets[idx]; - } - isl_set *&HeaderBBDom = DomainMap[HeaderBB]; isl_set *FirstIteration = createFirstIterationDomain(isl_set_get_space(HeaderBBDom), LoopDepth); @@ -1942,23 +1921,52 @@ void Scop::addLoopBoundsToHeaderDomains(LoopInfo &LI, ScopDetection &SD, isl_map *NextIterationMap = createNextIterationMap(isl_set_get_space(HeaderBBDom), LoopDepth); - int LatchLoopDepth = getRelativeLoopDepth(LI.getLoopFor(LatchBB)); - assert(LatchLoopDepth >= LoopDepth); - BackedgeCondition = - isl_set_project_out(BackedgeCondition, isl_dim_set, LoopDepth + 1, - LatchLoopDepth - LoopDepth); + isl_set *UnionBackedgeCondition = + isl_set_empty(isl_set_get_space(HeaderBBDom)); + + SmallVector LatchBlocks; + L->getLoopLatches(LatchBlocks); + + for (BasicBlock *LatchBB : LatchBlocks) { + assert(DomainMap.count(LatchBB)); + isl_set *LatchBBDom = DomainMap[LatchBB]; + isl_set *BackedgeCondition = nullptr; + + BranchInst *BI = cast(LatchBB->getTerminator()); + if (BI->isUnconditional()) + BackedgeCondition = isl_set_copy(LatchBBDom); + else { + SmallVector ConditionSets; + int idx = BI->getSuccessor(0) != HeaderBB; + buildConditionSets(*this, BI, L, LatchBBDom, ConditionSets); + + // Free the non back edge condition set as we do not need it. + isl_set_free(ConditionSets[1 - idx]); + + BackedgeCondition = ConditionSets[idx]; + } + + int LatchLoopDepth = getRelativeLoopDepth(LI.getLoopFor(LatchBB)); + assert(LatchLoopDepth >= LoopDepth); + BackedgeCondition = + isl_set_project_out(BackedgeCondition, isl_dim_set, LoopDepth + 1, + LatchLoopDepth - LoopDepth); + UnionBackedgeCondition = + isl_set_union(UnionBackedgeCondition, BackedgeCondition); + } isl_map *ForwardMap = isl_map_lex_le(isl_set_get_space(HeaderBBDom)); for (int i = 0; i < LoopDepth; i++) ForwardMap = isl_map_equate(ForwardMap, isl_dim_in, i, isl_dim_out, i); - isl_set *BackedgeConditionComplement = - isl_set_complement(BackedgeCondition); - BackedgeConditionComplement = isl_set_lower_bound_si( - BackedgeConditionComplement, isl_dim_set, LoopDepth, 0); - BackedgeConditionComplement = - isl_set_apply(BackedgeConditionComplement, ForwardMap); - HeaderBBDom = isl_set_subtract(HeaderBBDom, BackedgeConditionComplement); + isl_set *UnionBackedgeConditionComplement = + isl_set_complement(UnionBackedgeCondition); + UnionBackedgeConditionComplement = isl_set_lower_bound_si( + UnionBackedgeConditionComplement, isl_dim_set, LoopDepth, 0); + UnionBackedgeConditionComplement = + isl_set_apply(UnionBackedgeConditionComplement, ForwardMap); + HeaderBBDom = + isl_set_subtract(HeaderBBDom, UnionBackedgeConditionComplement); auto Parts = partitionSetParts(HeaderBBDom, LoopDepth); diff --git a/polly/test/ScopInfo/multiple_latch_blocks.ll b/polly/test/ScopInfo/multiple_latch_blocks.ll new file mode 100644 index 000000000000..e0594167267c --- /dev/null +++ b/polly/test/ScopInfo/multiple_latch_blocks.ll @@ -0,0 +1,47 @@ +; RUN: opt %loadPolly -analyze -polly-scops -polly-detect-unprofitable < %s | FileCheck %s +; +; CHECK: Domain := +; CHECK: [N, P] -> { Stmt_if_end[i0] : (i0 >= 1 + P and i0 >= 0 and i0 <= -1 + N) or (i0 >= 0 and i0 <= -1 + P and i0 <= -1 + N) }; +; +; void f(int *A, int N, int P, int Q) { +; for (int i = 0; i < N; i++) { +; if (i == P) +; continue; +; A[i]++; +; } +; } +; +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" + +define void @f(i32* %A, i32 %N, i32 %P, i32 %Q) { +entry: + %tmp = sext i32 %N to i64 + br label %for.cond + +for.cond: ; preds = %for.inc, %entry + %indvars.iv = phi i64 [ %indvars.iv.next, %if.then ], [ %indvars.iv.next, %for.inc ], [ 0, %entry ] + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %cmp = icmp slt i64 %indvars.iv, %tmp + br i1 %cmp, label %for.body, label %for.end + +for.body: ; preds = %for.cond + %tmp1 = trunc i64 %indvars.iv to i32 + %cmp1 = icmp eq i32 %tmp1, %P + br i1 %cmp1, label %if.then, label %if.end + +if.then: ; preds = %for.body + br label %for.cond + +if.end: ; preds = %for.body + %arrayidx = getelementptr inbounds i32, i32* %A, i64 %indvars.iv + %tmp2 = load i32, i32* %arrayidx, align 4 + %inc = add nsw i32 %tmp2, 1 + store i32 %inc, i32* %arrayidx, align 4 + br label %for.inc + +for.inc: ; preds = %if.end, %if.then + br label %for.cond + +for.end: ; preds = %for.cond + ret void +}